Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions flashinfer/comm/cuda_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
import ctypes
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import torch.distributed as dist
from torch.distributed import ProcessGroup
from paddle.base.core import ProcessGroup

# NOTE(Zihao): we should use cuda-python instead of ctypes cuda runtime bindings.
# However, cuda-python's API is not stable yet, so we use ctypes bindings instead.
Expand Down Expand Up @@ -205,11 +204,18 @@ def create_shared_buffer(
handle = cudart.cudaIpcGetMemHandle(pointer)
if group is None:
group = dist.group.WORLD
world_size = dist.get_world_size(group=group)
# world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
handles = [None] * world_size
# handles = [None] * world_size
# dist.all_gather_object(handles, handle, group=group)
# handles = [None] * world_size
# dist.all_gather_object(handles, handle, group=group)

# The behavior of the paddle framework and torch framework is inconsistent,
# so the following code is used instead
# TODO(bingoo): The PR(https://github.com/PaddlePaddle/Paddle/pull/77152)
# has been fixed.
handles = [None]
dist.all_gather_object(handles, handle, group=group)

pointers: List[int] = []
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/comm/nvshmem_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Optional

import torch
from torch.distributed import ProcessGroup
from paddle.base.core import ProcessGroup

from .nvshmem import get_nvshmem_module

Expand Down
12 changes: 8 additions & 4 deletions flashinfer/comm/trtllm_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from paddle.base.core import ProcessGroup

from ..jit.comm import gen_trtllm_comm_module
from ..utils import register_custom_op, round_up
Expand Down Expand Up @@ -602,9 +602,13 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
print(f"Rank {tp_rank} workspace[{i}] {hex(workspace[i])}")

# Store workspace pointers in device tensor
workspace_tensor = torch.tensor(
workspace, dtype=torch.int64, device=torch.device("cuda")
)
# workspace_tensor = torch.tensor(
# workspace, dtype=torch.int64, device=torch.device("cuda")
# )

# There is a bug in the paddle framework when device="CUDA".
# Currently, the bug is being avoided by changing the source code.
workspace_tensor = torch.tensor(workspace, dtype=torch.int64)

dist.barrier(group=group) # must sync after create_workspace

Expand Down
167 changes: 167 additions & 0 deletions tests/comm/test_trtllm_allreduce_fusion_paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import socket
import pytest

import paddle
import paddle.distributed as dist_pp

paddle.enable_compat()
import flashinfer.comm as comm

import os
import numpy as np

# test parameters
token_num = 128
hidden_dim = 1024
dtype = paddle.float16
pattern_code = comm.AllReduceFusionPattern.kAllReduce
layout_code = comm.QuantizationSFLayout.LINEAR
launch_with_pdl = False
use_oneshot = True
trigger_completion_at_end = True
fp32_acc = False


def kernel(workspace_tensor, rank, world_size):
device = f"cuda:{rank}"
message_size = token_num * hidden_dim
dtype = paddle.float16
# Create input data
allreduce_in = paddle.randn(message_size, dtype=dtype, device=device)
# allreduce_in_clone = allreduce_in.clone()
all_reduce_out = paddle.zeros(message_size, dtype=dtype, device=device)

# Add missing required parameters
residual_in = paddle.randn(message_size, dtype=dtype, device=device)
residual_out = paddle.zeros(message_size, dtype=dtype, device=device)
norm_out = paddle.zeros(message_size, dtype=dtype, device=device)
quant_out = paddle.zeros(message_size, dtype=dtype, device=device)
scale_out = paddle.zeros(
message_size // 16, dtype=dtype, device=device
) # SF_VEC_SIZE = 16
rms_gamma = paddle.randn(hidden_dim, dtype=dtype, device=device)
rms_eps = 1e-3
scale_factor = paddle.tensor(0.5, dtype=paddle.float32, device=device)

# Run fusion operation
print("Running fusion operation...")
comm.trtllm_allreduce_fusion(
allreduce_in=allreduce_in,
world_size=world_size,
world_rank=rank,
token_num=token_num,
hidden_dim=hidden_dim,
workspace_ptrs=workspace_tensor,
launch_with_pdl=launch_with_pdl,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=pattern_code,
allreduce_out=all_reduce_out,
residual_in=residual_in,
residual_out=residual_out,
norm_out=norm_out,
quant_out=quant_out,
scale_out=scale_out,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_factor,
layout_code=layout_code,
)

paddle.cuda.synchronize()

return allreduce_in, all_reduce_out


def _run_simple_worker(world_size, rank, distributed_init_port):
# Create workspace
# paddle.compat.enable_torch_proxy()
# Set all required environment variables
os.environ["FLAGS_SELECTED_GPUS"] = str(rank) # Key: set GPU ID
os.environ["PADDLE_TRAINER_ID"] = str(rank)
os.environ["PADDLE_TRAINERS_NUM"] = str(world_size)
os.environ["PADDLE_RANK_IN_NODE"] = str(rank)

# Build endpoint list
endpoints = ",".join(
[f"127.0.0.1:{distributed_init_port + i + 10}" for i in range(world_size)]
)
os.environ["PADDLE_TRAINER_ENDPOINTS"] = endpoints
os.environ["PADDLE_CURRENT_ENDPOINT"] = (
f"127.0.0.1:{distributed_init_port + rank + 10}"
)
# Set NCCL related environment variables (optional but recommended)
os.environ["FLAGS_SYNC_NCCL_ALLREDUCE"] = "1"

# Set device
paddle.set_device(f"gpu:{rank}")

# Initialize distributed environment
dist_pp.init_parallel_env()
group_pp = dist_pp.get_group()

try:
# Create workspace
ipc_handles, workspace_tensor = (
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
rank,
world_size,
token_num,
hidden_dim,
group=group_pp,
use_fp32_lamport=False,
)
)

dist_pp.barrier(group=group_pp)

# Run fusion operation
allreduce_in_clone, all_reduce_out = kernel(workspace_tensor, rank, world_size)

# # Calculate reference result
dist_pp.all_reduce(allreduce_in_clone, group=group_pp)
ref_allreduce_out = allreduce_in_clone.clone()

# # Verify results
tolerance = 8e-2
np.testing.assert_allclose(
all_reduce_out.numpy(), ref_allreduce_out.numpy(), atol=tolerance, rtol=1e-2
)

print(f"Rank {rank}: Test passed!")

finally:
dist_pp.barrier(group=group_pp)
comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group_pp)
dist_pp.destroy_process_group(group=group_pp)


def get_open_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]


def test_trtllm_allreduce_fusion_simple():
# Fixed test parameters
world_size = 2

paddle.manual_seed(42)
paddle.cuda.manual_seed_all(42)

available_gpus = paddle.cuda.device_count()
if world_size > available_gpus:
pytest.skip(f"Requires {world_size} GPUs, but only {available_gpus} available")

distributed_init_port = get_open_port()
rank = dist_pp.get_rank()
_run_simple_worker(world_size, rank, distributed_init_port)

print("Simple allreduce fusion test: passed")


# test cmd: python -m paddle.distributed.launch --log_dir=log --devices=0,1
# ./test_torch_pp_launch.py
if __name__ == "__main__":
test_trtllm_allreduce_fusion_simple()
Loading