From 823adfda2c46bc4f02e831925b0fb3af31c5c6c3 Mon Sep 17 00:00:00 2001 From: alextmagro Date: Thu, 30 Oct 2025 15:57:29 -0500 Subject: [PATCH 1/2] ROCm UserBuffers for Comm Overlap --- build_tools/pytorch.py | 23 +- ci/pytorch.sh | 2 + .../te_layer_with_overlap.py | 4 +- .../te_layer_with_overlap_profile.py | 504 +++++++++++++ .../pytorch/comm_gemm_overlap/ub_config.json | 15 + hipify_custom_map.json | 9 +- setup.py | 11 +- .../cpp/operator/test_normalization_mxfp8.cu | 6 +- .../distributed/run_layer_with_overlap.py | 7 + .../distributed/test_comm_gemm_overlap.py | 7 +- transformer_engine/common/CMakeLists.txt | 27 +- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 129 +++- .../rocm_comm_gemm_overlap.cpp | 664 ++++++++++++++++++ .../userbuffers/userbuffers-host.cpp | 43 +- .../userbuffers/userbuffers.cu | 105 ++- .../userbuffers/userbuffers.h | 27 +- .../transformer_engine/comm_gemm_overlap.h | 91 ++- .../include/transformer_engine/multi_stream.h | 6 + .../common/util/cuda_runtime.cpp | 8 +- transformer_engine/common/util/cuda_runtime.h | 2 - .../common/util/pybind_helper.h | 18 +- transformer_engine/pytorch/csrc/common.h | 2 - transformer_engine/pytorch/csrc/extensions.h | 10 - .../csrc/extensions/comm_gemm_overlap.cpp | 5 +- .../pytorch/csrc/extensions/gemm.cpp | 22 +- .../pytorch/csrc/extensions/pybind.cpp | 6 - transformer_engine/pytorch/module/base.py | 10 +- .../pytorch/ops/fused/__init__.py | 21 +- transformer_engine/pytorch/transformer.py | 6 +- 29 files changed, 1641 insertions(+), 149 deletions(-) create mode 100644 examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py create mode 100644 examples/pytorch/comm_gemm_overlap/ub_config.json create mode 100644 transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index bb084293f..8fcf72f7e 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -85,14 +85,6 @@ def setup_pytorch_extension( if version < (12, 0): raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") - if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): - assert ( - os.getenv("MPI_HOME") is not None - ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" - mpi_path = Path(os.getenv("MPI_HOME")) - include_dirs.append(mpi_path / "include") - cxx_flags.append("-DNVTE_UB_WITH_MPI") - library_dirs = [] libraries = [] if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): @@ -106,12 +98,22 @@ def setup_pytorch_extension( cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") if bool(int(os.getenv("NVTE_ENABLE_ROCSHMEM", 0))): - cxx_flags.append("-DNVTE_ENABLE_ROCSHMEM") mpi_home = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi")) include_dirs.append(mpi_home / "include") library_dirs.append(mpi_home / "lib") - libraries.append("mpi_cxx") + libraries.append("mpi") + cxx_flags.extend(["-DNVTE_ENABLE_ROCSHMEM", "-DOMPI_SKIP_MPICXX"]) + extra_link_args = [] + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!" + mpi_path = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi")) + include_dirs.append(mpi_path / "include") + library_dirs.append(mpi_path / "lib") + libraries.append("mpi") + cxx_flags.extend(["-DNVTE_UB_WITH_MPI", "-DOMPI_SKIP_MPICXX"]) # Construct PyTorch CUDA extension sources = [str(path) for path in sources] @@ -125,4 +127,5 @@ def setup_pytorch_extension( extra_compile_args={"cxx": cxx_flags}, libraries=[str(lib) for lib in libraries], library_dirs=[str(lib_dir) for lib_dir in library_dirs], + extra_link_args=[str(arg) for arg in extra_link_args], ) diff --git a/ci/pytorch.sh b/ci/pytorch.sh index be150485f..853cfc8f9 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -89,6 +89,8 @@ run_test_config_mgpu(){ configure_omp_threads 8 run_default_fa 1 test_fused_optimizer.py run_default_fa 3 test_sanity_import.py + run_default_fa 3 distributed/test_fusible_ops_with_userbuffers.py + run_default_fa 3 distributed/test_comm_gemm_overlap.py run_default_fa 2 distributed/test_fusible_ops.py run_default_fa 2 distributed/test_numerics.py run_default_fa 1 distributed/test_torch_fsdp2.py diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index d52e97d65..1fd40305c 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -68,7 +68,7 @@ def _parse_args(argv=None, namespace=None): ) parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") parser.add_argument( - "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + "--fp8", action="store_true", default=False, help="Enables the te.autocast() context." ) parser.add_argument( "--no-comm-overlap", @@ -299,7 +299,7 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) dist_print(" |-- Forward pass", group=tp_group, debug=True) with torch.amp.autocast("cuda", dtype=torch.bfloat16): - with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + with te.autocast(enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world): y = model(x) if isinstance(y, tuple): out, *_ = y diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py new file mode 100644 index 000000000..ba5afd2b6 --- /dev/null +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py @@ -0,0 +1,504 @@ +#!/usr/bin/python3 + +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import sys +import socket +import fcntl +import struct +import argparse +import warnings + +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel + +import torch.profiler + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.cpp_extensions as tex +from transformer_engine.common.recipe import Format, DelayedScaling + +warnings.filterwarnings("ignore", category=DeprecationWarning) +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +if not tex.device_supports_multicast(): + os.environ["UB_SKIPMC"] = "1" + + +def _te_layer_argtype(name): + te_layers = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, + ] + layer_map = dict(zip([layer.__name__.lower() for layer in te_layers], te_layers)) + if name.lower() not in layer_map.keys(): + raise argparse.ArgumentTypeError( + f"Invalid TE layer name! Please choose from: {layer_map.keys()}" + ) + return layer_map[name.lower()] + + +def _parse_args(argv=None, namespace=None): + parser = argparse.ArgumentParser( + description="Train a Transformer Engine module with GEMM+comm overlap via Userbuffers." + ) + parser.add_argument( + "-i", "--num-iters", type=int, default=10, help="Number of dummy 'training' iterations." + ) + parser.add_argument("-b", "--batch-size", type=int, default=8, help="Input batch size.") + parser.add_argument("-s", "--seq-length", type=int, default=16384, help="Input sequence length.") + parser.add_argument( + "-n", "--num-heads", type=int, default=64, help="Number of attention heads." + ) + parser.add_argument( + "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head." + ) + parser.add_argument( + "--layer-type", + type=_te_layer_argtype, + default=te.TransformerLayer, + help="Transformer Engine layer to train with comm+GEMM overlap.", + ) + parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") + parser.add_argument( + "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + ) + parser.add_argument( + "--no-comm-overlap", + action="store_true", + default=False, + help="Disable the comm+GEMM overlap.", + ) + parser.add_argument( + "--num-replicas", type=int, default=1, help="Number of data-parallel model replicas." + ) + parser.add_argument( + "--tcp-init", + action="store_true", + default=False, + help="Initialize torch.distributed with TcpStore.", + ) + parser.add_argument( + "--bind-to-device", + action="store_true", + default=False, + help="Initialize torch.distributed with `device_id` to bind each rank to a single device.", + ) + parser.add_argument( + "--bootstrap-backend", + type=str.lower, + default="nccl", + choices=["gloo", "mpi", "nccl"], + help="Communications backend for host tensor collectives during Userbuffers bootstrapping.", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="Print out from every rank instead of just the root rank of relevant process groups.", + ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Print out additional debug information.", + ) + parser.add_argument( + "--profile", + action="store_true", + default=False, + help="Enable PyTorch profiler.", + ) + parser.add_argument( + "--profile-dir", + type=str, + default="./logs/profiler_traces", + help="Directory to save PyTorch profiler traces.", + ) + parser.add_argument( + "--ub_config", + type=str, + default="./ub_config.json", + help="Userbuffer configuration file.", + ) + + args = parser.parse_args(argv, namespace) + if args.bootstrap_backend == "nccl": + args.bind_to_device = True + return args + + +def _get_layer_args(config, tp_group, tp_size, reference=False): + hidden_size = config.num_heads * config.head_dim + input_shape = [config.seq_length, config.batch_size, hidden_size] + args = [hidden_size] + kwargs = { + "params_dtype": torch.float32, + "device": "cuda", + "tp_group": tp_group, + "tp_size": tp_size, + "sequence_parallel": True, + } + kwargs["ub_overlap_ag"] = not config.no_comm_overlap + + if config.layer_type is te.Linear: + input_shape[2] = hidden_size // tp_size + args.append(hidden_size) + kwargs["parallel_mode"] = "row" + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + kwargs["ub_name"] = "proj" + else: + input_shape[0] = config.seq_length // tp_size + if config.layer_type is te.LayerNormLinear: + args.append(3 * hidden_size) + kwargs["parallel_mode"] = "column" + kwargs["ub_name"] = "qkv" + else: + kwargs["set_parallel_mode"] = True + kwargs["ub_overlap_rs"] = not config.no_comm_overlap + if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: + # args.append(4 * hidden_size) + args.append(int(3.5 * hidden_size)) + + kwargs["seq_length"] = config.seq_length + if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args.append(config.num_heads) + kwargs["attention_dropout"] = 0.0 + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + else: + kwargs["ub_tp_comm_overlap"] = not config.no_comm_overlap + kwargs["hidden_dropout"] = 0.0 + + return args, kwargs, input_shape + +def create_ub_cfgs(config_file: str, tp_size: int = 8): + import json + with open(config_file, 'r') as f: + data = json.load(f) + cfgs = {} + _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None + layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] + layers_all_gather_overlap = [ + "qkv_fprop", + "qkv_dgrad", + "proj_dgrad", + "proj_wgrad", + "fc1_fprop", + "fc1_dgrad", + "fc2_dgrad", + "fc2_wgrad", + ] + + for name, method in data.items(): + if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None: + _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range() + + cfg = { + "method": method, + "is_reduce_scatter": name in layers_reduce_scatter_overlap, + "num_sm": 1 if method == "ring_exchange" else 16, + "cga_size": 1 if method == "ring_exchange" else 2, + "set_sm_margin": False, + "num_splits": tp_size if method == "ring_exchange" else 4, + "aggregate": False, + "atomic_gemm": False, + "use_ce": True, + "fp8_buf": name in layers_all_gather_overlap, + "comm_priority": _MAX_STREAM_PRIORITY, + "gemm_priority": _MIN_STREAM_PRIORITY, + } + + cfgs[name] = cfg + + return cfgs + +def _train(opts): + if "OMPI_COMM_WORLD_SIZE" in os.environ: + # Execution with `mpirun -np N` + WORLD_RANK = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) + WORLD_SIZE = int(os.getenv("OMPI_COMM_WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("OMPI_COMM_WORLD_LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) + opts.tcp_init = True + opts.bind_to_device = True + opts.bootstrap_backend = "mpi" + elif "TORCHELASTIC_RUN_ID" in os.environ: + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + else: + raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") + NUM_NODES = WORLD_SIZE // LOCAL_SIZE + + # Initialize torch.distributed global process group and get DP/TP groups + torch.cuda.set_device(LOCAL_RANK) + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + if opts.tcp_init or NUM_NODES > 1: + if NUM_NODES > 1: + assert ( + "MASTER_ADDR" in os.environ + ), "Multi-node run requires MASTER_ADDR to be set in the environment." + MASTER_ADDR = os.getenv("MASTER_ADDR", socket.gethostbyname(socket.gethostname())) + MASTER_PORT = os.getenv("MASTER_PORT", "1234") + dist_init_kwargs["init_method"] = f"tcp://{MASTER_ADDR}:{MASTER_PORT}" + if opts.bind_to_device or opts.bootstrap_backend == "nccl": + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + nccl_world = dist.new_group(backend="nccl") + + def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False): + if debug and not opts.debug: + return + group_rank = dist.get_rank(group) + stream = sys.stderr if error else sys.stdout + if group_rank == src: + stream.write(f"[rank{WORLD_RANK}] {msg}{end}") + dist.barrier(group) + + dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") + + # Figure out process groups for tensor- and data-parallelism (if any) + if NUM_NODES > 1: + # Create a list of world ranks on this node + hostname = socket.gethostname() + ifname = os.getenv( + "NVTE_UB_SOCKET_IFNAME", + os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), + ) + + if ifname is not None: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + hostname = socket.inet_ntoa( + fcntl.ioctl( + s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) + )[20:24] + ) + except OSError as err: + raise OSError(f"Invalid network interface: {ifname}") from err + + hostnames = [None for _ in range(WORLD_SIZE)] + dist.all_gather_object(hostnames, hostname) + unique_hosts = [] + for host in hostnames: + if host not in unique_hosts: + unique_hosts.append(host) + assert len(unique_hosts) == NUM_NODES + + ranks_per_node_list = [[] for _ in range(NUM_NODES)] + self_node_idx = -1 + for i, host in enumerate(hostnames): + node_idx = unique_hosts.index(host) + ranks_per_node_list[node_idx].append(i) + if host == hostname: + self_node_idx = node_idx + assert self_node_idx >= 0 + self_node_ranks = ranks_per_node_list[self_node_idx] + + if opts.num_replicas > 1: + # Split node ranks into multiple replicas + assert len(self_node_ranks) % opts.num_replicas == 0 + tp_size = len(self_node_ranks) // opts.num_replicas + ranks_per_replica_list = [] + for node_ranks in ranks_per_node_list: + for i in range(opts.num_replicas): + start = i * tp_size + end = start + tp_size + ranks_per_replica_list.append(node_ranks[start:end]) + + self_replica_idx = -1 + for i, replica_ranks in enumerate(ranks_per_replica_list): + if WORLD_RANK in replica_ranks: + self_replica_idx = i + break + assert self_replica_idx >= 0 + + else: + # The entire node is the tensor-parallel group + ranks_per_replica_list = ranks_per_node_list + self_replica_idx = self_node_idx + + tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl") + ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) + + else: + if opts.num_replicas > 1: + # Mixed data- and tensor-parallelism on a single node + # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions + all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") + ranks_per_replica_tensor = all_ranks.reshape( + (opts.num_replicas, LOCAL_SIZE // opts.num_replicas) + ) + tp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.tolist(), backend="nccl" + ) + dp_group, _ = dist.new_subgroups_by_enumeration( + ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" + ) + else: + dp_group = None + tp_group = nccl_world + + tp_rank = dist.get_rank(tp_group) + tp_size = dist.get_world_size(tp_group) + dist_print( + f"Created tensor-parallel group: {dist.get_process_group_ranks(tp_group)}", + group=tp_group, + ) + if dp_group is not None: + dp_rank = dist.get_rank(dp_group) + dist_print( + f"Created data-parallel group: {dist.get_process_group_ranks(dp_group)}", + group=dp_group, + ) + else: + dp_rank = 0 + + # Intialize userbuffers + hidden_size = opts.num_heads * opts.head_dim + batched_size = opts.seq_length * opts.batch_size + if not opts.no_comm_overlap: + te.module.base.initialize_ub( + [batched_size, hidden_size], + tp_size, + use_fp8=opts.fp8, + dtype=torch.bfloat16, + bootstrap_backend=opts.bootstrap_backend, + ub_cfgs=create_ub_cfgs(opts.ub_config, tp_size) + ) + # Initialize the fused LayerNorm + Multi-layer Perceptron module + torch.manual_seed(opts.seed + dp_rank) + torch.cuda.manual_seed(opts.seed + tp_rank) + layer_args, layer_kwargs, input_size = _get_layer_args(opts, tp_group, tp_size) + model = opts.layer_type(*layer_args, **layer_kwargs) + if dp_group is not None: + model = DistributedDataParallel(model, dim=1, process_group=dp_group) + + # Initialize optimizer with model parameters + optim = torch.optim.Adam(model.parameters(), lr=0.0001) + + # Fp8 recipe setup + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + + if opts.profile: + log_dir = os.path.join(opts.profile_dir, f"rank_{WORLD_RANK}") + os.makedirs(log_dir, exist_ok=True) + dist_print(f"Profiler traces will be saved to: {log_dir}", group=nccl_world) + + schedule = torch.profiler.schedule(wait=1, warmup=2, active=5, repeat=1) + + on_trace_ready = torch.profiler.tensorboard_trace_handler( + log_dir, worker_name=f"rank_{WORLD_RANK}" + ) + + profiler_activities = [ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + import time + + start_time = time.time() + with torch.profiler.profile( + schedule=schedule, + # record_shapes=True, + # with_stack=True, + # with_flops=True, + # with_modules=True, + on_trace_ready=on_trace_ready, + profile_memory=True, + activities=profiler_activities, + ) as prof: + dist_print("Starting training iterations...") + for i in range(opts.num_iters): + dist_print(f" Iter {i+1}", group=tp_group, debug=True) + + dist_print(" |-- Generate random input batch", group=tp_group, debug=True) + x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) + + dist_print(" |-- Forward pass", group=tp_group, debug=True) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + dist_print(" |-- Compute loss", group=tp_group, debug=True) + loss = out.sum() + + dist_print(" |-- Backward pass", group=tp_group, debug=True) + loss.backward() + + dist_print(" |-- Optimizer step", group=tp_group, debug=True) + optim.step() + + prof.step() + torch.cuda.synchronize() + end_time = time.time() + total_wall_clock_time = end_time - start_time + print(f"Total Wall Clock Time: {total_wall_clock_time:.4f} seconds") + # total_flops = sum([item.flops for item in prof.key_averages()]) + # print(f"Total FLOPs: {total_flops}") + else: + dist_print("Starting training iterations...") + for i in range(opts.num_iters): + dist_print(f" Iter {i+1}", group=tp_group, debug=True) + + dist_print(" |-- Generate random input batch", group=tp_group, debug=True) + x = torch.randn(input_size, dtype=torch.float32, device="cuda", requires_grad=True) + + dist_print(" |-- Forward pass", group=tp_group, debug=True) + with torch.amp.autocast("cuda", dtype=torch.bfloat16): + with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + y = model(x) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + dist_print(" |-- Compute loss", group=tp_group, debug=True) + loss = out.sum() + + dist_print(" |-- Backward pass", group=tp_group, debug=True) + loss.backward() + + dist_print(" |-- Optimizer step", group=tp_group, debug=True) + optim.step() + + + dist_print("Finished training!") + te.module.base.destroy_ub() + + dist_print("Destroying all process groups...", debug=True) + dist.destroy_process_group() + if opts.debug and WORLD_RANK == 0: + print("Exiting...\n", end="", flush=True) + + return 0 + + +if __name__ == "__main__": + sys.exit(_train(_parse_args())) \ No newline at end of file diff --git a/examples/pytorch/comm_gemm_overlap/ub_config.json b/examples/pytorch/comm_gemm_overlap/ub_config.json new file mode 100644 index 000000000..a26c7f9f1 --- /dev/null +++ b/examples/pytorch/comm_gemm_overlap/ub_config.json @@ -0,0 +1,15 @@ +{ + "qkv_fprop": "ring_exchange", + "fc1_fprop": "ring_exchange", + "fc2_dgrad": "ring_exchange", + "proj_wgrad": "ring_exchange", + "fc2_wgrad": "ring_exchange", + + + "proj_fprop": "ring_exchange", + "fc2_fprop": "ring_exchange", + + "qkv_dgrad": "ring_exchange", + "fc1_dgrad": "ring_exchange" + +} \ No newline at end of file diff --git a/hipify_custom_map.json b/hipify_custom_map.json index 97824bbdb..e2f29d848 100644 --- a/hipify_custom_map.json +++ b/hipify_custom_map.json @@ -6,7 +6,14 @@ "ATen/cudnn/Handle.h" : "ATen/miopen/Handle.h", "CUfunc_cache" : "hipFuncCache_t", "" : "", - "cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)" + "cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)", + "cudaLaunchKernel": "hipLaunchKernel", + "CUmemGenericAllocationHandle": "hipMemGenericAllocationHandle_t", + "\"cuda_runtime.h\"": "\"hip_runtime.h\"", + "cudaLaunchConfig_t": "hipLaunchConfig_t", + "cudaLaunchAttribute": "hipLaunchAttribute", + "cudaLaunchAttributeCooperative": "hipLaunchAttributeCooperative", + "CUdeviceptr": "hipDeviceptr_t" } } diff --git a/setup.py b/setup.py index 1ae476311..9f94f24a7 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,12 @@ def run(self): def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" cmake_flags = [] + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): + assert ( + os.getenv("MPI_HOME") is not None + ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" + cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") + if rocm_build(): cmake_flags.append("-DUSE_ROCM=ON") if os.getenv("NVTE_AOTRITON_PATH"): @@ -99,11 +105,6 @@ def setup_common_extension() -> CMakeExtension: else: cmake_flags.append("-DUSE_ROCM=OFF") cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)] - if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): - assert ( - os.getenv("MPI_HOME") is not None - ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" - cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))): assert ( diff --git a/tests/cpp/operator/test_normalization_mxfp8.cu b/tests/cpp/operator/test_normalization_mxfp8.cu index e87ed2209..40c5be719 100644 --- a/tests/cpp/operator/test_normalization_mxfp8.cu +++ b/tests/cpp/operator/test_normalization_mxfp8.cu @@ -131,10 +131,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, DType wtype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input("input", std::vector{ N, H }, itype); + Tensor input("input2", std::vector{ N, H }, itype); Tensor z("z", std::vector{ N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING); - Tensor gamma("gamma", std::vector{ H }, wtype); - Tensor beta("beta", std::vector{ H }, wtype); + Tensor gamma("gamma2", std::vector{ H }, wtype); + Tensor beta("beta2", std::vector{ H }, wtype); Tensor mu("mu", std::vector{ N }, DType::kFloat32); Tensor rsigma("rsigma", std::vector{ N }, DType::kFloat32); Tensor workspace; diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 2a6e55b2c..ac185fd10 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -1,5 +1,7 @@ #!/usr/bin/python3 +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -30,6 +32,11 @@ warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) +import transformer_engine.pytorch.cpp_extensions as tex +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" +if not tex.device_supports_multicast(): + os.environ["UB_SKIPMC"] = "1" + class multi_module_model(torch.nn.Module): def __init__(self, module, num_layers, *args, **kwargs): diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 74d1dc69c..6ccd3942d 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,6 +13,9 @@ import transformer_engine.pytorch.cpp_extensions as tex from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.jax.cpp_extensions.misc import is_hip_extension + + if torch.cuda.device_count() < 2: pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.") @@ -180,7 +185,7 @@ def test_bulk_overlaps(comm_type, quantization, connections): Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ if connections == 8: - if torch.cuda.get_device_properties(0).major != 9: + if is_hip_extension() or torch.cuda.get_device_properties(0).major != 9: pytest.skip( "CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability" " 9.0 (HOPPER ARCH)." diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cefec6d06..b8e137bc1 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -170,7 +170,11 @@ list(APPEND transformer_engine_SOURCES fused_router/fused_topk_with_score_function.cu recipe/current_scaling.cu recipe/delayed_scaling.cu - recipe/fp8_block_scaling.cu) + recipe/fp8_block_scaling.cu + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/userbuffers/userbuffers.cu + comm_gemm_overlap/comm_gemm_overlap.cpp) if(USE_CUDA) # Removed indent to minimize code diff with NV upstream # Files unique in cuda building @@ -184,11 +188,7 @@ list(APPEND transformer_engine_SOURCES fused_attn/fused_attn.cpp fused_attn/utils.cu gemm/cutlass_grouped_gemm.cu - util/cuda_nvml.cpp - comm_gemm_overlap/userbuffers/ipcsocket.cc - comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/userbuffers/userbuffers.cu - comm_gemm_overlap/comm_gemm_overlap.cpp) + util/cuda_nvml.cpp) if (NVTE_WITH_CUBLASMP) list(APPEND transformer_engine_SOURCES @@ -203,7 +203,8 @@ else() fused_attn_rocm/fused_attn_ck.cpp fused_attn_rocm/utils.cpp gemm/rocm_gemm.cu - amd_detail/system.cpp) + amd_detail/system.cpp + comm_gemm_overlap/rocm_comm_gemm_overlap.cpp) # process source code files set(TE ${CMAKE_CURRENT_SOURCE_DIR}/../..) @@ -270,11 +271,15 @@ target_include_directories(transformer_engine PRIVATE ${CUTLASS_TOOLS_INCLUDE_DIR}) # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI +# Changed option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) if (NVTE_UB_WITH_MPI) - find_package(MPI REQUIRED) - target_link_libraries(transformer_engine PUBLIC MPI::MPI_CXX) - target_include_directories(transformer_engine PRIVATE ${MPI_CXX_INCLUDES}) + # OpenMPI C++ headers are deprecated -- flag unused w/ MPICH + add_definitions(-DOMPI_SKIP_MPICXX) + + target_include_directories(transformer_engine PRIVATE "$ENV{MPI_HOME}/include") + target_link_directories(transformer_engine PRIVATE "$ENV{MPI_HOME}/lib") + target_link_libraries(transformer_engine PUBLIC mpi) target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI) endif() @@ -451,7 +456,7 @@ endif() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}") message(STATUS "Threads per parallel build job: ${BUILD_THREADS_PER_JOB}") else() - set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3") + set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3 -fopenmp") set(HIP_HCC_FLAGS "${CMAKE_HIP_FLAGS} -mavx2 -mf16c -mfma -std=c++17") # Ask hcc to generate device code during compilation so we can use # host linker to link. diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index ec29e6e12..d56e57a69 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -21,6 +21,12 @@ #define HALF_BYTES 2 #define UB_MAX_SM 32 +#ifdef __HIP_PLATFORM_AMD__ +#define half_dtype hip_bfloat16 +#define __nv_fp8_e5m2 te_hip_fp8_e5m2 +#define __nv_fp8_e4m3 te_hip_fp8_e4m3 +#endif + using namespace std::placeholders; namespace transformer_engine { @@ -64,6 +70,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl #endif _comm_created = true; } + + initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm); +} + +void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { _use_ce = static_cast(use_ce); _num_comm_sm = num_comm_sm; _cga_size = comm_cga_size; @@ -74,7 +89,7 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl _gemm_priority = gemm_priority; _comm_priority = comm_priority; } - for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { + for (int i = 0; i < std::max(num_max_streams, num_splits); i++) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority)); _stream_compute.push_back(std::move(stream)); @@ -278,6 +293,11 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm); +} + +void CommOverlapBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm) { _rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, @@ -288,7 +308,9 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); + if (_ub_comm->myrank == 0) { + printf("!!! [UB] Register UBuf %d\n", _ub_reg); + } _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); NVTE_CHECK_CUDA( @@ -337,7 +359,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, comm_elements, _ub_comm, _stream_comm, - (cudaEvent_t)_comm_launch_event); + (cudaEvent_t)_comm_launch_event); } else { reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, (cudaEvent_t)_comm_launch_event); @@ -455,7 +477,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); -} // split_overlap_rs +} // atomic_gemm_overlap_rs /* ** Split FPROP GEMM + ReduceScatter @@ -603,6 +625,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, cudaStream_t stream_main) { + int comm_bytes = _ubuf.bytes(); int comm_bytes_per_rank = comm_bytes / _tp_size; @@ -640,6 +663,11 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, comm_type, aggregate); +} + +void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate) { _is_p2p = true; _is_reduce_scatter = comm_type == CommOverlapType::RS; _aggregate = aggregate; @@ -647,28 +675,28 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, // Create workspace tensor with userbuffer NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); - int buffer_chunk_bytes = buffer_bytes / tp_size; - _num_ubuf_chunks = tp_size; + int buffer_chunk_bytes = buffer_bytes / _tp_size; + _num_ubuf_chunks = _tp_size; if (_is_reduce_scatter) { // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk // outputs for reduction at the end of the pipelining. - buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1); - _num_ubuf_chunks = tp_size * 2 - 1; + buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1); + _num_ubuf_chunks = _tp_size * 2 - 1; } void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); + if (_rank == 0) printf("!!! [UBP2P] UBuf %d\n", _ub_reg); _ubuf = TensorWrapper( buffer_ptr, - std::vector{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]}, buffer_dtype); // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); for (int i = 0; i < _num_ubuf_chunks; i++) { _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), - std::vector{buffer_shape[0] / tp_size, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size, buffer_shape[1]}, buffer_dtype)); ubuf_byte_ptr += buffer_chunk_bytes; } @@ -691,15 +719,28 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); } - for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { + for (int i = 0; i < _stream_compute.size(); i++) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); _stream_send.push_back(std::move(stream)); } + for (int i = 0; i < 7; i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); + l_stream_send.push_back(std::move(stream)); + } + for (int i = 0; i < 7; i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); + l_stream_recv.push_back(std::move(stream)); + } NVTE_CHECK_CUDA( cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); + for (int i = 0; i < 7; i++) { + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&l_stop_recv[i], 0)); + } } CommOverlapP2PBase::~CommOverlapP2PBase() { @@ -709,6 +750,43 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { for (size_t i = 0; i < _stream_send.size(); i++) { cudaStreamDestroy(_stream_send[i]); } + for (int i = 0; i < 7; i++) { + cudaStreamDestroy(l_stream_recv[i]); + cudaStreamDestroy(l_stream_send[i]); + cudaEventDestroy(l_stop_recv[i]); + } +} + +void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, + bool local_chunk, bool rowwise) { + // Check element size + const size_t element_size = source.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), + " bytes)"); + + // Input data + const size_t source_size = source.numel(); + const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); + + // Userbuffers data + void *dst_ptr; + if (local_chunk) { + NVTE_CHECK(_ubufs[_tp_id].numel() == source_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + dst_ptr = _ubufs[_tp_id].dptr(); + } else { + NVTE_CHECK(_ubuf.numel() == source_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")"); + dst_ptr = _ubuf.dptr(); + } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, + cudaMemcpyDeviceToDevice, stream)); } TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, @@ -851,6 +929,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, const bool do_gelu = pre_gelu_out.numel() > 0; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", + _ubuf.numel(), " elements but got ", B_copy.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), + "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, + "-bit data type but got ", B_copy.element_size() * 8, "-bit"); + } + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); @@ -919,12 +1006,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } else { @@ -972,16 +1053,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } + // Copy all-gathered B from communication buffer into auxiliary output + if (B_copy.numel() > 0) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + _ub_comm->sms = ori_sms; for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); diff --git a/transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp new file mode 100644 index 000000000..ea05ea95f --- /dev/null +++ b/transformer_engine/common/comm_gemm_overlap/rocm_comm_gemm_overlap.cpp @@ -0,0 +1,664 @@ +/************************************************************************* + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_driver.h" +#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" +#include "common/util/system.h" +#include "userbuffers/userbuffers.h" + +namespace transformer_engine { +#if 0 +// Recursive doubling AG code for future reference +void CommOverlapP2PBase::rocm_split_overlap_ag_rd(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t k = (transa) ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + const int comm_bytes = _ubufs[0].bytes(); + const bool do_gelu = pre_gelu_out.numel() > 0; + const size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", + _ubuf.numel(), " elements but got ", B_copy.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), + "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, + "-bit data type but got ", B_copy.element_size() * 8, "-bit"); + } + + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + + int steps = 31 - __builtin_clz(_tp_size); + + // Chunk dims + std::vector input_b_chunk_shape = + (transb ? std::vector{k, n_chunk} : std::vector{n_chunk, k}); + std::vector output_chunk_shape = {n_chunk, m}; + size_t input_b_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; + + // GEMM + auto input_b_chunk = + get_buffer_chunk_like(B, input_b_chunk_size * _tp_id, input_b_chunk_shape); + auto output_chunk = + get_tensor_chunk(D, output_chunk_size * _tp_id, output_chunk_shape); + auto aux_chunk = + (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * _tp_id, {n_chunk, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (_tp_id % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[_tp_id % _stream_compute.size()]); + + std::vector owned_chunks; + owned_chunks.reserve(_tp_size); + owned_chunks.push_back(_tp_id); + size_t offset = 1; + + for (int step = 0; step < steps; step++) { + int send_rank = (_tp_id + offset) % _tp_size; + int recv_rank = (_tp_id - offset + _tp_size) % _tp_size; + + for (int i = 0; i < owned_chunks.size(); i++) { + size_t send_offset = owned_chunks[i] * comm_bytes; + userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, + comm_bytes, _ub_comm, send_rank, _stream_send[i % _stream_send.size()]); + } + + std::vector new_chunks; + for (size_t i = 0; i < owned_chunks.size(); i++) { + size_t new_chunk_id = (recv_rank + i * offset) % _tp_size; + if (new_chunk_id >= _tp_size || + std::find(owned_chunks.begin(), owned_chunks.end(), new_chunk_id) != owned_chunks.end()) continue; + size_t recv_offset = new_chunk_id * comm_bytes; + size_t stream_id = new_chunks.size() % _stream_compute.size(); + + userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, + comm_bytes, _ub_comm, recv_rank, _stream_recv); + + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[stream_id], _stop_recv, 0)); + + auto input_b_chunk = get_buffer_chunk_like(B, input_b_chunk_size * new_chunk_id, input_b_chunk_shape); + output_chunk = get_tensor_chunk(D, output_chunk_size * new_chunk_id, output_chunk_shape); + aux_chunk = (do_gelu) ? get_tensor_chunk(pre_gelu_out, output_chunk_size * new_chunk_id, {n_chunk, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + workspace_chunk = get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), + aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, + use_split_accumulator, _math_sms, + _stream_compute[stream_id]); + + new_chunks.push_back(new_chunk_id); + } + owned_chunks.insert(owned_chunks.end(), new_chunks.begin(), new_chunks.end()); + offset <<= 1; + } + + if (B_copy.numel() > 0) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + + _ub_comm->sms = ori_sms; + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); +} // rocm_split_overlap_ag_rd +#endif // #if 0 + +// TODO: Generalize for TP other than 2,4,8 using Walecki construction +constexpr int tp_next_8[7][8] = { + {1, 5, 4, 6, 3, 2, 7, 0}, + {2, 6, 1, 0, 5, 7, 4, 3}, + {3, 7, 0, 5, 6, 4, 1, 2}, + {4, 3, 6, 2, 7, 0, 5, 1}, + {5, 2, 7, 4, 1, 3, 0, 6}, + {6, 0, 5, 7, 2, 1, 3, 4}, + {7, 4, 3, 1, 0, 6, 2, 5}, +}; + +constexpr int tp_prev_8[7][8] = { + {7, 0, 5, 4, 2, 1, 3, 6}, + {3, 2, 0, 7, 6, 4, 1, 5}, + {2, 6, 7, 0, 5, 3, 4, 1}, + {5, 7, 3, 1, 0, 6, 2, 4}, + {6, 4, 1, 5, 3, 0, 7, 2}, + {1, 5, 4, 6, 7, 2, 0, 3}, + {4, 3, 6, 2, 1, 7, 5, 0}, +}; + +// No full Hamiltonian decomposition for TP=4 TP=6 (Tillson’s Theorem) +// Further optimization for these cases may be multiring w/ RD for example +constexpr int tp_next_4[2][4] = { + {1, 2, 3, 0}, + {3, 0, 1, 2}, +}; + +constexpr int tp_prev_4[2][4] = { + {3, 0, 1, 2}, + {1, 2, 3, 0} +}; + +template +constexpr bool multiring_hamiltonian_check(const int (&next)[NUM_RINGS][TP_SIZE]) { + for (int r = 0; r < NUM_RINGS; ++r) { + bool visited[TP_SIZE] = {}; + + int curr = 0; + for (int step = 0; step < TP_SIZE; ++step) { + if (visited[curr]) return false; + visited[curr] = true; + curr = next[r][curr]; + } + + if (curr != 0) return false; + + for (int i = 0; i < TP_SIZE; ++i) { + if (!visited[i]) return false; + } + } + return true; +} + +template +constexpr bool rings_are_unique( + const int next[NUM_RINGS][TP_SIZE]) +{ + for (int src = 0; src < TP_SIZE; ++src) { + bool seen[TP_SIZE] = {}; + + for (int r = 0; r < NUM_RINGS; ++r) { + int dst = next[r][src]; + + // No self-send + if (dst == src) + return false; + + if (seen[dst]) + return false; + + seen[dst] = true; + } + } + return true; +} + +template +constexpr bool prev_is_inverse_of_next( + const int next[NUM_RINGS][TP_SIZE], + const int prev[NUM_RINGS][TP_SIZE]) +{ + for (int r = 0; r < NUM_RINGS; ++r) { + for (int i = 0; i < TP_SIZE; ++i) { + int n = next[r][i]; + int p = prev[r][i]; + + if (n < 0 || n >= TP_SIZE) return false; + if (p < 0 || p >= TP_SIZE) return false; + + if (prev[r][n] != i) return false; + if (next[r][p] != i) return false; + } + } + return true; +} + +static_assert(multiring_hamiltonian_check<2,4>(tp_next_4), "Non-Hamiltonian ring present!"); +static_assert(multiring_hamiltonian_check<7,8>(tp_next_8), "Non-Hamiltonian ring present!"); + +static_assert(rings_are_unique<2,4>(tp_next_4), "Rings overlap"); +static_assert(rings_are_unique<7,8>(tp_next_8), "Rings overlap"); + +static_assert(prev_is_inverse_of_next<2,4>(tp_next_4, tp_prev_4), "tp_prev_4 is not inverse of tp_next_4"); +static_assert(prev_is_inverse_of_next<7,8>(tp_next_8, tp_prev_8), "tp_prev_8 is not inverse of tp_next_8"); + +// TODO: Introduce HIPGraphs for dependency management. +void CommOverlapP2PBase::rocm_split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + // Get GEMM dimensions between TN and NN input layouts + const size_t m = (transa) ? A.size(0) : A.size(1); + const size_t k = (transa) ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + // Get communication and GEMM output chunk sizes + const int comm_bytes = _ubufs[0].bytes(); + const bool do_gelu = pre_gelu_out.numel() > 0; + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + const int max_rings = (_tp_size == 4) ? 2 : + (_tp_size == 6) ? 4 : + _tp_size - 1; + const int num_rings = std::min({ + transformer_engine::getenv("GPU_MAX_HW_QUEUES", 4), + _tp_size - 1, + max_rings + }); + + const int *next, *prev; + switch (_tp_size) { + case 8: + next = reinterpret_cast(tp_next_8); + prev = reinterpret_cast(tp_prev_8); + break; + case 4: + next = reinterpret_cast(tp_next_4); + prev = reinterpret_cast(tp_prev_4); + break; + case 2: + return this->split_overlap_ag(A, transa, B, transb, D, bias, pre_gelu_out, workspace, grad, + accumulate, use_split_accumulator, B_copy, stream_main); + default: + NVTE_ERROR("ROCm supports TP sizes of 2, 4, 8 only."); + } + + const int alignment = 256; + const int base_slice_bytes = (comm_bytes / num_rings) & ~(alignment - 1); + const int total_base_bytes = base_slice_bytes * num_rings; + const int remainder_bytes = comm_bytes - total_base_bytes; + + const size_t base_n_slice = n_chunk / num_rings; + const size_t remainder_n = n_chunk - (base_n_slice * num_rings); + + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size()); + } + + auto get_slice_info = [&](int ring) -> std::pair { + size_t offset = ring * base_slice_bytes; + int size = base_slice_bytes; + if (ring == num_rings - 1) + size += remainder_bytes; + return {offset, size}; + }; + + auto get_slice_n = [&](int ring) -> size_t { + return base_n_slice + (ring == num_rings - 1 ? remainder_n : 0); + }; + + auto get_chunk_id = [&](int ring, int step) { + int owner = _tp_id; + for (int s = 0; s < step; ++s) + owner = prev[ring * _tp_size + owner]; + return owner; + }; + + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + + for (int r = 0; r < num_rings; ++r) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_recv[r], _start_compute, 0)); + } + + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); + } + + const int total_slices = _tp_size * num_rings; + std::vector slice_events(total_slices); + + for (int i = 0; i < total_slices; i++) { + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&slice_events[i], cudaEventDisableTiming)); + } + + auto get_event = [&](int chunk, int ring) { + return slice_events[chunk * num_rings + ring]; + }; + + for (int r = 0; r < num_rings; r++) { + NVTE_CHECK_CUDA(cudaEventRecord(get_event(_tp_id, r), stream_main)); + } + + auto get_slice_offset = [&](int chunk, int ring) { + auto [ring_offset, _] = get_slice_info(ring); + return chunk * comm_bytes + ring_offset; + }; + + auto launch_slice_gemm = [&](int ring_id, int step) { + int chunk_id = get_chunk_id(ring_id, step); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[ring_id], + get_event(chunk_id, ring_id), 0)); + size_t n_slice = get_slice_n(ring_id); + + size_t input_b_slice_elems = n_slice * k; + size_t output_slice_elems = n_slice * m; + + size_t b_elem_offset = chunk_id * n_chunk * k; + size_t d_elem_offset = chunk_id * n_chunk * m; + + for (int r = 0; r < ring_id; r++) { + size_t prev_n = get_slice_n(r); + b_elem_offset += prev_n * k; + d_elem_offset += prev_n * m; + } + + std::vector input_b_slice_shape = + (transb ? std::vector{k, n_slice} : std::vector{n_slice, k}); + std::vector output_slice_shape = {n_slice, m}; + + auto input_b_slice = get_buffer_chunk_like(B, b_elem_offset, input_b_slice_shape); + auto output_slice = get_tensor_chunk(D, d_elem_offset, output_slice_shape); + + auto aux_slice = (do_gelu) + ? get_tensor_chunk(pre_gelu_out, d_elem_offset, {n_slice, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + + auto workspace_chunk = get_tensor_chunk(workspace, ring_id * workspace_size_chunk, + {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_slice.data(), output_slice.data(), bias.data(), + aux_slice.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[ring_id]); + }; + + for (int step = 0; step < _tp_size; step++) { + for (int r = 0; r < num_rings; r++) { + if (step < _tp_size - 1) { + int curr_chunk_id = get_chunk_id(r, step); + int next_recv_chunk_id = get_chunk_id(r, step + 1); + + int next_rank = next[r * _tp_size + _tp_id]; + int prev_rank = prev[r * _tp_size + _tp_id]; + + size_t send_off = get_slice_offset(curr_chunk_id, r); + size_t recv_off = get_slice_offset(next_recv_chunk_id, r); + + auto [_, slice_bytes] = get_slice_info(r); + + if (step > 0) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], get_event(curr_chunk_id, r), 0)); + } + + { + int peerlocal = next_rank % _ub_comm->nvsize; + void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, _ub_comm, _ub_reg, r); + void *srcptr = reinterpret_cast(_ub_comm->mem_ptr[_ub_reg]) + send_off; + void *dstptr = reinterpret_cast(_ub_comm->peer_ptr[_ub_reg][peerlocal]) + send_off; + + NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, slice_bytes, cudaMemcpyDeviceToDevice, l_stream_send[r])); + uint32_t signal_val = step + 1; + hipStreamWriteValue32(l_stream_send[r], flagptr, signal_val, 0); + } + + { + int peerlocal = prev_rank % _ub_comm->nvsize; + void *flagptr = GET_RECV_PTR_BY_INDEX(prev_rank, _ub_comm, _ub_reg, r); + + uint32_t signal_val = step + 1; + hipStreamWaitValue32(l_stream_recv[r], flagptr, signal_val, hipStreamWaitValueGte, 0xFFFFFFFF); + } + + NVTE_CHECK_CUDA(cudaEventRecord(get_event(next_recv_chunk_id, r), l_stream_recv[r])); + } + } + + for (int r = 0; r < num_rings; r++) { + launch_slice_gemm(r, step); + } + } + + if (B_copy.numel() > 0) { + for (int r = 0; r < num_rings; r++) { + int last_chunk = get_chunk_id(r, _tp_size - 1); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[0], get_event(last_chunk, r), 0)); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, l_stream_send[0])); + } + + _ub_comm->sms = ori_sms; + + for (auto& s : _stream_compute) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, s)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + + for (int r = 0; r < num_rings; r++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, l_stream_send[r])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, l_stream_recv[r])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + } + + for (auto& ev : slice_events) { + NVTE_CHECK_CUDA(cudaEventDestroy(ev)); + } +} // CommOverlapP2PBase::rocm_split_overlap_ag + +void CommOverlapP2PBase::rocm_split_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { + int ori_sms = _ub_comm->sms; + _ub_comm->use_ce = _use_ce; + _ub_comm->sms = _num_comm_sm; + _ub_comm->cga_size = _cga_size; + + // GEMM dimensions + const size_t m = transa ? A.size(0) : A.size(1); + const size_t k = transa ? A.size(1) : A.size(0); + const size_t n_chunk = _ubufs[0].size(0); + const int comm_bytes = _ubufs[0].bytes(); + + size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + + const int max_rings = (_tp_size == 4) ? 2 : + (_tp_size == 6) ? 4 : + _tp_size - 1; + + const int num_rings = std::min({ + transformer_engine::getenv("GPU_MAX_HW_QUEUES", 4), + _tp_size - 1, + max_rings + }); + + const int *next, *prev; + switch (_tp_size) { + case 8: + next = reinterpret_cast(tp_next_8); + prev = reinterpret_cast(tp_prev_8); + break; + case 4: + next = reinterpret_cast(tp_next_4); + prev = reinterpret_cast(tp_prev_4); + break; + case 2: + return this->split_overlap_rs(A, transa, B, transb, D, bias, pre_gelu_out, workspace, grad, + accumulate, use_split_accumulator, rs_output, stream_main); + default: + NVTE_ERROR("ROCm supports TP sizes of 2, 4, 8 only."); + } + + const int alignment = 256; + const int base_slice_bytes = (comm_bytes / num_rings) & ~(alignment - 1); + const int total_base_bytes = base_slice_bytes * num_rings; + const int remainder_bytes = comm_bytes - total_base_bytes; + + const size_t base_n_slice = n_chunk / num_rings; + const size_t remainder_n = n_chunk - base_n_slice * num_rings; + + auto get_slice_info = [&](int ring) -> std::pair { + size_t offset = ring * base_slice_bytes; + int size = base_slice_bytes; + if (ring == num_rings - 1) + size += remainder_bytes; + return {offset, size}; + }; + + auto get_slice_n = [&](int ring) -> size_t { + return base_n_slice + (ring == num_rings - 1 ? remainder_n : 0); + }; + + auto get_chunk_id = [&](int ring, int step) { + int owner = _tp_id; + for (int s = 0; s < step; ++s) + owner = prev[ring * _tp_size + owner]; + return owner; + }; + + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); + for (int r = 0; r < num_rings; r++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_recv[r], _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[r], _start_compute, 0)); + } + + for (auto &s : _stream_compute) + NVTE_CHECK_CUDA(cudaStreamWaitEvent(s, _start_compute, 0)); + + const int total_slices = _tp_size * num_rings; + std::vector slice_events(total_slices); + + for (auto &e : slice_events) + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&e, cudaEventDisableTiming)); + + auto get_event = [&](int chunk, int ring) { + return slice_events[chunk * num_rings + ring]; + }; + + for (int r = 0; r < num_rings; ++r) + NVTE_CHECK_CUDA(cudaEventRecord(get_event(_tp_id, r), stream_main)); + + auto get_slice_offset = [&](int chunk, int ring) { + auto [ring_offset, _] = get_slice_info(ring); + return chunk * comm_bytes + ring_offset; + }; + + auto launch_slice_gemm = [&](int chunk_id, int ring_id, int step) { + size_t n_slice = get_slice_n(ring_id); + + size_t b_elem_offset = chunk_id * n_chunk * k; + size_t d_elem_offset = chunk_id * n_chunk * m; + + for (int r = 0; r < ring_id; ++r) { + b_elem_offset += get_slice_n(r) * k; + d_elem_offset += get_slice_n(r) * m; + } + + auto input_b_slice = get_tensor_chunk(B, b_elem_offset, transb ? std::vector{k, n_slice} : std::vector{n_slice, k}); + auto output_slice = get_tensor_chunk(D, d_elem_offset, {n_slice, m}); // D acts as the accumulation buffer + auto workspace_chunk = get_tensor_chunk(workspace, ring_id * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(A.data(), input_b_slice.data(), output_slice.data(), bias.data(), + pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), + accumulate, use_split_accumulator, _math_sms, + _stream_compute[ring_id]); + NVTE_CHECK_CUDA(cudaEventRecord(get_event(chunk_id, ring_id), _stream_compute[ring_id])); + }; + + for (int step = 0; step < _tp_size; ++step) { + for (int r = 0; r < num_rings; ++r) { + int curr_chunk = get_chunk_id(r, step); + launch_slice_gemm(curr_chunk, r, step); + } + + if (step > 0) { + int prev_step = step - 1; + + for (int r = 0; r < num_rings; ++r) { + int chunk_to_send = get_chunk_id(r, prev_step); + + NVTE_CHECK_CUDA(cudaStreamWaitEvent(l_stream_send[r], get_event(chunk_to_send, r), 0)); + + size_t send_off = get_slice_offset(chunk_to_send, r); + auto [_, slice_bytes] = get_slice_info(r); + + int next_rank = next[r * _tp_size + _tp_id]; + int prev_rank = prev[r * _tp_size + _tp_id]; + + { + int peerlocal = next_rank % _ub_comm->nvsize; + void *srcptr = reinterpret_cast(_ub_comm->mem_ptr[_ub_reg]) + send_off; + void *dstptr = reinterpret_cast(_ub_comm->peer_ptr[_ub_reg][peerlocal]) + send_off; + void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, _ub_comm, _ub_reg, r); + + NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, slice_bytes, cudaMemcpyDeviceToDevice, l_stream_send[r])); + uint32_t signal_val = prev_step + 1; // Use step count as signal + hipStreamWriteValue32(l_stream_send[r], flagptr, signal_val, 0); + } + + { + int peerlocal = prev_rank % _ub_comm->nvsize; + void *flagptr = GET_RECV_PTR_BY_INDEX(prev_rank, _ub_comm, _ub_reg, r); + uint32_t signal_val = prev_step + 1; + hipStreamWaitValue32(l_stream_recv[r], flagptr, signal_val, hipStreamWaitValueGte, 0xFFFFFFFF); + } + } + } + } + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); + } + for (int r = 0; r < num_rings; r++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, l_stream_send[r])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, l_stream_recv[r])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); + } + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + D.dtype(), fp8_type, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, + _ubufs[0].numel(), stream_main);); + } else { + reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); + } + + _ub_comm->sms = ori_sms; + + // Cleanup events + for (auto &e : slice_events) NVTE_CHECK_CUDA(cudaEventDestroy(e)); +} // rocm_split_overlap_rs + +} // namespace transformer_engine diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 1ce89c512..72324dd23 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -357,12 +357,12 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, NVTE_CHECK_CUDA(cudaDeviceSynchronize()); register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true); NVTE_CHECK_CUDA( - cudaMalloc(reinterpret_cast(&(*comm)->send_id), (*comm)->nranks * sizeof(int))); + cudaMalloc(reinterpret_cast(&(*comm)->send_id), (*comm)->nranks * NVTE_MAX_RINGS * sizeof(int))); NVTE_CHECK_CUDA(cudaMalloc(reinterpret_cast(&(*comm)->recv_id), - NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); - NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int))); + NVTE_MAX_REGIONS * (*comm)->nranks * NVTE_MAX_RINGS * sizeof(int))); + NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * NVTE_MAX_RINGS * sizeof(int))); NVTE_CHECK_CUDA( - cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int))); + cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * NVTE_MAX_RINGS * sizeof(int))); (*comm)->sms = 16; (*comm)->threads = 1024; @@ -375,8 +375,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, cudaMalloc(reinterpret_cast(&(*comm)->flags_baseptr), 2 * GPU_PAGE_SIZE)); NVTE_CHECK_CUDA(cudaMemset((*comm)->flags_baseptr, 0, 2 * GPU_PAGE_SIZE)); (*comm)->flags = reinterpret_cast( +#ifdef __HIP_PLATFORM_AMD__ + (reinterpret_cast((*comm)->flags) + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); +#else ((CUdeviceptr)(*comm)->flags_baseptr + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK); - +#endif using namespace std; sched_param param; @@ -670,9 +673,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * reinterpret_cast(&memhndl), sizeof(cudaIpcMemHandle_t), comm->comm_intra); + // Check for NVLINK support before attempting IPC operations + if (comm->nvsize > 1) { + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + cudaDeviceProp deviceProp; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, current_device)); + bool peer_access_available = false; + for (int i = 0; i < comm->nvsize; i++) { + if (i != comm->nvrank) { + int can_access_peer; + cudaError_t peer_result = cudaDeviceCanAccessPeer(&can_access_peer, current_device, i); + if (peer_result == cudaSuccess && can_access_peer) { + peer_access_available = true; + break; + } + } + } + if (!peer_access_available) { + free(tmp); + NVTE_ERROR( + "No peer-to-peer access available between GPUs. This platform does not support the " + "GPU-to-GPU " + "communication required for multi-GPU userbuffers. Consider using single-GPU mode."); + return 1; + } + } + for (int i = 0; i < comm->nvsize; i++) { if (i != comm->nvrank) { - NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*) + NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], cudaIpcMemLazyEnablePeerAccess)); } } @@ -693,4 +723,5 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * comm->mem_ptr[hndl] = *gpubuff; return comm->free_region++; + printf("***** Returning *****\n"); } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 1dcd54d0d..58ceb9d9d 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -5,6 +5,15 @@ ************************************************************************/ #include + +#ifdef __HIP_PLATFORM_AMD__ +#include +#include +#include "amd_detail/hip_float8.h" +#define half_dtype hip_bfloat16 +#define __nv_fp8_e5m2 te_hip_fp8_e5m2 +#define __nv_fp8_e4m3 te_hip_fp8_e4m3 +#else #include #include @@ -13,6 +22,7 @@ #else #define half_dtype half #endif +#endif #include #include @@ -24,6 +34,7 @@ #define MAX_THREADS 1024 +#if !defined(__HIP_PLATFORM_AMD__) && defined(__HIP_PLATFORM_NVIDIA__) #define ATOMIC_CONSUMER(chunk) \ if (counters) { \ if (threadIdx.x == 0 && blockIdx.x == 0) { \ @@ -34,6 +45,18 @@ } \ if (blockIdx.x == 0) __syncthreads(); \ } +#else +#define ATOMIC_CONSUMER(chunk) \ + if (counters) { \ + if (threadIdx.x == 0 && blockIdx.x == 0) { \ + while (0 != (atomicCAS(((unsigned int *)counters) + chunk, 0, 0))) { \ + } \ + ((unsigned int *)counters)[chunk] = 1; \ + __threadfence(); \ + } \ + if (blockIdx.x == 0) __syncthreads(); \ + } +#endif #define ATOMIC_PRODUCER(chunk) \ if (counters) { \ @@ -62,7 +85,7 @@ printf("[%s:%s:%d] " message "\n", FILENAME(__FILE__), __FUNCTION__, __LINE__, __VA_ARGS__) // Report and error on timeout -#define CHECK_TIMEOUT(t, timeout) ((clock64() - (t)) > timeout) +#define CHECK_TIMEOUT(t, timeout) (((uint64_t)clock64() - (t)) > timeout) template __global__ void __launch_bounds__(MAX_THREADS) @@ -132,7 +155,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) __threadfence_system(); + if (threadIdx.x == 0) __threadfence(); __syncthreads(); if (threadIdx.x < RANKS) { @@ -477,7 +500,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) __threadfence_system(); + if (threadIdx.x == 0) __threadfence(); __syncthreads(); if (threadIdx.x < RANKS) { @@ -708,7 +731,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) __threadfence_system(); + if (threadIdx.x == 0) __threadfence(); __syncthreads(); __shared__ int lastSM; @@ -1025,7 +1048,11 @@ __global__ void __launch_bounds__(MAX_THREADS) // reset counter for next producer. ((unsigned int *)counters)[0] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } __syncthreads(); @@ -1116,7 +1143,11 @@ __global__ void __launch_bounds__(MAX_THREADS) // reset counter for next producer. ((unsigned int *)counters)[chunk_i] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } __syncthreads(); @@ -1329,7 +1360,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); - if (threadIdx.x == 0) __threadfence_system(); + if (threadIdx.x == 0) __threadfence(); __syncthreads(); __shared__ int lastSM; @@ -1357,6 +1388,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } } // fp16 inplace allgather kernel (Volta,Hopper) +#ifndef __HIP_PLATFORM_AMD__ #define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ cudaLaunchAttribute attribute_ub[2]; \ @@ -1367,6 +1399,15 @@ __global__ void __launch_bounds__(MAX_THREADS) attribute_ub[0].id = cudaLaunchAttributeCooperative; \ cfg.attrs = attribute_ub; \ cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; +#else +#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[1]; \ + attribute_ub[0].id = cudaLaunchAttributeCooperative; \ + attribute_ub[0].value.cooperative = 1; \ + cfg.attrs = attribute_ub; \ + cfg.numAttrs = 1; +#endif #if (CUDART_VERSION >= 12030) #define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \ @@ -1378,6 +1419,11 @@ __global__ void __launch_bounds__(MAX_THREADS) #define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2 #endif +#ifdef __HIP_PLATFORM_AMD__ +#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ + cudaLaunchConfig_t cfg; \ + NVTE_ERROR("SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT is not supported for AMD GPUs") +#else #define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \ @@ -1389,6 +1435,7 @@ __global__ void __launch_bounds__(MAX_THREADS) attribute_ub[0].id = cudaLaunchAttributeCooperative; \ cfg.attrs = attribute_ub; \ cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH; +#endif #define callranks_ag(x) \ if (ar_nvsize == x) { \ @@ -2049,7 +2096,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); if (threadIdx.x) return; - __threadfence_system(); + __threadfence(); atomicAdd_system(flagptr, 1); // otherwise need local SM sync before sending flag } else { // 0 bytes and 1 SM only @@ -2111,7 +2158,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); if (threadIdx.x) return; - __threadfence_system(); + __threadfence(); atomicAdd_system(send_flagptr, 1); // otherwise need local SM sync before sending flag } else { // 0 bytes and 1 SM only @@ -2169,7 +2216,7 @@ __global__ void __launch_bounds__(MAX_THREADS) } __syncthreads(); if (threadIdx.x) return; - __threadfence_system(); + __threadfence(); atomicAdd_system(send_flagptr, 1); // otherwise need local SM sync before sending flag } else { // 0 bytes and 1 SM only @@ -2196,7 +2243,11 @@ __global__ void __launch_bounds__(MAX_THREADS) // Decrement atomic val to signal current output tile finish if (counters) { ((unsigned int *)counters)[0] = 0; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } } @@ -2236,7 +2287,7 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat } __syncthreads(); if (!threadIdx.x) { - __threadfence_system(); + __threadfence(); atomicAdd_system(send_flagptr, 1); // otherwise need local SM sync before sending flag } @@ -2267,7 +2318,11 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat // Decrement atomic val to signal current output tile finish if (counters) { ((unsigned int *)counters)[recv_chunk_id /*chunk_i+1*/] = 0; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } @@ -2284,6 +2339,7 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat // Return TRUE if two ranks share the same NV domain #define INTRANODE(peer) ((peer / comm->nvsize) == (comm->myrank / comm->nvsize)) +#ifndef __HIP_PLATFORM_AMD__ // Moved to header for visibility // Index corresponds to the type of flag: // 0 - Send index counter // 1 - CE start index counter @@ -2303,12 +2359,13 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + (recv_peer) * NVTE_MAX_REGIONS + (dsth) + \ (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ sizeof(int))) +#endif // #ifndef __HIP_PLATFORM_AMD__ void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, - const int peer, cudaStream_t stream) { + const int peer, cudaStream_t stream, int ring_id) { int peerlocal = peer % comm->nvsize; - void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 0); + void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, ring_id); // void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1); // void *ce_send_end_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 2); bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); @@ -2317,7 +2374,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return; if (comm->push == 0) { - kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]), + kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer * NVTE_MAX_RINGS + ring_id]), reinterpret_cast(flagptr)); NVTE_CHECK_CUDA(cudaGetLastError()); } else { @@ -2330,7 +2387,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds // kuserbuffers_inc<<<1, 1, 0, stream>>>(reinterpret_cast(ce_send_end_ptr)); } SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream); - int *arg1 = &comm->send_id[peer], *arg2 = reinterpret_cast(flagptr); + int *arg1 = &comm->send_id[peer * NVTE_MAX_RINGS + ring_id], *arg2 = reinterpret_cast(flagptr); int4 *arg3 = reinterpret_cast(srcptr), *arg4 = reinterpret_cast(dstptr); int arg5 = signalonly ? 0 : bytes / 16; void *kernelArgs[] = {reinterpret_cast(&arg1), reinterpret_cast(&arg2), @@ -2500,9 +2557,9 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, - const int peer, cudaStream_t stream) { + const int peer, cudaStream_t stream, int ring_id) { int peerlocal = peer % comm->nvsize; - void *flagptr = GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 0); + void *flagptr = GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, ring_id); bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); assert(INTRANODE(peer)); @@ -2514,12 +2571,12 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds kuserbuffers_pullrecv<<sms, signalonly ? 1 : 1024, 0, stream>>>( comm->myrank, peer, comm->nvrank, peerlocal, - &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast(flagptr), + &(comm->recv_id[(peer * NVTE_MAX_REGIONS + dsthandler) * NVTE_MAX_RINGS + ring_id]), reinterpret_cast(flagptr), reinterpret_cast(srcptr), reinterpret_cast(dstptr), signalonly ? 0 : bytes / 16, comm->ub_timeout); NVTE_CHECK_CUDA(cudaGetLastError()); if (!signalonly) { - kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler])); + kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[(peer * NVTE_MAX_REGIONS + dsthandler) * NVTE_MAX_RINGS + ring_id])); NVTE_CHECK_CUDA(cudaGetLastError()); } if (comm->use_ce) { @@ -2528,7 +2585,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds } else { kuserbuffers_pushrecv<<<1, 1, 0, stream>>>( comm->myrank, peer, comm->nvrank, peerlocal, - &comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler], reinterpret_cast(flagptr), + &comm->recv_id[(peer * NVTE_MAX_REGIONS + dsthandler) * NVTE_MAX_RINGS + ring_id], reinterpret_cast(flagptr), signalonly || comm->sms, comm->ub_timeout, reinterpret_cast(0 ? // temporary disable GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 1) @@ -2576,7 +2633,11 @@ static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) { // COMM kernel need to explicitely flash gmem. // GEMM kernel already executed, and can not see gmem // change without COMM kernel explicitely make change +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } // consumer @@ -2586,7 +2647,11 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) { while (0 != (atomicCAS((unsigned int *)atomic_ptr + chunk_i, 0, 0))) { } ((unsigned int *)atomic_ptr)[chunk_i] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } @@ -2598,7 +2663,11 @@ static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i while (0 != (atomicCAS((unsigned int *)atomic_ptr + i, 0, 0))) { } ((unsigned int *)atomic_ptr)[i] = 1; +#ifndef __HIP_PLATFORM_AMD__ asm volatile("fence.sc.gpu;\n"); +#else + __threadfence(); +#endif } } } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 4d52fbb64..c7f26df06 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -35,6 +35,7 @@ using ExtBarrierOp = std::function; #define NVTE_LAUNCH_GPU 1 #define NVTE_LAUNCH_CPU 2 #define NVTE_MAX_NVLINK 32 +#define NVTE_MAX_RINGS 7 #define NVTE_UB_MEM_UC_CONTIG 1 #define NVTE_UB_MEM_MC_CREATED 2 @@ -63,6 +64,28 @@ using ExtBarrierOp = std::function; #define NVTE_HF_NVREDUCEDONE (userbuffers_op_types + 3) #define NVTE_MAX_SHARP 16 +#ifdef __HIP_PLATFORM_AMD__ // Moved to header for visibility +// Index corresponds to the type of flag: +// 0 - Send index counter +// 1 - CE start index counter +// 2 - CE end index counter +#define GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsth, index) \ + ((reinterpret_cast((comm)->peer_ptr[0][(peerlocal)])) + \ + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + (comm)->myrank * NVTE_MAX_REGIONS + (dsth) + \ + (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ + sizeof(int))) + +// Index corresponds to the type of flag: +// 0 - Receive index counter +// 1 - CE start index counter +// 2 - CE end index counter +#define GET_RECV_PTR_BY_INDEX(recv_peer, comm, dsth, index) \ + ((reinterpret_cast((comm)->mem_ptr[0])) + \ + ((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + (recv_peer) * NVTE_MAX_REGIONS + (dsth) + \ + (index) * NVTE_MAX_NVLINK * NVTE_MAX_REGIONS) * \ + sizeof(int))) +#endif // #ifdef __HIP_PLATFORM_AMD__ + typedef struct ub_request { int optype; int blocksize; @@ -268,10 +291,10 @@ output is strided: row starts separated by stride elements*/ void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, - const int peer, cudaStream_t stream = 0); + const int peer, cudaStream_t stream = 0, int ring_id = 0); void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes, communicator *comm, - const int peer, cudaStream_t stream = 0); + const int peer, cudaStream_t stream = 0, int ring_id = 0); void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size_t send_offset, const size_t recv_offset, const size_t bytes, communicator *comm, const int send_peer, const int recv_peer, cudaStream_t stream = 0); diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 4d65e26ce..3b36ee951 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -15,7 +17,7 @@ #include "common/comm_gemm_overlap/userbuffers/userbuffers.h" -#define NVTE_COMM_OVERLAP_MAX_STREAMS 3 +#define NVTE_COMM_OVERLAP_MAX_STREAMS 7 namespace transformer_engine { @@ -37,7 +39,7 @@ enum class CommOverlapAlgo { ATOMIC_GEMM_RS = 5, ATOMIC_GEMM_AG_P2P = 6, ATOMIC_GEMM_RS_P2P = 7, - EXTERNAL_BULK_OVERLAP_AG = 8, + EXTERNAL_BULK_OVERLAP_AG = 8 }; class CommOverlapCore { @@ -67,6 +69,11 @@ class CommOverlapCore { std::vector _stream_compute; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; + private: + void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm); + public: CommOverlapCore() {} // dummy constructor for exposing type to Python @@ -78,23 +85,37 @@ class CommOverlapCore { virtual ~CommOverlapCore(); + void *get_ubuf_dptr() { return _ubuf.dptr(); } + void set_ubuf_scale_inv(float *scale_inv) { _ubuf_scale_inv = scale_inv; _ubuf_scale_inv_initialized = true; } + virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) { + NVTE_ERROR("Operation is not implemented."); + } + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, const std::vector &shape); TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, const std::vector &shape); + int get_tp_size() { return _tp_size; } + bool is_atomic_gemm() { return _atomic_gemm; } bool is_p2p_overlap() { return _is_p2p; } + bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + virtual bool is_aggregate() { + NVTE_ERROR("Operation is not implemented."); + } + virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, @@ -119,6 +140,14 @@ class CommOverlapCore { NVTE_ERROR("Operation is not implemented."); } + virtual void rocm_split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + virtual void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, @@ -139,6 +168,14 @@ class CommOverlapCore { cudaStream_t stream_main) { NVTE_ERROR("Operation is not implemented."); } + + virtual void rocm_split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } }; // CommOverlapCore class CommOverlapBase : public CommOverlapCore { @@ -148,6 +185,10 @@ class CommOverlapBase : public CommOverlapCore { cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm); + public: CommOverlapBase() {} // dummy constructor for exposing type to Python @@ -207,6 +248,22 @@ class CommOverlapBase : public CommOverlapCore { void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, cudaStream_t stream_main) override; + + void rocm_split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + + void rocm_split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { @@ -220,9 +277,13 @@ class CommOverlapP2PBase : public CommOverlapCore { int _num_ubuf_chunks; int _self_chunk_id; std::vector _ubufs; - std::vector _stream_send; + std::vector _stream_send, l_stream_send, l_stream_recv; cudaStream_t _stream_recv; - cudaEvent_t _stop_send, _stop_recv; + cudaEvent_t _stop_send, _stop_recv, l_stop_recv[7]; + + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate); public: CommOverlapP2PBase() {} // dummy constructor for exposing type to Python @@ -237,6 +298,9 @@ class CommOverlapP2PBase : public CommOverlapCore { virtual ~CommOverlapP2PBase(); + void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) override; + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, @@ -286,6 +350,25 @@ class CommOverlapP2PBase : public CommOverlapCore { TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) override; + /* + ** ROCm Multiring ReduceScatter + GEMM + */ + void rocm_split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; + + /* + ** ROCm Multiring AllGather + GEMM + */ + void rocm_split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override; + + bool is_aggregate() { return _aggregate; } // needed for rocm pathing /* ** This function overlaps the AG for the current communicator object with the GEMM for the overlap_gemm object. diff --git a/transformer_engine/common/include/transformer_engine/multi_stream.h b/transformer_engine/common/include/transformer_engine/multi_stream.h index e406a0786..cf67711f1 100644 --- a/transformer_engine/common/include/transformer_engine/multi_stream.h +++ b/transformer_engine/common/include/transformer_engine/multi_stream.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -11,7 +13,11 @@ #ifndef TRANSFORMER_ENGINE_MULTI_STREAM_H #define TRANSFORMER_ENGINE_MULTI_STREAM_H +#ifdef __HIP_PLATFORM_AMD__ +#include "util/hip_runtime.h" +#else #include "cuda_runtime.h" +#endif #ifdef __cplusplus extern "C" { diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 6f3f117d4..f49ff2c0b 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -27,7 +27,7 @@ namespace { #include "string_path_cuda_include.h" } // namespace -#endif // __HIP_PLATFORM_AMD__ +#endif // #ifndef __HIP_PLATFORM_AMD__ int num_devices() { auto query_num_devices = []() -> int { @@ -103,7 +103,6 @@ int sm_count(int device_id) { return cache[device_id]; } -#ifndef __HIP_PLATFORM_AMD__ void stream_priority_range(int *low_priority, int *high_priority, int device_id) { static std::vector> cache(num_devices()); static std::vector flags(num_devices()); @@ -124,6 +123,11 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id) *high_priority = cache[device_id].second; } +#ifdef __HIP_PLATFORM_AMD__ +bool supports_multicast(int _) { + return false; +} +#else bool supports_multicast(int device_id) { #if CUDART_VERSION >= 12010 // NOTE: This needs to be guarded at compile-time and run-time because the diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index 069981347..b23bcaef4 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -50,7 +50,6 @@ const std::string &sm_arch_name(int device_id = -1); */ int sm_count(int device_id = -1); -#ifndef __HIP_PLATFORM_AMD__ /* \brief Minimum and maximum stream priorities supported on device * * \param[in] device_id CUDA device (default is current device) @@ -68,7 +67,6 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id * \return CUDA multicast support flag */ bool supports_multicast(int device_id = -1); -#endif /* \brief Path to CUDA/ROCm Toolkit headers * diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index b243a8a0b..b4c1eaad0 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -10,10 +10,7 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ #include -//TODO: rocm does not support comm gemm overlap yet -#ifndef USE_ROCM #include -#endif #include #include @@ -35,9 +32,6 @@ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); #endif -// Define comm overlap handles if not using ROCm -#ifndef USE_ROCM - #define NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) \ pybind11::enum_(m, "CommOverlapType", \ pybind11::module_local()) \ @@ -56,7 +50,9 @@ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \ .value("EXTERNAL_BULK_OVERLAP_AG", \ - transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG); \ + transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG) \ + .value("SPLIT_PIPELINED_AG_RD_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_RD_P2P); \ py::class_>(m, "CommOverlapCore", \ pybind11::module_local()) \ @@ -91,14 +87,6 @@ py::call_guard(), py::arg("device_id") = -1); \ m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ py::call_guard()); -#else -#define NVTE_DECLARE_COMM_OVERLAP_HANDLES(m) \ - pybind11::class_(m, "CommOverlapType", \ - pybind11::module_local()); \ - py::class_>(m, "CommOverlapCore", \ - pybind11::module_local()); -#endif #define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ pybind11::enum_(m, "DType", pybind11::module_local()) \ diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 07384413d..7e015cc53 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -30,9 +30,7 @@ #include #include #include -#ifndef USE_ROCM #include -#endif #include #include #include diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 9b527b161..df408fad9 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -17,14 +17,6 @@ class CommOverlapHelper; class CommOverlap; class CommOverlapP2P; -#ifdef USE_ROCM -namespace transformer_engine { -//dummy CommOverlapCore, CommOverlapType in rocm -class CommOverlapCore{}; -class CommOverlapType{}; -} -#endif - namespace transformer_engine::pytorch { /*************************************************************************************************** @@ -504,7 +496,6 @@ void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at: } // namespace transformer_engine::pytorch -#ifndef USE_ROCM /*************************************************************************************************** * Comm+GEMM Overlap Wrappers **************************************************************************************************/ @@ -576,6 +567,5 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm std::pair get_communication_stream(); }; // CommOverlapP2P -#endif // !USE_ROCM #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 4aa2df2c9..29dd884fd 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -5,7 +5,6 @@ * * See LICENSE for license information. ************************************************************************/ -#ifndef USE_ROCM #include "../extensions.h" #include "transformer_engine/transformer_engine.h" @@ -314,10 +313,12 @@ std::pair CommOverlapP2P::get_communication_stream() { at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device())}; } +#ifndef USE_ROCM void transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm( CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream) { auto main_stream = at::cuda::getCurrentCUDAStream(); allgather_communicator.bulk_overlap_external_ag(at::cuda::CUDAStream(send_stream), at::cuda::CUDAStream(recv_stream), main_stream); } -#endif // !USE_ROCM +#endif + \ No newline at end of file diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index b637d49c7..e7d3809fd 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -229,7 +229,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans #endif if (comm_overlap) { -#ifndef USE_ROCM // Prepare extra output tensor TensorWrapper extra_output_tensor; if (extra_output.has_value()) { @@ -238,7 +237,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans extra_output_tensor = makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); } - // Direct GEMM call to the correct overlap if (bulk_overlap) { NVTE_SCOPED_GIL_RELEASE({ @@ -255,6 +253,15 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); +#ifdef __HIP_PLATFORM_AMD__ + } else if (!comm_overlap->is_aggregate()) { + NVTE_SCOPED_GIL_RELEASE({ + comm_overlap->rocm_split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, + extra_output_tensor, main_stream); + }); +#endif // #ifdef __HIP_PLATFORM_AMD } else { NVTE_SCOPED_GIL_RELEASE({ comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, @@ -272,17 +279,22 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans extra_output_tensor, main_stream); }); } else { +#ifdef __HIP_PLATFORM_AMD__ + NVTE_SCOPED_GIL_RELEASE({ + comm_overlap->rocm_split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, extra_output_tensor, + main_stream); +#else NVTE_SCOPED_GIL_RELEASE({ comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); +#endif }); } } -#else - NVTE_ERROR("ROCm TE does not support comm_overlap\n"); -#endif //!USE_ROCM } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 55b1d179e..7aa8796ad 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -465,7 +465,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("GRAD_OUTPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT3) .value("GRAD_INPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT3); -#ifndef USE_ROCM py::class_(m, "CommOverlapHelper") .def(py::init<>(), py::call_guard()) .def(py::init>(), @@ -505,9 +504,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) .def("get_communication_stream", &CommOverlapP2P::get_communication_stream); -#else - m.def("CommOverlapHelper", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); - m.def("CommOverlap", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); - m.def("CommOverlapP2P", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); -#endif //USE_ROCM } diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b49e38544..5d7c1e08b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -16,6 +16,7 @@ from contextlib import contextmanager import logging from types import MethodType +from itertools import chain import torch import torch.nn.functional as F @@ -325,7 +326,7 @@ def initialize_ub( # AG-RS overlap pairs of layers forming a tensor-parallel block ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} - external_gemm_to_overlap = {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"} + external_gemm_to_overlap = {} if IS_HIP_EXTENSION else {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"} global layers_atomic_ring_exchange layers_atomic_ring_exchange = [] @@ -346,7 +347,7 @@ def get_default_config(name): "is_reduce_scatter": is_reduce_scatter, "num_sm": 1 if method == "ring_exchange" else 16, "cga_size": 1 if method == "ring_exchange" else 2, - "set_sm_margin": not method == "ring_exchange", + "set_sm_margin": not method == "ring_exchange" and not IS_HIP_EXTENSION, "num_splits": tp_size if method == "ring_exchange" else 4, "aggregate": False, "atomic_gemm": False, @@ -426,6 +427,7 @@ def add_ub( if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf) else dtype ) + if method == "ring_exchange": ub_obj = tex.CommOverlapP2P( shape, # Communication buffer shape @@ -478,9 +480,7 @@ def add_ub( new_method = user_ub_cfg[name]["method"] methods[new_method].append(name) - for name in ( - methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"] - ): + for name in chain.from_iterable(methods.values()): ub_cfg = get_default_config(name) if user_ub_cfg is not None and name in user_ub_cfg: fp8_buf = (name in layers_all_gather_overlap) or ( diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index d14454dc0..1b1fe5a59 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -34,13 +34,12 @@ ForwardLinearScaleAdd, fuse_forward_linear_scale_add, ) -from torch.utils.cpp_extension import IS_HIP_EXTENSION -if not IS_HIP_EXTENSION: - from .userbuffers_backward_linear import ( - UserbuffersBackwardLinear, - fuse_userbuffers_backward_linear, - ) - from .userbuffers_forward_linear import ( - UserbuffersForwardLinear, - fuse_userbuffers_forward_linear, - ) + +from .userbuffers_backward_linear import ( + UserbuffersBackwardLinear, + fuse_userbuffers_backward_linear, +) +from .userbuffers_forward_linear import ( + UserbuffersForwardLinear, + fuse_userbuffers_forward_linear, +) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 89e43f845..6154463da 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -36,6 +36,8 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +from torch.utils.cpp_extension import IS_HIP_EXTENSION + warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") @@ -295,8 +297,8 @@ def __init__( ub_overlap_ag: bool = True, ub_overlap_rs: bool = True, ub_overlap_rs_dgrad: bool = False, - ub_bulk_dgrad: bool = True, - ub_bulk_wgrad: bool = True, + ub_bulk_dgrad: bool = not IS_HIP_EXTENSION, + ub_bulk_wgrad: bool = not IS_HIP_EXTENSION, bias: bool = True, activation: str = "gelu", normalization: str = "LayerNorm", From d7796535eafc9a28330ef316ea9da984ebd946de Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 27 Jan 2026 15:52:53 +0000 Subject: [PATCH 2/2] Copyrights and cleanup --- build_tools/pytorch.py | 2 +- .../comm_gemm_overlap/te_layer_with_overlap_profile.py | 2 +- tests/cpp/operator/test_normalization_mxfp8.cu | 6 +++--- tests/pytorch/distributed/run_layer_with_overlap.py | 2 +- transformer_engine/common/CMakeLists.txt | 2 +- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 2 ++ .../comm_gemm_overlap/userbuffers/userbuffers-host.cpp | 2 ++ .../common/comm_gemm_overlap/userbuffers/userbuffers.cu | 2 ++ .../common/comm_gemm_overlap/userbuffers/userbuffers.h | 2 ++ .../common/include/transformer_engine/comm_gemm_overlap.h | 2 +- .../common/include/transformer_engine/multi_stream.h | 2 +- transformer_engine/common/util/cuda_runtime.cpp | 2 +- transformer_engine/common/util/cuda_runtime.h | 2 +- transformer_engine/common/util/pybind_helper.h | 6 ++---- transformer_engine/pytorch/csrc/common.h | 2 +- transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/comm_gemm_overlap.cpp | 3 +-- transformer_engine/pytorch/module/base.py | 2 +- transformer_engine/pytorch/transformer.py | 2 ++ 19 files changed, 27 insertions(+), 20 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 8fcf72f7e..d0640dfd0 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py index ba5afd2b6..02b9b9696 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap_profile.py @@ -1,7 +1,7 @@ #!/usr/bin/python3 # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/cpp/operator/test_normalization_mxfp8.cu b/tests/cpp/operator/test_normalization_mxfp8.cu index 40c5be719..e87ed2209 100644 --- a/tests/cpp/operator/test_normalization_mxfp8.cu +++ b/tests/cpp/operator/test_normalization_mxfp8.cu @@ -131,10 +131,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, DType wtype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input("input2", std::vector{ N, H }, itype); + Tensor input("input", std::vector{ N, H }, itype); Tensor z("z", std::vector{ N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING); - Tensor gamma("gamma2", std::vector{ H }, wtype); - Tensor beta("beta2", std::vector{ H }, wtype); + Tensor gamma("gamma", std::vector{ H }, wtype); + Tensor beta("beta", std::vector{ H }, wtype); Tensor mu("mu", std::vector{ N }, DType::kFloat32); Tensor rsigma("rsigma", std::vector{ N }, DType::kFloat32); Tensor workspace; diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index ac185fd10..8e31f83a0 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -1,7 +1,7 @@ #!/usr/bin/python3 # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b8e137bc1..ceb862e15 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index d56e57a69..59531a41b 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 72324dd23..c7c9646ee 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 58ceb9d9d..83cb621a2 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index c7f26df06..72a0df7df 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 3b36ee951..4c36617d5 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/include/transformer_engine/multi_stream.h b/transformer_engine/common/include/transformer_engine/multi_stream.h index cf67711f1..64cb5a021 100644 --- a/transformer_engine/common/include/transformer_engine/multi_stream.h +++ b/transformer_engine/common/include/transformer_engine/multi_stream.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index f49ff2c0b..1ad2be342 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index b23bcaef4..e1aac6699 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index b4c1eaad0..66ae76917 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -50,9 +50,7 @@ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \ .value("EXTERNAL_BULK_OVERLAP_AG", \ - transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG) \ - .value("SPLIT_PIPELINED_AG_RD_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_RD_P2P); \ + transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG); \ py::class_>(m, "CommOverlapCore", \ pybind11::module_local()) \ diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 7e015cc53..8bb29aa4c 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index df408fad9..2ac1931b7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 29dd884fd..fe13d2f56 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -321,4 +321,3 @@ void transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm( at::cuda::CUDAStream(recv_stream), main_stream); } #endif - \ No newline at end of file diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 5d7c1e08b..3aa466cf0 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 6154463da..230b5ae1c 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information.