diff --git a/fastgen/networks/Flux/network.py b/fastgen/networks/Flux/network.py index efc04b2..7244955 100644 --- a/fastgen/networks/Flux/network.py +++ b/fastgen/networks/Flux/network.py @@ -793,4 +793,4 @@ def sample( # Euler step latents = scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] - return latents + return latents \ No newline at end of file diff --git a/fastgen/networks/LTX2/Data/dual_pipe_video_generate.py b/fastgen/networks/LTX2/Data/dual_pipe_video_generate.py new file mode 100644 index 0000000..74de57d --- /dev/null +++ b/fastgen/networks/LTX2/Data/dual_pipe_video_generate.py @@ -0,0 +1,77 @@ +import torch +from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps +from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline +from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number +from ltx_pipelines.utils.media_io import encode_video +from ltx_pipelines.utils.constants import ( + AUDIO_SAMPLE_RATE, + DEFAULT_NEGATIVE_PROMPT, + DEFAULT_VIDEO_GUIDER_PARAMS, + DEFAULT_AUDIO_GUIDER_PARAMS, + DEFAULT_NUM_FRAMES, + DEFAULT_FRAME_RATE, + DEFAULT_NUM_INFERENCE_STEPS, +) + +# ── Paths ──────────────────────────────────────────────────────────────────── +CHECKPOINT_PATH = "./ltx-2-19b-dev.safetensors" +DISTILLED_LORA_PATH = "./ltx-2-19b-distilled-lora-384.safetensors" +SPATIAL_UPSAMPLER_PATH = "./ltx-2-spatial-upscaler-x2-1.0.safetensors" +GEMMA_ROOT = "./gemma-3-12b-local" +OUTPUT_PATH = "./output_1080p.mp4" + +# ── Generation settings ────────────────────────────────────────────────────── +PROMPT = "The camera opens in a calm, sunlit frog yoga studio. Warm morning light washes over the wooden floor as incense smoke drifts lazily in the air. The senior frog instructor sits cross-legged at the center, eyes closed, voice deep and calm. “We are one with the pond.” All the frogs answer softly: “Ommm…” “We are one with the mud.” “Ommm…” He smiles faintly. “We are one with the flies.” A pause. The camera pans to the side towards one frog who twitches, eyes darting. Suddenly its tongue snaps out, catching a fly mid-air and pulling it into its mouth. The master exhales slowly, still serene. “But we do not chase the flies…” Beat. “not during class.” The guilty frog lowers its head in shame, folding its hands back into a meditative pose. The other frogs resume their chant: “Ommm…” Camera holds for a moment on the embarrassed frog, eyes closed too tightly, pretending nothing happened." +#"EXT. SMALL TOWN STREET – MORNING – LIVE NEWS BROADCAST The shot opens on a news reporter standing in front of a row of cordoned-off cars, yellow caution tape fluttering behind him. The light is warm, early sun reflecting off the camera lens. The faint hum of chatter and distant drilling fills the air. The reporter, composed but visibly excited, looks directly into the camera, microphone in hand. Reporter (live): “Thank you, Sylvia. And yes — this is a sentence I never thought I’d say on live television — but this morning, here in the quiet town of New Castle, Vermont… black gold has been found!” He gestures slightly toward the field behind him. Reporter (grinning): “If my cameraman can pan over, you’ll see what all the excitement’s about.” The camera pans right, slowly revealing a construction site surrounded by workers in hard hats. A beat of silence — then, with a sudden roar, a geyser of oil erupts from the ground, blasting upward in a violent plume. Workers cheer and scramble, the black stream glistening in the morning light. The camera shakes slightly, trying to stay focused through the chaos. Reporter (off-screen, shouting over the noise): “There it is, folks — the moment New Castle will never forget!” The camera catches the sunlight gleaming off the oil mist before pulling back, revealing the entire scene — the small-town skyline silhouetted against the wild fountain of oil." +SEED = 20173261 +HEIGHT = 1088 # nearest multiple of 64 to 1080 (required for two-stage) +WIDTH = 1920 + + +@torch.inference_mode() +def main() -> None: + # Build pipeline + pipeline = TI2VidTwoStagesPipeline( + checkpoint_path=CHECKPOINT_PATH, + distilled_lora=[ + LoraPathStrengthAndSDOps( + DISTILLED_LORA_PATH, 1.0, LTXV_LORA_COMFY_RENAMING_MAP + ) + ], + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH, + gemma_root=GEMMA_ROOT, + loras=[], + ) + + tiling_config = TilingConfig.default() + + # Generate + video, audio = pipeline( + prompt=PROMPT, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + seed=SEED, + height=HEIGHT, + width=WIDTH, + num_frames=482, # DEFAULT_NUM_FRAMES, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + video_guider_params=DEFAULT_VIDEO_GUIDER_PARAMS, + audio_guider_params=DEFAULT_AUDIO_GUIDER_PARAMS, + images=[], + tiling_config=tiling_config, + ) + + # Encode and save + encode_video( + video=video, + fps=DEFAULT_FRAME_RATE, + audio=audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=OUTPUT_PATH, + video_chunks_number=get_video_chunks_number(DEFAULT_NUM_FRAMES, tiling_config), + ) + print(f"Video saved to {OUTPUT_PATH}") + + +if __name__ == "__main__": + main() diff --git a/fastgen/networks/LTX2/Data/orig_upsample_comp_generate.py b/fastgen/networks/LTX2/Data/orig_upsample_comp_generate.py new file mode 100644 index 0000000..290db2f --- /dev/null +++ b/fastgen/networks/LTX2/Data/orig_upsample_comp_generate.py @@ -0,0 +1,169 @@ +""" +Batch video generation comparing `prompt` vs `upsampled_prompt` at 5s. +Output files: + {stem}_prompt_5s.mp4 + {stem}_upsampled_5s.mp4 +""" + +import json +import logging +import sys +from pathlib import Path + +import torch + +from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps +from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline +from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number +from ltx_pipelines.utils.media_io import encode_video +from ltx_pipelines.utils.constants import ( + AUDIO_SAMPLE_RATE, + DEFAULT_NEGATIVE_PROMPT, + DEFAULT_VIDEO_GUIDER_PARAMS, + DEFAULT_AUDIO_GUIDER_PARAMS, + DEFAULT_FRAME_RATE, + DEFAULT_NUM_INFERENCE_STEPS, +) + +# ── Logging ─────────────────────────────────────────────────────────────────── +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) + +# ── Paths ───────────────────────────────────────────────────────────────────── +CHECKPOINT_PATH = "./ltx-2-19b-dev.safetensors" +DISTILLED_LORA_PATH = "./ltx-2-19b-distilled-lora-384.safetensors" +SPATIAL_UPSAMPLER_PATH = "./ltx-2-spatial-upscaler-x2-1.0.safetensors" +GEMMA_ROOT = "./gemma-3-12b-local" +PROMPTS_FILE = "./prompts.txt" +OUTPUT_DIR = "./outputs_comparison" + +# ── Generation settings ─────────────────────────────────────────────────────── +SEED = 42 +HEIGHT = 1088 # nearest multiple of 64 to 1080 +WIDTH = 1920 +NUM_FRAMES = 121 # 5s @ 24fps: (8 × 15) + 1 = 121 + +# Prompt variants to compare: output_suffix -> JSON key in prompts.txt +PROMPT_VARIANTS = { + "prompt_5s": "prompt", + "upsampled_5s": "upsampled_prompt", +} + + +def load_prompts(path: str) -> list[dict]: + """Read prompts.txt — one JSON object per line.""" + prompts = [] + with open(path, "r", encoding="utf-8") as f: + for line_no, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + prompts.append(json.loads(line)) + except json.JSONDecodeError as e: + logger.warning(f"Skipping line {line_no} — invalid JSON: {e}") + logger.info(f"Loaded {len(prompts)} prompts from {path}") + return prompts + + +def output_path_for(source_file: str, suffix: str, output_dir: str) -> str: + """e.g. '000431.json' + 'prompt_5s' -> './outputs_comparison/000431_prompt_5s.mp4'""" + stem = Path(source_file).stem + return str(Path(output_dir) / f"{stem}_{suffix}.mp4") + + +@torch.inference_mode() +def main() -> None: + Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) + + prompts = load_prompts(PROMPTS_FILE) + if not prompts: + logger.error("No prompts found — exiting.") + return + + logger.info("Loading pipeline...") + pipeline = TI2VidTwoStagesPipeline( + checkpoint_path=CHECKPOINT_PATH, + distilled_lora=[ + LoraPathStrengthAndSDOps( + DISTILLED_LORA_PATH, 1.0, LTXV_LORA_COMFY_RENAMING_MAP + ) + ], + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH, + gemma_root=GEMMA_ROOT, + loras=[], + ) + logger.info("Pipeline loaded.") + + tiling_config = TilingConfig.default() + num_chunks = get_video_chunks_number(NUM_FRAMES, tiling_config) + total = len(prompts) + failed = [] # list of (source_file, suffix) + + for idx, entry in enumerate(prompts, 1): + source_file = entry.get("source_file", f"unknown_{idx}.json") + logger.info(f"[{idx}/{total}] {source_file}") + + for suffix, json_key in PROMPT_VARIANTS.items(): + prompt_text = entry.get(json_key, "").strip() + out_path = output_path_for(source_file, suffix, OUTPUT_DIR) + + logger.info(f" [{suffix}] Prompt ({json_key}): {prompt_text[:100]}...") + + if Path(out_path).exists(): + logger.info(f" [{suffix}] Already exists, skipping.") + continue + + if not prompt_text: + logger.error(f" [{suffix}] Key '{json_key}' missing or empty, skipping.") + failed.append((source_file, suffix)) + continue + + try: + video, audio = pipeline( + prompt=prompt_text, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + seed=SEED, + height=HEIGHT, + width=WIDTH, + num_frames=NUM_FRAMES, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + video_guider_params=DEFAULT_VIDEO_GUIDER_PARAMS, + audio_guider_params=DEFAULT_AUDIO_GUIDER_PARAMS, + images=[], + tiling_config=tiling_config, + ) + + encode_video( + video=video, + fps=DEFAULT_FRAME_RATE, + audio=audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=out_path, + video_chunks_number=num_chunks, + ) + logger.info(f" [{suffix}] Saved: {out_path}") + + except Exception as e: + logger.error(f" [{suffix}] FAILED {source_file}: {e}", exc_info=True) + failed.append((source_file, suffix)) + + # Summary + logger.info("=" * 60) + total_variants = total * len(PROMPT_VARIANTS) + logger.info( + f"Done. {total_variants - len(failed)}/{total_variants} videos generated successfully." + ) + if failed: + logger.warning(f"Failed ({len(failed)}):") + for source_file, suffix in failed: + logger.warning(f" {source_file} [{suffix}]") + + +if __name__ == "__main__": + main() diff --git a/fastgen/networks/LTX2/Data/various_durations_generate.py b/fastgen/networks/LTX2/Data/various_durations_generate.py new file mode 100644 index 0000000..e3d6b2e --- /dev/null +++ b/fastgen/networks/LTX2/Data/various_durations_generate.py @@ -0,0 +1,181 @@ +""" +Batch video generation from prompts.txt using LTX-2 TI2VidTwoStagesPipeline. +Generates each prompt at 3 durations: 5s, 8s, 10s with the same seed. +Output files: {stem}_5s.mp4, {stem}_8s.mp4, {stem}_10s.mp4 +""" + +import json +import logging +import sys +from pathlib import Path + +import torch + +from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps +from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline +from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number +from ltx_pipelines.utils.media_io import encode_video +from ltx_pipelines.utils.constants import ( + AUDIO_SAMPLE_RATE, + DEFAULT_NEGATIVE_PROMPT, + DEFAULT_VIDEO_GUIDER_PARAMS, + DEFAULT_AUDIO_GUIDER_PARAMS, + DEFAULT_FRAME_RATE, + DEFAULT_NUM_INFERENCE_STEPS, +) + +# ── Logging ─────────────────────────────────────────────────────────────────── +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) + +# ── Paths ───────────────────────────────────────────────────────────────────── +CHECKPOINT_PATH = "./ltx-2-19b-dev.safetensors" +DISTILLED_LORA_PATH = "./ltx-2-19b-distilled-lora-384.safetensors" +SPATIAL_UPSAMPLER_PATH = "./ltx-2-spatial-upscaler-x2-1.0.safetensors" +GEMMA_ROOT = "./gemma-3-12b-local" +PROMPTS_FILE = "./prompts.txt" +OUTPUT_DIR = "./outputs" + +# ── Generation settings ─────────────────────────────────────────────────────── +SEED = 42 +HEIGHT = 1088 # nearest multiple of 64 to 1080 (required for two-stage) +WIDTH = 1920 + +# Duration variants: label -> num_frames +# Formula: num_frames = (8 × K) + 1 +# 5s @ 24fps: 120 frames → K=15 → 121 +# 8s @ 24fps: 192 frames → K=24 → 193 +# 10s @ 24fps: 240 frames → K=30 → 241 +DURATION_VARIANTS = { + "5s": 121, + "8s": 193, + "10s": 241, +} + + +def load_prompts(path: str) -> list[dict]: + """Read prompts.txt — one JSON object per line.""" + prompts = [] + with open(path, "r", encoding="utf-8") as f: + for line_no, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + prompts.append(json.loads(line)) + except json.JSONDecodeError as e: + logger.warning(f"Skipping line {line_no} — invalid JSON: {e}") + logger.info(f"Loaded {len(prompts)} prompts from {path}") + return prompts + + +def output_path_for(source_file: str, suffix: str, output_dir: str) -> str: + """Convert e.g. '000431.json' + '5s' -> './outputs/000431_5s.mp4'""" + stem = Path(source_file).stem + return str(Path(output_dir) / f"{stem}_{suffix}.mp4") + + +@torch.inference_mode() +def main() -> None: + Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True) + + prompts = load_prompts(PROMPTS_FILE) + if not prompts: + logger.error("No prompts found — exiting.") + return + + logger.info("Loading pipeline...") + pipeline = TI2VidTwoStagesPipeline( + checkpoint_path=CHECKPOINT_PATH, + distilled_lora=[ + LoraPathStrengthAndSDOps( + DISTILLED_LORA_PATH, 1.0, LTXV_LORA_COMFY_RENAMING_MAP + ) + ], + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH, + gemma_root=GEMMA_ROOT, + loras=[], + ) + logger.info("Pipeline loaded.") + + tiling_config = TilingConfig.default() + total = len(prompts) + failed = [] # list of (source_file, suffix) tuples + + for idx, entry in enumerate(prompts, 1): + source_file = entry.get("source_file", f"unknown_{idx}.json") + upsampled = entry.get("upsampled_prompt", "").strip() + original = entry.get("prompt", "") + + logger.info(f"[{idx}/{total}] {source_file}") + logger.info(f" Prompt: {upsampled[:120]}...") + + if not upsampled: + logger.warning(" No upsampled_prompt, falling back to prompt.") + upsampled = original + + if not upsampled: + logger.error(f" No prompt available for {source_file}, skipping all durations.") + for suffix in DURATION_VARIANTS: + failed.append((source_file, suffix)) + continue + + # Generate each duration variant for this prompt + for suffix, num_frames in DURATION_VARIANTS.items(): + out_path = output_path_for(source_file, suffix, OUTPUT_DIR) + + logger.info(f" [{suffix}] {num_frames} frames → {out_path}") + + if Path(out_path).exists(): + logger.info(f" [{suffix}] Already exists, skipping.") + continue + + try: + video, audio = pipeline( + prompt=upsampled, + negative_prompt=DEFAULT_NEGATIVE_PROMPT, + seed=SEED, # same seed across all durations + height=HEIGHT, + width=WIDTH, + num_frames=num_frames, + frame_rate=DEFAULT_FRAME_RATE, + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, + video_guider_params=DEFAULT_VIDEO_GUIDER_PARAMS, + audio_guider_params=DEFAULT_AUDIO_GUIDER_PARAMS, + images=[], + tiling_config=tiling_config, + ) + + encode_video( + video=video, + fps=DEFAULT_FRAME_RATE, + audio=audio, + audio_sample_rate=AUDIO_SAMPLE_RATE, + output_path=out_path, + video_chunks_number=get_video_chunks_number(num_frames, tiling_config), + ) + logger.info(f" [{suffix}] Saved: {out_path}") + + except Exception as e: + logger.error(f" [{suffix}] FAILED {source_file}: {e}", exc_info=True) + failed.append((source_file, suffix)) + + # Summary + logger.info("=" * 60) + total_variants = total * len(DURATION_VARIANTS) + logger.info( + f"Done. {total_variants - len(failed)}/{total_variants} videos generated successfully." + ) + if failed: + logger.warning(f"Failed ({len(failed)}):") + for source_file, suffix in failed: + logger.warning(f" {source_file} [{suffix}]") + + +if __name__ == "__main__": + main() + diff --git a/fastgen/networks/LTX2/network.py b/fastgen/networks/LTX2/network.py new file mode 100644 index 0000000..38b5d96 --- /dev/null +++ b/fastgen/networks/LTX2/network.py @@ -0,0 +1,799 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +LTX-2 FastGen network implementation (video-only). + +Uses the local customized pipeline_ltx2.py and transformer_ltx2.py which +support audio_enabled=False, so no audio weights are allocated and no audio +ops run during training or inference. + +Follows the FastGen network pattern established by Flux and Wan: + - Inherits from FastGenNetwork + - Monkey-patches classify_forward onto self.transformer + - forward() operates entirely in video latent space [B, C, F, H, W] + - sample() calls self() (i.e. forward()) — NOT self.transformer directly + - feature_indices extracts video hidden_states for the discriminator +""" + +import types +from typing import Any, List, Optional, Set, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.distributed.fsdp import fully_shard +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizerFast + +# ---- Local customized modules (same folder as this file) ---- +from .pipeline_ltx2 import LTX2Pipeline +from .transformer_ltx2 import LTX2VideoTransformer3DModel, LTX2VideoTransformerBlock + +from diffusers.models.autoencoders import AutoencoderKLLTX2Video +from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler + +from fastgen.networks.network import FastGenNetwork +from fastgen.networks.noise_schedule import NET_PRED_TYPES +from fastgen.utils.distributed.fsdp import apply_fsdp_checkpointing +import fastgen.utils.logging_utils as logger + + +# --------------------------------------------------------------------------- +# Latent pack / unpack helpers (video only) +# --------------------------------------------------------------------------- + +def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + """[B, C, F, H, W] → [B, F//pt * H//p * W//p, C*pt*p*p]""" + B, C, F, H, W = latents.shape + latents = latents.reshape(B, C, F // patch_size_t, patch_size_t, + H // patch_size, patch_size, + W // patch_size, patch_size) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + +def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, + patch_size: int = 1, patch_size_t: int = 1, +) -> torch.Tensor: + """[B, T, D] → [B, C, F, H, W]""" + B = latents.size(0) + latents = latents.reshape(B, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + +def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, + scaling_factor: float = 1.0, +) -> torch.Tensor: + mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + return (latents - mean) * scaling_factor / std + + +def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, + scaling_factor: float = 1.0, +) -> torch.Tensor: + mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + return latents * std / scaling_factor + mean + + +def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, +) -> torch.Tensor: + B, T, H, L = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + token_indices = torch.arange(T, device=device).unsqueeze(0) + if padding_side == "right": + mask = token_indices < sequence_lengths[:, None] + else: + start = T - sequence_lengths[:, None] + mask = token_indices >= start + mask = mask[:, :, None, None] + + masked = text_hidden_states.masked_fill(~mask, 0.0) + num_valid = (sequence_lengths * H).view(B, 1, 1, 1) + mean = masked.sum(dim=(1, 2), keepdim=True) / (num_valid + eps) + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + normed = (text_hidden_states - mean) / (x_max - x_min + eps) * scale_factor + normed = normed.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, H * L) + normed = normed.masked_fill(~mask_flat, 0.0) + return normed.to(original_dtype) + + +def _calculate_shift( + image_seq_len: int, + base_seq_len: int = 1024, + max_seq_len: int = 4096, + base_shift: float = 0.95, + max_shift: float = 2.05, +) -> float: + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + return image_seq_len * m + b + + +def _retrieve_timesteps(scheduler, num_inference_steps, device, sigmas=None, mu=None): + kwargs = {} + if mu is not None: + kwargs["mu"] = mu + if sigmas is not None: + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + return scheduler.timesteps, len(scheduler.timesteps) + + +# --------------------------------------------------------------------------- +# classify_forward — monkey-patched onto self.transformer (video-only) +# --------------------------------------------------------------------------- + +def classify_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_attention_mask: Optional[torch.Tensor] = None, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + fps: float = 24.0, + video_coords: Optional[torch.Tensor] = None, + attention_kwargs: Optional[dict] = None, + return_dict: bool = False, + # FastGen distillation kwargs + return_features_early: bool = False, + feature_indices: Optional[Set[int]] = None, + return_logvar: bool = False, +) -> Union[ + torch.Tensor, # video_output only + Tuple[torch.Tensor, List[torch.Tensor]], # (video_output, features) + List[torch.Tensor], # features only (early exit) +]: + """ + Video-only classify_forward monkey-patched onto LTX2VideoTransformer3DModel. + + Since the transformer is built with audio_enabled=False, all audio arguments + are absent. Only video hidden_states are processed and stored as features. + + Returns + ------- + Normal mode (feature_indices empty, return_features_early False): + video_output [B, T_v, C_out] + + Feature mode (feature_indices non-empty, return_features_early False): + (video_output, List[video_feature_tensors]) + + Early-exit mode (return_features_early True): + List[video_feature_tensors] + """ + from diffusers.utils import USE_PEFT_BACKEND, scale_lora_layers, unscale_lora_layers + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + if USE_PEFT_BACKEND: + scale_lora_layers(self, lora_scale) + + if feature_indices is None: + feature_indices = set() + + if return_features_early and len(feature_indices) == 0: + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + return [] + + # -- Attention mask conversion -- + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + + # 1. RoPE + if video_coords is None: + video_coords = self.rope.prepare_video_coords( + batch_size, num_frames, height, width, hidden_states.device, fps=fps + ) + video_rotary_emb = self.rope(video_coords, device=hidden_states.device) + + # 2. Patchify + hidden_states = self.proj_in(hidden_states) + + # 3. Timestep embeddings + temb, embedded_timestep = self.time_embed( + timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + # 4. Prompt embeddings + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + # 5. Block loop — video only, with feature extraction + features: List[torch.Tensor] = [] + + for idx, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states, _ = self._gradient_checkpointing_func( + block, + hidden_states, + None, # audio_hidden_states — not used + encoder_hidden_states, + None, # audio_encoder_hidden_states — not used + temb, + None, # temb_audio + None, # video_cross_attn_scale_shift + None, # audio_cross_attn_scale_shift + None, # video_cross_attn_a2v_gate + None, # audio_cross_attn_v2a_gate + video_rotary_emb, + None, # audio_rotary_emb + None, # video_cross_attn_rotary_emb + None, # audio_cross_attn_rotary_emb + encoder_attention_mask, + None, # audio_encoder_attention_mask + ) + else: + hidden_states, _ = block( + hidden_states=hidden_states, + audio_hidden_states=None, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=None, + temb=temb, + temb_audio=None, + temb_ca_scale_shift=None, + temb_ca_audio_scale_shift=None, + temb_ca_gate=None, + temb_ca_audio_gate=None, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=None, + ca_video_rotary_emb=None, + ca_audio_rotary_emb=None, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=None, + audio_enabled=False, + ) + + if idx in feature_indices: + features.append(hidden_states.clone()) + + if return_features_early and len(features) == len(feature_indices): + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + return features + + # 6. Output layer + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + video_output = self.proj_out(hidden_states) + + # -- Logvar (optional) -- + logvar = None + if return_logvar: + assert hasattr(self, "logvar_linear"), ( + "logvar_linear must exist on transformer. It is added by LTX2.__init__." + ) + logvar = self.logvar_linear(temb.mean(dim=1)) # [B, 1] + + # -- Assemble output -- + if len(feature_indices) == 0: + out = video_output + else: + out = (video_output, features) + + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + + if return_logvar: + return out, logvar + return out + + +# --------------------------------------------------------------------------- +# Text encoder wrapper +# --------------------------------------------------------------------------- + +class LTX2TextEncoder(nn.Module): + """Wraps Gemma3 text encoder for LTX-2 conditioning.""" + + def __init__(self, model_id: str): + super().__init__() + self.tokenizer = GemmaTokenizerFast.from_pretrained(model_id, subfolder="tokenizer") + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.text_encoder = Gemma3ForConditionalGeneration.from_pretrained( + model_id, subfolder="text_encoder" + ) + self.text_encoder.eval().requires_grad_(False) + + @torch.no_grad() + def encode( + self, + prompt: Union[str, List[str]], + precision: torch.dtype = torch.bfloat16, + max_sequence_length: int = 1024, + scale_factor: int = 8, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(prompt, str): + prompt = [prompt] + + device = next(self.text_encoder.parameters()).device + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = text_inputs.input_ids.to(device) + attention_mask = text_inputs.attention_mask.to(device) + + outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + ) + hidden_states = torch.stack(outputs.hidden_states, dim=-1) + sequence_lengths = attention_mask.sum(dim=-1) + + prompt_embeds = _pack_text_embeds( + hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ).to(precision) + + return prompt_embeds, attention_mask + + def to(self, *args, **kwargs): + self.text_encoder.to(*args, **kwargs) + return self + + +# --------------------------------------------------------------------------- +# Main LTX-2 network (video-only) +# --------------------------------------------------------------------------- + +class LTX2(FastGenNetwork): + """ + FastGen wrapper for LTX-2, video-only distillation. + + Uses the local customized transformer_ltx2.py (audio_enabled=False) and + pipeline_ltx2.py (generate_audio=False) so no audio weights are allocated + and no audio ops run at any point. + + Distillation targets video only: + - forward() receives and returns video latents [B, C, F, H, W] + - classify_forward extracts video hidden_states at requested block indices + - sample() calls self() (forward()) — the pipeline is used only for its + helper utilities (latent prep, scheduler config) + """ + + MODEL_ID = "Lightricks/LTX-2" + + def __init__( + self, + model_id: str = MODEL_ID, + net_pred_type: str = "flow", + schedule_type: str = "rf", + disable_grad_ckpt: bool = False, + load_pretrained: bool = True, + **model_kwargs, + ): + super().__init__(net_pred_type=net_pred_type, schedule_type=schedule_type, **model_kwargs) + + self.model_id = model_id + self._disable_grad_ckpt = disable_grad_ckpt + + self._initialize_network(model_id, load_pretrained) + + # Monkey-patch classify_forward (video-only version) + self.transformer.forward = types.MethodType(classify_forward, self.transformer) + + if disable_grad_ckpt: + self.transformer.disable_gradient_checkpointing() + else: + self.transformer.enable_gradient_checkpointing() + + torch.cuda.empty_cache() + + # ------------------------------------------------------------------ + # Initialisation + # ------------------------------------------------------------------ + + def _initialize_network(self, model_id: str, load_pretrained: bool) -> None: + in_meta_context = self._is_in_meta_context() + should_load_weights = load_pretrained and not in_meta_context + + # -- Transformer: audio_enabled=False → no audio weights allocated -- + if should_load_weights: + logger.info("Loading LTX-2 transformer from pretrained (audio_enabled=False)") + # Load pretrained AV checkpoint then drop audio keys via strict=False + av_transformer = LTX2VideoTransformer3DModel.from_pretrained( + model_id, subfolder="transformer" + ) + config = av_transformer.config + # Build video-only transformer + self.transformer = LTX2VideoTransformer3DModel.from_config(config, audio_enabled=False) + missing, unexpected = self.transformer.load_state_dict( + av_transformer.state_dict(), strict=False + ) + assert len(missing) == 0, f"Missing video keys: {missing}" + logger.info(f"Dropped {len(unexpected)} audio keys from pretrained checkpoint") + del av_transformer + else: + config = LTX2VideoTransformer3DModel.load_config(model_id, subfolder="transformer") + if in_meta_context: + logger.info("Initializing LTX-2 transformer on meta device (audio_enabled=False)") + else: + logger.warning("LTX-2 transformer initialized from config only — no pretrained weights!") + self.transformer = LTX2VideoTransformer3DModel.from_config(config, audio_enabled=False) + + # inner_dim for logvar_linear + inner_dim = ( + self.transformer.config.num_attention_heads + * self.transformer.config.attention_head_dim + ) + self.transformer.logvar_linear = nn.Linear(inner_dim, 1) + logger.info(f"Added logvar_linear ({inner_dim} → 1) to transformer") + + # -- Connectors -- + if should_load_weights: + self.connectors: LTX2TextConnectors = LTX2TextConnectors.from_pretrained( + model_id, subfolder="connectors" + ) + else: + logger.warning("Skipping connector pretrained load") + self.connectors = None + + # -- VAE (video only — no audio_vae, no vocoder) -- + if should_load_weights: + self.vae: AutoencoderKLLTX2Video = AutoencoderKLLTX2Video.from_pretrained( + model_id, subfolder="vae" + ) + self.vae.eval().requires_grad_(False) + self._cache_vae_constants() + + # -- Scheduler -- + self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + model_id, subfolder="scheduler" + ) + + def _cache_vae_constants(self) -> None: + self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio + self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio + self.transformer_spatial_patch_size = self.transformer.config.patch_size + self.transformer_temporal_patch_size = self.transformer.config.patch_size_t + + # ------------------------------------------------------------------ + # Preprocessors (lazy) + # ------------------------------------------------------------------ + + def init_preprocessors(self): + if not hasattr(self, "text_encoder") or self.text_encoder is None: + self.init_text_encoder() + if self.connectors is None: + self.connectors = LTX2TextConnectors.from_pretrained( + self.model_id, subfolder="connectors" + ) + + def init_text_encoder(self): + self.text_encoder = LTX2TextEncoder(model_id=self.model_id) + + # ------------------------------------------------------------------ + # Device movement + # ------------------------------------------------------------------ + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + for attr in ("text_encoder", "connectors", "vae"): + obj = getattr(self, attr, None) + if obj is not None: + obj.to(*args, **kwargs) + return self + + # ------------------------------------------------------------------ + # FSDP + # ------------------------------------------------------------------ + + def fully_shard(self, **kwargs): + if self.transformer.gradient_checkpointing: + self.transformer.disable_gradient_checkpointing() + apply_fsdp_checkpointing( + self.transformer, + check_fn=lambda b: isinstance(b, LTX2VideoTransformerBlock), + ) + logger.info("Applied FSDP activation checkpointing to transformer blocks") + + for block in self.transformer.transformer_blocks: + fully_shard(block, **kwargs) + fully_shard(self.transformer, **kwargs) + + # ------------------------------------------------------------------ + # reset_parameters (FSDP meta device) + # ------------------------------------------------------------------ + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Embedding): + nn.init.normal_(m.weight, std=0.02) + super().reset_parameters() + + # ------------------------------------------------------------------ + # forward() — video-only distillation interface + # ------------------------------------------------------------------ + + def forward( + self, + x_t: torch.Tensor, + t: torch.Tensor, + condition: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + r: Optional[torch.Tensor] = None, + fps: float = 24.0, + return_features_early: bool = False, + feature_indices: Optional[Set[int]] = None, + return_logvar: bool = False, + fwd_pred_type: Optional[str] = None, + **fwd_kwargs, + ) -> Union[torch.Tensor, List[torch.Tensor], Tuple]: + """ + Training forward pass: video latents [B, C, F, H, W] → video latents. + + Args: + x_t: Noisy video latents [B, C, F, H, W]. + t: Timesteps [B]. + condition: (prompt_embeds [B, T, D], attention_mask [B, T]). + fps: Frames per second for RoPE coords. + return_features_early: Exit once all feature_indices are collected. + feature_indices: Block indices to extract hidden states from. + return_logvar: Return log-variance alongside output. + fwd_pred_type: Override prediction type. + + Returns: + Normal: video_out [B, C, F, H, W] + With features: (video_out, List[feature_tensors]) + Early exit: List[feature_tensors] + With logvar: (above, logvar [B, 1]) + """ + if feature_indices is None: + feature_indices = set() + if return_features_early and len(feature_indices) == 0: + return [] + + if fwd_pred_type is None: + fwd_pred_type = self.net_pred_type + else: + assert fwd_pred_type in NET_PRED_TYPES, f"Unsupported pred type: {fwd_pred_type}" + + batch_size = x_t.shape[0] + _, _, latent_f, latent_h, latent_w = x_t.shape + + # -- Text conditioning -- + prompt_embeds, attention_mask = condition + additive_mask = (1 - attention_mask.to(prompt_embeds.dtype)) * -1_000_000.0 + # Connectors return (video_embeds, audio_embeds, attn_mask); + # we only use the video branch since audio_enabled=False. + connector_video_embeds, _, connector_attn_mask = self.connectors( + prompt_embeds, additive_mask, additive_mask=True + ) + + # -- Timestep -- + timestep = t.to(x_t.dtype).expand(batch_size) + + # -- Pack video latents -- + hidden_states = _pack_latents( + x_t, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + # -- RoPE video coords -- + video_coords = self.transformer.rope.prepare_video_coords( + batch_size, latent_f, latent_h, latent_w, x_t.device, fps=fps + ) + + # -- Transformer forward (our patched classify_forward, video-only) -- + model_outputs = self.transformer( + hidden_states=hidden_states, + encoder_hidden_states=connector_video_embeds, + encoder_attention_mask=connector_attn_mask, + timestep=timestep, + num_frames=latent_f, + height=latent_h, + width=latent_w, + fps=fps, + video_coords=video_coords, + return_features_early=return_features_early, + feature_indices=feature_indices, + return_logvar=return_logvar, + ) + + # -- Early exit -- + if return_features_early: + return model_outputs # List[Tensor] + + # -- Unpack logvar -- + if return_logvar: + out, logvar = model_outputs[0], model_outputs[1] + else: + out = model_outputs + + # -- Separate video output from features -- + if len(feature_indices) == 0: + video_packed = out # [B, T_v, C] + features = None + else: + video_packed, features = out[0], out[1] + + # -- Unpack → [B, C, F, H, W] -- + video_out = _unpack_latents( + video_packed, latent_f, latent_h, latent_w, + self.transformer_spatial_patch_size, self.transformer_temporal_patch_size, + ) + + # -- Prediction type conversion -- + video_out = self.noise_scheduler.convert_model_output( + x_t, video_out, t, + src_pred_type=self.net_pred_type, + target_pred_type=fwd_pred_type, + ) + + # -- Assemble final output -- + if features is not None: + out = (video_out, features) + else: + out = video_out + + if return_logvar: + return out, logvar + return out + + # ------------------------------------------------------------------ + # sample() — full denoising loop; calls self() (forward()) + # Works entirely in unpacked [B, C, F, H, W] space so forward() is + # called with the correct input shape at every step. + # ------------------------------------------------------------------ + + @torch.no_grad() + def sample( + self, + noise: torch.Tensor, + condition: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + neg_condition: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + guidance_scale: float = 4.0, + num_steps: int = 40, + fps: float = 24.0, + frame_rate: Optional[float] = None, + **kwargs, + ) -> Tuple[torch.Tensor, None]: + """ + Denoising loop for video generation. Calls self() at each step. + + noise must be unpacked video latents [B, C, F, H, W]. + + Returns + ------- + (video_latents [B, C, F, H, W], None) + Denormalised video latents ready for VAE decode. Audio is always None. + """ + fps = frame_rate if frame_rate is not None else fps + do_cfg = neg_condition is not None and guidance_scale > 1.0 + + transformer_dtype = self.transformer.dtype + transformer_device = next(self.transformer.parameters()).device + + # noise must arrive as [B, C, F, H, W] (unpacked) + assert noise.ndim == 5, "sample() expects unpacked latents [B, C, F, H, W]" + video_latents = noise.to(device=transformer_device, dtype=transformer_dtype) + + # -- Build combined CFG condition (processed once, reused every step) -- + if do_cfg: + neg_embeds, neg_mask = neg_condition + cond_embeds, cond_mask = condition + combined_condition = ( + torch.cat([neg_embeds, cond_embeds], dim=0).to( + device=transformer_device, dtype=transformer_dtype + ), + torch.cat([neg_mask, cond_mask], dim=0).to(device=transformer_device), + ) + else: + embeds, mask = condition + combined_condition = ( + embeds.to(device=transformer_device, dtype=transformer_dtype), + mask.to(device=transformer_device), + ) + + # -- Scheduler timesteps with mu shift -- + B, C, latent_f, latent_h, latent_w = video_latents.shape + sigmas = np.linspace(1.0, 1.0 / num_steps, num_steps) + video_seq_len = latent_f * latent_h * latent_w + mu = _calculate_shift( + video_seq_len, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + timesteps, num_steps = _retrieve_timesteps( + self.scheduler, num_steps, transformer_device, sigmas=sigmas, mu=mu + ) + + # -- Denoising loop (unpacked latents throughout) -- + for t in timesteps: + # Duplicate along batch for CFG + latent_input = torch.cat([video_latents] * 2) if do_cfg else video_latents + t_input = ( + t.to(dtype=transformer_dtype, device=transformer_device) + .expand(latent_input.shape[0]) + ) + + # self() → forward() — expects and returns [B, C, F, H, W] + noise_pred = self( + latent_input, + t_input, + condition=combined_condition, + fps=fps, + fwd_pred_type="flow", + ) + + if do_cfg: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + + # Scheduler step — operates on packed tokens internally but we keep + # video_latents unpacked; use the pipeline's _pack/_unpack helpers. + video_packed = _pack_latents( + video_latents, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + noise_pred_packed = _pack_latents( + noise_pred, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + stepped_packed = self.scheduler.step( + noise_pred_packed, t, video_packed, return_dict=False + )[0] + video_latents = _unpack_latents( + stepped_packed, latent_f, latent_h, latent_w, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + # -- Denormalise -- + video_latents = _denormalize_latents( + video_latents, + self.vae.latents_mean, + self.vae.latents_std, + self.vae.config.scaling_factor, + ) + + return video_latents, None diff --git a/fastgen/networks/LTX2/pipeline_ltx2.py b/fastgen/networks/LTX2/pipeline_ltx2.py new file mode 100644 index 0000000..e644c64 --- /dev/null +++ b/fastgen/networks/LTX2/pipeline_ltx2.py @@ -0,0 +1,1083 @@ +# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +from typing import Any, Callable + +import numpy as np +import torch +from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast + +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.loaders import FromSingleFileMixin, LTX2LoraLoaderMixin +from diffusers.models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video +from diffusers.models.transformers import LTX2VideoTransformer3DModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.ltx2.connectors import LTX2TextConnectors +from diffusers.pipelines.ltx2.pipeline_output import LTX2PipelineOutput +from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTX2Pipeline + >>> from diffusers.pipelines.ltx2.export_utils import encode_video + + >>> pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> frame_rate = 24.0 + >>> video, audio = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=768, + ... height=512, + ... num_frames=121, + ... frame_rate=frame_rate, + ... num_inference_steps=40, + ... guidance_scale=4.0, + ... output_type="np", + ... return_dict=False, + ... ) + + >>> encode_video( + ... video[0], + ... fps=frame_rate, + ... audio=audio[0].float().cpu(), + ... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000 + ... output_path="video.mp4", + ... ) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin): + r""" + Pipeline for text-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTX2VideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Gemma3ForConditionalGeneration`]): + Gemma3 text encoder. + tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`): + Tokenizer for the text encoder. + connectors ([`LTX2TextConnectors`]): + Text connector stack used to adapt text encoder hidden states for the video and audio branches. + vocoder ([`LTX2Vocoder`]): + Vocoder to convert mel spectrograms to waveforms. + """ + + model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTX2Video, + audio_vae: AutoencoderKLLTX2Audio, + text_encoder: Gemma3ForConditionalGeneration, + tokenizer: GemmaTokenizer | GemmaTokenizerFast, + connectors: LTX2TextConnectors, + transformer: LTX2VideoTransformer3DModel, + vocoder: LTX2Vocoder, + ): + super().__init__() + + self.register_modules( + vae=vae, + audio_vae=audio_vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + connectors=connectors, + transformer=transformer, + vocoder=vocoder, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = ( + self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32 + ) + self.vae_temporal_compression_ratio = ( + self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8 + ) + self.audio_vae_mel_compression_ratio = ( + self.audio_vae.mel_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.audio_vae_temporal_compression_ratio = ( + self.audio_vae.temporal_compression_ratio if getattr(self, "audio_vae", None) is not None else 4 + ) + self.transformer_spatial_patch_size = ( + self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1 + ) + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1 + ) + + self.audio_sampling_rate = ( + self.audio_vae.config.sample_rate if getattr(self, "audio_vae", None) is not None else 16000 + ) + self.audio_hop_length = ( + self.audio_vae.config.mel_hop_length if getattr(self, "audio_vae", None) is not None else 160 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024 + ) + + @staticmethod + def _pack_text_embeds( + text_hidden_states: torch.Tensor, + sequence_lengths: torch.Tensor, + device: str | torch.device, + padding_side: str = "left", + scale_factor: int = 8, + eps: float = 1e-6, + ) -> torch.Tensor: + batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape + original_dtype = text_hidden_states.dtype + + token_indices = torch.arange(seq_len, device=device).unsqueeze(0) + if padding_side == "right": + mask = token_indices < sequence_lengths[:, None] + elif padding_side == "left": + start_indices = seq_len - sequence_lengths[:, None] + mask = token_indices >= start_indices + else: + raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}") + mask = mask[:, :, None, None] + + masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0) + num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1) + masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps) + + x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True) + x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True) + + normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps) + normalized_hidden_states = normalized_hidden_states * scale_factor + + normalized_hidden_states = normalized_hidden_states.flatten(2) + mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers) + normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0) + normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype) + return normalized_hidden_states + + def _get_gemma_prompt_embeds( + self, + prompt: str | list[str], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "left" + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + prompt = [p.strip() for p in prompt] + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + text_input_ids = text_input_ids.to(device) + prompt_attention_mask = prompt_attention_mask.to(device) + + text_encoder_outputs = self.text_encoder( + input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + text_encoder_hidden_states = text_encoder_outputs.hidden_states + text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1) + sequence_lengths = prompt_attention_mask.sum(dim=-1) + + prompt_embeds = self._pack_text_embeds( + text_encoder_hidden_states, + sequence_lengths, + device=device, + padding_side=self.tokenizer.padding_side, + scale_factor=scale_factor, + ) + prompt_embeds = prompt_embeds.to(dtype=dtype) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + def encode_prompt( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + max_sequence_length: int = 1024, + scale_factor: int = 8, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + scale_factor=scale_factor, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + @staticmethod + def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents - latents_mean) / latents_std + + @staticmethod + def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor): + latents_mean = latents_mean.to(latents.device, latents.dtype) + latents_std = latents_std.to(latents.device, latents.dtype) + return (latents * latents_std) + latents_mean + + @staticmethod + def _create_noised_state( + latents: torch.Tensor, noise_scale: float | torch.Tensor, generator: torch.Generator | None = None + ): + noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype) + noised_latents = noise_scale * noise + (1 - noise_scale) * latents + return noised_latents + + @staticmethod + def _pack_audio_latents( + latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None + ) -> torch.Tensor: + if patch_size is not None and patch_size_t is not None: + batch_size, num_channels, latent_length, latent_mel_bins = latents.shape + post_patch_latent_length = latent_length / patch_size_t + post_patch_mel_bins = latent_mel_bins / patch_size + latents = latents.reshape( + batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size + ) + latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2) + else: + latents = latents.transpose(1, 2).flatten(2, 3) + return latents + + @staticmethod + def _unpack_audio_latents( + latents: torch.Tensor, + latent_length: int, + num_mel_bins: int, + patch_size: int | None = None, + patch_size_t: int | None = None, + ) -> torch.Tensor: + if patch_size is not None and patch_size_t is not None: + batch_size = latents.size(0) + latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size) + latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) + else: + latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2) + return latents + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 768, + num_frames: int = 121, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 5: + latents = self._normalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + return latents + + def prepare_audio_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 8, + audio_latent_length: int = 1, + num_mel_bins: int = 64, + noise_scale: float = 0.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + generator: torch.Generator | None = None, + latents: torch.Tensor | None = None, + ) -> torch.Tensor: + if latents is not None: + if latents.ndim == 4: + latents = self._pack_audio_latents(latents) + if latents.ndim != 3: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is [batch_size, num_seq, num_features]." + ) + latents = self._normalize_audio_latents(latents, self.audio_vae.latents_mean, self.audio_vae.latents_std) + latents = self._create_noised_state(latents, noise_scale, generator) + return latents.to(device=device, dtype=dtype) + + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_audio_latents(latents) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: str | list[str] = None, + negative_prompt: str | list[str] | None = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + frame_rate: float = 24.0, + num_inference_steps: int = 40, + sigmas: list[float] | None = None, + timesteps: list[int] = None, + guidance_scale: float = 4.0, + guidance_rescale: float = 0.0, + noise_scale: float = 0.0, + num_videos_per_prompt: int = 1, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + audio_latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + prompt_attention_mask: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + negative_prompt_attention_mask: torch.Tensor | None = None, + decode_timestep: float | list[float] = 0.0, + decode_noise_scale: float | list[float] | None = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], + max_sequence_length: int = 1024, + generate_audio: bool = True, # NEW: set False to skip all audio ops + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `list[str]`, *optional*): + The prompt or prompts to guide image generation. + height (`int`, *optional*, defaults to `512`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `768`): + The width in pixels of the generated video. + num_frames (`int`, *optional*, defaults to `121`): + The number of video frames to generate. + frame_rate (`float`, *optional*, defaults to `24.0`): + The frames per second (FPS) of the generated video. + num_inference_steps (`int`, *optional*, defaults to 40): + The number of denoising steps. + sigmas (`List[float]`, *optional*): + Custom sigmas for the denoising process. + timesteps (`list[int]`, *optional*): + Custom timesteps for the denoising process. + guidance_scale (`float`, *optional*, defaults to `4.0`): + Classifier-free guidance scale. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor. + noise_scale (`float`, *optional*, defaults to `0.0`): + Noise interpolation factor applied to latents before denoising. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + Generator(s) for deterministic generation. + latents (`torch.Tensor`, *optional*): + Pre-generated video latents. + audio_latents (`torch.Tensor`, *optional*): + Pre-generated audio latents. Ignored when ``generate_audio=False``. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + Noise scale at the decode timestep. + output_type (`str`, *optional*, defaults to `"pil"`): + Output format for the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a structured output. + attention_kwargs (`dict`, *optional*): + Extra kwargs passed to the attention processor. + callback_on_step_end (`Callable`, *optional*): + Callback called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*, defaults to `["latents"]`): + Tensor inputs for the step-end callback. + max_sequence_length (`int`, *optional*, defaults to `1024`): + Maximum sequence length for the prompt. + generate_audio (`bool`, *optional*, defaults to `True`): + Whether to generate audio alongside the video. Set to ``False`` to skip all audio + operations (latent preparation, denoising, decoding). When ``False``, the returned + ``audio`` value will be ``None``. Automatically forced to ``False`` when the + transformer was built with ``audio_enabled=False``. + + Examples: + + Returns: + [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is the generated video frames + and the second element is the generated audio (or ``None``). + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # Honour the transformer's construction-time gate — if it has no audio modules, + # we cannot generate audio regardless of what the caller requests. + run_audio = generate_audio and getattr(self.transformer.config, "audio_enabled", True) + + # 1. Check inputs + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._attention_kwargs = attention_kwargs + self._interrupt = False + self._current_timestep = None + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0 + connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors( + prompt_embeds, additive_attention_mask, additive_mask=True + ) + + # 4. Prepare latent variables + + # --- Video latents (always prepared) --- + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + if latents is not None: + if latents.ndim == 5: + logger.info( + "Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], " + "`latent_num_frames`, `latent_height`, `latent_width` will be inferred." + ) + _, _, latent_num_frames, latent_height, latent_width = latents.shape + elif latents.ndim == 3: + logger.warning( + f"You have supplied packed `latents` of shape {latents.shape}, so the latent dims cannot be" + f" inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct." + ) + else: + raise ValueError( + f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either " + "[batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]." + ) + video_sequence_length = latent_num_frames * latent_height * latent_width + + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + noise_scale, + torch.float32, + device, + generator, + latents, + ) + + # --- Audio latents (only prepared when run_audio) --- + duration_s = num_frames / frame_rate + audio_latents_per_second = ( + self.audio_sampling_rate / self.audio_hop_length / float(self.audio_vae_temporal_compression_ratio) + ) + audio_num_frames = round(duration_s * audio_latents_per_second) + + num_mel_bins = self.audio_vae.config.mel_bins if getattr(self, "audio_vae", None) is not None else 64 + latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio + num_channels_latents_audio = ( + self.audio_vae.config.latent_channels if getattr(self, "audio_vae", None) is not None else 8 + ) + + if run_audio: + if audio_latents is not None: + if audio_latents.ndim == 4: + logger.info( + "Got audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], " + "`audio_num_frames` will be inferred." + ) + _, _, audio_num_frames, _ = audio_latents.shape + elif audio_latents.ndim == 3: + logger.warning( + f"You have supplied packed `audio_latents` of shape {audio_latents.shape}, so the latent " + f"dims cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct." + ) + else: + raise ValueError( + f"Provided `audio_latents` tensor has shape {audio_latents.shape}, but the expected shape " + "is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins]." + ) + + audio_latents = self.prepare_audio_latents( + batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents_audio, + audio_latent_length=audio_num_frames, + num_mel_bins=num_mel_bins, + noise_scale=noise_scale, + dtype=torch.float32, + device=device, + generator=generator, + latents=audio_latents, + ) + else: + audio_latents = None + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.get("base_image_seq_len", 1024), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.95), + self.scheduler.config.get("max_shift", 2.05), + ) + + if run_audio: + audio_scheduler = copy.deepcopy(self.scheduler) + _, _ = retrieve_timesteps(audio_scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, mu=mu) + else: + audio_scheduler = None + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas, mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions (RoPE coords) + # Video coords — always computed + video_coords = self.transformer.rope.prepare_video_coords( + latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate + ) + if self.do_classifier_free_guidance: + video_coords = video_coords.repeat((2,) + (1,) * (video_coords.ndim - 1)) + + # Audio coords — only computed when run_audio + if run_audio: + audio_coords = self.transformer.audio_rope.prepare_audio_coords( + audio_latents.shape[0], audio_num_frames, audio_latents.device + ) + if self.do_classifier_free_guidance: + audio_coords = audio_coords.repeat((2,) + (1,) * (audio_coords.ndim - 1)) + else: + audio_coords = None + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + # Video input (always) + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + + # Audio input (only when run_audio) + if run_audio: + audio_latent_model_input = ( + torch.cat([audio_latents] * 2) if self.do_classifier_free_guidance else audio_latents + ) + audio_latent_model_input = audio_latent_model_input.to(prompt_embeds.dtype) + else: + audio_latent_model_input = None + + timestep = t.expand(latent_model_input.shape[0]) + + with self.transformer.cache_context("cond_uncond"): + noise_pred_video, noise_pred_audio = self.transformer( + hidden_states=latent_model_input, + audio_hidden_states=audio_latent_model_input, + encoder_hidden_states=connector_prompt_embeds, + audio_encoder_hidden_states=connector_audio_prompt_embeds if run_audio else None, + timestep=timestep, + encoder_attention_mask=connector_attention_mask, + audio_encoder_attention_mask=connector_attention_mask if run_audio else None, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + fps=frame_rate, + audio_num_frames=audio_num_frames if run_audio else None, + video_coords=video_coords, + audio_coords=audio_coords, + attention_kwargs=attention_kwargs, + return_dict=False, + ) + + noise_pred_video = noise_pred_video.float() + + if self.do_classifier_free_guidance: + noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2) + noise_pred_video = noise_pred_video_uncond + self.guidance_scale * ( + noise_pred_video_text - noise_pred_video_uncond + ) + if self.guidance_rescale > 0: + noise_pred_video = rescale_noise_cfg( + noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale + ) + + latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0] + + # Audio denoising step — only when run_audio + if run_audio: + noise_pred_audio = noise_pred_audio.float() + if self.do_classifier_free_guidance: + noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2) + noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * ( + noise_pred_audio_text - noise_pred_audio_uncond + ) + if self.guidance_rescale > 0: + noise_pred_audio = rescale_noise_cfg( + noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale + ) + audio_latents = audio_scheduler.step(noise_pred_audio, t, audio_latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Decode video (always) + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + + # 9. Decode audio (only when run_audio) + if run_audio: + audio_latents = self._denormalize_audio_latents( + audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std + ) + audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins) + + if output_type == "latent": + video = latents + audio = audio_latents if run_audio else None + else: + latents = latents.to(prompt_embeds.dtype) + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(self.vae.dtype) + video = self.vae.decode(latents, timestep, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + if run_audio: + audio_latents = audio_latents.to(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + audio = self.vocoder(generated_mel_spectrograms) + else: + audio = None + + self.maybe_free_model_hooks() + + if not return_dict: + return (video, audio) + + return LTX2PipelineOutput(frames=video, audio=audio) diff --git a/fastgen/networks/LTX2/test_ltx2_pipeline.py b/fastgen/networks/LTX2/test_ltx2_pipeline.py new file mode 100644 index 0000000..7463eb3 --- /dev/null +++ b/fastgen/networks/LTX2/test_ltx2_pipeline.py @@ -0,0 +1,60 @@ +import torch +import numpy as np +import imageio + +from pipeline_ltx2 import LTX2Pipeline +from transformer_ltx2 import LTX2VideoTransformer3DModel + +device = "cuda:0" +width = 768 +height = 512 + +# 1. Load the full pipeline (vae, scheduler, text_encoder, etc.) +pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16) + +# 2. Build a video-only transformer — no audio weights allocated +config = pipe.transformer.config +transformer = LTX2VideoTransformer3DModel.from_config(config, audio_enabled=False) +transformer = transformer.to(torch.bfloat16) + +# 3. Copy video weights from the pretrained AV transformer, skip audio keys +state_dict = pipe.transformer.state_dict() +missing, unexpected = transformer.load_state_dict(state_dict, strict=False) +assert len(missing) == 0, f"Missing video keys: {missing}" +print(f"Skipped {len(unexpected)} audio keys from checkpoint") + +# 4. Swap in the video-only transformer and free the original AV one +del pipe.transformer +pipe.transformer = transformer + +pipe.to(device) # fully on GPU — fastest, needs ~24GB+ VRAM +# pipe.enable_model_cpu_offload(device=device) # offloads whole models between steps, needs ~12GB VRAM + +prompt = "A beautiful sunset over the ocean" +negative_prompt = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static." + +frame_rate = 24.0 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=width, + height=height, + num_frames=121, + frame_rate=frame_rate, + num_inference_steps=40, + guidance_scale=4.0, + output_type="np", + return_dict=False, + generate_audio=False, # audio is None, no audio ops run +) +print(audio) + +# video[0] is (num_frames, height, width, 3) float in [0, 1] +frames = (video[0] * 255).clip(0, 255).astype(np.uint8) + +output_path = "ltx2_video_only.mp4" +with imageio.get_writer(output_path, fps=frame_rate, codec="libx264", quality=8) as writer: + for frame in frames: + writer.append_data(frame) + +print(f"Saved to {output_path}") \ No newline at end of file diff --git a/fastgen/networks/LTX2/test_ltx_network.py b/fastgen/networks/LTX2/test_ltx_network.py new file mode 100644 index 0000000..5528b26 --- /dev/null +++ b/fastgen/networks/LTX2/test_ltx_network.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +from diffusers.pipelines.ltx2.export_utils import encode_video +from fastgen.networks.LTX2.network import LTX2 + + +def test_ltx2_generation(): + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 # LTX-2 is optimized for bfloat16 + + # 1. Initialize the LTX2 model + print("Initializing LTX-2 model...") + model = LTX2(model_id="Lightricks/LTX-2", load_pretrained=True).to(device, dtype=dtype) + model.init_preprocessors() + model.eval() + + # 2. Prepare Prompts + # prompt = "A high-performance sports car racing through a city street at night, neon lights reflecting off wet asphalt, motion blur streaking past the camera. The camera tracks low and close to the car as it accelerates aggressively, tires gripping the road, exhaust heat shimmering. Realistic lighting, cinematic depth of field, ultra-detailed textures, dynamic reflections, dramatic shadows, 4K realism, film-grade color grading." + prompt = "A tight close-up shot of a musician's hands playing a grand piano. The audio is a fast-paced, uplifting classical piano sonata. The pressing of the black and white keys visually syncs with the rapid flurry of high-pitched musical notes. There is a slight echo, suggesting the piano is in a large, empty concert hall." + prompt = "A street performer sitting on a brick stoop, strumming an acoustic guitar. The audio features a warm, indie-folk guitar chord progression. Over the guitar, a smooth, soulful human voice sings a slow, bluesy melody without words, just melodic humming and 'oohs'. The rhythmic strumming of the guitar perfectly matches the tempo of the vocal melody. Faint city traffic can be heard quietly in the deep background." + negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + print(f"Encoding prompt: {prompt}") + # FIX 1: encode() returns (embeds, mask) tuple — unpack and move each tensor separately + embeds, mask = model.text_encoder.encode(prompt, precision=dtype) + condition = (embeds.to(device), mask.to(device)) + neg_embeds, neg_mask = model.text_encoder.encode(negative_prompt, precision=dtype) + neg_condition = (neg_embeds.to(device), neg_mask.to(device)) + + # 3. Define Video Parameters + # LTX-2 video dimensions must be divisible by 32 spatially and (8n+1) temporally + height, width = 480, 704 + num_frames = 81 # (8 * 10) + 1 + batch_size = 1 + + # Calculate latent dimensions + latent_f = (num_frames - 1) // model.vae_temporal_compression_ratio + 1 + latent_h = height // model.vae_spatial_compression_ratio + latent_w = width // model.vae_spatial_compression_ratio + + # 4. Generate Initial Noise + # FIX 2: model.vae IS AutoencoderKLLTX2Video directly — no .vae sub-attribute. + # Use transformer.config.in_channels as the pipeline does (line 989). + latent_channels = model.transformer.config.in_channels # 128 + + # FIX 3: noise must be float32 — the pipeline creates latents in torch.float32 + # (prepare_latents line 997) and keeps them float32 through the scheduler. + # Only cast to bfloat16 right before the transformer call. + noise = torch.randn( + batch_size, latent_channels, latent_f, latent_h, latent_w, + device=device, dtype=torch.float32 + ) + + # 5. Run Sampling (Inference) + print("Starting sampling process...") + with torch.no_grad(): + latents, audio_latents = model.sample( + noise=noise, + condition=condition, + neg_condition=neg_condition, + guidance_scale=4.0, + num_steps=40, + fps=24.0, + ) + # latents: [B, C, F, H, W] denormalised video latents (float32) + # audio_latents: [B, C, L, M] denormalised audio latents (float32) + + # 6. Decode Latents to Video + print("Decoding latents to video...") + with torch.no_grad(): + video_tensor = model.vae.decode(latents.to(model.vae.dtype)) + # model.vae.decode(latents.to(model.vae.dtype), return_dict=False)[0] + # video_tensor: [B, C, F, H, W] in ~[-1, 1] + + # 7. Post-process and Save + # Convert [B, C, F, H, W] -> [F, H, W, C] uint8 + video_np = ( + (video_tensor[0][0].cpu().float().permute(1, 2, 3, 0).numpy() + 1.0) * 127.5 + ).clip(0, 255).astype(np.uint8) + + + print("Saving video to ltx2_test.mp4...") + import imageio + with imageio.get_writer("ltx2_test.mp4", fps=24, codec="libx264", quality=8) as writer: + for frame in video_np: + writer.append_data(frame) + print("Done!") + +if __name__ == "__main__": + test_ltx2_generation() \ No newline at end of file diff --git a/fastgen/networks/LTX2/transformer_ltx2.py b/fastgen/networks/LTX2/transformer_ltx2.py new file mode 100644 index 0000000..1d66730 --- /dev/null +++ b/fastgen/networks/LTX2/transformer_ltx2.py @@ -0,0 +1,1203 @@ +# Copyright 2025 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.utils import BaseOutput, apply_lora_scale, is_torch_version, logging +from diffusers.models._modeling_parallel import ContextParallelInput, ContextParallelOutput +from diffusers.models.attention import AttentionMixin, AttentionModuleMixin, FeedForward +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.cache_utils import CacheMixin +from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings, PixArtAlphaTextProjection +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out + + +def apply_split_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = freqs + + x_dtype = x.dtype + needs_reshape = False + if x.ndim != 4 and cos.ndim == 4: + b, h, t, _ = cos.shape + x = x.reshape(b, t, h, -1).swapaxes(1, 2) + needs_reshape = True + + last = x.shape[-1] + if last % 2 != 0: + raise ValueError(f"Expected x.shape[-1] to be even for split rotary, got {last}.") + r = last // 2 + + split_x = x.reshape(*x.shape[:-1], 2, r).float() + first_x = split_x[..., :1, :] + second_x = split_x[..., 1:, :] + + cos_u = cos.unsqueeze(-2) + sin_u = sin.unsqueeze(-2) + + out = split_x * cos_u + first_out = out[..., :1, :] + second_out = out[..., 1:, :] + + first_out.addcmul_(-sin_u, second_x) + second_out.addcmul_(sin_u, first_x) + + out = out.reshape(*out.shape[:-2], last) + + if needs_reshape: + out = out.swapaxes(1, 2).reshape(b, t, -1) + + out = out.to(dtype=x_dtype) + return out + + +@dataclass +class AudioVisualModelOutput(BaseOutput): + r""" + Holds the output of an audiovisual model which produces both visual (e.g. video) and audio outputs. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on the `encoder_hidden_states` input, representing the visual output + of the model. + audio_sample (`torch.Tensor` or `None`): + The audio output of the audiovisual model. None when audio is disabled. + """ + + sample: "torch.Tensor" # noqa: F821 + audio_sample: "torch.Tensor | None" # noqa: F821 + + +class LTX2AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://huggingface.co/papers/2310.00426; Section 2.3) and adapted by the LTX-2.0 + model. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_mod_params (`int`, *optional*, defaults to `6`): + The number of modulation parameters. + use_additional_conditions (`bool`, *optional*, defaults to `False`): + Whether to use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, num_mod_params: int = 6, use_additional_conditions: bool = False): + super().__init__() + self.num_mod_params = num_mod_params + + self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, self.num_mod_params * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: dict[str, torch.Tensor] | None = None, + batch_size: int | None = None, + hidden_dtype: torch.dtype | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +class LTX2AudioVideoAttnProcessor: + r""" + Processor for implementing attention (SDPA) for the LTX-2.0 model. + Supports separate RoPE embeddings for queries and keys (a2v / v2a cross attention). + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if is_torch_version("<", "2.0"): + raise ValueError( + "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation." + ) + + def __call__( + self, + attn: "LTX2Attention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if query_rotary_emb is not None: + if attn.rope_type == "interleaved": + query = apply_interleaved_rotary_emb(query, query_rotary_emb) + key = apply_interleaved_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + elif attn.rope_type == "split": + query = apply_split_rotary_emb(query, query_rotary_emb) + key = apply_split_rotary_emb( + key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb + ) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LTX2Attention(torch.nn.Module, AttentionModuleMixin): + r""" + Attention class for all LTX-2.0 attention layers. + Supports separate query and key RoPE embeddings for a2v / v2a cross-attention. + """ + + _default_processor_cls = LTX2AudioVideoAttnProcessor + _available_processors = [LTX2AudioVideoAttnProcessor] + + def __init__( + self, + query_dim: int, + heads: int = 8, + kv_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = True, + cross_attention_dim: int | None = None, + out_bias: bool = True, + qk_norm: str = "rms_norm_across_heads", + norm_eps: float = 1e-6, + norm_elementwise_affine: bool = True, + rope_type: str = "interleaved", + processor=None, + ): + super().__init__() + if qk_norm != "rms_norm_across_heads": + raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.") + + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = query_dim + self.heads = heads + self.rope_type = rope_type + + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + hidden_states = self.processor( + self, hidden_states, encoder_hidden_states, attention_mask, query_rotary_emb, key_rotary_emb, **kwargs + ) + return hidden_states + + +class LTX2VideoTransformerBlock(nn.Module): + r""" + Transformer block used in LTX-2.0. + + Supports two-level audio gating (Option C): + - Construction-time: pass ``audio_dim=None`` to skip all audio module allocation. + - Runtime: pass ``audio_enabled=False`` in ``forward()`` to skip audio ops this step. + + The a2v cross-attention (video attending to audio) is intentionally decoupled from + ``audio_enabled`` — video can still attend to audio as conditioning even when the + audio update branch is disabled, matching the original BasicAVTransformerBlock behaviour. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + # Audio args — all optional. Pass None to build a video-only block. + audio_dim: int | None = None, + audio_num_attention_heads: int | None = None, + audio_attention_head_dim: int | None = None, + audio_cross_attention_dim: int | None = None, + qk_norm: str = "rms_norm_across_heads", + activation_fn: str = "gelu-approximate", + attention_bias: bool = True, + attention_out_bias: bool = True, + eps: float = 1e-6, + elementwise_affine: bool = False, + rope_type: str = "interleaved", + ): + super().__init__() + + # Construction-time gate + self.has_audio = audio_dim is not None + + if self.has_audio: + assert audio_num_attention_heads is not None + assert audio_attention_head_dim is not None + assert audio_cross_attention_dim is not None + + # --- 1. Video Self-Attention (always built) --- + self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn1 = LTX2Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # --- 1b. Audio Self-Attention (conditional) --- + if self.has_audio: + self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn1 = LTX2Attention( + query_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # --- 2. Video Prompt Cross-Attention (always built) --- + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn2 = LTX2Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # --- 2b. Audio Prompt Cross-Attention (conditional) --- + if self.has_audio: + self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_attn2 = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=audio_cross_attention_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # --- 3. Audio-Video Cross-Attention (conditional — both modalities) --- + if self.has_audio: + # a2v: Q=Video, K/V=Audio + self.audio_to_video_norm = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_to_video_attn = LTX2Attention( + query_dim=dim, + cross_attention_dim=audio_dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # v2a: Q=Audio, K/V=Video + self.video_to_audio_norm = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.video_to_audio_attn = LTX2Attention( + query_dim=audio_dim, + cross_attention_dim=dim, + heads=audio_num_attention_heads, + kv_heads=audio_num_attention_heads, + dim_head=audio_attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + rope_type=rope_type, + ) + + # Per-layer cross-attention modulation params + self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim)) + self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim)) + + # --- 4. Feedforward (video always built, audio conditional) --- + self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.ff = FeedForward(dim, activation_fn=activation_fn) + + if self.has_audio: + self.audio_norm3 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine) + self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn) + + # --- 5. AdaLN modulation params (video always, audio conditional) --- + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + if self.has_audio: + self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor | None, + encoder_hidden_states: torch.Tensor, + audio_encoder_hidden_states: torch.Tensor | None, + temb: torch.Tensor, + temb_audio: torch.Tensor | None, + temb_ca_scale_shift: torch.Tensor | None, + temb_ca_audio_scale_shift: torch.Tensor | None, + temb_ca_gate: torch.Tensor | None, + temb_ca_audio_gate: torch.Tensor | None, + video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + audio_encoder_attention_mask: torch.Tensor | None = None, + a2v_cross_attention_mask: torch.Tensor | None = None, + v2a_cross_attention_mask: torch.Tensor | None = None, + audio_enabled: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + batch_size = hidden_states.size(0) + + # Runtime gates — mirrors BasicAVTransformerBlock exactly + # run_ax: audio updates (self-attn, cross-attn, ffn, v2a) + # run_a2v: video attending to audio — asymmetrically decoupled from audio_enabled + # so video can still use audio as conditioning even when audio_enabled=False + # run_v2a: audio updates from video — tied to run_ax + run_ax = self.has_audio and audio_enabled and audio_hidden_states is not None + run_a2v = self.has_audio and audio_hidden_states is not None + run_v2a = run_ax + + # 1. Video Self-Attention (always runs) + norm_hidden_states = self.norm1(hidden_states) + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape( + batch_size, temb.size(1), num_ada_params, -1 + ) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=video_rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa + + # 1b. Audio Self-Attention (conditional) + if run_ax: + norm_audio_hidden_states = self.audio_norm1(audio_hidden_states) + num_audio_ada_params = self.audio_scale_shift_table.shape[0] + audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape( + batch_size, temb_audio.size(1), num_audio_ada_params, -1 + ) + audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = ( + audio_ada_values.unbind(dim=2) + ) + norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa + attn_audio_hidden_states = self.audio_attn1( + hidden_states=norm_audio_hidden_states, + encoder_hidden_states=None, + query_rotary_emb=audio_rotary_emb, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa + + # 2. Video Cross-Attention with text (always runs) + norm_hidden_states = self.norm2(hidden_states) + attn_hidden_states = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_rotary_emb=None, + attention_mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + + # 2b. Audio Cross-Attention with text (conditional) + if run_ax: + norm_audio_hidden_states = self.audio_norm2(audio_hidden_states) + attn_audio_hidden_states = self.audio_attn2( + norm_audio_hidden_states, + encoder_hidden_states=audio_encoder_hidden_states, + query_rotary_emb=None, + attention_mask=audio_encoder_attention_mask, + ) + audio_hidden_states = audio_hidden_states + attn_audio_hidden_states + + # 3. Audio-Video Cross-Attention + if run_a2v or run_v2a: + norm_hidden_states_av = self.audio_to_video_norm(hidden_states) + norm_audio_hidden_states_av = self.video_to_audio_norm(audio_hidden_states) + + # Video modulation params + video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :] + video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :] + video_ca_scale_shift_table = ( + video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype) + + temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + video_ca_gate = ( + video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype) + + temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1) + ).unbind(dim=2) + video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table + a2v_gate = video_ca_gate[0].squeeze(2) + + # Audio modulation params + audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :] + audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :] + audio_ca_scale_shift_table = ( + audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype) + + temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1) + ).unbind(dim=2) + audio_ca_gate = ( + audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype) + + temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1) + ).unbind(dim=2) + audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table + v2a_gate = audio_ca_gate[0].squeeze(2) + + # 3a. a2v: Q=Video, K/V=Audio (runs even when audio_enabled=False) + if run_a2v: + mod_norm_hidden_states = ( + norm_hidden_states_av * (1 + video_a2v_ca_scale.squeeze(2)) + + video_a2v_ca_shift.squeeze(2) + ) + mod_norm_audio_hidden_states = ( + norm_audio_hidden_states_av * (1 + audio_a2v_ca_scale.squeeze(2)) + + audio_a2v_ca_shift.squeeze(2) + ) + a2v_attn_hidden_states = self.audio_to_video_attn( + mod_norm_hidden_states, + encoder_hidden_states=mod_norm_audio_hidden_states, + query_rotary_emb=ca_video_rotary_emb, + key_rotary_emb=ca_audio_rotary_emb, + attention_mask=a2v_cross_attention_mask, + ) + hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states + + # 3b. v2a: Q=Audio, K/V=Video (only when audio branch is fully active) + if run_v2a: + mod_norm_hidden_states = ( + norm_hidden_states_av * (1 + video_v2a_ca_scale.squeeze(2)) + + video_v2a_ca_shift.squeeze(2) + ) + mod_norm_audio_hidden_states = ( + norm_audio_hidden_states_av * (1 + audio_v2a_ca_scale.squeeze(2)) + + audio_v2a_ca_shift.squeeze(2) + ) + v2a_attn_hidden_states = self.video_to_audio_attn( + mod_norm_audio_hidden_states, + encoder_hidden_states=mod_norm_hidden_states, + query_rotary_emb=ca_audio_rotary_emb, + key_rotary_emb=ca_video_rotary_emb, + attention_mask=v2a_cross_attention_mask, + ) + audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states + + # 4. Video Feedforward (always runs) + norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp + hidden_states = hidden_states + self.ff(norm_hidden_states) * gate_mlp + + # 4b. Audio Feedforward (conditional) + if run_ax: + norm_audio_hidden_states = ( + self.audio_norm3(audio_hidden_states) * (1 + audio_scale_mlp) + audio_shift_mlp + ) + audio_hidden_states = audio_hidden_states + self.audio_ff(norm_audio_hidden_states) * audio_gate_mlp + + # Return None for audio when it didn't run — prevents stale tensor propagation + return hidden_states, audio_hidden_states if run_ax else None + + +class LTX2AudioVideoRotaryPosEmbed(nn.Module): + """ + Video and audio rotary positional embeddings (RoPE) for the LTX-2.0 model. + """ + + def __init__( + self, + dim: int, + patch_size: int = 1, + patch_size_t: int = 1, + base_num_frames: int = 20, + base_height: int = 2048, + base_width: int = 2048, + sampling_rate: int = 16000, + hop_length: int = 160, + scale_factors: tuple[int, ...] = (8, 32, 32), + theta: float = 10000.0, + causal_offset: int = 1, + modality: str = "video", + double_precision: bool = True, + rope_type: str = "interleaved", + num_attention_heads: int = 32, + ) -> None: + super().__init__() + + self.dim = dim + self.patch_size = patch_size + self.patch_size_t = patch_size_t + + if rope_type not in ["interleaved", "split"]: + raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.") + self.rope_type = rope_type + + self.base_num_frames = base_num_frames + self.num_attention_heads = num_attention_heads + + self.base_height = base_height + self.base_width = base_width + + self.sampling_rate = sampling_rate + self.hop_length = hop_length + self.audio_latents_per_second = float(sampling_rate) / float(hop_length) / float(scale_factors[0]) + + self.scale_factors = scale_factors + self.theta = theta + self.causal_offset = causal_offset + + self.modality = modality + if self.modality not in ["video", "audio"]: + raise ValueError(f"Modality {modality} is not supported. Supported modalities are `video` and `audio`.") + self.double_precision = double_precision + + def prepare_video_coords( + self, + batch_size: int, + num_frames: int, + height: int, + width: int, + device: torch.device, + fps: float = 24.0, + ) -> torch.Tensor: + grid_f = torch.arange(start=0, end=num_frames, step=self.patch_size_t, dtype=torch.float32, device=device) + grid_h = torch.arange(start=0, end=height, step=self.patch_size, dtype=torch.float32, device=device) + grid_w = torch.arange(start=0, end=width, step=self.patch_size, dtype=torch.float32, device=device) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) + + patch_size = (self.patch_size_t, self.patch_size, self.patch_size) + patch_size_delta = torch.tensor(patch_size, dtype=grid.dtype, device=grid.device) + patch_ends = grid + patch_size_delta.view(3, 1, 1, 1) + + latent_coords = torch.stack([grid, patch_ends], dim=-1) + latent_coords = latent_coords.flatten(1, 3) + latent_coords = latent_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1) + + scale_tensor = torch.tensor(self.scale_factors, device=latent_coords.device) + broadcast_shape = [1] * latent_coords.ndim + broadcast_shape[1] = -1 + pixel_coords = latent_coords * scale_tensor.view(*broadcast_shape) + + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + self.causal_offset - self.scale_factors[0]).clamp(min=0) + pixel_coords[:, 0, ...] = pixel_coords[:, 0, ...] / fps + + return pixel_coords + + def prepare_audio_coords( + self, + batch_size: int, + num_frames: int, + device: torch.device, + shift: int = 0, + ) -> torch.Tensor: + grid_f = torch.arange( + start=shift, end=num_frames + shift, step=self.patch_size_t, dtype=torch.float32, device=device + ) + + audio_scale_factor = self.scale_factors[0] + grid_start_mel = grid_f * audio_scale_factor + grid_start_mel = (grid_start_mel + self.causal_offset - audio_scale_factor).clip(min=0) + grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate + + grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor + grid_end_mel = (grid_end_mel + self.causal_offset - audio_scale_factor).clip(min=0) + grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate + + audio_coords = torch.stack([grid_start_s, grid_end_s], dim=-1) + audio_coords = audio_coords.unsqueeze(0).expand(batch_size, -1, -1) + audio_coords = audio_coords.unsqueeze(1) + return audio_coords + + def prepare_coords(self, *args, **kwargs): + if self.modality == "video": + return self.prepare_video_coords(*args, **kwargs) + elif self.modality == "audio": + return self.prepare_audio_coords(*args, **kwargs) + + def forward( + self, coords: torch.Tensor, device: str | torch.device | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + device = device or coords.device + + num_pos_dims = coords.shape[1] + + if coords.ndim == 4: + coords_start, coords_end = coords.chunk(2, dim=-1) + coords = (coords_start + coords_end) / 2.0 + coords = coords.squeeze(-1) + + if self.modality == "video": + max_positions = (self.base_num_frames, self.base_height, self.base_width) + elif self.modality == "audio": + max_positions = (self.base_num_frames,) + grid = torch.stack([coords[:, i] / max_positions[i] for i in range(num_pos_dims)], dim=-1).to(device) + num_rope_elems = num_pos_dims * 2 + + freqs_dtype = torch.float64 if self.double_precision else torch.float32 + pow_indices = torch.pow( + self.theta, + torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device), + ) + freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32) + + freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs + freqs = freqs.transpose(-1, -2).flatten(2) + + if self.rope_type == "interleaved": + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % num_rope_elems != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % num_rope_elems]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + elif self.rope_type == "split": + expected_freqs = self.dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq = freqs.cos() + sin_freq = freqs.sin() + + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + + b = cos_freq.shape[0] + t = cos_freq.shape[1] + + cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1) + sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1) + + cos_freqs = torch.swapaxes(cos_freq, 1, 2) + sin_freqs = torch.swapaxes(sin_freq, 1, 2) + + return cos_freqs, sin_freqs + + +class LTX2VideoTransformer3DModel( + ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin +): + r""" + A Transformer model for video-like data used in LTX-2.0. + + Supports two-level audio gating (Option C): + - Construction-time: set ``audio_enabled=False`` to build a video-only model with + no audio weights allocated at all. Ideal for finetuning video-only. + - Runtime: pass ``audio_hidden_states=None`` to skip all audio ops for a given + forward pass on a full AV checkpoint. + + When loading a pretrained AV checkpoint into a video-only model, use ``strict=False``: + missing, unexpected = transformer.load_state_dict(state_dict, strict=False) + # unexpected = all audio.* keys — expected, safe to ignore + # missing = should be [] — all video keys present + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["norm"] + _repeated_blocks = ["LTX2VideoTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + "rope": { + 0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + 1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } + + @register_to_config + def __init__( + self, + in_channels: int = 128, + out_channels: int | None = 128, + patch_size: int = 1, + patch_size_t: int = 1, + num_attention_heads: int = 32, + attention_head_dim: int = 128, + cross_attention_dim: int = 4096, + vae_scale_factors: tuple[int, int, int] = (8, 32, 32), + pos_embed_max_pos: int = 20, + base_height: int = 2048, + base_width: int = 2048, + audio_in_channels: int = 128, + audio_out_channels: int | None = 128, + audio_patch_size: int = 1, + audio_patch_size_t: int = 1, + audio_num_attention_heads: int = 32, + audio_attention_head_dim: int = 64, + audio_cross_attention_dim: int = 2048, + audio_scale_factor: int = 4, + audio_pos_embed_max_pos: int = 20, + audio_sampling_rate: int = 16000, + audio_hop_length: int = 160, + num_layers: int = 48, + activation_fn: str = "gelu-approximate", + qk_norm: str = "rms_norm_across_heads", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 3840, + attention_bias: bool = True, + attention_out_bias: bool = True, + rope_theta: float = 10000.0, + rope_double_precision: bool = True, + causal_offset: int = 1, + timestep_scale_multiplier: int = 1000, + cross_attn_timestep_scale_multiplier: int = 1000, + rope_type: str = "interleaved", + audio_enabled: bool = True, # <-- NEW: construction-time gate + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + audio_out_channels = audio_out_channels or audio_in_channels + inner_dim = num_attention_heads * attention_head_dim + audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + + # 1. Patchification input projections + self.proj_in = nn.Linear(in_channels, inner_dim) + if audio_enabled: + self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim) + + # 2. Prompt embeddings + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + if audio_enabled: + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=audio_inner_dim + ) + + # 3. Timestep modulation + self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False) + if audio_enabled: + self.audio_time_embed = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=6, use_additional_conditions=False + ) + self.av_cross_attn_video_scale_shift = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=4, use_additional_conditions=False + ) + self.av_cross_attn_audio_scale_shift = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=4, use_additional_conditions=False + ) + self.av_cross_attn_video_a2v_gate = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=1, use_additional_conditions=False + ) + self.av_cross_attn_audio_v2a_gate = LTX2AdaLayerNormSingle( + audio_inner_dim, num_mod_params=1, use_additional_conditions=False + ) + + # 3.3. Output layer modulation params + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + if audio_enabled: + self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5) + + # 4. RoPE — video always built, audio ropes only when audio_enabled + self.rope = LTX2AudioVideoRotaryPosEmbed( + dim=inner_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + scale_factors=vae_scale_factors, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + if audio_enabled: + self.audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_inner_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=audio_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + scale_factors=[audio_scale_factor], + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads, + ) + cross_attn_pos_embed_max_pos = max(pos_embed_max_pos, audio_pos_embed_max_pos) + self.cross_attn_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=patch_size, + patch_size_t=patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + base_height=base_height, + base_width=base_width, + theta=rope_theta, + causal_offset=causal_offset, + modality="video", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=num_attention_heads, + ) + self.cross_attn_audio_rope = LTX2AudioVideoRotaryPosEmbed( + dim=audio_cross_attention_dim, + patch_size=audio_patch_size, + patch_size_t=audio_patch_size_t, + base_num_frames=cross_attn_pos_embed_max_pos, + sampling_rate=audio_sampling_rate, + hop_length=audio_hop_length, + theta=rope_theta, + causal_offset=causal_offset, + modality="audio", + double_precision=rope_double_precision, + rope_type=rope_type, + num_attention_heads=audio_num_attention_heads, + ) + + # 5. Transformer blocks — pass None dims when audio disabled + self.transformer_blocks = nn.ModuleList( + [ + LTX2VideoTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + audio_dim=audio_inner_dim if audio_enabled else None, + audio_num_attention_heads=audio_num_attention_heads if audio_enabled else None, + audio_attention_head_dim=audio_attention_head_dim if audio_enabled else None, + audio_cross_attention_dim=audio_cross_attention_dim if audio_enabled else None, + qk_norm=qk_norm, + activation_fn=activation_fn, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + rope_type=rope_type, + ) + for _ in range(num_layers) + ] + ) + + # 6. Output layers + self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels) + if audio_enabled: + self.audio_norm_out = nn.LayerNorm(audio_inner_dim, eps=1e-6, elementwise_affine=False) + self.audio_proj_out = nn.Linear(audio_inner_dim, audio_out_channels) + + self.gradient_checkpointing = False + + @apply_lora_scale("attention_kwargs") + def forward( + self, + hidden_states: torch.Tensor, + audio_hidden_states: torch.Tensor | None = None, + encoder_hidden_states: torch.Tensor = None, + audio_encoder_hidden_states: torch.Tensor | None = None, + timestep: torch.LongTensor = None, + audio_timestep: torch.LongTensor | None = None, + encoder_attention_mask: torch.Tensor | None = None, + audio_encoder_attention_mask: torch.Tensor | None = None, + num_frames: int | None = None, + height: int | None = None, + width: int | None = None, + fps: float = 24.0, + audio_num_frames: int | None = None, + video_coords: torch.Tensor | None = None, + audio_coords: torch.Tensor | None = None, + attention_kwargs: dict[str, Any] | None = None, + return_dict: bool = True, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # Single runtime gate — all audio ops key off this. + # Requires: modules exist (construction-time) AND caller supplied a tensor. + run_audio = self.config.audio_enabled and audio_hidden_states is not None + + audio_timestep = audio_timestep if audio_timestep is not None else timestep + + # Attention mask conversion + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if run_audio and audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2: + audio_encoder_attention_mask = ( + 1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype) + ) * -10000.0 + audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + + # 1. RoPE — video always, audio only when run_audio + if video_coords is None: + video_coords = self.rope.prepare_video_coords( + batch_size, num_frames, height, width, hidden_states.device, fps=fps + ) + video_rotary_emb = self.rope(video_coords, device=hidden_states.device) + + if run_audio: + if audio_coords is None: + audio_coords = self.audio_rope.prepare_audio_coords( + batch_size, audio_num_frames, audio_hidden_states.device + ) + audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device) + video_cross_attn_rotary_emb = self.cross_attn_rope( + video_coords[:, 0:1, :], device=hidden_states.device + ) + audio_cross_attn_rotary_emb = self.cross_attn_audio_rope( + audio_coords[:, 0:1, :], device=audio_hidden_states.device + ) + else: + audio_rotary_emb = None + video_cross_attn_rotary_emb = None + audio_cross_attn_rotary_emb = None + + # 2. Patchify + hidden_states = self.proj_in(hidden_states) + if run_audio: + audio_hidden_states = self.audio_proj_in(audio_hidden_states) + + # 3. Timestep embeddings and modulation params + timestep_cross_attn_gate_scale_factor = ( + self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier + ) + + temb, embedded_timestep = self.time_embed( + timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + if run_audio: + temb_audio, audio_embedded_timestep = self.audio_time_embed( + audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, + ) + temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) + audio_embedded_timestep = audio_embedded_timestep.view( + batch_size, -1, audio_embedded_timestep.size(-1) + ) + + video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( + timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate( + timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, hidden_dtype=hidden_states.dtype, + ) + video_cross_attn_scale_shift = video_cross_attn_scale_shift.view( + batch_size, -1, video_cross_attn_scale_shift.shape[-1] + ) + video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view( + batch_size, -1, video_cross_attn_a2v_gate.shape[-1] + ) + + audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift( + audio_timestep.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate( + audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor, + batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype, + ) + audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view( + batch_size, -1, audio_cross_attn_scale_shift.shape[-1] + ) + audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view( + batch_size, -1, audio_cross_attn_v2a_gate.shape[-1] + ) + else: + temb_audio = None + audio_embedded_timestep = None + video_cross_attn_scale_shift = None + video_cross_attn_a2v_gate = None + audio_cross_attn_scale_shift = None + audio_cross_attn_v2a_gate = None + + # 4. Prompt embeddings + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + if run_audio: + audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states) + audio_encoder_hidden_states = audio_encoder_hidden_states.view( + batch_size, -1, audio_hidden_states.size(-1) + ) + + # 5. Transformer blocks + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + # Note: _gradient_checkpointing_func only accepts positional tensor args. + # audio_enabled is not passed explicitly here — the block derives + # run_ax from audio_hidden_states being None when run_audio=False. + hidden_states, audio_hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + audio_hidden_states, + encoder_hidden_states, + audio_encoder_hidden_states, + temb, + temb_audio, + video_cross_attn_scale_shift, + audio_cross_attn_scale_shift, + video_cross_attn_a2v_gate, + audio_cross_attn_v2a_gate, + video_rotary_emb, + audio_rotary_emb, + video_cross_attn_rotary_emb, + audio_cross_attn_rotary_emb, + encoder_attention_mask, + audio_encoder_attention_mask, + ) + else: + hidden_states, audio_hidden_states = block( + hidden_states=hidden_states, + audio_hidden_states=audio_hidden_states, + encoder_hidden_states=encoder_hidden_states, + audio_encoder_hidden_states=audio_encoder_hidden_states, + temb=temb, + temb_audio=temb_audio, + temb_ca_scale_shift=video_cross_attn_scale_shift, + temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, + temb_ca_gate=video_cross_attn_a2v_gate, + temb_ca_audio_gate=audio_cross_attn_v2a_gate, + video_rotary_emb=video_rotary_emb, + audio_rotary_emb=audio_rotary_emb, + ca_video_rotary_emb=video_cross_attn_rotary_emb, + ca_audio_rotary_emb=audio_cross_attn_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + audio_encoder_attention_mask=audio_encoder_attention_mask, + audio_enabled=run_audio, + ) + + # 6. Output layers + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) + + if run_audio: + audio_scale_shift_values = ( + self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None] + ) + audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1] + audio_hidden_states = self.audio_norm_out(audio_hidden_states) + audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift + audio_output = self.audio_proj_out(audio_hidden_states) + else: + audio_output = None + + if not return_dict: + return (output, audio_output) + return AudioVisualModelOutput(sample=output, audio_sample=audio_output) diff --git a/tests/test_ltx2.py b/tests/test_ltx2.py new file mode 100644 index 0000000..93b538a --- /dev/null +++ b/tests/test_ltx2.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +from diffusers.pipelines.ltx2.export_utils import encode_video +from fastgen.networks.LTX2.network import LTX2 + + +def test_ltx2_generation(): + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 # LTX-2 is optimized for bfloat16 + + # 1. Initialize the LTX2 model + print("Initializing LTX-2 model...") + model = LTX2(model_id="Lightricks/LTX-2", load_pretrained=True).to(device, dtype=dtype) + model.init_preprocessors() + model.eval() + + # 2. Prepare Prompts + # prompt = "A high-performance sports car racing through a city street at night, neon lights reflecting off wet asphalt, motion blur streaking past the camera. The camera tracks low and close to the car as it accelerates aggressively, tires gripping the road, exhaust heat shimmering. Realistic lighting, cinematic depth of field, ultra-detailed textures, dynamic reflections, dramatic shadows, 4K realism, film-grade color grading." + prompt = "A tight close-up shot of a musician's hands playing a grand piano. The audio is a fast-paced, uplifting classical piano sonata. The pressing of the black and white keys visually syncs with the rapid flurry of high-pitched musical notes. There is a slight echo, suggesting the piano is in a large, empty concert hall." + prompt = "A street performer sitting on a brick stoop, strumming an acoustic guitar. The audio features a warm, indie-folk guitar chord progression. Over the guitar, a smooth, soulful human voice sings a slow, bluesy melody without words, just melodic humming and 'oohs'. The rhythmic strumming of the guitar perfectly matches the tempo of the vocal melody. Faint city traffic can be heard quietly in the deep background." + negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + print(f"Encoding prompt: {prompt}") + # FIX 1: encode() returns (embeds, mask) tuple — unpack and move each tensor separately + embeds, mask = model.text_encoder.encode(prompt, precision=dtype) + condition = (embeds.to(device), mask.to(device)) + neg_embeds, neg_mask = model.text_encoder.encode(negative_prompt, precision=dtype) + neg_condition = (neg_embeds.to(device), neg_mask.to(device)) + + # 3. Define Video Parameters + # LTX-2 video dimensions must be divisible by 32 spatially and (8n+1) temporally + height, width = 480, 704 + num_frames = 81 # (8 * 10) + 1 + batch_size = 1 + + # Calculate latent dimensions + latent_f = (num_frames - 1) // model.vae_temporal_compression_ratio + 1 + latent_h = height // model.vae_spatial_compression_ratio + latent_w = width // model.vae_spatial_compression_ratio + + # 4. Generate Initial Noise + # FIX 2: model.vae IS AutoencoderKLLTX2Video directly — no .vae sub-attribute. + # Use transformer.config.in_channels as the pipeline does (line 989). + latent_channels = model.transformer.config.in_channels # 128 + + # FIX 3: noise must be float32 — the pipeline creates latents in torch.float32 + # (prepare_latents line 997) and keeps them float32 through the scheduler. + # Only cast to bfloat16 right before the transformer call. + noise = torch.randn( + batch_size, latent_channels, latent_f, latent_h, latent_w, + device=device, dtype=torch.float32 + ) + + # 5. Run Sampling (Inference) + print("Starting sampling process...") + with torch.no_grad(): + latents, audio_latents = model.sample( + noise=noise, + condition=condition, + neg_condition=neg_condition, + guidance_scale=4.0, + num_steps=40, + fps=24.0, + ) + # latents: [B, C, F, H, W] denormalised video latents (float32) + # audio_latents: [B, C, L, M] denormalised audio latents (float32) + + # 6. Decode Latents to Video + print("Decoding latents to video...") + with torch.no_grad(): + video_tensor = model.vae.decode(latents.to(model.vae.dtype), return_dict=False)[0] + # video_tensor: [B, C, F, H, W] in ~[-1, 1] + + # 7. Post-process and Save + # Convert [B, C, F, H, W] -> [F, H, W, C] uint8 + video_np = ( + (video_tensor[0].cpu().float().permute(1, 2, 3, 0).numpy() + 1.0) * 127.5 + ).clip(0, 255).astype(np.uint8) + + print("Saving video to ltx2_test.mp4...") + import imageio + with imageio.get_writer("ltx2_test.mp4", fps=24, codec="libx264", quality=8) as writer: + for frame in video_np: + writer.append_data(frame) + print("Done!") + +if __name__ == "__main__": + test_ltx2_generation() \ No newline at end of file