From 4bd69182be4165d9f441cb863e5bc0e709bc4383 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 12 Dec 2025 17:05:35 +0000 Subject: [PATCH 1/2] ring_attn added --- src/pruna/algorithms/qkv_diffusers.py | 1 + src/pruna/algorithms/ring_attn/__init__.py | 18 + src/pruna/algorithms/ring_attn/ring.py | 364 ++++++++++++++++++ .../algorithms/ring_attn/utils/__init__.py | 13 + .../algorithms/ring_attn/utils/ring_utils.py | 144 +++++++ .../ring_attn/utils/server_utils.py | 253 ++++++++++++ .../algorithms/torch_compile/torch_compile.py | 1 + tests/algorithms/testers/ring_distributer.py | 44 +++ 8 files changed, 838 insertions(+) create mode 100644 src/pruna/algorithms/ring_attn/__init__.py create mode 100644 src/pruna/algorithms/ring_attn/ring.py create mode 100644 src/pruna/algorithms/ring_attn/utils/__init__.py create mode 100644 src/pruna/algorithms/ring_attn/utils/ring_utils.py create mode 100644 src/pruna/algorithms/ring_attn/utils/server_utils.py create mode 100644 tests/algorithms/testers/ring_distributer.py diff --git a/src/pruna/algorithms/qkv_diffusers.py b/src/pruna/algorithms/qkv_diffusers.py index 5e5780a6..305a73e0 100644 --- a/src/pruna/algorithms/qkv_diffusers.py +++ b/src/pruna/algorithms/qkv_diffusers.py @@ -50,6 +50,7 @@ class QKVFusing(PrunaAlgorithmBase): "deepcache", "fora", "torch_compile", + "ring_attn", ] def model_check_fn(self, model: Any) -> bool: diff --git a/src/pruna/algorithms/ring_attn/__init__.py b/src/pruna/algorithms/ring_attn/__init__.py new file mode 100644 index 00000000..56ad877f --- /dev/null +++ b/src/pruna/algorithms/ring_attn/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pruna.algorithms.ring_attn.ring import RingAttn +from pruna.algorithms.ring_attn.utils.server_utils import DistributedServer + +__all__ = ["RingAttn", "DistributedServer"] diff --git a/src/pruna/algorithms/ring_attn/ring.py b/src/pruna/algorithms/ring_attn/ring.py new file mode 100644 index 00000000..f363d5d7 --- /dev/null +++ b/src/pruna/algorithms/ring_attn/ring.py @@ -0,0 +1,364 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +import functools +from collections.abc import Iterable +from types import ModuleType +from typing import Any, List, Optional, Union + +import torch +import torch.distributed as dist +from ConfigSpace import CategoricalHyperparameter +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from diffusers.models.transformers.transformer_wan import WanTransformer3DModel +from torch.distributed.tensor.device_mesh import DeviceMesh + +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag +from pruna.algorithms.ring_attn.utils.ring_utils import RingDistributedContext +from pruna.config.hyperparameters import Boolean +from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper +from pruna.engine.save import SAVE_FUNCTIONS + +ring_attention: ModuleType | None = None + +with contextlib.suppress(ImportError): + # see "import_algorithm_packages" for further explanation + import torch.distributed.tensor.experimental._attention as ring_attention + + +class RingAttn(PrunaAlgorithmBase): + """ + Distributed attention on multiple GPUs computation by using the torch native ring attention implementation. + + Each GPU stores only its own slice of Q/K/V and participates in a Ring Attention shuffle that lets every query + attend to every key/value. The result is lower KV-cache/activation memory per GPU and higher arithmetic intensity. + """ + + algorithm_name: str = "ring_attn" + group_tags: list[AlgorithmTag] = [AlgorithmTag.KERNEL] + save_fn = SAVE_FUNCTIONS.reapply + references = { + "Implementation": "https://docs.pytorch.org/tutorials/prototype/context_parallel.html", + "Paper": "https://arxiv.org/pdf/2310.01889", + } + tokenizer_required: bool = False + processor_required: bool = False + runs_on: list[str] = ["cuda"] + dataset_required: bool = False + compatible_before: Iterable[str | AlgorithmTag] = [ + "qkv_diffusers", + ] + compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile"] + + def get_hyperparameters(self) -> list: + """ + Get the hyperparameters for the RingAttn. + + Returns + ------- + list + A list of hyperparameters. + """ + return [ + Boolean( + "convert_to_f32", + default=True, + meta=dict(desc="Allowing intermediate computations in the attention mechanism to be upcast to 32-bit."), + ), + CategoricalHyperparameter( + "rotate_method", + default_value="ALL_TO_ALL", + meta=dict(desc="The method to use for rotating the computations."), + choices=["ALL_TO_ALL", "ALL_GATHER"], + ), + ] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is supported by the RingAttn. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is supported, False otherwise. + """ + if torch.cuda.device_count() < 2: + raise ValueError("RingAttn requires at least 2 GPUs") + + return hasattr(model, "transformer") and isinstance( + model.transformer, (FluxTransformer2DModel, WanTransformer3DModel) + ) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + + # configure the ring attention hyperparameters + _cp_options = ring_attention._cp_options # type: ignore + _cp_options.convert_to_f32 = smash_config["convert_to_f32"] + _cp_options.enable_load_balance = False + _cp_options.rotate_method = getattr(ring_attention._RotateMethod, smash_config["rotate_method"]) # type: ignore + + wrap_pipeline_call(model, torch.cuda.device_count()) + + mesh = dist.init_device_mesh("cuda", (torch.cuda.device_count(),), mesh_dim_names=("ring_dim",)) + rank = dist.get_rank() + world_size = torch.cuda.device_count() + + if isinstance(model.transformer, FluxTransformer2DModel): + wrap_flux2d_transformer_forward( + model.transformer, + world_size, + smash_config._base_config, + rank, + mesh, + cache_helper=getattr(model, "cache_helper", None), + ) + elif isinstance(model.transformer, WanTransformer3DModel): + wrap_wan3d_transformer_forward(model.transformer, world_size, smash_config._base_config, rank, mesh) + else: + raise ValueError(f"Unsupported transformer type: {type(model.transformer)}") + + return model + + def import_algorithm_packages(self) -> dict[str, Any]: + """ + Import the algorithm packages. + + Returns + ------- + dict[str, Any] + The algorithm packages. + """ + # even though it is a torch import we isolate it, as experimental modules can often change the interface + # we import the package even though we dont use it directly to make sure it is available + # additionally, we can not pass it as module to the distributed setup (not picklable) + # nor as a string (the import massively irritates torch.compile) + # we import it on the top of the file if available + import torch.distributed.tensor.experimental._attention as ring_attention # noqa: F401 + + return dict() + + +def wrap_wan3d_transformer_forward( + model: Any, + world_size: int, + smash_config: Union[SmashConfig, SmashConfigPrefixWrapper], + rank: int, + mesh: DeviceMesh, +) -> Any: + """ + Wrap the transformer forward pass to chunk the inputs and intercept the torch attention function. + + Parameters + ---------- + model : Any + The transformer model to wrap. + world_size : int + The number of GPUs to distribute the model on. + smash_config : SmashConfig + The SmashConfig to use. + rank : int + The rank of the current process. + mesh : DeviceMesh + The mesh to use for the distributed attention. + """ + for i, block in enumerate(model.blocks): + block_original = block.forward + + @functools.wraps(block_original) + def block_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + _block_ref=block, + _original_forward=block_original, + _layer_id=i, + _num_layers=len(model.blocks), + ) -> torch.Tensor: + # on the first layer, we chunk the hidden states + if _layer_id == 0: + hidden_states = hidden_states.chunk(world_size, dim=-2)[rank] + + rotary_emb = rotary_emb.chunk(world_size, dim=-2)[rank] + + # Use compiled version if available, otherwise use original (not the wrapped one!) + forward_to_call = getattr(_block_ref, "compiled_forward", _original_forward) + + with RingDistributedContext(mesh, smash_config): + hidden_states = forward_to_call(hidden_states, encoder_hidden_states, temb, rotary_emb) + + # on the last layer, we sync back the hidden states + if _layer_id == _num_layers - 1: + return sync_tensor(hidden_states, dim=-2, group=dist.distributed_c10d._get_default_group()) + + return hidden_states + + block.original_forward = block_original + block.forward = block_forward.__get__(block) # type: ignore + + +def wrap_pipeline_call(model: Any, world_size: int) -> Any: + """ + Wrap the model forward pass to set up a generator with rank-specific device. + + Parameters + ---------- + model : Any + The model to wrap. + world_size : int + The number of GPUs to distribute the model on. + """ + # Set up generator with rank-specific device, if it is not explicitly specified the different + # processes might sample different seeds, we have to sync this + original_forward = model.__call__ + + @functools.wraps(original_forward) + def new_forward( + *args, + **kwargs, + ): + rank = kwargs.pop("rank") if "rank" in kwargs else dist.get_rank() + if "generator" not in kwargs: + # if we distributed manually, we can not use "dist" to get the rank, in this case we pass the rank ourselves + seed_t = torch.randint(0, torch.iinfo(torch.int64).max, [1], dtype=torch.int64, device=f"cuda:{rank}") + seed_t = sync_tensor(seed_t, dim=0, group=None) + seed_t = seed_t.chunk(world_size, dim=0)[0] + seed = seed_t.item() + seed -= torch.iinfo(torch.int64).min + generator = torch.Generator(f"cuda:{rank}").manual_seed(seed) + kwargs["generator"] = generator + + return original_forward(*args, **kwargs) + + model.__call__ = new_forward # type: ignore + + +def wrap_flux2d_transformer_forward( + model: Any, + world_size: int, + smash_config: Union[SmashConfig, SmashConfigPrefixWrapper], + rank: int, + mesh: DeviceMesh, + cache_helper: Any | None = None, +) -> Any: + """ + Wrap the transformer forward pass to chunk the inputs and intercept the torch attention function. + + Parameters + ---------- + model : Any + The transformer model to wrap. + world_size : int + The number of GPUs to distribute the model on. + smash_config : SmashConfig + The SmashConfig to use. + rank : int + The rank of the current process. + mesh : DeviceMesh + The mesh to use for the distributed attention. + cache_helper : Any | None + The cache helper if one is present in the pipe. + """ + original_forward = model.forward + + @functools.wraps(original_forward) + def new_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + img_ids: torch.Tensor | None = None, + txt_ids: torch.Tensor | None = None, + *args, + **kwargs, + ): + # split all input tensors along the sequence length dimension and get chunk for this process (rank) + # we do the forward pass on two separate chunks and only "sync" when the attention is computed + # for intuition: number of chunks = number of GPUs + hidden_states = hidden_states.chunk(world_size, dim=1)[rank] + encoder_hidden_states = ( + encoder_hidden_states.chunk(world_size, dim=1)[rank] if encoder_hidden_states is not None else None + ) + img_ids = img_ids.chunk(world_size, dim=0)[rank] if img_ids is not None else None + txt_ids = txt_ids.chunk(world_size, dim=0)[rank] if txt_ids is not None else None + + # this context basically intercepts any call to F.scaled_dot_product_attention + # and replaces it with the ring attention implementation + with RingDistributedContext(mesh, smash_config): + output = self.inner_forward( + hidden_states, + encoder_hidden_states, + *args, + img_ids=img_ids, + txt_ids=txt_ids, + **kwargs, + ) + + # before we output the result, we attach the separate chunks together again + sample = output[0] + sample = sync_tensor(sample, dim=-2, group=dist.distributed_c10d._get_default_group()) + return (sample, *output[1:]) + + model.forward = new_forward.__get__(model) # type: ignore + model.inner_forward = original_forward.__get__(model if cache_helper is None else cache_helper) # type: ignore + + +def sync_tensor(tensor: torch.Tensor, dim: int, group: dist.ProcessGroup | None) -> torch.Tensor: + """ + Sync a tensor across a process group. + + Parameters + ---------- + tensor : torch.Tensor + The tensor to sync. + dim : int + The dimension to sync along. + group : dist.ProcessGroup | None + The process group to sync across. + + Returns + ------- + torch.Tensor + The synced tensor. + """ + tensor = tensor.transpose(0, dim).contiguous() + + if group is None: + group = dist.distributed_c10d._get_default_group() + + if isinstance(group, dist.ProcessGroup): + pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = group + else: + pg = group.get_group() + + x_shape = tensor.shape + tensor = tensor.flatten() + x_numel = tensor.numel() # type: ignore + tensor = dist._functional_collectives.all_gather_tensor(tensor, group=pg, gather_dim=0) # type: ignore + if isinstance(tensor, dist._functional_collectives.AsyncCollectiveTensor): + tensor.wait() + x_shape = list(x_shape) # type: ignore + x_shape[0] *= tensor.numel() // x_numel # type: ignore + tensor = tensor.reshape(x_shape) # type: ignore + tensor = tensor.transpose(0, dim) + return tensor diff --git a/src/pruna/algorithms/ring_attn/utils/__init__.py b/src/pruna/algorithms/ring_attn/utils/__init__.py new file mode 100644 index 00000000..38e0d7e5 --- /dev/null +++ b/src/pruna/algorithms/ring_attn/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/pruna/algorithms/ring_attn/utils/ring_utils.py b/src/pruna/algorithms/ring_attn/utils/ring_utils.py new file mode 100644 index 00000000..e07112cb --- /dev/null +++ b/src/pruna/algorithms/ring_attn/utils/ring_utils.py @@ -0,0 +1,144 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +from types import ModuleType +from typing import Union + +import torch +import torch.distributed as dist +from torch.nn.functional import scaled_dot_product_attention +from torch.overrides import TorchFunctionMode + +from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper + +ring_attention: ModuleType | None = None + +with contextlib.suppress(ImportError): + # see "import_algorithm_packages" for further explanation + import torch.distributed.tensor.experimental._attention as ring_attention + + +class LocalFunc(torch.autograd.Function): + """ + Local dummy function to mark the ring attention forwarding as an autograd function. + + Parameters + ---------- + *args : Any + The arguments to the autograd class construction. + **kwargs : Any + The keyword arguments to the autograd class construction. + """ + + @staticmethod + def forward(cls, *args, **kwargs): + """ + Forward pass for the ring attention implementation. + + Parameters + ---------- + *args : Any + The arguments to the forward pass. + **kwargs : Any + The keyword arguments to the forward pass. + + Returns + ------- + torch.Tensor + The output tensor. + """ + # when distributing manually, it seems this is overwritten but we have to ensure it is False + ring_attention._cp_options.enable_load_balance = False + # FUTURE: investigate if we can use the efficient implementation here and if it makes sense + return ring_attention._scaled_dot_product_ring_flash_attention(*args, **kwargs)[:2] + + @staticmethod + def backward(cls, *args, **kwargs): + """ + Backward pass for ring attention implementation of flash attention. + + Parameters + ---------- + *args : Any + The arguments to the backward pass. + **kwargs : Any + The keyword arguments to the backward pass. + + Returns + ------- + torch.Tensor + The gradient of the output tensor. + """ + return ring_attention._scaled_dot_product_ring_flash_attention_backward(*args, **kwargs) + + +class RingDistributedContext(TorchFunctionMode): + """ + Intercept *every* call to F.scaled_dot_product_attention and routes it through the ring implementation. + + Parameters + ---------- + device_mesh : dist.DeviceMesh + The device mesh to use for the distributed attention. + smash_config : Union[SmashConfig, SmashConfigPrefixWrapper] + The SmashConfig to use. + """ + + def __init__(self, device_mesh: dist.DeviceMesh, smash_config: Union[SmashConfig, SmashConfigPrefixWrapper]): + super().__init__() + self.pg = device_mesh + self.smash_config = smash_config + + def __torch_function__(self, func, types, args=(), kwargs=None): + """ + Intercept the scaled_dot_product_attention function and route it through the ring implementation. + + Parameters + ---------- + func : Callable + The function to intercept. + types : Tuple[type] + The types of the arguments. + args : Tuple + The arguments to the function. + kwargs : Dict + The keyword arguments to the function. + + Returns + ------- + Any + The result of the function. + """ + kwargs = {} if kwargs is None else kwargs + + if func is torch.Tensor.unflatten: + return torch.unflatten(*args, **kwargs) + + if func is scaled_dot_product_attention: + query, key, value, *extra = args + attn_mask = kwargs.pop("attn_mask", None) + if attn_mask is not None: + raise ValueError("Ring attention path does not support `attn_mask`; use causal masking instead.") + dropout_p = kwargs.get("dropout_p", 0.0) + is_causal = kwargs.get("is_causal", False) + scale = kwargs.get("scale", None) + + out, _ = LocalFunc.apply(self.pg, query, key, value, dropout_p, is_causal, scale) + return out.to(query.dtype) + + # fall back to default behavior + return func(*args, **kwargs) diff --git a/src/pruna/algorithms/ring_attn/utils/server_utils.py b/src/pruna/algorithms/ring_attn/utils/server_utils.py new file mode 100644 index 00000000..4bc32efa --- /dev/null +++ b/src/pruna/algorithms/ring_attn/utils/server_utils.py @@ -0,0 +1,253 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import contextlib +import multiprocessing as mp +import os +import socket + +import torch +import torch.distributed as dist + +from pruna.engine.utils import safe_memory_cleanup +from pruna.logging.logger import pruna_logger + +# Global placeholder for shared pipeline (CPU) +global_pipe = None + + +@atexit.register +def _cleanup_distributed(): + """Clean up the distributed process group even in uncontrolled system exits.""" + if dist.is_initialized(): + dist.destroy_process_group() + + +class DistributedServer: + """ + Wrapper to distribute the model across multiple GPUs. + + This is our way of avoiding to run a script with "torchrun", we do the setup ourselves and spawn processes with + copies of the model manually. + + Parameters + ---------- + pipe : Any + The pipeline to distribute. + smash_config : Any + Configuration for smashing the pipeline. + + Examples + -------- + >>> server = DistributedServer(pipe, wrap_fn, config) + >>> server.start() + >>> result = server("prompt") + """ + + def __init__(self, pipe, smash_config): + self.pipe = pipe + self.pool = None + self.world_size = torch.cuda.device_count() + self.smash_config = smash_config + self.device = "cuda" + if self.world_size < 2: + raise ValueError("Distributers require at least 2 GPUs") + + def set_env_vars(self): + """Set the environment variables for torch multi-processing.""" + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + os.environ.setdefault("MASTER_PORT", str(s.getsockname()[1])) + + @staticmethod + def _static_init_worker(world_size, pipe, smash_config): + """ + Static method version of _init_worker for better pickle compatibility. + + Parameters + ---------- + world_size : int + Total number of processes in the distributed setup + pipe : Any + The pipeline object to initialize + """ + global global_pipe + # Determine rank + rank = mp.current_process()._identity[0] - 1 + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + # Bind this process to its GPU before CUDA operations + torch.cuda.set_device(rank) + # Move the unpickled CPU pipeline to the current GPU + gpu_pipe = pipe.to(f"cuda:{rank}") + gpu_pipe.smashed = True + + from pruna import smash + + global_pipe = smash(gpu_pipe, smash_config) + + @staticmethod + def _static_cleanup_worker(task_data): + """ + Static method to perform cleanup operations on workers before termination. + + This method is called on each worker process before the pool is destroyed. + You can customize this method to perform any specific cleanup operations + needed for your workers. + + Returns + ------- + bool + True if cleanup was successful, False otherwise. + + Examples + -------- + >>> success = DistributedServer._static_cleanup_worker() + >>> print(f"Worker cleanup successful: {success}") + """ + global global_pipe + + if global_pipe is None: + return + + try: + rank = torch.cuda.current_device() + pruna_logger.info(f"Cleaning up worker on rank {rank}") + + global_pipe.to("cpu") + global_pipe.destroy() + safe_memory_cleanup() + + pruna_logger.info(f"Worker cleanup completed successfully on rank {rank}") + global_pipe = None + return True + + except Exception as e: + pruna_logger.error(f"Error during worker cleanup: {e}") + return False + + @staticmethod + def _static_process_task(task_data): + """ + Static method version of _process_task for better pickle compatibility. + + Parameters + ---------- + task_data : tuple + Tuple containing (rank, args, kwargs) where: + - rank: int, the GPU rank to use + - args: tuple, positional arguments to pass to global_pipe + - kwargs: dict, keyword arguments to pass to global_pipe + + Returns + ------- + Any + The output from the global pipe, processed based on rank. + + Examples + -------- + >>> task_data = (0, ("prompt",), {"num_steps": 20}) + >>> result = DistributedServer._static_process_task(task_data) + """ + rank, args, kwargs = task_data + torch.cuda.set_device(rank) + return global_pipe.__call__(*args, rank=rank, **kwargs) + + def start(self): + """Launch a worker pool, sending the CPU pipeline into each child.""" + if self.pool: + return + + self.device = "cuda" + + pruna_logger.info("Spawning distributed setup...") + pruna_logger.info("Before terminating the current process, call smashed_model.destroy() for proper cleanup.") + + self.set_env_vars() + ctx = mp.get_context("spawn") + + self.pool = ctx.Pool( + processes=self.world_size, + initializer=self._static_init_worker, + initargs=(self.world_size, self.pipe, self.smash_config), + ) + + def __call__(self, *args, **kwargs): + """ + Dispatch prompt across ranks and return rank-0's image. + + Parameters + ---------- + *args : Any + Positional arguments to pass to the pipeline + **kwargs : Any + Keyword arguments to pass to the pipeline + + Returns + ------- + Any + The result from rank 0 + """ + if not self.pool: + raise RuntimeError("Runner not started. Use start() or context manager.") + tasks = [(rank, args, kwargs) for rank in range(self.world_size)] + results = self.pool.map(self._static_process_task, tasks) + return results[0] + + def destroy(self): + """Cleanly shut down the worker pool.""" + if self.pool: + pruna_logger.info("Triggering cleanup on all workers...") + self.device = "cpu" + self.pool.map(self._static_cleanup_worker, range(self.world_size)) + + # Close the pool + self.pool.close() + for worker in self.pool._pool: + worker.join(timeout=10) + self.pool = None + + if dist.is_initialized(): + dist.destroy_process_group() + + def __getattr__(self, attr): + """ + Forward all other attributes to the global pipe. + + Parameters + ---------- + attr : str + The attribute name to forward + + Returns + ------- + Any + The attribute value from the pipeline + """ + if attr == "device": + return "cuda" if self.pool else "cpu" + return getattr(self.pipe, attr) + + def save_pretrained(self, path: str): + """ + Offer interface to save the distributed model to inform the user that it is not supported. + + Parameters + ---------- + path : str + The path to save the distributed model. + """ + raise NotImplementedError("Saving a distributed model is not supported at the moment.") diff --git a/src/pruna/algorithms/torch_compile/torch_compile.py b/src/pruna/algorithms/torch_compile/torch_compile.py index 62c4d086..05f9573f 100644 --- a/src/pruna/algorithms/torch_compile/torch_compile.py +++ b/src/pruna/algorithms/torch_compile/torch_compile.py @@ -69,6 +69,7 @@ class TorchCompile(PrunaAlgorithmBase): "flash_attn3", "deepcache", "fora", + "ring_attn", ] def get_hyperparameters(self) -> list: diff --git a/tests/algorithms/testers/ring_distributer.py b/tests/algorithms/testers/ring_distributer.py new file mode 100644 index 00000000..e1d54ec1 --- /dev/null +++ b/tests/algorithms/testers/ring_distributer.py @@ -0,0 +1,44 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from pruna.algorithms.ring_attn.ring import RingAttn +from pruna.engine.pruna_model import PrunaModel + +from .base_tester import AlgorithmTesterBase + + +@pytest.mark.distributed +class TestRingAttn(AlgorithmTesterBase): + """Test the RingAttn algorithm.""" + + models = ["flux_tiny_random", "wan_tiny_random"] + reject_models = ["opt_tiny_random"] + allow_pickle_files = False + algorithm_class = RingAttn + metrics = ["psnr"] + + def post_smash_hook(self, model: PrunaModel) -> None: + """Post-smash hook.""" + assert hasattr(model, "pool") + + def execute_load(self): + """Overwrite model loading as this is not supported for distributed models.""" + pass + + @classmethod + def execute_save(cls, smashed_model: PrunaModel): + """Overwrite model saving as this is not supported for distributed models.""" + pass From b0047faaa0a9b3985a7462d965bd3069512b84b4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 12 Dec 2025 17:23:29 +0000 Subject: [PATCH 2/2] fixed typo --- src/pruna/algorithms/ring_attn/ring.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pruna/algorithms/ring_attn/ring.py b/src/pruna/algorithms/ring_attn/ring.py index f363d5d7..f88a12d4 100644 --- a/src/pruna/algorithms/ring_attn/ring.py +++ b/src/pruna/algorithms/ring_attn/ring.py @@ -62,6 +62,7 @@ class RingAttn(PrunaAlgorithmBase): dataset_required: bool = False compatible_before: Iterable[str | AlgorithmTag] = [ "qkv_diffusers", + "padding_pruning", ] compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile"]