From ddaf423bf3c48e1450ddd9eca8f0f2f7f7bb8534 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sita=20B=C3=A9r=C3=A9t=C3=A9?= Date: Sat, 7 Feb 2026 17:38:00 +0000 Subject: [PATCH 1/3] Make vllm optional for when using int8-sgl --- lightx2v/common/ops/mm/mm_weight.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/lightx2v/common/ops/mm/mm_weight.py b/lightx2v/common/ops/mm/mm_weight.py index 6cbf1469..c93d8a68 100755 --- a/lightx2v/common/ops/mm/mm_weight.py +++ b/lightx2v/common/ops/mm/mm_weight.py @@ -41,9 +41,15 @@ scaled_mxfp8_quant, cutlass_scaled_mxfp8_mm = None, None try: - from vllm import _custom_ops as ops + from vllm import _custom_ops as vllm_ops except ImportError: - ops = None + vllm_ops = None + + +try: + from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 as sglang_int8_act_quant +except ImportError: + sglang_int8_act_quant = None try: import sgl_kernel @@ -527,6 +533,8 @@ def per_block_cast_to_fp8(self, x): # ========================= def act_quant_int8_perchannel_sym_torchao(self, x): input_tensor_quant, input_tensor_scale = torchao_int8_quant(x) + if self.scale_force_fp32: + input_tensor_scale = input_tensor_scale.to(torch.float32) return input_tensor_quant, input_tensor_scale def act_quant_fp8_perchannel_sym_torchao(self, x): @@ -537,7 +545,7 @@ def act_quant_fp8_perchannel_sym_torchao(self, x): return quantized, scale.float() def act_quant_fp8_perchannel_sym_vllm(self, x): - input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True) + input_tensor_quant, input_tensor_scale = vllm_ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True) return input_tensor_quant, input_tensor_scale def act_quant_fp8_perchannel_sym_sgl(self, x): @@ -548,7 +556,7 @@ def act_quant_fp8_perchannel_sym_sgl(self, x): return input_tensor_quant, input_tensor_scale def act_quant_int8_perchannel_sym_vllm(self, x): - input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) + input_tensor_quant, input_tensor_scale, _ = vllm_ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True) return input_tensor_quant, input_tensor_scale def act_quant_nvfp4(self, x): @@ -1336,7 +1344,7 @@ def __init__( self.weight_need_transpose = False self.bias_force_fp32 = True self.scale_force_fp32 = True - if ops is not None: + if vllm_ops is not None: self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm else: self.act_quant_func = fp8_quantize_triton @@ -1394,7 +1402,7 @@ def __init__( self.weight_need_transpose = False self.bias_force_fp32 = True self.scale_force_fp32 = True - if ops is not None: + if vllm_ops is not None: self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm else: self.act_quant_func = int8_quantize_triton @@ -1661,7 +1669,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate): Quant MM: Weight: int8 perchannel sym Act: int8 perchannel dynamic sym - Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm + Kernel: quant-mm using Sgl-kernel, act dynamic quant using sglang and fallbacke to whathever quant backend available with this priority: vllm > toarchao > triton """ def __init__( @@ -1688,7 +1696,10 @@ def __init__( lora_path, ) self.load_func = self.load_int8_perchannel_sym - self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm + # Priority sglang > vllm > toarchao > triton + self.act_quant_func = sglang_int8_act_quant or ( + self.act_quant_int8_perchannel_sym_vllm if vllm_ops else self.act_quant_int8_perchannel_sym_torchao if torchao_int8_quant else int8_quantize_triton + ) self.weight_need_transpose = True self.scale_force_fp32 = True From e21c3cca711f5c826ea2317e8ecb0397467f37f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sita=20B=C3=A9r=C3=A9t=C3=A9?= Date: Sat, 7 Feb 2026 19:06:23 +0000 Subject: [PATCH 2/3] Update lightx2v/common/ops/mm/mm_weight.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- lightx2v/common/ops/mm/mm_weight.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/lightx2v/common/ops/mm/mm_weight.py b/lightx2v/common/ops/mm/mm_weight.py index c93d8a68..41eac4bc 100755 --- a/lightx2v/common/ops/mm/mm_weight.py +++ b/lightx2v/common/ops/mm/mm_weight.py @@ -1697,9 +1697,14 @@ def __init__( ) self.load_func = self.load_int8_perchannel_sym # Priority sglang > vllm > toarchao > triton - self.act_quant_func = sglang_int8_act_quant or ( - self.act_quant_int8_perchannel_sym_vllm if vllm_ops else self.act_quant_int8_perchannel_sym_torchao if torchao_int8_quant else int8_quantize_triton - ) + if sglang_int8_act_quant: + self.act_quant_func = sglang_int8_act_quant + elif vllm_ops: + self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm + elif torchao_int8_quant: + self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao + else: + self.act_quant_func = int8_quantize_triton self.weight_need_transpose = True self.scale_force_fp32 = True From 27de0ef15ce6fb05ea497e81027110060a06a366 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sita=20B=C3=A9r=C3=A9t=C3=A9?= Date: Sat, 7 Feb 2026 19:06:35 +0000 Subject: [PATCH 3/3] Update lightx2v/common/ops/mm/mm_weight.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- lightx2v/common/ops/mm/mm_weight.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightx2v/common/ops/mm/mm_weight.py b/lightx2v/common/ops/mm/mm_weight.py index 41eac4bc..ea8abfd6 100755 --- a/lightx2v/common/ops/mm/mm_weight.py +++ b/lightx2v/common/ops/mm/mm_weight.py @@ -1669,7 +1669,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate): Quant MM: Weight: int8 perchannel sym Act: int8 perchannel dynamic sym - Kernel: quant-mm using Sgl-kernel, act dynamic quant using sglang and fallbacke to whathever quant backend available with this priority: vllm > toarchao > triton + Kernel: quant-mm using Sgl-kernel, act dynamic quant using sglang and fallback to whatever quant backend available with this priority: vllm > torchao > triton """ def __init__(