From 016316a57a8a0df9cacca66af2fa290862b8d5c4 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 10:20:19 +0200 Subject: [PATCH 01/38] mirage pipeline first commit --- src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_mirage.py | 489 ++++++++++++++ src/diffusers/pipelines/__init__.py | 1 + src/diffusers/pipelines/mirage/__init__.py | 4 + .../pipelines/mirage/pipeline_mirage.py | 629 ++++++++++++++++++ .../pipelines/mirage/pipeline_output.py | 35 + .../test_models_transformer_mirage.py | 252 +++++++ 9 files changed, 1413 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_mirage.py create mode 100644 src/diffusers/pipelines/mirage/__init__.py create mode 100644 src/diffusers/pipelines/mirage/pipeline_mirage.py create mode 100644 src/diffusers/pipelines/mirage/pipeline_output.py create mode 100644 tests/models/transformers/test_models_transformer_mirage.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8867250deda8..6fc6ac5f3ebd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -224,6 +224,7 @@ "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", + "MirageTransformer2DModel", "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 457f70448af3..279e69216b1b 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -93,6 +93,7 @@ _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] + _import_structure["transformers.transformer_mirage"] = ["MirageTransformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index b60f0636e6dc..ebe0d0c9b8e1 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -29,6 +29,7 @@ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel + from .transformer_mirage import MirageTransformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py new file mode 100644 index 000000000000..39c569cbb26b --- /dev/null +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -0,0 +1,489 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union, Tuple +import torch +import math +from torch import Tensor, nn +from torch.nn.functional import fold, unfold +from einops import rearrange +from einops.layers.torch import Rearrange + +from ...configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..modeling_outputs import Transformer2DModelOutput +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers + + +logger = logging.get_logger(__name__) + + +# Mirage Layer Components +def get_image_ids(bs: int, h: int, w: int, patch_size: int, device: torch.device) -> Tensor: + img_ids = torch.zeros(h // patch_size, w // patch_size, 2, device=device) + img_ids[..., 0] = torch.arange(h // patch_size, device=device)[:, None] + img_ids[..., 1] = torch.arange(w // patch_size, device=device)[None, :] + return img_ids.reshape((h // patch_size) * (w // patch_size), 2).unsqueeze(0).repeat(bs, 1, 1) + + +def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq) + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + self.rope_rearrange = Rearrange("b n d (i j) -> b n d i j", i=2, j=2) + + def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = pos.unsqueeze(-1) * omega.unsqueeze(0) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = self.rope_rearrange(out) + return out.float() + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000, time_factor: float = 1000.0) -> Tensor: + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms * self.scale).to(dtype=x_dtype) + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.lin = nn.Linear(dim, 6 * dim, bias=True) + nn.init.constant_(self.lin.weight, 0) + nn.init.constant_(self.lin.bias, 0) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) + return ModulationOut(*out[:3]), ModulationOut(*out[3:]) + + +class MirageBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + + self._fsdp_wrap = True + self._activation_checkpointing = True + + self.hidden_dim = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.hidden_size = hidden_size + + # img qkv + self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_qkv_proj = nn.Linear(hidden_size, hidden_size * 3, bias=False) + self.attn_out = nn.Linear(hidden_size, hidden_size, bias=False) + self.qk_norm = QKNorm(self.head_dim) + + # txt kv + self.txt_kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=False) + self.k_norm = RMSNorm(self.head_dim) + + + # mlp + self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) + self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) + self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False) + self.mlp_act = nn.GELU(approximate="tanh") + + self.modulation = Modulation(hidden_size) + self.spatial_cond_kv_proj: None | nn.Linear = None + + def attn_forward( + self, + img: Tensor, + txt: Tensor, + pe: Tensor, + modulation: ModulationOut, + spatial_conditioning: None | Tensor = None, + attention_mask: None | Tensor = None, + ) -> Tensor: + # image tokens proj and norm + img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift + + img_qkv = self.img_qkv_proj(img_mod) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.qk_norm(img_q, img_k, img_v) + + # txt tokens proj and norm + txt_kv = self.txt_kv_proj(txt) + txt_k, txt_v = rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) + txt_k = self.k_norm(txt_k) + + # compute attention + img_q, img_k = apply_rope(img_q, pe), apply_rope(img_k, pe) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + # optional spatial conditioning tokens + cond_len = 0 + if self.spatial_cond_kv_proj is not None: + assert spatial_conditioning is not None + cond_kv = self.spatial_cond_kv_proj(spatial_conditioning) + cond_k, cond_v = rearrange(cond_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) + cond_k = apply_rope(cond_k, pe) + cond_len = cond_k.shape[2] + k = torch.cat((cond_k, k), dim=2) + v = torch.cat((cond_v, v), dim=2) + + # build additive attention bias + attn_bias: Tensor | None = None + attn_mask: Tensor | None = None + + # build multiplicative 0/1 mask for provided attention_mask over [cond?, text, image] keys + if attention_mask is not None: + bs, _, l_img, _ = img_q.shape + l_txt = txt_k.shape[2] + l_all = k.shape[2] + + assert attention_mask.dim() == 2, f"Unsupported attention_mask shape: {attention_mask.shape}" + assert ( + attention_mask.shape[-1] == l_txt + ), f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" + + device = img_q.device + + ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device) + cond_mask = torch.ones((bs, cond_len), dtype=torch.bool, device=device) + + mask_parts = [ + cond_mask, + attention_mask.to(torch.bool), + ones_img, + ] + joint_mask = torch.cat(mask_parts, dim=-1) # (B, L_all) + + # repeat across heads and query positions + attn_mask = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1) # (B,H,L_img,L_all) + + attn = torch.nn.functional.scaled_dot_product_attention( + img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask + ) + attn = rearrange(attn, "B H L D -> B L (H D)") + attn = self.attn_out(attn) + + return attn + + def ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor: + x = (1 + modulation.scale) * self.post_attention_layernorm(x) + modulation.shift + return self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)) + + def forward( + self, + img: Tensor, + txt: Tensor, + vec: Tensor, + pe: Tensor, + spatial_conditioning: Tensor | None = None, + attention_mask: Tensor | None = None, + **_: dict[str, Any], + ) -> Tensor: + mod_attn, mod_mlp = self.modulation(vec) + + img = img + mod_attn.gate * self.attn_forward( + img, + txt, + pe, + mod_attn, + spatial_conditioning=spatial_conditioning, + attention_mask=attention_mask, + ) + img = img + mod_mlp.gate * self.ffn_forward(img, mod_mlp) + return img + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + nn.init.constant_(self.adaLN_modulation[1].weight, 0) + nn.init.constant_(self.adaLN_modulation[1].bias, 0) + nn.init.constant_(self.linear.weight, 0) + nn.init.constant_(self.linear.bias, 0) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +@dataclass +class MirageParams: + in_channels: int + patch_size: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + axes_dim: list[int] + theta: int + time_factor: float = 1000.0 + time_max_period: int = 10_000 + conditioning_block_ids: list[int] | None = None + + +def img2seq(img: Tensor, patch_size: int) -> Tensor: + """Flatten an image into a sequence of patches""" + return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) + + +def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: + """Revert img2seq""" + if isinstance(shape, tuple): + shape = shape[-2:] + elif isinstance(shape, torch.Tensor): + shape = (int(shape[0]), int(shape[1])) + else: + raise NotImplementedError(f"shape type {type(shape)} not supported") + return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) + + +class MirageTransformer2DModel(ModelMixin, ConfigMixin): + """Mirage Transformer model with IP-Adapter support.""" + + config_name = "config.json" + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 16, + patch_size: int = 2, + context_in_dim: int = 2304, + hidden_size: int = 1792, + mlp_ratio: float = 3.5, + num_heads: int = 28, + depth: int = 16, + axes_dim: list = None, + theta: int = 10000, + time_factor: float = 1000.0, + time_max_period: int = 10000, + conditioning_block_ids: list = None, + **kwargs + ): + super().__init__() + + if axes_dim is None: + axes_dim = [32, 32] + + # Create MirageParams from the provided arguments + params = MirageParams( + in_channels=in_channels, + patch_size=patch_size, + context_in_dim=context_in_dim, + hidden_size=hidden_size, + mlp_ratio=mlp_ratio, + num_heads=num_heads, + depth=depth, + axes_dim=axes_dim, + theta=theta, + time_factor=time_factor, + time_max_period=time_max_period, + conditioning_block_ids=conditioning_block_ids, + ) + + self.params = params + self.in_channels = params.in_channels + self.patch_size = params.patch_size + self.out_channels = self.in_channels * self.patch_size**2 + + self.time_factor = params.time_factor + self.time_max_period = params.time_max_period + + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + + pe_dim = params.hidden_size // params.num_heads + + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + conditioning_block_ids: list[int] = params.conditioning_block_ids or list(range(params.depth)) + + self.blocks = nn.ModuleList( + [ + MirageBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + ) + for i in range(params.depth) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: + """Timestep independent stuff""" + txt = self.txt_in(txt) + img = img2seq(image_latent, self.patch_size) + bs, _, h, w = image_latent.shape + img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device) + pe = self.pe_embedder(img_ids) + return img, txt, pe + + def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: + return self.time_in( + timestep_embedding( + t=timestep, dim=256, max_period=self.time_max_period, time_factor=self.time_factor + ).to(dtype) + ) + + def forward_transformers( + self, + image_latent: Tensor, + cross_attn_conditioning: Tensor, + timestep: Optional[Tensor] = None, + time_embedding: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + **block_kwargs: Any, + ) -> Tensor: + img = self.img_in(image_latent) + + if time_embedding is not None: + vec = time_embedding + else: + if timestep is None: + raise ValueError("Please provide either a timestep or a timestep_embedding") + vec = self.compute_timestep_embedding(timestep, dtype=img.dtype) + + for block in self.blocks: + img = block( + img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs + ) + + img = self.final_layer(img, vec) + return img + + def forward( + self, + image_latent: Tensor, + timestep: Tensor, + cross_attn_conditioning: Tensor, + micro_conditioning: Tensor, + cross_attn_mask: None | Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + img_seq, txt, pe = self.process_inputs(image_latent, cross_attn_conditioning) + img_seq = self.forward_transformers(img_seq, txt, timestep, pe=pe, attention_mask=cross_attn_mask) + output = seq2img(img_seq, self.patch_size, image_latent.shape) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 190c7871d270..7b7ebb633c3b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -144,6 +144,7 @@ "FluxKontextPipeline", "FluxKontextInpaintPipeline", ] + _import_structure["mirage"] = ["MiragePipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", diff --git a/src/diffusers/pipelines/mirage/__init__.py b/src/diffusers/pipelines/mirage/__init__.py new file mode 100644 index 000000000000..4fd8ad191b3f --- /dev/null +++ b/src/diffusers/pipelines/mirage/__init__.py @@ -0,0 +1,4 @@ +from .pipeline_mirage import MiragePipeline +from .pipeline_output import MiragePipelineOutput + +__all__ = ["MiragePipeline", "MiragePipelineOutput"] \ No newline at end of file diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py new file mode 100644 index 000000000000..126eab07977c --- /dev/null +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -0,0 +1,629 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +from typing import Any, Callable, Dict, List, Optional, Union + +import html +import re +import urllib.parse as ul + +import ftfy +import torch +from transformers import ( + AutoTokenizer, + GemmaTokenizerFast, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, AutoencoderDC +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import MiragePipelineOutput + +try: + from ...models.transformers.transformer_mirage import MirageTransformer2DModel +except ImportError: + MirageTransformer2DModel = None + +logger = logging.get_logger(__name__) + + +class TextPreprocessor: + """Text preprocessing utility for MiragePipeline.""" + + def __init__(self): + """Initialize text preprocessor.""" + self.bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + r"\\" + r"\/" + r"\*" + r"]{1,}" + ) + + def clean_text(self, text: str) -> str: + """Clean text using comprehensive text processing logic.""" + # See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py + text = str(text) + text = ul.unquote_plus(text) + text = text.strip().lower() + text = re.sub("", "person", text) + + # Remove all urls: + text = re.sub( + r"\b((?:https?|www):(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@))", + "", + text, + ) # regex for urls + + # @ + text = re.sub(r"@[\w\d]+\b", "", text) + + # 31C0—31EF CJK Strokes through 4E00—9FFF CJK Unified Ideographs + text = re.sub(r"[\u31c0-\u31ef]+", "", text) + text = re.sub(r"[\u31f0-\u31ff]+", "", text) + text = re.sub(r"[\u3200-\u32ff]+", "", text) + text = re.sub(r"[\u3300-\u33ff]+", "", text) + text = re.sub(r"[\u3400-\u4dbf]+", "", text) + text = re.sub(r"[\u4dc0-\u4dff]+", "", text) + text = re.sub(r"[\u4e00-\u9fff]+", "", text) + + # все виды тире / all types of dash --> "-" + text = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", + "-", + text, + ) + + # кавычки к одному стандарту + text = re.sub(r"[`´«»""¨]", '"', text) + text = re.sub(r"['']", "'", text) + + # " and & + text = re.sub(r""?", "", text) + text = re.sub(r"&", "", text) + + # ip addresses: + text = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", text) + + # article ids: + text = re.sub(r"\d:\d\d\s+$", "", text) + + # \n + text = re.sub(r"\\n", " ", text) + + # "#123", "#12345..", "123456.." + text = re.sub(r"#\d{1,3}\b", "", text) + text = re.sub(r"#\d{5,}\b", "", text) + text = re.sub(r"\b\d{6,}\b", "", text) + + # filenames: + text = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", text) + + # Clean punctuation + text = re.sub(r"[\"\']{2,}", r'"', text) # """AUSVERKAUFT""" + text = re.sub(r"[\.]{2,}", r" ", text) + + text = re.sub(self.bad_punct_regex, r" ", text) # ***AUSVERKAUFT***, #AUSVERKAUFT + text = re.sub(r"\s+\.\s+", r" ", text) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, text)) > 3: + text = re.sub(regex2, " ", text) + + # Basic cleaning + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + text = text.strip() + + # Clean alphanumeric patterns + text = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", text) # jc6640 + text = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", text) # jc6640vc + text = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", text) # 6640vc231 + + # Common spam patterns + text = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", text) + text = re.sub(r"(free\s)?download(\sfree)?", "", text) + text = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", text) + text = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", text) + text = re.sub(r"\bpage\s+\d+\b", "", text) + + text = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", text) # j2d1a2a... + text = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", text) + + # Final cleanup + text = re.sub(r"\b\s+\:\s+", r": ", text) + text = re.sub(r"(\D[,\./])\b", r"\1 ", text) + text = re.sub(r"\s+", " ", text) + + text.strip() + + text = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", text) + text = re.sub(r"^[\'\_,\-\:;]", r"", text) + text = re.sub(r"[\'\_,\-\:\-\+]$", r"", text) + text = re.sub(r"^\.\S+$", "", text) + + return text.strip() + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import MiragePipeline + >>> from diffusers.models import AutoencoderKL, AutoencoderDC + >>> from transformers import T5GemmaModel, GemmaTokenizerFast + + >>> # Load pipeline directly with from_pretrained + >>> pipe = MiragePipeline.from_pretrained("path/to/mirage_checkpoint") + + >>> # Or initialize pipeline components manually + >>> transformer = MirageTransformer2DModel.from_pretrained("path/to/transformer") + >>> scheduler = FlowMatchEulerDiscreteScheduler() + >>> # Load T5Gemma encoder + >>> t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") + >>> text_encoder = t5gemma_model.encoder + >>> tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") + >>> vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") + + >>> pipe = MiragePipeline( + ... transformer=transformer, + ... scheduler=scheduler, + ... text_encoder=text_encoder, + ... tokenizer=tokenizer, + ... vae=vae + ... ) + >>> pipe.to("cuda") + >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" + >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] + >>> image.save("mirage_output.png") + ``` +""" + + +class MiragePipeline( + DiffusionPipeline, + LoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + Pipeline for text-to-image generation using Mirage Transformer. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + transformer ([`MirageTransformer2DModel`]): + The Mirage transformer model to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + text_encoder ([`T5EncoderModel`]): + Standard text encoder model for encoding prompts. + tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]): + Tokenizer for the text encoder. + vae ([`AutoencoderKL`] or [`AutoencoderDC`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents"] + _optional_components = [] + + # Component configurations for automatic loading + config_name = "model_index.json" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + """ + Override from_pretrained to ensure T5GemmaEncoder is available for loading. + + This ensures that T5GemmaEncoder from transformers is accessible in the module namespace + during component loading, which is required for MiragePipeline checkpoints that use + T5GemmaEncoder as the text encoder. + """ + # Ensure T5GemmaEncoder is available for loading + import transformers + if not hasattr(transformers, 'T5GemmaEncoder'): + try: + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + transformers.T5GemmaEncoder = T5GemmaEncoder + except ImportError: + # T5GemmaEncoder not available in this transformers version + pass + + # Proceed with standard loading + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + + def __init__( + self, + transformer: MirageTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder: Union[T5EncoderModel, Any], + tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], + vae: Union[AutoencoderKL, AutoencoderDC], + ): + super().__init__() + + if MirageTransformer2DModel is None: + raise ImportError( + "MirageTransformer2DModel is not available. Please ensure the transformer_mirage module is properly installed." + ) + + # Store standard components + self.text_encoder = text_encoder + self.tokenizer = tokenizer + + # Initialize text preprocessor + self.text_preprocessor = TextPreprocessor() + + self.register_modules( + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + ) + + # Enhance VAE with universal properties for both AutoencoderKL and AutoencoderDC + self._enhance_vae_properties() + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + + def _enhance_vae_properties(self): + """Add universal properties to VAE for consistent interface across AutoencoderKL and AutoencoderDC.""" + if not hasattr(self, "vae") or self.vae is None: + return + + # Set spatial_compression_ratio property + if hasattr(self.vae, "spatial_compression_ratio"): + # AutoencoderDC already has this property + pass + elif hasattr(self.vae, "config") and hasattr(self.vae.config, "block_out_channels"): + # AutoencoderKL: calculate from block_out_channels + self.vae.spatial_compression_ratio = 2 ** (len(self.vae.config.block_out_channels) - 1) + else: + # Fallback + self.vae.spatial_compression_ratio = 8 + + # Set scaling_factor property with safe defaults + if hasattr(self.vae, "config"): + self.vae.scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) + else: + self.vae.scaling_factor = 0.18215 + + # Set shift_factor property with safe defaults (0.0 for AutoencoderDC) + if hasattr(self.vae, "config"): + shift_factor = getattr(self.vae.config, "shift_factor", None) + if shift_factor is None: # AutoencoderDC case + self.vae.shift_factor = 0.0 + else: + self.vae.shift_factor = shift_factor + else: + self.vae.shift_factor = 0.0 + + # Set latent_channels property (like VaeTower does) + if hasattr(self.vae, "config") and hasattr(self.vae.config, "latent_channels"): + # AutoencoderDC has latent_channels in config + self.vae.latent_channels = int(self.vae.config.latent_channels) + elif hasattr(self.vae, "config") and hasattr(self.vae.config, "in_channels"): + # AutoencoderKL has in_channels in config + self.vae.latent_channels = int(self.vae.config.in_channels) + else: + # Fallback based on VAE type - DC-AE typically has 32, AutoencoderKL has 4/16 + if hasattr(self.vae, "spatial_compression_ratio") and self.vae.spatial_compression_ratio == 32: + self.vae.latent_channels = 32 # DC-AE default + else: + self.vae.latent_channels = 4 # AutoencoderKL default + + @property + def vae_scale_factor(self): + """Compatibility property that returns spatial compression ratio.""" + return getattr(self.vae, "spatial_compression_ratio", 8) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ): + """Prepare initial latents for the diffusion process.""" + if latents is None: + latent_height, latent_width = height // self.vae.spatial_compression_ratio, width // self.vae.spatial_compression_ratio + shape = (batch_size, num_channels_latents, latent_height, latent_width) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # FlowMatchEulerDiscreteScheduler doesn't use init_noise_sigma scaling + return latents + + def encode_prompt(self, prompt: Union[str, List[str]], device: torch.device): + """Encode text prompt using standard text encoder and tokenizer.""" + if isinstance(prompt, str): + prompt = [prompt] + + return self._encode_prompt_standard(prompt, device) + + def _encode_prompt_standard(self, prompt: List[str], device: torch.device): + """Encode prompt using standard text encoder and tokenizer with batch processing.""" + # Clean text using modular preprocessor + cleaned_prompts = [self.text_preprocessor.clean_text(text) for text in prompt] + cleaned_uncond_prompts = [self.text_preprocessor.clean_text("") for _ in prompt] + + # Batch conditional and unconditional prompts together for efficiency + all_prompts = cleaned_prompts + cleaned_uncond_prompts + + # Tokenize all prompts in one batch + tokens = self.tokenizer( + all_prompts, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + + input_ids = tokens["input_ids"].to(device) + attention_mask = tokens["attention_mask"].bool().to(device) + + # Encode all prompts in one batch + with torch.no_grad(): + # Disable autocast like in TextTower + with torch.autocast("cuda", enabled=False): + emb = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + + # Use last hidden state (matching TextTower's use_last_hidden_state=True default) + all_embeddings = emb["last_hidden_state"] + + # Split back into conditional and unconditional + batch_size = len(prompt) + text_embeddings = all_embeddings[:batch_size] + uncond_text_embeddings = all_embeddings[batch_size:] + + cross_attn_mask = attention_mask[:batch_size] + uncond_cross_attn_mask = attention_mask[batch_size:] + + return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask + + def check_inputs( + self, + prompt: Union[str, List[str]], + height: int, + width: int, + guidance_scale: float, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + ): + """Check that all inputs are in correct format.""" + if height % self.vae.spatial_compression_ratio != 0 or width % self.vae.spatial_compression_ratio != 0: + raise ValueError(f"`height` and `width` have to be divisible by {self.vae.spatial_compression_ratio} but are {height} and {width}.") + + if guidance_scale < 1.0: + raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}") + + if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}" + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.mirage.MiragePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`. + `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed + in the `._callback_tensor_inputs` attribute. + + Examples: + + Returns: + [`~pipelines.mirage.MiragePipelineOutput`] or `tuple`: [`~pipelines.mirage.MiragePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + # 0. Default height and width to transformer config + height = height or 256 + width = width or 256 + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + guidance_scale, + callback_on_step_end_tensor_inputs, + ) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError("prompt must be provided as a string or list of strings") + + device = self._execution_device + + # 2. Encode input prompt + text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt( + prompt, device + ) + + # 3. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 4. Prepare latent variables + num_channels_latents = self.vae.latent_channels # From your transformer config + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 5. Prepare extra step kwargs + extra_step_kwargs = {} + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_eta: + extra_step_kwargs["eta"] = 0.0 + + # 6. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Duplicate latents for CFG + latents_in = torch.cat([latents, latents], dim=0) + + # Cross-attention batch (uncond, cond) + ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) + ca_mask = None + if cross_attn_mask is not None and uncond_cross_attn_mask is not None: + ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) + + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) + + # Process inputs for transformer + img_seq, txt, pe = self.transformer.process_inputs(latents_in, ca_embed) + + # Forward through transformer layers + img_seq = self.transformer.forward_transformers( + img_seq, txt, time_embedding=self.transformer.compute_timestep_embedding(t_cont, img_seq.dtype), + pe=pe, attention_mask=ca_mask + ) + + # Convert back to image format + from ...models.transformers.transformer_mirage import seq2img + noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) + + # Apply CFG + noise_uncond, noise_text = noise_both.chunk(2, dim=0) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_on_step_end(self, i, t, callback_kwargs) + + # Call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 8. Post-processing + if output_type == "latent": + image = latents + else: + # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC) + latents = (latents / self.vae.scaling_factor) + self.vae.shift_factor + # Decode using VAE (AutoencoderKL or AutoencoderDC) + image = self.vae.decode(latents, return_dict=False)[0] + # Use standard image processor for post-processing + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return MiragePipelineOutput(images=image) \ No newline at end of file diff --git a/src/diffusers/pipelines/mirage/pipeline_output.py b/src/diffusers/pipelines/mirage/pipeline_output.py new file mode 100644 index 000000000000..e5cdb2a40924 --- /dev/null +++ b/src/diffusers/pipelines/mirage/pipeline_output.py @@ -0,0 +1,35 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class MiragePipelineOutput(BaseOutput): + """ + Output class for Mirage pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file diff --git a/tests/models/transformers/test_models_transformer_mirage.py b/tests/models/transformers/test_models_transformer_mirage.py new file mode 100644 index 000000000000..11accdaecbee --- /dev/null +++ b/tests/models/transformers/test_models_transformer_mirage.py @@ -0,0 +1,252 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel, MirageParams + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class MirageTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = MirageTransformer2DModel + main_input_name = "image_latent" + + @property + def dummy_input(self): + return self.prepare_dummy_input() + + @property + def input_shape(self): + return (16, 4, 4) + + @property + def output_shape(self): + return (16, 4, 4) + + def prepare_dummy_input(self, height=32, width=32): + batch_size = 1 + num_latent_channels = 16 + sequence_length = 16 + embedding_dim = 1792 + + image_latent = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device) + cross_attn_conditioning = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + micro_conditioning = torch.randn((batch_size, embedding_dim)).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "image_latent": image_latent, + "timestep": timestep, + "cross_attn_conditioning": cross_attn_conditioning, + "micro_conditioning": micro_conditioning, + } + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 16, + "patch_size": 2, + "context_in_dim": 1792, + "hidden_size": 1792, + "mlp_ratio": 3.5, + "num_heads": 28, + "depth": 4, # Smaller depth for testing + "axes_dim": [32, 32], + "theta": 10_000, + } + inputs_dict = self.prepare_dummy_input() + return init_dict, inputs_dict + + def test_forward_signature(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + # Test forward + outputs = model(**inputs_dict) + + self.assertIsNotNone(outputs) + expected_shape = inputs_dict["image_latent"].shape + self.assertEqual(outputs.shape, expected_shape) + + def test_mirage_params_initialization(self): + # Test model initialization + model = MirageTransformer2DModel( + in_channels=16, + patch_size=2, + context_in_dim=1792, + hidden_size=1792, + mlp_ratio=3.5, + num_heads=28, + depth=4, + axes_dim=[32, 32], + theta=10_000, + ) + self.assertEqual(model.config.in_channels, 16) + self.assertEqual(model.config.hidden_size, 1792) + self.assertEqual(model.config.num_heads, 28) + + def test_model_with_dict_config(self): + # Test model initialization with from_config + config_dict = { + "in_channels": 16, + "patch_size": 2, + "context_in_dim": 1792, + "hidden_size": 1792, + "mlp_ratio": 3.5, + "num_heads": 28, + "depth": 4, + "axes_dim": [32, 32], + "theta": 10_000, + } + + model = MirageTransformer2DModel.from_config(config_dict) + self.assertEqual(model.config.in_channels, 16) + self.assertEqual(model.config.hidden_size, 1792) + + def test_process_inputs(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + img_seq, txt, pe = model.process_inputs( + inputs_dict["image_latent"], + inputs_dict["cross_attn_conditioning"] + ) + + # Check shapes + batch_size = inputs_dict["image_latent"].shape[0] + height, width = inputs_dict["image_latent"].shape[2:] + patch_size = init_dict["patch_size"] + expected_seq_len = (height // patch_size) * (width // patch_size) + + self.assertEqual(img_seq.shape, (batch_size, expected_seq_len, init_dict["in_channels"] * patch_size**2)) + self.assertEqual(txt.shape, (batch_size, inputs_dict["cross_attn_conditioning"].shape[1], init_dict["hidden_size"])) + # Check that pe has the correct batch size, sequence length and some embedding dimension + self.assertEqual(pe.shape[0], batch_size) # batch size + self.assertEqual(pe.shape[1], 1) # unsqueeze(1) in EmbedND + self.assertEqual(pe.shape[2], expected_seq_len) # sequence length + self.assertEqual(pe.shape[-2:], (2, 2)) # rope rearrange output + + def test_forward_transformers(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + # Process inputs first + img_seq, txt, pe = model.process_inputs( + inputs_dict["image_latent"], + inputs_dict["cross_attn_conditioning"] + ) + + # Test forward_transformers + output_seq = model.forward_transformers( + img_seq, + txt, + timestep=inputs_dict["timestep"], + pe=pe + ) + + # Check output shape + expected_out_channels = init_dict["in_channels"] * init_dict["patch_size"]**2 + self.assertEqual(output_seq.shape, (img_seq.shape[0], img_seq.shape[1], expected_out_channels)) + + def test_attention_mask(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + # Create attention mask + batch_size = inputs_dict["cross_attn_conditioning"].shape[0] + seq_len = inputs_dict["cross_attn_conditioning"].shape[1] + attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool).to(torch_device) + attention_mask[:, seq_len//2:] = False # Mask second half + + with torch.no_grad(): + outputs = model( + **inputs_dict, + cross_attn_mask=attention_mask + ) + + self.assertIsNotNone(outputs) + expected_shape = inputs_dict["image_latent"].shape + self.assertEqual(outputs.shape, expected_shape) + + def test_invalid_config(self): + # Test invalid configuration - hidden_size not divisible by num_heads + with self.assertRaises(ValueError): + MirageTransformer2DModel( + in_channels=16, + patch_size=2, + context_in_dim=1792, + hidden_size=1793, # Not divisible by 28 + mlp_ratio=3.5, + num_heads=28, + depth=4, + axes_dim=[32, 32], + theta=10_000, + ) + + # Test invalid axes_dim that doesn't sum to pe_dim + with self.assertRaises(ValueError): + MirageTransformer2DModel( + in_channels=16, + patch_size=2, + context_in_dim=1792, + hidden_size=1792, + mlp_ratio=3.5, + num_heads=28, + depth=4, + axes_dim=[30, 30], # Sum = 60, but pe_dim = 1792/28 = 64 + theta=10_000, + ) + + def test_gradient_checkpointing_enable(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + # Enable gradient checkpointing + model.enable_gradient_checkpointing() + + # Check that _activation_checkpointing is set + for block in model.blocks: + self.assertTrue(hasattr(block, '_activation_checkpointing')) + + def test_from_config(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + # Create model from config + model = self.model_class.from_config(init_dict) + self.assertIsInstance(model, self.model_class) + self.assertEqual(model.config.in_channels, init_dict["in_channels"]) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 4ac274be3d7647655437c6b810d1daa5c650f093 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 11:51:14 +0200 Subject: [PATCH 02/38] use attention processors --- src/diffusers/models/attention_processor.py | 58 +++++++++++++ .../models/transformers/transformer_mirage.py | 86 ++++++++++++++++--- 2 files changed, 134 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 990245de1742..08e80e4329ba 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5609,6 +5609,63 @@ def __new__(cls, *args, **kwargs): return processor +class MirageAttnProcessor2_0: + r""" + Processor for implementing Mirage-style attention with multi-source tokens and RoPE. + Properly integrates with diffusers Attention module while handling Mirage-specific logic. + """ + + def __init__(self): + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError("MirageAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: "Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Apply Mirage attention using standard diffusers interface. + + Expected tensor formats from MirageBlock.attn_forward(): + - hidden_states: Image queries with RoPE applied [B, H, L_img, D] + - encoder_hidden_states: Packed key+value tensors [B, H, L_all, 2*D] + (concatenated keys and values from text + image + spatial conditioning) + - attention_mask: Custom attention mask [B, H, L_img, L_all] or None + """ + + if encoder_hidden_states is None: + raise ValueError( + "MirageAttnProcessor2_0 requires 'encoder_hidden_states' containing packed key+value tensors. " + "This should be provided by MirageBlock.attn_forward()." + ) + + # Unpack the combined key+value tensor + # encoder_hidden_states is [B, H, L_all, 2*D] containing [keys, values] + key, value = encoder_hidden_states.chunk(2, dim=-1) # Each [B, H, L_all, D] + + # Apply scaled dot-product attention with Mirage's processed tensors + # hidden_states is image queries [B, H, L_img, D] + attn_output = torch.nn.functional.scaled_dot_product_attention( + hidden_states.contiguous(), key.contiguous(), value.contiguous(), attn_mask=attention_mask + ) + + # Reshape from [B, H, L_img, D] to [B, L_img, H*D] + batch_size, num_heads, seq_len, head_dim = attn_output.shape + attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim) + + # Apply output projection using the diffusers Attention module + attn_output = attn.to_out[0](attn_output) + if len(attn.to_out) > 1: + attn_output = attn.to_out[1](attn_output) # dropout if present + + return attn_output + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, @@ -5657,6 +5714,7 @@ def __new__(cls, *args, **kwargs): PAGHunyuanAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, LuminaAttnProcessor2_0, + MirageAttnProcessor2_0, FusedAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 39c569cbb26b..0225b9532aff 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..modeling_outputs import Transformer2DModelOutput +from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers @@ -159,13 +160,21 @@ def __init__( # img qkv self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_qkv_proj = nn.Linear(hidden_size, hidden_size * 3, bias=False) - self.attn_out = nn.Linear(hidden_size, hidden_size, bias=False) self.qk_norm = QKNorm(self.head_dim) # txt kv self.txt_kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=False) self.k_norm = RMSNorm(self.head_dim) + self.attention = Attention( + query_dim=hidden_size, + heads=num_heads, + dim_head=self.head_dim, + bias=False, + out_bias=False, + processor=MirageAttnProcessor2_0(), + ) + # mlp self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -214,15 +223,11 @@ def attn_forward( k = torch.cat((cond_k, k), dim=2) v = torch.cat((cond_v, v), dim=2) - # build additive attention bias - attn_bias: Tensor | None = None - attn_mask: Tensor | None = None - # build multiplicative 0/1 mask for provided attention_mask over [cond?, text, image] keys + attn_mask: Tensor | None = None if attention_mask is not None: bs, _, l_img, _ = img_q.shape l_txt = txt_k.shape[2] - l_all = k.shape[2] assert attention_mask.dim() == 2, f"Unsupported attention_mask shape: {attention_mask.shape}" assert ( @@ -244,11 +249,13 @@ def attn_forward( # repeat across heads and query positions attn_mask = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1) # (B,H,L_img,L_all) - attn = torch.nn.functional.scaled_dot_product_attention( - img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask + kv_packed = torch.cat([k, v], dim=-1) + + attn = self.attention( + hidden_states=img_q, + encoder_hidden_states=kv_packed, + attention_mask=attn_mask, ) - attn = rearrange(attn, "B H L D -> B L (H D)") - attn = self.attn_out(attn) return attn @@ -413,6 +420,65 @@ def __init__( self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: """Timestep independent stuff""" txt = self.txt_in(txt) From 904debcd11de7c6103e091b3223cd459b03d05a1 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 12:50:19 +0200 Subject: [PATCH 03/38] use diffusers rmsnorm --- .../models/transformers/transformer_mirage.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 0225b9532aff..f4199da1edcc 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -26,12 +26,12 @@ from ..modeling_outputs import Transformer2DModelOutput from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..normalization import RMSNorm logger = logging.get_logger(__name__) -# Mirage Layer Components def get_image_ids(bs: int, h: int, w: int, patch_size: int, device: torch.device) -> Tensor: img_ids = torch.zeros(h // patch_size, w // patch_size, 2, device=device) img_ids[..., 0] = torch.arange(h // patch_size, device=device)[:, None] @@ -93,23 +93,13 @@ def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int): - super().__init__() - self.scale = nn.Parameter(torch.ones(dim)) - - def forward(self, x: Tensor) -> Tensor: - x_dtype = x.dtype - x = x.float() - rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) - return (x * rrms * self.scale).to(dtype=x_dtype) class QKNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() - self.query_norm = RMSNorm(dim) - self.key_norm = RMSNorm(dim) + self.query_norm = RMSNorm(dim, eps=1e-6) + self.key_norm = RMSNorm(dim, eps=1e-6) def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: q = self.query_norm(q) @@ -164,7 +154,7 @@ def __init__( # txt kv self.txt_kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=False) - self.k_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim, eps=1e-6) self.attention = Attention( query_dim=hidden_size, From 122115adb1305834b298e677ae30fcef65c4fd35 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 14:26:50 +0200 Subject: [PATCH 04/38] use diffusers timestep embedding method --- .../models/transformers/transformer_mirage.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index f4199da1edcc..916559eb47ac 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -27,6 +27,7 @@ from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..normalization import RMSNorm +from ..embeddings import get_timestep_embedding logger = logging.get_logger(__name__) @@ -71,15 +72,6 @@ def forward(self, ids: Tensor) -> Tensor: return emb.unsqueeze(1) -def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000, time_factor: float = 1000.0) -> Tensor: - t = time_factor * t - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding class MLPEmbedder(nn.Module): @@ -480,8 +472,12 @@ def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[T def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: return self.time_in( - timestep_embedding( - t=timestep, dim=256, max_period=self.time_max_period, time_factor=self.time_factor + get_timestep_embedding( + timesteps=timestep, + embedding_dim=256, + max_period=self.time_max_period, + scale=self.time_factor, + flip_sin_to_cos=True # Match original cos, sin order ).to(dtype) ) From e3fe0e8e1f79216cfe83719debc1ed33dfb3e788 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 15:17:11 +0200 Subject: [PATCH 05/38] remove MirageParams --- .../models/transformers/transformer_mirage.py | 64 +++++-------------- .../pipelines/mirage/pipeline_output.py | 2 +- .../test_models_transformer_mirage.py | 8 +-- 3 files changed, 22 insertions(+), 52 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 916559eb47ac..396e000524ec 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -288,20 +288,6 @@ def forward(self, x: Tensor, vec: Tensor) -> Tensor: return x -@dataclass -class MirageParams: - in_channels: int - patch_size: int - context_in_dim: int - hidden_size: int - mlp_ratio: float - num_heads: int - depth: int - axes_dim: list[int] - theta: int - time_factor: float = 1000.0 - time_max_period: int = 10_000 - conditioning_block_ids: list[int] | None = None def img2seq(img: Tensor, patch_size: int) -> Tensor: @@ -348,55 +334,39 @@ def __init__( if axes_dim is None: axes_dim = [32, 32] - # Create MirageParams from the provided arguments - params = MirageParams( - in_channels=in_channels, - patch_size=patch_size, - context_in_dim=context_in_dim, - hidden_size=hidden_size, - mlp_ratio=mlp_ratio, - num_heads=num_heads, - depth=depth, - axes_dim=axes_dim, - theta=theta, - time_factor=time_factor, - time_max_period=time_max_period, - conditioning_block_ids=conditioning_block_ids, - ) - - self.params = params - self.in_channels = params.in_channels - self.patch_size = params.patch_size + # Store parameters directly + self.in_channels = in_channels + self.patch_size = patch_size self.out_channels = self.in_channels * self.patch_size**2 - self.time_factor = params.time_factor - self.time_max_period = params.time_max_period + self.time_factor = time_factor + self.time_max_period = time_max_period - if params.hidden_size % params.num_heads != 0: - raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + if hidden_size % num_heads != 0: + raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}") - pe_dim = params.hidden_size // params.num_heads + pe_dim = hidden_size // num_heads - if sum(params.axes_dim) != pe_dim: - raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + if sum(axes_dim) != pe_dim: + raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}") - self.hidden_size = params.hidden_size - self.num_heads = params.num_heads - self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.hidden_size = hidden_size + self.num_heads = num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) - self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + self.txt_in = nn.Linear(context_in_dim, self.hidden_size) - conditioning_block_ids: list[int] = params.conditioning_block_ids or list(range(params.depth)) + conditioning_block_ids: list[int] = conditioning_block_ids or list(range(depth)) self.blocks = nn.ModuleList( [ MirageBlock( self.hidden_size, self.num_heads, - mlp_ratio=params.mlp_ratio, + mlp_ratio=mlp_ratio, ) - for i in range(params.depth) + for i in range(depth) ] ) diff --git a/src/diffusers/pipelines/mirage/pipeline_output.py b/src/diffusers/pipelines/mirage/pipeline_output.py index e5cdb2a40924..dfb55821d142 100644 --- a/src/diffusers/pipelines/mirage/pipeline_output.py +++ b/src/diffusers/pipelines/mirage/pipeline_output.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Union import numpy as np import PIL.Image diff --git a/tests/models/transformers/test_models_transformer_mirage.py b/tests/models/transformers/test_models_transformer_mirage.py index 11accdaecbee..5e7b0bd165a6 100644 --- a/tests/models/transformers/test_models_transformer_mirage.py +++ b/tests/models/transformers/test_models_transformer_mirage.py @@ -17,7 +17,7 @@ import torch -from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel, MirageParams +from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin @@ -88,9 +88,9 @@ def test_forward_signature(self): self.assertIsNotNone(outputs) expected_shape = inputs_dict["image_latent"].shape - self.assertEqual(outputs.shape, expected_shape) + self.assertEqual(outputs.sample.shape, expected_shape) - def test_mirage_params_initialization(self): + def test_model_initialization(self): # Test model initialization model = MirageTransformer2DModel( in_channels=16, @@ -196,7 +196,7 @@ def test_attention_mask(self): self.assertIsNotNone(outputs) expected_shape = inputs_dict["image_latent"].shape - self.assertEqual(outputs.shape, expected_shape) + self.assertEqual(outputs.sample.shape, expected_shape) def test_invalid_config(self): # Test invalid configuration - hidden_size not divisible by num_heads From 85ae87b9311a1432f43c2928389c8eafc86c0991 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 16:35:56 +0200 Subject: [PATCH 06/38] checkpoint conversion script --- scripts/convert_mirage_to_diffusers.py | 312 +++++++++++++++++++++++++ 1 file changed, 312 insertions(+) create mode 100644 scripts/convert_mirage_to_diffusers.py diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_mirage_to_diffusers.py new file mode 100644 index 000000000000..85716e69ff92 --- /dev/null +++ b/scripts/convert_mirage_to_diffusers.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +""" +Script to convert Mirage checkpoint from original codebase to diffusers format. +""" + +import argparse +import json +import os +import shutil +import sys +import torch +from safetensors.torch import save_file +from transformers import GemmaTokenizerFast + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.pipelines.mirage import MiragePipeline + +def load_reference_config(vae_type: str) -> dict: + """Load transformer config from existing pipeline checkpoint.""" + + if vae_type == "flux": + config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated/transformer/config.json" + elif vae_type == "dc-ae": + config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated/transformer/config.json" + else: + raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") + + if not os.path.exists(config_path): + raise FileNotFoundError(f"Reference config not found: {config_path}") + + with open(config_path, 'r') as f: + config = json.load(f) + + print(f"✓ Loaded {vae_type} config: in_channels={config['in_channels']}") + return config + +def create_parameter_mapping() -> dict: + """Create mapping from old parameter names to new diffusers names.""" + + # Key mappings for structural changes + mapping = {} + + # RMSNorm: scale -> weight + for i in range(16): # 16 layers + mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.qk_norm.query_norm.weight" + mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.qk_norm.key_norm.weight" + mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.k_norm.weight" + + # Attention: attn_out -> attention.to_out.0 + mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight" + + return mapping + +def convert_checkpoint_parameters(old_state_dict: dict) -> dict: + """Convert old checkpoint parameters to new diffusers format.""" + + print("Converting checkpoint parameters...") + + mapping = create_parameter_mapping() + converted_state_dict = {} + + # First, print available keys to understand structure + print("Available keys in checkpoint:") + for key in sorted(old_state_dict.keys())[:10]: # Show first 10 keys + print(f" {key}") + if len(old_state_dict) > 10: + print(f" ... and {len(old_state_dict) - 10} more") + + for key, value in old_state_dict.items(): + new_key = key + + # Apply specific mappings if needed + if key in mapping: + new_key = mapping[key] + print(f" Mapped: {key} -> {new_key}") + + # Handle img_qkv_proj -> split to to_q, to_k, to_v + if "img_qkv_proj.weight" in key: + print(f" Found QKV projection: {key}") + # Split QKV weight into separate Q, K, V projections + qkv_weight = value + hidden_size = qkv_weight.shape[1] + q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) + + # Extract layer number from key (e.g., blocks.0.img_qkv_proj.weight -> 0) + parts = key.split('.') + layer_idx = None + for i, part in enumerate(parts): + if part == 'blocks' and i + 1 < len(parts) and parts[i+1].isdigit(): + layer_idx = parts[i+1] + break + + if layer_idx is not None: + converted_state_dict[f"blocks.{layer_idx}.attention.to_q.weight"] = q_weight + converted_state_dict[f"blocks.{layer_idx}.attention.to_k.weight"] = k_weight + converted_state_dict[f"blocks.{layer_idx}.attention.to_v.weight"] = v_weight + print(f" Split QKV for layer {layer_idx}") + + # Also keep the original img_qkv_proj for backward compatibility + converted_state_dict[new_key] = value + else: + converted_state_dict[new_key] = value + + print(f"✓ Converted {len(converted_state_dict)} parameters") + return converted_state_dict + + +def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> MirageTransformer2DModel: + """Create and load MirageTransformer2DModel from old checkpoint.""" + + print(f"Loading checkpoint from: {checkpoint_path}") + + # Load old checkpoint + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + old_checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # Handle different checkpoint formats + if isinstance(old_checkpoint, dict): + if 'model' in old_checkpoint: + state_dict = old_checkpoint['model'] + elif 'state_dict' in old_checkpoint: + state_dict = old_checkpoint['state_dict'] + else: + state_dict = old_checkpoint + else: + state_dict = old_checkpoint + + print(f"✓ Loaded checkpoint with {len(state_dict)} parameters") + + # Convert parameter names if needed + converted_state_dict = convert_checkpoint_parameters(state_dict) + + # Create transformer with config + print("Creating MirageTransformer2DModel...") + transformer = MirageTransformer2DModel(**config) + + # Load state dict + print("Loading converted parameters...") + missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False) + + if missing_keys: + print(f"⚠ Missing keys: {missing_keys}") + if unexpected_keys: + print(f"⚠ Unexpected keys: {unexpected_keys}") + + if not missing_keys and not unexpected_keys: + print("✓ All parameters loaded successfully!") + + return transformer + +def copy_pipeline_components(vae_type: str, output_path: str): + """Copy VAE, scheduler, text encoder, and tokenizer from reference pipeline.""" + + if vae_type == "flux": + ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated" + else: # dc-ae + ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated" + + components = ["vae", "scheduler", "text_encoder", "tokenizer"] + + for component in components: + src_path = os.path.join(ref_pipeline, component) + dst_path = os.path.join(output_path, component) + + if os.path.exists(src_path): + if os.path.isdir(src_path): + shutil.copytree(src_path, dst_path, dirs_exist_ok=True) + else: + shutil.copy2(src_path, dst_path) + print(f"✓ Copied {component}") + else: + print(f"⚠ Component not found: {src_path}") + +def create_model_index(vae_type: str, output_path: str): + """Create model_index.json for the pipeline.""" + + if vae_type == "flux": + vae_class = "AutoencoderKL" + else: # dc-ae + vae_class = "AutoencoderDC" + + model_index = { + "_class_name": "MiragePipeline", + "_diffusers_version": "0.31.0.dev0", + "_name_or_path": os.path.basename(output_path), + "scheduler": [ + "diffusers", + "FlowMatchEulerDiscreteScheduler" + ], + "text_encoder": [ + "transformers", + "T5GemmaEncoder" + ], + "tokenizer": [ + "transformers", + "GemmaTokenizerFast" + ], + "transformer": [ + "diffusers", + "MirageTransformer2DModel" + ], + "vae": [ + "diffusers", + vae_class + ] + } + + model_index_path = os.path.join(output_path, "model_index.json") + with open(model_index_path, 'w') as f: + json.dump(model_index, f, indent=2) + + print(f"✓ Created model_index.json") + +def main(args): + # Validate inputs + if not os.path.exists(args.checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}") + + # Load reference config based on VAE type + config = load_reference_config(args.vae_type) + + # Create output directory + os.makedirs(args.output_path, exist_ok=True) + print(f"✓ Output directory: {args.output_path}") + + # Create transformer from checkpoint + transformer = create_transformer_from_checkpoint(args.checkpoint_path, config) + + # Save transformer + transformer_path = os.path.join(args.output_path, "transformer") + os.makedirs(transformer_path, exist_ok=True) + + # Save config + with open(os.path.join(transformer_path, "config.json"), 'w') as f: + json.dump(config, f, indent=2) + + # Save model weights as safetensors + state_dict = transformer.state_dict() + save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors")) + print(f"✓ Saved transformer to {transformer_path}") + + # Copy other pipeline components + copy_pipeline_components(args.vae_type, args.output_path) + + # Create model index + create_model_index(args.vae_type, args.output_path) + + # Verify the pipeline can be loaded + try: + pipeline = MiragePipeline.from_pretrained(args.output_path) + print(f"Pipeline loaded successfully!") + print(f"Transformer: {type(pipeline.transformer).__name__}") + print(f"VAE: {type(pipeline.vae).__name__}") + print(f"Text Encoder: {type(pipeline.text_encoder).__name__}") + print(f"Scheduler: {type(pipeline.scheduler).__name__}") + + # Display model info + num_params = sum(p.numel() for p in pipeline.transformer.parameters()) + print(f"✓ Transformer parameters: {num_params:,}") + + except Exception as e: + print(f"Pipeline verification failed: {e}") + return False + + print("Conversion completed successfully!") + print(f"Converted pipeline saved to: {args.output_path}") + print(f"VAE type: {args.vae_type}") + + + return True + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Mirage checkpoint to diffusers format") + + parser.add_argument( + "--checkpoint_path", + type=str, + required=True, + help="Path to the original Mirage checkpoint (.pth file)" + ) + + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Output directory for the converted diffusers pipeline" + ) + + parser.add_argument( + "--vae_type", + type=str, + choices=["flux", "dc-ae"], + required=True, + help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)" + ) + + args = parser.parse_args() + + try: + success = main(args) + if not success: + sys.exit(1) + except Exception as e: + print(f"❌ Conversion failed: {e}") + import traceback + traceback.print_exc() + sys.exit(1) \ No newline at end of file From 9a697d06b70eaa4e0c9f1f1b5bca6209c65b005b Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 17:00:55 +0200 Subject: [PATCH 07/38] ruff formating --- scripts/convert_mirage_to_diffusers.py | 83 ++++++++----------- .../models/transformers/transformer_mirage.py | 41 ++++----- src/diffusers/pipelines/mirage/__init__.py | 3 +- .../pipelines/mirage/pipeline_mirage.py | 50 +++++++---- .../pipelines/mirage/pipeline_output.py | 2 +- .../test_models_transformer_mirage.py | 30 +++---- 6 files changed, 100 insertions(+), 109 deletions(-) diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_mirage_to_diffusers.py index 85716e69ff92..5e2a2ff768f4 100644 --- a/scripts/convert_mirage_to_diffusers.py +++ b/scripts/convert_mirage_to_diffusers.py @@ -8,16 +8,17 @@ import os import shutil import sys + import torch from safetensors.torch import save_file -from transformers import GemmaTokenizerFast -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel -from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.pipelines.mirage import MiragePipeline + def load_reference_config(vae_type: str) -> dict: """Load transformer config from existing pipeline checkpoint.""" @@ -31,12 +32,13 @@ def load_reference_config(vae_type: str) -> dict: if not os.path.exists(config_path): raise FileNotFoundError(f"Reference config not found: {config_path}") - with open(config_path, 'r') as f: + with open(config_path, "r") as f: config = json.load(f) print(f"✓ Loaded {vae_type} config: in_channels={config['in_channels']}") return config + def create_parameter_mapping() -> dict: """Create mapping from old parameter names to new diffusers names.""" @@ -54,6 +56,7 @@ def create_parameter_mapping() -> dict: return mapping + def convert_checkpoint_parameters(old_state_dict: dict) -> dict: """Convert old checkpoint parameters to new diffusers format.""" @@ -82,15 +85,14 @@ def convert_checkpoint_parameters(old_state_dict: dict) -> dict: print(f" Found QKV projection: {key}") # Split QKV weight into separate Q, K, V projections qkv_weight = value - hidden_size = qkv_weight.shape[1] q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) # Extract layer number from key (e.g., blocks.0.img_qkv_proj.weight -> 0) - parts = key.split('.') + parts = key.split(".") layer_idx = None for i, part in enumerate(parts): - if part == 'blocks' and i + 1 < len(parts) and parts[i+1].isdigit(): - layer_idx = parts[i+1] + if part == "blocks" and i + 1 < len(parts) and parts[i + 1].isdigit(): + layer_idx = parts[i + 1] break if layer_idx is not None: @@ -117,14 +119,14 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Mi if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - old_checkpoint = torch.load(checkpoint_path, map_location='cpu') + old_checkpoint = torch.load(checkpoint_path, map_location="cpu") # Handle different checkpoint formats if isinstance(old_checkpoint, dict): - if 'model' in old_checkpoint: - state_dict = old_checkpoint['model'] - elif 'state_dict' in old_checkpoint: - state_dict = old_checkpoint['state_dict'] + if "model" in old_checkpoint: + state_dict = old_checkpoint["model"] + elif "state_dict" in old_checkpoint: + state_dict = old_checkpoint["state_dict"] else: state_dict = old_checkpoint else: @@ -153,6 +155,7 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Mi return transformer + def copy_pipeline_components(vae_type: str, output_path: str): """Copy VAE, scheduler, text encoder, and tokenizer from reference pipeline.""" @@ -176,6 +179,7 @@ def copy_pipeline_components(vae_type: str, output_path: str): else: print(f"⚠ Component not found: {src_path}") + def create_model_index(vae_type: str, output_path: str): """Create model_index.json for the pipeline.""" @@ -188,33 +192,19 @@ def create_model_index(vae_type: str, output_path: str): "_class_name": "MiragePipeline", "_diffusers_version": "0.31.0.dev0", "_name_or_path": os.path.basename(output_path), - "scheduler": [ - "diffusers", - "FlowMatchEulerDiscreteScheduler" - ], - "text_encoder": [ - "transformers", - "T5GemmaEncoder" - ], - "tokenizer": [ - "transformers", - "GemmaTokenizerFast" - ], - "transformer": [ - "diffusers", - "MirageTransformer2DModel" - ], - "vae": [ - "diffusers", - vae_class - ] + "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], + "text_encoder": ["transformers", "T5GemmaEncoder"], + "tokenizer": ["transformers", "GemmaTokenizerFast"], + "transformer": ["diffusers", "MirageTransformer2DModel"], + "vae": ["diffusers", vae_class], } model_index_path = os.path.join(output_path, "model_index.json") - with open(model_index_path, 'w') as f: + with open(model_index_path, "w") as f: json.dump(model_index, f, indent=2) - print(f"✓ Created model_index.json") + print("✓ Created model_index.json") + def main(args): # Validate inputs @@ -236,7 +226,7 @@ def main(args): os.makedirs(transformer_path, exist_ok=True) # Save config - with open(os.path.join(transformer_path, "config.json"), 'w') as f: + with open(os.path.join(transformer_path, "config.json"), "w") as f: json.dump(config, f, indent=2) # Save model weights as safetensors @@ -253,7 +243,7 @@ def main(args): # Verify the pipeline can be loaded try: pipeline = MiragePipeline.from_pretrained(args.output_path) - print(f"Pipeline loaded successfully!") + print("Pipeline loaded successfully!") print(f"Transformer: {type(pipeline.transformer).__name__}") print(f"VAE: {type(pipeline.vae).__name__}") print(f"Text Encoder: {type(pipeline.text_encoder).__name__}") @@ -271,24 +261,18 @@ def main(args): print(f"Converted pipeline saved to: {args.output_path}") print(f"VAE type: {args.vae_type}") - return True + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Mirage checkpoint to diffusers format") parser.add_argument( - "--checkpoint_path", - type=str, - required=True, - help="Path to the original Mirage checkpoint (.pth file)" + "--checkpoint_path", type=str, required=True, help="Path to the original Mirage checkpoint (.pth file)" ) parser.add_argument( - "--output_path", - type=str, - required=True, - help="Output directory for the converted diffusers pipeline" + "--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline" ) parser.add_argument( @@ -296,7 +280,7 @@ def main(args): type=str, choices=["flux", "dc-ae"], required=True, - help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)" + help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)", ) args = parser.parse_args() @@ -306,7 +290,8 @@ def main(args): if not success: sys.exit(1) except Exception as e: - print(f"❌ Conversion failed: {e}") + print(f"Conversion failed: {e}") import traceback + traceback.print_exc() - sys.exit(1) \ No newline at end of file + sys.exit(1) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 396e000524ec..923d44d4f1ec 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -13,21 +13,21 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Dict, Optional, Union, Tuple +from typing import Any, Dict, Optional, Tuple, Union + import torch -import math -from torch import Tensor, nn -from torch.nn.functional import fold, unfold from einops import rearrange from einops.layers.torch import Rearrange +from torch import Tensor, nn +from torch.nn.functional import fold, unfold from ...configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin -from ..modeling_outputs import Transformer2DModelOutput -from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..normalization import RMSNorm +from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ..embeddings import get_timestep_embedding +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm logger = logging.get_logger(__name__) @@ -72,8 +72,6 @@ def forward(self, ids: Tensor) -> Tensor: return emb.unsqueeze(1) - - class MLPEmbedder(nn.Module): def __init__(self, in_dim: int, hidden_dim: int): super().__init__() @@ -85,8 +83,6 @@ def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) - - class QKNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() @@ -157,7 +153,6 @@ def __init__( processor=MirageAttnProcessor2_0(), ) - # mlp self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) @@ -212,9 +207,9 @@ def attn_forward( l_txt = txt_k.shape[2] assert attention_mask.dim() == 2, f"Unsupported attention_mask shape: {attention_mask.shape}" - assert ( - attention_mask.shape[-1] == l_txt - ), f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" + assert attention_mask.shape[-1] == l_txt, ( + f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" + ) device = img_q.device @@ -234,8 +229,8 @@ def attn_forward( kv_packed = torch.cat([k, v], dim=-1) attn = self.attention( - hidden_states=img_q, - encoder_hidden_states=kv_packed, + hidden_states=img_q, + encoder_hidden_states=kv_packed, attention_mask=attn_mask, ) @@ -288,8 +283,6 @@ def forward(self, x: Tensor, vec: Tensor) -> Tensor: return x - - def img2seq(img: Tensor, patch_size: int) -> Tensor: """Flatten an image into a sequence of patches""" return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) @@ -327,7 +320,7 @@ def __init__( time_factor: float = 1000.0, time_max_period: int = 10000, conditioning_block_ids: list = None, - **kwargs + **kwargs, ): super().__init__() @@ -447,7 +440,7 @@ def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Te embedding_dim=256, max_period=self.time_max_period, scale=self.time_factor, - flip_sin_to_cos=True # Match original cos, sin order + flip_sin_to_cos=True, # Match original cos, sin order ).to(dtype) ) @@ -470,9 +463,7 @@ def forward_transformers( vec = self.compute_timestep_embedding(timestep, dtype=img.dtype) for block in self.blocks: - img = block( - img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs - ) + img = block(img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs) img = self.final_layer(img, vec) return img diff --git a/src/diffusers/pipelines/mirage/__init__.py b/src/diffusers/pipelines/mirage/__init__.py index 4fd8ad191b3f..cba951057370 100644 --- a/src/diffusers/pipelines/mirage/__init__.py +++ b/src/diffusers/pipelines/mirage/__init__.py @@ -1,4 +1,5 @@ from .pipeline_mirage import MiragePipeline from .pipeline_output import MiragePipelineOutput -__all__ = ["MiragePipeline", "MiragePipelineOutput"] \ No newline at end of file + +__all__ = ["MiragePipeline", "MiragePipelineOutput"] diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py index 126eab07977c..c4a4783c5f38 100644 --- a/src/diffusers/pipelines/mirage/pipeline_mirage.py +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import html import inspect import os -from typing import Any, Callable, Dict, List, Optional, Union - -import html import re import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union import ftfy import torch @@ -31,7 +30,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, AutoencoderDC +from ...models import AutoencoderDC, AutoencoderKL from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( logging, @@ -41,6 +40,7 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import MiragePipelineOutput + try: from ...models.transformers.transformer_mirage import MirageTransformer2DModel except ImportError: @@ -55,7 +55,19 @@ class TextPreprocessor: def __init__(self): """Initialize text preprocessor.""" self.bad_punct_regex = re.compile( - r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + r"\\" + r"\/" + r"\*" + r"]{1,}" + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + r"\\" + + r"\/" + + r"\*" + + r"]{1,}" ) def clean_text(self, text: str) -> str: @@ -93,7 +105,7 @@ def clean_text(self, text: str) -> str: ) # кавычки к одному стандарту - text = re.sub(r"[`´«»""¨]", '"', text) + text = re.sub(r"[`´«»" "¨]", '"', text) text = re.sub(r"['']", "'", text) # " and & @@ -243,9 +255,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P """ # Ensure T5GemmaEncoder is available for loading import transformers - if not hasattr(transformers, 'T5GemmaEncoder'): + + if not hasattr(transformers, "T5GemmaEncoder"): try: from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + transformers.T5GemmaEncoder = T5GemmaEncoder except ImportError: # T5GemmaEncoder not available in this transformers version @@ -254,7 +268,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Proceed with standard loading return super().from_pretrained(pretrained_model_name_or_path, **kwargs) - def __init__( self, transformer: MirageTransformer2DModel, @@ -333,7 +346,7 @@ def _enhance_vae_properties(self): if hasattr(self.vae, "spatial_compression_ratio") and self.vae.spatial_compression_ratio == 32: self.vae.latent_channels = 32 # DC-AE default else: - self.vae.latent_channels = 4 # AutoencoderKL default + self.vae.latent_channels = 4 # AutoencoderKL default @property def vae_scale_factor(self): @@ -353,7 +366,10 @@ def prepare_latents( ): """Prepare initial latents for the diffusion process.""" if latents is None: - latent_height, latent_width = height // self.vae.spatial_compression_ratio, width // self.vae.spatial_compression_ratio + latent_height, latent_width = ( + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) shape = (batch_size, num_channels_latents, latent_height, latent_width) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -424,7 +440,9 @@ def check_inputs( ): """Check that all inputs are in correct format.""" if height % self.vae.spatial_compression_ratio != 0 or width % self.vae.spatial_compression_ratio != 0: - raise ValueError(f"`height` and `width` have to be divisible by {self.vae.spatial_compression_ratio} but are {height} and {width}.") + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae.spatial_compression_ratio} but are {height} and {width}." + ) if guidance_scale < 1.0: raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}") @@ -584,12 +602,16 @@ def __call__( # Forward through transformer layers img_seq = self.transformer.forward_transformers( - img_seq, txt, time_embedding=self.transformer.compute_timestep_embedding(t_cont, img_seq.dtype), - pe=pe, attention_mask=ca_mask + img_seq, + txt, + time_embedding=self.transformer.compute_timestep_embedding(t_cont, img_seq.dtype), + pe=pe, + attention_mask=ca_mask, ) # Convert back to image format from ...models.transformers.transformer_mirage import seq2img + noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) # Apply CFG @@ -626,4 +648,4 @@ def __call__( if not return_dict: return (image,) - return MiragePipelineOutput(images=image) \ No newline at end of file + return MiragePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/mirage/pipeline_output.py b/src/diffusers/pipelines/mirage/pipeline_output.py index dfb55821d142..e41c8e3bea00 100644 --- a/src/diffusers/pipelines/mirage/pipeline_output.py +++ b/src/diffusers/pipelines/mirage/pipeline_output.py @@ -32,4 +32,4 @@ class MiragePipelineOutput(BaseOutput): num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ - images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/tests/models/transformers/test_models_transformer_mirage.py b/tests/models/transformers/test_models_transformer_mirage.py index 5e7b0bd165a6..0085627aa7e4 100644 --- a/tests/models/transformers/test_models_transformer_mirage.py +++ b/tests/models/transformers/test_models_transformer_mirage.py @@ -133,8 +133,7 @@ def test_process_inputs(self): with torch.no_grad(): img_seq, txt, pe = model.process_inputs( - inputs_dict["image_latent"], - inputs_dict["cross_attn_conditioning"] + inputs_dict["image_latent"], inputs_dict["cross_attn_conditioning"] ) # Check shapes @@ -144,7 +143,9 @@ def test_process_inputs(self): expected_seq_len = (height // patch_size) * (width // patch_size) self.assertEqual(img_seq.shape, (batch_size, expected_seq_len, init_dict["in_channels"] * patch_size**2)) - self.assertEqual(txt.shape, (batch_size, inputs_dict["cross_attn_conditioning"].shape[1], init_dict["hidden_size"])) + self.assertEqual( + txt.shape, (batch_size, inputs_dict["cross_attn_conditioning"].shape[1], init_dict["hidden_size"]) + ) # Check that pe has the correct batch size, sequence length and some embedding dimension self.assertEqual(pe.shape[0], batch_size) # batch size self.assertEqual(pe.shape[1], 1) # unsqueeze(1) in EmbedND @@ -160,20 +161,14 @@ def test_forward_transformers(self): with torch.no_grad(): # Process inputs first img_seq, txt, pe = model.process_inputs( - inputs_dict["image_latent"], - inputs_dict["cross_attn_conditioning"] + inputs_dict["image_latent"], inputs_dict["cross_attn_conditioning"] ) # Test forward_transformers - output_seq = model.forward_transformers( - img_seq, - txt, - timestep=inputs_dict["timestep"], - pe=pe - ) + output_seq = model.forward_transformers(img_seq, txt, timestep=inputs_dict["timestep"], pe=pe) # Check output shape - expected_out_channels = init_dict["in_channels"] * init_dict["patch_size"]**2 + expected_out_channels = init_dict["in_channels"] * init_dict["patch_size"] ** 2 self.assertEqual(output_seq.shape, (img_seq.shape[0], img_seq.shape[1], expected_out_channels)) def test_attention_mask(self): @@ -186,13 +181,10 @@ def test_attention_mask(self): batch_size = inputs_dict["cross_attn_conditioning"].shape[0] seq_len = inputs_dict["cross_attn_conditioning"].shape[1] attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool).to(torch_device) - attention_mask[:, seq_len//2:] = False # Mask second half + attention_mask[:, seq_len // 2 :] = False # Mask second half with torch.no_grad(): - outputs = model( - **inputs_dict, - cross_attn_mask=attention_mask - ) + outputs = model(**inputs_dict, cross_attn_mask=attention_mask) self.assertIsNotNone(outputs) expected_shape = inputs_dict["image_latent"].shape @@ -237,7 +229,7 @@ def test_gradient_checkpointing_enable(self): # Check that _activation_checkpointing is set for block in model.blocks: - self.assertTrue(hasattr(block, '_activation_checkpointing')) + self.assertTrue(hasattr(block, "_activation_checkpointing")) def test_from_config(self): init_dict, _ = self.prepare_init_args_and_inputs_for_common() @@ -249,4 +241,4 @@ def test_from_config(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 34fa9dd4c1a77523cb58039147b6c487b0308593 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 19:05:51 +0000 Subject: [PATCH 08/38] remove dependencies to old checkpoints --- scripts/convert_mirage_to_diffusers.py | 229 +++++++++++++++++++++---- 1 file changed, 192 insertions(+), 37 deletions(-) diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_mirage_to_diffusers.py index 5e2a2ff768f4..eb6de1a37481 100644 --- a/scripts/convert_mirage_to_diffusers.py +++ b/scripts/convert_mirage_to_diffusers.py @@ -6,11 +6,12 @@ import argparse import json import os -import shutil import sys import torch from safetensors.torch import save_file +from dataclasses import dataclass, asdict +from typing import Tuple, Dict sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) @@ -18,35 +19,53 @@ from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel from diffusers.pipelines.mirage import MiragePipeline +@dataclass(frozen=True) +class MirageBase: + context_in_dim: int = 2304 + hidden_size: int = 1792 + mlp_ratio: float = 3.5 + num_heads: int = 28 + depth: int = 16 + axes_dim: Tuple[int, int] = (32, 32) + theta: int = 10_000 + time_factor: float = 1000.0 + time_max_period: int = 10_000 -def load_reference_config(vae_type: str) -> dict: - """Load transformer config from existing pipeline checkpoint.""" +@dataclass(frozen=True) +class MirageFlux(MirageBase): + in_channels: int = 16 + patch_size: int = 2 + + +@dataclass(frozen=True) +class MirageDCAE(MirageBase): + in_channels: int = 32 + patch_size: int = 1 + + +def build_config(vae_type: str) -> dict: if vae_type == "flux": - config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated/transformer/config.json" + cfg = MirageFlux() elif vae_type == "dc-ae": - config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated/transformer/config.json" + cfg = MirageDCAE() else: raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") - if not os.path.exists(config_path): - raise FileNotFoundError(f"Reference config not found: {config_path}") - - with open(config_path, "r") as f: - config = json.load(f) + config_dict = asdict(cfg) + config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index] + return config_dict - print(f"✓ Loaded {vae_type} config: in_channels={config['in_channels']}") - return config -def create_parameter_mapping() -> dict: +def create_parameter_mapping(depth: int) -> dict: """Create mapping from old parameter names to new diffusers names.""" # Key mappings for structural changes mapping = {} # RMSNorm: scale -> weight - for i in range(16): # 16 layers + for i in range(depth): mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.qk_norm.query_norm.weight" mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.qk_norm.key_norm.weight" mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.k_norm.weight" @@ -57,12 +76,12 @@ def create_parameter_mapping() -> dict: return mapping -def convert_checkpoint_parameters(old_state_dict: dict) -> dict: +def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]: """Convert old checkpoint parameters to new diffusers format.""" print("Converting checkpoint parameters...") - mapping = create_parameter_mapping() + mapping = create_parameter_mapping(depth) converted_state_dict = {} # First, print available keys to understand structure @@ -135,7 +154,8 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Mi print(f"✓ Loaded checkpoint with {len(state_dict)} parameters") # Convert parameter names if needed - converted_state_dict = convert_checkpoint_parameters(state_dict) + model_depth = int(config.get("depth", 16)) + converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth) # Create transformer with config print("Creating MirageTransformer2DModel...") @@ -156,28 +176,164 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Mi return transformer -def copy_pipeline_components(vae_type: str, output_path: str): - """Copy VAE, scheduler, text encoder, and tokenizer from reference pipeline.""" + + +def create_scheduler_config(output_path: str): + """Create FlowMatchEulerDiscreteScheduler config.""" + + scheduler_config = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "num_train_timesteps": 1000, + "shift": 1.0 + } + + scheduler_path = os.path.join(output_path, "scheduler") + os.makedirs(scheduler_path, exist_ok=True) + + with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f: + json.dump(scheduler_config, f, indent=2) + + print("✓ Created scheduler config") + + +def create_vae_config(vae_type: str, output_path: str): + """Create VAE config based on type.""" if vae_type == "flux": - ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated" + vae_config = { + "_class_name": "AutoencoderKL", + "latent_channels": 16, + "block_out_channels": [128, 256, 512, 512], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "scaling_factor": 0.3611, + "shift_factor": 0.1159, + "use_post_quant_conv": False, + "use_quant_conv": False + } else: # dc-ae - ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated" + vae_config = { + "_class_name": "AutoencoderDC", + "latent_channels": 32, + "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], + "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], + "encoder_block_types": [ + "ResBlock", + "ResBlock", + "ResBlock", + "EfficientViTBlock", + "EfficientViTBlock", + "EfficientViTBlock" + ], + "decoder_block_types": [ + "ResBlock", + "ResBlock", + "ResBlock", + "EfficientViTBlock", + "EfficientViTBlock", + "EfficientViTBlock" + ], + "encoder_layers_per_block": [2, 2, 2, 3, 3, 3], + "decoder_layers_per_block": [3, 3, 3, 3, 3, 3], + "encoder_qkv_multiscales": [[], [], [], [5], [5], [5]], + "decoder_qkv_multiscales": [[], [], [], [5], [5], [5]], + "scaling_factor": 0.41407, + "upsample_block_type": "interpolate" + } + + vae_path = os.path.join(output_path, "vae") + os.makedirs(vae_path, exist_ok=True) + + with open(os.path.join(vae_path, "config.json"), "w") as f: + json.dump(vae_config, f, indent=2) + + print("✓ Created VAE config") + + +def create_text_encoder_config(output_path: str): + """Create T5GemmaEncoder config.""" + + text_encoder_config = { + "model_name": "google/t5gemma-2b-2b-ul2", + "model_max_length": 256, + "use_attn_mask": True, + "use_last_hidden_state": True + } - components = ["vae", "scheduler", "text_encoder", "tokenizer"] + text_encoder_path = os.path.join(output_path, "text_encoder") + os.makedirs(text_encoder_path, exist_ok=True) + + with open(os.path.join(text_encoder_path, "config.json"), "w") as f: + json.dump(text_encoder_config, f, indent=2) + + print("✓ Created text encoder config") + + +def create_tokenizer_config(output_path: str): + """Create GemmaTokenizerFast config and files.""" + + tokenizer_config = { + "add_bos_token": False, + "add_eos_token": False, + "added_tokens_decoder": { + "0": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, + "1": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, + "2": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, + "3": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, + "106": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, + "107": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True} + }, + "additional_special_tokens": ["", ""], + "bos_token": "", + "clean_up_tokenization_spaces": False, + "eos_token": "", + "extra_special_tokens": {}, + "model_max_length": 256, + "pad_token": "", + "padding_side": "right", + "sp_model_kwargs": {}, + "spaces_between_special_tokens": False, + "tokenizer_class": "GemmaTokenizer", + "unk_token": "", + "use_default_system_prompt": False + } - for component in components: - src_path = os.path.join(ref_pipeline, component) - dst_path = os.path.join(output_path, component) + special_tokens_map = { + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "" + } - if os.path.exists(src_path): - if os.path.isdir(src_path): - shutil.copytree(src_path, dst_path, dirs_exist_ok=True) - else: - shutil.copy2(src_path, dst_path) - print(f"✓ Copied {component}") - else: - print(f"⚠ Component not found: {src_path}") + tokenizer_path = os.path.join(output_path, "tokenizer") + os.makedirs(tokenizer_path, exist_ok=True) + + with open(os.path.join(tokenizer_path, "tokenizer_config.json"), "w") as f: + json.dump(tokenizer_config, f, indent=2) + + with open(os.path.join(tokenizer_path, "special_tokens_map.json"), "w") as f: + json.dump(special_tokens_map, f, indent=2) + + print("✓ Created tokenizer config (Note: tokenizer.json and tokenizer.model files need to be provided separately)") + + +def create_pipeline_components(vae_type: str, output_path: str): + """Create all pipeline components with proper configs.""" + + create_scheduler_config(output_path) + create_vae_config(vae_type, output_path) + create_text_encoder_config(output_path) + create_tokenizer_config(output_path) def create_model_index(vae_type: str, output_path: str): @@ -211,8 +367,7 @@ def main(args): if not os.path.exists(args.checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}") - # Load reference config based on VAE type - config = load_reference_config(args.vae_type) + config = build_config(args.vae_type) # Create output directory os.makedirs(args.output_path, exist_ok=True) @@ -234,8 +389,8 @@ def main(args): save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors")) print(f"✓ Saved transformer to {transformer_path}") - # Copy other pipeline components - copy_pipeline_components(args.vae_type, args.output_path) + # Create other pipeline components + create_pipeline_components(args.vae_type, args.output_path) # Create model index create_model_index(args.vae_type, args.output_path) From 5cc965a7570022b959bc38f8dd167e2eaed18254 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 22:30:58 +0200 Subject: [PATCH 09/38] remove old checkpoints dependency --- scripts/convert_mirage_to_diffusers.py | 170 ++---------------- .../pipelines/mirage/pipeline_mirage.py | 68 +++++-- 2 files changed, 63 insertions(+), 175 deletions(-) diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_mirage_to_diffusers.py index eb6de1a37481..2ddb708bc704 100644 --- a/scripts/convert_mirage_to_diffusers.py +++ b/scripts/convert_mirage_to_diffusers.py @@ -84,13 +84,6 @@ def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth mapping = create_parameter_mapping(depth) converted_state_dict = {} - # First, print available keys to understand structure - print("Available keys in checkpoint:") - for key in sorted(old_state_dict.keys())[:10]: # Show first 10 keys - print(f" {key}") - if len(old_state_dict) > 10: - print(f" ... and {len(old_state_dict) - 10} more") - for key, value in old_state_dict.items(): new_key = key @@ -196,172 +189,37 @@ def create_scheduler_config(output_path: str): print("✓ Created scheduler config") -def create_vae_config(vae_type: str, output_path: str): - """Create VAE config based on type.""" - - if vae_type == "flux": - vae_config = { - "_class_name": "AutoencoderKL", - "latent_channels": 16, - "block_out_channels": [128, 256, 512, 512], - "down_block_types": [ - "DownEncoderBlock2D", - "DownEncoderBlock2D", - "DownEncoderBlock2D", - "DownEncoderBlock2D" - ], - "up_block_types": [ - "UpDecoderBlock2D", - "UpDecoderBlock2D", - "UpDecoderBlock2D", - "UpDecoderBlock2D" - ], - "scaling_factor": 0.3611, - "shift_factor": 0.1159, - "use_post_quant_conv": False, - "use_quant_conv": False - } - else: # dc-ae - vae_config = { - "_class_name": "AutoencoderDC", - "latent_channels": 32, - "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], - "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], - "encoder_block_types": [ - "ResBlock", - "ResBlock", - "ResBlock", - "EfficientViTBlock", - "EfficientViTBlock", - "EfficientViTBlock" - ], - "decoder_block_types": [ - "ResBlock", - "ResBlock", - "ResBlock", - "EfficientViTBlock", - "EfficientViTBlock", - "EfficientViTBlock" - ], - "encoder_layers_per_block": [2, 2, 2, 3, 3, 3], - "decoder_layers_per_block": [3, 3, 3, 3, 3, 3], - "encoder_qkv_multiscales": [[], [], [], [5], [5], [5]], - "decoder_qkv_multiscales": [[], [], [], [5], [5], [5]], - "scaling_factor": 0.41407, - "upsample_block_type": "interpolate" - } - - vae_path = os.path.join(output_path, "vae") - os.makedirs(vae_path, exist_ok=True) - - with open(os.path.join(vae_path, "config.json"), "w") as f: - json.dump(vae_config, f, indent=2) - - print("✓ Created VAE config") - - -def create_text_encoder_config(output_path: str): - """Create T5GemmaEncoder config.""" - - text_encoder_config = { - "model_name": "google/t5gemma-2b-2b-ul2", - "model_max_length": 256, - "use_attn_mask": True, - "use_last_hidden_state": True - } - - text_encoder_path = os.path.join(output_path, "text_encoder") - os.makedirs(text_encoder_path, exist_ok=True) - - with open(os.path.join(text_encoder_path, "config.json"), "w") as f: - json.dump(text_encoder_config, f, indent=2) - - print("✓ Created text encoder config") - - -def create_tokenizer_config(output_path: str): - """Create GemmaTokenizerFast config and files.""" - - tokenizer_config = { - "add_bos_token": False, - "add_eos_token": False, - "added_tokens_decoder": { - "0": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, - "1": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, - "2": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, - "3": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, - "106": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, - "107": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True} - }, - "additional_special_tokens": ["", ""], - "bos_token": "", - "clean_up_tokenization_spaces": False, - "eos_token": "", - "extra_special_tokens": {}, - "model_max_length": 256, - "pad_token": "", - "padding_side": "right", - "sp_model_kwargs": {}, - "spaces_between_special_tokens": False, - "tokenizer_class": "GemmaTokenizer", - "unk_token": "", - "use_default_system_prompt": False - } - - special_tokens_map = { - "bos_token": "", - "eos_token": "", - "pad_token": "", - "unk_token": "" - } - - tokenizer_path = os.path.join(output_path, "tokenizer") - os.makedirs(tokenizer_path, exist_ok=True) - - with open(os.path.join(tokenizer_path, "tokenizer_config.json"), "w") as f: - json.dump(tokenizer_config, f, indent=2) - - with open(os.path.join(tokenizer_path, "special_tokens_map.json"), "w") as f: - json.dump(special_tokens_map, f, indent=2) - - print("✓ Created tokenizer config (Note: tokenizer.json and tokenizer.model files need to be provided separately)") - - -def create_pipeline_components(vae_type: str, output_path: str): - """Create all pipeline components with proper configs.""" - - create_scheduler_config(output_path) - create_vae_config(vae_type, output_path) - create_text_encoder_config(output_path) - create_tokenizer_config(output_path) def create_model_index(vae_type: str, output_path: str): - """Create model_index.json for the pipeline.""" + """Create model_index.json for the pipeline with HuggingFace model references.""" if vae_type == "flux": - vae_class = "AutoencoderKL" + vae_model_name = "black-forest-labs/FLUX.1-dev" + vae_subfolder = "vae" else: # dc-ae - vae_class = "AutoencoderDC" + vae_model_name = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers" + vae_subfolder = None + + # Text encoder and tokenizer always use T5Gemma + text_model_name = "google/t5gemma-2b-2b-ul2" model_index = { "_class_name": "MiragePipeline", "_diffusers_version": "0.31.0.dev0", "_name_or_path": os.path.basename(output_path), "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], - "text_encoder": ["transformers", "T5GemmaEncoder"], - "tokenizer": ["transformers", "GemmaTokenizerFast"], + "text_encoder": text_model_name, + "tokenizer": text_model_name, "transformer": ["diffusers", "MirageTransformer2DModel"], - "vae": ["diffusers", vae_class], + "vae": vae_model_name, + "vae_subfolder": vae_subfolder, } model_index_path = os.path.join(output_path, "model_index.json") with open(model_index_path, "w") as f: json.dump(model_index, f, indent=2) - print("✓ Created model_index.json") - - def main(args): # Validate inputs if not os.path.exists(args.checkpoint_path): @@ -389,10 +247,8 @@ def main(args): save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors")) print(f"✓ Saved transformer to {transformer_path}") - # Create other pipeline components - create_pipeline_components(args.vae_type, args.output_path) + create_scheduler_config(args.output_path) - # Create model index create_model_index(args.vae_type, args.output_path) # Verify the pipeline can be loaded diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py index c4a4783c5f38..e6a13ff226cd 100644 --- a/src/diffusers/pipelines/mirage/pipeline_mirage.py +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -247,26 +247,61 @@ class MiragePipeline( @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): """ - Override from_pretrained to ensure T5GemmaEncoder is available for loading. + Override from_pretrained to load VAE and text encoder from HuggingFace models. - This ensures that T5GemmaEncoder from transformers is accessible in the module namespace - during component loading, which is required for MiragePipeline checkpoints that use - T5GemmaEncoder as the text encoder. + The MiragePipeline checkpoints only store transformer and scheduler locally. + VAE and text encoder are loaded from external HuggingFace models as specified + in model_index.json. """ - # Ensure T5GemmaEncoder is available for loading - import transformers + import json + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel + + model_index_path = os.path.join(pretrained_model_name_or_path, "model_index.json") + if not os.path.exists(model_index_path): + raise ValueError(f"model_index.json not found in {pretrained_model_name_or_path}") + + with open(model_index_path, "r") as f: + model_index = json.load(f) + + vae_model_name = model_index.get("vae") + vae_subfolder = model_index.get("vae_subfolder") + text_model_name = model_index.get("text_encoder") + tokenizer_model_name = model_index.get("tokenizer") + + logger.info(f"Loading VAE from {vae_model_name}...") + if "FLUX" in vae_model_name or "flux" in vae_model_name: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder=vae_subfolder) + else: # DC-AE + vae = AutoencoderDC.from_pretrained(vae_model_name) + + logger.info(f"Loading text encoder from {text_model_name}...") + t5gemma_model = T5GemmaModel.from_pretrained(text_model_name) + text_encoder = t5gemma_model.encoder + + logger.info(f"Loading tokenizer from {tokenizer_model_name}...") + tokenizer = GemmaTokenizerFast.from_pretrained(tokenizer_model_name) + tokenizer.model_max_length = 256 + + # Load transformer and scheduler from local checkpoint + logger.info(f"Loading transformer from {pretrained_model_name_or_path}...") + transformer = MirageTransformer2DModel.from_pretrained( + pretrained_model_name_or_path, subfolder="transformer" + ) - if not hasattr(transformers, "T5GemmaEncoder"): - try: - from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + logger.info(f"Loading scheduler from {pretrained_model_name_or_path}...") + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder="scheduler" + ) - transformers.T5GemmaEncoder = T5GemmaEncoder - except ImportError: - # T5GemmaEncoder not available in this transformers version - pass + pipeline = cls( + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + ) - # Proceed with standard loading - return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + return pipeline def __init__( self, @@ -283,11 +318,8 @@ def __init__( "MirageTransformer2DModel is not available. Please ensure the transformer_mirage module is properly installed." ) - # Store standard components self.text_encoder = text_encoder self.tokenizer = tokenizer - - # Initialize text preprocessor self.text_preprocessor = TextPreprocessor() self.register_modules( From d79cd8fffb959ada9bfeb4d3929b7aa2ce69f993 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 20:56:51 +0000 Subject: [PATCH 10/38] move default height and width in checkpoint config --- scripts/convert_mirage_to_diffusers.py | 9 +++++++++ .../pipelines/mirage/pipeline_mirage.py | 16 ++++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_mirage_to_diffusers.py index 2ddb708bc704..37de253d1448 100644 --- a/scripts/convert_mirage_to_diffusers.py +++ b/scripts/convert_mirage_to_diffusers.py @@ -19,6 +19,9 @@ from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel from diffusers.pipelines.mirage import MiragePipeline +DEFAULT_HEIGHT = 512 +DEFAULT_WIDTH = 512 + @dataclass(frozen=True) class MirageBase: context_in_dim: int = 2304 @@ -197,9 +200,13 @@ def create_model_index(vae_type: str, output_path: str): if vae_type == "flux": vae_model_name = "black-forest-labs/FLUX.1-dev" vae_subfolder = "vae" + default_height = DEFAULT_HEIGHT + default_width = DEFAULT_WIDTH else: # dc-ae vae_model_name = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers" vae_subfolder = None + default_height = DEFAULT_HEIGHT + default_width = DEFAULT_WIDTH # Text encoder and tokenizer always use T5Gemma text_model_name = "google/t5gemma-2b-2b-ul2" @@ -214,6 +221,8 @@ def create_model_index(vae_type: str, output_path: str): "transformer": ["diffusers", "MirageTransformer2DModel"], "vae": vae_model_name, "vae_subfolder": vae_subfolder, + "default_height": default_height, + "default_width": default_width, } model_index_path = os.path.join(output_path, "model_index.json") diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py index e6a13ff226cd..9d247eecbd7f 100644 --- a/src/diffusers/pipelines/mirage/pipeline_mirage.py +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -31,6 +31,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderDC, AutoencoderKL +from ...models.transformers.transformer_mirage import seq2img from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( logging, @@ -46,6 +47,9 @@ except ImportError: MirageTransformer2DModel = None +DEFAULT_HEIGHT = 512 +DEFAULT_WIDTH = 512 + logger = logging.get_logger(__name__) @@ -267,6 +271,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P vae_subfolder = model_index.get("vae_subfolder") text_model_name = model_index.get("text_encoder") tokenizer_model_name = model_index.get("tokenizer") + default_height = model_index.get("default_height", DEFAULT_HEIGHT) + default_width = model_index.get("default_width", DEFAULT_WIDTH) logger.info(f"Loading VAE from {vae_model_name}...") if "FLUX" in vae_model_name or "flux" in vae_model_name: @@ -301,6 +307,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P vae=vae, ) + # Store default dimensions as pipeline attributes + pipeline.default_height = default_height + pipeline.default_width = default_width + return pipeline def __init__( @@ -558,8 +568,8 @@ def __call__( """ # 0. Default height and width to transformer config - height = height or 256 - width = width or 256 + height = height or getattr(self, 'default_height', DEFAULT_HEIGHT) + width = width or getattr(self, 'default_width', DEFAULT_WIDTH) # 1. Check inputs self.check_inputs( @@ -642,8 +652,6 @@ def __call__( ) # Convert back to image format - from ...models.transformers.transformer_mirage import seq2img - noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) # Apply CFG From f2759fd0a8ea934ea0ecea9bfb68f43dffdca5f7 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 21:26:03 +0000 Subject: [PATCH 11/38] add docstrings --- .../models/transformers/transformer_mirage.py | 367 +++++++++++++++++- .../test_models_transformer_mirage.py | 6 +- 2 files changed, 351 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 923d44d4f1ec..c509f797fb8b 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -33,20 +33,70 @@ logger = logging.get_logger(__name__) -def get_image_ids(bs: int, h: int, w: int, patch_size: int, device: torch.device) -> Tensor: - img_ids = torch.zeros(h // patch_size, w // patch_size, 2, device=device) - img_ids[..., 0] = torch.arange(h // patch_size, device=device)[:, None] - img_ids[..., 1] = torch.arange(w // patch_size, device=device)[None, :] - return img_ids.reshape((h // patch_size) * (w // patch_size), 2).unsqueeze(0).repeat(bs, 1, 1) +def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> Tensor: + r""" + Generates 2D patch coordinate indices for a batch of images. + + Parameters: + batch_size (`int`): + Number of images in the batch. + height (`int`): + Height of the input images (in pixels). + width (`int`): + Width of the input images (in pixels). + patch_size (`int`): + Size of the square patches that the image is divided into. + device (`torch.device`): + The device on which to create the tensor. + + Returns: + `torch.Tensor`: + Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) + coordinates of each patch in the image grid. + """ + + img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device) + img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None] + img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :] + return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1) def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: + r""" + Applies rotary positional embeddings (RoPE) to a query tensor. + + Parameters: + xq (`torch.Tensor`): + Input tensor of shape `(..., dim)` representing the queries. + freqs_cis (`torch.Tensor`): + Precomputed rotary frequency components of shape `(..., dim/2, 2)` + containing cosine and sine pairs. + + Returns: + `torch.Tensor`: + Tensor of the same shape as `xq` with rotary embeddings applied. + """ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq) class EmbedND(nn.Module): + r""" + N-dimensional rotary positional embedding. + + This module creates rotary embeddings (RoPE) across multiple axes, where each + axis can have its own embedding dimension. The embeddings are combined and + returned as a single tensor + + Parameters: + dim (int): + Base embedding dimension (must be even). + theta (int): + Scaling factor that controls the frequency spectrum of the rotary embeddings. + axes_dim (list[int]): + List of embedding dimensions for each axis (each must be even). + """ def __init__(self, dim: int, theta: int, axes_dim: list[int]): super().__init__() self.dim = dim @@ -73,6 +123,19 @@ def forward(self, ids: Tensor) -> Tensor: class MLPEmbedder(nn.Module): + r""" + A simple 2-layer MLP used for embedding inputs. + + Parameters: + in_dim (`int`): + Dimensionality of the input features. + hidden_dim (`int`): + Dimensionality of the hidden and output embedding space. + + Returns: + `torch.Tensor`: + Tensor of shape `(..., hidden_dim)` containing the embedded representations. + """ def __init__(self, in_dim: int, hidden_dim: int): super().__init__() self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) @@ -84,6 +147,19 @@ def forward(self, x: Tensor) -> Tensor: class QKNorm(torch.nn.Module): + r""" + Applies RMS normalization to query and key tensors separately before attention + which can help stabilize training and improve numerical precision. + + Parameters: + dim (`int`): + Dimensionality of the query and key vectors. + + Returns: + (`torch.Tensor`, `torch.Tensor`): + A tuple `(q, k)` where both are normalized and cast to the same dtype + as the value tensor `v`. + """ def __init__(self, dim: int): super().__init__() self.query_norm = RMSNorm(dim, eps=1e-6) @@ -103,6 +179,22 @@ class ModulationOut: class Modulation(nn.Module): + r""" + Modulation network that generates scale, shift, and gating parameters. + + Given an input vector, the module projects it through a linear layer to + produce six chunks, which are grouped into two `ModulationOut` objects. + + Parameters: + dim (`int`): + Dimensionality of the input vector. The output will have `6 * dim` + features internally. + + Returns: + (`ModulationOut`, `ModulationOut`): + A tuple of two modulation outputs. Each `ModulationOut` contains + three components (e.g., scale, shift, gate). + """ def __init__(self, dim: int): super().__init__() self.lin = nn.Linear(dim, 6 * dim, bias=True) @@ -115,6 +207,68 @@ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]: class MirageBlock(nn.Module): + r""" + Multimodal transformer block with text–image cross-attention, modulation, and MLP. + + Parameters: + hidden_size (`int`): + Dimension of the hidden representations. + num_heads (`int`): + Number of attention heads. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Expansion ratio for the hidden dimension inside the MLP. + qk_scale (`float`, *optional*): + Scale factor for queries and keys. If not provided, defaults to + ``head_dim**-0.5``. + + Attributes: + img_pre_norm (`nn.LayerNorm`): + Pre-normalization applied to image tokens before QKV projection. + img_qkv_proj (`nn.Linear`): + Linear projection to produce image queries, keys, and values. + qk_norm (`QKNorm`): + RMS normalization applied separately to image queries and keys. + txt_kv_proj (`nn.Linear`): + Linear projection to produce text keys and values. + k_norm (`RMSNorm`): + RMS normalization applied to text keys. + attention (`Attention`): + Multi-head attention module for cross-attention between image, text, + and optional spatial conditioning tokens. + post_attention_layernorm (`nn.LayerNorm`): + Normalization applied after attention. + gate_proj / up_proj / down_proj (`nn.Linear`): + Feedforward layers forming the gated MLP. + mlp_act (`nn.GELU`): + Nonlinear activation used in the MLP. + modulation (`Modulation`): + Produces scale/shift/gating parameters for modulated layers. + spatial_cond_kv_proj (`nn.Linear`, *optional*): + Projection for optional spatial conditioning tokens. + + Methods: + attn_forward(img, txt, pe, modulation, spatial_conditioning=None, attention_mask=None): + Compute cross-attention between image and text tokens, with optional + spatial conditioning and attention masking. + + Parameters: + img (`torch.Tensor`): + Image tokens of shape `(B, L_img, hidden_size)`. + txt (`torch.Tensor`): + Text tokens of shape `(B, L_txt, hidden_size)`. + pe (`torch.Tensor`): + Rotary positional embeddings to apply to queries and keys. + modulation (`ModulationOut`): + Scale and shift parameters for modulating image tokens. + spatial_conditioning (`torch.Tensor`, *optional*): + Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. + attention_mask (`torch.Tensor`, *optional*): + Boolean mask of shape `(B, L_txt)` where 0 marks padding. + + Returns: + `torch.Tensor`: + Attention output of shape `(B, L_img, hidden_size)`. + """ def __init__( self, hidden_size: int, @@ -163,7 +317,7 @@ def __init__( self.modulation = Modulation(hidden_size) self.spatial_cond_kv_proj: None | nn.Linear = None - def attn_forward( + def _attn_forward( self, img: Tensor, txt: Tensor, @@ -236,7 +390,7 @@ def attn_forward( return attn - def ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor: + def _ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor: x = (1 + modulation.scale) * self.post_attention_layernorm(x) + modulation.shift return self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)) @@ -250,9 +404,36 @@ def forward( attention_mask: Tensor | None = None, **_: dict[str, Any], ) -> Tensor: + r""" + Runs modulation-gated cross-attention and MLP, with residual connections. + + Parameters: + img (`torch.Tensor`): + Image tokens of shape `(B, L_img, hidden_size)`. + txt (`torch.Tensor`): + Text tokens of shape `(B, L_txt, hidden_size)`. + vec (`torch.Tensor`): + Conditioning vector used by `Modulation` to produce scale/shift/gates, + shape `(B, hidden_size)` (or broadcastable). + pe (`torch.Tensor`): + Rotary positional embeddings applied inside attention. + spatial_conditioning (`torch.Tensor`, *optional*): + Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. Used only + if spatial conditioning is enabled in the block. + attention_mask (`torch.Tensor`, *optional*): + Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding. + **_: + Ignored additional keyword arguments for API compatibility. + + Returns: + `torch.Tensor`: + Updated image tokens of shape `(B, L_img, hidden_size)`. + """ + + mod_attn, mod_mlp = self.modulation(vec) - img = img + mod_attn.gate * self.attn_forward( + img = img + mod_attn.gate * self._attn_forward( img, txt, pe, @@ -260,12 +441,39 @@ def forward( spatial_conditioning=spatial_conditioning, attention_mask=attention_mask, ) - img = img + mod_mlp.gate * self.ffn_forward(img, mod_mlp) + img = img + mod_mlp.gate * self._ffn_forward(img, mod_mlp) return img class LastLayer(nn.Module): + r""" + Final projection layer with adaptive LayerNorm modulation. + + This layer applies a normalized and modulated transformation to input tokens + and projects them into patch-level outputs. + + Parameters: + hidden_size (`int`): + Dimensionality of the input tokens. + patch_size (`int`): + Size of the square image patches. + out_channels (`int`): + Number of output channels per pixel (e.g. RGB = 3). + + Forward Inputs: + x (`torch.Tensor`): + Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches. + vec (`torch.Tensor`): + Conditioning vector of shape `(B, hidden_size)` used to generate + shift and scale parameters for adaptive LayerNorm. + + Returns: + `torch.Tensor`: + Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`. + """ + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) @@ -284,12 +492,41 @@ def forward(self, x: Tensor, vec: Tensor) -> Tensor: def img2seq(img: Tensor, patch_size: int) -> Tensor: - """Flatten an image into a sequence of patches""" + r""" + Flattens an image tensor into a sequence of non-overlapping patches. + + Parameters: + img (`torch.Tensor`): + Input image tensor of shape `(B, C, H, W)`. + patch_size (`int`): + Size of each square patch. Must evenly divide both `H` and `W`. + + Returns: + `torch.Tensor`: + Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, + where `L = (H // patch_size) * (W // patch_size)` is the number of patches. + """ return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: - """Revert img2seq""" + r""" + Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`). + + Parameters: + seq (`torch.Tensor`): + Patch sequence of shape `(B, L, C * patch_size * patch_size)`, + where `L = (H // patch_size) * (W // patch_size)`. + patch_size (`int`): + Size of each square patch. + shape (`tuple` or `torch.Tensor`): + The original image spatial shape `(H, W)`. If a tensor is provided, + the first two values are interpreted as height and width. + + Returns: + `torch.Tensor`: + Reconstructed image tensor of shape `(B, C, H, W)`. + """ if isinstance(shape, tuple): shape = shape[-2:] elif isinstance(shape, torch.Tensor): @@ -300,7 +537,70 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: class MirageTransformer2DModel(ModelMixin, ConfigMixin): - """Mirage Transformer model with IP-Adapter support.""" + r""" + Transformer-based 2D model for text to image generation. + It supports attention processor injection and LoRA scaling. + + Parameters: + in_channels (`int`, *optional*, defaults to 16): + Number of input channels in the latent image. + patch_size (`int`, *optional*, defaults to 2): + Size of the square patches used to flatten the input image. + context_in_dim (`int`, *optional*, defaults to 2304): + Dimensionality of the text conditioning input. + hidden_size (`int`, *optional*, defaults to 1792): + Dimension of the hidden representation. + mlp_ratio (`float`, *optional*, defaults to 3.5): + Expansion ratio for the hidden dimension inside MLP blocks. + num_heads (`int`, *optional*, defaults to 28): + Number of attention heads. + depth (`int`, *optional*, defaults to 16): + Number of transformer blocks. + axes_dim (`list[int]`, *optional*): + List of dimensions for each positional embedding axis. Defaults to `[32, 32]`. + theta (`int`, *optional*, defaults to 10000): + Frequency scaling factor for rotary embeddings. + time_factor (`float`, *optional*, defaults to 1000.0): + Scaling factor applied in timestep embeddings. + time_max_period (`int`, *optional*, defaults to 10000): + Maximum frequency period for timestep embeddings. + conditioning_block_ids (`list[int]`, *optional*): + Indices of blocks that receive conditioning. Defaults to all blocks. + **kwargs: + Additional keyword arguments forwarded to the config. + + Attributes: + pe_embedder (`EmbedND`): + Multi-axis rotary embedding generator for positional encodings. + img_in (`nn.Linear`): + Projection layer for image patch tokens. + time_in (`MLPEmbedder`): + Embedding layer for timestep embeddings. + txt_in (`nn.Linear`): + Projection layer for text conditioning. + blocks (`nn.ModuleList`): + Stack of transformer blocks (`MirageBlock`). + final_layer (`LastLayer`): + Projection layer mapping hidden tokens back to patch outputs. + + Methods: + attn_processors: + Returns a dictionary of all attention processors in the model. + set_attn_processor(processor): + Replaces attention processors across all attention layers. + process_inputs(image_latent, txt): + Converts inputs into patch tokens, encodes text, and produces positional encodings. + compute_timestep_embedding(timestep, dtype): + Creates a timestep embedding of dimension 256, scaled and projected. + forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask, **block_kwargs): + Runs the sequence of transformer blocks over image and text tokens. + forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None, attention_kwargs=None, return_dict=True): + Full forward pass from latent input to reconstructed output image. + + Returns: + `Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing: + - `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`. + """ config_name = "config.json" _supports_gradient_checkpointing = True @@ -424,8 +724,8 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: - """Timestep independent stuff""" + def _process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: + txt = self.txt_in(txt) img = img2seq(image_latent, self.patch_size) bs, _, h, w = image_latent.shape @@ -433,7 +733,7 @@ def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[T pe = self.pe_embedder(img_ids) return img, txt, pe - def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: + def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: return self.time_in( get_timestep_embedding( timesteps=timestep, @@ -444,7 +744,7 @@ def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Te ).to(dtype) ) - def forward_transformers( + def _forward_transformers( self, image_latent: Tensor, cross_attn_conditioning: Tensor, @@ -460,7 +760,7 @@ def forward_transformers( else: if timestep is None: raise ValueError("Please provide either a timestep or a timestep_embedding") - vec = self.compute_timestep_embedding(timestep, dtype=img.dtype) + vec = self._compute_timestep_embedding(timestep, dtype=img.dtype) for block in self.blocks: img = block(img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs) @@ -478,6 +778,35 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: + r""" + Forward pass of the MirageTransformer2DModel. + + The latent image is split into patch tokens, combined with text conditioning, + and processed through a stack of transformer blocks modulated by the timestep. + The output is reconstructed into the latent image space. + + Parameters: + image_latent (`torch.Tensor`): + Input latent image tensor of shape `(B, C, H, W)`. + timestep (`torch.Tensor`): + Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning. + cross_attn_conditioning (`torch.Tensor`): + Text conditioning tensor of shape `(B, L_txt, context_in_dim)`. + micro_conditioning (`torch.Tensor`): + Extra conditioning vector (currently unused, reserved for future use). + cross_attn_mask (`torch.Tensor`, *optional*): + Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence. + attention_kwargs (`dict`, *optional*): + Additional arguments passed to attention layers. If using the PEFT backend, + the key `"scale"` controls LoRA scaling (default: 1.0). + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a `Transformer2DModelOutput` or a tuple. + + Returns: + `Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple: + + - `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`. + """ if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -491,8 +820,8 @@ def forward( logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) - img_seq, txt, pe = self.process_inputs(image_latent, cross_attn_conditioning) - img_seq = self.forward_transformers(img_seq, txt, timestep, pe=pe, attention_mask=cross_attn_mask) + img_seq, txt, pe = self._process_inputs(image_latent, cross_attn_conditioning) + img_seq = self._forward_transformers(img_seq, txt, timestep, pe=pe, attention_mask=cross_attn_mask) output = seq2img(img_seq, self.patch_size, image_latent.shape) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer diff --git a/tests/models/transformers/test_models_transformer_mirage.py b/tests/models/transformers/test_models_transformer_mirage.py index 0085627aa7e4..fe7436debc4c 100644 --- a/tests/models/transformers/test_models_transformer_mirage.py +++ b/tests/models/transformers/test_models_transformer_mirage.py @@ -132,7 +132,7 @@ def test_process_inputs(self): model.eval() with torch.no_grad(): - img_seq, txt, pe = model.process_inputs( + img_seq, txt, pe = model._process_inputs( inputs_dict["image_latent"], inputs_dict["cross_attn_conditioning"] ) @@ -160,12 +160,12 @@ def test_forward_transformers(self): with torch.no_grad(): # Process inputs first - img_seq, txt, pe = model.process_inputs( + img_seq, txt, pe = model._process_inputs( inputs_dict["image_latent"], inputs_dict["cross_attn_conditioning"] ) # Test forward_transformers - output_seq = model.forward_transformers(img_seq, txt, timestep=inputs_dict["timestep"], pe=pe) + output_seq = model._forward_transformers(img_seq, txt, timestep=inputs_dict["timestep"], pe=pe) # Check output shape expected_out_channels = init_dict["in_channels"] * init_dict["patch_size"] ** 2 From 394f725139a57e88f5e0b0d6e458db774606a7d5 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 21:28:43 +0000 Subject: [PATCH 12/38] if conditions and raised as ValueError instead of asserts --- .../models/transformers/transformer_mirage.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index c509f797fb8b..90ba11fb2d24 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -360,10 +360,12 @@ def _attn_forward( bs, _, l_img, _ = img_q.shape l_txt = txt_k.shape[2] - assert attention_mask.dim() == 2, f"Unsupported attention_mask shape: {attention_mask.shape}" - assert attention_mask.shape[-1] == l_txt, ( - f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" - ) + if attention_mask.dim() != 2: + raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") + if attention_mask.shape[-1] != l_txt: + raise ValueError( + f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" + ) device = img_q.device From 54fb0632d80e8516cc9c716144a609f226cde4e9 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 21:33:28 +0000 Subject: [PATCH 13/38] small fix --- src/diffusers/pipelines/mirage/pipeline_mirage.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py index 9d247eecbd7f..50304ae1a3ad 100644 --- a/src/diffusers/pipelines/mirage/pipeline_mirage.py +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -640,13 +640,13 @@ def __call__( t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) # Process inputs for transformer - img_seq, txt, pe = self.transformer.process_inputs(latents_in, ca_embed) + img_seq, txt, pe = self.transformer._process_inputs(latents_in, ca_embed) # Forward through transformer layers - img_seq = self.transformer.forward_transformers( + img_seq = self.transformer._forward_transformers( img_seq, txt, - time_embedding=self.transformer.compute_timestep_embedding(t_cont, img_seq.dtype), + time_embedding=self.transformer._compute_timestep_embedding(t_cont, img_seq.dtype), pe=pe, attention_mask=ca_mask, ) From c49fafbaaba17d5a9470af42d11ea1a30389fcb2 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 21:34:47 +0000 Subject: [PATCH 14/38] nit remove try block at import --- src/diffusers/pipelines/mirage/pipeline_mirage.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py index 50304ae1a3ad..ced78adec786 100644 --- a/src/diffusers/pipelines/mirage/pipeline_mirage.py +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -31,7 +31,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderDC, AutoencoderKL -from ...models.transformers.transformer_mirage import seq2img +from ...models.transformers.transformer_mirage import MirageTransformer2DModel, seq2img from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( logging, @@ -42,11 +42,6 @@ from .pipeline_output import MiragePipelineOutput -try: - from ...models.transformers.transformer_mirage import MirageTransformer2DModel -except ImportError: - MirageTransformer2DModel = None - DEFAULT_HEIGHT = 512 DEFAULT_WIDTH = 512 From 7e7df3569204a62f27797b60b7db2e1716c29a3b Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 21:35:16 +0000 Subject: [PATCH 15/38] mirage pipeline doc --- docs/source/en/api/pipelines/mirage.md | 158 +++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 docs/source/en/api/pipelines/mirage.md diff --git a/docs/source/en/api/pipelines/mirage.md b/docs/source/en/api/pipelines/mirage.md new file mode 100644 index 000000000000..3383bbecae2a --- /dev/null +++ b/docs/source/en/api/pipelines/mirage.md @@ -0,0 +1,158 @@ + + +# MiragePipeline + +
+ LoRA +
+ +Mirage is a text-to-image diffusion model using a transformer-based architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports both Flux VAE (AutoencoderKL) and DC-AE (AutoencoderDC) for latent compression. + +Key features: + +- **Transformer Architecture**: Uses a modern transformer-based denoising model with attention mechanisms optimized for image generation +- **Flow Matching**: Employs flow matching with Euler discrete scheduling for efficient sampling +- **Flexible VAE Support**: Compatible with both Flux VAE (8x compression, 16 latent channels) and DC-AE (32x compression, 32 latent channels) +- **T5Gemma Text Encoder**: Uses Google's T5Gemma-2B-2B-UL2 model for text encoding with strong text-image alignment +- **Efficient Architecture**: ~1.3B parameters in the transformer, enabling fast inference while maintaining quality +- **Modular Design**: Text encoder and VAE weights are loaded from HuggingFace, keeping checkpoint sizes small + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +## Loading the Pipeline + +Mirage checkpoints only store the transformer and scheduler weights locally. The VAE and text encoder are automatically loaded from HuggingFace during pipeline initialization: + +```py +from diffusers import MiragePipeline + +# Load pipeline - VAE and text encoder will be loaded from HuggingFace +pipe = MiragePipeline.from_pretrained("path/to/mirage_checkpoint") +pipe.to("cuda") + +prompt = "A digital painting of a rusty, vintage tram on a sandy beach" +image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] +image.save("mirage_output.png") +``` + +### Manual Component Loading + +You can also load components individually: + +```py +import torch +from diffusers import MiragePipeline +from diffusers.models import AutoencoderKL, AutoencoderDC +from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from transformers import T5GemmaModel, GemmaTokenizerFast + +# Load transformer +transformer = MirageTransformer2DModel.from_pretrained( + "path/to/checkpoint", subfolder="transformer" +) + +# Load scheduler +scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + "path/to/checkpoint", subfolder="scheduler" +) + +# Load T5Gemma text encoder +t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") +text_encoder = t5gemma_model.encoder +tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") + +# Load VAE - choose either Flux VAE or DC-AE +# Flux VAE (16 latent channels): +vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") +# Or DC-AE (32 latent channels): +# vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers") + +pipe = MiragePipeline( + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae +) +pipe.to("cuda") +``` + +## VAE Variants + +Mirage supports two VAE configurations: + +### Flux VAE (AutoencoderKL) +- **Compression**: 8x spatial compression +- **Latent channels**: 16 +- **Model**: `black-forest-labs/FLUX.1-dev` (subfolder: "vae") +- **Use case**: Balanced quality and speed + +### DC-AE (AutoencoderDC) +- **Compression**: 32x spatial compression +- **Latent channels**: 32 +- **Model**: `mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers` +- **Use case**: Higher compression for faster processing + +The VAE type is automatically determined from the checkpoint's `model_index.json` configuration. + +## Generation Parameters + +Key parameters for image generation: + +- **num_inference_steps**: Number of denoising steps (default: 28). More steps generally improve quality at the cost of speed. +- **guidance_scale**: Classifier-free guidance strength (default: 4.0). Higher values produce images more closely aligned with the prompt. +- **height/width**: Output image dimensions (default: 512x512). Can be customized in the checkpoint configuration. + +```py +# Example with custom parameters +image = pipe( + prompt="A serene mountain landscape at sunset", + num_inference_steps=28, + guidance_scale=4.0, + height=1024, + width=1024, + generator=torch.Generator("cuda").manual_seed(42) +).images[0] +``` + +## Memory Optimization + +For memory-constrained environments: + +```py +import torch +from diffusers import MiragePipeline + +pipe = MiragePipeline.from_pretrained("path/to/checkpoint", torch_dtype=torch.float16) +pipe.enable_model_cpu_offload() # Offload components to CPU when not in use + +# Or use sequential CPU offload for even lower memory +pipe.enable_sequential_cpu_offload() +``` + +## MiragePipeline + +[[autodoc]] MiragePipeline + - all + - __call__ + +## MiragePipelineOutput + +[[autodoc]] pipelines.mirage.pipeline_output.MiragePipelineOutput From de03851e2f9cfd75b797a8d271798f9fa59fccd7 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 7 Oct 2025 14:20:17 +0000 Subject: [PATCH 16/38] update doc --- docs/source/en/api/pipelines/mirage.md | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/docs/source/en/api/pipelines/mirage.md b/docs/source/en/api/pipelines/mirage.md index 3383bbecae2a..f0117795a989 100644 --- a/docs/source/en/api/pipelines/mirage.md +++ b/docs/source/en/api/pipelines/mirage.md @@ -22,18 +22,12 @@ Mirage is a text-to-image diffusion model using a transformer-based architecture Key features: -- **Transformer Architecture**: Uses a modern transformer-based denoising model with attention mechanisms optimized for image generation -- **Flow Matching**: Employs flow matching with Euler discrete scheduling for efficient sampling +- **Simplified MMDIT architecture**: Uses a simplified MMDIT architecture for image generation where text tokens are not updated through the transformer blocks +- **Flow Matching**: Employs flow matching with discrete scheduling for efficient sampling - **Flexible VAE Support**: Compatible with both Flux VAE (8x compression, 16 latent channels) and DC-AE (32x compression, 32 latent channels) -- **T5Gemma Text Encoder**: Uses Google's T5Gemma-2B-2B-UL2 model for text encoding with strong text-image alignment +- **T5Gemma Text Encoder**: Uses Google's T5Gemma-2B-2B-UL2 model for text encoding offering multiple language support - **Efficient Architecture**: ~1.3B parameters in the transformer, enabling fast inference while maintaining quality -- **Modular Design**: Text encoder and VAE weights are loaded from HuggingFace, keeping checkpoint sizes small - - -Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. - - ## Loading the Pipeline @@ -46,7 +40,7 @@ from diffusers import MiragePipeline pipe = MiragePipeline.from_pretrained("path/to/mirage_checkpoint") pipe.to("cuda") -prompt = "A digital painting of a rusty, vintage tram on a sandy beach" +prompt = "A vibrant night sky filled with colorful fireworks, with one large firework burst forming the glowing text “Photon” in bright, sparkling light" image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] image.save("mirage_output.png") ``` @@ -123,11 +117,11 @@ Key parameters for image generation: ```py # Example with custom parameters image = pipe( - prompt="A serene mountain landscape at sunset", + prompt="A vibrant night sky filled with colorful fireworks, with one large firework burst forming the glowing text “Photon” in bright, sparkling light", num_inference_steps=28, guidance_scale=4.0, - height=1024, - width=1024, + height=512, + width=512, generator=torch.Generator("cuda").manual_seed(42) ).images[0] ``` From a69aa4bb5bd6d4b5d5a0c6611e7c014df683b4a4 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 7 Oct 2025 14:25:27 +0000 Subject: [PATCH 17/38] rename model to photon --- .../en/api/pipelines/{mirage.md => photon.md} | 34 ++++++++-------- ...sers.py => convert_photon_to_diffusers.py} | 34 ++++++++-------- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/attention_processor.py | 20 +++++----- src/diffusers/models/transformers/__init__.py | 2 +- ...former_mirage.py => transformer_photon.py} | 14 +++---- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/mirage/__init__.py | 5 --- src/diffusers/pipelines/photon/__init__.py | 5 +++ .../{mirage => photon}/pipeline_output.py | 4 +- .../pipeline_photon.py} | 40 +++++++++---------- ...e.py => test_models_transformer_photon.py} | 14 +++---- 13 files changed, 89 insertions(+), 89 deletions(-) rename docs/source/en/api/pipelines/{mirage.md => photon.md} (86%) rename scripts/{convert_mirage_to_diffusers.py => convert_photon_to_diffusers.py} (92%) rename src/diffusers/models/transformers/{transformer_mirage.py => transformer_photon.py} (99%) delete mode 100644 src/diffusers/pipelines/mirage/__init__.py create mode 100644 src/diffusers/pipelines/photon/__init__.py rename src/diffusers/pipelines/{mirage => photon}/pipeline_output.py (93%) rename src/diffusers/pipelines/{mirage/pipeline_mirage.py => photon/pipeline_photon.py} (95%) rename tests/models/transformers/{test_models_transformer_mirage.py => test_models_transformer_photon.py} (95%) diff --git a/docs/source/en/api/pipelines/mirage.md b/docs/source/en/api/pipelines/photon.md similarity index 86% rename from docs/source/en/api/pipelines/mirage.md rename to docs/source/en/api/pipelines/photon.md index f0117795a989..f8f7098545f8 100644 --- a/docs/source/en/api/pipelines/mirage.md +++ b/docs/source/en/api/pipelines/photon.md @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. --> -# MiragePipeline +# PhotonPipeline
LoRA
-Mirage is a text-to-image diffusion model using a transformer-based architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports both Flux VAE (AutoencoderKL) and DC-AE (AutoencoderDC) for latent compression. +Photon is a text-to-image diffusion model using a transformer-based architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports both Flux VAE (AutoencoderKL) and DC-AE (AutoencoderDC) for latent compression. Key features: @@ -31,18 +31,18 @@ Key features: ## Loading the Pipeline -Mirage checkpoints only store the transformer and scheduler weights locally. The VAE and text encoder are automatically loaded from HuggingFace during pipeline initialization: +Photon checkpoints only store the transformer and scheduler weights locally. The VAE and text encoder are automatically loaded from HuggingFace during pipeline initialization: ```py -from diffusers import MiragePipeline +from diffusers import PhotonPipeline # Load pipeline - VAE and text encoder will be loaded from HuggingFace -pipe = MiragePipeline.from_pretrained("path/to/mirage_checkpoint") +pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint") pipe.to("cuda") prompt = "A vibrant night sky filled with colorful fireworks, with one large firework burst forming the glowing text “Photon” in bright, sparkling light" image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] -image.save("mirage_output.png") +image.save("photon_output.png") ``` ### Manual Component Loading @@ -51,14 +51,14 @@ You can also load components individually: ```py import torch -from diffusers import MiragePipeline +from diffusers import PhotonPipeline from diffusers.models import AutoencoderKL, AutoencoderDC -from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from transformers import T5GemmaModel, GemmaTokenizerFast # Load transformer -transformer = MirageTransformer2DModel.from_pretrained( +transformer = PhotonTransformer2DModel.from_pretrained( "path/to/checkpoint", subfolder="transformer" ) @@ -78,7 +78,7 @@ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="v # Or DC-AE (32 latent channels): # vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers") -pipe = MiragePipeline( +pipe = PhotonPipeline( transformer=transformer, scheduler=scheduler, text_encoder=text_encoder, @@ -90,7 +90,7 @@ pipe.to("cuda") ## VAE Variants -Mirage supports two VAE configurations: +Photon supports two VAE configurations: ### Flux VAE (AutoencoderKL) - **Compression**: 8x spatial compression @@ -132,21 +132,21 @@ For memory-constrained environments: ```py import torch -from diffusers import MiragePipeline +from diffusers import PhotonPipeline -pipe = MiragePipeline.from_pretrained("path/to/checkpoint", torch_dtype=torch.float16) +pipe = PhotonPipeline.from_pretrained("path/to/checkpoint", torch_dtype=torch.float16) pipe.enable_model_cpu_offload() # Offload components to CPU when not in use # Or use sequential CPU offload for even lower memory pipe.enable_sequential_cpu_offload() ``` -## MiragePipeline +## PhotonPipeline -[[autodoc]] MiragePipeline +[[autodoc]] PhotonPipeline - all - __call__ -## MiragePipelineOutput +## PhotonPipelineOutput -[[autodoc]] pipelines.mirage.pipeline_output.MiragePipelineOutput +[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_photon_to_diffusers.py similarity index 92% rename from scripts/convert_mirage_to_diffusers.py rename to scripts/convert_photon_to_diffusers.py index 37de253d1448..ad04463e019f 100644 --- a/scripts/convert_mirage_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -Script to convert Mirage checkpoint from original codebase to diffusers format. +Script to convert Photon checkpoint from original codebase to diffusers format. """ import argparse @@ -16,14 +16,14 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) -from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel -from diffusers.pipelines.mirage import MiragePipeline +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel +from diffusers.pipelines.photon import PhotonPipeline DEFAULT_HEIGHT = 512 DEFAULT_WIDTH = 512 @dataclass(frozen=True) -class MirageBase: +class PhotonBase: context_in_dim: int = 2304 hidden_size: int = 1792 mlp_ratio: float = 3.5 @@ -36,22 +36,22 @@ class MirageBase: @dataclass(frozen=True) -class MirageFlux(MirageBase): +class PhotonFlux(PhotonBase): in_channels: int = 16 patch_size: int = 2 @dataclass(frozen=True) -class MirageDCAE(MirageBase): +class PhotonDCAE(PhotonBase): in_channels: int = 32 patch_size: int = 1 def build_config(vae_type: str) -> dict: if vae_type == "flux": - cfg = MirageFlux() + cfg = PhotonFlux() elif vae_type == "dc-ae": - cfg = MirageDCAE() + cfg = PhotonDCAE() else: raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") @@ -125,8 +125,8 @@ def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth return converted_state_dict -def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> MirageTransformer2DModel: - """Create and load MirageTransformer2DModel from old checkpoint.""" +def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel: + """Create and load PhotonTransformer2DModel from old checkpoint.""" print(f"Loading checkpoint from: {checkpoint_path}") @@ -154,8 +154,8 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Mi converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth) # Create transformer with config - print("Creating MirageTransformer2DModel...") - transformer = MirageTransformer2DModel(**config) + print("Creating PhotonTransformer2DModel...") + transformer = PhotonTransformer2DModel(**config) # Load state dict print("Loading converted parameters...") @@ -212,13 +212,13 @@ def create_model_index(vae_type: str, output_path: str): text_model_name = "google/t5gemma-2b-2b-ul2" model_index = { - "_class_name": "MiragePipeline", + "_class_name": "PhotonPipeline", "_diffusers_version": "0.31.0.dev0", "_name_or_path": os.path.basename(output_path), "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], "text_encoder": text_model_name, "tokenizer": text_model_name, - "transformer": ["diffusers", "MirageTransformer2DModel"], + "transformer": ["diffusers", "PhotonTransformer2DModel"], "vae": vae_model_name, "vae_subfolder": vae_subfolder, "default_height": default_height, @@ -262,7 +262,7 @@ def main(args): # Verify the pipeline can be loaded try: - pipeline = MiragePipeline.from_pretrained(args.output_path) + pipeline = PhotonPipeline.from_pretrained(args.output_path) print("Pipeline loaded successfully!") print(f"Transformer: {type(pipeline.transformer).__name__}") print(f"VAE: {type(pipeline.vae).__name__}") @@ -285,10 +285,10 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert Mirage checkpoint to diffusers format") + parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format") parser.add_argument( - "--checkpoint_path", type=str, required=True, help="Path to the original Mirage checkpoint (.pth file)" + "--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file)" ) parser.add_argument( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6fc6ac5f3ebd..13b0ac8d64b0 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -224,7 +224,7 @@ "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", - "MirageTransformer2DModel", + "PhotonTransformer2DModel", "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 279e69216b1b..86e32c1eec3e 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -93,7 +93,7 @@ _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] - _import_structure["transformers.transformer_mirage"] = ["MirageTransformer2DModel"] + _import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 08e80e4329ba..23ec72a8f657 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5609,15 +5609,15 @@ def __new__(cls, *args, **kwargs): return processor -class MirageAttnProcessor2_0: +class PhotonAttnProcessor2_0: r""" - Processor for implementing Mirage-style attention with multi-source tokens and RoPE. - Properly integrates with diffusers Attention module while handling Mirage-specific logic. + Processor for implementing Photon-style attention with multi-source tokens and RoPE. + Properly integrates with diffusers Attention module while handling Photon-specific logic. """ def __init__(self): if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): - raise ImportError("MirageAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") + raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") def __call__( self, @@ -5629,9 +5629,9 @@ def __call__( **kwargs, ) -> torch.Tensor: """ - Apply Mirage attention using standard diffusers interface. + Apply Photon attention using standard diffusers interface. - Expected tensor formats from MirageBlock.attn_forward(): + Expected tensor formats from PhotonBlock.attn_forward(): - hidden_states: Image queries with RoPE applied [B, H, L_img, D] - encoder_hidden_states: Packed key+value tensors [B, H, L_all, 2*D] (concatenated keys and values from text + image + spatial conditioning) @@ -5640,15 +5640,15 @@ def __call__( if encoder_hidden_states is None: raise ValueError( - "MirageAttnProcessor2_0 requires 'encoder_hidden_states' containing packed key+value tensors. " - "This should be provided by MirageBlock.attn_forward()." + "PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing packed key+value tensors. " + "This should be provided by PhotonBlock.attn_forward()." ) # Unpack the combined key+value tensor # encoder_hidden_states is [B, H, L_all, 2*D] containing [keys, values] key, value = encoder_hidden_states.chunk(2, dim=-1) # Each [B, H, L_all, D] - # Apply scaled dot-product attention with Mirage's processed tensors + # Apply scaled dot-product attention with Photon's processed tensors # hidden_states is image queries [B, H, L_img, D] attn_output = torch.nn.functional.scaled_dot_product_attention( hidden_states.contiguous(), key.contiguous(), value.contiguous(), attn_mask=attention_mask @@ -5714,7 +5714,7 @@ def __call__( PAGHunyuanAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, LuminaAttnProcessor2_0, - MirageAttnProcessor2_0, + PhotonAttnProcessor2_0, FusedAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index ebe0d0c9b8e1..652f6d811393 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -29,7 +29,7 @@ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel - from .transformer_mirage import MirageTransformer2DModel + from .transformer_photon import PhotonTransformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_photon.py similarity index 99% rename from src/diffusers/models/transformers/transformer_mirage.py rename to src/diffusers/models/transformers/transformer_photon.py index 90ba11fb2d24..9ec6e9756c20 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 +from ..attention_processor import Attention, AttentionProcessor, PhotonAttnProcessor2_0 from ..embeddings import get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -206,7 +206,7 @@ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]: return ModulationOut(*out[:3]), ModulationOut(*out[3:]) -class MirageBlock(nn.Module): +class PhotonBlock(nn.Module): r""" Multimodal transformer block with text–image cross-attention, modulation, and MLP. @@ -304,7 +304,7 @@ def __init__( dim_head=self.head_dim, bias=False, out_bias=False, - processor=MirageAttnProcessor2_0(), + processor=PhotonAttnProcessor2_0(), ) # mlp @@ -538,7 +538,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) -class MirageTransformer2DModel(ModelMixin, ConfigMixin): +class PhotonTransformer2DModel(ModelMixin, ConfigMixin): r""" Transformer-based 2D model for text to image generation. It supports attention processor injection and LoRA scaling. @@ -581,7 +581,7 @@ class MirageTransformer2DModel(ModelMixin, ConfigMixin): txt_in (`nn.Linear`): Projection layer for text conditioning. blocks (`nn.ModuleList`): - Stack of transformer blocks (`MirageBlock`). + Stack of transformer blocks (`PhotonBlock`). final_layer (`LastLayer`): Projection layer mapping hidden tokens back to patch outputs. @@ -656,7 +656,7 @@ def __init__( self.blocks = nn.ModuleList( [ - MirageBlock( + PhotonBlock( self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, @@ -781,7 +781,7 @@ def forward( return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: r""" - Forward pass of the MirageTransformer2DModel. + Forward pass of the PhotonTransformer2DModel. The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of transformer blocks modulated by the timestep. diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7b7ebb633c3b..ae0d90c48c63 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -144,7 +144,7 @@ "FluxKontextPipeline", "FluxKontextInpaintPipeline", ] - _import_structure["mirage"] = ["MiragePipeline"] + _import_structure["photon"] = ["PhotonPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", diff --git a/src/diffusers/pipelines/mirage/__init__.py b/src/diffusers/pipelines/mirage/__init__.py deleted file mode 100644 index cba951057370..000000000000 --- a/src/diffusers/pipelines/mirage/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .pipeline_mirage import MiragePipeline -from .pipeline_output import MiragePipelineOutput - - -__all__ = ["MiragePipeline", "MiragePipelineOutput"] diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/photon/__init__.py new file mode 100644 index 000000000000..d1dd5b2cbf53 --- /dev/null +++ b/src/diffusers/pipelines/photon/__init__.py @@ -0,0 +1,5 @@ +from .pipeline_photon import PhotonPipeline +from .pipeline_output import PhotonPipelineOutput + + +__all__ = ["PhotonPipeline", "PhotonPipelineOutput"] diff --git a/src/diffusers/pipelines/mirage/pipeline_output.py b/src/diffusers/pipelines/photon/pipeline_output.py similarity index 93% rename from src/diffusers/pipelines/mirage/pipeline_output.py rename to src/diffusers/pipelines/photon/pipeline_output.py index e41c8e3bea00..ca0674d94b6c 100644 --- a/src/diffusers/pipelines/mirage/pipeline_output.py +++ b/src/diffusers/pipelines/photon/pipeline_output.py @@ -22,9 +22,9 @@ @dataclass -class MiragePipelineOutput(BaseOutput): +class PhotonPipelineOutput(BaseOutput): """ - Output class for Mirage pipelines. + Output class for Photon pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/photon/pipeline_photon.py similarity index 95% rename from src/diffusers/pipelines/mirage/pipeline_mirage.py rename to src/diffusers/pipelines/photon/pipeline_photon.py index ced78adec786..ce3479fedcdd 100644 --- a/src/diffusers/pipelines/mirage/pipeline_mirage.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -31,7 +31,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderDC, AutoencoderKL -from ...models.transformers.transformer_mirage import MirageTransformer2DModel, seq2img +from ...models.transformers.transformer_photon import PhotonTransformer2DModel, seq2img from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( logging, @@ -39,7 +39,7 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import MiragePipelineOutput +from .pipeline_output import PhotonPipelineOutput DEFAULT_HEIGHT = 512 @@ -49,7 +49,7 @@ class TextPreprocessor: - """Text preprocessing utility for MiragePipeline.""" + """Text preprocessing utility for PhotonPipeline.""" def __init__(self): """Initialize text preprocessor.""" @@ -179,15 +179,15 @@ def clean_text(self, text: str) -> str: Examples: ```py >>> import torch - >>> from diffusers import MiragePipeline + >>> from diffusers import PhotonPipeline >>> from diffusers.models import AutoencoderKL, AutoencoderDC >>> from transformers import T5GemmaModel, GemmaTokenizerFast >>> # Load pipeline directly with from_pretrained - >>> pipe = MiragePipeline.from_pretrained("path/to/mirage_checkpoint") + >>> pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint") >>> # Or initialize pipeline components manually - >>> transformer = MirageTransformer2DModel.from_pretrained("path/to/transformer") + >>> transformer = PhotonTransformer2DModel.from_pretrained("path/to/transformer") >>> scheduler = FlowMatchEulerDiscreteScheduler() >>> # Load T5Gemma encoder >>> t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") @@ -195,7 +195,7 @@ def clean_text(self, text: str) -> str: >>> tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") >>> vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") - >>> pipe = MiragePipeline( + >>> pipe = PhotonPipeline( ... transformer=transformer, ... scheduler=scheduler, ... text_encoder=text_encoder, @@ -205,26 +205,26 @@ def clean_text(self, text: str) -> str: >>> pipe.to("cuda") >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] - >>> image.save("mirage_output.png") + >>> image.save("photon_output.png") ``` """ -class MiragePipeline( +class PhotonPipeline( DiffusionPipeline, LoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, ): r""" - Pipeline for text-to-image generation using Mirage Transformer. + Pipeline for text-to-image generation using Photon Transformer. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: - transformer ([`MirageTransformer2DModel`]): - The Mirage transformer model to denoise the encoded image latents. + transformer ([`PhotonTransformer2DModel`]): + The Photon transformer model to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. text_encoder ([`T5EncoderModel`]): @@ -248,7 +248,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P """ Override from_pretrained to load VAE and text encoder from HuggingFace models. - The MiragePipeline checkpoints only store transformer and scheduler locally. + The PhotonPipeline checkpoints only store transformer and scheduler locally. VAE and text encoder are loaded from external HuggingFace models as specified in model_index.json. """ @@ -285,7 +285,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Load transformer and scheduler from local checkpoint logger.info(f"Loading transformer from {pretrained_model_name_or_path}...") - transformer = MirageTransformer2DModel.from_pretrained( + transformer = PhotonTransformer2DModel.from_pretrained( pretrained_model_name_or_path, subfolder="transformer" ) @@ -310,7 +310,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P def __init__( self, - transformer: MirageTransformer2DModel, + transformer: PhotonTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, text_encoder: Union[T5EncoderModel, Any], tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], @@ -318,9 +318,9 @@ def __init__( ): super().__init__() - if MirageTransformer2DModel is None: + if PhotonTransformer2DModel is None: raise ImportError( - "MirageTransformer2DModel is not available. Please ensure the transformer_mirage module is properly installed." + "PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed." ) self.text_encoder = text_encoder @@ -544,7 +544,7 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.mirage.MiragePipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`. @@ -557,7 +557,7 @@ def __call__( Examples: Returns: - [`~pipelines.mirage.MiragePipelineOutput`] or `tuple`: [`~pipelines.mirage.MiragePipelineOutput`] if + [`~pipelines.photon.PhotonPipelineOutput`] or `tuple`: [`~pipelines.photon.PhotonPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ @@ -683,4 +683,4 @@ def __call__( if not return_dict: return (image,) - return MiragePipelineOutput(images=image) + return PhotonPipelineOutput(images=image) diff --git a/tests/models/transformers/test_models_transformer_mirage.py b/tests/models/transformers/test_models_transformer_photon.py similarity index 95% rename from tests/models/transformers/test_models_transformer_mirage.py rename to tests/models/transformers/test_models_transformer_photon.py index fe7436debc4c..2f08484d230c 100644 --- a/tests/models/transformers/test_models_transformer_mirage.py +++ b/tests/models/transformers/test_models_transformer_photon.py @@ -17,7 +17,7 @@ import torch -from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin @@ -26,8 +26,8 @@ enable_full_determinism() -class MirageTransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = MirageTransformer2DModel +class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = PhotonTransformer2DModel main_input_name = "image_latent" @property @@ -92,7 +92,7 @@ def test_forward_signature(self): def test_model_initialization(self): # Test model initialization - model = MirageTransformer2DModel( + model = PhotonTransformer2DModel( in_channels=16, patch_size=2, context_in_dim=1792, @@ -121,7 +121,7 @@ def test_model_with_dict_config(self): "theta": 10_000, } - model = MirageTransformer2DModel.from_config(config_dict) + model = PhotonTransformer2DModel.from_config(config_dict) self.assertEqual(model.config.in_channels, 16) self.assertEqual(model.config.hidden_size, 1792) @@ -193,7 +193,7 @@ def test_attention_mask(self): def test_invalid_config(self): # Test invalid configuration - hidden_size not divisible by num_heads with self.assertRaises(ValueError): - MirageTransformer2DModel( + PhotonTransformer2DModel( in_channels=16, patch_size=2, context_in_dim=1792, @@ -207,7 +207,7 @@ def test_invalid_config(self): # Test invalid axes_dim that doesn't sum to pe_dim with self.assertRaises(ValueError): - MirageTransformer2DModel( + PhotonTransformer2DModel( in_channels=16, patch_size=2, context_in_dim=1792, From 9e099a7b45a2c39c3c9fea1eff1546e03702dab0 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 10:20:19 +0200 Subject: [PATCH 18/38] mirage pipeline first commit --- src/diffusers/__init__.py | 1 + src/diffusers/models/__init__.py | 1 + src/diffusers/models/transformers/__init__.py | 1 + .../models/transformers/transformer_mirage.py | 489 ++++++++++++++ src/diffusers/pipelines/__init__.py | 1 + src/diffusers/pipelines/mirage/__init__.py | 4 + .../pipelines/mirage/pipeline_mirage.py | 629 ++++++++++++++++++ .../pipelines/mirage/pipeline_output.py | 35 + .../test_models_transformer_mirage.py | 252 +++++++ 9 files changed, 1413 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_mirage.py create mode 100644 src/diffusers/pipelines/mirage/__init__.py create mode 100644 src/diffusers/pipelines/mirage/pipeline_mirage.py create mode 100644 src/diffusers/pipelines/mirage/pipeline_output.py create mode 100644 tests/models/transformers/test_models_transformer_mirage.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 686e8d99dabf..6c419b6e7ad1 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -224,6 +224,7 @@ "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", + "MirageTransformer2DModel", "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 457f70448af3..279e69216b1b 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -93,6 +93,7 @@ _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] + _import_structure["transformers.transformer_mirage"] = ["MirageTransformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index b60f0636e6dc..ebe0d0c9b8e1 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -29,6 +29,7 @@ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel + from .transformer_mirage import MirageTransformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py new file mode 100644 index 000000000000..39c569cbb26b --- /dev/null +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -0,0 +1,489 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union, Tuple +import torch +import math +from torch import Tensor, nn +from torch.nn.functional import fold, unfold +from einops import rearrange +from einops.layers.torch import Rearrange + +from ...configuration_utils import ConfigMixin, register_to_config +from ..modeling_utils import ModelMixin +from ..modeling_outputs import Transformer2DModelOutput +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers + + +logger = logging.get_logger(__name__) + + +# Mirage Layer Components +def get_image_ids(bs: int, h: int, w: int, patch_size: int, device: torch.device) -> Tensor: + img_ids = torch.zeros(h // patch_size, w // patch_size, 2, device=device) + img_ids[..., 0] = torch.arange(h // patch_size, device=device)[:, None] + img_ids[..., 1] = torch.arange(w // patch_size, device=device)[None, :] + return img_ids.reshape((h // patch_size) * (w // patch_size), 2).unsqueeze(0).repeat(bs, 1, 1) + + +def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq) + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: list[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + self.rope_rearrange = Rearrange("b n d (i j) -> b n d i j", i=2, j=2) + + def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = pos.unsqueeze(-1) * omega.unsqueeze(0) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + out = self.rope_rearrange(out) + return out.float() + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(1) + + +def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000, time_factor: float = 1000.0) -> Tensor: + t = time_factor * t + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + x_dtype = x.dtype + x = x.float() + rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) + return (x * rrms * self.scale).to(dtype=x_dtype) + + +class QKNorm(torch.nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = RMSNorm(dim) + self.key_norm = RMSNorm(dim) + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: + q = self.query_norm(q) + k = self.key_norm(k) + return q.to(v), k.to(v) + + +@dataclass +class ModulationOut: + shift: Tensor + scale: Tensor + gate: Tensor + + +class Modulation(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.lin = nn.Linear(dim, 6 * dim, bias=True) + nn.init.constant_(self.lin.weight, 0) + nn.init.constant_(self.lin.bias, 0) + + def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) + return ModulationOut(*out[:3]), ModulationOut(*out[3:]) + + +class MirageBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + + self._fsdp_wrap = True + self._activation_checkpointing = True + + self.hidden_dim = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.hidden_size = hidden_size + + # img qkv + self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_qkv_proj = nn.Linear(hidden_size, hidden_size * 3, bias=False) + self.attn_out = nn.Linear(hidden_size, hidden_size, bias=False) + self.qk_norm = QKNorm(self.head_dim) + + # txt kv + self.txt_kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=False) + self.k_norm = RMSNorm(self.head_dim) + + + # mlp + self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) + self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) + self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False) + self.mlp_act = nn.GELU(approximate="tanh") + + self.modulation = Modulation(hidden_size) + self.spatial_cond_kv_proj: None | nn.Linear = None + + def attn_forward( + self, + img: Tensor, + txt: Tensor, + pe: Tensor, + modulation: ModulationOut, + spatial_conditioning: None | Tensor = None, + attention_mask: None | Tensor = None, + ) -> Tensor: + # image tokens proj and norm + img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift + + img_qkv = self.img_qkv_proj(img_mod) + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + img_q, img_k = self.qk_norm(img_q, img_k, img_v) + + # txt tokens proj and norm + txt_kv = self.txt_kv_proj(txt) + txt_k, txt_v = rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) + txt_k = self.k_norm(txt_k) + + # compute attention + img_q, img_k = apply_rope(img_q, pe), apply_rope(img_k, pe) + k = torch.cat((txt_k, img_k), dim=2) + v = torch.cat((txt_v, img_v), dim=2) + + # optional spatial conditioning tokens + cond_len = 0 + if self.spatial_cond_kv_proj is not None: + assert spatial_conditioning is not None + cond_kv = self.spatial_cond_kv_proj(spatial_conditioning) + cond_k, cond_v = rearrange(cond_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) + cond_k = apply_rope(cond_k, pe) + cond_len = cond_k.shape[2] + k = torch.cat((cond_k, k), dim=2) + v = torch.cat((cond_v, v), dim=2) + + # build additive attention bias + attn_bias: Tensor | None = None + attn_mask: Tensor | None = None + + # build multiplicative 0/1 mask for provided attention_mask over [cond?, text, image] keys + if attention_mask is not None: + bs, _, l_img, _ = img_q.shape + l_txt = txt_k.shape[2] + l_all = k.shape[2] + + assert attention_mask.dim() == 2, f"Unsupported attention_mask shape: {attention_mask.shape}" + assert ( + attention_mask.shape[-1] == l_txt + ), f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" + + device = img_q.device + + ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device) + cond_mask = torch.ones((bs, cond_len), dtype=torch.bool, device=device) + + mask_parts = [ + cond_mask, + attention_mask.to(torch.bool), + ones_img, + ] + joint_mask = torch.cat(mask_parts, dim=-1) # (B, L_all) + + # repeat across heads and query positions + attn_mask = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1) # (B,H,L_img,L_all) + + attn = torch.nn.functional.scaled_dot_product_attention( + img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask + ) + attn = rearrange(attn, "B H L D -> B L (H D)") + attn = self.attn_out(attn) + + return attn + + def ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor: + x = (1 + modulation.scale) * self.post_attention_layernorm(x) + modulation.shift + return self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)) + + def forward( + self, + img: Tensor, + txt: Tensor, + vec: Tensor, + pe: Tensor, + spatial_conditioning: Tensor | None = None, + attention_mask: Tensor | None = None, + **_: dict[str, Any], + ) -> Tensor: + mod_attn, mod_mlp = self.modulation(vec) + + img = img + mod_attn.gate * self.attn_forward( + img, + txt, + pe, + mod_attn, + spatial_conditioning=spatial_conditioning, + attention_mask=attention_mask, + ) + img = img + mod_mlp.gate * self.ffn_forward(img, mod_mlp) + return img + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + nn.init.constant_(self.adaLN_modulation[1].weight, 0) + nn.init.constant_(self.adaLN_modulation[1].bias, 0) + nn.init.constant_(self.linear.weight, 0) + nn.init.constant_(self.linear.bias, 0) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +@dataclass +class MirageParams: + in_channels: int + patch_size: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + axes_dim: list[int] + theta: int + time_factor: float = 1000.0 + time_max_period: int = 10_000 + conditioning_block_ids: list[int] | None = None + + +def img2seq(img: Tensor, patch_size: int) -> Tensor: + """Flatten an image into a sequence of patches""" + return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) + + +def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: + """Revert img2seq""" + if isinstance(shape, tuple): + shape = shape[-2:] + elif isinstance(shape, torch.Tensor): + shape = (int(shape[0]), int(shape[1])) + else: + raise NotImplementedError(f"shape type {type(shape)} not supported") + return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) + + +class MirageTransformer2DModel(ModelMixin, ConfigMixin): + """Mirage Transformer model with IP-Adapter support.""" + + config_name = "config.json" + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 16, + patch_size: int = 2, + context_in_dim: int = 2304, + hidden_size: int = 1792, + mlp_ratio: float = 3.5, + num_heads: int = 28, + depth: int = 16, + axes_dim: list = None, + theta: int = 10000, + time_factor: float = 1000.0, + time_max_period: int = 10000, + conditioning_block_ids: list = None, + **kwargs + ): + super().__init__() + + if axes_dim is None: + axes_dim = [32, 32] + + # Create MirageParams from the provided arguments + params = MirageParams( + in_channels=in_channels, + patch_size=patch_size, + context_in_dim=context_in_dim, + hidden_size=hidden_size, + mlp_ratio=mlp_ratio, + num_heads=num_heads, + depth=depth, + axes_dim=axes_dim, + theta=theta, + time_factor=time_factor, + time_max_period=time_max_period, + conditioning_block_ids=conditioning_block_ids, + ) + + self.params = params + self.in_channels = params.in_channels + self.patch_size = params.patch_size + self.out_channels = self.in_channels * self.patch_size**2 + + self.time_factor = params.time_factor + self.time_max_period = params.time_max_period + + if params.hidden_size % params.num_heads != 0: + raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + + pe_dim = params.hidden_size // params.num_heads + + if sum(params.axes_dim) != pe_dim: + raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + conditioning_block_ids: list[int] = params.conditioning_block_ids or list(range(params.depth)) + + self.blocks = nn.ModuleList( + [ + MirageBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + ) + for i in range(params.depth) + ] + ) + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: + """Timestep independent stuff""" + txt = self.txt_in(txt) + img = img2seq(image_latent, self.patch_size) + bs, _, h, w = image_latent.shape + img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device) + pe = self.pe_embedder(img_ids) + return img, txt, pe + + def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: + return self.time_in( + timestep_embedding( + t=timestep, dim=256, max_period=self.time_max_period, time_factor=self.time_factor + ).to(dtype) + ) + + def forward_transformers( + self, + image_latent: Tensor, + cross_attn_conditioning: Tensor, + timestep: Optional[Tensor] = None, + time_embedding: Optional[Tensor] = None, + attention_mask: Optional[Tensor] = None, + **block_kwargs: Any, + ) -> Tensor: + img = self.img_in(image_latent) + + if time_embedding is not None: + vec = time_embedding + else: + if timestep is None: + raise ValueError("Please provide either a timestep or a timestep_embedding") + vec = self.compute_timestep_embedding(timestep, dtype=img.dtype) + + for block in self.blocks: + img = block( + img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs + ) + + img = self.final_layer(img, vec) + return img + + def forward( + self, + image_latent: Tensor, + timestep: Tensor, + cross_attn_conditioning: Tensor, + micro_conditioning: Tensor, + cross_attn_mask: None | Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + img_seq, txt, pe = self.process_inputs(image_latent, cross_attn_conditioning) + img_seq = self.forward_transformers(img_seq, txt, timestep, pe=pe, attention_mask=cross_attn_mask) + output = seq2img(img_seq, self.patch_size, image_latent.shape) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 190c7871d270..7b7ebb633c3b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -144,6 +144,7 @@ "FluxKontextPipeline", "FluxKontextInpaintPipeline", ] + _import_structure["mirage"] = ["MiragePipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", diff --git a/src/diffusers/pipelines/mirage/__init__.py b/src/diffusers/pipelines/mirage/__init__.py new file mode 100644 index 000000000000..4fd8ad191b3f --- /dev/null +++ b/src/diffusers/pipelines/mirage/__init__.py @@ -0,0 +1,4 @@ +from .pipeline_mirage import MiragePipeline +from .pipeline_output import MiragePipelineOutput + +__all__ = ["MiragePipeline", "MiragePipelineOutput"] \ No newline at end of file diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py new file mode 100644 index 000000000000..126eab07977c --- /dev/null +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -0,0 +1,629 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +from typing import Any, Callable, Dict, List, Optional, Union + +import html +import re +import urllib.parse as ul + +import ftfy +import torch +from transformers import ( + AutoTokenizer, + GemmaTokenizerFast, + T5EncoderModel, + T5TokenizerFast, +) + +from ...image_processor import VaeImageProcessor +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, AutoencoderDC +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import MiragePipelineOutput + +try: + from ...models.transformers.transformer_mirage import MirageTransformer2DModel +except ImportError: + MirageTransformer2DModel = None + +logger = logging.get_logger(__name__) + + +class TextPreprocessor: + """Text preprocessing utility for MiragePipeline.""" + + def __init__(self): + """Initialize text preprocessor.""" + self.bad_punct_regex = re.compile( + r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + r"\\" + r"\/" + r"\*" + r"]{1,}" + ) + + def clean_text(self, text: str) -> str: + """Clean text using comprehensive text processing logic.""" + # See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py + text = str(text) + text = ul.unquote_plus(text) + text = text.strip().lower() + text = re.sub("", "person", text) + + # Remove all urls: + text = re.sub( + r"\b((?:https?|www):(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@))", + "", + text, + ) # regex for urls + + # @ + text = re.sub(r"@[\w\d]+\b", "", text) + + # 31C0—31EF CJK Strokes through 4E00—9FFF CJK Unified Ideographs + text = re.sub(r"[\u31c0-\u31ef]+", "", text) + text = re.sub(r"[\u31f0-\u31ff]+", "", text) + text = re.sub(r"[\u3200-\u32ff]+", "", text) + text = re.sub(r"[\u3300-\u33ff]+", "", text) + text = re.sub(r"[\u3400-\u4dbf]+", "", text) + text = re.sub(r"[\u4dc0-\u4dff]+", "", text) + text = re.sub(r"[\u4e00-\u9fff]+", "", text) + + # все виды тире / all types of dash --> "-" + text = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", + "-", + text, + ) + + # кавычки к одному стандарту + text = re.sub(r"[`´«»""¨]", '"', text) + text = re.sub(r"['']", "'", text) + + # " and & + text = re.sub(r""?", "", text) + text = re.sub(r"&", "", text) + + # ip addresses: + text = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", text) + + # article ids: + text = re.sub(r"\d:\d\d\s+$", "", text) + + # \n + text = re.sub(r"\\n", " ", text) + + # "#123", "#12345..", "123456.." + text = re.sub(r"#\d{1,3}\b", "", text) + text = re.sub(r"#\d{5,}\b", "", text) + text = re.sub(r"\b\d{6,}\b", "", text) + + # filenames: + text = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", text) + + # Clean punctuation + text = re.sub(r"[\"\']{2,}", r'"', text) # """AUSVERKAUFT""" + text = re.sub(r"[\.]{2,}", r" ", text) + + text = re.sub(self.bad_punct_regex, r" ", text) # ***AUSVERKAUFT***, #AUSVERKAUFT + text = re.sub(r"\s+\.\s+", r" ", text) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, text)) > 3: + text = re.sub(regex2, " ", text) + + # Basic cleaning + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + text = text.strip() + + # Clean alphanumeric patterns + text = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", text) # jc6640 + text = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", text) # jc6640vc + text = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", text) # 6640vc231 + + # Common spam patterns + text = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", text) + text = re.sub(r"(free\s)?download(\sfree)?", "", text) + text = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", text) + text = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", text) + text = re.sub(r"\bpage\s+\d+\b", "", text) + + text = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", text) # j2d1a2a... + text = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", text) + + # Final cleanup + text = re.sub(r"\b\s+\:\s+", r": ", text) + text = re.sub(r"(\D[,\./])\b", r"\1 ", text) + text = re.sub(r"\s+", " ", text) + + text.strip() + + text = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", text) + text = re.sub(r"^[\'\_,\-\:;]", r"", text) + text = re.sub(r"[\'\_,\-\:\-\+]$", r"", text) + text = re.sub(r"^\.\S+$", "", text) + + return text.strip() + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import MiragePipeline + >>> from diffusers.models import AutoencoderKL, AutoencoderDC + >>> from transformers import T5GemmaModel, GemmaTokenizerFast + + >>> # Load pipeline directly with from_pretrained + >>> pipe = MiragePipeline.from_pretrained("path/to/mirage_checkpoint") + + >>> # Or initialize pipeline components manually + >>> transformer = MirageTransformer2DModel.from_pretrained("path/to/transformer") + >>> scheduler = FlowMatchEulerDiscreteScheduler() + >>> # Load T5Gemma encoder + >>> t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") + >>> text_encoder = t5gemma_model.encoder + >>> tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") + >>> vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") + + >>> pipe = MiragePipeline( + ... transformer=transformer, + ... scheduler=scheduler, + ... text_encoder=text_encoder, + ... tokenizer=tokenizer, + ... vae=vae + ... ) + >>> pipe.to("cuda") + >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" + >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] + >>> image.save("mirage_output.png") + ``` +""" + + +class MiragePipeline( + DiffusionPipeline, + LoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + Pipeline for text-to-image generation using Mirage Transformer. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + transformer ([`MirageTransformer2DModel`]): + The Mirage transformer model to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + text_encoder ([`T5EncoderModel`]): + Standard text encoder model for encoding prompts. + tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]): + Tokenizer for the text encoder. + vae ([`AutoencoderKL`] or [`AutoencoderDC`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents"] + _optional_components = [] + + # Component configurations for automatic loading + config_name = "model_index.json" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + """ + Override from_pretrained to ensure T5GemmaEncoder is available for loading. + + This ensures that T5GemmaEncoder from transformers is accessible in the module namespace + during component loading, which is required for MiragePipeline checkpoints that use + T5GemmaEncoder as the text encoder. + """ + # Ensure T5GemmaEncoder is available for loading + import transformers + if not hasattr(transformers, 'T5GemmaEncoder'): + try: + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + transformers.T5GemmaEncoder = T5GemmaEncoder + except ImportError: + # T5GemmaEncoder not available in this transformers version + pass + + # Proceed with standard loading + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + + + def __init__( + self, + transformer: MirageTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder: Union[T5EncoderModel, Any], + tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], + vae: Union[AutoencoderKL, AutoencoderDC], + ): + super().__init__() + + if MirageTransformer2DModel is None: + raise ImportError( + "MirageTransformer2DModel is not available. Please ensure the transformer_mirage module is properly installed." + ) + + # Store standard components + self.text_encoder = text_encoder + self.tokenizer = tokenizer + + # Initialize text preprocessor + self.text_preprocessor = TextPreprocessor() + + self.register_modules( + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + ) + + # Enhance VAE with universal properties for both AutoencoderKL and AutoencoderDC + self._enhance_vae_properties() + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + + def _enhance_vae_properties(self): + """Add universal properties to VAE for consistent interface across AutoencoderKL and AutoencoderDC.""" + if not hasattr(self, "vae") or self.vae is None: + return + + # Set spatial_compression_ratio property + if hasattr(self.vae, "spatial_compression_ratio"): + # AutoencoderDC already has this property + pass + elif hasattr(self.vae, "config") and hasattr(self.vae.config, "block_out_channels"): + # AutoencoderKL: calculate from block_out_channels + self.vae.spatial_compression_ratio = 2 ** (len(self.vae.config.block_out_channels) - 1) + else: + # Fallback + self.vae.spatial_compression_ratio = 8 + + # Set scaling_factor property with safe defaults + if hasattr(self.vae, "config"): + self.vae.scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) + else: + self.vae.scaling_factor = 0.18215 + + # Set shift_factor property with safe defaults (0.0 for AutoencoderDC) + if hasattr(self.vae, "config"): + shift_factor = getattr(self.vae.config, "shift_factor", None) + if shift_factor is None: # AutoencoderDC case + self.vae.shift_factor = 0.0 + else: + self.vae.shift_factor = shift_factor + else: + self.vae.shift_factor = 0.0 + + # Set latent_channels property (like VaeTower does) + if hasattr(self.vae, "config") and hasattr(self.vae.config, "latent_channels"): + # AutoencoderDC has latent_channels in config + self.vae.latent_channels = int(self.vae.config.latent_channels) + elif hasattr(self.vae, "config") and hasattr(self.vae.config, "in_channels"): + # AutoencoderKL has in_channels in config + self.vae.latent_channels = int(self.vae.config.in_channels) + else: + # Fallback based on VAE type - DC-AE typically has 32, AutoencoderKL has 4/16 + if hasattr(self.vae, "spatial_compression_ratio") and self.vae.spatial_compression_ratio == 32: + self.vae.latent_channels = 32 # DC-AE default + else: + self.vae.latent_channels = 4 # AutoencoderKL default + + @property + def vae_scale_factor(self): + """Compatibility property that returns spatial compression ratio.""" + return getattr(self.vae, "spatial_compression_ratio", 8) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ): + """Prepare initial latents for the diffusion process.""" + if latents is None: + latent_height, latent_width = height // self.vae.spatial_compression_ratio, width // self.vae.spatial_compression_ratio + shape = (batch_size, num_channels_latents, latent_height, latent_width) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # FlowMatchEulerDiscreteScheduler doesn't use init_noise_sigma scaling + return latents + + def encode_prompt(self, prompt: Union[str, List[str]], device: torch.device): + """Encode text prompt using standard text encoder and tokenizer.""" + if isinstance(prompt, str): + prompt = [prompt] + + return self._encode_prompt_standard(prompt, device) + + def _encode_prompt_standard(self, prompt: List[str], device: torch.device): + """Encode prompt using standard text encoder and tokenizer with batch processing.""" + # Clean text using modular preprocessor + cleaned_prompts = [self.text_preprocessor.clean_text(text) for text in prompt] + cleaned_uncond_prompts = [self.text_preprocessor.clean_text("") for _ in prompt] + + # Batch conditional and unconditional prompts together for efficiency + all_prompts = cleaned_prompts + cleaned_uncond_prompts + + # Tokenize all prompts in one batch + tokens = self.tokenizer( + all_prompts, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + + input_ids = tokens["input_ids"].to(device) + attention_mask = tokens["attention_mask"].bool().to(device) + + # Encode all prompts in one batch + with torch.no_grad(): + # Disable autocast like in TextTower + with torch.autocast("cuda", enabled=False): + emb = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + + # Use last hidden state (matching TextTower's use_last_hidden_state=True default) + all_embeddings = emb["last_hidden_state"] + + # Split back into conditional and unconditional + batch_size = len(prompt) + text_embeddings = all_embeddings[:batch_size] + uncond_text_embeddings = all_embeddings[batch_size:] + + cross_attn_mask = attention_mask[:batch_size] + uncond_cross_attn_mask = attention_mask[batch_size:] + + return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask + + def check_inputs( + self, + prompt: Union[str, List[str]], + height: int, + width: int, + guidance_scale: float, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + ): + """Check that all inputs are in correct format.""" + if height % self.vae.spatial_compression_ratio != 0 or width % self.vae.spatial_compression_ratio != 0: + raise ValueError(f"`height` and `width` have to be divisible by {self.vae.spatial_compression_ratio} but are {height} and {width}.") + + if guidance_scale < 1.0: + raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}") + + if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}" + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.mirage.MiragePipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`. + `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed + in the `._callback_tensor_inputs` attribute. + + Examples: + + Returns: + [`~pipelines.mirage.MiragePipelineOutput`] or `tuple`: [`~pipelines.mirage.MiragePipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + # 0. Default height and width to transformer config + height = height or 256 + width = width or 256 + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + guidance_scale, + callback_on_step_end_tensor_inputs, + ) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError("prompt must be provided as a string or list of strings") + + device = self._execution_device + + # 2. Encode input prompt + text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt( + prompt, device + ) + + # 3. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 4. Prepare latent variables + num_channels_latents = self.vae.latent_channels # From your transformer config + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 5. Prepare extra step kwargs + extra_step_kwargs = {} + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_eta: + extra_step_kwargs["eta"] = 0.0 + + # 6. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Duplicate latents for CFG + latents_in = torch.cat([latents, latents], dim=0) + + # Cross-attention batch (uncond, cond) + ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) + ca_mask = None + if cross_attn_mask is not None and uncond_cross_attn_mask is not None: + ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) + + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) + + # Process inputs for transformer + img_seq, txt, pe = self.transformer.process_inputs(latents_in, ca_embed) + + # Forward through transformer layers + img_seq = self.transformer.forward_transformers( + img_seq, txt, time_embedding=self.transformer.compute_timestep_embedding(t_cont, img_seq.dtype), + pe=pe, attention_mask=ca_mask + ) + + # Convert back to image format + from ...models.transformers.transformer_mirage import seq2img + noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) + + # Apply CFG + noise_uncond, noise_text = noise_both.chunk(2, dim=0) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_on_step_end(self, i, t, callback_kwargs) + + # Call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 8. Post-processing + if output_type == "latent": + image = latents + else: + # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC) + latents = (latents / self.vae.scaling_factor) + self.vae.shift_factor + # Decode using VAE (AutoencoderKL or AutoencoderDC) + image = self.vae.decode(latents, return_dict=False)[0] + # Use standard image processor for post-processing + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return MiragePipelineOutput(images=image) \ No newline at end of file diff --git a/src/diffusers/pipelines/mirage/pipeline_output.py b/src/diffusers/pipelines/mirage/pipeline_output.py new file mode 100644 index 000000000000..e5cdb2a40924 --- /dev/null +++ b/src/diffusers/pipelines/mirage/pipeline_output.py @@ -0,0 +1,35 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class MiragePipelineOutput(BaseOutput): + """ + Output class for Mirage pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file diff --git a/tests/models/transformers/test_models_transformer_mirage.py b/tests/models/transformers/test_models_transformer_mirage.py new file mode 100644 index 000000000000..11accdaecbee --- /dev/null +++ b/tests/models/transformers/test_models_transformer_mirage.py @@ -0,0 +1,252 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel, MirageParams + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class MirageTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = MirageTransformer2DModel + main_input_name = "image_latent" + + @property + def dummy_input(self): + return self.prepare_dummy_input() + + @property + def input_shape(self): + return (16, 4, 4) + + @property + def output_shape(self): + return (16, 4, 4) + + def prepare_dummy_input(self, height=32, width=32): + batch_size = 1 + num_latent_channels = 16 + sequence_length = 16 + embedding_dim = 1792 + + image_latent = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device) + cross_attn_conditioning = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + micro_conditioning = torch.randn((batch_size, embedding_dim)).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "image_latent": image_latent, + "timestep": timestep, + "cross_attn_conditioning": cross_attn_conditioning, + "micro_conditioning": micro_conditioning, + } + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 16, + "patch_size": 2, + "context_in_dim": 1792, + "hidden_size": 1792, + "mlp_ratio": 3.5, + "num_heads": 28, + "depth": 4, # Smaller depth for testing + "axes_dim": [32, 32], + "theta": 10_000, + } + inputs_dict = self.prepare_dummy_input() + return init_dict, inputs_dict + + def test_forward_signature(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + # Test forward + outputs = model(**inputs_dict) + + self.assertIsNotNone(outputs) + expected_shape = inputs_dict["image_latent"].shape + self.assertEqual(outputs.shape, expected_shape) + + def test_mirage_params_initialization(self): + # Test model initialization + model = MirageTransformer2DModel( + in_channels=16, + patch_size=2, + context_in_dim=1792, + hidden_size=1792, + mlp_ratio=3.5, + num_heads=28, + depth=4, + axes_dim=[32, 32], + theta=10_000, + ) + self.assertEqual(model.config.in_channels, 16) + self.assertEqual(model.config.hidden_size, 1792) + self.assertEqual(model.config.num_heads, 28) + + def test_model_with_dict_config(self): + # Test model initialization with from_config + config_dict = { + "in_channels": 16, + "patch_size": 2, + "context_in_dim": 1792, + "hidden_size": 1792, + "mlp_ratio": 3.5, + "num_heads": 28, + "depth": 4, + "axes_dim": [32, 32], + "theta": 10_000, + } + + model = MirageTransformer2DModel.from_config(config_dict) + self.assertEqual(model.config.in_channels, 16) + self.assertEqual(model.config.hidden_size, 1792) + + def test_process_inputs(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + img_seq, txt, pe = model.process_inputs( + inputs_dict["image_latent"], + inputs_dict["cross_attn_conditioning"] + ) + + # Check shapes + batch_size = inputs_dict["image_latent"].shape[0] + height, width = inputs_dict["image_latent"].shape[2:] + patch_size = init_dict["patch_size"] + expected_seq_len = (height // patch_size) * (width // patch_size) + + self.assertEqual(img_seq.shape, (batch_size, expected_seq_len, init_dict["in_channels"] * patch_size**2)) + self.assertEqual(txt.shape, (batch_size, inputs_dict["cross_attn_conditioning"].shape[1], init_dict["hidden_size"])) + # Check that pe has the correct batch size, sequence length and some embedding dimension + self.assertEqual(pe.shape[0], batch_size) # batch size + self.assertEqual(pe.shape[1], 1) # unsqueeze(1) in EmbedND + self.assertEqual(pe.shape[2], expected_seq_len) # sequence length + self.assertEqual(pe.shape[-2:], (2, 2)) # rope rearrange output + + def test_forward_transformers(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + # Process inputs first + img_seq, txt, pe = model.process_inputs( + inputs_dict["image_latent"], + inputs_dict["cross_attn_conditioning"] + ) + + # Test forward_transformers + output_seq = model.forward_transformers( + img_seq, + txt, + timestep=inputs_dict["timestep"], + pe=pe + ) + + # Check output shape + expected_out_channels = init_dict["in_channels"] * init_dict["patch_size"]**2 + self.assertEqual(output_seq.shape, (img_seq.shape[0], img_seq.shape[1], expected_out_channels)) + + def test_attention_mask(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + # Create attention mask + batch_size = inputs_dict["cross_attn_conditioning"].shape[0] + seq_len = inputs_dict["cross_attn_conditioning"].shape[1] + attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool).to(torch_device) + attention_mask[:, seq_len//2:] = False # Mask second half + + with torch.no_grad(): + outputs = model( + **inputs_dict, + cross_attn_mask=attention_mask + ) + + self.assertIsNotNone(outputs) + expected_shape = inputs_dict["image_latent"].shape + self.assertEqual(outputs.shape, expected_shape) + + def test_invalid_config(self): + # Test invalid configuration - hidden_size not divisible by num_heads + with self.assertRaises(ValueError): + MirageTransformer2DModel( + in_channels=16, + patch_size=2, + context_in_dim=1792, + hidden_size=1793, # Not divisible by 28 + mlp_ratio=3.5, + num_heads=28, + depth=4, + axes_dim=[32, 32], + theta=10_000, + ) + + # Test invalid axes_dim that doesn't sum to pe_dim + with self.assertRaises(ValueError): + MirageTransformer2DModel( + in_channels=16, + patch_size=2, + context_in_dim=1792, + hidden_size=1792, + mlp_ratio=3.5, + num_heads=28, + depth=4, + axes_dim=[30, 30], # Sum = 60, but pe_dim = 1792/28 = 64 + theta=10_000, + ) + + def test_gradient_checkpointing_enable(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + # Enable gradient checkpointing + model.enable_gradient_checkpointing() + + # Check that _activation_checkpointing is set + for block in model.blocks: + self.assertTrue(hasattr(block, '_activation_checkpointing')) + + def test_from_config(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + # Create model from config + model = self.model_class.from_config(init_dict) + self.assertIsInstance(model, self.model_class) + self.assertEqual(model.config.in_channels, init_dict["in_channels"]) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 6e10ed4938afe86d48d60dd8e97bbab38e737422 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 11:51:14 +0200 Subject: [PATCH 19/38] use attention processors --- src/diffusers/models/attention_processor.py | 58 +++++++++++++ .../models/transformers/transformer_mirage.py | 86 ++++++++++++++++--- 2 files changed, 134 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 66455d733aee..e4ab33be9784 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5605,6 +5605,63 @@ def __new__(cls, *args, **kwargs): return processor +class MirageAttnProcessor2_0: + r""" + Processor for implementing Mirage-style attention with multi-source tokens and RoPE. + Properly integrates with diffusers Attention module while handling Mirage-specific logic. + """ + + def __init__(self): + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError("MirageAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: "Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Apply Mirage attention using standard diffusers interface. + + Expected tensor formats from MirageBlock.attn_forward(): + - hidden_states: Image queries with RoPE applied [B, H, L_img, D] + - encoder_hidden_states: Packed key+value tensors [B, H, L_all, 2*D] + (concatenated keys and values from text + image + spatial conditioning) + - attention_mask: Custom attention mask [B, H, L_img, L_all] or None + """ + + if encoder_hidden_states is None: + raise ValueError( + "MirageAttnProcessor2_0 requires 'encoder_hidden_states' containing packed key+value tensors. " + "This should be provided by MirageBlock.attn_forward()." + ) + + # Unpack the combined key+value tensor + # encoder_hidden_states is [B, H, L_all, 2*D] containing [keys, values] + key, value = encoder_hidden_states.chunk(2, dim=-1) # Each [B, H, L_all, D] + + # Apply scaled dot-product attention with Mirage's processed tensors + # hidden_states is image queries [B, H, L_img, D] + attn_output = torch.nn.functional.scaled_dot_product_attention( + hidden_states.contiguous(), key.contiguous(), value.contiguous(), attn_mask=attention_mask + ) + + # Reshape from [B, H, L_img, D] to [B, L_img, H*D] + batch_size, num_heads, seq_len, head_dim = attn_output.shape + attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim) + + # Apply output projection using the diffusers Attention module + attn_output = attn.to_out[0](attn_output) + if len(attn.to_out) > 1: + attn_output = attn.to_out[1](attn_output) # dropout if present + + return attn_output + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, @@ -5653,6 +5710,7 @@ def __new__(cls, *args, **kwargs): PAGHunyuanAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, LuminaAttnProcessor2_0, + MirageAttnProcessor2_0, FusedAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 39c569cbb26b..0225b9532aff 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..modeling_outputs import Transformer2DModelOutput +from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers @@ -159,13 +160,21 @@ def __init__( # img qkv self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_qkv_proj = nn.Linear(hidden_size, hidden_size * 3, bias=False) - self.attn_out = nn.Linear(hidden_size, hidden_size, bias=False) self.qk_norm = QKNorm(self.head_dim) # txt kv self.txt_kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=False) self.k_norm = RMSNorm(self.head_dim) + self.attention = Attention( + query_dim=hidden_size, + heads=num_heads, + dim_head=self.head_dim, + bias=False, + out_bias=False, + processor=MirageAttnProcessor2_0(), + ) + # mlp self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) @@ -214,15 +223,11 @@ def attn_forward( k = torch.cat((cond_k, k), dim=2) v = torch.cat((cond_v, v), dim=2) - # build additive attention bias - attn_bias: Tensor | None = None - attn_mask: Tensor | None = None - # build multiplicative 0/1 mask for provided attention_mask over [cond?, text, image] keys + attn_mask: Tensor | None = None if attention_mask is not None: bs, _, l_img, _ = img_q.shape l_txt = txt_k.shape[2] - l_all = k.shape[2] assert attention_mask.dim() == 2, f"Unsupported attention_mask shape: {attention_mask.shape}" assert ( @@ -244,11 +249,13 @@ def attn_forward( # repeat across heads and query positions attn_mask = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1) # (B,H,L_img,L_all) - attn = torch.nn.functional.scaled_dot_product_attention( - img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask + kv_packed = torch.cat([k, v], dim=-1) + + attn = self.attention( + hidden_states=img_q, + encoder_hidden_states=kv_packed, + attention_mask=attn_mask, ) - attn = rearrange(attn, "B H L D -> B L (H D)") - attn = self.attn_out(attn) return attn @@ -413,6 +420,65 @@ def __init__( self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: """Timestep independent stuff""" txt = self.txt_in(txt) From 866c6de0e3cde370ce134706cf50d1f342064a42 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 12:50:19 +0200 Subject: [PATCH 20/38] use diffusers rmsnorm --- .../models/transformers/transformer_mirage.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 0225b9532aff..f4199da1edcc 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -26,12 +26,12 @@ from ..modeling_outputs import Transformer2DModelOutput from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..normalization import RMSNorm logger = logging.get_logger(__name__) -# Mirage Layer Components def get_image_ids(bs: int, h: int, w: int, patch_size: int, device: torch.device) -> Tensor: img_ids = torch.zeros(h // patch_size, w // patch_size, 2, device=device) img_ids[..., 0] = torch.arange(h // patch_size, device=device)[:, None] @@ -93,23 +93,13 @@ def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int): - super().__init__() - self.scale = nn.Parameter(torch.ones(dim)) - - def forward(self, x: Tensor) -> Tensor: - x_dtype = x.dtype - x = x.float() - rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) - return (x * rrms * self.scale).to(dtype=x_dtype) class QKNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() - self.query_norm = RMSNorm(dim) - self.key_norm = RMSNorm(dim) + self.query_norm = RMSNorm(dim, eps=1e-6) + self.key_norm = RMSNorm(dim, eps=1e-6) def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: q = self.query_norm(q) @@ -164,7 +154,7 @@ def __init__( # txt kv self.txt_kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=False) - self.k_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim, eps=1e-6) self.attention = Attention( query_dim=hidden_size, From 4e8b647227d013816903271b233027cc8034d2d1 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 14:26:50 +0200 Subject: [PATCH 21/38] use diffusers timestep embedding method --- .../models/transformers/transformer_mirage.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index f4199da1edcc..916559eb47ac 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -27,6 +27,7 @@ from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..normalization import RMSNorm +from ..embeddings import get_timestep_embedding logger = logging.get_logger(__name__) @@ -71,15 +72,6 @@ def forward(self, ids: Tensor) -> Tensor: return emb.unsqueeze(1) -def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000, time_factor: float = 1000.0) -> Tensor: - t = time_factor * t - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding class MLPEmbedder(nn.Module): @@ -480,8 +472,12 @@ def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[T def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: return self.time_in( - timestep_embedding( - t=timestep, dim=256, max_period=self.time_max_period, time_factor=self.time_factor + get_timestep_embedding( + timesteps=timestep, + embedding_dim=256, + max_period=self.time_max_period, + scale=self.time_factor, + flip_sin_to_cos=True # Match original cos, sin order ).to(dtype) ) From 472ad97e410b9d8a46a002ee45583edb8f02061e Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 15:17:11 +0200 Subject: [PATCH 22/38] remove MirageParams --- .../models/transformers/transformer_mirage.py | 64 +++++-------------- .../pipelines/mirage/pipeline_output.py | 2 +- .../test_models_transformer_mirage.py | 8 +-- 3 files changed, 22 insertions(+), 52 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 916559eb47ac..396e000524ec 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -288,20 +288,6 @@ def forward(self, x: Tensor, vec: Tensor) -> Tensor: return x -@dataclass -class MirageParams: - in_channels: int - patch_size: int - context_in_dim: int - hidden_size: int - mlp_ratio: float - num_heads: int - depth: int - axes_dim: list[int] - theta: int - time_factor: float = 1000.0 - time_max_period: int = 10_000 - conditioning_block_ids: list[int] | None = None def img2seq(img: Tensor, patch_size: int) -> Tensor: @@ -348,55 +334,39 @@ def __init__( if axes_dim is None: axes_dim = [32, 32] - # Create MirageParams from the provided arguments - params = MirageParams( - in_channels=in_channels, - patch_size=patch_size, - context_in_dim=context_in_dim, - hidden_size=hidden_size, - mlp_ratio=mlp_ratio, - num_heads=num_heads, - depth=depth, - axes_dim=axes_dim, - theta=theta, - time_factor=time_factor, - time_max_period=time_max_period, - conditioning_block_ids=conditioning_block_ids, - ) - - self.params = params - self.in_channels = params.in_channels - self.patch_size = params.patch_size + # Store parameters directly + self.in_channels = in_channels + self.patch_size = patch_size self.out_channels = self.in_channels * self.patch_size**2 - self.time_factor = params.time_factor - self.time_max_period = params.time_max_period + self.time_factor = time_factor + self.time_max_period = time_max_period - if params.hidden_size % params.num_heads != 0: - raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") + if hidden_size % num_heads != 0: + raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}") - pe_dim = params.hidden_size // params.num_heads + pe_dim = hidden_size // num_heads - if sum(params.axes_dim) != pe_dim: - raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") + if sum(axes_dim) != pe_dim: + raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}") - self.hidden_size = params.hidden_size - self.num_heads = params.num_heads - self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) + self.hidden_size = hidden_size + self.num_heads = num_heads + self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) - self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + self.txt_in = nn.Linear(context_in_dim, self.hidden_size) - conditioning_block_ids: list[int] = params.conditioning_block_ids or list(range(params.depth)) + conditioning_block_ids: list[int] = conditioning_block_ids or list(range(depth)) self.blocks = nn.ModuleList( [ MirageBlock( self.hidden_size, self.num_heads, - mlp_ratio=params.mlp_ratio, + mlp_ratio=mlp_ratio, ) - for i in range(params.depth) + for i in range(depth) ] ) diff --git a/src/diffusers/pipelines/mirage/pipeline_output.py b/src/diffusers/pipelines/mirage/pipeline_output.py index e5cdb2a40924..dfb55821d142 100644 --- a/src/diffusers/pipelines/mirage/pipeline_output.py +++ b/src/diffusers/pipelines/mirage/pipeline_output.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import List, Optional, Union +from typing import List, Union import numpy as np import PIL.Image diff --git a/tests/models/transformers/test_models_transformer_mirage.py b/tests/models/transformers/test_models_transformer_mirage.py index 11accdaecbee..5e7b0bd165a6 100644 --- a/tests/models/transformers/test_models_transformer_mirage.py +++ b/tests/models/transformers/test_models_transformer_mirage.py @@ -17,7 +17,7 @@ import torch -from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel, MirageParams +from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin @@ -88,9 +88,9 @@ def test_forward_signature(self): self.assertIsNotNone(outputs) expected_shape = inputs_dict["image_latent"].shape - self.assertEqual(outputs.shape, expected_shape) + self.assertEqual(outputs.sample.shape, expected_shape) - def test_mirage_params_initialization(self): + def test_model_initialization(self): # Test model initialization model = MirageTransformer2DModel( in_channels=16, @@ -196,7 +196,7 @@ def test_attention_mask(self): self.assertIsNotNone(outputs) expected_shape = inputs_dict["image_latent"].shape - self.assertEqual(outputs.shape, expected_shape) + self.assertEqual(outputs.sample.shape, expected_shape) def test_invalid_config(self): # Test invalid configuration - hidden_size not divisible by num_heads From 97a231e3561e1817d46fb7a0c7840423dc76ec99 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 16:35:56 +0200 Subject: [PATCH 23/38] checkpoint conversion script --- scripts/convert_mirage_to_diffusers.py | 312 +++++++++++++++++++++++++ 1 file changed, 312 insertions(+) create mode 100644 scripts/convert_mirage_to_diffusers.py diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_mirage_to_diffusers.py new file mode 100644 index 000000000000..85716e69ff92 --- /dev/null +++ b/scripts/convert_mirage_to_diffusers.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 +""" +Script to convert Mirage checkpoint from original codebase to diffusers format. +""" + +import argparse +import json +import os +import shutil +import sys +import torch +from safetensors.torch import save_file +from transformers import GemmaTokenizerFast + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.pipelines.mirage import MiragePipeline + +def load_reference_config(vae_type: str) -> dict: + """Load transformer config from existing pipeline checkpoint.""" + + if vae_type == "flux": + config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated/transformer/config.json" + elif vae_type == "dc-ae": + config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated/transformer/config.json" + else: + raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") + + if not os.path.exists(config_path): + raise FileNotFoundError(f"Reference config not found: {config_path}") + + with open(config_path, 'r') as f: + config = json.load(f) + + print(f"✓ Loaded {vae_type} config: in_channels={config['in_channels']}") + return config + +def create_parameter_mapping() -> dict: + """Create mapping from old parameter names to new diffusers names.""" + + # Key mappings for structural changes + mapping = {} + + # RMSNorm: scale -> weight + for i in range(16): # 16 layers + mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.qk_norm.query_norm.weight" + mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.qk_norm.key_norm.weight" + mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.k_norm.weight" + + # Attention: attn_out -> attention.to_out.0 + mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight" + + return mapping + +def convert_checkpoint_parameters(old_state_dict: dict) -> dict: + """Convert old checkpoint parameters to new diffusers format.""" + + print("Converting checkpoint parameters...") + + mapping = create_parameter_mapping() + converted_state_dict = {} + + # First, print available keys to understand structure + print("Available keys in checkpoint:") + for key in sorted(old_state_dict.keys())[:10]: # Show first 10 keys + print(f" {key}") + if len(old_state_dict) > 10: + print(f" ... and {len(old_state_dict) - 10} more") + + for key, value in old_state_dict.items(): + new_key = key + + # Apply specific mappings if needed + if key in mapping: + new_key = mapping[key] + print(f" Mapped: {key} -> {new_key}") + + # Handle img_qkv_proj -> split to to_q, to_k, to_v + if "img_qkv_proj.weight" in key: + print(f" Found QKV projection: {key}") + # Split QKV weight into separate Q, K, V projections + qkv_weight = value + hidden_size = qkv_weight.shape[1] + q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) + + # Extract layer number from key (e.g., blocks.0.img_qkv_proj.weight -> 0) + parts = key.split('.') + layer_idx = None + for i, part in enumerate(parts): + if part == 'blocks' and i + 1 < len(parts) and parts[i+1].isdigit(): + layer_idx = parts[i+1] + break + + if layer_idx is not None: + converted_state_dict[f"blocks.{layer_idx}.attention.to_q.weight"] = q_weight + converted_state_dict[f"blocks.{layer_idx}.attention.to_k.weight"] = k_weight + converted_state_dict[f"blocks.{layer_idx}.attention.to_v.weight"] = v_weight + print(f" Split QKV for layer {layer_idx}") + + # Also keep the original img_qkv_proj for backward compatibility + converted_state_dict[new_key] = value + else: + converted_state_dict[new_key] = value + + print(f"✓ Converted {len(converted_state_dict)} parameters") + return converted_state_dict + + +def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> MirageTransformer2DModel: + """Create and load MirageTransformer2DModel from old checkpoint.""" + + print(f"Loading checkpoint from: {checkpoint_path}") + + # Load old checkpoint + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + old_checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # Handle different checkpoint formats + if isinstance(old_checkpoint, dict): + if 'model' in old_checkpoint: + state_dict = old_checkpoint['model'] + elif 'state_dict' in old_checkpoint: + state_dict = old_checkpoint['state_dict'] + else: + state_dict = old_checkpoint + else: + state_dict = old_checkpoint + + print(f"✓ Loaded checkpoint with {len(state_dict)} parameters") + + # Convert parameter names if needed + converted_state_dict = convert_checkpoint_parameters(state_dict) + + # Create transformer with config + print("Creating MirageTransformer2DModel...") + transformer = MirageTransformer2DModel(**config) + + # Load state dict + print("Loading converted parameters...") + missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False) + + if missing_keys: + print(f"⚠ Missing keys: {missing_keys}") + if unexpected_keys: + print(f"⚠ Unexpected keys: {unexpected_keys}") + + if not missing_keys and not unexpected_keys: + print("✓ All parameters loaded successfully!") + + return transformer + +def copy_pipeline_components(vae_type: str, output_path: str): + """Copy VAE, scheduler, text encoder, and tokenizer from reference pipeline.""" + + if vae_type == "flux": + ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated" + else: # dc-ae + ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated" + + components = ["vae", "scheduler", "text_encoder", "tokenizer"] + + for component in components: + src_path = os.path.join(ref_pipeline, component) + dst_path = os.path.join(output_path, component) + + if os.path.exists(src_path): + if os.path.isdir(src_path): + shutil.copytree(src_path, dst_path, dirs_exist_ok=True) + else: + shutil.copy2(src_path, dst_path) + print(f"✓ Copied {component}") + else: + print(f"⚠ Component not found: {src_path}") + +def create_model_index(vae_type: str, output_path: str): + """Create model_index.json for the pipeline.""" + + if vae_type == "flux": + vae_class = "AutoencoderKL" + else: # dc-ae + vae_class = "AutoencoderDC" + + model_index = { + "_class_name": "MiragePipeline", + "_diffusers_version": "0.31.0.dev0", + "_name_or_path": os.path.basename(output_path), + "scheduler": [ + "diffusers", + "FlowMatchEulerDiscreteScheduler" + ], + "text_encoder": [ + "transformers", + "T5GemmaEncoder" + ], + "tokenizer": [ + "transformers", + "GemmaTokenizerFast" + ], + "transformer": [ + "diffusers", + "MirageTransformer2DModel" + ], + "vae": [ + "diffusers", + vae_class + ] + } + + model_index_path = os.path.join(output_path, "model_index.json") + with open(model_index_path, 'w') as f: + json.dump(model_index, f, indent=2) + + print(f"✓ Created model_index.json") + +def main(args): + # Validate inputs + if not os.path.exists(args.checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}") + + # Load reference config based on VAE type + config = load_reference_config(args.vae_type) + + # Create output directory + os.makedirs(args.output_path, exist_ok=True) + print(f"✓ Output directory: {args.output_path}") + + # Create transformer from checkpoint + transformer = create_transformer_from_checkpoint(args.checkpoint_path, config) + + # Save transformer + transformer_path = os.path.join(args.output_path, "transformer") + os.makedirs(transformer_path, exist_ok=True) + + # Save config + with open(os.path.join(transformer_path, "config.json"), 'w') as f: + json.dump(config, f, indent=2) + + # Save model weights as safetensors + state_dict = transformer.state_dict() + save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors")) + print(f"✓ Saved transformer to {transformer_path}") + + # Copy other pipeline components + copy_pipeline_components(args.vae_type, args.output_path) + + # Create model index + create_model_index(args.vae_type, args.output_path) + + # Verify the pipeline can be loaded + try: + pipeline = MiragePipeline.from_pretrained(args.output_path) + print(f"Pipeline loaded successfully!") + print(f"Transformer: {type(pipeline.transformer).__name__}") + print(f"VAE: {type(pipeline.vae).__name__}") + print(f"Text Encoder: {type(pipeline.text_encoder).__name__}") + print(f"Scheduler: {type(pipeline.scheduler).__name__}") + + # Display model info + num_params = sum(p.numel() for p in pipeline.transformer.parameters()) + print(f"✓ Transformer parameters: {num_params:,}") + + except Exception as e: + print(f"Pipeline verification failed: {e}") + return False + + print("Conversion completed successfully!") + print(f"Converted pipeline saved to: {args.output_path}") + print(f"VAE type: {args.vae_type}") + + + return True + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Mirage checkpoint to diffusers format") + + parser.add_argument( + "--checkpoint_path", + type=str, + required=True, + help="Path to the original Mirage checkpoint (.pth file)" + ) + + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Output directory for the converted diffusers pipeline" + ) + + parser.add_argument( + "--vae_type", + type=str, + choices=["flux", "dc-ae"], + required=True, + help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)" + ) + + args = parser.parse_args() + + try: + success = main(args) + if not success: + sys.exit(1) + except Exception as e: + print(f"❌ Conversion failed: {e}") + import traceback + traceback.print_exc() + sys.exit(1) \ No newline at end of file From 35d721f79bbb518ea5c1e209f71c0fd80cba9434 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Fri, 26 Sep 2025 17:00:55 +0200 Subject: [PATCH 24/38] ruff formating --- scripts/convert_mirage_to_diffusers.py | 83 ++++++++----------- .../models/transformers/transformer_mirage.py | 41 ++++----- src/diffusers/pipelines/mirage/__init__.py | 3 +- .../pipelines/mirage/pipeline_mirage.py | 50 +++++++---- .../pipelines/mirage/pipeline_output.py | 2 +- .../test_models_transformer_mirage.py | 30 +++---- 6 files changed, 100 insertions(+), 109 deletions(-) diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_mirage_to_diffusers.py index 85716e69ff92..5e2a2ff768f4 100644 --- a/scripts/convert_mirage_to_diffusers.py +++ b/scripts/convert_mirage_to_diffusers.py @@ -8,16 +8,17 @@ import os import shutil import sys + import torch from safetensors.torch import save_file -from transformers import GemmaTokenizerFast -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel -from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.pipelines.mirage import MiragePipeline + def load_reference_config(vae_type: str) -> dict: """Load transformer config from existing pipeline checkpoint.""" @@ -31,12 +32,13 @@ def load_reference_config(vae_type: str) -> dict: if not os.path.exists(config_path): raise FileNotFoundError(f"Reference config not found: {config_path}") - with open(config_path, 'r') as f: + with open(config_path, "r") as f: config = json.load(f) print(f"✓ Loaded {vae_type} config: in_channels={config['in_channels']}") return config + def create_parameter_mapping() -> dict: """Create mapping from old parameter names to new diffusers names.""" @@ -54,6 +56,7 @@ def create_parameter_mapping() -> dict: return mapping + def convert_checkpoint_parameters(old_state_dict: dict) -> dict: """Convert old checkpoint parameters to new diffusers format.""" @@ -82,15 +85,14 @@ def convert_checkpoint_parameters(old_state_dict: dict) -> dict: print(f" Found QKV projection: {key}") # Split QKV weight into separate Q, K, V projections qkv_weight = value - hidden_size = qkv_weight.shape[1] q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) # Extract layer number from key (e.g., blocks.0.img_qkv_proj.weight -> 0) - parts = key.split('.') + parts = key.split(".") layer_idx = None for i, part in enumerate(parts): - if part == 'blocks' and i + 1 < len(parts) and parts[i+1].isdigit(): - layer_idx = parts[i+1] + if part == "blocks" and i + 1 < len(parts) and parts[i + 1].isdigit(): + layer_idx = parts[i + 1] break if layer_idx is not None: @@ -117,14 +119,14 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Mi if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") - old_checkpoint = torch.load(checkpoint_path, map_location='cpu') + old_checkpoint = torch.load(checkpoint_path, map_location="cpu") # Handle different checkpoint formats if isinstance(old_checkpoint, dict): - if 'model' in old_checkpoint: - state_dict = old_checkpoint['model'] - elif 'state_dict' in old_checkpoint: - state_dict = old_checkpoint['state_dict'] + if "model" in old_checkpoint: + state_dict = old_checkpoint["model"] + elif "state_dict" in old_checkpoint: + state_dict = old_checkpoint["state_dict"] else: state_dict = old_checkpoint else: @@ -153,6 +155,7 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Mi return transformer + def copy_pipeline_components(vae_type: str, output_path: str): """Copy VAE, scheduler, text encoder, and tokenizer from reference pipeline.""" @@ -176,6 +179,7 @@ def copy_pipeline_components(vae_type: str, output_path: str): else: print(f"⚠ Component not found: {src_path}") + def create_model_index(vae_type: str, output_path: str): """Create model_index.json for the pipeline.""" @@ -188,33 +192,19 @@ def create_model_index(vae_type: str, output_path: str): "_class_name": "MiragePipeline", "_diffusers_version": "0.31.0.dev0", "_name_or_path": os.path.basename(output_path), - "scheduler": [ - "diffusers", - "FlowMatchEulerDiscreteScheduler" - ], - "text_encoder": [ - "transformers", - "T5GemmaEncoder" - ], - "tokenizer": [ - "transformers", - "GemmaTokenizerFast" - ], - "transformer": [ - "diffusers", - "MirageTransformer2DModel" - ], - "vae": [ - "diffusers", - vae_class - ] + "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], + "text_encoder": ["transformers", "T5GemmaEncoder"], + "tokenizer": ["transformers", "GemmaTokenizerFast"], + "transformer": ["diffusers", "MirageTransformer2DModel"], + "vae": ["diffusers", vae_class], } model_index_path = os.path.join(output_path, "model_index.json") - with open(model_index_path, 'w') as f: + with open(model_index_path, "w") as f: json.dump(model_index, f, indent=2) - print(f"✓ Created model_index.json") + print("✓ Created model_index.json") + def main(args): # Validate inputs @@ -236,7 +226,7 @@ def main(args): os.makedirs(transformer_path, exist_ok=True) # Save config - with open(os.path.join(transformer_path, "config.json"), 'w') as f: + with open(os.path.join(transformer_path, "config.json"), "w") as f: json.dump(config, f, indent=2) # Save model weights as safetensors @@ -253,7 +243,7 @@ def main(args): # Verify the pipeline can be loaded try: pipeline = MiragePipeline.from_pretrained(args.output_path) - print(f"Pipeline loaded successfully!") + print("Pipeline loaded successfully!") print(f"Transformer: {type(pipeline.transformer).__name__}") print(f"VAE: {type(pipeline.vae).__name__}") print(f"Text Encoder: {type(pipeline.text_encoder).__name__}") @@ -271,24 +261,18 @@ def main(args): print(f"Converted pipeline saved to: {args.output_path}") print(f"VAE type: {args.vae_type}") - return True + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Mirage checkpoint to diffusers format") parser.add_argument( - "--checkpoint_path", - type=str, - required=True, - help="Path to the original Mirage checkpoint (.pth file)" + "--checkpoint_path", type=str, required=True, help="Path to the original Mirage checkpoint (.pth file)" ) parser.add_argument( - "--output_path", - type=str, - required=True, - help="Output directory for the converted diffusers pipeline" + "--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline" ) parser.add_argument( @@ -296,7 +280,7 @@ def main(args): type=str, choices=["flux", "dc-ae"], required=True, - help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)" + help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)", ) args = parser.parse_args() @@ -306,7 +290,8 @@ def main(args): if not success: sys.exit(1) except Exception as e: - print(f"❌ Conversion failed: {e}") + print(f"Conversion failed: {e}") import traceback + traceback.print_exc() - sys.exit(1) \ No newline at end of file + sys.exit(1) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 396e000524ec..923d44d4f1ec 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -13,21 +13,21 @@ # limitations under the License. from dataclasses import dataclass -from typing import Any, Dict, Optional, Union, Tuple +from typing import Any, Dict, Optional, Tuple, Union + import torch -import math -from torch import Tensor, nn -from torch.nn.functional import fold, unfold from einops import rearrange from einops.layers.torch import Rearrange +from torch import Tensor, nn +from torch.nn.functional import fold, unfold from ...configuration_utils import ConfigMixin, register_to_config -from ..modeling_utils import ModelMixin -from ..modeling_outputs import Transformer2DModelOutput -from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..normalization import RMSNorm +from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 from ..embeddings import get_timestep_embedding +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm logger = logging.get_logger(__name__) @@ -72,8 +72,6 @@ def forward(self, ids: Tensor) -> Tensor: return emb.unsqueeze(1) - - class MLPEmbedder(nn.Module): def __init__(self, in_dim: int, hidden_dim: int): super().__init__() @@ -85,8 +83,6 @@ def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) - - class QKNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() @@ -157,7 +153,6 @@ def __init__( processor=MirageAttnProcessor2_0(), ) - # mlp self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) @@ -212,9 +207,9 @@ def attn_forward( l_txt = txt_k.shape[2] assert attention_mask.dim() == 2, f"Unsupported attention_mask shape: {attention_mask.shape}" - assert ( - attention_mask.shape[-1] == l_txt - ), f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" + assert attention_mask.shape[-1] == l_txt, ( + f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" + ) device = img_q.device @@ -234,8 +229,8 @@ def attn_forward( kv_packed = torch.cat([k, v], dim=-1) attn = self.attention( - hidden_states=img_q, - encoder_hidden_states=kv_packed, + hidden_states=img_q, + encoder_hidden_states=kv_packed, attention_mask=attn_mask, ) @@ -288,8 +283,6 @@ def forward(self, x: Tensor, vec: Tensor) -> Tensor: return x - - def img2seq(img: Tensor, patch_size: int) -> Tensor: """Flatten an image into a sequence of patches""" return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) @@ -327,7 +320,7 @@ def __init__( time_factor: float = 1000.0, time_max_period: int = 10000, conditioning_block_ids: list = None, - **kwargs + **kwargs, ): super().__init__() @@ -447,7 +440,7 @@ def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Te embedding_dim=256, max_period=self.time_max_period, scale=self.time_factor, - flip_sin_to_cos=True # Match original cos, sin order + flip_sin_to_cos=True, # Match original cos, sin order ).to(dtype) ) @@ -470,9 +463,7 @@ def forward_transformers( vec = self.compute_timestep_embedding(timestep, dtype=img.dtype) for block in self.blocks: - img = block( - img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs - ) + img = block(img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs) img = self.final_layer(img, vec) return img diff --git a/src/diffusers/pipelines/mirage/__init__.py b/src/diffusers/pipelines/mirage/__init__.py index 4fd8ad191b3f..cba951057370 100644 --- a/src/diffusers/pipelines/mirage/__init__.py +++ b/src/diffusers/pipelines/mirage/__init__.py @@ -1,4 +1,5 @@ from .pipeline_mirage import MiragePipeline from .pipeline_output import MiragePipelineOutput -__all__ = ["MiragePipeline", "MiragePipelineOutput"] \ No newline at end of file + +__all__ = ["MiragePipeline", "MiragePipelineOutput"] diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py index 126eab07977c..c4a4783c5f38 100644 --- a/src/diffusers/pipelines/mirage/pipeline_mirage.py +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import html import inspect import os -from typing import Any, Callable, Dict, List, Optional, Union - -import html import re import urllib.parse as ul +from typing import Any, Callable, Dict, List, Optional, Union import ftfy import torch @@ -31,7 +30,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, AutoencoderDC +from ...models import AutoencoderDC, AutoencoderKL from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( logging, @@ -41,6 +40,7 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import MiragePipelineOutput + try: from ...models.transformers.transformer_mirage import MirageTransformer2DModel except ImportError: @@ -55,7 +55,19 @@ class TextPreprocessor: def __init__(self): """Initialize text preprocessor.""" self.bad_punct_regex = re.compile( - r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + r"\\" + r"\/" + r"\*" + r"]{1,}" + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + r"\\" + + r"\/" + + r"\*" + + r"]{1,}" ) def clean_text(self, text: str) -> str: @@ -93,7 +105,7 @@ def clean_text(self, text: str) -> str: ) # кавычки к одному стандарту - text = re.sub(r"[`´«»""¨]", '"', text) + text = re.sub(r"[`´«»" "¨]", '"', text) text = re.sub(r"['']", "'", text) # " and & @@ -243,9 +255,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P """ # Ensure T5GemmaEncoder is available for loading import transformers - if not hasattr(transformers, 'T5GemmaEncoder'): + + if not hasattr(transformers, "T5GemmaEncoder"): try: from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + transformers.T5GemmaEncoder = T5GemmaEncoder except ImportError: # T5GemmaEncoder not available in this transformers version @@ -254,7 +268,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Proceed with standard loading return super().from_pretrained(pretrained_model_name_or_path, **kwargs) - def __init__( self, transformer: MirageTransformer2DModel, @@ -333,7 +346,7 @@ def _enhance_vae_properties(self): if hasattr(self.vae, "spatial_compression_ratio") and self.vae.spatial_compression_ratio == 32: self.vae.latent_channels = 32 # DC-AE default else: - self.vae.latent_channels = 4 # AutoencoderKL default + self.vae.latent_channels = 4 # AutoencoderKL default @property def vae_scale_factor(self): @@ -353,7 +366,10 @@ def prepare_latents( ): """Prepare initial latents for the diffusion process.""" if latents is None: - latent_height, latent_width = height // self.vae.spatial_compression_ratio, width // self.vae.spatial_compression_ratio + latent_height, latent_width = ( + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) shape = (batch_size, num_channels_latents, latent_height, latent_width) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: @@ -424,7 +440,9 @@ def check_inputs( ): """Check that all inputs are in correct format.""" if height % self.vae.spatial_compression_ratio != 0 or width % self.vae.spatial_compression_ratio != 0: - raise ValueError(f"`height` and `width` have to be divisible by {self.vae.spatial_compression_ratio} but are {height} and {width}.") + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae.spatial_compression_ratio} but are {height} and {width}." + ) if guidance_scale < 1.0: raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}") @@ -584,12 +602,16 @@ def __call__( # Forward through transformer layers img_seq = self.transformer.forward_transformers( - img_seq, txt, time_embedding=self.transformer.compute_timestep_embedding(t_cont, img_seq.dtype), - pe=pe, attention_mask=ca_mask + img_seq, + txt, + time_embedding=self.transformer.compute_timestep_embedding(t_cont, img_seq.dtype), + pe=pe, + attention_mask=ca_mask, ) # Convert back to image format from ...models.transformers.transformer_mirage import seq2img + noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) # Apply CFG @@ -626,4 +648,4 @@ def __call__( if not return_dict: return (image,) - return MiragePipelineOutput(images=image) \ No newline at end of file + return MiragePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/mirage/pipeline_output.py b/src/diffusers/pipelines/mirage/pipeline_output.py index dfb55821d142..e41c8e3bea00 100644 --- a/src/diffusers/pipelines/mirage/pipeline_output.py +++ b/src/diffusers/pipelines/mirage/pipeline_output.py @@ -32,4 +32,4 @@ class MiragePipelineOutput(BaseOutput): num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ - images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/tests/models/transformers/test_models_transformer_mirage.py b/tests/models/transformers/test_models_transformer_mirage.py index 5e7b0bd165a6..0085627aa7e4 100644 --- a/tests/models/transformers/test_models_transformer_mirage.py +++ b/tests/models/transformers/test_models_transformer_mirage.py @@ -133,8 +133,7 @@ def test_process_inputs(self): with torch.no_grad(): img_seq, txt, pe = model.process_inputs( - inputs_dict["image_latent"], - inputs_dict["cross_attn_conditioning"] + inputs_dict["image_latent"], inputs_dict["cross_attn_conditioning"] ) # Check shapes @@ -144,7 +143,9 @@ def test_process_inputs(self): expected_seq_len = (height // patch_size) * (width // patch_size) self.assertEqual(img_seq.shape, (batch_size, expected_seq_len, init_dict["in_channels"] * patch_size**2)) - self.assertEqual(txt.shape, (batch_size, inputs_dict["cross_attn_conditioning"].shape[1], init_dict["hidden_size"])) + self.assertEqual( + txt.shape, (batch_size, inputs_dict["cross_attn_conditioning"].shape[1], init_dict["hidden_size"]) + ) # Check that pe has the correct batch size, sequence length and some embedding dimension self.assertEqual(pe.shape[0], batch_size) # batch size self.assertEqual(pe.shape[1], 1) # unsqueeze(1) in EmbedND @@ -160,20 +161,14 @@ def test_forward_transformers(self): with torch.no_grad(): # Process inputs first img_seq, txt, pe = model.process_inputs( - inputs_dict["image_latent"], - inputs_dict["cross_attn_conditioning"] + inputs_dict["image_latent"], inputs_dict["cross_attn_conditioning"] ) # Test forward_transformers - output_seq = model.forward_transformers( - img_seq, - txt, - timestep=inputs_dict["timestep"], - pe=pe - ) + output_seq = model.forward_transformers(img_seq, txt, timestep=inputs_dict["timestep"], pe=pe) # Check output shape - expected_out_channels = init_dict["in_channels"] * init_dict["patch_size"]**2 + expected_out_channels = init_dict["in_channels"] * init_dict["patch_size"] ** 2 self.assertEqual(output_seq.shape, (img_seq.shape[0], img_seq.shape[1], expected_out_channels)) def test_attention_mask(self): @@ -186,13 +181,10 @@ def test_attention_mask(self): batch_size = inputs_dict["cross_attn_conditioning"].shape[0] seq_len = inputs_dict["cross_attn_conditioning"].shape[1] attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool).to(torch_device) - attention_mask[:, seq_len//2:] = False # Mask second half + attention_mask[:, seq_len // 2 :] = False # Mask second half with torch.no_grad(): - outputs = model( - **inputs_dict, - cross_attn_mask=attention_mask - ) + outputs = model(**inputs_dict, cross_attn_mask=attention_mask) self.assertIsNotNone(outputs) expected_shape = inputs_dict["image_latent"].shape @@ -237,7 +229,7 @@ def test_gradient_checkpointing_enable(self): # Check that _activation_checkpointing is set for block in model.blocks: - self.assertTrue(hasattr(block, '_activation_checkpointing')) + self.assertTrue(hasattr(block, "_activation_checkpointing")) def test_from_config(self): init_dict, _ = self.prepare_init_args_and_inputs_for_common() @@ -249,4 +241,4 @@ def test_from_config(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 775a115dcc70e9683eb5eb2e07baa4ba3cc941af Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 19:05:51 +0000 Subject: [PATCH 25/38] remove dependencies to old checkpoints --- scripts/convert_mirage_to_diffusers.py | 229 +++++++++++++++++++++---- 1 file changed, 192 insertions(+), 37 deletions(-) diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_mirage_to_diffusers.py index 5e2a2ff768f4..eb6de1a37481 100644 --- a/scripts/convert_mirage_to_diffusers.py +++ b/scripts/convert_mirage_to_diffusers.py @@ -6,11 +6,12 @@ import argparse import json import os -import shutil import sys import torch from safetensors.torch import save_file +from dataclasses import dataclass, asdict +from typing import Tuple, Dict sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) @@ -18,35 +19,53 @@ from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel from diffusers.pipelines.mirage import MiragePipeline +@dataclass(frozen=True) +class MirageBase: + context_in_dim: int = 2304 + hidden_size: int = 1792 + mlp_ratio: float = 3.5 + num_heads: int = 28 + depth: int = 16 + axes_dim: Tuple[int, int] = (32, 32) + theta: int = 10_000 + time_factor: float = 1000.0 + time_max_period: int = 10_000 -def load_reference_config(vae_type: str) -> dict: - """Load transformer config from existing pipeline checkpoint.""" +@dataclass(frozen=True) +class MirageFlux(MirageBase): + in_channels: int = 16 + patch_size: int = 2 + + +@dataclass(frozen=True) +class MirageDCAE(MirageBase): + in_channels: int = 32 + patch_size: int = 1 + + +def build_config(vae_type: str) -> dict: if vae_type == "flux": - config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated/transformer/config.json" + cfg = MirageFlux() elif vae_type == "dc-ae": - config_path = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated/transformer/config.json" + cfg = MirageDCAE() else: raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") - if not os.path.exists(config_path): - raise FileNotFoundError(f"Reference config not found: {config_path}") - - with open(config_path, "r") as f: - config = json.load(f) + config_dict = asdict(cfg) + config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index] + return config_dict - print(f"✓ Loaded {vae_type} config: in_channels={config['in_channels']}") - return config -def create_parameter_mapping() -> dict: +def create_parameter_mapping(depth: int) -> dict: """Create mapping from old parameter names to new diffusers names.""" # Key mappings for structural changes mapping = {} # RMSNorm: scale -> weight - for i in range(16): # 16 layers + for i in range(depth): mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.qk_norm.query_norm.weight" mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.qk_norm.key_norm.weight" mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.k_norm.weight" @@ -57,12 +76,12 @@ def create_parameter_mapping() -> dict: return mapping -def convert_checkpoint_parameters(old_state_dict: dict) -> dict: +def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]: """Convert old checkpoint parameters to new diffusers format.""" print("Converting checkpoint parameters...") - mapping = create_parameter_mapping() + mapping = create_parameter_mapping(depth) converted_state_dict = {} # First, print available keys to understand structure @@ -135,7 +154,8 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Mi print(f"✓ Loaded checkpoint with {len(state_dict)} parameters") # Convert parameter names if needed - converted_state_dict = convert_checkpoint_parameters(state_dict) + model_depth = int(config.get("depth", 16)) + converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth) # Create transformer with config print("Creating MirageTransformer2DModel...") @@ -156,28 +176,164 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Mi return transformer -def copy_pipeline_components(vae_type: str, output_path: str): - """Copy VAE, scheduler, text encoder, and tokenizer from reference pipeline.""" + + +def create_scheduler_config(output_path: str): + """Create FlowMatchEulerDiscreteScheduler config.""" + + scheduler_config = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "num_train_timesteps": 1000, + "shift": 1.0 + } + + scheduler_path = os.path.join(output_path, "scheduler") + os.makedirs(scheduler_path, exist_ok=True) + + with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f: + json.dump(scheduler_config, f, indent=2) + + print("✓ Created scheduler config") + + +def create_vae_config(vae_type: str, output_path: str): + """Create VAE config based on type.""" if vae_type == "flux": - ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_fluxvae_gemmaT5_updated" + vae_config = { + "_class_name": "AutoencoderKL", + "latent_channels": 16, + "block_out_channels": [128, 256, 512, 512], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "scaling_factor": 0.3611, + "shift_factor": 0.1159, + "use_post_quant_conv": False, + "use_quant_conv": False + } else: # dc-ae - ref_pipeline = "/raid/shared/storage/home/davidb/diffusers/diffusers_pipeline_checkpoints/pipeline_checkpoint_dcae_gemmaT5_updated" + vae_config = { + "_class_name": "AutoencoderDC", + "latent_channels": 32, + "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], + "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], + "encoder_block_types": [ + "ResBlock", + "ResBlock", + "ResBlock", + "EfficientViTBlock", + "EfficientViTBlock", + "EfficientViTBlock" + ], + "decoder_block_types": [ + "ResBlock", + "ResBlock", + "ResBlock", + "EfficientViTBlock", + "EfficientViTBlock", + "EfficientViTBlock" + ], + "encoder_layers_per_block": [2, 2, 2, 3, 3, 3], + "decoder_layers_per_block": [3, 3, 3, 3, 3, 3], + "encoder_qkv_multiscales": [[], [], [], [5], [5], [5]], + "decoder_qkv_multiscales": [[], [], [], [5], [5], [5]], + "scaling_factor": 0.41407, + "upsample_block_type": "interpolate" + } + + vae_path = os.path.join(output_path, "vae") + os.makedirs(vae_path, exist_ok=True) + + with open(os.path.join(vae_path, "config.json"), "w") as f: + json.dump(vae_config, f, indent=2) + + print("✓ Created VAE config") + + +def create_text_encoder_config(output_path: str): + """Create T5GemmaEncoder config.""" + + text_encoder_config = { + "model_name": "google/t5gemma-2b-2b-ul2", + "model_max_length": 256, + "use_attn_mask": True, + "use_last_hidden_state": True + } - components = ["vae", "scheduler", "text_encoder", "tokenizer"] + text_encoder_path = os.path.join(output_path, "text_encoder") + os.makedirs(text_encoder_path, exist_ok=True) + + with open(os.path.join(text_encoder_path, "config.json"), "w") as f: + json.dump(text_encoder_config, f, indent=2) + + print("✓ Created text encoder config") + + +def create_tokenizer_config(output_path: str): + """Create GemmaTokenizerFast config and files.""" + + tokenizer_config = { + "add_bos_token": False, + "add_eos_token": False, + "added_tokens_decoder": { + "0": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, + "1": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, + "2": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, + "3": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, + "106": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, + "107": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True} + }, + "additional_special_tokens": ["", ""], + "bos_token": "", + "clean_up_tokenization_spaces": False, + "eos_token": "", + "extra_special_tokens": {}, + "model_max_length": 256, + "pad_token": "", + "padding_side": "right", + "sp_model_kwargs": {}, + "spaces_between_special_tokens": False, + "tokenizer_class": "GemmaTokenizer", + "unk_token": "", + "use_default_system_prompt": False + } - for component in components: - src_path = os.path.join(ref_pipeline, component) - dst_path = os.path.join(output_path, component) + special_tokens_map = { + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "" + } - if os.path.exists(src_path): - if os.path.isdir(src_path): - shutil.copytree(src_path, dst_path, dirs_exist_ok=True) - else: - shutil.copy2(src_path, dst_path) - print(f"✓ Copied {component}") - else: - print(f"⚠ Component not found: {src_path}") + tokenizer_path = os.path.join(output_path, "tokenizer") + os.makedirs(tokenizer_path, exist_ok=True) + + with open(os.path.join(tokenizer_path, "tokenizer_config.json"), "w") as f: + json.dump(tokenizer_config, f, indent=2) + + with open(os.path.join(tokenizer_path, "special_tokens_map.json"), "w") as f: + json.dump(special_tokens_map, f, indent=2) + + print("✓ Created tokenizer config (Note: tokenizer.json and tokenizer.model files need to be provided separately)") + + +def create_pipeline_components(vae_type: str, output_path: str): + """Create all pipeline components with proper configs.""" + + create_scheduler_config(output_path) + create_vae_config(vae_type, output_path) + create_text_encoder_config(output_path) + create_tokenizer_config(output_path) def create_model_index(vae_type: str, output_path: str): @@ -211,8 +367,7 @@ def main(args): if not os.path.exists(args.checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}") - # Load reference config based on VAE type - config = load_reference_config(args.vae_type) + config = build_config(args.vae_type) # Create output directory os.makedirs(args.output_path, exist_ok=True) @@ -234,8 +389,8 @@ def main(args): save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors")) print(f"✓ Saved transformer to {transformer_path}") - # Copy other pipeline components - copy_pipeline_components(args.vae_type, args.output_path) + # Create other pipeline components + create_pipeline_components(args.vae_type, args.output_path) # Create model index create_model_index(args.vae_type, args.output_path) From 1c6c25cf1d9588c25362a997a9dfb2d886bf389c Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 22:30:58 +0200 Subject: [PATCH 26/38] remove old checkpoints dependency --- scripts/convert_mirage_to_diffusers.py | 170 ++---------------- .../pipelines/mirage/pipeline_mirage.py | 68 +++++-- 2 files changed, 63 insertions(+), 175 deletions(-) diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_mirage_to_diffusers.py index eb6de1a37481..2ddb708bc704 100644 --- a/scripts/convert_mirage_to_diffusers.py +++ b/scripts/convert_mirage_to_diffusers.py @@ -84,13 +84,6 @@ def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth mapping = create_parameter_mapping(depth) converted_state_dict = {} - # First, print available keys to understand structure - print("Available keys in checkpoint:") - for key in sorted(old_state_dict.keys())[:10]: # Show first 10 keys - print(f" {key}") - if len(old_state_dict) > 10: - print(f" ... and {len(old_state_dict) - 10} more") - for key, value in old_state_dict.items(): new_key = key @@ -196,172 +189,37 @@ def create_scheduler_config(output_path: str): print("✓ Created scheduler config") -def create_vae_config(vae_type: str, output_path: str): - """Create VAE config based on type.""" - - if vae_type == "flux": - vae_config = { - "_class_name": "AutoencoderKL", - "latent_channels": 16, - "block_out_channels": [128, 256, 512, 512], - "down_block_types": [ - "DownEncoderBlock2D", - "DownEncoderBlock2D", - "DownEncoderBlock2D", - "DownEncoderBlock2D" - ], - "up_block_types": [ - "UpDecoderBlock2D", - "UpDecoderBlock2D", - "UpDecoderBlock2D", - "UpDecoderBlock2D" - ], - "scaling_factor": 0.3611, - "shift_factor": 0.1159, - "use_post_quant_conv": False, - "use_quant_conv": False - } - else: # dc-ae - vae_config = { - "_class_name": "AutoencoderDC", - "latent_channels": 32, - "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], - "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], - "encoder_block_types": [ - "ResBlock", - "ResBlock", - "ResBlock", - "EfficientViTBlock", - "EfficientViTBlock", - "EfficientViTBlock" - ], - "decoder_block_types": [ - "ResBlock", - "ResBlock", - "ResBlock", - "EfficientViTBlock", - "EfficientViTBlock", - "EfficientViTBlock" - ], - "encoder_layers_per_block": [2, 2, 2, 3, 3, 3], - "decoder_layers_per_block": [3, 3, 3, 3, 3, 3], - "encoder_qkv_multiscales": [[], [], [], [5], [5], [5]], - "decoder_qkv_multiscales": [[], [], [], [5], [5], [5]], - "scaling_factor": 0.41407, - "upsample_block_type": "interpolate" - } - - vae_path = os.path.join(output_path, "vae") - os.makedirs(vae_path, exist_ok=True) - - with open(os.path.join(vae_path, "config.json"), "w") as f: - json.dump(vae_config, f, indent=2) - - print("✓ Created VAE config") - - -def create_text_encoder_config(output_path: str): - """Create T5GemmaEncoder config.""" - - text_encoder_config = { - "model_name": "google/t5gemma-2b-2b-ul2", - "model_max_length": 256, - "use_attn_mask": True, - "use_last_hidden_state": True - } - - text_encoder_path = os.path.join(output_path, "text_encoder") - os.makedirs(text_encoder_path, exist_ok=True) - - with open(os.path.join(text_encoder_path, "config.json"), "w") as f: - json.dump(text_encoder_config, f, indent=2) - - print("✓ Created text encoder config") - - -def create_tokenizer_config(output_path: str): - """Create GemmaTokenizerFast config and files.""" - - tokenizer_config = { - "add_bos_token": False, - "add_eos_token": False, - "added_tokens_decoder": { - "0": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, - "1": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, - "2": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, - "3": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, - "106": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True}, - "107": {"content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, "special": True} - }, - "additional_special_tokens": ["", ""], - "bos_token": "", - "clean_up_tokenization_spaces": False, - "eos_token": "", - "extra_special_tokens": {}, - "model_max_length": 256, - "pad_token": "", - "padding_side": "right", - "sp_model_kwargs": {}, - "spaces_between_special_tokens": False, - "tokenizer_class": "GemmaTokenizer", - "unk_token": "", - "use_default_system_prompt": False - } - - special_tokens_map = { - "bos_token": "", - "eos_token": "", - "pad_token": "", - "unk_token": "" - } - - tokenizer_path = os.path.join(output_path, "tokenizer") - os.makedirs(tokenizer_path, exist_ok=True) - - with open(os.path.join(tokenizer_path, "tokenizer_config.json"), "w") as f: - json.dump(tokenizer_config, f, indent=2) - - with open(os.path.join(tokenizer_path, "special_tokens_map.json"), "w") as f: - json.dump(special_tokens_map, f, indent=2) - - print("✓ Created tokenizer config (Note: tokenizer.json and tokenizer.model files need to be provided separately)") - - -def create_pipeline_components(vae_type: str, output_path: str): - """Create all pipeline components with proper configs.""" - - create_scheduler_config(output_path) - create_vae_config(vae_type, output_path) - create_text_encoder_config(output_path) - create_tokenizer_config(output_path) def create_model_index(vae_type: str, output_path: str): - """Create model_index.json for the pipeline.""" + """Create model_index.json for the pipeline with HuggingFace model references.""" if vae_type == "flux": - vae_class = "AutoencoderKL" + vae_model_name = "black-forest-labs/FLUX.1-dev" + vae_subfolder = "vae" else: # dc-ae - vae_class = "AutoencoderDC" + vae_model_name = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers" + vae_subfolder = None + + # Text encoder and tokenizer always use T5Gemma + text_model_name = "google/t5gemma-2b-2b-ul2" model_index = { "_class_name": "MiragePipeline", "_diffusers_version": "0.31.0.dev0", "_name_or_path": os.path.basename(output_path), "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], - "text_encoder": ["transformers", "T5GemmaEncoder"], - "tokenizer": ["transformers", "GemmaTokenizerFast"], + "text_encoder": text_model_name, + "tokenizer": text_model_name, "transformer": ["diffusers", "MirageTransformer2DModel"], - "vae": ["diffusers", vae_class], + "vae": vae_model_name, + "vae_subfolder": vae_subfolder, } model_index_path = os.path.join(output_path, "model_index.json") with open(model_index_path, "w") as f: json.dump(model_index, f, indent=2) - print("✓ Created model_index.json") - - def main(args): # Validate inputs if not os.path.exists(args.checkpoint_path): @@ -389,10 +247,8 @@ def main(args): save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors")) print(f"✓ Saved transformer to {transformer_path}") - # Create other pipeline components - create_pipeline_components(args.vae_type, args.output_path) + create_scheduler_config(args.output_path) - # Create model index create_model_index(args.vae_type, args.output_path) # Verify the pipeline can be loaded diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py index c4a4783c5f38..e6a13ff226cd 100644 --- a/src/diffusers/pipelines/mirage/pipeline_mirage.py +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -247,26 +247,61 @@ class MiragePipeline( @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): """ - Override from_pretrained to ensure T5GemmaEncoder is available for loading. + Override from_pretrained to load VAE and text encoder from HuggingFace models. - This ensures that T5GemmaEncoder from transformers is accessible in the module namespace - during component loading, which is required for MiragePipeline checkpoints that use - T5GemmaEncoder as the text encoder. + The MiragePipeline checkpoints only store transformer and scheduler locally. + VAE and text encoder are loaded from external HuggingFace models as specified + in model_index.json. """ - # Ensure T5GemmaEncoder is available for loading - import transformers + import json + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel + + model_index_path = os.path.join(pretrained_model_name_or_path, "model_index.json") + if not os.path.exists(model_index_path): + raise ValueError(f"model_index.json not found in {pretrained_model_name_or_path}") + + with open(model_index_path, "r") as f: + model_index = json.load(f) + + vae_model_name = model_index.get("vae") + vae_subfolder = model_index.get("vae_subfolder") + text_model_name = model_index.get("text_encoder") + tokenizer_model_name = model_index.get("tokenizer") + + logger.info(f"Loading VAE from {vae_model_name}...") + if "FLUX" in vae_model_name or "flux" in vae_model_name: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder=vae_subfolder) + else: # DC-AE + vae = AutoencoderDC.from_pretrained(vae_model_name) + + logger.info(f"Loading text encoder from {text_model_name}...") + t5gemma_model = T5GemmaModel.from_pretrained(text_model_name) + text_encoder = t5gemma_model.encoder + + logger.info(f"Loading tokenizer from {tokenizer_model_name}...") + tokenizer = GemmaTokenizerFast.from_pretrained(tokenizer_model_name) + tokenizer.model_max_length = 256 + + # Load transformer and scheduler from local checkpoint + logger.info(f"Loading transformer from {pretrained_model_name_or_path}...") + transformer = MirageTransformer2DModel.from_pretrained( + pretrained_model_name_or_path, subfolder="transformer" + ) - if not hasattr(transformers, "T5GemmaEncoder"): - try: - from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + logger.info(f"Loading scheduler from {pretrained_model_name_or_path}...") + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder="scheduler" + ) - transformers.T5GemmaEncoder = T5GemmaEncoder - except ImportError: - # T5GemmaEncoder not available in this transformers version - pass + pipeline = cls( + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + ) - # Proceed with standard loading - return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + return pipeline def __init__( self, @@ -283,11 +318,8 @@ def __init__( "MirageTransformer2DModel is not available. Please ensure the transformer_mirage module is properly installed." ) - # Store standard components self.text_encoder = text_encoder self.tokenizer = tokenizer - - # Initialize text preprocessor self.text_preprocessor = TextPreprocessor() self.register_modules( From b0d965cc508b447569912f0747dfc5b4746e2d6b Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 20:56:51 +0000 Subject: [PATCH 27/38] move default height and width in checkpoint config --- scripts/convert_mirage_to_diffusers.py | 9 +++++++++ .../pipelines/mirage/pipeline_mirage.py | 16 ++++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_mirage_to_diffusers.py index 2ddb708bc704..37de253d1448 100644 --- a/scripts/convert_mirage_to_diffusers.py +++ b/scripts/convert_mirage_to_diffusers.py @@ -19,6 +19,9 @@ from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel from diffusers.pipelines.mirage import MiragePipeline +DEFAULT_HEIGHT = 512 +DEFAULT_WIDTH = 512 + @dataclass(frozen=True) class MirageBase: context_in_dim: int = 2304 @@ -197,9 +200,13 @@ def create_model_index(vae_type: str, output_path: str): if vae_type == "flux": vae_model_name = "black-forest-labs/FLUX.1-dev" vae_subfolder = "vae" + default_height = DEFAULT_HEIGHT + default_width = DEFAULT_WIDTH else: # dc-ae vae_model_name = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers" vae_subfolder = None + default_height = DEFAULT_HEIGHT + default_width = DEFAULT_WIDTH # Text encoder and tokenizer always use T5Gemma text_model_name = "google/t5gemma-2b-2b-ul2" @@ -214,6 +221,8 @@ def create_model_index(vae_type: str, output_path: str): "transformer": ["diffusers", "MirageTransformer2DModel"], "vae": vae_model_name, "vae_subfolder": vae_subfolder, + "default_height": default_height, + "default_width": default_width, } model_index_path = os.path.join(output_path, "model_index.json") diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py index e6a13ff226cd..9d247eecbd7f 100644 --- a/src/diffusers/pipelines/mirage/pipeline_mirage.py +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -31,6 +31,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderDC, AutoencoderKL +from ...models.transformers.transformer_mirage import seq2img from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( logging, @@ -46,6 +47,9 @@ except ImportError: MirageTransformer2DModel = None +DEFAULT_HEIGHT = 512 +DEFAULT_WIDTH = 512 + logger = logging.get_logger(__name__) @@ -267,6 +271,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P vae_subfolder = model_index.get("vae_subfolder") text_model_name = model_index.get("text_encoder") tokenizer_model_name = model_index.get("tokenizer") + default_height = model_index.get("default_height", DEFAULT_HEIGHT) + default_width = model_index.get("default_width", DEFAULT_WIDTH) logger.info(f"Loading VAE from {vae_model_name}...") if "FLUX" in vae_model_name or "flux" in vae_model_name: @@ -301,6 +307,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P vae=vae, ) + # Store default dimensions as pipeline attributes + pipeline.default_height = default_height + pipeline.default_width = default_width + return pipeline def __init__( @@ -558,8 +568,8 @@ def __call__( """ # 0. Default height and width to transformer config - height = height or 256 - width = width or 256 + height = height or getattr(self, 'default_height', DEFAULT_HEIGHT) + width = width or getattr(self, 'default_width', DEFAULT_WIDTH) # 1. Check inputs self.check_inputs( @@ -642,8 +652,6 @@ def __call__( ) # Convert back to image format - from ...models.transformers.transformer_mirage import seq2img - noise_both = seq2img(img_seq, self.transformer.patch_size, latents_in.shape) # Apply CFG From 235fe491fe46dcfca7da299a57b48adbe9ad1c96 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 21:26:03 +0000 Subject: [PATCH 28/38] add docstrings --- .../models/transformers/transformer_mirage.py | 367 +++++++++++++++++- .../test_models_transformer_mirage.py | 6 +- 2 files changed, 351 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index 923d44d4f1ec..c509f797fb8b 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -33,20 +33,70 @@ logger = logging.get_logger(__name__) -def get_image_ids(bs: int, h: int, w: int, patch_size: int, device: torch.device) -> Tensor: - img_ids = torch.zeros(h // patch_size, w // patch_size, 2, device=device) - img_ids[..., 0] = torch.arange(h // patch_size, device=device)[:, None] - img_ids[..., 1] = torch.arange(w // patch_size, device=device)[None, :] - return img_ids.reshape((h // patch_size) * (w // patch_size), 2).unsqueeze(0).repeat(bs, 1, 1) +def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> Tensor: + r""" + Generates 2D patch coordinate indices for a batch of images. + + Parameters: + batch_size (`int`): + Number of images in the batch. + height (`int`): + Height of the input images (in pixels). + width (`int`): + Width of the input images (in pixels). + patch_size (`int`): + Size of the square patches that the image is divided into. + device (`torch.device`): + The device on which to create the tensor. + + Returns: + `torch.Tensor`: + Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) + coordinates of each patch in the image grid. + """ + + img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device) + img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None] + img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :] + return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1) def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: + r""" + Applies rotary positional embeddings (RoPE) to a query tensor. + + Parameters: + xq (`torch.Tensor`): + Input tensor of shape `(..., dim)` representing the queries. + freqs_cis (`torch.Tensor`): + Precomputed rotary frequency components of shape `(..., dim/2, 2)` + containing cosine and sine pairs. + + Returns: + `torch.Tensor`: + Tensor of the same shape as `xq` with rotary embeddings applied. + """ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq) class EmbedND(nn.Module): + r""" + N-dimensional rotary positional embedding. + + This module creates rotary embeddings (RoPE) across multiple axes, where each + axis can have its own embedding dimension. The embeddings are combined and + returned as a single tensor + + Parameters: + dim (int): + Base embedding dimension (must be even). + theta (int): + Scaling factor that controls the frequency spectrum of the rotary embeddings. + axes_dim (list[int]): + List of embedding dimensions for each axis (each must be even). + """ def __init__(self, dim: int, theta: int, axes_dim: list[int]): super().__init__() self.dim = dim @@ -73,6 +123,19 @@ def forward(self, ids: Tensor) -> Tensor: class MLPEmbedder(nn.Module): + r""" + A simple 2-layer MLP used for embedding inputs. + + Parameters: + in_dim (`int`): + Dimensionality of the input features. + hidden_dim (`int`): + Dimensionality of the hidden and output embedding space. + + Returns: + `torch.Tensor`: + Tensor of shape `(..., hidden_dim)` containing the embedded representations. + """ def __init__(self, in_dim: int, hidden_dim: int): super().__init__() self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) @@ -84,6 +147,19 @@ def forward(self, x: Tensor) -> Tensor: class QKNorm(torch.nn.Module): + r""" + Applies RMS normalization to query and key tensors separately before attention + which can help stabilize training and improve numerical precision. + + Parameters: + dim (`int`): + Dimensionality of the query and key vectors. + + Returns: + (`torch.Tensor`, `torch.Tensor`): + A tuple `(q, k)` where both are normalized and cast to the same dtype + as the value tensor `v`. + """ def __init__(self, dim: int): super().__init__() self.query_norm = RMSNorm(dim, eps=1e-6) @@ -103,6 +179,22 @@ class ModulationOut: class Modulation(nn.Module): + r""" + Modulation network that generates scale, shift, and gating parameters. + + Given an input vector, the module projects it through a linear layer to + produce six chunks, which are grouped into two `ModulationOut` objects. + + Parameters: + dim (`int`): + Dimensionality of the input vector. The output will have `6 * dim` + features internally. + + Returns: + (`ModulationOut`, `ModulationOut`): + A tuple of two modulation outputs. Each `ModulationOut` contains + three components (e.g., scale, shift, gate). + """ def __init__(self, dim: int): super().__init__() self.lin = nn.Linear(dim, 6 * dim, bias=True) @@ -115,6 +207,68 @@ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]: class MirageBlock(nn.Module): + r""" + Multimodal transformer block with text–image cross-attention, modulation, and MLP. + + Parameters: + hidden_size (`int`): + Dimension of the hidden representations. + num_heads (`int`): + Number of attention heads. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Expansion ratio for the hidden dimension inside the MLP. + qk_scale (`float`, *optional*): + Scale factor for queries and keys. If not provided, defaults to + ``head_dim**-0.5``. + + Attributes: + img_pre_norm (`nn.LayerNorm`): + Pre-normalization applied to image tokens before QKV projection. + img_qkv_proj (`nn.Linear`): + Linear projection to produce image queries, keys, and values. + qk_norm (`QKNorm`): + RMS normalization applied separately to image queries and keys. + txt_kv_proj (`nn.Linear`): + Linear projection to produce text keys and values. + k_norm (`RMSNorm`): + RMS normalization applied to text keys. + attention (`Attention`): + Multi-head attention module for cross-attention between image, text, + and optional spatial conditioning tokens. + post_attention_layernorm (`nn.LayerNorm`): + Normalization applied after attention. + gate_proj / up_proj / down_proj (`nn.Linear`): + Feedforward layers forming the gated MLP. + mlp_act (`nn.GELU`): + Nonlinear activation used in the MLP. + modulation (`Modulation`): + Produces scale/shift/gating parameters for modulated layers. + spatial_cond_kv_proj (`nn.Linear`, *optional*): + Projection for optional spatial conditioning tokens. + + Methods: + attn_forward(img, txt, pe, modulation, spatial_conditioning=None, attention_mask=None): + Compute cross-attention between image and text tokens, with optional + spatial conditioning and attention masking. + + Parameters: + img (`torch.Tensor`): + Image tokens of shape `(B, L_img, hidden_size)`. + txt (`torch.Tensor`): + Text tokens of shape `(B, L_txt, hidden_size)`. + pe (`torch.Tensor`): + Rotary positional embeddings to apply to queries and keys. + modulation (`ModulationOut`): + Scale and shift parameters for modulating image tokens. + spatial_conditioning (`torch.Tensor`, *optional*): + Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. + attention_mask (`torch.Tensor`, *optional*): + Boolean mask of shape `(B, L_txt)` where 0 marks padding. + + Returns: + `torch.Tensor`: + Attention output of shape `(B, L_img, hidden_size)`. + """ def __init__( self, hidden_size: int, @@ -163,7 +317,7 @@ def __init__( self.modulation = Modulation(hidden_size) self.spatial_cond_kv_proj: None | nn.Linear = None - def attn_forward( + def _attn_forward( self, img: Tensor, txt: Tensor, @@ -236,7 +390,7 @@ def attn_forward( return attn - def ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor: + def _ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor: x = (1 + modulation.scale) * self.post_attention_layernorm(x) + modulation.shift return self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)) @@ -250,9 +404,36 @@ def forward( attention_mask: Tensor | None = None, **_: dict[str, Any], ) -> Tensor: + r""" + Runs modulation-gated cross-attention and MLP, with residual connections. + + Parameters: + img (`torch.Tensor`): + Image tokens of shape `(B, L_img, hidden_size)`. + txt (`torch.Tensor`): + Text tokens of shape `(B, L_txt, hidden_size)`. + vec (`torch.Tensor`): + Conditioning vector used by `Modulation` to produce scale/shift/gates, + shape `(B, hidden_size)` (or broadcastable). + pe (`torch.Tensor`): + Rotary positional embeddings applied inside attention. + spatial_conditioning (`torch.Tensor`, *optional*): + Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. Used only + if spatial conditioning is enabled in the block. + attention_mask (`torch.Tensor`, *optional*): + Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding. + **_: + Ignored additional keyword arguments for API compatibility. + + Returns: + `torch.Tensor`: + Updated image tokens of shape `(B, L_img, hidden_size)`. + """ + + mod_attn, mod_mlp = self.modulation(vec) - img = img + mod_attn.gate * self.attn_forward( + img = img + mod_attn.gate * self._attn_forward( img, txt, pe, @@ -260,12 +441,39 @@ def forward( spatial_conditioning=spatial_conditioning, attention_mask=attention_mask, ) - img = img + mod_mlp.gate * self.ffn_forward(img, mod_mlp) + img = img + mod_mlp.gate * self._ffn_forward(img, mod_mlp) return img class LastLayer(nn.Module): + r""" + Final projection layer with adaptive LayerNorm modulation. + + This layer applies a normalized and modulated transformation to input tokens + and projects them into patch-level outputs. + + Parameters: + hidden_size (`int`): + Dimensionality of the input tokens. + patch_size (`int`): + Size of the square image patches. + out_channels (`int`): + Number of output channels per pixel (e.g. RGB = 3). + + Forward Inputs: + x (`torch.Tensor`): + Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches. + vec (`torch.Tensor`): + Conditioning vector of shape `(B, hidden_size)` used to generate + shift and scale parameters for adaptive LayerNorm. + + Returns: + `torch.Tensor`: + Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`. + """ + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) @@ -284,12 +492,41 @@ def forward(self, x: Tensor, vec: Tensor) -> Tensor: def img2seq(img: Tensor, patch_size: int) -> Tensor: - """Flatten an image into a sequence of patches""" + r""" + Flattens an image tensor into a sequence of non-overlapping patches. + + Parameters: + img (`torch.Tensor`): + Input image tensor of shape `(B, C, H, W)`. + patch_size (`int`): + Size of each square patch. Must evenly divide both `H` and `W`. + + Returns: + `torch.Tensor`: + Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, + where `L = (H // patch_size) * (W // patch_size)` is the number of patches. + """ return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: - """Revert img2seq""" + r""" + Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`). + + Parameters: + seq (`torch.Tensor`): + Patch sequence of shape `(B, L, C * patch_size * patch_size)`, + where `L = (H // patch_size) * (W // patch_size)`. + patch_size (`int`): + Size of each square patch. + shape (`tuple` or `torch.Tensor`): + The original image spatial shape `(H, W)`. If a tensor is provided, + the first two values are interpreted as height and width. + + Returns: + `torch.Tensor`: + Reconstructed image tensor of shape `(B, C, H, W)`. + """ if isinstance(shape, tuple): shape = shape[-2:] elif isinstance(shape, torch.Tensor): @@ -300,7 +537,70 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: class MirageTransformer2DModel(ModelMixin, ConfigMixin): - """Mirage Transformer model with IP-Adapter support.""" + r""" + Transformer-based 2D model for text to image generation. + It supports attention processor injection and LoRA scaling. + + Parameters: + in_channels (`int`, *optional*, defaults to 16): + Number of input channels in the latent image. + patch_size (`int`, *optional*, defaults to 2): + Size of the square patches used to flatten the input image. + context_in_dim (`int`, *optional*, defaults to 2304): + Dimensionality of the text conditioning input. + hidden_size (`int`, *optional*, defaults to 1792): + Dimension of the hidden representation. + mlp_ratio (`float`, *optional*, defaults to 3.5): + Expansion ratio for the hidden dimension inside MLP blocks. + num_heads (`int`, *optional*, defaults to 28): + Number of attention heads. + depth (`int`, *optional*, defaults to 16): + Number of transformer blocks. + axes_dim (`list[int]`, *optional*): + List of dimensions for each positional embedding axis. Defaults to `[32, 32]`. + theta (`int`, *optional*, defaults to 10000): + Frequency scaling factor for rotary embeddings. + time_factor (`float`, *optional*, defaults to 1000.0): + Scaling factor applied in timestep embeddings. + time_max_period (`int`, *optional*, defaults to 10000): + Maximum frequency period for timestep embeddings. + conditioning_block_ids (`list[int]`, *optional*): + Indices of blocks that receive conditioning. Defaults to all blocks. + **kwargs: + Additional keyword arguments forwarded to the config. + + Attributes: + pe_embedder (`EmbedND`): + Multi-axis rotary embedding generator for positional encodings. + img_in (`nn.Linear`): + Projection layer for image patch tokens. + time_in (`MLPEmbedder`): + Embedding layer for timestep embeddings. + txt_in (`nn.Linear`): + Projection layer for text conditioning. + blocks (`nn.ModuleList`): + Stack of transformer blocks (`MirageBlock`). + final_layer (`LastLayer`): + Projection layer mapping hidden tokens back to patch outputs. + + Methods: + attn_processors: + Returns a dictionary of all attention processors in the model. + set_attn_processor(processor): + Replaces attention processors across all attention layers. + process_inputs(image_latent, txt): + Converts inputs into patch tokens, encodes text, and produces positional encodings. + compute_timestep_embedding(timestep, dtype): + Creates a timestep embedding of dimension 256, scaled and projected. + forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask, **block_kwargs): + Runs the sequence of transformer blocks over image and text tokens. + forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None, attention_kwargs=None, return_dict=True): + Full forward pass from latent input to reconstructed output image. + + Returns: + `Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing: + - `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`. + """ config_name = "config.json" _supports_gradient_checkpointing = True @@ -424,8 +724,8 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) - def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: - """Timestep independent stuff""" + def _process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: + txt = self.txt_in(txt) img = img2seq(image_latent, self.patch_size) bs, _, h, w = image_latent.shape @@ -433,7 +733,7 @@ def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[T pe = self.pe_embedder(img_ids) return img, txt, pe - def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: + def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: return self.time_in( get_timestep_embedding( timesteps=timestep, @@ -444,7 +744,7 @@ def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Te ).to(dtype) ) - def forward_transformers( + def _forward_transformers( self, image_latent: Tensor, cross_attn_conditioning: Tensor, @@ -460,7 +760,7 @@ def forward_transformers( else: if timestep is None: raise ValueError("Please provide either a timestep or a timestep_embedding") - vec = self.compute_timestep_embedding(timestep, dtype=img.dtype) + vec = self._compute_timestep_embedding(timestep, dtype=img.dtype) for block in self.blocks: img = block(img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs) @@ -478,6 +778,35 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: + r""" + Forward pass of the MirageTransformer2DModel. + + The latent image is split into patch tokens, combined with text conditioning, + and processed through a stack of transformer blocks modulated by the timestep. + The output is reconstructed into the latent image space. + + Parameters: + image_latent (`torch.Tensor`): + Input latent image tensor of shape `(B, C, H, W)`. + timestep (`torch.Tensor`): + Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning. + cross_attn_conditioning (`torch.Tensor`): + Text conditioning tensor of shape `(B, L_txt, context_in_dim)`. + micro_conditioning (`torch.Tensor`): + Extra conditioning vector (currently unused, reserved for future use). + cross_attn_mask (`torch.Tensor`, *optional*): + Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence. + attention_kwargs (`dict`, *optional*): + Additional arguments passed to attention layers. If using the PEFT backend, + the key `"scale"` controls LoRA scaling (default: 1.0). + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a `Transformer2DModelOutput` or a tuple. + + Returns: + `Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple: + + - `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`. + """ if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -491,8 +820,8 @@ def forward( logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) - img_seq, txt, pe = self.process_inputs(image_latent, cross_attn_conditioning) - img_seq = self.forward_transformers(img_seq, txt, timestep, pe=pe, attention_mask=cross_attn_mask) + img_seq, txt, pe = self._process_inputs(image_latent, cross_attn_conditioning) + img_seq = self._forward_transformers(img_seq, txt, timestep, pe=pe, attention_mask=cross_attn_mask) output = seq2img(img_seq, self.patch_size, image_latent.shape) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer diff --git a/tests/models/transformers/test_models_transformer_mirage.py b/tests/models/transformers/test_models_transformer_mirage.py index 0085627aa7e4..fe7436debc4c 100644 --- a/tests/models/transformers/test_models_transformer_mirage.py +++ b/tests/models/transformers/test_models_transformer_mirage.py @@ -132,7 +132,7 @@ def test_process_inputs(self): model.eval() with torch.no_grad(): - img_seq, txt, pe = model.process_inputs( + img_seq, txt, pe = model._process_inputs( inputs_dict["image_latent"], inputs_dict["cross_attn_conditioning"] ) @@ -160,12 +160,12 @@ def test_forward_transformers(self): with torch.no_grad(): # Process inputs first - img_seq, txt, pe = model.process_inputs( + img_seq, txt, pe = model._process_inputs( inputs_dict["image_latent"], inputs_dict["cross_attn_conditioning"] ) # Test forward_transformers - output_seq = model.forward_transformers(img_seq, txt, timestep=inputs_dict["timestep"], pe=pe) + output_seq = model._forward_transformers(img_seq, txt, timestep=inputs_dict["timestep"], pe=pe) # Check output shape expected_out_channels = init_dict["in_channels"] * init_dict["patch_size"] ** 2 From a6ff5799588b0ec3181a1f133e99bde4e31077cb Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 21:28:43 +0000 Subject: [PATCH 29/38] if conditions and raised as ValueError instead of asserts --- .../models/transformers/transformer_mirage.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_mirage.py index c509f797fb8b..90ba11fb2d24 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_mirage.py @@ -360,10 +360,12 @@ def _attn_forward( bs, _, l_img, _ = img_q.shape l_txt = txt_k.shape[2] - assert attention_mask.dim() == 2, f"Unsupported attention_mask shape: {attention_mask.shape}" - assert attention_mask.shape[-1] == l_txt, ( - f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" - ) + if attention_mask.dim() != 2: + raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") + if attention_mask.shape[-1] != l_txt: + raise ValueError( + f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" + ) device = img_q.device From 3a915039a14ae793f2debaff06d9ab21b5e8a23d Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 21:33:28 +0000 Subject: [PATCH 30/38] small fix --- src/diffusers/pipelines/mirage/pipeline_mirage.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py index 9d247eecbd7f..50304ae1a3ad 100644 --- a/src/diffusers/pipelines/mirage/pipeline_mirage.py +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -640,13 +640,13 @@ def __call__( t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) # Process inputs for transformer - img_seq, txt, pe = self.transformer.process_inputs(latents_in, ca_embed) + img_seq, txt, pe = self.transformer._process_inputs(latents_in, ca_embed) # Forward through transformer layers - img_seq = self.transformer.forward_transformers( + img_seq = self.transformer._forward_transformers( img_seq, txt, - time_embedding=self.transformer.compute_timestep_embedding(t_cont, img_seq.dtype), + time_embedding=self.transformer._compute_timestep_embedding(t_cont, img_seq.dtype), pe=pe, attention_mask=ca_mask, ) From e200cf64f4bc91dbe03d61149c7fd0d87b3c6659 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 21:34:47 +0000 Subject: [PATCH 31/38] nit remove try block at import --- src/diffusers/pipelines/mirage/pipeline_mirage.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/mirage/pipeline_mirage.py index 50304ae1a3ad..ced78adec786 100644 --- a/src/diffusers/pipelines/mirage/pipeline_mirage.py +++ b/src/diffusers/pipelines/mirage/pipeline_mirage.py @@ -31,7 +31,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderDC, AutoencoderKL -from ...models.transformers.transformer_mirage import seq2img +from ...models.transformers.transformer_mirage import MirageTransformer2DModel, seq2img from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( logging, @@ -42,11 +42,6 @@ from .pipeline_output import MiragePipelineOutput -try: - from ...models.transformers.transformer_mirage import MirageTransformer2DModel -except ImportError: - MirageTransformer2DModel = None - DEFAULT_HEIGHT = 512 DEFAULT_WIDTH = 512 From 2ea8976a82bf25b0bd4ac4d5b8a20b3bc2a24b41 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 30 Sep 2025 21:35:16 +0000 Subject: [PATCH 32/38] mirage pipeline doc --- docs/source/en/api/pipelines/mirage.md | 158 +++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 docs/source/en/api/pipelines/mirage.md diff --git a/docs/source/en/api/pipelines/mirage.md b/docs/source/en/api/pipelines/mirage.md new file mode 100644 index 000000000000..3383bbecae2a --- /dev/null +++ b/docs/source/en/api/pipelines/mirage.md @@ -0,0 +1,158 @@ + + +# MiragePipeline + +
+ LoRA +
+ +Mirage is a text-to-image diffusion model using a transformer-based architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports both Flux VAE (AutoencoderKL) and DC-AE (AutoencoderDC) for latent compression. + +Key features: + +- **Transformer Architecture**: Uses a modern transformer-based denoising model with attention mechanisms optimized for image generation +- **Flow Matching**: Employs flow matching with Euler discrete scheduling for efficient sampling +- **Flexible VAE Support**: Compatible with both Flux VAE (8x compression, 16 latent channels) and DC-AE (32x compression, 32 latent channels) +- **T5Gemma Text Encoder**: Uses Google's T5Gemma-2B-2B-UL2 model for text encoding with strong text-image alignment +- **Efficient Architecture**: ~1.3B parameters in the transformer, enabling fast inference while maintaining quality +- **Modular Design**: Text encoder and VAE weights are loaded from HuggingFace, keeping checkpoint sizes small + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +## Loading the Pipeline + +Mirage checkpoints only store the transformer and scheduler weights locally. The VAE and text encoder are automatically loaded from HuggingFace during pipeline initialization: + +```py +from diffusers import MiragePipeline + +# Load pipeline - VAE and text encoder will be loaded from HuggingFace +pipe = MiragePipeline.from_pretrained("path/to/mirage_checkpoint") +pipe.to("cuda") + +prompt = "A digital painting of a rusty, vintage tram on a sandy beach" +image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] +image.save("mirage_output.png") +``` + +### Manual Component Loading + +You can also load components individually: + +```py +import torch +from diffusers import MiragePipeline +from diffusers.models import AutoencoderKL, AutoencoderDC +from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from transformers import T5GemmaModel, GemmaTokenizerFast + +# Load transformer +transformer = MirageTransformer2DModel.from_pretrained( + "path/to/checkpoint", subfolder="transformer" +) + +# Load scheduler +scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + "path/to/checkpoint", subfolder="scheduler" +) + +# Load T5Gemma text encoder +t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") +text_encoder = t5gemma_model.encoder +tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") + +# Load VAE - choose either Flux VAE or DC-AE +# Flux VAE (16 latent channels): +vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") +# Or DC-AE (32 latent channels): +# vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers") + +pipe = MiragePipeline( + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae +) +pipe.to("cuda") +``` + +## VAE Variants + +Mirage supports two VAE configurations: + +### Flux VAE (AutoencoderKL) +- **Compression**: 8x spatial compression +- **Latent channels**: 16 +- **Model**: `black-forest-labs/FLUX.1-dev` (subfolder: "vae") +- **Use case**: Balanced quality and speed + +### DC-AE (AutoencoderDC) +- **Compression**: 32x spatial compression +- **Latent channels**: 32 +- **Model**: `mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers` +- **Use case**: Higher compression for faster processing + +The VAE type is automatically determined from the checkpoint's `model_index.json` configuration. + +## Generation Parameters + +Key parameters for image generation: + +- **num_inference_steps**: Number of denoising steps (default: 28). More steps generally improve quality at the cost of speed. +- **guidance_scale**: Classifier-free guidance strength (default: 4.0). Higher values produce images more closely aligned with the prompt. +- **height/width**: Output image dimensions (default: 512x512). Can be customized in the checkpoint configuration. + +```py +# Example with custom parameters +image = pipe( + prompt="A serene mountain landscape at sunset", + num_inference_steps=28, + guidance_scale=4.0, + height=1024, + width=1024, + generator=torch.Generator("cuda").manual_seed(42) +).images[0] +``` + +## Memory Optimization + +For memory-constrained environments: + +```py +import torch +from diffusers import MiragePipeline + +pipe = MiragePipeline.from_pretrained("path/to/checkpoint", torch_dtype=torch.float16) +pipe.enable_model_cpu_offload() # Offload components to CPU when not in use + +# Or use sequential CPU offload for even lower memory +pipe.enable_sequential_cpu_offload() +``` + +## MiragePipeline + +[[autodoc]] MiragePipeline + - all + - __call__ + +## MiragePipelineOutput + +[[autodoc]] pipelines.mirage.pipeline_output.MiragePipelineOutput From 26429a370a34384562d772e9011cafc13bab8c07 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 7 Oct 2025 14:20:17 +0000 Subject: [PATCH 33/38] update doc --- docs/source/en/api/pipelines/mirage.md | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/docs/source/en/api/pipelines/mirage.md b/docs/source/en/api/pipelines/mirage.md index 3383bbecae2a..f0117795a989 100644 --- a/docs/source/en/api/pipelines/mirage.md +++ b/docs/source/en/api/pipelines/mirage.md @@ -22,18 +22,12 @@ Mirage is a text-to-image diffusion model using a transformer-based architecture Key features: -- **Transformer Architecture**: Uses a modern transformer-based denoising model with attention mechanisms optimized for image generation -- **Flow Matching**: Employs flow matching with Euler discrete scheduling for efficient sampling +- **Simplified MMDIT architecture**: Uses a simplified MMDIT architecture for image generation where text tokens are not updated through the transformer blocks +- **Flow Matching**: Employs flow matching with discrete scheduling for efficient sampling - **Flexible VAE Support**: Compatible with both Flux VAE (8x compression, 16 latent channels) and DC-AE (32x compression, 32 latent channels) -- **T5Gemma Text Encoder**: Uses Google's T5Gemma-2B-2B-UL2 model for text encoding with strong text-image alignment +- **T5Gemma Text Encoder**: Uses Google's T5Gemma-2B-2B-UL2 model for text encoding offering multiple language support - **Efficient Architecture**: ~1.3B parameters in the transformer, enabling fast inference while maintaining quality -- **Modular Design**: Text encoder and VAE weights are loaded from HuggingFace, keeping checkpoint sizes small - - -Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. - - ## Loading the Pipeline @@ -46,7 +40,7 @@ from diffusers import MiragePipeline pipe = MiragePipeline.from_pretrained("path/to/mirage_checkpoint") pipe.to("cuda") -prompt = "A digital painting of a rusty, vintage tram on a sandy beach" +prompt = "A vibrant night sky filled with colorful fireworks, with one large firework burst forming the glowing text “Photon” in bright, sparkling light" image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] image.save("mirage_output.png") ``` @@ -123,11 +117,11 @@ Key parameters for image generation: ```py # Example with custom parameters image = pipe( - prompt="A serene mountain landscape at sunset", + prompt="A vibrant night sky filled with colorful fireworks, with one large firework burst forming the glowing text “Photon” in bright, sparkling light", num_inference_steps=28, guidance_scale=4.0, - height=1024, - width=1024, + height=512, + width=512, generator=torch.Generator("cuda").manual_seed(42) ).images[0] ``` From 0abe136648c938b690ec49eb2aa09fec5a75f318 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Tue, 7 Oct 2025 14:25:27 +0000 Subject: [PATCH 34/38] rename model to photon --- .../en/api/pipelines/{mirage.md => photon.md} | 34 ++++++++-------- ...sers.py => convert_photon_to_diffusers.py} | 34 ++++++++-------- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 2 +- src/diffusers/models/attention_processor.py | 20 +++++----- src/diffusers/models/transformers/__init__.py | 2 +- ...former_mirage.py => transformer_photon.py} | 14 +++---- src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/mirage/__init__.py | 5 --- src/diffusers/pipelines/photon/__init__.py | 5 +++ .../{mirage => photon}/pipeline_output.py | 4 +- .../pipeline_photon.py} | 40 +++++++++---------- ...e.py => test_models_transformer_photon.py} | 14 +++---- 13 files changed, 89 insertions(+), 89 deletions(-) rename docs/source/en/api/pipelines/{mirage.md => photon.md} (86%) rename scripts/{convert_mirage_to_diffusers.py => convert_photon_to_diffusers.py} (92%) rename src/diffusers/models/transformers/{transformer_mirage.py => transformer_photon.py} (99%) delete mode 100644 src/diffusers/pipelines/mirage/__init__.py create mode 100644 src/diffusers/pipelines/photon/__init__.py rename src/diffusers/pipelines/{mirage => photon}/pipeline_output.py (93%) rename src/diffusers/pipelines/{mirage/pipeline_mirage.py => photon/pipeline_photon.py} (95%) rename tests/models/transformers/{test_models_transformer_mirage.py => test_models_transformer_photon.py} (95%) diff --git a/docs/source/en/api/pipelines/mirage.md b/docs/source/en/api/pipelines/photon.md similarity index 86% rename from docs/source/en/api/pipelines/mirage.md rename to docs/source/en/api/pipelines/photon.md index f0117795a989..f8f7098545f8 100644 --- a/docs/source/en/api/pipelines/mirage.md +++ b/docs/source/en/api/pipelines/photon.md @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. --> -# MiragePipeline +# PhotonPipeline
LoRA
-Mirage is a text-to-image diffusion model using a transformer-based architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports both Flux VAE (AutoencoderKL) and DC-AE (AutoencoderDC) for latent compression. +Photon is a text-to-image diffusion model using a transformer-based architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports both Flux VAE (AutoencoderKL) and DC-AE (AutoencoderDC) for latent compression. Key features: @@ -31,18 +31,18 @@ Key features: ## Loading the Pipeline -Mirage checkpoints only store the transformer and scheduler weights locally. The VAE and text encoder are automatically loaded from HuggingFace during pipeline initialization: +Photon checkpoints only store the transformer and scheduler weights locally. The VAE and text encoder are automatically loaded from HuggingFace during pipeline initialization: ```py -from diffusers import MiragePipeline +from diffusers import PhotonPipeline # Load pipeline - VAE and text encoder will be loaded from HuggingFace -pipe = MiragePipeline.from_pretrained("path/to/mirage_checkpoint") +pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint") pipe.to("cuda") prompt = "A vibrant night sky filled with colorful fireworks, with one large firework burst forming the glowing text “Photon” in bright, sparkling light" image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] -image.save("mirage_output.png") +image.save("photon_output.png") ``` ### Manual Component Loading @@ -51,14 +51,14 @@ You can also load components individually: ```py import torch -from diffusers import MiragePipeline +from diffusers import PhotonPipeline from diffusers.models import AutoencoderKL, AutoencoderDC -from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from transformers import T5GemmaModel, GemmaTokenizerFast # Load transformer -transformer = MirageTransformer2DModel.from_pretrained( +transformer = PhotonTransformer2DModel.from_pretrained( "path/to/checkpoint", subfolder="transformer" ) @@ -78,7 +78,7 @@ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="v # Or DC-AE (32 latent channels): # vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers") -pipe = MiragePipeline( +pipe = PhotonPipeline( transformer=transformer, scheduler=scheduler, text_encoder=text_encoder, @@ -90,7 +90,7 @@ pipe.to("cuda") ## VAE Variants -Mirage supports two VAE configurations: +Photon supports two VAE configurations: ### Flux VAE (AutoencoderKL) - **Compression**: 8x spatial compression @@ -132,21 +132,21 @@ For memory-constrained environments: ```py import torch -from diffusers import MiragePipeline +from diffusers import PhotonPipeline -pipe = MiragePipeline.from_pretrained("path/to/checkpoint", torch_dtype=torch.float16) +pipe = PhotonPipeline.from_pretrained("path/to/checkpoint", torch_dtype=torch.float16) pipe.enable_model_cpu_offload() # Offload components to CPU when not in use # Or use sequential CPU offload for even lower memory pipe.enable_sequential_cpu_offload() ``` -## MiragePipeline +## PhotonPipeline -[[autodoc]] MiragePipeline +[[autodoc]] PhotonPipeline - all - __call__ -## MiragePipelineOutput +## PhotonPipelineOutput -[[autodoc]] pipelines.mirage.pipeline_output.MiragePipelineOutput +[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput diff --git a/scripts/convert_mirage_to_diffusers.py b/scripts/convert_photon_to_diffusers.py similarity index 92% rename from scripts/convert_mirage_to_diffusers.py rename to scripts/convert_photon_to_diffusers.py index 37de253d1448..ad04463e019f 100644 --- a/scripts/convert_mirage_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """ -Script to convert Mirage checkpoint from original codebase to diffusers format. +Script to convert Photon checkpoint from original codebase to diffusers format. """ import argparse @@ -16,14 +16,14 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) -from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel -from diffusers.pipelines.mirage import MiragePipeline +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel +from diffusers.pipelines.photon import PhotonPipeline DEFAULT_HEIGHT = 512 DEFAULT_WIDTH = 512 @dataclass(frozen=True) -class MirageBase: +class PhotonBase: context_in_dim: int = 2304 hidden_size: int = 1792 mlp_ratio: float = 3.5 @@ -36,22 +36,22 @@ class MirageBase: @dataclass(frozen=True) -class MirageFlux(MirageBase): +class PhotonFlux(PhotonBase): in_channels: int = 16 patch_size: int = 2 @dataclass(frozen=True) -class MirageDCAE(MirageBase): +class PhotonDCAE(PhotonBase): in_channels: int = 32 patch_size: int = 1 def build_config(vae_type: str) -> dict: if vae_type == "flux": - cfg = MirageFlux() + cfg = PhotonFlux() elif vae_type == "dc-ae": - cfg = MirageDCAE() + cfg = PhotonDCAE() else: raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") @@ -125,8 +125,8 @@ def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth return converted_state_dict -def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> MirageTransformer2DModel: - """Create and load MirageTransformer2DModel from old checkpoint.""" +def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel: + """Create and load PhotonTransformer2DModel from old checkpoint.""" print(f"Loading checkpoint from: {checkpoint_path}") @@ -154,8 +154,8 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Mi converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth) # Create transformer with config - print("Creating MirageTransformer2DModel...") - transformer = MirageTransformer2DModel(**config) + print("Creating PhotonTransformer2DModel...") + transformer = PhotonTransformer2DModel(**config) # Load state dict print("Loading converted parameters...") @@ -212,13 +212,13 @@ def create_model_index(vae_type: str, output_path: str): text_model_name = "google/t5gemma-2b-2b-ul2" model_index = { - "_class_name": "MiragePipeline", + "_class_name": "PhotonPipeline", "_diffusers_version": "0.31.0.dev0", "_name_or_path": os.path.basename(output_path), "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], "text_encoder": text_model_name, "tokenizer": text_model_name, - "transformer": ["diffusers", "MirageTransformer2DModel"], + "transformer": ["diffusers", "PhotonTransformer2DModel"], "vae": vae_model_name, "vae_subfolder": vae_subfolder, "default_height": default_height, @@ -262,7 +262,7 @@ def main(args): # Verify the pipeline can be loaded try: - pipeline = MiragePipeline.from_pretrained(args.output_path) + pipeline = PhotonPipeline.from_pretrained(args.output_path) print("Pipeline loaded successfully!") print(f"Transformer: {type(pipeline.transformer).__name__}") print(f"VAE: {type(pipeline.vae).__name__}") @@ -285,10 +285,10 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert Mirage checkpoint to diffusers format") + parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format") parser.add_argument( - "--checkpoint_path", type=str, required=True, help="Path to the original Mirage checkpoint (.pth file)" + "--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file)" ) parser.add_argument( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6c419b6e7ad1..4eff8a27ff40 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -224,7 +224,7 @@ "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", "LuminaNextDiT2DModel", - "MirageTransformer2DModel", + "PhotonTransformer2DModel", "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 279e69216b1b..86e32c1eec3e 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -93,7 +93,7 @@ _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] - _import_structure["transformers.transformer_mirage"] = ["MirageTransformer2DModel"] + _import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e4ab33be9784..47cf4fab4a5e 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5605,15 +5605,15 @@ def __new__(cls, *args, **kwargs): return processor -class MirageAttnProcessor2_0: +class PhotonAttnProcessor2_0: r""" - Processor for implementing Mirage-style attention with multi-source tokens and RoPE. - Properly integrates with diffusers Attention module while handling Mirage-specific logic. + Processor for implementing Photon-style attention with multi-source tokens and RoPE. + Properly integrates with diffusers Attention module while handling Photon-specific logic. """ def __init__(self): if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): - raise ImportError("MirageAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") + raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") def __call__( self, @@ -5625,9 +5625,9 @@ def __call__( **kwargs, ) -> torch.Tensor: """ - Apply Mirage attention using standard diffusers interface. + Apply Photon attention using standard diffusers interface. - Expected tensor formats from MirageBlock.attn_forward(): + Expected tensor formats from PhotonBlock.attn_forward(): - hidden_states: Image queries with RoPE applied [B, H, L_img, D] - encoder_hidden_states: Packed key+value tensors [B, H, L_all, 2*D] (concatenated keys and values from text + image + spatial conditioning) @@ -5636,15 +5636,15 @@ def __call__( if encoder_hidden_states is None: raise ValueError( - "MirageAttnProcessor2_0 requires 'encoder_hidden_states' containing packed key+value tensors. " - "This should be provided by MirageBlock.attn_forward()." + "PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing packed key+value tensors. " + "This should be provided by PhotonBlock.attn_forward()." ) # Unpack the combined key+value tensor # encoder_hidden_states is [B, H, L_all, 2*D] containing [keys, values] key, value = encoder_hidden_states.chunk(2, dim=-1) # Each [B, H, L_all, D] - # Apply scaled dot-product attention with Mirage's processed tensors + # Apply scaled dot-product attention with Photon's processed tensors # hidden_states is image queries [B, H, L_img, D] attn_output = torch.nn.functional.scaled_dot_product_attention( hidden_states.contiguous(), key.contiguous(), value.contiguous(), attn_mask=attention_mask @@ -5710,7 +5710,7 @@ def __call__( PAGHunyuanAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, LuminaAttnProcessor2_0, - MirageAttnProcessor2_0, + PhotonAttnProcessor2_0, FusedAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index ebe0d0c9b8e1..652f6d811393 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -29,7 +29,7 @@ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel - from .transformer_mirage import MirageTransformer2DModel + from .transformer_photon import PhotonTransformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_mirage.py b/src/diffusers/models/transformers/transformer_photon.py similarity index 99% rename from src/diffusers/models/transformers/transformer_mirage.py rename to src/diffusers/models/transformers/transformer_photon.py index 90ba11fb2d24..9ec6e9756c20 100644 --- a/src/diffusers/models/transformers/transformer_mirage.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention_processor import Attention, AttentionProcessor, MirageAttnProcessor2_0 +from ..attention_processor import Attention, AttentionProcessor, PhotonAttnProcessor2_0 from ..embeddings import get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -206,7 +206,7 @@ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]: return ModulationOut(*out[:3]), ModulationOut(*out[3:]) -class MirageBlock(nn.Module): +class PhotonBlock(nn.Module): r""" Multimodal transformer block with text–image cross-attention, modulation, and MLP. @@ -304,7 +304,7 @@ def __init__( dim_head=self.head_dim, bias=False, out_bias=False, - processor=MirageAttnProcessor2_0(), + processor=PhotonAttnProcessor2_0(), ) # mlp @@ -538,7 +538,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) -class MirageTransformer2DModel(ModelMixin, ConfigMixin): +class PhotonTransformer2DModel(ModelMixin, ConfigMixin): r""" Transformer-based 2D model for text to image generation. It supports attention processor injection and LoRA scaling. @@ -581,7 +581,7 @@ class MirageTransformer2DModel(ModelMixin, ConfigMixin): txt_in (`nn.Linear`): Projection layer for text conditioning. blocks (`nn.ModuleList`): - Stack of transformer blocks (`MirageBlock`). + Stack of transformer blocks (`PhotonBlock`). final_layer (`LastLayer`): Projection layer mapping hidden tokens back to patch outputs. @@ -656,7 +656,7 @@ def __init__( self.blocks = nn.ModuleList( [ - MirageBlock( + PhotonBlock( self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, @@ -781,7 +781,7 @@ def forward( return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: r""" - Forward pass of the MirageTransformer2DModel. + Forward pass of the PhotonTransformer2DModel. The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of transformer blocks modulated by the timestep. diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7b7ebb633c3b..ae0d90c48c63 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -144,7 +144,7 @@ "FluxKontextPipeline", "FluxKontextInpaintPipeline", ] - _import_structure["mirage"] = ["MiragePipeline"] + _import_structure["photon"] = ["PhotonPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", diff --git a/src/diffusers/pipelines/mirage/__init__.py b/src/diffusers/pipelines/mirage/__init__.py deleted file mode 100644 index cba951057370..000000000000 --- a/src/diffusers/pipelines/mirage/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .pipeline_mirage import MiragePipeline -from .pipeline_output import MiragePipelineOutput - - -__all__ = ["MiragePipeline", "MiragePipelineOutput"] diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/photon/__init__.py new file mode 100644 index 000000000000..d1dd5b2cbf53 --- /dev/null +++ b/src/diffusers/pipelines/photon/__init__.py @@ -0,0 +1,5 @@ +from .pipeline_photon import PhotonPipeline +from .pipeline_output import PhotonPipelineOutput + + +__all__ = ["PhotonPipeline", "PhotonPipelineOutput"] diff --git a/src/diffusers/pipelines/mirage/pipeline_output.py b/src/diffusers/pipelines/photon/pipeline_output.py similarity index 93% rename from src/diffusers/pipelines/mirage/pipeline_output.py rename to src/diffusers/pipelines/photon/pipeline_output.py index e41c8e3bea00..ca0674d94b6c 100644 --- a/src/diffusers/pipelines/mirage/pipeline_output.py +++ b/src/diffusers/pipelines/photon/pipeline_output.py @@ -22,9 +22,9 @@ @dataclass -class MiragePipelineOutput(BaseOutput): +class PhotonPipelineOutput(BaseOutput): """ - Output class for Mirage pipelines. + Output class for Photon pipelines. Args: images (`List[PIL.Image.Image]` or `np.ndarray`) diff --git a/src/diffusers/pipelines/mirage/pipeline_mirage.py b/src/diffusers/pipelines/photon/pipeline_photon.py similarity index 95% rename from src/diffusers/pipelines/mirage/pipeline_mirage.py rename to src/diffusers/pipelines/photon/pipeline_photon.py index ced78adec786..ce3479fedcdd 100644 --- a/src/diffusers/pipelines/mirage/pipeline_mirage.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -31,7 +31,7 @@ from ...image_processor import VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderDC, AutoencoderKL -from ...models.transformers.transformer_mirage import MirageTransformer2DModel, seq2img +from ...models.transformers.transformer_photon import PhotonTransformer2DModel, seq2img from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( logging, @@ -39,7 +39,7 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import MiragePipelineOutput +from .pipeline_output import PhotonPipelineOutput DEFAULT_HEIGHT = 512 @@ -49,7 +49,7 @@ class TextPreprocessor: - """Text preprocessing utility for MiragePipeline.""" + """Text preprocessing utility for PhotonPipeline.""" def __init__(self): """Initialize text preprocessor.""" @@ -179,15 +179,15 @@ def clean_text(self, text: str) -> str: Examples: ```py >>> import torch - >>> from diffusers import MiragePipeline + >>> from diffusers import PhotonPipeline >>> from diffusers.models import AutoencoderKL, AutoencoderDC >>> from transformers import T5GemmaModel, GemmaTokenizerFast >>> # Load pipeline directly with from_pretrained - >>> pipe = MiragePipeline.from_pretrained("path/to/mirage_checkpoint") + >>> pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint") >>> # Or initialize pipeline components manually - >>> transformer = MirageTransformer2DModel.from_pretrained("path/to/transformer") + >>> transformer = PhotonTransformer2DModel.from_pretrained("path/to/transformer") >>> scheduler = FlowMatchEulerDiscreteScheduler() >>> # Load T5Gemma encoder >>> t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") @@ -195,7 +195,7 @@ def clean_text(self, text: str) -> str: >>> tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") >>> vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") - >>> pipe = MiragePipeline( + >>> pipe = PhotonPipeline( ... transformer=transformer, ... scheduler=scheduler, ... text_encoder=text_encoder, @@ -205,26 +205,26 @@ def clean_text(self, text: str) -> str: >>> pipe.to("cuda") >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] - >>> image.save("mirage_output.png") + >>> image.save("photon_output.png") ``` """ -class MiragePipeline( +class PhotonPipeline( DiffusionPipeline, LoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin, ): r""" - Pipeline for text-to-image generation using Mirage Transformer. + Pipeline for text-to-image generation using Photon Transformer. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: - transformer ([`MirageTransformer2DModel`]): - The Mirage transformer model to denoise the encoded image latents. + transformer ([`PhotonTransformer2DModel`]): + The Photon transformer model to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. text_encoder ([`T5EncoderModel`]): @@ -248,7 +248,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P """ Override from_pretrained to load VAE and text encoder from HuggingFace models. - The MiragePipeline checkpoints only store transformer and scheduler locally. + The PhotonPipeline checkpoints only store transformer and scheduler locally. VAE and text encoder are loaded from external HuggingFace models as specified in model_index.json. """ @@ -285,7 +285,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Load transformer and scheduler from local checkpoint logger.info(f"Loading transformer from {pretrained_model_name_or_path}...") - transformer = MirageTransformer2DModel.from_pretrained( + transformer = PhotonTransformer2DModel.from_pretrained( pretrained_model_name_or_path, subfolder="transformer" ) @@ -310,7 +310,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P def __init__( self, - transformer: MirageTransformer2DModel, + transformer: PhotonTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, text_encoder: Union[T5EncoderModel, Any], tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], @@ -318,9 +318,9 @@ def __init__( ): super().__init__() - if MirageTransformer2DModel is None: + if PhotonTransformer2DModel is None: raise ImportError( - "MirageTransformer2DModel is not available. Please ensure the transformer_mirage module is properly installed." + "PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed." ) self.text_encoder = text_encoder @@ -544,7 +544,7 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.mirage.MiragePipelineOutput`] instead of a plain tuple. + Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`. @@ -557,7 +557,7 @@ def __call__( Examples: Returns: - [`~pipelines.mirage.MiragePipelineOutput`] or `tuple`: [`~pipelines.mirage.MiragePipelineOutput`] if + [`~pipelines.photon.PhotonPipelineOutput`] or `tuple`: [`~pipelines.photon.PhotonPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. """ @@ -683,4 +683,4 @@ def __call__( if not return_dict: return (image,) - return MiragePipelineOutput(images=image) + return PhotonPipelineOutput(images=image) diff --git a/tests/models/transformers/test_models_transformer_mirage.py b/tests/models/transformers/test_models_transformer_photon.py similarity index 95% rename from tests/models/transformers/test_models_transformer_mirage.py rename to tests/models/transformers/test_models_transformer_photon.py index fe7436debc4c..2f08484d230c 100644 --- a/tests/models/transformers/test_models_transformer_mirage.py +++ b/tests/models/transformers/test_models_transformer_photon.py @@ -17,7 +17,7 @@ import torch -from diffusers.models.transformers.transformer_mirage import MirageTransformer2DModel +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin @@ -26,8 +26,8 @@ enable_full_determinism() -class MirageTransformerTests(ModelTesterMixin, unittest.TestCase): - model_class = MirageTransformer2DModel +class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = PhotonTransformer2DModel main_input_name = "image_latent" @property @@ -92,7 +92,7 @@ def test_forward_signature(self): def test_model_initialization(self): # Test model initialization - model = MirageTransformer2DModel( + model = PhotonTransformer2DModel( in_channels=16, patch_size=2, context_in_dim=1792, @@ -121,7 +121,7 @@ def test_model_with_dict_config(self): "theta": 10_000, } - model = MirageTransformer2DModel.from_config(config_dict) + model = PhotonTransformer2DModel.from_config(config_dict) self.assertEqual(model.config.in_channels, 16) self.assertEqual(model.config.hidden_size, 1792) @@ -193,7 +193,7 @@ def test_attention_mask(self): def test_invalid_config(self): # Test invalid configuration - hidden_size not divisible by num_heads with self.assertRaises(ValueError): - MirageTransformer2DModel( + PhotonTransformer2DModel( in_channels=16, patch_size=2, context_in_dim=1792, @@ -207,7 +207,7 @@ def test_invalid_config(self): # Test invalid axes_dim that doesn't sum to pe_dim with self.assertRaises(ValueError): - MirageTransformer2DModel( + PhotonTransformer2DModel( in_channels=16, patch_size=2, context_in_dim=1792, From fe0e3d5e6502f172b46761f396d47214452ef59b Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Wed, 8 Oct 2025 13:30:53 +0000 Subject: [PATCH 35/38] add text tower and vae in checkpoint --- scripts/convert_photon_to_diffusers.py | 88 +++++++++--- .../pipelines/photon/pipeline_photon.py | 127 ++++-------------- 2 files changed, 93 insertions(+), 122 deletions(-) diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py index ad04463e019f..8e182bf182d0 100644 --- a/scripts/convert_photon_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -19,8 +19,7 @@ from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.pipelines.photon import PhotonPipeline -DEFAULT_HEIGHT = 512 -DEFAULT_WIDTH = 512 +DEFAULT_RESOLUTION = 512 @dataclass(frozen=True) class PhotonBase: @@ -47,16 +46,19 @@ class PhotonDCAE(PhotonBase): patch_size: int = 1 -def build_config(vae_type: str) -> dict: +def build_config(vae_type: str, resolution: int = DEFAULT_RESOLUTION) -> dict: if vae_type == "flux": cfg = PhotonFlux() + sample_size = resolution // 8 elif vae_type == "dc-ae": cfg = PhotonDCAE() + sample_size = resolution // 32 else: raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") config_dict = asdict(cfg) config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index] + config_dict["sample_size"] = sample_size return config_dict @@ -194,35 +196,64 @@ def create_scheduler_config(output_path: str): -def create_model_index(vae_type: str, output_path: str): - """Create model_index.json for the pipeline with HuggingFace model references.""" +def download_and_save_vae(vae_type: str, output_path: str): + """Download and save VAE to local directory.""" + from diffusers import AutoencoderKL, AutoencoderDC + + vae_path = os.path.join(output_path, "vae") + os.makedirs(vae_path, exist_ok=True) if vae_type == "flux": - vae_model_name = "black-forest-labs/FLUX.1-dev" - vae_subfolder = "vae" - default_height = DEFAULT_HEIGHT - default_width = DEFAULT_WIDTH + print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...") + vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") else: # dc-ae - vae_model_name = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers" - vae_subfolder = None - default_height = DEFAULT_HEIGHT - default_width = DEFAULT_WIDTH + print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers...") + vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers") + + vae.save_pretrained(vae_path) + print(f"✓ Saved VAE to {vae_path}") + + +def download_and_save_text_encoder(output_path: str): + """Download and save T5Gemma text encoder and tokenizer.""" + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel + from transformers import GemmaTokenizerFast + + text_encoder_path = os.path.join(output_path, "text_encoder") + tokenizer_path = os.path.join(output_path, "tokenizer") + os.makedirs(text_encoder_path, exist_ok=True) + os.makedirs(tokenizer_path, exist_ok=True) + + print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...") + t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") + + t5gemma_model.save_pretrained(text_encoder_path) + print(f"✓ Saved T5Gemma model to {text_encoder_path}") + + print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...") + tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") + tokenizer.model_max_length = 256 + tokenizer.save_pretrained(tokenizer_path) + print(f"✓ Saved tokenizer to {tokenizer_path}") - # Text encoder and tokenizer always use T5Gemma - text_model_name = "google/t5gemma-2b-2b-ul2" + +def create_model_index(vae_type: str, output_path: str): + """Create model_index.json for the pipeline.""" + + if vae_type == "flux": + vae_class = "AutoencoderKL" + else: # dc-ae + vae_class = "AutoencoderDC" model_index = { "_class_name": "PhotonPipeline", "_diffusers_version": "0.31.0.dev0", "_name_or_path": os.path.basename(output_path), "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], - "text_encoder": text_model_name, - "tokenizer": text_model_name, + "text_encoder": ["transformers", "T5GemmaModel"], + "tokenizer": ["transformers", "GemmaTokenizerFast"], "transformer": ["diffusers", "PhotonTransformer2DModel"], - "vae": vae_model_name, - "vae_subfolder": vae_subfolder, - "default_height": default_height, - "default_width": default_width, + "vae": ["diffusers", vae_class], } model_index_path = os.path.join(output_path, "model_index.json") @@ -234,7 +265,7 @@ def main(args): if not os.path.exists(args.checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}") - config = build_config(args.vae_type) + config = build_config(args.vae_type, args.resolution) # Create output directory os.makedirs(args.output_path, exist_ok=True) @@ -256,8 +287,13 @@ def main(args): save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors")) print(f"✓ Saved transformer to {transformer_path}") + # Create scheduler config create_scheduler_config(args.output_path) + download_and_save_vae(args.vae_type, args.output_path) + download_and_save_text_encoder(args.output_path) + + # Create model_index.json create_model_index(args.vae_type, args.output_path) # Verify the pipeline can be loaded @@ -303,6 +339,14 @@ def main(args): help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)", ) + parser.add_argument( + "--resolution", + type=int, + choices=[256, 512, 1024], + default=DEFAULT_RESOLUTION, + help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.", + ) + args = parser.parse_args() try: diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index ce3479fedcdd..6272d7d8ae77 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -28,18 +28,18 @@ T5TokenizerFast, ) -from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderDC, AutoencoderKL -from ...models.transformers.transformer_photon import PhotonTransformer2DModel, seq2img -from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderDC, AutoencoderKL +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel, seq2img +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( logging, replace_example_docstring, ) -from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import PhotonPipelineOutput +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput DEFAULT_HEIGHT = 512 @@ -180,29 +180,11 @@ def clean_text(self, text: str) -> str: ```py >>> import torch >>> from diffusers import PhotonPipeline - >>> from diffusers.models import AutoencoderKL, AutoencoderDC - >>> from transformers import T5GemmaModel, GemmaTokenizerFast - >>> # Load pipeline directly with from_pretrained + >>> # Load pipeline with from_pretrained >>> pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint") - - >>> # Or initialize pipeline components manually - >>> transformer = PhotonTransformer2DModel.from_pretrained("path/to/transformer") - >>> scheduler = FlowMatchEulerDiscreteScheduler() - >>> # Load T5Gemma encoder - >>> t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") - >>> text_encoder = t5gemma_model.encoder - >>> tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") - >>> vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") - - >>> pipe = PhotonPipeline( - ... transformer=transformer, - ... scheduler=scheduler, - ... text_encoder=text_encoder, - ... tokenizer=tokenizer, - ... vae=vae - ... ) >>> pipe.to("cuda") + >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] >>> image.save("photon_output.png") @@ -240,74 +222,6 @@ class PhotonPipeline( _callback_tensor_inputs = ["latents"] _optional_components = [] - # Component configurations for automatic loading - config_name = "model_index.json" - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): - """ - Override from_pretrained to load VAE and text encoder from HuggingFace models. - - The PhotonPipeline checkpoints only store transformer and scheduler locally. - VAE and text encoder are loaded from external HuggingFace models as specified - in model_index.json. - """ - import json - from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel - - model_index_path = os.path.join(pretrained_model_name_or_path, "model_index.json") - if not os.path.exists(model_index_path): - raise ValueError(f"model_index.json not found in {pretrained_model_name_or_path}") - - with open(model_index_path, "r") as f: - model_index = json.load(f) - - vae_model_name = model_index.get("vae") - vae_subfolder = model_index.get("vae_subfolder") - text_model_name = model_index.get("text_encoder") - tokenizer_model_name = model_index.get("tokenizer") - default_height = model_index.get("default_height", DEFAULT_HEIGHT) - default_width = model_index.get("default_width", DEFAULT_WIDTH) - - logger.info(f"Loading VAE from {vae_model_name}...") - if "FLUX" in vae_model_name or "flux" in vae_model_name: - vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder=vae_subfolder) - else: # DC-AE - vae = AutoencoderDC.from_pretrained(vae_model_name) - - logger.info(f"Loading text encoder from {text_model_name}...") - t5gemma_model = T5GemmaModel.from_pretrained(text_model_name) - text_encoder = t5gemma_model.encoder - - logger.info(f"Loading tokenizer from {tokenizer_model_name}...") - tokenizer = GemmaTokenizerFast.from_pretrained(tokenizer_model_name) - tokenizer.model_max_length = 256 - - # Load transformer and scheduler from local checkpoint - logger.info(f"Loading transformer from {pretrained_model_name_or_path}...") - transformer = PhotonTransformer2DModel.from_pretrained( - pretrained_model_name_or_path, subfolder="transformer" - ) - - logger.info(f"Loading scheduler from {pretrained_model_name_or_path}...") - scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - pretrained_model_name_or_path, subfolder="scheduler" - ) - - pipeline = cls( - transformer=transformer, - scheduler=scheduler, - text_encoder=text_encoder, - tokenizer=tokenizer, - vae=vae, - ) - - # Store default dimensions as pipeline attributes - pipeline.default_height = default_height - pipeline.default_width = default_width - - return pipeline - def __init__( self, transformer: PhotonTransformer2DModel, @@ -323,6 +237,10 @@ def __init__( "PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed." ) + # Extract encoder if text_encoder is T5GemmaModel + if hasattr(text_encoder, 'encoder'): + text_encoder = text_encoder.encoder + self.text_encoder = text_encoder self.tokenizer = tokenizer self.text_preprocessor = TextPreprocessor() @@ -337,7 +255,16 @@ def __init__( # Enhance VAE with universal properties for both AutoencoderKL and AutoencoderDC self._enhance_vae_properties() - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + + # Set image processor using vae_scale_factor property + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Set default dimensions from transformer config + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None and hasattr(self.transformer.config, "sample_size") + else 64 + ) def _enhance_vae_properties(self): """Add universal properties to VAE for consistent interface across AutoencoderKL and AutoencoderDC.""" @@ -563,8 +490,8 @@ def __call__( """ # 0. Default height and width to transformer config - height = height or getattr(self, 'default_height', DEFAULT_HEIGHT) - width = width or getattr(self, 'default_width', DEFAULT_WIDTH) + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs self.check_inputs( From 855b068997f965003071411b005fa981ae2a6d49 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Wed, 8 Oct 2025 13:31:15 +0000 Subject: [PATCH 36/38] update doc --- docs/source/en/api/pipelines/photon.md | 41 +++++++++++++++++++------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index f8f7098545f8..e0f963b148d2 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -18,7 +18,7 @@ LoRA -Photon is a text-to-image diffusion model using a transformer-based architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports both Flux VAE (AutoencoderKL) and DC-AE (AutoencoderDC) for latent compression. +Photon is a text-to-image diffusion model using simplified MMDIT architecture with flow matching for efficient high-quality image generation. The model uses T5Gemma as the text encoder and supports either Flux VAE (AutoencoderKL) or DC-AE (AutoencoderDC) for latent compression. Key features: @@ -28,19 +28,37 @@ Key features: - **T5Gemma Text Encoder**: Uses Google's T5Gemma-2B-2B-UL2 model for text encoding offering multiple language support - **Efficient Architecture**: ~1.3B parameters in the transformer, enabling fast inference while maintaining quality +## Available models: +We offer a range of **Photon models** featuring different **VAE configurations**, each optimized for generating images at various resolutions. +Both **fine-tuned** and **non-fine-tuned** versions are available: + +- **Non-fine-tuned models** perform best with **highly detailed prompts**, capturing fine nuances and complex compositions. +- **Fine-tuned models**, trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist), enhance the **aesthetic quality** of the base models—especially when prompts are **less detailed**. + + +| Model | Recommended dtype | Resolution | Fine-tuned | +|:-----:|:-----------------:|:----------:|:----------:| +| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i) | `torch.bfloat16` | 256x256 | No | +| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft) | `torch.bfloat16` | 256x256 | Yes | +| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i) | `torch.bfloat16` | 512x512 | No | +| [`Photoroom/photon-512-t2i-sft`](hhttps://huggingface.co/Photoroom/photon-512-t2i-sft) | `torch.bfloat16` | 512x512 | Yes | +| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae) | `torch.bfloat16` | 512x512 | No | +| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft) | `torch.bfloat16` | 512x512 | Yes | + +Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information. ## Loading the Pipeline Photon checkpoints only store the transformer and scheduler weights locally. The VAE and text encoder are automatically loaded from HuggingFace during pipeline initialization: ```py -from diffusers import PhotonPipeline +from diffusers.pipelines.photon import PhotonPipeline # Load pipeline - VAE and text encoder will be loaded from HuggingFace -pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint") +pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i") pipe.to("cuda") -prompt = "A vibrant night sky filled with colorful fireworks, with one large firework burst forming the glowing text “Photon” in bright, sparkling light" +prompt = "A highly detailed 3D animated scene of a cute, intelligent duck scientist in a futuristic laboratory. The duck stands on a shiny metallic floor surrounded by glowing glass tubes filled with colorful liquids—blue, green, and purple—connected by translucent hoses emitting soft light. The duck wears a tiny white lab coat, safety goggles, and has a curious, determined expression while conducting an experiment. Sparks of energy and soft particle effects fill the air as scientific instruments hum with power. In the background, holographic screens display molecular diagrams and equations. Above the duck’s head, the word “PHOTON” glows vividly in midair as if made of pure light, illuminating the scene with a warm golden glow. The lighting is cinematic, with rich reflections and subtle depth of field, emphasizing a Pixar-like, ultra-polished 3D animation style. Rendered in ultra high resolution, realistic subsurface scattering on the duck’s feathers, and vibrant color grading that gives a sense of wonder and scientific discovery." image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] image.save("photon_output.png") ``` @@ -59,12 +77,12 @@ from transformers import T5GemmaModel, GemmaTokenizerFast # Load transformer transformer = PhotonTransformer2DModel.from_pretrained( - "path/to/checkpoint", subfolder="transformer" + "Photoroom/photon-512-t2i", subfolder="transformer" ) # Load scheduler scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - "path/to/checkpoint", subfolder="scheduler" + "Photoroom/photon-512-t2i", subfolder="scheduler" ) # Load T5Gemma text encoder @@ -116,8 +134,11 @@ Key parameters for image generation: ```py # Example with custom parameters -image = pipe( - prompt="A vibrant night sky filled with colorful fireworks, with one large firework burst forming the glowing text “Photon” in bright, sparkling light", +import torch +from diffusers.pipelines.photon import PhotonPipeline + +pipe = pipe( + prompt="A highly detailed 3D animated scene of a cute, intelligent duck scientist in a futuristic laboratory. The duck stands on a shiny metallic floor surrounded by glowing glass tubes filled with colorful liquids—blue, green, and purple—connected by translucent hoses emitting soft light. The duck wears a tiny white lab coat, safety goggles, and has a curious, determined expression while conducting an experiment. Sparks of energy and soft particle effects fill the air as scientific instruments hum with power. In the background, holographic screens display molecular diagrams and equations. Above the duck’s head, the word “PHOTON” glows vividly in midair as if made of pure light, illuminating the scene with a warm golden glow. The lighting is cinematic, with rich reflections and subtle depth of field, emphasizing a Pixar-like, ultra-polished 3D animation style. Rendered in ultra high resolution, realistic subsurface scattering on the duck’s feathers, and vibrant color grading that gives a sense of wonder and scientific discovery.", num_inference_steps=28, guidance_scale=4.0, height=512, @@ -132,9 +153,9 @@ For memory-constrained environments: ```py import torch -from diffusers import PhotonPipeline +from diffusers.pipelines.photon import PhotonPipeline -pipe = PhotonPipeline.from_pretrained("path/to/checkpoint", torch_dtype=torch.float16) +pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i", torch_dtype=torch.float16) pipe.enable_model_cpu_offload() # Offload components to CPU when not in use # Or use sequential CPU offload for even lower memory From 89beae8286774a257859d26bb316bcb1db4c9cf4 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Wed, 8 Oct 2025 16:15:52 +0200 Subject: [PATCH 37/38] update photon doc --- docs/source/en/api/pipelines/photon.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md index e0f963b148d2..b78219f8b214 100644 --- a/docs/source/en/api/pipelines/photon.md +++ b/docs/source/en/api/pipelines/photon.md @@ -68,9 +68,8 @@ image.save("photon_output.png") You can also load components individually: ```py -import torch -from diffusers import PhotonPipeline -from diffusers.models import AutoencoderKL, AutoencoderDC +from diffusers.pipelines.photon import PhotonPipeline +from diffusers.models import AutoencoderKL from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from transformers import T5GemmaModel, GemmaTokenizerFast From 2df0e2f74db7c0afd5df9bf57ec1ca9601635d52 Mon Sep 17 00:00:00 2001 From: David Bertoin Date: Wed, 8 Oct 2025 14:21:06 +0000 Subject: [PATCH 38/38] ruff fixes --- scripts/convert_photon_to_diffusers.py | 22 +++---- src/diffusers/models/transformers/__init__.py | 2 +- .../models/transformers/transformer_photon.py | 62 +++++++++---------- src/diffusers/pipelines/photon/__init__.py | 2 +- .../pipelines/photon/pipeline_photon.py | 11 ++-- 5 files changed, 47 insertions(+), 52 deletions(-) diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py index 8e182bf182d0..0dd114a68997 100644 --- a/scripts/convert_photon_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -7,11 +7,11 @@ import json import os import sys +from dataclasses import asdict, dataclass +from typing import Dict, Tuple import torch from safetensors.torch import save_file -from dataclasses import dataclass, asdict -from typing import Tuple, Dict sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) @@ -19,8 +19,10 @@ from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.pipelines.photon import PhotonPipeline + DEFAULT_RESOLUTION = 512 + @dataclass(frozen=True) class PhotonBase: context_in_dim: int = 2304 @@ -62,7 +64,6 @@ def build_config(vae_type: str, resolution: int = DEFAULT_RESOLUTION) -> dict: return config_dict - def create_parameter_mapping(depth: int) -> dict: """Create mapping from old parameter names to new diffusers names.""" @@ -174,16 +175,10 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Ph return transformer - - def create_scheduler_config(output_path: str): """Create FlowMatchEulerDiscreteScheduler config.""" - scheduler_config = { - "_class_name": "FlowMatchEulerDiscreteScheduler", - "num_train_timesteps": 1000, - "shift": 1.0 - } + scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": 1.0} scheduler_path = os.path.join(output_path, "scheduler") os.makedirs(scheduler_path, exist_ok=True) @@ -194,11 +189,9 @@ def create_scheduler_config(output_path: str): print("✓ Created scheduler config") - - def download_and_save_vae(vae_type: str, output_path: str): """Download and save VAE to local directory.""" - from diffusers import AutoencoderKL, AutoencoderDC + from diffusers import AutoencoderDC, AutoencoderKL vae_path = os.path.join(output_path, "vae") os.makedirs(vae_path, exist_ok=True) @@ -216,8 +209,8 @@ def download_and_save_vae(vae_type: str, output_path: str): def download_and_save_text_encoder(output_path: str): """Download and save T5Gemma text encoder and tokenizer.""" - from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel from transformers import GemmaTokenizerFast + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel text_encoder_path = os.path.join(output_path, "text_encoder") tokenizer_path = os.path.join(output_path, "tokenizer") @@ -260,6 +253,7 @@ def create_model_index(vae_type: str, output_path: str): with open(model_index_path, "w") as f: json.dump(model_index, f, indent=2) + def main(args): # Validate inputs if not os.path.exists(args.checkpoint_path): diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 652f6d811393..7fdab560a702 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -29,9 +29,9 @@ from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel - from .transformer_photon import PhotonTransformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel + from .transformer_photon import PhotonTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index 9ec6e9756c20..452be9d2e6df 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -54,7 +54,7 @@ def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, dev Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the image grid. """ - + img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device) img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None] img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :] @@ -69,7 +69,7 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: xq (`torch.Tensor`): Input tensor of shape `(..., dim)` representing the queries. freqs_cis (`torch.Tensor`): - Precomputed rotary frequency components of shape `(..., dim/2, 2)` + Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs. Returns: @@ -88,7 +88,7 @@ class EmbedND(nn.Module): This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding dimension. The embeddings are combined and returned as a single tensor - + Parameters: dim (int): Base embedding dimension (must be even). @@ -97,6 +97,7 @@ class EmbedND(nn.Module): axes_dim (list[int]): List of embedding dimensions for each axis (each must be even). """ + def __init__(self, dim: int, theta: int, axes_dim: list[int]): super().__init__() self.dim = dim @@ -136,6 +137,7 @@ class MLPEmbedder(nn.Module): `torch.Tensor`: Tensor of shape `(..., hidden_dim)` containing the embedded representations. """ + def __init__(self, in_dim: int, hidden_dim: int): super().__init__() self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) @@ -160,6 +162,7 @@ class QKNorm(torch.nn.Module): A tuple `(q, k)` where both are normalized and cast to the same dtype as the value tensor `v`. """ + def __init__(self, dim: int): super().__init__() self.query_norm = RMSNorm(dim, eps=1e-6) @@ -195,6 +198,7 @@ class Modulation(nn.Module): A tuple of two modulation outputs. Each `ModulationOut` contains three components (e.g., scale, shift, gate). """ + def __init__(self, dim: int): super().__init__() self.lin = nn.Linear(dim, 6 * dim, bias=True) @@ -269,6 +273,7 @@ class PhotonBlock(nn.Module): `torch.Tensor`: Attention output of shape `(B, L_img, hidden_size)`. """ + def __init__( self, hidden_size: int, @@ -363,9 +368,7 @@ def _attn_forward( if attention_mask.dim() != 2: raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") if attention_mask.shape[-1] != l_txt: - raise ValueError( - f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" - ) + raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}") device = img_q.device @@ -407,31 +410,30 @@ def forward( **_: dict[str, Any], ) -> Tensor: r""" - Runs modulation-gated cross-attention and MLP, with residual connections. - - Parameters: - img (`torch.Tensor`): - Image tokens of shape `(B, L_img, hidden_size)`. - txt (`torch.Tensor`): - Text tokens of shape `(B, L_txt, hidden_size)`. - vec (`torch.Tensor`): - Conditioning vector used by `Modulation` to produce scale/shift/gates, - shape `(B, hidden_size)` (or broadcastable). - pe (`torch.Tensor`): - Rotary positional embeddings applied inside attention. - spatial_conditioning (`torch.Tensor`, *optional*): - Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. Used only - if spatial conditioning is enabled in the block. - attention_mask (`torch.Tensor`, *optional*): - Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding. - **_: - Ignored additional keyword arguments for API compatibility. + Runs modulation-gated cross-attention and MLP, with residual connections. - Returns: - `torch.Tensor`: - Updated image tokens of shape `(B, L_img, hidden_size)`. - """ + Parameters: + img (`torch.Tensor`): + Image tokens of shape `(B, L_img, hidden_size)`. + txt (`torch.Tensor`): + Text tokens of shape `(B, L_txt, hidden_size)`. + vec (`torch.Tensor`): + Conditioning vector used by `Modulation` to produce scale/shift/gates, + shape `(B, hidden_size)` (or broadcastable). + pe (`torch.Tensor`): + Rotary positional embeddings applied inside attention. + spatial_conditioning (`torch.Tensor`, *optional*): + Extra conditioning tokens of shape `(B, L_cond, hidden_size)`. Used only + if spatial conditioning is enabled in the block. + attention_mask (`torch.Tensor`, *optional*): + Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding. + **_: + Ignored additional keyword arguments for API compatibility. + Returns: + `torch.Tensor`: + Updated image tokens of shape `(B, L_img, hidden_size)`. + """ mod_attn, mod_mlp = self.modulation(vec) @@ -475,7 +477,6 @@ class LastLayer(nn.Module): """ def __init__(self, hidden_size: int, patch_size: int, out_channels: int): - super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) @@ -727,7 +728,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): fn_recursive_attn_processor(name, module, processor) def _process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: - txt = self.txt_in(txt) img = img2seq(image_latent, self.patch_size) bs, _, h, w = image_latent.shape diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/photon/__init__.py index d1dd5b2cbf53..559c9d0b1d2d 100644 --- a/src/diffusers/pipelines/photon/__init__.py +++ b/src/diffusers/pipelines/photon/__init__.py @@ -1,5 +1,5 @@ -from .pipeline_photon import PhotonPipeline from .pipeline_output import PhotonPipelineOutput +from .pipeline_photon import PhotonPipeline __all__ = ["PhotonPipeline", "PhotonPipelineOutput"] diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index 6272d7d8ae77..0fc926261517 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -14,7 +14,6 @@ import html import inspect -import os import re import urllib.parse as ul from typing import Any, Callable, Dict, List, Optional, Union @@ -32,14 +31,14 @@ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderDC, AutoencoderKL from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel, seq2img +from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( logging, replace_example_docstring, ) from diffusers.utils.torch_utils import randn_tensor -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput DEFAULT_HEIGHT = 512 @@ -238,7 +237,7 @@ def __init__( ) # Extract encoder if text_encoder is T5GemmaModel - if hasattr(text_encoder, 'encoder'): + if hasattr(text_encoder, "encoder"): text_encoder = text_encoder.encoder self.text_encoder = text_encoder @@ -262,7 +261,9 @@ def __init__( # Set default dimensions from transformer config self.default_sample_size = ( self.transformer.config.sample_size - if hasattr(self, "transformer") and self.transformer is not None and hasattr(self.transformer.config, "sample_size") + if hasattr(self, "transformer") + and self.transformer is not None + and hasattr(self.transformer.config, "sample_size") else 64 )