Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand Down Expand Up @@ -85,14 +85,6 @@ def setup_pytorch_extension(
if version < (12, 0):
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")

if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI")

library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
Expand All @@ -106,12 +98,22 @@ def setup_pytorch_extension(
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")

if bool(int(os.getenv("NVTE_ENABLE_ROCSHMEM", 0))):
cxx_flags.append("-DNVTE_ENABLE_ROCSHMEM")
mpi_home = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi"))
include_dirs.append(mpi_home / "include")
library_dirs.append(mpi_home / "lib")
libraries.append("mpi_cxx")
libraries.append("mpi")
cxx_flags.extend(["-DNVTE_ENABLE_ROCSHMEM", "-DOMPI_SKIP_MPICXX"])

extra_link_args = []
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
mpi_path = Path(os.getenv("MPI_HOME", "/usr/lib/x86_64-linux-gnu/openmpi"))
include_dirs.append(mpi_path / "include")
library_dirs.append(mpi_path / "lib")
libraries.append("mpi")
cxx_flags.extend(["-DNVTE_UB_WITH_MPI", "-DOMPI_SKIP_MPICXX"])

# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
Expand All @@ -125,4 +127,5 @@ def setup_pytorch_extension(
extra_compile_args={"cxx": cxx_flags},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
extra_link_args=[str(arg) for arg in extra_link_args],
)
2 changes: 2 additions & 0 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ run_test_config_mgpu(){
configure_omp_threads 8
run_default_fa 1 test_fused_optimizer.py
run_default_fa 3 test_sanity_import.py
run_default_fa 3 distributed/test_fusible_ops_with_userbuffers.py
run_default_fa 3 distributed/test_comm_gemm_overlap.py
run_default_fa 2 distributed/test_fusible_ops.py
run_default_fa 2 distributed/test_numerics.py
run_default_fa 1 distributed/test_torch_fsdp2.py
Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _parse_args(argv=None, namespace=None):
)
parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
parser.add_argument(
"--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
"--fp8", action="store_true", default=False, help="Enables the te.autocast() context."
)
parser.add_argument(
"--no-comm-overlap",
Expand Down Expand Up @@ -299,7 +299,7 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False)

dist_print(" |-- Forward pass", group=tp_group, debug=True)
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world):
with te.autocast(enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world):
y = model(x)
if isinstance(y, tuple):
out, *_ = y
Expand Down
Loading
Loading