From 3f792975b5ae76520a8af9826f4f16693646b0ba Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Mon, 19 Jan 2026 16:08:04 +0800 Subject: [PATCH 1/3] support all reduce fusion kernel --- flashinfer/comm/cuda_ipc.py | 16 +- flashinfer/comm/nvshmem_allreduce.py | 2 +- flashinfer/comm/trtllm_ar.py | 12 +- .../test_trtllm_allreduce_fusion_paddle.py | 160 ++++++++++++++++++ 4 files changed, 182 insertions(+), 8 deletions(-) create mode 100644 tests/comm/test_trtllm_allreduce_fusion_paddle.py diff --git a/flashinfer/comm/cuda_ipc.py b/flashinfer/comm/cuda_ipc.py index e85c9f26e8..eff838dfb5 100644 --- a/flashinfer/comm/cuda_ipc.py +++ b/flashinfer/comm/cuda_ipc.py @@ -17,9 +17,10 @@ import ctypes from dataclasses import dataclass from typing import Any, Dict, List, Optional - +import paddle +paddle.compat.enable_torch_proxy() import torch.distributed as dist -from torch.distributed import ProcessGroup +from paddle.base.core import ProcessGroup # NOTE(Zihao): we should use cuda-python instead of ctypes cuda runtime bindings. # However, cuda-python's API is not stable yet, so we use ctypes bindings instead. @@ -207,9 +208,14 @@ def create_shared_buffer( group = dist.group.WORLD world_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) - handles = [None] * world_size - dist.all_gather_object(handles, handle, group=group) - handles = [None] * world_size + # handles = [None] * world_size + # dist.all_gather_object(handles, handle, group=group) + # handles = [None] * world_size + # dist.all_gather_object(handles, handle, group=group) + + # The behavior of the paddle framework and torch framework is inconsistent, + # so the following code is used instead + handles = [] dist.all_gather_object(handles, handle, group=group) pointers: List[int] = [] diff --git a/flashinfer/comm/nvshmem_allreduce.py b/flashinfer/comm/nvshmem_allreduce.py index 83797d42d8..2ecf268146 100644 --- a/flashinfer/comm/nvshmem_allreduce.py +++ b/flashinfer/comm/nvshmem_allreduce.py @@ -17,7 +17,7 @@ from typing import Optional import torch -from torch.distributed import ProcessGroup +from paddle.base.core import ProcessGroup from .nvshmem import get_nvshmem_module diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 31d6f51892..2754387e8f 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -20,9 +20,11 @@ from types import SimpleNamespace from typing import List, Optional, Tuple, Union +import paddle +paddle.compat.enable_torch_proxy() import torch import torch.distributed as dist -from torch.distributed import ProcessGroup +from paddle.base.core import ProcessGroup from ..jit.comm import gen_trtllm_comm_module from ..utils import register_custom_op, round_up @@ -602,8 +604,14 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( print(f"Rank {tp_rank} workspace[{i}] {hex(workspace[i])}") # Store workspace pointers in device tensor + # workspace_tensor = torch.tensor( + # workspace, dtype=torch.int64, device=torch.device("cuda") + # ) + + # There is a bug in the paddle framework when device="CUDA". + # Currently, the bug is being avoided by changing the source code. workspace_tensor = torch.tensor( - workspace, dtype=torch.int64, device=torch.device("cuda") + workspace, dtype=torch.int64 ) dist.barrier(group=group) # must sync after create_workspace diff --git a/tests/comm/test_trtllm_allreduce_fusion_paddle.py b/tests/comm/test_trtllm_allreduce_fusion_paddle.py new file mode 100644 index 0000000000..50f06de379 --- /dev/null +++ b/tests/comm/test_trtllm_allreduce_fusion_paddle.py @@ -0,0 +1,160 @@ +import socket +import pytest + +import flashinfer.comm as comm + +import paddle +import paddle.distributed as dist_pp +paddle.compat.enable_torch_proxy() + +import os +import numpy as np + +# test parameters +token_num = 128 +hidden_dim = 1024 +dtype = paddle.float16 +pattern_code = comm.AllReduceFusionPattern.kAllReduce +layout_code = comm.QuantizationSFLayout.LINEAR +launch_with_pdl = False +use_oneshot = True +trigger_completion_at_end = True +fp32_acc = False + +def kernel(workspace_tensor, rank, world_size): + device = f"cuda:{rank}" + message_size = token_num * hidden_dim + dtype = paddle.float16 + # Create input data + allreduce_in = paddle.randn(message_size, dtype=dtype, device=device) + # allreduce_in_clone = allreduce_in.clone() + all_reduce_out = paddle.zeros(message_size, dtype=dtype, device=device) + + # Add missing required parameters + residual_in = paddle.randn(message_size, dtype=dtype, device=device) + residual_out = paddle.zeros(message_size, dtype=dtype, device=device) + norm_out = paddle.zeros(message_size, dtype=dtype, device=device) + quant_out = paddle.zeros(message_size, dtype=dtype, device=device) + scale_out = paddle.zeros(message_size // 16, dtype=dtype, device=device) # SF_VEC_SIZE = 16 + rms_gamma = paddle.randn(hidden_dim, dtype=dtype, device=device) + rms_eps = 1e-3 + scale_factor = paddle.tensor(0.5, dtype=paddle.float32, device=device) + + # Run fusion operation + print("Running fusion operation...") + comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + world_size=world_size, + world_rank=rank, + token_num=token_num, + hidden_dim=hidden_dim, + workspace_ptrs=workspace_tensor, + launch_with_pdl=launch_with_pdl, + use_oneshot=use_oneshot, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=pattern_code, + allreduce_out=all_reduce_out, + residual_in=residual_in, + residual_out=residual_out, + norm_out=norm_out, + quant_out=quant_out, + scale_out=scale_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + scale_factor=scale_factor, + layout_code=layout_code, + ) + + paddle.cuda.synchronize() + + return allreduce_in, all_reduce_out + +def _run_simple_worker(world_size, rank, distributed_init_port): + + # Create workspace + # paddle.compat.enable_torch_proxy() + # Set all required environment variables + os.environ['FLAGS_selected_gpus'] = str(rank) # Key: set GPU ID + os.environ['PADDLE_TRAINER_ID'] = str(rank) + os.environ['PADDLE_TRAINERS_NUM'] = str(world_size) + os.environ['PADDLE_RANK_IN_NODE'] = str(rank) + + # Build endpoint list + endpoints = ','.join([f'127.0.0.1:{distributed_init_port+i+10}' for i in range(world_size)]) + os.environ['PADDLE_TRAINER_ENDPOINTS'] = endpoints + os.environ['PADDLE_CURRENT_ENDPOINT'] = f'127.0.0.1:{distributed_init_port+rank+10}' + # Set NCCL related environment variables (optional but recommended) + os.environ['FLAGS_sync_nccl_allreduce'] = '1' + + # Set device + paddle.set_device(f"gpu:{rank}") + + # Initialize distributed environment + dist_pp.init_parallel_env() + group_pp = dist_pp.get_group() + + try: + # Create workspace + ipc_handles, workspace_tensor = ( + comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + rank, + world_size, + token_num, + hidden_dim, + group=group_pp, + use_fp32_lamport=False, + ) + ) + + dist_pp.barrier(group=group_pp) + + # Run fusion operation + allreduce_in_clone, all_reduce_out = kernel(workspace_tensor, rank, world_size) + + # # Calculate reference result + dist_pp.all_reduce(allreduce_in_clone, group=group_pp) + ref_allreduce_out = allreduce_in_clone.clone() + + # # Verify results + tolerance = 8e-2 + np.testing.assert_allclose(all_reduce_out.numpy(), + ref_allreduce_out.numpy(), atol=tolerance, rtol=1e-2) + + print(f"Rank {rank}: Test passed!") + + finally: + dist_pp.barrier(group=group_pp) + comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group_pp) + dist_pp.destroy_process_group(group=group_pp) + + +def get_open_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def test_trtllm_allreduce_fusion_simple(): + # Fixed test parameters + world_size = 2 + + paddle.manual_seed(42) + paddle.cuda.manual_seed_all(42) + + available_gpus = paddle.cuda.device_count() + if world_size > available_gpus: + pytest.skip(f"Requires {world_size} GPUs, but only {available_gpus} available") + + procs = [] + distributed_init_port = get_open_port() + rank = dist_pp.get_rank() + _run_simple_worker(world_size, rank, distributed_init_port) + + print("Simple allreduce fusion test: passed") + + +# test cmd: python -m paddle.distributed.launch --log_dir=log --devices=0,1 +# ./test_torch_pp_launch.py +if __name__ == "__main__": + test_trtllm_allreduce_fusion_simple() \ No newline at end of file From 3c534ed9cccec76fb33b4de4d135fe989282a253 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Mon, 19 Jan 2026 20:16:37 +0800 Subject: [PATCH 2/3] remove redundant compat interface --- flashinfer/comm/cuda_ipc.py | 6 +-- flashinfer/comm/trtllm_ar.py | 6 +-- .../test_trtllm_allreduce_fusion_paddle.py | 47 +++++++++++-------- 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/flashinfer/comm/cuda_ipc.py b/flashinfer/comm/cuda_ipc.py index eff838dfb5..ada7f99874 100644 --- a/flashinfer/comm/cuda_ipc.py +++ b/flashinfer/comm/cuda_ipc.py @@ -17,8 +17,6 @@ import ctypes from dataclasses import dataclass from typing import Any, Dict, List, Optional -import paddle -paddle.compat.enable_torch_proxy() import torch.distributed as dist from paddle.base.core import ProcessGroup @@ -206,7 +204,7 @@ def create_shared_buffer( handle = cudart.cudaIpcGetMemHandle(pointer) if group is None: group = dist.group.WORLD - world_size = dist.get_world_size(group=group) + # world_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) # handles = [None] * world_size # dist.all_gather_object(handles, handle, group=group) @@ -215,7 +213,7 @@ def create_shared_buffer( # The behavior of the paddle framework and torch framework is inconsistent, # so the following code is used instead - handles = [] + handles = [None] dist.all_gather_object(handles, handle, group=group) pointers: List[int] = [] diff --git a/flashinfer/comm/trtllm_ar.py b/flashinfer/comm/trtllm_ar.py index 2754387e8f..82f43a515c 100644 --- a/flashinfer/comm/trtllm_ar.py +++ b/flashinfer/comm/trtllm_ar.py @@ -20,8 +20,6 @@ from types import SimpleNamespace from typing import List, Optional, Tuple, Union -import paddle -paddle.compat.enable_torch_proxy() import torch import torch.distributed as dist from paddle.base.core import ProcessGroup @@ -610,9 +608,7 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( # There is a bug in the paddle framework when device="CUDA". # Currently, the bug is being avoided by changing the source code. - workspace_tensor = torch.tensor( - workspace, dtype=torch.int64 - ) + workspace_tensor = torch.tensor(workspace, dtype=torch.int64) dist.barrier(group=group) # must sync after create_workspace diff --git a/tests/comm/test_trtllm_allreduce_fusion_paddle.py b/tests/comm/test_trtllm_allreduce_fusion_paddle.py index 50f06de379..7db7d9cbd1 100644 --- a/tests/comm/test_trtllm_allreduce_fusion_paddle.py +++ b/tests/comm/test_trtllm_allreduce_fusion_paddle.py @@ -1,11 +1,11 @@ import socket import pytest -import flashinfer.comm as comm - import paddle import paddle.distributed as dist_pp -paddle.compat.enable_torch_proxy() + +paddle.enable_compat() +import flashinfer.comm as comm import os import numpy as np @@ -21,6 +21,7 @@ trigger_completion_at_end = True fp32_acc = False + def kernel(workspace_tensor, rank, world_size): device = f"cuda:{rank}" message_size = token_num * hidden_dim @@ -35,7 +36,9 @@ def kernel(workspace_tensor, rank, world_size): residual_out = paddle.zeros(message_size, dtype=dtype, device=device) norm_out = paddle.zeros(message_size, dtype=dtype, device=device) quant_out = paddle.zeros(message_size, dtype=dtype, device=device) - scale_out = paddle.zeros(message_size // 16, dtype=dtype, device=device) # SF_VEC_SIZE = 16 + scale_out = paddle.zeros( + message_size // 16, dtype=dtype, device=device + ) # SF_VEC_SIZE = 16 rms_gamma = paddle.randn(hidden_dim, dtype=dtype, device=device) rms_eps = 1e-3 scale_factor = paddle.tensor(0.5, dtype=paddle.float32, device=device) @@ -70,22 +73,26 @@ def kernel(workspace_tensor, rank, world_size): return allreduce_in, all_reduce_out -def _run_simple_worker(world_size, rank, distributed_init_port): +def _run_simple_worker(world_size, rank, distributed_init_port): # Create workspace # paddle.compat.enable_torch_proxy() # Set all required environment variables - os.environ['FLAGS_selected_gpus'] = str(rank) # Key: set GPU ID - os.environ['PADDLE_TRAINER_ID'] = str(rank) - os.environ['PADDLE_TRAINERS_NUM'] = str(world_size) - os.environ['PADDLE_RANK_IN_NODE'] = str(rank) + os.environ["FLAGS_SELECTED_GPUS"] = str(rank) # Key: set GPU ID + os.environ["PADDLE_TRAINER_ID"] = str(rank) + os.environ["PADDLE_TRAINERS_NUM"] = str(world_size) + os.environ["PADDLE_RANK_IN_NODE"] = str(rank) # Build endpoint list - endpoints = ','.join([f'127.0.0.1:{distributed_init_port+i+10}' for i in range(world_size)]) - os.environ['PADDLE_TRAINER_ENDPOINTS'] = endpoints - os.environ['PADDLE_CURRENT_ENDPOINT'] = f'127.0.0.1:{distributed_init_port+rank+10}' + endpoints = ",".join( + [f"127.0.0.1:{distributed_init_port + i + 10}" for i in range(world_size)] + ) + os.environ["PADDLE_TRAINER_ENDPOINTS"] = endpoints + os.environ["PADDLE_CURRENT_ENDPOINT"] = ( + f"127.0.0.1:{distributed_init_port + rank + 10}" + ) # Set NCCL related environment variables (optional but recommended) - os.environ['FLAGS_sync_nccl_allreduce'] = '1' + os.environ["FLAGS_SYNC_NCCL_ALLREDUCE"] = "1" # Set device paddle.set_device(f"gpu:{rank}") @@ -118,8 +125,9 @@ def _run_simple_worker(world_size, rank, distributed_init_port): # # Verify results tolerance = 8e-2 - np.testing.assert_allclose(all_reduce_out.numpy(), - ref_allreduce_out.numpy(), atol=tolerance, rtol=1e-2) + np.testing.assert_allclose( + all_reduce_out.numpy(), ref_allreduce_out.numpy(), atol=tolerance, rtol=1e-2 + ) print(f"Rank {rank}: Test passed!") @@ -138,15 +146,14 @@ def get_open_port() -> int: def test_trtllm_allreduce_fusion_simple(): # Fixed test parameters world_size = 2 - + paddle.manual_seed(42) paddle.cuda.manual_seed_all(42) - + available_gpus = paddle.cuda.device_count() if world_size > available_gpus: pytest.skip(f"Requires {world_size} GPUs, but only {available_gpus} available") - procs = [] distributed_init_port = get_open_port() rank = dist_pp.get_rank() _run_simple_worker(world_size, rank, distributed_init_port) @@ -155,6 +162,6 @@ def test_trtllm_allreduce_fusion_simple(): # test cmd: python -m paddle.distributed.launch --log_dir=log --devices=0,1 -# ./test_torch_pp_launch.py +# ./test_torch_pp_launch.py if __name__ == "__main__": - test_trtllm_allreduce_fusion_simple() \ No newline at end of file + test_trtllm_allreduce_fusion_simple() From 7be316845252513fd75bedbf0aa36d71f88b1816 Mon Sep 17 00:00:00 2001 From: Bingoo <1575938147@qq.com> Date: Mon, 19 Jan 2026 20:23:11 +0800 Subject: [PATCH 3/3] remove redundant compat interface --- flashinfer/comm/cuda_ipc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flashinfer/comm/cuda_ipc.py b/flashinfer/comm/cuda_ipc.py index ada7f99874..46c7b2bd6b 100644 --- a/flashinfer/comm/cuda_ipc.py +++ b/flashinfer/comm/cuda_ipc.py @@ -213,6 +213,8 @@ def create_shared_buffer( # The behavior of the paddle framework and torch framework is inconsistent, # so the following code is used instead + # TODO(bingoo): The PR(https://github.com/PaddlePaddle/Paddle/pull/77152) + # has been fixed. handles = [None] dist.all_gather_object(handles, handle, group=group)