Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions src/pruna/algorithms/torchao_autoquant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Any, Callable, Dict

import torch

from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
from pruna.algorithms.base.tags import AlgorithmTag
from pruna.config.hyperparameters import Boolean
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.engine.model_checks import get_diffusers_transformer_models
from pruna.engine.save import SAVE_FUNCTIONS


class Autoquant(PrunaAlgorithmBase):
"""
Implement autoquantization using the torchao library.

This algorithm compiles, quantizes and sparsifies weights, gradients, and activations for inference.
This algorithm is specifically adapted for Image-Gen models.
"""

algorithm_name: str = "torchao_autoquant"
group_tags: list[AlgorithmTag] = [AlgorithmTag.QUANTIZER]
references: dict[str, str] = {"GitHub": "https://huggingface.co/docs/diffusers/quantization/torchao"}
save_fn: SAVE_FUNCTIONS = SAVE_FUNCTIONS.save_before_apply
tokenizer_required: bool = False
processor_required: bool = False
runs_on: list[str] = ["cpu", "cuda"]
dataset_required: bool = False

def get_hyperparameters(self) -> list:
"""
Configure all algorithm-specific hyperparameters with ConfigSpace.

Returns
-------
list
The hyperparameters.
"""
return [
Boolean("compile", default=True, meta=dict(desc="Whether to compile the model after quantization or not.")),
]

def model_check_fn(self, model: Any) -> bool:
"""
Check if the model is a torch.nn.Module.

Parameters
----------
model : Any
The model to check.

Returns
-------
bool
True if the model is a causal language model, False otherwise.
"""
transformer_models = get_diffusers_transformer_models()

if isinstance(model, tuple(transformer_models)):
return True

for _, attr_value in inspect.getmembers(model):
if isinstance(attr_value, tuple(transformer_models)):
return True
return isinstance(model, torch.nn.Module)

def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
"""
Quantize the model.

Parameters
----------
model : Any
The model to quantize.
smash_config : SmashConfigPrefixWrapper
The configuration for the quantization.

Returns
-------
Any
The quantized model.
"""
transformer_models = get_diffusers_transformer_models()

working_model = model.transformer if isinstance(model.transformer, tuple(transformer_models)) else model

if smash_config["compile"]:
working_model.torch_compiler = TorchCompiler(working_model)

# Now we can compile the model
working_model = working_model.torch_compiler.compile()
working_model = self.import_algorithm_packages()["autoquant"](working_model, error_on_unseen=False)

if isinstance(model.transformer, tuple(transformer_models)):
model.transformer = working_model
else:
model = working_model

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Autoquant assumes model.transformer always exists

model_check_fn returns True for any torch.nn.Module, but _apply unconditionally reads model.transformer. For valid inputs that are torch.nn.Module instances without a transformer attribute (or when the input is itself a diffusers transformer module), this triggers an attribute error and breaks applying torchao_autoquant.

Additional Locations (1)

Fix in Cursor Fix in Web

return model

def import_algorithm_packages(self) -> Dict[str, Any]:
"""
Provide a algorithm packages for the algorithm.

Returns
-------
Dict[str, Any]
The algorithm packages.
"""
from torchao.quantization import autoquant

return dict(autoquant=autoquant)


class TorchCompiler(object):
"""
A class that compiles a PyTorch model using the pre-defined compilation options.

Parameters
----------
model : torch.nn.Module
The PyTorch model to be compiled.
"""

def __init__(self, model: Callable[..., Any]) -> None:
"""
Initialize the TorchCompiler with a model and a configuration.

Parameters
----------
model : Callable[..., Any]
The PyTorch model to be compiled.
smash_config : dict
A configuration dictionary that contains the settings for the compilation process.
"""
self.model = model

def compile(self) -> Callable[..., Any]:
"""
Compile the PyTorch model using options provided in the `smash_config`.

Returns
-------
torch.nn.Module
The compiled PyTorch model.
"""
self.model = torch.compile(self.model, mode="max-autotune-no-cudagraphs", fullgraph=True)

return self.model
18 changes: 18 additions & 0 deletions tests/algorithms/testers/autoquant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

from pruna.algorithms.torchao_autoquant import Autoquant

from .base_tester import AlgorithmTesterBase


# This is classified as high because autoquant requires INT8 tensor cores which require GPUs with
# Turing architecture and above
@pytest.mark.high
class TestAutoquant(AlgorithmTesterBase):
"""Test the Autoquant quantizer."""

models = ["flux_tiny_random"]
reject_models = ["dummy_lambda"]
allow_pickle_files = False
algorithm_class = Autoquant
metrics = ["cmmd"]