Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions lightx2v/common/ops/mm/mm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,15 @@
scaled_mxfp8_quant, cutlass_scaled_mxfp8_mm = None, None

try:
from vllm import _custom_ops as ops
from vllm import _custom_ops as vllm_ops
except ImportError:
ops = None
vllm_ops = None


try:
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 as sglang_int8_act_quant
except ImportError:
sglang_int8_act_quant = None

try:
import sgl_kernel
Expand Down Expand Up @@ -527,6 +533,8 @@ def per_block_cast_to_fp8(self, x):
# =========================
def act_quant_int8_perchannel_sym_torchao(self, x):
input_tensor_quant, input_tensor_scale = torchao_int8_quant(x)
if self.scale_force_fp32:
input_tensor_scale = input_tensor_scale.to(torch.float32)
return input_tensor_quant, input_tensor_scale

def act_quant_fp8_perchannel_sym_torchao(self, x):
Expand All @@ -537,7 +545,7 @@ def act_quant_fp8_perchannel_sym_torchao(self, x):
return quantized, scale.float()

def act_quant_fp8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
input_tensor_quant, input_tensor_scale = vllm_ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
return input_tensor_quant, input_tensor_scale

def act_quant_fp8_perchannel_sym_sgl(self, x):
Expand All @@ -548,7 +556,7 @@ def act_quant_fp8_perchannel_sym_sgl(self, x):
return input_tensor_quant, input_tensor_scale

def act_quant_int8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
input_tensor_quant, input_tensor_scale, _ = vllm_ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
return input_tensor_quant, input_tensor_scale

def act_quant_nvfp4(self, x):
Expand Down Expand Up @@ -1336,7 +1344,7 @@ def __init__(
self.weight_need_transpose = False
self.bias_force_fp32 = True
self.scale_force_fp32 = True
if ops is not None:
if vllm_ops is not None:
self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
else:
self.act_quant_func = fp8_quantize_triton
Expand Down Expand Up @@ -1394,7 +1402,7 @@ def __init__(
self.weight_need_transpose = False
self.bias_force_fp32 = True
self.scale_force_fp32 = True
if ops is not None:
if vllm_ops is not None:
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
else:
self.act_quant_func = int8_quantize_triton
Expand Down Expand Up @@ -1661,7 +1669,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
Kernel: quant-mm using Sgl-kernel, act dynamic quant using sglang and fallback to whatever quant backend available with this priority: vllm > torchao > triton
"""

def __init__(
Expand All @@ -1688,7 +1696,15 @@ def __init__(
lora_path,
)
self.load_func = self.load_int8_perchannel_sym
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
# Priority sglang > vllm > toarchao > triton
if sglang_int8_act_quant:
self.act_quant_func = sglang_int8_act_quant
elif vllm_ops:
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
elif torchao_int8_quant:
self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao
else:
self.act_quant_func = int8_quantize_triton
self.weight_need_transpose = True
self.scale_force_fp32 = True

Expand Down