diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md
new file mode 100644
index 000000000000..b78219f8b214
--- /dev/null
+++ b/docs/source/en/api/pipelines/photon.md
@@ -0,0 +1,172 @@
+
+
+# PhotonPipeline
+
+
+

+
+
+Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression.
+
+Key features:
+
+- **Simplified MMDIT architecture**: Uses a simplified MMDIT architecture for image generation where text tokens are not updated through the transformer blocks
+- **Flow Matching**: Employs flow matching with discrete scheduling for efficient sampling
+- **Flexible VAE Support**: Compatible with both Flux VAE (8x compression, 16 latent channels) and DC-AE (32x compression, 32 latent channels)
+- **T5Gemma Text Encoder**: Uses Google's T5Gemma-2B-2B-UL2 model for text encoding offering multiple language support
+- **Efficient Architecture**: ~1.3B parameters in the transformer, enabling fast inference while maintaining quality
+
+## Available models:
+We offer a range of **Photon models** featuring different **VAE configurations**, each optimized for generating images at various resolutions.
+Both **fine-tuned** and **non-fine-tuned** versions are available:
+
+- **Non-fine-tuned models** perform best with **highly detailed prompts**, capturing fine nuances and complex compositions.
+- **Fine-tuned models**, trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist), enhance the **aesthetic quality** of the base models—especially when prompts are **less detailed**.
+
+
+| Model | Recommended dtype | Resolution | Fine-tuned |
+|:-----:|:-----------------:|:----------:|:----------:|
+| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i) | `torch.bfloat16` | 256x256 | No |
+| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft) | `torch.bfloat16` | 256x256 | Yes |
+| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i) | `torch.bfloat16` | 512x512 | No |
+| [`Photoroom/photon-512-t2i-sft`](hhttps://huggingface.co/Photoroom/photon-512-t2i-sft) | `torch.bfloat16` | 512x512 | Yes |
+| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae) | `torch.bfloat16` | 512x512 | No |
+| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft) | `torch.bfloat16` | 512x512 | Yes |
+
+Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information.
+
+## Loading the Pipeline
+
+Photon checkpoints only store the transformer and scheduler weights locally. The VAE and text encoder are automatically loaded from HuggingFace during pipeline initialization:
+
+```py
+from diffusers.pipelines.photon import PhotonPipeline
+
+# Load pipeline - VAE and text encoder will be loaded from HuggingFace
+pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i")
+pipe.to("cuda")
+
+prompt = "A highly detailed 3D animated scene of a cute, intelligent duck scientist in a futuristic laboratory. The duck stands on a shiny metallic floor surrounded by glowing glass tubes filled with colorful liquids—blue, green, and purple—connected by translucent hoses emitting soft light. The duck wears a tiny white lab coat, safety goggles, and has a curious, determined expression while conducting an experiment. Sparks of energy and soft particle effects fill the air as scientific instruments hum with power. In the background, holographic screens display molecular diagrams and equations. Above the duck’s head, the word “PHOTON” glows vividly in midair as if made of pure light, illuminating the scene with a warm golden glow. The lighting is cinematic, with rich reflections and subtle depth of field, emphasizing a Pixar-like, ultra-polished 3D animation style. Rendered in ultra high resolution, realistic subsurface scattering on the duck’s feathers, and vibrant color grading that gives a sense of wonder and scientific discovery."
+image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
+image.save("photon_output.png")
+```
+
+### Manual Component Loading
+
+You can also load components individually:
+
+```py
+from diffusers.pipelines.photon import PhotonPipeline
+from diffusers.models import AutoencoderKL
+from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from transformers import T5GemmaModel, GemmaTokenizerFast
+
+# Load transformer
+transformer = PhotonTransformer2DModel.from_pretrained(
+ "Photoroom/photon-512-t2i", subfolder="transformer"
+)
+
+# Load scheduler
+scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ "Photoroom/photon-512-t2i", subfolder="scheduler"
+)
+
+# Load T5Gemma text encoder
+t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
+text_encoder = t5gemma_model.encoder
+tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
+
+# Load VAE - choose either Flux VAE or DC-AE
+# Flux VAE (16 latent channels):
+vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
+# Or DC-AE (32 latent channels):
+# vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers")
+
+pipe = PhotonPipeline(
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae
+)
+pipe.to("cuda")
+```
+
+## VAE Variants
+
+Photon supports two VAE configurations:
+
+### Flux VAE (AutoencoderKL)
+- **Compression**: 8x spatial compression
+- **Latent channels**: 16
+- **Model**: `black-forest-labs/FLUX.1-dev` (subfolder: "vae")
+- **Use case**: Balanced quality and speed
+
+### DC-AE (AutoencoderDC)
+- **Compression**: 32x spatial compression
+- **Latent channels**: 32
+- **Model**: `mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers`
+- **Use case**: Higher compression for faster processing
+
+The VAE type is automatically determined from the checkpoint's `model_index.json` configuration.
+
+## Generation Parameters
+
+Key parameters for image generation:
+
+- **num_inference_steps**: Number of denoising steps (default: 28). More steps generally improve quality at the cost of speed.
+- **guidance_scale**: Classifier-free guidance strength (default: 4.0). Higher values produce images more closely aligned with the prompt.
+- **height/width**: Output image dimensions (default: 512x512). Can be customized in the checkpoint configuration.
+
+```py
+# Example with custom parameters
+import torch
+from diffusers.pipelines.photon import PhotonPipeline
+
+pipe = pipe(
+ prompt="A highly detailed 3D animated scene of a cute, intelligent duck scientist in a futuristic laboratory. The duck stands on a shiny metallic floor surrounded by glowing glass tubes filled with colorful liquids—blue, green, and purple—connected by translucent hoses emitting soft light. The duck wears a tiny white lab coat, safety goggles, and has a curious, determined expression while conducting an experiment. Sparks of energy and soft particle effects fill the air as scientific instruments hum with power. In the background, holographic screens display molecular diagrams and equations. Above the duck’s head, the word “PHOTON” glows vividly in midair as if made of pure light, illuminating the scene with a warm golden glow. The lighting is cinematic, with rich reflections and subtle depth of field, emphasizing a Pixar-like, ultra-polished 3D animation style. Rendered in ultra high resolution, realistic subsurface scattering on the duck’s feathers, and vibrant color grading that gives a sense of wonder and scientific discovery.",
+ num_inference_steps=28,
+ guidance_scale=4.0,
+ height=512,
+ width=512,
+ generator=torch.Generator("cuda").manual_seed(42)
+).images[0]
+```
+
+## Memory Optimization
+
+For memory-constrained environments:
+
+```py
+import torch
+from diffusers.pipelines.photon import PhotonPipeline
+
+pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i", torch_dtype=torch.float16)
+pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
+
+# Or use sequential CPU offload for even lower memory
+pipe.enable_sequential_cpu_offload()
+```
+
+## PhotonPipeline
+
+[[autodoc]] PhotonPipeline
+ - all
+ - __call__
+
+## PhotonPipelineOutput
+
+[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput
diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py
new file mode 100644
index 000000000000..0dd114a68997
--- /dev/null
+++ b/scripts/convert_photon_to_diffusers.py
@@ -0,0 +1,355 @@
+#!/usr/bin/env python3
+"""
+Script to convert Photon checkpoint from original codebase to diffusers format.
+"""
+
+import argparse
+import json
+import os
+import sys
+from dataclasses import asdict, dataclass
+from typing import Dict, Tuple
+
+import torch
+from safetensors.torch import save_file
+
+
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
+
+from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
+from diffusers.pipelines.photon import PhotonPipeline
+
+
+DEFAULT_RESOLUTION = 512
+
+
+@dataclass(frozen=True)
+class PhotonBase:
+ context_in_dim: int = 2304
+ hidden_size: int = 1792
+ mlp_ratio: float = 3.5
+ num_heads: int = 28
+ depth: int = 16
+ axes_dim: Tuple[int, int] = (32, 32)
+ theta: int = 10_000
+ time_factor: float = 1000.0
+ time_max_period: int = 10_000
+
+
+@dataclass(frozen=True)
+class PhotonFlux(PhotonBase):
+ in_channels: int = 16
+ patch_size: int = 2
+
+
+@dataclass(frozen=True)
+class PhotonDCAE(PhotonBase):
+ in_channels: int = 32
+ patch_size: int = 1
+
+
+def build_config(vae_type: str, resolution: int = DEFAULT_RESOLUTION) -> dict:
+ if vae_type == "flux":
+ cfg = PhotonFlux()
+ sample_size = resolution // 8
+ elif vae_type == "dc-ae":
+ cfg = PhotonDCAE()
+ sample_size = resolution // 32
+ else:
+ raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
+
+ config_dict = asdict(cfg)
+ config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index]
+ config_dict["sample_size"] = sample_size
+ return config_dict
+
+
+def create_parameter_mapping(depth: int) -> dict:
+ """Create mapping from old parameter names to new diffusers names."""
+
+ # Key mappings for structural changes
+ mapping = {}
+
+ # RMSNorm: scale -> weight
+ for i in range(depth):
+ mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.qk_norm.query_norm.weight"
+ mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.qk_norm.key_norm.weight"
+ mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.k_norm.weight"
+
+ # Attention: attn_out -> attention.to_out.0
+ mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"
+
+ return mapping
+
+
+def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]:
+ """Convert old checkpoint parameters to new diffusers format."""
+
+ print("Converting checkpoint parameters...")
+
+ mapping = create_parameter_mapping(depth)
+ converted_state_dict = {}
+
+ for key, value in old_state_dict.items():
+ new_key = key
+
+ # Apply specific mappings if needed
+ if key in mapping:
+ new_key = mapping[key]
+ print(f" Mapped: {key} -> {new_key}")
+
+ # Handle img_qkv_proj -> split to to_q, to_k, to_v
+ if "img_qkv_proj.weight" in key:
+ print(f" Found QKV projection: {key}")
+ # Split QKV weight into separate Q, K, V projections
+ qkv_weight = value
+ q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0)
+
+ # Extract layer number from key (e.g., blocks.0.img_qkv_proj.weight -> 0)
+ parts = key.split(".")
+ layer_idx = None
+ for i, part in enumerate(parts):
+ if part == "blocks" and i + 1 < len(parts) and parts[i + 1].isdigit():
+ layer_idx = parts[i + 1]
+ break
+
+ if layer_idx is not None:
+ converted_state_dict[f"blocks.{layer_idx}.attention.to_q.weight"] = q_weight
+ converted_state_dict[f"blocks.{layer_idx}.attention.to_k.weight"] = k_weight
+ converted_state_dict[f"blocks.{layer_idx}.attention.to_v.weight"] = v_weight
+ print(f" Split QKV for layer {layer_idx}")
+
+ # Also keep the original img_qkv_proj for backward compatibility
+ converted_state_dict[new_key] = value
+ else:
+ converted_state_dict[new_key] = value
+
+ print(f"✓ Converted {len(converted_state_dict)} parameters")
+ return converted_state_dict
+
+
+def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel:
+ """Create and load PhotonTransformer2DModel from old checkpoint."""
+
+ print(f"Loading checkpoint from: {checkpoint_path}")
+
+ # Load old checkpoint
+ if not os.path.exists(checkpoint_path):
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
+
+ old_checkpoint = torch.load(checkpoint_path, map_location="cpu")
+
+ # Handle different checkpoint formats
+ if isinstance(old_checkpoint, dict):
+ if "model" in old_checkpoint:
+ state_dict = old_checkpoint["model"]
+ elif "state_dict" in old_checkpoint:
+ state_dict = old_checkpoint["state_dict"]
+ else:
+ state_dict = old_checkpoint
+ else:
+ state_dict = old_checkpoint
+
+ print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")
+
+ # Convert parameter names if needed
+ model_depth = int(config.get("depth", 16))
+ converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
+
+ # Create transformer with config
+ print("Creating PhotonTransformer2DModel...")
+ transformer = PhotonTransformer2DModel(**config)
+
+ # Load state dict
+ print("Loading converted parameters...")
+ missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
+
+ if missing_keys:
+ print(f"⚠ Missing keys: {missing_keys}")
+ if unexpected_keys:
+ print(f"⚠ Unexpected keys: {unexpected_keys}")
+
+ if not missing_keys and not unexpected_keys:
+ print("✓ All parameters loaded successfully!")
+
+ return transformer
+
+
+def create_scheduler_config(output_path: str):
+ """Create FlowMatchEulerDiscreteScheduler config."""
+
+ scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": 1.0}
+
+ scheduler_path = os.path.join(output_path, "scheduler")
+ os.makedirs(scheduler_path, exist_ok=True)
+
+ with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f:
+ json.dump(scheduler_config, f, indent=2)
+
+ print("✓ Created scheduler config")
+
+
+def download_and_save_vae(vae_type: str, output_path: str):
+ """Download and save VAE to local directory."""
+ from diffusers import AutoencoderDC, AutoencoderKL
+
+ vae_path = os.path.join(output_path, "vae")
+ os.makedirs(vae_path, exist_ok=True)
+
+ if vae_type == "flux":
+ print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...")
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
+ else: # dc-ae
+ print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers...")
+ vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers")
+
+ vae.save_pretrained(vae_path)
+ print(f"✓ Saved VAE to {vae_path}")
+
+
+def download_and_save_text_encoder(output_path: str):
+ """Download and save T5Gemma text encoder and tokenizer."""
+ from transformers import GemmaTokenizerFast
+ from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel
+
+ text_encoder_path = os.path.join(output_path, "text_encoder")
+ tokenizer_path = os.path.join(output_path, "tokenizer")
+ os.makedirs(text_encoder_path, exist_ok=True)
+ os.makedirs(tokenizer_path, exist_ok=True)
+
+ print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...")
+ t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
+
+ t5gemma_model.save_pretrained(text_encoder_path)
+ print(f"✓ Saved T5Gemma model to {text_encoder_path}")
+
+ print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...")
+ tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
+ tokenizer.model_max_length = 256
+ tokenizer.save_pretrained(tokenizer_path)
+ print(f"✓ Saved tokenizer to {tokenizer_path}")
+
+
+def create_model_index(vae_type: str, output_path: str):
+ """Create model_index.json for the pipeline."""
+
+ if vae_type == "flux":
+ vae_class = "AutoencoderKL"
+ else: # dc-ae
+ vae_class = "AutoencoderDC"
+
+ model_index = {
+ "_class_name": "PhotonPipeline",
+ "_diffusers_version": "0.31.0.dev0",
+ "_name_or_path": os.path.basename(output_path),
+ "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
+ "text_encoder": ["transformers", "T5GemmaModel"],
+ "tokenizer": ["transformers", "GemmaTokenizerFast"],
+ "transformer": ["diffusers", "PhotonTransformer2DModel"],
+ "vae": ["diffusers", vae_class],
+ }
+
+ model_index_path = os.path.join(output_path, "model_index.json")
+ with open(model_index_path, "w") as f:
+ json.dump(model_index, f, indent=2)
+
+
+def main(args):
+ # Validate inputs
+ if not os.path.exists(args.checkpoint_path):
+ raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
+
+ config = build_config(args.vae_type, args.resolution)
+
+ # Create output directory
+ os.makedirs(args.output_path, exist_ok=True)
+ print(f"✓ Output directory: {args.output_path}")
+
+ # Create transformer from checkpoint
+ transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
+
+ # Save transformer
+ transformer_path = os.path.join(args.output_path, "transformer")
+ os.makedirs(transformer_path, exist_ok=True)
+
+ # Save config
+ with open(os.path.join(transformer_path, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ # Save model weights as safetensors
+ state_dict = transformer.state_dict()
+ save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
+ print(f"✓ Saved transformer to {transformer_path}")
+
+ # Create scheduler config
+ create_scheduler_config(args.output_path)
+
+ download_and_save_vae(args.vae_type, args.output_path)
+ download_and_save_text_encoder(args.output_path)
+
+ # Create model_index.json
+ create_model_index(args.vae_type, args.output_path)
+
+ # Verify the pipeline can be loaded
+ try:
+ pipeline = PhotonPipeline.from_pretrained(args.output_path)
+ print("Pipeline loaded successfully!")
+ print(f"Transformer: {type(pipeline.transformer).__name__}")
+ print(f"VAE: {type(pipeline.vae).__name__}")
+ print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
+ print(f"Scheduler: {type(pipeline.scheduler).__name__}")
+
+ # Display model info
+ num_params = sum(p.numel() for p in pipeline.transformer.parameters())
+ print(f"✓ Transformer parameters: {num_params:,}")
+
+ except Exception as e:
+ print(f"Pipeline verification failed: {e}")
+ return False
+
+ print("Conversion completed successfully!")
+ print(f"Converted pipeline saved to: {args.output_path}")
+ print(f"VAE type: {args.vae_type}")
+
+ return True
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format")
+
+ parser.add_argument(
+ "--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file)"
+ )
+
+ parser.add_argument(
+ "--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
+ )
+
+ parser.add_argument(
+ "--vae_type",
+ type=str,
+ choices=["flux", "dc-ae"],
+ required=True,
+ help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
+ )
+
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ choices=[256, 512, 1024],
+ default=DEFAULT_RESOLUTION,
+ help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.",
+ )
+
+ args = parser.parse_args()
+
+ try:
+ success = main(args)
+ if not success:
+ sys.exit(1)
+ except Exception as e:
+ print(f"Conversion failed: {e}")
+ import traceback
+
+ traceback.print_exc()
+ sys.exit(1)
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 686e8d99dabf..4eff8a27ff40 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -224,6 +224,7 @@
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
+ "PhotonTransformer2DModel",
"MochiTransformer3DModel",
"ModelMixin",
"MotionAdapter",
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 457f70448af3..86e32c1eec3e 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -93,6 +93,7 @@
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
+ _import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 66455d733aee..47cf4fab4a5e 100755
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -5605,6 +5605,63 @@ def __new__(cls, *args, **kwargs):
return processor
+class PhotonAttnProcessor2_0:
+ r"""
+ Processor for implementing Photon-style attention with multi-source tokens and RoPE.
+ Properly integrates with diffusers Attention module while handling Photon-specific logic.
+ """
+
+ def __init__(self):
+ if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
+ raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: "Attention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Apply Photon attention using standard diffusers interface.
+
+ Expected tensor formats from PhotonBlock.attn_forward():
+ - hidden_states: Image queries with RoPE applied [B, H, L_img, D]
+ - encoder_hidden_states: Packed key+value tensors [B, H, L_all, 2*D]
+ (concatenated keys and values from text + image + spatial conditioning)
+ - attention_mask: Custom attention mask [B, H, L_img, L_all] or None
+ """
+
+ if encoder_hidden_states is None:
+ raise ValueError(
+ "PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing packed key+value tensors. "
+ "This should be provided by PhotonBlock.attn_forward()."
+ )
+
+ # Unpack the combined key+value tensor
+ # encoder_hidden_states is [B, H, L_all, 2*D] containing [keys, values]
+ key, value = encoder_hidden_states.chunk(2, dim=-1) # Each [B, H, L_all, D]
+
+ # Apply scaled dot-product attention with Photon's processed tensors
+ # hidden_states is image queries [B, H, L_img, D]
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ hidden_states.contiguous(), key.contiguous(), value.contiguous(), attn_mask=attention_mask
+ )
+
+ # Reshape from [B, H, L_img, D] to [B, L_img, H*D]
+ batch_size, num_heads, seq_len, head_dim = attn_output.shape
+ attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim)
+
+ # Apply output projection using the diffusers Attention module
+ attn_output = attn.to_out[0](attn_output)
+ if len(attn.to_out) > 1:
+ attn_output = attn.to_out[1](attn_output) # dropout if present
+
+ return attn_output
+
+
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
@@ -5653,6 +5710,7 @@ def __new__(cls, *args, **kwargs):
PAGHunyuanAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
LuminaAttnProcessor2_0,
+ PhotonAttnProcessor2_0,
FusedAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index b60f0636e6dc..7fdab560a702 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -31,6 +31,7 @@
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
+ from .transformer_photon import PhotonTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py
new file mode 100644
index 000000000000..452be9d2e6df
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_photon.py
@@ -0,0 +1,833 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from einops import rearrange
+from einops.layers.torch import Rearrange
+from torch import Tensor, nn
+from torch.nn.functional import fold, unfold
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention_processor import Attention, AttentionProcessor, PhotonAttnProcessor2_0
+from ..embeddings import get_timestep_embedding
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import RMSNorm
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> Tensor:
+ r"""
+ Generates 2D patch coordinate indices for a batch of images.
+
+ Parameters:
+ batch_size (`int`):
+ Number of images in the batch.
+ height (`int`):
+ Height of the input images (in pixels).
+ width (`int`):
+ Width of the input images (in pixels).
+ patch_size (`int`):
+ Size of the square patches that the image is divided into.
+ device (`torch.device`):
+ The device on which to create the tensor.
+
+ Returns:
+ `torch.Tensor`:
+ Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col)
+ coordinates of each patch in the image grid.
+ """
+
+ img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device)
+ img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None]
+ img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :]
+ return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1)
+
+
+def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor:
+ r"""
+ Applies rotary positional embeddings (RoPE) to a query tensor.
+
+ Parameters:
+ xq (`torch.Tensor`):
+ Input tensor of shape `(..., dim)` representing the queries.
+ freqs_cis (`torch.Tensor`):
+ Precomputed rotary frequency components of shape `(..., dim/2, 2)`
+ containing cosine and sine pairs.
+
+ Returns:
+ `torch.Tensor`:
+ Tensor of the same shape as `xq` with rotary embeddings applied.
+ """
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+ return xq_out.reshape(*xq.shape).type_as(xq)
+
+
+class EmbedND(nn.Module):
+ r"""
+ N-dimensional rotary positional embedding.
+
+ This module creates rotary embeddings (RoPE) across multiple axes, where each
+ axis can have its own embedding dimension. The embeddings are combined and
+ returned as a single tensor
+
+ Parameters:
+ dim (int):
+ Base embedding dimension (must be even).
+ theta (int):
+ Scaling factor that controls the frequency spectrum of the rotary embeddings.
+ axes_dim (list[int]):
+ List of embedding dimensions for each axis (each must be even).
+ """
+
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+ self.rope_rearrange = Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
+
+ def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor:
+ assert dim % 2 == 0
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
+ omega = 1.0 / (theta**scale)
+ out = pos.unsqueeze(-1) * omega.unsqueeze(0)
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
+ out = self.rope_rearrange(out)
+ return out.float()
+
+ def forward(self, ids: Tensor) -> Tensor:
+ n_axes = ids.shape[-1]
+ emb = torch.cat(
+ [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ dim=-3,
+ )
+ return emb.unsqueeze(1)
+
+
+class MLPEmbedder(nn.Module):
+ r"""
+ A simple 2-layer MLP used for embedding inputs.
+
+ Parameters:
+ in_dim (`int`):
+ Dimensionality of the input features.
+ hidden_dim (`int`):
+ Dimensionality of the hidden and output embedding space.
+
+ Returns:
+ `torch.Tensor`:
+ Tensor of shape `(..., hidden_dim)` containing the embedded representations.
+ """
+
+ def __init__(self, in_dim: int, hidden_dim: int):
+ super().__init__()
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
+ self.silu = nn.SiLU()
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.out_layer(self.silu(self.in_layer(x)))
+
+
+class QKNorm(torch.nn.Module):
+ r"""
+ Applies RMS normalization to query and key tensors separately before attention
+ which can help stabilize training and improve numerical precision.
+
+ Parameters:
+ dim (`int`):
+ Dimensionality of the query and key vectors.
+
+ Returns:
+ (`torch.Tensor`, `torch.Tensor`):
+ A tuple `(q, k)` where both are normalized and cast to the same dtype
+ as the value tensor `v`.
+ """
+
+ def __init__(self, dim: int):
+ super().__init__()
+ self.query_norm = RMSNorm(dim, eps=1e-6)
+ self.key_norm = RMSNorm(dim, eps=1e-6)
+
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
+ q = self.query_norm(q)
+ k = self.key_norm(k)
+ return q.to(v), k.to(v)
+
+
+@dataclass
+class ModulationOut:
+ shift: Tensor
+ scale: Tensor
+ gate: Tensor
+
+
+class Modulation(nn.Module):
+ r"""
+ Modulation network that generates scale, shift, and gating parameters.
+
+ Given an input vector, the module projects it through a linear layer to
+ produce six chunks, which are grouped into two `ModulationOut` objects.
+
+ Parameters:
+ dim (`int`):
+ Dimensionality of the input vector. The output will have `6 * dim`
+ features internally.
+
+ Returns:
+ (`ModulationOut`, `ModulationOut`):
+ A tuple of two modulation outputs. Each `ModulationOut` contains
+ three components (e.g., scale, shift, gate).
+ """
+
+ def __init__(self, dim: int):
+ super().__init__()
+ self.lin = nn.Linear(dim, 6 * dim, bias=True)
+ nn.init.constant_(self.lin.weight, 0)
+ nn.init.constant_(self.lin.bias, 0)
+
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]:
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1)
+ return ModulationOut(*out[:3]), ModulationOut(*out[3:])
+
+
+class PhotonBlock(nn.Module):
+ r"""
+ Multimodal transformer block with text–image cross-attention, modulation, and MLP.
+
+ Parameters:
+ hidden_size (`int`):
+ Dimension of the hidden representations.
+ num_heads (`int`):
+ Number of attention heads.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ Expansion ratio for the hidden dimension inside the MLP.
+ qk_scale (`float`, *optional*):
+ Scale factor for queries and keys. If not provided, defaults to
+ ``head_dim**-0.5``.
+
+ Attributes:
+ img_pre_norm (`nn.LayerNorm`):
+ Pre-normalization applied to image tokens before QKV projection.
+ img_qkv_proj (`nn.Linear`):
+ Linear projection to produce image queries, keys, and values.
+ qk_norm (`QKNorm`):
+ RMS normalization applied separately to image queries and keys.
+ txt_kv_proj (`nn.Linear`):
+ Linear projection to produce text keys and values.
+ k_norm (`RMSNorm`):
+ RMS normalization applied to text keys.
+ attention (`Attention`):
+ Multi-head attention module for cross-attention between image, text,
+ and optional spatial conditioning tokens.
+ post_attention_layernorm (`nn.LayerNorm`):
+ Normalization applied after attention.
+ gate_proj / up_proj / down_proj (`nn.Linear`):
+ Feedforward layers forming the gated MLP.
+ mlp_act (`nn.GELU`):
+ Nonlinear activation used in the MLP.
+ modulation (`Modulation`):
+ Produces scale/shift/gating parameters for modulated layers.
+ spatial_cond_kv_proj (`nn.Linear`, *optional*):
+ Projection for optional spatial conditioning tokens.
+
+ Methods:
+ attn_forward(img, txt, pe, modulation, spatial_conditioning=None, attention_mask=None):
+ Compute cross-attention between image and text tokens, with optional
+ spatial conditioning and attention masking.
+
+ Parameters:
+ img (`torch.Tensor`):
+ Image tokens of shape `(B, L_img, hidden_size)`.
+ txt (`torch.Tensor`):
+ Text tokens of shape `(B, L_txt, hidden_size)`.
+ pe (`torch.Tensor`):
+ Rotary positional embeddings to apply to queries and keys.
+ modulation (`ModulationOut`):
+ Scale and shift parameters for modulating image tokens.
+ spatial_conditioning (`torch.Tensor`, *optional*):
+ Extra conditioning tokens of shape `(B, L_cond, hidden_size)`.
+ attention_mask (`torch.Tensor`, *optional*):
+ Boolean mask of shape `(B, L_txt)` where 0 marks padding.
+
+ Returns:
+ `torch.Tensor`:
+ Attention output of shape `(B, L_img, hidden_size)`.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qk_scale: float | None = None,
+ ):
+ super().__init__()
+
+ self._fsdp_wrap = True
+ self._activation_checkpointing = True
+
+ self.hidden_dim = hidden_size
+ self.num_heads = num_heads
+ self.head_dim = hidden_size // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.hidden_size = hidden_size
+
+ # img qkv
+ self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.img_qkv_proj = nn.Linear(hidden_size, hidden_size * 3, bias=False)
+ self.qk_norm = QKNorm(self.head_dim)
+
+ # txt kv
+ self.txt_kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=False)
+ self.k_norm = RMSNorm(self.head_dim, eps=1e-6)
+
+ self.attention = Attention(
+ query_dim=hidden_size,
+ heads=num_heads,
+ dim_head=self.head_dim,
+ bias=False,
+ out_bias=False,
+ processor=PhotonAttnProcessor2_0(),
+ )
+
+ # mlp
+ self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
+ self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
+ self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False)
+ self.mlp_act = nn.GELU(approximate="tanh")
+
+ self.modulation = Modulation(hidden_size)
+ self.spatial_cond_kv_proj: None | nn.Linear = None
+
+ def _attn_forward(
+ self,
+ img: Tensor,
+ txt: Tensor,
+ pe: Tensor,
+ modulation: ModulationOut,
+ spatial_conditioning: None | Tensor = None,
+ attention_mask: None | Tensor = None,
+ ) -> Tensor:
+ # image tokens proj and norm
+ img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift
+
+ img_qkv = self.img_qkv_proj(img_mod)
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+ img_q, img_k = self.qk_norm(img_q, img_k, img_v)
+
+ # txt tokens proj and norm
+ txt_kv = self.txt_kv_proj(txt)
+ txt_k, txt_v = rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads)
+ txt_k = self.k_norm(txt_k)
+
+ # compute attention
+ img_q, img_k = apply_rope(img_q, pe), apply_rope(img_k, pe)
+ k = torch.cat((txt_k, img_k), dim=2)
+ v = torch.cat((txt_v, img_v), dim=2)
+
+ # optional spatial conditioning tokens
+ cond_len = 0
+ if self.spatial_cond_kv_proj is not None:
+ assert spatial_conditioning is not None
+ cond_kv = self.spatial_cond_kv_proj(spatial_conditioning)
+ cond_k, cond_v = rearrange(cond_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads)
+ cond_k = apply_rope(cond_k, pe)
+ cond_len = cond_k.shape[2]
+ k = torch.cat((cond_k, k), dim=2)
+ v = torch.cat((cond_v, v), dim=2)
+
+ # build multiplicative 0/1 mask for provided attention_mask over [cond?, text, image] keys
+ attn_mask: Tensor | None = None
+ if attention_mask is not None:
+ bs, _, l_img, _ = img_q.shape
+ l_txt = txt_k.shape[2]
+
+ if attention_mask.dim() != 2:
+ raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
+ if attention_mask.shape[-1] != l_txt:
+ raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}")
+
+ device = img_q.device
+
+ ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device)
+ cond_mask = torch.ones((bs, cond_len), dtype=torch.bool, device=device)
+
+ mask_parts = [
+ cond_mask,
+ attention_mask.to(torch.bool),
+ ones_img,
+ ]
+ joint_mask = torch.cat(mask_parts, dim=-1) # (B, L_all)
+
+ # repeat across heads and query positions
+ attn_mask = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1) # (B,H,L_img,L_all)
+
+ kv_packed = torch.cat([k, v], dim=-1)
+
+ attn = self.attention(
+ hidden_states=img_q,
+ encoder_hidden_states=kv_packed,
+ attention_mask=attn_mask,
+ )
+
+ return attn
+
+ def _ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor:
+ x = (1 + modulation.scale) * self.post_attention_layernorm(x) + modulation.shift
+ return self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x))
+
+ def forward(
+ self,
+ img: Tensor,
+ txt: Tensor,
+ vec: Tensor,
+ pe: Tensor,
+ spatial_conditioning: Tensor | None = None,
+ attention_mask: Tensor | None = None,
+ **_: dict[str, Any],
+ ) -> Tensor:
+ r"""
+ Runs modulation-gated cross-attention and MLP, with residual connections.
+
+ Parameters:
+ img (`torch.Tensor`):
+ Image tokens of shape `(B, L_img, hidden_size)`.
+ txt (`torch.Tensor`):
+ Text tokens of shape `(B, L_txt, hidden_size)`.
+ vec (`torch.Tensor`):
+ Conditioning vector used by `Modulation` to produce scale/shift/gates,
+ shape `(B, hidden_size)` (or broadcastable).
+ pe (`torch.Tensor`):
+ Rotary positional embeddings applied inside attention.
+ spatial_conditioning (`torch.Tensor`, *optional*):
+ Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. Used only
+ if spatial conditioning is enabled in the block.
+ attention_mask (`torch.Tensor`, *optional*):
+ Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
+ **_:
+ Ignored additional keyword arguments for API compatibility.
+
+ Returns:
+ `torch.Tensor`:
+ Updated image tokens of shape `(B, L_img, hidden_size)`.
+ """
+
+ mod_attn, mod_mlp = self.modulation(vec)
+
+ img = img + mod_attn.gate * self._attn_forward(
+ img,
+ txt,
+ pe,
+ mod_attn,
+ spatial_conditioning=spatial_conditioning,
+ attention_mask=attention_mask,
+ )
+ img = img + mod_mlp.gate * self._ffn_forward(img, mod_mlp)
+ return img
+
+
+class LastLayer(nn.Module):
+ r"""
+ Final projection layer with adaptive LayerNorm modulation.
+
+ This layer applies a normalized and modulated transformation to input tokens
+ and projects them into patch-level outputs.
+
+ Parameters:
+ hidden_size (`int`):
+ Dimensionality of the input tokens.
+ patch_size (`int`):
+ Size of the square image patches.
+ out_channels (`int`):
+ Number of output channels per pixel (e.g. RGB = 3).
+
+ Forward Inputs:
+ x (`torch.Tensor`):
+ Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches.
+ vec (`torch.Tensor`):
+ Conditioning vector of shape `(B, hidden_size)` used to generate
+ shift and scale parameters for adaptive LayerNorm.
+
+ Returns:
+ `torch.Tensor`:
+ Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`.
+ """
+
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
+
+ nn.init.constant_(self.adaLN_modulation[1].weight, 0)
+ nn.init.constant_(self.adaLN_modulation[1].bias, 0)
+ nn.init.constant_(self.linear.weight, 0)
+ nn.init.constant_(self.linear.bias, 0)
+
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
+ x = self.linear(x)
+ return x
+
+
+def img2seq(img: Tensor, patch_size: int) -> Tensor:
+ r"""
+ Flattens an image tensor into a sequence of non-overlapping patches.
+
+ Parameters:
+ img (`torch.Tensor`):
+ Input image tensor of shape `(B, C, H, W)`.
+ patch_size (`int`):
+ Size of each square patch. Must evenly divide both `H` and `W`.
+
+ Returns:
+ `torch.Tensor`:
+ Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`,
+ where `L = (H // patch_size) * (W // patch_size)` is the number of patches.
+ """
+ return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
+
+
+def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor:
+ r"""
+ Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).
+
+ Parameters:
+ seq (`torch.Tensor`):
+ Patch sequence of shape `(B, L, C * patch_size * patch_size)`,
+ where `L = (H // patch_size) * (W // patch_size)`.
+ patch_size (`int`):
+ Size of each square patch.
+ shape (`tuple` or `torch.Tensor`):
+ The original image spatial shape `(H, W)`. If a tensor is provided,
+ the first two values are interpreted as height and width.
+
+ Returns:
+ `torch.Tensor`:
+ Reconstructed image tensor of shape `(B, C, H, W)`.
+ """
+ if isinstance(shape, tuple):
+ shape = shape[-2:]
+ elif isinstance(shape, torch.Tensor):
+ shape = (int(shape[0]), int(shape[1]))
+ else:
+ raise NotImplementedError(f"shape type {type(shape)} not supported")
+ return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
+
+
+class PhotonTransformer2DModel(ModelMixin, ConfigMixin):
+ r"""
+ Transformer-based 2D model for text to image generation.
+ It supports attention processor injection and LoRA scaling.
+
+ Parameters:
+ in_channels (`int`, *optional*, defaults to 16):
+ Number of input channels in the latent image.
+ patch_size (`int`, *optional*, defaults to 2):
+ Size of the square patches used to flatten the input image.
+ context_in_dim (`int`, *optional*, defaults to 2304):
+ Dimensionality of the text conditioning input.
+ hidden_size (`int`, *optional*, defaults to 1792):
+ Dimension of the hidden representation.
+ mlp_ratio (`float`, *optional*, defaults to 3.5):
+ Expansion ratio for the hidden dimension inside MLP blocks.
+ num_heads (`int`, *optional*, defaults to 28):
+ Number of attention heads.
+ depth (`int`, *optional*, defaults to 16):
+ Number of transformer blocks.
+ axes_dim (`list[int]`, *optional*):
+ List of dimensions for each positional embedding axis. Defaults to `[32, 32]`.
+ theta (`int`, *optional*, defaults to 10000):
+ Frequency scaling factor for rotary embeddings.
+ time_factor (`float`, *optional*, defaults to 1000.0):
+ Scaling factor applied in timestep embeddings.
+ time_max_period (`int`, *optional*, defaults to 10000):
+ Maximum frequency period for timestep embeddings.
+ conditioning_block_ids (`list[int]`, *optional*):
+ Indices of blocks that receive conditioning. Defaults to all blocks.
+ **kwargs:
+ Additional keyword arguments forwarded to the config.
+
+ Attributes:
+ pe_embedder (`EmbedND`):
+ Multi-axis rotary embedding generator for positional encodings.
+ img_in (`nn.Linear`):
+ Projection layer for image patch tokens.
+ time_in (`MLPEmbedder`):
+ Embedding layer for timestep embeddings.
+ txt_in (`nn.Linear`):
+ Projection layer for text conditioning.
+ blocks (`nn.ModuleList`):
+ Stack of transformer blocks (`PhotonBlock`).
+ final_layer (`LastLayer`):
+ Projection layer mapping hidden tokens back to patch outputs.
+
+ Methods:
+ attn_processors:
+ Returns a dictionary of all attention processors in the model.
+ set_attn_processor(processor):
+ Replaces attention processors across all attention layers.
+ process_inputs(image_latent, txt):
+ Converts inputs into patch tokens, encodes text, and produces positional encodings.
+ compute_timestep_embedding(timestep, dtype):
+ Creates a timestep embedding of dimension 256, scaled and projected.
+ forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask, **block_kwargs):
+ Runs the sequence of transformer blocks over image and text tokens.
+ forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None, attention_kwargs=None, return_dict=True):
+ Full forward pass from latent input to reconstructed output image.
+
+ Returns:
+ `Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing:
+ - `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`.
+ """
+
+ config_name = "config.json"
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 16,
+ patch_size: int = 2,
+ context_in_dim: int = 2304,
+ hidden_size: int = 1792,
+ mlp_ratio: float = 3.5,
+ num_heads: int = 28,
+ depth: int = 16,
+ axes_dim: list = None,
+ theta: int = 10000,
+ time_factor: float = 1000.0,
+ time_max_period: int = 10000,
+ conditioning_block_ids: list = None,
+ **kwargs,
+ ):
+ super().__init__()
+
+ if axes_dim is None:
+ axes_dim = [32, 32]
+
+ # Store parameters directly
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.out_channels = self.in_channels * self.patch_size**2
+
+ self.time_factor = time_factor
+ self.time_max_period = time_max_period
+
+ if hidden_size % num_heads != 0:
+ raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
+
+ pe_dim = hidden_size // num_heads
+
+ if sum(axes_dim) != pe_dim:
+ raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
+
+ self.hidden_size = hidden_size
+ self.num_heads = num_heads
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
+ self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
+
+ conditioning_block_ids: list[int] = conditioning_block_ids or list(range(depth))
+
+ self.blocks = nn.ModuleList(
+ [
+ PhotonBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=mlp_ratio,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def _process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]:
+ txt = self.txt_in(txt)
+ img = img2seq(image_latent, self.patch_size)
+ bs, _, h, w = image_latent.shape
+ img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device)
+ pe = self.pe_embedder(img_ids)
+ return img, txt, pe
+
+ def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor:
+ return self.time_in(
+ get_timestep_embedding(
+ timesteps=timestep,
+ embedding_dim=256,
+ max_period=self.time_max_period,
+ scale=self.time_factor,
+ flip_sin_to_cos=True, # Match original cos, sin order
+ ).to(dtype)
+ )
+
+ def _forward_transformers(
+ self,
+ image_latent: Tensor,
+ cross_attn_conditioning: Tensor,
+ timestep: Optional[Tensor] = None,
+ time_embedding: Optional[Tensor] = None,
+ attention_mask: Optional[Tensor] = None,
+ **block_kwargs: Any,
+ ) -> Tensor:
+ img = self.img_in(image_latent)
+
+ if time_embedding is not None:
+ vec = time_embedding
+ else:
+ if timestep is None:
+ raise ValueError("Please provide either a timestep or a timestep_embedding")
+ vec = self._compute_timestep_embedding(timestep, dtype=img.dtype)
+
+ for block in self.blocks:
+ img = block(img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs)
+
+ img = self.final_layer(img, vec)
+ return img
+
+ def forward(
+ self,
+ image_latent: Tensor,
+ timestep: Tensor,
+ cross_attn_conditioning: Tensor,
+ micro_conditioning: Tensor,
+ cross_attn_mask: None | Tensor = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
+ r"""
+ Forward pass of the PhotonTransformer2DModel.
+
+ The latent image is split into patch tokens, combined with text conditioning,
+ and processed through a stack of transformer blocks modulated by the timestep.
+ The output is reconstructed into the latent image space.
+
+ Parameters:
+ image_latent (`torch.Tensor`):
+ Input latent image tensor of shape `(B, C, H, W)`.
+ timestep (`torch.Tensor`):
+ Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
+ cross_attn_conditioning (`torch.Tensor`):
+ Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
+ micro_conditioning (`torch.Tensor`):
+ Extra conditioning vector (currently unused, reserved for future use).
+ cross_attn_mask (`torch.Tensor`, *optional*):
+ Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
+ attention_kwargs (`dict`, *optional*):
+ Additional arguments passed to attention layers. If using the PEFT backend,
+ the key `"scale"` controls LoRA scaling (default: 1.0).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a `Transformer2DModelOutput` or a tuple.
+
+ Returns:
+ `Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple:
+
+ - `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`.
+ """
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+ img_seq, txt, pe = self._process_inputs(image_latent, cross_attn_conditioning)
+ img_seq = self._forward_transformers(img_seq, txt, timestep, pe=pe, attention_mask=cross_attn_mask)
+ output = seq2img(img_seq, self.patch_size, image_latent.shape)
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 190c7871d270..ae0d90c48c63 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -144,6 +144,7 @@
"FluxKontextPipeline",
"FluxKontextInpaintPipeline",
]
+ _import_structure["photon"] = ["PhotonPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/photon/__init__.py
new file mode 100644
index 000000000000..559c9d0b1d2d
--- /dev/null
+++ b/src/diffusers/pipelines/photon/__init__.py
@@ -0,0 +1,5 @@
+from .pipeline_output import PhotonPipelineOutput
+from .pipeline_photon import PhotonPipeline
+
+
+__all__ = ["PhotonPipeline", "PhotonPipelineOutput"]
diff --git a/src/diffusers/pipelines/photon/pipeline_output.py b/src/diffusers/pipelines/photon/pipeline_output.py
new file mode 100644
index 000000000000..ca0674d94b6c
--- /dev/null
+++ b/src/diffusers/pipelines/photon/pipeline_output.py
@@ -0,0 +1,35 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class PhotonPipelineOutput(BaseOutput):
+ """
+ Output class for Photon pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py
new file mode 100644
index 000000000000..0fc926261517
--- /dev/null
+++ b/src/diffusers/pipelines/photon/pipeline_photon.py
@@ -0,0 +1,614 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import ftfy
+import torch
+from transformers import (
+ AutoTokenizer,
+ GemmaTokenizerFast,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderDC, AutoencoderKL
+from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel, seq2img
+from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+DEFAULT_HEIGHT = 512
+DEFAULT_WIDTH = 512
+
+logger = logging.get_logger(__name__)
+
+
+class TextPreprocessor:
+ """Text preprocessing utility for PhotonPipeline."""
+
+ def __init__(self):
+ """Initialize text preprocessor."""
+ self.bad_punct_regex = re.compile(
+ r"["
+ + "#®•©™&@·º½¾¿¡§~"
+ + r"\)"
+ + r"\("
+ + r"\]"
+ + r"\["
+ + r"\}"
+ + r"\{"
+ + r"\|"
+ + r"\\"
+ + r"\/"
+ + r"\*"
+ + r"]{1,}"
+ )
+
+ def clean_text(self, text: str) -> str:
+ """Clean text using comprehensive text processing logic."""
+ # See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py
+ text = str(text)
+ text = ul.unquote_plus(text)
+ text = text.strip().lower()
+ text = re.sub("", "person", text)
+
+ # Remove all urls:
+ text = re.sub(
+ r"\b((?:https?|www):(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@))",
+ "",
+ text,
+ ) # regex for urls
+
+ # @
+ text = re.sub(r"@[\w\d]+\b", "", text)
+
+ # 31C0—31EF CJK Strokes through 4E00—9FFF CJK Unified Ideographs
+ text = re.sub(r"[\u31c0-\u31ef]+", "", text)
+ text = re.sub(r"[\u31f0-\u31ff]+", "", text)
+ text = re.sub(r"[\u3200-\u32ff]+", "", text)
+ text = re.sub(r"[\u3300-\u33ff]+", "", text)
+ text = re.sub(r"[\u3400-\u4dbf]+", "", text)
+ text = re.sub(r"[\u4dc0-\u4dff]+", "", text)
+ text = re.sub(r"[\u4e00-\u9fff]+", "", text)
+
+ # все виды тире / all types of dash --> "-"
+ text = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",
+ "-",
+ text,
+ )
+
+ # кавычки к одному стандарту
+ text = re.sub(r"[`´«»" "¨]", '"', text)
+ text = re.sub(r"['']", "'", text)
+
+ # " and &
+ text = re.sub(r""?", "", text)
+ text = re.sub(r"&", "", text)
+
+ # ip addresses:
+ text = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", text)
+
+ # article ids:
+ text = re.sub(r"\d:\d\d\s+$", "", text)
+
+ # \n
+ text = re.sub(r"\\n", " ", text)
+
+ # "#123", "#12345..", "123456.."
+ text = re.sub(r"#\d{1,3}\b", "", text)
+ text = re.sub(r"#\d{5,}\b", "", text)
+ text = re.sub(r"\b\d{6,}\b", "", text)
+
+ # filenames:
+ text = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", text)
+
+ # Clean punctuation
+ text = re.sub(r"[\"\']{2,}", r'"', text) # """AUSVERKAUFT"""
+ text = re.sub(r"[\.]{2,}", r" ", text)
+
+ text = re.sub(self.bad_punct_regex, r" ", text) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ text = re.sub(r"\s+\.\s+", r" ", text) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, text)) > 3:
+ text = re.sub(regex2, " ", text)
+
+ # Basic cleaning
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ text = text.strip()
+
+ # Clean alphanumeric patterns
+ text = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", text) # jc6640
+ text = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", text) # jc6640vc
+ text = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", text) # 6640vc231
+
+ # Common spam patterns
+ text = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", text)
+ text = re.sub(r"(free\s)?download(\sfree)?", "", text)
+ text = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", text)
+ text = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", text)
+ text = re.sub(r"\bpage\s+\d+\b", "", text)
+
+ text = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", text) # j2d1a2a...
+ text = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", text)
+
+ # Final cleanup
+ text = re.sub(r"\b\s+\:\s+", r": ", text)
+ text = re.sub(r"(\D[,\./])\b", r"\1 ", text)
+ text = re.sub(r"\s+", " ", text)
+
+ text.strip()
+
+ text = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", text)
+ text = re.sub(r"^[\'\_,\-\:;]", r"", text)
+ text = re.sub(r"[\'\_,\-\:\-\+]$", r"", text)
+ text = re.sub(r"^\.\S+$", "", text)
+
+ return text.strip()
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import PhotonPipeline
+
+ >>> # Load pipeline with from_pretrained
+ >>> pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint")
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
+ >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
+ >>> image.save("photon_output.png")
+ ```
+"""
+
+
+class PhotonPipeline(
+ DiffusionPipeline,
+ LoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Photon Transformer.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ transformer ([`PhotonTransformer2DModel`]):
+ The Photon transformer model to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ text_encoder ([`T5EncoderModel`]):
+ Standard text encoder model for encoding prompts.
+ tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]):
+ Tokenizer for the text encoder.
+ vae ([`AutoencoderKL`] or [`AutoencoderDC`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents"]
+ _optional_components = []
+
+ def __init__(
+ self,
+ transformer: PhotonTransformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ text_encoder: Union[T5EncoderModel, Any],
+ tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer],
+ vae: Union[AutoencoderKL, AutoencoderDC],
+ ):
+ super().__init__()
+
+ if PhotonTransformer2DModel is None:
+ raise ImportError(
+ "PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed."
+ )
+
+ # Extract encoder if text_encoder is T5GemmaModel
+ if hasattr(text_encoder, "encoder"):
+ text_encoder = text_encoder.encoder
+
+ self.text_encoder = text_encoder
+ self.tokenizer = tokenizer
+ self.text_preprocessor = TextPreprocessor()
+
+ self.register_modules(
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ )
+
+ # Enhance VAE with universal properties for both AutoencoderKL and AutoencoderDC
+ self._enhance_vae_properties()
+
+ # Set image processor using vae_scale_factor property
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ # Set default dimensions from transformer config
+ self.default_sample_size = (
+ self.transformer.config.sample_size
+ if hasattr(self, "transformer")
+ and self.transformer is not None
+ and hasattr(self.transformer.config, "sample_size")
+ else 64
+ )
+
+ def _enhance_vae_properties(self):
+ """Add universal properties to VAE for consistent interface across AutoencoderKL and AutoencoderDC."""
+ if not hasattr(self, "vae") or self.vae is None:
+ return
+
+ # Set spatial_compression_ratio property
+ if hasattr(self.vae, "spatial_compression_ratio"):
+ # AutoencoderDC already has this property
+ pass
+ elif hasattr(self.vae, "config") and hasattr(self.vae.config, "block_out_channels"):
+ # AutoencoderKL: calculate from block_out_channels
+ self.vae.spatial_compression_ratio = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ else:
+ # Fallback
+ self.vae.spatial_compression_ratio = 8
+
+ # Set scaling_factor property with safe defaults
+ if hasattr(self.vae, "config"):
+ self.vae.scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
+ else:
+ self.vae.scaling_factor = 0.18215
+
+ # Set shift_factor property with safe defaults (0.0 for AutoencoderDC)
+ if hasattr(self.vae, "config"):
+ shift_factor = getattr(self.vae.config, "shift_factor", None)
+ if shift_factor is None: # AutoencoderDC case
+ self.vae.shift_factor = 0.0
+ else:
+ self.vae.shift_factor = shift_factor
+ else:
+ self.vae.shift_factor = 0.0
+
+ # Set latent_channels property (like VaeTower does)
+ if hasattr(self.vae, "config") and hasattr(self.vae.config, "latent_channels"):
+ # AutoencoderDC has latent_channels in config
+ self.vae.latent_channels = int(self.vae.config.latent_channels)
+ elif hasattr(self.vae, "config") and hasattr(self.vae.config, "in_channels"):
+ # AutoencoderKL has in_channels in config
+ self.vae.latent_channels = int(self.vae.config.in_channels)
+ else:
+ # Fallback based on VAE type - DC-AE typically has 32, AutoencoderKL has 4/16
+ if hasattr(self.vae, "spatial_compression_ratio") and self.vae.spatial_compression_ratio == 32:
+ self.vae.latent_channels = 32 # DC-AE default
+ else:
+ self.vae.latent_channels = 4 # AutoencoderKL default
+
+ @property
+ def vae_scale_factor(self):
+ """Compatibility property that returns spatial compression ratio."""
+ return getattr(self.vae, "spatial_compression_ratio", 8)
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ):
+ """Prepare initial latents for the diffusion process."""
+ if latents is None:
+ latent_height, latent_width = (
+ height // self.vae.spatial_compression_ratio,
+ width // self.vae.spatial_compression_ratio,
+ )
+ shape = (batch_size, num_channels_latents, latent_height, latent_width)
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # FlowMatchEulerDiscreteScheduler doesn't use init_noise_sigma scaling
+ return latents
+
+ def encode_prompt(self, prompt: Union[str, List[str]], device: torch.device):
+ """Encode text prompt using standard text encoder and tokenizer."""
+ if isinstance(prompt, str):
+ prompt = [prompt]
+
+ return self._encode_prompt_standard(prompt, device)
+
+ def _encode_prompt_standard(self, prompt: List[str], device: torch.device):
+ """Encode prompt using standard text encoder and tokenizer with batch processing."""
+ # Clean text using modular preprocessor
+ cleaned_prompts = [self.text_preprocessor.clean_text(text) for text in prompt]
+ cleaned_uncond_prompts = [self.text_preprocessor.clean_text("") for _ in prompt]
+
+ # Batch conditional and unconditional prompts together for efficiency
+ all_prompts = cleaned_prompts + cleaned_uncond_prompts
+
+ # Tokenize all prompts in one batch
+ tokens = self.tokenizer(
+ all_prompts,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+
+ input_ids = tokens["input_ids"].to(device)
+ attention_mask = tokens["attention_mask"].bool().to(device)
+
+ # Encode all prompts in one batch
+ with torch.no_grad():
+ # Disable autocast like in TextTower
+ with torch.autocast("cuda", enabled=False):
+ emb = self.text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ )
+
+ # Use last hidden state (matching TextTower's use_last_hidden_state=True default)
+ all_embeddings = emb["last_hidden_state"]
+
+ # Split back into conditional and unconditional
+ batch_size = len(prompt)
+ text_embeddings = all_embeddings[:batch_size]
+ uncond_text_embeddings = all_embeddings[batch_size:]
+
+ cross_attn_mask = attention_mask[:batch_size]
+ uncond_cross_attn_mask = attention_mask[batch_size:]
+
+ return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask
+
+ def check_inputs(
+ self,
+ prompt: Union[str, List[str]],
+ height: int,
+ width: int,
+ guidance_scale: float,
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
+ ):
+ """Check that all inputs are in correct format."""
+ if height % self.vae.spatial_compression_ratio != 0 or width % self.vae.spatial_compression_ratio != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {self.vae.spatial_compression_ratio} but are {height} and {width}."
+ )
+
+ if guidance_scale < 1.0:
+ raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}")
+
+ if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}"
+ )
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 28):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`.
+ `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed
+ in the `._callback_tensor_inputs` attribute.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.photon.PhotonPipelineOutput`] or `tuple`: [`~pipelines.photon.PhotonPipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ # 0. Default height and width to transformer config
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ guidance_scale,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError("prompt must be provided as a string or list of strings")
+
+ device = self._execution_device
+
+ # 2. Encode input prompt
+ text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt(
+ prompt, device
+ )
+
+ # 3. Prepare timesteps
+ if timesteps is not None:
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
+ timesteps = self.scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.vae.latent_channels # From your transformer config
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare extra step kwargs
+ extra_step_kwargs = {}
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_eta:
+ extra_step_kwargs["eta"] = 0.0
+
+ # 6. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Duplicate latents for CFG
+ latents_in = torch.cat([latents, latents], dim=0)
+
+ # Cross-attention batch (uncond, cond)
+ ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
+ ca_mask = None
+ if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
+ ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
+
+ # Normalize timestep for the transformer
+ t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device)
+
+ # Process inputs for transformer
+ img_seq, txt, pe = self.transformer._process_inputs(latents_in, ca_embed)
+
+ # Forward through transformer layers
+ img_seq = self.transformer._forward_transformers(
+ img_seq,
+ txt,
+ time_embedding=self.transformer._compute_timestep_embedding(t_cont, img_seq.dtype),
+ pe=pe,
+ attention_mask=ca_mask,
+ )
+
+ # Convert back to image format
+ noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape)
+
+ # Apply CFG
+ noise_uncond, noise_text = noise_both.chunk(2, dim=0)
+ noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
+
+ # Compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_on_step_end(self, i, t, callback_kwargs)
+
+ # Call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # 8. Post-processing
+ if output_type == "latent":
+ image = latents
+ else:
+ # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC)
+ latents = (latents / self.vae.scaling_factor) + self.vae.shift_factor
+ # Decode using VAE (AutoencoderKL or AutoencoderDC)
+ image = self.vae.decode(latents, return_dict=False)[0]
+ # Use standard image processor for post-processing
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return PhotonPipelineOutput(images=image)
diff --git a/tests/models/transformers/test_models_transformer_photon.py b/tests/models/transformers/test_models_transformer_photon.py
new file mode 100644
index 000000000000..2f08484d230c
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_photon.py
@@ -0,0 +1,244 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = PhotonTransformer2DModel
+ main_input_name = "image_latent"
+
+ @property
+ def dummy_input(self):
+ return self.prepare_dummy_input()
+
+ @property
+ def input_shape(self):
+ return (16, 4, 4)
+
+ @property
+ def output_shape(self):
+ return (16, 4, 4)
+
+ def prepare_dummy_input(self, height=32, width=32):
+ batch_size = 1
+ num_latent_channels = 16
+ sequence_length = 16
+ embedding_dim = 1792
+
+ image_latent = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device)
+ cross_attn_conditioning = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ micro_conditioning = torch.randn((batch_size, embedding_dim)).to(torch_device)
+ timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+
+ return {
+ "image_latent": image_latent,
+ "timestep": timestep,
+ "cross_attn_conditioning": cross_attn_conditioning,
+ "micro_conditioning": micro_conditioning,
+ }
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 16,
+ "patch_size": 2,
+ "context_in_dim": 1792,
+ "hidden_size": 1792,
+ "mlp_ratio": 3.5,
+ "num_heads": 28,
+ "depth": 4, # Smaller depth for testing
+ "axes_dim": [32, 32],
+ "theta": 10_000,
+ }
+ inputs_dict = self.prepare_dummy_input()
+ return init_dict, inputs_dict
+
+ def test_forward_signature(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ # Test forward
+ outputs = model(**inputs_dict)
+
+ self.assertIsNotNone(outputs)
+ expected_shape = inputs_dict["image_latent"].shape
+ self.assertEqual(outputs.sample.shape, expected_shape)
+
+ def test_model_initialization(self):
+ # Test model initialization
+ model = PhotonTransformer2DModel(
+ in_channels=16,
+ patch_size=2,
+ context_in_dim=1792,
+ hidden_size=1792,
+ mlp_ratio=3.5,
+ num_heads=28,
+ depth=4,
+ axes_dim=[32, 32],
+ theta=10_000,
+ )
+ self.assertEqual(model.config.in_channels, 16)
+ self.assertEqual(model.config.hidden_size, 1792)
+ self.assertEqual(model.config.num_heads, 28)
+
+ def test_model_with_dict_config(self):
+ # Test model initialization with from_config
+ config_dict = {
+ "in_channels": 16,
+ "patch_size": 2,
+ "context_in_dim": 1792,
+ "hidden_size": 1792,
+ "mlp_ratio": 3.5,
+ "num_heads": 28,
+ "depth": 4,
+ "axes_dim": [32, 32],
+ "theta": 10_000,
+ }
+
+ model = PhotonTransformer2DModel.from_config(config_dict)
+ self.assertEqual(model.config.in_channels, 16)
+ self.assertEqual(model.config.hidden_size, 1792)
+
+ def test_process_inputs(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ img_seq, txt, pe = model._process_inputs(
+ inputs_dict["image_latent"], inputs_dict["cross_attn_conditioning"]
+ )
+
+ # Check shapes
+ batch_size = inputs_dict["image_latent"].shape[0]
+ height, width = inputs_dict["image_latent"].shape[2:]
+ patch_size = init_dict["patch_size"]
+ expected_seq_len = (height // patch_size) * (width // patch_size)
+
+ self.assertEqual(img_seq.shape, (batch_size, expected_seq_len, init_dict["in_channels"] * patch_size**2))
+ self.assertEqual(
+ txt.shape, (batch_size, inputs_dict["cross_attn_conditioning"].shape[1], init_dict["hidden_size"])
+ )
+ # Check that pe has the correct batch size, sequence length and some embedding dimension
+ self.assertEqual(pe.shape[0], batch_size) # batch size
+ self.assertEqual(pe.shape[1], 1) # unsqueeze(1) in EmbedND
+ self.assertEqual(pe.shape[2], expected_seq_len) # sequence length
+ self.assertEqual(pe.shape[-2:], (2, 2)) # rope rearrange output
+
+ def test_forward_transformers(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ # Process inputs first
+ img_seq, txt, pe = model._process_inputs(
+ inputs_dict["image_latent"], inputs_dict["cross_attn_conditioning"]
+ )
+
+ # Test forward_transformers
+ output_seq = model._forward_transformers(img_seq, txt, timestep=inputs_dict["timestep"], pe=pe)
+
+ # Check output shape
+ expected_out_channels = init_dict["in_channels"] * init_dict["patch_size"] ** 2
+ self.assertEqual(output_seq.shape, (img_seq.shape[0], img_seq.shape[1], expected_out_channels))
+
+ def test_attention_mask(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ # Create attention mask
+ batch_size = inputs_dict["cross_attn_conditioning"].shape[0]
+ seq_len = inputs_dict["cross_attn_conditioning"].shape[1]
+ attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool).to(torch_device)
+ attention_mask[:, seq_len // 2 :] = False # Mask second half
+
+ with torch.no_grad():
+ outputs = model(**inputs_dict, cross_attn_mask=attention_mask)
+
+ self.assertIsNotNone(outputs)
+ expected_shape = inputs_dict["image_latent"].shape
+ self.assertEqual(outputs.sample.shape, expected_shape)
+
+ def test_invalid_config(self):
+ # Test invalid configuration - hidden_size not divisible by num_heads
+ with self.assertRaises(ValueError):
+ PhotonTransformer2DModel(
+ in_channels=16,
+ patch_size=2,
+ context_in_dim=1792,
+ hidden_size=1793, # Not divisible by 28
+ mlp_ratio=3.5,
+ num_heads=28,
+ depth=4,
+ axes_dim=[32, 32],
+ theta=10_000,
+ )
+
+ # Test invalid axes_dim that doesn't sum to pe_dim
+ with self.assertRaises(ValueError):
+ PhotonTransformer2DModel(
+ in_channels=16,
+ patch_size=2,
+ context_in_dim=1792,
+ hidden_size=1792,
+ mlp_ratio=3.5,
+ num_heads=28,
+ depth=4,
+ axes_dim=[30, 30], # Sum = 60, but pe_dim = 1792/28 = 64
+ theta=10_000,
+ )
+
+ def test_gradient_checkpointing_enable(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+
+ # Enable gradient checkpointing
+ model.enable_gradient_checkpointing()
+
+ # Check that _activation_checkpointing is set
+ for block in model.blocks:
+ self.assertTrue(hasattr(block, "_activation_checkpointing"))
+
+ def test_from_config(self):
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+
+ # Create model from config
+ model = self.model_class.from_config(init_dict)
+ self.assertIsInstance(model, self.model_class)
+ self.assertEqual(model.config.in_channels, init_dict["in_channels"])
+
+
+if __name__ == "__main__":
+ unittest.main()