From d22bfeca1e7d1882b57a313685e545dc9ec31788 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Fri, 20 Feb 2026 16:37:42 +0800 Subject: [PATCH 01/13] update --- fastgen/networks/LTX2/network.py | 357 +++++++++++++++++++++++++++++++ tests/test_ltx2.py | 72 +++++++ 2 files changed, 429 insertions(+) create mode 100644 fastgen/networks/LTX2/network.py create mode 100644 tests/test_ltx2.py diff --git a/fastgen/networks/LTX2/network.py b/fastgen/networks/LTX2/network.py new file mode 100644 index 0000000..cfeed80 --- /dev/null +++ b/fastgen/networks/LTX2/network.py @@ -0,0 +1,357 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Any, Optional, List, Set, Union, Tuple +import types + +import torch +import torch.nn as nn +from torch import dtype +from torch.distributed.fsdp import fully_shard + +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.models import LTXVideoTransformer3DModel, AutoencoderKLLTX2Video +from diffusers.models.transformers.transformer_ltx2 import LTX2VideoTransformerBlock +from diffusers.pipelines.ltx2 import LTX2TextConnectors +from transformers import GemmaTokenizer, Gemma3ForConditionalGeneration + +from fastgen.networks.network import FastGenNetwork +from fastgen.networks.noise_schedule import NET_PRED_TYPES +from fastgen.utils.basic_utils import str2bool +from fastgen.utils.distributed.fsdp import apply_fsdp_checkpointing +import fastgen.utils.logging_utils as logger + + +class LTX2TextEncoder: + """Text encoder for LTX-2 using Gemma 3.""" + + def __init__(self, model_id: str): + self.tokenizer = GemmaTokenizer.from_pretrained( + model_id, + cache_dir=os.environ["HF_HOME"], + subfolder="tokenizer", + local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), + ) + self.text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + model_id, + cache_dir=os.environ["HF_HOME"], + subfolder="text_encoder", + local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), + ) + self.text_encoder.eval().requires_grad_(False) + + def encode( + self, + conditioning: Optional[Any] = None, + precision: dtype = torch.float32, + max_sequence_length: int = 512, + ) -> torch.Tensor: + """Encode text prompts to raw Gemma 3 embeddings.""" + if isinstance(conditioning, str): + conditioning = [conditioning] + + text_inputs = self.tokenizer( + conditioning, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + with torch.no_grad(): + outputs = self.text_encoder( + input_ids=text_inputs.input_ids.to(self.text_encoder.device), + attention_mask=text_inputs.attention_mask.to(self.text_encoder.device), + output_hidden_states=True, + return_dict=True, + ) + # Raw Gemma hidden states to be projected by connectors + prompt_embeds = outputs.hidden_states[-1].to(precision) + + return prompt_embeds + + def to(self, *args, **kwargs): + self.text_encoder.to(*args, **kwargs) + return self + + +class LTX2VideoEncoder: + """Spatio-temporal VAE encoder/decoder for LTX-2.""" + + def __init__(self, model_id: str): + self.vae: AutoencoderKLLTX2Video = AutoencoderKLLTX2Video.from_pretrained( + model_id, + cache_dir=os.environ["HF_HOME"], + subfolder="vae", + local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), + ) + self.vae.eval().requires_grad_(False) + + # LTX-2 normalization constants + self.scaling_factor = getattr(self.vae.config, "scaling_factor", 1.5305) + self.shift_factor = getattr(self.vae.config, "shift_factor", 0.0609) + + def encode(self, real_video: torch.Tensor) -> torch.Tensor: + """Encode videos to 3D latent space.""" + latents = self.vae.encode(real_video, return_dict=False)[0].sample() + latents = (latents - self.shift_factor) * self.scaling_factor + return latents + + def decode(self, latents: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor: + """Decode latents to video frames with optional timestep conditioning.""" + latents = (latents / self.scaling_factor) + self.shift_factor + video = self.vae.decode(latents, timestep=timestep, return_dict=False)[0] + return video.clip(-1.0, 1.0) + + def to(self, *args, **kwargs): + self.vae.to(*args, **kwargs) + return self + + +def classify_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timesteps: torch.Tensor, + indices: torch.Tensor, + encoder_attention_mask: Optional[torch.Tensor] = None, + audio_hidden_states: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, + return_features_early: bool = False, + feature_indices: Optional[Set[int]] = None, + return_logvar: bool = False, + **kwargs, +) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + """Patched forward pass for LTXVideoTransformer3DModel.""" + if feature_indices is None: + feature_indices = set() + + # 1. Text Connector Projection + # Projects Gemma embeddings for video and audio streams + encoder_hidden_states, audio_encoder_hidden_states = self.connectors( + encoder_hidden_states + ) + + # 2. Time & Text Embeddings + temb = self.time_text_embed(timesteps, encoder_hidden_states) + + # LTX-2 also prepares a separate audio time embedding + # Here we simplify or derive it if necessary + temb_audio = temb + + # 3. Derive Modulation Parameters for Cross-Attention + # These are typically linear projections of temb/temb_audio + # Based on the user's Turn 10 snippet, these must be passed to the blocks + video_ca_scale_shift = self.video_ca_modulation(temb) + audio_ca_scale_shift = self.audio_ca_modulation(temb_audio) + video_ca_a2v_gate = self.video_ca_gate(temb) + audio_ca_v2a_gate = self.audio_ca_gate(temb_audio) + + # 4. Positional Embeddings + video_rotary_emb = self.pos_embed(indices) + audio_rotary_emb = None # Placeholder if audio is zeroed + + if audio_hidden_states is None: + audio_hidden_states = torch.zeros_like(hidden_states[:, :0]) + + idx, features = 0, [] + + # 5. Dual-Stream Transformer Blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, audio_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + audio_hidden_states, + encoder_hidden_states, + audio_encoder_hidden_states, + temb, + temb_audio, + video_ca_scale_shift, + audio_ca_scale_shift, + video_ca_a2v_gate, + audio_ca_v2a_gate, + video_rotary_emb, + audio_rotary_emb, + None, # ca_video_rotary_emb + None, # ca_audio_rotary_emb + encoder_attention_mask, + audio_encoder_attention_mask, + ) + else: + hidden_states, audio_hidden_states = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_ca_scale_shift, + temb_ca_audio_scale_shift=audio_ca_scale_shift, + temb_ca_gate=video_ca_a2v_gate, + temb_ca_audio_gate=audio_ca_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + ) + + if idx in feature_indices: + features.append(hidden_states.clone()) + if return_features_early and len(features) == len(feature_indices): + return features + idx += 1 + + # 6. Final Projection + output = self.proj_out(hidden_states, temb) + + if return_features_early: + return features + out = output if len(feature_indices) == 0 else [output, features] + + if return_logvar: + # temb_dim = 4096 + logvar = self.logvar_linear(temb) + return out, logvar + + return out + + +class LTX2(FastGenNetwork): + """LTX-2 network for text-to-video generation.""" + + MODEL_ID = "Lightricks/LTX-2" + + def __init__( + self, + model_id: str = MODEL_ID, + net_pred_type: str = "flow", + schedule_type: str = "rf", + disable_grad_ckpt: bool = False, + load_pretrained: bool = True, + **model_kwargs, + ): + super().__init__(net_pred_type=net_pred_type, schedule_type=schedule_type, **model_kwargs) + self.model_id = model_id + self._initialize_network(model_id, load_pretrained) + self.transformer.forward = types.MethodType(classify_forward, self.transformer) + + if disable_grad_ckpt: + self.transformer.disable_gradient_checkpointing() + else: + self.transformer.enable_gradient_checkpointing() + + def _initialize_network(self, model_id: str, load_pretrained: bool) -> None: + in_meta_context = self._is_in_meta_context() + + if load_pretrained and not in_meta_context: + logger.info(f"Loading LTX-2 transformer and connectors from {model_id}") + self.transformer = LTXVideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer" + ) + # Load LTX2TextConnectors + self.transformer.connectors = LTX2TextConnectors.from_pretrained( + model_id, subfolder="connectors" + ) + else: + config = LTXVideoTransformer3DModel.load_config(model_id, subfolder="transformer") + self.transformer = LTXVideoTransformer3DModel.from_config(config) + conn_config = LTX2TextConnectors.load_config(model_id, subfolder="connectors") + self.transformer.connectors = LTX2TextConnectors.from_config(conn_config) + + self.transformer.logvar_linear = nn.Linear(4096, 1) + + def _calculate_shift(self, sequence_length: int) -> float: + """Resolution-dependent shift.""" + base_seq_len, max_seq_len = 1024, 4096 + base_shift, max_shift = 0.5, 1.15 + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + return sequence_length * m + b + + @torch.no_grad() + def sample( + self, + noise: torch.Tensor, + condition: Optional[torch.Tensor] = None, + neg_condition: Optional[torch.Tensor] = None, + guidance_scale: Optional[float] = 4.0, + num_steps: int = 40, + **kwargs, + ) -> torch.Tensor: + """Generate video samples using Euler flow matching.""" + batch_size, _, frames, height, width = noise.shape + mu = self._calculate_shift(frames * height * width) + + scheduler = FlowMatchEulerDiscreteScheduler(shift=mu) + scheduler.set_timesteps(num_steps, device=noise.device) + timesteps, latents = scheduler.timesteps, noise.clone() + + for timestep in timesteps: + t = (timestep / 1000.0).expand(batch_size).to(dtype=noise.dtype, device=noise.device) + t = self.noise_scheduler.safe_clamp(t, min=self.noise_scheduler.min_t, max=self.noise_scheduler.max_t) + + if neg_condition is not None and guidance_scale is not None: + noise_pred = self( + torch.cat([latents, latents], dim=0), + torch.cat([t, t], dim=0), + condition=torch.cat([neg_condition, condition], dim=0), + fwd_pred_type="flow" + ) + uncond, cond = noise_pred.chunk(2) + noise_pred = uncond + guidance_scale * (cond - uncond) + else: + noise_pred = self(latents, t, condition=condition, fwd_pred_type="flow") + + latents = scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + + return latents + + def fully_shard(self, **kwargs): + if self.transformer.gradient_checkpointing: + self.transformer.disable_gradient_checkpointing() + apply_fsdp_checkpointing( + self.transformer, check_fn=lambda b: isinstance(b, LTX2VideoTransformerBlock) + ) + for block in self.transformer.transformer_blocks: + fully_shard(block, **kwargs) + fully_shard(self.transformer, **kwargs) + + def init_preprocessors(self): + self.text_encoder = LTX2TextEncoder(self.model_id) + self.vae = LTX2VideoEncoder(self.model_id) + + def _prepare_video_indices(self, F, H, W, device, dtype): + t_coords = torch.arange(F, device=device, dtype=dtype) + h_coords = torch.arange(H, device=device, dtype=dtype) + w_coords = torch.arange(W, device=device, dtype=dtype) + return torch.stack(torch.meshgrid(t_coords, h_coords, w_coords, indexing="ij"), dim=-1).reshape(-1, 3) + + def _pack_latents(self, x): + b, c, f, h, w = x.shape + return x.permute(0, 2, 3, 4, 1).reshape(b, -1, c) + + def _unpack_latents(self, x, f, h, w): + b, _, c = x.shape + return x.reshape(b, f, h, w, c).permute(0, 4, 1, 2, 3) + + def forward(self, x_t, t, condition=None, **kwargs): + b, c, f, h, w = x_t.shape + indices = self._prepare_video_indices(f, h, w, x_t.device, x_t.dtype) + model_outputs = self.transformer( + hidden_states=self._pack_latents(x_t), + encoder_hidden_states=condition, + timesteps=t, + indices=indices, + **kwargs + ) + # Logvar handling + if kwargs.get("return_logvar", False): + out, logvar = model_outputs + out = self._unpack_latents(out, f, h, w) + return out, logvar + + # Standard unpack + out = model_outputs if not isinstance(model_outputs, list) else model_outputs[0] + return self._unpack_latents(out, f, h, w) \ No newline at end of file diff --git a/tests/test_ltx2.py b/tests/test_ltx2.py new file mode 100644 index 0000000..491b4ac --- /dev/null +++ b/tests/test_ltx2.py @@ -0,0 +1,72 @@ +import torch +import numpy as np +from PIL import Image +from diffusers.pipelines.ltx2.export_utils import encode_video +from fastgen.networks.LTX2.network import LTX2 + +def test_ltx2_generation(): + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 # LTX-2 is optimized for bfloat16 + + # 1. Initialize the LTX2 model + print("Initializing LTX-2 model...") + model = LTX2(model_id="Lightricks/LTX-2", load_pretrained=True).to(device, dtype=dtype) + model.init_preprocessors() + model.eval() + + # 2. Prepare Prompts + prompt = "A majestic dragon flying over a snowy mountain range, cinematic lighting, 4k" + negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + print(f"Encoding prompt: {prompt}") + # Gemma 3 encoding + condition = model.text_encoder.encode(prompt, precision=dtype).to(device) + neg_condition = model.text_encoder.encode(negative_prompt, precision=dtype).to(device) + + # 3. Define Video Parameters + # LTX-2 video dimensions must be divisible by 32 spatially and 8+1 temporally + height, width = 480, 704 + num_frames = 81 # (8 * 10) + 1 + batch_size = 1 + + # Calculate latent dimensions (VAE compresses 8x spatially and 8x temporally) + latent_f = (num_frames - 1) // 8 + 1 + latent_h = height // 32 + latent_w = width // 32 + + # 4. Generate Initial Noise + # LTX-2 VAE latent channels is typically 128 + latent_channels = model.vae.vae.config.latent_channels + noise = torch.randn(batch_size, latent_channels, latent_f, latent_h, latent_w, device=device, dtype=dtype) + + # 5. Run Sampling (Inference) + print("Starting sampling process...") + with torch.no_grad(): + latents = model.sample( + noise=noise, + condition=condition, + neg_condition=neg_condition, + guidance_scale=4.0, + num_steps=40 + ) + + # 6. Decode Latents to Video + print("Decoding latents to video...") + with torch.no_grad(): + # LTX-2 uses flow-matching, so we can pass a zero/final timestep if needed by the VAE + video_tensor = model.vae.decode(latents) + + # 7. Post-process and Save + # Video tensor is [B, C, F, H, W] in range [-1, 1] + video_np = ((video_tensor[0].cpu().permute(1, 2, 3, 0).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8) + + print("Saving video to ltx2_test.mp4...") + encode_video( + video_np, + fps=24, + output_path="ltx2_test.mp4" + ) + print("Done!") + +if __name__ == "__main__": + test_ltx2_generation() \ No newline at end of file From 2a588aa8bfc088b27ebf89ae0299428b5d0008a4 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Mon, 23 Feb 2026 12:22:11 +0800 Subject: [PATCH 02/13] fix the error, now video generated correctly --- fastgen/networks/LTX2/network.py | 845 +++++++++++++++++++------------ tests/test_ltx2.py | 88 ++-- 2 files changed, 586 insertions(+), 347 deletions(-) diff --git a/fastgen/networks/LTX2/network.py b/fastgen/networks/LTX2/network.py index cfeed80..9c84d18 100644 --- a/fastgen/networks/LTX2/network.py +++ b/fastgen/networks/LTX2/network.py @@ -1,357 +1,568 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 +""" +LTX-2 FastGen network implementation. -import os -from typing import Any, Optional, List, Set, Union, Tuple -import types +Architecture verified against: + - diffusers/src/diffusers/pipelines/ltx2/pipeline_ltx2.py + - diffusers/src/diffusers/pipelines/ltx2/connectors.py + - diffusers/src/diffusers/models/transformers/transformer_ltx2.py +""" +import copy + +import numpy as np import torch import torch.nn as nn -from torch import dtype -from torch.distributed.fsdp import fully_shard - -from diffusers import FlowMatchEulerDiscreteScheduler -from diffusers.models import LTXVideoTransformer3DModel, AutoencoderKLLTX2Video -from diffusers.models.transformers.transformer_ltx2 import LTX2VideoTransformerBlock -from diffusers.pipelines.ltx2 import LTX2TextConnectors -from transformers import GemmaTokenizer, Gemma3ForConditionalGeneration - -from fastgen.networks.network import FastGenNetwork -from fastgen.networks.noise_schedule import NET_PRED_TYPES -from fastgen.utils.basic_utils import str2bool -from fastgen.utils.distributed.fsdp import apply_fsdp_checkpointing -import fastgen.utils.logging_utils as logger - - -class LTX2TextEncoder: - """Text encoder for LTX-2 using Gemma 3.""" +from diffusers.models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from diffusers.models.transformers import LTX2VideoTransformer3DModel +from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizerFast + + +# --------------------------------------------------------------------------- +# Helpers (mirrors of diffusers pipeline static methods) +# --------------------------------------------------------------------------- + +def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + """Pack video latents [B, C, F, H, W] → [B, F//pt * H//p * W//p, C*pt*p*p].""" + B, C, F, H, W = latents.shape + pF = F // patch_size_t + pH = H // patch_size + pW = W // patch_size + latents = latents.reshape(B, C, pF, patch_size_t, pH, patch_size, pW, patch_size) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + +def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, + patch_size: int = 1, patch_size_t: int = 1 +) -> torch.Tensor: + """Unpack video latents [B, T, D] → [B, C, F, H, W].""" + B = latents.size(0) + latents = latents.reshape(B, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + +def _pack_audio_latents(latents: torch.Tensor) -> torch.Tensor: + """Pack audio latents [B, C, L, M] → [B, L, C*M].""" + return latents.transpose(1, 2).flatten(2, 3) + + +def _unpack_audio_latents(latents: torch.Tensor, latent_length: int, num_mel_bins: int) -> torch.Tensor: + """Unpack audio latents [B, L, C*M] → [B, C, L, M].""" + B = latents.size(0) + latents = latents.reshape(B, latent_length, num_mel_bins, -1) + return latents.permute(0, 3, 1, 2) + + +def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, + scaling_factor: float = 1.0, +) -> torch.Tensor: + mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + return (latents - mean) * scaling_factor / std + + +def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, + scaling_factor: float = 1.0, +) -> torch.Tensor: + mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + return latents * std / scaling_factor + mean + + +def _normalize_audio_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor +) -> torch.Tensor: + mean = latents_mean.to(latents.device, latents.dtype) + std = latents_std.to(latents.device, latents.dtype) + return (latents - mean) / std + + +def _denormalize_audio_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor +) -> torch.Tensor: + mean = latents_mean.to(latents.device, latents.dtype) + std = latents_std.to(latents.device, latents.dtype) + return latents * std + mean + + +def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, +) -> torch.Tensor: + """ + Stack all Gemma hidden-state layers, normalize per-batch/per-layer over + non-padded positions, and pack into [B, T, H * num_layers]. + + Args: + text_hidden_states: [B, T, H, num_layers] (stacked output from Gemma) + sequence_lengths: [B] (number of non-padded tokens) + Returns: + [B, T, H * num_layers] + """ + B, T, H, L = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + token_indices = torch.arange(T, device=device).unsqueeze(0) # [1, T] + if padding_side == "right": + mask = token_indices < sequence_lengths[:, None] + else: # left + start = T - sequence_lengths[:, None] + mask = token_indices >= start + mask = mask[:, :, None, None] # [B, T, 1, 1] + + masked = text_hidden_states.masked_fill(~mask, 0.0) + num_valid = (sequence_lengths * H).view(B, 1, 1, 1) + mean = masked.sum(dim=(1, 2), keepdim=True) / (num_valid + eps) + + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + normed = (text_hidden_states - mean) / (x_max - x_min + eps) * scale_factor + normed = normed.flatten(2) # [B, T, H*L] + mask_flat = mask.squeeze(-1).expand(-1, -1, H * L) + normed = normed.masked_fill(~mask_flat, 0.0) + return normed.to(original_dtype) + + +def _calculate_shift( + image_seq_len: int, + base_seq_len: int = 1024, + max_seq_len: int = 4096, + base_shift: float = 0.95, + max_shift: float = 2.05, +) -> float: + """Mirrors the pipeline's calculate_shift — defaults match LTX-2 scheduler config.""" + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + return image_seq_len * m + b + + +def _retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=None, mu=None): + """Call scheduler.set_timesteps, forwarding mu when dynamic shifting is enabled.""" + kwargs = {} + if mu is not None: + kwargs["mu"] = mu + if sigmas is not None: + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + return scheduler.timesteps, len(scheduler.timesteps) + + +# --------------------------------------------------------------------------- +# Text encoder wrapper +# --------------------------------------------------------------------------- + +class LTX2TextEncoder(nn.Module): + """ + Wraps Gemma3ForConditionalGeneration for LTX-2 text conditioning. + + Returns both the packed prompt embeddings AND the tokenizer attention mask, + which is required by LTX2TextConnectors. + """ def __init__(self, model_id: str): - self.tokenizer = GemmaTokenizer.from_pretrained( - model_id, - cache_dir=os.environ["HF_HOME"], - subfolder="tokenizer", - local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), - ) + super().__init__() + self.tokenizer = GemmaTokenizerFast.from_pretrained(model_id, subfolder="tokenizer") + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.text_encoder = Gemma3ForConditionalGeneration.from_pretrained( - model_id, - cache_dir=os.environ["HF_HOME"], - subfolder="text_encoder", - local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), + model_id, subfolder="text_encoder" ) - self.text_encoder.eval().requires_grad_(False) + @torch.no_grad() def encode( self, - conditioning: Optional[Any] = None, - precision: dtype = torch.float32, - max_sequence_length: int = 512, - ) -> torch.Tensor: - """Encode text prompts to raw Gemma 3 embeddings.""" - if isinstance(conditioning, str): - conditioning = [conditioning] + prompt: str | list[str], + precision: torch.dtype = torch.bfloat16, + max_sequence_length: int = 1024, + scale_factor: int = 8, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Encode text prompt(s) into packed Gemma hidden states. + + Returns + ------- + prompt_embeds : torch.Tensor [B, T, H * num_layers] + Normalised, packed text embeddings ready for LTX2TextConnectors. + attention_mask : torch.Tensor [B, T] + Binary padding mask (1 = real token, 0 = pad) from the tokenizer. + """ + if isinstance(prompt, str): + prompt = [prompt] + + device = next(self.text_encoder.parameters()).device text_inputs = self.tokenizer( - conditioning, + prompt, padding="max_length", max_length=max_sequence_length, truncation=True, + add_special_tokens=True, return_tensors="pt", ) + input_ids = text_inputs.input_ids.to(device) + attention_mask = text_inputs.attention_mask.to(device) + + outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + ) - with torch.no_grad(): - outputs = self.text_encoder( - input_ids=text_inputs.input_ids.to(self.text_encoder.device), - attention_mask=text_inputs.attention_mask.to(self.text_encoder.device), - output_hidden_states=True, - return_dict=True, - ) - # Raw Gemma hidden states to be projected by connectors - prompt_embeds = outputs.hidden_states[-1].to(precision) + # Stack all hidden states: [B, T, H, num_layers] + hidden_states = torch.stack(outputs.hidden_states, dim=-1) + sequence_lengths = attention_mask.sum(dim=-1) + + prompt_embeds = _pack_text_embeds( + hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ).to(precision) + + # Bug 1 fix: return BOTH embeds and mask + return prompt_embeds, attention_mask + + +# --------------------------------------------------------------------------- +# Main LTX-2 network +# --------------------------------------------------------------------------- + +class LTX2(nn.Module): + """ + FastGen wrapper for LTX-2 audio-video generation. + + Component layout (mirrors the diffusers model_cpu_offload_seq): + text_encoder → connectors → transformer → vae → audio_vae → vocoder + + Notes + ----- + * ``connectors`` lives as a sibling of ``transformer``, NOT nested inside + it (Bug 5 fix). Connectors process the text embeddings once before the + denoising loop and produce separate video and audio encoder hidden states. + * No monkey-patching of the transformer is needed: LTX2VideoTransformer3DModel + handles its own block loop with the correct audio/video dual-branch logic + internally (Bug 3 fix). + """ + + def __init__(self, model_id: str = "Lightricks/LTX-2", load_pretrained: bool = True): + super().__init__() + self.model_id = model_id + self._initialized = False + if load_pretrained: + self._initialize_network() - return prompt_embeds + # ------------------------------------------------------------------ + # Initialisation + # ------------------------------------------------------------------ - def to(self, *args, **kwargs): - self.text_encoder.to(*args, **kwargs) - return self + def _initialize_network(self): + model_id = self.model_id + # Text encoder (Gemma3) + self.text_encoder = LTX2TextEncoder(model_id) -class LTX2VideoEncoder: - """Spatio-temporal VAE encoder/decoder for LTX-2.""" + # Bug 5 fix: connectors is a TOP-LEVEL sibling of transformer, + # NOT attached to self.transformer. + self.connectors = LTX2TextConnectors.from_pretrained( + model_id, subfolder="connectors" + ) - def __init__(self, model_id: str): - self.vae: AutoencoderKLLTX2Video = AutoencoderKLLTX2Video.from_pretrained( - model_id, - cache_dir=os.environ["HF_HOME"], - subfolder="vae", - local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), + # Transformer (LTX2VideoTransformer3DModel handles all block logic) + self.transformer = LTX2VideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer" ) - self.vae.eval().requires_grad_(False) - - # LTX-2 normalization constants - self.scaling_factor = getattr(self.vae.config, "scaling_factor", 1.5305) - self.shift_factor = getattr(self.vae.config, "shift_factor", 0.0609) - - def encode(self, real_video: torch.Tensor) -> torch.Tensor: - """Encode videos to 3D latent space.""" - latents = self.vae.encode(real_video, return_dict=False)[0].sample() - latents = (latents - self.shift_factor) * self.scaling_factor - return latents - - def decode(self, latents: torch.Tensor, timestep: Optional[torch.Tensor] = None) -> torch.Tensor: - """Decode latents to video frames with optional timestep conditioning.""" - latents = (latents / self.scaling_factor) + self.shift_factor - video = self.vae.decode(latents, timestep=timestep, return_dict=False)[0] - return video.clip(-1.0, 1.0) - - def to(self, *args, **kwargs): - self.vae.to(*args, **kwargs) - return self - - -def classify_forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timesteps: torch.Tensor, - indices: torch.Tensor, - encoder_attention_mask: Optional[torch.Tensor] = None, - audio_hidden_states: Optional[torch.Tensor] = None, - audio_encoder_attention_mask: Optional[torch.Tensor] = None, - return_features_early: bool = False, - feature_indices: Optional[Set[int]] = None, - return_logvar: bool = False, - **kwargs, -) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: - """Patched forward pass for LTXVideoTransformer3DModel.""" - if feature_indices is None: - feature_indices = set() - - # 1. Text Connector Projection - # Projects Gemma embeddings for video and audio streams - encoder_hidden_states, audio_encoder_hidden_states = self.connectors( - encoder_hidden_states - ) - - # 2. Time & Text Embeddings - temb = self.time_text_embed(timesteps, encoder_hidden_states) - - # LTX-2 also prepares a separate audio time embedding - # Here we simplify or derive it if necessary - temb_audio = temb - - # 3. Derive Modulation Parameters for Cross-Attention - # These are typically linear projections of temb/temb_audio - # Based on the user's Turn 10 snippet, these must be passed to the blocks - video_ca_scale_shift = self.video_ca_modulation(temb) - audio_ca_scale_shift = self.audio_ca_modulation(temb_audio) - video_ca_a2v_gate = self.video_ca_gate(temb) - audio_ca_v2a_gate = self.audio_ca_gate(temb_audio) - - # 4. Positional Embeddings - video_rotary_emb = self.pos_embed(indices) - audio_rotary_emb = None # Placeholder if audio is zeroed - - if audio_hidden_states is None: - audio_hidden_states = torch.zeros_like(hidden_states[:, :0]) - - idx, features = 0, [] - - # 5. Dual-Stream Transformer Blocks - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states, audio_hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - audio_hidden_states, - encoder_hidden_states, - audio_encoder_hidden_states, - temb, - temb_audio, - video_ca_scale_shift, - audio_ca_scale_shift, - video_ca_a2v_gate, - audio_ca_v2a_gate, - video_rotary_emb, - audio_rotary_emb, - None, # ca_video_rotary_emb - None, # ca_audio_rotary_emb - encoder_attention_mask, - audio_encoder_attention_mask, - ) - else: - hidden_states, audio_hidden_states = block( - hidden_states=hidden_states, - audio_hidden_states=audio_hidden_states, - encoder_hidden_states=encoder_hidden_states, - audio_encoder_hidden_states=audio_encoder_hidden_states, - temb=temb, - temb_audio=temb_audio, - temb_ca_scale_shift=video_ca_scale_shift, - temb_ca_audio_scale_shift=audio_ca_scale_shift, - temb_ca_gate=video_ca_a2v_gate, - temb_ca_audio_gate=audio_ca_v2a_gate, - video_rotary_emb=video_rotary_emb, - audio_rotary_emb=audio_rotary_emb, - encoder_attention_mask=encoder_attention_mask, - audio_encoder_attention_mask=audio_encoder_attention_mask, - ) - - if idx in feature_indices: - features.append(hidden_states.clone()) - if return_features_early and len(features) == len(feature_indices): - return features - idx += 1 - - # 6. Final Projection - output = self.proj_out(hidden_states, temb) - - if return_features_early: - return features - out = output if len(feature_indices) == 0 else [output, features] - - if return_logvar: - # temb_dim = 4096 - logvar = self.logvar_linear(temb) - return out, logvar - - return out - - -class LTX2(FastGenNetwork): - """LTX-2 network for text-to-video generation.""" - - MODEL_ID = "Lightricks/LTX-2" - - def __init__( - self, - model_id: str = MODEL_ID, - net_pred_type: str = "flow", - schedule_type: str = "rf", - disable_grad_ckpt: bool = False, - load_pretrained: bool = True, - **model_kwargs, - ): - super().__init__(net_pred_type=net_pred_type, schedule_type=schedule_type, **model_kwargs) - self.model_id = model_id - self._initialize_network(model_id, load_pretrained) - self.transformer.forward = types.MethodType(classify_forward, self.transformer) - if disable_grad_ckpt: - self.transformer.disable_gradient_checkpointing() - else: - self.transformer.enable_gradient_checkpointing() - - def _initialize_network(self, model_id: str, load_pretrained: bool) -> None: - in_meta_context = self._is_in_meta_context() - - if load_pretrained and not in_meta_context: - logger.info(f"Loading LTX-2 transformer and connectors from {model_id}") - self.transformer = LTXVideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer" - ) - # Load LTX2TextConnectors - self.transformer.connectors = LTX2TextConnectors.from_pretrained( - model_id, subfolder="connectors" - ) - else: - config = LTXVideoTransformer3DModel.load_config(model_id, subfolder="transformer") - self.transformer = LTXVideoTransformer3DModel.from_config(config) - conn_config = LTX2TextConnectors.load_config(model_id, subfolder="connectors") - self.transformer.connectors = LTX2TextConnectors.from_config(conn_config) + # VAEs + self.vae = AutoencoderKLLTX2Video.from_pretrained( + model_id, subfolder="vae" + ) + self.audio_vae = AutoencoderKLLTX2Audio.from_pretrained( + model_id, subfolder="audio_vae" + ) + + # Vocoder (mel spectrogram → waveform) + self.vocoder = LTX2Vocoder.from_pretrained( + model_id, subfolder="vocoder" + ) + + # Scheduler + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model_id, subfolder="scheduler" + ) + + # Cache compression ratios + self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio + self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio + self.transformer_spatial_patch_size = self.transformer.config.patch_size + self.transformer_temporal_patch_size = self.transformer.config.patch_size_t - self.transformer.logvar_linear = nn.Linear(4096, 1) + # Audio VAE constants. + # sample_rate / mel_hop_length live in config; the compression ratios are + # instance attributes (LATENT_DOWNSAMPLE_FACTOR = 4) set in __init__, not in config. + self.audio_sampling_rate = self.audio_vae.config.sample_rate # 16000 + self.audio_hop_length = self.audio_vae.config.mel_hop_length # 160 + self.audio_vae_temporal_compression = self.audio_vae.temporal_compression_ratio # 4 + self.audio_vae_mel_compression_ratio = self.audio_vae.mel_compression_ratio # 4 - def _calculate_shift(self, sequence_length: int) -> float: - """Resolution-dependent shift.""" - base_seq_len, max_seq_len = 1024, 4096 - base_shift, max_shift = 0.5, 1.15 - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - return sequence_length * m + b + self._initialized = True + + def init_preprocessors(self): + """No-op placeholder (preprocessors are initialised in _initialize_network).""" + pass + + # ------------------------------------------------------------------ + # Forward pass (single denoising step) + # ------------------------------------------------------------------ + + def forward( + self, + # packed video latents [B, T_v, C_v] + hidden_states: torch.Tensor, + # packed audio latents [B, T_a, C_a] + audio_hidden_states: torch.Tensor, + # already-projected by connectors + encoder_hidden_states: torch.Tensor, # video text embeds [B, T_t, D_v] + audio_encoder_hidden_states: torch.Tensor, # audio text embeds [B, T_t, D_a] + encoder_attention_mask: torch.Tensor, # [B, T_t] + timestep: torch.Tensor, + # spatial metadata for RoPE + num_frames: int, + height: int, + width: int, + audio_num_frames: int, + fps: float = 24.0, + video_coords: torch.Tensor | None = None, + audio_coords: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Single denoising step through the transformer. + + Returns + ------- + (noise_pred_video, noise_pred_audio) both in packed token format. + """ + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + timestep=timestep, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=encoder_attention_mask, # same mask for audio + num_frames=num_frames, + height=height, + width=width, + fps=fps, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, + return_dict=False, + ) + return noise_pred_video, noise_pred_audio + + # ------------------------------------------------------------------ + # Sampling loop + # ------------------------------------------------------------------ @torch.no_grad() def sample( self, noise: torch.Tensor, - condition: Optional[torch.Tensor] = None, - neg_condition: Optional[torch.Tensor] = None, - guidance_scale: Optional[float] = 4.0, + condition: tuple[torch.Tensor, torch.Tensor], + neg_condition: tuple[torch.Tensor, torch.Tensor] | None = None, + guidance_scale: float = 4.0, num_steps: int = 40, - **kwargs, - ) -> torch.Tensor: - """Generate video samples using Euler flow matching.""" - batch_size, _, frames, height, width = noise.shape - mu = self._calculate_shift(frames * height * width) - - scheduler = FlowMatchEulerDiscreteScheduler(shift=mu) - scheduler.set_timesteps(num_steps, device=noise.device) - timesteps, latents = scheduler.timesteps, noise.clone() - - for timestep in timesteps: - t = (timestep / 1000.0).expand(batch_size).to(dtype=noise.dtype, device=noise.device) - t = self.noise_scheduler.safe_clamp(t, min=self.noise_scheduler.min_t, max=self.noise_scheduler.max_t) - - if neg_condition is not None and guidance_scale is not None: - noise_pred = self( - torch.cat([latents, latents], dim=0), - torch.cat([t, t], dim=0), - condition=torch.cat([neg_condition, condition], dim=0), - fwd_pred_type="flow" + fps: float = 24.0, + frame_rate: float | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Run the full denoising loop for text-to-video+audio generation. + + Follows pipeline_ltx2.py exactly: + - latents kept in float32 throughout; only cast to prompt dtype for transformer + - noise predictions cast to float32 before CFG and scheduler step + - connectors called ONCE on combined [uncond, cond] batch (not twice) + - audio duration derived from pixel-frame count, not latent frames + - transformer wrapped in cache_context("cond_uncond") + + Returns + ------- + (video_latents, audio_latents): + video: [B, C, F, H, W] denormalised, ready for vae.decode() + audio: [B, C, L, M] denormalised, ready for audio_vae.decode() + """ + fps = frame_rate if frame_rate is not None else fps + do_cfg = neg_condition is not None and guidance_scale > 1.0 + + device = noise.device + B, C, latent_f, latent_h, latent_w = noise.shape + + # ---------------------------------------------------------------- + # Latent dimensions + # ---------------------------------------------------------------- + # Audio duration uses pixel-frame count (pipeline line 1003): + # duration_s = num_frames / frame_rate + # where num_frames is the original pixel count. For a causal VAE: + # pixel_frames = (latent_f - 1) * temporal_compression + 1 + pixel_frames = (latent_f - 1) * self.vae_temporal_compression_ratio + 1 + duration_s = pixel_frames / fps + + audio_latents_per_second = ( + self.audio_sampling_rate + / self.audio_hop_length + / float(self.audio_vae_temporal_compression) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + num_mel_bins = self.audio_vae.config.mel_bins + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_audio_ch = self.audio_vae.config.latent_channels + + # ---------------------------------------------------------------- + # Pack latents — keep in float32 to match pipeline + # ---------------------------------------------------------------- + video_latents = _pack_latents( + noise.float(), self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + audio_shape = (B, num_audio_ch, audio_num_frames, latent_mel_bins) + audio_latents = torch.randn(audio_shape, device=device, dtype=torch.float32) + audio_latents = _pack_audio_latents(audio_latents) + + # ---------------------------------------------------------------- + # Text conditioning — run connectors ONCE on combined [uncond, cond] + # batch, exactly as the pipeline does (lines 959-966) + # ---------------------------------------------------------------- + prompt_embeds, attention_mask = condition + + if do_cfg: + neg_embeds, neg_mask = neg_condition + # Stack [uncond, cond] before connectors — single forward pass + combined_embeds = torch.cat([neg_embeds, prompt_embeds], dim=0) + combined_mask = torch.cat([neg_mask, attention_mask], dim=0) + else: + combined_embeds = prompt_embeds + combined_mask = attention_mask + + additive_mask = (1 - combined_mask.to(combined_embeds.dtype)) * -1_000_000.0 + connector_video_embeds, connector_audio_embeds, connector_attn_mask = self.connectors( + combined_embeds, additive_mask, additive_mask=True + ) + + # ---------------------------------------------------------------- + # Pre-compute RoPE coordinates (pipeline lines 1078-1087) + # Compute for single batch, then repeat for CFG + # ---------------------------------------------------------------- + video_coords = self.transformer.rope.prepare_video_coords( + B, latent_f, latent_h, latent_w, device, fps=fps + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + B, audio_num_frames, device + ) + if do_cfg: + video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) + audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) + + # ---------------------------------------------------------------- + # Scheduler timesteps (pipeline lines 1042-1067) + # ---------------------------------------------------------------- + sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps) + + video_seq_len = latent_f * latent_h * latent_w + mu = _calculate_shift( + video_seq_len, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + + audio_scheduler = copy.deepcopy(self.scheduler) + _retrieve_timesteps(audio_scheduler, num_steps, device, sigmas=sigmas, mu=mu) + timesteps, num_steps = _retrieve_timesteps(self.scheduler, num_steps, device, sigmas=sigmas, mu=mu) + + # ---------------------------------------------------------------- + # Denoising loop (pipeline lines 1091-1154) + # ---------------------------------------------------------------- + # prompt_embeds.dtype drives the cast for transformer input + prompt_dtype = connector_video_embeds.dtype + + for t in timesteps: + # Cast latents to prompt dtype for transformer; keep float32 for scheduler + latent_input = torch.cat([video_latents] * 2) if do_cfg else video_latents + audio_latent_input = torch.cat([audio_latents] * 2) if do_cfg else audio_latents + latent_input = latent_input.to(prompt_dtype) + audio_latent_input = audio_latent_input.to(prompt_dtype) + + timestep = t.expand(latent_input.shape[0]) + + # Wrap in cache_context as the pipeline does (CacheMixin optimisation) + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.forward( + hidden_states=latent_input, + audio_hidden_states=audio_latent_input, + encoder_hidden_states=connector_video_embeds, + audio_encoder_hidden_states=connector_audio_embeds, + encoder_attention_mask=connector_attn_mask, + timestep=timestep, + num_frames=latent_f, + height=latent_h, + width=latent_w, + audio_num_frames=audio_num_frames, + fps=fps, + video_coords=video_coords, + audio_coords=audio_coords, ) - uncond, cond = noise_pred.chunk(2) - noise_pred = uncond + guidance_scale * (cond - uncond) - else: - noise_pred = self(latents, t, condition=condition, fwd_pred_type="flow") - latents = scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + # Cast noise preds to float32 before CFG and scheduler step + # (pipeline lines 1127-1128) + noise_pred_video = noise_pred_video.float() + noise_pred_audio = noise_pred_audio.float() - return latents + if do_cfg: + video_uncond, video_cond = noise_pred_video.chunk(2) + noise_pred_video = video_uncond + guidance_scale * (video_cond - video_uncond) - def fully_shard(self, **kwargs): - if self.transformer.gradient_checkpointing: - self.transformer.disable_gradient_checkpointing() - apply_fsdp_checkpointing( - self.transformer, check_fn=lambda b: isinstance(b, LTX2VideoTransformerBlock) - ) - for block in self.transformer.transformer_blocks: - fully_shard(block, **kwargs) - fully_shard(self.transformer, **kwargs) + audio_uncond, audio_cond = noise_pred_audio.chunk(2) + noise_pred_audio = audio_uncond + guidance_scale * (audio_cond - audio_uncond) - def init_preprocessors(self): - self.text_encoder = LTX2TextEncoder(self.model_id) - self.vae = LTX2VideoEncoder(self.model_id) - - def _prepare_video_indices(self, F, H, W, device, dtype): - t_coords = torch.arange(F, device=device, dtype=dtype) - h_coords = torch.arange(H, device=device, dtype=dtype) - w_coords = torch.arange(W, device=device, dtype=dtype) - return torch.stack(torch.meshgrid(t_coords, h_coords, w_coords, indexing="ij"), dim=-1).reshape(-1, 3) - - def _pack_latents(self, x): - b, c, f, h, w = x.shape - return x.permute(0, 2, 3, 4, 1).reshape(b, -1, c) - - def _unpack_latents(self, x, f, h, w): - b, _, c = x.shape - return x.reshape(b, f, h, w, c).permute(0, 4, 1, 2, 3) - - def forward(self, x_t, t, condition=None, **kwargs): - b, c, f, h, w = x_t.shape - indices = self._prepare_video_indices(f, h, w, x_t.device, x_t.dtype) - model_outputs = self.transformer( - hidden_states=self._pack_latents(x_t), - encoder_hidden_states=condition, - timesteps=t, - indices=indices, - **kwargs + # Scheduler steps operate on float32 latents + video_latents = self.scheduler.step(noise_pred_video, t, video_latents, return_dict=False)[0] + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + # ---------------------------------------------------------------- + # Unpack and denormalise (pipeline lines 1172-1187) + # ---------------------------------------------------------------- + video_latents = _unpack_latents( + video_latents, latent_f, latent_h, latent_w, + self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, + ) + video_latents = _denormalize_latents( + video_latents, self.vae.latents_mean, self.vae.latents_std, + self.vae.config.scaling_factor, ) - # Logvar handling - if kwargs.get("return_logvar", False): - out, logvar = model_outputs - out = self._unpack_latents(out, f, h, w) - return out, logvar - - # Standard unpack - out = model_outputs if not isinstance(model_outputs, list) else model_outputs[0] - return self._unpack_latents(out, f, h, w) \ No newline at end of file + + # Denormalise while still packed [B, L, 128], then unpack to [B, C, L, M] + # (pipeline lines 1184-1187: _denormalize then _unpack) + audio_latents = _denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = _unpack_audio_latents(audio_latents, audio_num_frames, latent_mel_bins) + + return video_latents, audio_latents \ No newline at end of file diff --git a/tests/test_ltx2.py b/tests/test_ltx2.py index 491b4ac..1a7432f 100644 --- a/tests/test_ltx2.py +++ b/tests/test_ltx2.py @@ -1,13 +1,13 @@ import torch import numpy as np -from PIL import Image from diffusers.pipelines.ltx2.export_utils import encode_video from fastgen.networks.LTX2.network import LTX2 + def test_ltx2_generation(): device = "cuda" if torch.cuda.is_available() else "cpu" - dtype = torch.bfloat16 # LTX-2 is optimized for bfloat16 - + dtype = torch.bfloat16 # LTX-2 is optimized for bfloat16 + # 1. Initialize the LTX2 model print("Initializing LTX-2 model...") model = LTX2(model_id="Lightricks/LTX-2", load_pretrained=True).to(device, dtype=dtype) @@ -15,58 +15,86 @@ def test_ltx2_generation(): model.eval() # 2. Prepare Prompts - prompt = "A majestic dragon flying over a snowy mountain range, cinematic lighting, 4k" + prompt = "A high-performance sports car racing through a city street at night, neon lights reflecting off wet asphalt, motion blur streaking past the camera. The camera tracks low and close to the car as it accelerates aggressively, tires gripping the road, exhaust heat shimmering. Realistic lighting, cinematic depth of field, ultra-detailed textures, dynamic reflections, dramatic shadows, 4K realism, film-grade color grading." negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" - + print(f"Encoding prompt: {prompt}") - # Gemma 3 encoding - condition = model.text_encoder.encode(prompt, precision=dtype).to(device) - neg_condition = model.text_encoder.encode(negative_prompt, precision=dtype).to(device) + # FIX 1: encode() returns (embeds, mask) tuple — unpack and move each tensor separately + embeds, mask = model.text_encoder.encode(prompt, precision=dtype) + condition = (embeds.to(device), mask.to(device)) + neg_embeds, neg_mask = model.text_encoder.encode(negative_prompt, precision=dtype) + neg_condition = (neg_embeds.to(device), neg_mask.to(device)) # 3. Define Video Parameters - # LTX-2 video dimensions must be divisible by 32 spatially and 8+1 temporally + # LTX-2 video dimensions must be divisible by 32 spatially and (8n+1) temporally height, width = 480, 704 - num_frames = 81 # (8 * 10) + 1 + num_frames = 81 # (8 * 10) + 1 batch_size = 1 - - # Calculate latent dimensions (VAE compresses 8x spatially and 8x temporally) - latent_f = (num_frames - 1) // 8 + 1 - latent_h = height // 32 - latent_w = width // 32 - + + # Calculate latent dimensions + latent_f = (num_frames - 1) // model.vae_temporal_compression_ratio + 1 + latent_h = height // model.vae_spatial_compression_ratio + latent_w = width // model.vae_spatial_compression_ratio + # 4. Generate Initial Noise - # LTX-2 VAE latent channels is typically 128 - latent_channels = model.vae.vae.config.latent_channels - noise = torch.randn(batch_size, latent_channels, latent_f, latent_h, latent_w, device=device, dtype=dtype) + # FIX 2: model.vae IS AutoencoderKLLTX2Video directly — no .vae sub-attribute. + # Use transformer.config.in_channels as the pipeline does (line 989). + latent_channels = model.transformer.config.in_channels # 128 + + # FIX 3: noise must be float32 — the pipeline creates latents in torch.float32 + # (prepare_latents line 997) and keeps them float32 through the scheduler. + # Only cast to bfloat16 right before the transformer call. + noise = torch.randn( + batch_size, latent_channels, latent_f, latent_h, latent_w, + device=device, dtype=torch.float32 + ) # 5. Run Sampling (Inference) print("Starting sampling process...") with torch.no_grad(): - latents = model.sample( + latents, audio_latents = model.sample( noise=noise, condition=condition, neg_condition=neg_condition, guidance_scale=4.0, - num_steps=40 + num_steps=40, + fps=24.0, ) + # latents: [B, C, F, H, W] denormalised video latents (float32) + # audio_latents: [B, C, L, M] denormalised audio latents (float32) - # 6. Decode Latents to Video - print("Decoding latents to video...") + # 6. Decode Latents to Video + Audio + print("Decoding latents to video and audio...") with torch.no_grad(): - # LTX-2 uses flow-matching, so we can pass a zero/final timestep if needed by the VAE - video_tensor = model.vae.decode(latents) + # vae.decode() signature: decode(z, temb=None, causal=None, return_dict=True) + # timestep_conditioning is False for LTX-2, so no temb needed. + # Use return_dict=False to get the tensor directly. + video_tensor = model.vae.decode(latents.to(model.vae.dtype), return_dict=False)[0] + # video_tensor: [B, C, F, H, W] in ~[-1, 1] + + # Decode mel spectrogram -> waveform via audio_vae + vocoder + mel_spectrograms = model.audio_vae.decode( + audio_latents.to(model.audio_vae.dtype), return_dict=False + )[0] + audio_waveform = model.vocoder(mel_spectrograms) + # audio_waveform: [B, channels, samples] at vocoder.config.output_sampling_rate Hz # 7. Post-process and Save - # Video tensor is [B, C, F, H, W] in range [-1, 1] - video_np = ((video_tensor[0].cpu().permute(1, 2, 3, 0).numpy() + 1.0) * 127.5).clip(0, 255).astype(np.uint8) - + # Convert [B, C, F, H, W] -> [F, H, W, C] uint8 + video_np = ( + (video_tensor[0].cpu().float().permute(1, 2, 3, 0).numpy() + 1.0) * 127.5 + ).clip(0, 255).astype(np.uint8) + print("Saving video to ltx2_test.mp4...") encode_video( video_np, fps=24, - output_path="ltx2_test.mp4" + audio=audio_waveform[0].float().cpu(), + audio_sample_rate=model.vocoder.config.output_sampling_rate, # 24000 Hz + output_path="ltx2_test.mp4", ) print("Done!") + if __name__ == "__main__": - test_ltx2_generation() \ No newline at end of file + test_ltx2_generation() From 919cdb171fc564e93f1c9c62155c9d32e6591ff3 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Mon, 23 Feb 2026 12:25:05 +0800 Subject: [PATCH 03/13] add fastgen header --- fastgen/networks/LTX2/network.py | 3 +++ tests/test_ltx2.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/fastgen/networks/LTX2/network.py b/fastgen/networks/LTX2/network.py index 9c84d18..bd7d910 100644 --- a/fastgen/networks/LTX2/network.py +++ b/fastgen/networks/LTX2/network.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + """ LTX-2 FastGen network implementation. diff --git a/tests/test_ltx2.py b/tests/test_ltx2.py index 1a7432f..d5769a9 100644 --- a/tests/test_ltx2.py +++ b/tests/test_ltx2.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + import torch import numpy as np from diffusers.pipelines.ltx2.export_utils import encode_video From f3a2133f8aa640133d41ccf8f99ba30090ee26c5 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Mon, 23 Feb 2026 19:05:18 +0800 Subject: [PATCH 04/13] fix audio issue --- fastgen/networks/LTX2/network.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fastgen/networks/LTX2/network.py b/fastgen/networks/LTX2/network.py index bd7d910..db59bbc 100644 --- a/fastgen/networks/LTX2/network.py +++ b/fastgen/networks/LTX2/network.py @@ -55,10 +55,12 @@ def _pack_audio_latents(latents: torch.Tensor) -> torch.Tensor: def _unpack_audio_latents(latents: torch.Tensor, latent_length: int, num_mel_bins: int) -> torch.Tensor: - """Unpack audio latents [B, L, C*M] → [B, C, L, M].""" - B = latents.size(0) - latents = latents.reshape(B, latent_length, num_mel_bins, -1) - return latents.permute(0, 3, 1, 2) + """Unpack audio latents [B, L, C*M] -> [B, C, L, M]. + Mirrors pipeline: latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + unflatten splits last dim 128 -> [C=8, M=16], giving [B, L, C, M], + then transpose(1, 2) -> [B, C, L, M]. + """ + return latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) def _normalize_latents( From c56093e80645e9ad8bb444b9c58409d6caf1b017 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Tue, 24 Feb 2026 15:33:02 +0800 Subject: [PATCH 05/13] add classify_forward --- fastgen/networks/LTX2/network.py | 847 ++++++++++++++++++++++++------- tests/test_ltx2.py | 6 +- 2 files changed, 681 insertions(+), 172 deletions(-) diff --git a/fastgen/networks/LTX2/network.py b/fastgen/networks/LTX2/network.py index db59bbc..c995357 100644 --- a/fastgen/networks/LTX2/network.py +++ b/fastgen/networks/LTX2/network.py @@ -8,20 +8,35 @@ - diffusers/src/diffusers/pipelines/ltx2/pipeline_ltx2.py - diffusers/src/diffusers/pipelines/ltx2/connectors.py - diffusers/src/diffusers/models/transformers/transformer_ltx2.py + +Follows the FastGen network pattern established by Flux and Wan: + - Inherits from FastGenNetwork + - Monkey-patches classify_forward onto self.transformer + - forward() handles video-only latent for distillation (audio flows through but is ignored for loss) + - feature_indices extracts video hidden_states only """ import copy +import types +from typing import Any, List, Optional, Set, Tuple, Union import numpy as np import torch import torch.nn as nn from diffusers.models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video from diffusers.models.transformers import LTX2VideoTransformer3DModel +from diffusers.models.transformers.transformer_ltx2 import LTX2VideoTransformerBlock from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from torch.distributed.fsdp import fully_shard from transformers import Gemma3ForConditionalGeneration, GemmaTokenizerFast +from fastgen.networks.network import FastGenNetwork +from fastgen.networks.noise_schedule import NET_PRED_TYPES +from fastgen.utils.distributed.fsdp import apply_fsdp_checkpointing +import fastgen.utils.logging_utils as logger + # --------------------------------------------------------------------------- # Helpers (mirrors of diffusers pipeline static methods) @@ -55,11 +70,7 @@ def _pack_audio_latents(latents: torch.Tensor) -> torch.Tensor: def _unpack_audio_latents(latents: torch.Tensor, latent_length: int, num_mel_bins: int) -> torch.Tensor: - """Unpack audio latents [B, L, C*M] -> [B, C, L, M]. - Mirrors pipeline: latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) - unflatten splits last dim 128 -> [C=8, M=16], giving [B, L, C, M], - then transpose(1, 2) -> [B, C, L, M]. - """ + """Unpack audio latents [B, L, C*M] -> [B, C, L, M].""" return latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) @@ -108,12 +119,6 @@ def _pack_text_embeds( """ Stack all Gemma hidden-state layers, normalize per-batch/per-layer over non-padded positions, and pack into [B, T, H * num_layers]. - - Args: - text_hidden_states: [B, T, H, num_layers] (stacked output from Gemma) - sequence_lengths: [B] (number of non-padded tokens) - Returns: - [B, T, H * num_layers] """ B, T, H, L = text_hidden_states.shape original_dtype = text_hidden_states.dtype @@ -165,6 +170,270 @@ def _retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=None, mu= return scheduler.timesteps, len(scheduler.timesteps) +# --------------------------------------------------------------------------- +# classify_forward — monkey-patched onto self.transformer +# --------------------------------------------------------------------------- + +def classify_forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + audio_timestep: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + fps: float = 24.0, + audio_num_frames: Optional[int] = None, + video_coords: Optional[torch.Tensor] = None, + audio_coords: Optional[torch.Tensor] = None, + attention_kwargs: Optional[dict] = None, + return_dict: bool = True, # accepted for API compatibility; always ignored + # FastGen distillation kwargs + return_features_early: bool = False, + feature_indices: Optional[Set[int]] = None, + return_logvar: bool = False, +) -> Union[ + Tuple[torch.Tensor, torch.Tensor], # (video_out, audio_out) + Tuple[Tuple[torch.Tensor, torch.Tensor], List[torch.Tensor]], # ((video_out, audio_out), features) + List[torch.Tensor], # features only (early exit) +]: + """ + Drop-in replacement for LTX2VideoTransformer3DModel.forward that adds FastGen + distillation support (feature extraction, early exit, logvar). + + Audio always flows through every block unchanged — we never short-circuit it — + but only video hidden_states are stored as features for the discriminator. + + Returns + ------- + Normal mode (feature_indices empty, return_features_early False): + (video_output, audio_output) — identical to the original forward + + Feature mode (feature_indices non-empty, return_features_early False): + ((video_output, audio_output), List[video_feature_tensors]) + + Early-exit mode (return_features_early True): + List[video_feature_tensors] — forward stops as soon as all features collected + """ + # LoRA scale handling — mirrors the @apply_lora_scale decorator in upstream + from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + if USE_PEFT_BACKEND: + scale_lora_layers(self, lora_scale) + print("calling classfiy forward:") + if feature_indices is None: + feature_indices = set() + + if return_features_early and len(feature_indices) == 0: + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + return [] + + # ------------------------------------------------------------------ # + # Steps 1-4: identical to the original forward (no changes) + # ------------------------------------------------------------------ # + audio_timestep = audio_timestep if audio_timestep is not None else timestep + + # Convert attention masks to additive bias form + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: + audio_encoder_attention_mask = ( + 1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype) + ) * -10000.0 + audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + + # 1. RoPE positional embeddings + if video_coords is None: + video_coords = self.rope.prepare_video_coords( + batch_size, num_frames, height, width, hidden_states.device, fps=fps + ) + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, audio_hidden_states.device + ) + + video_rotary_emb = self.rope(video_coords, device=hidden_states.device) + audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) + + video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope( + audio_coords[:, 0:1, :], device=audio_hidden_states.device + ) + + # 2. Patchify input projections + hidden_states = self.proj_in(hidden_states) + audio_hidden_states = self.audio_proj_in(audio_hidden_states) + + # 3. Timestep embeddings and modulation parameters + timestep_cross_attn_gate_scale_factor = ( + self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier + ) + + temb, embedded_timestep = self.time_embed( + timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + temb_audio, audio_embedded_timestep = self.audio_time_embed( + audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) + audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) + + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( + timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.view( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view( + batch_size, -1, video_cross_attn_a2v_gate.shape[-1] + ) + + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( + audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( + audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view( + batch_size, -1, audio_cross_attn_v2a_gate.shape[-1] + ) + + # 4. Prompt embeddings + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view( + batch_size, -1, audio_hidden_states.size(-1) + ) + + # ------------------------------------------------------------------ # + # Step 5: Block loop with video-only feature extraction + # Audio always flows through every block — we never skip it. + # ------------------------------------------------------------------ # + features: List[torch.Tensor] = [] + + for idx, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, audio_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + audio_hidden_states, + encoder_hidden_states, + audio_encoder_hidden_states, + temb, + temb_audio, + video_cross_attn_scale_shift, + audio_cross_attn_scale_shift, + video_cross_attn_a2v_gate, + audio_cross_attn_v2a_gate, + video_rotary_emb, + audio_rotary_emb, + video_cross_attn_rotary_emb, + audio_cross_attn_rotary_emb, + encoder_attention_mask, + audio_encoder_attention_mask, + ) + else: + hidden_states, audio_hidden_states = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + ) + + # Video-only feature extraction at requested block indices + # TODO: we only extract the video feature for now + if idx in feature_indices: + features.append(hidden_states.clone()) # [B, T_v, D_v] — packed video tokens + + # Early exit once all requested features are collected + if return_features_early and len(features) == len(feature_indices): + return features + + # ------------------------------------------------------------------ # + # Step 6: Output layers (video + audio) — unchanged from original + # ------------------------------------------------------------------ # + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + video_output = self.proj_out(hidden_states) + + audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] + audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) + + # ------------------------------------------------------------------ # + # Assemble output following FastGen convention + # ------------------------------------------------------------------ # + if return_features_early: + # Should have been caught above; guard for safety + assert len(features) == len(feature_indices), f"{len(features)} != {len(feature_indices)}" + return features + + # Logvar (optional — requires logvar_linear to be added to the transformer) + logvar = None + if return_logvar: + assert hasattr(self, "logvar_linear"), ( + "logvar_linear is required when return_logvar=True. " + "It is added by LTX2.__init__." + ) + # temb has shape [B, T_tokens, inner_dim]; take mean over tokens for a scalar logvar per sample + logvar = self.logvar_linear(temb.mean(dim=1)) # [B, 1] + + if len(feature_indices) == 0: + out = (video_output, audio_output) + else: + out = [(video_output, audio_output), features] + + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + + if return_logvar: + return out, logvar + return out + + # --------------------------------------------------------------------------- # Text encoder wrapper # --------------------------------------------------------------------------- @@ -187,24 +456,23 @@ def __init__(self, model_id: str): self.text_encoder = Gemma3ForConditionalGeneration.from_pretrained( model_id, subfolder="text_encoder" ) + self.text_encoder.eval().requires_grad_(False) @torch.no_grad() def encode( self, - prompt: str | list[str], + prompt: Union[str, List[str]], precision: torch.dtype = torch.bfloat16, max_sequence_length: int = 1024, scale_factor: int = 8, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Encode text prompt(s) into packed Gemma hidden states. Returns ------- prompt_embeds : torch.Tensor [B, T, H * num_layers] - Normalised, packed text embeddings ready for LTX2TextConnectors. attention_mask : torch.Tensor [B, T] - Binary padding mask (1 = real token, 0 = pad) from the tokenizer. """ if isinstance(prompt, str): prompt = [prompt] @@ -241,169 +509,432 @@ def encode( scale_factor=scale_factor, ).to(precision) - # Bug 1 fix: return BOTH embeds and mask return prompt_embeds, attention_mask + def to(self, *args, **kwargs): + self.text_encoder.to(*args, **kwargs) + return self + # --------------------------------------------------------------------------- -# Main LTX-2 network +# Main LTX-2 network — follows FastGen pattern (Flux / Wan) # --------------------------------------------------------------------------- -class LTX2(nn.Module): +class LTX2(FastGenNetwork): """ FastGen wrapper for LTX-2 audio-video generation. - Component layout (mirrors the diffusers model_cpu_offload_seq): - text_encoder → connectors → transformer → vae → audio_vae → vocoder - - Notes - ----- - * ``connectors`` lives as a sibling of ``transformer``, NOT nested inside - it (Bug 5 fix). Connectors process the text embeddings once before the - denoising loop and produce separate video and audio encoder hidden states. - * No monkey-patching of the transformer is needed: LTX2VideoTransformer3DModel - handles its own block loop with the correct audio/video dual-branch logic - internally (Bug 3 fix). + Distillation targets video only: + - forward() receives and returns video latents [B, C, F, H, W] + - Audio is generated internally but not used for the distillation loss + - classify_forward extracts video hidden_states at requested block indices + + Component layout: + text_encoder → connectors → transformer (patched) → vae → audio_vae → vocoder """ - def __init__(self, model_id: str = "Lightricks/LTX-2", load_pretrained: bool = True): - super().__init__() + MODEL_ID = "Lightricks/LTX-2" + + def __init__( + self, + model_id: str = MODEL_ID, + net_pred_type: str = "flow", + schedule_type: str = "rf", + disable_grad_ckpt: bool = False, + load_pretrained: bool = True, + **model_kwargs, + ): + """ + LTX-2 constructor. + + Args: + model_id: HuggingFace model ID or local path. Defaults to "Lightricks/LTX-2". + net_pred_type: Prediction type. Defaults to "flow" (flow matching). + schedule_type: Schedule type. Defaults to "rf" (rectified flow). + disable_grad_ckpt: Disable gradient checkpointing during training. + Set True when using FSDP to avoid memory access errors. + load_pretrained: Load pretrained weights. If False, initialises from config only. + """ + super().__init__(net_pred_type=net_pred_type, schedule_type=schedule_type, **model_kwargs) + self.model_id = model_id - self._initialized = False - if load_pretrained: - self._initialize_network() + self._disable_grad_ckpt = disable_grad_ckpt - # ------------------------------------------------------------------ - # Initialisation - # ------------------------------------------------------------------ + self._initialize_network(model_id, load_pretrained) - def _initialize_network(self): - model_id = self.model_id + # Monkey-patch classify_forward onto self.transformer (same pattern as Flux / Wan) + self.transformer.forward = types.MethodType(classify_forward, self.transformer) - # Text encoder (Gemma3) - self.text_encoder = LTX2TextEncoder(model_id) + # Gradient checkpointing + if disable_grad_ckpt: + self.transformer.disable_gradient_checkpointing() + else: + self.transformer.enable_gradient_checkpointing() - # Bug 5 fix: connectors is a TOP-LEVEL sibling of transformer, - # NOT attached to self.transformer. - self.connectors = LTX2TextConnectors.from_pretrained( - model_id, subfolder="connectors" - ) + torch.cuda.empty_cache() - # Transformer (LTX2VideoTransformer3DModel handles all block logic) - self.transformer = LTX2VideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer" - ) + # ------------------------------------------------------------------ + # Initialisation + # ------------------------------------------------------------------ - # VAEs - self.vae = AutoencoderKLLTX2Video.from_pretrained( - model_id, subfolder="vae" - ) - self.audio_vae = AutoencoderKLLTX2Audio.from_pretrained( - model_id, subfolder="audio_vae" - ) + def _initialize_network(self, model_id: str, load_pretrained: bool) -> None: + """Initialize the transformer and supporting modules.""" + in_meta_context = self._is_in_meta_context() + should_load_weights = load_pretrained and (not in_meta_context) - # Vocoder (mel spectrogram → waveform) - self.vocoder = LTX2Vocoder.from_pretrained( - model_id, subfolder="vocoder" + if should_load_weights: + logger.info("Loading LTX-2 transformer from pretrained") + self.transformer: LTX2VideoTransformer3DModel = LTX2VideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer" + ) + else: + config = LTX2VideoTransformer3DModel.load_config(model_id, subfolder="transformer") + if in_meta_context: + logger.info( + "Initializing LTX-2 transformer on meta device " + "(zero memory, will receive weights via FSDP sync)" + ) + else: + logger.info("Initializing LTX-2 transformer from config (no pretrained weights)") + logger.warning("LTX-2 transformer being initialized from config. No weights are loaded!") + self.transformer: LTX2VideoTransformer3DModel = LTX2VideoTransformer3DModel.from_config(config) + + # inner_dim = num_attention_heads * attention_head_dim + inner_dim = ( + self.transformer.config.num_attention_heads + * self.transformer.config.attention_head_dim ) - # Scheduler + # Add logvar_linear for uncertainty weighting (DMD2 / f-distill) + # temb mean has shape [B, inner_dim] → logvar scalar per sample + self.transformer.logvar_linear = nn.Linear(inner_dim, 1) + logger.info(f"Added logvar_linear ({inner_dim} → 1) to LTX-2 transformer") + + # Connectors: top-level sibling of transformer (NOT nested inside it) + if should_load_weights: + self.connectors: LTX2TextConnectors = LTX2TextConnectors.from_pretrained( + model_id, subfolder="connectors" + ) + else: + # Connectors are lightweight; always load if pretrained is skipped for the transformer + logger.warning("Skipping connector pretrained load (meta context or load_pretrained=False)") + self.connectors = None # will be loaded lazily via init_preprocessors + + # Cache compression ratios used by forward() and sample() + if should_load_weights: + # VAEs (needed for sample(); not for the training forward pass) + self.vae: AutoencoderKLLTX2Video = AutoencoderKLLTX2Video.from_pretrained( + model_id, subfolder="vae" + ) + self.vae.eval().requires_grad_(False) + + self.audio_vae: AutoencoderKLLTX2Audio = AutoencoderKLLTX2Audio.from_pretrained( + model_id, subfolder="audio_vae" + ) + self.audio_vae.eval().requires_grad_(False) + + self.vocoder: LTX2Vocoder = LTX2Vocoder.from_pretrained( + model_id, subfolder="vocoder" + ) + self.vocoder.eval().requires_grad_(False) + + self._cache_vae_constants() + + # Scheduler (used in sample()) self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( model_id, subfolder="scheduler" ) - # Cache compression ratios + def _cache_vae_constants(self) -> None: + """Cache VAE spatial/temporal compression constants for use in forward() / sample().""" self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio self.transformer_spatial_patch_size = self.transformer.config.patch_size self.transformer_temporal_patch_size = self.transformer.config.patch_size_t - # Audio VAE constants. - # sample_rate / mel_hop_length live in config; the compression ratios are - # instance attributes (LATENT_DOWNSAMPLE_FACTOR = 4) set in __init__, not in config. - self.audio_sampling_rate = self.audio_vae.config.sample_rate # 16000 - self.audio_hop_length = self.audio_vae.config.mel_hop_length # 160 - self.audio_vae_temporal_compression = self.audio_vae.temporal_compression_ratio # 4 - self.audio_vae_mel_compression_ratio = self.audio_vae.mel_compression_ratio # 4 + self.audio_sampling_rate = self.audio_vae.config.sample_rate + self.audio_hop_length = self.audio_vae.config.mel_hop_length + self.audio_vae_temporal_compression = self.audio_vae.temporal_compression_ratio + self.audio_vae_mel_compression_ratio = self.audio_vae.mel_compression_ratio - self._initialized = True + # ------------------------------------------------------------------ + # Preprocessor initialisation (lazy, matches Flux / Wan pattern) + # ------------------------------------------------------------------ def init_preprocessors(self): - """No-op placeholder (preprocessors are initialised in _initialize_network).""" - pass + """Initialize text encoder and connectors.""" + if not hasattr(self, "text_encoder") or self.text_encoder is None: + self.init_text_encoder() + if self.connectors is None: + self.connectors = LTX2TextConnectors.from_pretrained( + self.model_id, subfolder="connectors" + ) + + def init_text_encoder(self): + """Initialize the Gemma3 text encoder for LTX-2.""" + self.text_encoder = LTX2TextEncoder(model_id=self.model_id) # ------------------------------------------------------------------ - # Forward pass (single denoising step) + # Device movement + # ------------------------------------------------------------------ + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + if hasattr(self, "text_encoder") and self.text_encoder is not None: + self.text_encoder.to(*args, **kwargs) + if hasattr(self, "connectors") and self.connectors is not None: + self.connectors.to(*args, **kwargs) + if hasattr(self, "vae") and self.vae is not None: + self.vae.to(*args, **kwargs) + if hasattr(self, "audio_vae") and self.audio_vae is not None: + self.audio_vae.to(*args, **kwargs) + if hasattr(self, "vocoder") and self.vocoder is not None: + self.vocoder.to(*args, **kwargs) + return self + + # ------------------------------------------------------------------ + # FSDP + # ------------------------------------------------------------------ + + def fully_shard(self, **kwargs): + """Fully shard the LTX-2 transformer for FSDP2. + + Shards self.transformer (not self) to avoid ABC __class__ assignment issues. + """ + if self.transformer.gradient_checkpointing: + self.transformer.disable_gradient_checkpointing() + apply_fsdp_checkpointing( + self.transformer, + check_fn=lambda block: isinstance(block, LTX2VideoTransformerBlock), + ) + logger.info("Applied FSDP activation checkpointing to LTX-2 transformer blocks") + + for block in self.transformer.transformer_blocks: + fully_shard(block, **kwargs) + fully_shard(self.transformer, **kwargs) + + # ------------------------------------------------------------------ + # reset_parameters (required for FSDP meta device init) + # ------------------------------------------------------------------ + + def reset_parameters(self): + """Reinitialise parameters after meta device materialisation (FSDP2).""" + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) + + super().reset_parameters() + logger.debug("Reinitialized LTX-2 parameters") + + # ------------------------------------------------------------------ + # Audio latent sizing helper (shared by forward and sample) + # ------------------------------------------------------------------ + + def _compute_audio_shape( + self, latent_f: int, fps: float, device: torch.device, dtype: torch.dtype + ) -> Tuple[int, int, int]: + """ + Compute audio latent dimensions from video latent frame count. + + Returns (audio_num_frames, latent_mel_bins, num_audio_ch). + """ + pixel_frames = (latent_f - 1) * self.vae_temporal_compression_ratio + 1 + duration_s = pixel_frames / fps + + audio_latents_per_second = ( + self.audio_sampling_rate + / self.audio_hop_length + / float(self.audio_vae_temporal_compression) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + num_mel_bins = self.audio_vae.config.mel_bins + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_audio_ch = self.audio_vae.config.latent_channels + return audio_num_frames, latent_mel_bins, num_audio_ch + + # ------------------------------------------------------------------ + # forward() — video-only distillation interface # ------------------------------------------------------------------ def forward( self, - # packed video latents [B, T_v, C_v] - hidden_states: torch.Tensor, - # packed audio latents [B, T_a, C_a] - audio_hidden_states: torch.Tensor, - # already-projected by connectors - encoder_hidden_states: torch.Tensor, # video text embeds [B, T_t, D_v] - audio_encoder_hidden_states: torch.Tensor, # audio text embeds [B, T_t, D_a] - encoder_attention_mask: torch.Tensor, # [B, T_t] - timestep: torch.Tensor, - # spatial metadata for RoPE - num_frames: int, - height: int, - width: int, - audio_num_frames: int, + x_t: torch.Tensor, + t: torch.Tensor, + condition: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + r: Optional[torch.Tensor] = None, # unused, kept for API compatibility fps: float = 24.0, - video_coords: torch.Tensor | None = None, - audio_coords: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + return_features_early: bool = False, + feature_indices: Optional[Set[int]] = None, + return_logvar: bool = False, + fwd_pred_type: Optional[str] = None, + **fwd_kwargs, + ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + """Forward pass for distillation — video latents in, video latents out. + + Audio latents are generated as random noise internally so the joint + audio-video transformer runs normally, but only the video prediction is + returned and used for loss computation. + + Args: + x_t: Video latents [B, C, F, H, W]. + t: Timestep [B] in [0, 1]. + condition: Tuple of (prompt_embeds [B, T, D], attention_mask [B, T]) + from LTX2TextEncoder.encode(). + r: Unused (kept for FastGen API compatibility). + fps: Frames per second (needed for RoPE coordinate computation). + return_features_early: Return video features as soon as collected. + feature_indices: Set of transformer block indices to extract video features from. + return_logvar: Return log-variance estimate alongside the output. + fwd_pred_type: Override prediction type. + + Returns: + Normal: video_out [B, C, F, H, W] + With features: (video_out, List[video_feature_tensors]) + Early exit: List[video_feature_tensors] + With logvar: (above, logvar [B, 1]) """ - Single denoising step through the transformer. + if feature_indices is None: + feature_indices = set() + if return_features_early and len(feature_indices) == 0: + return [] - Returns - ------- - (noise_pred_video, noise_pred_audio) both in packed token format. - """ - noise_pred_video, noise_pred_audio = self.transformer( + if fwd_pred_type is None: + fwd_pred_type = self.net_pred_type + else: + assert fwd_pred_type in NET_PRED_TYPES, f"{fwd_pred_type} is not supported" + + batch_size = x_t.shape[0] + _, _, latent_f, latent_h, latent_w = x_t.shape + + # Unpack text conditioning + prompt_embeds, attention_mask = condition + + # ---- Run connectors to get per-modality encoder hidden states ---- + # attention_mask from tokenizer is binary [B, T]; convert to additive bias + additive_mask = (1 - attention_mask.to(prompt_embeds.dtype)) * -1_000_000.0 + connector_video_embeds, connector_audio_embeds, connector_attn_mask = self.connectors( + prompt_embeds, additive_mask, additive_mask=True + ) + + # ---- Timestep: [B] scalar per sample, matching pipeline_ltx2.py ---- + # time_embed calls .flatten() then views back to [B, 1, D] internally. + # Do NOT expand to per-token here — that is handled inside time_embed. + timestep = t.to(x_t.dtype).expand(batch_size) # [B] + + # ---- Pack video latents ---- + hidden_states = _pack_latents( + x_t, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + # ---- Audio latents: random noise (not trained, just needed to run the joint transformer) ---- + audio_num_frames, latent_mel_bins, num_audio_ch = self._compute_audio_shape( + latent_f, fps, x_t.device, x_t.dtype + ) + audio_latents = torch.randn( + batch_size, num_audio_ch, audio_num_frames, latent_mel_bins, + device=x_t.device, dtype=x_t.dtype, + ) + audio_hidden_states = _pack_audio_latents(audio_latents) + + # ---- RoPE coordinates (pre-computed once, reused in transformer) ---- + video_coords = self.transformer.rope.prepare_video_coords( + batch_size, latent_f, latent_h, latent_w, x_t.device, fps=fps + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, x_t.device + ) + + # ---- Transformer forward (our patched classify_forward) ---- + model_outputs = self.transformer( hidden_states=hidden_states, audio_hidden_states=audio_hidden_states, - encoder_hidden_states=encoder_hidden_states, - audio_encoder_hidden_states=audio_encoder_hidden_states, + encoder_hidden_states=connector_video_embeds, + audio_encoder_hidden_states=connector_audio_embeds, + encoder_attention_mask=connector_attn_mask, + audio_encoder_attention_mask=connector_attn_mask, timestep=timestep, - encoder_attention_mask=encoder_attention_mask, - audio_encoder_attention_mask=encoder_attention_mask, # same mask for audio - num_frames=num_frames, - height=height, - width=width, + num_frames=latent_f, + height=latent_h, + width=latent_w, fps=fps, audio_num_frames=audio_num_frames, video_coords=video_coords, audio_coords=audio_coords, - return_dict=False, + return_features_early=return_features_early, + feature_indices=feature_indices, + return_logvar=return_logvar, + ) + + # ---- Early exit: list of video feature tensors ---- + if return_features_early: + return model_outputs # List[Tensor], each [B, T_v, D_v] + + # ---- Unpack logvar if requested ---- + if return_logvar: + out, logvar = model_outputs[0], model_outputs[1] + else: + out = model_outputs + + # ---- Extract video prediction only; discard audio ---- + if len(feature_indices) == 0: + # out is (video_output, audio_output) + video_packed = out[0] # [B, T_v, C_packed] + features = None + else: + # out is [(video_output, audio_output), features] + video_packed = out[0][0] # [B, T_v, C_packed] + features = out[1] # List[Tensor] + + # ---- Unpack video tokens → [B, C, F, H, W] ---- + video_out = _unpack_latents( + video_packed, latent_f, latent_h, latent_w, + self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, ) - return noise_pred_video, noise_pred_audio + + # ---- Convert model output to requested prediction type ---- + video_out = self.noise_scheduler.convert_model_output( + x_t, video_out, t, + src_pred_type=self.net_pred_type, + target_pred_type=fwd_pred_type, + ) + + # ---- Re-pack output following FastGen convention ---- + if features is not None: + out = [video_out, features] + else: + out = video_out + + if return_logvar: + return out, logvar + return out # ------------------------------------------------------------------ - # Sampling loop + # sample() — full denoising loop for inference + # Follows pipeline_ltx2.py exactly (verified working logic preserved) # ------------------------------------------------------------------ @torch.no_grad() def sample( self, noise: torch.Tensor, - condition: tuple[torch.Tensor, torch.Tensor], - neg_condition: tuple[torch.Tensor, torch.Tensor] | None = None, + condition: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + neg_condition: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, guidance_scale: float = 4.0, num_steps: int = 40, fps: float = 24.0, - frame_rate: float | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + frame_rate: Optional[float] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Run the full denoising loop for text-to-video+audio generation. Follows pipeline_ltx2.py exactly: - - latents kept in float32 throughout; only cast to prompt dtype for transformer - - noise predictions cast to float32 before CFG and scheduler step - - connectors called ONCE on combined [uncond, cond] batch (not twice) + - latents kept in float32 throughout + - connectors called ONCE on combined [uncond, cond] batch - audio duration derived from pixel-frame count, not latent frames - transformer wrapped in cache_context("cond_uncond") @@ -419,46 +950,26 @@ def sample( device = noise.device B, C, latent_f, latent_h, latent_w = noise.shape - # ---------------------------------------------------------------- - # Latent dimensions - # ---------------------------------------------------------------- - # Audio duration uses pixel-frame count (pipeline line 1003): - # duration_s = num_frames / frame_rate - # where num_frames is the original pixel count. For a causal VAE: - # pixel_frames = (latent_f - 1) * temporal_compression + 1 - pixel_frames = (latent_f - 1) * self.vae_temporal_compression_ratio + 1 - duration_s = pixel_frames / fps - - audio_latents_per_second = ( - self.audio_sampling_rate - / self.audio_hop_length - / float(self.audio_vae_temporal_compression) + # ---- Audio shape ---- + audio_num_frames, latent_mel_bins, num_audio_ch = self._compute_audio_shape( + latent_f, fps, device, torch.float32 ) - audio_num_frames = round(duration_s * audio_latents_per_second) - num_mel_bins = self.audio_vae.config.mel_bins - latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_audio_ch = self.audio_vae.config.latent_channels + num_mel_bins = self.audio_vae.config.mel_bins - # ---------------------------------------------------------------- - # Pack latents — keep in float32 to match pipeline - # ---------------------------------------------------------------- + # ---- Pack latents (float32 throughout, matching pipeline) ---- video_latents = _pack_latents( noise.float(), self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) - - audio_shape = (B, num_audio_ch, audio_num_frames, latent_mel_bins) - audio_latents = torch.randn(audio_shape, device=device, dtype=torch.float32) + audio_latents = torch.randn( + B, num_audio_ch, audio_num_frames, latent_mel_bins, + device=device, dtype=torch.float32 + ) audio_latents = _pack_audio_latents(audio_latents) - # ---------------------------------------------------------------- - # Text conditioning — run connectors ONCE on combined [uncond, cond] - # batch, exactly as the pipeline does (lines 959-966) - # ---------------------------------------------------------------- + # ---- Text conditioning — connectors called ONCE on combined [uncond, cond] ---- prompt_embeds, attention_mask = condition - if do_cfg: neg_embeds, neg_mask = neg_condition - # Stack [uncond, cond] before connectors — single forward pass combined_embeds = torch.cat([neg_embeds, prompt_embeds], dim=0) combined_mask = torch.cat([neg_mask, attention_mask], dim=0) else: @@ -470,10 +981,7 @@ def sample( combined_embeds, additive_mask, additive_mask=True ) - # ---------------------------------------------------------------- - # Pre-compute RoPE coordinates (pipeline lines 1078-1087) - # Compute for single batch, then repeat for CFG - # ---------------------------------------------------------------- + # ---- Pre-compute RoPE coordinates ---- video_coords = self.transformer.rope.prepare_video_coords( B, latent_f, latent_h, latent_w, device, fps=fps ) @@ -484,11 +992,8 @@ def sample( video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) - # ---------------------------------------------------------------- - # Scheduler timesteps (pipeline lines 1042-1067) - # ---------------------------------------------------------------- + # ---- Scheduler timesteps ---- sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps) - video_seq_len = latent_f * latent_h * latent_w mu = _calculate_shift( video_seq_len, @@ -497,63 +1002,67 @@ def sample( self.scheduler.config.get("base_shift", 0.95), self.scheduler.config.get("max_shift", 2.05), ) - audio_scheduler = copy.deepcopy(self.scheduler) _retrieve_timesteps(audio_scheduler, num_steps, device, sigmas=sigmas, mu=mu) timesteps, num_steps = _retrieve_timesteps(self.scheduler, num_steps, device, sigmas=sigmas, mu=mu) - # ---------------------------------------------------------------- - # Denoising loop (pipeline lines 1091-1154) - # ---------------------------------------------------------------- - # prompt_embeds.dtype drives the cast for transformer input prompt_dtype = connector_video_embeds.dtype + # ---- Token counts after packing ---- + num_video_tokens = video_latents.shape[1] # [B, T_v, C] + num_audio_tokens = audio_latents.shape[1] # [B, T_a, C] + + # ---- Denoising loop ---- for t in timesteps: - # Cast latents to prompt dtype for transformer; keep float32 for scheduler latent_input = torch.cat([video_latents] * 2) if do_cfg else video_latents audio_latent_input = torch.cat([audio_latents] * 2) if do_cfg else audio_latents latent_input = latent_input.to(prompt_dtype) audio_latent_input = audio_latent_input.to(prompt_dtype) - timestep = t.expand(latent_input.shape[0]) + # Scale timestep and expand to per-token shape. + # The scheduler yields sigmas/timesteps that time_embed expects directly — + # LTX2AdaLayerNormSingle multiplies by timestep_scale_multiplier internally. + bs_input = latent_input.shape[0] + t_base = t.to(prompt_dtype).unsqueeze(0).expand(bs_input) # [B] + timestep = t_base.unsqueeze(1).expand(bs_input, num_video_tokens) # [B, T_v] + audio_timestep = t_base.unsqueeze(1).expand(bs_input, num_audio_tokens)# [B, T_a] - # Wrap in cache_context as the pipeline does (CacheMixin optimisation) with self.transformer.cache_context("cond_uncond"): - noise_pred_video, noise_pred_audio = self.forward( + # classify_forward returns (video_output, audio_output) when + # feature_indices is empty and return_features_early is False. + # Note: no return_dict kwarg — classify_forward does not accept it. + model_out = self.transformer( hidden_states=latent_input, audio_hidden_states=audio_latent_input, encoder_hidden_states=connector_video_embeds, audio_encoder_hidden_states=connector_audio_embeds, encoder_attention_mask=connector_attn_mask, + audio_encoder_attention_mask=connector_attn_mask, timestep=timestep, + audio_timestep=audio_timestep, num_frames=latent_f, height=latent_h, width=latent_w, - audio_num_frames=audio_num_frames, fps=fps, + audio_num_frames=audio_num_frames, video_coords=video_coords, audio_coords=audio_coords, ) + noise_pred_video, noise_pred_audio = model_out - # Cast noise preds to float32 before CFG and scheduler step - # (pipeline lines 1127-1128) noise_pred_video = noise_pred_video.float() noise_pred_audio = noise_pred_audio.float() if do_cfg: video_uncond, video_cond = noise_pred_video.chunk(2) noise_pred_video = video_uncond + guidance_scale * (video_cond - video_uncond) - audio_uncond, audio_cond = noise_pred_audio.chunk(2) noise_pred_audio = audio_uncond + guidance_scale * (audio_cond - audio_uncond) - # Scheduler steps operate on float32 latents video_latents = self.scheduler.step(noise_pred_video, t, video_latents, return_dict=False)[0] audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] - # ---------------------------------------------------------------- - # Unpack and denormalise (pipeline lines 1172-1187) - # ---------------------------------------------------------------- + # ---- Unpack and denormalise ---- video_latents = _unpack_latents( video_latents, latent_f, latent_h, latent_w, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, @@ -563,8 +1072,6 @@ def sample( self.vae.config.scaling_factor, ) - # Denormalise while still packed [B, L, 128], then unpack to [B, C, L, M] - # (pipeline lines 1184-1187: _denormalize then _unpack) audio_latents = _denormalize_audio_latents( audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std ) diff --git a/tests/test_ltx2.py b/tests/test_ltx2.py index d5769a9..2d53c44 100644 --- a/tests/test_ltx2.py +++ b/tests/test_ltx2.py @@ -18,7 +18,9 @@ def test_ltx2_generation(): model.eval() # 2. Prepare Prompts - prompt = "A high-performance sports car racing through a city street at night, neon lights reflecting off wet asphalt, motion blur streaking past the camera. The camera tracks low and close to the car as it accelerates aggressively, tires gripping the road, exhaust heat shimmering. Realistic lighting, cinematic depth of field, ultra-detailed textures, dynamic reflections, dramatic shadows, 4K realism, film-grade color grading." + # prompt = "A high-performance sports car racing through a city street at night, neon lights reflecting off wet asphalt, motion blur streaking past the camera. The camera tracks low and close to the car as it accelerates aggressively, tires gripping the road, exhaust heat shimmering. Realistic lighting, cinematic depth of field, ultra-detailed textures, dynamic reflections, dramatic shadows, 4K realism, film-grade color grading." + prompt = "A tight close-up shot of a musician's hands playing a grand piano. The audio is a fast-paced, uplifting classical piano sonata. The pressing of the black and white keys visually syncs with the rapid flurry of high-pitched musical notes. There is a slight echo, suggesting the piano is in a large, empty concert hall." + prompt = "A street performer sitting on a brick stoop, strumming an acoustic guitar. The audio features a warm, indie-folk guitar chord progression. Over the guitar, a smooth, soulful human voice sings a slow, bluesy melody without words, just melodic humming and 'oohs'. The rhythmic strumming of the guitar perfectly matches the tempo of the vocal melody. Faint city traffic can be heard quietly in the deep background." negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" print(f"Encoding prompt: {prompt}") @@ -100,4 +102,4 @@ def test_ltx2_generation(): if __name__ == "__main__": - test_ltx2_generation() + test_ltx2_generation() \ No newline at end of file From 740564ff4ac285c898b258cd928185c68090be12 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Wed, 25 Feb 2026 23:15:10 +0800 Subject: [PATCH 06/13] add LTX2 pipeline with audio branch optional --- fastgen/networks/Flux/pipeline_ltx2.py | 1083 +++++++++++++++++ fastgen/networks/Flux/test_ltx2_pipeline.py | 60 + fastgen/networks/Flux/transformer_ltx2.py | 1203 +++++++++++++++++++ 3 files changed, 2346 insertions(+) create mode 100644 fastgen/networks/Flux/pipeline_ltx2.py create mode 100644 fastgen/networks/Flux/test_ltx2_pipeline.py create mode 100644 fastgen/networks/Flux/transformer_ltx2.py diff --git a/fastgen/networks/Flux/pipeline_ltx2.py b/fastgen/networks/Flux/pipeline_ltx2.py new file mode 100644 index 0000000..28f3fa8 --- /dev/null +++ b/fastgen/networks/Flux/pipeline_ltx2.py @@ -0,0 +1,1083 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.loaders import FromSingleFileMixin, LTX2LoraLoaderMixin +from diffusers.models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from diffusers.models.transformers import LTX2VideoTransformer3DModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors +from diffusers.pipelines.ltx2.pipeline_output import LTX2PipelineOutput +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2Pipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + + >>> pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video, audio = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="np", + ... return_dict=False, + ... ) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTX2VideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Gemma3ForConditionalGeneration`]): + Gemma3 text encoder. + tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`): + Tokenizer for the text encoder. + connectors ([`LTX2TextConnectors`]): + Text connector stack used to adapt text encoder hidden states for the video and audio branches. + vocoder ([`LTX2Vocoder`]): + Vocoder to convert mel spectrograms to waveforms. + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + @staticmethod + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: str | torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + mask = token_indices < sequence_lengths[:, None] + elif padding_side == "left": + start_indices = seq_len - sequence_lengths[:, None] + mask = token_indices >= start_indices + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] + + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + + @staticmethod + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + @staticmethod + def _create_noised_state( + latents: torch.Tensor, noise_scale: float | torch.Tensor, generator: torch.Generator | None = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + + @staticmethod + def _pack_audio_latents( + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None + ) -> torch.Tensor: + if patch_size is not None and patch_size_t is not None: + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + latents = latents.transpose(1, 2).flatten(2, 3) + return latents + + @staticmethod + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: int | None = None, + patch_size_t: int | None = None, + ) -> torch.Tensor: + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 768, + num_frames: int = 121, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + return latents + + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + audio_latent_length: int = 1, + num_mel_bins: int = 64, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 4: + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + sigmas: list[float] | None = None, + timesteps: list[int] = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + noise_scale: float = 0.0, + num_videos_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + audio_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + generate_audio: bool = True, # NEW: set False to skip all audio ops + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate. + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. + sigmas (`List[float]`, *optional*): + Custom sigmas for the denoising process. + timesteps (`list[int]`, *optional*): + Custom timesteps for the denoising process. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Classifier-free guidance scale. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor. + noise_scale (`float`, *optional*, defaults to `0.0`): + Noise interpolation factor applied to latents before denoising. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + Generator(s) for deterministic generation. + latents (`torch.Tensor`, *optional*): + Pre-generated video latents. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated audio latents. Ignored when ``generate_audio=False``. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + Noise scale at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + Output format for the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a structured output. + attention_kwargs (`dict`, *optional*): + Extra kwargs passed to the attention processor. + callback_on_step_end (`Callable`, *optional*): + Callback called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*, defaults to `["latents"]`): + Tensor inputs for the step-end callback. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length for the prompt. + generate_audio (`bool`, *optional*, defaults to `True`): + Whether to generate audio alongside the video. Set to ``False`` to skip all audio + operations (latent preparation, denoising, decoding). When ``False``, the returned + ``audio`` value will be ``None``. Automatically forced to ``False`` when the + transformer was built with ``audio_enabled=False``. + + Examples: + + Returns: + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is the generated video frames + and the second element is the generated audio (or ``None``). + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # Honour the transformer's construction-time gate — if it has no audio modules, + # we cannot generate audio regardless of what the caller requests. + run_audio = generate_audio and self.transformer.config.audio_enabled + + # 1. Check inputs + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + # 4. Prepare latent variables + + # --- Video latents (always prepared) --- + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], " + "`latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape + elif latents.ndim == 3: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either " + "[batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + noise_scale, + torch.float32, + device, + generator, + latents, + ) + + # --- Audio latents (only prepared when run_audio) --- + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + + if run_audio: + if audio_latents is not None: + if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], " + "`audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape + elif audio_latents.ndim == 3: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent " + f"dims cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape " + "is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." + ) + + audio_latents = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + else: + audio_latents = None + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + + if run_audio: + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps(audio_scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, mu=mu) + else: + audio_scheduler = None + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions (RoPE coords) + # Video coords — always computed + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + if self.do_classifier_free_guidance: + video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) + + # Audio coords — only computed when run_audio + if run_audio: + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + if self.do_classifier_free_guidance: + audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) + else: + audio_coords = None + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + # Video input (always) + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + + # Audio input (only when run_audio) + if run_audio: + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + else: + audio_latent_model_input = None + + timestep = t.expand(latent_model_input.shape[0]) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds if run_audio else None, + timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask if run_audio else None, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames if run_audio else None, + video_coords=video_coords, + audio_coords=audio_coords, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + + noise_pred_video = noise_pred_video.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + if self.guidance_rescale > 0: + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + + latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] + + # Audio denoising step — only when run_audio + if run_audio: + noise_pred_audio = noise_pred_audio.float() + if self.do_classifier_free_guidance: + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + if self.guidance_rescale > 0: + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Decode video (always) + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + # 9. Decode audio (only when run_audio) + if run_audio: + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents if run_audio else None + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + if run_audio: + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + else: + audio = None + + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) \ No newline at end of file diff --git a/fastgen/networks/Flux/test_ltx2_pipeline.py b/fastgen/networks/Flux/test_ltx2_pipeline.py new file mode 100644 index 0000000..7463eb3 --- /dev/null +++ b/fastgen/networks/Flux/test_ltx2_pipeline.py @@ -0,0 +1,60 @@ +import torch +import numpy as np +import imageio + +from pipeline_ltx2 import LTX2Pipeline +from transformer_ltx2 import LTX2VideoTransformer3DModel + +device = "cuda:0" +width = 768 +height = 512 + +# 1. Load the full pipeline (vae, scheduler, text_encoder, etc.) +pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + +# 2. Build a video-only transformer — no audio weights allocated +config = pipe.transformer.config +transformer = LTX2VideoTransformer3DModel.from_config(config, audio_enabled=False) +transformer = transformer.to(torch.bfloat16) + +# 3. Copy video weights from the pretrained AV transformer, skip audio keys +state_dict = pipe.transformer.state_dict() +missing, unexpected = transformer.load_state_dict(state_dict, strict=False) +assert len(missing) == 0, f"Missing video keys: {missing}" +print(f"Skipped {len(unexpected)} audio keys from checkpoint") + +# 4. Swap in the video-only transformer and free the original AV one +del pipe.transformer +pipe.transformer = transformer + +pipe.to(device) # fully on GPU — fastest, needs ~24GB+ VRAM +# pipe.enable_model_cpu_offload(device=device) # offloads whole models between steps, needs ~12GB VRAM + +prompt = "A beautiful sunset over the ocean" +negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + +frame_rate = 24.0 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=40, + guidance_scale=4.0, + output_type="np", + return_dict=False, + generate_audio=False, # audio is None, no audio ops run +) +print(audio) + +# video[0] is (num_frames, height, width, 3) float in [0, 1] +frames = (video[0] * 255).clip(0, 255).astype(np.uint8) + +output_path = "ltx2_video_only.mp4" +with imageio.get_writer(output_path, fps=frame_rate, codec="libx264", quality=8) as writer: + for frame in frames: + writer.append_data(frame) + +print(f"Saved to {output_path}") \ No newline at end of file diff --git a/fastgen/networks/Flux/transformer_ltx2.py b/fastgen/networks/Flux/transformer_ltx2.py new file mode 100644 index 0000000..9d331b1 --- /dev/null +++ b/fastgen/networks/Flux/transformer_ltx2.py @@ -0,0 +1,1203 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.utils import BaseOutput, apply_lora_scale, is_torch_version, logging +from diffusers.models._modeling_parallel import ContextParallelInput, ContextParallelOutput +from diffusers.models.attention import AttentionMixin, AttentionModuleMixin, FeedForward +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.cache_utils import CacheMixin +from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +def apply_split_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + + x_dtype = x.dtype + needs_reshape = False + if x.ndim != 4 and cos.ndim == 4: + b, h, t, _ = cos.shape + x = x.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + last = x.shape[-1] + if last % 2 != 0: + raise ValueError(f"Expected x.shape[-1] to be even for split rotary, got {last}.") + r = last // 2 + + split_x = x.reshape(*x.shape[:-1], 2, r).float() + first_x = split_x[..., :1, :] + second_x = split_x[..., 1:, :] + + cos_u = cos.unsqueeze(-2) + sin_u = sin.unsqueeze(-2) + + out = split_x * cos_u + first_out = out[..., :1, :] + second_out = out[..., 1:, :] + + first_out.addcmul_(-sin_u, second_x) + second_out.addcmul_(sin_u, first_x) + + out = out.reshape(*out.shape[:-2], last) + + if needs_reshape: + out = out.swapaxes(1, 2).reshape(b, t, -1) + + out = out.to(dtype=x_dtype) + return out + + +@dataclass +class AudioVisualModelOutput(BaseOutput): + r""" + Holds the output of an audiovisual model which produces both visual (e.g. video) and audio outputs. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on the `encoder_hidden_states` input, representing the visual output + of the model. + audio_sample (`torch.Tensor` or `None`): + The audio output of the audiovisual model. None when audio is disabled. + """ + + sample: "torch.Tensor" # noqa: F821 + audio_sample: "torch.Tensor | None" # noqa: F821 + + +class LTX2AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3) and adapted by the LTX-2.0 + model. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_mod_params (`int`, *optional*, defaults to `6`): + The number of modulation parameters. + use_additional_conditions (`bool`, *optional*, defaults to `False`): + Whether to use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, num_mod_params: int = 6, use_additional_conditions: bool = False): + super().__init__() + self.num_mod_params = num_mod_params + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, self.num_mod_params * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: dict[str, torch.Tensor] | None = None, + batch_size: int | None = None, + hidden_dtype: torch.dtype | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class LTX2AudioVideoAttnProcessor: + r""" + Processor for implementing attention (SDPA) for the LTX-2.0 model. + Supports separate RoPE embeddings for queries and keys (a2v / v2a cross attention). + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if is_torch_version("<", "2.0"): + raise ValueError( + "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." + ) + + def __call__( + self, + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if query_rotary_emb is not None: + if attn.rope_type == "interleaved": + query = apply_interleaved_rotary_emb(query, query_rotary_emb) + key = apply_interleaved_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + elif attn.rope_type == "split": + query = apply_split_rotary_emb(query, query_rotary_emb) + key = apply_split_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LTX2Attention(torch.nn.Module, AttentionModuleMixin): + r""" + Attention class for all LTX-2.0 attention layers. + Supports separate query and key RoPE embeddings for a2v / v2a cross-attention. + """ + + _default_processor_cls = LTX2AudioVideoAttnProcessor + _available_processors = [LTX2AudioVideoAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + kv_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = True, + cross_attention_dim: int | None = None, + out_bias: bool = True, + qk_norm: str = "rms_norm_across_heads", + norm_eps: float = 1e-6, + norm_elementwise_affine: bool = True, + rope_type: str = "interleaved", + processor=None, + ): + super().__init__() + if qk_norm != "rms_norm_across_heads": + raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.") + + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = query_dim + self.heads = heads + self.rope_type = rope_type + + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + hidden_states = self.processor( + self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs + ) + return hidden_states + + +class LTX2VideoTransformerBlock(nn.Module): + r""" + Transformer block used in LTX-2.0. + + Supports two-level audio gating (Option C): + - Construction-time: pass ``audio_dim=None`` to skip all audio module allocation. + - Runtime: pass ``audio_enabled=False`` in ``forward()`` to skip audio ops this step. + + The a2v cross-attention (video attending to audio) is intentionally decoupled from + ``audio_enabled`` — video can still attend to audio as conditioning even when the + audio update branch is disabled, matching the original BasicAVTransformerBlock behaviour. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + # Audio args — all optional. Pass None to build a video-only block. + audio_dim: int | None = None, + audio_num_attention_heads: int | None = None, + audio_attention_head_dim: int | None = None, + audio_cross_attention_dim: int | None = None, + qk_norm: str = "rms_norm_across_heads", + activation_fn: str = "gelu-approximate", + attention_bias: bool = True, + attention_out_bias: bool = True, + eps: float = 1e-6, + elementwise_affine: bool = False, + rope_type: str = "interleaved", + ): + super().__init__() + + # Construction-time gate + self.has_audio = audio_dim is not None + + if self.has_audio: + assert audio_num_attention_heads is not None + assert audio_attention_head_dim is not None + assert audio_cross_attention_dim is not None + + # --- 1. Video Self-Attention (always built) --- + self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # --- 1b. Audio Self-Attention (conditional) --- + if self.has_audio: + self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn1 = LTX2Attention( + query_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # --- 2. Video Prompt Cross-Attention (always built) --- + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn2 = LTX2Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # --- 2b. Audio Prompt Cross-Attention (conditional) --- + if self.has_audio: + self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn2 = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=audio_cross_attention_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # --- 3. Audio-Video Cross-Attention (conditional — both modalities) --- + if self.has_audio: + # a2v: Q=Video, K/V=Audio + self.audio_to_video_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_to_video_attn = LTX2Attention( + query_dim=dim, + cross_attention_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # v2a: Q=Audio, K/V=Video + self.video_to_audio_norm = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.video_to_audio_attn = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # Per-layer cross-attention modulation params + self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) + self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim)) + + # --- 4. Feedforward (video always built, audio conditional) --- + self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + if self.has_audio: + self.audio_norm3 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn) + + # --- 5. AdaLN modulation params (video always, audio conditional) --- + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + if self.has_audio: + self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor | None, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor | None, + temb: torch.Tensor, + temb_audio: torch.Tensor | None, + temb_ca_scale_shift: torch.Tensor | None, + temb_ca_audio_scale_shift: torch.Tensor | None, + temb_ca_gate: torch.Tensor | None, + temb_ca_audio_gate: torch.Tensor | None, + video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + audio_encoder_attention_mask: torch.Tensor | None = None, + a2v_cross_attention_mask: torch.Tensor | None = None, + v2a_cross_attention_mask: torch.Tensor | None = None, + audio_enabled: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + batch_size = hidden_states.size(0) + + # Runtime gates — mirrors BasicAVTransformerBlock exactly + # run_ax: audio updates (self-attn, cross-attn, ffn, v2a) + # run_a2v: video attending to audio — asymmetrically decoupled from audio_enabled + # so video can still use audio as conditioning even when audio_enabled=False + # run_v2a: audio updates from video — tied to run_ax + run_ax = self.has_audio and audio_enabled and audio_hidden_states is not None + run_a2v = self.has_audio and audio_hidden_states is not None + run_v2a = run_ax + + # 1. Video Self-Attention (always runs) + norm_hidden_states = self.norm1(hidden_states) + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape( + batch_size, temb.size(1), num_ada_params, -1 + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=video_rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa + + # 1b. Audio Self-Attention (conditional) + if run_ax: + norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) + num_audio_ada_params = self.audio_scale_shift_table.shape[0] + audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( + batch_size, temb_audio.size(1), num_audio_ada_params, -1 + ) + audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = ( + audio_ada_values.unbind(dim=2) + ) + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa + attn_audio_hidden_states = self.audio_attn1( + hidden_states=norm_audio_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=audio_rotary_emb, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa + + # 2. Video Cross-Attention with text (always runs) + norm_hidden_states = self.norm2(hidden_states) + attn_hidden_states = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_rotary_emb=None, + attention_mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + + # 2b. Audio Cross-Attention with text (conditional) + if run_ax: + norm_audio_hidden_states = self.audio_norm2(audio_hidden_states) + attn_audio_hidden_states = self.audio_attn2( + norm_audio_hidden_states, + encoder_hidden_states=audio_encoder_hidden_states, + query_rotary_emb=None, + attention_mask=audio_encoder_attention_mask, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states + + # 3. Audio-Video Cross-Attention + if run_a2v or run_v2a: + norm_hidden_states_av = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states_av = self.video_to_audio_norm(audio_hidden_states) + + # Video modulation params + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] + video_ca_scale_shift_table = ( + video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype) + + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + video_ca_gate = ( + video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype) + + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) + ).unbind(dim=2) + video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table + a2v_gate = video_ca_gate[0].squeeze(2) + + # Audio modulation params + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + audio_ca_scale_shift_table = ( + audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype) + + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + audio_ca_gate = ( + audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype) + + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) + ).unbind(dim=2) + audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table + v2a_gate = audio_ca_gate[0].squeeze(2) + + # 3a. a2v: Q=Video, K/V=Audio (runs even when audio_enabled=False) + if run_a2v: + mod_norm_hidden_states = ( + norm_hidden_states_av * (1 + video_a2v_ca_scale.squeeze(2)) + + video_a2v_ca_shift.squeeze(2) + ) + mod_norm_audio_hidden_states = ( + norm_audio_hidden_states_av * (1 + audio_a2v_ca_scale.squeeze(2)) + + audio_a2v_ca_shift.squeeze(2) + ) + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + query_rotary_emb=ca_video_rotary_emb, + key_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + # 3b. v2a: Q=Audio, K/V=Video (only when audio branch is fully active) + if run_v2a: + mod_norm_hidden_states = ( + norm_hidden_states_av * (1 + video_v2a_ca_scale.squeeze(2)) + + video_v2a_ca_shift.squeeze(2) + ) + mod_norm_audio_hidden_states = ( + norm_audio_hidden_states_av * (1 + audio_v2a_ca_scale.squeeze(2)) + + audio_v2a_ca_shift.squeeze(2) + ) + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states, + encoder_hidden_states=mod_norm_hidden_states, + query_rotary_emb=ca_audio_rotary_emb, + key_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + + # 4. Video Feedforward (always runs) + norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp + hidden_states = hidden_states + self.ff(norm_hidden_states) * gate_mlp + + # 4b. Audio Feedforward (conditional) + if run_ax: + norm_audio_hidden_states = ( + self.audio_norm3(audio_hidden_states) * (1 + audio_scale_mlp) + audio_shift_mlp + ) + audio_hidden_states = audio_hidden_states + self.audio_ff(norm_audio_hidden_states) * audio_gate_mlp + + # Return None for audio when it didn't run — prevents stale tensor propagation + return hidden_states, audio_hidden_states if run_ax else None + + +class LTX2AudioVideoRotaryPosEmbed(nn.Module): + """ + Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model. + """ + + def __init__( + self, + dim: int, + patch_size: int = 1, + patch_size_t: int = 1, + base_num_frames: int = 20, + base_height: int = 2048, + base_width: int = 2048, + sampling_rate: int = 16000, + hop_length: int = 160, + scale_factors: tuple[int, ...] = (8, 32, 32), + theta: float = 10000.0, + causal_offset: int = 1, + modality: str = "video", + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ) -> None: + super().__init__() + + self.dim = dim + self.patch_size = patch_size + self.patch_size_t = patch_size_t + + if rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + self.rope_type = rope_type + + self.base_num_frames = base_num_frames + self.num_attention_heads = num_attention_heads + + self.base_height = base_height + self.base_width = base_width + + self.sampling_rate = sampling_rate + self.hop_length = hop_length + self.audio_latents_per_second = float(sampling_rate) / float(hop_length) / float(scale_factors[0]) + + self.scale_factors = scale_factors + self.theta = theta + self.causal_offset = causal_offset + + self.modality = modality + if self.modality not in ["video", "audio"]: + raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.") + self.double_precision = double_precision + + def prepare_video_coords( + self, + batch_size: int, + num_frames: int, + height: int, + width: int, + device: torch.device, + fps: float = 24.0, + ) -> torch.Tensor: + grid_f = torch.arange(start=0, end=num_frames, step=self.patch_size_t, dtype=torch.float32, device=device) + grid_h = torch.arange(start=0, end=height, step=self.patch_size, dtype=torch.float32, device=device) + grid_w = torch.arange(start=0, end=width, step=self.patch_size, dtype=torch.float32, device=device) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) + + patch_size = (self.patch_size_t, self.patch_size, self.patch_size) + patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device) + patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) + + latent_coords = torch.stack([grid, patch_ends], dim=-1) + latent_coords = latent_coords.flatten(1, 3) + latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) + + scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device) + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 + pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape) + + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0) + pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps + + return pixel_coords + + def prepare_audio_coords( + self, + batch_size: int, + num_frames: int, + device: torch.device, + shift: int = 0, + ) -> torch.Tensor: + grid_f = torch.arange( + start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device + ) + + audio_scale_factor = self.scale_factors[0] + grid_start_mel = grid_f * audio_scale_factor + grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0) + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + + grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor + grid_end_mel = (grid_end_mel + self.causal_offset - audio_scale_factor).clip(min=0) + grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate + + audio_coords = torch.stack([grid_start_s, grid_end_s], dim=-1) + audio_coords = audio_coords.unsqueeze(0).expand(batch_size, -1, -1) + audio_coords = audio_coords.unsqueeze(1) + return audio_coords + + def prepare_coords(self, *args, **kwargs): + if self.modality == "video": + return self.prepare_video_coords(*args, **kwargs) + elif self.modality == "audio": + return self.prepare_audio_coords(*args, **kwargs) + + def forward( + self, coords: torch.Tensor, device: str | torch.device | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or coords.device + + num_pos_dims = coords.shape[1] + + if coords.ndim == 4: + coords_start, coords_end = coords.chunk(2, dim=-1) + coords = (coords_start + coords_end) / 2.0 + coords = coords.squeeze(-1) + + if self.modality == "video": + max_positions = (self.base_num_frames, self.base_height, self.base_width) + elif self.modality == "audio": + max_positions = (self.base_num_frames,) + grid = torch.stack([coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1).to(device) + num_rope_elems = num_pos_dims * 2 + + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs + freqs = freqs.transpose(-1, -2).flatten(2) + + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + + cos_freqs = torch.swapaxes(cos_freq, 1, 2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) + + return cos_freqs, sin_freqs + + +class LTX2VideoTransformer3DModel( + ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin +): + r""" + A Transformer model for video-like data used in LTX-2.0. + + Supports two-level audio gating (Option C): + - Construction-time: set ``audio_enabled=False`` to build a video-only model with + no audio weights allocated at all. Ideal for finetuning video-only. + - Runtime: pass ``audio_hidden_states=None`` to skip all audio ops for a given + forward pass on a full AV checkpoint. + + When loading a pretrained AV checkpoint into a video-only model, use ``strict=False``: + missing, unexpected = transformer.load_state_dict(state_dict, strict=False) + # unexpected = all audio.* keys — expected, safe to ignore + # missing = should be [] — all video keys present + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["LTX2VideoTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + "rope": { + 0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + 1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } + + @register_to_config + def __init__( + self, + in_channels: int = 128, + out_channels: int | None = 128, + patch_size: int = 1, + patch_size_t: int = 1, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + cross_attention_dim: int = 4096, + vae_scale_factors: tuple[int, int, int] = (8, 32, 32), + pos_embed_max_pos: int = 20, + base_height: int = 2048, + base_width: int = 2048, + audio_in_channels: int = 128, + audio_out_channels: int | None = 128, + audio_patch_size: int = 1, + audio_patch_size_t: int = 1, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_cross_attention_dim: int = 2048, + audio_scale_factor: int = 4, + audio_pos_embed_max_pos: int = 20, + audio_sampling_rate: int = 16000, + audio_hop_length: int = 160, + num_layers: int = 48, + activation_fn: str = "gelu-approximate", + qk_norm: str = "rms_norm_across_heads", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 3840, + attention_bias: bool = True, + attention_out_bias: bool = True, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + causal_offset: int = 1, + timestep_scale_multiplier: int = 1000, + cross_attn_timestep_scale_multiplier: int = 1000, + rope_type: str = "interleaved", + audio_enabled: bool = True, # <-- NEW: construction-time gate + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + audio_out_channels = audio_out_channels or audio_in_channels + inner_dim = num_attention_heads * attention_head_dim + audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + + # 1. Patchification input projections + self.proj_in = nn.Linear(in_channels, inner_dim) + if audio_enabled: + self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim) + + # 2. Prompt embeddings + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + if audio_enabled: + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=audio_inner_dim + ) + + # 3. Timestep modulation + self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False) + if audio_enabled: + self.audio_time_embed = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=6, use_additional_conditions=False + ) + self.av_cross_attn_video_scale_shift = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=4, use_additional_conditions=False + ) + self.av_cross_attn_audio_scale_shift = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=4, use_additional_conditions=False + ) + self.av_cross_attn_video_a2v_gate = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=1, use_additional_conditions=False + ) + self.av_cross_attn_audio_v2a_gate = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=1, use_additional_conditions=False + ) + + # 3.3. Output layer modulation params + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + if audio_enabled: + self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5) + + # 4. RoPE — video always built, audio ropes only when audio_enabled + self.rope = LTX2AudioVideoRotaryPosEmbed( + dim=inner_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + scale_factors=vae_scale_factors, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + if audio_enabled: + self.audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_inner_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=audio_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + scale_factors=[audio_scale_factor], + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads, + ) + cross_attn_pos_embed_max_pos = max(pos_embed_max_pos, audio_pos_embed_max_pos) + self.cross_attn_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads, + ) + + # 5. Transformer blocks — pass None dims when audio disabled + self.transformer_blocks = nn.ModuleList( + [ + LTX2VideoTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + audio_dim=audio_inner_dim if audio_enabled else None, + audio_num_attention_heads=audio_num_attention_heads if audio_enabled else None, + audio_attention_head_dim=audio_attention_head_dim if audio_enabled else None, + audio_cross_attention_dim=audio_cross_attention_dim if audio_enabled else None, + qk_norm=qk_norm, + activation_fn=activation_fn, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + # 6. Output layers + self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels) + if audio_enabled: + self.audio_norm_out = nn.LayerNorm(audio_inner_dim, eps=1e-6, elementwise_affine=False) + self.audio_proj_out = nn.Linear(audio_inner_dim, audio_out_channels) + + self.gradient_checkpointing = False + + @apply_lora_scale("attention_kwargs") + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor = None, + audio_encoder_hidden_states: torch.Tensor | None = None, + timestep: torch.LongTensor = None, + audio_timestep: torch.LongTensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + audio_encoder_attention_mask: torch.Tensor | None = None, + num_frames: int | None = None, + height: int | None = None, + width: int | None = None, + fps: float = 24.0, + audio_num_frames: int | None = None, + video_coords: torch.Tensor | None = None, + audio_coords: torch.Tensor | None = None, + attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # Single runtime gate — all audio ops key off this. + # Requires: modules exist (construction-time) AND caller supplied a tensor. + run_audio = self.config.audio_enabled and audio_hidden_states is not None + + audio_timestep = audio_timestep if audio_timestep is not None else timestep + + # Attention mask conversion + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if run_audio and audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: + audio_encoder_attention_mask = ( + 1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype) + ) * -10000.0 + audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + + # 1. RoPE — video always, audio only when run_audio + if video_coords is None: + video_coords = self.rope.prepare_video_coords( + batch_size, num_frames, height, width, hidden_states.device, fps=fps + ) + video_rotary_emb = self.rope(video_coords, device=hidden_states.device) + + if run_audio: + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, audio_hidden_states.device + ) + audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) + video_cross_attn_rotary_emb = self.cross_attn_rope( + video_coords[:, 0:1, :], device=hidden_states.device + ) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope( + audio_coords[:, 0:1, :], device=audio_hidden_states.device + ) + else: + audio_rotary_emb = None + video_cross_attn_rotary_emb = None + audio_cross_attn_rotary_emb = None + + # 2. Patchify + hidden_states = self.proj_in(hidden_states) + if run_audio: + audio_hidden_states = self.audio_proj_in(audio_hidden_states) + + # 3. Timestep embeddings and modulation params + timestep_cross_attn_gate_scale_factor = ( + self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier + ) + + temb, embedded_timestep = self.time_embed( + timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + if run_audio: + temb_audio, audio_embedded_timestep = self.audio_time_embed( + audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) + audio_embedded_timestep = audio_embedded_timestep.view( + batch_size, -1, audio_embedded_timestep.size(-1) + ) + + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( + timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.view( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view( + batch_size, -1, video_cross_attn_a2v_gate.shape[-1] + ) + + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( + audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( + audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view( + batch_size, -1, audio_cross_attn_v2a_gate.shape[-1] + ) + else: + temb_audio = None + audio_embedded_timestep = None + video_cross_attn_scale_shift = None + video_cross_attn_a2v_gate = None + audio_cross_attn_scale_shift = None + audio_cross_attn_v2a_gate = None + + # 4. Prompt embeddings + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + if run_audio: + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view( + batch_size, -1, audio_hidden_states.size(-1) + ) + + # 5. Transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + # Note: _gradient_checkpointing_func only accepts positional tensor args. + # audio_enabled is not passed explicitly here — the block derives + # run_ax from audio_hidden_states being None when run_audio=False. + hidden_states, audio_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + audio_hidden_states, + encoder_hidden_states, + audio_encoder_hidden_states, + temb, + temb_audio, + video_cross_attn_scale_shift, + audio_cross_attn_scale_shift, + video_cross_attn_a2v_gate, + audio_cross_attn_v2a_gate, + video_rotary_emb, + audio_rotary_emb, + video_cross_attn_rotary_emb, + audio_cross_attn_rotary_emb, + encoder_attention_mask, + audio_encoder_attention_mask, + ) + else: + hidden_states, audio_hidden_states = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + audio_enabled=run_audio, + ) + + # 6. Output layers + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) + + if run_audio: + audio_scale_shift_values = ( + self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] + ) + audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) + else: + audio_output = None + + if not return_dict: + return (output, audio_output) + return AudioVisualModelOutput(sample=output, audio_sample=audio_output) \ No newline at end of file From 6107ab0345ca4946a44aec92b0968854c8ac92e1 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Thu, 26 Feb 2026 16:30:07 +0800 Subject: [PATCH 07/13] update --- fastgen/networks/Flux/network.py | 1384 ++++++++++++++---------- fastgen/networks/Flux/pipeline_ltx2.py | 2 +- 2 files changed, 807 insertions(+), 579 deletions(-) diff --git a/fastgen/networks/Flux/network.py b/fastgen/networks/Flux/network.py index efc04b2..9a627a4 100644 --- a/fastgen/networks/Flux/network.py +++ b/fastgen/networks/Flux/network.py @@ -1,336 +1,538 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import os -from typing import Any, Optional, List, Set, Union, Tuple +""" +LTX-2 FastGen network implementation. + +Architecture verified against: + - diffusers/src/diffusers/pipelines/ltx2/pipeline_ltx2.py + - diffusers/src/diffusers/pipelines/ltx2/connectors.py + - diffusers/src/diffusers/models/transformers/transformer_ltx2.py + +Follows the FastGen network pattern established by Flux and Wan: + - Inherits from FastGenNetwork + - Monkey-patches classify_forward onto self.transformer + - forward() handles video-only latent for distillation (audio flows through but is ignored for loss) + - feature_indices extracts video hidden_states only +""" + +import copy import types +from typing import Any, List, Optional, Set, Tuple, Union +import numpy as np import torch -import torch.utils.checkpoint -from torch import dtype +import torch.nn as nn +from diffusers.models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from diffusers.models.transformers import LTX2VideoTransformer3DModel +from diffusers.models.transformers.transformer_ltx2 import LTX2VideoTransformerBlock +from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from torch.distributed.fsdp import fully_shard - -from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL -from diffusers.models import FluxTransformer2DModel -from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock -from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizerFast from fastgen.networks.network import FastGenNetwork from fastgen.networks.noise_schedule import NET_PRED_TYPES -from fastgen.utils.basic_utils import str2bool from fastgen.utils.distributed.fsdp import apply_fsdp_checkpointing import fastgen.utils.logging_utils as logger -class FluxTextEncoder: - """Text encoder for Flux using CLIP and T5 models.""" - - def __init__(self, model_id: str): - # CLIP text encoder - self.tokenizer = CLIPTokenizer.from_pretrained( - model_id, - cache_dir=os.environ["HF_HOME"], - subfolder="tokenizer", - local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), - ) - self.text_encoder = CLIPTextModel.from_pretrained( - model_id, - cache_dir=os.environ["HF_HOME"], - subfolder="text_encoder", - local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), - ) - self.text_encoder.eval().requires_grad_(False) - - # T5 text encoder - self.tokenizer_2 = T5TokenizerFast.from_pretrained( - model_id, - cache_dir=os.environ["HF_HOME"], - subfolder="tokenizer_2", - local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), - ) - self.text_encoder_2 = T5EncoderModel.from_pretrained( - model_id, - cache_dir=os.environ["HF_HOME"], - subfolder="text_encoder_2", - local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), - ) - self.text_encoder_2.eval().requires_grad_(False) - - def encode( - self, - conditioning: Optional[Any] = None, - precision: dtype = torch.float32, - max_sequence_length: int = 512, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Encode text prompts to embeddings. - - Args: - conditioning: Text prompt(s) to encode. - precision: Data type for the output embeddings. - max_sequence_length: Maximum sequence length for T5 tokenization. - - Returns: - Tuple of (pooled_prompt_embeds, prompt_embeds) tensors. - """ - if isinstance(conditioning, str): - conditioning = [conditioning] - - # CLIP encoding for pooled embeddings - text_inputs = self.tokenizer( - conditioning, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - with torch.no_grad(): - text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) - prompt_embeds = self.text_encoder( - text_input_ids, - output_hidden_states=False, - ) - pooled_prompt_embeds = prompt_embeds.pooler_output.to(precision) - - # T5 encoding for text embeddings - text_inputs_2 = self.tokenizer_2( - conditioning, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_tensors="pt", - ) - - with torch.no_grad(): - text_input_ids_2 = text_inputs_2.input_ids.to(self.text_encoder_2.device) - prompt_embeds_2 = self.text_encoder_2( - text_input_ids_2, - output_hidden_states=False, - )[0].to(precision) - - return pooled_prompt_embeds, prompt_embeds_2 - - def to(self, *args, **kwargs): - """Moves the model to the specified device.""" - self.text_encoder.to(*args, **kwargs) - self.text_encoder_2.to(*args, **kwargs) - return self - - -class FluxImageEncoder: - """VAE encoder/decoder for Flux. - - Flux VAE uses both scaling_factor and shift_factor for latent normalization. +# --------------------------------------------------------------------------- +# Helpers (mirrors of diffusers pipeline static methods) +# --------------------------------------------------------------------------- + +def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + """Pack video latents [B, C, F, H, W] → [B, F//pt * H//p * W//p, C*pt*p*p].""" + B, C, F, H, W = latents.shape + pF = F // patch_size_t + pH = H // patch_size + pW = W // patch_size + latents = latents.reshape(B, C, pF, patch_size_t, pH, patch_size, pW, patch_size) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + +def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, + patch_size: int = 1, patch_size_t: int = 1 +) -> torch.Tensor: + """Unpack video latents [B, T, D] → [B, C, F, H, W].""" + B = latents.size(0) + latents = latents.reshape(B, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + +def _pack_audio_latents(latents: torch.Tensor) -> torch.Tensor: + """Pack audio latents [B, C, L, M] → [B, L, C*M].""" + return latents.transpose(1, 2).flatten(2, 3) + + +def _unpack_audio_latents(latents: torch.Tensor, latent_length: int, num_mel_bins: int) -> torch.Tensor: + """Unpack audio latents [B, L, C*M] -> [B, C, L, M].""" + return latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + + +def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, + scaling_factor: float = 1.0, +) -> torch.Tensor: + mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + return (latents - mean) * scaling_factor / std + + +def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, + scaling_factor: float = 1.0, +) -> torch.Tensor: + mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + return latents * std / scaling_factor + mean + + +def _normalize_audio_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor +) -> torch.Tensor: + mean = latents_mean.to(latents.device, latents.dtype) + std = latents_std.to(latents.device, latents.dtype) + return (latents - mean) / std + + +def _denormalize_audio_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor +) -> torch.Tensor: + mean = latents_mean.to(latents.device, latents.dtype) + std = latents_std.to(latents.device, latents.dtype) + return latents * std + mean + + +def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, +) -> torch.Tensor: """ + Stack all Gemma hidden-state layers, normalize per-batch/per-layer over + non-padded positions, and pack into [B, T, H * num_layers]. + """ + B, T, H, L = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + token_indices = torch.arange(T, device=device).unsqueeze(0) # [1, T] + if padding_side == "right": + mask = token_indices < sequence_lengths[:, None] + else: # left + start = T - sequence_lengths[:, None] + mask = token_indices >= start + mask = mask[:, :, None, None] # [B, T, 1, 1] + + masked = text_hidden_states.masked_fill(~mask, 0.0) + num_valid = (sequence_lengths * H).view(B, 1, 1, 1) + mean = masked.sum(dim=(1, 2), keepdim=True) / (num_valid + eps) + + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + normed = (text_hidden_states - mean) / (x_max - x_min + eps) * scale_factor + normed = normed.flatten(2) # [B, T, H*L] + mask_flat = mask.squeeze(-1).expand(-1, -1, H * L) + normed = normed.masked_fill(~mask_flat, 0.0) + return normed.to(original_dtype) + + +def _calculate_shift( + image_seq_len: int, + base_seq_len: int = 1024, + max_seq_len: int = 4096, + base_shift: float = 0.95, + max_shift: float = 2.05, +) -> float: + """Mirrors the pipeline's calculate_shift — defaults match LTX-2 scheduler config.""" + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + return image_seq_len * m + b + + +def _retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=None, mu=None): + """Call scheduler.set_timesteps, forwarding mu when dynamic shifting is enabled.""" + kwargs = {} + if mu is not None: + kwargs["mu"] = mu + if sigmas is not None: + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + return scheduler.timesteps, len(scheduler.timesteps) - def __init__(self, model_id: str): - self.vae: AutoencoderKL = AutoencoderKL.from_pretrained( - model_id, - cache_dir=os.environ["HF_HOME"], - subfolder="vae", - local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), - ) - self.vae.eval().requires_grad_(False) - - # Flux VAE uses shift_factor in addition to scaling_factor - self.scaling_factor = getattr(self.vae.config, "scaling_factor", 0.3611) - self.shift_factor = getattr(self.vae.config, "shift_factor", 0.1159) - - def encode(self, real_images: torch.Tensor) -> torch.Tensor: - """Encode images to latent space. - - Args: - real_images: Input images in [-1, 1] range. - - Returns: - torch.Tensor: Latent representations (shifted and scaled). - """ - latent_images = self.vae.encode(real_images, return_dict=False)[0].sample() - # Apply Flux-specific shift and scale - latent_images = (latent_images - self.shift_factor) * self.scaling_factor - return latent_images - - def decode(self, latent_images: torch.Tensor) -> torch.Tensor: - """Decode latents to images. - - Args: - latent_images: Latent representations (shifted and scaled). - - Returns: - torch.Tensor: Decoded images in [-1, 1] range. - """ - # Reverse Flux-specific shift and scale - latents = (latent_images / self.scaling_factor) + self.shift_factor - images = self.vae.decode(latents, return_dict=False)[0].clip(-1.0, 1.0) - return images - - def to(self, *args, **kwargs): - """Moves the model to the specified device.""" - self.vae.to(*args, **kwargs) - return self +# --------------------------------------------------------------------------- +# classify_forward — monkey-patched onto self.transformer +# --------------------------------------------------------------------------- def classify_forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor = None, - pooled_projections: torch.Tensor = None, - timestep: torch.LongTensor = None, - img_ids: torch.Tensor = None, - txt_ids: torch.Tensor = None, - guidance: torch.Tensor = None, - joint_attention_kwargs: Optional[dict] = None, + audio_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + audio_timestep: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + audio_encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + fps: float = 24.0, + audio_num_frames: Optional[int] = None, + video_coords: Optional[torch.Tensor] = None, + audio_coords: Optional[torch.Tensor] = None, + attention_kwargs: Optional[dict] = None, + return_dict: bool = True, # accepted for API compatibility; always ignored + # FastGen distillation kwargs return_features_early: bool = False, feature_indices: Optional[Set[int]] = None, return_logvar: bool = False, -) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: - """ - Modified forward pass for FluxTransformer2DModel with feature extraction support. - - Args: - hidden_states: Input latent states. - encoder_hidden_states: T5 text encoder hidden states. - pooled_projections: CLIP pooled text embeddings. - timestep: Current timestep. - img_ids: Image position IDs. - txt_ids: Text position IDs. - guidance: Guidance scale embedding. - joint_attention_kwargs: Additional attention kwargs. - return_features_early: If True, return features as soon as collected. - feature_indices: Set of block indices to extract features from. - return_logvar: If True, return log variance estimate. - - Returns: - Model output, optionally with features or logvar. +) -> Union[ + Tuple[torch.Tensor, torch.Tensor], # (video_out, audio_out) + Tuple[Tuple[torch.Tensor, torch.Tensor], List[torch.Tensor]], # ((video_out, audio_out), features) + List[torch.Tensor], # features only (early exit) +]: """ - if feature_indices is None: - feature_indices = set() - - if return_features_early and len(feature_indices) == 0: - return [] + Drop-in replacement for LTX2VideoTransformer3DModel.forward that adds FastGen + distillation support (feature extraction, early exit, logvar). - idx, features = 0, [] + Audio always flows through every block unchanged — we never short-circuit it — + but only video hidden_states are stored as features for the discriminator. - # Store original sequence length to compute spatial dims for feature reshaping - # hidden_states: [B, seq_len, C*4] where seq_len = (H//2) * (W//2) - seq_len = hidden_states.shape[1] - spatial_size = int(seq_len**0.5) # Assuming square spatial dimensions + Returns + ------- + Normal mode (feature_indices empty, return_features_early False): + (video_output, audio_output) — identical to the original forward - # 1. Patch embedding - hidden_states = self.x_embedder(hidden_states) + Feature mode (feature_indices non-empty, return_features_early False): + ((video_output, audio_output), List[video_feature_tensors]) - # 2. Time embedding - timestep_scaled = timestep.to(hidden_states.dtype) * 1000 - if guidance is not None: - guidance_scaled = guidance.to(hidden_states.dtype) * 1000 - temb = self.time_text_embed(timestep_scaled, guidance_scaled, pooled_projections) + Early-exit mode (return_features_early True): + List[video_feature_tensors] — forward stops as soon as all features collected + """ + # LoRA scale handling — mirrors the @apply_lora_scale decorator in upstream + from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) else: - temb = self.time_text_embed(timestep_scaled, pooled_projections) + lora_scale = 1.0 + if USE_PEFT_BACKEND: + scale_lora_layers(self, lora_scale) - # 3. Text embedding - encoder_hidden_states = self.context_embedder(encoder_hidden_states) + if feature_indices is None: + feature_indices = set() - # 4. Prepare positional embeddings - ids = torch.cat((txt_ids, img_ids), dim=0) - image_rotary_emb = self.pos_embed(ids) + if return_features_early and len(feature_indices) == 0: + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + return [] - # 5. Joint transformer blocks - for block in self.transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, - joint_attention_kwargs, - ) - else: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, - ) + # ------------------------------------------------------------------ # + # Steps 1-4: identical to the original forward (no changes) + # ------------------------------------------------------------------ # + audio_timestep = audio_timestep if audio_timestep is not None else timestep - # Check if we should extract features at this index - if idx in feature_indices: - # Reshape from [B, seq_len, hidden_dim] to [B, hidden_dim, H, W] for discriminator - feat = hidden_states.clone() - B, S, C = feat.shape - feat = feat.permute(0, 2, 1).reshape(B, C, spatial_size, spatial_size) - features.append(feat) + # Convert attention masks to additive bias form + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - # Early return if we have all features - if return_features_early and len(features) == len(feature_indices): - return features + if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: + audio_encoder_attention_mask = ( + 1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype) + ) * -10000.0 + audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) - idx += 1 + batch_size = hidden_states.size(0) - # 6. Single transformer blocks - for block in self.single_transformer_blocks: + # 1. RoPE positional embeddings + if video_coords is None: + video_coords = self.rope.prepare_video_coords( + batch_size, num_frames, height, width, hidden_states.device, fps=fps + ) + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, audio_hidden_states.device + ) + + video_rotary_emb = self.rope(video_coords, device=hidden_states.device) + audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) + + video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope( + audio_coords[:, 0:1, :], device=audio_hidden_states.device + ) + + # 2. Patchify input projections + hidden_states = self.proj_in(hidden_states) + audio_hidden_states = self.audio_proj_in(audio_hidden_states) + + # 3. Timestep embeddings and modulation parameters + timestep_cross_attn_gate_scale_factor = ( + self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier + ) + + temb, embedded_timestep = self.time_embed( + timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + temb_audio, audio_embedded_timestep = self.audio_time_embed( + audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) + audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) + + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( + timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.view( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view( + batch_size, -1, video_cross_attn_a2v_gate.shape[-1] + ) + + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( + audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( + audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view( + batch_size, -1, audio_cross_attn_v2a_gate.shape[-1] + ) + + # 4. Prompt embeddings + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view( + batch_size, -1, audio_hidden_states.size(-1) + ) + + # ------------------------------------------------------------------ # + # Step 5: Block loop with video-only feature extraction + # Audio always flows through every block — we never skip it. + # ------------------------------------------------------------------ # + features: List[torch.Tensor] = [] + + for idx, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + hidden_states, audio_hidden_states = self._gradient_checkpointing_func( block, hidden_states, + audio_hidden_states, encoder_hidden_states, + audio_encoder_hidden_states, temb, - image_rotary_emb, - joint_attention_kwargs, + temb_audio, + video_cross_attn_scale_shift, + audio_cross_attn_scale_shift, + video_cross_attn_a2v_gate, + audio_cross_attn_v2a_gate, + video_rotary_emb, + audio_rotary_emb, + video_cross_attn_rotary_emb, + audio_cross_attn_rotary_emb, + encoder_attention_mask, + audio_encoder_attention_mask, ) else: - encoder_hidden_states, hidden_states = block( + hidden_states, audio_hidden_states = block( hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, ) - # Check if we should extract features at this index + # Video-only feature extraction at requested block indices if idx in feature_indices: - # Reshape from [B, seq_len, hidden_dim] to [B, hidden_dim, H, W] for discriminator - feat = hidden_states.clone() - B, S, C = feat.shape - feat = feat.permute(0, 2, 1).reshape(B, C, spatial_size, spatial_size) - features.append(feat) + features.append(hidden_states.clone()) # [B, T_v, D_v] — packed video tokens - # Early return if we have all features + # Early exit once all requested features are collected if return_features_early and len(features) == len(feature_indices): return features - idx += 1 - - # 7. Final projection - hidden_states is already image-only after single blocks - hidden_states = self.norm_out(hidden_states, temb) - output = self.proj_out(hidden_states) - - # If we have all the features, we can exit early + # ------------------------------------------------------------------ # + # Step 6: Output layers (video + audio) — unchanged from original + # ------------------------------------------------------------------ # + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + video_output = self.proj_out(hidden_states) + + audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] + audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) + + # ------------------------------------------------------------------ # + # Assemble output following FastGen convention + # ------------------------------------------------------------------ # if return_features_early: + # Should have been caught above; guard for safety assert len(features) == len(feature_indices), f"{len(features)} != {len(feature_indices)}" return features - # Prepare output + # Logvar (optional — requires logvar_linear to be added to the transformer) + logvar = None + if return_logvar: + assert hasattr(self, "logvar_linear"), ( + "logvar_linear is required when return_logvar=True. " + "It is added by LTX2.__init__." + ) + # temb has shape [B, T_tokens, inner_dim]; take mean over tokens for a scalar logvar per sample + logvar = self.logvar_linear(temb.mean(dim=1)) # [B, 1] + if len(feature_indices) == 0: - out = output + out = (video_output, audio_output) else: - out = [output, features] + out = [(video_output, audio_output), features] + + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) if return_logvar: - logvar = self.logvar_linear(temb) return out, logvar - return out -class Flux(FastGenNetwork): - """Flux.1 network for text-to-image generation. +# --------------------------------------------------------------------------- +# Text encoder wrapper +# --------------------------------------------------------------------------- + +class LTX2TextEncoder(nn.Module): + """ + Wraps Gemma3ForConditionalGeneration for LTX-2 text conditioning. - Reference: https://huggingface.co/black-forest-labs/FLUX.1-dev + Returns both the packed prompt embeddings AND the tokenizer attention mask, + which is required by LTX2TextConnectors. """ - MODEL_ID = "black-forest-labs/FLUX.1-dev" + def __init__(self, model_id: str): + super().__init__() + self.tokenizer = GemmaTokenizerFast.from_pretrained(model_id, subfolder="tokenizer") + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + model_id, subfolder="text_encoder" + ) + self.text_encoder.eval().requires_grad_(False) + + @torch.no_grad() + def encode( + self, + prompt: Union[str, List[str]], + precision: torch.dtype = torch.bfloat16, + max_sequence_length: int = 1024, + scale_factor: int = 8, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Encode text prompt(s) into packed Gemma hidden states. + + Returns + ------- + prompt_embeds : torch.Tensor [B, T, H * num_layers] + attention_mask : torch.Tensor [B, T] + """ + if isinstance(prompt, str): + prompt = [prompt] + + device = next(self.text_encoder.parameters()).device + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = text_inputs.input_ids.to(device) + attention_mask = text_inputs.attention_mask.to(device) + + outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + ) + + # Stack all hidden states: [B, T, H, num_layers] + hidden_states = torch.stack(outputs.hidden_states, dim=-1) + sequence_lengths = attention_mask.sum(dim=-1) + + prompt_embeds = _pack_text_embeds( + hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ).to(precision) + + return prompt_embeds, attention_mask + + def to(self, *args, **kwargs): + self.text_encoder.to(*args, **kwargs) + return self + + +# --------------------------------------------------------------------------- +# Main LTX-2 network — follows FastGen pattern (Flux / Wan) +# --------------------------------------------------------------------------- + +class LTX2(FastGenNetwork): + """ + FastGen wrapper for LTX-2 audio-video generation. + + Distillation targets video only: + - forward() receives and returns video latents [B, C, F, H, W] + - Audio is generated internally but not used for the distillation loss + - classify_forward extracts video hidden_states at requested block indices + + Component layout: + text_encoder → connectors → transformer (patched) → vae → audio_vae → vocoder + """ + + MODEL_ID = "Lightricks/LTX-2" def __init__( self, @@ -338,43 +540,31 @@ def __init__( net_pred_type: str = "flow", schedule_type: str = "rf", disable_grad_ckpt: bool = False, - guidance_scale: Optional[float] = 3.5, load_pretrained: bool = True, **model_kwargs, ): - """Flux.1 constructor. + """ + LTX-2 constructor. Args: - model_id: The HuggingFace model ID to load. - Defaults to "black-forest-labs/FLUX.1-dev". - net_pred_type: Prediction type. Defaults to "flow" for flow matching. + model_id: HuggingFace model ID or local path. Defaults to "Lightricks/LTX-2". + net_pred_type: Prediction type. Defaults to "flow" (flow matching). schedule_type: Schedule type. Defaults to "rf" (rectified flow). - disable_grad_ckpt: Whether to disable gradient checkpointing during training. - Defaults to False. Set to True when using FSDP to avoid memory access errors. - guidance_scale: Default guidance scale for Flux.1-dev guidance distillation. - None means no guidance. Defaults to 3.5 (recommended for Flux.1-dev). + disable_grad_ckpt: Disable gradient checkpointing during training. + Set True when using FSDP to avoid memory access errors. + load_pretrained: Load pretrained weights. If False, initialises from config only. """ super().__init__(net_pred_type=net_pred_type, schedule_type=schedule_type, **model_kwargs) self.model_id = model_id - self.guidance_scale = guidance_scale self._disable_grad_ckpt = disable_grad_ckpt - logger.debug(f"Embedded guidance scale: {guidance_scale}") - # Initialize the network (handles meta device and pretrained loading) self._initialize_network(model_id, load_pretrained) - # Override forward with classify_forward + # Monkey-patch classify_forward onto self.transformer (same pattern as Flux / Wan) self.transformer.forward = types.MethodType(classify_forward, self.transformer) - # Disable cuDNN SDPA backend to avoid mha_graph->execute errors during backward. - # This is a known issue with Flux transformer and cuDNN attention. - # Flash and mem_efficient backends still work; only cuDNN is problematic. - if torch.backends.cuda.is_built(): - torch.backends.cuda.enable_cudnn_sdp(False) - logger.info("Disabled cuDNN SDPA backend for Flux compatibility") - - # Gradient checkpointing configuration + # Gradient checkpointing if disable_grad_ckpt: self.transformer.disable_gradient_checkpointing() else: @@ -382,239 +572,239 @@ def __init__( torch.cuda.empty_cache() - def _initialize_network(self, model_id: str, load_pretrained: bool) -> None: - """Initialize the transformer network. + # ------------------------------------------------------------------ + # Initialisation + # ------------------------------------------------------------------ - Args: - model_id: The HuggingFace model ID or local path. - load_pretrained: Whether to load pretrained weights. - """ - # Check if we're in a meta context (for FSDP memory-efficient loading) + def _initialize_network(self, model_id: str, load_pretrained: bool) -> None: + """Initialize the transformer and supporting modules.""" in_meta_context = self._is_in_meta_context() should_load_weights = load_pretrained and (not in_meta_context) if should_load_weights: - logger.info("Loading Flux transformer from pretrained") - self.transformer: FluxTransformer2DModel = FluxTransformer2DModel.from_pretrained( - model_id, - cache_dir=os.environ["HF_HOME"], - subfolder="transformer", - local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), + logger.info("Loading LTX-2 transformer from pretrained") + self.transformer: LTX2VideoTransformer3DModel = LTX2VideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer" ) else: - # Load config and create model structure - # If we're in a meta context, tensors will automatically be on meta device - config = FluxTransformer2DModel.load_config( - model_id, - cache_dir=os.environ["HF_HOME"], - subfolder="transformer", - local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), - ) + config = LTX2VideoTransformer3DModel.load_config(model_id, subfolder="transformer") if in_meta_context: logger.info( - "Initializing Flux transformer on meta device (zero memory, will receive weights via FSDP sync)" + "Initializing LTX-2 transformer on meta device " + "(zero memory, will receive weights via FSDP sync)" ) else: - logger.info("Initializing Flux transformer from config (no pretrained weights)") - logger.warning("Flux transformer being initialized from config. No weights are loaded!") - self.transformer: FluxTransformer2DModel = FluxTransformer2DModel.from_config(config) - - # Add logvar linear layer for variance estimation - Flux uses 3072-dim time embeddings - self.transformer.logvar_linear = torch.nn.Linear(3072, 1) - - def reset_parameters(self): - """Reinitialize parameters for FSDP meta device initialization. + logger.info("Initializing LTX-2 transformer from config (no pretrained weights)") + logger.warning("LTX-2 transformer being initialized from config. No weights are loaded!") + self.transformer: LTX2VideoTransformer3DModel = LTX2VideoTransformer3DModel.from_config(config) + + # inner_dim = num_attention_heads * attention_head_dim + inner_dim = ( + self.transformer.config.num_attention_heads + * self.transformer.config.attention_head_dim + ) - This is required when using meta device initialization for FSDP2. - Reinitializes all linear layers and embeddings. - """ - import torch.nn as nn + # Add logvar_linear for uncertainty weighting (DMD2 / f-distill) + # temb mean has shape [B, inner_dim] → logvar scalar per sample + self.transformer.logvar_linear = nn.Linear(inner_dim, 1) + logger.info(f"Added logvar_linear ({inner_dim} → 1) to LTX-2 transformer") - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Embedding): - nn.init.normal_(m.weight, std=0.02) + # Connectors: top-level sibling of transformer (NOT nested inside it) + if should_load_weights: + self.connectors: LTX2TextConnectors = LTX2TextConnectors.from_pretrained( + model_id, subfolder="connectors" + ) + else: + # Connectors are lightweight; always load if pretrained is skipped for the transformer + logger.warning("Skipping connector pretrained load (meta context or load_pretrained=False)") + self.connectors = None # will be loaded lazily via init_preprocessors - super().reset_parameters() + # Cache compression ratios used by forward() and sample() + if should_load_weights: + # VAEs (needed for sample(); not for the training forward pass) + self.vae: AutoencoderKLLTX2Video = AutoencoderKLLTX2Video.from_pretrained( + model_id, subfolder="vae" + ) + self.vae.eval().requires_grad_(False) - logger.debug("Reinitialized Flux parameters") + self.audio_vae: AutoencoderKLLTX2Audio = AutoencoderKLLTX2Audio.from_pretrained( + model_id, subfolder="audio_vae" + ) + self.audio_vae.eval().requires_grad_(False) - def fully_shard(self, **kwargs): - """Fully shard the Flux network for FSDP. + self.vocoder: LTX2Vocoder = LTX2Vocoder.from_pretrained( + model_id, subfolder="vocoder" + ) + self.vocoder.eval().requires_grad_(False) - Note: Flux has two types of transformer blocks: - - transformer_blocks: Joint attention blocks for text-image interaction - - single_transformer_blocks: Single stream blocks for image processing + self._cache_vae_constants() - We shard `self.transformer` instead of `self` because the network wrapper - class may have complex multiple inheritance with ABC, which causes Python's - __class__ assignment to fail due to incompatible memory layouts. - """ - # Note: Checkpointing has to happen first, for proper casting during backward pass recomputation. - if self.transformer.gradient_checkpointing: - # Disable the built-in gradient checkpointing (which uses torch.utils.checkpoint) - self.transformer.disable_gradient_checkpointing() - # Apply FSDP-compatible activation checkpointing to both block types - apply_fsdp_checkpointing( - self.transformer, - check_fn=lambda block: isinstance(block, (FluxTransformerBlock, FluxSingleTransformerBlock)), - ) - logger.info("Applied FSDP activation checkpointing to Flux transformer blocks") + # Scheduler (used in sample()) + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model_id, subfolder="scheduler" + ) - # Apply FSDP sharding to joint transformer blocks - for block in self.transformer.transformer_blocks: - fully_shard(block, **kwargs) + def _cache_vae_constants(self) -> None: + """Cache VAE spatial/temporal compression constants for use in forward() / sample().""" + self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio + self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio + self.transformer_spatial_patch_size = self.transformer.config.patch_size + self.transformer_temporal_patch_size = self.transformer.config.patch_size_t - # Apply FSDP sharding to single transformer blocks - for block in self.transformer.single_transformer_blocks: - fully_shard(block, **kwargs) + self.audio_sampling_rate = self.audio_vae.config.sample_rate + self.audio_hop_length = self.audio_vae.config.mel_hop_length + self.audio_vae_temporal_compression = self.audio_vae.temporal_compression_ratio + self.audio_vae_mel_compression_ratio = self.audio_vae.mel_compression_ratio - fully_shard(self.transformer, **kwargs) + # ------------------------------------------------------------------ + # Preprocessor initialisation (lazy, matches Flux / Wan pattern) + # ------------------------------------------------------------------ def init_preprocessors(self): - """Initialize text and image encoders.""" - if not hasattr(self, "text_encoder"): + """Initialize text encoder and connectors.""" + if not hasattr(self, "text_encoder") or self.text_encoder is None: self.init_text_encoder() - if not hasattr(self, "vae"): - self.init_vae() + if self.connectors is None: + self.connectors = LTX2TextConnectors.from_pretrained( + self.model_id, subfolder="connectors" + ) def init_text_encoder(self): - """Initialize the text encoder for Flux.""" - self.text_encoder = FluxTextEncoder(model_id=self.model_id) + """Initialize the Gemma3 text encoder for LTX-2.""" + self.text_encoder = LTX2TextEncoder(model_id=self.model_id) - def init_vae(self): - """Initialize only the VAE for visualization.""" - self.vae = FluxImageEncoder(model_id=self.model_id) + # ------------------------------------------------------------------ + # Device movement + # ------------------------------------------------------------------ def to(self, *args, **kwargs): - """Moves the model to the specified device.""" super().to(*args, **kwargs) - if hasattr(self, "text_encoder"): + if hasattr(self, "text_encoder") and self.text_encoder is not None: self.text_encoder.to(*args, **kwargs) - if hasattr(self, "vae"): + if hasattr(self, "connectors") and self.connectors is not None: + self.connectors.to(*args, **kwargs) + if hasattr(self, "vae") and self.vae is not None: self.vae.to(*args, **kwargs) + if hasattr(self, "audio_vae") and self.audio_vae is not None: + self.audio_vae.to(*args, **kwargs) + if hasattr(self, "vocoder") and self.vocoder is not None: + self.vocoder.to(*args, **kwargs) return self - def _prepare_latent_image_ids( - self, - height: int, - width: int, - device: torch.device, - dtype: torch.dtype, - ) -> torch.Tensor: - """Prepare image position IDs for the transformer. + # ------------------------------------------------------------------ + # FSDP + # ------------------------------------------------------------------ - Args: - height: Latent height (before packing, will be divided by 2). - width: Latent width (before packing, will be divided by 2). - device: Target device. - dtype: Target dtype. + def fully_shard(self, **kwargs): + """Fully shard the LTX-2 transformer for FSDP2. - Returns: - torch.Tensor: Image position IDs [(H//2)*(W//2), 3] (2D, no batch dim). + Shards self.transformer (not self) to avoid ABC __class__ assignment issues. """ - # Use packed dimensions - packed_height = height // 2 - packed_width = width // 2 - latent_image_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype) - latent_image_ids[..., 1] = torch.arange(packed_height, device=device, dtype=dtype)[:, None] - latent_image_ids[..., 2] = torch.arange(packed_width, device=device, dtype=dtype)[None, :] - latent_image_ids = latent_image_ids.reshape(packed_height * packed_width, 3) - return latent_image_ids - - def _prepare_text_ids( - self, - seq_length: int, - device: torch.device, - dtype: torch.dtype, - ) -> torch.Tensor: - """Prepare text position IDs. + if self.transformer.gradient_checkpointing: + self.transformer.disable_gradient_checkpointing() + apply_fsdp_checkpointing( + self.transformer, + check_fn=lambda block: isinstance(block, LTX2VideoTransformerBlock), + ) + logger.info("Applied FSDP activation checkpointing to LTX-2 transformer blocks") - Args: - seq_length: Text sequence length. - device: Target device. - dtype: Target dtype. + for block in self.transformer.transformer_blocks: + fully_shard(block, **kwargs) + fully_shard(self.transformer, **kwargs) - Returns: - torch.Tensor: Text position IDs [seq_length, 3] (2D, no batch dim). - """ - text_ids = torch.zeros(seq_length, 3, device=device, dtype=dtype) - return text_ids + # ------------------------------------------------------------------ + # reset_parameters (required for FSDP meta device init) + # ------------------------------------------------------------------ - def _pack_latents(self, latents: torch.Tensor) -> torch.Tensor: - """Pack latents from [B, C, H, W] to [B, (H//2)*(W//2), C*4] for Flux transformer. + def reset_parameters(self): + """Reinitialise parameters after meta device materialisation (FSDP2).""" + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) - Flux uses 2x2 patch packing where each 2x2 spatial block is flattened into channels. + super().reset_parameters() + logger.debug("Reinitialized LTX-2 parameters") - Args: - latents: Input latents [B, C, H, W]. + # ------------------------------------------------------------------ + # Audio latent sizing helper (shared by forward and sample) + # ------------------------------------------------------------------ - Returns: - Packed latents [B, (H//2)*(W//2), C*4]. + def _compute_audio_shape( + self, latent_f: int, fps: float, device: torch.device, dtype: torch.dtype + ) -> Tuple[int, int, int]: """ - batch_size, channels, height, width = latents.shape - # Reshape to [B, C, H//2, 2, W//2, 2] - latents = latents.view(batch_size, channels, height // 2, 2, width // 2, 2) - # Permute to [B, H//2, W//2, C, 2, 2] - latents = latents.permute(0, 2, 4, 1, 3, 5) - # Reshape to [B, (H//2)*(W//2), C*4] - latents = latents.reshape(batch_size, (height // 2) * (width // 2), channels * 4) - return latents + Compute audio latent dimensions from video latent frame count. - def _unpack_latents(self, latents: torch.Tensor, height: int, width: int) -> torch.Tensor: - """Unpack latents from [B, (H//2)*(W//2), C*4] to [B, C, H, W]. - - Reverses the 2x2 patch packing used by Flux. + Returns (audio_num_frames, latent_mel_bins, num_audio_ch). + """ + pixel_frames = (latent_f - 1) * self.vae_temporal_compression_ratio + 1 + duration_s = pixel_frames / fps - Args: - latents: Packed latents [B, (H//2)*(W//2), C*4]. - height: Target height (original H before packing). - width: Target width (original W before packing). + audio_latents_per_second = ( + self.audio_sampling_rate + / self.audio_hop_length + / float(self.audio_vae_temporal_compression) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + num_mel_bins = self.audio_vae.config.mel_bins + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_audio_ch = self.audio_vae.config.latent_channels + return audio_num_frames, latent_mel_bins, num_audio_ch - Returns: - Unpacked latents [B, C, H, W]. - """ - batch_size = latents.shape[0] - channels = latents.shape[2] // 4 # C*4 -> C - # Reshape to [B, H//2, W//2, C, 2, 2] - latents = latents.reshape(batch_size, height // 2, width // 2, channels, 2, 2) - # Permute to [B, C, H//2, 2, W//2, 2] - latents = latents.permute(0, 3, 1, 4, 2, 5) - # Reshape to [B, C, H, W] - latents = latents.reshape(batch_size, channels, height, width) - return latents + # ------------------------------------------------------------------ + # forward() — video-only distillation interface + # ------------------------------------------------------------------ def forward( self, x_t: torch.Tensor, t: torch.Tensor, condition: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - r: Optional[torch.Tensor] = None, # unused, kept for API compatibility - guidance: Optional[torch.Tensor] = None, + r: Optional[torch.Tensor] = None, # unused, kept for API compatibility + fps: float = 24.0, return_features_early: bool = False, feature_indices: Optional[Set[int]] = None, return_logvar: bool = False, fwd_pred_type: Optional[str] = None, + audio_latents: Optional[torch.Tensor] = None, + return_audio: bool = False, **fwd_kwargs, ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: - """Forward pass of Flux diffusion model. + """Forward pass — video latents in, video latents out. + + Follows the Flux/Wan FastGen pattern. Audio is optional: + - Distillation (default): ``audio_latents=None, return_audio=False``. + Random noise is used for audio; only video prediction is returned. + - Inference via sample(): ``audio_latents=, return_audio=True``. + The current denoising audio latents are passed in and the audio + noise prediction is returned alongside the video prediction so that + sample() can step both schedulers. Args: - x_t: The diffused data sample [B, C, H, W]. - t: The current timestep. - condition: Tuple of (pooled_prompt_embeds, prompt_embeds) from text encoder. - r: Another timestep (for mean flow methods). - return_features_early: If True, return features once collected. - feature_indices: Set of block indices for feature extraction. - return_logvar: If True, return the logvar. - fwd_pred_type: Override network prediction type. - guidance: Optional guidance scale embedding. + x_t: Video latents [B, C, F, H, W]. + t: Timestep [B] — scheduler sigmas passed directly to time_embed. + condition: Tuple of (prompt_embeds [B, T, D], attention_mask [B, T]). + r: Unused (kept for FastGen API compatibility). + fps: Frames per second (needed for RoPE coordinate computation). + return_features_early: Return video features as soon as collected. + feature_indices: Set of transformer block indices to extract video features from. + return_logvar: Return log-variance estimate alongside the output. + fwd_pred_type: Override prediction type. + audio_latents: Optional packed audio latents [B, T_a, C_a] from sample(). + When None, fresh random noise is generated internally. + return_audio: When True, return (video_out, audio_packed) instead of + just video_out. Only used by sample(). Returns: - Model output tensor or tuple with logvar/features. + Normal: video_out [B, C, F, H, W] + return_audio: (video_out, audio_packed [B, T_a, C_a]) + With features: (video_out, List[video_feature_tensors]) + Early exit: List[video_feature_tensors] + With logvar: (above, logvar [B, 1]) """ if feature_indices is None: feature_indices = set() @@ -627,78 +817,122 @@ def forward( assert fwd_pred_type in NET_PRED_TYPES, f"{fwd_pred_type} is not supported" batch_size = x_t.shape[0] - height, width = x_t.shape[2], x_t.shape[3] + _, _, latent_f, latent_h, latent_w = x_t.shape - # Unpack condition: (pooled_prompt_embeds, prompt_embeds) - pooled_prompt_embeds, prompt_embeds = condition + # Unpack text conditioning + prompt_embeds, attention_mask = condition - # Prepare position IDs (2D tensors, no batch dimension) - img_ids = self._prepare_latent_image_ids(height, width, x_t.device, x_t.dtype) - txt_ids = self._prepare_text_ids(prompt_embeds.shape[1], x_t.device, x_t.dtype) + # ---- Run connectors to get per-modality encoder hidden states ---- + additive_mask = (1 - attention_mask.to(prompt_embeds.dtype)) * -1_000_000.0 + connector_video_embeds, connector_audio_embeds, connector_attn_mask = self.connectors( + prompt_embeds, additive_mask, additive_mask=True + ) - # Pack latents for transformer: [B, C, H, W] -> [B, (H//2)*(W//2), C*4] - hidden_states = self._pack_latents(x_t) + # ---- Timestep: [B] — time_embed handles scale internally ---- + timestep = t.to(x_t.dtype).expand(batch_size) # [B] - # Note: Flux.1-dev (w/ guidance distillation) uses embedded guidance, so the default guidance is not None - if guidance is None: - guidance = torch.full( - (batch_size,), self.guidance_scale, device=hidden_states.device, dtype=hidden_states.dtype + # ---- Pack video latents ---- + hidden_states = _pack_latents( + x_t, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + # ---- Audio latents: use provided (from sample()) or generate random noise ---- + audio_num_frames, latent_mel_bins, num_audio_ch = self._compute_audio_shape( + latent_f, fps, x_t.device, x_t.dtype + ) + if audio_latents is not None: + # Already packed [B, T_a, C_a] — passed in by sample() + audio_hidden_states = audio_latents.to(x_t.dtype) + else: + audio_hidden_states = _pack_audio_latents( + torch.randn( + batch_size, num_audio_ch, audio_num_frames, latent_mel_bins, + device=x_t.device, dtype=x_t.dtype, + ) ) + # ---- RoPE coordinates ---- + video_coords = self.transformer.rope.prepare_video_coords( + batch_size, latent_f, latent_h, latent_w, x_t.device, fps=fps + ) + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, x_t.device + ) + + # ---- Transformer forward (classify_forward, monkey-patched) ---- model_outputs = self.transformer( hidden_states=hidden_states, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - timestep=t, # Flux expects timestep in [0, 1] - img_ids=img_ids, - txt_ids=txt_ids, - guidance=guidance, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=connector_video_embeds, + audio_encoder_hidden_states=connector_audio_embeds, + encoder_attention_mask=connector_attn_mask, + audio_encoder_attention_mask=connector_attn_mask, + timestep=timestep, + num_frames=latent_f, + height=latent_h, + width=latent_w, + fps=fps, + audio_num_frames=audio_num_frames, + video_coords=video_coords, + audio_coords=audio_coords, return_features_early=return_features_early, feature_indices=feature_indices, return_logvar=return_logvar, ) + # ---- Early exit: list of video feature tensors ---- if return_features_early: - return model_outputs + return model_outputs # List[Tensor], each [B, T_v, D_v] + # ---- Unpack logvar if requested ---- if return_logvar: out, logvar = model_outputs[0], model_outputs[1] else: out = model_outputs - # Unpack output: [B, H*W, C] -> [B, C, H, W] - if isinstance(out, torch.Tensor): - out = self._unpack_latents(out, height, width) - out = self.noise_scheduler.convert_model_output( - x_t, out, t, src_pred_type=self.net_pred_type, target_pred_type=fwd_pred_type - ) + # ---- Extract video prediction; capture audio for sample() if requested ---- + if len(feature_indices) == 0: + # out is (video_output, audio_output) + video_packed = out[0] # [B, T_v, C_packed] + audio_packed = out[1] # [B, T_a, C_packed] — only used when return_audio=True + features = None else: - out[0] = self._unpack_latents(out[0], height, width) - out[0] = self.noise_scheduler.convert_model_output( - x_t, out[0], t, src_pred_type=self.net_pred_type, target_pred_type=fwd_pred_type - ) + # out is [(video_output, audio_output), features] + video_packed = out[0][0] # [B, T_v, C_packed] + audio_packed = out[0][1] # [B, T_a, C_packed] + features = out[1] # List[Tensor] + + # ---- Unpack video tokens → [B, C, F, H, W] ---- + video_out = _unpack_latents( + video_packed, latent_f, latent_h, latent_w, + self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, + ) + + # ---- Convert model output to requested prediction type ---- + video_out = self.noise_scheduler.convert_model_output( + x_t, video_out, t, + src_pred_type=self.net_pred_type, + target_pred_type=fwd_pred_type, + ) + + # ---- Re-pack output following FastGen convention ---- + if features is not None: + out = [video_out, features] + else: + out = video_out + + # Return audio noise pred alongside video when called from sample() + if return_audio: + out = (out, audio_packed.float()) if return_logvar: return out, logvar return out - def _calculate_shift( - self, - image_seq_len: int, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.16, - ) -> float: - """Calculate the shift value for the scheduler based on image resolution. - - This implements the resolution-dependent shift from the Flux paper. - """ - - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - mu = image_seq_len * m + b - return mu + # ------------------------------------------------------------------ + # sample() — full denoising loop for inference + # Follows pipeline_ltx2.py exactly (verified working logic preserved) + # ------------------------------------------------------------------ @torch.no_grad() def sample( @@ -706,91 +940,85 @@ def sample( noise: torch.Tensor, condition: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, neg_condition: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - guidance_scale: Optional[float] = 3.5, - num_steps: int = 28, + guidance_scale: float = 4.0, + num_steps: int = 40, + fps: float = 24.0, + frame_rate: Optional[float] = None, **kwargs, - ) -> torch.Tensor: - """Generate samples using Euler flow matching. - - Args: - noise: Initial noise tensor [B, C, H, W]. - condition: Tuple of (pooled_prompt_embeds, prompt_embeds). - neg_condition: Optional negative condition tuple for CFG. - guidance_scale: Guidance scale (if not None, enables guidance via distillation). - num_steps: Number of sampling steps (default 28 for good quality/speed balance). - **kwargs: Additional keyword arguments - - Returns: - Generated latent samples. + ) -> Tuple[torch.Tensor, None]: """ - batch_size, channels, height, width = noise.shape - - # Calculate image sequence length for shift calculation - # After 2x2 packing: seq_len = (H // 2) * (W // 2) - image_seq_len = (height // 2) * (width // 2) + Run the full denoising loop for text-to-video generation (audio always None). - # Calculate resolution-dependent shift (mu) - mu = self._calculate_shift(image_seq_len) - - # Initialize scheduler with proper shift - scheduler = FlowMatchEulerDiscreteScheduler(shift=mu) - scheduler.set_timesteps(num_steps, device=noise.device) - timesteps = scheduler.timesteps + Returns + ------- + (video_latents, None): + video: [B, C, F, H, W] denormalised video latents + audio: always None + """ + fps = frame_rate if frame_rate is not None else fps + do_cfg = neg_condition is not None and guidance_scale > 1.0 + + transformer_dtype = self.transformer.dtype + transformer_device = next(self.transformer.parameters()).device + + # Move latents to transformer device and dtype + video_latents = noise.to(device=transformer_device, dtype=transformer_dtype) + + # Build combined condition for CFG so connectors are called once per step + if do_cfg: + neg_embeds, neg_mask = neg_condition + cond_embeds, cond_mask = condition + combined_condition = ( + torch.cat([neg_embeds, cond_embeds], dim=0).to(device=transformer_device, dtype=transformer_dtype), + torch.cat([neg_mask, cond_mask], dim=0).to(device=transformer_device), + ) + else: + embeds, mask = condition + combined_condition = ( + embeds.to(device=transformer_device, dtype=transformer_dtype), + mask.to(device=transformer_device), + ) - # Initialize latents with proper scaling based on the initial timestep - t_init = self.noise_scheduler.safe_clamp( - timesteps[0] / 1000.0, min=self.noise_scheduler.min_t, max=self.noise_scheduler.max_t + # ---- Scheduler timesteps ---- + B, C, latent_f, latent_h, latent_w = video_latents.shape + sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps) + video_seq_len = latent_f * latent_h * latent_w + mu = _calculate_shift( + video_seq_len, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + timesteps, num_steps = _retrieve_timesteps( + self.scheduler, num_steps, transformer_device, sigmas=sigmas, mu=mu ) - latents = self.noise_scheduler.latents(noise=noise, t_init=t_init) - - pooled_prompt_embeds, prompt_embeds = condition - - # Prepare guidance embedding for guidance distillation (Flux.1-dev mode) - # Note: Flux.1-dev uses embedded guidance, not traditional CFG - guidance_tensor = None - if guidance_scale is not None: - guidance_tensor = torch.full((batch_size,), guidance_scale, device=latents.device, dtype=latents.dtype) - - # Sampling loop - for timestep in timesteps: - # Scheduler timesteps are in [0, 1000], transformer expects [0, 1] - t = (timestep / 1000.0).expand(batch_size) - t = self.noise_scheduler.safe_clamp(t, min=self.noise_scheduler.min_t, max=self.noise_scheduler.max_t).to( - latents.dtype - ) - # Two guidance modes: - # 1. CFG mode: when neg_condition is provided (doubles batch, uses uncond/cond difference) - # 2. Guidance distillation mode: when neg_condition is None (single forward, guidance embedded) - if neg_condition is not None: - # Traditional CFG mode - neg_pooled, neg_prompt = neg_condition - latent_model_input = torch.cat([latents, latents], dim=0) - pooled_input = torch.cat([neg_pooled, pooled_prompt_embeds], dim=0) - prompt_input = torch.cat([neg_prompt, prompt_embeds], dim=0) - t_input = torch.cat([t, t], dim=0) - - noise_pred = self( - latent_model_input, - t_input, - (pooled_input, prompt_input), - fwd_pred_type="flow", - guidance=None, # No guidance embedding for CFG mode - ) + # ---- Denoising loop ---- + for t in timesteps: + latent_input = torch.cat([video_latents] * 2) if do_cfg else video_latents + t_input = t.to(dtype=transformer_dtype, device=transformer_device).expand(latent_input.shape[0]) + + noise_pred = self( + latent_input, + t_input, + condition=combined_condition, + fps=fps, + fwd_pred_type="flow", + ) + if do_cfg: noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) - else: - # Guidance distillation mode (recommended for Flux.1-dev) - noise_pred = self( - latents, - t, - condition, - fwd_pred_type="flow", - guidance=guidance_tensor, - ) - # Euler step - latents = scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + video_latents = self.scheduler.step(noise_pred, t, video_latents, return_dict=False)[0] + + # ---- Denormalise ---- + video_latents = _denormalize_latents( + video_latents, + self.vae.latents_mean, + self.vae.latents_std, + self.vae.config.scaling_factor, + ) - return latents + return video_latents, None \ No newline at end of file diff --git a/fastgen/networks/Flux/pipeline_ltx2.py b/fastgen/networks/Flux/pipeline_ltx2.py index 28f3fa8..9afa887 100644 --- a/fastgen/networks/Flux/pipeline_ltx2.py +++ b/fastgen/networks/Flux/pipeline_ltx2.py @@ -753,7 +753,7 @@ def __call__( # Honour the transformer's construction-time gate — if it has no audio modules, # we cannot generate audio regardless of what the caller requests. - run_audio = generate_audio and self.transformer.config.audio_enabled + run_audio = generate_audio and getattr(self.transformer.config, "audio_enabled", True) # 1. Check inputs self.check_inputs( From af9607a174c0905493dfaee6ef86818d0fa66e44 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Thu, 26 Feb 2026 16:44:51 +0800 Subject: [PATCH 08/13] revise the ltx pipeline loc --- fastgen/networks/{Flux => LTX2}/pipeline_ltx2.py | 0 fastgen/networks/{Flux => LTX2}/test_ltx2_pipeline.py | 0 fastgen/networks/{Flux => LTX2}/transformer_ltx2.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename fastgen/networks/{Flux => LTX2}/pipeline_ltx2.py (100%) rename fastgen/networks/{Flux => LTX2}/test_ltx2_pipeline.py (100%) rename fastgen/networks/{Flux => LTX2}/transformer_ltx2.py (100%) diff --git a/fastgen/networks/Flux/pipeline_ltx2.py b/fastgen/networks/LTX2/pipeline_ltx2.py similarity index 100% rename from fastgen/networks/Flux/pipeline_ltx2.py rename to fastgen/networks/LTX2/pipeline_ltx2.py diff --git a/fastgen/networks/Flux/test_ltx2_pipeline.py b/fastgen/networks/LTX2/test_ltx2_pipeline.py similarity index 100% rename from fastgen/networks/Flux/test_ltx2_pipeline.py rename to fastgen/networks/LTX2/test_ltx2_pipeline.py diff --git a/fastgen/networks/Flux/transformer_ltx2.py b/fastgen/networks/LTX2/transformer_ltx2.py similarity index 100% rename from fastgen/networks/Flux/transformer_ltx2.py rename to fastgen/networks/LTX2/transformer_ltx2.py From d95a383acbb4b66b20635c95297f903edd957eb8 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Thu, 26 Feb 2026 16:45:12 +0800 Subject: [PATCH 09/13] update --- tests/test_ltx2.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/tests/test_ltx2.py b/tests/test_ltx2.py index 2d53c44..93b538a 100644 --- a/tests/test_ltx2.py +++ b/tests/test_ltx2.py @@ -68,22 +68,12 @@ def test_ltx2_generation(): # latents: [B, C, F, H, W] denormalised video latents (float32) # audio_latents: [B, C, L, M] denormalised audio latents (float32) - # 6. Decode Latents to Video + Audio - print("Decoding latents to video and audio...") + # 6. Decode Latents to Video + print("Decoding latents to video...") with torch.no_grad(): - # vae.decode() signature: decode(z, temb=None, causal=None, return_dict=True) - # timestep_conditioning is False for LTX-2, so no temb needed. - # Use return_dict=False to get the tensor directly. video_tensor = model.vae.decode(latents.to(model.vae.dtype), return_dict=False)[0] # video_tensor: [B, C, F, H, W] in ~[-1, 1] - # Decode mel spectrogram -> waveform via audio_vae + vocoder - mel_spectrograms = model.audio_vae.decode( - audio_latents.to(model.audio_vae.dtype), return_dict=False - )[0] - audio_waveform = model.vocoder(mel_spectrograms) - # audio_waveform: [B, channels, samples] at vocoder.config.output_sampling_rate Hz - # 7. Post-process and Save # Convert [B, C, F, H, W] -> [F, H, W, C] uint8 video_np = ( @@ -91,15 +81,11 @@ def test_ltx2_generation(): ).clip(0, 255).astype(np.uint8) print("Saving video to ltx2_test.mp4...") - encode_video( - video_np, - fps=24, - audio=audio_waveform[0].float().cpu(), - audio_sample_rate=model.vocoder.config.output_sampling_rate, # 24000 Hz - output_path="ltx2_test.mp4", - ) + import imageio + with imageio.get_writer("ltx2_test.mp4", fps=24, codec="libx264", quality=8) as writer: + for frame in video_np: + writer.append_data(frame) print("Done!") - if __name__ == "__main__": test_ltx2_generation() \ No newline at end of file From 81976d7f94e2e379f630ab541ca78498c60c4ece Mon Sep 17 00:00:00 2001 From: linnan wang Date: Thu, 26 Feb 2026 16:47:53 +0800 Subject: [PATCH 10/13] udpate --- fastgen/networks/LTX2/network.py | 157 +++++++++---------------------- 1 file changed, 42 insertions(+), 115 deletions(-) diff --git a/fastgen/networks/LTX2/network.py b/fastgen/networks/LTX2/network.py index c995357..d0acd1d 100644 --- a/fastgen/networks/LTX2/network.py +++ b/fastgen/networks/LTX2/network.py @@ -928,71 +928,42 @@ def sample( fps: float = 24.0, frame_rate: Optional[float] = None, **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, None]: """ - Run the full denoising loop for text-to-video+audio generation. - - Follows pipeline_ltx2.py exactly: - - latents kept in float32 throughout - - connectors called ONCE on combined [uncond, cond] batch - - audio duration derived from pixel-frame count, not latent frames - - transformer wrapped in cache_context("cond_uncond") + Run the full denoising loop for text-to-video generation (audio always None). Returns ------- - (video_latents, audio_latents): - video: [B, C, F, H, W] denormalised, ready for vae.decode() - audio: [B, C, L, M] denormalised, ready for audio_vae.decode() + (video_latents, None): + video: [B, C, F, H, W] denormalised video latents + audio: always None """ fps = frame_rate if frame_rate is not None else fps do_cfg = neg_condition is not None and guidance_scale > 1.0 - device = noise.device - B, C, latent_f, latent_h, latent_w = noise.shape - - # ---- Audio shape ---- - audio_num_frames, latent_mel_bins, num_audio_ch = self._compute_audio_shape( - latent_f, fps, device, torch.float32 - ) - num_mel_bins = self.audio_vae.config.mel_bins + transformer_dtype = self.transformer.dtype + transformer_device = next(self.transformer.parameters()).device - # ---- Pack latents (float32 throughout, matching pipeline) ---- - video_latents = _pack_latents( - noise.float(), self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ) - audio_latents = torch.randn( - B, num_audio_ch, audio_num_frames, latent_mel_bins, - device=device, dtype=torch.float32 - ) - audio_latents = _pack_audio_latents(audio_latents) + # Move latents to transformer device and dtype + video_latents = noise.to(device=transformer_device, dtype=transformer_dtype) - # ---- Text conditioning — connectors called ONCE on combined [uncond, cond] ---- - prompt_embeds, attention_mask = condition + # Build combined condition for CFG so connectors are called once per step if do_cfg: neg_embeds, neg_mask = neg_condition - combined_embeds = torch.cat([neg_embeds, prompt_embeds], dim=0) - combined_mask = torch.cat([neg_mask, attention_mask], dim=0) + cond_embeds, cond_mask = condition + combined_condition = ( + torch.cat([neg_embeds, cond_embeds], dim=0).to(device=transformer_device, dtype=transformer_dtype), + torch.cat([neg_mask, cond_mask], dim=0).to(device=transformer_device), + ) else: - combined_embeds = prompt_embeds - combined_mask = attention_mask - - additive_mask = (1 - combined_mask.to(combined_embeds.dtype)) * -1_000_000.0 - connector_video_embeds, connector_audio_embeds, connector_attn_mask = self.connectors( - combined_embeds, additive_mask, additive_mask=True - ) - - # ---- Pre-compute RoPE coordinates ---- - video_coords = self.transformer.rope.prepare_video_coords( - B, latent_f, latent_h, latent_w, device, fps=fps - ) - audio_coords = self.transformer.audio_rope.prepare_audio_coords( - B, audio_num_frames, device - ) - if do_cfg: - video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) - audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) + embeds, mask = condition + combined_condition = ( + embeds.to(device=transformer_device, dtype=transformer_dtype), + mask.to(device=transformer_device), + ) # ---- Scheduler timesteps ---- + B, C, latent_f, latent_h, latent_w = video_latents.shape sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps) video_seq_len = latent_f * latent_h * latent_w mu = _calculate_shift( @@ -1002,79 +973,35 @@ def sample( self.scheduler.config.get("base_shift", 0.95), self.scheduler.config.get("max_shift", 2.05), ) - audio_scheduler = copy.deepcopy(self.scheduler) - _retrieve_timesteps(audio_scheduler, num_steps, device, sigmas=sigmas, mu=mu) - timesteps, num_steps = _retrieve_timesteps(self.scheduler, num_steps, device, sigmas=sigmas, mu=mu) - - prompt_dtype = connector_video_embeds.dtype - - # ---- Token counts after packing ---- - num_video_tokens = video_latents.shape[1] # [B, T_v, C] - num_audio_tokens = audio_latents.shape[1] # [B, T_a, C] + timesteps, num_steps = _retrieve_timesteps( + self.scheduler, num_steps, transformer_device, sigmas=sigmas, mu=mu + ) # ---- Denoising loop ---- for t in timesteps: - latent_input = torch.cat([video_latents] * 2) if do_cfg else video_latents - audio_latent_input = torch.cat([audio_latents] * 2) if do_cfg else audio_latents - latent_input = latent_input.to(prompt_dtype) - audio_latent_input = audio_latent_input.to(prompt_dtype) - - # Scale timestep and expand to per-token shape. - # The scheduler yields sigmas/timesteps that time_embed expects directly — - # LTX2AdaLayerNormSingle multiplies by timestep_scale_multiplier internally. - bs_input = latent_input.shape[0] - t_base = t.to(prompt_dtype).unsqueeze(0).expand(bs_input) # [B] - timestep = t_base.unsqueeze(1).expand(bs_input, num_video_tokens) # [B, T_v] - audio_timestep = t_base.unsqueeze(1).expand(bs_input, num_audio_tokens)# [B, T_a] - - with self.transformer.cache_context("cond_uncond"): - # classify_forward returns (video_output, audio_output) when - # feature_indices is empty and return_features_early is False. - # Note: no return_dict kwarg — classify_forward does not accept it. - model_out = self.transformer( - hidden_states=latent_input, - audio_hidden_states=audio_latent_input, - encoder_hidden_states=connector_video_embeds, - audio_encoder_hidden_states=connector_audio_embeds, - encoder_attention_mask=connector_attn_mask, - audio_encoder_attention_mask=connector_attn_mask, - timestep=timestep, - audio_timestep=audio_timestep, - num_frames=latent_f, - height=latent_h, - width=latent_w, - fps=fps, - audio_num_frames=audio_num_frames, - video_coords=video_coords, - audio_coords=audio_coords, - ) - noise_pred_video, noise_pred_audio = model_out - - noise_pred_video = noise_pred_video.float() - noise_pred_audio = noise_pred_audio.float() + latent_input = torch.cat([video_latents] * 2) if do_cfg else video_latents + t_input = t.to(dtype=transformer_dtype, device=transformer_device).expand(latent_input.shape[0]) + + noise_pred = self( + latent_input, + t_input, + condition=combined_condition, + fps=fps, + fwd_pred_type="flow", + ) if do_cfg: - video_uncond, video_cond = noise_pred_video.chunk(2) - noise_pred_video = video_uncond + guidance_scale * (video_cond - video_uncond) - audio_uncond, audio_cond = noise_pred_audio.chunk(2) - noise_pred_audio = audio_uncond + guidance_scale * (audio_cond - audio_uncond) + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) - video_latents = self.scheduler.step(noise_pred_video, t, video_latents, return_dict=False)[0] - audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + video_latents = self.scheduler.step(noise_pred, t, video_latents, return_dict=False)[0] - # ---- Unpack and denormalise ---- - video_latents = _unpack_latents( - video_latents, latent_f, latent_h, latent_w, - self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, - ) + # ---- Denormalise ---- video_latents = _denormalize_latents( - video_latents, self.vae.latents_mean, self.vae.latents_std, + video_latents, + self.vae.latents_mean, + self.vae.latents_std, self.vae.config.scaling_factor, ) - audio_latents = _denormalize_audio_latents( - audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std - ) - audio_latents = _unpack_audio_latents(audio_latents, audio_num_frames, latent_mel_bins) - - return video_latents, audio_latents \ No newline at end of file + return video_latents, None \ No newline at end of file From f72ce2b09dab90fe8850b5fe57968458cc189964 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Thu, 26 Feb 2026 18:05:43 +0800 Subject: [PATCH 11/13] now network.py should work --- fastgen/networks/LTX2/network.py | 616 +++++++--------------- fastgen/networks/LTX2/pipeline_ltx2.py | 2 +- fastgen/networks/LTX2/test_ltx_network.py | 93 ++++ fastgen/networks/LTX2/transformer_ltx2.py | 2 +- 4 files changed, 299 insertions(+), 414 deletions(-) create mode 100644 fastgen/networks/LTX2/test_ltx_network.py diff --git a/fastgen/networks/LTX2/network.py b/fastgen/networks/LTX2/network.py index d0acd1d..38b5d96 100644 --- a/fastgen/networks/LTX2/network.py +++ b/fastgen/networks/LTX2/network.py @@ -2,36 +2,37 @@ # SPDX-License-Identifier: Apache-2.0 """ -LTX-2 FastGen network implementation. +LTX-2 FastGen network implementation (video-only). -Architecture verified against: - - diffusers/src/diffusers/pipelines/ltx2/pipeline_ltx2.py - - diffusers/src/diffusers/pipelines/ltx2/connectors.py - - diffusers/src/diffusers/models/transformers/transformer_ltx2.py +Uses the local customized pipeline_ltx2.py and transformer_ltx2.py which +support audio_enabled=False, so no audio weights are allocated and no audio +ops run during training or inference. Follows the FastGen network pattern established by Flux and Wan: - Inherits from FastGenNetwork - Monkey-patches classify_forward onto self.transformer - - forward() handles video-only latent for distillation (audio flows through but is ignored for loss) - - feature_indices extracts video hidden_states only + - forward() operates entirely in video latent space [B, C, F, H, W] + - sample() calls self() (i.e. forward()) — NOT self.transformer directly + - feature_indices extracts video hidden_states for the discriminator """ -import copy import types from typing import Any, List, Optional, Set, Tuple, Union import numpy as np import torch import torch.nn as nn -from diffusers.models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video -from diffusers.models.transformers import LTX2VideoTransformer3DModel -from diffusers.models.transformers.transformer_ltx2 import LTX2VideoTransformerBlock -from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors -from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder -from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from torch.distributed.fsdp import fully_shard from transformers import Gemma3ForConditionalGeneration, GemmaTokenizerFast +# ---- Local customized modules (same folder as this file) ---- +from .pipeline_ltx2 import LTX2Pipeline +from .transformer_ltx2 import LTX2VideoTransformer3DModel, LTX2VideoTransformerBlock + +from diffusers.models.autoencoders import AutoencoderKLLTX2Video +from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler + from fastgen.networks.network import FastGenNetwork from fastgen.networks.noise_schedule import NET_PRED_TYPES from fastgen.utils.distributed.fsdp import apply_fsdp_checkpointing @@ -39,41 +40,30 @@ # --------------------------------------------------------------------------- -# Helpers (mirrors of diffusers pipeline static methods) +# Latent pack / unpack helpers (video only) # --------------------------------------------------------------------------- def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: - """Pack video latents [B, C, F, H, W] → [B, F//pt * H//p * W//p, C*pt*p*p].""" + """[B, C, F, H, W] → [B, F//pt * H//p * W//p, C*pt*p*p]""" B, C, F, H, W = latents.shape - pF = F // patch_size_t - pH = H // patch_size - pW = W // patch_size - latents = latents.reshape(B, C, pF, patch_size_t, pH, patch_size, pW, patch_size) + latents = latents.reshape(B, C, F // patch_size_t, patch_size_t, + H // patch_size, patch_size, + W // patch_size, patch_size) latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) return latents def _unpack_latents( latents: torch.Tensor, num_frames: int, height: int, width: int, - patch_size: int = 1, patch_size_t: int = 1 + patch_size: int = 1, patch_size_t: int = 1, ) -> torch.Tensor: - """Unpack video latents [B, T, D] → [B, C, F, H, W].""" + """[B, T, D] → [B, C, F, H, W]""" B = latents.size(0) latents = latents.reshape(B, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) return latents -def _pack_audio_latents(latents: torch.Tensor) -> torch.Tensor: - """Pack audio latents [B, C, L, M] → [B, L, C*M].""" - return latents.transpose(1, 2).flatten(2, 3) - - -def _unpack_audio_latents(latents: torch.Tensor, latent_length: int, num_mel_bins: int) -> torch.Tensor: - """Unpack audio latents [B, L, C*M] -> [B, C, L, M].""" - return latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) - - def _normalize_latents( latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0, @@ -92,22 +82,6 @@ def _denormalize_latents( return latents * std / scaling_factor + mean -def _normalize_audio_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor -) -> torch.Tensor: - mean = latents_mean.to(latents.device, latents.dtype) - std = latents_std.to(latents.device, latents.dtype) - return (latents - mean) / std - - -def _denormalize_audio_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor -) -> torch.Tensor: - mean = latents_mean.to(latents.device, latents.dtype) - std = latents_std.to(latents.device, latents.dtype) - return latents * std + mean - - def _pack_text_embeds( text_hidden_states: torch.Tensor, sequence_lengths: torch.Tensor, @@ -116,30 +90,25 @@ def _pack_text_embeds( scale_factor: int = 8, eps: float = 1e-6, ) -> torch.Tensor: - """ - Stack all Gemma hidden-state layers, normalize per-batch/per-layer over - non-padded positions, and pack into [B, T, H * num_layers]. - """ B, T, H, L = text_hidden_states.shape original_dtype = text_hidden_states.dtype - token_indices = torch.arange(T, device=device).unsqueeze(0) # [1, T] + token_indices = torch.arange(T, device=device).unsqueeze(0) if padding_side == "right": mask = token_indices < sequence_lengths[:, None] - else: # left + else: start = T - sequence_lengths[:, None] mask = token_indices >= start - mask = mask[:, :, None, None] # [B, T, 1, 1] + mask = mask[:, :, None, None] masked = text_hidden_states.masked_fill(~mask, 0.0) num_valid = (sequence_lengths * H).view(B, 1, 1, 1) mean = masked.sum(dim=(1, 2), keepdim=True) / (num_valid + eps) - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) normed = (text_hidden_states - mean) / (x_max - x_min + eps) * scale_factor - normed = normed.flatten(2) # [B, T, H*L] + normed = normed.flatten(2) mask_flat = mask.squeeze(-1).expand(-1, -1, H * L) normed = normed.masked_fill(~mask_flat, 0.0) return normed.to(original_dtype) @@ -152,14 +121,12 @@ def _calculate_shift( base_shift: float = 0.95, max_shift: float = 2.05, ) -> float: - """Mirrors the pipeline's calculate_shift — defaults match LTX-2 scheduler config.""" m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len return image_seq_len * m + b def _retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=None, mu=None): - """Call scheduler.set_timesteps, forwarding mu when dynamic shifting is enabled.""" kwargs = {} if mu is not None: kwargs["mu"] = mu @@ -171,56 +138,48 @@ def _retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=None, mu= # --------------------------------------------------------------------------- -# classify_forward — monkey-patched onto self.transformer +# classify_forward — monkey-patched onto self.transformer (video-only) # --------------------------------------------------------------------------- def classify_forward( self, hidden_states: torch.Tensor, - audio_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - audio_encoder_hidden_states: torch.Tensor, timestep: torch.Tensor, - audio_timestep: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - audio_encoder_attention_mask: Optional[torch.Tensor] = None, num_frames: Optional[int] = None, height: Optional[int] = None, width: Optional[int] = None, fps: float = 24.0, - audio_num_frames: Optional[int] = None, video_coords: Optional[torch.Tensor] = None, - audio_coords: Optional[torch.Tensor] = None, attention_kwargs: Optional[dict] = None, - return_dict: bool = True, # accepted for API compatibility; always ignored + return_dict: bool = False, # FastGen distillation kwargs return_features_early: bool = False, feature_indices: Optional[Set[int]] = None, return_logvar: bool = False, ) -> Union[ - Tuple[torch.Tensor, torch.Tensor], # (video_out, audio_out) - Tuple[Tuple[torch.Tensor, torch.Tensor], List[torch.Tensor]], # ((video_out, audio_out), features) - List[torch.Tensor], # features only (early exit) + torch.Tensor, # video_output only + Tuple[torch.Tensor, List[torch.Tensor]], # (video_output, features) + List[torch.Tensor], # features only (early exit) ]: """ - Drop-in replacement for LTX2VideoTransformer3DModel.forward that adds FastGen - distillation support (feature extraction, early exit, logvar). + Video-only classify_forward monkey-patched onto LTX2VideoTransformer3DModel. - Audio always flows through every block unchanged — we never short-circuit it — - but only video hidden_states are stored as features for the discriminator. + Since the transformer is built with audio_enabled=False, all audio arguments + are absent. Only video hidden_states are processed and stored as features. Returns ------- Normal mode (feature_indices empty, return_features_early False): - (video_output, audio_output) — identical to the original forward + video_output [B, T_v, C_out] Feature mode (feature_indices non-empty, return_features_early False): - ((video_output, audio_output), List[video_feature_tensors]) + (video_output, List[video_feature_tensors]) Early-exit mode (return_features_early True): - List[video_feature_tensors] — forward stops as soon as all features collected + List[video_feature_tensors] """ - # LoRA scale handling — mirrors the @apply_lora_scale decorator in upstream from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() @@ -229,7 +188,7 @@ def classify_forward( lora_scale = 1.0 if USE_PEFT_BACKEND: scale_lora_layers(self, lora_scale) - print("calling classfiy forward:") + if feature_indices is None: feature_indices = set() @@ -238,193 +197,107 @@ def classify_forward( unscale_lora_layers(self, lora_scale) return [] - # ------------------------------------------------------------------ # - # Steps 1-4: identical to the original forward (no changes) - # ------------------------------------------------------------------ # - audio_timestep = audio_timestep if audio_timestep is not None else timestep - - # Convert attention masks to additive bias form + # -- Attention mask conversion -- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: - audio_encoder_attention_mask = ( - 1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype) - ) * -10000.0 - audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) - batch_size = hidden_states.size(0) - # 1. RoPE positional embeddings + # 1. RoPE if video_coords is None: video_coords = self.rope.prepare_video_coords( batch_size, num_frames, height, width, hidden_states.device, fps=fps ) - if audio_coords is None: - audio_coords = self.audio_rope.prepare_audio_coords( - batch_size, audio_num_frames, audio_hidden_states.device - ) - video_rotary_emb = self.rope(video_coords, device=hidden_states.device) - audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) - - video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) - audio_cross_attn_rotary_emb = self.cross_attn_audio_rope( - audio_coords[:, 0:1, :], device=audio_hidden_states.device - ) - # 2. Patchify input projections + # 2. Patchify hidden_states = self.proj_in(hidden_states) - audio_hidden_states = self.audio_proj_in(audio_hidden_states) - - # 3. Timestep embeddings and modulation parameters - timestep_cross_attn_gate_scale_factor = ( - self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier - ) + # 3. Timestep embeddings temb, embedded_timestep = self.time_embed( timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, ) temb = temb.view(batch_size, -1, temb.size(-1)) embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) - temb_audio, audio_embedded_timestep = self.audio_time_embed( - audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, - ) - temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) - audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) - - video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( - timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, - ) - video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( - timestep.flatten() * timestep_cross_attn_gate_scale_factor, - batch_size=batch_size, hidden_dtype=hidden_states.dtype, - ) - video_cross_attn_scale_shift = video_cross_attn_scale_shift.view( - batch_size, -1, video_cross_attn_scale_shift.shape[-1] - ) - video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view( - batch_size, -1, video_cross_attn_a2v_gate.shape[-1] - ) - - audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( - audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, - ) - audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( - audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, - batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, - ) - audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view( - batch_size, -1, audio_cross_attn_scale_shift.shape[-1] - ) - audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view( - batch_size, -1, audio_cross_attn_v2a_gate.shape[-1] - ) - # 4. Prompt embeddings encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) - audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) - audio_encoder_hidden_states = audio_encoder_hidden_states.view( - batch_size, -1, audio_hidden_states.size(-1) - ) - - # ------------------------------------------------------------------ # - # Step 5: Block loop with video-only feature extraction - # Audio always flows through every block — we never skip it. - # ------------------------------------------------------------------ # + # 5. Block loop — video only, with feature extraction features: List[torch.Tensor] = [] for idx, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states, audio_hidden_states = self._gradient_checkpointing_func( + hidden_states, _ = self._gradient_checkpointing_func( block, hidden_states, - audio_hidden_states, + None, # audio_hidden_states — not used encoder_hidden_states, - audio_encoder_hidden_states, + None, # audio_encoder_hidden_states — not used temb, - temb_audio, - video_cross_attn_scale_shift, - audio_cross_attn_scale_shift, - video_cross_attn_a2v_gate, - audio_cross_attn_v2a_gate, + None, # temb_audio + None, # video_cross_attn_scale_shift + None, # audio_cross_attn_scale_shift + None, # video_cross_attn_a2v_gate + None, # audio_cross_attn_v2a_gate video_rotary_emb, - audio_rotary_emb, - video_cross_attn_rotary_emb, - audio_cross_attn_rotary_emb, + None, # audio_rotary_emb + None, # video_cross_attn_rotary_emb + None, # audio_cross_attn_rotary_emb encoder_attention_mask, - audio_encoder_attention_mask, + None, # audio_encoder_attention_mask ) else: - hidden_states, audio_hidden_states = block( + hidden_states, _ = block( hidden_states=hidden_states, - audio_hidden_states=audio_hidden_states, + audio_hidden_states=None, encoder_hidden_states=encoder_hidden_states, - audio_encoder_hidden_states=audio_encoder_hidden_states, + audio_encoder_hidden_states=None, temb=temb, - temb_audio=temb_audio, - temb_ca_scale_shift=video_cross_attn_scale_shift, - temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, - temb_ca_gate=video_cross_attn_a2v_gate, - temb_ca_audio_gate=audio_cross_attn_v2a_gate, + temb_audio=None, + temb_ca_scale_shift=None, + temb_ca_audio_scale_shift=None, + temb_ca_gate=None, + temb_ca_audio_gate=None, video_rotary_emb=video_rotary_emb, - audio_rotary_emb=audio_rotary_emb, - ca_video_rotary_emb=video_cross_attn_rotary_emb, - ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + audio_rotary_emb=None, + ca_video_rotary_emb=None, + ca_audio_rotary_emb=None, encoder_attention_mask=encoder_attention_mask, - audio_encoder_attention_mask=audio_encoder_attention_mask, + audio_encoder_attention_mask=None, + audio_enabled=False, ) - # Video-only feature extraction at requested block indices - # TODO: we only extract the video feature for now if idx in feature_indices: - features.append(hidden_states.clone()) # [B, T_v, D_v] — packed video tokens + features.append(hidden_states.clone()) - # Early exit once all requested features are collected if return_features_early and len(features) == len(feature_indices): + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) return features - # ------------------------------------------------------------------ # - # Step 6: Output layers (video + audio) — unchanged from original - # ------------------------------------------------------------------ # + # 6. Output layer scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] hidden_states = self.norm_out(hidden_states) hidden_states = hidden_states * (1 + scale) + shift video_output = self.proj_out(hidden_states) - audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] - audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] - audio_hidden_states = self.audio_norm_out(audio_hidden_states) - audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift - audio_output = self.audio_proj_out(audio_hidden_states) - - # ------------------------------------------------------------------ # - # Assemble output following FastGen convention - # ------------------------------------------------------------------ # - if return_features_early: - # Should have been caught above; guard for safety - assert len(features) == len(feature_indices), f"{len(features)} != {len(feature_indices)}" - return features - - # Logvar (optional — requires logvar_linear to be added to the transformer) + # -- Logvar (optional) -- logvar = None if return_logvar: assert hasattr(self, "logvar_linear"), ( - "logvar_linear is required when return_logvar=True. " - "It is added by LTX2.__init__." + "logvar_linear must exist on transformer. It is added by LTX2.__init__." ) - # temb has shape [B, T_tokens, inner_dim]; take mean over tokens for a scalar logvar per sample logvar = self.logvar_linear(temb.mean(dim=1)) # [B, 1] + # -- Assemble output -- if len(feature_indices) == 0: - out = (video_output, audio_output) + out = video_output else: - out = [(video_output, audio_output), features] + out = (video_output, features) if USE_PEFT_BACKEND: unscale_lora_layers(self, lora_scale) @@ -439,12 +312,7 @@ def classify_forward( # --------------------------------------------------------------------------- class LTX2TextEncoder(nn.Module): - """ - Wraps Gemma3ForConditionalGeneration for LTX-2 text conditioning. - - Returns both the packed prompt embeddings AND the tokenizer attention mask, - which is required by LTX2TextConnectors. - """ + """Wraps Gemma3 text encoder for LTX-2 conditioning.""" def __init__(self, model_id: str): super().__init__() @@ -466,19 +334,10 @@ def encode( max_sequence_length: int = 1024, scale_factor: int = 8, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Encode text prompt(s) into packed Gemma hidden states. - - Returns - ------- - prompt_embeds : torch.Tensor [B, T, H * num_layers] - attention_mask : torch.Tensor [B, T] - """ if isinstance(prompt, str): prompt = [prompt] device = next(self.text_encoder.parameters()).device - text_inputs = self.tokenizer( prompt, padding="max_length", @@ -496,8 +355,6 @@ def encode( output_hidden_states=True, return_dict=True, ) - - # Stack all hidden states: [B, T, H, num_layers] hidden_states = torch.stack(outputs.hidden_states, dim=-1) sequence_lengths = attention_mask.sum(dim=-1) @@ -517,20 +374,22 @@ def to(self, *args, **kwargs): # --------------------------------------------------------------------------- -# Main LTX-2 network — follows FastGen pattern (Flux / Wan) +# Main LTX-2 network (video-only) # --------------------------------------------------------------------------- class LTX2(FastGenNetwork): """ - FastGen wrapper for LTX-2 audio-video generation. + FastGen wrapper for LTX-2, video-only distillation. + + Uses the local customized transformer_ltx2.py (audio_enabled=False) and + pipeline_ltx2.py (generate_audio=False) so no audio weights are allocated + and no audio ops run at any point. Distillation targets video only: - forward() receives and returns video latents [B, C, F, H, W] - - Audio is generated internally but not used for the distillation loss - classify_forward extracts video hidden_states at requested block indices - - Component layout: - text_encoder → connectors → transformer (patched) → vae → audio_vae → vocoder + - sample() calls self() (forward()) — the pipeline is used only for its + helper utilities (latent prep, scheduler config) """ MODEL_ID = "Lightricks/LTX-2" @@ -544,17 +403,6 @@ def __init__( load_pretrained: bool = True, **model_kwargs, ): - """ - LTX-2 constructor. - - Args: - model_id: HuggingFace model ID or local path. Defaults to "Lightricks/LTX-2". - net_pred_type: Prediction type. Defaults to "flow" (flow matching). - schedule_type: Schedule type. Defaults to "rf" (rectified flow). - disable_grad_ckpt: Disable gradient checkpointing during training. - Set True when using FSDP to avoid memory access errors. - load_pretrained: Load pretrained weights. If False, initialises from config only. - """ super().__init__(net_pred_type=net_pred_type, schedule_type=schedule_type, **model_kwargs) self.model_id = model_id @@ -562,10 +410,9 @@ def __init__( self._initialize_network(model_id, load_pretrained) - # Monkey-patch classify_forward onto self.transformer (same pattern as Flux / Wan) + # Monkey-patch classify_forward (video-only version) self.transformer.forward = types.MethodType(classify_forward, self.transformer) - # Gradient checkpointing if disable_grad_ckpt: self.transformer.disable_gradient_checkpointing() else: @@ -578,91 +425,74 @@ def __init__( # ------------------------------------------------------------------ def _initialize_network(self, model_id: str, load_pretrained: bool) -> None: - """Initialize the transformer and supporting modules.""" in_meta_context = self._is_in_meta_context() - should_load_weights = load_pretrained and (not in_meta_context) + should_load_weights = load_pretrained and not in_meta_context + # -- Transformer: audio_enabled=False → no audio weights allocated -- if should_load_weights: - logger.info("Loading LTX-2 transformer from pretrained") - self.transformer: LTX2VideoTransformer3DModel = LTX2VideoTransformer3DModel.from_pretrained( + logger.info("Loading LTX-2 transformer from pretrained (audio_enabled=False)") + # Load pretrained AV checkpoint then drop audio keys via strict=False + av_transformer = LTX2VideoTransformer3DModel.from_pretrained( model_id, subfolder="transformer" ) + config = av_transformer.config + # Build video-only transformer + self.transformer = LTX2VideoTransformer3DModel.from_config(config, audio_enabled=False) + missing, unexpected = self.transformer.load_state_dict( + av_transformer.state_dict(), strict=False + ) + assert len(missing) == 0, f"Missing video keys: {missing}" + logger.info(f"Dropped {len(unexpected)} audio keys from pretrained checkpoint") + del av_transformer else: config = LTX2VideoTransformer3DModel.load_config(model_id, subfolder="transformer") if in_meta_context: - logger.info( - "Initializing LTX-2 transformer on meta device " - "(zero memory, will receive weights via FSDP sync)" - ) + logger.info("Initializing LTX-2 transformer on meta device (audio_enabled=False)") else: - logger.info("Initializing LTX-2 transformer from config (no pretrained weights)") - logger.warning("LTX-2 transformer being initialized from config. No weights are loaded!") - self.transformer: LTX2VideoTransformer3DModel = LTX2VideoTransformer3DModel.from_config(config) + logger.warning("LTX-2 transformer initialized from config only — no pretrained weights!") + self.transformer = LTX2VideoTransformer3DModel.from_config(config, audio_enabled=False) - # inner_dim = num_attention_heads * attention_head_dim + # inner_dim for logvar_linear inner_dim = ( self.transformer.config.num_attention_heads * self.transformer.config.attention_head_dim ) - - # Add logvar_linear for uncertainty weighting (DMD2 / f-distill) - # temb mean has shape [B, inner_dim] → logvar scalar per sample self.transformer.logvar_linear = nn.Linear(inner_dim, 1) - logger.info(f"Added logvar_linear ({inner_dim} → 1) to LTX-2 transformer") + logger.info(f"Added logvar_linear ({inner_dim} → 1) to transformer") - # Connectors: top-level sibling of transformer (NOT nested inside it) + # -- Connectors -- if should_load_weights: self.connectors: LTX2TextConnectors = LTX2TextConnectors.from_pretrained( model_id, subfolder="connectors" ) else: - # Connectors are lightweight; always load if pretrained is skipped for the transformer - logger.warning("Skipping connector pretrained load (meta context or load_pretrained=False)") - self.connectors = None # will be loaded lazily via init_preprocessors + logger.warning("Skipping connector pretrained load") + self.connectors = None - # Cache compression ratios used by forward() and sample() + # -- VAE (video only — no audio_vae, no vocoder) -- if should_load_weights: - # VAEs (needed for sample(); not for the training forward pass) self.vae: AutoencoderKLLTX2Video = AutoencoderKLLTX2Video.from_pretrained( model_id, subfolder="vae" ) self.vae.eval().requires_grad_(False) - - self.audio_vae: AutoencoderKLLTX2Audio = AutoencoderKLLTX2Audio.from_pretrained( - model_id, subfolder="audio_vae" - ) - self.audio_vae.eval().requires_grad_(False) - - self.vocoder: LTX2Vocoder = LTX2Vocoder.from_pretrained( - model_id, subfolder="vocoder" - ) - self.vocoder.eval().requires_grad_(False) - self._cache_vae_constants() - # Scheduler (used in sample()) + # -- Scheduler -- self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( model_id, subfolder="scheduler" ) def _cache_vae_constants(self) -> None: - """Cache VAE spatial/temporal compression constants for use in forward() / sample().""" - self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio - self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio + self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio + self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio self.transformer_spatial_patch_size = self.transformer.config.patch_size self.transformer_temporal_patch_size = self.transformer.config.patch_size_t - self.audio_sampling_rate = self.audio_vae.config.sample_rate - self.audio_hop_length = self.audio_vae.config.mel_hop_length - self.audio_vae_temporal_compression = self.audio_vae.temporal_compression_ratio - self.audio_vae_mel_compression_ratio = self.audio_vae.mel_compression_ratio - # ------------------------------------------------------------------ - # Preprocessor initialisation (lazy, matches Flux / Wan pattern) + # Preprocessors (lazy) # ------------------------------------------------------------------ def init_preprocessors(self): - """Initialize text encoder and connectors.""" if not hasattr(self, "text_encoder") or self.text_encoder is None: self.init_text_encoder() if self.connectors is None: @@ -671,7 +501,6 @@ def init_preprocessors(self): ) def init_text_encoder(self): - """Initialize the Gemma3 text encoder for LTX-2.""" self.text_encoder = LTX2TextEncoder(model_id=self.model_id) # ------------------------------------------------------------------ @@ -680,16 +509,10 @@ def init_text_encoder(self): def to(self, *args, **kwargs): super().to(*args, **kwargs) - if hasattr(self, "text_encoder") and self.text_encoder is not None: - self.text_encoder.to(*args, **kwargs) - if hasattr(self, "connectors") and self.connectors is not None: - self.connectors.to(*args, **kwargs) - if hasattr(self, "vae") and self.vae is not None: - self.vae.to(*args, **kwargs) - if hasattr(self, "audio_vae") and self.audio_vae is not None: - self.audio_vae.to(*args, **kwargs) - if hasattr(self, "vocoder") and self.vocoder is not None: - self.vocoder.to(*args, **kwargs) + for attr in ("text_encoder", "connectors", "vae"): + obj = getattr(self, attr, None) + if obj is not None: + obj.to(*args, **kwargs) return self # ------------------------------------------------------------------ @@ -697,28 +520,23 @@ def to(self, *args, **kwargs): # ------------------------------------------------------------------ def fully_shard(self, **kwargs): - """Fully shard the LTX-2 transformer for FSDP2. - - Shards self.transformer (not self) to avoid ABC __class__ assignment issues. - """ if self.transformer.gradient_checkpointing: self.transformer.disable_gradient_checkpointing() apply_fsdp_checkpointing( self.transformer, - check_fn=lambda block: isinstance(block, LTX2VideoTransformerBlock), + check_fn=lambda b: isinstance(b, LTX2VideoTransformerBlock), ) - logger.info("Applied FSDP activation checkpointing to LTX-2 transformer blocks") + logger.info("Applied FSDP activation checkpointing to transformer blocks") for block in self.transformer.transformer_blocks: fully_shard(block, **kwargs) fully_shard(self.transformer, **kwargs) # ------------------------------------------------------------------ - # reset_parameters (required for FSDP meta device init) + # reset_parameters (FSDP meta device) # ------------------------------------------------------------------ def reset_parameters(self): - """Reinitialise parameters after meta device materialisation (FSDP2).""" for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) @@ -726,35 +544,7 @@ def reset_parameters(self): nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, std=0.02) - super().reset_parameters() - logger.debug("Reinitialized LTX-2 parameters") - - # ------------------------------------------------------------------ - # Audio latent sizing helper (shared by forward and sample) - # ------------------------------------------------------------------ - - def _compute_audio_shape( - self, latent_f: int, fps: float, device: torch.device, dtype: torch.dtype - ) -> Tuple[int, int, int]: - """ - Compute audio latent dimensions from video latent frame count. - - Returns (audio_num_frames, latent_mel_bins, num_audio_ch). - """ - pixel_frames = (latent_f - 1) * self.vae_temporal_compression_ratio + 1 - duration_s = pixel_frames / fps - - audio_latents_per_second = ( - self.audio_sampling_rate - / self.audio_hop_length - / float(self.audio_vae_temporal_compression) - ) - audio_num_frames = round(duration_s * audio_latents_per_second) - num_mel_bins = self.audio_vae.config.mel_bins - latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_audio_ch = self.audio_vae.config.latent_channels - return audio_num_frames, latent_mel_bins, num_audio_ch # ------------------------------------------------------------------ # forward() — video-only distillation interface @@ -765,37 +555,32 @@ def forward( x_t: torch.Tensor, t: torch.Tensor, condition: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - r: Optional[torch.Tensor] = None, # unused, kept for API compatibility + r: Optional[torch.Tensor] = None, fps: float = 24.0, return_features_early: bool = False, feature_indices: Optional[Set[int]] = None, return_logvar: bool = False, fwd_pred_type: Optional[str] = None, **fwd_kwargs, - ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: - """Forward pass for distillation — video latents in, video latents out. - - Audio latents are generated as random noise internally so the joint - audio-video transformer runs normally, but only the video prediction is - returned and used for loss computation. + ) -> Union[torch.Tensor, List[torch.Tensor], Tuple]: + """ + Training forward pass: video latents [B, C, F, H, W] → video latents. Args: - x_t: Video latents [B, C, F, H, W]. - t: Timestep [B] in [0, 1]. - condition: Tuple of (prompt_embeds [B, T, D], attention_mask [B, T]) - from LTX2TextEncoder.encode(). - r: Unused (kept for FastGen API compatibility). - fps: Frames per second (needed for RoPE coordinate computation). - return_features_early: Return video features as soon as collected. - feature_indices: Set of transformer block indices to extract video features from. - return_logvar: Return log-variance estimate alongside the output. - fwd_pred_type: Override prediction type. + x_t: Noisy video latents [B, C, F, H, W]. + t: Timesteps [B]. + condition: (prompt_embeds [B, T, D], attention_mask [B, T]). + fps: Frames per second for RoPE coords. + return_features_early: Exit once all feature_indices are collected. + feature_indices: Block indices to extract hidden states from. + return_logvar: Return log-variance alongside output. + fwd_pred_type: Override prediction type. Returns: - Normal: video_out [B, C, F, H, W] - With features: (video_out, List[video_feature_tensors]) - Early exit: List[video_feature_tensors] - With logvar: (above, logvar [B, 1]) + Normal: video_out [B, C, F, H, W] + With features: (video_out, List[feature_tensors]) + Early exit: List[feature_tensors] + With logvar: (above, logvar [B, 1]) """ if feature_indices is None: feature_indices = set() @@ -805,106 +590,82 @@ def forward( if fwd_pred_type is None: fwd_pred_type = self.net_pred_type else: - assert fwd_pred_type in NET_PRED_TYPES, f"{fwd_pred_type} is not supported" + assert fwd_pred_type in NET_PRED_TYPES, f"Unsupported pred type: {fwd_pred_type}" batch_size = x_t.shape[0] _, _, latent_f, latent_h, latent_w = x_t.shape - # Unpack text conditioning + # -- Text conditioning -- prompt_embeds, attention_mask = condition - - # ---- Run connectors to get per-modality encoder hidden states ---- - # attention_mask from tokenizer is binary [B, T]; convert to additive bias additive_mask = (1 - attention_mask.to(prompt_embeds.dtype)) * -1_000_000.0 - connector_video_embeds, connector_audio_embeds, connector_attn_mask = self.connectors( + # Connectors return (video_embeds, audio_embeds, attn_mask); + # we only use the video branch since audio_enabled=False. + connector_video_embeds, _, connector_attn_mask = self.connectors( prompt_embeds, additive_mask, additive_mask=True ) - # ---- Timestep: [B] scalar per sample, matching pipeline_ltx2.py ---- - # time_embed calls .flatten() then views back to [B, 1, D] internally. - # Do NOT expand to per-token here — that is handled inside time_embed. - timestep = t.to(x_t.dtype).expand(batch_size) # [B] + # -- Timestep -- + timestep = t.to(x_t.dtype).expand(batch_size) - # ---- Pack video latents ---- + # -- Pack video latents -- hidden_states = _pack_latents( x_t, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size ) - # ---- Audio latents: random noise (not trained, just needed to run the joint transformer) ---- - audio_num_frames, latent_mel_bins, num_audio_ch = self._compute_audio_shape( - latent_f, fps, x_t.device, x_t.dtype - ) - audio_latents = torch.randn( - batch_size, num_audio_ch, audio_num_frames, latent_mel_bins, - device=x_t.device, dtype=x_t.dtype, - ) - audio_hidden_states = _pack_audio_latents(audio_latents) - - # ---- RoPE coordinates (pre-computed once, reused in transformer) ---- + # -- RoPE video coords -- video_coords = self.transformer.rope.prepare_video_coords( batch_size, latent_f, latent_h, latent_w, x_t.device, fps=fps ) - audio_coords = self.transformer.audio_rope.prepare_audio_coords( - batch_size, audio_num_frames, x_t.device - ) - # ---- Transformer forward (our patched classify_forward) ---- + # -- Transformer forward (our patched classify_forward, video-only) -- model_outputs = self.transformer( hidden_states=hidden_states, - audio_hidden_states=audio_hidden_states, encoder_hidden_states=connector_video_embeds, - audio_encoder_hidden_states=connector_audio_embeds, encoder_attention_mask=connector_attn_mask, - audio_encoder_attention_mask=connector_attn_mask, timestep=timestep, num_frames=latent_f, height=latent_h, width=latent_w, fps=fps, - audio_num_frames=audio_num_frames, video_coords=video_coords, - audio_coords=audio_coords, return_features_early=return_features_early, feature_indices=feature_indices, return_logvar=return_logvar, ) - # ---- Early exit: list of video feature tensors ---- + # -- Early exit -- if return_features_early: - return model_outputs # List[Tensor], each [B, T_v, D_v] + return model_outputs # List[Tensor] - # ---- Unpack logvar if requested ---- + # -- Unpack logvar -- if return_logvar: out, logvar = model_outputs[0], model_outputs[1] else: out = model_outputs - # ---- Extract video prediction only; discard audio ---- + # -- Separate video output from features -- if len(feature_indices) == 0: - # out is (video_output, audio_output) - video_packed = out[0] # [B, T_v, C_packed] + video_packed = out # [B, T_v, C] features = None else: - # out is [(video_output, audio_output), features] - video_packed = out[0][0] # [B, T_v, C_packed] - features = out[1] # List[Tensor] + video_packed, features = out[0], out[1] - # ---- Unpack video tokens → [B, C, F, H, W] ---- + # -- Unpack → [B, C, F, H, W] -- video_out = _unpack_latents( video_packed, latent_f, latent_h, latent_w, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, ) - # ---- Convert model output to requested prediction type ---- + # -- Prediction type conversion -- video_out = self.noise_scheduler.convert_model_output( x_t, video_out, t, src_pred_type=self.net_pred_type, target_pred_type=fwd_pred_type, ) - # ---- Re-pack output following FastGen convention ---- + # -- Assemble final output -- if features is not None: - out = [video_out, features] + out = (video_out, features) else: out = video_out @@ -913,8 +674,9 @@ def forward( return out # ------------------------------------------------------------------ - # sample() — full denoising loop for inference - # Follows pipeline_ltx2.py exactly (verified working logic preserved) + # sample() — full denoising loop; calls self() (forward()) + # Works entirely in unpacked [B, C, F, H, W] space so forward() is + # called with the correct input shape at every step. # ------------------------------------------------------------------ @torch.no_grad() @@ -930,13 +692,14 @@ def sample( **kwargs, ) -> Tuple[torch.Tensor, None]: """ - Run the full denoising loop for text-to-video generation (audio always None). + Denoising loop for video generation. Calls self() at each step. + + noise must be unpacked video latents [B, C, F, H, W]. Returns ------- - (video_latents, None): - video: [B, C, F, H, W] denormalised video latents - audio: always None + (video_latents [B, C, F, H, W], None) + Denormalised video latents ready for VAE decode. Audio is always None. """ fps = frame_rate if frame_rate is not None else fps do_cfg = neg_condition is not None and guidance_scale > 1.0 @@ -944,15 +707,18 @@ def sample( transformer_dtype = self.transformer.dtype transformer_device = next(self.transformer.parameters()).device - # Move latents to transformer device and dtype + # noise must arrive as [B, C, F, H, W] (unpacked) + assert noise.ndim == 5, "sample() expects unpacked latents [B, C, F, H, W]" video_latents = noise.to(device=transformer_device, dtype=transformer_dtype) - # Build combined condition for CFG so connectors are called once per step + # -- Build combined CFG condition (processed once, reused every step) -- if do_cfg: neg_embeds, neg_mask = neg_condition cond_embeds, cond_mask = condition combined_condition = ( - torch.cat([neg_embeds, cond_embeds], dim=0).to(device=transformer_device, dtype=transformer_dtype), + torch.cat([neg_embeds, cond_embeds], dim=0).to( + device=transformer_device, dtype=transformer_dtype + ), torch.cat([neg_mask, cond_mask], dim=0).to(device=transformer_device), ) else: @@ -962,7 +728,7 @@ def sample( mask.to(device=transformer_device), ) - # ---- Scheduler timesteps ---- + # -- Scheduler timesteps with mu shift -- B, C, latent_f, latent_h, latent_w = video_latents.shape sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps) video_seq_len = latent_f * latent_h * latent_w @@ -977,11 +743,16 @@ def sample( self.scheduler, num_steps, transformer_device, sigmas=sigmas, mu=mu ) - # ---- Denoising loop ---- + # -- Denoising loop (unpacked latents throughout) -- for t in timesteps: + # Duplicate along batch for CFG latent_input = torch.cat([video_latents] * 2) if do_cfg else video_latents - t_input = t.to(dtype=transformer_dtype, device=transformer_device).expand(latent_input.shape[0]) + t_input = ( + t.to(dtype=transformer_dtype, device=transformer_device) + .expand(latent_input.shape[0]) + ) + # self() → forward() — expects and returns [B, C, F, H, W] noise_pred = self( latent_input, t_input, @@ -992,11 +763,32 @@ def sample( if do_cfg: noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) - video_latents = self.scheduler.step(noise_pred, t, video_latents, return_dict=False)[0] + # Scheduler step — operates on packed tokens internally but we keep + # video_latents unpacked; use the pipeline's _pack/_unpack helpers. + video_packed = _pack_latents( + video_latents, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + noise_pred_packed = _pack_latents( + noise_pred, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + stepped_packed = self.scheduler.step( + noise_pred_packed, t, video_packed, return_dict=False + )[0] + video_latents = _unpack_latents( + stepped_packed, latent_f, latent_h, latent_w, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) - # ---- Denormalise ---- + # -- Denormalise -- video_latents = _denormalize_latents( video_latents, self.vae.latents_mean, @@ -1004,4 +796,4 @@ def sample( self.vae.config.scaling_factor, ) - return video_latents, None \ No newline at end of file + return video_latents, None diff --git a/fastgen/networks/LTX2/pipeline_ltx2.py b/fastgen/networks/LTX2/pipeline_ltx2.py index 9afa887..e644c64 100644 --- a/fastgen/networks/LTX2/pipeline_ltx2.py +++ b/fastgen/networks/LTX2/pipeline_ltx2.py @@ -1080,4 +1080,4 @@ def __call__( if not return_dict: return (video, audio) - return LTX2PipelineOutput(frames=video, audio=audio) \ No newline at end of file + return LTX2PipelineOutput(frames=video, audio=audio) diff --git a/fastgen/networks/LTX2/test_ltx_network.py b/fastgen/networks/LTX2/test_ltx_network.py new file mode 100644 index 0000000..5528b26 --- /dev/null +++ b/fastgen/networks/LTX2/test_ltx_network.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +from diffusers.pipelines.ltx2.export_utils import encode_video +from fastgen.networks.LTX2.network import LTX2 + + +def test_ltx2_generation(): + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 # LTX-2 is optimized for bfloat16 + + # 1. Initialize the LTX2 model + print("Initializing LTX-2 model...") + model = LTX2(model_id="Lightricks/LTX-2", load_pretrained=True).to(device, dtype=dtype) + model.init_preprocessors() + model.eval() + + # 2. Prepare Prompts + # prompt = "A high-performance sports car racing through a city street at night, neon lights reflecting off wet asphalt, motion blur streaking past the camera. The camera tracks low and close to the car as it accelerates aggressively, tires gripping the road, exhaust heat shimmering. Realistic lighting, cinematic depth of field, ultra-detailed textures, dynamic reflections, dramatic shadows, 4K realism, film-grade color grading." + prompt = "A tight close-up shot of a musician's hands playing a grand piano. The audio is a fast-paced, uplifting classical piano sonata. The pressing of the black and white keys visually syncs with the rapid flurry of high-pitched musical notes. There is a slight echo, suggesting the piano is in a large, empty concert hall." + prompt = "A street performer sitting on a brick stoop, strumming an acoustic guitar. The audio features a warm, indie-folk guitar chord progression. Over the guitar, a smooth, soulful human voice sings a slow, bluesy melody without words, just melodic humming and 'oohs'. The rhythmic strumming of the guitar perfectly matches the tempo of the vocal melody. Faint city traffic can be heard quietly in the deep background." + negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + print(f"Encoding prompt: {prompt}") + # FIX 1: encode() returns (embeds, mask) tuple — unpack and move each tensor separately + embeds, mask = model.text_encoder.encode(prompt, precision=dtype) + condition = (embeds.to(device), mask.to(device)) + neg_embeds, neg_mask = model.text_encoder.encode(negative_prompt, precision=dtype) + neg_condition = (neg_embeds.to(device), neg_mask.to(device)) + + # 3. Define Video Parameters + # LTX-2 video dimensions must be divisible by 32 spatially and (8n+1) temporally + height, width = 480, 704 + num_frames = 81 # (8 * 10) + 1 + batch_size = 1 + + # Calculate latent dimensions + latent_f = (num_frames - 1) // model.vae_temporal_compression_ratio + 1 + latent_h = height // model.vae_spatial_compression_ratio + latent_w = width // model.vae_spatial_compression_ratio + + # 4. Generate Initial Noise + # FIX 2: model.vae IS AutoencoderKLLTX2Video directly — no .vae sub-attribute. + # Use transformer.config.in_channels as the pipeline does (line 989). + latent_channels = model.transformer.config.in_channels # 128 + + # FIX 3: noise must be float32 — the pipeline creates latents in torch.float32 + # (prepare_latents line 997) and keeps them float32 through the scheduler. + # Only cast to bfloat16 right before the transformer call. + noise = torch.randn( + batch_size, latent_channels, latent_f, latent_h, latent_w, + device=device, dtype=torch.float32 + ) + + # 5. Run Sampling (Inference) + print("Starting sampling process...") + with torch.no_grad(): + latents, audio_latents = model.sample( + noise=noise, + condition=condition, + neg_condition=neg_condition, + guidance_scale=4.0, + num_steps=40, + fps=24.0, + ) + # latents: [B, C, F, H, W] denormalised video latents (float32) + # audio_latents: [B, C, L, M] denormalised audio latents (float32) + + # 6. Decode Latents to Video + print("Decoding latents to video...") + with torch.no_grad(): + video_tensor = model.vae.decode(latents.to(model.vae.dtype)) + # model.vae.decode(latents.to(model.vae.dtype), return_dict=False)[0] + # video_tensor: [B, C, F, H, W] in ~[-1, 1] + + # 7. Post-process and Save + # Convert [B, C, F, H, W] -> [F, H, W, C] uint8 + video_np = ( + (video_tensor[0][0].cpu().float().permute(1, 2, 3, 0).numpy() + 1.0) * 127.5 + ).clip(0, 255).astype(np.uint8) + + + print("Saving video to ltx2_test.mp4...") + import imageio + with imageio.get_writer("ltx2_test.mp4", fps=24, codec="libx264", quality=8) as writer: + for frame in video_np: + writer.append_data(frame) + print("Done!") + +if __name__ == "__main__": + test_ltx2_generation() \ No newline at end of file diff --git a/fastgen/networks/LTX2/transformer_ltx2.py b/fastgen/networks/LTX2/transformer_ltx2.py index 9d331b1..1d66730 100644 --- a/fastgen/networks/LTX2/transformer_ltx2.py +++ b/fastgen/networks/LTX2/transformer_ltx2.py @@ -1200,4 +1200,4 @@ def forward( if not return_dict: return (output, audio_output) - return AudioVisualModelOutput(sample=output, audio_sample=audio_output) \ No newline at end of file + return AudioVisualModelOutput(sample=output, audio_sample=audio_output) From 756ac68fd328d18ed95e4091d0b2313e01a9d254 Mon Sep 17 00:00:00 2001 From: linnan wang Date: Thu, 26 Feb 2026 18:08:09 +0800 Subject: [PATCH 12/13] revert the changes to flux --- fastgen/networks/Flux/network.py | 1384 +++++++++++++----------------- 1 file changed, 578 insertions(+), 806 deletions(-) diff --git a/fastgen/networks/Flux/network.py b/fastgen/networks/Flux/network.py index 9a627a4..7244955 100644 --- a/fastgen/networks/Flux/network.py +++ b/fastgen/networks/Flux/network.py @@ -1,538 +1,336 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -""" -LTX-2 FastGen network implementation. - -Architecture verified against: - - diffusers/src/diffusers/pipelines/ltx2/pipeline_ltx2.py - - diffusers/src/diffusers/pipelines/ltx2/connectors.py - - diffusers/src/diffusers/models/transformers/transformer_ltx2.py - -Follows the FastGen network pattern established by Flux and Wan: - - Inherits from FastGenNetwork - - Monkey-patches classify_forward onto self.transformer - - forward() handles video-only latent for distillation (audio flows through but is ignored for loss) - - feature_indices extracts video hidden_states only -""" - -import copy +import os +from typing import Any, Optional, List, Set, Union, Tuple import types -from typing import Any, List, Optional, Set, Tuple, Union -import numpy as np import torch -import torch.nn as nn -from diffusers.models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video -from diffusers.models.transformers import LTX2VideoTransformer3DModel -from diffusers.models.transformers.transformer_ltx2 import LTX2VideoTransformerBlock -from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors -from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder -from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +import torch.utils.checkpoint +from torch import dtype from torch.distributed.fsdp import fully_shard -from transformers import Gemma3ForConditionalGeneration, GemmaTokenizerFast + +from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL +from diffusers.models import FluxTransformer2DModel +from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from fastgen.networks.network import FastGenNetwork from fastgen.networks.noise_schedule import NET_PRED_TYPES +from fastgen.utils.basic_utils import str2bool from fastgen.utils.distributed.fsdp import apply_fsdp_checkpointing import fastgen.utils.logging_utils as logger -# --------------------------------------------------------------------------- -# Helpers (mirrors of diffusers pipeline static methods) -# --------------------------------------------------------------------------- - -def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: - """Pack video latents [B, C, F, H, W] → [B, F//pt * H//p * W//p, C*pt*p*p].""" - B, C, F, H, W = latents.shape - pF = F // patch_size_t - pH = H // patch_size - pW = W // patch_size - latents = latents.reshape(B, C, pF, patch_size_t, pH, patch_size, pW, patch_size) - latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) - return latents - - -def _unpack_latents( - latents: torch.Tensor, num_frames: int, height: int, width: int, - patch_size: int = 1, patch_size_t: int = 1 -) -> torch.Tensor: - """Unpack video latents [B, T, D] → [B, C, F, H, W].""" - B = latents.size(0) - latents = latents.reshape(B, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) - latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) - return latents - - -def _pack_audio_latents(latents: torch.Tensor) -> torch.Tensor: - """Pack audio latents [B, C, L, M] → [B, L, C*M].""" - return latents.transpose(1, 2).flatten(2, 3) - - -def _unpack_audio_latents(latents: torch.Tensor, latent_length: int, num_mel_bins: int) -> torch.Tensor: - """Unpack audio latents [B, L, C*M] -> [B, C, L, M].""" - return latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) - - -def _normalize_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, - scaling_factor: float = 1.0, -) -> torch.Tensor: - mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - return (latents - mean) * scaling_factor / std - - -def _denormalize_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, - scaling_factor: float = 1.0, -) -> torch.Tensor: - mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) - return latents * std / scaling_factor + mean - - -def _normalize_audio_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor -) -> torch.Tensor: - mean = latents_mean.to(latents.device, latents.dtype) - std = latents_std.to(latents.device, latents.dtype) - return (latents - mean) / std - - -def _denormalize_audio_latents( - latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor -) -> torch.Tensor: - mean = latents_mean.to(latents.device, latents.dtype) - std = latents_std.to(latents.device, latents.dtype) - return latents * std + mean - - -def _pack_text_embeds( - text_hidden_states: torch.Tensor, - sequence_lengths: torch.Tensor, - device: torch.device, - padding_side: str = "left", - scale_factor: int = 8, - eps: float = 1e-6, -) -> torch.Tensor: - """ - Stack all Gemma hidden-state layers, normalize per-batch/per-layer over - non-padded positions, and pack into [B, T, H * num_layers]. +class FluxTextEncoder: + """Text encoder for Flux using CLIP and T5 models.""" + + def __init__(self, model_id: str): + # CLIP text encoder + self.tokenizer = CLIPTokenizer.from_pretrained( + model_id, + cache_dir=os.environ["HF_HOME"], + subfolder="tokenizer", + local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), + ) + self.text_encoder = CLIPTextModel.from_pretrained( + model_id, + cache_dir=os.environ["HF_HOME"], + subfolder="text_encoder", + local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), + ) + self.text_encoder.eval().requires_grad_(False) + + # T5 text encoder + self.tokenizer_2 = T5TokenizerFast.from_pretrained( + model_id, + cache_dir=os.environ["HF_HOME"], + subfolder="tokenizer_2", + local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), + ) + self.text_encoder_2 = T5EncoderModel.from_pretrained( + model_id, + cache_dir=os.environ["HF_HOME"], + subfolder="text_encoder_2", + local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), + ) + self.text_encoder_2.eval().requires_grad_(False) + + def encode( + self, + conditioning: Optional[Any] = None, + precision: dtype = torch.float32, + max_sequence_length: int = 512, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode text prompts to embeddings. + + Args: + conditioning: Text prompt(s) to encode. + precision: Data type for the output embeddings. + max_sequence_length: Maximum sequence length for T5 tokenization. + + Returns: + Tuple of (pooled_prompt_embeds, prompt_embeds) tensors. + """ + if isinstance(conditioning, str): + conditioning = [conditioning] + + # CLIP encoding for pooled embeddings + text_inputs = self.tokenizer( + conditioning, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + with torch.no_grad(): + text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) + prompt_embeds = self.text_encoder( + text_input_ids, + output_hidden_states=False, + ) + pooled_prompt_embeds = prompt_embeds.pooler_output.to(precision) + + # T5 encoding for text embeddings + text_inputs_2 = self.tokenizer_2( + conditioning, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + with torch.no_grad(): + text_input_ids_2 = text_inputs_2.input_ids.to(self.text_encoder_2.device) + prompt_embeds_2 = self.text_encoder_2( + text_input_ids_2, + output_hidden_states=False, + )[0].to(precision) + + return pooled_prompt_embeds, prompt_embeds_2 + + def to(self, *args, **kwargs): + """Moves the model to the specified device.""" + self.text_encoder.to(*args, **kwargs) + self.text_encoder_2.to(*args, **kwargs) + return self + + +class FluxImageEncoder: + """VAE encoder/decoder for Flux. + + Flux VAE uses both scaling_factor and shift_factor for latent normalization. """ - B, T, H, L = text_hidden_states.shape - original_dtype = text_hidden_states.dtype - - token_indices = torch.arange(T, device=device).unsqueeze(0) # [1, T] - if padding_side == "right": - mask = token_indices < sequence_lengths[:, None] - else: # left - start = T - sequence_lengths[:, None] - mask = token_indices >= start - mask = mask[:, :, None, None] # [B, T, 1, 1] - - masked = text_hidden_states.masked_fill(~mask, 0.0) - num_valid = (sequence_lengths * H).view(B, 1, 1, 1) - mean = masked.sum(dim=(1, 2), keepdim=True) / (num_valid + eps) - - x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) - x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) - - normed = (text_hidden_states - mean) / (x_max - x_min + eps) * scale_factor - normed = normed.flatten(2) # [B, T, H*L] - mask_flat = mask.squeeze(-1).expand(-1, -1, H * L) - normed = normed.masked_fill(~mask_flat, 0.0) - return normed.to(original_dtype) - - -def _calculate_shift( - image_seq_len: int, - base_seq_len: int = 1024, - max_seq_len: int = 4096, - base_shift: float = 0.95, - max_shift: float = 2.05, -) -> float: - """Mirrors the pipeline's calculate_shift — defaults match LTX-2 scheduler config.""" - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - return image_seq_len * m + b - - -def _retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=None, mu=None): - """Call scheduler.set_timesteps, forwarding mu when dynamic shifting is enabled.""" - kwargs = {} - if mu is not None: - kwargs["mu"] = mu - if sigmas is not None: - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - return scheduler.timesteps, len(scheduler.timesteps) + def __init__(self, model_id: str): + self.vae: AutoencoderKL = AutoencoderKL.from_pretrained( + model_id, + cache_dir=os.environ["HF_HOME"], + subfolder="vae", + local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), + ) + self.vae.eval().requires_grad_(False) + + # Flux VAE uses shift_factor in addition to scaling_factor + self.scaling_factor = getattr(self.vae.config, "scaling_factor", 0.3611) + self.shift_factor = getattr(self.vae.config, "shift_factor", 0.1159) + + def encode(self, real_images: torch.Tensor) -> torch.Tensor: + """Encode images to latent space. + + Args: + real_images: Input images in [-1, 1] range. + + Returns: + torch.Tensor: Latent representations (shifted and scaled). + """ + latent_images = self.vae.encode(real_images, return_dict=False)[0].sample() + # Apply Flux-specific shift and scale + latent_images = (latent_images - self.shift_factor) * self.scaling_factor + return latent_images + + def decode(self, latent_images: torch.Tensor) -> torch.Tensor: + """Decode latents to images. + + Args: + latent_images: Latent representations (shifted and scaled). + + Returns: + torch.Tensor: Decoded images in [-1, 1] range. + """ + # Reverse Flux-specific shift and scale + latents = (latent_images / self.scaling_factor) + self.shift_factor + images = self.vae.decode(latents, return_dict=False)[0].clip(-1.0, 1.0) + return images + + def to(self, *args, **kwargs): + """Moves the model to the specified device.""" + self.vae.to(*args, **kwargs) + return self -# --------------------------------------------------------------------------- -# classify_forward — monkey-patched onto self.transformer -# --------------------------------------------------------------------------- def classify_forward( self, hidden_states: torch.Tensor, - audio_hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - audio_encoder_hidden_states: torch.Tensor, - timestep: torch.Tensor, - audio_timestep: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - audio_encoder_attention_mask: Optional[torch.Tensor] = None, - num_frames: Optional[int] = None, - height: Optional[int] = None, - width: Optional[int] = None, - fps: float = 24.0, - audio_num_frames: Optional[int] = None, - video_coords: Optional[torch.Tensor] = None, - audio_coords: Optional[torch.Tensor] = None, - attention_kwargs: Optional[dict] = None, - return_dict: bool = True, # accepted for API compatibility; always ignored - # FastGen distillation kwargs + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[dict] = None, return_features_early: bool = False, feature_indices: Optional[Set[int]] = None, return_logvar: bool = False, -) -> Union[ - Tuple[torch.Tensor, torch.Tensor], # (video_out, audio_out) - Tuple[Tuple[torch.Tensor, torch.Tensor], List[torch.Tensor]], # ((video_out, audio_out), features) - List[torch.Tensor], # features only (early exit) -]: +) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: """ - Drop-in replacement for LTX2VideoTransformer3DModel.forward that adds FastGen - distillation support (feature extraction, early exit, logvar). - - Audio always flows through every block unchanged — we never short-circuit it — - but only video hidden_states are stored as features for the discriminator. - - Returns - ------- - Normal mode (feature_indices empty, return_features_early False): - (video_output, audio_output) — identical to the original forward - - Feature mode (feature_indices non-empty, return_features_early False): - ((video_output, audio_output), List[video_feature_tensors]) - - Early-exit mode (return_features_early True): - List[video_feature_tensors] — forward stops as soon as all features collected + Modified forward pass for FluxTransformer2DModel with feature extraction support. + + Args: + hidden_states: Input latent states. + encoder_hidden_states: T5 text encoder hidden states. + pooled_projections: CLIP pooled text embeddings. + timestep: Current timestep. + img_ids: Image position IDs. + txt_ids: Text position IDs. + guidance: Guidance scale embedding. + joint_attention_kwargs: Additional attention kwargs. + return_features_early: If True, return features as soon as collected. + feature_indices: Set of block indices to extract features from. + return_logvar: If True, return log variance estimate. + + Returns: + Model output, optionally with features or logvar. """ - # LoRA scale handling — mirrors the @apply_lora_scale decorator in upstream - from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - if USE_PEFT_BACKEND: - scale_lora_layers(self, lora_scale) - if feature_indices is None: feature_indices = set() if return_features_early and len(feature_indices) == 0: - if USE_PEFT_BACKEND: - unscale_lora_layers(self, lora_scale) return [] - # ------------------------------------------------------------------ # - # Steps 1-4: identical to the original forward (no changes) - # ------------------------------------------------------------------ # - audio_timestep = audio_timestep if audio_timestep is not None else timestep + idx, features = 0, [] - # Convert attention masks to additive bias form - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # Store original sequence length to compute spatial dims for feature reshaping + # hidden_states: [B, seq_len, C*4] where seq_len = (H//2) * (W//2) + seq_len = hidden_states.shape[1] + spatial_size = int(seq_len**0.5) # Assuming square spatial dimensions - if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: - audio_encoder_attention_mask = ( - 1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype) - ) * -10000.0 - audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + # 1. Patch embedding + hidden_states = self.x_embedder(hidden_states) - batch_size = hidden_states.size(0) + # 2. Time embedding + timestep_scaled = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance_scaled = guidance.to(hidden_states.dtype) * 1000 + temb = self.time_text_embed(timestep_scaled, guidance_scaled, pooled_projections) + else: + temb = self.time_text_embed(timestep_scaled, pooled_projections) - # 1. RoPE positional embeddings - if video_coords is None: - video_coords = self.rope.prepare_video_coords( - batch_size, num_frames, height, width, hidden_states.device, fps=fps - ) - if audio_coords is None: - audio_coords = self.audio_rope.prepare_audio_coords( - batch_size, audio_num_frames, audio_hidden_states.device - ) + # 3. Text embedding + encoder_hidden_states = self.context_embedder(encoder_hidden_states) - video_rotary_emb = self.rope(video_coords, device=hidden_states.device) - audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) - - video_cross_attn_rotary_emb = self.cross_attn_rope(video_coords[:, 0:1, :], device=hidden_states.device) - audio_cross_attn_rotary_emb = self.cross_attn_audio_rope( - audio_coords[:, 0:1, :], device=audio_hidden_states.device - ) - - # 2. Patchify input projections - hidden_states = self.proj_in(hidden_states) - audio_hidden_states = self.audio_proj_in(audio_hidden_states) - - # 3. Timestep embeddings and modulation parameters - timestep_cross_attn_gate_scale_factor = ( - self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier - ) - - temb, embedded_timestep = self.time_embed( - timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, - ) - temb = temb.view(batch_size, -1, temb.size(-1)) - embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) - - temb_audio, audio_embedded_timestep = self.audio_time_embed( - audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, - ) - temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) - audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) - - video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( - timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, - ) - video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( - timestep.flatten() * timestep_cross_attn_gate_scale_factor, - batch_size=batch_size, hidden_dtype=hidden_states.dtype, - ) - video_cross_attn_scale_shift = video_cross_attn_scale_shift.view( - batch_size, -1, video_cross_attn_scale_shift.shape[-1] - ) - video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view( - batch_size, -1, video_cross_attn_a2v_gate.shape[-1] - ) - - audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( - audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, - ) - audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( - audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, - batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, - ) - audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view( - batch_size, -1, audio_cross_attn_scale_shift.shape[-1] - ) - audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view( - batch_size, -1, audio_cross_attn_v2a_gate.shape[-1] - ) - - # 4. Prompt embeddings - encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) - - audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) - audio_encoder_hidden_states = audio_encoder_hidden_states.view( - batch_size, -1, audio_hidden_states.size(-1) - ) - - # ------------------------------------------------------------------ # - # Step 5: Block loop with video-only feature extraction - # Audio always flows through every block — we never skip it. - # ------------------------------------------------------------------ # - features: List[torch.Tensor] = [] - - for idx, block in enumerate(self.transformer_blocks): + # 4. Prepare positional embeddings + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + # 5. Joint transformer blocks + for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states, audio_hidden_states = self._gradient_checkpointing_func( + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, - audio_hidden_states, encoder_hidden_states, - audio_encoder_hidden_states, temb, - temb_audio, - video_cross_attn_scale_shift, - audio_cross_attn_scale_shift, - video_cross_attn_a2v_gate, - audio_cross_attn_v2a_gate, - video_rotary_emb, - audio_rotary_emb, - video_cross_attn_rotary_emb, - audio_cross_attn_rotary_emb, - encoder_attention_mask, - audio_encoder_attention_mask, + image_rotary_emb, + joint_attention_kwargs, ) else: - hidden_states, audio_hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, - audio_hidden_states=audio_hidden_states, encoder_hidden_states=encoder_hidden_states, - audio_encoder_hidden_states=audio_encoder_hidden_states, temb=temb, - temb_audio=temb_audio, - temb_ca_scale_shift=video_cross_attn_scale_shift, - temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, - temb_ca_gate=video_cross_attn_a2v_gate, - temb_ca_audio_gate=audio_cross_attn_v2a_gate, - video_rotary_emb=video_rotary_emb, - audio_rotary_emb=audio_rotary_emb, - ca_video_rotary_emb=video_cross_attn_rotary_emb, - ca_audio_rotary_emb=audio_cross_attn_rotary_emb, - encoder_attention_mask=encoder_attention_mask, - audio_encoder_attention_mask=audio_encoder_attention_mask, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, ) - # Video-only feature extraction at requested block indices + # Check if we should extract features at this index if idx in feature_indices: - features.append(hidden_states.clone()) # [B, T_v, D_v] — packed video tokens + # Reshape from [B, seq_len, hidden_dim] to [B, hidden_dim, H, W] for discriminator + feat = hidden_states.clone() + B, S, C = feat.shape + feat = feat.permute(0, 2, 1).reshape(B, C, spatial_size, spatial_size) + features.append(feat) - # Early exit once all requested features are collected + # Early return if we have all features if return_features_early and len(features) == len(feature_indices): return features - # ------------------------------------------------------------------ # - # Step 6: Output layers (video + audio) — unchanged from original - # ------------------------------------------------------------------ # - scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] - shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] - hidden_states = self.norm_out(hidden_states) - hidden_states = hidden_states * (1 + scale) + shift - video_output = self.proj_out(hidden_states) - - audio_scale_shift_values = self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] - audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] - audio_hidden_states = self.audio_norm_out(audio_hidden_states) - audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift - audio_output = self.audio_proj_out(audio_hidden_states) - - # ------------------------------------------------------------------ # - # Assemble output following FastGen convention - # ------------------------------------------------------------------ # + idx += 1 + + # 6. Single transformer blocks + for block in self.single_transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + joint_attention_kwargs, + ) + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # Check if we should extract features at this index + if idx in feature_indices: + # Reshape from [B, seq_len, hidden_dim] to [B, hidden_dim, H, W] for discriminator + feat = hidden_states.clone() + B, S, C = feat.shape + feat = feat.permute(0, 2, 1).reshape(B, C, spatial_size, spatial_size) + features.append(feat) + + # Early return if we have all features + if return_features_early and len(features) == len(feature_indices): + return features + + idx += 1 + + # 7. Final projection - hidden_states is already image-only after single blocks + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + # If we have all the features, we can exit early if return_features_early: - # Should have been caught above; guard for safety assert len(features) == len(feature_indices), f"{len(features)} != {len(feature_indices)}" return features - # Logvar (optional — requires logvar_linear to be added to the transformer) - logvar = None - if return_logvar: - assert hasattr(self, "logvar_linear"), ( - "logvar_linear is required when return_logvar=True. " - "It is added by LTX2.__init__." - ) - # temb has shape [B, T_tokens, inner_dim]; take mean over tokens for a scalar logvar per sample - logvar = self.logvar_linear(temb.mean(dim=1)) # [B, 1] - + # Prepare output if len(feature_indices) == 0: - out = (video_output, audio_output) + out = output else: - out = [(video_output, audio_output), features] - - if USE_PEFT_BACKEND: - unscale_lora_layers(self, lora_scale) + out = [output, features] if return_logvar: + logvar = self.logvar_linear(temb) return out, logvar - return out - - -# --------------------------------------------------------------------------- -# Text encoder wrapper -# --------------------------------------------------------------------------- - -class LTX2TextEncoder(nn.Module): - """ - Wraps Gemma3ForConditionalGeneration for LTX-2 text conditioning. - - Returns both the packed prompt embeddings AND the tokenizer attention mask, - which is required by LTX2TextConnectors. - """ - - def __init__(self, model_id: str): - super().__init__() - self.tokenizer = GemmaTokenizerFast.from_pretrained(model_id, subfolder="tokenizer") - self.tokenizer.padding_side = "left" - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token - - self.text_encoder = Gemma3ForConditionalGeneration.from_pretrained( - model_id, subfolder="text_encoder" - ) - self.text_encoder.eval().requires_grad_(False) - - @torch.no_grad() - def encode( - self, - prompt: Union[str, List[str]], - precision: torch.dtype = torch.bfloat16, - max_sequence_length: int = 1024, - scale_factor: int = 8, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Encode text prompt(s) into packed Gemma hidden states. - - Returns - ------- - prompt_embeds : torch.Tensor [B, T, H * num_layers] - attention_mask : torch.Tensor [B, T] - """ - if isinstance(prompt, str): - prompt = [prompt] - - device = next(self.text_encoder.parameters()).device - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - input_ids = text_inputs.input_ids.to(device) - attention_mask = text_inputs.attention_mask.to(device) - - outputs = self.text_encoder( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - return_dict=True, - ) - # Stack all hidden states: [B, T, H, num_layers] - hidden_states = torch.stack(outputs.hidden_states, dim=-1) - sequence_lengths = attention_mask.sum(dim=-1) - - prompt_embeds = _pack_text_embeds( - hidden_states, - sequence_lengths, - device=device, - padding_side=self.tokenizer.padding_side, - scale_factor=scale_factor, - ).to(precision) - - return prompt_embeds, attention_mask - - def to(self, *args, **kwargs): - self.text_encoder.to(*args, **kwargs) - return self - - -# --------------------------------------------------------------------------- -# Main LTX-2 network — follows FastGen pattern (Flux / Wan) -# --------------------------------------------------------------------------- + return out -class LTX2(FastGenNetwork): - """ - FastGen wrapper for LTX-2 audio-video generation. - Distillation targets video only: - - forward() receives and returns video latents [B, C, F, H, W] - - Audio is generated internally but not used for the distillation loss - - classify_forward extracts video hidden_states at requested block indices +class Flux(FastGenNetwork): + """Flux.1 network for text-to-image generation. - Component layout: - text_encoder → connectors → transformer (patched) → vae → audio_vae → vocoder + Reference: https://huggingface.co/black-forest-labs/FLUX.1-dev """ - MODEL_ID = "Lightricks/LTX-2" + MODEL_ID = "black-forest-labs/FLUX.1-dev" def __init__( self, @@ -540,31 +338,43 @@ def __init__( net_pred_type: str = "flow", schedule_type: str = "rf", disable_grad_ckpt: bool = False, + guidance_scale: Optional[float] = 3.5, load_pretrained: bool = True, **model_kwargs, ): - """ - LTX-2 constructor. + """Flux.1 constructor. Args: - model_id: HuggingFace model ID or local path. Defaults to "Lightricks/LTX-2". - net_pred_type: Prediction type. Defaults to "flow" (flow matching). + model_id: The HuggingFace model ID to load. + Defaults to "black-forest-labs/FLUX.1-dev". + net_pred_type: Prediction type. Defaults to "flow" for flow matching. schedule_type: Schedule type. Defaults to "rf" (rectified flow). - disable_grad_ckpt: Disable gradient checkpointing during training. - Set True when using FSDP to avoid memory access errors. - load_pretrained: Load pretrained weights. If False, initialises from config only. + disable_grad_ckpt: Whether to disable gradient checkpointing during training. + Defaults to False. Set to True when using FSDP to avoid memory access errors. + guidance_scale: Default guidance scale for Flux.1-dev guidance distillation. + None means no guidance. Defaults to 3.5 (recommended for Flux.1-dev). """ super().__init__(net_pred_type=net_pred_type, schedule_type=schedule_type, **model_kwargs) self.model_id = model_id + self.guidance_scale = guidance_scale self._disable_grad_ckpt = disable_grad_ckpt + logger.debug(f"Embedded guidance scale: {guidance_scale}") + # Initialize the network (handles meta device and pretrained loading) self._initialize_network(model_id, load_pretrained) - # Monkey-patch classify_forward onto self.transformer (same pattern as Flux / Wan) + # Override forward with classify_forward self.transformer.forward = types.MethodType(classify_forward, self.transformer) - # Gradient checkpointing + # Disable cuDNN SDPA backend to avoid mha_graph->execute errors during backward. + # This is a known issue with Flux transformer and cuDNN attention. + # Flash and mem_efficient backends still work; only cuDNN is problematic. + if torch.backends.cuda.is_built(): + torch.backends.cuda.enable_cudnn_sdp(False) + logger.info("Disabled cuDNN SDPA backend for Flux compatibility") + + # Gradient checkpointing configuration if disable_grad_ckpt: self.transformer.disable_gradient_checkpointing() else: @@ -572,239 +382,239 @@ def __init__( torch.cuda.empty_cache() - # ------------------------------------------------------------------ - # Initialisation - # ------------------------------------------------------------------ - def _initialize_network(self, model_id: str, load_pretrained: bool) -> None: - """Initialize the transformer and supporting modules.""" + """Initialize the transformer network. + + Args: + model_id: The HuggingFace model ID or local path. + load_pretrained: Whether to load pretrained weights. + """ + # Check if we're in a meta context (for FSDP memory-efficient loading) in_meta_context = self._is_in_meta_context() should_load_weights = load_pretrained and (not in_meta_context) if should_load_weights: - logger.info("Loading LTX-2 transformer from pretrained") - self.transformer: LTX2VideoTransformer3DModel = LTX2VideoTransformer3DModel.from_pretrained( - model_id, subfolder="transformer" + logger.info("Loading Flux transformer from pretrained") + self.transformer: FluxTransformer2DModel = FluxTransformer2DModel.from_pretrained( + model_id, + cache_dir=os.environ["HF_HOME"], + subfolder="transformer", + local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), ) else: - config = LTX2VideoTransformer3DModel.load_config(model_id, subfolder="transformer") + # Load config and create model structure + # If we're in a meta context, tensors will automatically be on meta device + config = FluxTransformer2DModel.load_config( + model_id, + cache_dir=os.environ["HF_HOME"], + subfolder="transformer", + local_files_only=str2bool(os.getenv("LOCAL_FILES_ONLY", "false")), + ) if in_meta_context: logger.info( - "Initializing LTX-2 transformer on meta device " - "(zero memory, will receive weights via FSDP sync)" + "Initializing Flux transformer on meta device (zero memory, will receive weights via FSDP sync)" ) else: - logger.info("Initializing LTX-2 transformer from config (no pretrained weights)") - logger.warning("LTX-2 transformer being initialized from config. No weights are loaded!") - self.transformer: LTX2VideoTransformer3DModel = LTX2VideoTransformer3DModel.from_config(config) - - # inner_dim = num_attention_heads * attention_head_dim - inner_dim = ( - self.transformer.config.num_attention_heads - * self.transformer.config.attention_head_dim - ) + logger.info("Initializing Flux transformer from config (no pretrained weights)") + logger.warning("Flux transformer being initialized from config. No weights are loaded!") + self.transformer: FluxTransformer2DModel = FluxTransformer2DModel.from_config(config) - # Add logvar_linear for uncertainty weighting (DMD2 / f-distill) - # temb mean has shape [B, inner_dim] → logvar scalar per sample - self.transformer.logvar_linear = nn.Linear(inner_dim, 1) - logger.info(f"Added logvar_linear ({inner_dim} → 1) to LTX-2 transformer") + # Add logvar linear layer for variance estimation - Flux uses 3072-dim time embeddings + self.transformer.logvar_linear = torch.nn.Linear(3072, 1) - # Connectors: top-level sibling of transformer (NOT nested inside it) - if should_load_weights: - self.connectors: LTX2TextConnectors = LTX2TextConnectors.from_pretrained( - model_id, subfolder="connectors" - ) - else: - # Connectors are lightweight; always load if pretrained is skipped for the transformer - logger.warning("Skipping connector pretrained load (meta context or load_pretrained=False)") - self.connectors = None # will be loaded lazily via init_preprocessors + def reset_parameters(self): + """Reinitialize parameters for FSDP meta device initialization. - # Cache compression ratios used by forward() and sample() - if should_load_weights: - # VAEs (needed for sample(); not for the training forward pass) - self.vae: AutoencoderKLLTX2Video = AutoencoderKLLTX2Video.from_pretrained( - model_id, subfolder="vae" - ) - self.vae.eval().requires_grad_(False) + This is required when using meta device initialization for FSDP2. + Reinitializes all linear layers and embeddings. + """ + import torch.nn as nn - self.audio_vae: AutoencoderKLLTX2Audio = AutoencoderKLLTX2Audio.from_pretrained( - model_id, subfolder="audio_vae" - ) - self.audio_vae.eval().requires_grad_(False) + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) - self.vocoder: LTX2Vocoder = LTX2Vocoder.from_pretrained( - model_id, subfolder="vocoder" - ) - self.vocoder.eval().requires_grad_(False) + super().reset_parameters() - self._cache_vae_constants() + logger.debug("Reinitialized Flux parameters") - # Scheduler (used in sample()) - self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( - model_id, subfolder="scheduler" - ) + def fully_shard(self, **kwargs): + """Fully shard the Flux network for FSDP. - def _cache_vae_constants(self) -> None: - """Cache VAE spatial/temporal compression constants for use in forward() / sample().""" - self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio - self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio - self.transformer_spatial_patch_size = self.transformer.config.patch_size - self.transformer_temporal_patch_size = self.transformer.config.patch_size_t + Note: Flux has two types of transformer blocks: + - transformer_blocks: Joint attention blocks for text-image interaction + - single_transformer_blocks: Single stream blocks for image processing + + We shard `self.transformer` instead of `self` because the network wrapper + class may have complex multiple inheritance with ABC, which causes Python's + __class__ assignment to fail due to incompatible memory layouts. + """ + # Note: Checkpointing has to happen first, for proper casting during backward pass recomputation. + if self.transformer.gradient_checkpointing: + # Disable the built-in gradient checkpointing (which uses torch.utils.checkpoint) + self.transformer.disable_gradient_checkpointing() + # Apply FSDP-compatible activation checkpointing to both block types + apply_fsdp_checkpointing( + self.transformer, + check_fn=lambda block: isinstance(block, (FluxTransformerBlock, FluxSingleTransformerBlock)), + ) + logger.info("Applied FSDP activation checkpointing to Flux transformer blocks") + + # Apply FSDP sharding to joint transformer blocks + for block in self.transformer.transformer_blocks: + fully_shard(block, **kwargs) - self.audio_sampling_rate = self.audio_vae.config.sample_rate - self.audio_hop_length = self.audio_vae.config.mel_hop_length - self.audio_vae_temporal_compression = self.audio_vae.temporal_compression_ratio - self.audio_vae_mel_compression_ratio = self.audio_vae.mel_compression_ratio + # Apply FSDP sharding to single transformer blocks + for block in self.transformer.single_transformer_blocks: + fully_shard(block, **kwargs) - # ------------------------------------------------------------------ - # Preprocessor initialisation (lazy, matches Flux / Wan pattern) - # ------------------------------------------------------------------ + fully_shard(self.transformer, **kwargs) def init_preprocessors(self): - """Initialize text encoder and connectors.""" - if not hasattr(self, "text_encoder") or self.text_encoder is None: + """Initialize text and image encoders.""" + if not hasattr(self, "text_encoder"): self.init_text_encoder() - if self.connectors is None: - self.connectors = LTX2TextConnectors.from_pretrained( - self.model_id, subfolder="connectors" - ) + if not hasattr(self, "vae"): + self.init_vae() def init_text_encoder(self): - """Initialize the Gemma3 text encoder for LTX-2.""" - self.text_encoder = LTX2TextEncoder(model_id=self.model_id) + """Initialize the text encoder for Flux.""" + self.text_encoder = FluxTextEncoder(model_id=self.model_id) - # ------------------------------------------------------------------ - # Device movement - # ------------------------------------------------------------------ + def init_vae(self): + """Initialize only the VAE for visualization.""" + self.vae = FluxImageEncoder(model_id=self.model_id) def to(self, *args, **kwargs): + """Moves the model to the specified device.""" super().to(*args, **kwargs) - if hasattr(self, "text_encoder") and self.text_encoder is not None: + if hasattr(self, "text_encoder"): self.text_encoder.to(*args, **kwargs) - if hasattr(self, "connectors") and self.connectors is not None: - self.connectors.to(*args, **kwargs) - if hasattr(self, "vae") and self.vae is not None: + if hasattr(self, "vae"): self.vae.to(*args, **kwargs) - if hasattr(self, "audio_vae") and self.audio_vae is not None: - self.audio_vae.to(*args, **kwargs) - if hasattr(self, "vocoder") and self.vocoder is not None: - self.vocoder.to(*args, **kwargs) return self - # ------------------------------------------------------------------ - # FSDP - # ------------------------------------------------------------------ + def _prepare_latent_image_ids( + self, + height: int, + width: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """Prepare image position IDs for the transformer. - def fully_shard(self, **kwargs): - """Fully shard the LTX-2 transformer for FSDP2. + Args: + height: Latent height (before packing, will be divided by 2). + width: Latent width (before packing, will be divided by 2). + device: Target device. + dtype: Target dtype. - Shards self.transformer (not self) to avoid ABC __class__ assignment issues. + Returns: + torch.Tensor: Image position IDs [(H//2)*(W//2), 3] (2D, no batch dim). """ - if self.transformer.gradient_checkpointing: - self.transformer.disable_gradient_checkpointing() - apply_fsdp_checkpointing( - self.transformer, - check_fn=lambda block: isinstance(block, LTX2VideoTransformerBlock), - ) - logger.info("Applied FSDP activation checkpointing to LTX-2 transformer blocks") + # Use packed dimensions + packed_height = height // 2 + packed_width = width // 2 + latent_image_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype) + latent_image_ids[..., 1] = torch.arange(packed_height, device=device, dtype=dtype)[:, None] + latent_image_ids[..., 2] = torch.arange(packed_width, device=device, dtype=dtype)[None, :] + latent_image_ids = latent_image_ids.reshape(packed_height * packed_width, 3) + return latent_image_ids + + def _prepare_text_ids( + self, + seq_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + """Prepare text position IDs. - for block in self.transformer.transformer_blocks: - fully_shard(block, **kwargs) - fully_shard(self.transformer, **kwargs) + Args: + seq_length: Text sequence length. + device: Target device. + dtype: Target dtype. - # ------------------------------------------------------------------ - # reset_parameters (required for FSDP meta device init) - # ------------------------------------------------------------------ + Returns: + torch.Tensor: Text position IDs [seq_length, 3] (2D, no batch dim). + """ + text_ids = torch.zeros(seq_length, 3, device=device, dtype=dtype) + return text_ids - def reset_parameters(self): - """Reinitialise parameters after meta device materialisation (FSDP2).""" - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Embedding): - nn.init.normal_(m.weight, std=0.02) + def _pack_latents(self, latents: torch.Tensor) -> torch.Tensor: + """Pack latents from [B, C, H, W] to [B, (H//2)*(W//2), C*4] for Flux transformer. - super().reset_parameters() - logger.debug("Reinitialized LTX-2 parameters") + Flux uses 2x2 patch packing where each 2x2 spatial block is flattened into channels. - # ------------------------------------------------------------------ - # Audio latent sizing helper (shared by forward and sample) - # ------------------------------------------------------------------ + Args: + latents: Input latents [B, C, H, W]. - def _compute_audio_shape( - self, latent_f: int, fps: float, device: torch.device, dtype: torch.dtype - ) -> Tuple[int, int, int]: + Returns: + Packed latents [B, (H//2)*(W//2), C*4]. """ - Compute audio latent dimensions from video latent frame count. + batch_size, channels, height, width = latents.shape + # Reshape to [B, C, H//2, 2, W//2, 2] + latents = latents.view(batch_size, channels, height // 2, 2, width // 2, 2) + # Permute to [B, H//2, W//2, C, 2, 2] + latents = latents.permute(0, 2, 4, 1, 3, 5) + # Reshape to [B, (H//2)*(W//2), C*4] + latents = latents.reshape(batch_size, (height // 2) * (width // 2), channels * 4) + return latents - Returns (audio_num_frames, latent_mel_bins, num_audio_ch). - """ - pixel_frames = (latent_f - 1) * self.vae_temporal_compression_ratio + 1 - duration_s = pixel_frames / fps + def _unpack_latents(self, latents: torch.Tensor, height: int, width: int) -> torch.Tensor: + """Unpack latents from [B, (H//2)*(W//2), C*4] to [B, C, H, W]. - audio_latents_per_second = ( - self.audio_sampling_rate - / self.audio_hop_length - / float(self.audio_vae_temporal_compression) - ) - audio_num_frames = round(duration_s * audio_latents_per_second) - num_mel_bins = self.audio_vae.config.mel_bins - latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio - num_audio_ch = self.audio_vae.config.latent_channels - return audio_num_frames, latent_mel_bins, num_audio_ch + Reverses the 2x2 patch packing used by Flux. + + Args: + latents: Packed latents [B, (H//2)*(W//2), C*4]. + height: Target height (original H before packing). + width: Target width (original W before packing). - # ------------------------------------------------------------------ - # forward() — video-only distillation interface - # ------------------------------------------------------------------ + Returns: + Unpacked latents [B, C, H, W]. + """ + batch_size = latents.shape[0] + channels = latents.shape[2] // 4 # C*4 -> C + # Reshape to [B, H//2, W//2, C, 2, 2] + latents = latents.reshape(batch_size, height // 2, width // 2, channels, 2, 2) + # Permute to [B, C, H//2, 2, W//2, 2] + latents = latents.permute(0, 3, 1, 4, 2, 5) + # Reshape to [B, C, H, W] + latents = latents.reshape(batch_size, channels, height, width) + return latents def forward( self, x_t: torch.Tensor, t: torch.Tensor, condition: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - r: Optional[torch.Tensor] = None, # unused, kept for API compatibility - fps: float = 24.0, + r: Optional[torch.Tensor] = None, # unused, kept for API compatibility + guidance: Optional[torch.Tensor] = None, return_features_early: bool = False, feature_indices: Optional[Set[int]] = None, return_logvar: bool = False, fwd_pred_type: Optional[str] = None, - audio_latents: Optional[torch.Tensor] = None, - return_audio: bool = False, **fwd_kwargs, ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: - """Forward pass — video latents in, video latents out. - - Follows the Flux/Wan FastGen pattern. Audio is optional: - - Distillation (default): ``audio_latents=None, return_audio=False``. - Random noise is used for audio; only video prediction is returned. - - Inference via sample(): ``audio_latents=, return_audio=True``. - The current denoising audio latents are passed in and the audio - noise prediction is returned alongside the video prediction so that - sample() can step both schedulers. + """Forward pass of Flux diffusion model. Args: - x_t: Video latents [B, C, F, H, W]. - t: Timestep [B] — scheduler sigmas passed directly to time_embed. - condition: Tuple of (prompt_embeds [B, T, D], attention_mask [B, T]). - r: Unused (kept for FastGen API compatibility). - fps: Frames per second (needed for RoPE coordinate computation). - return_features_early: Return video features as soon as collected. - feature_indices: Set of transformer block indices to extract video features from. - return_logvar: Return log-variance estimate alongside the output. - fwd_pred_type: Override prediction type. - audio_latents: Optional packed audio latents [B, T_a, C_a] from sample(). - When None, fresh random noise is generated internally. - return_audio: When True, return (video_out, audio_packed) instead of - just video_out. Only used by sample(). + x_t: The diffused data sample [B, C, H, W]. + t: The current timestep. + condition: Tuple of (pooled_prompt_embeds, prompt_embeds) from text encoder. + r: Another timestep (for mean flow methods). + return_features_early: If True, return features once collected. + feature_indices: Set of block indices for feature extraction. + return_logvar: If True, return the logvar. + fwd_pred_type: Override network prediction type. + guidance: Optional guidance scale embedding. Returns: - Normal: video_out [B, C, F, H, W] - return_audio: (video_out, audio_packed [B, T_a, C_a]) - With features: (video_out, List[video_feature_tensors]) - Early exit: List[video_feature_tensors] - With logvar: (above, logvar [B, 1]) + Model output tensor or tuple with logvar/features. """ if feature_indices is None: feature_indices = set() @@ -817,122 +627,78 @@ def forward( assert fwd_pred_type in NET_PRED_TYPES, f"{fwd_pred_type} is not supported" batch_size = x_t.shape[0] - _, _, latent_f, latent_h, latent_w = x_t.shape + height, width = x_t.shape[2], x_t.shape[3] - # Unpack text conditioning - prompt_embeds, attention_mask = condition + # Unpack condition: (pooled_prompt_embeds, prompt_embeds) + pooled_prompt_embeds, prompt_embeds = condition - # ---- Run connectors to get per-modality encoder hidden states ---- - additive_mask = (1 - attention_mask.to(prompt_embeds.dtype)) * -1_000_000.0 - connector_video_embeds, connector_audio_embeds, connector_attn_mask = self.connectors( - prompt_embeds, additive_mask, additive_mask=True - ) + # Prepare position IDs (2D tensors, no batch dimension) + img_ids = self._prepare_latent_image_ids(height, width, x_t.device, x_t.dtype) + txt_ids = self._prepare_text_ids(prompt_embeds.shape[1], x_t.device, x_t.dtype) - # ---- Timestep: [B] — time_embed handles scale internally ---- - timestep = t.to(x_t.dtype).expand(batch_size) # [B] + # Pack latents for transformer: [B, C, H, W] -> [B, (H//2)*(W//2), C*4] + hidden_states = self._pack_latents(x_t) - # ---- Pack video latents ---- - hidden_states = _pack_latents( - x_t, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size - ) - - # ---- Audio latents: use provided (from sample()) or generate random noise ---- - audio_num_frames, latent_mel_bins, num_audio_ch = self._compute_audio_shape( - latent_f, fps, x_t.device, x_t.dtype - ) - if audio_latents is not None: - # Already packed [B, T_a, C_a] — passed in by sample() - audio_hidden_states = audio_latents.to(x_t.dtype) - else: - audio_hidden_states = _pack_audio_latents( - torch.randn( - batch_size, num_audio_ch, audio_num_frames, latent_mel_bins, - device=x_t.device, dtype=x_t.dtype, - ) + # Note: Flux.1-dev (w/ guidance distillation) uses embedded guidance, so the default guidance is not None + if guidance is None: + guidance = torch.full( + (batch_size,), self.guidance_scale, device=hidden_states.device, dtype=hidden_states.dtype ) - # ---- RoPE coordinates ---- - video_coords = self.transformer.rope.prepare_video_coords( - batch_size, latent_f, latent_h, latent_w, x_t.device, fps=fps - ) - audio_coords = self.transformer.audio_rope.prepare_audio_coords( - batch_size, audio_num_frames, x_t.device - ) - - # ---- Transformer forward (classify_forward, monkey-patched) ---- model_outputs = self.transformer( hidden_states=hidden_states, - audio_hidden_states=audio_hidden_states, - encoder_hidden_states=connector_video_embeds, - audio_encoder_hidden_states=connector_audio_embeds, - encoder_attention_mask=connector_attn_mask, - audio_encoder_attention_mask=connector_attn_mask, - timestep=timestep, - num_frames=latent_f, - height=latent_h, - width=latent_w, - fps=fps, - audio_num_frames=audio_num_frames, - video_coords=video_coords, - audio_coords=audio_coords, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + timestep=t, # Flux expects timestep in [0, 1] + img_ids=img_ids, + txt_ids=txt_ids, + guidance=guidance, return_features_early=return_features_early, feature_indices=feature_indices, return_logvar=return_logvar, ) - # ---- Early exit: list of video feature tensors ---- if return_features_early: - return model_outputs # List[Tensor], each [B, T_v, D_v] + return model_outputs - # ---- Unpack logvar if requested ---- if return_logvar: out, logvar = model_outputs[0], model_outputs[1] else: out = model_outputs - # ---- Extract video prediction; capture audio for sample() if requested ---- - if len(feature_indices) == 0: - # out is (video_output, audio_output) - video_packed = out[0] # [B, T_v, C_packed] - audio_packed = out[1] # [B, T_a, C_packed] — only used when return_audio=True - features = None - else: - # out is [(video_output, audio_output), features] - video_packed = out[0][0] # [B, T_v, C_packed] - audio_packed = out[0][1] # [B, T_a, C_packed] - features = out[1] # List[Tensor] - - # ---- Unpack video tokens → [B, C, F, H, W] ---- - video_out = _unpack_latents( - video_packed, latent_f, latent_h, latent_w, - self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, - ) - - # ---- Convert model output to requested prediction type ---- - video_out = self.noise_scheduler.convert_model_output( - x_t, video_out, t, - src_pred_type=self.net_pred_type, - target_pred_type=fwd_pred_type, - ) - - # ---- Re-pack output following FastGen convention ---- - if features is not None: - out = [video_out, features] + # Unpack output: [B, H*W, C] -> [B, C, H, W] + if isinstance(out, torch.Tensor): + out = self._unpack_latents(out, height, width) + out = self.noise_scheduler.convert_model_output( + x_t, out, t, src_pred_type=self.net_pred_type, target_pred_type=fwd_pred_type + ) else: - out = video_out - - # Return audio noise pred alongside video when called from sample() - if return_audio: - out = (out, audio_packed.float()) + out[0] = self._unpack_latents(out[0], height, width) + out[0] = self.noise_scheduler.convert_model_output( + x_t, out[0], t, src_pred_type=self.net_pred_type, target_pred_type=fwd_pred_type + ) if return_logvar: return out, logvar return out - # ------------------------------------------------------------------ - # sample() — full denoising loop for inference - # Follows pipeline_ltx2.py exactly (verified working logic preserved) - # ------------------------------------------------------------------ + def _calculate_shift( + self, + image_seq_len: int, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, + ) -> float: + """Calculate the shift value for the scheduler based on image resolution. + + This implements the resolution-dependent shift from the Flux paper. + """ + + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu @torch.no_grad() def sample( @@ -940,85 +706,91 @@ def sample( noise: torch.Tensor, condition: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, neg_condition: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - guidance_scale: float = 4.0, - num_steps: int = 40, - fps: float = 24.0, - frame_rate: Optional[float] = None, + guidance_scale: Optional[float] = 3.5, + num_steps: int = 28, **kwargs, - ) -> Tuple[torch.Tensor, None]: - """ - Run the full denoising loop for text-to-video generation (audio always None). + ) -> torch.Tensor: + """Generate samples using Euler flow matching. + + Args: + noise: Initial noise tensor [B, C, H, W]. + condition: Tuple of (pooled_prompt_embeds, prompt_embeds). + neg_condition: Optional negative condition tuple for CFG. + guidance_scale: Guidance scale (if not None, enables guidance via distillation). + num_steps: Number of sampling steps (default 28 for good quality/speed balance). + **kwargs: Additional keyword arguments - Returns - ------- - (video_latents, None): - video: [B, C, F, H, W] denormalised video latents - audio: always None + Returns: + Generated latent samples. """ - fps = frame_rate if frame_rate is not None else fps - do_cfg = neg_condition is not None and guidance_scale > 1.0 - - transformer_dtype = self.transformer.dtype - transformer_device = next(self.transformer.parameters()).device - - # Move latents to transformer device and dtype - video_latents = noise.to(device=transformer_device, dtype=transformer_dtype) - - # Build combined condition for CFG so connectors are called once per step - if do_cfg: - neg_embeds, neg_mask = neg_condition - cond_embeds, cond_mask = condition - combined_condition = ( - torch.cat([neg_embeds, cond_embeds], dim=0).to(device=transformer_device, dtype=transformer_dtype), - torch.cat([neg_mask, cond_mask], dim=0).to(device=transformer_device), - ) - else: - embeds, mask = condition - combined_condition = ( - embeds.to(device=transformer_device, dtype=transformer_dtype), - mask.to(device=transformer_device), - ) + batch_size, channels, height, width = noise.shape - # ---- Scheduler timesteps ---- - B, C, latent_f, latent_h, latent_w = video_latents.shape - sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps) - video_seq_len = latent_f * latent_h * latent_w - mu = _calculate_shift( - video_seq_len, - self.scheduler.config.get("base_image_seq_len", 1024), - self.scheduler.config.get("max_image_seq_len", 4096), - self.scheduler.config.get("base_shift", 0.95), - self.scheduler.config.get("max_shift", 2.05), - ) - timesteps, num_steps = _retrieve_timesteps( - self.scheduler, num_steps, transformer_device, sigmas=sigmas, mu=mu - ) + # Calculate image sequence length for shift calculation + # After 2x2 packing: seq_len = (H // 2) * (W // 2) + image_seq_len = (height // 2) * (width // 2) + + # Calculate resolution-dependent shift (mu) + mu = self._calculate_shift(image_seq_len) + + # Initialize scheduler with proper shift + scheduler = FlowMatchEulerDiscreteScheduler(shift=mu) + scheduler.set_timesteps(num_steps, device=noise.device) + timesteps = scheduler.timesteps - # ---- Denoising loop ---- - for t in timesteps: - latent_input = torch.cat([video_latents] * 2) if do_cfg else video_latents - t_input = t.to(dtype=transformer_dtype, device=transformer_device).expand(latent_input.shape[0]) - - noise_pred = self( - latent_input, - t_input, - condition=combined_condition, - fps=fps, - fwd_pred_type="flow", + # Initialize latents with proper scaling based on the initial timestep + t_init = self.noise_scheduler.safe_clamp( + timesteps[0] / 1000.0, min=self.noise_scheduler.min_t, max=self.noise_scheduler.max_t + ) + latents = self.noise_scheduler.latents(noise=noise, t_init=t_init) + + pooled_prompt_embeds, prompt_embeds = condition + + # Prepare guidance embedding for guidance distillation (Flux.1-dev mode) + # Note: Flux.1-dev uses embedded guidance, not traditional CFG + guidance_tensor = None + if guidance_scale is not None: + guidance_tensor = torch.full((batch_size,), guidance_scale, device=latents.device, dtype=latents.dtype) + + # Sampling loop + for timestep in timesteps: + # Scheduler timesteps are in [0, 1000], transformer expects [0, 1] + t = (timestep / 1000.0).expand(batch_size) + t = self.noise_scheduler.safe_clamp(t, min=self.noise_scheduler.min_t, max=self.noise_scheduler.max_t).to( + latents.dtype ) - if do_cfg: + # Two guidance modes: + # 1. CFG mode: when neg_condition is provided (doubles batch, uses uncond/cond difference) + # 2. Guidance distillation mode: when neg_condition is None (single forward, guidance embedded) + if neg_condition is not None: + # Traditional CFG mode + neg_pooled, neg_prompt = neg_condition + latent_model_input = torch.cat([latents, latents], dim=0) + pooled_input = torch.cat([neg_pooled, pooled_prompt_embeds], dim=0) + prompt_input = torch.cat([neg_prompt, prompt_embeds], dim=0) + t_input = torch.cat([t, t], dim=0) + + noise_pred = self( + latent_model_input, + t_input, + (pooled_input, prompt_input), + fwd_pred_type="flow", + guidance=None, # No guidance embedding for CFG mode + ) + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + else: + # Guidance distillation mode (recommended for Flux.1-dev) + noise_pred = self( + latents, + t, + condition, + fwd_pred_type="flow", + guidance=guidance_tensor, + ) - video_latents = self.scheduler.step(noise_pred, t, video_latents, return_dict=False)[0] - - # ---- Denormalise ---- - video_latents = _denormalize_latents( - video_latents, - self.vae.latents_mean, - self.vae.latents_std, - self.vae.config.scaling_factor, - ) + # Euler step + latents = scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] - return video_latents, None \ No newline at end of file + return latents \ No newline at end of file From 868a9843feed59f49c06fad5bf09cdd607931e9e Mon Sep 17 00:00:00 2001 From: linnan wang Date: Wed, 4 Mar 2026 13:45:43 +0800 Subject: [PATCH 13/13] track the recent progress --- .../LTX2/Data/dual_pipe_video_generate.py | 77 ++++++++ .../LTX2/Data/orig_upsample_comp_generate.py | 169 ++++++++++++++++ .../LTX2/Data/various_durations_generate.py | 181 ++++++++++++++++++ 3 files changed, 427 insertions(+) create mode 100644 fastgen/networks/LTX2/Data/dual_pipe_video_generate.py create mode 100644 fastgen/networks/LTX2/Data/orig_upsample_comp_generate.py create mode 100644 fastgen/networks/LTX2/Data/various_durations_generate.py diff --git a/fastgen/networks/LTX2/Data/dual_pipe_video_generate.py b/fastgen/networks/LTX2/Data/dual_pipe_video_generate.py new file mode 100644 index 0000000..74de57d --- /dev/null +++ b/fastgen/networks/LTX2/Data/dual_pipe_video_generate.py @@ -0,0 +1,77 @@ +import torch +from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps +from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline +from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number +from ltx_pipelines.utils.media_io import encode_video +from ltx_pipelines.utils.constants import ( + AUDIO_SAMPLE_RATE, + DEFAULT_NEGATIVE_PROMPT, + DEFAULT_VIDEO_GUIDER_PARAMS, + DEFAULT_AUDIO_GUIDER_PARAMS, + DEFAULT_NUM_FRAMES, + DEFAULT_FRAME_RATE, + DEFAULT_NUM_INFERENCE_STEPS, +) + +# ── Paths ──────────────────────────────────────────────────────────────────── +CHECKPOINT_PATH = "./ltx-2-19b-dev.safetensors" +DISTILLED_LORA_PATH = "./ltx-2-19b-distilled-lora-384.safetensors" +SPATIAL_UPSAMPLER_PATH = "./ltx-2-spatial-upscaler-x2-1.0.safetensors" +GEMMA_ROOT = "./gemma-3-12b-local" +OUTPUT_PATH = "./output_1080p.mp4" + +# ── Generation settings ────────────────────────────────────────────────────── +PROMPT = "The camera opens in a calm, sunlit frog yoga studio. Warm morning light washes over the wooden floor as incense smoke drifts lazily in the air. The senior frog instructor sits cross-legged at the center, eyes closed, voice deep and calm. “We are one with the pond.” All the frogs answer softly: “Ommm…” “We are one with the mud.” “Ommm…” He smiles faintly. “We are one with the flies.” A pause. The camera pans to the side towards one frog who twitches, eyes darting. Suddenly its tongue snaps out, catching a fly mid-air and pulling it into its mouth. The master exhales slowly, still serene. “But we do not chase the flies…” Beat. “not during class.” The guilty frog lowers its head in shame, folding its hands back into a meditative pose. The other frogs resume their chant: “Ommm…” Camera holds for a moment on the embarrassed frog, eyes closed too tightly, pretending nothing happened." +#"EXT. SMALL TOWN STREET – MORNING – LIVE NEWS BROADCAST The shot opens on a news reporter standing in front of a row of cordoned-off cars, yellow caution tape fluttering behind him. The light is warm, early sun reflecting off the camera lens. The faint hum of chatter and distant drilling fills the air. The reporter, composed but visibly excited, looks directly into the camera, microphone in hand. Reporter (live): “Thank you, Sylvia. And yes — this is a sentence I never thought I’d say on live television — but this morning, here in the quiet town of New Castle, Vermont… black gold has been found!” He gestures slightly toward the field behind him. Reporter (grinning): “If my cameraman can pan over, you’ll see what all the excitement’s about.” The camera pans right, slowly revealing a construction site surrounded by workers in hard hats. A beat of silence — then, with a sudden roar, a geyser of oil erupts from the ground, blasting upward in a violent plume. Workers cheer and scramble, the black stream glistening in the morning light. The camera shakes slightly, trying to stay focused through the chaos. Reporter (off-screen, shouting over the noise): “There it is, folks — the moment New Castle will never forget!” The camera catches the sunlight gleaming off the oil mist before pulling back, revealing the entire scene — the small-town skyline silhouetted against the wild fountain of oil." +SEED = 20173261 +HEIGHT = 1088 # nearest multiple of 64 to 1080 (required for two-stage) +WIDTH = 1920 + + +@torch.inference_mode() +def main() -> None: + # Build pipeline + pipeline = TI2VidTwoStagesPipeline( + checkpoint_path=CHECKPOINT_PATH, + distilled_lora=[ + LoraPathStrengthAndSDOps( + DISTILLED_LORA_PATH, 1.0, LTXV_LORA_COMFY_RENAMING_MAP + ) + ], + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH, + gemma_root=GEMMA_ROOT, + loras=[], + ) + + tiling_config = TilingConfig.default() + + # Generate + video, audio = pipeline( + prompt=PROMPT, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + seed=SEED, + height=HEIGHT, + width=WIDTH, + num_frames=482, # DEFAULT_NUM_FRAMES, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + video_guider_params=DEFAULT_VIDEO_GUIDER_PARAMS, + audio_guider_params=DEFAULT_AUDIO_GUIDER_PARAMS, + images=[], + tiling_config=tiling_config, + ) + + # Encode and save + encode_video( + video=video, + fps=DEFAULT_FRAME_RATE, + audio=audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=OUTPUT_PATH, + video_chunks_number=get_video_chunks_number(DEFAULT_NUM_FRAMES, tiling_config), + ) + print(f"Video saved to {OUTPUT_PATH}") + + +if __name__ == "__main__": + main() diff --git a/fastgen/networks/LTX2/Data/orig_upsample_comp_generate.py b/fastgen/networks/LTX2/Data/orig_upsample_comp_generate.py new file mode 100644 index 0000000..290db2f --- /dev/null +++ b/fastgen/networks/LTX2/Data/orig_upsample_comp_generate.py @@ -0,0 +1,169 @@ +""" +Batch video generation comparing `prompt` vs `upsampled_prompt` at 5s. +Output files: + {stem}_prompt_5s.mp4 + {stem}_upsampled_5s.mp4 +""" + +import json +import logging +import sys +from pathlib import Path + +import torch + +from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps +from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline +from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number +from ltx_pipelines.utils.media_io import encode_video +from ltx_pipelines.utils.constants import ( + AUDIO_SAMPLE_RATE, + DEFAULT_NEGATIVE_PROMPT, + DEFAULT_VIDEO_GUIDER_PARAMS, + DEFAULT_AUDIO_GUIDER_PARAMS, + DEFAULT_FRAME_RATE, + DEFAULT_NUM_INFERENCE_STEPS, +) + +# ── Logging ─────────────────────────────────────────────────────────────────── +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) + +# ── Paths ───────────────────────────────────────────────────────────────────── +CHECKPOINT_PATH = "./ltx-2-19b-dev.safetensors" +DISTILLED_LORA_PATH = "./ltx-2-19b-distilled-lora-384.safetensors" +SPATIAL_UPSAMPLER_PATH = "./ltx-2-spatial-upscaler-x2-1.0.safetensors" +GEMMA_ROOT = "./gemma-3-12b-local" +PROMPTS_FILE = "./prompts.txt" +OUTPUT_DIR = "./outputs_comparison" + +# ── Generation settings ─────────────────────────────────────────────────────── +SEED = 42 +HEIGHT = 1088 # nearest multiple of 64 to 1080 +WIDTH = 1920 +NUM_FRAMES = 121 # 5s @ 24fps: (8 × 15) + 1 = 121 + +# Prompt variants to compare: output_suffix -> JSON key in prompts.txt +PROMPT_VARIANTS = { + "prompt_5s": "prompt", + "upsampled_5s": "upsampled_prompt", +} + + +def load_prompts(path: str) -> list[dict]: + """Read prompts.txt — one JSON object per line.""" + prompts = [] + with open(path, "r", encoding="utf-8") as f: + for line_no, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + prompts.append(json.loads(line)) + except json.JSONDecodeError as e: + logger.warning(f"Skipping line {line_no} — invalid JSON: {e}") + logger.info(f"Loaded {len(prompts)} prompts from {path}") + return prompts + + +def output_path_for(source_file: str, suffix: str, output_dir: str) -> str: + """e.g. '000431.json' + 'prompt_5s' -> './outputs_comparison/000431_prompt_5s.mp4'""" + stem = Path(source_file).stem + return str(Path(output_dir) / f"{stem}_{suffix}.mp4") + + +@torch.inference_mode() +def main() -> None: + Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) + + prompts = load_prompts(PROMPTS_FILE) + if not prompts: + logger.error("No prompts found — exiting.") + return + + logger.info("Loading pipeline...") + pipeline = TI2VidTwoStagesPipeline( + checkpoint_path=CHECKPOINT_PATH, + distilled_lora=[ + LoraPathStrengthAndSDOps( + DISTILLED_LORA_PATH, 1.0, LTXV_LORA_COMFY_RENAMING_MAP + ) + ], + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH, + gemma_root=GEMMA_ROOT, + loras=[], + ) + logger.info("Pipeline loaded.") + + tiling_config = TilingConfig.default() + num_chunks = get_video_chunks_number(NUM_FRAMES, tiling_config) + total = len(prompts) + failed = [] # list of (source_file, suffix) + + for idx, entry in enumerate(prompts, 1): + source_file = entry.get("source_file", f"unknown_{idx}.json") + logger.info(f"[{idx}/{total}] {source_file}") + + for suffix, json_key in PROMPT_VARIANTS.items(): + prompt_text = entry.get(json_key, "").strip() + out_path = output_path_for(source_file, suffix, OUTPUT_DIR) + + logger.info(f" [{suffix}] Prompt ({json_key}): {prompt_text[:100]}...") + + if Path(out_path).exists(): + logger.info(f" [{suffix}] Already exists, skipping.") + continue + + if not prompt_text: + logger.error(f" [{suffix}] Key '{json_key}' missing or empty, skipping.") + failed.append((source_file, suffix)) + continue + + try: + video, audio = pipeline( + prompt=prompt_text, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + seed=SEED, + height=HEIGHT, + width=WIDTH, + num_frames=NUM_FRAMES, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + video_guider_params=DEFAULT_VIDEO_GUIDER_PARAMS, + audio_guider_params=DEFAULT_AUDIO_GUIDER_PARAMS, + images=[], + tiling_config=tiling_config, + ) + + encode_video( + video=video, + fps=DEFAULT_FRAME_RATE, + audio=audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=out_path, + video_chunks_number=num_chunks, + ) + logger.info(f" [{suffix}] Saved: {out_path}") + + except Exception as e: + logger.error(f" [{suffix}] FAILED {source_file}: {e}", exc_info=True) + failed.append((source_file, suffix)) + + # Summary + logger.info("=" * 60) + total_variants = total * len(PROMPT_VARIANTS) + logger.info( + f"Done. {total_variants - len(failed)}/{total_variants} videos generated successfully." + ) + if failed: + logger.warning(f"Failed ({len(failed)}):") + for source_file, suffix in failed: + logger.warning(f" {source_file} [{suffix}]") + + +if __name__ == "__main__": + main() diff --git a/fastgen/networks/LTX2/Data/various_durations_generate.py b/fastgen/networks/LTX2/Data/various_durations_generate.py new file mode 100644 index 0000000..e3d6b2e --- /dev/null +++ b/fastgen/networks/LTX2/Data/various_durations_generate.py @@ -0,0 +1,181 @@ +""" +Batch video generation from prompts.txt using LTX-2 TI2VidTwoStagesPipeline. +Generates each prompt at 3 durations: 5s, 8s, 10s with the same seed. +Output files: {stem}_5s.mp4, {stem}_8s.mp4, {stem}_10s.mp4 +""" + +import json +import logging +import sys +from pathlib import Path + +import torch + +from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps +from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline +from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number +from ltx_pipelines.utils.media_io import encode_video +from ltx_pipelines.utils.constants import ( + AUDIO_SAMPLE_RATE, + DEFAULT_NEGATIVE_PROMPT, + DEFAULT_VIDEO_GUIDER_PARAMS, + DEFAULT_AUDIO_GUIDER_PARAMS, + DEFAULT_FRAME_RATE, + DEFAULT_NUM_INFERENCE_STEPS, +) + +# ── Logging ─────────────────────────────────────────────────────────────────── +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) + +# ── Paths ───────────────────────────────────────────────────────────────────── +CHECKPOINT_PATH = "./ltx-2-19b-dev.safetensors" +DISTILLED_LORA_PATH = "./ltx-2-19b-distilled-lora-384.safetensors" +SPATIAL_UPSAMPLER_PATH = "./ltx-2-spatial-upscaler-x2-1.0.safetensors" +GEMMA_ROOT = "./gemma-3-12b-local" +PROMPTS_FILE = "./prompts.txt" +OUTPUT_DIR = "./outputs" + +# ── Generation settings ─────────────────────────────────────────────────────── +SEED = 42 +HEIGHT = 1088 # nearest multiple of 64 to 1080 (required for two-stage) +WIDTH = 1920 + +# Duration variants: label -> num_frames +# Formula: num_frames = (8 × K) + 1 +# 5s @ 24fps: 120 frames → K=15 → 121 +# 8s @ 24fps: 192 frames → K=24 → 193 +# 10s @ 24fps: 240 frames → K=30 → 241 +DURATION_VARIANTS = { + "5s": 121, + "8s": 193, + "10s": 241, +} + + +def load_prompts(path: str) -> list[dict]: + """Read prompts.txt — one JSON object per line.""" + prompts = [] + with open(path, "r", encoding="utf-8") as f: + for line_no, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + prompts.append(json.loads(line)) + except json.JSONDecodeError as e: + logger.warning(f"Skipping line {line_no} — invalid JSON: {e}") + logger.info(f"Loaded {len(prompts)} prompts from {path}") + return prompts + + +def output_path_for(source_file: str, suffix: str, output_dir: str) -> str: + """Convert e.g. '000431.json' + '5s' -> './outputs/000431_5s.mp4'""" + stem = Path(source_file).stem + return str(Path(output_dir) / f"{stem}_{suffix}.mp4") + + +@torch.inference_mode() +def main() -> None: + Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) + + prompts = load_prompts(PROMPTS_FILE) + if not prompts: + logger.error("No prompts found — exiting.") + return + + logger.info("Loading pipeline...") + pipeline = TI2VidTwoStagesPipeline( + checkpoint_path=CHECKPOINT_PATH, + distilled_lora=[ + LoraPathStrengthAndSDOps( + DISTILLED_LORA_PATH, 1.0, LTXV_LORA_COMFY_RENAMING_MAP + ) + ], + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH, + gemma_root=GEMMA_ROOT, + loras=[], + ) + logger.info("Pipeline loaded.") + + tiling_config = TilingConfig.default() + total = len(prompts) + failed = [] # list of (source_file, suffix) tuples + + for idx, entry in enumerate(prompts, 1): + source_file = entry.get("source_file", f"unknown_{idx}.json") + upsampled = entry.get("upsampled_prompt", "").strip() + original = entry.get("prompt", "") + + logger.info(f"[{idx}/{total}] {source_file}") + logger.info(f" Prompt: {upsampled[:120]}...") + + if not upsampled: + logger.warning(" No upsampled_prompt, falling back to prompt.") + upsampled = original + + if not upsampled: + logger.error(f" No prompt available for {source_file}, skipping all durations.") + for suffix in DURATION_VARIANTS: + failed.append((source_file, suffix)) + continue + + # Generate each duration variant for this prompt + for suffix, num_frames in DURATION_VARIANTS.items(): + out_path = output_path_for(source_file, suffix, OUTPUT_DIR) + + logger.info(f" [{suffix}] {num_frames} frames → {out_path}") + + if Path(out_path).exists(): + logger.info(f" [{suffix}] Already exists, skipping.") + continue + + try: + video, audio = pipeline( + prompt=upsampled, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + seed=SEED, # same seed across all durations + height=HEIGHT, + width=WIDTH, + num_frames=num_frames, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + video_guider_params=DEFAULT_VIDEO_GUIDER_PARAMS, + audio_guider_params=DEFAULT_AUDIO_GUIDER_PARAMS, + images=[], + tiling_config=tiling_config, + ) + + encode_video( + video=video, + fps=DEFAULT_FRAME_RATE, + audio=audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=out_path, + video_chunks_number=get_video_chunks_number(num_frames, tiling_config), + ) + logger.info(f" [{suffix}] Saved: {out_path}") + + except Exception as e: + logger.error(f" [{suffix}] FAILED {source_file}: {e}", exc_info=True) + failed.append((source_file, suffix)) + + # Summary + logger.info("=" * 60) + total_variants = total * len(DURATION_VARIANTS) + logger.info( + f"Done. {total_variants - len(failed)}/{total_variants} videos generated successfully." + ) + if failed: + logger.warning(f"Failed ({len(failed)}):") + for source_file, suffix in failed: + logger.warning(f" {source_file} [{suffix}]") + + +if __name__ == "__main__": + main() +