diff --git a/src/pruna/algorithms/deepcache.py b/src/pruna/algorithms/deepcache.py index cf2a9ad8..6f404224 100644 --- a/src/pruna/algorithms/deepcache.py +++ b/src/pruna/algorithms/deepcache.py @@ -43,7 +43,14 @@ class DeepCache(PrunaAlgorithmBase): processor_required: bool = False dataset_required: bool = False runs_on: list[str] = ["cpu", "cuda", "accelerate"] - compatible_before: Iterable[str] = ["qkv_diffusers", "half", "hqq_diffusers", "diffusers_int8", "quanto"] + compatible_before: Iterable[str] = [ + "qkv_diffusers", + "half", + "hqq_diffusers", + "diffusers_int8", + "quanto", + "sage_attn", + ] compatible_after: Iterable[str] = ["stable_fast", "torch_compile"] def get_hyperparameters(self) -> list: diff --git a/src/pruna/algorithms/fastercache.py b/src/pruna/algorithms/fastercache.py index 856401e7..41dece15 100644 --- a/src/pruna/algorithms/fastercache.py +++ b/src/pruna/algorithms/fastercache.py @@ -57,7 +57,7 @@ class FasterCache(PrunaAlgorithmBase): processor_required: bool = False dataset_required: bool = False runs_on: list[str] = ["cpu", "cuda", "accelerate"] - compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8"] + compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8", "sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/fora.py b/src/pruna/algorithms/fora.py index 00697ea4..fb9f08e6 100644 --- a/src/pruna/algorithms/fora.py +++ b/src/pruna/algorithms/fora.py @@ -44,7 +44,14 @@ class FORA(PrunaAlgorithmBase): processor_required: bool = False runs_on: list[str] = ["cpu", "cuda", "accelerate"] dataset_required: bool = False - compatible_before: Iterable[str] = ["qkv_diffusers", "diffusers_int8", "hqq_diffusers", "torchao", "flash_attn3"] + compatible_before: Iterable[str] = [ + "qkv_diffusers", + "diffusers_int8", + "hqq_diffusers", + "torchao", + "flash_attn3", + "sage_attn" + ] compatible_after: Iterable[str] = ["stable_fast", "torch_compile"] def get_hyperparameters(self) -> list: diff --git a/src/pruna/algorithms/gptq_model.py b/src/pruna/algorithms/gptq_model.py index edfbcda7..1b2c78e7 100644 --- a/src/pruna/algorithms/gptq_model.py +++ b/src/pruna/algorithms/gptq_model.py @@ -45,7 +45,7 @@ class GPTQ(PrunaAlgorithmBase): processor_required: bool = False runs_on: list[str] = ["cuda"] dataset_required: bool = True - compatible_after: Iterable[str] = ["torch_compile"] + compatible_after: Iterable[str] = ["torch_compile", "sage_attn"] required_install: str = ( "You must first install the base package with ``pip install pruna`` " "before installing the GPTQ extension with ``pip install pruna[gptq] --extra-index-url https://prunaai.pythonanywhere.com/``" diff --git a/src/pruna/algorithms/half.py b/src/pruna/algorithms/half.py index b416de27..3b0afeae 100644 --- a/src/pruna/algorithms/half.py +++ b/src/pruna/algorithms/half.py @@ -51,6 +51,7 @@ class Half(PrunaAlgorithmBase): "torch_compile", "ifw", "whisper_s2t", + "sage_attn", ] def model_check_fn(self, model: Any) -> bool: diff --git a/src/pruna/algorithms/hqq.py b/src/pruna/algorithms/hqq.py index 4ccc6804..88d3f680 100644 --- a/src/pruna/algorithms/hqq.py +++ b/src/pruna/algorithms/hqq.py @@ -52,7 +52,7 @@ class HQQ(PrunaAlgorithmBase): runs_on: list[str] = ["cuda"] dataset_required: bool = False compatible_before: Iterable[str] = ["torch_structured"] - compatible_after: Iterable[str] = ["torch_compile"] + compatible_after: Iterable[str] = ["torch_compile", "sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/hqq_diffusers.py b/src/pruna/algorithms/hqq_diffusers.py index 99f1c051..544c10ef 100644 --- a/src/pruna/algorithms/hqq_diffusers.py +++ b/src/pruna/algorithms/hqq_diffusers.py @@ -54,7 +54,7 @@ class HQQDiffusers(PrunaAlgorithmBase): runs_on: list[str] = ["cuda"] dataset_required: bool = False compatible_before: Iterable[str] = ["qkv_diffusers"] - compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile"] + compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile", "sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/huggingface_diffusers_int8.py b/src/pruna/algorithms/huggingface_diffusers_int8.py index 02b9e8d6..8a431e93 100644 --- a/src/pruna/algorithms/huggingface_diffusers_int8.py +++ b/src/pruna/algorithms/huggingface_diffusers_int8.py @@ -60,7 +60,7 @@ class DiffusersInt8(PrunaAlgorithmBase): runs_on: list[str] = ["cuda", "accelerate"] save_fn: None = None compatible_before: Iterable[str] = ["qkv_diffusers"] - compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile"] + compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile", "sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/huggingface_llm_int8.py b/src/pruna/algorithms/huggingface_llm_int8.py index eb381204..9704f909 100644 --- a/src/pruna/algorithms/huggingface_llm_int8.py +++ b/src/pruna/algorithms/huggingface_llm_int8.py @@ -56,7 +56,7 @@ class LLMInt8(PrunaAlgorithmBase): dataset_required: bool = False runs_on: list[str] = ["cuda", "accelerate"] save_fn: None = None - compatible_after: Iterable[str] = ["torch_compile"] + compatible_after: Iterable[str] = ["torch_compile", "sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/pab.py b/src/pruna/algorithms/pab.py index 2472217d..a71d0c12 100644 --- a/src/pruna/algorithms/pab.py +++ b/src/pruna/algorithms/pab.py @@ -53,7 +53,7 @@ class PAB(PrunaAlgorithmBase): processor_required: bool = False dataset_required: bool = False runs_on: list[str] = ["cpu", "cuda", "accelerate"] - compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8"] + compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8", "sage_attn"] compatible_after: Iterable[str] = [] def get_hyperparameters(self) -> list: diff --git a/src/pruna/algorithms/quanto.py b/src/pruna/algorithms/quanto.py index c60f6ad8..7e95d93b 100644 --- a/src/pruna/algorithms/quanto.py +++ b/src/pruna/algorithms/quanto.py @@ -51,7 +51,7 @@ class Quanto(PrunaAlgorithmBase): dataset_required: bool = False runs_on: list[str] = ["cuda"] compatible_before: Iterable[str] = ["qkv_diffusers"] - compatible_after: Iterable[str] = ["deepcache"] + compatible_after: Iterable[str] = ["deepcache", "sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/sage_attn.py b/src/pruna/algorithms/sage_attn.py new file mode 100644 index 00000000..b6f58138 --- /dev/null +++ b/src/pruna/algorithms/sage_attn.py @@ -0,0 +1,184 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any, List + +import torch +from diffusers import DiffusionPipeline + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag as tags +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.config.target_modules import TARGET_MODULES_TYPE, TargetModules, map_targeted_nn_roots +from pruna.engine.save import SAVE_FUNCTIONS +from pruna.logging.logger import pruna_logger + + +class SageAttn(PrunaAlgorithmBase): + """ + Replace torch.nn.functional.scaled_dot_product_attention with sage_attn. + + SageAttention is a fast and memory-efficient attention mechanism. It applies the flash attention mechanism + in combination with quantization and smoothing to speed up attention computations. + """ + + algorithm_name: str = "sage_attn" + group_tags: list[str] = [tags.KERNEL] + save_fn = SAVE_FUNCTIONS.reapply + references: dict[str, str] = { + "Paper (SA2++)": "https://arxiv.org/pdf/2505.21136v3", + "GitHub": "https://github.com/thu-ml/SageAttention", + "Kernel Hub": "https://huggingface.co/kernels-community/sage_attention", + } + tokenizer_required: bool = False + processor_required: bool = False + runs_on: list[str] = ["cuda", "accelerate"] + dataset_required: bool = False + compatible_before: Iterable[str] = [tags.QUANTIZER] + compatible_after: Iterable[str] = ["torch_compile", tags.CACHER] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model has an attention mechanism that can be replaced with sage_attn. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a valid model for the algorithm, False otherwise. + """ + if not isinstance(model, DiffusionPipeline) or not hasattr(model, "components"): + return False + + return any( + hasattr(component, "set_attention_backend") and component.dtype in (torch.bfloat16, torch.float16) + for component in model.components.values() + ) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Wrap the model to use SageAttention where possible. + + Parameters + ---------- + model : Any + The model to wrap. + smash_config : SmashConfigPrefixWrapper + The configuration for the application of the algorithm. + + Returns + ------- + Any + The wrapped model. + """ + target_modules = smash_config["target_modules"] + + if target_modules is None: + target_modules = self.get_model_dependent_hyperparameter_defaults( + model, + smash_config + ) # for consistency, not used yet + + def apply_sage_attn( + root_name: str | None, + root_nn_module: torch.nn.Module, + relative_target_paths: List[str], + ) -> torch.nn.Module: + """ + Apply the SageAttention backend to targeted submodules of a root module. + + For each relative submodule path, this function retrieves the corresponding + submodule from ``root_nn_module`` and applies + ``set_attention_backend("sage_hub")`` if the method is available. + + Parameters + ---------- + root_name : str or None + The attribute name of the root module within the model (used for identification). + May be ``None`` if the model itself is a ``torch.nn.Module``. + root_nn_module : torch.nn.Module + The root torch.nn.module containing the targeted submodules. + relative_target_paths : List[str] + Relative paths of submodules (with respect to ``root_nn_module``) to consider. + + Returns + ------- + torch.nn.Module + The root ntorch.nn.module with the SageAttention backend applied where supported. + """ + for rel_path in relative_target_paths: + try: + sub_module = root_nn_module.get_submodule(rel_path) + except AttributeError: + # safety net: should not happen, + # since the paths come from named_modules() + continue + if hasattr(sub_module, "set_attention_backend"): + sub_module.set_attention_backend("sage_hub") + else: + pruna_logger.warning(f"Module {root_name}.{rel_path} does not have a set_attention_backend method" + "and will not be replaced with SageAttention") + return root_nn_module + + return map_targeted_nn_roots(apply_sage_attn, model, target_modules) + + def get_hyperparameters(self) -> list: + """ + Get the list of configurable hyperparameters for this algorithm. + + Returns + ------- + list + A list of hyperparameter objects (e.g., Boolean, TargetModules) used by the + configuration system. + """ + return [ + TargetModules(name="target_modules", default_value=None), + ] + + def get_model_dependent_hyperparameter_defaults( + self, + model: Any, + smash_config: SmashConfigPrefixWrapper, + ) -> TARGET_MODULES_TYPE: + """ + Get model-dependent default hyperparameters for this algorithm. + + Parameters + ---------- + model : Any + The model/pipeline instance for which defaults should be computed. + smash_config : SmashConfigPrefixWrapper + The configuration wrapper passed to the algorithm. It can be used to read other + algorithm settings when selecting defaults. + + Returns + ------- + TARGET_MODULES_TYPE + A dictionary with keys "include" and "exclude" defining which modules should be + targeted by default. + """ + # So far, everything is included and nothing is excluded + # Filtering is done in the _apply method by the set_attention_backend method + include = ["*"] + exclude = [] + + return {"include": include, "exclude": exclude} diff --git a/src/pruna/algorithms/torch_compile/torch_compile.py b/src/pruna/algorithms/torch_compile/torch_compile.py index 62c4d086..fe09bbd2 100644 --- a/src/pruna/algorithms/torch_compile/torch_compile.py +++ b/src/pruna/algorithms/torch_compile/torch_compile.py @@ -69,6 +69,7 @@ class TorchCompile(PrunaAlgorithmBase): "flash_attn3", "deepcache", "fora", + "sage_attn", ] def get_hyperparameters(self) -> list: diff --git a/src/pruna/algorithms/torch_dynamic.py b/src/pruna/algorithms/torch_dynamic.py index fc97ca90..e3b1ac3f 100644 --- a/src/pruna/algorithms/torch_dynamic.py +++ b/src/pruna/algorithms/torch_dynamic.py @@ -43,7 +43,7 @@ class TorchDynamic(PrunaAlgorithmBase): runs_on: list[str] = ["cpu", "cuda"] dataset_required: bool = False compatible_before: Iterable[str] = [] - compatible_after: Iterable[str] = [] + compatible_after: Iterable[str] = ["sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/src/pruna/algorithms/torchao.py b/src/pruna/algorithms/torchao.py index d0a10ee1..63550972 100644 --- a/src/pruna/algorithms/torchao.py +++ b/src/pruna/algorithms/torchao.py @@ -90,7 +90,7 @@ class Torchao(PrunaAlgorithmBase): runs_on: list[str] = ["cpu", "cuda", "accelerate"] dataset_required: bool = False compatible_before: Iterable[str] = ["qkv_diffusers", "torch_structured"] - compatible_after: Iterable[str] = ["flash_attn3", "fora", "torch_compile"] + compatible_after: Iterable[str] = ["flash_attn3", "fora", "torch_compile", "sage_attn"] def get_hyperparameters(self) -> list: """ diff --git a/tests/algorithms/testers/sage_attn.py b/tests/algorithms/testers/sage_attn.py new file mode 100644 index 00000000..8fdabd30 --- /dev/null +++ b/tests/algorithms/testers/sage_attn.py @@ -0,0 +1,16 @@ +import pytest + +from pruna.algorithms.sage_attn import SageAttn + +from .base_tester import AlgorithmTesterBase + + +@pytest.mark.high +class TestSageAttn(AlgorithmTesterBase): + """Test the sage attention kernel.""" + + models = ["flux_tiny", "wan_tiny_random"] + reject_models = ["opt_tiny_random"] + allow_pickle_files = False + algorithm_class = SageAttn + metrics = ["latency"] \ No newline at end of file