diff --git a/tests/comm/test_trtllm_allreduce_fusion_paddle.py b/tests/comm/test_trtllm_allreduce_fusion_paddle.py index 7db7d9cbd1..f5b61ca9eb 100644 --- a/tests/comm/test_trtllm_allreduce_fusion_paddle.py +++ b/tests/comm/test_trtllm_allreduce_fusion_paddle.py @@ -1,10 +1,10 @@ import socket import pytest - import paddle import paddle.distributed as dist_pp paddle.enable_compat() +from paddle.device.cuda.graphs import CUDAGraph import flashinfer.comm as comm import os @@ -36,15 +36,12 @@ def kernel(workspace_tensor, rank, world_size): residual_out = paddle.zeros(message_size, dtype=dtype, device=device) norm_out = paddle.zeros(message_size, dtype=dtype, device=device) quant_out = paddle.zeros(message_size, dtype=dtype, device=device) - scale_out = paddle.zeros( - message_size // 16, dtype=dtype, device=device - ) # SF_VEC_SIZE = 16 + scale_out = paddle.zeros(message_size // 16, dtype=dtype, device=device) rms_gamma = paddle.randn(hidden_dim, dtype=dtype, device=device) rms_eps = 1e-3 scale_factor = paddle.tensor(0.5, dtype=paddle.float32, device=device) # Run fusion operation - print("Running fusion operation...") comm.trtllm_allreduce_fusion( allreduce_in=allreduce_in, world_size=world_size, @@ -69,7 +66,7 @@ def kernel(workspace_tensor, rank, world_size): layout_code=layout_code, ) - paddle.cuda.synchronize() + # paddle.cuda.synchronize() return allreduce_in, all_reduce_out @@ -91,6 +88,7 @@ def _run_simple_worker(world_size, rank, distributed_init_port): os.environ["PADDLE_CURRENT_ENDPOINT"] = ( f"127.0.0.1:{distributed_init_port + rank + 10}" ) + # Set NCCL related environment variables (optional but recommended) os.environ["FLAGS_SYNC_NCCL_ALLREDUCE"] = "1" @@ -117,7 +115,25 @@ def _run_simple_worker(world_size, rank, distributed_init_port): dist_pp.barrier(group=group_pp) # Run fusion operation - allreduce_in_clone, all_reduce_out = kernel(workspace_tensor, rank, world_size) + loop = 5 + s = paddle.cuda.Stream() + s.wait_stream(paddle.cuda.current_stream()) + with paddle.cuda.stream(s): + for _ in range(loop): + allreduce_in_clone, all_reduce_out = kernel( + workspace_tensor, rank, world_size + ) + + g = CUDAGraph() + g.capture_begin() + for _ in range(loop): + allreduce_in_clone, all_reduce_out = kernel( + workspace_tensor, rank, world_size + ) + g.capture_end() + + g.replay() + paddle.cuda.synchronize() # # Calculate reference result dist_pp.all_reduce(allreduce_in_clone, group=group_pp)