diff --git a/pyproject.toml b/pyproject.toml index 44df84bd..55930b1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,7 +136,9 @@ dependencies = [ "aenum", "vbench-pruna; sys_platform != 'darwin'", "imageio-ffmpeg", - "jaxtyping" + "jaxtyping", + "basicsr>=1.4.2", + "realesrgan>=0.3.0", ] [project.optional-dependencies] diff --git a/src/pruna/algorithms/base/tags.py b/src/pruna/algorithms/base/tags.py index f2e37e6f..c155a8cf 100644 --- a/src/pruna/algorithms/base/tags.py +++ b/src/pruna/algorithms/base/tags.py @@ -64,6 +64,10 @@ class AlgorithmTag(Enum): "batcher", "Batching groups multiple inputs together to be processed simultaneously, improving computational efficiency and reducing overall processing time.", ) + ENHANCER = ( + "enhancer", + "Enhancers improve the quality of the model's output. Enhancers can range from post-processing to test time compute algorithms.", + ) def __init__(self, name: str, description: str): """ diff --git a/src/pruna/algorithms/denoise.py b/src/pruna/algorithms/denoise.py new file mode 100644 index 00000000..cba15a2e --- /dev/null +++ b/src/pruna/algorithms/denoise.py @@ -0,0 +1,277 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +import torch +from ConfigSpace import UniformFloatHyperparameter +from diffusers import AutoPipelineForImage2Image, DiffusionPipeline + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.save import SAVE_FUNCTIONS +from pruna.engine.utils import determine_dtype + + +class Img2ImgDenoise(PrunaAlgorithmBase): + """ + Refines images using the model's own image-to-image capabilities. + + This enhancer takes the output images from a diffusion pipeline and refines them + by smartly reusing the same pipeline. This assumes the base model is a diffusers + pipeline supporting image-to-image. + + Attributes + ---------- + algorithm_name : str + The name identifier for this algorithm. + references : dict[str, str] + Dictionary containing references (optional). + tokenizer_required : bool + Whether a tokenizer is required (usually False, depends on pipeline). + processor_required : bool + Whether a processor is required (usually False, depends on pipeline). + run_on_cpu : bool + Whether this enhancer can run on CPU (depends on base model). + run_on_cuda : bool + Whether this enhancer can run on CUDA devices (depends on base model). + dataset_required : bool + Whether a dataset is required for this enhancer. + compatible_algorithms : dict + Dictionary of algorithms that are compatible with this enhancer. + """ + + algorithm_name: str = "img2img_denoise" + group_tags: list[AlgorithmTag] = [AlgorithmTag.ENHANCER] # type: ignore[attr-defined] + save_fn = SAVE_FUNCTIONS.reapply + references: dict[str, str] = { + "Diffusers": "https://huggingface.co/docs/diffusers/using-diffusers/img2img", + } + tokenizer_required: bool = False + processor_required: bool = False + runs_on: list[str] = ["cpu", "cuda"] + dataset_required: bool = False + compatible_before: Iterable[str | AlgorithmTag] = [ + AlgorithmTag.CACHER, + "torch_compile", + "stable_fast", + "hqq_diffusers", + "diffusers_int8", + "torchao", + "qkv_diffusers", + "ring_attn", + ] + + def get_hyperparameters(self) -> list: + """ + Get the hyperparameters for the Img2Img Denoise enhancer. + + Returns + ------- + list + A list of hyperparameters, including: + - strength: Controls how much noise is added to the input image, + influencing how much it changes (0.0-1.0). Lower values + mean less change/more refinement. + """ + return [ + UniformFloatHyperparameter( + "strength", + lower=0.0, + upper=1.0, + default_value=0.02, + log=False, + meta=dict(desc="Strength of the denoising/refinement. Lower values mean less change/more refinement."), + ), + ] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is a diffusers pipeline with UNet or Transformer. + + Parameters + ---------- + model : Any + The model instance to check. + + Returns + ------- + bool + True if the model seems compatible, False otherwise. + """ + if not isinstance(model, DiffusionPipeline) or not hasattr(model, "_name_or_path"): + return False + + model_dtype = determine_dtype(model) + + # check if the model is supported in an img2img pipeline + try: + AutoPipelineForImage2Image.from_pretrained( + model._name_or_path, + transformer=getattr(model, "transformer", None), + unet=getattr(model, "unet", None), + vae=getattr(model, "vae", None), + text_encoder=getattr(model, "text_encoder", None), + torch_dtype=model_dtype, + scheduler=getattr(model, "scheduler", None), + ) + except Exception: + return False + return True + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Apply image-to-image denoising/refinement to the model's output. + + Parameters + ---------- + model : Any + The diffusers pipeline model to enhance. + smash_config : SmashConfigPrefixWrapper + The configuration containing hyperparameters like 'strength'. + + Returns + ------- + Any + The model with its output generation wrapped for refinement. + """ + model_dtype = determine_dtype(model) + + refiner = AutoPipelineForImage2Image.from_pretrained( + model._name_or_path, + transformer=getattr(model, "transformer", None), + unet=getattr(model, "unet", None), + vae=getattr(model, "vae", None), + text_encoder=getattr(model, "text_encoder", None), + torch_dtype=model_dtype, + scheduler=getattr(model, "scheduler", None), + ).to(smash_config.device) + + denoise_strength = smash_config["strength"] + + model.denoise_helper = DenoiseHelper( + model, + refiner, + strength=denoise_strength, + ) + model.denoise_helper.enable() + + return model + + def import_algorithm_packages(self) -> dict[str, Any]: + """ + Import necessary algorithm packages. + + Returns + ------- + dict + An empty dictionary as no packages are imported in this implementation. + """ + return dict() + + +class DenoiseHelper: + """ + Helper class to wrap a pipeline's call for image-to-image refinement. + + Intercepts the output images and runs them through the same pipeline + again using image-to-image mode with a specified strength. + + Parameters + ---------- + model : Any + The diffusers pipeline model being wrapped. + refiner : Any + The separate pipeline used for the refinement step. + strength : float + The strength parameter for the img2img refinement step. + """ + + def __init__(self, model: Any, refiner: Any, strength: float) -> None: + if not (hasattr(model, "__call__") and callable(model.__call__)): + raise TypeError("Model must have a callable __call__ method to be wrapped.") + self.model = model + self.refiner = refiner + self.refiner.set_progress_bar_config(disable=True) + self.original_pipe_call = self.model.__call__ + self.strength = strength + # Store device for placing tensors if needed + self.device = getattr(model, "device", torch.device("cuda" if torch.cuda.is_available() else "cpu")) + + def _wrapped_pipe_call(self, *args, **kwargs) -> Any: + """ + Wrap the pipeline call to apply img2img refinement to the output. + + Runs the original call, then takes the output images and processes + them via the refiner pipeline using its img2img capability. Handles + multiple output images if generated. Selectively forwards relevant + arguments to the refiner. + + Parameters + ---------- + *args : tuple + Positional arguments for the original pipeline call (e.g., prompt). + **kwargs : dict + Keyword arguments for the original pipeline call. + + Returns + ------- + Any + The pipeline output containing images refined via img2img. + """ + # Execute the original call (e.g., text-to-image) + output = self.original_pipe_call(*args, **kwargs) + + # --- Refinement Step --- + # Check if output has images and is not None + if output is None or not hasattr(output, "images") or not output.images: + return output # Return original output if no images + + # Disable cache helper if it exists during refinement + if hasattr(self.model, "cache_helper") and hasattr(self.model.cache_helper, "disable"): + self.model.cache_helper.disable() + + refined_images = [] + + kwargs.pop("num_images_per_prompt", None) + # Process each image individually + for img in output.images: + # Ensure image is on the correct device/format if necessary (often handled by pipeline) + refined_output_single = self.refiner(image=img, strength=self.strength, *args, **kwargs) + if refined_output_single is not None and hasattr(refined_output_single, "images"): + refined_images.extend(refined_output_single.images) + else: + # Handle cases where refinement might fail for a single image + refined_images.append(img) # Keep original if refinement fails + + # Re-enable cache helper if it exists + if hasattr(self.model, "cache_helper") and hasattr(self.model.cache_helper, "enable"): + self.model.cache_helper.enable() + + # Replace original images with refined ones in the output object + output.images = refined_images + + return output + + def enable(self) -> None: + """Enable the img2img refinement by replacing the pipeline call.""" + self.model.__call__ = self._wrapped_pipe_call + + def disable(self) -> None: + """Disable refinement by restoring the original pipeline call.""" + self.model.__call__ = self.original_pipe_call diff --git a/src/pruna/algorithms/upscale.py b/src/pruna/algorithms/upscale.py new file mode 100644 index 00000000..1c2b1f0f --- /dev/null +++ b/src/pruna/algorithms/upscale.py @@ -0,0 +1,281 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys +import tempfile +import types +from collections.abc import Iterable +from typing import Any + +import numpy as np +from ConfigSpace import Constant +from PIL import Image + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.model_checks import is_diffusers_pipeline +from pruna.engine.save import SAVE_FUNCTIONS + + +class RealESRGAN(PrunaAlgorithmBase): + """ + Implement Real-ESRGAN upscaling for images. + + This enhancer applies the Real-ESRGAN model to upscale images produced by + diffusion models or other image generation pipelines. + + Attributes + ---------- + algorithm_name : str + The name identifier for this algorithm. + references : dict[str, str] + Dictionary containing references to the original paper and implementation. + tokenizer_required : bool + Whether a tokenizer is required for this enhancer. + processor_required : bool + Whether a processor is required for this enhancer. + run_on_cpu : bool + Whether this enhancer can run on CPU. + run_on_cuda : bool + Whether this enhancer can run on CUDA devices. + dataset_required : bool + Whether a dataset is required for this enhancer. + compatible_algorithms : dict + Dictionary of algorithms that are compatible with this enhancer. + """ + + algorithm_name: str = "realesrgan_upscale" + group_tags: list[AlgorithmTag] = [AlgorithmTag.ENHANCER] # type: ignore[attr-defined] + save_fn = SAVE_FUNCTIONS.reapply + references: dict[str, str] = { + "Paper": "https://arxiv.org/abs/2107.10833", + "GitHub": "https://github.com/xinntao/Real-ESRGAN", + } + tokenizer_required: bool = False + processor_required: bool = False + runs_on: list[str] = ["cpu", "cuda", "accelerate"] + dataset_required: bool = False + compatible_before: Iterable[str | AlgorithmTag] = [ + AlgorithmTag.CACHER, + "torch_compile", + "stable_fast", + "hqq_diffusers", + "diffusers_int8", + "torchao", + "qkv_diffusers", + "ring_attn" + ] + + def get_hyperparameters(self) -> list: + """ + Get the hyperparameters for the RealESRGAN enhancer. + + This method defines the configurable parameters for the RealESRGAN model, + including scaling factors, tile sizes, padding values, and precision options. + + Returns + ------- + list + A list of hyperparameters for the RealESRGAN model, including: + - outscale: Output scaling factor + - tile: Tile size for processing large images (0 means no tiling) + - tile_pad: Padding size for tiles to avoid boundary artifacts + - pre_pad: Padding before processing + - face_enhance: Whether to enhance faces specifically + - fp32: Whether to use FP32 precision instead of FP16 + - netscale: Network scaling factor + """ + return [ + Constant("outscale", value=4), + Constant("tile", value=0), + Constant("tile_pad", value=10), + Constant("pre_pad", value=0), + Constant("face_enhance", value=False), + Constant("fp32", value=False), + Constant("netscale", value=4), + ] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model has a unet or transformer from diffusers as an attribute. + + Parameters + ---------- + model : Any + The model instance to check. + + Returns + ------- + bool + True if the model has a unet or transformer from diffusers as an attribute, False otherwise. + """ + return is_diffusers_pipeline(model, include_video=False) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Apply RealESRGAN upscaling to the model. + + This method sets up the RealESRGAN model and integrates it with the + provided model by wrapping the model's output function to apply + upscaling automatically. + + Parameters + ---------- + model : Any + The model to enhance with RealESRGAN upscaling. This is typically + an image generation model like a diffusion model. + smash_config : SmashConfigPrefixWrapper + The configuration for enhancement, containing parameters like + tile size, padding values, and precision options. + + Returns + ------- + Any + The model with RealESRGAN upscaling applied, which will automatically + upscale any images produced by the model. + """ + imported_modules = self.import_algorithm_packages() + + rrdb_net = imported_modules["RRDBNet"]( + num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4 + ) + netscale = smash_config["netscale"] + model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth" + + with tempfile.TemporaryDirectory(prefix=str(smash_config["cache_dir"])) as temp_dir: + model_path = imported_modules["load_file_from_url"](model_url, model_dir=temp_dir, progress=True) + + upsampler = imported_modules["RealESRGANer"]( + scale=netscale, + model_path=model_path, + model=rrdb_net, + tile=smash_config["tile"], + tile_pad=smash_config["tile_pad"], + pre_pad=smash_config["pre_pad"], + half=not smash_config["fp32"], + gpu_id=0, + ) + + model.upscale_helper = UpscaleHelper(model, upsampler) + model.upscale_helper.enable() + + return model + + def import_algorithm_packages(self) -> dict[str, Any]: + """ + Import the necessary packages for the RealESRGAN algorithm. + + This method imports all required dependencies for the RealESRGAN upscaling + algorithm, including the RRDBNet architecture, download utilities, and + the RealESRGANer implementation. It also handles a compatibility fix for + torchvision. + + Returns + ------- + dict[str, Any] + Dictionary containing the imported modules, with keys: + - 'RealESRGANer': The main RealESRGAN implementation + - 'RRDBNet': The neural network architecture used by RealESRGAN + - 'load_file_from_url': Utility function to download model weights + """ + from torchvision.transforms.functional import rgb_to_grayscale + + # Create a module for `torchvision.transforms.functional_tensor` + functional_tensor = types.ModuleType("torchvision.transforms.functional_tensor") + functional_tensor.rgb_to_grayscale = rgb_to_grayscale # type: ignore + + # Add this module to sys.modules so other imports can access it + sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor + + from basicsr.archs.rrdbnet_arch import RRDBNet + from basicsr.utils.download_util import load_file_from_url + from realesrgan import RealESRGANer + + return {"RealESRGANer": RealESRGANer, "RRDBNet": RRDBNet, "load_file_from_url": load_file_from_url} + + +class UpscaleHelper: + """ + Helper class for RealESRGAN upscaling. + + This class provides functionality to integrate RealESRGAN upscaling with + any model by wrapping the model's call method to automatically apply + upscaling to the output images. + + Parameters + ---------- + model : Any + The model to enhance with upscaling. + upsampler : Any + The RealESRGAN upsampler instance that performs the actual upscaling. + """ + + def __init__(self, model: Any, upsampler: Any) -> None: + self.model = model + self.upsampler = upsampler + self.original_pipe_call = self.model.__call__ + + def _wrapped_pipe_call(self, *args, **kwargs) -> Any: + """ + Wrap the pipeline call to apply upscaling to the output. + + This method intercepts calls to the model, runs the original call, + and then applies RealESRGAN upscaling to the output images before + returning them. + + Parameters + ---------- + *args : tuple + Positional arguments to pass to the original pipeline call. + **kwargs : dict + Keyword arguments to pass to the original pipeline call. + + Returns + ------- + Any + The upscaled output from the pipeline, with RealESRGAN enhancement + applied to improve image quality and resolution. + """ + output = self.original_pipe_call(*args, **kwargs) + enhanced_images = [] + for image in output.images: + # Get the original image size before enhancement + original_width, original_height = image.size + enhanced_image = self.upsampler.enhance(np.array(image))[0] + enhanced_image = Image.fromarray(enhanced_image) + enhanced_images.append(enhanced_image) + output.images = enhanced_images + return output + + def enable(self) -> None: + """ + Enable the RealESRGAN upscaling by replacing the pipeline call. + + This method replaces the model's __call__ method with the wrapped + version that applies upscaling, effectively enabling automatic + upscaling for all outputs from the model. + """ + self.model.__call__ = self._wrapped_pipe_call + + def disable(self) -> None: + """ + Disable the RealESRGAN upscaling by restoring the original pipeline call. + + This method restores the model's original __call__ method, effectively + disabling the automatic upscaling of outputs from the model. + """ + self.model.__call__ = self.original_pipe_call diff --git a/tests/algorithms/testers/denoise.py b/tests/algorithms/testers/denoise.py new file mode 100644 index 00000000..3d30ce47 --- /dev/null +++ b/tests/algorithms/testers/denoise.py @@ -0,0 +1,13 @@ +from pruna.algorithms.denoise import Img2ImgDenoise + +from .base_tester import AlgorithmTesterBase + + +class TestDenoise(AlgorithmTesterBase): + """Test the Denoise algorithm.""" + + models = ["sd_tiny_random"] + reject_models = ["opt_tiny_random"] + allow_pickle_files = False + algorithm_class = Img2ImgDenoise + metrics = ["lpips"] diff --git a/tests/algorithms/testers/upscale.py b/tests/algorithms/testers/upscale.py new file mode 100644 index 00000000..a1020fed --- /dev/null +++ b/tests/algorithms/testers/upscale.py @@ -0,0 +1,16 @@ +import pytest + +from pruna.algorithms.upscale import RealESRGAN + +from .base_tester import AlgorithmTesterBase + + +@pytest.mark.cuda +class TestUpscale(AlgorithmTesterBase): + """Test the Upscale algorithm.""" + + models = ["sd_tiny_random"] + reject_models = ["opt_tiny_random"] + allow_pickle_files = False + algorithm_class = RealESRGAN + metrics = ["cmmd"]