From 7fe412e8a190f378980fda13700ac479768f2e68 Mon Sep 17 00:00:00 2001 From: lifelongeeek Date: Fri, 13 Feb 2026 07:16:32 +0000 Subject: [PATCH 1/2] feat: add expert_wise_scale support for per-expert FP8 quantization in MoE models Add per-expert scale quantization support for MoE expert MLPs, which preserves individual expert scale factors instead of averaging them. This significantly improves FP8 quantization accuracy for MoE models like Qwen3-30B-A3B. Changes: - config.py: Add expert_wise_scale config option to MoENeuronConfig - modeling_qwen3_moe.py: Fuse per-expert scales during HF->Neuron state_dict conversion (gate_up_proj and down_proj), with fallback to averaged scales when expert_wise_scale=False - model_wrapper.py: Two-pass quantization conversion when expert_wise_scale=True (Pass 1: per_channel_symmetric for non-expert modules, Pass 2: expert_wise_per_channel_symmetric for expert MLPs) Validated on Qwen3-30B-A3B (trn2.3xlarge): - IFEval inst_strict: 0.405 (vs 0.00 with averaged scale) - TruthfulQA bleu_acc: 0.45 (vs 0.03 with averaged scale) - Throughput: 58.7 tok/s (unchanged from baseline) --- .../models/config.py | 1 + .../models/model_wrapper.py | 54 +++++++++-- .../models/qwen3_moe/modeling_qwen3_moe.py | 96 +++++++++++++++++++ 3 files changed, 144 insertions(+), 7 deletions(-) diff --git a/src/neuronx_distributed_inference/models/config.py b/src/neuronx_distributed_inference/models/config.py index 61d60afe..9354d806 100644 --- a/src/neuronx_distributed_inference/models/config.py +++ b/src/neuronx_distributed_inference/models/config.py @@ -699,6 +699,7 @@ def __init__( self.moe_tp_degree = kwargs.pop("moe_tp_degree", 1) self.moe_ep_degree = kwargs.pop("moe_ep_degree", 1) + self.expert_wise_scale = kwargs.pop("expert_wise_scale", False) self.transpose_shared_experts_weights = kwargs.pop("transpose_shared_experts_weights", False) self.blockwise_matmul_config = kwargs.pop("blockwise_matmul_config", {}) diff --git a/src/neuronx_distributed_inference/models/model_wrapper.py b/src/neuronx_distributed_inference/models/model_wrapper.py index 20bab283..d188d776 100644 --- a/src/neuronx_distributed_inference/models/model_wrapper.py +++ b/src/neuronx_distributed_inference/models/model_wrapper.py @@ -1652,14 +1652,54 @@ def load_module(self): else: models_to_convert.append(float_model) + # Check if expert_wise_scale is enabled for two-pass conversion + use_expert_wise_scale = getattr(self.neuron_config, "expert_wise_scale", False) + for model in models_to_convert: - convert( - model, - q_config=q_config, - inplace=True, - mapping=None, - modules_to_not_convert=get_modules_to_not_convert(model.config.neuron_config), - ) + user_modules_to_not_convert = get_modules_to_not_convert(model.config.neuron_config) + + if use_expert_wise_scale and quantization_type == QuantizationType.PER_CHANNEL_SYMMETRIC: + # Two-pass conversion: + # Pass 1: Convert non-expert modules with per_channel_symmetric, + # skip expert MoE modules (expert_mlps) + pass1_skip = list(user_modules_to_not_convert) if user_modules_to_not_convert else [] + pass1_skip.append("expert_mlps") + convert( + model, + q_config=q_config, + inplace=True, + mapping=None, + modules_to_not_convert=pass1_skip, + ) + + # Pass 2: Convert expert MoE modules with expert_wise_per_channel_symmetric + expert_q_config = get_default_expert_wise_per_channel_custom_qconfig_dict() + if isinstance(self.neuron_config.quantization_dtype, str): + expert_q_config["quantized_dtype"] = QuantizedDtype.get_dtype( + self.neuron_config.quantization_dtype + ) + elif isinstance(self.neuron_config.quantization_dtype, QuantizedDtype): + expert_q_config["quantized_dtype"] = self.neuron_config.quantization_dtype + expert_q_config["activation_quantization_type"] = ActivationQuantizationType( + self.neuron_config.activation_quantization_type + ) + expert_q_config["clamp_bound"] = self.neuron_config.quantize_clamp_bound + convert( + model, + q_config=expert_q_config, + inplace=True, + mapping=None, + include=["*expert_mlps.mlp_op*"], + ) + else: + # Standard single-pass conversion + convert( + model, + q_config=q_config, + inplace=True, + mapping=None, + modules_to_not_convert=user_modules_to_not_convert, + ) self.module = float_model else: diff --git a/src/neuronx_distributed_inference/models/qwen3_moe/modeling_qwen3_moe.py b/src/neuronx_distributed_inference/models/qwen3_moe/modeling_qwen3_moe.py index 97e4c160..05d4925e 100644 --- a/src/neuronx_distributed_inference/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/neuronx_distributed_inference/models/qwen3_moe/modeling_qwen3_moe.py @@ -203,6 +203,76 @@ def convert_qwen3_moe_hf_to_neuron_state_dict(neuron_state_dict, config): gate_up_proj = gate_up_proj.reshape(config.num_experts, hidden_size, -1) neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj + # Fuse expert scales for gate_up_proj if quantized + is_expert_quantized = ( + getattr(config.neuron_config, "quantized", False) + or getattr(config.neuron_config, "quantized_mlp_kernel_enabled", False) + ) and f"layers.{l}.mlp.experts.0.gate_proj.scale" in neuron_state_dict + + use_expert_wise_scale = is_expert_quantized and getattr( + config.neuron_config, "expert_wise_scale", False + ) + + if is_expert_quantized: + if use_expert_wise_scale: + # Per-expert scales: [num_experts, 1, 2*intermediate_size] + gate_scales = [] + up_scales = [] + for e in range(config.num_experts): + gate_scales.append( + neuron_state_dict[f"layers.{l}.mlp.experts.{e}.gate_proj.scale"] + .detach().clone().to(torch.float32) + ) + up_scales.append( + neuron_state_dict[f"layers.{l}.mlp.experts.{e}.up_proj.scale"] + .detach().clone().to(torch.float32) + ) + # Each scale is [intermediate_size, 1] -> transpose to [1, intermediate_size] + gate_scale = torch.stack([s.T for s in gate_scales], dim=0) # [E, 1, intermediate] + up_scale = torch.stack([s.T for s in up_scales], dim=0) # [E, 1, intermediate] + gate_up_proj_scale = torch.zeros( + config.num_experts, 1, 2 * intermediate_size, + dtype=torch.float32, device=device, + ) + torch.narrow(gate_up_proj_scale, 2, 0, intermediate_size).copy_(gate_scale) + torch.narrow(gate_up_proj_scale, 2, intermediate_size, intermediate_size).copy_(up_scale) + else: + # Averaged scale: [1, 1, 2*intermediate_size] + gate_scale_sum = torch.zeros(intermediate_size, 1, dtype=torch.float32, device=device) + up_scale_sum = torch.zeros(intermediate_size, 1, dtype=torch.float32, device=device) + for e in range(config.num_experts): + gate_scale_sum += ( + neuron_state_dict[f"layers.{l}.mlp.experts.{e}.gate_proj.scale"] + .detach().clone().to(torch.float32) + ) + up_scale_sum += ( + neuron_state_dict[f"layers.{l}.mlp.experts.{e}.up_proj.scale"] + .detach().clone().to(torch.float32) + ) + gate_scale = (gate_scale_sum / config.num_experts).T.unsqueeze(0) # [1, 1, intermediate] + up_scale = (up_scale_sum / config.num_experts).T.unsqueeze(0) # [1, 1, intermediate] + gate_up_proj_scale = torch.zeros( + 1, 1, 2 * intermediate_size, + dtype=torch.float32, device=device, + ) + torch.narrow(gate_up_proj_scale, 2, 0, intermediate_size).copy_(gate_scale) + torch.narrow(gate_up_proj_scale, 2, intermediate_size, intermediate_size).copy_(up_scale) + + for e in range(config.num_experts): + del neuron_state_dict[f"layers.{l}.mlp.experts.{e}.gate_proj.scale"] + del neuron_state_dict[f"layers.{l}.mlp.experts.{e}.up_proj.scale"] + + if pad_size > 0: + if use_expert_wise_scale: + gate_up_proj_scale = gate_up_proj_scale.reshape(config.num_experts, 1, 2, -1) + gate_up_proj_scale = torch.nn.functional.pad(gate_up_proj_scale, (0, pad_size)) + gate_up_proj_scale = gate_up_proj_scale.reshape(config.num_experts, 1, -1) + else: + gate_up_proj_scale = gate_up_proj_scale.reshape(1, 1, 2, -1) + gate_up_proj_scale = torch.nn.functional.pad(gate_up_proj_scale, (0, pad_size)) + gate_up_proj_scale = gate_up_proj_scale.reshape(1, 1, -1) + neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.gate_up_proj.scale"] = gate_up_proj_scale + down_proj = torch.empty( config.num_experts, intermediate_size, @@ -227,6 +297,32 @@ def convert_qwen3_moe_hf_to_neuron_state_dict(neuron_state_dict, config): down_proj = torch.nn.functional.pad(down_proj, (0, 0, 0, pad_size)) neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj + # Fuse expert scales for down_proj if quantized + if is_expert_quantized: + if use_expert_wise_scale: + # Per-expert scales: [num_experts, 1, hidden_size] + down_scales = [] + for e in range(config.num_experts): + down_scales.append( + neuron_state_dict[f"layers.{l}.mlp.experts.{e}.down_proj.scale"] + .detach().clone().to(torch.float32) + ) + # Each scale is [hidden_size, 1] -> transpose to [1, hidden_size] + down_proj_scale = torch.stack([s.T for s in down_scales], dim=0) # [E, 1, hidden] + else: + # Averaged scale: [1, 1, hidden_size] + down_proj_scale_sum = torch.zeros(hidden_size, 1, dtype=torch.float32, device=device) + for e in range(config.num_experts): + down_proj_scale_sum += ( + neuron_state_dict[f"layers.{l}.mlp.experts.{e}.down_proj.scale"] + .detach().clone().to(torch.float32) + ) + down_proj_scale = (down_proj_scale_sum / config.num_experts).T.unsqueeze(0) + + for e in range(config.num_experts): + del neuron_state_dict[f"layers.{l}.mlp.experts.{e}.down_proj.scale"] + neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.down_proj.scale"] = down_proj_scale + gc.collect() if config.neuron_config.fused_qkv: From 417fe1b7e7b92903cd93132792915882db9ce285 Mon Sep 17 00:00:00 2001 From: lifelongeeek Date: Fri, 13 Feb 2026 09:52:30 +0000 Subject: [PATCH 2/2] fix: read expert_wise_scale per-model instead of from global wrapper config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move use_expert_wise_scale into the per-model loop and read from model.config.neuron_config instead of self.neuron_config. This ensures correct behavior in fused speculation mode where a non-MoE draft model and MoE target model are processed in the same loop — previously the global flag would incorrectly apply two-pass expert quantization to a draft model that has no expert_mlps module. --- src/neuronx_distributed_inference/models/model_wrapper.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/neuronx_distributed_inference/models/model_wrapper.py b/src/neuronx_distributed_inference/models/model_wrapper.py index d188d776..ca2ba45d 100644 --- a/src/neuronx_distributed_inference/models/model_wrapper.py +++ b/src/neuronx_distributed_inference/models/model_wrapper.py @@ -1652,12 +1652,13 @@ def load_module(self): else: models_to_convert.append(float_model) - # Check if expert_wise_scale is enabled for two-pass conversion - use_expert_wise_scale = getattr(self.neuron_config, "expert_wise_scale", False) - for model in models_to_convert: user_modules_to_not_convert = get_modules_to_not_convert(model.config.neuron_config) + # Read expert_wise_scale per-model (not from self.neuron_config) so that + # fused speculation with a non-MoE draft + MoE target is handled correctly. + use_expert_wise_scale = getattr(model.config.neuron_config, "expert_wise_scale", False) + if use_expert_wise_scale and quantization_type == QuantizationType.PER_CHANNEL_SYMMETRIC: # Two-pass conversion: # Pass 1: Convert non-expert modules with per_channel_symmetric,