diff --git a/lightx2v/common/ops/mm/mm_weight.py b/lightx2v/common/ops/mm/mm_weight.py index 6cbf1469..ea8abfd6 100755 --- a/lightx2v/common/ops/mm/mm_weight.py +++ b/lightx2v/common/ops/mm/mm_weight.py @@ -41,9 +41,15 @@ scaled_mxfp8_quant, cutlass_scaled_mxfp8_mm = None, None try: - from vllm import _custom_ops as ops + from vllm import _custom_ops as vllm_ops except ImportError: - ops = None + vllm_ops = None + + +try: + from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 as sglang_int8_act_quant +except ImportError: + sglang_int8_act_quant = None try: import sgl_kernel @@ -527,6 +533,8 @@ def per_block_cast_to_fp8(self, x): # ========================= def act_quant_int8_perchannel_sym_torchao(self, x): input_tensor_quant, input_tensor_scale = torchao_int8_quant(x) + if self.scale_force_fp32: + input_tensor_scale = input_tensor_scale.to(torch.float32) return input_tensor_quant, input_tensor_scale def act_quant_fp8_perchannel_sym_torchao(self, x): @@ -537,7 +545,7 @@ def act_quant_fp8_perchannel_sym_torchao(self, x): return quantized, scale.float() def act_quant_fp8_perchannel_sym_vllm(self, x): - input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True) + input_tensor_quant, input_tensor_scale = vllm_ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True) return input_tensor_quant, input_tensor_scale def act_quant_fp8_perchannel_sym_sgl(self, x): @@ -548,7 +556,7 @@ def act_quant_fp8_perchannel_sym_sgl(self, x): return input_tensor_quant, input_tensor_scale def act_quant_int8_perchannel_sym_vllm(self, x): - input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) + input_tensor_quant, input_tensor_scale, _ = vllm_ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) return input_tensor_quant, input_tensor_scale def act_quant_nvfp4(self, x): @@ -1336,7 +1344,7 @@ def __init__( self.weight_need_transpose = False self.bias_force_fp32 = True self.scale_force_fp32 = True - if ops is not None: + if vllm_ops is not None: self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm else: self.act_quant_func = fp8_quantize_triton @@ -1394,7 +1402,7 @@ def __init__( self.weight_need_transpose = False self.bias_force_fp32 = True self.scale_force_fp32 = True - if ops is not None: + if vllm_ops is not None: self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm else: self.act_quant_func = int8_quantize_triton @@ -1661,7 +1669,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate): Quant MM: Weight: int8 perchannel sym Act: int8 perchannel dynamic sym - Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm + Kernel: quant-mm using Sgl-kernel, act dynamic quant using sglang and fallback to whatever quant backend available with this priority: vllm > torchao > triton """ def __init__( @@ -1688,7 +1696,15 @@ def __init__( lora_path, ) self.load_func = self.load_int8_perchannel_sym - self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm + # Priority sglang > vllm > toarchao > triton + if sglang_int8_act_quant: + self.act_quant_func = sglang_int8_act_quant + elif vllm_ops: + self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm + elif torchao_int8_quant: + self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao + else: + self.act_quant_func = int8_quantize_triton self.weight_need_transpose = True self.scale_force_fp32 = True