diff --git a/flashinfer/comm/cuda_ipc.py b/flashinfer/comm/cuda_ipc.py index e85c9f26e8..46c7b2bd6b 100644 --- a/flashinfer/comm/cuda_ipc.py +++ b/flashinfer/comm/cuda_ipc.py @@ -17,9 +17,8 @@ import ctypes from dataclasses import dataclass from typing import Any, Dict, List, Optional - import torch.distributed as dist -from torch.distributed import ProcessGroup +from paddle.base.core import ProcessGroup # NOTE(Zihao): we should use cuda-python instead of ctypes cuda runtime bindings. # However, cuda-python's API is not stable yet, so we use ctypes bindings instead. @@ -205,11 +204,18 @@ def create_shared_buffer( handle = cudart.cudaIpcGetMemHandle(pointer) if group is None: group = dist.group.WORLD - world_size = dist.get_world_size(group=group) + # world_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) - handles = [None] * world_size - dist.all_gather_object(handles, handle, group=group) - handles = [None] * world_size + # handles = [None] * world_size + # dist.all_gather_object(handles, handle, group=group) + # handles = [None] * world_size + # dist.all_gather_object(handles, handle, group=group) + + # The behavior of the paddle framework and torch framework is inconsistent, + # so the following code is used instead + # TODO(bingoo): The PR(https://github.com/PaddlePaddle/Paddle/pull/77152) + # has been fixed. + handles = [None] dist.all_gather_object(handles, handle, group=group) pointers: List[int] = [] diff --git a/flashinfer/comm/nvshmem_allreduce.py b/flashinfer/comm/nvshmem_allreduce.py index 83797d42d8..2ecf268146 100644 --- a/flashinfer/comm/nvshmem_allreduce.py +++ b/flashinfer/comm/nvshmem_allreduce.py @@ -17,7 +17,7 @@ from typing import Optional import torch -from torch.distributed import ProcessGroup +from paddle.base.core import ProcessGroup from .nvshmem import get_nvshmem_module diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 31d6f51892..82f43a515c 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -22,7 +22,7 @@ import torch import torch.distributed as dist -from torch.distributed import ProcessGroup +from paddle.base.core import ProcessGroup from ..jit.comm import gen_trtllm_comm_module from ..utils import register_custom_op, round_up @@ -602,9 +602,13 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( print(f"Rank {tp_rank} workspace[{i}] {hex(workspace[i])}") # Store workspace pointers in device tensor - workspace_tensor = torch.tensor( - workspace, dtype=torch.int64, device=torch.device("cuda") - ) + # workspace_tensor = torch.tensor( + # workspace, dtype=torch.int64, device=torch.device("cuda") + # ) + + # There is a bug in the paddle framework when device="CUDA". + # Currently, the bug is being avoided by changing the source code. + workspace_tensor = torch.tensor(workspace, dtype=torch.int64) dist.barrier(group=group) # must sync after create_workspace diff --git a/tests/comm/test_trtllm_allreduce_fusion_paddle.py b/tests/comm/test_trtllm_allreduce_fusion_paddle.py new file mode 100644 index 0000000000..7db7d9cbd1 --- /dev/null +++ b/tests/comm/test_trtllm_allreduce_fusion_paddle.py @@ -0,0 +1,167 @@ +import socket +import pytest + +import paddle +import paddle.distributed as dist_pp + +paddle.enable_compat() +import flashinfer.comm as comm + +import os +import numpy as np + +# test parameters +token_num = 128 +hidden_dim = 1024 +dtype = paddle.float16 +pattern_code = comm.AllReduceFusionPattern.kAllReduce +layout_code = comm.QuantizationSFLayout.LINEAR +launch_with_pdl = False +use_oneshot = True +trigger_completion_at_end = True +fp32_acc = False + + +def kernel(workspace_tensor, rank, world_size): + device = f"cuda:{rank}" + message_size = token_num * hidden_dim + dtype = paddle.float16 + # Create input data + allreduce_in = paddle.randn(message_size, dtype=dtype, device=device) + # allreduce_in_clone = allreduce_in.clone() + all_reduce_out = paddle.zeros(message_size, dtype=dtype, device=device) + + # Add missing required parameters + residual_in = paddle.randn(message_size, dtype=dtype, device=device) + residual_out = paddle.zeros(message_size, dtype=dtype, device=device) + norm_out = paddle.zeros(message_size, dtype=dtype, device=device) + quant_out = paddle.zeros(message_size, dtype=dtype, device=device) + scale_out = paddle.zeros( + message_size // 16, dtype=dtype, device=device + ) # SF_VEC_SIZE = 16 + rms_gamma = paddle.randn(hidden_dim, dtype=dtype, device=device) + rms_eps = 1e-3 + scale_factor = paddle.tensor(0.5, dtype=paddle.float32, device=device) + + # Run fusion operation + print("Running fusion operation...") + comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + world_size=world_size, + world_rank=rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace_tensor, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=all_reduce_out, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=layout_code, + ) + + paddle.cuda.synchronize() + + return allreduce_in, all_reduce_out + + +def _run_simple_worker(world_size, rank, distributed_init_port): + # Create workspace + # paddle.compat.enable_torch_proxy() + # Set all required environment variables + os.environ["FLAGS_SELECTED_GPUS"] = str(rank) # Key: set GPU ID + os.environ["PADDLE_TRAINER_ID"] = str(rank) + os.environ["PADDLE_TRAINERS_NUM"] = str(world_size) + os.environ["PADDLE_RANK_IN_NODE"] = str(rank) + + # Build endpoint list + endpoints = ",".join( + [f"127.0.0.1:{distributed_init_port + i + 10}" for i in range(world_size)] + ) + os.environ["PADDLE_TRAINER_ENDPOINTS"] = endpoints + os.environ["PADDLE_CURRENT_ENDPOINT"] = ( + f"127.0.0.1:{distributed_init_port + rank + 10}" + ) + # Set NCCL related environment variables (optional but recommended) + os.environ["FLAGS_SYNC_NCCL_ALLREDUCE"] = "1" + + # Set device + paddle.set_device(f"gpu:{rank}") + + # Initialize distributed environment + dist_pp.init_parallel_env() + group_pp = dist_pp.get_group() + + try: + # Create workspace + ipc_handles, workspace_tensor = ( + comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, + world_size, + token_num, + hidden_dim, + group=group_pp, + use_fp32_lamport=False, + ) + ) + + dist_pp.barrier(group=group_pp) + + # Run fusion operation + allreduce_in_clone, all_reduce_out = kernel(workspace_tensor, rank, world_size) + + # # Calculate reference result + dist_pp.all_reduce(allreduce_in_clone, group=group_pp) + ref_allreduce_out = allreduce_in_clone.clone() + + # # Verify results + tolerance = 8e-2 + np.testing.assert_allclose( + all_reduce_out.numpy(), ref_allreduce_out.numpy(), atol=tolerance, rtol=1e-2 + ) + + print(f"Rank {rank}: Test passed!") + + finally: + dist_pp.barrier(group=group_pp) + comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group_pp) + dist_pp.destroy_process_group(group=group_pp) + + +def get_open_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def test_trtllm_allreduce_fusion_simple(): + # Fixed test parameters + world_size = 2 + + paddle.manual_seed(42) + paddle.cuda.manual_seed_all(42) + + available_gpus = paddle.cuda.device_count() + if world_size > available_gpus: + pytest.skip(f"Requires {world_size} GPUs, but only {available_gpus} available") + + distributed_init_port = get_open_port() + rank = dist_pp.get_rank() + _run_simple_worker(world_size, rank, distributed_init_port) + + print("Simple allreduce fusion test: passed") + + +# test cmd: python -m paddle.distributed.launch --log_dir=log --devices=0,1 +# ./test_torch_pp_launch.py +if __name__ == "__main__": + test_trtllm_allreduce_fusion_simple()