Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 3 additions & 26 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,10 @@
ops_dir = Path(__file__).parent / "turbodiffusion" / "ops"
cutlass_dir = ops_dir / "cutlass"

nvcc_flags = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=--verbose,--warn-on-local-memory-usage",
"-lineinfo",
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
"-DNDEBUG",
"-Xcompiler",
"-fPIC"
]
nvcc_flags = ["-gencode", "arch=compute_86,code=sm_86"]

cc_flag = [
"-gencode", "arch=compute_120a,code=sm_120a",
"-gencode", "arch=compute_100,code=sm_100",
"-gencode", "arch=compute_90,code=sm_90",
"-gencode", "arch=compute_89,code=sm_89",
"-gencode", "arch=compute_80,code=sm_80"
]
]

ext_modules = [
CUDAExtension(
Expand All @@ -60,7 +37,7 @@
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"],
"nvcc": nvcc_flags + ["-DEXECMODE=0"] + cc_flag + ["--threads", "4"],
"nvcc": nvcc_flags + ["-DEXECMODE=0"] + cc_flag + ["--threads", "1"],
},
include_dirs=[
cutlass_dir / "include",
Expand Down
62 changes: 62 additions & 0 deletions setup.py.bak_20260129_034900
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Copyright (c) 2025 by TurboDiffusion team.

Licensed under the Apache License, Version 2.0 (the "License");

Citation (please cite if you use this code):

@article{zhang2025turbodiffusion,
title={TurboDiffusion: Accelerating Video Diffusion Models by 100-200 Times},
author={Zhang, Jintao and Zheng, Kaiwen and Jiang, Kai and Wang, Haoxu and Stoica, Ion and Gonzalez, Joseph E and Chen, Jianfei and Zhu, Jun},
journal={arXiv preprint arXiv:2512.16093},
year={2025}
}
"""

from pathlib import Path
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

ops_dir = Path(__file__).parent / "turbodiffusion" / "ops"
cutlass_dir = ops_dir / "cutlass"

nvcc_flags = ["-gencode", "arch=compute_86,code=sm_86"]

cc_flag = [
"-gencode", "arch=compute_120a,code=sm_120a",
"-gencode", "arch=compute_100,code=sm_100",
"-gencode", "arch=compute_90,code=sm_90",
"-gencode", "arch=compute_89,code=sm_89",
"-gencode", "arch=compute_80,code=sm_80"
]

ext_modules = [
CUDAExtension(
name="turbo_diffusion_ops",
sources=[
"turbodiffusion/ops/bindings.cpp",
"turbodiffusion/ops/quant/quant.cu",
"turbodiffusion/ops/norm/rmsnorm.cu",
"turbodiffusion/ops/norm/layernorm.cu",
"turbodiffusion/ops/gemm/gemm.cu"
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"],
"nvcc": nvcc_flags + ["-DEXECMODE=0"] + cc_flag + ["--threads", "1"],
},
include_dirs=[
cutlass_dir / "include",
cutlass_dir / "tools" / "util" / "include",
ops_dir
],
libraries=["cuda"],
)
]

setup(
packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks")
),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
)
258 changes: 258 additions & 0 deletions turbodiffusion/SLA/core.py.bak_20260129_070906
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
"""
Copyright (c) 2025 by SLA team.

Licensed under the Apache License, Version 2.0 (the "License");

Citation (please cite if you use this code):

@article{zhang2025sla,
title={SLA: Beyond Sparsity in Diffusion Transformers via Fine-Tunable Sparse-Linear Attention},
author={Jintao Zhang and Haoxu Wang and Kai Jiang and Shuo Yang and Kaiwen Zheng and Haocheng Xi and Ziteng Wang and Hongzhou Zhu and Min Zhao and Ion Stoica and Joseph E. Gonzalez and Jun Zhu and Jianfei Chen},
journal={arXiv preprint arXiv:2509.24006},
year={2025}
}
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

SAGESLA_ENABLED = True
try:
import spas_sage_attn._qattn as qattn
import spas_sage_attn._fused as fused
from spas_sage_attn.utils import get_vanilla_qk_quant, block_map_lut_triton
except ImportError:
SAGESLA_ENABLED = False

SAGE2PP_ENABLED = True
try:
from spas_sage_attn._qattn import qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold
except ImportError:
SAGE2PP_ENABLED = False

from .kernel import _attention
from .utils import get_block_map, get_cuda_arch


class SparseLinearAttention(nn.Module):
def __init__(self, head_dim, topk, feature_map='softmax', BLKQ=64, BLKK=64, use_bf16=True, tie_feature_map_qk=True):
R'''
Args:
head_dim: dimension of each head.
topk: ratio of keys selected for sparse attention, shared across all queries.
feature_map: feature map for linear attention, one of ['hedgehog', 'elu', 'relu', 'softmax'].
BLKQ: block size for query.
BLKK: block size for key.
use_bf16: whether to use bfloat16 (default) or float16 for computation. The conversion to bf16/fp16 is done inside the module.
tie_feature_map_qk: whether to use the same feature map for query and key.
'''
super().__init__()
self.dtype = torch.bfloat16 if use_bf16 else torch.float16
self.topk = topk
self.BLKQ = BLKQ
self.BLKK = BLKK
self.proj_l = nn.Linear(head_dim, head_dim, dtype=torch.float32)

if feature_map == 'elu':
def elu_feature_map(x):
return F.elu(x) + 1
self.feature_map_q = elu_feature_map
self.feature_map_k = elu_feature_map
elif feature_map == 'relu':
self.feature_map_q = nn.ReLU()
self.feature_map_k = nn.ReLU()
elif feature_map == 'softmax':
def softmax_feature_map(x):
return F.softmax(x, dim=-1)
self.feature_map_q = softmax_feature_map
self.feature_map_k = softmax_feature_map
else:
raise NotImplementedError(f'Not supported feature map {feature_map}.')

if tie_feature_map_qk:
self.feature_map_k = self.feature_map_q

self.init_weights_()

def init_weights_(self):
with torch.no_grad():
nn.init.zeros_(self.proj_l.weight)
nn.init.zeros_(self.proj_l.bias)

def forward(self, q, k, v, return_sparsity=False):
R'''
Args:
q: queries of shape (B, H, L, D).
k: keys of shape (B, H, L, D).
v: values of shape (B, H, L, D).
return_sparsity: whether to return the actual sparsity.
'''
dtype = q.dtype

q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()

sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=self.BLKQ, BLKK=self.BLKK)

q = q.to(self.dtype)
k = k.to(self.dtype)
v = v.to(self.dtype)
o_s = _attention.apply(q, k, v, sparse_map, lut, real_topk, self.BLKQ, self.BLKK)

q = self.feature_map_q(q).contiguous().to(self.dtype) # c_q
k = self.feature_map_k(k).contiguous().to(self.dtype) # c_k
def calc_linear(q, k, v):
kvsum = k.transpose(-1, -2) @ v
ksum = torch.sum(k, dim=-2, keepdim=True)
return (q @ kvsum) / (1e-5 + (q * ksum).sum(dim=-1, keepdim=True))
o_l = calc_linear(q, k, v)

with torch.amp.autocast('cuda', dtype=self.dtype):
o_l = self.proj_l(o_l)
o = (o_s + o_l).to(dtype).transpose(1, 2)

if return_sparsity:
return o, real_topk / sparse_map.shape[-1]
else:
return o


class SageSparseLinearAttention(nn.Module):
def __init__(self, head_dim, topk, feature_map='softmax', use_bf16=True, tie_feature_map_qk=True):
R'''
Args:
head_dim: dimension of each head.
topk: ratio of keys selected for sparse attention, shared across all queries.
feature_map: feature map for linear attention, one of ['hedgehog', 'elu', 'relu', 'softmax'].
BLKQ: block size for query.
BLKK: block size for key.
use_bf16: whether to use bfloat16 (default) or float16 for computation. The conversion to bf16/fp16 is done inside the module.
tie_feature_map_qk: whether to use the same feature map for query and key.
timestep_adaptive_topk: whether to adaptively adjust topk during diffusion.
'''
assert SAGESLA_ENABLED, "Install SpargeAttn first to enable SageSLA."

super().__init__()
self.dtype = torch.bfloat16 if use_bf16 else torch.float16
self.topk = topk
self.proj_l = nn.Linear(head_dim, head_dim, dtype=torch.float32)

if feature_map == 'elu':
def elu_feature_map(x):
return F.elu(x) + 1
self.feature_map_q = elu_feature_map
self.feature_map_k = elu_feature_map
elif feature_map == 'relu':
self.feature_map_q = nn.ReLU()
self.feature_map_k = nn.ReLU()
elif feature_map == 'softmax':
def softmax_feature_map(x):
return F.softmax(x, dim=-1)
self.feature_map_q = softmax_feature_map
self.feature_map_k = softmax_feature_map
else:
raise NotImplementedError(f'Not supported feature map {feature_map}.')

if tie_feature_map_qk:
self.feature_map_k = self.feature_map_q

self.init_weights_()

def init_weights_(self):
with torch.no_grad():
nn.init.zeros_(self.proj_l.weight)
nn.init.zeros_(self.proj_l.bias)

def forward(self, q, k, v, return_sparsity=False):
R'''
Args:
q: queries of shape (B, H, L, D).
k: keys of shape (B, H, L, D).
v: values of shape (B, H, L, D).
return_sparsity: whether to return the actual sparsity.
timestep: current timestep for diffusion models.
total_timesteps: total timesteps for diffusion models.
'''

dtype = q.dtype

q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()

arch = get_cuda_arch(q.device.index)
if arch == "sm90":
sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=64, BLKK=128)
else:
sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=128, BLKK=64)

