Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
822d7ea
Changed VERSION to 2.8.0
ptrendx Sep 19, 2025
9797673
[JAX] Remove import jax.extend.ffi (#2193)
phu0ngng Sep 22, 2025
ac51322
[PyTorch] Add sink attention support from cuDNN (#2148)
cyanguwa Sep 22, 2025
307a993
[QA] Add pytest xml report for all tests in qa folder that use pytest…
shengfangd Sep 23, 2025
c70f1d9
[JAX] Local-Amax for Current-Scaling (#2183)
mingxu1067 Sep 23, 2025
3f02a2e
[JAX] Restore Shardy Rule with CompoundFactor (#2167)
phu0ngng Sep 23, 2025
ee58762
[JAX] Update JAX version requirement in pyproject.toml (#2197)
phu0ngng Sep 24, 2025
bd8e566
[PyTorch] Unpin version of onnxscript and onnxruntime (#2202)
pggPL Sep 26, 2025
238a3fd
[JAX] Fix XML filename in the L0_jax_uniitest (#2205)
phu0ngng Sep 27, 2025
de13b8c
[JAX] CollectiveGemm (#2166)
phu0ngng Sep 27, 2025
ef38de4
[JAX] Add xml export for `test_multiprocessing_encoder` and `test_cge…
phu0ngng Sep 29, 2025
c464c85
[JAX] Address tolerance check for current scaling dact dbias (#2211)
jberchtold-nvidia Sep 29, 2025
8a7b893
[Core][PyTorch] NVFP4 recipe (#2177)
ksivaman Sep 29, 2025
51d046b
Fix the segfault in the nvfp4 quantization (#2214)
ptrendx Sep 30, 2025
5a58f50
[PyTorch] Add FP8 attention with current scaling (#2012)
cyanguwa Sep 30, 2025
88d541c
[JAX] Load modules during initialize for Norm and Act primitives (#2219)
jberchtold-nvidia Sep 30, 2025
2b2f921
Fix the cuBLAS workspace alignment (#2223)
ptrendx Oct 1, 2025
071589e
[PyTorch] Set usages for linear op quantizers before forward (#2222)
timmoon10 Oct 2, 2025
5339e97
Resolved conflicts
VeeraRajasekhar Jan 22, 2026
4e9f03d
Addressed cpp and torch tests
VeeraRajasekhar Jan 23, 2026
c01ef70
Fixed test_attention.py
VeeraRajasekhar Jan 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake==3.21.0 pybind11[global] ninja
pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -43,7 +43,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -63,7 +63,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install pybind11[global]
run: pip install pybind11[global] nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -83,7 +83,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
run: pip install torch pybind11[global] einops onnxscript
run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand Down
2 changes: 0 additions & 2 deletions benchmarks/attention/benchmark_attention_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ def sanity_checks(
cfg,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=cfg.window_size,
pad_between_seqs=pad_between_seqs,
)
flash_ok, fused_ok, _ = avail
Expand Down Expand Up @@ -368,7 +367,6 @@ def main(args):
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
Expand Down
152 changes: 152 additions & 0 deletions benchmarks/benchmark_rht_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import argparse
import torch
import pandas as pd
import torch.utils.benchmark as benchmark

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
import transformer_engine.pytorch.cpp_extensions as ext

from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer

scale_padding_to = 1
permute_scale = False

TORCH_TO_TE_FLOAT_MAP = {
torch.bfloat16: tex.DType.kBFloat16,
}


def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16):
# Generate random input data
M, K = shape
x = torch.randn([M, K], dtype=input_dtype, device="cuda")

assert shape[0] % 16 == 0, "Shape must be divisible by 16"
assert shape[1] % 16 == 0, "Shape must be divisible by 16"

# Quantize
nvfp4_quantizer = NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_amax_reduction=False,
amax_reduction_group=None,
with_rht=True,
with_post_rht_amax=True,
with_random_sign_mask=True,
stochastic_rounding=stochastic_rounding,
)
x_nvfp4_sut = nvfp4_quantizer.make_empty(
(M, K), dtype=x.dtype, device=x.device, requires_grad=False
)
x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)

with torch.no_grad():
stmt = "kernel_func(input, output)"
globals_dict = {
"kernel_func": nvfp4_quantizer.update_quantized,
"input": x,
"output": x_nvfp4_sut,
}

timing = benchmark.Timer(
stmt=stmt,
globals=globals_dict,
num_threads=1,
).blocked_autorange(min_run_time=5)
print(timing)
timing_us = timing.median * 1e6

input_nbytes = shape[0] * shape[1] * 2 # bf16
output_nbytes = shape[0] * shape[1] // 2 # //2 for fp4
sf_nbytes = shape[0] * shape[1] // 16 # //16 for 1 byte per 16 elems

total_nbytes = (
0
+ input_nbytes
* 3 # Reading input for Amax(x)&Amax(RHT(x.T)), Reading input for Cast(x), Reaindg input for Cast(RHT(x.T))
+ 2 * 4 # Output 2 * float for scale & amax
+ 2 * 4 # Input 2 * float
+ output_nbytes * 2 # Output from Cast(x) and Cast(RHT(x.T))
+ sf_nbytes * 2 # Scale factor
)

throughput_GBps = total_nbytes / (1024 * 1024 * 1024) / (timing_us / 1e6)

print(
f"Stochastic rounding: {stochastic_rounding}, Total: {total_nbytes} bytes, Throughput:"
f" {throughput_GBps} GB/s"
)
return timing_us, throughput_GBps


# Nsight Compute Profiling Command:
# ncu -f -o block_scaled_1d_cast_transpose_kernel --set=full --kernel-name "block_scaled_1d_cast_transpose_kernel" -s 5 -c 5 python benchmark_cast_transpose_1d_block.py --profile

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
args = parser.parse_args()

if args.profile:
print("Profiling is enabled.")
else:
print("Profiling is disabled.")

shapes = [
(8192, 5120),
(8192, 10240),
(8192, 2560),
(8192, 11328),
(8192, 512),
(8192, 3584),
(5120, 8192),
(10240, 8192),
(2560, 8192),
(11328, 8192),
(512, 8192),
(3584, 8192),
(4096, 16384),
(14336, 16384),
]

if args.profile:
shapes = [
(16384, 6144),
]

data = []
for stochastic_rounding in [True]: # , False]:
for shape in shapes:
print(
f"Running benchmark_func with shape {shape} and stochastic_rounding"
f" {stochastic_rounding}"
)
timing_us, throughput_GBps = run_kernel(shape, stochastic_rounding)
data.append(
[
"benchmark_func",
shape,
stochastic_rounding,
timing_us,
throughput_GBps,
]
)

df = pd.DataFrame(
data=data,
columns=[
"kernel",
"shape",
"stochastic_rounding",
"timing_us",
"throughput(GB/s)",
],
)
print(df)
df.to_csv("benchmark_cast_nvfp4.csv", index=False)
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.8.0.dev0
2.8.0
7 changes: 7 additions & 0 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,18 @@ def setup_jax_extension(
# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension

# Note: Collective GEMM operations are not supported on ROCm yet
if rocm_build():
comm_libraries = []
else:
comm_libraries = ["nccl"]

return Pybind11Extension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args=cxx_flags,
libraries=comm_libraries,
)


Expand Down
2 changes: 1 addition & 1 deletion build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"]
return ["torch>=2.1", "einops", "onnxscript", "onnx"]


def test_requirements() -> List[str]:
Expand Down
15 changes: 9 additions & 6 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,15 +305,18 @@ def get_cuda_include_dirs() -> Tuple[str, str]:

@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
version = cuda_version()
if os.getenv("NVTE_CUDA_ARCHS") is None:
archs = os.getenv("NVTE_CUDA_ARCHS")
if archs is None:
version = cuda_version()
if version >= (13, 0):
os.environ["NVTE_CUDA_ARCHS"] = "75;80;89;90;100;120"
archs = "75;80;89;90;100;100a;103a;120"
elif version >= (12, 9):
archs = "70;80;89;90;100;100a;103a;120"
elif version >= (12, 8):
os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90;100;120"
archs = "70;80;89;90;100;100a;120"
else:
os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90"
return os.getenv("NVTE_CUDA_ARCHS")
archs = "70;80;89;90"
return archs


def cuda_version() -> Tuple[int, ...]:
Expand Down
Loading
Loading