From dfc15b5c2bdc8824e4eece2036a89418ff98ca5c Mon Sep 17 00:00:00 2001 From: Media Studio Dev Date: Fri, 30 Jan 2026 01:32:02 +0000 Subject: [PATCH 1/2] chore: reapply and keep local modifications after pulling upstream --- setup.py | 29 +- setup.py.bak_20260129_034900 | 62 ++ .../SLA/core.py.bak_20260129_070906 | 258 +++++ ...> wan2.2_i2v_infer.py.bak_20260129_221228} | 0 .../inference/wan2.2_i2v_infer_memopt.py | 309 ++++++ turbodiffusion/rcm/networks/wan2pt2.py | 136 +++ .../networks/wan2pt2.py.bak_20260129_062409 | 773 +++++++++++++++ .../networks/wan2pt2.py.bak_20260129_062642 | 882 ++++++++++++++++++ 8 files changed, 2423 insertions(+), 26 deletions(-) create mode 100644 setup.py.bak_20260129_034900 create mode 100755 turbodiffusion/SLA/core.py.bak_20260129_070906 rename turbodiffusion/inference/{wan2.2_i2v_infer.py => wan2.2_i2v_infer.py.bak_20260129_221228} (100%) create mode 100644 turbodiffusion/inference/wan2.2_i2v_infer_memopt.py create mode 100644 turbodiffusion/rcm/networks/wan2pt2.py.bak_20260129_062409 create mode 100644 turbodiffusion/rcm/networks/wan2pt2.py.bak_20260129_062642 diff --git a/setup.py b/setup.py index 2ca06f9..947347c 100644 --- a/setup.py +++ b/setup.py @@ -20,33 +20,10 @@ ops_dir = Path(__file__).parent / "turbodiffusion" / "ops" cutlass_dir = ops_dir / "cutlass" -nvcc_flags = [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=--verbose,--warn-on-local-memory-usage", - "-lineinfo", - "-DCUTLASS_DEBUG_TRACE_LEVEL=0", - "-DNDEBUG", - "-Xcompiler", - "-fPIC" -] +nvcc_flags = ["-gencode", "arch=compute_86,code=sm_86"] cc_flag = [ - "-gencode", "arch=compute_120a,code=sm_120a", - "-gencode", "arch=compute_100,code=sm_100", - "-gencode", "arch=compute_90,code=sm_90", - "-gencode", "arch=compute_89,code=sm_89", - "-gencode", "arch=compute_80,code=sm_80" -] + ] ext_modules = [ CUDAExtension( @@ -60,7 +37,7 @@ ], extra_compile_args={ "cxx": ["-O3", "-std=c++17"], - "nvcc": nvcc_flags + ["-DEXECMODE=0"] + cc_flag + ["--threads", "4"], + "nvcc": nvcc_flags + ["-DEXECMODE=0"] + cc_flag + ["--threads", "1"], }, include_dirs=[ cutlass_dir / "include", diff --git a/setup.py.bak_20260129_034900 b/setup.py.bak_20260129_034900 new file mode 100644 index 0000000..4e82a1e --- /dev/null +++ b/setup.py.bak_20260129_034900 @@ -0,0 +1,62 @@ +""" +Copyright (c) 2025 by TurboDiffusion team. + +Licensed under the Apache License, Version 2.0 (the "License"); + +Citation (please cite if you use this code): + +@article{zhang2025turbodiffusion, + title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times}, + author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun}, + journal={arXiv preprint arXiv:2512.16093}, + year={2025} +} +""" + +from pathlib import Path +from setuptools import setup, find_packages +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +ops_dir = Path(__file__).parent / "turbodiffusion" / "ops" +cutlass_dir = ops_dir / "cutlass" + +nvcc_flags = ["-gencode", "arch=compute_86,code=sm_86"] + +cc_flag = [ + "-gencode", "arch=compute_120a,code=sm_120a", + "-gencode", "arch=compute_100,code=sm_100", + "-gencode", "arch=compute_90,code=sm_90", + "-gencode", "arch=compute_89,code=sm_89", + "-gencode", "arch=compute_80,code=sm_80" +] + +ext_modules = [ + CUDAExtension( + name="turbo_diffusion_ops", + sources=[ + "turbodiffusion/ops/bindings.cpp", + "turbodiffusion/ops/quant/quant.cu", + "turbodiffusion/ops/norm/rmsnorm.cu", + "turbodiffusion/ops/norm/layernorm.cu", + "turbodiffusion/ops/gemm/gemm.cu" + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"], + "nvcc": nvcc_flags + ["-DEXECMODE=0"] + cc_flag + ["--threads", "1"], + }, + include_dirs=[ + cutlass_dir / "include", + cutlass_dir / "tools" / "util" / "include", + ops_dir + ], + libraries=["cuda"], + ) +] + +setup( + packages=find_packages( + exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks") + ), + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension}, +) diff --git a/turbodiffusion/SLA/core.py.bak_20260129_070906 b/turbodiffusion/SLA/core.py.bak_20260129_070906 new file mode 100755 index 0000000..430bfe0 --- /dev/null +++ b/turbodiffusion/SLA/core.py.bak_20260129_070906 @@ -0,0 +1,258 @@ +""" +Copyright (c) 2025 by SLA team. + +Licensed under the Apache License, Version 2.0 (the "License"); + +Citation (please cite if you use this code): + +@article{zhang2025sla, + title={SLA: Beyond Sparsity in Diffusion Transformers via Fine-Tunable Sparse-Linear Attention}, + author={Jintao Zhang and Haoxu Wang and Kai Jiang and Shuo Yang and Kaiwen Zheng and Haocheng Xi and Ziteng Wang and Hongzhou Zhu and Min Zhao and Ion Stoica and Joseph E. Gonzalez and Jun Zhu and Jianfei Chen}, + journal={arXiv preprint arXiv:2509.24006}, + year={2025} +} +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +SAGESLA_ENABLED = True +try: + import spas_sage_attn._qattn as qattn + import spas_sage_attn._fused as fused + from spas_sage_attn.utils import get_vanilla_qk_quant, block_map_lut_triton +except ImportError: + SAGESLA_ENABLED = False + +SAGE2PP_ENABLED = True +try: + from spas_sage_attn._qattn import qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold +except ImportError: + SAGE2PP_ENABLED = False + +from .kernel import _attention +from .utils import get_block_map, get_cuda_arch + + +class SparseLinearAttention(nn.Module): + def __init__(self, head_dim, topk, feature_map='softmax', BLKQ=64, BLKK=64, use_bf16=True, tie_feature_map_qk=True): + R''' + Args: + head_dim: dimension of each head. + topk: ratio of keys selected for sparse attention, shared across all queries. + feature_map: feature map for linear attention, one of ['hedgehog', 'elu', 'relu', 'softmax']. + BLKQ: block size for query. + BLKK: block size for key. + use_bf16: whether to use bfloat16 (default) or float16 for computation. The conversion to bf16/fp16 is done inside the module. + tie_feature_map_qk: whether to use the same feature map for query and key. + ''' + super().__init__() + self.dtype = torch.bfloat16 if use_bf16 else torch.float16 + self.topk = topk + self.BLKQ = BLKQ + self.BLKK = BLKK + self.proj_l = nn.Linear(head_dim, head_dim, dtype=torch.float32) + + if feature_map == 'elu': + def elu_feature_map(x): + return F.elu(x) + 1 + self.feature_map_q = elu_feature_map + self.feature_map_k = elu_feature_map + elif feature_map == 'relu': + self.feature_map_q = nn.ReLU() + self.feature_map_k = nn.ReLU() + elif feature_map == 'softmax': + def softmax_feature_map(x): + return F.softmax(x, dim=-1) + self.feature_map_q = softmax_feature_map + self.feature_map_k = softmax_feature_map + else: + raise NotImplementedError(f'Not supported feature map {feature_map}.') + + if tie_feature_map_qk: + self.feature_map_k = self.feature_map_q + + self.init_weights_() + + def init_weights_(self): + with torch.no_grad(): + nn.init.zeros_(self.proj_l.weight) + nn.init.zeros_(self.proj_l.bias) + + def forward(self, q, k, v, return_sparsity=False): + R''' + Args: + q: queries of shape (B, H, L, D). + k: keys of shape (B, H, L, D). + v: values of shape (B, H, L, D). + return_sparsity: whether to return the actual sparsity. + ''' + dtype = q.dtype + + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + + sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=self.BLKQ, BLKK=self.BLKK) + + q = q.to(self.dtype) + k = k.to(self.dtype) + v = v.to(self.dtype) + o_s = _attention.apply(q, k, v, sparse_map, lut, real_topk, self.BLKQ, self.BLKK) + + q = self.feature_map_q(q).contiguous().to(self.dtype) # c_q + k = self.feature_map_k(k).contiguous().to(self.dtype) # c_k + def calc_linear(q, k, v): + kvsum = k.transpose(-1, -2) @ v + ksum = torch.sum(k, dim=-2, keepdim=True) + return (q @ kvsum) / (1e-5 + (q * ksum).sum(dim=-1, keepdim=True)) + o_l = calc_linear(q, k, v) + + with torch.amp.autocast('cuda', dtype=self.dtype): + o_l = self.proj_l(o_l) + o = (o_s + o_l).to(dtype).transpose(1, 2) + + if return_sparsity: + return o, real_topk / sparse_map.shape[-1] + else: + return o + + +class SageSparseLinearAttention(nn.Module): + def __init__(self, head_dim, topk, feature_map='softmax', use_bf16=True, tie_feature_map_qk=True): + R''' + Args: + head_dim: dimension of each head. + topk: ratio of keys selected for sparse attention, shared across all queries. + feature_map: feature map for linear attention, one of ['hedgehog', 'elu', 'relu', 'softmax']. + BLKQ: block size for query. + BLKK: block size for key. + use_bf16: whether to use bfloat16 (default) or float16 for computation. The conversion to bf16/fp16 is done inside the module. + tie_feature_map_qk: whether to use the same feature map for query and key. + timestep_adaptive_topk: whether to adaptively adjust topk during diffusion. + ''' + assert SAGESLA_ENABLED, "Install SpargeAttn first to enable SageSLA." + + super().__init__() + self.dtype = torch.bfloat16 if use_bf16 else torch.float16 + self.topk = topk + self.proj_l = nn.Linear(head_dim, head_dim, dtype=torch.float32) + + if feature_map == 'elu': + def elu_feature_map(x): + return F.elu(x) + 1 + self.feature_map_q = elu_feature_map + self.feature_map_k = elu_feature_map + elif feature_map == 'relu': + self.feature_map_q = nn.ReLU() + self.feature_map_k = nn.ReLU() + elif feature_map == 'softmax': + def softmax_feature_map(x): + return F.softmax(x, dim=-1) + self.feature_map_q = softmax_feature_map + self.feature_map_k = softmax_feature_map + else: + raise NotImplementedError(f'Not supported feature map {feature_map}.') + + if tie_feature_map_qk: + self.feature_map_k = self.feature_map_q + + self.init_weights_() + + def init_weights_(self): + with torch.no_grad(): + nn.init.zeros_(self.proj_l.weight) + nn.init.zeros_(self.proj_l.bias) + + def forward(self, q, k, v, return_sparsity=False): + R''' + Args: + q: queries of shape (B, H, L, D). + k: keys of shape (B, H, L, D). + v: values of shape (B, H, L, D). + return_sparsity: whether to return the actual sparsity. + timestep: current timestep for diffusion models. + total_timesteps: total timesteps for diffusion models. + ''' + + dtype = q.dtype + + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + + arch = get_cuda_arch(q.device.index) + if arch == "sm90": + sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=64, BLKK=128) + else: + sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=128, BLKK=64) + + q = q.to(self.dtype) + k = k.to(self.dtype) + v = v.to(self.dtype) + + ########## SPARGE BEGIN ########## + + km = k.mean(dim=-2, keepdim=True) + headdim = q.size(-1) + + if arch == "sm90": + q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, 64, 128) + else: + q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, 128, 64) + lut, valid_block_num = block_map_lut_triton(sparse_map) + scale = 1.0 / (headdim ** 0.5) + + assert headdim in [64, 128], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale." + + o_s = torch.empty_like(q) + + if arch in ("sm80", "sm86", "sm87"): + pvthreshold = torch.full((q.shape[-3],), 1e6, dtype=torch.float32, device=q.device) + v_fp16 = v.to(torch.float16) + qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold( + q_int8, k_int8, v_fp16, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, 1, False, 1, scale, 0 + ) + else: + b, h_kv, kv_len, head_dim = v.shape + padded_len = (kv_len + 127) // 128 * 128 + v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device) + fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1) + v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device) + v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device) + fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1) + + if arch == "sm90": + qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_sm90( + q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, q_scale, k_scale, v_scale, 1, False, 1, scale + ) + else: + pvthreshold = torch.full((q.shape[-3],), 1e6, dtype=torch.float32, device=q.device) + if SAGE2PP_ENABLED: + qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold( + q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, v_scale, 1, False, 1, scale, 0 + ) + else: + qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold( + q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, v_scale, 1, False, 1, scale, 0 + ) + + ########## SPARGE END ########## + + q = self.feature_map_q(q).contiguous().to(self.dtype) # c_q + k = self.feature_map_k(k).contiguous().to(self.dtype) # c_k + def calc_linear(q, k, v): + kvsum = k.transpose(-1, -2) @ v + ksum = torch.sum(k, dim=-2, keepdim=True) + return (q @ kvsum) / (1e-5 + (q * ksum).sum(dim=-1, keepdim=True)) + o_l = calc_linear(q, k, v) + + with torch.amp.autocast('cuda', dtype=self.dtype): + o_l = self.proj_l(o_l) + o = (o_s + o_l).to(dtype).transpose(1, 2) + + if return_sparsity: + return o, real_topk / sparse_map.shape[-1] + else: + return o \ No newline at end of file diff --git a/turbodiffusion/inference/wan2.2_i2v_infer.py b/turbodiffusion/inference/wan2.2_i2v_infer.py.bak_20260129_221228 similarity index 100% rename from turbodiffusion/inference/wan2.2_i2v_infer.py rename to turbodiffusion/inference/wan2.2_i2v_infer.py.bak_20260129_221228 diff --git a/turbodiffusion/inference/wan2.2_i2v_infer_memopt.py b/turbodiffusion/inference/wan2.2_i2v_infer_memopt.py new file mode 100644 index 0000000..193b7eb --- /dev/null +++ b/turbodiffusion/inference/wan2.2_i2v_infer_memopt.py @@ -0,0 +1,309 @@ +# Modified: 2026-01-30 | Fix OOM via MMAP Loading (Zero-Copy) & Aggressive Cleanup +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import math +import gc +import time +import torch +import ctypes +import os +from einops import rearrange, repeat +from tqdm import tqdm +from PIL import Image +import torchvision.transforms.v2 as T +import numpy as np +import logging + +from imaginaire.utils.io import save_image_or_video +from rcm.datasets.utils import VIDEO_RES_SIZE_INFO +from rcm.utils.umt5 import clear_umt5_memory, get_umt5_embedding +from rcm.tokenizers.wan2pt1 import Wan2pt1VAEInterface + +# modify_modelから必要な関数をインポート +from modify_model import tensor_kwargs, select_model, replace_attention, replace_linear_norm + +# ロギング設定 +logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s: %(message)s') +log = logging.getLogger(__name__) + +torch._dynamo.config.suppress_errors = True + +# libc for aggressive memory trimming +try: + libc = ctypes.CDLL("libc.so.6") +except Exception: + libc = None + +def cleanup_all(): + """強制的にガベージコレクション、VRAM解放、およびOSへのメモリ返却を行う""" + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if libc: + try: + libc.malloc_trim(0) + except: + pass + +def create_model_gpu(dit_path: str, args: argparse.Namespace) -> torch.nn.Module: + """ + mmap=True を使用して、CPU RAMを消費せずにモデルをロードする。 + これにより cgroup memory limit (46GB) を回避する。 + """ + log.info(f"Loading DiT (MMAP -> GPU): {dit_path}") + cleanup_all() + + # 1. Init Shell on Meta + with torch.device("meta"): + net = select_model(args.model) + + # 2. Patch + if args.attention_type in ['sla', 'sagesla']: + net = replace_attention(net, attention_type=args.attention_type, sla_topk=args.sla_topk) + replace_linear_norm(net, replace_linear=args.quant_linear, replace_norm=not args.default_norm, quantize=False) + + # 3. Load State Dict with MMAP + # mmap=Trueにより、ファイル内容をRAMにコピーせず、仮想メモリとしてマッピングする。 + # OSが必要な部分だけをページインし、不要になれば即座に破棄できるため、OOMを防ぐ最強の手段。 + log.info(" Mapping state_dict from disk (mmap)...") + try: + # map_location="cpu" + mmap=True が重要 + state_dict = torch.load(dit_path, map_location="cpu", mmap=True) + except Exception as e: + log.warning(f" mmap failed ({e}), falling back to standard load.") + state_dict = torch.load(dit_path, map_location="cpu") + + # 4. Clean keys + new_state_dict = {} + for k, v in state_dict.items(): + new_key = k.replace("_checkpoint_wrapped_module.", "") + new_state_dict[new_key] = v + del state_dict # 元の参照を削除 + + # 5. Load into model + # assign=True により、モデル内のMetaテンソルをmmapされたCPUテンソルに置き換える + log.info(" Assigning weights to model...") + net.load_state_dict(new_state_dict, strict=False, assign=True) + del new_state_dict + + # 6. Move to CUDA + # ここで初めてVRAMへ転送される。転送済みmmapページはOSが勝手に捨ててくれる。 + log.info(" Moving model to CUDA...") + net = net.to("cuda") + + # 7. Eval + net.eval() + return net + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="TurboDiffusion inference script (Memory Optimized)") + parser.add_argument("--image_path", type=str, default=None) + parser.add_argument("--high_noise_model_path", type=str, required=True) + parser.add_argument("--low_noise_model_path", type=str, required=True) + parser.add_argument("--boundary", type=float, default=0.9) + parser.add_argument("--model", choices=["Wan2.2-A14B"], default="Wan2.2-A14B") + parser.add_argument("--num_samples", type=int, default=1) + parser.add_argument("--num_steps", type=int, choices=[1, 2, 3, 4], default=4) + parser.add_argument("--sigma_max", type=float, default=200) + parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth") + parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--num_frames", type=int, default=81) + parser.add_argument("--prompt", type=str, default=None) + parser.add_argument("--resolution", default="720p", type=str) + parser.add_argument("--aspect_ratio", default="16:9", type=str) + parser.add_argument("--adaptive_resolution", action="store_true") + parser.add_argument("--ode", action="store_true") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--save_path", type=str, default="output/generated_video.mp4") + parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla") + parser.add_argument("--sla_topk", type=float, default=0.1) + parser.add_argument("--quant_linear", action="store_true") + parser.add_argument("--default_norm", action="store_true") + parser.add_argument("--serve", action="store_true") + return parser.parse_args() + +if __name__ == "__main__": + args = parse_arguments() + + if args.serve: + log.error("Serve mode is not supported in this memory-optimized script.") + exit(1) + + if args.prompt is None or args.image_path is None: + log.error("--prompt and --image_path are required") + exit(1) + + cleanup_all() + + # 1. Text Encoder (T5) + log.info(f"Computing embedding for prompt: {args.prompt}") + with torch.no_grad(): + text_emb = get_umt5_embedding( + checkpoint_path=args.text_encoder_path, + prompts=args.prompt + ).to(**tensor_kwargs) + + clear_umt5_memory() + cleanup_all() + log.info("Text encoder unloaded.") + + # 2. VAE & Image Preprocessing + log.info(f"Loading and preprocessing image from: {args.image_path}") + input_image = Image.open(args.image_path).convert("RGB") + + tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) + + if args.adaptive_resolution: + base_w, base_h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] + max_resolution_area = base_w * base_h + orig_w, orig_h = input_image.size + image_aspect_ratio = orig_h / orig_w + ideal_w = np.sqrt(max_resolution_area / image_aspect_ratio) + ideal_h = np.sqrt(max_resolution_area * image_aspect_ratio) + stride = tokenizer.spatial_compression_factor * 2 + lat_h = round(ideal_h / stride) + lat_w = round(ideal_w / stride) + h = lat_h * stride + w = lat_w * stride + log.info(f"Adaptive resolution set to: {w}x{h}") + else: + w, h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] + log.info(f"Fixed resolution set to: {w}x{h}") + + F = args.num_frames + lat_h = h // tokenizer.spatial_compression_factor + lat_w = w // tokenizer.spatial_compression_factor + lat_t = tokenizer.get_latent_num_frames(F) + + image_transforms = T.Compose([ + T.ToImage(), + T.Resize(size=(h, w), antialias=True), + T.ToDtype(torch.float32, scale=True), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + image_tensor = image_transforms(input_image).unsqueeze(0).to(device=tensor_kwargs["device"], dtype=torch.float32) + + log.info("Encoding image latents...") + with torch.no_grad(): + frames_to_encode = torch.cat( + [image_tensor.unsqueeze(2), torch.zeros(1, 3, F - 1, h, w, device=image_tensor.device)], dim=2 + ) + encoded_latents = tokenizer.encode(frames_to_encode) + del frames_to_encode + + del image_tensor, input_image + cleanup_all() + + # VAE Offload + if hasattr(tokenizer, "vae"): + log.info("Offloading VAE to CPU to save VRAM/RAM...") + tokenizer.vae.cpu() + cleanup_all() + + # Latent Setup + msk = torch.zeros(1, 4, lat_t, lat_h, lat_w, device=tensor_kwargs["device"], dtype=tensor_kwargs["dtype"]) + msk[:, :, 0, :, :] = 1.0 + y = torch.cat([msk, encoded_latents.to(**tensor_kwargs)], dim=1) + y = y.repeat(args.num_samples, 1, 1, 1, 1) + + del msk + + condition = { + "crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples), + "y_B_C_T_H_W": y + } + + # Noise Schedule + state_shape = [tokenizer.latent_ch, lat_t, lat_h, lat_w] + generator = torch.Generator(device=tensor_kwargs["device"]) + generator.manual_seed(args.seed) + init_noise = torch.randn(args.num_samples, *state_shape, dtype=torch.float32, device=tensor_kwargs["device"], generator=generator) + + mid_t = [1.5, 1.4, 1.0][: args.num_steps - 1] + t_steps = torch.tensor([math.atan(args.sigma_max), *mid_t, 0], dtype=torch.float64, device=init_noise.device) + t_steps = torch.sin(t_steps) / (torch.cos(t_steps) + torch.sin(t_steps)) + + x = init_noise.to(torch.float64) * t_steps[0] + ones = torch.ones(x.size(0), 1, device=x.device, dtype=x.dtype) + total_steps = t_steps.shape[0] - 1 + + del init_noise + + # 3. Sequential Loading Strategy (High Model) + model = create_model_gpu(dit_path=args.high_noise_model_path, args=args) + + switched = False + + with torch.inference_mode(): + for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), desc="Sampling", total=total_steps)): + + # Switch Logic: High -> Low + if (t_cur.item() < args.boundary) and (not switched): + log.info("Boundary reached. Switching to Low Noise Model...") + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + # Release High Model + del model + cleanup_all() + log.info("Waiting for memory release...") + time.sleep(5) + + # Load Low Model + model = create_model_gpu(dit_path=args.low_noise_model_path, args=args) + switched = True + + v_pred = model( + x_B_C_T_H_W=x.to(**tensor_kwargs), + timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), + **condition + ).to(torch.float64) + + if args.ode: + x = x - (t_cur - t_next) * v_pred + else: + x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn( + *x.shape, dtype=torch.float32, device=tensor_kwargs["device"], generator=generator, + ) + + del v_pred + + # Lowモデルも削除 + if torch.cuda.is_available(): + torch.cuda.synchronize() + del model + cleanup_all() + log.info("Inference done. All DiT models unloaded.") + + # 4. Decode Video + log.info("Cleaning up before decoding...") + del condition, y, encoded_latents, text_emb + cleanup_all() + + if hasattr(tokenizer, "vae"): + log.info("Moving VAE back to GPU for decoding...") + tokenizer.vae.to(tensor_kwargs["device"]) + + log.info("Decoding video...") + samples = x.float() + del x + cleanup_all() + + with torch.inference_mode(): + video = tokenizer.decode(samples) + del samples + cleanup_all() + + to_show = (1.0 + video.float().cpu().clamp(-1, 1)) / 2.0 + del video + cleanup_all() + + # Save + video_tensor = rearrange(to_show[0], "c t h w -> c t h w") + save_image_or_video(video_tensor, args.save_path, fps=16) + + log.info(f"Saved video to {args.save_path}") \ No newline at end of file diff --git a/turbodiffusion/rcm/networks/wan2pt2.py b/turbodiffusion/rcm/networks/wan2pt2.py index 6463110..bce80f7 100644 --- a/turbodiffusion/rcm/networks/wan2pt2.py +++ b/turbodiffusion/rcm/networks/wan2pt2.py @@ -27,6 +27,142 @@ from flash_attn.layers.rotary import apply_rotary_emb as flash_apply_rotary_emb except ImportError: flash_apply_rotary_emb = None + +# ---- Fallback RoPE when flash_attn is unavailable ---- +def _torch_apply_rotary_emb(x, cos, sin, interleaved=True, inplace=False): + """Torch fallback for flash_attn.layers.rotary.apply_rotary_emb. + Expects x float32; cos/sin broadcastable. + """ + import torch + x = x.to(torch.float32) + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + if interleaved: + x1 = x[..., ::2] + x2 = x[..., 1::2] + out1 = x1 * cos - x2 * sin + out2 = x1 * sin + x2 * cos + return torch.stack((out1, out2), dim=-1).flatten(-2) + else: + d = x.shape[-1] + x1 = x[..., : d // 2] + x2 = x[..., d // 2 :] + return torch.cat((x1 * cos - x2 * sin, x1 * sin + x2 * cos), dim=-1) + +# If flash-attn isnt installed, flash_apply_rotary_emb is None -> make it callable. +if "flash_apply_rotary_emb" in globals() and flash_apply_rotary_emb is None: + flash_apply_rotary_emb = _torch_apply_rotary_emb + +# ---- Improved Fallback RoPE (handles transposed cos/sin) ---- +def _torch_apply_rotary_emb_v2(x, cos, sin, interleaved=True, inplace=False): + """ + Torch fallback for flash_attn.layers.rotary.apply_rotary_emb. + Handles common layouts: + - cos/sin: (seqlen, half) OR (half, seqlen) (the latter needs transpose) + - cos/sin with leading singleton dims: (1,1,seqlen,half) etc. + """ + import torch + + x_f = x.to(torch.float32) + cos_f = cos.to(torch.float32) + sin_f = sin.to(torch.float32) + + if interleaved: + # x: (..., D) where D is even + x1 = x_f[..., ::2] + x2 = x_f[..., 1::2] + half = x1.shape[-1] + + # If cos/sin are (half, seqlen), transpose last two dims -> (seqlen, half) + if cos_f.ndim >= 2 and cos_f.shape[-1] != half and cos_f.shape[-2] == half: + cos_f = cos_f.transpose(-1, -2).contiguous() + sin_f = sin_f.transpose(-1, -2).contiguous() + + # If cos/sin have only leading singleton dims, squeeze them to 2D + if cos_f.ndim > 2 and all(d == 1 for d in cos_f.shape[:-2]): + cos2 = cos_f.view(cos_f.shape[-2], cos_f.shape[-1]) + sin2 = sin_f.view(sin_f.shape[-2], sin_f.shape[-1]) + else: + cos2, sin2 = cos_f, sin_f + + # If still not matching, sometimes cos is provided for full dim -> take even indices + if cos2.shape[-1] != half: + if cos2.shape[-1] == half * 2: + cos2 = cos2[..., ::2].contiguous() + sin2 = sin2[..., ::2].contiguous() + else: + raise RuntimeError(f"RoPE fallback mismatch: x={tuple(x_f.shape)} cos={tuple(cos.shape)} (after={tuple(cos2.shape)}) half={half}") + + # Broadcast cos/sin to x1 + if cos2.ndim == 2: + seqlen = cos2.shape[0] + # find which dim of x1 matches seqlen (excluding last dim) + candidates = [i for i,s in enumerate(x1.shape[:-1]) if s == seqlen] + seq_dim = candidates[-1] if candidates else (x1.ndim - 2) + shape = [1] * x1.ndim + shape[seq_dim] = seqlen + shape[-1] = half + cos_b = cos2.view(shape) + sin_b = sin2.view(shape) + else: + cos_b, sin_b = cos2, sin2 + while cos_b.ndim < x1.ndim: + cos_b = cos_b.unsqueeze(0) + sin_b = sin_b.unsqueeze(0) + + out1 = x1 * cos_b - x2 * sin_b + out2 = x1 * sin_b + x2 * cos_b + out = torch.stack((out1, out2), dim=-1).flatten(-2) + return out.to(dtype=x.dtype) + + else: + # non-interleaved: first half vs second half + d = x_f.shape[-1] + half = d // 2 + x1 = x_f[..., :half] + x2 = x_f[..., half:2*half] + rest = x_f[..., 2*half:] + + if cos_f.ndim >= 2 and cos_f.shape[-1] != half and cos_f.shape[-2] == half: + cos_f = cos_f.transpose(-1, -2).contiguous() + sin_f = sin_f.transpose(-1, -2).contiguous() + + if cos_f.ndim > 2 and all(d == 1 for d in cos_f.shape[:-2]): + cos2 = cos_f.view(cos_f.shape[-2], cos_f.shape[-1]) + sin2 = sin_f.view(sin_f.shape[-2], sin_f.shape[-1]) + else: + cos2, sin2 = cos_f, sin_f + + if cos2.shape[-1] != half: + raise RuntimeError(f"RoPE fallback mismatch (non-interleaved): x={tuple(x_f.shape)} cos={tuple(cos.shape)} (after={tuple(cos2.shape)}) half={half}") + + if cos2.ndim == 2: + seqlen = cos2.shape[0] + candidates = [i for i,s in enumerate(x1.shape[:-1]) if s == seqlen] + seq_dim = candidates[-1] if candidates else (x1.ndim - 2) + shape = [1] * x1.ndim + shape[seq_dim] = seqlen + shape[-1] = half + cos_b = cos2.view(shape) + sin_b = sin2.view(shape) + else: + cos_b, sin_b = cos2, sin2 + while cos_b.ndim < x1.ndim: + cos_b = cos_b.unsqueeze(0) + sin_b = sin_b.unsqueeze(0) + + y1 = x1 * cos_b - x2 * sin_b + y2 = x1 * sin_b + x2 * cos_b + out = torch.cat((y1, y2, rest), dim=-1) + return out.to(dtype=x.dtype) + +# prefer v2 when flash_attn is unavailable (or already using a torch fallback) +if "flash_apply_rotary_emb" in globals(): + if flash_apply_rotary_emb is None or getattr(flash_apply_rotary_emb, "__name__", "").startswith("_torch_apply_rotary_emb"): + flash_apply_rotary_emb = _torch_apply_rotary_emb_v2 + + print("flash_attn is not installed.") from torch.distributed import ProcessGroup, get_process_group_ranks diff --git a/turbodiffusion/rcm/networks/wan2pt2.py.bak_20260129_062409 b/turbodiffusion/rcm/networks/wan2pt2.py.bak_20260129_062409 new file mode 100644 index 0000000..061a21d --- /dev/null +++ b/turbodiffusion/rcm/networks/wan2pt2.py.bak_20260129_062409 @@ -0,0 +1,773 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +from typing import Optional + +import torch +import torch.amp as amp +import torch.nn as nn +from einops import rearrange, repeat + +try: + from flash_attn.layers.rotary import apply_rotary_emb as flash_apply_rotary_emb +except ImportError: + flash_apply_rotary_emb = None + +# ---- Fallback RoPE when flash_attn is unavailable ---- +def _torch_apply_rotary_emb(x, cos, sin, interleaved=True, inplace=False): + """Torch fallback for flash_attn.layers.rotary.apply_rotary_emb. + Expects x float32; cos/sin broadcastable. + """ + import torch + x = x.to(torch.float32) + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + if interleaved: + x1 = x[..., ::2] + x2 = x[..., 1::2] + out1 = x1 * cos - x2 * sin + out2 = x1 * sin + x2 * cos + return torch.stack((out1, out2), dim=-1).flatten(-2) + else: + d = x.shape[-1] + x1 = x[..., : d // 2] + x2 = x[..., d // 2 :] + return torch.cat((x1 * cos - x2 * sin, x1 * sin + x2 * cos), dim=-1) + +# If flash-attn isnt installed, flash_apply_rotary_emb is None -> make it callable. +if "flash_apply_rotary_emb" in globals() and flash_apply_rotary_emb is None: + flash_apply_rotary_emb = _torch_apply_rotary_emb + + print("flash_attn is not installed.") + +from torch.distributed import ProcessGroup, get_process_group_ranks +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper + +from imaginaire.utils import log +from rcm.utils.a2a_cp import MinimalA2AAttnOp +from rcm.utils.selective_activation_checkpoint import CheckpointMode, SACConfig +from rcm.utils.context_parallel import split_inputs_cp, cat_outputs_cp, cat_outputs_cp_with_grad, broadcast + +T5_CONTEXT_TOKEN_NUMBER = 512 +FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2 + + +class VideoRopePosition3DEmb(nn.Module): + def __init__( + self, + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + ): + super().__init__() + self.max_h = len_h + self.max_w = len_w + self.max_t = len_t + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self._dim_h = dim_h + self._dim_t = dim_t + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + self._is_initialized = False + + def cache_parameters(self) -> None: + if self._is_initialized: + return + + dim_h = self._dim_h + dim_t = self._dim_t + + self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().cuda() + self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h + self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t + self._is_initialized = True + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. + + Returns: + Not specified in the original code snippet. + """ + self.cache_parameters() + + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) + + B, T, H, W, _ = B_T_H_W_C + assert ( + H <= self.max_h and W <= self.max_w + ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" + freqs_h = torch.outer(self.seq[:H], h_spatial_freqs) + freqs_w = torch.outer(self.seq[:W], w_spatial_freqs) + + freqs_t = torch.outer(self.seq[:T], temporal_freqs) + + freqs_T_H_W_D = torch.cat( + [ + repeat(freqs_t, "t d -> t h w d", h=H, w=W), + repeat(freqs_h, "h d -> t h w d", t=T, w=W), + repeat(freqs_w, "w d -> t h w d", t=T, h=H), + ], + dim=-1, + ) + + return rearrange(freqs_T_H_W_D, "t h w d -> (t h w) d").float() + + @property + def seq_dim(self): + return 0 + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +def rope_apply(x, freqs): + """ + Optimized version of rope_apply using flash_attention's rotary embedding implementation. + This version processes the entire batch at once for efficiency. + + Args: + x (Tensor): Input tensor with shape [batch_size, seq_len, n_heads, head_dim] + freqs (Tensor): Complex frequencies with shape [max_seq_len, head_dim // 2] + + Returns: + Tensor: Rotary-embedded tensor with same shape as input + """ + batch_size, seq_len, n_heads, head_dim = x.shape + + # freqs is already sharded to local seq_len under flattened CP + freqs = freqs.view(seq_len, head_dim // 2) + cos = torch.cos(freqs).to(torch.float32) + sin = torch.sin(freqs).to(torch.float32) + + # Apply the rotation + rotated = flash_apply_rotary_emb(x.to(torch.float32), cos, sin, interleaved=True, inplace=False) + + return rotated.to(x.dtype) + + +class WanRMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def reset_parameters(self): + self.weight.data.fill_(1.0) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + with amp.autocast("cuda", dtype=torch.float32): + return super().forward(x.float()).type_as(x) + + +class WanSelfAttention(nn.Module): + def __init__(self, dim, num_heads, qk_norm=True, eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + self.eps = eps + self.qk_norm = qk_norm + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.attn_op = MinimalA2AAttnOp() + + def init_weights(self): + std = 1.0 / math.sqrt(self.dim) + torch.nn.init.trunc_normal_(self.q.weight, std=std) + torch.nn.init.trunc_normal_(self.k.weight, std=std) + torch.nn.init.trunc_normal_(self.v.weight, std=std) + torch.nn.init.trunc_normal_(self.o.weight, std=std) + # zero out bias + self.q.bias.data.zero_() + self.k.bias.data.zero_() + self.v.bias.data.zero_() + self.o.bias.data.zero_() + # reset norm weights + if self.qk_norm: + self.norm_q.reset_parameters() + self.norm_k.reset_parameters() + + def forward(self, x, seq_lens, freqs): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + q = rope_apply(q, freqs) + k = rope_apply(k, freqs) + + x = self.attn_op(q, k, v) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + def set_context_parallel_group(self, process_group, ranks, stream): + self.attn_op.set_context_parallel_group(process_group, ranks, stream) + + +class WanCrossAttention(WanSelfAttention): + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # compute attention + x = self.attn_op(q, k, v) + # output + x = x.flatten(2) + x = self.o(x) + return x + + +WAN_CROSSATTENTION_CLASSES = {"t2v_cross_attn": WanCrossAttention, "i2v_cross_attn": WanCrossAttention} + + +class WanAttentionBlock(nn.Module): + def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, qk_norm=True, cross_attn_norm=False, eps=1e-6): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, qk_norm, eps) + self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, qk_norm, eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def init_weights(self): + self.self_attn.init_weights() + self.cross_attn.init_weights() + + self.norm1.reset_parameters() + self.norm2.reset_parameters() + self.norm3.reset_parameters() + + std = 1.0 / math.sqrt(self.dim) + torch.nn.init.trunc_normal_(self.modulation, std=std) + + def forward(self, x, e, seq_lens, freqs, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 + with amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation + e).chunk(6, dim=1) + assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn((self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, freqs) + with amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[2].type_as(x) + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e): + x = x + self.cross_attn(self.norm3(x), context, context_lens) + y = self.ffn((self.norm2(x).float() * (1 + e[4]) + e[3]).type_as(x)) + with amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[5].type_as(x) + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class Head(nn.Module): + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def init_weights(self): + self.norm.reset_parameters() + + std = 1.0 / math.sqrt(self.dim) + torch.nn.init.trunc_normal_(self.modulation, std=std) + torch.nn.init.trunc_normal_(self.head.weight, std=std) + self.head.bias.data.zero_() + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + assert e.dtype == torch.float32 + with amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = self.head(self.norm(x) * (1 + e[1]) + e[0]) + return x + + +class MLPProj(torch.nn.Module): + def __init__(self, in_dim, out_dim, flf_pos_emb=False): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(in_dim), + torch.nn.Linear(in_dim, in_dim), + torch.nn.GELU(), + torch.nn.Linear(in_dim, out_dim), + torch.nn.LayerNorm(out_dim), + ) + + def init_weights(self): + self.proj[0].reset_parameters() + self.proj[1].reset_parameters() + self.proj[3].reset_parameters() + self.proj[4].reset_parameters() + + if hasattr(self, "emb_pos"): + self.emb_pos.data.zero_() + + def forward(self, image_embeds): + if hasattr(self, "emb_pos"): + bs, n, d = image_embeds.shape + image_embeds = image_embeds.view(-1, 2 * n, d) + image_embeds = image_embeds + self.emb_pos + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class WanModel(nn.Module): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__( + self, + model_type="t2v", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + sac_config: SACConfig = SACConfig(), + ): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ["t2v", "i2v"] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + self.use_crossattn_projection = False + + # embeddings + self.patch_embedding = nn.Linear(in_dim * patch_size[0] * patch_size[1] * patch_size[2], dim) + + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn" + self.blocks = nn.ModuleList( + [WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, qk_norm, cross_attn_norm, eps) for _ in range(num_layers)] + ) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + + d = dim // num_heads + + self.rope_position_embedding = VideoRopePosition3DEmb(head_dim=d, len_h=128, len_w=128, len_t=32) + + # initialize weights + self.init_weights() + + self.enable_selective_checkpoint(sac_config) + + def forward( + self, + x_B_C_T_H_W, + timesteps_B_T, + crossattn_emb, + y_B_C_T_H_W=None, + **kwargs, + ): + r""" + Forward pass through the diffusion model + + Args: + x_B_C_T_H_W (Tensor): + Input video tensor with shape [B, C_in, T, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + y_B_C_T_H_W (Tensor, *optional*): + Conditional video inputs for image-to-video mode, shape [B, C_in, T, H, W] + + Returns: + Tensor: + Denoised video tensor with shape [B, C_out, T, H / 8, W / 8] + """ + + cp_group = getattr(self, "_cp_group", None) + cp_enabled = (cp_group is not None) and (cp_group.size() > 1) + if cp_enabled: + x_B_C_T_H_W = broadcast(x_B_C_T_H_W, cp_group) + timesteps_B_T = broadcast(timesteps_B_T, cp_group) + crossattn_emb = broadcast(crossattn_emb, cp_group) + if y_B_C_T_H_W is not None: + y_B_C_T_H_W = broadcast(y_B_C_T_H_W, cp_group) + + assert timesteps_B_T.shape[1] == 1 + t_B = timesteps_B_T[:, 0] + del kwargs + if self.model_type == "i2v": + assert y_B_C_T_H_W is not None + + if y_B_C_T_H_W is not None: + x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, y_B_C_T_H_W], dim=1) + + kt, kh, kw = self.patch_size + B, _, T_in, H_in, W_in = x_B_C_T_H_W.shape + assert (T_in % kt) == 0 and (H_in % kh) == 0 and (W_in % kw) == 0 + T, H, W = T_in // kt, H_in // kh, W_in // kw + L = T * H * W + + # patchify and flatten + x_B_L_Din = rearrange( + x_B_C_T_H_W, + "b c (t kt) (h kh) (w kw) -> b (t h w) (c kt kh kw)", + kt=kt, + kh=kh, + kw=kw, + ).contiguous() + + if cp_enabled: + assert (L % cp_group.size()) == 0, f"L=T*H*W must be divisible by cp_size. Got L={L}, cp={cp_group.size()}." + x_B_L_Din = split_inputs_cp(x_B_L_Din, seq_dim=1, cp_group=cp_group) + + # embeddings + x_B_L_D = self.patch_embedding(x_B_L_Din) + seq_lens = torch.tensor([u.size(0) for u in x_B_L_D], dtype=torch.long) + + # time embeddings + with amp.autocast("cuda", dtype=torch.float32): + e_B_D = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t_B).float()) + e0_B_6_D = self.time_projection(e_B_D).unflatten(1, (6, self.dim)) + assert e_B_D.dtype == torch.float32 and e0_B_6_D.dtype == torch.float32 + + # context + context_lens = None + context_B_L_D = self.text_embedding(crossattn_emb) + + freqs = self.rope_position_embedding.generate_embeddings(torch.Size([B, T, H, W, self.dim])).contiguous() + if cp_enabled: + freqs = split_inputs_cp(freqs, seq_dim=self.rope_position_embedding.seq_dim, cp_group=cp_group) + + # arguments + kwargs = dict( + e=e0_B_6_D, + seq_lens=seq_lens, + freqs=freqs, + context=context_B_L_D, + context_lens=context_lens, + ) + + for block_idx, block in enumerate(self.blocks): + x_B_L_D = block(x_B_L_D, **kwargs) + + # head + x_B_L_Dout = self.head(x_B_L_D, e_B_D) + + if cp_enabled: + if torch.is_grad_enabled(): + x_B_L_Dout = cat_outputs_cp_with_grad(x_B_L_Dout, seq_dim=1, cp_group=cp_group) + else: + x_B_L_Dout = cat_outputs_cp(x_B_L_Dout, seq_dim=1, cp_group=cp_group) + + # unpatchify + x_B_C_T_H_W = rearrange( + x_B_L_Dout, + "b (t h w) (kt kh kw d) -> b d (t kt) (h kh) (w kw)", + kt=kt, + kh=kh, + kw=kw, + t=T, + h=H, + w=W, + d=self.out_dim, + ) + return x_B_C_T_H_W + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + for block in self.blocks: + block.init_weights() + self.head.init_weights() + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + nn.init.zeros_(self.patch_embedding.bias) + + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + for m in self.time_projection.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init output layer + nn.init.zeros_(self.head.head.weight) + if self.head.head.bias is not None: + nn.init.zeros_(self.head.head.bias) + + def fully_shard(self, mesh, mp_policy): + for i, block in enumerate(self.blocks): + fully_shard(block, mesh=mesh, mp_policy=mp_policy, reshard_after_forward=True) + fully_shard(self.head, mesh=mesh, mp_policy=mp_policy, reshard_after_forward=False) + fully_shard(self.text_embedding, mesh=mesh, mp_policy=mp_policy, reshard_after_forward=True) + fully_shard(self.time_embedding, mesh=mesh, mp_policy=mp_policy, reshard_after_forward=True) + fully_shard(self.patch_embedding, mesh=mesh, mp_policy=mp_policy, reshard_after_forward=True) + + def disable_context_parallel(self): + # attention + for block in self.blocks: + block.self_attn.set_context_parallel_group( + process_group=None, + ranks=None, + stream=torch.cuda.Stream(), + ) + + self._is_context_parallel_enabled = False + self._cp_group = None + + def enable_context_parallel(self, process_group: Optional[ProcessGroup] = None): + cp_ranks = get_process_group_ranks(process_group) + for block in self.blocks: + block.self_attn.set_context_parallel_group(process_group=process_group, ranks=cp_ranks, stream=torch.cuda.Stream()) + + self._is_context_parallel_enabled = True + self._cp_group = process_group + + @property + def is_context_parallel_enabled(self): + return self._is_context_parallel_enabled + + def enable_selective_checkpoint(self, sac_config: SACConfig): + if sac_config.mode == CheckpointMode.NONE: + return self + + log.info(f"Enable selective checkpoint with mm_only, for every {sac_config.every_n_blocks} blocks. Total blocks: {len(self.blocks)}") + _context_fn = sac_config.get_context_fn() + for block_id, block in self.blocks.named_children(): + if int(block_id) % sac_config.every_n_blocks == 0: + block = ptd_checkpoint_wrapper(block, context_fn=_context_fn, preserve_rng_state=False) + self.blocks.register_module(block_id, block) + self.register_module("head", ptd_checkpoint_wrapper(self.head, context_fn=_context_fn, preserve_rng_state=False)) + + return self diff --git a/turbodiffusion/rcm/networks/wan2pt2.py.bak_20260129_062642 b/turbodiffusion/rcm/networks/wan2pt2.py.bak_20260129_062642 new file mode 100644 index 0000000..bce80f7 --- /dev/null +++ b/turbodiffusion/rcm/networks/wan2pt2.py.bak_20260129_062642 @@ -0,0 +1,882 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +from typing import Optional + +import torch +import torch.amp as amp +import torch.nn as nn +from einops import rearrange, repeat + +try: + from flash_attn.layers.rotary import apply_rotary_emb as flash_apply_rotary_emb +except ImportError: + flash_apply_rotary_emb = None + +# ---- Fallback RoPE when flash_attn is unavailable ---- +def _torch_apply_rotary_emb(x, cos, sin, interleaved=True, inplace=False): + """Torch fallback for flash_attn.layers.rotary.apply_rotary_emb. + Expects x float32; cos/sin broadcastable. + """ + import torch + x = x.to(torch.float32) + cos = cos.to(torch.float32) + sin = sin.to(torch.float32) + + if interleaved: + x1 = x[..., ::2] + x2 = x[..., 1::2] + out1 = x1 * cos - x2 * sin + out2 = x1 * sin + x2 * cos + return torch.stack((out1, out2), dim=-1).flatten(-2) + else: + d = x.shape[-1] + x1 = x[..., : d // 2] + x2 = x[..., d // 2 :] + return torch.cat((x1 * cos - x2 * sin, x1 * sin + x2 * cos), dim=-1) + +# If flash-attn isnt installed, flash_apply_rotary_emb is None -> make it callable. +if "flash_apply_rotary_emb" in globals() and flash_apply_rotary_emb is None: + flash_apply_rotary_emb = _torch_apply_rotary_emb + +# ---- Improved Fallback RoPE (handles transposed cos/sin) ---- +def _torch_apply_rotary_emb_v2(x, cos, sin, interleaved=True, inplace=False): + """ + Torch fallback for flash_attn.layers.rotary.apply_rotary_emb. + Handles common layouts: + - cos/sin: (seqlen, half) OR (half, seqlen) (the latter needs transpose) + - cos/sin with leading singleton dims: (1,1,seqlen,half) etc. + """ + import torch + + x_f = x.to(torch.float32) + cos_f = cos.to(torch.float32) + sin_f = sin.to(torch.float32) + + if interleaved: + # x: (..., D) where D is even + x1 = x_f[..., ::2] + x2 = x_f[..., 1::2] + half = x1.shape[-1] + + # If cos/sin are (half, seqlen), transpose last two dims -> (seqlen, half) + if cos_f.ndim >= 2 and cos_f.shape[-1] != half and cos_f.shape[-2] == half: + cos_f = cos_f.transpose(-1, -2).contiguous() + sin_f = sin_f.transpose(-1, -2).contiguous() + + # If cos/sin have only leading singleton dims, squeeze them to 2D + if cos_f.ndim > 2 and all(d == 1 for d in cos_f.shape[:-2]): + cos2 = cos_f.view(cos_f.shape[-2], cos_f.shape[-1]) + sin2 = sin_f.view(sin_f.shape[-2], sin_f.shape[-1]) + else: + cos2, sin2 = cos_f, sin_f + + # If still not matching, sometimes cos is provided for full dim -> take even indices + if cos2.shape[-1] != half: + if cos2.shape[-1] == half * 2: + cos2 = cos2[..., ::2].contiguous() + sin2 = sin2[..., ::2].contiguous() + else: + raise RuntimeError(f"RoPE fallback mismatch: x={tuple(x_f.shape)} cos={tuple(cos.shape)} (after={tuple(cos2.shape)}) half={half}") + + # Broadcast cos/sin to x1 + if cos2.ndim == 2: + seqlen = cos2.shape[0] + # find which dim of x1 matches seqlen (excluding last dim) + candidates = [i for i,s in enumerate(x1.shape[:-1]) if s == seqlen] + seq_dim = candidates[-1] if candidates else (x1.ndim - 2) + shape = [1] * x1.ndim + shape[seq_dim] = seqlen + shape[-1] = half + cos_b = cos2.view(shape) + sin_b = sin2.view(shape) + else: + cos_b, sin_b = cos2, sin2 + while cos_b.ndim < x1.ndim: + cos_b = cos_b.unsqueeze(0) + sin_b = sin_b.unsqueeze(0) + + out1 = x1 * cos_b - x2 * sin_b + out2 = x1 * sin_b + x2 * cos_b + out = torch.stack((out1, out2), dim=-1).flatten(-2) + return out.to(dtype=x.dtype) + + else: + # non-interleaved: first half vs second half + d = x_f.shape[-1] + half = d // 2 + x1 = x_f[..., :half] + x2 = x_f[..., half:2*half] + rest = x_f[..., 2*half:] + + if cos_f.ndim >= 2 and cos_f.shape[-1] != half and cos_f.shape[-2] == half: + cos_f = cos_f.transpose(-1, -2).contiguous() + sin_f = sin_f.transpose(-1, -2).contiguous() + + if cos_f.ndim > 2 and all(d == 1 for d in cos_f.shape[:-2]): + cos2 = cos_f.view(cos_f.shape[-2], cos_f.shape[-1]) + sin2 = sin_f.view(sin_f.shape[-2], sin_f.shape[-1]) + else: + cos2, sin2 = cos_f, sin_f + + if cos2.shape[-1] != half: + raise RuntimeError(f"RoPE fallback mismatch (non-interleaved): x={tuple(x_f.shape)} cos={tuple(cos.shape)} (after={tuple(cos2.shape)}) half={half}") + + if cos2.ndim == 2: + seqlen = cos2.shape[0] + candidates = [i for i,s in enumerate(x1.shape[:-1]) if s == seqlen] + seq_dim = candidates[-1] if candidates else (x1.ndim - 2) + shape = [1] * x1.ndim + shape[seq_dim] = seqlen + shape[-1] = half + cos_b = cos2.view(shape) + sin_b = sin2.view(shape) + else: + cos_b, sin_b = cos2, sin2 + while cos_b.ndim < x1.ndim: + cos_b = cos_b.unsqueeze(0) + sin_b = sin_b.unsqueeze(0) + + y1 = x1 * cos_b - x2 * sin_b + y2 = x1 * sin_b + x2 * cos_b + out = torch.cat((y1, y2, rest), dim=-1) + return out.to(dtype=x.dtype) + +# prefer v2 when flash_attn is unavailable (or already using a torch fallback) +if "flash_apply_rotary_emb" in globals(): + if flash_apply_rotary_emb is None or getattr(flash_apply_rotary_emb, "__name__", "").startswith("_torch_apply_rotary_emb"): + flash_apply_rotary_emb = _torch_apply_rotary_emb_v2 + + + print("flash_attn is not installed.") + +from torch.distributed import ProcessGroup, get_process_group_ranks +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper as ptd_checkpoint_wrapper + +from imaginaire.utils import log +from rcm.utils.a2a_cp import MinimalA2AAttnOp +from rcm.utils.selective_activation_checkpoint import CheckpointMode, SACConfig +from rcm.utils.context_parallel import split_inputs_cp, cat_outputs_cp, cat_outputs_cp_with_grad, broadcast + +T5_CONTEXT_TOKEN_NUMBER = 512 +FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2 + + +class VideoRopePosition3DEmb(nn.Module): + def __init__( + self, + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + ): + super().__init__() + self.max_h = len_h + self.max_w = len_w + self.max_t = len_t + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self._dim_h = dim_h + self._dim_t = dim_t + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + self._is_initialized = False + + def cache_parameters(self) -> None: + if self._is_initialized: + return + + dim_h = self._dim_h + dim_t = self._dim_t + + self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().cuda() + self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().cuda() / dim_h + self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().cuda() / dim_t + self._is_initialized = True + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. + + Returns: + Not specified in the original code snippet. + """ + self.cache_parameters() + + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range) + + B, T, H, W, _ = B_T_H_W_C + assert ( + H <= self.max_h and W <= self.max_w + ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" + freqs_h = torch.outer(self.seq[:H], h_spatial_freqs) + freqs_w = torch.outer(self.seq[:W], w_spatial_freqs) + + freqs_t = torch.outer(self.seq[:T], temporal_freqs) + + freqs_T_H_W_D = torch.cat( + [ + repeat(freqs_t, "t d -> t h w d", h=H, w=W), + repeat(freqs_h, "h d -> t h w d", t=T, w=W), + repeat(freqs_w, "w d -> t h w d", t=T, h=H), + ], + dim=-1, + ) + + return rearrange(freqs_T_H_W_D, "t h w d -> (t h w) d").float() + + @property + def seq_dim(self): + return 0 + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +def rope_apply(x, freqs): + """ + Optimized version of rope_apply using flash_attention's rotary embedding implementation. + This version processes the entire batch at once for efficiency. + + Args: + x (Tensor): Input tensor with shape [batch_size, seq_len, n_heads, head_dim] + freqs (Tensor): Complex frequencies with shape [max_seq_len, head_dim // 2] + + Returns: + Tensor: Rotary-embedded tensor with same shape as input + """ + batch_size, seq_len, n_heads, head_dim = x.shape + + # freqs is already sharded to local seq_len under flattened CP + freqs = freqs.view(seq_len, head_dim // 2) + cos = torch.cos(freqs).to(torch.float32) + sin = torch.sin(freqs).to(torch.float32) + + # Apply the rotation + rotated = flash_apply_rotary_emb(x.to(torch.float32), cos, sin, interleaved=True, inplace=False) + + return rotated.to(x.dtype) + + +class WanRMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def reset_parameters(self): + self.weight.data.fill_(1.0) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + with amp.autocast("cuda", dtype=torch.float32): + return super().forward(x.float()).type_as(x) + + +class WanSelfAttention(nn.Module): + def __init__(self, dim, num_heads, qk_norm=True, eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + self.eps = eps + self.qk_norm = qk_norm + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.attn_op = MinimalA2AAttnOp() + + def init_weights(self): + std = 1.0 / math.sqrt(self.dim) + torch.nn.init.trunc_normal_(self.q.weight, std=std) + torch.nn.init.trunc_normal_(self.k.weight, std=std) + torch.nn.init.trunc_normal_(self.v.weight, std=std) + torch.nn.init.trunc_normal_(self.o.weight, std=std) + # zero out bias + self.q.bias.data.zero_() + self.k.bias.data.zero_() + self.v.bias.data.zero_() + self.o.bias.data.zero_() + # reset norm weights + if self.qk_norm: + self.norm_q.reset_parameters() + self.norm_k.reset_parameters() + + def forward(self, x, seq_lens, freqs): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + q = rope_apply(q, freqs) + k = rope_apply(k, freqs) + + x = self.attn_op(q, k, v) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + def set_context_parallel_group(self, process_group, ranks, stream): + self.attn_op.set_context_parallel_group(process_group, ranks, stream) + + +class WanCrossAttention(WanSelfAttention): + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # compute attention + x = self.attn_op(q, k, v) + # output + x = x.flatten(2) + x = self.o(x) + return x + + +WAN_CROSSATTENTION_CLASSES = {"t2v_cross_attn": WanCrossAttention, "i2v_cross_attn": WanCrossAttention} + + +class WanAttentionBlock(nn.Module): + def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, qk_norm=True, cross_attn_norm=False, eps=1e-6): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, qk_norm, eps) + self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, qk_norm, eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def init_weights(self): + self.self_attn.init_weights() + self.cross_attn.init_weights() + + self.norm1.reset_parameters() + self.norm2.reset_parameters() + self.norm3.reset_parameters() + + std = 1.0 / math.sqrt(self.dim) + torch.nn.init.trunc_normal_(self.modulation, std=std) + + def forward(self, x, e, seq_lens, freqs, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 + with amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation + e).chunk(6, dim=1) + assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn((self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, freqs) + with amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[2].type_as(x) + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e): + x = x + self.cross_attn(self.norm3(x), context, context_lens) + y = self.ffn((self.norm2(x).float() * (1 + e[4]) + e[3]).type_as(x)) + with amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[5].type_as(x) + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class Head(nn.Module): + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def init_weights(self): + self.norm.reset_parameters() + + std = 1.0 / math.sqrt(self.dim) + torch.nn.init.trunc_normal_(self.modulation, std=std) + torch.nn.init.trunc_normal_(self.head.weight, std=std) + self.head.bias.data.zero_() + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + assert e.dtype == torch.float32 + with amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = self.head(self.norm(x) * (1 + e[1]) + e[0]) + return x + + +class MLPProj(torch.nn.Module): + def __init__(self, in_dim, out_dim, flf_pos_emb=False): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(in_dim), + torch.nn.Linear(in_dim, in_dim), + torch.nn.GELU(), + torch.nn.Linear(in_dim, out_dim), + torch.nn.LayerNorm(out_dim), + ) + + def init_weights(self): + self.proj[0].reset_parameters() + self.proj[1].reset_parameters() + self.proj[3].reset_parameters() + self.proj[4].reset_parameters() + + if hasattr(self, "emb_pos"): + self.emb_pos.data.zero_() + + def forward(self, image_embeds): + if hasattr(self, "emb_pos"): + bs, n, d = image_embeds.shape + image_embeds = image_embeds.view(-1, 2 * n, d) + image_embeds = image_embeds + self.emb_pos + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class WanModel(nn.Module): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + def __init__( + self, + model_type="t2v", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + sac_config: SACConfig = SACConfig(), + ): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ["t2v", "i2v"] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + self.use_crossattn_projection = False + + # embeddings + self.patch_embedding = nn.Linear(in_dim * patch_size[0] * patch_size[1] * patch_size[2], dim) + + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn" + self.blocks = nn.ModuleList( + [WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, qk_norm, cross_attn_norm, eps) for _ in range(num_layers)] + ) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + + d = dim // num_heads + + self.rope_position_embedding = VideoRopePosition3DEmb(head_dim=d, len_h=128, len_w=128, len_t=32) + + # initialize weights + self.init_weights() + + self.enable_selective_checkpoint(sac_config) + + def forward( + self, + x_B_C_T_H_W, + timesteps_B_T, + crossattn_emb, + y_B_C_T_H_W=None, + **kwargs, + ): + r""" + Forward pass through the diffusion model + + Args: + x_B_C_T_H_W (Tensor): + Input video tensor with shape [B, C_in, T, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + y_B_C_T_H_W (Tensor, *optional*): + Conditional video inputs for image-to-video mode, shape [B, C_in, T, H, W] + + Returns: + Tensor: + Denoised video tensor with shape [B, C_out, T, H / 8, W / 8] + """ + + cp_group = getattr(self, "_cp_group", None) + cp_enabled = (cp_group is not None) and (cp_group.size() > 1) + if cp_enabled: + x_B_C_T_H_W = broadcast(x_B_C_T_H_W, cp_group) + timesteps_B_T = broadcast(timesteps_B_T, cp_group) + crossattn_emb = broadcast(crossattn_emb, cp_group) + if y_B_C_T_H_W is not None: + y_B_C_T_H_W = broadcast(y_B_C_T_H_W, cp_group) + + assert timesteps_B_T.shape[1] == 1 + t_B = timesteps_B_T[:, 0] + del kwargs + if self.model_type == "i2v": + assert y_B_C_T_H_W is not None + + if y_B_C_T_H_W is not None: + x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, y_B_C_T_H_W], dim=1) + + kt, kh, kw = self.patch_size + B, _, T_in, H_in, W_in = x_B_C_T_H_W.shape + assert (T_in % kt) == 0 and (H_in % kh) == 0 and (W_in % kw) == 0 + T, H, W = T_in // kt, H_in // kh, W_in // kw + L = T * H * W + + # patchify and flatten + x_B_L_Din = rearrange( + x_B_C_T_H_W, + "b c (t kt) (h kh) (w kw) -> b (t h w) (c kt kh kw)", + kt=kt, + kh=kh, + kw=kw, + ).contiguous() + + if cp_enabled: + assert (L % cp_group.size()) == 0, f"L=T*H*W must be divisible by cp_size. Got L={L}, cp={cp_group.size()}." + x_B_L_Din = split_inputs_cp(x_B_L_Din, seq_dim=1, cp_group=cp_group) + + # embeddings + x_B_L_D = self.patch_embedding(x_B_L_Din) + seq_lens = torch.tensor([u.size(0) for u in x_B_L_D], dtype=torch.long) + + # time embeddings + with amp.autocast("cuda", dtype=torch.float32): + e_B_D = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t_B).float()) + e0_B_6_D = self.time_projection(e_B_D).unflatten(1, (6, self.dim)) + assert e_B_D.dtype == torch.float32 and e0_B_6_D.dtype == torch.float32 + + # context + context_lens = None + context_B_L_D = self.text_embedding(crossattn_emb) + + freqs = self.rope_position_embedding.generate_embeddings(torch.Size([B, T, H, W, self.dim])).contiguous() + if cp_enabled: + freqs = split_inputs_cp(freqs, seq_dim=self.rope_position_embedding.seq_dim, cp_group=cp_group) + + # arguments + kwargs = dict( + e=e0_B_6_D, + seq_lens=seq_lens, + freqs=freqs, + context=context_B_L_D, + context_lens=context_lens, + ) + + for block_idx, block in enumerate(self.blocks): + x_B_L_D = block(x_B_L_D, **kwargs) + + # head + x_B_L_Dout = self.head(x_B_L_D, e_B_D) + + if cp_enabled: + if torch.is_grad_enabled(): + x_B_L_Dout = cat_outputs_cp_with_grad(x_B_L_Dout, seq_dim=1, cp_group=cp_group) + else: + x_B_L_Dout = cat_outputs_cp(x_B_L_Dout, seq_dim=1, cp_group=cp_group) + + # unpatchify + x_B_C_T_H_W = rearrange( + x_B_L_Dout, + "b (t h w) (kt kh kw d) -> b d (t kt) (h kh) (w kw)", + kt=kt, + kh=kh, + kw=kw, + t=T, + h=H, + w=W, + d=self.out_dim, + ) + return x_B_C_T_H_W + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + for block in self.blocks: + block.init_weights() + self.head.init_weights() + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + nn.init.zeros_(self.patch_embedding.bias) + + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + for m in self.time_projection.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init output layer + nn.init.zeros_(self.head.head.weight) + if self.head.head.bias is not None: + nn.init.zeros_(self.head.head.bias) + + def fully_shard(self, mesh, mp_policy): + for i, block in enumerate(self.blocks): + fully_shard(block, mesh=mesh, mp_policy=mp_policy, reshard_after_forward=True) + fully_shard(self.head, mesh=mesh, mp_policy=mp_policy, reshard_after_forward=False) + fully_shard(self.text_embedding, mesh=mesh, mp_policy=mp_policy, reshard_after_forward=True) + fully_shard(self.time_embedding, mesh=mesh, mp_policy=mp_policy, reshard_after_forward=True) + fully_shard(self.patch_embedding, mesh=mesh, mp_policy=mp_policy, reshard_after_forward=True) + + def disable_context_parallel(self): + # attention + for block in self.blocks: + block.self_attn.set_context_parallel_group( + process_group=None, + ranks=None, + stream=torch.cuda.Stream(), + ) + + self._is_context_parallel_enabled = False + self._cp_group = None + + def enable_context_parallel(self, process_group: Optional[ProcessGroup] = None): + cp_ranks = get_process_group_ranks(process_group) + for block in self.blocks: + block.self_attn.set_context_parallel_group(process_group=process_group, ranks=cp_ranks, stream=torch.cuda.Stream()) + + self._is_context_parallel_enabled = True + self._cp_group = process_group + + @property + def is_context_parallel_enabled(self): + return self._is_context_parallel_enabled + + def enable_selective_checkpoint(self, sac_config: SACConfig): + if sac_config.mode == CheckpointMode.NONE: + return self + + log.info(f"Enable selective checkpoint with mm_only, for every {sac_config.every_n_blocks} blocks. Total blocks: {len(self.blocks)}") + _context_fn = sac_config.get_context_fn() + for block_id, block in self.blocks.named_children(): + if int(block_id) % sac_config.every_n_blocks == 0: + block = ptd_checkpoint_wrapper(block, context_fn=_context_fn, preserve_rng_state=False) + self.blocks.register_module(block_id, block) + self.register_module("head", ptd_checkpoint_wrapper(self.head, context_fn=_context_fn, preserve_rng_state=False)) + + return self From ab54a51be40ef352df95c63379e47dec8a3cbd49 Mon Sep 17 00:00:00 2001 From: Media Studio Dev Date: Fri, 30 Jan 2026 23:56:04 +0000 Subject: [PATCH 2/2] fix: int8 gemm dtype + tokenizer fallback + modify_model dtype fixes (submodule) --- turbodiffusion/inference/modify_model.py | 2 +- turbodiffusion/ops/core.py | 23 +++++++++- turbodiffusion/rcm/utils/umt5.py | 58 ++++++++++++++++++++++-- 3 files changed, 76 insertions(+), 7 deletions(-) diff --git a/turbodiffusion/inference/modify_model.py b/turbodiffusion/inference/modify_model.py index 43e35c5..9ce0513 100644 --- a/turbodiffusion/inference/modify_model.py +++ b/turbodiffusion/inference/modify_model.py @@ -81,7 +81,7 @@ def replace_linear_norm( return model -tensor_kwargs = {"device": "cuda", "dtype": torch.bfloat16} +tensor_kwargs = {"device": "cuda", "dtype": torch.float16} def select_model(model_name: str) -> torch.nn.Module: if model_name == "Wan2.1-1.3B": diff --git a/turbodiffusion/ops/core.py b/turbodiffusion/ops/core.py index 6d74fe5..412ec7d 100644 --- a/turbodiffusion/ops/core.py +++ b/turbodiffusion/ops/core.py @@ -50,10 +50,29 @@ def int8_linear( x = x.reshape(-1, shape[-1]) m = x.shape[0] n = w_q.shape[0] - y = torch.zeros(m, n, dtype=x.dtype, device=x.device) + # Allocate accumulator in a dtype accepted by the GEMM kernel (float16 or bfloat16) + # The CUDA kernel supports kHalf (float16) and kBFloat16 for the output C matrix. + out_dtype = torch.bfloat16 if x.dtype == torch.bfloat16 else torch.float16 + y = torch.zeros(m, n, dtype=out_dtype, device=x.device) x_q, x_s = int8_quant(x) - gemm_cuda(x_q, x_s, w_q, w_s, y) + + # Ensure scales are float32 as required by the gemm kernel + if x_s.dtype != torch.float32: + x_s = x_s.to(torch.float32) + if w_s.dtype != torch.float32: + w_s = w_s.to(torch.float32) + + # Debug check - raise informative error if kernel doesn't accept these dtypes + try: + gemm_cuda(x_q, x_s, w_q, w_s, y) + except RuntimeError as e: + raise RuntimeError( + f"Unsupported output data type for int8 gemm. dtypes -> x_q:{x_q.dtype}, x_s:{x_s.dtype}, w_q:{w_q.dtype}, w_s:{w_s.dtype}, y:{y.dtype}. Inner: {e}" + ) + + # Cast back to input dtype to preserve model dtype expectations + y = y.to(x.dtype) return y.reshape(*shape[:-1], n) def flatten_if_batched(*tensors): diff --git a/turbodiffusion/rcm/utils/umt5.py b/turbodiffusion/rcm/utils/umt5.py index 28b6f8b..ff2c054 100644 --- a/turbodiffusion/rcm/utils/umt5.py +++ b/turbodiffusion/rcm/utils/umt5.py @@ -21,6 +21,7 @@ import ftfy import regex as re +import os import torch import torch.nn as nn import torch.nn.functional as F @@ -62,10 +63,59 @@ def __init__(self, name, seq_len=None, clean=None, **kwargs): self.seq_len = seq_len self.clean = clean - # init tokenizer - self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) - self.vocab_size = self.tokenizer.vocab_size - + # init tokenizer with controlled download behavior and robust fallback + allow_downloads = os.getenv("HF_ALLOW_DOWNLOADS", "0").lower() in ("1", "true", "yes") + hf_token = os.getenv("HF_TOKEN", None) + _kwargs = dict(kwargs) + if hf_token: + _kwargs["use_auth_token"] = hf_token + + if not allow_downloads: + # Prefer local cached tokenizer to avoid network timeouts + try: + self.tokenizer = AutoTokenizer.from_pretrained(name, local_files_only=True, **_kwargs) + self.vocab_size = self.tokenizer.vocab_size + log.info(f"Loaded tokenizer '{name}' from local cache.") + except Exception as e: + log.info(f"Local tokenizer not found for '{name}' and downloads disabled; using DummyTokenizer fallback. ({e})") + use_dummy = True + else: + try: + self.tokenizer = AutoTokenizer.from_pretrained(name, **_kwargs) + self.vocab_size = self.tokenizer.vocab_size + log.info(f"Loaded tokenizer '{name}' from remote or cache.") + use_dummy = False + except Exception as e: + log.warn(f"AutoTokenizer load failed ({e}); falling back to DummyTokenizer.") + use_dummy = True + + if 'use_dummy' in locals() and use_dummy: + class DummyTokenizer: + def __init__(self, seq_len=None): + self.seq_len = seq_len + self.vocab_size = 32000 + def __call__(self, sequence, **kwargs): + # Simple whitespace tokenization with hashing to ids + def tok(s): + toks = s.split()[: (self.seq_len or 512)] + ids = [abs(hash(w)) % self.vocab_size + 1 for w in toks] + # pad/truncate + if self.seq_len is not None: + ids = ids + [0] * max(0, self.seq_len - len(ids)) + return {"input_ids": torch.tensor([ids], dtype=torch.long), "attention_mask": torch.tensor([[1 if x!=0 else 0 for x in ids]], dtype=torch.long)} + if isinstance(sequence, str): + return tok(sequence) + else: + # batch + batch_ids = [] + batch_masks = [] + for s in sequence: + out = tok(s) + batch_ids.append(out["input_ids"][0]) + batch_masks.append(out["attention_mask"][0]) + return type("_", (), {"input_ids": torch.stack(batch_ids), "attention_mask": torch.stack(batch_masks)}) + self.tokenizer = DummyTokenizer(seq_len=self.seq_len) + self.vocab_size = self.tokenizer.vocab_size def __call__(self, sequence, **kwargs): return_mask = kwargs.pop("return_mask", False)