q = q.to(self.dtype)
k = k.to(self.dtype)
v = v.to(self.dtype)

########## SPARGE BEGIN ##########

km = k.mean(dim=-2, keepdim=True)
headdim = q.size(-1)

if arch == "sm90":
q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, 64, 128)
else:
q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, 128, 64)
lut, valid_block_num = block_map_lut_triton(sparse_map)
scale = 1.0 / (headdim ** 0.5)

assert headdim in [64, 128], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."

o_s = torch.empty_like(q)

if arch in ("sm80", "sm86", "sm87"):
pvthreshold = torch.full((q.shape[-3],), 1e6, dtype=torch.float32, device=q.device)
v_fp16 = v.to(torch.float16)
qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold(
q_int8, k_int8, v_fp16, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, 1, False, 1, scale, 0
)
else:
b, h_kv, kv_len, head_dim = v.shape
padded_len = (kv_len + 127) // 128 * 128
v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device)
fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1)
v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device)
v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1)

if arch == "sm90":
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_sm90(
q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, q_scale, k_scale, v_scale, 1, False, 1, scale
)
else:
pvthreshold = torch.full((q.shape[-3],), 1e6, dtype=torch.float32, device=q.device)
if SAGE2PP_ENABLED:
qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(
q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, v_scale, 1, False, 1, scale, 0
)
else:
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(
q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, v_scale, 1, False, 1, scale, 0
)

########## SPARGE END ##########

q = self.feature_map_q(q).contiguous().to(self.dtype) # c_q
k = self.feature_map_k(k).contiguous().to(self.dtype) # c_k
def calc_linear(q, k, v):
kvsum = k.transpose(-1, -2) @ v
ksum = torch.sum(k, dim=-2, keepdim=True)
return (q @ kvsum) / (1e-5 + (q * ksum).sum(dim=-1, keepdim=True))
o_l = calc_linear(q, k, v)

with torch.amp.autocast('cuda', dtype=self.dtype):
o_l = self.proj_l(o_l)
o = (o_s + o_l).to(dtype).transpose(1, 2)

if return_sparsity:
return o, real_topk / sparse_map.shape[-1]
else:
return o
2 changes: 1 addition & 1 deletion turbodiffusion/inference/modify_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def replace_linear_norm(
return model


tensor_kwargs = {"device": "cuda", "dtype": torch.bfloat16}
tensor_kwargs = {"device": "cuda", "dtype": torch.float16}

def select_model(model_name: str) -> torch.nn.Module:
if model_name == "Wan2.1-1.3B":
Expand Down
Loading