diff --git a/defuser/__init__.py b/defuser/__init__.py index 4de60ab..5417a08 100644 --- a/defuser/__init__.py +++ b/defuser/__init__.py @@ -3,6 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +from defuser.utils.hf import env_flag + +DEBUG_ON = env_flag("DEBUG") + def convert_model(*args, **kwargs): """Lazily import conversion entrypoint to avoid import-time cycles.""" from .defuser import convert_model as _convert_model diff --git a/defuser/defuser.py b/defuser/defuser.py index 55d1fff..5985755 100644 --- a/defuser/defuser.py +++ b/defuser/defuser.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium - from torch import nn from defuser.model_registry import MODEL_CONFIG diff --git a/defuser/modeling/moe_experts_interface.py b/defuser/modeling/moe_experts_interface.py index ee9d746..7145a89 100644 --- a/defuser/modeling/moe_experts_interface.py +++ b/defuser/modeling/moe_experts_interface.py @@ -31,6 +31,8 @@ from defuser.utils.device import clear_memory, to_meta +from defuser import DEBUG_ON + logger = LogBar(__name__) try: @@ -108,7 +110,7 @@ def linear_loop_experts_forward( Returns: final_hidden_states: Output tensor of shape (num_tokens, hidden_dim) """ - logger.debug(f"Using {LINEAR_LOOP_IMPL} experts forward for {self.__class__.__name__}") + if DEBUG_ON: logger.debug(f"Using {LINEAR_LOOP_IMPL} experts forward for {self.__class__.__name__}") # Handle [batch_size, seq_len, hidden_dim] input format if hidden_states.dim() == 3: @@ -194,7 +196,7 @@ def register_linear_loop_experts() -> bool: if LINEAR_LOOP_IMPL not in ALL_EXPERTS_FUNCTIONS._global_mapping: ALL_EXPERTS_FUNCTIONS._global_mapping[LINEAR_LOOP_IMPL] = linear_loop_experts_forward - logger.debug(f"Registered '{LINEAR_LOOP_IMPL}' experts implementation") + if DEBUG_ON: logger.debug(f"Registered '{LINEAR_LOOP_IMPL}' experts implementation") return True @@ -225,7 +227,7 @@ def _detect_expert_projections(module: nn.Module) -> dict[str, dict]: param = getattr(module, attr_name, None) if param is not None and isinstance(param, nn.Parameter) and param.dim() == 3: # Use default config for unknown projections - logger.debug(f"Discovered unknown 3D projection: {attr_name}") + if DEBUG_ON: logger.debug(f"Discovered unknown 3D projection: {attr_name}") detected[attr_name] = {"is_input_proj": True, "output_multiplier": 1} return detected @@ -359,7 +361,7 @@ def _unfuse_single_projection( setattr(module, proj_name, to_meta(param)) if has_bias: setattr(module, bias_name, to_meta(bias_param)) - logger.debug(f"Released memory for {proj_name} using to_meta()") + if DEBUG_ON: logger.debug(f"Released memory for {proj_name} using to_meta()") except Exception: pass @@ -483,7 +485,7 @@ def _unfuse_experts_weights_inplace( # Only unfuse if the module supports the decorator (unless check_decorator is False) if check_decorator and not _experts_supports_decorator(module): - logger.debug(f"Skipping unfuse for {module.__class__.__name__}: does not support @use_experts_implementation") + if DEBUG_ON: logger.debug(f"Skipping unfuse for {module.__class__.__name__}: does not support @use_experts_implementation") return False # Get first projection to determine num_experts and layout @@ -535,7 +537,7 @@ def _unfuse_experts_weights_inplace( if bias_param is not None: delattr(module, bias_name) fused_to_remove.append(proj_name) - logger.debug(f"Split {proj_name} -> {split_into}: {num_experts} experts") + if DEBUG_ON: logger.debug(f"Split {proj_name} -> {split_into}: {num_experts} experts") # Remove fused entries and add split entries for name in fused_to_remove: @@ -616,13 +618,13 @@ def prepare_model_for_moe_quantization(model: nn.Module, implementation: str = L for name, module in model.named_modules(): if _unfuse_experts_weights_inplace(module): unfused_modules.append(name) - logger.debug(f"[MoE Prep] Unfused '{name}'") + if DEBUG_ON: logger.debug(f"[MoE Prep] Unfused '{name}'") # Only set config if we actually unfused something # Models that don't support the decorator (like Llama4) won't have anything unfused # and should use full module replacement instead if unfused_modules: - logger.info(f"[MoE Prep] Unfused {len(unfused_modules)} MOE experts modules") + if DEBUG_ON: logger.info(f"[MoE Prep] Unfused {len(unfused_modules)} MOE experts modules") clear_memory() # Set config for linear_loop forward @@ -630,6 +632,6 @@ def prepare_model_for_moe_quantization(model: nn.Module, implementation: str = L saved_impl = getattr(model.config, "experts_implementation", None) impl_to_set = saved_impl if saved_impl else implementation model.config._experts_implementation = impl_to_set - logger.debug(f"Set model.config._experts_implementation = '{impl_to_set}'") + if DEBUG_ON: logger.debug(f"Set model.config._experts_implementation = '{impl_to_set}'") return unfused_modules diff --git a/defuser/modeling/replace_modules.py b/defuser/modeling/replace_modules.py index 5d92f38..8fb19b2 100644 --- a/defuser/modeling/replace_modules.py +++ b/defuser/modeling/replace_modules.py @@ -18,6 +18,8 @@ from defuser.utils.common import is_within_max_layers, is_transformers_version_greater_or_equal_5 +from defuser import DEBUG_ON + logger = LogBar(__name__) @@ -42,7 +44,7 @@ def _materialize_module(module: torch.nn.Module) -> None: found_meta = True if not found_meta: - logger.debug("All parameters and buffers have been materialized from meta device.") + if DEBUG_ON: logger.debug("All parameters and buffers have been materialized from meta device.") release_original_module_(model) @@ -90,7 +92,7 @@ def __init_subclass__(cls, **kwargs): ) cls._replacement_registry[cls.original_module_class()] = cls - logger.debug(f"Registered {cls.__name__} for replacing {cls.original_module_class()}") + if DEBUG_ON: logger.debug(f"Registered {cls.__name__} for replacing {cls.original_module_class()}") def __init__(self, original: torch.nn.Module): super().__init__() @@ -261,7 +263,7 @@ def _apply_custom_replacements( replaced = [] # Step 1: Collect all modules that need replacement - logger.debug("Scanning for modules to replace") + if DEBUG_ON: logger.debug("Scanning for modules to replace") modules_to_replace = [] for name, module in model.named_modules(): # skip replaced modules @@ -283,13 +285,14 @@ def _apply_custom_replacements( # The module might have been replaced earlier in the loop (parent-first replacement). # Skip if the class has changed or it no longer matches replacement criteria. if module.__class__.__name__ != class_name: - logger.debug( - f"Skipping replacement for {name}: class changed from {class_name} to {module.__class__.__name__}" - ) + if DEBUG_ON: + logger.debug( + f"Skipping replacement for {name}: class changed from {class_name} to {module.__class__.__name__}" + ) continue replacement_cls = ReplacementModuleBase.get_replacement_class(class_name) if not replacement_cls.is_to_be_replaced(module): - logger.debug(f"Skipping replacement for {name}: no longer matches replacement criteria") + if DEBUG_ON: logger.debug(f"Skipping replacement for {name}: no longer matches replacement criteria") continue orig_dtype = next(module.parameters()).dtype replacement = replacement_cls.from_original( @@ -299,7 +302,7 @@ def _apply_custom_replacements( model.set_submodule(name, replacement) replaced.append((name, replacement_cls)) else: - logger.debug("No modules found for replacement") + if DEBUG_ON: logger.debug("No modules found for replacement") # Log what was replaced if replaced: @@ -362,7 +365,7 @@ def register_replacement( replacement_module_class=replacement.__class__.__name__, replacement_module_ref=weakref.ref(replacement), ) - logger.debug(f"Registered replacement for module: {name}") + if DEBUG_ON: logger.debug(f"Registered replacement for module: {name}") def get_original(self, replacement: ReplacementModuleBase) -> torch.nn.Module | None: """Get the original module for a given replacement module.""" @@ -388,7 +391,7 @@ def release_original(self, replacement: ReplacementModuleBase) -> None: replacement_ref = info.replacement_module_ref() if replacement_ref is None or replacement_ref is replacement: del self._name_to_info[name] - logger.debug(f"Released original module for replacement {replacement_id}") + if DEBUG_ON: logger.debug(f"Released original module for replacement {replacement_id}") def release_all_originals(self) -> None: """Release all tracked original modules.""" @@ -402,13 +405,13 @@ def release_all_originals(self) -> None: self._replacement_to_name.clear() self._name_to_info.clear() if count > 0: - logger.debug(f"Released {count} original modules from tracker") + if DEBUG_ON: logger.debug(f"Released {count} original modules from tracker") def clear(self) -> None: """Clear all tracked information.""" self._replacement_to_name.clear() self._name_to_info.clear() - logger.debug("Cleared module replacement tracker") + if DEBUG_ON: logger.debug("Cleared module replacement tracker") _global_tracker = ModuleReplacementTracker() diff --git a/defuser/utils/common.py b/defuser/utils/common.py index 870671b..8d49a4c 100644 --- a/defuser/utils/common.py +++ b/defuser/utils/common.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium - +import os # Adapted from intel/auto-round # at https://github.com/intel/auto-round/blob/main/auto_round/utils/common.py @@ -13,6 +13,21 @@ # Match module paths like "...layers.0..." and capture the numeric layer index. _LAYER_NAME_RE = re.compile(r"(?:^|\.)layers\.(\d+)(?:\.|$)") +TRUTHFUL = {"1", "true", "yes", "on", "y"} + + +def env_flag(name: str, default: str | bool | None = "0") -> bool: + """Return ``True`` when an env var is set to a truthy value.""" + + value = os.getenv(name) + if value is None: + if default is None: + return False + if isinstance(default, bool): + return default + value = default + return str(value).strip().lower() in TRUTHFUL + @lru_cache(None) def is_transformers_version_greater_or_equal_5(): diff --git a/pyproject.toml b/pyproject.toml index 71a1dd3..8c0dcd2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "Defuser" -version = "0.0.6" +version = "0.0.7" description = "Model defuser helper for HF Transformers." readme = "README.md" requires-python = ">=3.9"