diff --git a/src/neuronx_distributed_inference/models/config.py b/src/neuronx_distributed_inference/models/config.py index 61d60afe..9354d806 100644 --- a/src/neuronx_distributed_inference/models/config.py +++ b/src/neuronx_distributed_inference/models/config.py @@ -699,6 +699,7 @@ def __init__( self.moe_tp_degree = kwargs.pop("moe_tp_degree", 1) self.moe_ep_degree = kwargs.pop("moe_ep_degree", 1) + self.expert_wise_scale = kwargs.pop("expert_wise_scale", False) self.transpose_shared_experts_weights = kwargs.pop("transpose_shared_experts_weights", False) self.blockwise_matmul_config = kwargs.pop("blockwise_matmul_config", {}) diff --git a/src/neuronx_distributed_inference/models/model_wrapper.py b/src/neuronx_distributed_inference/models/model_wrapper.py index 20bab283..ca2ba45d 100644 --- a/src/neuronx_distributed_inference/models/model_wrapper.py +++ b/src/neuronx_distributed_inference/models/model_wrapper.py @@ -1653,13 +1653,54 @@ def load_module(self): models_to_convert.append(float_model) for model in models_to_convert: - convert( - model, - q_config=q_config, - inplace=True, - mapping=None, - modules_to_not_convert=get_modules_to_not_convert(model.config.neuron_config), - ) + user_modules_to_not_convert = get_modules_to_not_convert(model.config.neuron_config) + + # Read expert_wise_scale per-model (not from self.neuron_config) so that + # fused speculation with a non-MoE draft + MoE target is handled correctly. + use_expert_wise_scale = getattr(model.config.neuron_config, "expert_wise_scale", False) + + if use_expert_wise_scale and quantization_type == QuantizationType.PER_CHANNEL_SYMMETRIC: + # Two-pass conversion: + # Pass 1: Convert non-expert modules with per_channel_symmetric, + # skip expert MoE modules (expert_mlps) + pass1_skip = list(user_modules_to_not_convert) if user_modules_to_not_convert else [] + pass1_skip.append("expert_mlps") + convert( + model, + q_config=q_config, + inplace=True, + mapping=None, + modules_to_not_convert=pass1_skip, + ) + + # Pass 2: Convert expert MoE modules with expert_wise_per_channel_symmetric + expert_q_config = get_default_expert_wise_per_channel_custom_qconfig_dict() + if isinstance(self.neuron_config.quantization_dtype, str): + expert_q_config["quantized_dtype"] = QuantizedDtype.get_dtype( + self.neuron_config.quantization_dtype + ) + elif isinstance(self.neuron_config.quantization_dtype, QuantizedDtype): + expert_q_config["quantized_dtype"] = self.neuron_config.quantization_dtype + expert_q_config["activation_quantization_type"] = ActivationQuantizationType( + self.neuron_config.activation_quantization_type + ) + expert_q_config["clamp_bound"] = self.neuron_config.quantize_clamp_bound + convert( + model, + q_config=expert_q_config, + inplace=True, + mapping=None, + include=["*expert_mlps.mlp_op*"], + ) + else: + # Standard single-pass conversion + convert( + model, + q_config=q_config, + inplace=True, + mapping=None, + modules_to_not_convert=user_modules_to_not_convert, + ) self.module = float_model else: diff --git a/src/neuronx_distributed_inference/models/qwen3_moe/modeling_qwen3_moe.py b/src/neuronx_distributed_inference/models/qwen3_moe/modeling_qwen3_moe.py index 97e4c160..05d4925e 100644 --- a/src/neuronx_distributed_inference/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/neuronx_distributed_inference/models/qwen3_moe/modeling_qwen3_moe.py @@ -203,6 +203,76 @@ def convert_qwen3_moe_hf_to_neuron_state_dict(neuron_state_dict, config): gate_up_proj = gate_up_proj.reshape(config.num_experts, hidden_size, -1) neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj + # Fuse expert scales for gate_up_proj if quantized + is_expert_quantized = ( + getattr(config.neuron_config, "quantized", False) + or getattr(config.neuron_config, "quantized_mlp_kernel_enabled", False) + ) and f"layers.{l}.mlp.experts.0.gate_proj.scale" in neuron_state_dict + + use_expert_wise_scale = is_expert_quantized and getattr( + config.neuron_config, "expert_wise_scale", False + ) + + if is_expert_quantized: + if use_expert_wise_scale: + # Per-expert scales: [num_experts, 1, 2*intermediate_size] + gate_scales = [] + up_scales = [] + for e in range(config.num_experts): + gate_scales.append( + neuron_state_dict[f"layers.{l}.mlp.experts.{e}.gate_proj.scale"] + .detach().clone().to(torch.float32) + ) + up_scales.append( + neuron_state_dict[f"layers.{l}.mlp.experts.{e}.up_proj.scale"] + .detach().clone().to(torch.float32) + ) + # Each scale is [intermediate_size, 1] -> transpose to [1, intermediate_size] + gate_scale = torch.stack([s.T for s in gate_scales], dim=0) # [E, 1, intermediate] + up_scale = torch.stack([s.T for s in up_scales], dim=0) # [E, 1, intermediate] + gate_up_proj_scale = torch.zeros( + config.num_experts, 1, 2 * intermediate_size, + dtype=torch.float32, device=device, + ) + torch.narrow(gate_up_proj_scale, 2, 0, intermediate_size).copy_(gate_scale) + torch.narrow(gate_up_proj_scale, 2, intermediate_size, intermediate_size).copy_(up_scale) + else: + # Averaged scale: [1, 1, 2*intermediate_size] + gate_scale_sum = torch.zeros(intermediate_size, 1, dtype=torch.float32, device=device) + up_scale_sum = torch.zeros(intermediate_size, 1, dtype=torch.float32, device=device) + for e in range(config.num_experts): + gate_scale_sum += ( + neuron_state_dict[f"layers.{l}.mlp.experts.{e}.gate_proj.scale"] + .detach().clone().to(torch.float32) + ) + up_scale_sum += ( + neuron_state_dict[f"layers.{l}.mlp.experts.{e}.up_proj.scale"] + .detach().clone().to(torch.float32) + ) + gate_scale = (gate_scale_sum / config.num_experts).T.unsqueeze(0) # [1, 1, intermediate] + up_scale = (up_scale_sum / config.num_experts).T.unsqueeze(0) # [1, 1, intermediate] + gate_up_proj_scale = torch.zeros( + 1, 1, 2 * intermediate_size, + dtype=torch.float32, device=device, + ) + torch.narrow(gate_up_proj_scale, 2, 0, intermediate_size).copy_(gate_scale) + torch.narrow(gate_up_proj_scale, 2, intermediate_size, intermediate_size).copy_(up_scale) + + for e in range(config.num_experts): + del neuron_state_dict[f"layers.{l}.mlp.experts.{e}.gate_proj.scale"] + del neuron_state_dict[f"layers.{l}.mlp.experts.{e}.up_proj.scale"] + + if pad_size > 0: + if use_expert_wise_scale: + gate_up_proj_scale = gate_up_proj_scale.reshape(config.num_experts, 1, 2, -1) + gate_up_proj_scale = torch.nn.functional.pad(gate_up_proj_scale, (0, pad_size)) + gate_up_proj_scale = gate_up_proj_scale.reshape(config.num_experts, 1, -1) + else: + gate_up_proj_scale = gate_up_proj_scale.reshape(1, 1, 2, -1) + gate_up_proj_scale = torch.nn.functional.pad(gate_up_proj_scale, (0, pad_size)) + gate_up_proj_scale = gate_up_proj_scale.reshape(1, 1, -1) + neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.gate_up_proj.scale"] = gate_up_proj_scale + down_proj = torch.empty( config.num_experts, intermediate_size, @@ -227,6 +297,32 @@ def convert_qwen3_moe_hf_to_neuron_state_dict(neuron_state_dict, config): down_proj = torch.nn.functional.pad(down_proj, (0, 0, 0, pad_size)) neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj + # Fuse expert scales for down_proj if quantized + if is_expert_quantized: + if use_expert_wise_scale: + # Per-expert scales: [num_experts, 1, hidden_size] + down_scales = [] + for e in range(config.num_experts): + down_scales.append( + neuron_state_dict[f"layers.{l}.mlp.experts.{e}.down_proj.scale"] + .detach().clone().to(torch.float32) + ) + # Each scale is [hidden_size, 1] -> transpose to [1, hidden_size] + down_proj_scale = torch.stack([s.T for s in down_scales], dim=0) # [E, 1, hidden] + else: + # Averaged scale: [1, 1, hidden_size] + down_proj_scale_sum = torch.zeros(hidden_size, 1, dtype=torch.float32, device=device) + for e in range(config.num_experts): + down_proj_scale_sum += ( + neuron_state_dict[f"layers.{l}.mlp.experts.{e}.down_proj.scale"] + .detach().clone().to(torch.float32) + ) + down_proj_scale = (down_proj_scale_sum / config.num_experts).T.unsqueeze(0) + + for e in range(config.num_experts): + del neuron_state_dict[f"layers.{l}.mlp.experts.{e}.down_proj.scale"] + neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.down_proj.scale"] = down_proj_scale + gc.collect() if config.neuron_config.fused_qkv: