Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 76 additions & 53 deletions defuser/defuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,37 @@

from torch import nn

from defuser.modeling.fused_moe.update_module import update_module
from defuser.utils.hf import patch
from defuser.model_registry import MODEL_CONFIG
from defuser.modeling.update_module import update_module
from packaging import version
import transformers
from logbar import LogBar

logger = LogBar(__name__)

def convert_hf_model(
model: nn.Module,
cleanup_original: bool = False,
max_layers: int | None = None,
def check_model_compatibility(model: nn.Module) -> bool:
"""Validate model type and transformers version compatibility."""
config = getattr(model, "config", None)
model_type = getattr(config, "model_type", None)
if model_type not in MODEL_CONFIG:
raise ValueError(f"Unsupported model_type: {model_type}")

min_ver = MODEL_CONFIG[model_type].get("min_transformers_version")
current_ver = version.parse(transformers.__version__)
if min_ver and current_ver < version.parse(min_ver):
logger.warn(
f"Skip conversion for model_type={model_type}: "
f"requires transformers>={min_ver}, current version is {transformers.__version__}."
)
return False

return True


def convert_model(
model: nn.Module,
cleanup_original: bool = False,
max_layers: int | None = None,
) -> nn.Module:
if max_layers is not None and max_layers < 1:
raise ValueError("max_layers must be >= 1 when provided")
Expand Down Expand Up @@ -48,53 +71,53 @@ def convert_hf_model(
#
# If this patch succeeds, it means the model is in the Qwen3 MoE format and
# no further tensor transformation is required.
is_applied = patch(model, max_layers=max_layers)
if not is_applied:
# -----------------------------------------------------------------------
# Step 2: Handle Qwen3.5 MoE checkpoints
# -----------------------------------------------------------------------
#
# If `apply_modeling_patch` fails, we assume the checkpoint corresponds to
# **Qwen3.5 MoE**.
#
# In Qwen3.5 MoE, the expert MLP weights are stored in a **fused format**.
# Specifically, the checkpoint keeps tensors such as:
#
# gate_up_proj
# down_proj
#
# where `gate_proj` and `up_proj` are fused together.
#
# Because our runtime modeling expects **defused tensors**, simply replacing
# the module structure is not enough. We must also convert the stored
# parameters.
#
# `update_module()` performs two tasks:
#
# 1) Replace the modeling structure so that it matches the expected
# defused MoE implementation.
#
# 2) Prepare the module for **tensor defusion** of the expert weights.
#
# After the structure update, `materialize_model_()` will be invoked to
# actually split the fused tensors:
#
# gate_up_proj --> gate_proj + up_proj
#
# and ensure the module finally contains the expected parameters:
#
# gate_proj
# up_proj
# down_proj
#
# This ensures compatibility between the Qwen3.5 fused checkpoint format
# and the runtime model implementation that operates on defused weights.
model = update_module(
model,
cleanup_original=cleanup_original,
max_layers=max_layers,
)
return model

# -----------------------------------------------------------------------
# Step 2: Handle Qwen3.5 MoE checkpoints
# -----------------------------------------------------------------------
#
# If `apply_modeling_patch` fails, we assume the checkpoint corresponds to
# **Qwen3.5 MoE**.
#
# In Qwen3.5 MoE, the expert MLP weights are stored in a **fused format**.
# Specifically, the checkpoint keeps tensors such as:
#
# gate_up_proj
# down_proj
#
# where `gate_proj` and `up_proj` are fused together.
#
# Because our runtime modeling expects **defused tensors**, simply replacing
# the module structure is not enough. We must also convert the stored
# parameters.
#
# `update_module()` performs two tasks:
#
# 1) Replace the modeling structure so that it matches the expected
# defused MoE implementation.
#
# 2) Prepare the module for **tensor defusion** of the expert weights.
#
# After the structure update, `materialize_model_()` will be invoked to
# actually split the fused tensors:
#
# gate_up_proj --> gate_proj + up_proj
#
# and ensure the module finally contains the expected parameters:
#
# gate_proj
# up_proj
# down_proj
#
# This ensures compatibility between the Qwen3.5 fused checkpoint format
# and the runtime model implementation that operates on defused weights.

check_model_compatibility(model)

return update_module(
model,
cleanup_original=cleanup_original,
max_layers=max_layers,
)

__all__ = ["convert_hf_model"]
19 changes: 0 additions & 19 deletions defuser/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,14 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from enum import Enum


class PATCH(str, Enum):
DEFUSE = "defuse"
REPLACE_MODULE = "replace_module"


MODEL_CONFIG = {
"qwen3_moe": {
"min_transformers_version": "5.0.0",
# structure path only replaces modeling structure
PATCH.REPLACE_MODULE: [
(
"transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock",
"defuser.modeling.unfused_moe.qwen3_moe.LinearQwen3MoeSparseMoeBlock",
)
],
},
"qwen3_5_moe": {
"min_transformers_version": "5.2.0",
# Replacement module path imported only when the defuse workflow runs
PATCH.DEFUSE: "defuser.modeling.fused_moe.qwen3_5_moe",
},
"qwen3_5_moe_text": {
"min_transformers_version": "5.2.0",
# Replacement module path imported only when the defuse workflow runs
PATCH.DEFUSE: "defuser.modeling.fused_moe.qwen3_5_moe",
},
}
Empty file.
116 changes: 0 additions & 116 deletions defuser/modeling/fused_moe/qwen3_5_moe.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,6 @@
"gate_proj": {"is_input_proj": True, "output_multiplier": 1}, # hidden -> intermediate (gate)
"up_proj": {"is_input_proj": True, "output_multiplier": 1}, # hidden -> intermediate (up)
"down_proj": {"is_input_proj": False, "output_multiplier": 1}, # intermediate -> hidden
# Mixtral-style
"w1": {"is_input_proj": True, "output_multiplier": 1}, # gate: hidden -> intermediate
"w2": {"is_input_proj": False, "output_multiplier": 1}, # down: intermediate -> hidden
"w3": {"is_input_proj": True, "output_multiplier": 1}, # up: hidden -> intermediate
# DBRX-style
"v1": {"is_input_proj": True, "output_multiplier": 1},
"w1_proj": {"is_input_proj": True, "output_multiplier": 1},
"w2_proj": {"is_input_proj": False, "output_multiplier": 1},
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,11 @@
from logbar import LogBar
from tqdm import tqdm

from defuser.model_registry import MODEL_CONFIG, PATCH
from defuser.utils.common import is_within_max_layers
from defuser.utils.common import is_within_max_layers, is_transformers_version_greater_or_equal_5

logger = LogBar(__name__)


def is_model_patchable(model: torch.nn.Module) -> bool:
"""Check if the model has a custom replacement registered via MODEL_CONFIG.

Returns True if the model's model_type matches a key in MODEL_CONFIG.
"""
if hasattr(model, "config") and hasattr(model.config, "model_type"):
return model.config.model_type in MODEL_CONFIG
return False


def _import_required_replacements(model: torch.nn.Module) -> None:
"""Import replacement modules required for the model's defuse workflow."""
if not is_model_patchable(model):
return
model_type = model.config.model_type
module_path = MODEL_CONFIG[model_type].get(PATCH.DEFUSE)
if not module_path:
return
importlib.import_module(module_path)
logger.debug(f"Loaded replacement module for {model_type}: {module_path}")


def materialize_model(model: torch.nn.Module) -> None:
def _materialize_module(module: torch.nn.Module) -> None:
if isinstance(module, ReplacementModuleBase):
Expand Down Expand Up @@ -216,7 +193,7 @@ def _handle_moe_modules(model: torch.nn.Module) -> list[str]:
Returns:
List of module names that were processed
"""
from defuser.modeling.fused_moe.moe_experts_interface import (
from defuser.modeling.moe_experts_interface import (
is_linear_loop_available,
prepare_model_for_moe_quantization,
)
Expand Down Expand Up @@ -259,15 +236,10 @@ def apply_replacements(
Returns:
The model with modules replaced.
"""
_import_required_replacements(model)

_log_first_moe_block(model, "before replacement")

# Custom replacements first
if is_model_patchable(model):
_apply_custom_replacements(model, max_layers=max_layers)
# if auto_detect_moe and is_transformers_version_greater_or_equal_5():
# _handle_moe_modules(model)
if auto_detect_moe and is_transformers_version_greater_or_equal_5():
_handle_moe_modules(model)

_log_first_moe_block(model, "after replacement")

Expand Down
Empty file.
Loading