Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions defuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from defuser.utils.hf import env_flag

DEBUG_ON = env_flag("DEBUG")

def convert_model(*args, **kwargs):
"""Lazily import conversion entrypoint to avoid import-time cycles."""
from .defuser import convert_model as _convert_model
Expand Down
1 change: 0 additions & 1 deletion defuser/defuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from torch import nn

from defuser.model_registry import MODEL_CONFIG
Expand Down
20 changes: 11 additions & 9 deletions defuser/modeling/moe_experts_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

from defuser.utils.device import clear_memory, to_meta

from defuser import DEBUG_ON

logger = LogBar(__name__)

try:
Expand Down Expand Up @@ -108,7 +110,7 @@ def linear_loop_experts_forward(
Returns:
final_hidden_states: Output tensor of shape (num_tokens, hidden_dim)
"""
logger.debug(f"Using {LINEAR_LOOP_IMPL} experts forward for {self.__class__.__name__}")
if DEBUG_ON: logger.debug(f"Using {LINEAR_LOOP_IMPL} experts forward for {self.__class__.__name__}")

# Handle [batch_size, seq_len, hidden_dim] input format
if hidden_states.dim() == 3:
Expand Down Expand Up @@ -194,7 +196,7 @@ def register_linear_loop_experts() -> bool:

if LINEAR_LOOP_IMPL not in ALL_EXPERTS_FUNCTIONS._global_mapping:
ALL_EXPERTS_FUNCTIONS._global_mapping[LINEAR_LOOP_IMPL] = linear_loop_experts_forward
logger.debug(f"Registered '{LINEAR_LOOP_IMPL}' experts implementation")
if DEBUG_ON: logger.debug(f"Registered '{LINEAR_LOOP_IMPL}' experts implementation")

return True

Expand Down Expand Up @@ -225,7 +227,7 @@ def _detect_expert_projections(module: nn.Module) -> dict[str, dict]:
param = getattr(module, attr_name, None)
if param is not None and isinstance(param, nn.Parameter) and param.dim() == 3:
# Use default config for unknown projections
logger.debug(f"Discovered unknown 3D projection: {attr_name}")
if DEBUG_ON: logger.debug(f"Discovered unknown 3D projection: {attr_name}")
detected[attr_name] = {"is_input_proj": True, "output_multiplier": 1}

return detected
Expand Down Expand Up @@ -359,7 +361,7 @@ def _unfuse_single_projection(
setattr(module, proj_name, to_meta(param))
if has_bias:
setattr(module, bias_name, to_meta(bias_param))
logger.debug(f"Released memory for {proj_name} using to_meta()")
if DEBUG_ON: logger.debug(f"Released memory for {proj_name} using to_meta()")
except Exception:
pass

Expand Down Expand Up @@ -483,7 +485,7 @@ def _unfuse_experts_weights_inplace(

# Only unfuse if the module supports the decorator (unless check_decorator is False)
if check_decorator and not _experts_supports_decorator(module):
logger.debug(f"Skipping unfuse for {module.__class__.__name__}: does not support @use_experts_implementation")
if DEBUG_ON: logger.debug(f"Skipping unfuse for {module.__class__.__name__}: does not support @use_experts_implementation")
return False

# Get first projection to determine num_experts and layout
Expand Down Expand Up @@ -535,7 +537,7 @@ def _unfuse_experts_weights_inplace(
if bias_param is not None:
delattr(module, bias_name)
fused_to_remove.append(proj_name)
logger.debug(f"Split {proj_name} -> {split_into}: {num_experts} experts")
if DEBUG_ON: logger.debug(f"Split {proj_name} -> {split_into}: {num_experts} experts")

# Remove fused entries and add split entries
for name in fused_to_remove:
Expand Down Expand Up @@ -616,20 +618,20 @@ def prepare_model_for_moe_quantization(model: nn.Module, implementation: str = L
for name, module in model.named_modules():
if _unfuse_experts_weights_inplace(module):
unfused_modules.append(name)
logger.debug(f"[MoE Prep] Unfused '{name}'")
if DEBUG_ON: logger.debug(f"[MoE Prep] Unfused '{name}'")

# Only set config if we actually unfused something
# Models that don't support the decorator (like Llama4) won't have anything unfused
# and should use full module replacement instead
if unfused_modules:
logger.info(f"[MoE Prep] Unfused {len(unfused_modules)} MOE experts modules")
if DEBUG_ON: logger.info(f"[MoE Prep] Unfused {len(unfused_modules)} MOE experts modules")
clear_memory()

# Set config for linear_loop forward
if hasattr(model, "config"):
saved_impl = getattr(model.config, "experts_implementation", None)
impl_to_set = saved_impl if saved_impl else implementation
model.config._experts_implementation = impl_to_set
logger.debug(f"Set model.config._experts_implementation = '{impl_to_set}'")
if DEBUG_ON: logger.debug(f"Set model.config._experts_implementation = '{impl_to_set}'")

return unfused_modules
27 changes: 15 additions & 12 deletions defuser/modeling/replace_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from defuser.utils.common import is_within_max_layers, is_transformers_version_greater_or_equal_5

from defuser import DEBUG_ON

logger = LogBar(__name__)


Expand All @@ -42,7 +44,7 @@ def _materialize_module(module: torch.nn.Module) -> None:
found_meta = True

if not found_meta:
logger.debug("All parameters and buffers have been materialized from meta device.")
if DEBUG_ON: logger.debug("All parameters and buffers have been materialized from meta device.")
release_original_module_(model)


Expand Down Expand Up @@ -90,7 +92,7 @@ def __init_subclass__(cls, **kwargs):
)

cls._replacement_registry[cls.original_module_class()] = cls
logger.debug(f"Registered {cls.__name__} for replacing {cls.original_module_class()}")
if DEBUG_ON: logger.debug(f"Registered {cls.__name__} for replacing {cls.original_module_class()}")

def __init__(self, original: torch.nn.Module):
super().__init__()
Expand Down Expand Up @@ -261,7 +263,7 @@ def _apply_custom_replacements(
replaced = []

# Step 1: Collect all modules that need replacement
logger.debug("Scanning for modules to replace")
if DEBUG_ON: logger.debug("Scanning for modules to replace")
modules_to_replace = []
for name, module in model.named_modules():
# skip replaced modules
Expand All @@ -283,13 +285,14 @@ def _apply_custom_replacements(
# The module might have been replaced earlier in the loop (parent-first replacement).
# Skip if the class has changed or it no longer matches replacement criteria.
if module.__class__.__name__ != class_name:
logger.debug(
f"Skipping replacement for {name}: class changed from {class_name} to {module.__class__.__name__}"
)
if DEBUG_ON:
logger.debug(
f"Skipping replacement for {name}: class changed from {class_name} to {module.__class__.__name__}"
)
continue
replacement_cls = ReplacementModuleBase.get_replacement_class(class_name)
if not replacement_cls.is_to_be_replaced(module):
logger.debug(f"Skipping replacement for {name}: no longer matches replacement criteria")
if DEBUG_ON: logger.debug(f"Skipping replacement for {name}: no longer matches replacement criteria")
continue
orig_dtype = next(module.parameters()).dtype
replacement = replacement_cls.from_original(
Expand All @@ -299,7 +302,7 @@ def _apply_custom_replacements(
model.set_submodule(name, replacement)
replaced.append((name, replacement_cls))
else:
logger.debug("No modules found for replacement")
if DEBUG_ON: logger.debug("No modules found for replacement")

# Log what was replaced
if replaced:
Expand Down Expand Up @@ -362,7 +365,7 @@ def register_replacement(
replacement_module_class=replacement.__class__.__name__,
replacement_module_ref=weakref.ref(replacement),
)
logger.debug(f"Registered replacement for module: {name}")
if DEBUG_ON: logger.debug(f"Registered replacement for module: {name}")

def get_original(self, replacement: ReplacementModuleBase) -> torch.nn.Module | None:
"""Get the original module for a given replacement module."""
Expand All @@ -388,7 +391,7 @@ def release_original(self, replacement: ReplacementModuleBase) -> None:
replacement_ref = info.replacement_module_ref()
if replacement_ref is None or replacement_ref is replacement:
del self._name_to_info[name]
logger.debug(f"Released original module for replacement {replacement_id}")
if DEBUG_ON: logger.debug(f"Released original module for replacement {replacement_id}")

def release_all_originals(self) -> None:
"""Release all tracked original modules."""
Expand All @@ -402,13 +405,13 @@ def release_all_originals(self) -> None:
self._replacement_to_name.clear()
self._name_to_info.clear()
if count > 0:
logger.debug(f"Released {count} original modules from tracker")
if DEBUG_ON: logger.debug(f"Released {count} original modules from tracker")

def clear(self) -> None:
"""Clear all tracked information."""
self._replacement_to_name.clear()
self._name_to_info.clear()
logger.debug("Cleared module replacement tracker")
if DEBUG_ON: logger.debug("Cleared module replacement tracker")


_global_tracker = ModuleReplacementTracker()
17 changes: 16 additions & 1 deletion defuser/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import os
# Adapted from intel/auto-round
# at https://github.com/intel/auto-round/blob/main/auto_round/utils/common.py

Expand All @@ -13,6 +13,21 @@
# Match module paths like "...layers.0..." and capture the numeric layer index.
_LAYER_NAME_RE = re.compile(r"(?:^|\.)layers\.(\d+)(?:\.|$)")

TRUTHFUL = {"1", "true", "yes", "on", "y"}


def env_flag(name: str, default: str | bool | None = "0") -> bool:
"""Return ``True`` when an env var is set to a truthy value."""

value = os.getenv(name)
if value is None:
if default is None:
return False
if isinstance(default, bool):
return default
value = default
return str(value).strip().lower() in TRUTHFUL


@lru_cache(None)
def is_transformers_version_greater_or_equal_5():
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "Defuser"
version = "0.0.6"
version = "0.0.7"
description = "Model defuser helper for HF Transformers."
readme = "README.md"
requires-python = ">=3.9"
Expand Down