diff --git a/lightx2v/models/input_encoders/lightllm/qwen25_text_encoder_kernel.py b/lightx2v/models/input_encoders/lightllm/qwen25_text_encoder_kernel.py index a5cda505..fe5016e0 100644 --- a/lightx2v/models/input_encoders/lightllm/qwen25_text_encoder_kernel.py +++ b/lightx2v/models/input_encoders/lightllm/qwen25_text_encoder_kernel.py @@ -124,16 +124,16 @@ def _apply_kernel_optimizations(self): if self.use_flash_attention_kernel: logger.info(" ✓ Flash Attention 2 (loaded with model)") - if self.use_rmsnorm_kernel: - try: - from sgl_kernel.elementwise import rmsnorm - - self._rmsnorm_kernel = rmsnorm - self._replace_rmsnorm_with_kernel() - logger.info(" ✓ RMSNorm kernel integrated (from sgl_kernel)") - except ImportError as e: - logger.warning(f" ✗ Failed to import sgl_kernel: {e}. RMSNorm optimization disabled.") - self.use_rmsnorm_kernel = False + if self.use_rmsnorm_kernel: + try: + from sgl_kernel.elementwise import rmsnorm + + self._rmsnorm_kernel = rmsnorm + self._replace_rmsnorm_with_kernel() + logger.info(" ✓ RMSNorm kernel integrated (from sgl_kernel)") + except ImportError as e: + logger.warning(f" ✗ Failed to import sgl_kernel: {e}. RMSNorm optimization disabled.") + self.use_rmsnorm_kernel = False def _replace_rmsnorm_with_kernel(self): """Replace RMSNorm layers with fused kernel"""