From 7894df2429e4ccc58516d5dcbdaead0b4aab8c31 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Mon, 8 Dec 2025 13:57:02 +0000 Subject: [PATCH 1/8] Add sage attention algorithm to pruna framework by using diffusers attention backend --- src/pruna/algorithms/sage_attn.py | 100 ++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 src/pruna/algorithms/sage_attn.py diff --git a/src/pruna/algorithms/sage_attn.py b/src/pruna/algorithms/sage_attn.py new file mode 100644 index 00000000..1e9b6ca9 --- /dev/null +++ b/src/pruna/algorithms/sage_attn.py @@ -0,0 +1,100 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +import torch +from diffusers import DiffusionPipeline + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag as tags +from pruna.engine.save import SAVE_FUNCTIONS +from pruna.config.smash_config import SmashConfigPrefixWrapper + + +class SageAttn(PrunaAlgorithmBase): + + """ + Replace torch.nn.functional.scaled_dot_product_attention with sage_attn. + + SageAttention is a fast and memory-efficient attention mechanism. It applies the flash attention mechanism + in combination with quantization and smoothing to speed up attention computations. + """ + + algorithm_name: str = "sage_attn" + group_tags: list[str] = [tags.KERNEL] + save_fn = SAVE_FUNCTIONS.reapply + references: dict[str, str] = { + "GitHub": "https://github.com/thu-ml/SageAttention", + "Kernel Hub": "https://huggingface.co/kernels-community/sage_attention", + } + tokenizer_required: bool = False + processor_required: bool = False + runs_on: list[str] = ["cuda", "accelerate"] + dataset_required: bool = False + compatible_before: Iterable[str] = ["torchao"] + compatible_after: Iterable[str] = ["fora", "torch_compile"] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model has an attention mechanism that can be replaced with sage_attn. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a valid model for the algorithm, False otherwise. + """ + if not isinstance(model, DiffusionPipeline) or not hasattr(model, "components"): + return False + + return any( + hasattr(component, "set_attention_backend") and component.dtype in [torch.bfloat16, torch.float16] + for component in model.components.values() + ) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Wrap the model to use SageAttention where possible. + + Parameters + ---------- + model : Any + The model to wrap. + smash_config : SmashConfigPrefixWrapper + The configuration for the application of the algorithm. + + Returns + ------- + Any + The wrapped model. + """ + + # We simply apply the sage attention backend from diffusers + # Furthermore, we use the sage attention kernel from the hub as the default sageattn function + # is broken (at least at the moment) + for component in model.components.values(): + if hasattr(component, "set_attention_backend") and component.dtype in [ + torch.bfloat16, + torch.float16, + ]: + component.set_attention_backend("sage_hub") + return model \ No newline at end of file From 0c597827f17b94c6e0d6a3dfdcf80b4bab13b0f3 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Mon, 8 Dec 2025 14:49:20 +0000 Subject: [PATCH 2/8] Add compatibility for sage-attn with torch-compile --- src/pruna/algorithms/sage_attn.py | 4 ++-- src/pruna/algorithms/torch_compile/torch_compile.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pruna/algorithms/sage_attn.py b/src/pruna/algorithms/sage_attn.py index 1e9b6ca9..1897211e 100644 --- a/src/pruna/algorithms/sage_attn.py +++ b/src/pruna/algorithms/sage_attn.py @@ -46,8 +46,8 @@ class SageAttn(PrunaAlgorithmBase): processor_required: bool = False runs_on: list[str] = ["cuda", "accelerate"] dataset_required: bool = False - compatible_before: Iterable[str] = ["torchao"] - compatible_after: Iterable[str] = ["fora", "torch_compile"] + compatible_before: Iterable[str] = [] + compatible_after: Iterable[str] = ["torch_compile"] def model_check_fn(self, model: Any) -> bool: """ diff --git a/src/pruna/algorithms/torch_compile/torch_compile.py b/src/pruna/algorithms/torch_compile/torch_compile.py index 62c4d086..fe09bbd2 100644 --- a/src/pruna/algorithms/torch_compile/torch_compile.py +++ b/src/pruna/algorithms/torch_compile/torch_compile.py @@ -69,6 +69,7 @@ class TorchCompile(PrunaAlgorithmBase): "flash_attn3", "deepcache", "fora", + "sage_attn", ] def get_hyperparameters(self) -> list: From c8eda60b1da2423ae82c731bd5751d16a7dcf45e Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Mon, 8 Dec 2025 15:42:25 +0000 Subject: [PATCH 3/8] =?UTF-8?q?Add=20tests=20f=C3=BCr=20sage=20attn=20algo?= =?UTF-8?q?rithm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/algorithms/testers/sage_attn.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 tests/algorithms/testers/sage_attn.py diff --git a/tests/algorithms/testers/sage_attn.py b/tests/algorithms/testers/sage_attn.py new file mode 100644 index 00000000..8fdabd30 --- /dev/null +++ b/tests/algorithms/testers/sage_attn.py @@ -0,0 +1,16 @@ +import pytest + +from pruna.algorithms.sage_attn import SageAttn + +from .base_tester import AlgorithmTesterBase + + +@pytest.mark.high +class TestSageAttn(AlgorithmTesterBase): + """Test the sage attention kernel.""" + + models = ["flux_tiny", "wan_tiny_random"] + reject_models = ["opt_tiny_random"] + allow_pickle_files = False + algorithm_class = SageAttn + metrics = ["latency"] \ No newline at end of file From 69c9679722bfec78a3ccac65235d515cced88e88 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Mon, 8 Dec 2025 15:55:34 +0000 Subject: [PATCH 4/8] Change formatting using ruff --- src/pruna/algorithms/sage_attn.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/pruna/algorithms/sage_attn.py b/src/pruna/algorithms/sage_attn.py index 1897211e..d7165b1a 100644 --- a/src/pruna/algorithms/sage_attn.py +++ b/src/pruna/algorithms/sage_attn.py @@ -22,12 +22,11 @@ from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags -from pruna.engine.save import SAVE_FUNCTIONS from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.save import SAVE_FUNCTIONS class SageAttn(PrunaAlgorithmBase): - """ Replace torch.nn.functional.scaled_dot_product_attention with sage_attn. @@ -87,7 +86,6 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: Any The wrapped model. """ - # We simply apply the sage attention backend from diffusers # Furthermore, we use the sage attention kernel from the hub as the default sageattn function # is broken (at least at the moment) @@ -97,4 +95,4 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: torch.float16, ]: component.set_attention_backend("sage_hub") - return model \ No newline at end of file + return model From 03ed96b068edba55301d315e15c35b71d9740195 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Wed, 17 Dec 2025 12:42:38 +0000 Subject: [PATCH 5/8] Quick commit, add sage_attn2++ as reference paper, add cachers and quantizers as compatible after and before, add sage_attn in corresponding cachers and quantizers algorithms as compatible, add dtype check as sage_attn only works for float/bfloat16 (double checked), add target modules (but not fully finished yet) --- src/pruna/algorithms/deepcache.py | 2 +- src/pruna/algorithms/fastercache.py | 2 +- src/pruna/algorithms/fora.py | 2 +- src/pruna/algorithms/gptq_model.py | 2 +- src/pruna/algorithms/half.py | 1 + src/pruna/algorithms/hqq.py | 2 +- src/pruna/algorithms/hqq_diffusers.py | 2 +- .../algorithms/huggingface_diffusers_int8.py | 2 +- src/pruna/algorithms/huggingface_llm_int8.py | 2 +- src/pruna/algorithms/pab.py | 2 +- src/pruna/algorithms/quanto.py | 2 +- src/pruna/algorithms/sage_attn.py | 213 +++++++++++++++++- src/pruna/algorithms/torch_dynamic.py | 2 +- src/pruna/algorithms/torchao.py | 2 +- 14 files changed, 214 insertions(+), 24 deletions(-) diff --git a/src/pruna/algorithms/deepcache.py b/src/pruna/algorithms/deepcache.py index cf2a9ad8..6b817674 100644 --- a/src/pruna/algorithms/deepcache.py +++ b/src/pruna/algorithms/deepcache.py @@ -43,7 +43,7 @@ class DeepCache(PrunaAlgorithmBase): processor_required: bool = False dataset_required: bool = False runs_on: list[str] = ["cpu", "cuda", "accelerate"] - compatible_before: Iterable[str] = ["qkv_diffusers", "half", "hqq_diffusers", "diffusers_int8", "quanto"] + compatible_before: Iterable[str] = ["qkv_diffusers", "half", "hqq_diffusers", "diffusers_int8", "quanto", "sage_attn"] compatible_after: Iterable[str] = ["stable_fast", "torch_compile"] def get_hyperparameters(self) -> list: diff --git a/src/pruna/algorithms/fastercache.py b/src/pruna/algorithms/fastercache.py index 856401e7..41dece15 100644 --- a/src/pruna/algorithms/fastercache.py +++ b/src/pruna/algorithms/fastercache.py @@ -57,7 +57,7 @@ class FasterCache(PrunaAlgorithmBase): processor_required: bool = False dataset_required: bool = False runs_on: list[str] = ["cpu", "cuda", "accelerate"] - compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8"] + compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8", "sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/fora.py b/src/pruna/algorithms/fora.py index 00697ea4..2194c87c 100644 --- a/src/pruna/algorithms/fora.py +++ b/src/pruna/algorithms/fora.py @@ -44,7 +44,7 @@ class FORA(PrunaAlgorithmBase): processor_required: bool = False runs_on: list[str] = ["cpu", "cuda", "accelerate"] dataset_required: bool = False - compatible_before: Iterable[str] = ["qkv_diffusers", "diffusers_int8", "hqq_diffusers", "torchao", "flash_attn3"] + compatible_before: Iterable[str] = ["qkv_diffusers", "diffusers_int8", "hqq_diffusers", "torchao", "flash_attn3", "sage_attn"] compatible_after: Iterable[str] = ["stable_fast", "torch_compile"] def get_hyperparameters(self) -> list: diff --git a/src/pruna/algorithms/gptq_model.py b/src/pruna/algorithms/gptq_model.py index edfbcda7..1b2c78e7 100644 --- a/src/pruna/algorithms/gptq_model.py +++ b/src/pruna/algorithms/gptq_model.py @@ -45,7 +45,7 @@ class GPTQ(PrunaAlgorithmBase): processor_required: bool = False runs_on: list[str] = ["cuda"] dataset_required: bool = True - compatible_after: Iterable[str] = ["torch_compile"] + compatible_after: Iterable[str] = ["torch_compile", "sage_attn"] required_install: str = ( "You must first install the base package with ``pip install pruna`` " "before installing the GPTQ extension with ``pip install pruna[gptq] --extra-index-url https://prunaai.pythonanywhere.com/``" diff --git a/src/pruna/algorithms/half.py b/src/pruna/algorithms/half.py index b416de27..3b0afeae 100644 --- a/src/pruna/algorithms/half.py +++ b/src/pruna/algorithms/half.py @@ -51,6 +51,7 @@ class Half(PrunaAlgorithmBase): "torch_compile", "ifw", "whisper_s2t", + "sage_attn", ] def model_check_fn(self, model: Any) -> bool: diff --git a/src/pruna/algorithms/hqq.py b/src/pruna/algorithms/hqq.py index 4ccc6804..88d3f680 100644 --- a/src/pruna/algorithms/hqq.py +++ b/src/pruna/algorithms/hqq.py @@ -52,7 +52,7 @@ class HQQ(PrunaAlgorithmBase): runs_on: list[str] = ["cuda"] dataset_required: bool = False compatible_before: Iterable[str] = ["torch_structured"] - compatible_after: Iterable[str] = ["torch_compile"] + compatible_after: Iterable[str] = ["torch_compile", "sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/hqq_diffusers.py b/src/pruna/algorithms/hqq_diffusers.py index 99f1c051..544c10ef 100644 --- a/src/pruna/algorithms/hqq_diffusers.py +++ b/src/pruna/algorithms/hqq_diffusers.py @@ -54,7 +54,7 @@ class HQQDiffusers(PrunaAlgorithmBase): runs_on: list[str] = ["cuda"] dataset_required: bool = False compatible_before: Iterable[str] = ["qkv_diffusers"] - compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile"] + compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile", "sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/huggingface_diffusers_int8.py b/src/pruna/algorithms/huggingface_diffusers_int8.py index 02b9e8d6..8a431e93 100644 --- a/src/pruna/algorithms/huggingface_diffusers_int8.py +++ b/src/pruna/algorithms/huggingface_diffusers_int8.py @@ -60,7 +60,7 @@ class DiffusersInt8(PrunaAlgorithmBase): runs_on: list[str] = ["cuda", "accelerate"] save_fn: None = None compatible_before: Iterable[str] = ["qkv_diffusers"] - compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile"] + compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile", "sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/huggingface_llm_int8.py b/src/pruna/algorithms/huggingface_llm_int8.py index eb381204..9704f909 100644 --- a/src/pruna/algorithms/huggingface_llm_int8.py +++ b/src/pruna/algorithms/huggingface_llm_int8.py @@ -56,7 +56,7 @@ class LLMInt8(PrunaAlgorithmBase): dataset_required: bool = False runs_on: list[str] = ["cuda", "accelerate"] save_fn: None = None - compatible_after: Iterable[str] = ["torch_compile"] + compatible_after: Iterable[str] = ["torch_compile", "sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/pab.py b/src/pruna/algorithms/pab.py index 2472217d..a71d0c12 100644 --- a/src/pruna/algorithms/pab.py +++ b/src/pruna/algorithms/pab.py @@ -53,7 +53,7 @@ class PAB(PrunaAlgorithmBase): processor_required: bool = False dataset_required: bool = False runs_on: list[str] = ["cpu", "cuda", "accelerate"] - compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8"] + compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8", "sage_attn"] compatible_after: Iterable[str] = [] def get_hyperparameters(self) -> list: diff --git a/src/pruna/algorithms/quanto.py b/src/pruna/algorithms/quanto.py index c60f6ad8..7e95d93b 100644 --- a/src/pruna/algorithms/quanto.py +++ b/src/pruna/algorithms/quanto.py @@ -51,7 +51,7 @@ class Quanto(PrunaAlgorithmBase): dataset_required: bool = False runs_on: list[str] = ["cuda"] compatible_before: Iterable[str] = ["qkv_diffusers"] - compatible_after: Iterable[str] = ["deepcache"] + compatible_after: Iterable[str] = ["deepcache", "sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/sage_attn.py b/src/pruna/algorithms/sage_attn.py index d7165b1a..3b5dccc7 100644 --- a/src/pruna/algorithms/sage_attn.py +++ b/src/pruna/algorithms/sage_attn.py @@ -16,15 +16,21 @@ from collections.abc import Iterable from typing import Any +import re +import fnmatch +from collections import OrderedDict import torch from diffusers import DiffusionPipeline +from pruna import SmashConfig from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags from pruna.config.smash_config import SmashConfigPrefixWrapper from pruna.engine.save import SAVE_FUNCTIONS - +from pruna.config.hyperparameters import Boolean +from pruna.config.target_modules import TARGET_MODULES_TYPE, TargetModules +from pruna.logging.logger import pruna_logger class SageAttn(PrunaAlgorithmBase): """ @@ -38,6 +44,7 @@ class SageAttn(PrunaAlgorithmBase): group_tags: list[str] = [tags.KERNEL] save_fn = SAVE_FUNCTIONS.reapply references: dict[str, str] = { + "Paper (SA2++)": "https://arxiv.org/pdf/2505.21136v3", "GitHub": "https://github.com/thu-ml/SageAttention", "Kernel Hub": "https://huggingface.co/kernels-community/sage_attention", } @@ -45,8 +52,9 @@ class SageAttn(PrunaAlgorithmBase): processor_required: bool = False runs_on: list[str] = ["cuda", "accelerate"] dataset_required: bool = False - compatible_before: Iterable[str] = [] - compatible_after: Iterable[str] = ["torch_compile"] + compatible_before: Iterable[str] = [tags.QUANTIZER] + compatible_after: Iterable[str] = ["torch_compile", tags.CACHER] + def model_check_fn(self, model: Any) -> bool: """ @@ -66,7 +74,7 @@ def model_check_fn(self, model: Any) -> bool: return False return any( - hasattr(component, "set_attention_backend") and component.dtype in [torch.bfloat16, torch.float16] + hasattr(component, "set_attention_backend") and component.dtype in (torch.bfloat16, torch.float16) for component in model.components.values() ) @@ -86,13 +94,194 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: Any The wrapped model. """ - # We simply apply the sage attention backend from diffusers - # Furthermore, we use the sage attention kernel from the hub as the default sageattn function - # is broken (at least at the moment) - for component in model.components.values(): - if hasattr(component, "set_attention_backend") and component.dtype in [ - torch.bfloat16, - torch.float16, - ]: + target_modules = smash_config.get("target_modules", None) + exclude_first_and_last_transformer_blocks = smash_config.get("exclude_first_and_last_transformer_blocks", False) + + if exclude_first_and_last_transformer_blocks: + extra_excludes = self._get_transformer_sub_excludes(model) + + if target_modules is None: + target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config) + + include_patterns = target_modules.get("include", []) + exclude_patterns = target_modules.get("exclude", []) + exclude_patterns.extend(extra_excludes) + + # Heuristic: if any pattern contains a dot, there are nested rules + def has_nested_rules(comp_name: str) -> bool: + prefix = comp_name + "." + return any(p.startswith(prefix) for p in (include_patterns + exclude_patterns)) + + def is_relevant_component_by_include(comp_name: str) -> bool: + """If includes are set, only Components to touch, that are either directly + included or have any include below them.""" + if not include_patterns: + return True + prefix = comp_name + "." + return any(p == comp_name or p.startswith(prefix) for p in include_patterns) + + def should_apply(name: str) -> bool: + """Excludes > Includes; if includes are empty => everything (except excludes).""" + if exclude_patterns and _matches_any(name, exclude_patterns): + return False + if include_patterns and not _matches_any(name, include_patterns): + return False + return True + + for comp_name, component in model.components.items(): + + # --- Check component level --- + + # 1) Component-level filter (e.g. exclude "vae"), exclude if in exclude_patterns and not in include_patterns + if exclude_patterns and _matches_any(comp_name, exclude_patterns): + continue + + # 2) Pick only relevant components + if not is_relevant_component_by_include(comp_name): + continue + + # 2) dtype guard, as sage attn is only applicable for bfloat16 and float16 + if not hasattr(component, "dtype") or component.dtype not in (torch.bfloat16, torch.float16): + continue + + # 3) If there are no nested rules for the current component, make a faster global call otherwise go to submodule level + if hasattr(component, "set_attention_backend") and not has_nested_rules(comp_name) and should_apply(comp_name): component.set_attention_backend("sage_hub") + continue + + # --- Check submodule level --- + + # 1) Check for named_modules method for step 2) to work + if component is None or not hasattr(component, "named_modules"): + continue + + # 2) Nested rules: iterate over submodules and match full_name + for sub_name, sub_module in component.named_modules(): + if not sub_name: + continue + + full_name = f"{comp_name}.{sub_name}" # e.g., transformer.blocks.0.attn1 + + if not should_apply(full_name): + continue + + if hasattr(sub_module, "set_attention_backend") and sub_module.dtype in (torch.bfloat16, torch.float16): + sub_module.set_attention_backend("sage_hub") + return model + + + def get_hyperparameters(self) -> list: + return [ + Boolean( + "exclude_first_and_last_transformer_blocks", + default=False, + meta=dict(desc="If True, do NOT apply SageAttention to the first and last transformer blocks for each transformer component."), + ), + TargetModules(name="target_modules", default_value=None), + ] + + # def get_model_dependent_hyperparameter_defaults( + # self, + # model: Any, + # smash_config: SmashConfigPrefixWrapper, + # ) -> TARGET_MODULES_TYPE: + # # So far we just exclude, per default everything is included + # # Filtering is done in the _apply method by the set_attention_backend method + # include = ["*"] + # exclude = [] + + # if smash_config["exclude_first_and_last_transformer_blocks"]: + # exclude = self._get_transformer_sub_excludes(model) + + # if not exclude: + # print( + # "exclude_first_and_last_transformer_blocks enabled, " + # "but no transformer blocks were found for exclusion." + # ) + + # return {"include": include, "exclude": exclude} + + def get_model_dependent_hyperparameter_defaults( + self, + model: Any, + smash_config: SmashConfigPrefixWrapper, + ) -> TARGET_MODULES_TYPE: + # So far, everything is included and nothing + # Filtering is done in the _apply method by the set_attention_backend method + include = ["*"] + exclude = [] + + return {"include": include, "exclude": exclude} + + def _get_transformer_sub_excludes( + self, + model: Any, + ) -> list[str]: + """ + Returns a flat list of glob patterns, e.g. + [ + "transformer.blocks.0*", + "transformer.blocks.39*", + "transformer_2.blocks.0*", + "transformer_2.blocks.39*", + ] + """ + excludes: list[str] = [] + + roots = self._get_transformer_roots(model) + + for root in roots: + # get the component + comp = model.components.get(root, None) + # if the component is None/missing (e.g. the case for transformer_2 in Wan2.2-TI2V-5B-Diffusers), skip it + if comp is None: + print(f"skip {root}: component is None/missing") + continue + # get the attention names + attn_names = [ + name + for name, module in model.components[root].named_modules() + if name and hasattr(module, "set_attention_backend") + ] + # if there are no attention names, skip it + if not attn_names: + continue + + # get the block paths + block_paths = _unique_in_order([n.rsplit(".", 1)[0] for n in attn_names]) + + # if there are less than 3 block paths, skip it + if len(block_paths) < 3: + pruna_logger.warning(f"Root {root} has less than 3 transformer blocks. Thus its first and last blocks are not excluded for sage_attn.") + continue + + # We just want to exclude the first and last blocks of the transformer components + excludes.extend([ + f"{root}.{block_paths[0]}*", + f"{root}.{block_paths[-1]}*", + ]) + + return excludes + + def _get_transformer_roots(self, model: Any) -> list[str]: + roots = [] + for name, _ in model.components.items(): + if name == "transformer" or name.startswith("transformer_"): + roots.append(name) + + # Sort the roots by the number of the transformer component, just to be sure + def key(n: str) -> int: + # transformer -> 0, transformer_10 -> 10 + if n == "transformer": + return 0 + m = re.match(r"transformer_(\d+)$", n) + return int(m.group(1)) if m else 10**9 # unknown suffix goes to end + + return sorted(roots, key=key) + +def _unique_in_order(items: list[str]) -> list[str]: + return list(OrderedDict.fromkeys(items)) + +def _matches_any(name: str, patterns: list[str]) -> bool: + return any(fnmatch.fnmatch(name, pat) for pat in (patterns or [])) \ No newline at end of file diff --git a/src/pruna/algorithms/torch_dynamic.py b/src/pruna/algorithms/torch_dynamic.py index fc97ca90..e3b1ac3f 100644 --- a/src/pruna/algorithms/torch_dynamic.py +++ b/src/pruna/algorithms/torch_dynamic.py @@ -43,7 +43,7 @@ class TorchDynamic(PrunaAlgorithmBase): runs_on: list[str] = ["cpu", "cuda"] dataset_required: bool = False compatible_before: Iterable[str] = [] - compatible_after: Iterable[str] = [] + compatible_after: Iterable[str] = ["sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/torchao.py b/src/pruna/algorithms/torchao.py index d0a10ee1..63550972 100644 --- a/src/pruna/algorithms/torchao.py +++ b/src/pruna/algorithms/torchao.py @@ -90,7 +90,7 @@ class Torchao(PrunaAlgorithmBase): runs_on: list[str] = ["cpu", "cuda", "accelerate"] dataset_required: bool = False compatible_before: Iterable[str] = ["qkv_diffusers", "torch_structured"] - compatible_after: Iterable[str] = ["flash_attn3", "fora", "torch_compile"] + compatible_after: Iterable[str] = ["flash_attn3", "fora", "torch_compile", "sage_attn"] def get_hyperparameters(self) -> list: """ From e7414aa89ae46621e243dd0f7e21484255e29a3d Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Wed, 17 Dec 2025 16:55:24 +0000 Subject: [PATCH 6/8] Add target modules including hyperparameter for excluding first and last attention block per attention component. Remove dtype gaurd as dtypes of q, k, and v per attn module is implicitly checked by sage attention kernel. --- src/pruna/algorithms/deepcache.py | 9 +- src/pruna/algorithms/fora.py | 9 +- src/pruna/algorithms/sage_attn.py | 205 ++++++++++++++---------------- 3 files changed, 111 insertions(+), 112 deletions(-) diff --git a/src/pruna/algorithms/deepcache.py b/src/pruna/algorithms/deepcache.py index 6b817674..6f404224 100644 --- a/src/pruna/algorithms/deepcache.py +++ b/src/pruna/algorithms/deepcache.py @@ -43,7 +43,14 @@ class DeepCache(PrunaAlgorithmBase): processor_required: bool = False dataset_required: bool = False runs_on: list[str] = ["cpu", "cuda", "accelerate"] - compatible_before: Iterable[str] = ["qkv_diffusers", "half", "hqq_diffusers", "diffusers_int8", "quanto", "sage_attn"] + compatible_before: Iterable[str] = [ + "qkv_diffusers", + "half", + "hqq_diffusers", + "diffusers_int8", + "quanto", + "sage_attn", + ] compatible_after: Iterable[str] = ["stable_fast", "torch_compile"] def get_hyperparameters(self) -> list: diff --git a/src/pruna/algorithms/fora.py b/src/pruna/algorithms/fora.py index 2194c87c..fb9f08e6 100644 --- a/src/pruna/algorithms/fora.py +++ b/src/pruna/algorithms/fora.py @@ -44,7 +44,14 @@ class FORA(PrunaAlgorithmBase): processor_required: bool = False runs_on: list[str] = ["cpu", "cuda", "accelerate"] dataset_required: bool = False - compatible_before: Iterable[str] = ["qkv_diffusers", "diffusers_int8", "hqq_diffusers", "torchao", "flash_attn3", "sage_attn"] + compatible_before: Iterable[str] = [ + "qkv_diffusers", + "diffusers_int8", + "hqq_diffusers", + "torchao", + "flash_attn3", + "sage_attn" + ] compatible_after: Iterable[str] = ["stable_fast", "torch_compile"] def get_hyperparameters(self) -> list: diff --git a/src/pruna/algorithms/sage_attn.py b/src/pruna/algorithms/sage_attn.py index 3b5dccc7..327db67c 100644 --- a/src/pruna/algorithms/sage_attn.py +++ b/src/pruna/algorithms/sage_attn.py @@ -14,24 +14,24 @@ from __future__ import annotations -from collections.abc import Iterable -from typing import Any -import re import fnmatch +import re from collections import OrderedDict +from collections.abc import Iterable +from typing import Any import torch from diffusers import DiffusionPipeline -from pruna import SmashConfig from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags -from pruna.config.smash_config import SmashConfigPrefixWrapper -from pruna.engine.save import SAVE_FUNCTIONS from pruna.config.hyperparameters import Boolean +from pruna.config.smash_config import SmashConfigPrefixWrapper from pruna.config.target_modules import TARGET_MODULES_TYPE, TargetModules +from pruna.engine.save import SAVE_FUNCTIONS from pruna.logging.logger import pruna_logger + class SageAttn(PrunaAlgorithmBase): """ Replace torch.nn.functional.scaled_dot_product_attention with sage_attn. @@ -55,7 +55,6 @@ class SageAttn(PrunaAlgorithmBase): compatible_before: Iterable[str] = [tags.QUANTIZER] compatible_after: Iterable[str] = ["torch_compile", tags.CACHER] - def model_check_fn(self, model: Any) -> bool: """ Check if the model has an attention mechanism that can be replaced with sage_attn. @@ -94,14 +93,16 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: Any The wrapped model. """ - target_modules = smash_config.get("target_modules", None) - exclude_first_and_last_transformer_blocks = smash_config.get("exclude_first_and_last_transformer_blocks", False) - - if exclude_first_and_last_transformer_blocks: - extra_excludes = self._get_transformer_sub_excludes(model) - + target_modules = smash_config["target_modules"] + exclude_first_and_last_transformer_blocks = smash_config["exclude_first_and_last_transformer_blocks"] + if target_modules is None: - target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config) + target_modules = self.get_model_dependent_hyperparameter_defaults( + model, + smash_config + ) # for consistency, not used yet + + extra_excludes = (_get_transformer_sub_excludes(model) if exclude_first_and_last_transformer_blocks else []) include_patterns = target_modules.get("include", []) exclude_patterns = target_modules.get("exclude", []) @@ -111,22 +112,19 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: def has_nested_rules(comp_name: str) -> bool: prefix = comp_name + "." return any(p.startswith(prefix) for p in (include_patterns + exclude_patterns)) - + def is_relevant_component_by_include(comp_name: str) -> bool: - """If includes are set, only Components to touch, that are either directly - included or have any include below them.""" if not include_patterns: return True + if _matches_any(comp_name, include_patterns): # "*", "transformer", ... + return True prefix = comp_name + "." - return any(p == comp_name or p.startswith(prefix) for p in include_patterns) + return any(p.startswith(prefix) for p in include_patterns) # "transformer.*", "transformer.blocks.*" def should_apply(name: str) -> bool: - """Excludes > Includes; if includes are empty => everything (except excludes).""" if exclude_patterns and _matches_any(name, exclude_patterns): return False - if include_patterns and not _matches_any(name, include_patterns): - return False - return True + return not include_patterns or _matches_any(name, include_patterns) for comp_name, component in model.components.items(): @@ -140,12 +138,11 @@ def should_apply(name: str) -> bool: if not is_relevant_component_by_include(comp_name): continue - # 2) dtype guard, as sage attn is only applicable for bfloat16 and float16 - if not hasattr(component, "dtype") or component.dtype not in (torch.bfloat16, torch.float16): - continue - - # 3) If there are no nested rules for the current component, make a faster global call otherwise go to submodule level - if hasattr(component, "set_attention_backend") and not has_nested_rules(comp_name) and should_apply(comp_name): + # 3) If there are no nested rules for the current component, + # make a faster global call otherwise go to submodule level + if (hasattr(component, "set_attention_backend") + and not has_nested_rules(comp_name) + and should_apply(comp_name)): component.set_attention_backend("sage_hub") continue @@ -165,123 +162,111 @@ def should_apply(name: str) -> bool: if not should_apply(full_name): continue - if hasattr(sub_module, "set_attention_backend") and sub_module.dtype in (torch.bfloat16, torch.float16): + if hasattr(sub_module, "set_attention_backend"): sub_module.set_attention_backend("sage_hub") return model - def get_hyperparameters(self) -> list: + """Return hyperparameters for this algorithm.""" return [ Boolean( "exclude_first_and_last_transformer_blocks", default=False, - meta=dict(desc="If True, do NOT apply SageAttention to the first and last transformer blocks for each transformer component."), + meta=dict(desc="If True, do NOT apply SageAttention to the first and last" + "transformer blocks for each transformer component."), ), TargetModules(name="target_modules", default_value=None), ] - # def get_model_dependent_hyperparameter_defaults( - # self, - # model: Any, - # smash_config: SmashConfigPrefixWrapper, - # ) -> TARGET_MODULES_TYPE: - # # So far we just exclude, per default everything is included - # # Filtering is done in the _apply method by the set_attention_backend method - # include = ["*"] - # exclude = [] - - # if smash_config["exclude_first_and_last_transformer_blocks"]: - # exclude = self._get_transformer_sub_excludes(model) - - # if not exclude: - # print( - # "exclude_first_and_last_transformer_blocks enabled, " - # "but no transformer blocks were found for exclusion." - # ) - - # return {"include": include, "exclude": exclude} - def get_model_dependent_hyperparameter_defaults( self, model: Any, smash_config: SmashConfigPrefixWrapper, - ) -> TARGET_MODULES_TYPE: - # So far, everything is included and nothing + ) -> TARGET_MODULES_TYPE: + """Return model-dependent default target_modules.""" + # So far, everything is included and nothing is excluded # Filtering is done in the _apply method by the set_attention_backend method include = ["*"] exclude = [] return {"include": include, "exclude": exclude} - def _get_transformer_sub_excludes( - self, - model: Any, - ) -> list[str]: - """ - Returns a flat list of glob patterns, e.g. - [ - "transformer.blocks.0*", - "transformer.blocks.39*", - "transformer_2.blocks.0*", - "transformer_2.blocks.39*", + +def _get_transformer_sub_excludes( + model: Any, +) -> list[str]: + """ + Returns a flat list of glob patterns. + + Example: + [ + "transformer.blocks.0*", + "transformer.blocks.39*", + "transformer_2.blocks.0*", + "transformer_2.blocks.39*", + ] + """ + excludes: list[str] = [] + + roots = _get_transformer_roots(model) + + for root in roots: + # get the component + comp = model.components.get(root, None) + # if the component is None/missing (e.g. the case for transformer_2 in Wan2.2-TI2V-5B-Diffusers), skip it + if comp is None: + pruna_logger.warning("skip %s for excludes: component is None", root) + continue + # get the attention names + attn_names = [ + name + for name, module in model.components[root].named_modules() + if name and hasattr(module, "set_attention_backend") ] - """ - excludes: list[str] = [] + # if there are no attention names, skip it + if not attn_names: + continue - roots = self._get_transformer_roots(model) + # get the block paths + block_paths = _unique_in_order([n.rsplit(".", 1)[0] for n in attn_names]) - for root in roots: - # get the component - comp = model.components.get(root, None) - # if the component is None/missing (e.g. the case for transformer_2 in Wan2.2-TI2V-5B-Diffusers), skip it - if comp is None: - print(f"skip {root}: component is None/missing") - continue - # get the attention names - attn_names = [ - name - for name, module in model.components[root].named_modules() - if name and hasattr(module, "set_attention_backend") - ] - # if there are no attention names, skip it - if not attn_names: - continue + # if there are less than 3 block paths, skip it + if len(block_paths) < 3: + pruna_logger.warning(f"Root {root} has less than 3 transformer blocks." + "Thus its first and last blocks are not excluded for sage_attn.") + continue - # get the block paths - block_paths = _unique_in_order([n.rsplit(".", 1)[0] for n in attn_names]) + # We just want to exclude the first and last blocks of the transformer components + excludes.extend([ + f"{root}.{block_paths[0]}*", + f"{root}.{block_paths[-1]}*", + ]) - # if there are less than 3 block paths, skip it - if len(block_paths) < 3: - pruna_logger.warning(f"Root {root} has less than 3 transformer blocks. Thus its first and last blocks are not excluded for sage_attn.") - continue + return excludes - # We just want to exclude the first and last blocks of the transformer components - excludes.extend([ - f"{root}.{block_paths[0]}*", - f"{root}.{block_paths[-1]}*", - ]) - return excludes +def _get_transformer_roots(model: Any) -> list[str]: + """Get the roots of the transformer components.""" + roots = [] + for name, _ in model.components.items(): + if name == "transformer" or name.startswith("transformer_"): + roots.append(name) - def _get_transformer_roots(self, model: Any) -> list[str]: - roots = [] - for name, _ in model.components.items(): - if name == "transformer" or name.startswith("transformer_"): - roots.append(name) + # Sort the roots by the number of the transformer component, just to be sure + def key(n: str) -> int: + # transformer -> 0, transformer_10 -> 10 + if n == "transformer": + return 0 + m = re.match(r"transformer_(\d+)$", n) + return int(m.group(1)) if m else 10**9 # unknown suffix goes to end - # Sort the roots by the number of the transformer component, just to be sure - def key(n: str) -> int: - # transformer -> 0, transformer_10 -> 10 - if n == "transformer": - return 0 - m = re.match(r"transformer_(\d+)$", n) - return int(m.group(1)) if m else 10**9 # unknown suffix goes to end + return sorted(roots, key=key) - return sorted(roots, key=key) def _unique_in_order(items: list[str]) -> list[str]: return list(OrderedDict.fromkeys(items)) + def _matches_any(name: str, patterns: list[str]) -> bool: - return any(fnmatch.fnmatch(name, pat) for pat in (patterns or [])) \ No newline at end of file + return any(fnmatch.fnmatch(name, pat) for pat in (patterns or [])) From 16aa7f49bc9d6d3938bdf92c09d27d87f2f62577 Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Wed, 17 Dec 2025 17:04:50 +0000 Subject: [PATCH 7/8] Add doc strings to functions and methods --- src/pruna/algorithms/sage_attn.py | 103 ++++++++++++++++++++++++++---- 1 file changed, 91 insertions(+), 12 deletions(-) diff --git a/src/pruna/algorithms/sage_attn.py b/src/pruna/algorithms/sage_attn.py index 327db67c..e83abc2c 100644 --- a/src/pruna/algorithms/sage_attn.py +++ b/src/pruna/algorithms/sage_attn.py @@ -168,7 +168,15 @@ def should_apply(name: str) -> bool: return model def get_hyperparameters(self) -> list: - """Return hyperparameters for this algorithm.""" + """ + Get the list of configurable hyperparameters for this algorithm. + + Returns + ------- + list + A list of hyperparameter objects (e.g., Boolean, TargetModules) used by the + configuration system. + """ return [ Boolean( "exclude_first_and_last_transformer_blocks", @@ -184,7 +192,23 @@ def get_model_dependent_hyperparameter_defaults( model: Any, smash_config: SmashConfigPrefixWrapper, ) -> TARGET_MODULES_TYPE: - """Return model-dependent default target_modules.""" + """ + Get model-dependent default hyperparameters for this algorithm. + + Parameters + ---------- + model : Any + The model/pipeline instance for which defaults should be computed. + smash_config : SmashConfigPrefixWrapper + The configuration wrapper passed to the algorithm. It can be used to read other + algorithm settings when selecting defaults. + + Returns + ------- + TARGET_MODULES_TYPE + A dictionary with keys "include" and "exclude" defining which modules should be + targeted by default. + """ # So far, everything is included and nothing is excluded # Filtering is done in the _apply method by the set_attention_backend method include = ["*"] @@ -197,15 +221,23 @@ def _get_transformer_sub_excludes( model: Any, ) -> list[str]: """ - Returns a flat list of glob patterns. - - Example: - [ - "transformer.blocks.0*", - "transformer.blocks.39*", - "transformer_2.blocks.0*", - "transformer_2.blocks.39*", - ] + Build a list of glob patterns excluding the first and last transformer blocks. + + This inspects transformer-like components (e.g. "transformer", "transformer_2") and + derives glob patterns that exclude the first and last block paths containing modules + that support ``set_attention_backend``. + + Parameters + ---------- + model : Any + A Diffusers pipeline-like object with a ``components`` mapping containing + transformer components. + + Returns + ------- + list[str] + A flat list of glob patterns (e.g. "transformer.blocks.0*") to be added to the + exclude patterns for targeting. """ excludes: list[str] = [] @@ -247,7 +279,23 @@ def _get_transformer_sub_excludes( def _get_transformer_roots(model: Any) -> list[str]: - """Get the roots of the transformer components.""" + """ + Get transformer component root names from a Diffusers pipeline. + + A "transformer root" is any entry in ``model.components`` named "transformer" or + starting with "transformer_". Roots are returned in numeric order such that + "transformer" comes first, then "transformer_2", "transformer_10", etc. + + Parameters + ---------- + model : Any + A Diffusers pipeline-like object exposing a ``components`` mapping. + + Returns + ------- + list[str] + Sorted list of transformer component names found in ``model.components``. + """ roots = [] for name, _ in model.components.items(): if name == "transformer" or name.startswith("transformer_"): @@ -265,8 +313,39 @@ def key(n: str) -> int: def _unique_in_order(items: list[str]) -> list[str]: + """ + Remove duplicates while preserving the original order. + + Parameters + ---------- + items : list[str] + Input items. + + Returns + ------- + list[str] + A list with the same relative order as the input, with duplicates removed. + """ return list(OrderedDict.fromkeys(items)) def _matches_any(name: str, patterns: list[str]) -> bool: + """ + Check whether a name matches any glob pattern in a list. + + Uses Unix shell-style wildcards via ``fnmatch`` (e.g., "*", "transformer.*", + "transformer.blocks.0*"). + + Parameters + ---------- + name : str + The string to test (e.g., "transformer.blocks.0.attn1"). + patterns : list[str] + List of glob patterns to match against. + + Returns + ------- + bool + True if ``name`` matches at least one pattern, otherwise False. + """ return any(fnmatch.fnmatch(name, pat) for pat in (patterns or [])) From b6cb7c227f826f88a5181c924a40dc3085e917ab Mon Sep 17 00:00:00 2001 From: Marius Graml Date: Wed, 24 Dec 2025 14:55:12 +0000 Subject: [PATCH 8/8] Refactor sage_attn to use target_modules utilities --- src/pruna/algorithms/sage_attn.py | 247 +++++------------------------- 1 file changed, 40 insertions(+), 207 deletions(-) diff --git a/src/pruna/algorithms/sage_attn.py b/src/pruna/algorithms/sage_attn.py index e83abc2c..b6f58138 100644 --- a/src/pruna/algorithms/sage_attn.py +++ b/src/pruna/algorithms/sage_attn.py @@ -14,20 +14,16 @@ from __future__ import annotations -import fnmatch -import re -from collections import OrderedDict from collections.abc import Iterable -from typing import Any +from typing import Any, List import torch from diffusers import DiffusionPipeline from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag as tags -from pruna.config.hyperparameters import Boolean from pruna.config.smash_config import SmashConfigPrefixWrapper -from pruna.config.target_modules import TARGET_MODULES_TYPE, TargetModules +from pruna.config.target_modules import TARGET_MODULES_TYPE, TargetModules, map_targeted_nn_roots from pruna.engine.save import SAVE_FUNCTIONS from pruna.logging.logger import pruna_logger @@ -94,7 +90,6 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: The wrapped model. """ target_modules = smash_config["target_modules"] - exclude_first_and_last_transformer_blocks = smash_config["exclude_first_and_last_transformer_blocks"] if target_modules is None: target_modules = self.get_model_dependent_hyperparameter_defaults( @@ -102,70 +97,48 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: smash_config ) # for consistency, not used yet - extra_excludes = (_get_transformer_sub_excludes(model) if exclude_first_and_last_transformer_blocks else []) - - include_patterns = target_modules.get("include", []) - exclude_patterns = target_modules.get("exclude", []) - exclude_patterns.extend(extra_excludes) - - # Heuristic: if any pattern contains a dot, there are nested rules - def has_nested_rules(comp_name: str) -> bool: - prefix = comp_name + "." - return any(p.startswith(prefix) for p in (include_patterns + exclude_patterns)) - - def is_relevant_component_by_include(comp_name: str) -> bool: - if not include_patterns: - return True - if _matches_any(comp_name, include_patterns): # "*", "transformer", ... - return True - prefix = comp_name + "." - return any(p.startswith(prefix) for p in include_patterns) # "transformer.*", "transformer.blocks.*" - - def should_apply(name: str) -> bool: - if exclude_patterns and _matches_any(name, exclude_patterns): - return False - return not include_patterns or _matches_any(name, include_patterns) - - for comp_name, component in model.components.items(): - - # --- Check component level --- - - # 1) Component-level filter (e.g. exclude "vae"), exclude if in exclude_patterns and not in include_patterns - if exclude_patterns and _matches_any(comp_name, exclude_patterns): - continue - - # 2) Pick only relevant components - if not is_relevant_component_by_include(comp_name): - continue - - # 3) If there are no nested rules for the current component, - # make a faster global call otherwise go to submodule level - if (hasattr(component, "set_attention_backend") - and not has_nested_rules(comp_name) - and should_apply(comp_name)): - component.set_attention_backend("sage_hub") - continue - - # --- Check submodule level --- - - # 1) Check for named_modules method for step 2) to work - if component is None or not hasattr(component, "named_modules"): - continue - - # 2) Nested rules: iterate over submodules and match full_name - for sub_name, sub_module in component.named_modules(): - if not sub_name: - continue - - full_name = f"{comp_name}.{sub_name}" # e.g., transformer.blocks.0.attn1 - - if not should_apply(full_name): + def apply_sage_attn( + root_name: str | None, + root_nn_module: torch.nn.Module, + relative_target_paths: List[str], + ) -> torch.nn.Module: + """ + Apply the SageAttention backend to targeted submodules of a root module. + + For each relative submodule path, this function retrieves the corresponding + submodule from ``root_nn_module`` and applies + ``set_attention_backend("sage_hub")`` if the method is available. + + Parameters + ---------- + root_name : str or None + The attribute name of the root module within the model (used for identification). + May be ``None`` if the model itself is a ``torch.nn.Module``. + root_nn_module : torch.nn.Module + The root torch.nn.module containing the targeted submodules. + relative_target_paths : List[str] + Relative paths of submodules (with respect to ``root_nn_module``) to consider. + + Returns + ------- + torch.nn.Module + The root ntorch.nn.module with the SageAttention backend applied where supported. + """ + for rel_path in relative_target_paths: + try: + sub_module = root_nn_module.get_submodule(rel_path) + except AttributeError: + # safety net: should not happen, + # since the paths come from named_modules() continue - if hasattr(sub_module, "set_attention_backend"): sub_module.set_attention_backend("sage_hub") + else: + pruna_logger.warning(f"Module {root_name}.{rel_path} does not have a set_attention_backend method" + "and will not be replaced with SageAttention") + return root_nn_module - return model + return map_targeted_nn_roots(apply_sage_attn, model, target_modules) def get_hyperparameters(self) -> list: """ @@ -178,12 +151,6 @@ def get_hyperparameters(self) -> list: configuration system. """ return [ - Boolean( - "exclude_first_and_last_transformer_blocks", - default=False, - meta=dict(desc="If True, do NOT apply SageAttention to the first and last" - "transformer blocks for each transformer component."), - ), TargetModules(name="target_modules", default_value=None), ] @@ -215,137 +182,3 @@ def get_model_dependent_hyperparameter_defaults( exclude = [] return {"include": include, "exclude": exclude} - - -def _get_transformer_sub_excludes( - model: Any, -) -> list[str]: - """ - Build a list of glob patterns excluding the first and last transformer blocks. - - This inspects transformer-like components (e.g. "transformer", "transformer_2") and - derives glob patterns that exclude the first and last block paths containing modules - that support ``set_attention_backend``. - - Parameters - ---------- - model : Any - A Diffusers pipeline-like object with a ``components`` mapping containing - transformer components. - - Returns - ------- - list[str] - A flat list of glob patterns (e.g. "transformer.blocks.0*") to be added to the - exclude patterns for targeting. - """ - excludes: list[str] = [] - - roots = _get_transformer_roots(model) - - for root in roots: - # get the component - comp = model.components.get(root, None) - # if the component is None/missing (e.g. the case for transformer_2 in Wan2.2-TI2V-5B-Diffusers), skip it - if comp is None: - pruna_logger.warning("skip %s for excludes: component is None", root) - continue - # get the attention names - attn_names = [ - name - for name, module in model.components[root].named_modules() - if name and hasattr(module, "set_attention_backend") - ] - # if there are no attention names, skip it - if not attn_names: - continue - - # get the block paths - block_paths = _unique_in_order([n.rsplit(".", 1)[0] for n in attn_names]) - - # if there are less than 3 block paths, skip it - if len(block_paths) < 3: - pruna_logger.warning(f"Root {root} has less than 3 transformer blocks." - "Thus its first and last blocks are not excluded for sage_attn.") - continue - - # We just want to exclude the first and last blocks of the transformer components - excludes.extend([ - f"{root}.{block_paths[0]}*", - f"{root}.{block_paths[-1]}*", - ]) - - return excludes - - -def _get_transformer_roots(model: Any) -> list[str]: - """ - Get transformer component root names from a Diffusers pipeline. - - A "transformer root" is any entry in ``model.components`` named "transformer" or - starting with "transformer_". Roots are returned in numeric order such that - "transformer" comes first, then "transformer_2", "transformer_10", etc. - - Parameters - ---------- - model : Any - A Diffusers pipeline-like object exposing a ``components`` mapping. - - Returns - ------- - list[str] - Sorted list of transformer component names found in ``model.components``. - """ - roots = [] - for name, _ in model.components.items(): - if name == "transformer" or name.startswith("transformer_"): - roots.append(name) - - # Sort the roots by the number of the transformer component, just to be sure - def key(n: str) -> int: - # transformer -> 0, transformer_10 -> 10 - if n == "transformer": - return 0 - m = re.match(r"transformer_(\d+)$", n) - return int(m.group(1)) if m else 10**9 # unknown suffix goes to end - - return sorted(roots, key=key) - - -def _unique_in_order(items: list[str]) -> list[str]: - """ - Remove duplicates while preserving the original order. - - Parameters - ---------- - items : list[str] - Input items. - - Returns - ------- - list[str] - A list with the same relative order as the input, with duplicates removed. - """ - return list(OrderedDict.fromkeys(items)) - - -def _matches_any(name: str, patterns: list[str]) -> bool: - """ - Check whether a name matches any glob pattern in a list. - - Uses Unix shell-style wildcards via ``fnmatch`` (e.g., "*", "transformer.*", - "transformer.blocks.0*"). - - Parameters - ---------- - name : str - The string to test (e.g., "transformer.blocks.0.attn1"). - patterns : list[str] - List of glob patterns to match against. - - Returns - ------- - bool - True if ``name`` matches at least one pattern, otherwise False. - """ - return any(fnmatch.fnmatch(name, pat) for pat in (patterns or []))