From 13f16d58f8f6a09b8e793fa484d16843df2dd5e4 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Mon, 9 Mar 2026 03:39:33 +0000 Subject: [PATCH 1/5] cleanup Signed-off-by: ZX-ModelCloud --- defuser/defuser.py | 106 +++++++++------- defuser/model_registry.py | 8 ++ .../fused_moe/moe_experts_interface.py | 8 -- defuser/modeling/fused_moe/qwen3_5_moe.py | 116 ------------------ defuser/modeling/fused_moe/replace_modules.py | 33 +---- tests/test_convert_model.py | 15 ++- 6 files changed, 81 insertions(+), 205 deletions(-) delete mode 100644 defuser/modeling/fused_moe/qwen3_5_moe.py diff --git a/defuser/defuser.py b/defuser/defuser.py index 1bb2f95..6148ea5 100644 --- a/defuser/defuser.py +++ b/defuser/defuser.py @@ -5,14 +5,15 @@ from torch import nn +from defuser.model_registry import MODEL_CONFIG, CONVERSION_BEHAVIOR from defuser.modeling.fused_moe.update_module import update_module from defuser.utils.hf import patch def convert_hf_model( - model: nn.Module, - cleanup_original: bool = False, - max_layers: int | None = None, + model: nn.Module, + cleanup_original: bool = False, + max_layers: int | None = None, ) -> nn.Module: if max_layers is not None and max_layers < 1: raise ValueError("max_layers must be >= 1 when provided") @@ -48,53 +49,66 @@ def convert_hf_model( # # If this patch succeeds, it means the model is in the Qwen3 MoE format and # no further tensor transformation is required. - is_applied = patch(model, max_layers=max_layers) - if not is_applied: - # ----------------------------------------------------------------------- - # Step 2: Handle Qwen3.5 MoE checkpoints - # ----------------------------------------------------------------------- - # - # If `apply_modeling_patch` fails, we assume the checkpoint corresponds to - # **Qwen3.5 MoE**. - # - # In Qwen3.5 MoE, the expert MLP weights are stored in a **fused format**. - # Specifically, the checkpoint keeps tensors such as: - # - # gate_up_proj - # down_proj - # - # where `gate_proj` and `up_proj` are fused together. - # - # Because our runtime modeling expects **defused tensors**, simply replacing - # the module structure is not enough. We must also convert the stored - # parameters. - # - # `update_module()` performs two tasks: - # - # 1) Replace the modeling structure so that it matches the expected - # defused MoE implementation. - # - # 2) Prepare the module for **tensor defusion** of the expert weights. - # - # After the structure update, `materialize_model_()` will be invoked to - # actually split the fused tensors: - # - # gate_up_proj --> gate_proj + up_proj - # - # and ensure the module finally contains the expected parameters: - # - # gate_proj - # up_proj - # down_proj - # - # This ensures compatibility between the Qwen3.5 fused checkpoint format - # and the runtime model implementation that operates on defused weights. - model = update_module( + + # ----------------------------------------------------------------------- + # Step 2: Handle Qwen3.5 MoE checkpoints + # ----------------------------------------------------------------------- + # + # If `apply_modeling_patch` fails, we assume the checkpoint corresponds to + # **Qwen3.5 MoE**. + # + # In Qwen3.5 MoE, the expert MLP weights are stored in a **fused format**. + # Specifically, the checkpoint keeps tensors such as: + # + # gate_up_proj + # down_proj + # + # where `gate_proj` and `up_proj` are fused together. + # + # Because our runtime modeling expects **defused tensors**, simply replacing + # the module structure is not enough. We must also convert the stored + # parameters. + # + # `update_module()` performs two tasks: + # + # 1) Replace the modeling structure so that it matches the expected + # defused MoE implementation. + # + # 2) Prepare the module for **tensor defusion** of the expert weights. + # + # After the structure update, `materialize_model_()` will be invoked to + # actually split the fused tensors: + # + # gate_up_proj --> gate_proj + up_proj + # + # and ensure the module finally contains the expected parameters: + # + # gate_proj + # up_proj + # down_proj + # + # This ensures compatibility between the Qwen3.5 fused checkpoint format + # and the runtime model implementation that operates on defused weights. + + model_type = getattr(getattr(model, "config", None), "model_type", None) + if model_type not in MODEL_CONFIG: + raise ValueError(f"Unsupported model_type: {model_type}") + + behavior = MODEL_CONFIG[model_type]["behavior"] + + if behavior == CONVERSION_BEHAVIOR.REPLACE_ONLY: + if not patch(model, max_layers=max_layers): + raise RuntimeError(f"Failed to replace modules for {model_type}") + return model + + if behavior == CONVERSION_BEHAVIOR.REPLACE_AND_DEFUSE: + return update_module( model, cleanup_original=cleanup_original, max_layers=max_layers, ) - return model + + raise ValueError(f"Unknown conversion behavior for {model_type}: {behavior}") __all__ = ["convert_hf_model"] diff --git a/defuser/model_registry.py b/defuser/model_registry.py index e09e4a6..11dfdf5 100644 --- a/defuser/model_registry.py +++ b/defuser/model_registry.py @@ -11,9 +11,15 @@ class PATCH(str, Enum): REPLACE_MODULE = "replace_module" +class CONVERSION_BEHAVIOR(str, Enum): + REPLACE_ONLY = "replace_only" + REPLACE_AND_DEFUSE = "replace_and_defuse" + + MODEL_CONFIG = { "qwen3_moe": { "min_transformers_version": "5.0.0", + "behavior": CONVERSION_BEHAVIOR.REPLACE_ONLY, # structure path only replaces modeling structure PATCH.REPLACE_MODULE: [ ( @@ -26,10 +32,12 @@ class PATCH(str, Enum): "min_transformers_version": "5.2.0", # Replacement module path imported only when the defuse workflow runs PATCH.DEFUSE: "defuser.modeling.fused_moe.qwen3_5_moe", + "behavior": CONVERSION_BEHAVIOR.REPLACE_AND_DEFUSE, }, "qwen3_5_moe_text": { "min_transformers_version": "5.2.0", # Replacement module path imported only when the defuse workflow runs PATCH.DEFUSE: "defuser.modeling.fused_moe.qwen3_5_moe", + "behavior": CONVERSION_BEHAVIOR.REPLACE_AND_DEFUSE, }, } diff --git a/defuser/modeling/fused_moe/moe_experts_interface.py b/defuser/modeling/fused_moe/moe_experts_interface.py index 4f9bb77..ee9d746 100644 --- a/defuser/modeling/fused_moe/moe_experts_interface.py +++ b/defuser/modeling/fused_moe/moe_experts_interface.py @@ -56,14 +56,6 @@ "gate_proj": {"is_input_proj": True, "output_multiplier": 1}, # hidden -> intermediate (gate) "up_proj": {"is_input_proj": True, "output_multiplier": 1}, # hidden -> intermediate (up) "down_proj": {"is_input_proj": False, "output_multiplier": 1}, # intermediate -> hidden - # Mixtral-style - "w1": {"is_input_proj": True, "output_multiplier": 1}, # gate: hidden -> intermediate - "w2": {"is_input_proj": False, "output_multiplier": 1}, # down: intermediate -> hidden - "w3": {"is_input_proj": True, "output_multiplier": 1}, # up: hidden -> intermediate - # DBRX-style - "v1": {"is_input_proj": True, "output_multiplier": 1}, - "w1_proj": {"is_input_proj": True, "output_multiplier": 1}, - "w2_proj": {"is_input_proj": False, "output_multiplier": 1}, } diff --git a/defuser/modeling/fused_moe/qwen3_5_moe.py b/defuser/modeling/fused_moe/qwen3_5_moe.py deleted file mode 100644 index 256b86a..0000000 --- a/defuser/modeling/fused_moe/qwen3_5_moe.py +++ /dev/null @@ -1,116 +0,0 @@ -# SPDX-FileCopyrightText: 2026 ModelCloud.ai -# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -# Adapted from intel/auto-round -# at https://github.com/intel/auto-round/blob/main/auto_round/modeling/fused_moe/qwen3_5_moe.py - -import torch -from torch.nn import functional as F -from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeMLP -from transformers.utils.versions import require_version - -from defuser.modeling.fused_moe.replace_modules import ReplacementModuleBase -from defuser.utils.device import clear_memory, to_meta, unsupported_meta_device - -require_version("transformers>=5.2.0") - -from defuser.utils.model import _update_parameter - - -class LinearQwen3_5MoeSparseMoeBlock(ReplacementModuleBase): - def __init__(self, original, config): - super().__init__(original) - self.gate = original.gate - text_config = config.get_text_config() - self.shared_expert = original.shared_expert - self.experts = SequentialQwen3_5MoeExperts(text_config, original.experts) - self.shared_expert_gate = original.shared_expert_gate - self.num_experts = text_config.num_experts - - @classmethod - def original_module_class(cls) -> str: - """Return the class name of the module this replaces.""" - return "Qwen3_5MoeSparseMoeBlock" - - def _materialize_weights(self) -> None: - original = self._get_original_module() - self.experts._materialize_weights(original.experts) - clear_memory() - - def experts_forward( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, - ) -> torch.Tensor: - final_hidden_states = torch.zeros_like(hidden_states) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) - expert_mask = expert_mask.permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - - for expert_idx in expert_hit: - expert_idx = expert_idx[0] - if expert_idx == self.num_experts: - continue - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) - current_state = hidden_states[token_idx] - # gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) - # current_hidden_states = self.act_fn(gate) * up - # current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = self.experts[expert_idx](current_state) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] - final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) - - return final_hidden_states - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states_reshaped = hidden_states.view(-1, hidden_dim) - shared_expert_output = self.shared_expert(hidden_states_reshaped) - _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) - expert_output = self.experts_forward(hidden_states_reshaped, selected_experts, routing_weights) - - shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output - - expert_output += shared_expert_output - expert_output = expert_output.reshape(batch_size, sequence_length, hidden_dim) - return expert_output - - @classmethod - def from_original( - cls, - original, - config, - **kwargs, - ): - """Create an instance from the original module.""" - return cls(original, config) - - -class SequentialQwen3_5MoeExperts(torch.nn.ModuleList): - def __init__(self, config, original): - super().__init__() - self.num_experts = original.gate_up_proj.shape[0] - intermediate_size = config.moe_intermediate_size - - super().__init__([Qwen3_5MoeMLP(config, intermediate_size) for _ in range(self.num_experts)]) - - def _materialize_weights(self, original) -> None: - intermediate_size = original.down_proj.shape[-1] - if not unsupported_meta_device(original): - for i in range(self.num_experts): - gate_up = original.gate_up_proj[i] - down = original.down_proj[i] - - gate_proj = gate_up[:intermediate_size, :] - up_proj = gate_up[intermediate_size:, :] - - _update_parameter(self[i].gate_proj, "weight", gate_proj.contiguous()) - _update_parameter(self[i].up_proj, "weight", up_proj.contiguous()) - _update_parameter(self[i].down_proj, "weight", down.contiguous()) - del gate_up, down, gate_proj, up_proj - to_meta(original) # release original experts parameters - clear_memory() diff --git a/defuser/modeling/fused_moe/replace_modules.py b/defuser/modeling/fused_moe/replace_modules.py index 5d0723d..21d65fc 100644 --- a/defuser/modeling/fused_moe/replace_modules.py +++ b/defuser/modeling/fused_moe/replace_modules.py @@ -17,33 +17,11 @@ from tqdm import tqdm from defuser.model_registry import MODEL_CONFIG, PATCH -from defuser.utils.common import is_within_max_layers +from defuser.utils.common import is_within_max_layers, is_transformers_version_greater_or_equal_5 logger = LogBar(__name__) -def is_model_patchable(model: torch.nn.Module) -> bool: - """Check if the model has a custom replacement registered via MODEL_CONFIG. - - Returns True if the model's model_type matches a key in MODEL_CONFIG. - """ - if hasattr(model, "config") and hasattr(model.config, "model_type"): - return model.config.model_type in MODEL_CONFIG - return False - - -def _import_required_replacements(model: torch.nn.Module) -> None: - """Import replacement modules required for the model's defuse workflow.""" - if not is_model_patchable(model): - return - model_type = model.config.model_type - module_path = MODEL_CONFIG[model_type].get(PATCH.DEFUSE) - if not module_path: - return - importlib.import_module(module_path) - logger.debug(f"Loaded replacement module for {model_type}: {module_path}") - - def materialize_model(model: torch.nn.Module) -> None: def _materialize_module(module: torch.nn.Module) -> None: if isinstance(module, ReplacementModuleBase): @@ -259,15 +237,10 @@ def apply_replacements( Returns: The model with modules replaced. """ - _import_required_replacements(model) - _log_first_moe_block(model, "before replacement") - # Custom replacements first - if is_model_patchable(model): - _apply_custom_replacements(model, max_layers=max_layers) - # if auto_detect_moe and is_transformers_version_greater_or_equal_5(): - # _handle_moe_modules(model) + if auto_detect_moe and is_transformers_version_greater_or_equal_5(): + _handle_moe_modules(model) _log_first_moe_block(model, "after replacement") diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py index 58385be..4cc59ea 100644 --- a/tests/test_convert_model.py +++ b/tests/test_convert_model.py @@ -28,7 +28,6 @@ def test_qwen3_moe(): def test_qwen3_5_moe(): - from defuser.modeling.fused_moe.qwen3_5_moe import LinearQwen3_5MoeSparseMoeBlock from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock config = AutoConfig.from_pretrained("/monster/data/model/Qwen3.5-35B-A3B") @@ -54,10 +53,16 @@ def test_qwen3_5_moe(): assert converted moe_block = model.model.language_model.layers[0].mlp - assert isinstance(moe_block, LinearQwen3_5MoeSparseMoeBlock) + experts = moe_block.experts + + assert hasattr(experts, "0") + expert0 = getattr(experts, "0") + assert hasattr(expert0, "gate_proj") + assert hasattr(expert0, "up_proj") + assert hasattr(expert0, "down_proj") materialize_model(model.model.language_model.layers[0]) - torch.testing.assert_close(moe_block.experts[0].gate_proj.weight, expected_gate) - torch.testing.assert_close(moe_block.experts[0].up_proj.weight, expected_up) - torch.testing.assert_close(moe_block.experts[0].down_proj.weight, expected_down) + torch.testing.assert_close(expert0.gate_proj.weight, expected_gate) + torch.testing.assert_close(expert0.up_proj.weight, expected_up) + torch.testing.assert_close(expert0.down_proj.weight, expected_down) From 4c45df129e554765c464686acf677c3259c72447 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Mon, 9 Mar 2026 06:36:24 +0000 Subject: [PATCH 2/5] remove LinearQwen3MoeSparseMoeBlock Signed-off-by: ZX-ModelCloud --- defuser/model_registry.py | 8 +-- defuser/modeling/unfused_moe/qwen3_moe.py | 70 ----------------------- tests/test_convert_model.py | 15 +++-- 3 files changed, 11 insertions(+), 82 deletions(-) delete mode 100644 defuser/modeling/unfused_moe/qwen3_moe.py diff --git a/defuser/model_registry.py b/defuser/model_registry.py index 11dfdf5..3b9bf97 100644 --- a/defuser/model_registry.py +++ b/defuser/model_registry.py @@ -19,14 +19,8 @@ class CONVERSION_BEHAVIOR(str, Enum): MODEL_CONFIG = { "qwen3_moe": { "min_transformers_version": "5.0.0", - "behavior": CONVERSION_BEHAVIOR.REPLACE_ONLY, + "behavior": CONVERSION_BEHAVIOR.REPLACE_AND_DEFUSE, # structure path only replaces modeling structure - PATCH.REPLACE_MODULE: [ - ( - "transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock", - "defuser.modeling.unfused_moe.qwen3_moe.LinearQwen3MoeSparseMoeBlock", - ) - ], }, "qwen3_5_moe": { "min_transformers_version": "5.2.0", diff --git a/defuser/modeling/unfused_moe/qwen3_moe.py b/defuser/modeling/unfused_moe/qwen3_moe.py deleted file mode 100644 index 640e71c..0000000 --- a/defuser/modeling/unfused_moe/qwen3_moe.py +++ /dev/null @@ -1,70 +0,0 @@ -# SPDX-FileCopyrightText: 2026 ModelCloud.ai -# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -# Adapted from intel/auto-round -# at https://github.com/intel/auto-round/blob/main/auto_round/modeling/unfused_moe/qwen3_moe.py - -import torch -import torch.nn as nn -from torch.nn import functional as F - - -class LinearQwen3MoeSparseMoeBlock(nn.Module): - def __init__(self, config): - super().__init__() - from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeMLP - - self.num_experts = config.num_experts - self.top_k = config.num_experts_per_tok - self.norm_topk_prob = config.norm_topk_prob - - # This must be linear for vllm alignment - self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) - self.experts = nn.ModuleList( - [Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)] - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - if self.norm_topk_prob: # only diff with mixtral sparse moe block! - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be solicited - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit: - if expert_idx == self.num_experts: - continue - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py index 4cc59ea..5366264 100644 --- a/tests/test_convert_model.py +++ b/tests/test_convert_model.py @@ -10,12 +10,11 @@ def test_qwen3_moe(): - from defuser.modeling.unfused_moe.qwen3_moe import LinearQwen3MoeSparseMoeBlock - - config = AutoConfig.from_pretrained("/monster/data/model/Qwen3-30B-A3B") + model_id = "Qwen/Qwen3-30B-A3B" + config = AutoConfig.from_pretrained(model_id) config.num_hidden_layers = 1 model = AutoModelForCausalLM.from_pretrained( - "/monster/data/model/Qwen3-30B-A3B", + model_id, config=config, ignore_mismatched_sizes=True, ) @@ -24,7 +23,13 @@ def test_qwen3_moe(): converted = convert_hf_model(model, max_layers=1) assert converted - assert isinstance(model.model.layers[0].mlp, LinearQwen3MoeSparseMoeBlock) + + experts = model.model.layers[0].mlp.experts + assert hasattr(experts, "0") + expert0 = getattr(experts, "0") + assert hasattr(expert0, "gate_proj") + assert hasattr(expert0, "up_proj") + assert hasattr(expert0, "down_proj") def test_qwen3_5_moe(): From b12ee7a3426aa29f1491e8aee714dea117c7e261 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Mon, 9 Mar 2026 07:48:03 +0000 Subject: [PATCH 3/5] update version Signed-off-by: ZX-ModelCloud --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a9ba9aa..4a081d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "Defuser" -version = "0.0.4" +version = "0.0.5" description = "Model defuser helper for HF Transformers." readme = "README.md" requires-python = ">=3.9" From 6a3cb4e3dae88db85265c0cc3d4ab9985f496082 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Mon, 9 Mar 2026 08:15:05 +0000 Subject: [PATCH 4/5] cleanup Signed-off-by: ZX-ModelCloud --- defuser/defuser.py | 50 +++++++++++-------- defuser/model_registry.py | 21 -------- defuser/modeling/fused_moe/__init__.py | 0 .../{fused_moe => }/moe_experts_interface.py | 0 .../{fused_moe => }/replace_modules.py | 3 +- defuser/modeling/unfused_moe/__init__.py | 0 .../modeling/{fused_moe => }/update_module.py | 2 +- defuser/utils/hf.py | 34 +------------ tests/test_convert_model.py | 2 +- tests/test_replace_modules_tracker.py | 2 +- 10 files changed, 33 insertions(+), 81 deletions(-) delete mode 100644 defuser/modeling/fused_moe/__init__.py rename defuser/modeling/{fused_moe => }/moe_experts_interface.py (100%) rename defuser/modeling/{fused_moe => }/replace_modules.py (99%) delete mode 100644 defuser/modeling/unfused_moe/__init__.py rename defuser/modeling/{fused_moe => }/update_module.py (84%) diff --git a/defuser/defuser.py b/defuser/defuser.py index 6148ea5..559a2e3 100644 --- a/defuser/defuser.py +++ b/defuser/defuser.py @@ -5,9 +5,28 @@ from torch import nn -from defuser.model_registry import MODEL_CONFIG, CONVERSION_BEHAVIOR -from defuser.modeling.fused_moe.update_module import update_module -from defuser.utils.hf import patch +from defuser.model_registry import MODEL_CONFIG +from defuser.modeling.update_module import update_module +from packaging import version +import transformers + + +def check_model_compatibility(model: nn.Module) -> str: + """Validate model type and transformers version compatibility.""" + config = getattr(model, "config", None) + model_type = getattr(config, "model_type", None) + if model_type not in MODEL_CONFIG: + raise ValueError(f"Unsupported model_type: {model_type}") + + min_ver = MODEL_CONFIG[model_type].get("min_transformers_version") + current_ver = version.parse(transformers.__version__) + if min_ver and current_ver < version.parse(min_ver): + raise RuntimeError( + f"transformers>={min_ver} is required for model_type={model_type}, " + f"but found {transformers.__version__}" + ) + + return model_type def convert_hf_model( @@ -90,25 +109,12 @@ def convert_hf_model( # This ensures compatibility between the Qwen3.5 fused checkpoint format # and the runtime model implementation that operates on defused weights. - model_type = getattr(getattr(model, "config", None), "model_type", None) - if model_type not in MODEL_CONFIG: - raise ValueError(f"Unsupported model_type: {model_type}") - - behavior = MODEL_CONFIG[model_type]["behavior"] - - if behavior == CONVERSION_BEHAVIOR.REPLACE_ONLY: - if not patch(model, max_layers=max_layers): - raise RuntimeError(f"Failed to replace modules for {model_type}") - return model - - if behavior == CONVERSION_BEHAVIOR.REPLACE_AND_DEFUSE: - return update_module( - model, - cleanup_original=cleanup_original, - max_layers=max_layers, - ) - - raise ValueError(f"Unknown conversion behavior for {model_type}: {behavior}") + check_model_compatibility(model) + return update_module( + model, + cleanup_original=cleanup_original, + max_layers=max_layers, + ) __all__ = ["convert_hf_model"] diff --git a/defuser/model_registry.py b/defuser/model_registry.py index 3b9bf97..1e7a33d 100644 --- a/defuser/model_registry.py +++ b/defuser/model_registry.py @@ -3,35 +3,14 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from enum import Enum - - -class PATCH(str, Enum): - DEFUSE = "defuse" - REPLACE_MODULE = "replace_module" - - -class CONVERSION_BEHAVIOR(str, Enum): - REPLACE_ONLY = "replace_only" - REPLACE_AND_DEFUSE = "replace_and_defuse" - - MODEL_CONFIG = { "qwen3_moe": { "min_transformers_version": "5.0.0", - "behavior": CONVERSION_BEHAVIOR.REPLACE_AND_DEFUSE, - # structure path only replaces modeling structure }, "qwen3_5_moe": { "min_transformers_version": "5.2.0", - # Replacement module path imported only when the defuse workflow runs - PATCH.DEFUSE: "defuser.modeling.fused_moe.qwen3_5_moe", - "behavior": CONVERSION_BEHAVIOR.REPLACE_AND_DEFUSE, }, "qwen3_5_moe_text": { "min_transformers_version": "5.2.0", - # Replacement module path imported only when the defuse workflow runs - PATCH.DEFUSE: "defuser.modeling.fused_moe.qwen3_5_moe", - "behavior": CONVERSION_BEHAVIOR.REPLACE_AND_DEFUSE, }, } diff --git a/defuser/modeling/fused_moe/__init__.py b/defuser/modeling/fused_moe/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/defuser/modeling/fused_moe/moe_experts_interface.py b/defuser/modeling/moe_experts_interface.py similarity index 100% rename from defuser/modeling/fused_moe/moe_experts_interface.py rename to defuser/modeling/moe_experts_interface.py diff --git a/defuser/modeling/fused_moe/replace_modules.py b/defuser/modeling/replace_modules.py similarity index 99% rename from defuser/modeling/fused_moe/replace_modules.py rename to defuser/modeling/replace_modules.py index 21d65fc..5d92f38 100644 --- a/defuser/modeling/fused_moe/replace_modules.py +++ b/defuser/modeling/replace_modules.py @@ -16,7 +16,6 @@ from logbar import LogBar from tqdm import tqdm -from defuser.model_registry import MODEL_CONFIG, PATCH from defuser.utils.common import is_within_max_layers, is_transformers_version_greater_or_equal_5 logger = LogBar(__name__) @@ -194,7 +193,7 @@ def _handle_moe_modules(model: torch.nn.Module) -> list[str]: Returns: List of module names that were processed """ - from defuser.modeling.fused_moe.moe_experts_interface import ( + from defuser.modeling.moe_experts_interface import ( is_linear_loop_available, prepare_model_for_moe_quantization, ) diff --git a/defuser/modeling/unfused_moe/__init__.py b/defuser/modeling/unfused_moe/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/defuser/modeling/fused_moe/update_module.py b/defuser/modeling/update_module.py similarity index 84% rename from defuser/modeling/fused_moe/update_module.py rename to defuser/modeling/update_module.py index 2ac82e3..c56aa63 100644 --- a/defuser/modeling/fused_moe/update_module.py +++ b/defuser/modeling/update_module.py @@ -6,7 +6,7 @@ # Adapted from intel/auto-round # at https://github.com/intel/auto-round/blob/main/auto_round/special_model_handler.py -from defuser.modeling.fused_moe.replace_modules import apply_replacements, release_original_module_ +from defuser.modeling.replace_modules import apply_replacements, release_original_module_ def update_module( diff --git a/defuser/utils/hf.py b/defuser/utils/hf.py index ee81b30..b4168f9 100644 --- a/defuser/utils/hf.py +++ b/defuser/utils/hf.py @@ -16,8 +16,7 @@ from packaging import version from transformers import AutoConfig -from defuser.model_registry import MODEL_CONFIG, PATCH -from defuser.utils.common import is_within_max_layers +from defuser.model_registry import MODEL_CONFIG logger = LogBar(__name__) @@ -105,34 +104,3 @@ def pre_check_config(model_name: str | torch.nn.Module): except: return True return True - - -def patch(model: torch.nn.Module, max_layers: int | None = None) -> bool: - res = pre_check_config(model) - if not res: - return False - model_type = getattr(model.config, "model_type") - cfg = MODEL_CONFIG[model_type] - # patch blocks - for orig_path, custom_path in cfg.get(PATCH.REPLACE_MODULE, []): - orig_module_path, orig_class_name = orig_path.rsplit(".", 1) - custom_module_path, custom_class_name = custom_path.rsplit(".", 1) - try: - orig_module = importlib.import_module(orig_module_path) - custom_module = importlib.import_module(custom_module_path) - custom_class = getattr(custom_module, custom_class_name) - orig_class = getattr(orig_module, orig_class_name) - names = [] - for n, m in model.named_modules(): - if isinstance(m, orig_class): - if not is_within_max_layers(n, max_layers): - continue - names.append((n, next(m.parameters()).dtype)) - for (n, orig_dtype) in names: - model.set_submodule(n, custom_class(model.config).to(orig_dtype), True) - logger.info(f"Patched module: {orig_path} -> {custom_path}") - return True - except Exception as e: - logger.warn(f"Failed to patch {orig_path}: {e}") - return False - return False diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py index 5366264..5c3e462 100644 --- a/tests/test_convert_model.py +++ b/tests/test_convert_model.py @@ -6,7 +6,7 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForImageTextToText from defuser import convert_hf_model -from defuser.modeling.fused_moe.replace_modules import materialize_model +from defuser.modeling.replace_modules import materialize_model def test_qwen3_moe(): diff --git a/tests/test_replace_modules_tracker.py b/tests/test_replace_modules_tracker.py index fd7d223..2aa33fe 100644 --- a/tests/test_replace_modules_tracker.py +++ b/tests/test_replace_modules_tracker.py @@ -9,7 +9,7 @@ import pytest import torch -from defuser.modeling.fused_moe.replace_modules import ( +from defuser.modeling.replace_modules import ( ModuleReplacementTracker, ReplacementModuleBase, release_original_module_, From b91713098a4c36e43dc7490cc36195b38f2abd91 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Mon, 9 Mar 2026 09:45:06 +0000 Subject: [PATCH 5/5] cleanup Signed-off-by: ZX-ModelCloud --- defuser/defuser.py | 15 +++++++++------ tests/test_convert_model.py | 6 +++--- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/defuser/defuser.py b/defuser/defuser.py index 559a2e3..02931c8 100644 --- a/defuser/defuser.py +++ b/defuser/defuser.py @@ -9,9 +9,11 @@ from defuser.modeling.update_module import update_module from packaging import version import transformers +from logbar import LogBar +logger = LogBar(__name__) -def check_model_compatibility(model: nn.Module) -> str: +def check_model_compatibility(model: nn.Module) -> bool: """Validate model type and transformers version compatibility.""" config = getattr(model, "config", None) model_type = getattr(config, "model_type", None) @@ -21,15 +23,16 @@ def check_model_compatibility(model: nn.Module) -> str: min_ver = MODEL_CONFIG[model_type].get("min_transformers_version") current_ver = version.parse(transformers.__version__) if min_ver and current_ver < version.parse(min_ver): - raise RuntimeError( - f"transformers>={min_ver} is required for model_type={model_type}, " - f"but found {transformers.__version__}" + logger.warn( + f"Skip conversion for model_type={model_type}: " + f"requires transformers>={min_ver}, current version is {transformers.__version__}." ) + return False - return model_type + return True -def convert_hf_model( +def convert_model( model: nn.Module, cleanup_original: bool = False, max_layers: int | None = None, diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py index 5c3e462..718d5f6 100644 --- a/tests/test_convert_model.py +++ b/tests/test_convert_model.py @@ -5,7 +5,7 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForImageTextToText -from defuser import convert_hf_model +from defuser import convert_model from defuser.modeling.replace_modules import materialize_model @@ -21,7 +21,7 @@ def test_qwen3_moe(): assert model.config.model_type == "qwen3_moe" - converted = convert_hf_model(model, max_layers=1) + converted = convert_model(model, max_layers=1) assert converted experts = model.model.layers[0].mlp.experts @@ -54,7 +54,7 @@ def test_qwen3_5_moe(): expected_up = original_moe_block.experts.gate_up_proj[0, intermediate_dim:, :hidden_dim].contiguous().clone() expected_down = original_moe_block.experts.down_proj[0, :hidden_dim, :intermediate_dim].contiguous().clone() - converted = convert_hf_model(model, cleanup_original=False, max_layers=1) + converted = convert_model(model, cleanup_original=False, max_layers=1) assert converted moe_block = model.model.language_model.layers[0].mlp