diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index f40b28189..506bc83f0 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -19,7 +19,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
- pip install cmake==3.21.0 pybind11[global] ninja
+ pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
@@ -43,7 +43,7 @@ jobs:
run: |
apt-get update
apt-get install -y git python3.9 pip cudnn9-cuda-12
- pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
+ pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
@@ -63,7 +63,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
- run: pip install pybind11[global]
+ run: pip install pybind11[global] nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
@@ -83,7 +83,7 @@ jobs:
options: --user root
steps:
- name: 'Dependencies'
- run: pip install torch pybind11[global] einops onnxscript
+ run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1
- name: 'Checkout'
uses: actions/checkout@v3
with:
diff --git a/benchmarks/attention/benchmark_attention_rocm.py b/benchmarks/attention/benchmark_attention_rocm.py
index 0c37696ac..d3b1fd759 100644
--- a/benchmarks/attention/benchmark_attention_rocm.py
+++ b/benchmarks/attention/benchmark_attention_rocm.py
@@ -307,7 +307,6 @@ def sanity_checks(
cfg,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
- window_size=cfg.window_size,
pad_between_seqs=pad_between_seqs,
)
flash_ok, fused_ok, _ = avail
@@ -368,7 +367,6 @@ def main(args):
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
- window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
diff --git a/benchmarks/benchmark_rht_cast.py b/benchmarks/benchmark_rht_cast.py
new file mode 100644
index 000000000..9c47856f7
--- /dev/null
+++ b/benchmarks/benchmark_rht_cast.py
@@ -0,0 +1,152 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import argparse
+import torch
+import pandas as pd
+import torch.utils.benchmark as benchmark
+
+import transformer_engine.pytorch as te
+import transformer_engine_torch as tex
+import transformer_engine.pytorch.cpp_extensions as ext
+
+from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
+
+scale_padding_to = 1
+permute_scale = False
+
+TORCH_TO_TE_FLOAT_MAP = {
+ torch.bfloat16: tex.DType.kBFloat16,
+}
+
+
+def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16):
+ # Generate random input data
+ M, K = shape
+ x = torch.randn([M, K], dtype=input_dtype, device="cuda")
+
+ assert shape[0] % 16 == 0, "Shape must be divisible by 16"
+ assert shape[1] % 16 == 0, "Shape must be divisible by 16"
+
+ # Quantize
+ nvfp4_quantizer = NVFP4Quantizer(
+ fp4_dtype=tex.DType.kFloat4E2M1,
+ rowwise=True,
+ columnwise=True,
+ with_amax_reduction=False,
+ amax_reduction_group=None,
+ with_rht=True,
+ with_post_rht_amax=True,
+ with_random_sign_mask=True,
+ stochastic_rounding=stochastic_rounding,
+ )
+ x_nvfp4_sut = nvfp4_quantizer.make_empty(
+ (M, K), dtype=x.dtype, device=x.device, requires_grad=False
+ )
+ x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
+
+ with torch.no_grad():
+ stmt = "kernel_func(input, output)"
+ globals_dict = {
+ "kernel_func": nvfp4_quantizer.update_quantized,
+ "input": x,
+ "output": x_nvfp4_sut,
+ }
+
+ timing = benchmark.Timer(
+ stmt=stmt,
+ globals=globals_dict,
+ num_threads=1,
+ ).blocked_autorange(min_run_time=5)
+ print(timing)
+ timing_us = timing.median * 1e6
+
+ input_nbytes = shape[0] * shape[1] * 2 # bf16
+ output_nbytes = shape[0] * shape[1] // 2 # //2 for fp4
+ sf_nbytes = shape[0] * shape[1] // 16 # //16 for 1 byte per 16 elems
+
+ total_nbytes = (
+ 0
+ + input_nbytes
+ * 3 # Reading input for Amax(x)&Amax(RHT(x.T)), Reading input for Cast(x), Reaindg input for Cast(RHT(x.T))
+ + 2 * 4 # Output 2 * float for scale & amax
+ + 2 * 4 # Input 2 * float
+ + output_nbytes * 2 # Output from Cast(x) and Cast(RHT(x.T))
+ + sf_nbytes * 2 # Scale factor
+ )
+
+ throughput_GBps = total_nbytes / (1024 * 1024 * 1024) / (timing_us / 1e6)
+
+ print(
+ f"Stochastic rounding: {stochastic_rounding}, Total: {total_nbytes} bytes, Throughput:"
+ f" {throughput_GBps} GB/s"
+ )
+ return timing_us, throughput_GBps
+
+
+# Nsight Compute Profiling Command:
+# ncu -f -o block_scaled_1d_cast_transpose_kernel --set=full --kernel-name "block_scaled_1d_cast_transpose_kernel" -s 5 -c 5 python benchmark_cast_transpose_1d_block.py --profile
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--profile", action="store_true", help="Enable profiling mode")
+ args = parser.parse_args()
+
+ if args.profile:
+ print("Profiling is enabled.")
+ else:
+ print("Profiling is disabled.")
+
+ shapes = [
+ (8192, 5120),
+ (8192, 10240),
+ (8192, 2560),
+ (8192, 11328),
+ (8192, 512),
+ (8192, 3584),
+ (5120, 8192),
+ (10240, 8192),
+ (2560, 8192),
+ (11328, 8192),
+ (512, 8192),
+ (3584, 8192),
+ (4096, 16384),
+ (14336, 16384),
+ ]
+
+ if args.profile:
+ shapes = [
+ (16384, 6144),
+ ]
+
+ data = []
+ for stochastic_rounding in [True]: # , False]:
+ for shape in shapes:
+ print(
+ f"Running benchmark_func with shape {shape} and stochastic_rounding"
+ f" {stochastic_rounding}"
+ )
+ timing_us, throughput_GBps = run_kernel(shape, stochastic_rounding)
+ data.append(
+ [
+ "benchmark_func",
+ shape,
+ stochastic_rounding,
+ timing_us,
+ throughput_GBps,
+ ]
+ )
+
+ df = pd.DataFrame(
+ data=data,
+ columns=[
+ "kernel",
+ "shape",
+ "stochastic_rounding",
+ "timing_us",
+ "throughput(GB/s)",
+ ],
+ )
+ print(df)
+ df.to_csv("benchmark_cast_nvfp4.csv", index=False)
diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt
index 81006d78c..834f26295 100644
--- a/build_tools/VERSION.txt
+++ b/build_tools/VERSION.txt
@@ -1 +1 @@
-2.8.0.dev0
+2.8.0
diff --git a/build_tools/jax.py b/build_tools/jax.py
index 182940c11..20679defc 100644
--- a/build_tools/jax.py
+++ b/build_tools/jax.py
@@ -100,11 +100,18 @@ def setup_jax_extension(
# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension
+ # Note: Collective GEMM operations are not supported on ROCm yet
+ if rocm_build():
+ comm_libraries = []
+ else:
+ comm_libraries = ["nccl"]
+
return Pybind11Extension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args=cxx_flags,
+ libraries=comm_libraries,
)
diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py
index bb084293f..0799c73b9 100644
--- a/build_tools/pytorch.py
+++ b/build_tools/pytorch.py
@@ -27,7 +27,7 @@
def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
- return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"]
+ return ["torch>=2.1", "einops", "onnxscript", "onnx"]
def test_requirements() -> List[str]:
diff --git a/build_tools/utils.py b/build_tools/utils.py
index e3c5b6be8..a6514bf78 100644
--- a/build_tools/utils.py
+++ b/build_tools/utils.py
@@ -305,15 +305,18 @@ def get_cuda_include_dirs() -> Tuple[str, str]:
@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
- version = cuda_version()
- if os.getenv("NVTE_CUDA_ARCHS") is None:
+ archs = os.getenv("NVTE_CUDA_ARCHS")
+ if archs is None:
+ version = cuda_version()
if version >= (13, 0):
- os.environ["NVTE_CUDA_ARCHS"] = "75;80;89;90;100;120"
+ archs = "75;80;89;90;100;100a;103a;120"
+ elif version >= (12, 9):
+ archs = "70;80;89;90;100;100a;103a;120"
elif version >= (12, 8):
- os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90;100;120"
+ archs = "70;80;89;90;100;100a;120"
else:
- os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90"
- return os.getenv("NVTE_CUDA_ARCHS")
+ archs = "70;80;89;90"
+ return archs
def cuda_version() -> Tuple[int, ...]:
diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py
new file mode 100644
index 000000000..da79b2137
--- /dev/null
+++ b/examples/jax/collective_gemm/common.py
@@ -0,0 +1,245 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+"""Shared functions for the comm_overlap tests"""
+
+import jax.numpy as jnp
+import numpy as np
+
+
+# Add this after your existing imports
+def dtype_tols(dtype, rtol=None, atol=None):
+ """Expected numerical tolerance for a data type."""
+ # Return immediately if tolerances are fully specified
+ if rtol is not None and atol is not None:
+ return {"rtol": rtol, "atol": atol}
+
+ # Default tolerances for common dtypes
+ if dtype in [jnp.float32, "float32"]:
+ return {"rtol": 1e-5, "atol": 1e-8}
+ elif dtype in [jnp.float16, "float16"]:
+ return {"rtol": 1e-3, "atol": 1e-6}
+ elif dtype in [jnp.bfloat16, "bfloat16"]:
+ return {"rtol": 1e-2, "atol": 1e-5}
+ else:
+ return {"rtol": 1e-5, "atol": 1e-8}
+
+
+def assert_allclose(
+ actual,
+ desired,
+ rtol=None,
+ atol=None,
+ dtype=None,
+ **kwargs,
+):
+ """Check if two tensors are close."""
+ # Infer data type if needed
+ if dtype is None:
+ if isinstance(actual, float):
+ dtype = "float32"
+ else:
+ dtype = actual.dtype
+
+ # Determine tolerances
+ tols = {}
+ if rtol is None or atol is None:
+ tols = dtype_tols(dtype)
+ if rtol is not None:
+ tols["rtol"] = rtol
+ if atol is not None:
+ tols["atol"] = atol
+
+ # Cast tensors to fp32
+ if not isinstance(actual, float):
+ actual = actual.astype(jnp.float32)
+ if not isinstance(desired, float):
+ desired = desired.astype(jnp.float32)
+
+ # Check if tensors are close
+ np.testing.assert_allclose(actual, desired, **tols, **kwargs)
+
+
+def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-8):
+ if not jnp.allclose(ref_output, gathered_output, rtol=rtol, atol=atol):
+ diff = jnp.abs(ref_output - gathered_output)
+ mask = diff > (atol + rtol * jnp.abs(gathered_output))
+ print(mask.astype(int))
+ print(jnp.where(mask, diff, 0))
+
+
+# Shared constants for all tests
+DP_AXIS = "data"
+TPSP_AXIS = "tensor_sequence"
+PARAMS_KEY = "params"
+
+# Shared functions for distributed testing
+import argparse
+import jax
+from jax.experimental import mesh_utils
+from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap
+
+# Global flag to track if distributed has been initialized
+_distributed_initialized = False
+
+
+def _is_distributed_initialized():
+ """Check if JAX distributed has been initialized."""
+ return _distributed_initialized
+
+
+def _initialize_distributed(args):
+ """Initialize JAX distributed with custom arguments."""
+ global _distributed_initialized
+
+ # Check if already initialized
+ if _distributed_initialized:
+ return
+
+ if args.coordinator_address is None or args.num_processes is None or args.process_id is None:
+ raise ValueError(
+ "All distributed initialization arguments are required: "
+ "--coordinator-address, --num-processes, --process-id"
+ )
+ if args.local_device_ids is None:
+ assert (
+ args.num_devices_per_process is not None
+ ), "Either local_device_ids or num_devices_per_process must be provided"
+ # Calculate device range for this process
+ # Single process single device: each process gets one unique device
+ # Single process multiple devices: each process gets a unique range of devices
+ start_device = args.process_id * args.num_devices_per_process
+ device_range = range(start_device, start_device + args.num_devices_per_process)
+ global_device_ids_for_this_process = ",".join(map(str, device_range))
+ else:
+ # Use explicitly provided global device IDs
+ global_device_ids_for_this_process = args.local_device_ids
+ args.num_devices_per_process = len(args.local_device_ids.split(","))
+
+ assert args.num_devices_per_process == 1, "Only single process single GPU is supported!"
+
+ print(
+ f"Initializing JAX distributed with coordinator={args.coordinator_address}, "
+ f"num_processes={args.num_processes}, process_id={args.process_id}"
+ )
+ # Note: "local_device_ids" is a JAX term meaning "global CUDA devices managed by this process"
+ jax.distributed.initialize(
+ coordinator_address=args.coordinator_address,
+ num_processes=args.num_processes,
+ process_id=args.process_id,
+ local_device_ids=global_device_ids_for_this_process,
+ )
+
+ _distributed_initialized = True
+ jax.clear_caches()
+ jax.config.update(
+ "jax_use_shardy_partitioner", False
+ ) # CollectiveGEMM does not work with Shardy yet
+
+ assert jax.local_device_count() == 1, (
+ f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found"
+ f" {jax.local_device_count()}"
+ )
+
+ devices_per_process = 1
+ num_total_devices = args.num_processes
+
+ print(
+ f"Initializing CGEMM communicator with num_total_devices={num_total_devices},"
+ f" devices_per_process={devices_per_process}, process_id={args.process_id}"
+ )
+
+ collective_gemm_bootstrap(
+ num_total_devices=num_total_devices,
+ num_devices_per_process=devices_per_process,
+ process_id=args.process_id,
+ tensor_parallel_size=args.tensor_parallel_size,
+ )
+
+
+def _get_dp_and_tp_sizes(args):
+ num_gpu = args.num_processes * args.num_devices_per_process
+ if args.tensor_parallel_size is None:
+ num_gpu_dp = 2 if args.enable_data_parallel else 1
+ assert (
+ num_gpu > 1 and num_gpu % num_gpu_dp == 0
+ ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
+ num_gpu_tp = num_gpu // num_gpu_dp
+ else:
+ num_gpu_tp = args.tensor_parallel_size
+ assert (
+ num_gpu > 1 and num_gpu % num_gpu_tp == 0
+ ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
+ num_gpu_dp = num_gpu // num_gpu_tp
+ return num_gpu_dp, num_gpu_tp
+
+
+def _create_mesh(args):
+ """Create mesh configuration with proper validation."""
+ num_gpu = args.num_processes * args.num_devices_per_process
+ assert num_gpu == len(jax.devices()), "Number of GPUs must be equal to number of devices"
+ num_gpu_dp, num_gpu_tp = _get_dp_and_tp_sizes(args)
+
+ print(f"Using {num_gpu_dp}x{num_gpu_tp} mesh ({num_gpu_dp * num_gpu_tp} total GPUs)")
+
+ device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
+ mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=(DP_AXIS, TPSP_AXIS))
+ return mesh
+
+
+def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"):
+ """Create common argument parser for all collective GEMM tests."""
+ parser = argparse.ArgumentParser(description=description)
+
+ # Distributed initialization arguments
+ parser.add_argument(
+ "--coordinator-address",
+ type=str,
+ default=None,
+ help="Coordinator address for distributed initialization",
+ )
+ parser.add_argument(
+ "--num-processes",
+ type=int,
+ default=None,
+ help="Number of processes for distributed initialization",
+ )
+ parser.add_argument(
+ "--process-id", type=int, default=None, help="Process ID for distributed initialization"
+ )
+ parser.add_argument(
+ "--local-device-ids",
+ type=str,
+ default=None,
+ help="Local device IDs for distributed initialization (comma-separated)",
+ )
+ parser.add_argument(
+ "--num-devices-per-process", type=int, default=1, help="Number of devices per process"
+ )
+
+ # Test configuration arguments
+ parser.add_argument(
+ "--tensor-parallel-size", type=int, default=None, help="Tensor parallel size"
+ )
+ parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing")
+ parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing")
+ parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension")
+ parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension")
+ parser.add_argument(
+ "--collective-type",
+ type=str,
+ default="all_gather",
+ choices=["all_gather", "reduce_scatter"],
+ help="Type of collective operation",
+ )
+ parser.add_argument(
+ "--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use"
+ )
+ parser.add_argument(
+ "--enable-data-parallel", action="store_true", help="Enable data parallelism"
+ )
+ parser.add_argument(
+ "--enable-result-check", action="store_true", default=True, help="Enable result checking"
+ )
+
+ return parser
diff --git a/examples/jax/collective_gemm/conftest.py b/examples/jax/collective_gemm/conftest.py
new file mode 100644
index 000000000..83937971a
--- /dev/null
+++ b/examples/jax/collective_gemm/conftest.py
@@ -0,0 +1,29 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+"""config for collective_gemm tests"""
+import pytest
+
+
+def pytest_addoption(parser):
+ """Pytest hook for collective_gemm tests"""
+ parser.addoption("--coordinator-address", action="store", default="localhost:12345")
+ parser.addoption("--num-processes", action="store", default=1)
+ parser.addoption("--process-id", action="store", default=0)
+ parser.addoption("--local-device-ids", action="store", default=None)
+
+
+@pytest.fixture(autouse=True)
+def distributed_args(request):
+ """Fixture for querying distributed initialization arguments"""
+ if request.cls:
+ request.cls.coordinator_address = request.config.getoption("--coordinator-address")
+ request.cls.num_processes = int(request.config.getoption("--num-processes"))
+ request.cls.process_id = int(request.config.getoption("--process-id"))
+ request.cls.local_device_ids = request.config.getoption("--local-device-ids")
+ request.cls.num_devices_per_process = (
+ 1
+ if request.cls.local_device_ids is None
+ else len(request.cls.local_device_ids.split(","))
+ )
diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh
new file mode 100644
index 000000000..af263eb53
--- /dev/null
+++ b/examples/jax/collective_gemm/run_test_cgemm.sh
@@ -0,0 +1,119 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
+
+: ${TE_PATH:=/opt/transformerengine}
+: ${XML_LOG_DIR:=/logs}
+mkdir -p "$XML_LOG_DIR"
+
+# Check if NVLINK is supported before running tests
+echo "*** Checking NVLINK support***"
+NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1)
+NVLINK_EXIT_CODE=$?
+
+# Check if command failed OR output indicates no NVLINK
+if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then
+ echo "NVLINK is not supported on this platform"
+ echo "Collective GEMM tests require NVLINK connectivity"
+ echo "SKIPPING all tests"
+ exit 0
+else
+ echo "NVLINK support detected"
+fi
+
+# Define the test files to run
+TEST_FILES=(
+"test_gemm.py"
+"test_dense_grad.py"
+"test_layernorm_mlp_grad.py"
+)
+
+echo
+echo "*** Executing tests in examples/jax/collective_gemm/ ***"
+
+HAS_FAILURE=0 # Global failure flag
+PIDS=() # Array to store all process PIDs
+
+# Cleanup function to kill all processes
+cleanup() {
+ for pid in "${PIDS[@]}"; do
+ if kill -0 "$pid" 2>/dev/null; then
+ echo "Killing process $pid"
+ kill -TERM "$pid" 2>/dev/null || true
+ fi
+ done
+ # Wait a bit and force kill if needed
+ sleep 2
+ for pid in "${PIDS[@]}"; do
+ if kill -0 "$pid" 2>/dev/null; then
+ echo "Force killing process $pid"
+ kill -KILL "$pid" 2>/dev/null || true
+ fi
+ done
+}
+
+# Set up signal handlers to cleanup on exit
+trap cleanup EXIT INT TERM
+
+# Run each test file across all GPUs
+for TEST_FILE in "${TEST_FILES[@]}"; do
+ echo
+ echo "=== Starting test file: $TEST_FILE ..."
+
+ # Clear PIDs array for this test file
+ PIDS=()
+
+ for i in $(seq 0 $(($NUM_GPUS - 1))); do
+ # Define output file for logs
+ LOG_FILE="${TEST_FILE}_gpu_${i}.log"
+
+ if [ $i -eq 0 ]; then
+ # For process 0: show live output AND save to log file using tee
+ echo "=== Live output from process 0 ==="
+ pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
+ -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \
+ "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
+ --num-processes=$NUM_GPUS \
+ --process-id=$i 2>&1 | tee "$LOG_FILE" &
+ PID=$!
+ PIDS+=($PID)
+ else
+ # For other processes: redirect to log files only
+ pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
+ -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
+ --num-processes=$NUM_GPUS \
+ --process-id=$i > "$LOG_FILE" 2>&1 &
+ PID=$!
+ PIDS+=($PID)
+ fi
+ done
+
+ # Wait for all processes to finish
+ wait
+
+ # Check and print the log content from process 0 (now has log file thanks to tee)
+ if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then
+ echo "... $TEST_FILE SKIPPED"
+ elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then
+ echo "... $TEST_FILE FAILED"
+ HAS_FAILURE=1
+ elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then
+ echo "... $TEST_FILE PASSED"
+ else
+ echo "... $TEST_FILE INVALID"
+ HAS_FAILURE=1
+ fi
+
+ # Remove the log files after processing them
+ wait
+ rm ${TEST_FILE}_gpu_*.log
+done
+
+wait
+
+# Final cleanup (trap will also call cleanup on exit)
+cleanup
+
+exit $HAS_FAILURE
diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py
new file mode 100644
index 000000000..df2dd5618
--- /dev/null
+++ b/examples/jax/collective_gemm/test_dense_grad.py
@@ -0,0 +1,214 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""
+import argparse
+import unittest
+import os
+
+import jax
+import jax.numpy as jnp
+from jax.sharding import PartitionSpec, NamedSharding
+import flax
+
+from common import (
+ assert_allclose,
+ _initialize_distributed,
+ _get_dp_and_tp_sizes,
+ _create_mesh,
+ DP_AXIS,
+ TPSP_AXIS,
+ PARAMS_KEY,
+ cgemm_parser,
+)
+
+from transformer_engine.jax.dense import dense
+
+from transformer_engine.jax.quantize import fp8_autocast
+from transformer_engine.jax.cpp_extensions.gemm import (
+ CollectiveOp,
+ CollectiveOpSet,
+ noop_collective_op_set,
+)
+from transformer_engine.jax.sharding import MeshResource
+import transformer_engine.jax.flax as te_flax
+
+
+def _get_logical_axes(collective_op):
+ if collective_op.is_all_gather:
+ input_axes = (DP_AXIS, TPSP_AXIS, None)
+ weight_axes = (None, TPSP_AXIS)
+ bias_axes = (TPSP_AXIS,)
+ output_axes = (DP_AXIS, None, TPSP_AXIS)
+ else: # RS
+ input_axes = (DP_AXIS, None, TPSP_AXIS)
+ weight_axes = (TPSP_AXIS, None)
+ bias_axes = (None,)
+ output_axes = (DP_AXIS, TPSP_AXIS, None)
+ return input_axes, weight_axes, bias_axes, output_axes
+
+
+def _get_operand_sharding(mesh, collective_op):
+ input_axes, weight_axes, bias_axes, _ = _get_logical_axes(collective_op)
+ x_sharding = NamedSharding(mesh, PartitionSpec(*input_axes))
+ weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_axes))
+ bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes))
+ return x_sharding, weight_sharding, bias_sharding
+
+
+def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set):
+ output = dense(
+ x,
+ weight,
+ bias,
+ contracting_dims=((2,), (0,)),
+ input_axes=input_axes,
+ kernel_axes=weight_axes,
+ output_axes=output_axes,
+ collective_op_set=collective_op_set,
+ )
+ return jnp.mean(output.astype(jnp.float32))
+
+
+def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set):
+ return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))(
+ x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set
+ )
+
+
+def run_dense_grad_tests(args, mesh=None):
+ """Execute Dense Gradient tests."""
+ print(args)
+ _initialize_distributed(args)
+ mesh = mesh or _create_mesh(args)
+
+ # Create test data
+ rng = jax.random.PRNGKey(0)
+ rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4)
+ x = jax.random.normal(
+ x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
+ )
+ weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16)
+ bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16)
+
+ collective_op = (
+ CollectiveOp.ALL_GATHER
+ if args.collective_type == "all_gather"
+ else CollectiveOp.REDUCE_SCATTER
+ )
+ collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op)
+
+ with mesh, fp8_autocast(
+ enabled=False,
+ fp8_recipe=None,
+ mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
+ ):
+ # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
+ axis_rules = flax.linen.get_logical_axis_rules()
+ axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS))
+ te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
+ with flax.linen.logical_axis_rules(te_extended_axis_rules):
+
+ x_sharding, weight_sharding, bias_sharding = _get_operand_sharding(mesh, collective_op)
+ x_sharded = jax.device_put(x, x_sharding)
+ weight_sharded = jax.device_put(weight, weight_sharding)
+ bias_sharded = jax.device_put(bias, bias_sharding)
+
+ input_axes, weight_axes, _, output_axes = _get_logical_axes(collective_op)
+ ref_output, ref_grads = _value_and_grad_dense(
+ x_sharded,
+ weight_sharded,
+ bias_sharded,
+ input_axes,
+ weight_axes,
+ output_axes,
+ noop_collective_op_set,
+ )
+ output, sharded_grads = _value_and_grad_dense(
+ x_sharded,
+ weight_sharded,
+ bias_sharded,
+ input_axes,
+ weight_axes,
+ output_axes,
+ collective_op_set,
+ )
+ jax.block_until_ready(ref_output)
+ jax.block_until_ready(output)
+ gathered_grads = []
+ gathered_ref_grads = []
+ for ref_grad, grad in zip(ref_grads, sharded_grads):
+ gathered_grads.append(
+ jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None)))
+ )
+ gathered_ref_grads.append(
+ jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None)))
+ )
+ jax.block_until_ready(gathered_grads)
+ jax.block_until_ready(gathered_ref_grads)
+
+ if args.enable_result_check and args.process_id == 0:
+ assert_allclose(ref_output, output, dtype=jnp.bfloat16)
+ for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads):
+ assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16)
+
+
+class TestCollectiveDenseGradient(unittest.TestCase):
+ """Collective Dense Gradient unittests"""
+
+ def setUp(self):
+ self.args = cgemm_parser(
+ "Collective Dense Gradient test on multi-GPU with tensor parallelism"
+ ).parse_args([])
+ self.args.coordinator_address = self.coordinator_address
+ self.args.num_processes = self.num_processes
+ self.args.process_id = self.process_id
+ self.args.local_device_ids = self.local_device_ids
+ self.args.num_devices_per_process = self.num_devices_per_process
+ self.args.enable_data_parallel = True
+ self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
+ _initialize_distributed(self.args)
+ # Create mesh once for all tests
+ self.mesh = _create_mesh(self.args)
+ jax.sharding.set_mesh(self.mesh)
+ self.args.enable_result_check = True
+ os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
+
+ def tearDown(self):
+ os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
+
+ def test_te_bf16_all_gather(self):
+ """Test Collective Dense Gradient with AllGather"""
+ self.args.collective_type = "all_gather"
+ run_dense_grad_tests(self.args, self.mesh)
+
+ def test_te_bf16_reduce_scatter(self):
+ """Test Collective Dense Gradient with ReduceScatter"""
+ self.args.collective_type = "reduce_scatter"
+ run_dense_grad_tests(self.args, self.mesh)
+
+
+if __name__ == "__main__":
+ import sys
+
+ if len(sys.argv) < 7: # Need at least the 3 required distributed args
+ print("Error: This script requires distributed initialization arguments.")
+ print(
+ "Usage: python test_dense_grad.py --coordinator-address
--num-processes "
+ " --process-id [--local-device-ids ] [other args]"
+ )
+ print(
+ "Example: python test_dense_grad.py --coordinator-address localhost:1234"
+ " --num-processes 4 --process-id 0"
+ )
+ print(
+ "Example: python test_dense_grad.py --coordinator-address localhost:1234"
+ " --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3"
+ )
+ sys.exit(1)
+
+ args = cgemm_parser(
+ "Collective Dense Gradient test on multi-GPU with tensor parallelism"
+ ).parse_args([])
+ _initialize_distributed(args)
+ run_dense_grad_tests(args, mesh=None)
diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py
new file mode 100644
index 000000000..307e4444e
--- /dev/null
+++ b/examples/jax/collective_gemm/test_gemm.py
@@ -0,0 +1,206 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+"""Collective GEMM test on multi-GPU with tensor parallelism
+
+This script uses custom distributed initialization with the following arguments:
+- --coordinator-address: Coordinator address for distributed initialization
+- --num-processes: Number of processes for distributed initialization
+- --process-id: Process ID for distributed initialization
+- --local-device-ids: Local device IDs for distributed initialization
+
+Example:
+ python test_gemm.py --coordinator-address localhost:1234 --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3
+"""
+import unittest
+import os
+from functools import partial
+
+import jax
+import jax.numpy as jnp
+from jax.sharding import PartitionSpec, NamedSharding
+
+from common import (
+ assert_allclose,
+ _initialize_distributed,
+ _get_dp_and_tp_sizes,
+ _create_mesh,
+ DP_AXIS,
+ TPSP_AXIS,
+ PARAMS_KEY,
+ cgemm_parser,
+)
+
+import transformer_engine.jax.cpp_extensions as tex
+from transformer_engine.jax.quantize import fp8_autocast
+from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp
+from transformer_engine.jax.sharding import MeshResource
+
+
+def _get_operand_sharding(mesh, collective_op, is_with_dp):
+
+ dp_axis = DP_AXIS if is_with_dp else None
+ if collective_op == CollectiveOp.ALL_GATHER:
+ x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None))
+ weight_sharding = NamedSharding(mesh, PartitionSpec(None, TPSP_AXIS))
+ bias_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS))
+ output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS))
+ else: # RS
+ x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS))
+ weight_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS, None))
+ bias_sharding = NamedSharding(mesh, PartitionSpec(None))
+ output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None))
+
+ return x_sharding, weight_sharding, bias_sharding, output_sharding
+
+
+def _get_dp_and_tp_sizes(args):
+ num_gpu = args.num_processes * args.num_devices_per_process
+ if args.tensor_parallel_size is None:
+ num_gpu_dp = 2 if args.enable_data_parallel else 1
+ assert (
+ num_gpu > 1 and num_gpu % num_gpu_dp == 0
+ ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
+ num_gpu_tp = num_gpu // num_gpu_dp
+ else:
+ num_gpu_tp = args.tensor_parallel_size
+ assert (
+ num_gpu > 1 and num_gpu % num_gpu_tp == 0
+ ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs"
+ num_gpu_dp = num_gpu // num_gpu_tp
+ return num_gpu_dp, num_gpu_tp
+
+
+@partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding"))
+def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_sharding):
+ output = tex.gemm(
+ x,
+ weight,
+ bias=bias,
+ contracting_dims=contracting_dims,
+ collective_op=collective_op,
+ )
+ if output_sharding is not None:
+ output = jax.lax.with_sharding_constraint(output, output_sharding)
+ return output
+
+
+def run_gemm_tests(args, mesh=None):
+ """Execute GEMM tests."""
+ print(args)
+ # Collective GEMM requires Shardy partitioner to be disabled
+ jax.config.update("jax_use_shardy_partitioner", False)
+
+ # Initialize distributed with provided arguments
+ _initialize_distributed(args)
+ mesh = mesh or _create_mesh(args)
+
+ # Create test data
+ rng = jax.random.PRNGKey(0)
+ rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4)
+ x = jax.random.normal(
+ x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
+ )
+ weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16)
+ bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16)
+ collective_op = (
+ CollectiveOp.ALL_GATHER
+ if args.collective_type == "all_gather"
+ else CollectiveOp.REDUCE_SCATTER
+ )
+
+ with mesh, fp8_autocast(
+ enabled=False,
+ fp8_recipe=None,
+ mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
+ ):
+ print(f"Device mesh: {mesh}")
+
+ x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding(
+ mesh, collective_op, args.enable_data_parallel
+ )
+ x_sharded = jax.device_put(x, x_sharding)
+ weight_sharded = jax.device_put(weight, weight_sharding)
+ bias_sharded = jax.device_put(bias, bias_sharding)
+
+ ref_output = _jitted_cgemm(
+ x_sharded,
+ weight_sharded,
+ bias_sharded,
+ contracting_dims=((2,), (0,)),
+ collective_op=CollectiveOp.NONE,
+ output_sharding=output_sharding,
+ )
+ output = _jitted_cgemm(
+ x_sharded,
+ weight_sharded,
+ bias_sharded,
+ contracting_dims=((2,), (0,)),
+ collective_op=collective_op,
+ # CollectiveGEMM output should have a correct sharding without applying sharding constraint
+ output_sharding=None,
+ )
+ assert (
+ ref_output.sharding == output.sharding
+ ), f"ref_output.sharding={ref_output.sharding}, output.sharding={output.sharding}"
+ gathered_ref_output = jax.lax.with_sharding_constraint(
+ ref_output, NamedSharding(mesh, PartitionSpec(None))
+ )
+ gathered_output = jax.lax.with_sharding_constraint(
+ output, NamedSharding(mesh, PartitionSpec(None))
+ )
+ jax.block_until_ready(gathered_ref_output)
+ jax.block_until_ready(gathered_output)
+
+ if args.enable_result_check and args.process_id == 0:
+ assert_allclose(gathered_ref_output, gathered_output)
+
+
+class TestCollectiveGemmWithDP(unittest.TestCase):
+ """Collective GEMM with DP unittests"""
+
+ def setUp(self):
+ self.args = cgemm_parser(
+ "Collective GEMM test on multi-GPU with tensor parallelism"
+ ).parse_args([])
+ self.args.coordinator_address = self.coordinator_address
+ self.args.num_processes = self.num_processes
+ self.args.process_id = self.process_id
+ self.args.local_device_ids = self.local_device_ids
+ self.args.num_devices_per_process = self.num_devices_per_process
+ self.args.enable_data_parallel = True
+ self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
+ _initialize_distributed(self.args)
+ self.mesh = _create_mesh(self.args)
+ jax.sharding.set_mesh(self.mesh)
+ self.args.enable_result_check = True
+ os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
+
+ def tearDown(self):
+ os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
+
+ def test_te_bf16_all_gather_with_dp(self):
+ """Test Collective GEMM with AllGather"""
+ self.args.collective_type = "all_gather"
+ run_gemm_tests(self.args, self.mesh)
+
+ def test_te_bf16_reduce_scatter_with_dp(self):
+ """Test Collective GEMM with ReduceScatter"""
+ self.args.collective_type = "reduce_scatter"
+ run_gemm_tests(self.args, self.mesh)
+
+
+if __name__ == "__main__":
+ import sys
+
+ if len(sys.argv) < 5: # Need at least the 3 required distributed args
+ print("Error: This script requires distributed initialization arguments.")
+ print(
+ "Usage: python test_gemm.py --coordinator-address --num-processes "
+ " --process-id [--local-device-ids ] [other args]"
+ )
+ sys.exit(1)
+
+ args = cgemm_parser("Collective GEMM test on multi-GPU with tensor parallelism").parse_args()
+ _initialize_distributed(args)
+ run_gemm_tests(args, mesh=None)
diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py
new file mode 100644
index 000000000..7bd6eb6a3
--- /dev/null
+++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py
@@ -0,0 +1,272 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+"""Collective Dense Gradient test on multi-GPU with tensor parallelism"""
+import argparse
+import unittest
+import os
+
+import jax
+import jax.numpy as jnp
+from jax.sharding import PartitionSpec, NamedSharding
+import flax
+
+from common import (
+ assert_allclose,
+ _initialize_distributed,
+ _get_dp_and_tp_sizes,
+ _create_mesh,
+ DP_AXIS,
+ TPSP_AXIS,
+ PARAMS_KEY,
+ cgemm_parser,
+)
+
+from transformer_engine.jax.layernorm_mlp import layernorm_mlp
+
+from transformer_engine.jax.quantize import fp8_autocast
+from transformer_engine.jax.cpp_extensions.gemm import (
+ CollectiveOpSet,
+ CollectiveOp,
+ noop_collective_op_set,
+)
+from transformer_engine.jax.sharding import MeshResource
+import transformer_engine.jax.flax as te_flax
+
+
+def _get_logical_axes():
+ input_1_axes = (DP_AXIS, TPSP_AXIS, None)
+ weight_1_axes = (None, None, TPSP_AXIS)
+ bias_axes_1 = (None, TPSP_AXIS)
+ input_2_axes = (DP_AXIS, None, TPSP_AXIS)
+ weight_2_axes = (TPSP_AXIS, None)
+ bias_axes_2 = (None,)
+ return input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2
+
+
+def _get_operand_sharding(mesh):
+ input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2 = (
+ _get_logical_axes()
+ )
+ x_sharding = NamedSharding(mesh, PartitionSpec(*input_1_axes))
+ weight_1_sharding = NamedSharding(mesh, PartitionSpec(*weight_1_axes))
+ bias_1_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_1))
+ weight_2_sharding = NamedSharding(mesh, PartitionSpec(*weight_2_axes))
+ bias_2_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_2))
+ return x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding
+
+
+def _mean_layernorm_mlp(
+ x,
+ weight_1,
+ bias_1,
+ weight_2,
+ bias_2,
+ gamma,
+ input_1_axes,
+ input_2_axes,
+ weight_1_axes,
+ weight_2_axes,
+ collective_op_sets,
+):
+ output = layernorm_mlp(
+ x,
+ gamma,
+ beta=None,
+ kernels=[weight_1, weight_2],
+ biases=[bias_1, bias_2],
+ norm_type="rmsnorm",
+ dot_1_input_axes=input_1_axes,
+ dot_2_input_axes=input_2_axes,
+ kernel_1_axes=weight_1_axes,
+ kernel_2_axes=weight_2_axes,
+ activation_type=("gelu",),
+ collective_op_sets=collective_op_sets,
+ )
+ return jnp.mean(output)
+
+
+def _value_and_grad_layernorm_mlp(
+ x,
+ weight_1,
+ bias_1,
+ weight_2,
+ bias_2,
+ gamma,
+ input_1_axes,
+ input_2_axes,
+ weight_1_axes,
+ weight_2_axes,
+ collective_op_sets,
+):
+ return jax.jit(
+ jax.value_and_grad(_mean_layernorm_mlp, (0, 1, 2, 3, 4, 5)), static_argnums=(6, 7, 8, 9, 10)
+ )(
+ x,
+ weight_1,
+ bias_1,
+ weight_2,
+ bias_2,
+ gamma,
+ input_1_axes,
+ input_2_axes,
+ weight_1_axes,
+ weight_2_axes,
+ collective_op_sets,
+ )
+
+
+def run_layernorm_mlp_grad_tests(args, mesh=None):
+ """Execute Dense Gradient tests."""
+ print(args)
+ # Collective GEMM requires Shardy partitioner to be disabled
+ jax.config.update("jax_use_shardy_partitioner", False)
+
+ # Initialize distributed with provided arguments
+ _initialize_distributed(args)
+
+ mesh = mesh or _create_mesh(args)
+
+ # Create test data
+ rng = jax.random.PRNGKey(0)
+ rng, x_rng, weight_1_rng, bias_1_rng, weight_2_rng, bias_2_rng, gamma_rng = jax.random.split(
+ rng, 7
+ )
+ x = jax.random.normal(
+ x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16
+ )
+ weight_1 = jax.random.normal(
+ weight_1_rng, (args.hidden_in, 1, args.hidden_out), dtype=jnp.bfloat16
+ ) / jnp.sqrt(args.hidden_in)
+ bias_1 = jax.random.normal(bias_1_rng, (1, args.hidden_out), dtype=jnp.bfloat16)
+ weight_2 = jax.random.normal(
+ weight_2_rng, (args.hidden_out, args.hidden_in), dtype=jnp.bfloat16
+ ) / jnp.sqrt(args.hidden_out)
+ bias_2 = jax.random.normal(bias_2_rng, (args.hidden_in,), dtype=jnp.bfloat16)
+ gamma = jax.random.normal(gamma_rng, (args.hidden_in,), dtype=jnp.bfloat16) / jnp.sqrt(
+ args.hidden_in
+ )
+ collective_op_set_1 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.ALL_GATHER)
+ collective_op_set_2 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.REDUCE_SCATTER)
+ collective_op_sets = (collective_op_set_1, collective_op_set_2)
+ noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set)
+
+ with mesh, fp8_autocast(
+ enabled=False,
+ fp8_recipe=None,
+ mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS),
+ ):
+ # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast
+ axis_rules = flax.linen.get_logical_axis_rules()
+ axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS))
+ te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules)
+ with flax.linen.logical_axis_rules(te_extended_axis_rules):
+ x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding = (
+ _get_operand_sharding(mesh)
+ )
+ x_sharded = jax.device_put(x, x_sharding)
+ weight_1_sharded = jax.device_put(weight_1, weight_1_sharding)
+ bias_1_sharded = jax.device_put(bias_1, bias_1_sharding)
+ weight_2_sharded = jax.device_put(weight_2, weight_2_sharding)
+ bias_2_sharded = jax.device_put(bias_2, bias_2_sharding)
+
+ input_1_axes, weight_1_axes, _, input_2_axes, weight_2_axes, _ = _get_logical_axes()
+ ref_output, ref_grads = _value_and_grad_layernorm_mlp(
+ x_sharded,
+ weight_1_sharded,
+ bias_1_sharded,
+ weight_2_sharded,
+ bias_2_sharded,
+ gamma,
+ input_1_axes,
+ input_2_axes,
+ weight_1_axes,
+ weight_2_axes,
+ noop_collective_op_sets,
+ )
+ output, sharded_grads = _value_and_grad_layernorm_mlp(
+ x_sharded,
+ weight_1_sharded,
+ bias_1_sharded,
+ weight_2_sharded,
+ bias_2_sharded,
+ gamma,
+ input_1_axes,
+ input_2_axes,
+ weight_1_axes,
+ weight_2_axes,
+ collective_op_sets,
+ )
+ jax.block_until_ready(ref_output)
+ jax.block_until_ready(output)
+ gathered_grads = []
+ gathered_ref_grads = []
+ for ref_grad, grad in zip(ref_grads, sharded_grads):
+ gathered_grads.append(
+ jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None)))
+ )
+ gathered_ref_grads.append(
+ jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None)))
+ )
+ jax.block_until_ready(gathered_grads)
+ jax.block_until_ready(gathered_ref_grads)
+
+ if args.enable_result_check and args.process_id == 0:
+ assert_allclose(ref_output, output, dtype=jnp.bfloat16)
+ for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads):
+ assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16)
+
+
+class TestCollectiveLayerNormMLPGradient(unittest.TestCase):
+ """Collective Dense Gradient unittests"""
+
+ def setUp(self):
+ self.args = cgemm_parser(
+ "Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism"
+ ).parse_args([])
+ self.args.coordinator_address = self.coordinator_address
+ self.args.num_processes = self.num_processes
+ self.args.process_id = self.process_id
+ self.args.local_device_ids = self.local_device_ids
+ self.args.num_devices_per_process = self.num_devices_per_process
+ self.args.enable_data_parallel = True
+ self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
+ _initialize_distributed(self.args)
+ # Create mesh once for all tests
+ self.mesh = _create_mesh(self.args)
+ jax.sharding.set_mesh(self.mesh)
+ self.args.enable_result_check = True
+ os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1"
+
+ def tearDown(self):
+ os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None)
+
+ def test_te_bf16_layernorm_mlp_grad(self):
+ """Test Collective Dense Gradient with AllGather"""
+ run_layernorm_mlp_grad_tests(self.args, self.mesh)
+
+
+if __name__ == "__main__":
+ import sys
+
+ if len(sys.argv) < 7: # Need at least the 3 required distributed args
+ print("Error: This script requires distributed initialization arguments.")
+ print(
+ "Usage: python test_layernorm_mlp_grad.py --coordinator-address "
+ " --num-processes --process-id [--local-device-ids ] [other args]"
+ )
+ print(
+ "Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234"
+ " --num-processes 4 --process-id 0"
+ )
+ print(
+ "Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234"
+ " --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3"
+ )
+ sys.exit(1)
+
+ args = cgemm_parser(
+ "Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism"
+ ).parse_args([])
+ _initialize_distributed(args)
+ run_layernorm_mlp_grad_tests(args, mesh=None)
diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh
index 2a1ac0f8f..2a979e177 100644
--- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh
+++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh
@@ -15,11 +15,37 @@ TEST_CASES=(
"test_te_current_scaling_fp8_shardy"
)
+: ${TE_PATH:=/opt/transformerengine}
+: ${XML_LOG_DIR:=/logs}
+mkdir -p "$XML_LOG_DIR"
+
echo
echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***"
HAS_FAILURE=0 # Global failure flag
+PIDS=() # Array to store all process PIDs
+
+# Cleanup function to kill all processes
+cleanup() {
+ for pid in "${PIDS[@]}"; do
+ if kill -0 "$pid" 2>/dev/null; then
+ echo "Killing process $pid"
+ kill -TERM "$pid" 2>/dev/null || true
+ fi
+ done
+ # Wait a bit and force kill if needed
+ sleep 2
+ for pid in "${PIDS[@]}"; do
+ if kill -0 "$pid" 2>/dev/null; then
+ echo "Force killing process $pid"
+ kill -KILL "$pid" 2>/dev/null || true
+ fi
+ done
+}
+
+# Set up signal handlers to cleanup on exit
+trap cleanup EXIT INT TERM
# Run each test case across all GPUs
for TEST_CASE in "${TEST_CASES[@]}"; do
echo
@@ -29,25 +55,40 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
# Define output file for logs
LOG_FILE="${TEST_CASE}_gpu_${i}.log"
- # Run pytest and redirect stdout and stderr to the log file
- pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
- -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
- --num-process=$NUM_GPUS \
- --process-id=$i > "$LOG_FILE" 2>&1 &
- done
+ # For process 0: show live output AND save to log file using tee
+ if [ $i -eq 0 ]; then
+ echo "=== Live output from process 0 ==="
+ pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
+ -vs --junitxml=$XML_LOG_DIR/multiprocessing_encoder_${TEST_CASE}.xml \
+ "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
+ --num-process=$NUM_GPUS \
+ --process-id=$i 2>&1 | tee "$LOG_FILE" &
+ PID=$!
+ PIDS+=($PID)
+ else
+ pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
+ -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
+ --num-process=$NUM_GPUS \
+ --process-id=$i > "$LOG_FILE" 2>&1 &
+ PID=$!
+ PIDS+=($PID)
+ fi
+ done
# Wait for the process to finish
wait
- tail -n +7 "${TEST_CASE}_gpu_0.log"
# Check and print the log content accordingly
if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED"
+ elif grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then
+ echo "... $TEST_CASE FAILED"
+ HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE PASSED"
else
+ echo "... $TEST_CASE INVALID"
HAS_FAILURE=1
- echo "... $TEST_CASE FAILED"
fi
# Remove the log file after processing it
@@ -56,4 +97,8 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
done
wait
+
+# Final cleanup (trap will also call cleanup on exit)
+cleanup
+
exit $HAS_FAILURE
diff --git a/pyproject.toml b/pyproject.toml
index c4df4aecc..80110af8a 100755
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,7 +5,7 @@
# See LICENSE for license information.
[build-system]
-requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax", "flax>=0.7.1"]
+requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "nvidia-mathdx==25.1.1", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh
index d9c46347f..ae45f398e 100644
--- a/qa/L0_jax_distributed_unittest/test.sh
+++ b/qa/L0_jax_distributed_unittest/test.sh
@@ -29,6 +29,10 @@ wait
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
+wait
+
+TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh"
+wait
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh
index e4a3f4630..cb097d492 100644
--- a/qa/L0_jax_unittest/test.sh
+++ b/qa/L0_jax_unittest/test.sh
@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
-NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
+NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder_without_custom_call.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh
index b4bf0a024..7f19dda67 100644
--- a/qa/L0_pytorch_debug_unittest/test.sh
+++ b/qa/L0_pytorch_debug_unittest/test.sh
@@ -7,6 +7,8 @@
: ${TE_PATH:=/opt/transformerengine}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
: ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/}
+: ${XML_LOG_DIR:=/logs}
+mkdir -p "$XML_LOG_DIR"
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
@@ -20,17 +22,16 @@ pip uninstall -y nvdlfw-inspect
pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git
pip install pytest==8.2.1
-pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
-pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
-pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
-pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
-NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
-pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
-pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
+pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
+pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
+pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
+pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
+NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
+pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
# standard sanity and numerics tests with initialized debug
-NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
-NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
+NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
+NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
exit $FAIL
diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh
index 394273ca4..cdf0df888 100644
--- a/qa/L0_pytorch_unittest/test.sh
+++ b/qa/L0_pytorch_unittest/test.sh
@@ -31,6 +31,7 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
+python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh
index 7f061d222..e698e997a 100644
--- a/qa/L1_pytorch_distributed_unittest/test.sh
+++ b/qa/L1_pytorch_distributed_unittest/test.sh
@@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
+python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
@@ -47,9 +48,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
-pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
+pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
# standard numerics tests with initialized debug
-NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
+NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh
index 1486d5097..7fce13a3d 100644
--- a/qa/L1_pytorch_onnx_unittest/test.sh
+++ b/qa/L1_pytorch_onnx_unittest/test.sh
@@ -3,9 +3,11 @@
# See LICENSE for license information.
-pip3 install onnxruntime==1.20.1
-pip3 install onnxruntime_extensions==0.13.0
+pip3 install onnxruntime
+pip3 install onnxruntime_extensions
: ${TE_PATH:=/opt/transformerengine}
+: ${XML_LOG_DIR:=/logs}
+mkdir -p "$XML_LOG_DIR"
-python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py
+python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py
diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt
index 46bcf4242..c9a7a03e8 100644
--- a/tests/cpp/operator/CMakeLists.txt
+++ b/tests/cpp/operator/CMakeLists.txt
@@ -32,7 +32,8 @@ list(APPEND test_cuda_sources
../test_common.cu)
if(USE_CUDA)
list(APPEND test_cuda_sources
- test_cast_float8blockwise.cu)
+ test_cast_float8blockwise.cu
+ test_cast_nvfp4_transpose.cu)
else()
list(APPEND test_cuda_sources
test_cublaslt_gemm.cu)
@@ -70,6 +71,13 @@ else()
add_executable(test_operator ${test_hip_sources})
endif()
+# Add profiling and debug flags for CUDA compilation
+set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Generate line info for device code
+set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g") # Add debug symbols for host code
+set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") # Add info about registers usage
+# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping
+
+# Find required packages
find_package(OpenMP REQUIRED)
if(USE_CUDA)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn)
diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu
index b635dc00b..f1780cbe0 100644
--- a/tests/cpp/operator/test_cast_mxfp8.cu
+++ b/tests/cpp/operator/test_cast_mxfp8.cu
@@ -86,6 +86,7 @@ void compute_ref(const ProcessingMethod processing_method,
// Cache computations
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
+
const size_t idx = i * cols + j;
const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
@@ -321,7 +322,7 @@ void performTest_x1(const ProcessingMethod processing_method,
std::vector mismatches_scales_indices;
size_t mismatches_scales = 0;
- compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
+ compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales_indices, mismatches_scales,
scale_diff_abs_tolerance,
@@ -506,7 +507,7 @@ void performTest_x2(const ProcessingMethod processing_method,
std::vector mismatches_scales_indices_rowwise;
size_t mismatches_scales_rowwise = 0;
- compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(),
+ compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise,
mismatches_scales_indices_rowwise, mismatches_scales_rowwise,
@@ -516,7 +517,7 @@ void performTest_x2(const ProcessingMethod processing_method,
std::vector mismatches_scales_indices_colwise;
size_t mismatches_scales_colwise = 0;
- compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(),
+ compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise,
mismatches_scales_indices_colwise, mismatches_scales_colwise,
diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
index 52180786d..38a9c8d29 100644
--- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
+++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
@@ -274,21 +274,22 @@ void performTest_x1(const size_t rows,
? output.rowwise_cpu_scale_inv_ptr()
: output.columnwise_cpu_scale_inv_ptr();
if (rowwise) {
- compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
- unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
+ compare_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
+ unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales_indices,
- mismatches_scales,
- scale_diff_abs_tolerance,
- abs_tolerable_mismatches_limit,
- rel_tolerable_mismatches_limit);
+ mismatches_scales,
+ scale_diff_abs_tolerance,
+ abs_tolerable_mismatches_limit,
+ rel_tolerable_mismatches_limit);
} else {
- compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
- unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
+ compare_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
+ unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
mismatches_scales_indices,
- mismatches_scales,
- scale_diff_abs_tolerance,
- abs_tolerable_mismatches_limit,
- rel_tolerable_mismatches_limit);
+ mismatches_scales,
+ scale_diff_abs_tolerance,
+ abs_tolerable_mismatches_limit,
+ rel_tolerable_mismatches_limit);
+
}
#ifdef __HIP_PLATFORM_AMD__
@@ -396,22 +397,22 @@ void performTest_x2(const size_t rows,
std::vector mismatches_scales_indices_rowwise;
size_t mismatches_scales_rowwise = 0;
- compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(),
- ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
- unpadded_blocks_X_rowwise, scales_stride_rowwise,
- mismatches_scales_indices_rowwise, mismatches_scales_rowwise,
- scale_diff_abs_tolerance,
- abs_tolerable_mismatches_limit,
- rel_tolerable_mismatches_limit);
+ compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(),
+ ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
+ unpadded_blocks_X_rowwise, scales_stride_rowwise,
+ mismatches_scales_indices_rowwise, mismatches_scales_rowwise,
+ scale_diff_abs_tolerance,
+ abs_tolerable_mismatches_limit,
+ rel_tolerable_mismatches_limit);
std::vector mismatches_scales_indices_colwise;
size_t mismatches_scales_colwise = 0;
- compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(),
- ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
- unpadded_blocks_X_colwise, scales_stride_colwise,
- mismatches_scales_indices_colwise, mismatches_scales_colwise,
- scale_diff_abs_tolerance,
- abs_tolerable_mismatches_limit,
- rel_tolerable_mismatches_limit);
+ compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(),
+ ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
+ unpadded_blocks_X_colwise, scales_stride_colwise,
+ mismatches_scales_indices_colwise, mismatches_scales_colwise,
+ scale_diff_abs_tolerance,
+ abs_tolerable_mismatches_limit,
+ rel_tolerable_mismatches_limit);
#ifdef __HIP_PLATFORM_AMD__
if (::testing::Test::HasFatalFailure()) return;
diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu
new file mode 100644
index 000000000..e905a0064
--- /dev/null
+++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu
@@ -0,0 +1,741 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include
+#include
+#include
+#include
+#include
+
+#include
+#include
+#include "../test_common.h"
+#include "transformer_engine/transformer_engine.h"
+#include
+
+using namespace transformer_engine;
+using namespace test;
+
+namespace {
+
+enum ActivationType {
+ Identity,
+ GeLU,
+ SiLU,
+ ReLU,
+ QGeLU,
+ SReLU
+};
+
+double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) {
+ const __half2_raw raw_truncated_to_fp4e2m1_pair =
+ __nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1);
+
+ const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair);
+ const double truncated_to_fp4e2m1_x = static_cast(truncated_to_fp4e2m1_pair.x);
+ const double truncated_to_fp4e2m1_y = static_cast(truncated_to_fp4e2m1_pair.y);
+ return {truncated_to_fp4e2m1_x, truncated_to_fp4e2m1_y};
+}
+
+template
+std::vector create_transpose(const InputType* const input, const size_t rows, size_t cols) {
+ std::vector input_t(cols * rows);
+ for (size_t i = 0; i < rows; ++i) {
+ for (size_t j = 0; j < cols; ++j) {
+ const size_t idx = i * cols + j;
+ const size_t idx_t = j * rows + i;
+ input_t[idx_t] = input[idx];
+ }
+ }
+ return input_t;
+}
+
+// Compute the global encode scale factor for a given global amax
+float compute_global_encode_scaling_factor_FP4(const float global_amax) {
+ constexpr float fp8_max = 448.0f; // 448.0f;
+ constexpr float fp4_max = 6.0f; // 6.0f;
+ float global_encode_scale = fp8_max * fp4_max / global_amax;
+ // If scale is infinity, return max value of float32
+ global_encode_scale = fminf(global_encode_scale, Numeric_Traits::maxNorm);
+ // If global amax is 0 or infinity, return 1
+ if (global_amax == 0.0f || global_encode_scale == 0.0f) {
+ return 1.0f;
+ }
+ return global_encode_scale;
+}
+
+// 1D Scaling: Original implementation with 1x16 blocks
+template
+void quantize_nvfp4_1d(float (*OP)(const float),
+ const InputType* const input,
+ fp4e2m1x2* const output,
+ fp8e4m3* const scales,
+ const size_t rows,
+ const size_t cols,
+ const size_t scales_stride,
+ const float global_amax) {
+
+ // Compute a global encoding/decoding scaling factor for all S_dec_b
+ const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
+
+ constexpr size_t block_size_X = 16;
+ const size_t blocks_X = divide_round_up(cols, block_size_X);
+
+ std::array cache_buffer;
+ for (size_t i = 0; i < block_size_X; ++i) {
+ cache_buffer[i] = 0.0f;
+ }
+
+ for (size_t i = 0; i < rows; ++i) {
+ for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
+ const size_t j_min = block_X * block_size_X;
+ const size_t j_max = j_min + block_size_X;
+
+ // Find block amax
+ float block_amax = 0.0f;
+ for (size_t j = j_min; j < j_max; ++j) {
+ const size_t idx = i * cols + j;
+ const size_t cache_idx = j - j_min;
+
+ const float input_elt = static_cast(input[idx]);
+ const float act_elt = OP(input_elt);
+
+ // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32
+ const float elt = static_cast(static_cast(act_elt));
+ cache_buffer[cache_idx] = elt;
+ block_amax = std::max(block_amax, std::abs(elt));
+ }
+
+ // 2. Compute E4M3 scaling factor
+ // Compute per-block encoding/decoding scaling factor
+ const float S_dec_b = block_amax / 6.0f;
+
+ // Scale & Store per-block decoding scaling factor
+ const float S_dec_b_fp8 = S_dec_b * S_enc;
+
+ // Compute "correct" per-block encoding scaling factor
+ const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8;
+
+ const size_t scale_idx = i * scales_stride + block_X;
+ scales[scale_idx] = static_cast(S_dec_b_fp8);
+ const float scale_reciprocal = S_enc_b_fp8;
+
+ for (size_t j = j_min; j < j_max; j += 2) {
+ const int idx_pair = (i * cols + j) / 2;
+ const int cache_idx_x = j - j_min;
+ const int cache_idx_y = cache_idx_x + 1;
+ const float cached_x = cache_buffer[cache_idx_x];
+ const float cached_y = cache_buffer[cache_idx_y];
+ const float scaled_elt_x = cached_x * scale_reciprocal;
+ const float scaled_elt_y = cached_y * scale_reciprocal;
+ const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y};
+
+ fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair);
+ output[idx_pair] = casted_to_e2m1_pair;
+
+ // const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair);
+ }
+ }
+ }
+}
+
+// Compute 2D mathematical scaling factors (8x8 for 128x128 input)
+template
+void compute_2d_mathematical_scales(float (*OP)(const float),
+ const InputType* const input,
+ const size_t rows,
+ const size_t cols,
+ const float global_amax,
+ std::vector>& math_scales) {
+
+ const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
+ constexpr size_t block_size_Y = 16;
+ constexpr size_t block_size_X = 16;
+ const size_t blocks_Y = divide_round_up(rows, block_size_Y);
+ const size_t blocks_X = divide_round_up(cols, block_size_X);
+
+ math_scales.resize(blocks_Y, std::vector(blocks_X));
+
+ for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) {
+ for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
+ const size_t i_min = block_Y * block_size_Y;
+ const size_t i_max = std::min(i_min + block_size_Y, rows);
+ const size_t j_min = block_X * block_size_X;
+ const size_t j_max = std::min(j_min + block_size_X, cols);
+
+ // Find 2D block amax over entire 16x16 region
+ float block_amax = 0.0f;
+ for (size_t i = i_min; i < i_max; ++i) {
+ for (size_t j = j_min; j < j_max; ++j) {
+ const size_t idx = i * cols + j;
+ const float input_elt = static_cast(input[idx]);
+ const float act_elt = OP(input_elt);
+ const float elt = static_cast(static_cast(act_elt));
+ block_amax = std::max(block_amax, std::abs(elt));
+ }
+ }
+
+ // Compute E4M3 scaling factor for this 16x16 block
+ const float S_dec_b = block_amax / 6.0f;
+ const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc);
+ math_scales[block_Y][block_X] = S_dec_b_fp8;
+ }
+ }
+}
+
+// 2D Scaling: NEW implementation with proper replication
+template
+void quantize_nvfp4_2d(float (*OP)(const float),
+ const InputType* const input,
+ fp4e2m1x2* const output,
+ fp8e4m3* const scales,
+ const size_t rows,
+ const size_t cols,
+ const size_t scales_stride,
+ const float global_amax) {
+
+ // Step 1: Compute mathematical 8x8 scaling factors
+ std::vector> math_scales;
+ compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales);
+
+ const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax);
+ constexpr size_t block_size_Y = 16;
+ constexpr size_t block_size_X = 16;
+ const size_t blocks_Y = divide_round_up(rows, block_size_Y);
+ const size_t blocks_X = divide_round_up(cols, block_size_X);
+
+ // Step 2: Replicate scaling factors row-wise (128×8 storage) - only if scales is not nullptr
+ if (scales != nullptr) {
+ // Each of the 128 rows gets scaling factors from its corresponding 16×16 block
+ for (size_t i = 0; i < rows; ++i) {
+ const size_t block_Y = i / block_size_Y;
+ for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
+ const size_t scale_idx = i * scales_stride + block_X;
+ scales[scale_idx] = math_scales[block_Y][block_X];
+ }
+ }
+ }
+
+ // Step 3: Apply quantization using the mathematical scaling factors
+ std::array, block_size_Y> cache_buffer;
+
+ for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) {
+ for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
+ const size_t i_min = block_Y * block_size_Y;
+ const size_t i_max = std::min(i_min + block_size_Y, rows);
+ const size_t j_min = block_X * block_size_X;
+ const size_t j_max = std::min(j_min + block_size_X, cols);
+
+ // Get the scaling factor for this block
+ const float S_dec_b_fp8 = static_cast(math_scales[block_Y][block_X]);
+ const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8;
+ const float scale_reciprocal = S_enc_b_fp8;
+
+ // Process and cache data for this 16x16 block
+ for (size_t i = i_min; i < i_max; ++i) {
+ for (size_t j = j_min; j < j_max; ++j) {
+ const size_t idx = i * cols + j;
+ const size_t cache_idx_y = i - i_min;
+ const size_t cache_idx_x = j - j_min;
+
+ const float input_elt = static_cast(input[idx]);
+ const float act_elt = OP(input_elt);
+ const float elt = static_cast(static_cast(act_elt));
+ cache_buffer[cache_idx_y][cache_idx_x] = elt;
+ }
+ }
+
+ // Apply scaling to all elements in this 16x16 block
+ for (size_t i = i_min; i < i_max; ++i) {
+ for (size_t j = j_min; j < j_max; j += 2) {
+ const int idx_pair = (i * cols + j) / 2;
+ const size_t cache_idx_y = i - i_min;
+ const size_t cache_idx_x1 = j - j_min;
+ const size_t cache_idx_x2 = std::min(cache_idx_x1 + 1, block_size_X - 1);
+
+ const float cached_x = cache_buffer[cache_idx_y][cache_idx_x1];
+ const float cached_y = ((j + 1) < j_max && cache_idx_x2 < block_size_X) ?
+ cache_buffer[cache_idx_y][cache_idx_x2] : 0.0f;
+
+ const float scaled_elt_x = cached_x * scale_reciprocal;
+ const float scaled_elt_y = cached_y * scale_reciprocal;
+ const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y};
+
+ fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair);
+ output[idx_pair] = casted_to_e2m1_pair;
+ }
+ }
+ }
+ }
+}
+
+// Wrapper function that calls appropriate implementation based on 2D flag
+template
+void quantize_nvfp4(float (*OP)(const float),
+ const InputType* const input,
+ fp4e2m1x2* const output,
+ fp8e4m3* const scales,
+ const size_t rows,
+ const size_t cols,
+ const size_t scales_stride,
+ const float global_amax,
+ const bool use_2d_quantization = false) {
+ if (use_2d_quantization) {
+ quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax);
+ } else {
+ quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax);
+ }
+}
+
+template
+void compute_ref(float (*OP)(const float),
+ const InputType* input,
+ fp4e2m1x2* output,
+ fp4e2m1x2* output_t,
+ fp8e4m3* scales,
+ fp8e4m3* scales_t,
+ const float global_amax,
+ const size_t rows,
+ const size_t cols,
+ const size_t scales_stride,
+ const size_t scales_stride_t,
+ const bool use_2d_quantization = false)
+{
+ std::vector input_t = create_transpose(input, rows, cols);
+
+ if (use_2d_quantization) {
+ // Step 1: Compute mathematical 8×8 scaling factors
+ std::vector> math_scales;
+ compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales);
+
+ constexpr size_t block_size_Y = 16;
+ constexpr size_t block_size_X = 16;
+ const size_t blocks_Y = divide_round_up(rows, block_size_Y);
+ const size_t blocks_X = divide_round_up(cols, block_size_X);
+
+ // Step 2: Generate scales (128×8) by replicating row-wise
+ for (size_t i = 0; i < rows; ++i) {
+ const size_t block_Y = i / block_size_Y;
+ for (size_t block_X = 0; block_X < blocks_X; ++block_X) {
+ const size_t scale_idx = i * scales_stride + block_X;
+ scales[scale_idx] = math_scales[block_Y][block_X];
+ }
+ }
+
+ // Step 3: Generate scales_t (128×8) with proper transposed block mapping
+ for (size_t i = 0; i < cols; ++i) { // cols = 128, which becomes rows of transposed data
+ const size_t block_X_orig = i / block_size_X; // i was column index in original, so maps to block_X
+ for (size_t block_Y_new = 0; block_Y_new < blocks_Y; ++block_Y_new) { // block in transposed coordinate
+ const size_t scale_idx = i * scales_stride_t + block_Y_new;
+ scales_t[scale_idx] = math_scales[block_Y_new][block_X_orig];
+ }
+ }
+
+ // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d
+ // (This part processes the actual FP4 data using the mathematical scaling factors)
+ quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled
+ quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled
+
+ } else {
+ quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization);
+ quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization);
+ }
+}
+
+void compare_nvfp4_tensors(const std::string& name,
+ const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
+ const int rows, const int cols,
+ double atol = 1e-5, double rtol = 1e-8) {
+ std::vector mismatch_messages;
+ size_t total_mismatches = 0;
+
+ for (int i = 0; i < rows; ++i) {
+ for (int j = 0; j < cols; j += 2) {
+ const int idx = i * cols + j;
+ double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2]));
+ double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2]));
+
+ for (int k = 0; k < 2; ++k) {
+ const double t = (k == 0 ? test_data_pair.x : test_data_pair.y);
+ const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y);
+
+ bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
+ /* For Float32 the floating point comparison is enough to error out */
+ bool assertion = false;
+ if (mismatch && !assertion) {
+ /* Check if it is just a failure of round to nearest choosing different
+ side of the real value */
+ const double mean = (t + r) / 2;
+ const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
+ const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
+ const double cast_mean_p = static_cast(static_cast(mean_p));
+ const double cast_mean_m = static_cast(static_cast(mean_m));
+ assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
+ }
+ if (assertion) {
+ total_mismatches++;
+ std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " +
+ std::to_string(t) + " vs " + std::to_string(r) +
+ " (abs_diff: " + std::to_string(fabs(t - r)) +
+ ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")";
+ mismatch_messages.push_back(msg);
+
+ // Optional: limit number of detailed messages to avoid overwhelming output
+ if (mismatch_messages.size() <= 100) {
+ std::cout << "Error in tensor " << name << ": " << msg << std::endl;
+ }
+ }
+ }
+ }
+ }
+
+ // Always report summary - either success or failure
+ std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl;
+ std::cout << "Total elements checked: " << (rows * cols) << std::endl;
+
+ if (total_mismatches > 0) {
+ std::cout << "STATUS: FAILED for output" << std::endl;
+ std::cout << "Total mismatches found: " << total_mismatches << std::endl;
+ std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl;
+ if (mismatch_messages.size() > 100) {
+ std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl;
+ }
+ std::cout << "============================" << std::endl;
+
+ GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name;
+ } else {
+ std::cout << "STATUS: PASSED for output" << std::endl;
+ std::cout << "All elements match within tolerance!" << std::endl;
+ std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl;
+ std::cout << "============================" << std::endl;
+ }
+}
+
+// Optional: Function to dump tensor data to files for detailed analysis
+void dump_nvfp4_tensor_data(const std::string& prefix,
+ const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
+ const int rows, const int cols) {
+ std::string test_file = prefix + "_test.txt";
+ std::string ref_file = prefix + "_ref.txt";
+ std::string diff_file = prefix + "_diff.txt";
+
+ std::ofstream test_out(test_file);
+ std::ofstream ref_out(ref_file);
+ std::ofstream diff_out(diff_file);
+
+ if (test_out.is_open() && ref_out.is_open() && diff_out.is_open()) {
+ for (int i = 0; i < rows; ++i) {
+ for (int j = 0; j < cols; j += 2) {
+ const int idx = i * cols + j;
+ double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2]));
+ double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2]));
+
+ for (int k = 0; k < 2; ++k) {
+ const double t = (k == 0 ? test_data_pair.x : test_data_pair.y);
+ const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y);
+ const int pos = idx + k;
+
+ test_out << "pos[" << pos << "] = " << t << std::endl;
+ ref_out << "pos[" << pos << "] = " << r << std::endl;
+ diff_out << "pos[" << pos << "] test=" << t << " ref=" << r
+ << " abs_diff=" << fabs(t - r)
+ << " rel_diff=" << (r == 0 ? 0.0 : fabs((t - r) / r)) << std::endl;
+ }
+ }
+ }
+ std::cout << "DEBUG: Dumped tensor data to files: " << test_file << ", " << ref_file << ", " << diff_file << std::endl;
+ } else {
+ std::cout << "WARNING: Could not open files for tensor data dump" << std::endl;
+ }
+}
+
+void print_detailed_tensor_comparison(const std::string& name,
+ const fp4e2m1 *test_data, const fp4e2m1 *ref_data,
+ const int rows, const int cols) {
+ printf("\n=== DETAILED COMPARISON for %s (%d×%d = %d elements) ===\n",
+ name.c_str(), rows, cols, rows * cols);
+
+ const int total_elements = rows * cols;
+ const int check_count = 128;
+
+ printf("--- FIRST %d ELEMENTS ---\n", check_count);
+ printf("Index | Test_Value | Ref_Value | Match\n");
+ printf("------|---------------|---------------|-------\n");
+ for (int i = 0; i < std::min(check_count, total_elements); ++i) {
+ double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2]));
+ double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2]));
+
+ double t = (i % 2 == 0) ? test_pair.x : test_pair.y;
+ double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y;
+ bool match = (fabs(t - r) < 1e-6);
+
+ printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗");
+ }
+
+ if (total_elements > 2 * check_count) {
+ printf("\n--- LAST %d ELEMENTS ---\n", check_count);
+ printf("Index | Test_Value | Ref_Value | Match\n");
+ printf("------|---------------|---------------|-------\n");
+ for (int i = total_elements - check_count; i < total_elements; ++i) {
+ double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2]));
+ double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2]));
+
+ double t = (i % 2 == 0) ? test_pair.x : test_pair.y;
+ double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y;
+ bool match = (fabs(t - r) < 1e-6);
+
+ printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗");
+ }
+ }
+ printf("==================================\n");
+}
+
+void compareResults_nvfp4(const Tensor &test,
+ const void *ref, const void *ref_t, const int rows, const int cols,
+ double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) {
+ if (if_on_gpus) test.to_cpu();
+
+ const fp4e2m1 *test_data = test.rowwise_cpu_dptr();
+ const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr();
+ const fp4e2m1 *ref_data = reinterpret_cast(ref);
+ const fp4e2m1 *ref_data_t = reinterpret_cast(ref_t);
+
+ // Print detailed element-by-element comparison
+ // print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols);
+ // print_detailed_tensor_comparison("output_t", test_data_t, ref_data_t, cols, rows);
+
+ // Optionally dump tensor data to files for detailed analysis
+ if (dump_data) {
+ dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols);
+ dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows);
+ }
+
+ compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol);
+ compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol);
+}
+
+template
+void performTest(float (*OP)(const float),
+ const std::vector& shape) {
+ using namespace test;
+
+ DType itype = TypeInfo::dtype;
+ DType otype = DType::kFloat4E2M1;
+
+ const size_t rows = first_dimension(shape);
+ const size_t cols = last_dimension(shape);
+
+ // Use get_scale_tensor_dims for NVFP4 scale tensor dimensions
+ // Now that CheckScaleTensorShape is fixed, this should work correctly
+ const std::array scale_dims = get_scale_tensor_dims(rows, cols, 1, 16);
+ const std::array scale_dims_t = get_scale_tensor_dims(cols, rows, 1, 16);
+
+ const size_t unpadded_blocks_Y = scale_dims[0];
+ const size_t unpadded_blocks_X = scale_dims[1];
+ const size_t blocks_Y = scale_dims[2];
+ const size_t blocks_X = scale_dims[3];
+ const size_t scales_stride = blocks_X;
+
+ const size_t unpadded_blocks_Y_t = scale_dims_t[0];
+ const size_t unpadded_blocks_X_t = scale_dims_t[1];
+ const size_t blocks_Y_t = scale_dims_t[2];
+ const size_t blocks_X_t = scale_dims_t[3];
+ const size_t scales_stride_t = blocks_X_t;
+
+ Tensor input("input", shape, itype);
+ Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING);
+
+ std::unique_ptr ref_output = std::make_unique(rows * (cols / 2));
+ std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2));
+ std::unique_ptr ref_scales = std::make_unique(blocks_Y * blocks_X);
+ std::unique_ptr ref_scales_t = std::make_unique(blocks_Y_t * blocks_X_t);
+
+ fillCase(&input, InputsFillCase::uniform);
+
+ // Find global amax
+ float amax = 0.0f;
+ const InputType* input_dptr = input.rowwise_cpu_dptr();
+ for (size_t i = 0; i < rows; ++i) {
+ for (size_t j = 0; j < cols; ++j) {
+ const size_t idx = i * cols + j;
+ amax = fmaxf(amax, static_cast(input_dptr[idx]));
+ }
+ }
+ // Set 2nd stage NVFP4 scaling factor
+ output.set_scale(amax);
+
+ bool use_2d_quantization = false;
+
+ compute_ref(OP,
+ input.rowwise_cpu_dptr(),
+ ref_output.get(),
+ ref_output_t.get(),
+ ref_scales.get(),
+ ref_scales_t.get(),
+ output.scale(),
+ rows,
+ cols,
+ scales_stride,
+ scales_stride_t,
+ use_2d_quantization);
+
+ QuantizationConfigWrapper quant_config;
+
+ // Initialize stochastic rounding
+ Tensor rng_state("rng_state", std::vector{2}, DType::kInt64);
+ rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed
+ rng_state.rowwise_cpu_dptr()[1] = 321; // rng_sequence
+ rng_state.from_cpu();
+ quant_config.set_stochastic_rounding(false);
+ quant_config.set_rng_state(rng_state.data());
+
+ // Set 2D quantization based on compile-time flag
+ quant_config.set_nvfp4_2d_quantization(use_2d_quantization);
+
+ // Call appropriate function based on operation type
+ // Activation functions take 3 parameters (input, output, stream)
+ // nvte_quantize_v2 takes 4 parameters (input, output, quant_config, stream)
+ if (OP == &gelu) {
+ nvte_gelu(input.data(), output.data(), 0);
+ } else if (OP == &silu) {
+ nvte_silu(input.data(), output.data(), 0);
+ } else if (OP == &relu) {
+ nvte_relu(input.data(), output.data(), 0);
+ } else if (OP == &qgelu) {
+ nvte_qgelu(input.data(), output.data(), 0);
+ } else if (OP == &srelu) {
+ nvte_srelu(input.data(), output.data(), 0);
+ } else {
+ nvte_quantize_v2(input.data(), output.data(), quant_config, 0);
+ }
+
+ cudaDeviceSynchronize();
+ auto err = cudaGetLastError();
+ if (err != cudaSuccess) {
+ printf("DEBUG: CUDA error detected: %s\n", cudaGetErrorString(err));
+ }
+ ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
+
+ const double atol = 0.05;
+ const double rtol = 0.1;
+
+ // Set dump_data=true to enable dumping tensor data to files for analysis
+ compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false);
+
+ const fp8e4m3* kernel_scales = output.rowwise_cpu_scale_inv_ptr();
+ const fp8e4m3* ref_scales_ptr = ref_scales.get();
+ const fp8e4m3* kernel_scales_t = output.columnwise_cpu_scale_inv_ptr();
+ const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get();
+
+ size_t scale_mismatches_num = 0;
+ compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(),
+ ref_scales.get(),
+ unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
+ scale_mismatches_num);
+
+ compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(),
+ ref_scales_t.get(),
+ unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t,
+ scale_mismatches_num);
+}
+
+std::vector> tensor_dims = {
+ {32, 32},
+ {32, 64},
+ {64, 32},
+ {64, 96},
+ {128, 128},
+ {256, 256},
+ {512, 512},
+ {1024, 1024},
+ {2048, 2048},
+ {128, 256},
+ {8192, 128},
+ {2048, 160},
+ {8, 32, 1024},
+ {16, 8, 4, 512},
+ {1024, 16384},
+ {4096, 13312},
+};
+
+// Only GeLU activation tests are supported
+std::vector Activation_types = {
+ ActivationType::Identity,
+ ActivationType::GeLU,
+ ActivationType::SiLU,
+ ActivationType::ReLU,
+ ActivationType::QGeLU,
+ ActivationType::SReLU,
+};
+
+} // namespace
+
+class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam
+ ,
+ transformer_engine::DType>> {};
+
+TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
+ // Skip tests for pre-Blackwell architectures
+ if (getDeviceComputeCapability() < blackwellComputeCapability) {
+ GTEST_SKIP();
+ }
+
+ using namespace transformer_engine;
+ using namespace test;
+
+ const ActivationType Act_type = std::get<0>(GetParam());
+ const auto tensor_dims = std::get<1>(GetParam());
+ const DType input_type = std::get<2>(GetParam());
+
+ // Skip tests if the input tensor is 1D
+ if (tensor_dims.size() < 2) {
+ GTEST_SKIP();
+ }
+
+ // Forward activations
+ auto OP = &identity;
+ switch (Act_type) {
+ case ActivationType::GeLU: OP = &gelu; break;
+ case ActivationType::SiLU: OP = &silu; break;
+ case ActivationType::ReLU: OP = &relu; break;
+ case ActivationType::QGeLU: OP = &qgelu; break;
+ case ActivationType::SReLU: OP = &srelu; break;
+ }
+
+ TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
+ performTest(OP, tensor_dims);
+ );
+}
+
+std::string to_string(const ActivationType Act_type) {
+ switch (Act_type) {
+ case ActivationType::Identity: return "CAST_ONLY";
+ case ActivationType::GeLU: return "GeLU";
+ case ActivationType::SiLU: return "SiLU";
+ case ActivationType::ReLU: return "ReLU";
+ case ActivationType::QGeLU: return "QGeLU";
+ case ActivationType::SReLU: return "SReLU";
+ default: return "";
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ OperatorTest,
+ FusedCastTransposeNVFP4TestSuite,
+ ::testing::Combine(
+ ::testing::ValuesIn(Activation_types),
+ ::testing::ValuesIn(tensor_dims),
+ ::testing::Values(DType::kBFloat16)),
+ [](const testing::TestParamInfo& info) {
+ std::string name = to_string(std::get<0>(info.param));
+ const auto& shape = std::get<1>(info.param);
+ for ( const auto& s: shape) {
+ name += "X" + std::to_string(s);
+ }
+ name += "X" + test::typeName(std::get<2>(info.param));
+ return name;
+ });
diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu
index 9f926d07b..eebe4108a 100644
--- a/tests/cpp/test_common.cu
+++ b/tests/cpp/test_common.cu
@@ -109,6 +109,10 @@ size_t DIVUP(const size_t &x, const size_t &y){
return (((x) + ((y)-1)) / (y));
}
+size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){
+ return DIVUP(x, y) * y;
+}
+
struct scale_inv_meta {
std::vector shape;
DType type;
@@ -145,21 +149,71 @@ std::pair get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise;
- auto block_alignment = std::vector{128ul, 4ul};
- {
- auto alignment = block_alignment[0];
- auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(1)), alignment) * alignment;
- alignment = block_alignment[1];
- auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(32)), alignment) * alignment;
- ret_rowwise.shape = {scale_dim_0, scale_dim_1};
+ const size_t block_size_X_rowwise = 32;
+ size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
+ size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
+ ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};
+
+ const size_t block_size_Y_colwise = 32;
+ size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
+ size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
+ ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};
+
+ ret_rowwise.type = DType::kFloat8E8M0;
+ ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
+ ret_colwise.type = DType::kFloat8E8M0;
+ ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
+
+ return {ret_rowwise, ret_colwise};
+ }
+ if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
+ std::vector shape_vec;
+ for (size_t i = 0; i < shape.ndim; ++i) {
+ shape_vec.push_back(shape.data[i]);
}
- {
- auto alignment = block_alignment[1];
- auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(32)), alignment) * alignment;
- alignment = block_alignment[0];
- auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(1)), alignment) * alignment;
- ret_colwise.shape = {scale_dim_0, scale_dim_1};
+ size_t first_dim = first_dimension(shape_vec);
+ size_t last_dim = last_dimension(shape_vec);
+
+ NVTE_CHECK(last_dim % 32 == 0);
+ NVTE_CHECK(first_dim % 32 == 0);
+
+ scale_inv_meta ret_rowwise, ret_colwise;
+
+ size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
+ size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise);
+ ret_rowwise.shape = {scale_dim_Y, scale_dim_X};
+
+ size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise);
+ size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise);
+ ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t};
+
+ ret_rowwise.type = DType::kFloat8E4M3;
+ ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
+ ret_colwise.type = DType::kFloat8E4M3;
+ ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3);
+
+ return {ret_rowwise, ret_colwise};
+ }
+ if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
+ std::vector shape_vec;
+ for (size_t i = 0; i < shape.ndim; ++i) {
+ shape_vec.push_back(shape.data[i]);
}
+ size_t first_dim = first_dimension(shape_vec);
+ size_t last_dim = last_dimension(shape_vec);
+
+ scale_inv_meta ret_rowwise, ret_colwise;
+
+ const size_t block_size_X_rowwise = 32;
+ size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise);
+ size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise);
+ ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise};
+
+ const size_t block_size_Y_colwise = 32;
+ size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise);
+ size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise);
+ ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise};
+
ret_rowwise.type = DType::kFloat8E8M0;
ret_colwise.type = DType::kFloat8E8M0;
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
@@ -178,13 +232,13 @@ std::pair get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise;
{
- auto scale_dim_0 = DIVUP(first_dim, static_cast(128));
- auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4;
+ size_t scale_dim_0 = DIVUP(first_dim, 128lu);
+ size_t scale_dim_1 = DIVUP(DIVUP(last_dim, 128lu), 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
- auto scale_dim_0 = DIVUP(last_dim, static_cast(128));
- auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4;
+ size_t scale_dim_0 = DIVUP(last_dim, 128lu);
+ size_t scale_dim_1 = DIVUP(DIVUP(first_dim, 128lu), 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat32;
@@ -204,13 +258,13 @@ std::pair get_scales(const NVTEShape& shape,
scale_inv_meta ret_rowwise, ret_colwise;
{
- auto scale_dim_0 = DIVUP(last_dim, static_cast(128));
- auto scale_dim_1 = DIVUP(first_dim, 4) * 4;
+ size_t scale_dim_0 = DIVUP(last_dim, 128lu);
+ size_t scale_dim_1 = DIVUP(first_dim, 4) * 4;
ret_rowwise.shape = {scale_dim_0, scale_dim_1};
}
{
- auto scale_dim_0 = DIVUP(first_dim, static_cast(128));
- auto scale_dim_1 = DIVUP(last_dim, 4) * 4;
+ size_t scale_dim_0 = DIVUP(first_dim, 128lu);
+ size_t scale_dim_1 = DIVUP(last_dim, 4) * 4;
ret_colwise.shape = {scale_dim_0, scale_dim_1};
}
ret_rowwise.type = DType::kFloat32;
@@ -252,14 +306,15 @@ Tensor::Tensor(const std::string& name,
NVTEShape columnwise_shape = {};
std::vector columnwise_shape_vec;
- if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
+ if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING
+ || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) {
// Transpose when tensor scaling
columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]);
for (size_t i = 0; i < shape.ndim - 1; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
} else {
- // Same shape for MX
+ // Same shape for MX and NVFP4
for (size_t i = 0; i < shape.ndim; ++i) {
columnwise_shape_vec.emplace_back(shape.data[i]);
}
@@ -285,10 +340,13 @@ Tensor::Tensor(const std::string& name,
std::fill_n(cpu_data_columnwise_.get(), total_size, 0);
}
}
- tensor_.set_rowwise_data(dptr_rowwise, type, shape);
- tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape);
- if (isFp8Type(type)) {
+ const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
+ const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type;
+ tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape);
+ tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape);
+
+ if (isFp8Type(type) || isFp4Type(type)) {
if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
(void)cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
(void)cudaMemset(amax, 0, sizeof(float));
@@ -307,13 +365,19 @@ Tensor::Tensor(const std::string& name,
}
if (columnwise) {
tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32,
- std::vector{1});
+ std::vector{1});
columnwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float));
std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0);
}
} else {
- auto [rowwise_scale_meta, colwise_scale_meta] =
- get_scales(normalized_shape, tensor_.scaling_mode());
+ if (scaling_mode == NVTE_NVFP4_1D_SCALING) {
+ // Used for NVFP4 second stage scaling
+ cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*)
+ cudaMemset(scale, 0, sizeof(float));
+ scale_cpu_data_ = std::make_shared(0);
+ tensor_.set_scale(scale, DType::kFloat32, std::vector{1});
+ }
+ auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode());
auto rowwise_scale_size = rowwise_scale_meta.bytes();
auto columnwise_scale_size = colwise_scale_meta.bytes();
auto scale_shape = rowwise_scale_meta.shape;
@@ -348,13 +412,16 @@ void Tensor::to_cpu() const {
cudaMemcpyDeviceToHost);
}
if (columnwise_) {
+ const DType colwise_type = tensor_.dtype();
+
+ const size_t colwise_size = bytes(s, colwise_type);
(void)cudaMemcpy(cpu_data_columnwise_.get(),
- tensor_.get_columnwise_data().data_ptr,
- size,
- cudaMemcpyDeviceToHost);
+ tensor_.get_columnwise_data().data_ptr,
+ colwise_size,
+ cudaMemcpyDeviceToHost);
}
- if (isFp8Type(dtype())) {
- if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
+ if (isFp8Type(dtype()) || isFp4Type(dtype())) {
+ if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) {
if (tensor_.amax() != nullptr){
(void)cudaMemcpy(amax_cpu_data_.get(),
tensor_.amax(),
@@ -366,8 +433,7 @@ void Tensor::to_cpu() const {
sizeof(float),
cudaMemcpyDeviceToHost);
}
- auto [rowwise_scale_meta, colwise_scale_meta] =
- get_scales(s, tensor_.scaling_mode());
+ auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = rowwise_scale_meta.bytes();
(void)cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
@@ -396,15 +462,15 @@ void Tensor::from_cpu() const {
(void)cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
cudaMemcpyHostToDevice);
}
- if (isFp8Type(dtype())) {
- if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
+ if (isFp8Type(dtype()) || isFp4Type(dtype())) {
+ if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
+ || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) {
if (tensor_.amax() != nullptr){
(void)cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
}
(void)cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
}
- auto [rowwise_scale_meta, colwise_scale_meta] =
- get_scales(s, tensor_.scaling_mode());
+ auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode());
if (rowwise_) {
auto scale_size = rowwise_scale_meta.bytes();
(void)cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
@@ -421,7 +487,7 @@ void Tensor::from_cpu() const {
}
void Tensor::set_scale(float scale) {
- if (isFp8Type(dtype())) {
+ if (isFp8Type(dtype()) || isFp4Type(dtype())) {
NVTE_CHECK(scale_cpu_data_);
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
*scale_cpu_data_ = scale;
@@ -431,7 +497,7 @@ void Tensor::set_scale(float scale) {
}
void Tensor::set_scale_inv(float scale_inv) {
- if (isFp8Type(dtype())) {
+ if (isFp8Type(dtype()) || isFp4Type(dtype())) {
if (rowwise_) {
NVTE_CHECK(rowwise_scale_inv_cpu_data_);
}
@@ -439,8 +505,7 @@ void Tensor::set_scale_inv(float scale_inv) {
NVTE_CHECK(columnwise_scale_inv_cpu_data_);
}
- auto [rowwise_scale_meta, colwise_scale_meta] =
- get_scales(tensor_.shape(), tensor_.scaling_mode());
+ auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode());
if (rowwise_) {
auto num_scales = product(rowwise_scale_meta.shape);
if (num_scales == 1) {
@@ -473,7 +538,8 @@ void Tensor::set_scale_inv(float scale_inv) {
}
void Tensor::shareFP8Meta(const Tensor &other) {
- if (isFp8Type(dtype()) && isFp8Type(other.dtype())) {
+ if ((isFp8Type(dtype()) && isFp8Type(other.dtype()))
+ || isFp4Type(dtype()) && isFp4Type(other.dtype())) {
auto new_tensor = TensorWrapper(other.tensor_.scaling_mode());
auto my_rowwise_data = tensor_.get_rowwise_data();
new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast(my_rowwise_data.dtype),
@@ -686,13 +752,30 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t
}
}
-void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
- const size_t row_blocks, const size_t col_blocks, const size_t stride,
- std::vector &mismatch_indices,
- size_t& mismatches_num, const size_t atol,
- const double abs_tolerable_mismatches_limit,
- const double rel_tolerable_mismatches_limit)
+template
+struct CastToType;
+
+template <>
+struct CastToType {
+ using type = int;
+};
+
+template <>
+struct CastToType {
+ using type = float;
+};
+
+template
+void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
+ const size_t row_blocks, const size_t col_blocks, const size_t stride,
+ std::vector &mismatch_indices, size_t& mismatches_num, const size_t atol,
+ const double abs_tolerable_mismatches_limit,
+ const double rel_tolerable_mismatches_limit)
{
+ using UpcastType = typename CastToType::type;
+ auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3);
+
+
const size_t N = row_blocks * col_blocks;
const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit,
std::floor(N * rel_tolerable_mismatches_limit));
@@ -701,11 +784,31 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
for (int i = 0; i < row_blocks; ++i) {
for (int j = 0; j < col_blocks; ++j) {
const int idx = i * stride + j;
- const int test_val = static_cast(test[idx]);
- const int ref_val = static_cast(ref[idx]);
- const int abs_delta = std::abs(test_val - ref_val);
+ float t, r;
- if (abs_delta > atol) {
+ bool assertion = false;
+
+ if (std::is_same::value) {
+ t = static_cast(test[idx]);
+ r = static_cast(ref[idx]);
+ assertion = std::abs(t - r) > atol;
+ } else {
+ t = static_cast(*reinterpret_cast(&test[idx]));
+ r = static_cast(*reinterpret_cast(&ref[idx]));
+ const bool mismatch = (fabs(t - r) > atol_fp8e4m3)
+ && (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3);
+ if (mismatch) {
+ /* Check if it is just a failure of round to nearest choosing different
+ side of the real value */
+ const double mean = (t + r) / 2;
+ const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
+ const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
+ const double cast_mean_p = static_cast(static_cast(mean_p));
+ const double cast_mean_m = static_cast(static_cast(mean_m));
+ assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
+ }
+ }
+ if (assertion) {
mismatches_num++;
mismatch_indices.push_back(idx);
}
@@ -713,8 +816,8 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
std::cout << "Error in " << name << std::endl;
for (const int index : mismatch_indices) {
std::cout << "Mismatch at (" << index << "):"
- << static_cast(test[index]) << " vs "
- << static_cast(ref[index]) << std::endl;
+ << static_cast(test[index]) << " vs "
+ << static_cast(ref[index]) << std::endl;
}
GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of "
<< tolerable_mismatches_limit << ".";
@@ -723,6 +826,22 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
}
}
+// Instantiate templates
+template
+void compare_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
+ const size_t row_blocks, const size_t col_blocks, const size_t stride,
+ std::vector &mismatch_indices, size_t& mismatches_num, const size_t atol,
+ const double abs_tolerable_mismatches_limit,
+ const double rel_tolerable_mismatches_limit);
+
+template
+void compare_scaling_factors(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref,
+ const size_t row_blocks, const size_t col_blocks, const size_t stride,
+ std::vector &mismatch_indices, size_t& mismatches_num, const size_t atol,
+ const double abs_tolerable_mismatches_limit,
+ const double rel_tolerable_mismatches_limit);
+
+
#ifdef __HIP_PLATFORM_AMD__
void adjust_ref_for_e8m0_scale_error(const std::string &name,
@@ -932,11 +1051,14 @@ bool isFp8Type(DType type) {
return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0;
}
-int32_t getDeviceComputeCapability()
-{
- cudaDeviceProp deviceProp;
- (void)cudaGetDeviceProperties(&deviceProp, 0);
- return 10 * deviceProp.major + deviceProp.minor;
+bool isFp4Type(DType type) {
+ return type == DType::kFloat4E2M1;
+}
+
+int32_t getDeviceComputeCapability() {
+ cudaDeviceProp deviceProp;
+ (void)cudaGetDeviceProperties(&deviceProp, 0);
+ return 10 * deviceProp.major + deviceProp.minor;
}
size_t first_dimension(const std::vector &shape) {
@@ -954,7 +1076,8 @@ std::array get_scale_tensor_dims(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols) {
- const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32);
+ const bool is_rowwise = (block_size_rows == 1)
+ && ((block_size_cols == 32) || (block_size_cols == 16));
const size_t alignment_Y = is_rowwise
? scale_tensor_alignment_Y_rowwise
diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h
index 3c0a387c6..95aff8cf8 100644
--- a/tests/cpp/test_common.h
+++ b/tests/cpp/test_common.h
@@ -74,6 +74,8 @@ using fp8e5m2 = te_hip_fp8_e5m2;
using fp8e8m0 = uint8_t;
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
+using fp4e2m1x2 = __nv_fp4x2_e2m1;
+using fp4e2m1x4 = __nv_fp4x4_e2m1;
#endif
template
@@ -235,7 +237,9 @@ class Tensor {
float scale() const {
if(scale_cpu_data_) {
- NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!");
+ NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)
+ || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING),
+ "Invalid scaling_mode!");
to_cpu();
return *scale_cpu_data_;
} else {
@@ -249,6 +253,8 @@ class Tensor {
NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!");
+ } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
+ NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!");
} else {
NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!");
}
@@ -262,6 +268,8 @@ class Tensor {
NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!");
} else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) {
NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!");
+ } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) {
+ NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!");
} else {
NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!");
}
@@ -316,10 +324,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127;
constexpr uint32_t FP32_MANTISSA_BITS = 23;
// [128,4] rowwise and [4,128] colwise alignment requirement
-constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_rowwise = 128;
-constexpr size_t scale_tensor_alignment_X_colwise = 128;
+constexpr size_t scale_tensor_alignment_X_rowwise = 4;
constexpr size_t scale_tensor_alignment_Y_colwise = 4;
+constexpr size_t scale_tensor_alignment_X_colwise = 128;
inline size_t divide_round_up(const size_t N, const size_t M) {
return (N - 1 + M) / M;
@@ -480,12 +488,13 @@ void compareResults(const std::string &name, const float test, const float ref,
double atol = 1e-5, double rtol = 1e-8);
void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref,
size_t N, float mismatch_rate_tol = 0.);
-void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
- const size_t row_blocks, const size_t col_blocks, const size_t stride,
- std::vector &mismatch_indices, size_t& mismatches_num,
- const size_t scale_diff_abs_tolerance = 0,
- const double abs_tolerable_mismatches_limit = 0,
- const double rel_tolerable_mismatches_limit = 0);
+template
+void compare_scaling_factors(const std::string &name, const T *test, const T *ref,
+ const size_t row_blocks, const size_t col_blocks, const size_t stride,
+ std::vector &mismatch_indices, size_t& mismatches_num,
+ const size_t scale_diff_abs_tolerance = 0,
+ const double abs_tolerable_mismatches_limit = 0,
+ const double rel_tolerable_mismatches_limit = 0);
#ifdef USE_ROCM
void adjust_ref_for_e8m0_scale_error(const std::string &name,
@@ -516,6 +525,7 @@ const std::string& caseName(InputsFillCase type);
extern std::vector all_fp_types;
bool isFp8Type(DType type);
+bool isFp4Type(DType type);
int32_t getDeviceComputeCapability();
constexpr int32_t hopperComputeCapability = 90;
@@ -593,7 +603,7 @@ constexpr int32_t blackwellComputeCapability = 100;
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
default: \
printf("dtype: %d\n", static_cast(dtype)); \
- NVTE_ERROR("Invalid type MARKED TEST."); \
+ NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
@@ -612,7 +622,7 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
- NVTE_ERROR("Invalid type MARKED TEST 2."); \
+ NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
@@ -620,7 +630,7 @@ constexpr int32_t blackwellComputeCapability = 100;
using namespace transformer_engine; \
SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
default: \
- NVTE_ERROR("Invalid type MARKED TEST 3."); \
+ NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
@@ -645,5 +655,5 @@ constexpr int32_t blackwellComputeCapability = 100;
} \
break; \
default: \
- NVTE_ERROR("Invalid type MARKED TEST 4."); \
+ NVTE_ERROR("Invalid type."); \
}
diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py
index 6ec3c27a4..0a75053b1 100644
--- a/tests/jax/test_custom_call_compute.py
+++ b/tests/jax/test_custom_call_compute.py
@@ -788,9 +788,15 @@ def _test_quantize_dact_dbias(
assert_allclose(te_output.data, jax_output.data)
if is_dbias:
- # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
precise_comparison = not (
- in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()
+ # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16.
+ (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling())
+ # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently.
+ or (
+ activation_type == ("squared_relu",)
+ and in_dtype == jnp.bfloat16
+ and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING
+ )
)
assert_allclose(
te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype
diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py
index 03c0d1119..bf01e3216 100644
--- a/tests/jax/test_distributed_layernorm.py
+++ b/tests/jax/test_distributed_layernorm.py
@@ -78,8 +78,6 @@ def generate_collectives_count_ref(
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
other_bytes = 0
- if fp8_recipe == recipe.Float8CurrentScaling():
- allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
)
diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py
index 10bb066a4..5e2898cf3 100644
--- a/tests/pytorch/attention/run_attention_with_cp.py
+++ b/tests/pytorch/attention/run_attention_with_cp.py
@@ -15,94 +15,28 @@
from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import (
get_cu_seqlens_on_cp_rank,
)
+from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize
import transformer_engine_torch as tex
from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn
from transformer_engine.pytorch.fp8 import fp8_autocast
-from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
-from transformer_engine.common.recipe import DelayedScaling
+from transformer_engine.pytorch.tensor.float8_tensor import (
+ Float8Tensor,
+ Float8Quantizer,
+ Float8CurrentScalingQuantizer,
+)
+from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling
import warnings
+from utils import ModelConfig, compare_and_assert
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
-def run_dpa_with_cp(
- dtype="bf16",
- model=None,
- qkv_format="bshd",
- kernel_backend="FlashAttention",
- cp_comm_type="p2p",
- fp8_mha=False,
+def generate_input_shapes(
+ qkv_format: str,
+ config: ModelConfig,
+ world_size: int,
+ kernel_backend: str,
):
- """Test DotProductAttention module with context parallelism"""
-
- # args are passed as strings
- fp8_mha = fp8_mha == "True"
- os.environ["NVTE_FLASH_ATTN"] = "0"
- os.environ["NVTE_FUSED_ATTN"] = "0"
- if kernel_backend == "FlashAttention":
- os.environ["NVTE_FLASH_ATTN"] = "1"
- config = model_configs_flash_attn[model]
- if kernel_backend == "FusedAttention":
- os.environ["NVTE_FUSED_ATTN"] = "1"
- config = model_configs_fused_attn[model]
-
- assert config.attn_mask_type in [
- "causal",
- "no_mask",
- ], f"{config.attn_mask_type} is an unsupported attention mask type!"
- if qkv_format == "thd":
- if "causal" in config.attn_mask_type:
- config.attn_mask_type = "padding_causal"
- else:
- config.attn_mask_type = "padding"
-
- rank = int(os.getenv("RANK", "0"))
- world_size = int(os.getenv("WORLD_SIZE", "1"))
-
- if dist.is_initialized():
- world_size = dist.get_world_size()
- rank = dist.get_rank()
- else:
- device_count = torch.cuda.device_count()
- device = rank % device_count
- torch.cuda.set_device(device)
-
- print(f"[INFO] world_size:{world_size}, rank:{rank}")
-
- dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
-
- # create flash attn comm group for CP
- cp_comm_ranks = range(world_size)
- assert rank in cp_comm_ranks
- cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
- if cp_comm_type == "a2a+p2p":
- assert (
- world_size % 2 == 0
- ), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!"
- cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
- cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
- cp_comm_sub_groups = []
- for sub_ranks in cp_comm_sub_ranks:
- sub_group = dist.new_group(sub_ranks, backend="nccl")
- if rank in sub_ranks:
- cp_comm_sub_groups.append(sub_group)
-
- if dtype == "fp8":
- fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha)
-
- # instantiate core attn module
- core_attn = DotProductAttention(
- config.num_heads,
- (config.head_dim_qk, config.head_dim_v),
- num_gqa_groups=config.num_gqa_groups,
- attention_dropout=config.dropout_p,
- qkv_format=qkv_format,
- attn_mask_type=config.attn_mask_type,
- window_size=config.window_size,
- )
- core_attn = core_attn.cuda()
-
- # create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (
config.batch_size,
@@ -197,35 +131,192 @@ def run_dpa_with_cp(
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
else:
- assert False, f"{qkv_format} is an unsupported qkv_format!"
-
- q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
- k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda()
- v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda()
- dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()
- dout_quantizer = Float8Quantizer(
- fp8_dtype=tex.DType.kFloat8E5M2,
- scale=torch.tensor([1], dtype=torch.float32).cuda(),
- amax=torch.tensor([0], dtype=torch.float32).cuda(),
+ assert False, f"{qkv_format=} is not supported!"
+
+ return (
+ q_input_shape,
+ k_input_shape,
+ v_input_shape,
+ attn_output_shape,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ cu_seqlens_q_padded,
+ cu_seqlens_kv_padded,
)
- # create flash attention bias
+
+def get_tols(config, dtype):
+ if dtype == "bf16":
+ if config.num_heads == config.num_gqa_groups:
+ atol = 2.5e-2
+ rtol = 2.5e-2
+ else:
+ atol = 3.5e-2
+ rtol = 3.5e-2
+ rmse_tol = 0.01
+ elif dtype == "fp16":
+ atol = 5e-3
+ rtol = 5e-3
+ rmse_tol = 0.01
+ elif dtype == "fp8":
+ atol = 5e-1
+ rtol = 5e-1
+ rmse_tol = 0.15
+ else:
+ assert False, f"{dtype=} is not supported!"
+
+ return atol, rtol, rmse_tol
+
+
+def run_dpa_with_cp(
+ dtype="bf16",
+ model=None,
+ qkv_format="bshd",
+ kernel_backend="FlashAttention",
+ cp_comm_type="p2p",
+ fp8_bwd="True",
+ fp8_dpa="False",
+ fp8_mha="False",
+ scaling_mode="delayed",
+ f16_O="False",
+ log_level=logging.WARNING,
+):
+ """Test DotProductAttention module with context parallelism"""
+ logging.root.setLevel(log_level)
+
+ # set up environment variables and config
+ fp8_bwd = fp8_bwd == "True" and dtype == "fp8"
+ os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0"
+ fp8_dpa = fp8_dpa == "True" and dtype == "fp8"
+ fp8_mha = fp8_mha == "True" and dtype == "fp8"
+ f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True"
+ os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0"
+ os.environ["NVTE_FLASH_ATTN"] = "0"
+ os.environ["NVTE_FUSED_ATTN"] = "0"
+ if kernel_backend == "FlashAttention":
+ os.environ["NVTE_FLASH_ATTN"] = "1"
+ config = model_configs_flash_attn[model]
+ if kernel_backend == "FusedAttention":
+ os.environ["NVTE_FUSED_ATTN"] = "1"
+ config = model_configs_fused_attn[model]
+ assert config.attn_mask_type in [
+ "causal",
+ "no_mask",
+ ], f"{config.attn_mask_type=} is not supported!"
+ if qkv_format == "thd":
+ if "causal" in config.attn_mask_type:
+ config.attn_mask_type = "padding_causal"
+ else:
+ config.attn_mask_type = "padding"
+
+ # set up distributed group
+ rank = int(os.getenv("RANK", "0"))
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
+ if dist.is_initialized():
+ world_size = dist.get_world_size()
+ rank = dist.get_rank()
+ else:
+ device_count = torch.cuda.device_count()
+ device = rank % device_count
+ torch.cuda.set_device(device)
+ logging.info(f"[Rank {rank}] Setup: world_size {world_size}")
+ dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
+
+ # set up communication group for CP
+ cp_comm_ranks = range(world_size)
+ assert rank in cp_comm_ranks
+ cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
+ if cp_comm_type == "a2a+p2p":
+ assert world_size % 2 == 0, (
+ "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size"
+ " = 2."
+ )
+ cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
+ cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
+ cp_comm_sub_groups = []
+ for sub_ranks in cp_comm_sub_ranks:
+ sub_group = dist.new_group(sub_ranks, backend="nccl")
+ if rank in sub_ranks:
+ cp_comm_sub_groups.append(sub_group)
+
+ if dtype == "fp8":
+ if scaling_mode == "delayed":
+ fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
+ if scaling_mode == "current":
+ fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha)
+
+ # instantiate attention module
+ core_attn = DotProductAttention(
+ config.num_heads,
+ (config.head_dim_qk, config.head_dim_v),
+ num_gqa_groups=config.num_gqa_groups,
+ attention_dropout=config.dropout_p,
+ qkv_format=qkv_format,
+ attn_mask_type=config.attn_mask_type,
+ window_size=config.window_size,
+ softmax_type=config.softmax_type,
+ ).cuda()
+ if config.softmax_type != "vanilla":
+ core_attn.softmax_offset.requires_grad = True
+
+ # generate attention inputs
+ (
+ q_input_shape,
+ k_input_shape,
+ v_input_shape,
+ attn_output_shape,
+ cu_seqlens_q,
+ cu_seqlens_kv,
+ cu_seqlens_q_padded,
+ cu_seqlens_kv_padded,
+ ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend)
+ q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
+ k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
+ v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda()
+ dout_orig = torch.clamp(
+ torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1
+ ).cuda()
+ if scaling_mode == "delayed":
+ qkv_quantizer = Float8Quantizer(
+ fp8_dtype=tex.DType.kFloat8E4M3,
+ scale=torch.tensor([1], dtype=torch.float32).cuda(),
+ amax=torch.tensor([0], dtype=torch.float32).cuda(),
+ )
+ dout_quantizer = Float8Quantizer(
+ fp8_dtype=tex.DType.kFloat8E5M2,
+ scale=torch.tensor([1], dtype=torch.float32).cuda(),
+ amax=torch.tensor([0], dtype=torch.float32).cuda(),
+ )
+ if scaling_mode == "current":
+ qkv_quantizer = Float8CurrentScalingQuantizer(
+ fp8_dtype=tex.DType.kFloat8E4M3,
+ device="cuda",
+ )
+ dout_quantizer = Float8CurrentScalingQuantizer(
+ fp8_dtype=tex.DType.kFloat8E5M2,
+ device="cuda",
+ )
+ qkv_layout = "_".join([qkv_format] * 3)
+ q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]]
+ if fp8_mha:
+ q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer)
+ for x in [q, k, v]:
+ x.requires_grad = True
+
if config.attn_bias_type not in ["no_bias", "alibi"]:
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda()
else:
bias = None
- # run core_attn without CP
- for x in [q, k, v]:
- x.requires_grad = True
-
+ ############ run without CP ############
+ logging.info(f"[Rank {rank}] Run without context parallelism")
if dtype == "fp8":
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
-
with fp8_context:
+ # q, k, v, out in FP8; dout in F16
out = core_attn(
q,
k,
@@ -238,16 +329,25 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
+ fp8_output=fp8_mha,
)
- if fp8_mha:
+ if fp8_bwd and fp8_mha:
dout_fp8 = dout_quantizer(dout)
out.backward(dout_fp8)
else:
out.backward(dout)
+ dq, dk, dv = q.grad, k.grad, v.grad
+ d_softmax_offset = None
+ if config.softmax_type != "vanilla":
+ d_softmax_offset = core_attn.softmax_offset.grad
+
+ ############ run with CP ############
+ logging.info(f"[Rank {rank}] Run with context parallelism")
- # run core_attn wit CP
+ # set up inputs
q_, k_, v_, dout_, *rest = [
- x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])
+ x.clone().detach()
+ for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias])
]
bias_ = rest[0] if len(rest) else None
if qkv_format == "bshd" or qkv_format == "sbhd":
@@ -277,6 +377,14 @@ def run_dpa_with_cp(
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
+ q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]]
+ if scaling_mode == "delayed":
+ qkv_quantizer.scale.fill_(1.0)
+ qkv_quantizer.amax.fill_(0.0)
+ dout_quantizer.scale.fill_(1.0)
+ dout_quantizer.amax.fill_(0.0)
+ if fp8_mha:
+ q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer)
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
bias_ = bias_.view(
@@ -284,20 +392,25 @@ def run_dpa_with_cp(
)
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
+ # set up environment
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)
-
+ if config.softmax_type != "vanilla":
+ core_attn.softmax_offset.grad.zero_()
if dtype == "fp8":
- core_attn.reset_fp8_meta_tensors()
+ core_attn.fp8_initialized = False
+ core_attn.fp8_meta_tensors_initialized = False
fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group)
else:
fp8_context = nullcontext()
+ # run attention
with fp8_context:
+ # q, k, v, out in FP8; dout in F16
out_ = core_attn(
q_,
k_,
@@ -310,24 +423,32 @@ def run_dpa_with_cp(
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
+ fp8_output=fp8_mha,
)
- if fp8_mha:
+ if fp8_bwd and fp8_mha:
dout_fp8_ = dout_quantizer(dout_)
out_.backward(dout_fp8_)
else:
out_.backward(dout_)
+ dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad
+ d_softmax_offset_ = None
+ if config.softmax_type != "vanilla":
+ d_softmax_offset_ = core_attn.softmax_offset.grad.clone()
+ # get outputs
+ tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_]
if fp8_mha:
- assert isinstance(out, Float8Tensor)
- assert isinstance(out_, Float8Tensor)
- out = out.dequantize()
- out_ = out_.dequantize()
-
- for x in [out_, q_.grad, k_.grad, v_.grad]:
- assert torch.all(~torch.isnan(x))
- assert torch.all(~torch.isinf(x))
-
- # compare results with and without CP
+ tensors_to_deq = [out, out_] if not fp8_bwd else tensors
+ for i, tensor in enumerate(tensors_to_deq):
+ tensors_to_deq[i] = tensor.dequantize()
+ if not fp8_bwd:
+ tensors[0], tensors[4] = tensors_to_deq
+ for tensor in tensors:
+ assert torch.all(~torch.isnan(tensor))
+ assert torch.all(~torch.isinf(tensor))
+ out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors
+
+ ############ compare results between CP and no-CP ############
if qkv_format == "bshd" or qkv_format == "sbhd":
dq, dk, dv, out = [
x.view(
@@ -336,17 +457,17 @@ def run_dpa_with_cp(
x.shape[seq_dim] // (2 * world_size),
*x.shape[(seq_dim + 1) :],
)
- for x in [q.grad, k.grad, v.grad, out]
+ for x in [dq, dk, dv, out]
]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [
x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
- for x in [q_.grad, k_.grad, v_.grad, out_]
+ for x in [dq_, dk_, dv_, out_]
]
elif qkv_format == "thd":
- dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]]
- dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]]
- dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_]
+ dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]]
+ dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]]
+ dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
@@ -389,56 +510,70 @@ def run_dpa_with_cp(
).item()
== 0
)
- else:
- assert False, f"{qkv_format} is an unsupported qkv_format!"
-
- if dtype == "bf16":
- if config.num_heads == config.num_gqa_groups:
- tols = dict(atol=2.5e-2, rtol=2.5e-2)
- else:
- tols = dict(atol=3.5e-2, rtol=3.5e-2)
- elif dtype == "fp16":
- tols = dict(atol=5e-3, rtol=5e-3)
- elif dtype == "fp8":
- tols = dict(atol=5e-1, rtol=5e-1)
- rmse_tol = 0.1
- else:
- assert False, f"{dtype} is an unsupported dtype!"
-
- def _rmse(a, b):
- return torch.sqrt((a - b).square().mean()).item()
- def _error(a, b):
- if dtype != "fp8":
- torch.testing.assert_close(a, b, **tols)
- else:
- try:
- torch.testing.assert_close(a, b, **tols)
- except Exception as e:
- logging.debug(e)
-
- rmse = _rmse(a, b)
- rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
- assert (
- rmse < rmse_tol * rmse_range
- ), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
- rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
- )
-
- if qkv_format == "bshd":
- for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
- _error(a[:, 0], b[:, 0])
- _error(a[:, 1], b[:, 1])
- elif qkv_format == "sbhd":
- for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
- _error(a[0], b[0])
- _error(a[1], b[1])
- elif qkv_format == "thd":
- for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]):
- _error(a, b)
- else:
- assert False, f"{qkv_format} is an unsupported qkv_format!"
+ atol, rtol, rmse_tol = get_tols(config, dtype)
+ tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_]
+ tensors_no_cp = [out, dq, dk, dv, d_softmax_offset]
+ names = ["out", "dq", "dk", "dv", "d_softmax_offset"]
+ names_cp = [x + "_cp" for x in names]
+ names_no_cp = [x + "_no_cp" for x in names]
+ is_fp8 = dtype == "fp8"
+ for i, t in enumerate(tensors_no_cp):
+ if t is not None:
+ if "softmax_offset" not in names[i]:
+ if qkv_format == "bshd":
+ compare_and_assert(
+ t[:, 0],
+ tensors_cp[i][:, 0],
+ names_no_cp[i],
+ names_cp[i],
+ atol,
+ rtol,
+ rmse_tol,
+ is_fp8,
+ )
+ compare_and_assert(
+ t[:, 1],
+ tensors_cp[i][:, 1],
+ names_no_cp[i],
+ names_cp[i],
+ atol,
+ rtol,
+ rmse_tol,
+ is_fp8,
+ )
+ elif qkv_format == "sbhd":
+ compare_and_assert(
+ t[0],
+ tensors_cp[i][0],
+ names_no_cp[i],
+ names_cp[i],
+ atol,
+ rtol,
+ rmse_tol,
+ is_fp8,
+ )
+ compare_and_assert(
+ t[1],
+ tensors_cp[i][1],
+ names_no_cp[i],
+ names_cp[i],
+ atol,
+ rtol,
+ rmse_tol,
+ is_fp8,
+ )
+ elif qkv_format == "thd":
+ compare_and_assert(
+ t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
+ )
+ else:
+ compare_and_assert(
+ t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
+ )
+ logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches")
+ # destroy distribution group
dist.destroy_process_group()
diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py
index a5128653e..08f3b0c6b 100644
--- a/tests/pytorch/attention/test_attention.py
+++ b/tests/pytorch/attention/test_attention.py
@@ -4,7 +4,6 @@
#
# See LICENSE for license information.
import logging
-import math
import os
import sys
import pathlib
@@ -54,19 +53,21 @@
sys.path.append(str(_current_file.parent.parent))
from utils import (
reset_rng_states,
+ compare_and_assert,
ModelConfig,
dtype_tols,
get_available_attention_backends,
)
-# Only run FP8 tests on H100
+# Check if hardware supports FP8
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
+# Reset RNG seed and states
seed = 1234
-# Reset RNG states
reset_rng_states()
+# Reset FP8 global state manager
@pytest.fixture(autouse=True)
def reset_global_fp8_state():
yield
@@ -82,9 +83,14 @@ def reset_attn_backend():
"NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", "NVTE_FP8_DPA_BWD"])
yield
+# Define F16 data types to test
+param_types = [torch.float16]
+if is_bf16_compatible():
+ param_types.append(torch.bfloat16)
+param_types_lean = [torch.bfloat16]
model_configs_base = {
- # test: b, h, hg, d, sq, skv, p, mask, bias
+ # test: ModelConfig(b, sq, hq, dqk)
"base_1_0": ModelConfig(8, 128, 16, 64),
"base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256),
"base_2_0": ModelConfig(2, 2048, 24, 128),
@@ -100,11 +106,6 @@ def reset_attn_backend():
}
-param_types = [torch.float16]
-if is_bf16_compatible(): # bf16 requires sm_80 or higher
- param_types.append(torch.bfloat16)
-param_types_lean = [torch.bfloat16]
-
# TODO: Enable config support in other backend(s) -- currently only the CK
# backend is capable of supporting it.
@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.")
@@ -121,7 +122,6 @@ def test_gqa_mla_thd():
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
- window_size=config.window_size,
pad_between_seqs=True,
)
if FusedAttnBackend["CK"] not in fused_attn_backends:
@@ -147,7 +147,6 @@ def test_dot_product_mem_calc():
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
- window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
@@ -204,12 +203,12 @@ def test_dot_product_attention(
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
+ # Get backends
is_training = True
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
- window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
@@ -220,7 +219,6 @@ def test_dot_product_attention(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
- window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
@@ -332,6 +330,7 @@ def test_dot_product_attention(
share_cu_seqlens_ref,
)
+ # Compare results
logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
@@ -369,23 +368,102 @@ def test_dpa_checkpoint(dtype, model_configs, model):
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False, False)
+model_configs_softmax = {
+ # test: ModelConfig(b, sq, hq, dqk)
+ "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
+ "softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"),
+ "softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"),
+ "softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"),
+ "softmax_2_1": ModelConfig(
+ 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
+ ),
+ "softmax_2_2": ModelConfig(
+ 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
+ ),
+ "softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"),
+ "softmax_3_1": ModelConfig(
+ 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one"
+ ),
+ "softmax_3_2": ModelConfig(
+ 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable"
+ ),
+ "softmax_4_0": ModelConfig(
+ 2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal"
+ ),
+ "softmax_4_1": ModelConfig(
+ 2,
+ 2048,
+ 64,
+ 64,
+ num_gqa_groups=8,
+ window_size=(128, 0),
+ attn_mask_type="causal",
+ softmax_type="off-by-one",
+ ),
+ "softmax_4_2": ModelConfig(
+ 2,
+ 2048,
+ 64,
+ 64,
+ num_gqa_groups=8,
+ window_size=(128, 0),
+ attn_mask_type="causal",
+ softmax_type="learnable",
+ ),
+ "softmax_5_0": ModelConfig(
+ 2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal"
+ ),
+ "softmax_5_1": ModelConfig(
+ 2,
+ 2048,
+ 64,
+ 64,
+ num_gqa_groups=8,
+ window_size=(128, 0),
+ attn_mask_type="padding_causal",
+ softmax_type="off-by-one",
+ ),
+ "softmax_5_2": ModelConfig(
+ 2,
+ 2048,
+ 64,
+ 64,
+ num_gqa_groups=8,
+ window_size=(128, 0),
+ attn_mask_type="padding_causal",
+ softmax_type="learnable",
+ ),
+}
+
+
+@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
+@pytest.mark.parametrize("dtype", [torch.bfloat16])
+@pytest.mark.parametrize("model_configs", [model_configs_softmax])
+@pytest.mark.parametrize("model", model_configs_softmax.keys())
+def test_dpa_softmax(dtype, model_configs, model):
+ """Test DotProductAttention module with different softmax types"""
+ test_dot_product_attention(
+ dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False, False
+ )
+
+
model_configs_mla = {
- # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
- "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
- "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
- "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
- "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
+ # test: ModelConfig(b, sq, hq, dqk)
+ "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),
+ "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128),
+ "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128),
+ "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64),
"mla_2_1": ModelConfig(
1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
- ), # cross, 1
+ ),
"mla_2_2": ModelConfig(
1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
- ), # cross, 1
- "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
- "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
- "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
- "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
- "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference
+ ),
+ "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64),
+ "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128),
+ "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128),
+ "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128),
+ "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160),
}
@@ -399,7 +477,7 @@ def test_dpa_mla(dtype, model_configs, model):
model_configs_mask = {
- # test: b, h, hg, d, sq, skv, p, mask, bias
+ # test: ModelConfig(b, sq, hq, dqk)
"mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"),
"mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"),
"mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"),
@@ -454,18 +532,16 @@ def test_dpa_mask(dtype, model_configs, model):
model_configs_bias = {
- # test: b, h, hg, d, sq, skv, p, mask, bias
+ # test: ModelConfig(b, sq, hq, dqk)
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"),
"bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"),
- "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped
- "bias_1_5": ModelConfig(
- 2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"
- ), # skipped
+ "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"),
+ "bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"),
"bias_2_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias"
- ), # skipped
+ ),
"bias_2_1": ModelConfig(
2,
128,
@@ -474,10 +550,10 @@ def test_dpa_mask(dtype, model_configs, model):
max_seqlen_kv=256,
attn_mask_type="padding",
attn_bias_type="post_scale_bias",
- ), # skipped
+ ),
"bias_2_2": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias"
- ), # skipped
+ ),
"bias_2_3": ModelConfig(
2,
2048,
@@ -486,13 +562,11 @@ def test_dpa_mask(dtype, model_configs, model):
max_seqlen_kv=4096,
attn_mask_type="padding",
attn_bias_type="post_scale_bias",
- ), # skipped
- "bias_2_4": ModelConfig(
- 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"
- ), # skipped
+ ),
+ "bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"),
"bias_2_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi"
- ), # skipped
+ ),
"bias_3_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
@@ -510,14 +584,14 @@ def test_dpa_mask(dtype, model_configs, model):
max_seqlen_kv=4096,
attn_mask_type="causal",
attn_bias_type="post_scale_bias",
- ), # skipped
+ ),
"bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"),
"bias_3_5": ModelConfig(
2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi"
- ), # skipped
+ ),
"bias_4_0": ModelConfig(
4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
- ), # skipped
+ ),
"bias_4_1": ModelConfig(
2,
128,
@@ -526,10 +600,10 @@ def test_dpa_mask(dtype, model_configs, model):
max_seqlen_kv=256,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
- ), # skipped
+ ),
"bias_4_2": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias"
- ), # skipped
+ ),
"bias_4_3": ModelConfig(
2,
2048,
@@ -538,10 +612,10 @@ def test_dpa_mask(dtype, model_configs, model):
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="post_scale_bias",
- ), # skipped
+ ),
"bias_4_4": ModelConfig(
4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi"
- ), # skipped
+ ),
"bias_4_5": ModelConfig(
2,
2048,
@@ -550,7 +624,7 @@ def test_dpa_mask(dtype, model_configs, model):
max_seqlen_kv=4096,
attn_mask_type="padding_causal",
attn_bias_type="alibi",
- ), # skipped
+ ),
}
@@ -564,7 +638,7 @@ def test_dpa_bias(dtype, model_configs, model):
model_configs_bias_shapes = {
- # test: b, h, hg, d, sq, skv, p,
+ # test: ModelConfig(b, sq, hq, dqk)
"bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"),
"bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
@@ -602,7 +676,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa = {
- # test: b, h, hg, d, sq, skv, p, mask, bias
+ # test: ModelConfig(b, sq, hq, dqk)
"swa_1_1": ModelConfig(2, 2048, 16, 64),
"swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4),
"swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096),
@@ -642,7 +716,7 @@ def test_dpa_sliding_window(dtype, model_configs, model):
model_configs_alibi_slopes = {
- # test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
+ # test: ModelConfig(b, sq, hq, dqk)
"alibi_1_0": ModelConfig(
2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla"
),
@@ -696,7 +770,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model):
model_configs_layout = {
- # test: b, h, hg, d, sq, skv, p, mask, bias
+ # test: ModelConfig(b, sq, hq, dqk)
"layout_0_0": ModelConfig(2, 128, 16, 64),
"layout_0_1": ModelConfig(
2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
@@ -744,7 +818,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = {
- # test: b, h, hg, d, sq, skv, p, mask, bias
+ # test: ModelConfig(b, sq, hq, dqk)
"layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"),
"layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"),
"layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"),
@@ -824,7 +898,6 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout, pad_between
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
- window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
if FusedAttnBackend["CK"] not in fused_attn_backends:
@@ -848,7 +921,6 @@ def test_dpa_qkv_layout_thd_mqa_gqa(dtype, model_configs, model, qkv_layout, pad
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
- window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
if FusedAttnBackend["CK"] not in fused_attn_backends:
@@ -882,7 +954,6 @@ def _run_dot_product_attention(
share_cu_seqlens_ref: bool = False,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
-
# Set RNG and environment varables
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
@@ -1145,9 +1216,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
tp_group=None,
layer_number=1,
attention_type=config.attn_type,
+ softmax_type=config.softmax_type,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()
+ if is_training and config.softmax_type != "vanilla":
+ block.softmax_offset.requires_grad = True
cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None
@@ -1188,12 +1262,14 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
)
if is_training:
out.backward(d_out)
-
+ d_softmax_offset = None
+ if is_training and config.softmax_type != "vanilla":
+ d_softmax_offset = block.softmax_offset.grad
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
if is_training:
- return out, (q.grad, k.grad, v.grad)
+ return out, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
- return out, (None, None, None)
+ return out, (None, None, None, d_softmax_offset)
if backend == "FusedAttention":
if qkv_format == "thd" and pad_between_seqs:
out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
@@ -1222,18 +1298,18 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
[v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
)
if is_training:
- return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
+ return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset)
else:
- return out_orig, (None, None, None)
+ return out_orig, (None, None, None, d_softmax_offset)
else:
if is_training:
- return out, (q.grad, k.grad, v.grad)
+ return out, (q.grad, k.grad, v.grad, d_softmax_offset)
else:
- return out, (None, None, None)
+ return out, (None, None, None, d_softmax_offset)
model_configs_te_layer = {
- # test: b, h, hg, d, sq, skv, p, mask, bias
+ # test: ModelConfig(b, sq, hq, dqk)
"te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"),
"te_1_1": ModelConfig(
4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias"
@@ -1598,6 +1674,7 @@ def _run_transformer_layer(
model_configs_fp8_extra_state = {
+ # test: ModelConfig(b, sq, hq, dqk)
"large": ModelConfig(2, 128, 4, 128, num_layers=1),
}
@@ -1607,7 +1684,8 @@ def _run_transformer_layer(
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
-def test_sanity_attention_extra_state(model, dtype):
+def test_dpa_fp8_extra_state(model, dtype):
+ """Test DotProductAttention module in FP8 with checkpointing"""
config = model_configs_fp8_extra_state[model]
# Test backend availability
is_training = True
@@ -1621,9 +1699,9 @@ def test_sanity_attention_extra_state(model, dtype):
if not fused_attn_supported and not flash_attn_supported:
pytest.skip("No attention backend available.")
- outputs = _run_attention_extra_state(dtype, config, checkpoint=False)
- outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True)
- outputs_checkpoint_v1_6 = _run_attention_extra_state(
+ outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False)
+ outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True)
+ outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state(
dtype, config, mimic_v1_6=True, checkpoint=True
)
@@ -1645,7 +1723,8 @@ def test_sanity_attention_extra_state(model, dtype):
)
-def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
+def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False):
+ """Run DotProductAttention module in FP8 with checkpointing"""
steps = 10
path = "checkpoint.pt"
fp8_enabled = True
@@ -1742,7 +1821,7 @@ def get_model(dtype, config):
model_configs_fp8_vs_f16 = {
- # test: b, h, hg, d, sq, skv, p, mask, bias
+ # test: ModelConfig(b, sq, hq, dqk)
"fp8_9": ModelConfig(2, 2048, 16, 128),
"fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12),
"fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4),
@@ -1800,22 +1879,44 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("RoPE", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
-def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training):
+@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
+def test_mha_fp8_vs_f16(
+ dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode
+):
+ """Test MultiHeadAttention module in FP8"""
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
config = model_configs_fp8_vs_f16[model]
# Test backend availability
+ if scaling_mode == "delayed":
+ fp8_recipe = recipe.DelayedScaling(
+ margin=0,
+ fp8_format=recipe.Format.HYBRID,
+ amax_history_len=1,
+ amax_compute_algo="most_recent",
+ fp8_dpa=True,
+ fp8_mha=True,
+ )
+ elif scaling_mode == "current":
+ fp8_recipe = recipe.Float8CurrentScaling(
+ fp8_format=recipe.Format.HYBRID,
+ fp8_dpa=True,
+ fp8_mha=True,
+ )
+ fp8_meta = {}
+ fp8_meta["recipe"] = fp8_recipe
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout=qkv_format.replace("hd", "h3d"),
+ fp8=True,
+ fp8_meta=fp8_meta,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
- # Skip if only unfused backend is supported
- if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
- pytest.skip("Less than two backends to compare.")
+ if flash_attn_supported + fused_attn_supported < 1:
+ pytest.skip("No FP8 attention backend available.")
if not fp8_dpa_bwd:
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
@@ -1833,7 +1934,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
- dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
+ dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
os.environ["NVTE_FLASH_ATTN"] = "0"
@@ -1841,20 +1942,21 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
- dtype, config, True, qkv_format, input_layernorm, RoPE, is_training
+ dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
- dtype, config, False, qkv_format, input_layernorm, RoPE, is_training
+ dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
)
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.15
- logging.debug("========== {:^25s} ==========".format("forward output"))
if flash_attn_supported:
- _error(
+ logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
+ logging.debug("========== {:^25s} ==========".format("forward output"))
+ compare_and_assert(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
@@ -1862,8 +1964,11 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
atol,
rtol,
rmse_tol,
+ True,
)
- _error(
+ logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
+ logging.debug("========== {:^25s} ==========".format("forward output"))
+ compare_and_assert(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
@@ -1871,12 +1976,13 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
atol,
rtol,
rmse_tol,
+ True,
)
if is_training:
for i in range(len(param_names[:1])):
logging.debug("========== {:^25s} ==========".format(param_names[i]))
- _error(
+ compare_and_assert(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
@@ -1884,10 +1990,14 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd,
atol,
rtol,
rmse_tol,
+ True,
)
-def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training):
+def _run_mha_fp8_vs_f16(
+ dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
+):
+ """Run MultiHeadAttention module in FP8"""
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
@@ -1896,15 +2006,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
- fp8_recipe = recipe.DelayedScaling(
- margin=0,
- fp8_format=recipe.Format.HYBRID,
- amax_history_len=1,
- amax_compute_algo="most_recent",
- fp8_dpa=fp8_mha,
- fp8_mha=fp8_mha,
- )
-
with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe):
rotary_pos_emb = None
if RoPE:
@@ -2014,7 +2115,9 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
@pytest.mark.parametrize("fp8_dpa_bwd", [True, False])
@pytest.mark.parametrize("is_training", [True, False])
-def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
+@pytest.mark.parametrize("scaling_mode", ["delayed", "current"])
+def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode):
+ """Test DotProductAttention module in FP8"""
config = model_configs_fp8_vs_f16[model]
# TODO(cyang): think of another way to verify dropout results
@@ -2029,16 +2132,33 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1"
+ os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1"
# Test backend availability
+ if scaling_mode == "delayed":
+ fp8_recipe = recipe.DelayedScaling(
+ margin=0,
+ fp8_format=recipe.Format.HYBRID,
+ amax_history_len=1,
+ amax_compute_algo="most_recent",
+ fp8_dpa=True,
+ )
+ elif scaling_mode == "current":
+ fp8_recipe = recipe.Float8CurrentScaling(
+ fp8_format=recipe.Format.HYBRID,
+ fp8_dpa=True,
+ )
+ fp8_meta = {}
+ fp8_meta["recipe"] = fp8_recipe
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout=qkv_layout,
+ fp8=True,
+ fp8_meta=fp8_meta,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
- # Skip if only unfused backend is supported
if flash_attn_supported + fused_attn_supported < 1:
pytest.skip("No FP8 attention backend available.")
if not fp8_dpa_bwd:
@@ -2058,33 +2178,45 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
- logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
+ logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
- dtype, config, True, qkv_layout, is_training
+ dtype, config, True, qkv_layout, is_training, fp8_recipe
+ )
+
+ if unfused_attn_supported:
+ os.environ["NVTE_FLASH_ATTN"] = "0"
+ os.environ["NVTE_FUSED_ATTN"] = "0"
+ _attention_backends["backend_selection_requires_update"] = True
+ logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)")
+ unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
+ dtype, config, True, qkv_layout, is_training, fp8_recipe
)
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
- logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
+ logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
- dtype, config, True, qkv_layout, is_training
+ dtype, config, True, qkv_layout, is_training, fp8_recipe
)
+ os.environ["NVTE_FLASH_ATTN"] = "0"
+ os.environ["NVTE_FUSED_ATTN"] = "1"
if config.dropout_p == 0.0:
# test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
- logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
+ logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)")
fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
- dtype, config, False, qkv_layout, is_training
+ dtype, config, False, qkv_layout, is_training, fp8_recipe
)
atol = 5e-1
rtol = 5e-2
rmse_tol = 0.11
bwd_names = ["dq", "dk", "dv"]
- logging.debug("========== {:^25s} ==========".format("forward output"))
if flash_attn_supported:
- _error(
+ logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
+ logging.debug("========== {:^25s} ==========".format("forward output"))
+ compare_and_assert(
flash_attn_fwd_fp8,
fused_attn_fwd_f16,
"flash_attn_fwd_fp8",
@@ -2092,14 +2224,43 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
atol,
rtol,
rmse_tol,
+ True,
)
+ if unfused_attn_supported:
+ logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:"))
+ logging.debug("========== {:^25s} ==========".format("forward output"))
+ compare_and_assert(
+ unfused_attn_fwd_fp8,
+ fused_attn_fwd_f16,
+ "unfused_attn_fwd_fp8",
+ "fused_attn_fwd_f16",
+ atol,
+ rtol,
+ rmse_tol,
+ True,
+ )
+ if is_training:
+ for i, _ in enumerate(fused_attn_bwd_f16):
+ logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
+ compare_and_assert(
+ unfused_attn_bwd_fp8[i],
+ fused_attn_bwd_f16[i],
+ f"unfused_attn_bwd_fp8[{i}]",
+ f"fused_attn_bwd_f16[{i}]",
+ atol,
+ rtol,
+ rmse_tol,
+ True,
+ )
if config.dropout_p != 0.0:
# test cuDNN FP8 dropout
assert torch.all(
fused_attn_fwd_fp8 == 1
), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s."
else:
- _error(
+ logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
+ logging.debug("========== {:^25s} ==========".format("forward output"))
+ compare_and_assert(
fused_attn_fwd_fp8,
fused_attn_fwd_f16,
"fused_attn_fwd_fp8",
@@ -2107,11 +2268,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
atol,
rtol,
rmse_tol,
+ True,
)
if is_training:
for i, _ in enumerate(fused_attn_bwd_f16):
logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
- _error(
+ compare_and_assert(
fused_attn_bwd_fp8[i],
fused_attn_bwd_f16[i],
f"fused_attn_bwd_fp8[{i}]",
@@ -2119,11 +2281,13 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training):
atol,
rtol,
rmse_tol,
+ True,
)
+ os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0"
-def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
-
+def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_recipe):
+ """Run DotProductAttention module in FP8"""
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
@@ -2132,14 +2296,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
- fp8_recipe = recipe.DelayedScaling(
- margin=0,
- fp8_format=recipe.Format.HYBRID,
- amax_history_len=1,
- amax_compute_algo="most_recent",
- fp8_dpa=fp8_dpa,
- )
-
qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
with fp8_model_init(enabled=fp8_dpa):
dpa = DotProductAttention(
@@ -2246,6 +2402,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
+ fp8_output=fp8_dpa,
)
if is_training:
out.backward(out_grad)
@@ -2256,7 +2413,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
model_configs_fp8 = {
- # test: b, h, hg, d, sq, skv, p, mask, bias
+ # test: ModelConfig(b, sq, hq, dqk)
"fp8_1": ModelConfig(1, 512, 1, 64),
"fp8_2": ModelConfig(4, 512, 16, 64),
"fp8_3": ModelConfig(1, 2048, 1, 128),
@@ -2312,7 +2469,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
atol = 5e-1
rtol = 5e-1
rmse_tol = 0.13
- _error(
+ compare_and_assert(
fused_attn_fwd_fp8,
unfused_attn_fwd_f16,
"fused_attn_fwd_fp8",
@@ -2320,8 +2477,9 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
atol,
rtol,
rmse_tol,
+ True,
)
- _error(
+ compare_and_assert(
fused_attn_bwd_fp8,
unfused_attn_bwd_f16,
"fused_attn_bwd_fp8",
@@ -2329,6 +2487,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
atol,
rtol,
rmse_tol,
+ True,
)
diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py
index ece5a37de..c9fc86878 100644
--- a/tests/pytorch/attention/test_attention_with_cp.py
+++ b/tests/pytorch/attention/test_attention_with_cp.py
@@ -8,6 +8,7 @@
import subprocess
import sys
import pathlib
+import logging
import pytest
import torch
@@ -16,19 +17,27 @@
get_device_compute_capability,
get_cudnn_version,
)
+from transformer_engine.common.recipe import (
+ DelayedScaling,
+ Float8CurrentScaling,
+)
from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends
+pytest_logging_level = logging.getLevelName(logging.root.level)
+
# Initialize RNG state
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
+test_essential = True
+
model_configs_flash_attn = {
- # test: b, h, hg, d, sq, skv, p, mask, bias
+ # test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA
@@ -63,18 +72,31 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
return args
+dtypes = ["bf16", "fp16"]
+qkv_formats = ["bshd", "sbhd", "thd"]
+cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
+if test_essential:
+ configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"]
+ model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs}
+ dtypes = ["bf16"]
+ qkv_formats = ["sbhd", "thd"]
+
+
@pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
-@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
+@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
-@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
-@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
+@pytest.mark.parametrize("qkv_format", qkv_formats)
+@pytest.mark.parametrize("cp_comm_type", cp_comm_types)
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")
config = model_configs_flash_attn[model]
+ config.context_parallel = True
+ config.cp_comm_type = cp_comm_type
+
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
@@ -104,12 +126,13 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
qkv_format=qkv_format,
kernel_backend="FlashAttention",
cp_comm_type=cp_comm_type,
+ log_level=pytest_logging_level,
),
check=True,
)
model_configs_fused_attn = {
- # test: b, h, hg, d, sq, skv, p, mask, bias
+ # test: ModelConfig(b, sq, hq, dqk)
"cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA
"cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA
"cp_1_2": ModelConfig(
@@ -140,17 +163,42 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64
), # MLA
"cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA
+ "cp_4_0": ModelConfig(
+ 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla"
+ ), # GQA
+ "cp_4_1": ModelConfig(
+ 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one"
+ ), # GQA
+ "cp_4_2": ModelConfig(
+ 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable"
+ ), # GQA
}
+dtypes = ["bf16", "fp16", "fp8"]
+qkv_formats = ["bshd", "sbhd", "thd"]
+cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"]
+if test_essential:
+ configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"]
+ model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs}
+ dtypes = ["bf16", "fp8"]
+ qkv_formats = ["sbhd", "thd"]
+
+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
-@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
+@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
-@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
-@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
-@pytest.mark.parametrize("fp8_mha", [False, True])
-def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha):
+@pytest.mark.parametrize("qkv_format", qkv_formats)
+@pytest.mark.parametrize("cp_comm_type", cp_comm_types)
+@pytest.mark.parametrize("fp8_bwd", [True, False])
+@pytest.mark.parametrize("fp8_mha", [True, False])
+@pytest.mark.parametrize("fp8_dpa", [True, False])
+@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"])
+@pytest.mark.parametrize("f16_O", [True, False])
+def test_cp_with_fused_attention(
+ dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O
+):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")
@@ -161,10 +209,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if (not IS_HIP_EXTENSION) and dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!")
+ if dtype == "fp8" and not fp8_dpa and fp8_mha:
+ pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!")
+ if dtype != "fp8" and fp8_bwd:
+ pytest.skip("Only fp8 works with fp8_bwd=True!")
if IS_HIP_EXTENSION and dtype == "fp8":
pytest.skip("FP8 attention has not been supported on ROCm yet!")
config = model_configs_fused_attn[model]
+ config.context_parallel = True
+ config.cp_comm_type = cp_comm_type
+
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
@@ -192,19 +247,57 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)
- if dtype != "fp8" and fp8_mha:
- pytest.skip("Only fp8 works with fp8_mha=True!")
+ if dtype != "fp8" and (fp8_mha or fp8_dpa):
+ pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!")
+ if dtype == "fp8" and not (fp8_mha or fp8_dpa):
+ pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!")
+ if dtype != "fp8" and scaling_mode is not None:
+ pytest.skip("Only fp8 works with scaling_mode != None!")
+ if dtype == "fp8" and scaling_mode is None:
+ pytest.skip("fp8 only works with scaling_mode != None!")
+ if (
+ dtype == "fp8"
+ and scaling_mode == "current"
+ and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"]
+ ):
+ pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!")
+ if f16_O and (dtype != "fp8" or scaling_mode != "current"):
+ pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!")
if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently only support KV P2P!")
if dtype == "fp8" and config.head_dim_qk != config.head_dim_v:
pytest.skip("MLA CP currently does not support FP8 attention!")
+ if dtype == "fp8" and config.softmax_type != "vanilla":
+ pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!")
+ if config.softmax_type != "vanilla" and cp_comm_type != "a2a":
+ pytest.skip(
+ "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
+ )
+ if config.softmax_type != "vanilla" and qkv_format == "thd":
+ pytest.skip(
+ "CP implementation does not support qkv_format=thd for non-vanilla softmax types!"
+ )
+
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
+ fp8_meta = {}
+ fp8_meta["recipe"] = None
+ fp8_meta["local_recipes"] = []
+ fp8 = dtype == "fp8" and (fp8_dpa or fp8_mha)
+ if fp8 and scaling_mode == "delayed":
+ fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True)
+ fp8_meta["local_recipes"] = [DelayedScaling(fp8_dpa=True)]
+ if fp8 and scaling_mode == "current":
+ fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True)
+ fp8_meta["local_recipes"] = [
+ Float8CurrentScaling(fp8_dpa=True),
+ DelayedScaling(fp8_dpa=True),
+ ]
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
- qkv_dtype=dtypes[dtype],
+ qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn,
qkv_layout="_".join([qkv_format] * 3),
- window_size=config.window_size,
- context_parallel=True,
+ fp8=fp8,
+ fp8_meta=fp8_meta,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
@@ -218,7 +311,12 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
qkv_format=qkv_format,
kernel_backend="FusedAttention",
cp_comm_type=cp_comm_type,
+ fp8_bwd=fp8_bwd,
+ fp8_dpa=fp8_dpa,
fp8_mha=fp8_mha,
+ scaling_mode=scaling_mode,
+ f16_O=f16_O,
+ log_level=pytest_logging_level,
),
check=True,
)
diff --git a/tests/pytorch/attention/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py
index af71866f3..8bd8906c5 100644
--- a/tests/pytorch/attention/test_kv_cache.py
+++ b/tests/pytorch/attention/test_kv_cache.py
@@ -483,7 +483,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
- window_size=config.window_size,
pad_between_seqs=False,
is_training=False,
fp8=is_fp8,
diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py
index 8a201b72d..f68ddce7f 100644
--- a/tests/pytorch/distributed/run_numerics.py
+++ b/tests/pytorch/distributed/run_numerics.py
@@ -11,6 +11,7 @@
import os
import sys
from functools import wraps
+import math
import transformer_engine.pytorch as te
import torch
@@ -23,10 +24,15 @@
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
+ NVFP4BlockScaling,
Format,
Recipe,
+ QParams,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
+from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
+from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
+from transformer_engine.pytorch.distributed import gather_along_first_dim
from run_layer_with_overlap import _compare_tensors
if IS_HIP_EXTENSION:
@@ -53,6 +59,14 @@
)
+def nvfp4_vanilla():
+ nvfp4_recipe = NVFP4BlockScaling()
+ nvfp4_recipe.fp4_quant_fwd_inp = QParams()
+ nvfp4_recipe.fp4_quant_fwd_weight = QParams()
+ nvfp4_recipe.fp4_quant_bwd_grad = QParams()
+ return nvfp4_recipe
+
+
# Quantization recipe setup
def quantization_recipe() -> Recipe:
if QUANTIZATION == "fp8":
@@ -65,6 +79,8 @@ def quantization_recipe() -> Recipe:
return Float8CurrentScaling()
if QUANTIZATION == "fp8_block_scaling":
return Float8BlockScaling()
+ if QUANTIZATION == "nvfp4":
+ return nvfp4_vanilla()
return te.fp8.get_default_fp8_recipe()
@@ -113,10 +129,14 @@ def main(argv=None, namespace=None):
# Quantization scheme
QUANTIZATION = args.quantization
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
- if QUANTIZATION in ("fp8", "mxfp8"):
+ if QUANTIZATION in ("fp8", "mxfp8", "nvfp4"):
SEQ_LEN = 32
BATCH_SIZE = 32
HIDDEN_SIZE = 128
+ # For fp8 block scaling, block size is 128,
+ # and to make low precision TP work, input tensor
+ # must be 128x128 divisible to be eligible for
+ # low precision All-Gather when needed
elif QUANTIZATION == "fp8_block_scaling":
SEQ_LEN = 128
BATCH_SIZE = 128
@@ -124,6 +144,7 @@ def main(argv=None, namespace=None):
test_dict = [
test_quantizer,
+ test_quantized_all_gather,
test_linear,
test_layernorm,
test_layernorm_linear,
@@ -193,6 +214,9 @@ def _get_tolerances(dtype):
# row parallel & sequence parallel, because we do the all_gather in backward pass
if QUANTIZATION == "fp8_cs":
return {"rtol": 0.4, "atol": 0.25}
+ elif QUANTIZATION == "nvfp4":
+ # TODO(zhongboz): investigate why the tolerance is so large
+ return {"rtol": 0.125, "atol": 0.12}
elif QUANTIZATION is not None:
return {"rtol": 0.125, "atol": 0.0625}
@@ -348,24 +372,36 @@ def _alloc_main_grad(model_single_node, model_distributed):
###############################################
# Quantizer #
###############################################
-def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size):
+def _construct_quantizer(quantizer_class, low_precision_dtype, device, tp_group, tp_size):
"""
quantizer is the reference quantizer on a single GPU.
quantizer_dist is the distributed quantizer to be tested on multiple GPUs.
"""
if quantizer_class == Float8CurrentScalingQuantizer:
quantizer_dist = quantizer_class(
- fp8_dtype=fp8_dtype,
+ fp8_dtype=low_precision_dtype,
device=device,
with_amax_reduction=True,
amax_reduction_group=tp_group,
)
quantizer = quantizer_class(
- fp8_dtype=fp8_dtype,
+ fp8_dtype=low_precision_dtype,
device=device,
with_amax_reduction=False,
)
return quantizer, quantizer_dist
+ elif quantizer_class == NVFP4Quantizer:
+ quantizer_dist = quantizer_class(
+ fp4_dtype=low_precision_dtype,
+ with_amax_reduction=True,
+ amax_reduction_group=tp_group,
+ )
+ quantizer = quantizer_class(
+ fp4_dtype=low_precision_dtype,
+ with_amax_reduction=False,
+ amax_reduction_group=None,
+ )
+ return quantizer, quantizer_dist
else:
raise ValueError(f"Unsupported quantizer class: {quantizer_class}")
@@ -436,6 +472,194 @@ def test_quantizer():
_test_quantizer(input_dtype, fp8_dtype)
+############################################
+# Quantized All-Gather #
+############################################
+
+
+def _ref_zero_padding_scale_inv(scale_inv, unpadded_shape):
+ """
+ Zero padding the scale_inv.
+ scale_inv shape is the padded shape, but not zero padded
+ unpadded_shape is the original shape before padding
+ """
+ dim0, dim1 = scale_inv.shape
+ unpadded_dim0, unpadded_dim1 = unpadded_shape
+ pad_dim0 = (128 - unpadded_dim0 % 128) % 128
+ pad_dim1 = (4 - unpadded_dim1 % 4) % 4
+ new_dim0 = unpadded_dim0 + pad_dim0
+ new_dim1 = unpadded_dim1 + pad_dim1
+
+ assert dim0 == new_dim0
+ assert dim1 == new_dim1
+
+ # return input if no padding is needed
+ if pad_dim0 == 0 and pad_dim1 == 0:
+ return scale_inv
+
+ # unpad first to remove random bits from torch empty
+ scale_inv = scale_inv[:unpadded_dim0, :unpadded_dim1].contiguous()
+ # using torch padding
+ new_scale_inv = torch.nn.functional.pad(
+ scale_inv, (0, pad_dim1, 0, pad_dim0), mode="constant", value=0
+ )
+
+ assert new_scale_inv.shape == (new_dim0, new_dim1)
+
+ return new_scale_inv
+
+
+def _get_unpadded_scale_inv_shape(input_shape, quantizer_cls, columnwise):
+ """
+ Calculate the unpadded shape of the scale_inv tensor.
+ """
+ M, K = 1, 1
+ M = math.prod(input_shape[:-1])
+ K = input_shape[-1]
+
+ if quantizer_cls == NVFP4Quantizer:
+ if columnwise:
+ outer = K
+ inner = math.ceil(M / NVFP4_BLOCK_SCALING_SIZE)
+ return (outer, inner)
+ else:
+ outer = M
+ inner = math.ceil(K / NVFP4_BLOCK_SCALING_SIZE)
+ return (outer, inner)
+ else:
+ raise ValueError(f"Unsupported quantizer class: {quantizer_cls}")
+
+
+@run_distributed_test()
+def _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls):
+ """Test the quantizer under distributed settings.
+
+ Args:
+ input_dtype (torch.dtype): The data type of the input.
+ low_precision_dtype (tex.DType): The data type of the low precision, can be fp4 or fp8.
+ """
+
+ M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE // 2
+
+ # high precision input
+ x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype)
+ # set one element of the input to a very large value, which doesn't live in rank 0 after the split
+ # to test the amax reduction on purpose
+ # x_hp_cpu[M - 1, N - 1] = 1e4
+
+ # get the unpadded shapes
+ unpadded_rowwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, False)
+ unpadded_columnwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, True)
+
+ # rank 0 takes the full copy and quantize with GPU 0 for verification
+ if WORLD_RANK == 0:
+ x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda")
+ x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK]
+
+ # Create quantizers
+ quantizer, quantizer_dist = _construct_quantizer(
+ quantizer_cls, low_precision_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE
+ )
+
+ # quantize the entire input
+ if WORLD_RANK == 0:
+ x_low_precision_single = quantizer(x_hp_rank0)
+
+ # run all-gather with a quantizer as input for quantized all-gather
+ x_low_precision_total, _ = gather_along_first_dim(
+ x_hp_local_rank, NCCL_WORLD, async_op=False, quantizer=quantizer_dist
+ )
+
+ # check the outputs
+ if WORLD_RANK == 0:
+ # assert all data and scale_inv are the same
+ torch.testing.assert_close(
+ x_low_precision_single._rowwise_data,
+ x_low_precision_total._rowwise_data,
+ rtol=0.0,
+ atol=0.0,
+ )
+ # check the rowwise scale without any padding
+ unpad_dim0, unpad_dim1 = unpadded_rowwise_scale_inv_shape
+ unpadded_rowwise_scale_inv_ref = x_low_precision_single._rowwise_scale_inv[
+ :unpad_dim0, :unpad_dim1
+ ]
+ unpadded_rowwise_scale_inv = x_low_precision_total._rowwise_scale_inv[
+ :unpad_dim0, :unpad_dim1
+ ]
+ torch.testing.assert_close(
+ unpadded_rowwise_scale_inv_ref,
+ unpadded_rowwise_scale_inv,
+ rtol=0.0,
+ atol=0.0,
+ )
+ torch.testing.assert_close(
+ _ref_zero_padding_scale_inv(
+ x_low_precision_single._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape
+ ),
+ _ref_zero_padding_scale_inv(
+ x_low_precision_total._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape
+ ),
+ rtol=0.0,
+ atol=0.0,
+ )
+ torch.testing.assert_close(
+ x_low_precision_single._columnwise_data,
+ x_low_precision_total._columnwise_data,
+ rtol=0.0,
+ atol=0.0,
+ )
+ unpad_dim0, unpad_dim1 = unpadded_columnwise_scale_inv_shape
+ unpadded_columnwise_scale_inv_ref = x_low_precision_single._columnwise_scale_inv[
+ :unpad_dim0, :unpad_dim1
+ ]
+ unpadded_columnwise_scale_inv = x_low_precision_total._columnwise_scale_inv[
+ :unpad_dim0, :unpad_dim1
+ ]
+ torch.testing.assert_close(
+ unpadded_columnwise_scale_inv_ref,
+ unpadded_columnwise_scale_inv,
+ rtol=0.0,
+ atol=0.0,
+ )
+ torch.testing.assert_close(
+ _ref_zero_padding_scale_inv(
+ x_low_precision_single._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape
+ ),
+ _ref_zero_padding_scale_inv(
+ x_low_precision_total._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape
+ ),
+ rtol=0.0,
+ atol=0.0,
+ )
+
+
+def test_quantized_all_gather():
+ """
+ Run quantized all-gather tests with various configurations.
+ """
+ # skip this test for other quantization schemes
+ is_nvfp4 = QUANTIZATION == "nvfp4"
+ # add other recipes for testing if needed
+ if not is_nvfp4:
+ return
+
+ input_dtypes = [torch.bfloat16]
+ fp4_dtype = [tex.DType.kFloat4E2M1]
+ fp8_dtype = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]
+ quantizer_cls_nvfp4 = [NVFP4Quantizer]
+ # add FP8 quantizers if needed
+ quantizer_cls_fp8 = []
+
+ low_precisio_dtypes = fp4_dtype if is_nvfp4 else fp8_dtype
+ quantizer_cls_list = quantizer_cls_nvfp4 if is_nvfp4 else quantizer_cls_fp8
+
+ for quantizer_cls in quantizer_cls_list:
+ for input_dtype in input_dtypes:
+ for low_precision_dtype in low_precisio_dtypes:
+ _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls)
+
+
############################################
# Linear #
############################################
@@ -536,10 +760,11 @@ def test_linear():
{"init_method": _constant},
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
- {"params_dtype": torch.float16},
+ {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"delay_wgrad_compute": True},
{"save_original_input": True},
]
+
for kwargs in kwargs_list:
if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
continue
@@ -715,11 +940,12 @@ def test_layernorm_linear():
{"init_method": _constant},
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
- {"params_dtype": torch.float16},
+ {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"zero_centered_gamma": False},
{"return_layernorm_output": True},
{"delay_wgrad_compute": True},
]
+
for kwargs in kwargs_list:
for parallel_mode in ["column"]:
for sequence_parallel in [False, True]:
@@ -821,7 +1047,7 @@ def test_layernorm_mlp():
{"normalization": "RMSNorm"},
{"zero_centered_gamma": True},
{"bias": False},
- {"params_dtype": torch.float16},
+ {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"activation": "relu"},
{"fuse_wgrad_accumulation": True},
{"return_bias": True},
@@ -919,7 +1145,7 @@ def test_transformer_layer():
{"fuse_qkv_params": True, "fuse_wgrad_accumulation": True},
{"qkv_weight_interleaved": False},
{"bias": False},
- {"params_dtype": torch.float16},
+ {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16},
{"fuse_qkv_params": True},
{"activation": "relu"},
]
diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py
new file mode 100644
index 000000000..b1722b79a
--- /dev/null
+++ b/tests/pytorch/distributed/run_numerics_exact.py
@@ -0,0 +1,718 @@
+#!/usr/bin/python3
+
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import argparse
+import datetime
+import os
+import sys
+from functools import wraps
+import math
+
+import transformer_engine.pytorch as te
+import torch
+from torch import nn
+import torch.distributed as dist
+import transformer_engine_torch as tex
+from transformer_engine.common.recipe import (
+ NVFP4BlockScaling,
+ Format,
+ Recipe,
+ QParams,
+)
+from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
+from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
+from run_layer_with_overlap import _compare_tensors
+
+
+BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE = 128, 256, 128
+WORLD_RANK, WORLD_SIZE = None, None
+NCCL_WORLD = None
+LOSS_FN = nn.MSELoss()
+QUANTIZATION = None
+
+
+def nvfp4_rht_and_2d_quantization():
+ nvfp4_recipe = NVFP4BlockScaling()
+ nvfp4_recipe.fp4_quant_fwd_inp = QParams(
+ random_hadamard_transform=True, fp4_2d_quantization=False
+ )
+ nvfp4_recipe.fp4_quant_fwd_weight = QParams(
+ random_hadamard_transform=False, fp4_2d_quantization=True
+ )
+ nvfp4_recipe.fp4_quant_bwd_grad = QParams(
+ random_hadamard_transform=True, fp4_2d_quantization=False
+ )
+ return nvfp4_recipe
+
+
+# Quantization recipe setup
+def quantization_recipe() -> Recipe:
+ if QUANTIZATION == "nvfp4":
+ return nvfp4_rht_and_2d_quantization()
+ raise ValueError(f"Unsupported quantization: {QUANTIZATION}")
+
+
+def setup_environment_for_reference():
+ if QUANTIZATION == "nvfp4":
+ os.environ["QAT_PARAMS"] = "9003"
+ else:
+ raise ValueError(f"Unsupported quantization for reference: {QUANTIZATION}")
+
+
+def cleanup_environment():
+ if "QAT_PARAMS" in os.environ:
+ del os.environ["QAT_PARAMS"]
+
+
+def main(argv=None, namespace=None):
+ global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, QUANTIZATION, BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE
+
+ WORLD_RANK = int(os.getenv("RANK", "0"))
+ WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
+ LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
+ LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
+
+ assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node
+ assert LOCAL_SIZE <= torch.cuda.device_count()
+ dist_init_kwargs = {
+ "backend": "nccl",
+ "rank": WORLD_RANK,
+ "world_size": WORLD_SIZE,
+ "timeout": datetime.timedelta(seconds=30),
+ }
+ dist_init_kwargs["init_method"] = "env://"
+ dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
+ assert dist.is_nccl_available()
+ torch.cuda.set_device(LOCAL_RANK)
+ dist.init_process_group(**dist_init_kwargs)
+
+ NCCL_WORLD = dist.new_group(backend="nccl")
+
+ WORLD_SIZE = dist.get_world_size()
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--quantization", type=str, default=None)
+ parser.add_argument("--batch-size", type=int, default=32)
+ parser.add_argument("--hidden-size", type=int, default=128)
+ parser.add_argument("--out-size", type=int, default=128)
+ args = parser.parse_args(argv, namespace)
+
+ # Quantization scheme
+ QUANTIZATION = args.quantization
+ BATCH_SIZE = args.batch_size
+ HIDDEN_SIZE = args.hidden_size
+ OUT_SIZE = args.out_size
+
+ test_dict = [
+ test_linear,
+ test_layernorm_linear,
+ ]
+
+ for test in test_dict:
+ test()
+ dist.destroy_process_group()
+ return 0
+
+
+def run_distributed_test(test_name=None):
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ name = test_name if test_name is not None else func.__name__
+
+ dist_print(f"Starting test {name} with args {args} and {kwargs}")
+ torch.cuda.set_device(WORLD_RANK)
+ torch.manual_seed(12345)
+ torch.cuda.manual_seed(12345)
+ func(*args, **kwargs)
+
+ dist.barrier()
+ dist_print(f"Passed test {name}")
+
+ return wrapper
+
+ return decorator
+
+
+def dist_print(msg, src=None, end="\n", error=False):
+ stream = sys.stderr if error else sys.stdout
+ if WORLD_RANK == (0 if src is None else src):
+ stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n")
+
+
+############################################
+# Linear #
+############################################
+class TestDistributedLinearBase:
+ @staticmethod
+ def _prepare_data(
+ batch_size, hidden_size, out_size, use_bias=True, seed=0, dtype=torch.float32
+ ):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda")
+ w = torch.randn((out_size, hidden_size), dtype=dtype, device="cuda")
+ bias = torch.randn((out_size), dtype=dtype, device="cuda") if use_bias else None
+ gradient = torch.randn((batch_size, out_size), dtype=dtype, device="cuda")
+
+ return x, w, bias, gradient
+
+ @staticmethod
+ def _shard_tensor(x, world_size, axis):
+ split_size = x.size()[axis] // world_size
+ split_tensor = torch.split(x, split_size, axis)
+ out = []
+ for tensor in split_tensor:
+ out.append(tensor.detach().clone().requires_grad_(x.requires_grad))
+ return out
+
+ @staticmethod
+ def _gather_tensor(local, world_size, tp_group, concat_dim):
+ out_list = [torch.zeros_like(local) for _ in range(world_size)]
+ torch.distributed.all_gather(out_list, local, tp_group)
+ return torch.cat(out_list, dim=concat_dim)
+
+ @staticmethod
+ def _all_reduce_tensor(local, world_size, tp_group):
+ if world_size == 1:
+ return local
+ handle = torch.distributed.all_reduce(local, group=tp_group, async_op=False)
+ return local
+
+ @staticmethod
+ def _get_sum_abs_error(a, b):
+ return torch.sum(torch.abs(a - b))
+
+ @staticmethod
+ def _get_mean_abs_relative_error(a, b):
+ error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b))
+ return torch.mean(error)
+
+ @classmethod
+ def run_linear_preprocess_parallel(
+ cls,
+ x,
+ w,
+ bias,
+ gradient,
+ parallel_mode=None,
+ sequence_parallel=False,
+ tp_size=1,
+ rank=0,
+ ):
+ if tp_size > 1:
+ if parallel_mode == "column":
+ # split w in N dim, which should be axis 0
+ w = cls._shard_tensor(w, tp_size, 0)[rank]
+ bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None
+ # split gradient in N dim, which should be axis 1
+ gradient = cls._shard_tensor(gradient, tp_size, 1)[rank]
+ if sequence_parallel:
+ # split x in M dim, which should be axis 0
+ x = cls._shard_tensor(x, tp_size, 0)[rank]
+ # row parallel, split x in k dim, which should be axis 1, split w in k dim, should be axis 1
+ if parallel_mode == "row":
+ # split x in K dim, which should be axis 1
+ x = cls._shard_tensor(x, tp_size, 1)[rank]
+ # split w in K dim, which should be axis 1
+ w = cls._shard_tensor(w, tp_size, 1)[rank]
+ if sequence_parallel:
+ # split gradient in M dim, which should be axis 0
+ gradient = cls._shard_tensor(gradient, tp_size, 0)[rank]
+ return x, w, bias, gradient
+
+ @classmethod
+ def run_linear_postprocess_parallel(
+ cls,
+ y_q,
+ dgrad,
+ wgrad,
+ bgrad,
+ parallel_mode,
+ sequence_parallel,
+ tp_size,
+ tp_group,
+ ):
+ if tp_size > 1:
+ if parallel_mode == "column":
+ # gather y_q in N dim, which should be axis 1
+ y_q = cls._gather_tensor(y_q, tp_size, tp_group, 1)
+ # gather wgrad in N dim, which should be axis 0
+ wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 0)
+ # gather bgrad in N dim, which should be axis 0
+ bgrad = (
+ cls._gather_tensor(bgrad, tp_size, tp_group, 0) if bgrad is not None else None
+ )
+ if sequence_parallel:
+ # gather dgrad in M dim, which should be axis 0
+ dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 0)
+ if parallel_mode == "row":
+ # gather dgrad in K dim, which should be axis 1
+ dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 1)
+ # gather wgrad in K dim, which should be axis 1
+ wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 1)
+ if sequence_parallel:
+ # gather y_q in M dim, which should be axis 0
+ y_q = cls._gather_tensor(y_q, tp_size, tp_group, 0)
+ # we need to sum bias gradient when using TP + SP
+ bgrad = (
+ cls._all_reduce_tensor(bgrad, tp_size, tp_group)
+ if bgrad is not None
+ else None
+ )
+
+ return y_q, dgrad, wgrad, bgrad
+
+ @classmethod
+ def run_linear_one_step(
+ cls, layer, x, gradient, is_first_microbatch=None, fuse_wgrad_accumulation=False
+ ):
+ # reset gradients
+ layer.zero_grad()
+ x.grad = None
+
+ # Forward pass
+ if isinstance(layer, te.Linear):
+ # Kitchen Linear
+ y_q = layer.forward(x, is_first_microbatch=is_first_microbatch)
+ else:
+ # the default torch.nn.Linear
+ y_q = layer(x)
+
+ # Backward pass
+ y_q.backward(gradient)
+
+ # Collect gradients
+ dgrad = x.grad
+ bgrad = (
+ layer._parameters["bias"].grad
+ if layer._parameters.get("bias", None) is not None
+ else None
+ )
+ assert "weight" in layer._parameters
+ if fuse_wgrad_accumulation:
+ wgrad = layer._parameters["weight"].main_grad
+ assert layer._parameters["weight"].grad is None
+ else:
+ wgrad = layer._parameters["weight"].grad
+
+ return y_q, dgrad, wgrad, bgrad
+
+ @classmethod
+ def run_linear_multiple_steps(
+ cls,
+ layer,
+ x,
+ gradient,
+ run_num_steps,
+ enable_weight_cache,
+ fuse_wgrad_accumulation=False,
+ ):
+ """
+ Run multiple steps of linear layer and collect results.
+ """
+
+ y_q_list, dgrad_list, wgrad_list = [], [], []
+ bgrad_list = [] if layer._parameters.get("bias", None) is not None else None
+
+ for i in range(run_num_steps):
+ x_i = (x + i).clone().detach().requires_grad_(True)
+ # run_linear_one_step
+ y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(
+ layer,
+ x_i,
+ gradient,
+ is_first_microbatch=(i == 0) if enable_weight_cache else None,
+ fuse_wgrad_accumulation=fuse_wgrad_accumulation,
+ )
+
+ # Collect results
+ y_q_list.append(y_q.detach().clone())
+ dgrad_list.append(dgrad.detach().clone())
+ wgrad_list.append(wgrad.detach().clone())
+ if bgrad_list is not None and bgrad is not None:
+ bgrad_list.append(bgrad.detach().clone())
+
+ # Stack the results
+ return (
+ torch.stack(y_q_list),
+ torch.stack(dgrad_list),
+ torch.stack(wgrad_list),
+ torch.stack(bgrad_list) if bgrad_list is not None else None,
+ )
+
+ @classmethod
+ def run_linear(
+ cls,
+ x,
+ w,
+ bias,
+ gradient,
+ parallel_mode=None,
+ sequence_parallel=False,
+ tp_group=None,
+ tp_size=1,
+ rank=0,
+ run_num_steps=1,
+ enable_weight_cache=False,
+ fuse_wgrad_accumulation=False,
+ ):
+ """
+ If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
+ the reference single GPU run.
+ """
+ # clone inputs and move to current device
+ # w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
+ x = x.clone().detach().requires_grad_(True).to("cuda")
+ w = w.clone().detach().to("cuda")
+ gradient = gradient.clone().detach().to("cuda")
+ bias = bias.clone().detach().to("cuda") if bias is not None else None
+ in_features = x.shape[1]
+ out_features = w.shape[0]
+
+ # If Model parallel: split inputs for a given rank
+ x, w, bias, gradient = cls.run_linear_preprocess_parallel(
+ x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank
+ )
+
+ # set data types
+ params_dtype = x.dtype
+
+ # Create linear layer and copy weights
+ layer = te.Linear(
+ in_features,
+ out_features,
+ bias=bias is not None,
+ params_dtype=params_dtype,
+ parallel_mode=parallel_mode,
+ sequence_parallel=sequence_parallel,
+ tp_group=tp_group,
+ tp_size=tp_size,
+ fuse_wgrad_accumulation=fuse_wgrad_accumulation,
+ )
+
+ layer = layer.to("cuda")
+
+ with torch.no_grad():
+ layer.weight.copy_(w)
+ if bias is not None:
+ layer.bias.copy_(bias)
+
+ if fuse_wgrad_accumulation:
+ assert (
+ run_num_steps > 1
+ ), "Fused weight gradient accumulation requires run_num_steps > 1"
+ layer.weight.main_grad = torch.zeros_like(layer.weight)
+
+ # Run one step or multiple steps
+ if run_num_steps == 1:
+ y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient)
+ else:
+ y_q, dgrad, wgrad, bgrad = cls.run_linear_multiple_steps(
+ layer,
+ x,
+ gradient,
+ run_num_steps,
+ enable_weight_cache,
+ fuse_wgrad_accumulation,
+ )
+
+ # If Model parallel: gather output and gradients from all ranks
+ y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel(
+ y_q,
+ dgrad,
+ wgrad,
+ bgrad,
+ parallel_mode,
+ sequence_parallel,
+ tp_size,
+ tp_group,
+ )
+
+ return y_q, dgrad, wgrad, bgrad
+
+
+@run_distributed_test()
+def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
+ """Test the linear layer with specified parallel mode and sequence parallelization.
+
+ Args:
+ parallel_mode (str): 'row' or 'column' parallelism.
+ sequence_parallel (bool): Enable sequence parallelism if True.
+ kwargs (dict): Additional arguments for the linear layer.
+
+ QUANTIZATION options: nvfp4 <=> experimental nvfp4 as a reference
+ """
+ params_dtype = torch.bfloat16
+ use_bias = kwargs.get("bias", True)
+ fuse_wgrad_accumulation = kwargs.get("fuse_wgrad_accumulation", False)
+ seed = torch.initial_seed()
+ recipe = quantization_recipe()
+
+ # turn on weight quantization cache when fusing wgrad accumulation
+ enable_weight_cache = fuse_wgrad_accumulation
+ run_num_steps = 1 if not fuse_wgrad_accumulation else 5
+
+ x, w, bias, gradient = TestDistributedLinearBase._prepare_data(
+ BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype
+ )
+
+ # run the recipe under test
+ with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
+ y_q, dgrad, wgrad, bgrad = TestDistributedLinearBase.run_linear(
+ x,
+ w,
+ bias,
+ gradient,
+ parallel_mode=parallel_mode,
+ sequence_parallel=sequence_parallel,
+ tp_group=NCCL_WORLD,
+ tp_size=WORLD_SIZE,
+ rank=WORLD_RANK,
+ fuse_wgrad_accumulation=fuse_wgrad_accumulation,
+ run_num_steps=1 if not fuse_wgrad_accumulation else 5,
+ enable_weight_cache=fuse_wgrad_accumulation,
+ )
+
+ # run the reference
+ setup_environment_for_reference()
+ with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
+ y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = TestDistributedLinearBase.run_linear(
+ x,
+ w,
+ bias,
+ gradient,
+ parallel_mode=parallel_mode,
+ sequence_parallel=sequence_parallel,
+ tp_group=NCCL_WORLD,
+ tp_size=WORLD_SIZE,
+ rank=WORLD_RANK,
+ fuse_wgrad_accumulation=fuse_wgrad_accumulation,
+ run_num_steps=run_num_steps,
+ enable_weight_cache=enable_weight_cache,
+ )
+ # Clean up env
+ cleanup_environment()
+
+ # compare results, zero tolerance
+ if WORLD_RANK == 0:
+ torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch")
+ torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch")
+ torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch")
+ if bgrad is not None and bgrad_ref is not None:
+ torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch")
+
+
+def test_linear():
+ """Run linear layer tests with various configurations."""
+ kwargs_list = [
+ {"bias": False},
+ ]
+
+ for kwargs in kwargs_list:
+ if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8":
+ continue
+ for parallel_mode in ["column", "row"]:
+ for sequence_parallel in [False, True]:
+ _test_linear(parallel_mode, sequence_parallel, **kwargs)
+
+
+############################################
+# LayerNormLinear #
+############################################
+class TestDistributedLayerNormLinearBase(TestDistributedLinearBase):
+
+ @classmethod
+ def run_linear_one_step(cls, layer, x, gradient, is_first_microbatch=None):
+ # reset gradients
+ layer.zero_grad()
+ x.grad = None
+
+ # Forward pass
+ y_q, ln_out = layer.forward(x, is_first_microbatch=is_first_microbatch)
+
+ # Backward pass
+ y_q.backward(gradient)
+
+ # Collect gradients
+ dgrad = x.grad
+
+ parameters = layer._parameters
+
+ # bias and weight gradients
+ bgrad = parameters["bias"].grad if parameters.get("bias", None) is not None else None
+ assert "weight" in parameters
+ wgrad = parameters["weight"].grad
+
+ return y_q, ln_out, dgrad, wgrad, bgrad
+
+ @classmethod
+ def run_linear_multiple_steps(
+ cls, layer, x, gradient, run_num_steps, enable_weight_cache, fuse_wgrad_accumulation=False
+ ):
+ # raise error, no test case for multiple steps for now
+ raise NotImplementedError("LayerNormLinear does not support test multiple steps for now")
+
+ @classmethod
+ def run_layernorm_linear(
+ cls,
+ x,
+ w,
+ bias,
+ gradient,
+ parallel_mode=None,
+ sequence_parallel=False,
+ tp_group=None,
+ tp_size=1,
+ rank=0,
+ run_num_steps=1,
+ enable_weight_cache=False,
+ LayerNormLinearClass=te.LayerNormLinear,
+ normalization="LayerNorm",
+ ):
+ """
+ If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with
+ the reference single GPU run.
+ """
+ # clone inputs and move to current device
+ # w has shape [N, K], x has shape [M, K], gradient has shape [M, N]
+ x = x.clone().detach().requires_grad_(True).to("cuda")
+ w = w.clone().detach().to("cuda")
+ gradient = gradient.clone().detach().to("cuda")
+ bias = bias.clone().detach().to("cuda") if bias is not None else None
+ in_features = x.shape[1]
+ out_features = w.shape[0]
+
+ # If Model parallel: split inputs for a given rank
+ x, w, bias, gradient = cls.run_linear_preprocess_parallel(
+ x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank
+ )
+
+ # set data types
+ params_dtype = x.dtype
+
+ # Create linear layer and copy weights
+ layer = LayerNormLinearClass(
+ in_features,
+ out_features,
+ bias=bias is not None,
+ params_dtype=params_dtype,
+ parallel_mode=parallel_mode,
+ sequence_parallel=sequence_parallel,
+ tp_group=tp_group,
+ tp_size=tp_size,
+ normalization=normalization,
+ return_layernorm_output=True,
+ )
+
+ layer = layer.to("cuda")
+
+ # Copy weights
+ # kitchen_linear has different parameter names
+ with torch.no_grad():
+ layer.weight.copy_(w)
+ if bias is not None:
+ layer.bias.copy_(bias)
+
+ # Run one step
+ y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient)
+
+ # If Model parallel: gather output and gradients from all ranks
+ y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel(
+ y_q,
+ dgrad,
+ wgrad,
+ bgrad,
+ parallel_mode,
+ sequence_parallel,
+ tp_size,
+ tp_group,
+ )
+
+ return y_q, ln_out, dgrad, wgrad, bgrad
+
+
+@run_distributed_test()
+def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
+ """Test the linear layer with specified parallel mode and sequence parallelization.
+
+ Args:
+ parallel_mode (str): 'column' parallelism.
+ sequence_parallel (bool): Enable sequence parallelism if True.
+ kwargs (dict): Additional arguments for the linear layer.
+ """
+ params_dtype = torch.bfloat16
+ use_bias = kwargs.get("bias", True)
+ seed = torch.initial_seed()
+ recipe = quantization_recipe()
+
+ # run multiple steps currently not supported for LayerNormLinear
+ run_num_steps = 1
+
+ x, w, bias, gradient = TestDistributedLayerNormLinearBase._prepare_data(
+ BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype
+ )
+
+ # run the recipe under test
+ with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
+ y_q, ln_out, dgrad, wgrad, bgrad = TestDistributedLayerNormLinearBase.run_layernorm_linear(
+ x,
+ w,
+ bias,
+ gradient,
+ parallel_mode=parallel_mode,
+ sequence_parallel=sequence_parallel,
+ tp_group=NCCL_WORLD,
+ tp_size=WORLD_SIZE,
+ rank=WORLD_RANK,
+ run_num_steps=run_num_steps,
+ enable_weight_cache=False,
+ )
+
+ # run the reference
+ setup_environment_for_reference()
+ with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
+ y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = (
+ TestDistributedLayerNormLinearBase.run_layernorm_linear(
+ x,
+ w,
+ bias,
+ gradient,
+ parallel_mode=parallel_mode,
+ sequence_parallel=sequence_parallel,
+ tp_group=NCCL_WORLD,
+ tp_size=WORLD_SIZE,
+ rank=WORLD_RANK,
+ run_num_steps=run_num_steps,
+ enable_weight_cache=False,
+ )
+ )
+ # Clean up env
+ cleanup_environment()
+
+ # compare results, zero tolerance
+ if WORLD_RANK == 0:
+ torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch")
+ torch.testing.assert_close(ln_out, ln_out_ref, atol=0, rtol=0, msg="LN output mismatch")
+ torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch")
+ torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch")
+ if bgrad is not None and bgrad_ref is not None:
+ torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch")
+
+
+def test_layernorm_linear():
+ kwargs_list = [
+ {"bias": False},
+ ]
+
+ for kwargs in kwargs_list:
+ for parallel_mode in ["column"]:
+ for sequence_parallel in [False, True]:
+ _test_layernorm_linear(parallel_mode, sequence_parallel, **kwargs)
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py
index 6dc17b126..83acd6df7 100644
--- a/tests/pytorch/distributed/test_fusible_ops.py
+++ b/tests/pytorch/distributed/test_fusible_ops.py
@@ -29,6 +29,7 @@
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
+from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.utils import is_bf16_compatible, is_fp8_fnuz
import transformer_engine_torch as tex
@@ -36,17 +37,20 @@
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
-from utils import dtype_tols, make_recipe
+from utils import dtype_tols, make_recipe, quantization_tols
# Check what quantization schemes are supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
+nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_mxfp8_available()
quantization_list: list[Optional[str]] = [None]
if fp8_available:
quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
quantization_list.append("mxfp8")
+if nvfp4_available:
+ quantization_list.append("nvfp4")
@functools.cache
@@ -117,6 +121,14 @@ def make_reference_and_test_tensors(
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
+ elif quantization == "nvfp4":
+ test = NVFP4Quantizer(
+ with_rht=False,
+ with_post_rht_amax=False,
+ with_2d_quantization=False,
+ stochastic_rounding=False,
+ with_random_sign_mask=False,
+ )(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
@@ -439,7 +451,7 @@ def _test_basic_linear(
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
- tols = dtype_tols(tex.DType.kFloat8E4M3)
+ tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
@@ -611,7 +623,7 @@ def _test_linear(
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
- tols = dtype_tols(tex.DType.kFloat8E4M3)
+ tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
@@ -625,6 +637,204 @@ def _test_linear(
torch.testing.assert_close(db_test, db_ref, **tols)
+def _test_mlp(
+ *,
+ bias: bool = True,
+ hidden_size: int = 32,
+ local_batch_size: int = 32,
+ dtype: torch.dtype = torch.float32,
+ device: torch.device = "cuda",
+ quantization: Optional[str] = None,
+ quantized_weight: bool = False,
+ sequence_parallel: bool = False,
+) -> None:
+ """2-layer MLP
+
+ MLP includes GELU activation in order to test op fusions. Model
+ performs warmup steps in order to test inter-step logic.
+
+ """
+
+ # Skip invalid configurations
+ quantized_compute = quantization is not None
+ if not quantized_compute and quantized_weight:
+ return
+
+ # Distributed process group
+ process_group = world_group()
+ rank = torch.distributed.get_rank(process_group)
+ world_size = torch.distributed.get_world_size(process_group)
+
+ # Tensor dimensions
+ mlp_size = hidden_size * world_size
+ batch_size = local_batch_size
+ if sequence_parallel:
+ batch_size *= world_size
+ in_shape = (batch_size, hidden_size)
+
+ # Random data
+ reset_rng()
+ x_ref, x_test = make_reference_and_test_tensors(
+ in_shape,
+ quantization=quantization,
+ test_dtype=dtype,
+ test_device=device,
+ )
+ w1_ref, w1_test = make_reference_and_test_tensors(
+ (mlp_size, hidden_size),
+ quantization=quantization,
+ test_dtype=dtype,
+ test_device=device,
+ )
+ b1_ref, b1_test = None, None
+ w2_ref, w2_test = make_reference_and_test_tensors(
+ (hidden_size, mlp_size),
+ quantization=quantization,
+ test_dtype=dtype,
+ test_device=device,
+ )
+ b2_ref, b2_test = None, None
+ if bias:
+ b1_ref, b1_test = make_reference_and_test_tensors(
+ (mlp_size,),
+ test_dtype=dtype,
+ test_device=device,
+ )
+ b2_ref, b2_test = make_reference_and_test_tensors(
+ (world_size, hidden_size),
+ test_dtype=dtype,
+ test_device=device,
+ )
+ dy_ref, dy_test = make_reference_and_test_tensors(
+ in_shape,
+ quantization=quantization,
+ test_dtype=dtype,
+ test_device=device,
+ requires_grad=False,
+ )
+
+ # Plain PyTorch implementation
+ y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
+ y_ref = torch.nn.functional.linear(y_ref, w1_ref)
+ if bias:
+ y_ref += b1_ref
+ y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
+ y_ref = torch.nn.functional.linear(y_ref, w2_ref)
+ if bias:
+ y_ref += b2_ref.sum(dim=0)
+ y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
+ y_ref.backward(dy_ref)
+
+ # Convert to distributed tensors
+ with torch.no_grad():
+ local_mlp_size = mlp_size // world_size
+ local_mlp_slice = slice(rank * local_mlp_size, (rank + 1) * local_mlp_size)
+ dx_ref = x_ref.grad
+ dw1_ref = w1_ref.grad[local_mlp_slice, :]
+ w1_ref = w1_ref[local_mlp_slice, :]
+ w1_test = w1_test[local_mlp_slice, :]
+ dw2_ref = w2_ref.grad[:, local_mlp_slice]
+ w2_ref = w2_ref[:, local_mlp_slice]
+ w2_test = w2_test[:, local_mlp_slice]
+ if bias:
+ db1_ref = b1_ref.grad[local_mlp_slice]
+ b1_ref = b1_ref[local_mlp_slice]
+ b1_test = b1_test[local_mlp_slice]
+ db2_ref = b2_ref.grad[rank, :]
+ b2_ref = b2_ref[rank, :]
+ b2_test = b2_test[rank, :]
+ else:
+ db1_ref = None
+ db2_ref = None
+ if sequence_parallel:
+ local_batch_slice = slice(
+ rank * local_batch_size,
+ (rank + 1) * local_batch_size,
+ )
+ x_ref = x_ref[local_batch_slice, ...]
+ dx_ref = dx_ref[local_batch_slice, ...]
+ x_test = x_test[local_batch_slice, ...].clone()
+ y_ref = y_ref[local_batch_slice, ...]
+ dy_ref = dy_ref[local_batch_slice, ...]
+ dy_test = dy_test[local_batch_slice, ...].clone()
+ x_test.requires_grad_()
+
+ # Implementation with fusible operation
+ recipe = make_recipe(quantization)
+ with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
+ model = te_ops.Sequential(
+ te_ops.GELU(),
+ te_ops.Linear(
+ hidden_size,
+ mlp_size,
+ bias=bias,
+ device=device,
+ dtype=dtype,
+ tensor_parallel_mode="column",
+ tensor_parallel_group=process_group,
+ sequence_parallel=sequence_parallel,
+ ),
+ te_ops.GELU(),
+ te_ops.Linear(
+ mlp_size,
+ hidden_size,
+ bias=bias,
+ device=device,
+ dtype=dtype,
+ tensor_parallel_mode="row",
+ tensor_parallel_group=process_group,
+ sequence_parallel=sequence_parallel,
+ ),
+ te_ops.GELU(),
+ )
+ with torch.no_grad():
+ model[1].weight.copy_(w1_test)
+ model[3].weight.copy_(w2_test)
+ if bias:
+ model[1].bias.copy_(b1_test)
+ model[3].bias.copy_(b2_test)
+ del w1_test, w2_test, b1_test, b2_test
+
+ # Warmup steps
+ for _ in range(3):
+ with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
+ y_test = model(x_test)
+ y_test.backward(dy_test)
+ x_test.grad = None
+ model[1].weight.grad = None
+ model[3].weight.grad = None
+ if bias:
+ model[1].bias.grad = None
+ model[3].bias.grad = None
+
+ # Forward and backward step
+ with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
+ y_test = model(x_test)
+ y_test.backward(dy_test)
+
+ # Expected numerical error
+ tols = dtype_tols(dtype)
+ if dtype == torch.float32:
+ tols = dtype_tols(torch.float16) # TF32 GEMM
+ if quantized_compute:
+ tols = quantization_tols(quantization)
+
+ # Check results
+ y_test = y_test.to(dtype=torch.float64, device="cpu")
+ dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
+ dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
+ dw2_test = model[3].weight.grad.to(dtype=torch.float64, device="cpu")
+ torch.testing.assert_close(y_test, y_ref, **tols)
+ torch.testing.assert_close(dx_test, dx_ref, **tols)
+ torch.testing.assert_close(dw1_test, dw1_ref, **tols)
+ torch.testing.assert_close(dw2_test, dw2_ref, **tols)
+ if bias:
+ db1_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu")
+ db2_test = model[3].bias.grad.to(dtype=torch.float64, device="cpu")
+ torch.testing.assert_close(db1_test, db1_ref, **tols)
+ torch.testing.assert_close(db2_test, db2_ref, **tols)
+
+
def _test_fp8_scale_update(
*,
amax_history_len: int = 31,
@@ -791,16 +1001,31 @@ def run_parallel_tests() -> None:
for config in itertools.product(
quantization_list,
("column", "row"),
+ (False, True),
):
if rank == 0:
print(f"Running _test_linear with {config=}")
- quantization, tensor_parallel_mode = config
+ quantization, tensor_parallel_mode, sequence_parallel = config
dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32
_test_linear(
bias=True, # bias=False is tested in _test_basic_linear
dtype=dtype,
quantization=quantization,
tensor_parallel_mode=tensor_parallel_mode,
+ sequence_parallel=sequence_parallel,
+ )
+
+ # MLP
+ for config in itertools.product(quantization_list, (False, True)):
+ if rank == 0:
+ print(f"Running _test_mlp with {config=}")
+ quantization, sequence_parallel = config
+ dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32
+ _test_mlp(
+ bias=True, # bias=False is tested in _test_basic_linear
+ dtype=dtype,
+ quantization=quantization,
+ sequence_parallel=sequence_parallel,
)
# FP8 scale update
diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py
index 1ff5aff99..d09c530cb 100644
--- a/tests/pytorch/distributed/test_numerics.py
+++ b/tests/pytorch/distributed/test_numerics.py
@@ -31,6 +31,7 @@
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
+nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
@@ -51,7 +52,9 @@ def _run_test(quantization):
all_boolean = [True, False]
-@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
+@pytest.mark.parametrize(
+ "quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"]
+)
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
@@ -61,4 +64,6 @@ def test_distributed(quantization):
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
+ if quantization == "nvfp4" and not nvfp4_available:
+ pytest.skip(reason_for_no_nvfp4)
_run_test(quantization)
diff --git a/tests/pytorch/distributed/test_numerics_exact.py b/tests/pytorch/distributed/test_numerics_exact.py
new file mode 100644
index 000000000..890a24804
--- /dev/null
+++ b/tests/pytorch/distributed/test_numerics_exact.py
@@ -0,0 +1,70 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import os
+import subprocess
+from pathlib import Path
+
+import pytest
+import torch
+from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
+
+"""
+ Distributed numerics tests
+
+ This numerical test aims for zero tolerance test for absolute confidence in numerics.
+ In the case of NVFP4, with the experimental NVFP4 quantization, we matched bitwise
+ result with the native silicon. For distrbuted test cases, we can do the same by thing
+ by comparing BF16 AG results with the low precision AG results at layer level.
+"""
+
+
+if torch.cuda.device_count() < 2:
+ pytest.skip("Distributed training needs at least 2 GPUs.")
+
+fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
+mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
+fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
+ FP8GlobalStateManager.is_fp8_block_scaling_available()
+)
+nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()
+
+TEST_ROOT = Path(__file__).parent.resolve()
+NUM_PROCS: int = min(4, torch.cuda.device_count())
+LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
+
+
+def _run_test(quantization, batch_size, hidden_size, out_size):
+ test_path = TEST_ROOT / "run_numerics_exact.py"
+ test_cmd = LAUNCH_CMD + [str(test_path)]
+
+ test_cmd += ["--quantization", quantization]
+ test_cmd += ["--batch-size", str(batch_size)]
+ test_cmd += ["--hidden-size", str(hidden_size)]
+ test_cmd += ["--out-size", str(out_size)]
+
+ result = subprocess.run(test_cmd, env=os.environ, check=False)
+ assert result.returncode == 0
+
+
+all_boolean = [True, False]
+
+
+@pytest.mark.parametrize("quantization", ["nvfp4"])
+@pytest.mark.parametrize(
+ "batch_size, hidden_size, out_size",
+ [
+ (64, 128, 128),
+ (128, 128, 128),
+ (128, 256, 256),
+ (512, 1024, 768),
+ (512, 256, 1024),
+ (2048, 2048, 2048),
+ ],
+)
+def test_distributed(quantization, batch_size, hidden_size, out_size):
+ if quantization == "nvfp4" and not nvfp4_available:
+ pytest.skip(reason_for_no_nvfp4)
+
+ _run_test(quantization, batch_size, hidden_size, out_size)
diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
new file mode 100644
index 000000000..a9e73aaf9
--- /dev/null
+++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
@@ -0,0 +1,243 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import pytest
+import torch
+import transformer_engine as te
+import transformer_engine_torch as tex
+from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
+from transformer_engine.pytorch.constants import TE_DType
+from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
+from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef
+from transformer_engine.pytorch.experimental import utils
+
+
+recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
+
+
+def check_nvfp4_gemm_versus_reference(
+ x_dtype: torch.dtype,
+ w_dtype: torch.dtype,
+ out_dtype: torch.dtype,
+ M: int,
+ K: int,
+ N: int,
+ accumulate: bool,
+ *,
+ x_columnwise: bool = False,
+ w_columnwise: bool = False,
+):
+ te_dtype = tex.DType.kFloat4E2M1
+
+ # Setup device and random seed
+ device = "cuda"
+ seed = 0
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+ # Input tensors
+ x_shape = (K, M) if x_columnwise else (M, K)
+ w_shape = (K, N) if w_columnwise else (N, K)
+ x = torch.randn(x_shape, dtype=x_dtype, device=device)
+ w = torch.randn(w_shape, dtype=w_dtype, device=device)
+
+ # Setup out tensor if accumulate is True
+ if accumulate:
+ out = torch.randn((M, N), dtype=out_dtype, device=device)
+ else:
+ out = None
+
+ # Native TE NVFP4 quantization
+ x_quantizer = NVFP4Quantizer(
+ fp4_dtype=te_dtype,
+ rowwise=True,
+ columnwise=True,
+ with_amax_reduction=False,
+ amax_reduction_group=None,
+ with_rht=False,
+ with_post_rht_amax=False,
+ )
+ w_quantizer = NVFP4Quantizer(
+ fp4_dtype=te_dtype,
+ rowwise=True,
+ columnwise=True,
+ with_amax_reduction=False,
+ amax_reduction_group=None,
+ with_rht=False,
+ with_post_rht_amax=False,
+ )
+
+ # Quantize x and w
+ x_nvfp4_native = x_quantizer.make_empty(
+ x_shape, dtype=x_dtype, device=device, requires_grad=False
+ )
+ x_nvfp4_native = x_quantizer.update_quantized(x, x_nvfp4_native)
+ w_nvfp4_native = w_quantizer.make_empty(
+ w_shape, dtype=w_dtype, device=device, requires_grad=False
+ )
+ w_nvfp4_native = w_quantizer.update_quantized(w, w_nvfp4_native)
+
+ # Extract quantized data from native NVFP4Tensors
+ qx_data = (
+ x_nvfp4_native._columnwise_data.view(dtype=torch.uint8)
+ if x_columnwise
+ else x_nvfp4_native._rowwise_data.view(dtype=torch.uint8)
+ )
+ qw_data = (
+ w_nvfp4_native._columnwise_data.view(dtype=torch.uint8)
+ if w_columnwise
+ else w_nvfp4_native._rowwise_data.view(dtype=torch.uint8)
+ )
+ sx_native = (
+ x_nvfp4_native._columnwise_scale_inv if x_columnwise else x_nvfp4_native._rowwise_scale_inv
+ )
+ sw_native = (
+ w_nvfp4_native._columnwise_scale_inv if w_columnwise else w_nvfp4_native._rowwise_scale_inv
+ )
+
+ # Trim quantized data to match the actual tensor dimensions (remove padding)
+ qx_data = qx_data[:M, :]
+ qw_data = qw_data[:N, :]
+
+ # NVFP4 uses 16-element blocks, trim scales to remove padding
+ block_length = 16 # NVFP4 uses 16-element blocks
+ expected_sx_cols = expected_sw_cols = K // block_length
+ # Trim the scales to remove padding
+ sx_trimmed = sx_native[:M, :expected_sx_cols]
+ sw_trimmed = sw_native[:N, :expected_sw_cols]
+
+ # Native scales are stored as uint8 but need to be interpreted as float8_e4m3fn
+ # for the reference GEMM to work correctly
+ sx_trimmed = sx_trimmed.view(torch.float8_e4m3fn)
+ sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn)
+
+ # Create reference quantizer for reference GEMM
+ ref_quantizer = NVFP4QuantizerRef(
+ dtype=utils.Fp4Formats.E2M1,
+ rowwise=True,
+ columnwise=True,
+ pow_2_scales=False,
+ eps=0.0,
+ quant_tile_shape=(1, 16),
+ )
+
+ # Create reference quantized tensors needed by reference GEMM
+ x_nvfp4_ref = ref_quantizer.quantize(x)
+ w_nvfp4_ref = ref_quantizer.quantize(w)
+
+ # Reference GEMM using quantizer's qgemm method
+ y_ref = ref_quantizer.qgemm(
+ qx=qx_data,
+ qw=qw_data,
+ m_params=None, # MMParams not used in reference
+ out_dtype=out_dtype,
+ sx=sx_trimmed,
+ sw=sw_trimmed,
+ bias=None, # No bias for this test
+ out=out.clone() if accumulate else None,
+ accumulate=accumulate,
+ gemm_type=None, # GEMMType not used in reference
+ qresult_x=x_nvfp4_ref,
+ qresult_w=w_nvfp4_ref,
+ )
+
+ # Native TE GEMM using tex.generic_gemm (cuBLAS GEMM)
+ # Allocate cuBLAS workspace
+ workspace = torch.empty(4, dtype=torch.uint8, device=device)
+
+ transa = True if not w_columnwise else False
+ transb = False if not x_columnwise else True
+ out_quantizer = None
+ bias = None
+ bias_dtype = TE_DType[torch.bfloat16]
+ use_gelu = False
+ gelu_input = None
+ use_grad = False
+ use_split_accumulator = False
+
+ # Native cuBLAS GEMM
+ # return type is out, bias_grad, gelu_input, extra_output
+ # We are just capturing out.
+ y_native = tex.generic_gemm(
+ w_nvfp4_native,
+ transa,
+ x_nvfp4_native,
+ transb,
+ out.clone() if accumulate else None,
+ out_quantizer,
+ TE_DType[out_dtype],
+ bias,
+ bias_dtype,
+ use_gelu,
+ gelu_input,
+ use_grad,
+ workspace,
+ workspace.shape[0],
+ accumulate,
+ use_split_accumulator,
+ )[0]
+
+ # just in case of accumulation, make sure y_ref and y_native are not the same tensor
+ assert y_ref is not y_native, "y_ref and y_native should not be the same tensor"
+ # Reset nans to zeros because torch.assert_close does not assume nans to be equal
+ assert not torch.isnan(y_ref.float()).all(), "All elements are nan"
+ y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref)
+ y_native = torch.where(y_native.isnan(), torch.zeros_like(y_native), y_native)
+
+ # Compare results with some tolerance
+ torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3)
+
+
+@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
+@pytest.mark.parametrize(
+ "M, K, N",
+ [
+ (128, 128, 128),
+ (256, 128, 256),
+ (256, 256, 256),
+ (256, 1024, 256),
+ (1024, 1024, 1024),
+ (4096, 512, 3072),
+ (112, 128, 96),
+ (304, 640, 304),
+ (1008, 3072, 992),
+ (256, 64, 256),
+ (128, 128, 112),
+ ],
+)
+@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
+@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str)
+@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str)
+@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"])
+@pytest.mark.parametrize(
+ "is_x_columnwise, is_w_columnwise",
+ [
+ (False, False), # Only rowwise x rowwise is supported by reference GEMM
+ # Note: Reference GEMM expects inputs as (M,K) x (N,K) with rowwise quantization
+ # Columnwise layouts are not supported by the reference implementation
+ ],
+ ids=["rowxrow"],
+)
+def test_nvfp4_gemm_versus_reference(
+ M: int,
+ K: int,
+ N: int,
+ x_dtype: torch.dtype,
+ w_dtype: torch.dtype,
+ out_dtype: torch.dtype,
+ accumulate: bool,
+ is_x_columnwise: bool,
+ is_w_columnwise: bool,
+):
+ check_nvfp4_gemm_versus_reference(
+ x_dtype=x_dtype,
+ w_dtype=w_dtype,
+ out_dtype=out_dtype,
+ M=M,
+ K=K,
+ N=N,
+ accumulate=accumulate,
+ x_columnwise=is_x_columnwise,
+ w_columnwise=is_w_columnwise,
+ )
diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py
new file mode 100644
index 000000000..ae9975839
--- /dev/null
+++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py
@@ -0,0 +1,559 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import os
+import pytest
+import torch
+import transformer_engine as te
+from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
+from transformer_engine.pytorch.distributed import fp8_autocast
+from transformer_engine.common import recipe
+
+
+recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
+
+
+class GetRecipes:
+ @staticmethod
+ def nvfp4_vanilla():
+ nvfp4_recipe = recipe.NVFP4BlockScaling()
+ nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
+ nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
+ nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
+ return nvfp4_recipe
+
+ @staticmethod
+ def nvfp4_rht_only():
+ nvfp4_recipe = recipe.NVFP4BlockScaling()
+ nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(random_hadamard_transform=True)
+ nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(random_hadamard_transform=False)
+ nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(random_hadamard_transform=True)
+ return nvfp4_recipe
+
+ @staticmethod
+ def nvfp4_2d_quantization_only():
+ nvfp4_recipe = recipe.NVFP4BlockScaling()
+ nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(fp4_2d_quantization=False)
+ nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True)
+ nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(fp4_2d_quantization=False)
+ return nvfp4_recipe
+
+ @staticmethod
+ def nvfp4_rht_and_2d_quantization():
+ nvfp4_recipe = recipe.NVFP4BlockScaling()
+ nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
+ random_hadamard_transform=True, fp4_2d_quantization=False
+ )
+ nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
+ random_hadamard_transform=False, fp4_2d_quantization=True
+ )
+ nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
+ random_hadamard_transform=True, fp4_2d_quantization=False
+ )
+ return nvfp4_recipe
+
+ @staticmethod
+ def nvfp4_recipe_to_test(with_rht: bool = False, with_2d_quantization: bool = False):
+ if with_rht and with_2d_quantization:
+ return GetRecipes.nvfp4_rht_and_2d_quantization()
+ elif with_rht:
+ return GetRecipes.nvfp4_rht_only()
+ elif with_2d_quantization:
+ return GetRecipes.nvfp4_2d_quantization_only()
+ else:
+ return GetRecipes.nvfp4_vanilla()
+
+
+def setup_environment_for_reference(with_rht: bool = False, with_2d_quantization: bool = False):
+ if with_rht and with_2d_quantization:
+ os.environ["QAT_PARAMS"] = "9003"
+ elif with_rht:
+ os.environ["QAT_PARAMS"] = "960109"
+ elif with_2d_quantization:
+ os.environ["QAT_PARAMS"] = "9002"
+ else:
+ os.environ["QAT_PARAMS"] = "6010"
+
+
+def cleanup_environment():
+ if "QAT_PARAMS" in os.environ:
+ del os.environ["QAT_PARAMS"]
+
+
+def reset_rng_states():
+ seed = 1234
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+
+def check_nvfp4_module_versus_reference(
+ module_class,
+ in_features: int,
+ out_features: int,
+ bias: bool,
+ x_dtype: torch.dtype,
+ num_steps: int = 1,
+ with_rht: bool = False,
+ with_2d_quantization: bool = False,
+):
+ """
+ Compare native NVFP4 module against reference implementation.
+
+ Args:
+ module_class: te.Linear or te.LayerNormLinear
+ in_features: Input feature dimension
+ out_features: Output feature dimension
+ bias: Whether to use bias
+ x_dtype: Input tensor dtype
+ num_steps: Number of forward/backward steps to test
+ """
+ device = "cuda"
+ batch_size = 32
+ seq_len = 128
+
+ # Create both modules with identical initialization
+ cleanup_environment()
+ reset_rng_states()
+
+ # Create native module
+ print("\nCreate native module")
+ if module_class == te.pytorch.Linear:
+ native_module = te.pytorch.Linear(
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ device=device,
+ params_dtype=x_dtype,
+ )
+ elif module_class == te.pytorch.LayerNormLinear:
+ native_module = te.pytorch.LayerNormLinear(
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ device=device,
+ params_dtype=x_dtype,
+ )
+ else:
+ raise ValueError(f"Unsupported module class: {module_class}")
+
+ # Create reference module with same weights
+ setup_environment_for_reference(with_rht, with_2d_quantization)
+ reset_rng_states()
+
+ # Create reference module
+ print("Create reference module")
+ if module_class == te.pytorch.Linear:
+ ref_module = te.pytorch.Linear(
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ device=device,
+ params_dtype=x_dtype,
+ )
+ elif module_class == te.pytorch.LayerNormLinear:
+ ref_module = te.pytorch.LayerNormLinear(
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ device=device,
+ params_dtype=x_dtype,
+ )
+
+ # Sync weights between native and reference modules
+ with torch.no_grad():
+ # Copy main weight and bias parameters
+ if hasattr(native_module, "weight") and hasattr(ref_module, "weight"):
+ ref_module.weight.copy_(native_module.weight)
+ if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"):
+ ref_module.bias.copy_(native_module.bias)
+
+ # Copy layer norm parameters if they exist
+ if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"):
+ ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight)
+ if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"):
+ ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias)
+
+ nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization)
+
+ # Training loop comparison
+ native_outputs = []
+ ref_outputs = []
+
+ for step in range(num_steps):
+ torch.manual_seed(1234 + step)
+ torch.cuda.manual_seed(1234 + step)
+
+ x_shape = (batch_size, seq_len, in_features)
+ x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device)
+ x_native = x_val.clone().detach().requires_grad_(True)
+ x_ref = x_native.clone().detach().requires_grad_(True)
+
+ grad_output_shape = (batch_size, seq_len, out_features)
+ grad_output_val = torch.normal(
+ mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device
+ )
+ grad_output = grad_output_val.clone().detach()
+
+ # Native forward/backward
+ cleanup_environment()
+ with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe):
+ # enable weight cache by giving is_first_microbatch
+ y_native = native_module(x_native, is_first_microbatch=(step == 0))
+ y_native.backward(grad_output)
+
+ # Reference forward/backward
+ setup_environment_for_reference(with_rht, with_2d_quantization)
+ with fp8_autocast(
+ enabled=True, fp8_recipe=nvfp4_recipe
+ ): # Exact recipe does not play a role here
+ y_ref = ref_module(x_ref)
+ y_ref.backward(grad_output)
+
+ # Store results
+ native_outputs.append(
+ {
+ "output": y_native.detach().clone(),
+ "input_grad": (
+ x_native.grad.detach().clone() if x_native.grad is not None else None
+ ),
+ "weight_grad": (
+ native_module.weight.grad.detach().clone()
+ if native_module.weight.grad is not None
+ else None
+ ),
+ "bias_grad": (
+ native_module.bias.grad.detach().clone()
+ if bias and native_module.bias.grad is not None
+ else None
+ ),
+ }
+ )
+
+ ref_outputs.append(
+ {
+ "output": y_ref.detach().clone(),
+ "input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None),
+ "weight_grad": (
+ ref_module.weight.grad.detach().clone()
+ if ref_module.weight.grad is not None
+ else None
+ ),
+ "bias_grad": (
+ ref_module.bias.grad.detach().clone()
+ if bias and ref_module.bias.grad is not None
+ else None
+ ),
+ }
+ )
+
+ # Compare results across all steps
+ for step in range(num_steps):
+ native_out = native_outputs[step]
+ ref_out = ref_outputs[step]
+
+ # Compare outputs
+ torch.testing.assert_close(
+ native_out["output"],
+ ref_out["output"],
+ atol=1e-6,
+ rtol=1e-6,
+ msg=f"Output mismatch at step {step}",
+ )
+
+ # Compare input gradients
+ torch.testing.assert_close(
+ native_out["input_grad"],
+ ref_out["input_grad"],
+ atol=1e-6,
+ rtol=1e-6,
+ msg=(
+ f"Input gradient mismatch at step {step}. Native: {native_out['input_grad']}, Ref:"
+ f" {ref_out['input_grad']}"
+ ),
+ )
+
+ # Compare weight gradients
+ torch.testing.assert_close(
+ native_out["weight_grad"],
+ ref_out["weight_grad"],
+ atol=1e-6,
+ rtol=1e-6,
+ msg=(
+ f"Weight gradient mismatch at step {step}. Native: {native_out['weight_grad']},"
+ f" Ref: {ref_out['weight_grad']}"
+ ),
+ )
+
+ # Compare bias gradients
+ if bias and native_out["bias_grad"] is not None and ref_out["bias_grad"] is not None:
+ torch.testing.assert_close(
+ native_out["bias_grad"],
+ ref_out["bias_grad"],
+ atol=1e-6,
+ rtol=1e-6,
+ msg=f"Bias gradient mismatch at step {step}",
+ )
+
+ # Clean up
+ cleanup_environment()
+
+
+@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
+@pytest.mark.parametrize(
+ "in_features, out_features",
+ [
+ (128, 256),
+ (256, 128),
+ (512, 512),
+ (768, 3072),
+ (1024, 4096),
+ ],
+)
+# @pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"])
+@pytest.mark.parametrize("bias", [False], ids=["no_bias"])
+@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
+@pytest.mark.parametrize("num_steps", [1, 3], ids=["single_step", "multi_step"])
+@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"])
+@pytest.mark.parametrize(
+ "with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"]
+)
+def test_nvfp4_linear_versus_reference(
+ in_features: int,
+ out_features: int,
+ bias: bool,
+ x_dtype: torch.dtype,
+ num_steps: int,
+ with_rht: bool,
+ with_2d_quantization: bool,
+):
+ """Test NVFP4 Linear module against reference implementation."""
+ if with_rht and x_dtype != torch.bfloat16:
+ pytest.skip("RHT is only supported for bfloat16 input")
+
+ check_nvfp4_module_versus_reference(
+ module_class=te.pytorch.Linear,
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ x_dtype=x_dtype,
+ num_steps=num_steps,
+ with_rht=with_rht,
+ with_2d_quantization=with_2d_quantization,
+ )
+
+
+def check_nvfp4_layernorm_linear_versus_reference(
+ in_features: int,
+ out_features: int,
+ bias: bool,
+ normalization: str,
+ x_dtype: torch.dtype,
+ num_steps: int = 1,
+ with_rht: bool = False,
+ with_2d_quantization: bool = False,
+):
+ """
+ Compare native NVFP4 LayerNormLinear module against reference implementation,
+ including ln_out.
+ """
+ device = "cuda"
+ batch_size = 32
+ seq_len = 128
+
+ # Create both modules with identical initialization
+ cleanup_environment()
+ reset_rng_states()
+
+ # Native module
+ native_module = te.pytorch.LayerNormLinear(
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ device=device,
+ params_dtype=x_dtype,
+ normalization=normalization,
+ return_layernorm_output=True,
+ )
+
+ # Reference module
+ setup_environment_for_reference(with_rht, with_2d_quantization)
+ reset_rng_states()
+ ref_module = te.pytorch.LayerNormLinear(
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ device=device,
+ params_dtype=x_dtype,
+ normalization=normalization,
+ return_layernorm_output=True,
+ )
+
+ # Sync weights and LN params
+ with torch.no_grad():
+ if hasattr(native_module, "weight") and hasattr(ref_module, "weight"):
+ ref_module.weight.copy_(native_module.weight)
+ if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"):
+ ref_module.bias.copy_(native_module.bias)
+ if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"):
+ if (
+ native_module.layer_norm_weight is not None
+ and ref_module.layer_norm_weight is not None
+ ):
+ ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight)
+ if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"):
+ if native_module.layer_norm_bias is not None and ref_module.layer_norm_bias is not None:
+ ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias)
+
+ nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization)
+
+ native_outputs = []
+ ref_outputs = []
+
+ for step in range(num_steps):
+ torch.manual_seed(1234 + step)
+ torch.cuda.manual_seed(1234 + step)
+
+ x_shape = (batch_size, seq_len, in_features)
+ x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device)
+ x_native = x_val.clone().detach().requires_grad_(True)
+ x_ref = x_native.clone().detach().requires_grad_(True)
+
+ grad_output_shape = (batch_size, seq_len, out_features)
+ grad_output_val = torch.normal(
+ mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device
+ )
+ grad_output = grad_output_val.clone().detach()
+
+ # Native forward/backward
+ cleanup_environment()
+ with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe):
+ y_native, ln_out_native = native_module(x_native, is_first_microbatch=(step == 0))
+ y_native.backward(grad_output)
+
+ # Reference forward/backward
+ setup_environment_for_reference(with_rht, with_2d_quantization)
+ with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe):
+ y_ref, ln_out_ref = ref_module(x_ref)
+ y_ref.backward(grad_output)
+
+ native_outputs.append(
+ {
+ "output": y_native.detach().clone(),
+ "ln_out": ln_out_native.detach().clone(),
+ "input_grad": (
+ x_native.grad.detach().clone() if x_native.grad is not None else None
+ ),
+ "weight_grad": (
+ native_module.weight.grad.detach().clone()
+ if native_module.weight.grad is not None
+ else None
+ ),
+ "bias_grad": (
+ native_module.bias.grad.detach().clone()
+ if bias and native_module.bias.grad is not None
+ else None
+ ),
+ }
+ )
+ ref_outputs.append(
+ {
+ "output": y_ref.detach().clone(),
+ "ln_out": ln_out_ref.detach().clone(),
+ "input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None),
+ "weight_grad": (
+ ref_module.weight.grad.detach().clone()
+ if ref_module.weight.grad is not None
+ else None
+ ),
+ "bias_grad": (
+ ref_module.bias.grad.detach().clone()
+ if bias and ref_module.bias.grad is not None
+ else None
+ ),
+ }
+ )
+
+ # Compare results
+ for step in range(num_steps):
+ n = native_outputs[step]
+ r = ref_outputs[step]
+ torch.testing.assert_close(
+ n["output"],
+ r["output"],
+ atol=1e-6,
+ rtol=1e-6,
+ msg=f"Output mismatch at step {step}",
+ )
+ torch.testing.assert_close(
+ n["ln_out"],
+ r["ln_out"],
+ atol=1e-6,
+ rtol=1e-6,
+ msg=f"LN output mismatch at step {step}",
+ )
+ torch.testing.assert_close(
+ n["input_grad"],
+ r["input_grad"],
+ atol=1e-6,
+ rtol=1e-6,
+ msg=f"Input gradient mismatch at step {step}",
+ )
+ torch.testing.assert_close(
+ n["weight_grad"],
+ r["weight_grad"],
+ atol=1e-6,
+ rtol=1e-6,
+ msg=f"Weight gradient mismatch at step {step}",
+ )
+ if bias and n["bias_grad"] is not None and r["bias_grad"] is not None:
+ torch.testing.assert_close(
+ n["bias_grad"],
+ r["bias_grad"],
+ atol=1e-6,
+ rtol=1e-6,
+ msg=f"Bias gradient mismatch at step {step}",
+ )
+
+ cleanup_environment()
+
+
+@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
+@pytest.mark.parametrize(
+ "in_features, out_features",
+ [
+ (128, 256),
+ (256, 128),
+ ],
+)
+@pytest.mark.parametrize("bias", [False], ids=["no_bias"])
+@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
+@pytest.mark.parametrize("num_steps", [1], ids=["single_step"])
+@pytest.mark.parametrize("normalization", ["LayerNorm", "RMSNorm"], ids=["LayerNorm", "RMSNorm"])
+@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"])
+@pytest.mark.parametrize(
+ "with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"]
+)
+def test_nvfp4_layernorm_linear_versus_reference(
+ in_features: int,
+ out_features: int,
+ bias: bool,
+ normalization: str,
+ x_dtype: torch.dtype,
+ num_steps: int,
+ with_rht: bool,
+ with_2d_quantization: bool,
+):
+ if with_rht and x_dtype != torch.bfloat16:
+ pytest.skip("RHT is only supported for bfloat16 input")
+
+ check_nvfp4_layernorm_linear_versus_reference(
+ in_features=in_features,
+ out_features=out_features,
+ bias=bias,
+ normalization=normalization,
+ x_dtype=x_dtype,
+ num_steps=num_steps,
+ with_rht=with_rht,
+ with_2d_quantization=with_2d_quantization,
+ )
diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
new file mode 100644
index 000000000..dc3c4a4e9
--- /dev/null
+++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
@@ -0,0 +1,495 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import pytest
+import torch
+import transformer_engine as te
+import transformer_engine_torch as tex
+from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
+from transformer_engine.common.recipe import NVFP4BlockScaling
+from transformer_engine.pytorch.constants import TE_DType
+from transformer_engine.pytorch.tensor.nvfp4_tensor import (
+ NVFP4Quantizer,
+)
+from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef
+from transformer_engine.pytorch.experimental import utils
+from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype
+
+
+recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
+
+
+def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
+ repeated = x.repeat_interleave(2, dim=1)
+ repeated[:, 0::2] &= 0x0F
+ repeated[:, 1::2] >>= 4
+ return repeated
+
+
+def check_quantization_nvfp4_versus_reference(
+ x_dtype: torch.dtype,
+ M: int,
+ N: int,
+ return_transpose: bool,
+ swizzled_scale: bool,
+ use_cpp_allocator: bool,
+ with_2d_quantization: bool,
+) -> None:
+ te_dtype = tex.DType.kFloat4E2M1
+
+ # Setup device and random seed
+ device = "cuda"
+ seed = 0
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ # Input
+ x = torch.randn((M, N), dtype=x_dtype, device=device)
+
+ # Quantize
+ nvfp4_quantizer = NVFP4Quantizer(
+ fp4_dtype=te_dtype,
+ rowwise=True,
+ columnwise=return_transpose,
+ with_amax_reduction=False,
+ amax_reduction_group=None,
+ with_rht=False,
+ with_post_rht_amax=False,
+ with_2d_quantization=with_2d_quantization,
+ )
+ if use_cpp_allocator:
+ x_nvfp4_sut = nvfp4_quantizer(x)
+ else:
+ x_nvfp4_sut = nvfp4_quantizer.make_empty(
+ (M, N), dtype=x_dtype, device=device, requires_grad=False
+ )
+ x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
+
+ # Extract data from NVFP4Tensor
+ assert x_nvfp4_sut._rowwise_data is not None
+ qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
+ assert x_nvfp4_sut._rowwise_scale_inv is not None
+ sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
+ qx_t = (
+ x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
+ if x_nvfp4_sut._columnwise_data is not None
+ else None
+ )
+ sx_t = x_nvfp4_sut._columnwise_scale_inv
+ qx_amax = x_nvfp4_sut._amax_rowwise
+
+ # Reference quantization
+ quant_tile_shape = (1, 16) if not with_2d_quantization else (16, 16)
+ ref_quantizer = NVFP4QuantizerRef(
+ dtype=utils.Fp4Formats.E2M1,
+ rowwise=True,
+ columnwise=return_transpose,
+ pow_2_scales=False,
+ eps=0.0,
+ quant_tile_shape=quant_tile_shape,
+ )
+ x_nvfp4_ref = ref_quantizer.quantize(x)
+
+ # Extract data from RefNVFP4Tensor
+ qx_ref = (
+ unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8))
+ if x_nvfp4_ref.data is not None
+ else None
+ )
+ sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
+ qx_t_ref = (
+ unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8))
+ if x_nvfp4_ref.data_t is not None
+ else None
+ )
+ sx_t_ref = (
+ x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
+ )
+ ref_amax = x_nvfp4_ref.global_amax_row
+
+ qx = unpack_fp4(qx)
+ qx_t = unpack_fp4(qx_t) if qx_t is not None else None
+
+ torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
+
+ # Compare only the valid portion of scale tensors (reference may not have padding)
+ ref_sx_shape = sx_ref.shape
+ sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
+
+ torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
+
+ if return_transpose:
+ torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
+
+ # Compare only the valid portion of transpose scale tensors
+ ref_sx_t_shape = sx_t_ref.shape
+ sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
+ torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
+
+ torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
+
+
+@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
+@pytest.mark.parametrize(
+ "M, N",
+ [
+ # full tile cases
+ (128, 128),
+ (256, 256),
+ (256, 1024),
+ (1024, 256),
+ # Padding required cases
+ (256, 272),
+ (304, 304),
+ (320, 256),
+ # Some larger tiles
+ (2048, 2048),
+ (1024, 2048),
+ (2048, 1024),
+ # # largest tile
+ (8192, 8192),
+ ],
+)
+@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
+@pytest.mark.parametrize(
+ "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
+)
+@pytest.mark.parametrize("swizzled_scale", [False], ids=["linear_scale"])
+@pytest.mark.parametrize(
+ "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
+)
+@pytest.mark.parametrize(
+ "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"]
+)
+def test_quantization_block_tiling_versus_reference(
+ x_dtype: torch.dtype,
+ M: int,
+ N: int,
+ return_transpose: bool,
+ swizzled_scale: bool,
+ use_cpp_allocator: bool,
+ with_2d_quantization: bool,
+) -> None:
+ check_quantization_nvfp4_versus_reference(
+ x_dtype=x_dtype,
+ M=M,
+ N=N,
+ return_transpose=return_transpose,
+ swizzled_scale=swizzled_scale,
+ use_cpp_allocator=use_cpp_allocator,
+ with_2d_quantization=with_2d_quantization,
+ )
+
+
+@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
+@pytest.mark.parametrize(
+ "M, N",
+ [
+ (128, 128),
+ ],
+)
+@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
+@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"])
+@pytest.mark.parametrize(
+ "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
+)
+@pytest.mark.parametrize(
+ "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
+)
+def test_nvfp4_quantization_extrema_versus_reference(
+ x_dtype: torch.dtype,
+ M: int,
+ N: int,
+ extrema_high: bool,
+ return_transpose: bool,
+ use_cpp_allocator: bool,
+):
+ te_dtype = tex.DType.kFloat4E2M1
+
+ device = "cuda"
+ seed = 0
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+ if extrema_high:
+ x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device)
+ else:
+ x = torch.zeros((M, N), dtype=x_dtype, device=device)
+
+ nvfp4_quantizer = NVFP4Quantizer(
+ fp4_dtype=te_dtype,
+ rowwise=True,
+ columnwise=return_transpose,
+ with_amax_reduction=False,
+ amax_reduction_group=None,
+ with_rht=False,
+ with_post_rht_amax=False,
+ )
+
+ if use_cpp_allocator:
+ x_nvfp4_sut = nvfp4_quantizer(x)
+ else:
+ x_nvfp4_sut = nvfp4_quantizer.make_empty(
+ (M, N), dtype=x_dtype, device=device, requires_grad=False
+ )
+ x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
+
+ assert x_nvfp4_sut._rowwise_data is not None
+ qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
+ assert x_nvfp4_sut._rowwise_scale_inv is not None
+ sx = x_nvfp4_sut._rowwise_scale_inv
+ qx_t = (
+ x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
+ if x_nvfp4_sut._columnwise_data is not None
+ else None
+ )
+ sx_t = x_nvfp4_sut._columnwise_scale_inv
+ qx_amax = x_nvfp4_sut._amax_rowwise
+
+ ref_quantizer = NVFP4QuantizerRef(
+ dtype=utils.Fp4Formats.E2M1,
+ rowwise=True,
+ columnwise=return_transpose,
+ pow_2_scales=False,
+ eps=0.0,
+ quant_tile_shape=(1, 16),
+ )
+ x_nvfp4_ref = ref_quantizer.quantize(x)
+
+ qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None
+ sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
+ qx_t_ref = (
+ x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None
+ )
+ sx_t_ref = (
+ x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
+ )
+ ref_amax = x_nvfp4_ref.global_amax_row
+
+ torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
+
+ ref_sx_shape = sx_ref.shape
+ sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
+ torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
+
+ if return_transpose:
+ torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
+ ref_sx_t_shape = sx_t_ref.shape
+ sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
+ torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
+
+ torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
+
+
+@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
+@pytest.mark.parametrize(
+ "M, N",
+ [
+ (16, 128),
+ (32, 128),
+ ],
+)
+@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
+@pytest.mark.parametrize(
+ "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
+)
+@pytest.mark.parametrize(
+ "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
+)
+def test_nvfp4_quantization_boundary_values(
+ x_dtype: torch.dtype,
+ M: int,
+ N: int,
+ return_transpose: bool,
+ use_cpp_allocator: bool,
+):
+ """
+ Stress rounding/threshold behavior by placing values just below/above
+ many potential bin edges within each 16-element microblock.
+ Validates native vs reference byte-for-byte and scale parity.
+ """
+ te_dtype = tex.DType.kFloat4E2M1
+
+ device = "cuda"
+ seed = 123
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+ # Construct a single row with paired boundary values: v-eps, v+eps
+ # spanning a wide dynamic range to exercise clipping and multiple bins.
+ # Ensure even N and N is multiple of 16 for microblocks, which holds for 128.
+ base = torch.linspace(-12.0, 12.0, steps=N // 2, dtype=torch.float32, device=device)
+ eps = torch.full_like(base, 1e-3)
+ # Avoid zero eps for very small magnitudes
+ eps = torch.maximum(eps, 1e-4 * torch.ones_like(base))
+ lower = base - eps
+ upper = base + eps
+ row = torch.empty(N, dtype=torch.float32, device=device)
+ row[0::2] = lower
+ row[1::2] = upper
+ x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype)
+
+ nvfp4_quantizer = NVFP4Quantizer(
+ fp4_dtype=te_dtype,
+ rowwise=True,
+ columnwise=return_transpose,
+ with_amax_reduction=False,
+ amax_reduction_group=None,
+ with_rht=False,
+ with_post_rht_amax=False,
+ )
+
+ if use_cpp_allocator:
+ x_nvfp4_sut = nvfp4_quantizer(x)
+ else:
+ x_nvfp4_sut = nvfp4_quantizer.make_empty(
+ (M, N), dtype=x_dtype, device=device, requires_grad=False
+ )
+ x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
+
+ assert x_nvfp4_sut._rowwise_data is not None
+ qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
+ assert x_nvfp4_sut._rowwise_scale_inv is not None
+ sx = x_nvfp4_sut._rowwise_scale_inv
+ qx_t = (
+ x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
+ if x_nvfp4_sut._columnwise_data is not None
+ else None
+ )
+ sx_t = x_nvfp4_sut._columnwise_scale_inv
+ qx_amax = x_nvfp4_sut._amax_rowwise
+
+ ref_quantizer = NVFP4QuantizerRef(
+ dtype=utils.Fp4Formats.E2M1,
+ rowwise=True,
+ columnwise=return_transpose,
+ pow_2_scales=False,
+ eps=0.0,
+ quant_tile_shape=(1, 16),
+ )
+ x_nvfp4_ref = ref_quantizer.quantize(x)
+
+ qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None
+ sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
+ qx_t_ref = (
+ x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None
+ )
+ sx_t_ref = (
+ x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
+ )
+ ref_amax = x_nvfp4_ref.global_amax_row
+
+ torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
+
+ # Compare only valid portion of scales (trim any padding)
+ ref_sx_shape = sx_ref.shape
+ sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
+ torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
+
+ if return_transpose:
+ torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
+ ref_sx_t_shape = sx_t_ref.shape
+ sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
+ torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
+
+ torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
+
+
+@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
+@pytest.mark.parametrize(
+ "M, N",
+ [
+ (32, 128),
+ ],
+)
+@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
+@pytest.mark.parametrize(
+ "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
+)
+@pytest.mark.parametrize(
+ "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
+)
+def test_nvfp4_quantization_noncontiguous_inputs(
+ x_dtype: torch.dtype,
+ M: int,
+ N: int,
+ return_transpose: bool,
+ use_cpp_allocator: bool,
+):
+ te_dtype = tex.DType.kFloat4E2M1
+
+ device = "cuda"
+ seed = 17
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+ # Start from a contiguous tensor, then make a non-contiguous view by transpose
+ x_base = torch.randn((M, N), dtype=x_dtype, device=device)
+ x_nc = x_base.t() # shape (N, M), non-contiguous
+ assert not x_nc.is_contiguous()
+
+ nvfp4_quantizer = NVFP4Quantizer(
+ fp4_dtype=te_dtype,
+ rowwise=True,
+ columnwise=return_transpose,
+ with_amax_reduction=False,
+ amax_reduction_group=None,
+ with_rht=False,
+ with_post_rht_amax=False,
+ )
+
+ if use_cpp_allocator:
+ x_nvfp4_sut = nvfp4_quantizer(x_nc)
+ else:
+ x_nvfp4_sut = nvfp4_quantizer.make_empty(
+ x_nc.shape, dtype=x_dtype, device=device, requires_grad=False
+ )
+ x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut)
+
+ assert x_nvfp4_sut._rowwise_data is not None
+ qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
+ assert x_nvfp4_sut._rowwise_scale_inv is not None
+ sx = x_nvfp4_sut._rowwise_scale_inv
+ qx_t = (
+ x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
+ if x_nvfp4_sut._columnwise_data is not None
+ else None
+ )
+ sx_t = x_nvfp4_sut._columnwise_scale_inv
+ qx_amax = x_nvfp4_sut._amax_rowwise
+
+ ref_quantizer = NVFP4QuantizerRef(
+ dtype=utils.Fp4Formats.E2M1,
+ rowwise=True,
+ columnwise=return_transpose,
+ pow_2_scales=False,
+ eps=0.0,
+ quant_tile_shape=(1, 16),
+ )
+ x_nvfp4_ref = ref_quantizer.quantize(x_nc)
+
+ qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None
+ sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
+ qx_t_ref = (
+ x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None
+ )
+ sx_t_ref = (
+ x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None
+ )
+ ref_amax = x_nvfp4_ref.global_amax_row
+
+ # Quantized must match
+ torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
+
+ # Compare only valid portion of scales (trim padding)
+ ref_sx_shape = sx_ref.shape
+ sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
+ torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
+
+ if return_transpose:
+ torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
+ ref_sx_t_shape = sx_t_ref.shape
+ sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
+ torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
+
+ torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0)
diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
new file mode 100644
index 000000000..bb542456e
--- /dev/null
+++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
@@ -0,0 +1,255 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py.
+# Separate to make sure all the functionalities are working as expected.
+# Otherwise reference implementation will get messy.
+
+# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality
+# together with the quantization functionality.
+
+from typing import Tuple
+import math
+
+import transformer_engine as te
+import transformer_engine_torch as tex
+from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
+from transformer_engine.common.recipe import NVFP4BlockScaling
+from transformer_engine.pytorch.constants import TE_DType
+from transformer_engine.pytorch.tensor.nvfp4_tensor import (
+ NVFP4Quantizer,
+)
+from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef
+from transformer_engine.pytorch.experimental import utils
+from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype
+
+import pytest
+import torch
+
+recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
+
+
+def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
+ repeated = x.repeat_interleave(2, dim=1)
+ repeated[:, 0::2] &= 0x0F
+ repeated[:, 1::2] >>= 4
+ return repeated
+
+
+def check_quantization_nvfp4_versus_reference(
+ x_dtype: torch.dtype,
+ M: int,
+ N: int,
+ contiguous: bool,
+ return_transpose: bool,
+ use_cpp_allocator: bool,
+ swizzled_scale: bool = False,
+ hadamard_dimension: int = 16,
+ with_rht: bool = True,
+ with_post_rht_amax: bool = True,
+ with_random_sign_mask: bool = True,
+) -> None:
+ assert with_rht and with_post_rht_amax, "RHT and post-RHT amax reduction must be enabled."
+
+ te_dtype = tex.DType.kFloat4E2M1
+
+ # Setup device and random seed
+ device = "cuda"
+ seed = 0
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+ # Input
+ x = torch.randn((M, N), dtype=x_dtype, device=device)
+
+ x = x.transpose(0, 1) if not contiguous else x
+
+ # Quantize
+ nvfp4_quantizer = NVFP4Quantizer(
+ fp4_dtype=te_dtype,
+ rowwise=True,
+ columnwise=return_transpose,
+ with_amax_reduction=False,
+ amax_reduction_group=None,
+ with_rht=with_rht,
+ with_post_rht_amax=with_post_rht_amax,
+ with_random_sign_mask=with_random_sign_mask,
+ )
+ if use_cpp_allocator:
+ x_nvfp4_sut = nvfp4_quantizer(x)
+ else:
+ x_nvfp4_sut = nvfp4_quantizer.make_empty(
+ x.shape, dtype=x_dtype, device=device, requires_grad=False
+ )
+ x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut)
+
+ # Extract data from NVFP4Tensor
+ assert x_nvfp4_sut._rowwise_data is not None
+ qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
+ assert x_nvfp4_sut._rowwise_scale_inv is not None
+ sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
+ qx_t = (
+ x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
+ if x_nvfp4_sut._columnwise_data is not None
+ else None
+ )
+ sx_t = x_nvfp4_sut._columnwise_scale_inv
+ amax_rowwise = x_nvfp4_sut._amax_rowwise
+ amax_colwise = x_nvfp4_sut._amax_columnwise
+
+ qx = unpack_fp4(qx)
+ qx_t = unpack_fp4(qx_t) if qx_t is not None else None
+
+ # Reference quantization using NVFP4QuantizerRef with built-in RHT
+ ref_quantizer = NVFP4QuantizerRef(
+ dtype=utils.Fp4Formats.E2M1,
+ rowwise=True,
+ columnwise=return_transpose,
+ pow_2_scales=False,
+ eps=0.0,
+ quant_tile_shape=(1, 16),
+ with_rht=with_rht,
+ with_random_sign_mask=with_random_sign_mask,
+ )
+ x_nvfp4_ref = ref_quantizer.quantize(x)
+ # Extract data from RefNVFP4Tensor
+ qx_ref = (
+ unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8))
+ if x_nvfp4_ref.data is not None
+ else None
+ )
+ sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None
+ ref_amax_rowwise = x_nvfp4_ref.global_amax_row
+
+ if return_transpose:
+ assert x_nvfp4_ref.data_t is not None
+ assert x_nvfp4_ref.scale_t is not None
+ qx_t_ref = unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8))
+ sx_t_ref = x_nvfp4_ref.scale_t.view(dtype=torch.uint8)
+ # Compute transpose amax using the same reference quantizer
+ x_t_for_amax = (
+ ref_quantizer._apply_rht(x.t().contiguous()) if with_rht else x.t().contiguous()
+ )
+ ref_amax_colwise_t = torch.max(torch.abs(x_t_for_amax)).to(torch.float32).view(1)
+ else:
+ qx_t_ref = None
+ sx_t_ref = None
+ ref_amax_colwise_t = None
+
+ torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0)
+
+ torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0)
+ # Compare only the valid portion of scale tensors (reference may not have padding)
+ ref_sx_shape = sx_ref.shape
+ sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]]
+ torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0)
+
+ if return_transpose:
+ torch.testing.assert_close(amax_colwise, ref_amax_colwise_t, atol=0.0, rtol=0.0)
+
+ torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0)
+
+ # Compare only the valid portion of transpose scale tensors
+ ref_sx_t_shape = sx_t_ref.shape
+ sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]]
+ torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0)
+
+
+@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
+@pytest.mark.parametrize(
+ "M, N",
+ [
+ # full tile cases
+ (128, 128),
+ (256, 256),
+ (256, 1024),
+ (1024, 256),
+ # Padding required cases
+ (256, 272),
+ (304, 304),
+ (320, 256),
+ # Some larger tiles
+ (2048, 2048),
+ (1024, 2048),
+ (2048, 1024),
+ # Real shapes,
+ (8192, 5120),
+ (8192, 10240),
+ (8192, 2560),
+ (8192, 11328),
+ (8192, 512),
+ (8192, 3584),
+ (5120, 8192),
+ (10240, 8192),
+ (2560, 8192),
+ (11328, 8192),
+ (512, 8192),
+ (3584, 8192),
+ (4096, 16384),
+ (14336, 16384),
+ ],
+)
+@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
+@pytest.mark.parametrize(
+ "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
+)
+@pytest.mark.parametrize(
+ "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
+)
+@pytest.mark.parametrize(
+ "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"]
+)
+def test_rht_with_quantization_block_tiling_versus_reference(
+ x_dtype: torch.dtype,
+ M: int,
+ N: int,
+ return_transpose: bool,
+ use_cpp_allocator: bool,
+ with_random_sign_mask: bool,
+) -> None:
+ check_quantization_nvfp4_versus_reference(
+ x_dtype=x_dtype,
+ M=M,
+ N=N,
+ contiguous=True,
+ return_transpose=return_transpose,
+ use_cpp_allocator=use_cpp_allocator,
+ with_random_sign_mask=with_random_sign_mask,
+ )
+
+
+@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
+@pytest.mark.parametrize(
+ "M, N",
+ [
+ (32, 128),
+ ],
+)
+@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
+@pytest.mark.parametrize(
+ "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
+)
+@pytest.mark.parametrize(
+ "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"]
+)
+@pytest.mark.parametrize(
+ "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"]
+)
+def test_nvfp4_quantization_noncontiguous_inputs(
+ x_dtype: torch.dtype,
+ M: int,
+ N: int,
+ return_transpose: bool,
+ use_cpp_allocator: bool,
+ with_random_sign_mask: bool,
+):
+ check_quantization_nvfp4_versus_reference(
+ x_dtype=x_dtype,
+ M=M,
+ N=N,
+ contiguous=False,
+ return_transpose=return_transpose,
+ use_cpp_allocator=use_cpp_allocator,
+ with_random_sign_mask=with_random_sign_mask,
+ )
diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
new file mode 100755
index 000000000..46077eb20
--- /dev/null
+++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py
@@ -0,0 +1,238 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import pytest
+import torch
+from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
+from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
+
+recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available()
+
+seed = 12345
+torch.manual_seed(seed)
+torch.cuda.manual_seed(seed)
+
+
+def unpack_fp4(x: torch.Tensor) -> torch.Tensor:
+ repeated = x.repeat_interleave(2, dim=1)
+ repeated[:, 0::2] &= 0x0F
+ repeated[:, 1::2] >>= 4
+ return repeated
+
+
+_FP4_LUT = torch.tensor(
+ [
+ 0.0, # 0: 0000 - zero
+ 0.5, # 1: 0001 - smallest positive normal
+ 1.0, # 2: 0010
+ 1.5, # 3: 0011
+ 2.0, # 4: 0100
+ 3.0, # 5: 0101
+ 4.0, # 6: 0110
+ 6.0, # 7: 0111 - largest positive normal
+ -0.0, # 8: 1000 - negative zero
+ -0.5, # 9: 1001 - smallest negative normal
+ -1.0, # 10: 1010
+ -1.5, # 11: 1011
+ -2.0, # 12: 1100
+ -3.0, # 13: 1101
+ -4.0, # 14: 1110
+ -6.0, # 15: 1111 - largest negative normal
+ ],
+ dtype=torch.float32,
+)
+
+
+def fp4_to_fp32(fp4: torch.Tensor) -> torch.Tensor:
+ # Convert FP4 indices to their corresponding floating point values
+ # Each index (0-15) represents a 4-bit FP4 value in E2M1 format
+ # Values based on the FP4 E2M1 specification
+ fp4_lut = _FP4_LUT.to(fp4.device)
+ return fp4_lut[fp4.to(torch.long)]
+
+
+def dequantize_fp4(qx: torch.Tensor, sx: torch.Tensor, amax: torch.Tensor) -> torch.Tensor:
+ sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32)
+ dqx = fp4_to_fp32(unpack_fp4(qx))
+ sf = sf[: dqx.shape[0], : dqx.shape[1]]
+ dequant = dqx * sf * (amax / (6.0 * 448))
+ return dequant
+
+
+def RHT(x: torch.Tensor) -> torch.Tensor:
+ def get_wgrad_sign_vector() -> torch.Tensor:
+ """Hard-coded signs for Hadamard transform"""
+ return torch.tensor(
+ [
+ 1.0,
+ 1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ -1.0,
+ -1.0,
+ -1.0,
+ -1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ 1.0,
+ -1.0,
+ -1.0,
+ ],
+ dtype=torch.float32,
+ )
+
+ def _build_hadamard_matrix(
+ size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True
+ ) -> torch.Tensor:
+ """Construct a Hadamard matrix of given power-of-two size with entries +-1.
+
+ Uses Sylvester construction to avoid SciPy dependency.
+ """
+ assert (size & (size - 1)) == 0, "Hadamard size must be a power of two"
+ h = torch.ones((1, 1), device=device, dtype=torch.float32)
+ while h.shape[0] < size:
+ h = torch.cat(
+ [
+ torch.cat([h, h], dim=1),
+ torch.cat([h, -h], dim=1),
+ ],
+ dim=0,
+ )
+ if with_random_sign_mask:
+ sign_mat = get_wgrad_sign_vector().to(device) * torch.eye(
+ size, device=device, dtype=torch.float32
+ )
+ h = sign_mat @ h
+ return h.to(dtype)
+
+ rht_dim = 16
+ # Build H and scale
+ H = _build_hadamard_matrix(rht_dim, x.device, x.dtype)
+ scale = 1.0 / float(rht_dim) ** 0.5
+
+ # Perform blockwise transform along the last dimension
+ original_shape = x.shape
+ x_mat = x.contiguous().view(-1, rht_dim)
+ # Random sign matrix is identity in this reference (no sign flipping)
+ transform = H * scale
+ out = x_mat @ transform
+ return out.view(original_shape)
+
+
+def quantize_fp4(
+ x: torch.Tensor, use_stochastic_rounding: bool, use_2D: bool, use_RHT: bool
+) -> torch.Tensor:
+ nvfp4_quantizer = NVFP4Quantizer(
+ rowwise=True,
+ columnwise=True,
+ with_amax_reduction=False,
+ amax_reduction_group=None,
+ with_rht=use_RHT,
+ with_post_rht_amax=True,
+ stochastic_rounding=use_stochastic_rounding,
+ with_2d_quantization=use_2D,
+ )
+
+ x_nvfp4_sut = nvfp4_quantizer(x)
+ # Extract data from NVFP4Tensor
+ assert x_nvfp4_sut._rowwise_data is not None
+ qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8)
+ assert x_nvfp4_sut._rowwise_scale_inv is not None
+ sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv
+ assert x_nvfp4_sut._columnwise_data is not None
+ qx_t: torch.Tensor = x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8)
+ assert x_nvfp4_sut._columnwise_scale_inv is not None
+ sx_t: torch.Tensor = x_nvfp4_sut._columnwise_scale_inv
+
+ return qx, sx, qx_t, sx_t
+
+
+def check_quantization_nvfp4_versus_reference(
+ x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool
+) -> None:
+ device = "cuda"
+ torch.manual_seed(seed)
+ n_iters = 50
+
+ x = torch.randn((M, N), dtype=x_dtype, device=device) * 2 - 1
+ y = x.t().contiguous()
+ if use_RHT:
+ y = RHT(y)
+ amax = torch.max(torch.abs(x)).float()
+ q_rn, s_rn, q_t_rn, s_t_rn = quantize_fp4(
+ x, use_stochastic_rounding=False, use_2D=use_2D, use_RHT=use_RHT
+ )
+ dq_rn = dequantize_fp4(q_rn, s_rn, amax)
+ dq_t_rn = dequantize_fp4(q_t_rn, s_t_rn, amax)
+ error_rn = (dq_rn - x).float()
+ me_rn = torch.sqrt((error_rn * error_rn).mean())
+ error_t_rn = (dq_t_rn - y).float()
+ me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean())
+ sr_result = torch.zeros_like(x).float()
+ sr_t_result = torch.zeros_like(x).float().t().contiguous()
+ for i in range(n_iters):
+ q_sr, s_sr, q_t_sr, s_t_sr = quantize_fp4(
+ x, use_stochastic_rounding=True, use_2D=use_2D, use_RHT=use_RHT
+ )
+
+ dq_sr = dequantize_fp4(q_sr, s_sr, amax)
+ dq_t_sr = dequantize_fp4(q_t_sr, s_t_sr, amax)
+
+ sr_result += dq_sr.float()
+ sr_t_result += dq_t_sr.float()
+
+ # sr_result_tmp = sr_result / (i + 1)
+ # error_sr = (sr_result_tmp - x).float()
+ # me_sr = torch.sqrt((error_sr * error_sr).mean())
+ # sr_t_result_tmp = sr_t_result / (i + 1)
+ # error_t_sr = (sr_t_result_tmp - y).float()
+ # me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
+ # print(f"Iteration {i}: RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
+ # print(f"Iteration {i}: RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
+
+ # Get the mean result of the stochastic rounding
+ # It should be more accurate than the RN result
+ sr_result /= n_iters
+ error_sr = (sr_result - x).float()
+ me_sr = torch.sqrt((error_sr * error_sr).mean())
+ sr_t_result /= n_iters
+ error_t_sr = (sr_t_result - y).float()
+ me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean())
+
+ print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}")
+ print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}")
+ assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest."
+ assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest."
+
+
+@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
+@pytest.mark.parametrize(
+ "M, N",
+ [
+ (8192, 8192),
+ (8192, 8256), # to test the nonfused RHT path
+ ],
+)
+@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str)
+@pytest.mark.parametrize("use_2D", [False, True], ids=str)
+@pytest.mark.parametrize("use_RHT", [False, True], ids=str)
+def test_quantization_block_tiling_versus_reference(
+ x_dtype: torch.dtype,
+ use_2D: bool,
+ use_RHT: bool,
+ M: int,
+ N: int,
+) -> None:
+ if x_dtype == torch.float32 and use_RHT:
+ pytest.skip("RHT is only supported with bfloat16")
+ check_quantization_nvfp4_versus_reference(
+ x_dtype=x_dtype,
+ use_2D=use_2D,
+ use_RHT=use_RHT,
+ M=M,
+ N=N,
+ )
diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py
index 90e624c94..be7a65deb 100644
--- a/tests/pytorch/test_cuda_graphs.py
+++ b/tests/pytorch/test_cuda_graphs.py
@@ -32,12 +32,59 @@
reset_rng_states()
model_configs = {
- "small": ModelConfig(32, 2, 2, 32),
+ "small": ModelConfig(2, 32, 2, 32),
}
+
+def nvfp4_vanilla():
+ nvfp4_recipe = recipe.NVFP4BlockScaling()
+ nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
+ nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
+ nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
+ return nvfp4_recipe
+
+
+def nvfp4_rht_and_2d_quantization():
+ nvfp4_recipe = recipe.NVFP4BlockScaling()
+ nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
+ random_hadamard_transform=True, fp4_2d_quantization=False
+ )
+ nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
+ random_hadamard_transform=False, fp4_2d_quantization=True
+ )
+ nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
+ random_hadamard_transform=True, fp4_2d_quantization=False
+ )
+ return nvfp4_recipe
+
+
+def check_rht_usage(recipe: recipe.Recipe) -> bool:
+ # if using RHT, we can only support bf16
+ # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
+ if recipe.nvfp4():
+ if (
+ recipe.fp4_quant_fwd_inp.random_hadamard_transform
+ or recipe.fp4_quant_fwd_weight.random_hadamard_transform
+ or recipe.fp4_quant_bwd_grad.random_hadamard_transform
+ ):
+ return True
+ return False
+
+
+def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
+ supported_input_dtypes = []
+ if recipe.nvfp4():
+ supported_input_dtypes.append(torch.bfloat16)
+ # if not using RHT, we can add fp32 as well
+ if not check_rht_usage(recipe):
+ supported_input_dtypes.append(torch.float32)
+ return supported_input_dtypes
+
+
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
+ fp8_recipes.append(nvfp4_rht_and_2d_quantization())
if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
@@ -278,7 +325,7 @@ def _test_cuda_graphs(
@pytest.mark.parametrize("module", _test_cuda_graphs_modules)
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8_params", (False, True))
-@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None])
+@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__)
def test_make_graphed_callables(
*,
module: str,
@@ -295,8 +342,18 @@ def test_make_graphed_callables(
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
- if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op":
- pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
+ if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op":
+ pytest.skip(
+ f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs"
+ )
+ if fp8 and fp8_recipe.nvfp4():
+ if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype):
+ pytest.skip(
+ f"Input dtype {dtype} not supported for NVFP4 Recipe"
+ f" {fp8_recipe.__class__.__name__}"
+ )
+ if fp8_params:
+ pytest.skip("NVFP4 params not supported")
# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
@@ -334,17 +391,19 @@ def test_make_graphed_callables(
"module",
_test_make_graphed_callables_with_fp8_weight_caching_modules,
)
+@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("fp8_params", (False, True))
-@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
+@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__)
def test_make_graphed_callables_with_fp8_weight_caching(
*,
module: str,
+ dtype: torch.dtype,
fp8_params: bool,
fp8_recipe: recipe.Recipe,
) -> None:
test_make_graphed_callables(
module=module,
- dtype=torch.float32,
+ dtype=dtype,
fp8_params=fp8_params,
fp8_recipe=fp8_recipe,
fp8_weight_caching=True,
diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py
index 281fc67a5..80802fce7 100644
--- a/tests/pytorch/test_float8_current_scaling_exact.py
+++ b/tests/pytorch/test_float8_current_scaling_exact.py
@@ -10,7 +10,6 @@
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
-import transformer_engine_torch as tex
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.common.recipe import Float8CurrentScaling, Format
from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype
@@ -273,6 +272,14 @@ def run_linear_multiple_steps(
if bgrad_list is not None and bgrad is not None:
bgrad_list.append(bgrad.detach().clone())
+ # Stack the results
+ return (
+ torch.stack(y_q_list),
+ torch.stack(dgrad_list),
+ torch.stack(wgrad_list),
+ torch.stack(bgrad_list) if bgrad_list is not None else None,
+ )
+
@classmethod
def run_linear(
cls,
diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py
index 500b25f58..5d197c4a6 100644
--- a/tests/pytorch/test_fusible_ops.py
+++ b/tests/pytorch/test_fusible_ops.py
@@ -37,17 +37,19 @@
Float8Quantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
+from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.utils import get_device_compute_capability
import transformer_engine_torch as tex
from torch.utils.cpp_extension import IS_HIP_EXTENSION
# Import utility functions
-from utils import dtype_tols, make_recipe, reset_rng_states
+from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states
-# Check if FP8 is supported
+# Check for supported quantization schemes
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
+nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()
# Supported data types
_dtypes: list[torch.dtype] = [torch.float32, torch.float16]
@@ -63,6 +65,8 @@
_quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling"))
if mxfp8_available:
_quantization_list.append("mxfp8")
+if nvfp4_available:
+ _quantization_list.append("nvfp4")
def maybe_skip_quantization(
@@ -70,6 +74,7 @@ def maybe_skip_quantization(
*,
dims: Optional[Iterable[int] | int] = None,
device: Optional[torch.device | str] = None,
+ dtype: Optional[torch.dtype] = None,
) -> None:
"""Skip test case if a quantization scheme is not supported"""
@@ -77,12 +82,17 @@ def maybe_skip_quantization(
if quantization is None:
return
- # Check if quantization scheme is supported
+ # Check if quantization scheme is supported on device
+ if device is not None and torch.device(device).type != "cuda":
+ pytest.skip("Quantization is only supported on CUDA devices")
if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
+ if quantization == "nvfp4" and not nvfp4_available:
+ pytest.skip(reason_for_no_nvfp4)
+ # Check dims
if dims is not None:
if not isinstance(dims, Iterable):
dims = (dims,)
@@ -92,10 +102,14 @@ def maybe_skip_quantization(
elif quantization == "mxfp8":
if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0:
pytest.skip("MXFP8 GEMMs require dims that are divisible by 32")
+ elif quantization == "nvfp4":
+ if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0:
+ pytest.skip("NVFP4 GEMMs require dims that are divisible by 16")
- # Check if device is supported
- if device is not None and torch.device(device).type != "cuda":
- pytest.skip("Quantization is only supported on CUDA devices")
+ # Check dtype
+ if dtype is not None:
+ if quantization == "nvfp4" and dtype != torch.bfloat16:
+ pytest.skip("NVFP4 quantization is only supported with BF16 data")
@torch.no_grad()
@@ -145,6 +159,14 @@ def make_reference_and_test_tensors(
test = quantizer(test)
elif quantization == "mxfp8":
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
+ elif quantization == "nvfp4":
+ test = NVFP4Quantizer(
+ with_rht=False,
+ with_post_rht_amax=False,
+ with_2d_quantization=False,
+ stochastic_rounding=False,
+ with_random_sign_mask=False,
+ )(test)
else:
raise ValueError(f"Unsupported quantization scheme ({quantization})")
if isinstance(test, QuantizedTensor) and not test_is_quantized:
@@ -399,12 +421,12 @@ def test_fp8_scale_update(
torch.testing.assert_close(
y,
torch.full_like(y, y_val_ref),
- **dtype_tols(tex.DType.kFloat8E4M3),
+ **quantization_tols("fp8_delayed_scaling"),
)
torch.testing.assert_close(
x.grad,
torch.full_like(x.grad, dx_val_ref),
- **dtype_tols(tex.DType.kFloat8E5M2),
+ **quantization_tols("fp8_delayed_scaling"),
)
# Check that scaling factors match expected
@@ -438,7 +460,8 @@ def test_dtype_cast(
# Skip invalid configurations
in_shape = (size, size)
with_quantization = quantization is not None
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=init_dtype)
+ maybe_skip_quantization(quantization, dtype=final_dtype)
# Random data
dtype = torch.float32
@@ -506,7 +529,8 @@ def test_pyt_autocast(
# Skip invalid configurations
in_shape = (size, size)
quantized_compute = quantization is not None
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=model_dtype)
+ maybe_skip_quantization(quantization, dtype=autocast_dtype)
# Construct operation
recipe = make_recipe(quantization)
@@ -562,7 +586,7 @@ def test_identity(
# Skip invalid configurations
with_quantization = quantization is not None
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
@@ -628,7 +652,7 @@ def test_reshape(
# Skip invalid configurations
if memory_format == torch.channels_last and len(in_shape) != 4:
pytest.skip("torch.channels_last only supports 4D tensors")
- maybe_skip_quantization(quantization, device=device)
+ maybe_skip_quantization(quantization, device=device, dtype=dtype)
with_quantization = quantization is not None
# Random data
@@ -694,7 +718,7 @@ def test_bias(
# Skip invalid configurations
with_quantization = quantization is not None
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
@@ -756,7 +780,7 @@ def test_quantize(
# Skip invalid configurations
with_quantization = quantization is not None
- maybe_skip_quantization(quantization, device=device)
+ maybe_skip_quantization(quantization, device=device, dtype=dtype)
if quantization == "mxfp8":
maybe_skip_quantization(quantization, dims=in_shape)
@@ -823,7 +847,7 @@ def _test_basic_linear(
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
quantization_needed = any(
(
@@ -903,7 +927,7 @@ def _test_basic_linear(
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute or quantized_output or quantized_grad_input:
- tols = dtype_tols(tex.DType.kFloat8E4M3)
+ tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
@@ -1024,7 +1048,7 @@ def test_linear(
out_shape = in_shape[:-1] + [out_features]
# Skip invalid configurations
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantization is None and (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not specified")
@@ -1091,7 +1115,7 @@ def test_linear(
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
- tols = dtype_tols(tex.DType.kFloat8E4M3)
+ tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
@@ -1128,7 +1152,7 @@ def test_layer_norm(
in_shape = list(in_shape)[:-1] + list(weight_shape)
# Skip invalid configurations
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
@@ -1189,7 +1213,7 @@ def test_layer_norm(
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute:
- tols = dtype_tols(tex.DType.kFloat8E4M3)
+ tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
@@ -1298,7 +1322,7 @@ def test_rmsnorm(
in_shape = list(in_shape)[:-1] + list(weight_shape)
# Skip invalid configurations
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
@@ -1354,7 +1378,7 @@ def test_rmsnorm(
# Explicit checks for quantization
if quantized_compute:
- tols = dtype_tols(y_test._quantizer.dtype)
+ tols = quantization_tols(quantization)
expected_tensor_cls = {
Float8Quantizer:Float8Tensor,
Float8CurrentScalingQuantizer:Float8Tensor,
@@ -1441,7 +1465,7 @@ def test_add_extra_input(
# Skip invalid configurations
with_quantization = quantization is not None
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x1_ref, x1_test = make_reference_and_test_tensors(
@@ -1480,8 +1504,11 @@ def test_add_extra_input(
# Check results
tols = dtype_tols(dtype)
- if with_quantization:
- tols = dtype_tols(x1_test._fp8_dtype)
+ if in_place:
+ if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"):
+ tols = dtype_tols(x1_test._fp8_dtype)
+ elif quantization == "nvfp4":
+ tols = dtype_tols(x1_test._fp4_dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu")
dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu")
@@ -1510,7 +1537,7 @@ def test_make_extra_output(
# Skip invalid configurations
with_quantization = quantization is not None
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
@@ -1583,7 +1610,7 @@ def test_activation(
# Skip invalid configurations
quantized_compute = quantization is not None
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
if cache_quantized_input:
maybe_skip_quantization("fp8_current_scaling", device=device)
@@ -1657,8 +1684,10 @@ def test_activation(
# Expected numerical error
tols = dtype_tols(dtype)
- if quantized_compute or cache_quantized_input:
- tols = dtype_tols(tex.DType.kFloat8E4M3)
+ if quantized_compute:
+ tols = quantization_tols(quantization)
+ elif cache_quantized_input:
+ tols = quantization_tols("fp8_current_scaling")
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
@@ -1689,7 +1718,7 @@ def test_swiglu(
quantized_compute = quantization is not None
if not quantized_compute and (quantize_forward or quantize_backward):
pytest.skip("Quantization scheme has not been provided")
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
@@ -1723,7 +1752,7 @@ def test_swiglu(
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute:
- tols = dtype_tols(tex.DType.kFloat8E4M3)
+ tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
@@ -1791,7 +1820,7 @@ def test_dropout(
# Skip invalid configurations
quantized_input = quantization is not None
- maybe_skip_quantization(quantization, dims=shape, device=device)
+ maybe_skip_quantization(quantization, dims=shape, device=device, dtype=dtype)
# Random data
# Note: Shift values to make sure inputs are non-zero
@@ -1882,7 +1911,7 @@ def test_forward_linear_bias_activation(
# Skip invalid configurations
quantized_compute = quantization is not None
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if dtype not in (torch.float16, torch.bfloat16):
pytest.skip(
@@ -1953,7 +1982,7 @@ def test_forward_linear_bias_activation(
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
- tols = dtype_tols(tex.DType.kFloat8E4M3)
+ tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
@@ -1989,7 +2018,7 @@ def test_forward_linear_bias_add(
# Skip invalid configurations
quantized_compute = quantization is not None
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
@@ -2064,7 +2093,7 @@ def test_forward_linear_bias_add(
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
- tols = dtype_tols(tex.DType.kFloat8E4M3)
+ tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
@@ -2102,7 +2131,7 @@ def test_forward_linear_scale_add(
# Skip invalid configurations
quantized_compute = quantization is not None
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
@@ -2170,7 +2199,7 @@ def test_forward_linear_scale_add(
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
- tols = dtype_tols(tex.DType.kFloat8E4M3)
+ tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
@@ -2203,7 +2232,7 @@ def test_backward_activation_bias(
# Skip invalid configurations
with_quantization = quantization is not None
- maybe_skip_quantization(quantization, device=device)
+ maybe_skip_quantization(quantization, device=device, dtype=dtype)
if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0):
pytest.skip("Unsupported tensor size for MXFP8")
@@ -2265,7 +2294,7 @@ def test_backward_activation_bias(
# Expected numerical error
tols = dtype_tols(dtype)
if with_quantization:
- tols = dtype_tols(tex.DType.kFloat8E4M3)
+ tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
@@ -2384,7 +2413,7 @@ def test_backward_linear_add(
# Skip invalid configurations
quantized_compute = quantization is not None
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
@@ -2452,7 +2481,7 @@ def test_backward_linear_add(
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
- tols = dtype_tols(tex.DType.kFloat8E4M3)
+ tols = quantization_tols(quantization)
# Check results
y1_test = y1_test.to(dtype=torch.float64, device="cpu")
@@ -2487,7 +2516,7 @@ def test_backward_linear_scale(
# Skip invalid configurations
quantized_compute = quantization is not None
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
if quantized_compute and dtype not in (torch.float16, torch.bfloat16):
pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output")
@@ -2547,7 +2576,7 @@ def test_backward_linear_scale(
if dtype == torch.float32:
tols = dtype_tols(torch.float16) # TF32 GEMM
if quantized_compute:
- tols = dtype_tols(tex.DType.kFloat8E4M3)
+ tols = quantization_tols(quantization)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
@@ -2588,7 +2617,7 @@ def test_linear(
# Skip invalid configurations
quantized_compute = quantization is not None
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
# Construct model
@@ -2714,7 +2743,7 @@ def test_layernorm_mlp(
ffn_shape = in_shape[:-1] + (ffn_hidden_size,)
# Skip invalid configurations
- maybe_skip_quantization(quantization, dims=in_shape, device=device)
+ maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=ffn_shape, device=device)
quantization_needed = quantized_compute or quantized_weight
if quantization is None and quantization_needed:
diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py
index c59bf376a..cfbb8094b 100644
--- a/tests/pytorch/test_recipe.py
+++ b/tests/pytorch/test_recipe.py
@@ -21,6 +21,7 @@
fp8_model_init,
)
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
+from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.utils import is_fp8_fnuz
from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear
@@ -502,3 +503,39 @@ def test_quantizer_update(self, module_class):
y = module(x, [batch_size])
else:
y = module(x)
+
+
+fp4_available, reason_for_no_fp4 = FP8GlobalStateManager.is_nvfp4_available()
+
+
+@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4)
+@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str)
+@pytest.mark.parametrize(
+ "M, N",
+ [
+ # full tile cases
+ (128, 128),
+ (256, 1024),
+ (1024, 256),
+ # Padding required cases
+ (256, 272),
+ (304, 304),
+ (320, 256),
+ # # largest tile
+ (8192, 8192),
+ ],
+)
+def test_fp4_dequantize(dtype, M, N):
+ q = NVFP4Quantizer()
+ a = torch.rand((M, N)).cuda().to(dtype=dtype)
+ starting_tensor = q(a)
+ dequantized_tensor = starting_tensor.dequantize()
+ new_tensor = q(dequantized_tensor)
+ torch.testing.assert_close(
+ new_tensor._rowwise_data,
+ starting_tensor._rowwise_data,
+ rtol=0,
+ atol=0,
+ )
+ new_dequantized_tensor = new_tensor.dequantize()
+ torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor)
diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py
index a7d762c3d..dbc1ed42f 100644
--- a/tests/pytorch/test_sanity.py
+++ b/tests/pytorch/test_sanity.py
@@ -90,9 +90,19 @@ def is_fp8_supported(config: ModelConfig):
"large": ModelConfig(2, 128, 4, 128, num_layers=1),
}
+
+def nvfp4_vanilla():
+ nvfp4_recipe = recipe.NVFP4BlockScaling()
+ nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams()
+ nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams()
+ nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams()
+ return nvfp4_recipe
+
+
fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
+ fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this
if fp8_block_scaling_available:
fp8_recipes.append(recipe.Float8BlockScaling())
if fp8_available:
@@ -382,6 +392,8 @@ def test_sanity_layernorm_linear(
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
+ if fp8_recipe.nvfp4() and dtype == torch.float16:
+ pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
@@ -410,6 +422,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
+ if fp8_recipe.nvfp4() and dtype == torch.float16:
+ pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
@@ -440,6 +454,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
+ if fp8_recipe.nvfp4() and dtype == torch.float16:
+ pytest.skip("FP16 output for NVFP4 not supported")
use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
@@ -479,6 +495,8 @@ def test_sanity_grouped_linear(
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
+ if fp8_recipe.nvfp4():
+ pytest.skip("NVFP4 not supported for grouped linear")
use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
@@ -529,6 +547,8 @@ def test_sanity_layernorm_mlp(
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
+ if fp8_recipe.nvfp4() and dtype == torch.float16:
+ pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
@@ -571,6 +591,8 @@ def test_sanity_gpt(
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
+ if fp8_recipe.nvfp4() and dtype == torch.float16:
+ pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
@@ -632,6 +654,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization):
pytest.skip(reason_for_no_fp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
+ if fp8_recipe.nvfp4() and dtype == torch.float16:
+ pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
@@ -686,6 +710,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization):
pytest.skip(reason_for_no_fp8)
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
+ if fp8_recipe.nvfp4() and dtype == torch.float16:
+ pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
@@ -737,6 +763,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
+ if fp8_recipe.nvfp4() and dtype == torch.float16:
+ pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
@@ -767,6 +795,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model):
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
+ if fp8_recipe.nvfp4() and dtype == torch.float16:
+ pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
@@ -801,6 +831,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
+ if fp8_recipe.nvfp4() and dtype == torch.float16:
+ pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
@@ -835,6 +867,8 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
+ if fp8_recipe.nvfp4() and dtype == torch.float16:
+ pytest.skip("FP16 output for NVFP4 not supported")
sigma = 0.023
init_method = init_method_normal(sigma)
diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py
index 684c15737..96f41c72e 100644
--- a/tests/pytorch/utils.py
+++ b/tests/pytorch/utils.py
@@ -24,6 +24,7 @@
get_attention_backend,
AttentionParams,
AttentionLogging,
+ check_set_window_size,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
@@ -78,6 +79,8 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
# Transformer Engine dtypes
if isinstance(dtype, tex.DType):
+ if dtype == tex.DType.kFloat4E2M1:
+ return dict(rtol=0.25, atol=0.125) # epsilon = 0.25
dtype = {
tex.DType.kByte: torch.uint8,
tex.DType.kInt32: torch.int32,
@@ -100,10 +103,25 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]:
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e4m3fnuz:
return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625
if dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz:
- return dict(rtol=0.25, atol=0.125) # epsilon = 0.152
+ return dict(rtol=0.25, atol=0.125) # epsilon = 0.125
raise ValueError(f"Unsupported dtype ({dtype})")
+def quantization_tols(name: str) -> dict[str, float]:
+ """Estimated numerical error for a quantization scheme"""
+ if name in (
+ "fp8",
+ "fp8_delayed_scaling",
+ "fp8_current_scaling",
+ "mxfp8",
+ "mxfp8_block_scaling",
+ ):
+ return dtype_tols(tex.DType.kFloat8E4M3)
+ if name == "nvfp4":
+ return dtype_tols(tex.DType.kFloat4E2M1)
+ raise ValueError(f"Unsupported quantization scheme ({name})")
+
+
def make_recipe(name: Optional[str]) -> Optional[Recipe]:
"""Make recipe for quantization scheme"""
if name is None:
@@ -123,6 +141,12 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]:
)
if name == "fp8_block_scaling":
return transformer_engine.common.recipe.Float8BlockScaling()
+ if name == "nvfp4":
+ return transformer_engine.common.recipe.NVFP4BlockScaling(
+ disable_rht=True,
+ disable_stochastic_rounding=True,
+ disable_2d_quantization=True,
+ )
raise ValueError(f"Unsupported quantization scheme ({name})")
@@ -143,6 +167,31 @@ def reset_rng_states() -> None:
torch.cuda.set_rng_state(cuda_rng_state)
+def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8):
+ if not is_fp8:
+ torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
+ return
+
+ try:
+ if a.dtype != b.dtype:
+ a = a.to(b.dtype)
+ torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
+ except Exception as e:
+ logging.debug(e)
+
+ rmse = torch.sqrt((a - b).square().mean()).item()
+ logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse))
+ rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item())
+ assert rmse < rmse_tol * rmse_range, (
+ name_a
+ + " vs "
+ + name_b
+ + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
+ rmse, rmse_tol * rmse_range, rmse_tol, rmse_range
+ )
+ )
+
+
class ModelConfig:
def __init__(
self,
@@ -153,12 +202,15 @@ def __init__(
max_seqlen_kv: int = None,
num_gqa_groups: int = None,
head_dim_v: int = None,
+ softmax_type: str = "vanilla",
dropout_p: float = 0.0,
attn_mask_type: str = "no_mask",
attn_bias_type: str = "no_bias",
alibi_type: str = "none",
bias_shape: str = "1hss",
window_size: Tuple[int, int] = (-1, -1),
+ context_parallel: bool = False,
+ cp_comm_type: str = "p2p",
total_requests: int = None,
max_ctx_len: int = None,
num_layers: int = 1,
@@ -177,13 +229,16 @@ def __init__(
self.kv_channels = (self.head_dim_qk, self.head_dim_v)
self.hidden_size = self.num_heads * self.head_dim_qk
self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v
+ self.softmax_type = softmax_type
self.dropout_p = dropout_p
self.attn_mask_type = attn_mask_type
self.attn_bias_type = attn_bias_type
self.alibi_type = alibi_type
self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross"
self.bias_shape = bias_shape
- self.window_size = window_size
+ self.window_size = check_set_window_size(self.attn_mask_type, window_size)
+ self.context_parallel = context_parallel
+ self.cp_comm_type = cp_comm_type
self.total_requests = total_requests
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
@@ -221,9 +276,7 @@ def get_available_attention_backends(
config: ModelConfig,
qkv_dtype: torch.dtype,
qkv_layout: str,
- window_size: Tuple[int, int] = (-1, -1),
pad_between_seqs: bool = False,
- context_parallel: bool = False,
deterministic: bool = False,
fp8: bool = False,
fp8_meta: Optional[Dict[str, Any]] = None,
@@ -276,19 +329,21 @@ def test():
head_dim_qk=config.head_dim_qk,
head_dim_v=config.head_dim_v,
attn_mask_type=config.attn_mask_type,
- window_size=window_size,
+ window_size=config.window_size,
alibi_slopes_shape=alibi_slopes_shape,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias_shape=core_attention_bias_shape,
core_attention_bias_requires_grad=core_attention_bias_requires_grad,
pad_between_seqs=pad_between_seqs,
attention_dropout=config.dropout_p,
- context_parallel=context_parallel,
+ context_parallel=config.context_parallel,
+ cp_comm_type=config.cp_comm_type,
deterministic=deterministic,
fp8=fp8,
fp8_meta=fp8_meta,
is_training=is_training,
inference_params=inference_params,
+ softmax_type=config.softmax_type,
)
(
use_flash_attention,
diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt
index cefec6d06..bd2a023bb 100644
--- a/transformer_engine/common/CMakeLists.txt
+++ b/transformer_engine/common/CMakeLists.txt
@@ -120,6 +120,30 @@ endif()
# Python
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
+if(USE_CUDA)
+ # NVIDIA MathDX include directory (from Python package install location)
+ if(NOT DEFINED MATHDX_INCLUDE_DIR)
+ execute_process(
+ COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx
+ OUTPUT_VARIABLE _PIP_SHOW_MATHDX
+ ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR
+ RESULT_VARIABLE _PIP_SHOW_MATHDX_RES
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+ if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0)
+ message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}")
+ endif()
+ string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}")
+ if(NOT _MATHDX_LOC_MATCH)
+ message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}")
+ endif()
+ set(MATHDX_LOCATION "${CMAKE_MATCH_1}")
+ set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include")
+ endif()
+ if(NOT EXISTS "${MATHDX_INCLUDE_DIR}")
+ message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.")
+ endif()
+endif()
+
# Configure Transformer Engine library
include_directories(${PROJECT_SOURCE_DIR}/..)
set(transformer_engine_SOURCES)
@@ -145,6 +169,7 @@ list(APPEND transformer_engine_SOURCES
fused_attn/kv_cache.cu
activation/relu.cu
activation/swiglu.cu
+ gemm/config.cpp
gemm/cublaslt_gemm.cu
normalization/common.cpp
normalization/layernorm/ln_api.cpp
@@ -178,6 +203,7 @@ list(APPEND transformer_engine_SOURCES
cudnn_utils.cpp
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
+ transpose/quantize_transpose_vector_blockwise_fp4.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
fused_attn/fused_attn_fp8.cu
@@ -185,6 +211,9 @@ list(APPEND transformer_engine_SOURCES
fused_attn/utils.cu
gemm/cutlass_grouped_gemm.cu
util/cuda_nvml.cpp
+ recipe/nvfp4.cu
+ hadamard_transform/hadamard_transform.cu
+ hadamard_transform/hadamard_transform_cast_fusion.cu
comm_gemm_overlap/userbuffers/ipcsocket.cc
comm_gemm_overlap/userbuffers/userbuffers-host.cpp
comm_gemm_overlap/userbuffers/userbuffers.cu
@@ -261,7 +290,8 @@ target_link_libraries(transformer_engine PUBLIC
CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE
- ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
+ ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
+target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR})
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
index ec29e6e12..56369db27 100644
--- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
@@ -64,6 +64,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
#endif
_comm_created = true;
}
+
+ initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority,
+ num_comm_sm, set_sm_margin, use_ce, atomic_gemm);
+}
+
+void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams,
+ int comm_cga_size, int gemm_priority, int comm_priority,
+ int num_comm_sm, bool set_sm_margin, bool use_ce,
+ bool atomic_gemm) {
_use_ce = static_cast(use_ce);
_num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size;
@@ -278,6 +287,11 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false,
atomic_gemm) {
+ initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm);
+}
+
+void CommOverlapBase::initialize(const std::vector &buffer_shape, DType buffer_dtype,
+ bool rs_overlap_first_gemm) {
_rs_overlap_first_gemm = rs_overlap_first_gemm;
_rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0);
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
@@ -288,7 +302,9 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType
size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
- if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
+ if (_ub_comm->myrank == 0) {
+ printf("!!! [UB] Register UBuf %d\n", _ub_reg);
+ }
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
NVTE_CHECK_CUDA(
@@ -640,6 +656,11 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm) {
+ initialize(buffer_shape, buffer_dtype, comm_type, aggregate);
+}
+
+void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DType buffer_dtype,
+ CommOverlapType comm_type, bool aggregate) {
_is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS;
_aggregate = aggregate;
@@ -647,28 +668,28 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape,
// Create workspace tensor with userbuffer
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype);
- int buffer_chunk_bytes = buffer_bytes / tp_size;
- _num_ubuf_chunks = tp_size;
+ int buffer_chunk_bytes = buffer_bytes / _tp_size;
+ _num_ubuf_chunks = _tp_size;
if (_is_reduce_scatter) {
// GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk
// outputs for reduction at the end of the pipelining.
- buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1);
- _num_ubuf_chunks = tp_size * 2 - 1;
+ buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1);
+ _num_ubuf_chunks = _tp_size * 2 - 1;
}
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
- if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
+ if (_rank == 0) printf("!!! [UBP2P] UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(
buffer_ptr,
- std::vector{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]},
+ std::vector{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]},
buffer_dtype);
// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr);
for (int i = 0; i < _num_ubuf_chunks; i++) {
_ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr),
- std::vector{buffer_shape[0] / tp_size, buffer_shape[1]},
+ std::vector{buffer_shape[0] / _tp_size, buffer_shape[1]},
buffer_dtype));
ubuf_byte_ptr += buffer_chunk_bytes;
}
@@ -691,7 +712,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape,
NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t)));
}
- for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) {
+ for (int i = 0; i < _stream_compute.size(); i++) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority));
_stream_send.push_back(std::move(stream));
@@ -711,6 +732,38 @@ CommOverlapP2PBase::~CommOverlapP2PBase() {
}
}
+void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source,
+ bool local_chunk, bool rowwise) {
+ // Check element size
+ const size_t element_size = source.element_size();
+ NVTE_CHECK(_ubuf.element_size() == element_size,
+ "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ",
+ "(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(),
+ " bytes)");
+
+ // Input data
+ const size_t source_size = source.numel();
+ const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr();
+
+ // Userbuffers data
+ void *dst_ptr;
+ if (local_chunk) {
+ NVTE_CHECK(_ubufs[_tp_id].numel() == source_size,
+ "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ",
+ "(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")");
+ dst_ptr = _ubufs[_tp_id].dptr();
+ } else {
+ NVTE_CHECK(_ubuf.numel() == source_size,
+ "Tried to copy an invalid tensor into a Userbuffers buffer ",
+ "(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")");
+ dst_ptr = _ubuf.dptr();
+ }
+
+ // Copy data
+ NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size,
+ cudaMemcpyDeviceToDevice, stream));
+}
+
TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,
size_t chunk_id) {
// Start with a chunk of the source tensor
@@ -851,6 +904,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
const bool do_gelu = pre_gelu_out.numel() > 0;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
+ // Check B copy sizing
+ if (B_copy.numel() > 0) {
+ NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ",
+ _ubuf.numel(), " elements but got ", B_copy.numel());
+ NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(),
+ "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8,
+ "-bit data type but got ", B_copy.element_size() * 8, "-bit");
+ }
+
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
@@ -919,12 +981,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
- } else if (B_copy.numel() > 0) {
- assert(B_copy.numel() == _ubufs[_tp_id].numel());
- assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
- NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
- _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
- _stream_send[0]));
}
}
} else {
@@ -972,16 +1028,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
- } else if (B_copy.numel() > 0) {
- assert(B_copy.numel() == _ubufs[_tp_id].numel());
- assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
- NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
- _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice,
- _stream_send[0]));
}
}
}
+ // Copy all-gathered B from communication buffer into auxiliary output
+ if (B_copy.numel() > 0) {
+ NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(),
+ cudaMemcpyDeviceToDevice, _stream_send[0]));
+ }
+
_ub_comm->sms = ori_sms;
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
index 1ce89c512..6c7bed55a 100644
--- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
+++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
@@ -670,9 +670,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
reinterpret_cast(&memhndl), sizeof(cudaIpcMemHandle_t),
comm->comm_intra);
+ // Check for NVLINK support before attempting IPC operations
+ if (comm->nvsize > 1) {
+ int current_device;
+ NVTE_CHECK_CUDA(cudaGetDevice(¤t_device));
+ cudaDeviceProp deviceProp;
+ NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, current_device));
+ bool peer_access_available = false;
+ for (int i = 0; i < comm->nvsize; i++) {
+ if (i != comm->nvrank) {
+ int can_access_peer;
+ cudaError_t peer_result = cudaDeviceCanAccessPeer(&can_access_peer, current_device, i);
+ if (peer_result == cudaSuccess && can_access_peer) {
+ peer_access_available = true;
+ break;
+ }
+ }
+ }
+ if (!peer_access_available) {
+ free(tmp);
+ NVTE_ERROR(
+ "No peer-to-peer access available between GPUs. This platform does not support the "
+ "GPU-to-GPU "
+ "communication required for multi-GPU userbuffers. Consider using single-GPU mode.");
+ return 1;
+ }
+ }
+
for (int i = 0; i < comm->nvsize; i++) {
if (i != comm->nvrank) {
- NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*)
+ NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i],
cudaIpcMemLazyEnablePeerAccess));
}
}
@@ -693,4 +720,5 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
comm->mem_ptr[hndl] = *gpubuff;
return comm->free_region++;
+ printf("***** Returning *****\n");
}
diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu
index e67694c38..af3a51373 100644
--- a/transformer_engine/common/common.cu
+++ b/transformer_engine/common/common.cu
@@ -42,6 +42,10 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
return CUDA_R_8F_E4M3;
case DType::kFloat8E5M2:
return CUDA_R_8F_E5M2;
+#if CUDA_VERSION >= 12080
+ case DType::kFloat4E2M1:
+ return CUDA_R_4F_E2M1;
+#endif
default:
NVTE_ERROR("Invalid type");
}
@@ -165,7 +169,9 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
- const uint32_t offset_elems, const size_t type_num_bits) {
+ const uint32_t offset_elems, const size_t type_num_bits,
+ const CUtensorMapSwizzle swizzle) {
+ cuda_driver::ensure_context_exists();
// Get a function pointer to the cuTensorMapEncodeTiled driver API
// Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {
@@ -174,6 +180,8 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
}();
// rank is the number of dimensions of the array
constexpr uint32_t rank = 2;
+
+ // Dimension for the packed data types must reflect the number of individual U# values.
uint64_t size[rank] = {globalX, globalY};
// The stride is the number of bytes to traverse from the first element of one row to the next
@@ -212,7 +220,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
// Swizzling can be used to avoid shared memory bank conflicts.
- CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
+ swizzle,
// L2 Promotion can be used to widen the effect of a cache-policy to a wider
// set of L2 cache lines.
diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h
index ce510334b..d1208f1a0 100644
--- a/transformer_engine/common/common.h
+++ b/transformer_engine/common/common.h
@@ -54,8 +54,14 @@ inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) {
return mode == NVTE_DELAYED_TENSOR_SCALING;
}
+inline bool is_nvfp4_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }
+
+inline bool is_mxfp8_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }
+
inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; }
+inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; }
+
inline size_t product(const std::vector &shape, const size_t begin, const size_t end) {
NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ",
end, " in a vector with ", shape.size(), " entries");
@@ -114,6 +120,7 @@ struct Tensor {
SimpleTensor data;
SimpleTensor columnwise_data;
SimpleTensor amax;
+ SimpleTensor columnwise_amax;
SimpleTensor scale;
SimpleTensor scale_inv;
SimpleTensor columnwise_scale_inv;
@@ -125,6 +132,7 @@ struct Tensor {
: data(),
columnwise_data(),
amax(nullptr, {1}, DType::kFloat32),
+ columnwise_amax(nullptr, {1}, DType::kFloat32),
scale(nullptr, {1}, DType::kFloat32),
scale_inv(nullptr, {1}, DType::kFloat32),
columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
@@ -135,6 +143,7 @@ struct Tensor {
data.clear();
columnwise_data.clear();
amax.clear();
+ columnwise_amax.clear();
scale.clear();
scale_inv.clear();
columnwise_scale_inv.clear();
@@ -180,6 +189,7 @@ struct Tensor {
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569).
*/
switch (scaling_mode) {
+ case NVTE_NVFP4_1D_SCALING:
case NVTE_DELAYED_TENSOR_SCALING:
if (!has_data() && has_columnwise_data()) {
std::vector ret;
@@ -195,7 +205,6 @@ struct Tensor {
}
break;
case NVTE_MXFP8_1D_SCALING:
- case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
if (!has_data() && has_columnwise_data()) {
return columnwise_data.shape;
} else {
@@ -267,12 +276,18 @@ struct QuantizationConfig {
NVTETensor noop_tensor = nullptr;
Float8BlockScaleTensorFormat float8_block_scale_tensor_format =
Float8BlockScaleTensorFormat::GEMM_READY;
+ NVTETensor rng_state = nullptr;
+ bool nvfp4_2d_quantization = false;
+ bool stochastic_rounding = false;
static constexpr size_t attr_sizes[] = {
- sizeof(bool), // force_pow_2_scales
- sizeof(float), // amax_epsilon
- sizeof(NVTETensor), // noop_tensor
- sizeof(Float8BlockScaleTensorFormat) // float8_block_scale_tensor_format
+ sizeof(bool), // force_pow_2_scales
+ sizeof(float), // amax_epsilon
+ sizeof(NVTETensor), // noop_tensor
+ sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format
+ sizeof(NVTETensor), // rng_seed and offset
+ sizeof(bool), // nvfp4_2d_quantization
+ sizeof(bool) // stochastic_rounding
};
};
@@ -322,6 +337,8 @@ using fp8e8m0 = __nv_fp8_e8m0;
#endif // CUDA_VERSION >= 12080
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
+using fp4e2m1x2 = __nv_fp4x2_e2m1;
+using fp4e2m1x4 = __nv_fp4x4_e2m1;
#endif //FP4_TYPE_SUPPORTED
#else
using bf16 = hip_bfloat16;
@@ -370,6 +387,7 @@ struct TypeExtrema;
template <>
struct TypeExtrema {
static constexpr float max = 6.0f;
+ static constexpr float max_inverse = 1.0 / max;
};
#endif
@@ -377,16 +395,20 @@ template <>
struct TypeExtrema {
#ifndef __HIP_PLATFORM_AMD__
static constexpr float max = 448.0f;
+ static constexpr float max_inverse = 1.0 / max;
#elif defined(__HIP_DEVICE_COMPILE__)
static constexpr float maxNorm = te_fp8_fnuz() ? 240.0f : 448.0f;
+ static constexpr float max_inverse = 1.0 / maxNorm;
#else
static float maxNorm;
+ static constexpr float max_inverse = 1.0 / 448.0f;
#endif
};
template <>
struct TypeExtrema {
static constexpr float max = 57344.0f;
+ static constexpr float max_inverse = 1.0 / max;
};
template <>
@@ -600,6 +622,18 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}
+// Add a pack_size argument to select the packed type for FP4
+#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \
+ switch (dtype) { \
+ using namespace transformer_engine; \
+ case DType::kFloat4E2M1: { \
+ using type = __nv_fp4x2_storage_t; \
+ { __VA_ARGS__ } \
+ } break; \
+ default: \
+ NVTE_ERROR("Invalid type."); \
+ }
+
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
@@ -764,10 +798,11 @@ void checkCuDriverContext(CUstream stream);
CUtensorMapDataType get_CUtensorMapDataType(DType dtype);
// Set up parameters to create TMA descriptor.
-void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
- const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
- const uint32_t shmemX, const uint32_t stride_elems,
- const uint32_t offset_elems, const size_t type_num_bits);
+void create_2D_tensor_map(
+ CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY,
+ const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX,
+ const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits,
+ const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE);
#endif //#ifdef __HIP_PLATFORM_AMD__
bool is_supported_by_CC_100();
diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp
index 795697635..77cd8d235 100644
--- a/transformer_engine/common/fused_attn/fused_attn.cpp
+++ b/transformer_engine/common/fused_attn/fused_attn.cpp
@@ -135,9 +135,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
// select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
- NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
- size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
- size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
+ NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
+ float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
+ size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
+ int64_t window_size_right) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
@@ -175,7 +176,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// sm90: fwd d<=256, bwd d=128 only
// sm100: fwd d<=128, bwd d<=128
- ((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) ||
+ ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) ||
+ (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) ||
(sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) &&
head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
@@ -183,7 +185,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
- !requires_64bit_ragged_offset &&
+ !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) &&
// 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000)) {
if (cudnn_runtime_version >= 8900) {
@@ -213,7 +215,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) ||
(qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) &&
((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) &&
- !requires_64bit_ragged_offset) {
+ !requires_64bit_ragged_offset &&
+ (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) {
flag_m512 = true;
}
if (
@@ -363,7 +366,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// check 64-bit ragged offset support
(supported_ragged_offset_size) &&
// 9.10.0/9.10.1: known bugs with SDPA F16
- (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) {
+ (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001) &&
+ // softmax type
+ // pre-9.13.1: vanilla
+ // 9.13.1+: vanilla, off-by-one, learnable
+ (cudnn_runtime_version >= 91301 ||
+ (cudnn_runtime_version < 91301 &&
+ softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
@@ -405,14 +414,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
}
// NVTE fused attention FWD with packed QKV
-void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
- NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
- const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
- const NVTETensor rng_state, size_t max_seqlen, bool is_training,
- float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
+void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
+ const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
+ NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
+ const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
+ size_t max_seqlen, bool is_training, float attn_scale,
+ float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
- int64_t window_size_left, int64_t window_size_right,
- NVTETensor workspace, cudaStream_t stream) {
+ NVTE_Softmax_Type softmax_type, int64_t window_size_left,
+ int64_t window_size_right, NVTETensor workspace,
+ cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
@@ -421,6 +432,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
const Tensor *input_QKV = convertNVTETensorCheck(QKV);
const Tensor *input_Bias = convertNVTETensorCheck(Bias);
+ const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset);
Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace);
@@ -447,8 +459,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTEDType QKV_type = static_cast(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
- is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h,
- max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
+ is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
+ h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
@@ -463,9 +475,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
#if (CUDNN_VERSION >= 8900)
fused_attn_arbitrary_seqlen_fwd_qkvpacked(
b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type,
- attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O,
- Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace,
- stream, handle);
+ attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias,
+ input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded,
+ input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
@@ -487,10 +499,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
- NVTETensor dBias, const NVTETensor cu_seqlens,
- const NVTETensor cu_seqlens_padded, size_t max_seqlen,
- float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
- NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
+ NVTETensor dBias, NVTETensor dSoftmaxOffset,
+ const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
+ size_t max_seqlen, float attn_scale, float dropout,
+ NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
+ NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
@@ -505,6 +518,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
Tensor *input_output_dP = convertNVTETensorCheck(dP);
Tensor *output_dQKV = convertNVTETensorCheck(dQKV);
Tensor *output_dBias = convertNVTETensorCheck(dBias);
+ Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset);
Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_QKV->data.shape.size();
@@ -529,8 +543,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTEDType QKV_type = static_cast(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
- true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
- max_seqlen, d, d, window_size_left, window_size_right);
+ true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h,
+ max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
@@ -543,19 +557,22 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
- Tensor *input_Bias, *input_rng_state;
+ size_t i = 0;
+ Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ Tensor *input_Bias, *input_SoftmaxOffset;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
- input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
- input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
- } else {
- input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
+ input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ }
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
fused_attn_arbitrary_seqlen_bwd_qkvpacked(
b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
- window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO,
- input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded,
- input_rng_state, wkspace, stream, handle);
+ softmax_type, window_size_left, window_size_right, deterministic, input_QKV, input_O,
+ input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQKV, output_dBias,
+ output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace,
+ stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
@@ -580,14 +597,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
}
// NVTE fused attention FWD with packed KV
void nvte_fused_attn_fwd_kvpacked(
- const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
- NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
- const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
- const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
- size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
+ const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
+ NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
+ const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
+ const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
+ const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
+ size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
- int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
- cudaStream_t stream) {
+ NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
+ NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
@@ -600,6 +618,7 @@ void nvte_fused_attn_fwd_kvpacked(
const Tensor *input_Q = convertNVTETensorCheck(Q);
const Tensor *input_KV = convertNVTETensorCheck(KV);
const Tensor *input_Bias = convertNVTETensorCheck(Bias);
+ const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset);
Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace);
@@ -660,8 +679,8 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTEDType KV_type = static_cast(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
- is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
- max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
+ is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
+ h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
@@ -677,10 +696,11 @@ void nvte_fused_attn_fwd_kvpacked(
fused_attn_arbitrary_seqlen_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
- dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
- input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
- input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
- input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
+ dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left,
+ window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O,
+ Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
+ input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state,
+ wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
@@ -702,12 +722,12 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
- NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
- const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
- size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
- NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
- int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
- cudaStream_t stream) {
+ NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q,
+ const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
+ const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
+ float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
+ NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
+ int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
@@ -723,6 +743,7 @@ void nvte_fused_attn_bwd_kvpacked(
Tensor *output_dQ = convertNVTETensorCheck(dQ);
Tensor *output_dKV = convertNVTETensorCheck(dKV);
Tensor *output_dBias = convertNVTETensorCheck(dBias);
+ Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset);
Tensor *wkspace = convertNVTETensor(workspace);
size_t b = input_cu_seqlens_q->data.shape[0] - 1;
@@ -755,8 +776,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType KV_type = static_cast(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
- true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
- max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
+ true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
+ h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
@@ -770,20 +791,23 @@ void nvte_fused_attn_bwd_kvpacked(
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8903)
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
- Tensor *input_Bias, *input_rng_state;
+ size_t i = 0;
+ Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ Tensor *input_Bias, *input_SoftmaxOffset;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
- input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
- input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
- } else {
- input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
+ input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ }
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
fused_attn_arbitrary_seqlen_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
- bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q,
- input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias,
- input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
- input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
+ bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic,
+ input_Q, input_KV, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ,
+ output_dKV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv,
+ input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream,
+ handle);
#else
const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
@@ -809,16 +833,17 @@ void nvte_fused_attn_bwd_kvpacked(
}
// NVTE fused attention FWD with separate Q, K and V
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
- const NVTETensor Bias, NVTETensor S, NVTETensor O,
- NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
- const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
+ const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
+ NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
+ const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
+ const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
- int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
- cudaStream_t stream) {
+ NVTE_Softmax_Type softmax_type, int64_t window_size_left,
+ int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
@@ -832,6 +857,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const Tensor *input_K = convertNVTETensorCheck(K);
const Tensor *input_V = convertNVTETensorCheck(V);
const Tensor *input_Bias = convertNVTETensorCheck(Bias);
+ const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset);
Tensor *input_output_S = convertNVTETensorCheck(S);
Tensor *output_O = convertNVTETensorCheck(O);
Tensor *wkspace = convertNVTETensor(workspace);
@@ -886,8 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
- is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
- max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
+ is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
+ h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
@@ -903,10 +929,11 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_fwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale,
- dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right,
- input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
- input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
- input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
+ dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left,
+ window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O,
+ Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
+ input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state,
+ wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
@@ -928,14 +955,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
- NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
- const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
+ NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset,
+ const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
+ const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
- NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
- int64_t window_size_right, bool deterministic, NVTETensor workspace,
- cudaStream_t stream) {
+ NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
+ int64_t window_size_left, int64_t window_size_right, bool deterministic,
+ NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
@@ -953,6 +981,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
Tensor *output_dK = convertNVTETensorCheck(dK);
Tensor *output_dV = convertNVTETensorCheck(dV);
Tensor *output_dBias = convertNVTETensorCheck(dBias);
+ Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset);
Tensor *wkspace = convertNVTETensor(workspace);
auto ndim = input_Q->data.shape.size();
@@ -978,8 +1007,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast(input_K->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
- true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
- max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
+ true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
+ h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
@@ -993,19 +1022,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
- Tensor *input_Bias, *input_rng_state;
+ size_t i = 0;
+ Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ Tensor *input_Bias, *input_SoftmaxOffset;
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
- input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
- input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
- } else {
- input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
+ input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ }
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout,
- qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic,
- input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK,
- output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
+ qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right,
+ deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias,
+ input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias,
+ output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
index 4e6c3c858..ba0f84578 100644
--- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
+++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
@@ -54,10 +54,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k,
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
- NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
- int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias,
- void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
- void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
+ NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
+ int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK,
+ void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrSoftmaxStats,
+ void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
+ void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
@@ -75,6 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
is_causal = true;
is_bottom_right = false;
}
+ bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (is_training && dropout_probability != 0.0f);
NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
@@ -98,8 +100,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
s_q = is_ragged_q ? max_t_q : s_q;
s_kv = is_ragged_kv ? max_t_kv : s_kv;
}
- const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
+ const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
try {
FADescriptor_v1 descriptor{b,
h,
@@ -122,11 +124,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
layout,
bias_type,
mask_type,
+ softmax_type,
window_size_left,
window_size_right,
true,
tensorType,
- tensorType};
+ cudnn_frontend::DataType_t::NOT_SET,
+ cudnn_frontend::DataType_t::NOT_SET,
+ cudnn_frontend::DataType_t::NOT_SET};
namespace fe = cudnn_frontend;
using graph_and_tensors =
@@ -138,6 +143,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr, // O
std::shared_ptr, // Stats
std::shared_ptr, // bias
+ std::shared_ptr, // softmax_offset
std::shared_ptr, // seq_q
std::shared_ptr, // seq_kv
std::shared_ptr, // page_table_k
@@ -168,7 +174,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
- std::shared_ptr Q, K, V, attn_scale;
+ std::shared_ptr Q, K, V, attn_scale, softmax_offset;
std::shared_ptr bias, seq_q, seq_kv;
std::shared_ptr page_table_k, page_table_v;
std::shared_ptr offset_q, offset_k, offset_v, offset_o,
@@ -302,6 +308,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset);
}
+ if (is_softmax_offset) {
+ softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
+ .set_name("softmax_offset")
+ .set_dim({1, h, 1, 1})
+ .set_stride({h, 1, 1, 1})
+ .set_data_type(fe::DataType_t::FLOAT));
+ sdpa_options.set_sink_token(softmax_offset);
+ }
+
auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options);
std::vector o_stride(4);
@@ -338,6 +353,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O);
auto Stats_tuple = std::make_tuple(Stats);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
+ auto softmax_offset_tuple =
+ is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr);
auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v)
@@ -358,17 +375,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
- auto return_tuple = std::tuple_cat(
- std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple,
- page_table_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple);
+ auto return_tuple =
+ std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple,
+ softmax_offset_tuple, padding_tuple, page_table_tuple, offset_qo_tuple,
+ offset_kv_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
- auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v,
- offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] =
- get_graph(sdpa_f16_fprop_cache, descriptor);
+ auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, softmax_offset, seq_q, seq_kv,
+ page_table_k, page_table_v, offset_q, offset_o, offset_k, offset_v, offset_stats,
+ dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed
// n.b. Care should be taken to align each of the added worksapce tensors to their type.
@@ -473,6 +491,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
variant_pack[dropout_seed] = devPtrDropoutSeed;
variant_pack[dropout_offset] = devPtrDropoutOffset;
}
+
+ if (is_softmax_offset) {
+ variant_pack[softmax_offset] = devPtrSoftmaxOffset;
+ }
+
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
@@ -483,14 +506,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v,
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
- NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
- int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose,
- void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias,
- void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
- void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
- void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV,
- cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size,
- cudaStream_t stream, cudnnHandle_t handle) {
+ NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
+ int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ,
+ void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats,
+ void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV,
+ void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed,
+ void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV,
+ void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
+ void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
@@ -506,6 +529,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
is_causal = true;
is_bottom_right = false;
}
+ bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (dropout_probability != 0.0f);
NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
@@ -558,11 +582,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
layout,
bias_type,
mask_type,
+ softmax_type,
window_size_left,
window_size_right,
deterministic,
tensorType,
- tensorType};
+ cudnn_frontend::DataType_t::NOT_SET,
+ cudnn_frontend::DataType_t::NOT_SET,
+ cudnn_frontend::DataType_t::NOT_SET};
namespace fe = cudnn_frontend;
using graph_and_tensors =
@@ -579,6 +606,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr, // dV
std::shared_ptr, // bias
std::shared_ptr, // dBias
+ std::shared_ptr, // softmax_offset
+ std::shared_ptr, // d_softmax_offset
std::shared_ptr, // seq_q
std::shared_ptr, // seq_kv
std::shared_ptr, // offset_q
@@ -608,7 +637,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
.set_compute_data_type(fe::DataType_t::FLOAT);
std::shared_ptr q, k, v, o, dO, stats, attn_scale;
- std::shared_ptr bias, dBias, seq_q, seq_kv;
+ std::shared_ptr bias, dBias, softmax_offset, d_softmax_offset,
+ seq_q, seq_kv;
std::shared_ptr offset_q, offset_k, offset_v, offset_o,
offset_stats;
std::shared_ptr dropout_seed, dropout_offset;
@@ -771,6 +801,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset);
}
+ if (is_softmax_offset) {
+ softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
+ .set_name("softmax_offset")
+ .set_dim({1, h, 1, 1})
+ .set_stride({h, 1, 1, 1})
+ .set_data_type(fe::DataType_t::FLOAT));
+ sdpa_backward_options.set_sink_token(softmax_offset);
+ d_softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes()
+ .set_name("d_softmax_offset")
+ .set_dim({1, h, 1, 1})
+ .set_stride({h, 1, 1, 1})
+ .set_data_type(fe::DataType_t::FLOAT));
+ sdpa_backward_options.set_dsink_token(d_softmax_offset);
+ }
+
auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options);
dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride);
@@ -796,6 +841,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
std::shared_ptr> // dV
key_tensors_tuple = std::make_tuple(q, k, v, o, dO, stats, attn_scale, dQ, dK, dV);
auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr);
+ auto softmax_offset_tuple = is_softmax_offset
+ ? std::make_tuple(softmax_offset, d_softmax_offset)
+ : std::make_tuple(nullptr, nullptr);
auto padding_tuple =
is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr);
auto offset_qo_tuple =
@@ -814,17 +862,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle));
NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle));
- auto return_tuple =
- std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple,
- offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple);
+ auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple,
+ softmax_offset_tuple, padding_tuple, offset_qo_tuple,
+ offset_kv_tuple, offset_s_tuple, dropout_tuple);
cache.insert({descriptor, return_tuple});
return return_tuple;
};
- auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv,
- offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] =
- get_graph(sdpa_f16_bprop_cache, descriptor);
+ auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, softmax_offset,
+ d_softmax_offset, seq_q, seq_kv, offset_q, offset_o, offset_k, offset_v, offset_stats,
+ dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor);
// Exit to request upper level API to allocate memory if needed
// n.b. Care should be taken to align each of the added worksapce tensors to their type.
@@ -938,6 +986,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
variant_pack[dropout_offset] = devPtrDropoutOffset;
}
+ if (is_softmax_offset) {
+ variant_pack[softmax_offset] = devPtrSoftmaxOffset;
+ variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset;
+ }
+
NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace));
} catch (cudnn_frontend::cudnnException &e) {
NVTE_ERROR(e.what());
@@ -949,8 +1002,9 @@ using namespace transformer_engine::fused_attn;
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
- NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
- int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
+ NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
+ int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV,
+ const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
@@ -977,6 +1031,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
+ void *devPtrSoftmaxOffset = nullptr;
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
+ }
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
@@ -990,53 +1048,50 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
max_tokens = get_max_tokens(num_tokens);
}
+ size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
+ Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_S->data.dptr = nullptr;
+ if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
+ output_S->data.shape = {max_tokens, num_attn_heads, 1};
+ } else {
+ output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
+ }
+ output_S->data.dtype = DType::kFloat32;
+ Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_rng_state->data.dptr = nullptr;
+ output_rng_state->data.shape = {2};
+ output_rng_state->data.dtype = DType::kInt64;
+
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
- Aux_CTX_Tensors->size = 3;
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
- output_S->data.dptr = nullptr;
- if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
- output_S->data.shape = {max_tokens, num_attn_heads, 1};
- } else {
- output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
- }
- output_S->data.dtype = DType::kFloat32;
- Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
- output_rng_state->data.dptr = nullptr;
- output_rng_state->data.shape = {2};
- output_rng_state->data.dtype = DType::kInt64;
- Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
+ Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen};
output_bias->data.dtype = QKV_type;
- } else {
- Aux_CTX_Tensors->size = 2;
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
- output_S->data.dptr = nullptr;
- if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
- output_S->data.shape = {max_tokens, num_attn_heads, 1};
- } else {
- output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1};
- }
- output_S->data.dtype = DType::kFloat32;
- Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
- output_rng_state->data.dptr = nullptr;
- output_rng_state->data.shape = {2};
- output_rng_state->data.dtype = DType::kInt64;
}
- } else if (Aux_CTX_Tensors->size == 2) {
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
- devPtrS = output_S->data.dptr;
- Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
- output_rng_state->data.dptr = rng_state->data.dptr;
- } else if (Aux_CTX_Tensors->size == 3) {
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
+
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_softmax_offset->data.dptr = nullptr;
+ output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
+ output_softmax_offset->data.dtype = DType::kFloat32;
+ }
+
+ Aux_CTX_Tensors->size = i;
+ } else if (Aux_CTX_Tensors->size >= 2) {
+ Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS = output_S->data.dptr;
- Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
+ Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr;
- Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
- output_bias->data.dptr = devPtrBias;
+ if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
+ Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_bias->data.dptr = devPtrBias;
+ }
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
+ }
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
@@ -1050,11 +1105,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
fused_attn_arbitrary_seqlen_fwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training,
- attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
- devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
- devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets,
- devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream,
- handle);
+ attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
+ window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS,
+ devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr,
+ nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type),
+ workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
@@ -1074,9 +1129,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
- NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
- bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
- const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
+ NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
+ int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
+ const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
+ Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
@@ -1122,6 +1178,12 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
+ void *devPtrSoftmaxOffset = nullptr;
+ void *devPtrdSoftmaxOffset = nullptr;
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
+ devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
+ }
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr;
@@ -1135,11 +1197,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim,
max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout,
- bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK,
- devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO,
- devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens,
- devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
- &workspace_size, stream, handle);
+ bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic,
+ devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset,
+ devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed,
+ devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets,
+ get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
@@ -1161,12 +1223,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
- int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
- const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
- NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
- const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
- const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
- Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
+ NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
+ const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
+ const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
+ const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
+ const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
+ const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
@@ -1192,6 +1254,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
+ void *devPtrSoftmaxOffset = nullptr;
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
+ }
void *devPtrO = output_O->data.dptr;
void *devPtrS = nullptr;
@@ -1216,53 +1282,50 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
+ size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
+ Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_S->data.dptr = nullptr;
+ if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
+ output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
+ } else {
+ output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
+ }
+ output_S->data.dtype = DType::kFloat32;
+ Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_rng_state->data.dptr = nullptr;
+ output_rng_state->data.shape = {2};
+ output_rng_state->data.dtype = DType::kInt64;
+
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
- Aux_CTX_Tensors->size = 3;
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
- output_S->data.dptr = nullptr;
- if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
- output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
- } else {
- output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
- }
- output_S->data.dtype = DType::kFloat32;
- Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
- output_rng_state->data.dptr = nullptr;
- output_rng_state->data.shape = {2};
- output_rng_state->data.dtype = DType::kInt64;
- Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
+ Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
- } else {
- Aux_CTX_Tensors->size = 2;
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
- output_S->data.dptr = nullptr;
- if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
- output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
- } else {
- output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
- }
- output_S->data.dtype = DType::kFloat32;
- Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
- output_rng_state->data.dptr = nullptr;
- output_rng_state->data.shape = {2};
- output_rng_state->data.dtype = DType::kInt64;
}
- } else if (Aux_CTX_Tensors->size == 2) {
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
- devPtrS = output_S->data.dptr;
- Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
- output_rng_state->data.dptr = rng_state->data.dptr;
- } else if (Aux_CTX_Tensors->size == 3) {
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
+
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_softmax_offset->data.dptr = nullptr;
+ output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
+ output_softmax_offset->data.dtype = DType::kFloat32;
+ }
+
+ Aux_CTX_Tensors->size = i;
+ } else if (Aux_CTX_Tensors->size >= 2) {
+ Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS = output_S->data.dptr;
- Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
+ Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr;
- Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
- output_bias->data.dptr = devPtrBias;
+ if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
+ Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_bias->data.dptr = devPtrBias;
+ }
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
+ }
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
@@ -1277,11 +1340,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
- attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
- devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
- devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
- devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
- &workspace_size, stream, handle);
+ attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
+ window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS,
+ devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
+ devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
+ get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
@@ -1302,10 +1365,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
- NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
- bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
- const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
- Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
+ NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
+ int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
+ const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
+ const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
+ Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
@@ -1359,6 +1423,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
+ void *devPtrSoftmaxOffset = nullptr;
+ void *devPtrdSoftmaxOffset = nullptr;
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
+ devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
+ }
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
@@ -1374,9 +1444,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
- qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ,
- devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV,
- devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
+ qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
+ deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
+ devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
+ devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
@@ -1401,12 +1472,13 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
- NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
- const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
- Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
- const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
- const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
- const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
+ NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
+ int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
+ const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
+ NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
+ const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
+ const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
+ Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
@@ -1425,6 +1497,10 @@ void fused_attn_arbitrary_seqlen_fwd(
bias_b = input_Bias->data.shape[0];
bias_h = input_Bias->data.shape[1];
}
+ void *devPtrSoftmaxOffset = nullptr;
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
+ }
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
@@ -1446,53 +1522,50 @@ void fused_attn_arbitrary_seqlen_fwd(
max_tokens_kv = get_max_tokens(num_tokens_kv);
}
+ size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();
+ Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_S->data.dptr = nullptr;
+ if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
+ output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
+ } else {
+ output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
+ }
+ output_S->data.dtype = DType::kFloat32;
+ Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_rng_state->data.dptr = nullptr;
+ output_rng_state->data.shape = {2};
+ output_rng_state->data.dtype = DType::kInt64;
+
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
- Aux_CTX_Tensors->size = 3;
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
- output_S->data.dptr = nullptr;
- if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
- output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
- } else {
- output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
- }
- output_S->data.dtype = DType::kFloat32;
- Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
- output_rng_state->data.dptr = nullptr;
- output_rng_state->data.shape = {2};
- output_rng_state->data.dtype = DType::kInt64;
- Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
+ Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_bias->data.dptr = nullptr;
output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv};
output_bias->data.dtype = QKV_type;
- } else {
- Aux_CTX_Tensors->size = 2;
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
- output_S->data.dptr = nullptr;
- if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
- output_S->data.shape = {max_tokens_q, num_attn_heads, 1};
- } else {
- output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
- }
- output_S->data.dtype = DType::kFloat32;
- Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
- output_rng_state->data.dptr = nullptr;
- output_rng_state->data.shape = {2};
- output_rng_state->data.dtype = DType::kInt64;
}
- } else if (Aux_CTX_Tensors->size == 2) {
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
- devPtrS = output_S->data.dptr;
- Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
- output_rng_state->data.dptr = rng_state->data.dptr;
- } else if (Aux_CTX_Tensors->size == 3) {
- Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
+
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_softmax_offset->data.dptr = nullptr;
+ output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1};
+ output_softmax_offset->data.dtype = DType::kFloat32;
+ }
+
+ Aux_CTX_Tensors->size = i;
+ } else if (Aux_CTX_Tensors->size >= 2) {
+ Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS = output_S->data.dptr;
- Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]);
+ Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr;
- Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]);
- output_bias->data.dptr = devPtrBias;
+ if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
+ Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_bias->data.dptr = devPtrBias;
+ }
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
+ output_softmax_offset->data.dptr = devPtrSoftmaxOffset;
+ }
} else {
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
}
@@ -1507,11 +1580,11 @@ void fused_attn_arbitrary_seqlen_fwd(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
- attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right,
- devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed,
- devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
- devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
- &workspace_size, stream, handle);
+ attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
+ window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS,
+ devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV,
+ devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV,
+ get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
@@ -1532,13 +1605,14 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
- NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
- int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K,
- const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
- Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
- const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
- const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
- cudaStream_t stream, cudnnHandle_t handle) {
+ NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
+ int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
+ const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
+ const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
+ Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
+ Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
+ const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
+ Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
@@ -1577,6 +1651,12 @@ void fused_attn_arbitrary_seqlen_bwd(
void *devPtrdV = output_dV->data.dptr;
void *devPtrSoftmaxStats = nullptr;
devPtrSoftmaxStats = output_S->data.dptr;
+ void *devPtrSoftmaxOffset = nullptr;
+ void *devPtrdSoftmaxOffset = nullptr;
+ if (softmax_type != NVTE_VANILLA_SOFTMAX) {
+ devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr;
+ devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr;
+ }
void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr;
void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr;
@@ -1592,9 +1672,10 @@ void fused_attn_arbitrary_seqlen_bwd(
fused_attn_arbitrary_seqlen_bwd_impl(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
- qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ,
- devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV,
- devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
+ qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
+ deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
+ devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
+ devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
index e1a20274f..b9658b053 100644
--- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
+++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h
@@ -21,17 +21,19 @@ namespace transformer_engine {
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
- NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
- int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
+ NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
+ int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV,
+ const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded,
const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
- NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
- bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO,
- const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias,
+ NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
+ int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
+ const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
+ Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
@@ -41,21 +43,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked(
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale,
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
- int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
- const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O,
- NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
- const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
- const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
- Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
+ NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
+ const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias,
+ const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
+ const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
+ const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
+ const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
- NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
- bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O,
- const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ,
- Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q,
+ NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
+ int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
+ const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
+ const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
+ Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
@@ -66,24 +69,26 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k,
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
- NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right,
- const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
- Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q,
- const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
- const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
- const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
+ NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
+ int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
+ const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
+ NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
+ const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
+ const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
+ Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
void fused_attn_arbitrary_seqlen_bwd(
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
- NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left,
- int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K,
- const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
- Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
- const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
- const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
- cudaStream_t stream, cudnnHandle_t handle);
+ NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
+ int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
+ const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
+ const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
+ Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
+ Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
+ const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
+ Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine
diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu
index d7f098376..21c544491 100644
--- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu
+++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu
@@ -1658,8 +1658,9 @@ void fused_attn_fp8_fwd_impl_v1(
void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK,
void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO,
void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV,
- void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type,
- void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
+ void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type,
+ cudnn_frontend::DataType_t o_tensor_type, void* workspace, size_t* workspace_size,
+ cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI);
@@ -1672,6 +1673,13 @@ void fused_attn_fp8_fwd_impl_v1(
auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
+ bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF ||
+ o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16);
+ bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 ||
+ o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2);
+ NVTE_CHECK(is_current_scaling || is_delayed_scaling,
+ "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or "
+ "kFloat8E5M2!");
try {
FADescriptor_v1 descriptor{b,
@@ -1695,11 +1703,14 @@ void fused_attn_fp8_fwd_impl_v1(
layout,
bias_type,
mask_type,
+ NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0,
0,
true,
- fwd_tensor_type,
- fwd_tensor_type};
+ qkv_tensor_type,
+ o_tensor_type,
+ cudnn_frontend::DataType_t::NOT_SET,
+ cudnn_frontend::DataType_t::NOT_SET};
namespace fe = cudnn_frontend;
using graph_and_tensors =
@@ -1738,7 +1749,7 @@ void fused_attn_fp8_fwd_impl_v1(
// otherwise, build the op_graph and the plan. Then update cache
auto mha_graph = std::make_shared();
- mha_graph->set_io_data_type(fwd_tensor_type)
+ mha_graph->set_io_data_type(qkv_tensor_type)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
@@ -1786,7 +1797,13 @@ void fused_attn_fp8_fwd_impl_v1(
descale_v = mha_graph->tensor_like(descale_q, "Descale_V");
descale_s = mha_graph->tensor_like(descale_q, "Descale_S");
scale_s = mha_graph->tensor_like(descale_q, "Scale_S");
- scale_o = mha_graph->tensor_like(descale_q, "Scale_O");
+
+ if (is_delayed_scaling) {
+ scale_o = mha_graph->tensor_like(descale_q, "Scale_O");
+ }
+ if (is_current_scaling) {
+ scale_o = mha_graph->tensor(1.0f);
+ }
fe::graph::SDPA_fp8_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_fp8_attributes()
@@ -1838,11 +1855,12 @@ void fused_attn_fp8_fwd_impl_v1(
std::vector o_stride(4);
generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout,
NVTE_QKV_Matrix::NVTE_O_Matrix);
- O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride);
+ O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride).set_data_type(o_tensor_type);
amax_o->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
+
amax_s->set_output(true)
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
@@ -1915,13 +1933,16 @@ void fused_attn_fp8_fwd_impl_v1(
{descale_v, devPtrDescaleV},
{descale_s, devPtrDescaleS},
{scale_s, devPtrScaleS},
- {scale_o, devPtrScaleO},
{attn_scale, &scaling_factor},
{O, devPtrO},
{amax_s, devPtrAmaxS},
{amax_o, devPtrAmaxO},
{Stats, devPtrM}};
+ if (is_delayed_scaling) {
+ variant_pack[scale_o] = devPtrScaleO;
+ }
+
/* if (is_bias) {
variant_pack[bias] = devPtrBias;
} */
@@ -1962,8 +1983,9 @@ void fused_attn_fp8_bwd_impl_v1(
void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV,
void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV,
void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed,
- void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type,
- cudnn_frontend::DataType_t bwd_tensor_type, void* workspace, size_t* workspace_size,
+ void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type,
+ cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type,
+ cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
@@ -1977,6 +1999,15 @@ void fused_attn_fp8_bwd_impl_v1(
auto bias_h = h;
NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!");
NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!");
+ bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF ||
+ dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16);
+ bool is_delayed_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 ||
+ dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2);
+ NVTE_CHECK(is_current_scaling || is_delayed_scaling,
+ "FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or "
+ "kFloat8E5M2!");
+ bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF ||
+ o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16);
try {
FADescriptor_v1 descriptor{b,
@@ -2000,11 +2031,14 @@ void fused_attn_fp8_bwd_impl_v1(
layout,
bias_type,
mask_type,
+ NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0,
0,
false,
- fwd_tensor_type,
- bwd_tensor_type};
+ qkv_tensor_type,
+ o_tensor_type,
+ do_tensor_type,
+ dqkv_tensor_type};
namespace fe = cudnn_frontend;
using graph_and_tensors =
@@ -2057,7 +2091,7 @@ void fused_attn_fp8_bwd_impl_v1(
// otherwise, build the op_graph and the plan. Then update cache
auto mha_graph = std::make_shared();
- mha_graph->set_io_data_type(fwd_tensor_type)
+ mha_graph->set_io_data_type(qkv_tensor_type)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
@@ -2097,7 +2131,8 @@ void fused_attn_fp8_bwd_impl_v1(
o = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("O")
.set_dim({b, h, s_q, d})
- .set_stride(o_stride));
+ .set_stride(o_stride)
+ .set_data_type(o_tensor_type));
dO = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("dO")
.set_dim({b, h, s_q, d})
@@ -2123,14 +2158,26 @@ void fused_attn_fp8_bwd_impl_v1(
descale_k = mha_graph->tensor_like(descale_q, "Descale_q");
descale_v = mha_graph->tensor_like(descale_q, "Descale_V");
descale_s = mha_graph->tensor_like(descale_q, "Descale_S");
- descale_o = mha_graph->tensor_like(descale_q, "Descale_O");
descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP");
+ if (is_O_in_F16) {
+ descale_o = mha_graph->tensor(1.0f);
+ } else {
+ descale_o = mha_graph->tensor_like(descale_q, "Descale_O");
+ }
descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO");
scale_s = mha_graph->tensor_like(descale_q, "Scale_S");
scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP");
- scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ");
- scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK");
- scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV");
+
+ if (is_delayed_scaling) {
+ scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ");
+ scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK");
+ scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV");
+ }
+ if (is_current_scaling) {
+ scale_dQ = mha_graph->tensor(1.0f);
+ scale_dK = mha_graph->tensor(1.0f);
+ scale_dV = mha_graph->tensor(1.0f);
+ }
fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options;
sdpa_backward_options = fe::graph::SDPA_fp8_backward_attributes()
@@ -2212,10 +2259,10 @@ void fused_attn_fp8_bwd_impl_v1(
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::FLOAT);
- dO->set_data_type(bwd_tensor_type);
- dQ->set_data_type(bwd_tensor_type);
- dK->set_data_type(bwd_tensor_type);
- dV->set_data_type(bwd_tensor_type);
+ dO->set_data_type(do_tensor_type);
+ dQ->set_data_type(dqkv_tensor_type);
+ dK->set_data_type(dqkv_tensor_type);
+ dV->set_data_type(dqkv_tensor_type);
std::tuple, // q
std::shared_ptr, // k
@@ -2296,14 +2343,10 @@ void fused_attn_fp8_bwd_impl_v1(
{descale_q, devPtrDescaleQ},
{descale_k, devPtrDescaleK},
{descale_v, devPtrDescaleV},
- {descale_o, devPtrDescaleO},
{descale_dO, devPtrDescaledO},
{descale_s, devPtrDescaleS},
{descale_dP, devPtrDescaledP},
{scale_s, devPtrScaleS},
- {scale_dQ, devPtrScaledQ},
- {scale_dK, devPtrScaledK},
- {scale_dV, devPtrScaledV},
{scale_dP, devPtrScaledP},
{dQ, devPtrdQ},
{dK, devPtrdK},
@@ -2314,6 +2357,15 @@ void fused_attn_fp8_bwd_impl_v1(
{amax_dP, devPtrAmaxdP},
};
+ if (is_delayed_scaling) {
+ variant_pack[scale_dQ] = devPtrScaledQ;
+ variant_pack[scale_dK] = devPtrScaledK;
+ variant_pack[scale_dV] = devPtrScaledV;
+ }
+ if (!is_O_in_F16) {
+ variant_pack[descale_o] = devPtrDescaleO;
+ }
+
/* if (is_bias) {
variant_pack[bias] = devPtrBias;
if ((bias_b == 1) && (bias_h == h)) {
@@ -2364,6 +2416,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_QKV->data.dtype;
+ const DType O_type = output_O->data.dtype;
void* devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
size_t stride = 0;
@@ -2430,8 +2483,8 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens,
- devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
- &workspace_size, stream, handle);
+ devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
+ get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl(
batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout,
@@ -2465,6 +2518,7 @@ void fused_attn_fp8_bwd_qkvpacked(
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_QKV->data.dtype;
+ const DType dO_type = input_dO->data.dtype;
const DType dQKV_type = output_dQKV->data.dtype;
void* devPtrQKV = input_QKV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
@@ -2482,7 +2536,11 @@ void fused_attn_fp8_bwd_qkvpacked(
void* devPtrDescaleV = input_QKV->scale_inv.dptr;
void* devPtrO = input_O->data.dptr;
- void* devPtrDescaleO = input_O->scale_inv.dptr;
+ const DType O_type = input_O->data.dtype;
+ void* devPtrDescaleO = nullptr;
+ if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) {
+ devPtrDescaleO = input_O->scale_inv.dptr;
+ }
void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr;
@@ -2525,7 +2583,8 @@ void fused_attn_fp8_bwd_qkvpacked(
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
- get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle);
+ get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type),
+ workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl(
batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout,
@@ -2563,6 +2622,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
+ const DType O_type = output_O->data.dtype;
void* devPtrQ = input_Q->data.dptr;
void* devPtrKV = input_KV->data.dptr;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
@@ -2631,8 +2691,8 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
- devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
- &workspace_size, stream, handle);
+ devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
+ get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale,
@@ -2669,6 +2729,7 @@ void fused_attn_fp8_bwd_kvpacked(
cudnnHandle_t handle) {
using namespace transformer_engine;
const DType QKV_type = input_Q->data.dtype;
+ const DType dO_type = input_dO->data.dtype;
const DType dQKV_type = output_dQ->data.dtype;
void* devPtrQ = input_Q->data.dptr;
void* devPtrKV = input_KV->data.dptr;
@@ -2686,7 +2747,11 @@ void fused_attn_fp8_bwd_kvpacked(
void* devPtrDescaleV = input_KV->scale_inv.dptr;
void* devPtrO = input_O->data.dptr;
- void* devPtrDescaleO = input_O->scale_inv.dptr;
+ const DType O_type = input_O->data.dtype;
+ void* devPtrDescaleO = nullptr;
+ if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) {
+ devPtrDescaleO = input_O->scale_inv.dptr;
+ }
void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr;
@@ -2731,7 +2796,8 @@ void fused_attn_fp8_bwd_kvpacked(
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
- get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle);
+ get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type),
+ workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout,
@@ -2820,6 +2886,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1);
const DType QKV_type = input_Q->data.dtype;
+ const DType O_type = output_O->data.dtype;
size_t workspace_size = 0;
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
@@ -2829,8 +2896,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM,
devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS,
devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
- devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
- &workspace_size, stream, handle);
+ devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
+ get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_fwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale,
@@ -2876,7 +2943,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
void* devPtrDescaleV = input_Q->scale_inv.dptr;
void* devPtrO = input_O->data.dptr;
- void* devPtrDescaleO = input_O->scale_inv.dptr;
+ const DType O_type = input_O->data.dtype;
+ void* devPtrDescaleO = nullptr;
+ if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) {
+ devPtrDescaleO = input_O->scale_inv.dptr;
+ }
void* devPtrdO = input_dO->data.dptr;
void* devPtrDescaledO = input_dO->scale_inv.dptr;
@@ -2909,6 +2980,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1);
const DType QKV_type = input_Q->data.dtype;
+ const DType dO_type = input_dO->data.dtype;
const DType dQKV_type = output_dQ->data.dtype;
size_t workspace_size = 0;
@@ -2922,7 +2994,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou
devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP,
devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV,
devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type),
- get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle);
+ get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type),
+ workspace->data.dptr, &workspace_size, stream, handle);
} else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
fused_attn::fused_attn_fp8_bwd_impl(
batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout,
diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h
index 678b63691..f03774f8e 100644
--- a/transformer_engine/common/fused_attn/utils.h
+++ b/transformer_engine/common/fused_attn/utils.h
@@ -107,23 +107,28 @@ struct FADescriptor_v1 {
NVTE_QKV_Layout layout;
NVTE_Bias_Type bias_type;
NVTE_Mask_Type mask_type;
+ NVTE_Softmax_Type softmax_type;
std::int64_t window_size_left;
std::int64_t window_size_right;
bool deterministic;
- cudnn_frontend::DataType_t fwd_tensor_type;
- cudnn_frontend::DataType_t bwd_tensor_type;
+ cudnn_frontend::DataType_t qkv_tensor_type;
+ cudnn_frontend::DataType_t o_tensor_type;
+ cudnn_frontend::DataType_t do_tensor_type;
+ cudnn_frontend::DataType_t dqkv_tensor_type;
bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
- attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left,
- window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) <
+ attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
+ window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
+ o_tensor_type, do_tensor_type, dqkv_tensor_type) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
- rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left,
- rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type,
- rhs.bwd_tensor_type);
+ rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
+ rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
+ rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
+ rhs.dqkv_tensor_type);
}
};
diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp
index bb5e22887..da40a223d 100644
--- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp
+++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp
@@ -276,9 +276,10 @@ void log_fused_attn_config(
// select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
- NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
- size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
- size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
+ NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
+ float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
+ size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
+ int64_t window_size_right) {
using namespace transformer_engine;
// by default, fused attn is enabled
@@ -339,14 +340,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// NVTE fused attention FWD with packed QKV
-void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
- NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
+void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
+ const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
+ NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
- int64_t window_size_left, int64_t window_size_right,
- NVTETensor workspace, cudaStream_t stream) {
+ NVTE_Softmax_Type softmax_type, int64_t window_size_left,
+ int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
@@ -384,8 +386,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right));
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
- is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h,
- max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
+ is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
+ h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) {
fused_attn_ck_fwd_qkvpacked(
@@ -419,10 +421,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
- NVTETensor dBias, const NVTETensor cu_seqlens,
- const NVTETensor cu_seqlens_padded, size_t max_seqlen,
- float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
- NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
+ NVTETensor dBias, NVTETensor dSoftmaxOffset,
+ const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
+ size_t max_seqlen, float attn_scale, float dropout,
+ NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
+ NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
@@ -468,8 +471,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right));
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
- true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
- max_seqlen, d, d, window_size_left, window_size_right);
+ true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h,
+ max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) {
if((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)){
@@ -505,14 +508,17 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
}
// NVTE fused attention FWD with packed KV
-void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
- NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
- const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
- const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
- size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
- NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
- int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
- cudaStream_t stream) {
+void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias,
+ const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
+ NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
+ const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
+ const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
+ const NVTETensor page_table_v, const NVTETensor rng_state,
+ size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
+ float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
+ NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
+ NVTE_Softmax_Type softmax_type, int64_t window_size_left,
+ int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
@@ -556,8 +562,8 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right));
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
- is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
- max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
+ is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
+ h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) {
fused_attn_ck_fwd_kvpacked(
@@ -596,12 +602,12 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
- NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
- const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
- size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
- NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
- int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
- cudaStream_t stream) {
+ NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q,
+ const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
+ const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
+ float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
+ NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
+ int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
@@ -650,8 +656,8 @@ void nvte_fused_attn_bwd_kvpacked(
std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right));
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
- true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
- max_seqlen_kv, d, d, window_size_left, window_size_right);
+ true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
+ h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
@@ -694,16 +700,16 @@ void nvte_fused_attn_bwd_kvpacked(
// NVTE fused attention FWD with separate Q, K and V
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
- const NVTETensor Bias, NVTETensor S, NVTETensor O,
- NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
+ const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
+ NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
- int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
- cudaStream_t stream) {
+ NVTE_Softmax_Type softmax_type, int64_t window_size_left,
+ int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
@@ -740,8 +746,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right));
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
- is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
- max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
+ is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
+ h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) {
fused_attn_ck_fwd(
@@ -780,14 +786,14 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
- NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
- const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
- const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
- size_t max_seqlen_kv, float attn_scale, float dropout,
- NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
- NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
- int64_t window_size_right, bool deterministic, NVTETensor workspace,
- cudaStream_t stream) {
+ NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset,
+ const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
+ const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
+ size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale,
+ float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
+ NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
+ int64_t window_size_left, int64_t window_size_right, bool deterministic,
+ NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
@@ -830,8 +836,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right));
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
- true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
- max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
+ true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
+ h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp
new file mode 100644
index 000000000..cf211beaf
--- /dev/null
+++ b/transformer_engine/common/gemm/config.cpp
@@ -0,0 +1,116 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include "./config.h"
+
+#include
+#include
+
+#include
+
+#include "../util/logging.h"
+
+NVTEMatmulConfig nvte_create_matmul_config() { return new transformer_engine::MatmulConfig; }
+
+void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
+ void *buf, size_t size_in_bytes, size_t *size_written) {
+ // Write attribute size
+ NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ",
+ static_cast(attr), ")");
+ NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
+ const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr];
+ *size_written = attr_size;
+
+ // Return immediately if buffer is not provided
+ if (buf == nullptr) {
+ return;
+ }
+
+ // Check buffer size
+ NVTE_CHECK(size_in_bytes >= attr_size,
+ "Buffer is too small for matmul config attribute "
+ "(attribute ",
+ static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
+ " bytes)");
+
+ // Write to buffer
+ NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)");
+ const auto &config_ = *reinterpret_cast(config);
+ switch (attr) {
+ case kNVTEMatmulConfigBiasTensor:
+ std::memcpy(buf, &config_.bias_tensor, attr_size);
+ break;
+ case kNVTEMatmulConfigDBiasTensor:
+ std::memcpy(buf, &config_.dbias_tensor, attr_size);
+ break;
+ case kNVTEMatmulConfigWithGELUEpilogue:
+ std::memcpy(buf, &config_.with_gelu_epilogue, attr_size);
+ break;
+ case kNVTEMatmulConfigWithDGELUEpilogue:
+ std::memcpy(buf, &config_.with_dgelu_epilogue, attr_size);
+ break;
+ case kNVTEMatmulConfigEpilogueAuxTensor:
+ std::memcpy(buf, &config_.epilogue_aux_tensor, attr_size);
+ break;
+ case kNVTEMatmulConfigUseSplitAccumulator:
+ std::memcpy(buf, &config_.use_split_accumulator, attr_size);
+ break;
+ case kNVTEMatmulConfigSMCount:
+ std::memcpy(buf, &config_.sm_count, attr_size);
+ break;
+ default:
+ NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast(attr), ")");
+ }
+}
+
+void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
+ const void *buf, size_t size_in_bytes) {
+ // Check attribute and buffer
+ NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ",
+ static_cast(attr), ")");
+ const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr];
+ NVTE_CHECK(size_in_bytes >= attr_size,
+ "Buffer is too small for matmul config attribute "
+ "(attribute ",
+ static_cast