diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f40b28189..506bc83f0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,7 +19,7 @@ jobs: run: | apt-get update apt-get install -y git python3.9 pip cudnn9-cuda-12 - pip install cmake==3.21.0 pybind11[global] ninja + pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: @@ -43,7 +43,7 @@ jobs: run: | apt-get update apt-get install -y git python3.9 pip cudnn9-cuda-12 - pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript + pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: @@ -63,7 +63,7 @@ jobs: options: --user root steps: - name: 'Dependencies' - run: pip install pybind11[global] + run: pip install pybind11[global] nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: @@ -83,7 +83,7 @@ jobs: options: --user root steps: - name: 'Dependencies' - run: pip install torch pybind11[global] einops onnxscript + run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: diff --git a/benchmarks/attention/benchmark_attention_rocm.py b/benchmarks/attention/benchmark_attention_rocm.py index 0c37696ac..d3b1fd759 100644 --- a/benchmarks/attention/benchmark_attention_rocm.py +++ b/benchmarks/attention/benchmark_attention_rocm.py @@ -307,7 +307,6 @@ def sanity_checks( cfg, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=cfg.window_size, pad_between_seqs=pad_between_seqs, ) flash_ok, fused_ok, _ = avail @@ -368,7 +367,6 @@ def main(args): config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=pad_between_seqs, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends diff --git a/benchmarks/benchmark_rht_cast.py b/benchmarks/benchmark_rht_cast.py new file mode 100644 index 000000000..9c47856f7 --- /dev/null +++ b/benchmarks/benchmark_rht_cast.py @@ -0,0 +1,152 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import torch +import pandas as pd +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +import transformer_engine.pytorch.cpp_extensions as ext + +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + +scale_padding_to = 1 +permute_scale = False + +TORCH_TO_TE_FLOAT_MAP = { + torch.bfloat16: tex.DType.kBFloat16, +} + + +def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16): + # Generate random input data + M, K = shape + x = torch.randn([M, K], dtype=input_dtype, device="cuda") + + assert shape[0] % 16 == 0, "Shape must be divisible by 16" + assert shape[1] % 16 == 0, "Shape must be divisible by 16" + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + stochastic_rounding=stochastic_rounding, + ) + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, K), dtype=x.dtype, device=x.device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + with torch.no_grad(): + stmt = "kernel_func(input, output)" + globals_dict = { + "kernel_func": nvfp4_quantizer.update_quantized, + "input": x, + "output": x_nvfp4_sut, + } + + timing = benchmark.Timer( + stmt=stmt, + globals=globals_dict, + num_threads=1, + ).blocked_autorange(min_run_time=5) + print(timing) + timing_us = timing.median * 1e6 + + input_nbytes = shape[0] * shape[1] * 2 # bf16 + output_nbytes = shape[0] * shape[1] // 2 # //2 for fp4 + sf_nbytes = shape[0] * shape[1] // 16 # //16 for 1 byte per 16 elems + + total_nbytes = ( + 0 + + input_nbytes + * 3 # Reading input for Amax(x)&Amax(RHT(x.T)), Reading input for Cast(x), Reaindg input for Cast(RHT(x.T)) + + 2 * 4 # Output 2 * float for scale & amax + + 2 * 4 # Input 2 * float + + output_nbytes * 2 # Output from Cast(x) and Cast(RHT(x.T)) + + sf_nbytes * 2 # Scale factor + ) + + throughput_GBps = total_nbytes / (1024 * 1024 * 1024) / (timing_us / 1e6) + + print( + f"Stochastic rounding: {stochastic_rounding}, Total: {total_nbytes} bytes, Throughput:" + f" {throughput_GBps} GB/s" + ) + return timing_us, throughput_GBps + + +# Nsight Compute Profiling Command: +# ncu -f -o block_scaled_1d_cast_transpose_kernel --set=full --kernel-name "block_scaled_1d_cast_transpose_kernel" -s 5 -c 5 python benchmark_cast_transpose_1d_block.py --profile + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable profiling mode") + args = parser.parse_args() + + if args.profile: + print("Profiling is enabled.") + else: + print("Profiling is disabled.") + + shapes = [ + (8192, 5120), + (8192, 10240), + (8192, 2560), + (8192, 11328), + (8192, 512), + (8192, 3584), + (5120, 8192), + (10240, 8192), + (2560, 8192), + (11328, 8192), + (512, 8192), + (3584, 8192), + (4096, 16384), + (14336, 16384), + ] + + if args.profile: + shapes = [ + (16384, 6144), + ] + + data = [] + for stochastic_rounding in [True]: # , False]: + for shape in shapes: + print( + f"Running benchmark_func with shape {shape} and stochastic_rounding" + f" {stochastic_rounding}" + ) + timing_us, throughput_GBps = run_kernel(shape, stochastic_rounding) + data.append( + [ + "benchmark_func", + shape, + stochastic_rounding, + timing_us, + throughput_GBps, + ] + ) + + df = pd.DataFrame( + data=data, + columns=[ + "kernel", + "shape", + "stochastic_rounding", + "timing_us", + "throughput(GB/s)", + ], + ) + print(df) + df.to_csv("benchmark_cast_nvfp4.csv", index=False) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 81006d78c..834f26295 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.8.0.dev0 +2.8.0 diff --git a/build_tools/jax.py b/build_tools/jax.py index 182940c11..20679defc 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -100,11 +100,18 @@ def setup_jax_extension( # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension + # Note: Collective GEMM operations are not supported on ROCm yet + if rocm_build(): + comm_libraries = [] + else: + comm_libraries = ["nccl"] + return Pybind11Extension( "transformer_engine_jax", sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], extra_compile_args=cxx_flags, + libraries=comm_libraries, ) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index bb084293f..0799c73b9 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -27,7 +27,7 @@ def install_requirements() -> List[str]: """Install dependencies for TE/PyTorch extensions.""" - return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"] + return ["torch>=2.1", "einops", "onnxscript", "onnx"] def test_requirements() -> List[str]: diff --git a/build_tools/utils.py b/build_tools/utils.py index e3c5b6be8..a6514bf78 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -305,15 +305,18 @@ def get_cuda_include_dirs() -> Tuple[str, str]: @functools.lru_cache(maxsize=None) def cuda_archs() -> str: - version = cuda_version() - if os.getenv("NVTE_CUDA_ARCHS") is None: + archs = os.getenv("NVTE_CUDA_ARCHS") + if archs is None: + version = cuda_version() if version >= (13, 0): - os.environ["NVTE_CUDA_ARCHS"] = "75;80;89;90;100;120" + archs = "75;80;89;90;100;100a;103a;120" + elif version >= (12, 9): + archs = "70;80;89;90;100;100a;103a;120" elif version >= (12, 8): - os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90;100;120" + archs = "70;80;89;90;100;100a;120" else: - os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90" - return os.getenv("NVTE_CUDA_ARCHS") + archs = "70;80;89;90" + return archs def cuda_version() -> Tuple[int, ...]: diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py new file mode 100644 index 000000000..da79b2137 --- /dev/null +++ b/examples/jax/collective_gemm/common.py @@ -0,0 +1,245 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Shared functions for the comm_overlap tests""" + +import jax.numpy as jnp +import numpy as np + + +# Add this after your existing imports +def dtype_tols(dtype, rtol=None, atol=None): + """Expected numerical tolerance for a data type.""" + # Return immediately if tolerances are fully specified + if rtol is not None and atol is not None: + return {"rtol": rtol, "atol": atol} + + # Default tolerances for common dtypes + if dtype in [jnp.float32, "float32"]: + return {"rtol": 1e-5, "atol": 1e-8} + elif dtype in [jnp.float16, "float16"]: + return {"rtol": 1e-3, "atol": 1e-6} + elif dtype in [jnp.bfloat16, "bfloat16"]: + return {"rtol": 1e-2, "atol": 1e-5} + else: + return {"rtol": 1e-5, "atol": 1e-8} + + +def assert_allclose( + actual, + desired, + rtol=None, + atol=None, + dtype=None, + **kwargs, +): + """Check if two tensors are close.""" + # Infer data type if needed + if dtype is None: + if isinstance(actual, float): + dtype = "float32" + else: + dtype = actual.dtype + + # Determine tolerances + tols = {} + if rtol is None or atol is None: + tols = dtype_tols(dtype) + if rtol is not None: + tols["rtol"] = rtol + if atol is not None: + tols["atol"] = atol + + # Cast tensors to fp32 + if not isinstance(actual, float): + actual = actual.astype(jnp.float32) + if not isinstance(desired, float): + desired = desired.astype(jnp.float32) + + # Check if tensors are close + np.testing.assert_allclose(actual, desired, **tols, **kwargs) + + +def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-8): + if not jnp.allclose(ref_output, gathered_output, rtol=rtol, atol=atol): + diff = jnp.abs(ref_output - gathered_output) + mask = diff > (atol + rtol * jnp.abs(gathered_output)) + print(mask.astype(int)) + print(jnp.where(mask, diff, 0)) + + +# Shared constants for all tests +DP_AXIS = "data" +TPSP_AXIS = "tensor_sequence" +PARAMS_KEY = "params" + +# Shared functions for distributed testing +import argparse +import jax +from jax.experimental import mesh_utils +from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap + +# Global flag to track if distributed has been initialized +_distributed_initialized = False + + +def _is_distributed_initialized(): + """Check if JAX distributed has been initialized.""" + return _distributed_initialized + + +def _initialize_distributed(args): + """Initialize JAX distributed with custom arguments.""" + global _distributed_initialized + + # Check if already initialized + if _distributed_initialized: + return + + if args.coordinator_address is None or args.num_processes is None or args.process_id is None: + raise ValueError( + "All distributed initialization arguments are required: " + "--coordinator-address, --num-processes, --process-id" + ) + if args.local_device_ids is None: + assert ( + args.num_devices_per_process is not None + ), "Either local_device_ids or num_devices_per_process must be provided" + # Calculate device range for this process + # Single process single device: each process gets one unique device + # Single process multiple devices: each process gets a unique range of devices + start_device = args.process_id * args.num_devices_per_process + device_range = range(start_device, start_device + args.num_devices_per_process) + global_device_ids_for_this_process = ",".join(map(str, device_range)) + else: + # Use explicitly provided global device IDs + global_device_ids_for_this_process = args.local_device_ids + args.num_devices_per_process = len(args.local_device_ids.split(",")) + + assert args.num_devices_per_process == 1, "Only single process single GPU is supported!" + + print( + f"Initializing JAX distributed with coordinator={args.coordinator_address}, " + f"num_processes={args.num_processes}, process_id={args.process_id}" + ) + # Note: "local_device_ids" is a JAX term meaning "global CUDA devices managed by this process" + jax.distributed.initialize( + coordinator_address=args.coordinator_address, + num_processes=args.num_processes, + process_id=args.process_id, + local_device_ids=global_device_ids_for_this_process, + ) + + _distributed_initialized = True + jax.clear_caches() + jax.config.update( + "jax_use_shardy_partitioner", False + ) # CollectiveGEMM does not work with Shardy yet + + assert jax.local_device_count() == 1, ( + f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found" + f" {jax.local_device_count()}" + ) + + devices_per_process = 1 + num_total_devices = args.num_processes + + print( + f"Initializing CGEMM communicator with num_total_devices={num_total_devices}," + f" devices_per_process={devices_per_process}, process_id={args.process_id}" + ) + + collective_gemm_bootstrap( + num_total_devices=num_total_devices, + num_devices_per_process=devices_per_process, + process_id=args.process_id, + tensor_parallel_size=args.tensor_parallel_size, + ) + + +def _get_dp_and_tp_sizes(args): + num_gpu = args.num_processes * args.num_devices_per_process + if args.tensor_parallel_size is None: + num_gpu_dp = 2 if args.enable_data_parallel else 1 + assert ( + num_gpu > 1 and num_gpu % num_gpu_dp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_tp = num_gpu // num_gpu_dp + else: + num_gpu_tp = args.tensor_parallel_size + assert ( + num_gpu > 1 and num_gpu % num_gpu_tp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_dp = num_gpu // num_gpu_tp + return num_gpu_dp, num_gpu_tp + + +def _create_mesh(args): + """Create mesh configuration with proper validation.""" + num_gpu = args.num_processes * args.num_devices_per_process + assert num_gpu == len(jax.devices()), "Number of GPUs must be equal to number of devices" + num_gpu_dp, num_gpu_tp = _get_dp_and_tp_sizes(args) + + print(f"Using {num_gpu_dp}x{num_gpu_tp} mesh ({num_gpu_dp * num_gpu_tp} total GPUs)") + + device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) + mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=(DP_AXIS, TPSP_AXIS)) + return mesh + + +def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"): + """Create common argument parser for all collective GEMM tests.""" + parser = argparse.ArgumentParser(description=description) + + # Distributed initialization arguments + parser.add_argument( + "--coordinator-address", + type=str, + default=None, + help="Coordinator address for distributed initialization", + ) + parser.add_argument( + "--num-processes", + type=int, + default=None, + help="Number of processes for distributed initialization", + ) + parser.add_argument( + "--process-id", type=int, default=None, help="Process ID for distributed initialization" + ) + parser.add_argument( + "--local-device-ids", + type=str, + default=None, + help="Local device IDs for distributed initialization (comma-separated)", + ) + parser.add_argument( + "--num-devices-per-process", type=int, default=1, help="Number of devices per process" + ) + + # Test configuration arguments + parser.add_argument( + "--tensor-parallel-size", type=int, default=None, help="Tensor parallel size" + ) + parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing") + parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing") + parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension") + parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension") + parser.add_argument( + "--collective-type", + type=str, + default="all_gather", + choices=["all_gather", "reduce_scatter"], + help="Type of collective operation", + ) + parser.add_argument( + "--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use" + ) + parser.add_argument( + "--enable-data-parallel", action="store_true", help="Enable data parallelism" + ) + parser.add_argument( + "--enable-result-check", action="store_true", default=True, help="Enable result checking" + ) + + return parser diff --git a/examples/jax/collective_gemm/conftest.py b/examples/jax/collective_gemm/conftest.py new file mode 100644 index 000000000..83937971a --- /dev/null +++ b/examples/jax/collective_gemm/conftest.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""config for collective_gemm tests""" +import pytest + + +def pytest_addoption(parser): + """Pytest hook for collective_gemm tests""" + parser.addoption("--coordinator-address", action="store", default="localhost:12345") + parser.addoption("--num-processes", action="store", default=1) + parser.addoption("--process-id", action="store", default=0) + parser.addoption("--local-device-ids", action="store", default=None) + + +@pytest.fixture(autouse=True) +def distributed_args(request): + """Fixture for querying distributed initialization arguments""" + if request.cls: + request.cls.coordinator_address = request.config.getoption("--coordinator-address") + request.cls.num_processes = int(request.config.getoption("--num-processes")) + request.cls.process_id = int(request.config.getoption("--process-id")) + request.cls.local_device_ids = request.config.getoption("--local-device-ids") + request.cls.num_devices_per_process = ( + 1 + if request.cls.local_device_ids is None + else len(request.cls.local_device_ids.split(",")) + ) diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh new file mode 100644 index 000000000..af263eb53 --- /dev/null +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -0,0 +1,119 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} + +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + +# Check if NVLINK is supported before running tests +echo "*** Checking NVLINK support***" +NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1) +NVLINK_EXIT_CODE=$? + +# Check if command failed OR output indicates no NVLINK +if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then + echo "NVLINK is not supported on this platform" + echo "Collective GEMM tests require NVLINK connectivity" + echo "SKIPPING all tests" + exit 0 +else + echo "NVLINK support detected" +fi + +# Define the test files to run +TEST_FILES=( +"test_gemm.py" +"test_dense_grad.py" +"test_layernorm_mlp_grad.py" +) + +echo +echo "*** Executing tests in examples/jax/collective_gemm/ ***" + +HAS_FAILURE=0 # Global failure flag +PIDS=() # Array to store all process PIDs + +# Cleanup function to kill all processes +cleanup() { + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill -TERM "$pid" 2>/dev/null || true + fi + done + # Wait a bit and force kill if needed + sleep 2 + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -KILL "$pid" 2>/dev/null || true + fi + done +} + +# Set up signal handlers to cleanup on exit +trap cleanup EXIT INT TERM + +# Run each test file across all GPUs +for TEST_FILE in "${TEST_FILES[@]}"; do + echo + echo "=== Starting test file: $TEST_FILE ..." + + # Clear PIDs array for this test file + PIDS=() + + for i in $(seq 0 $(($NUM_GPUS - 1))); do + # Define output file for logs + LOG_FILE="${TEST_FILE}_gpu_${i}.log" + + if [ $i -eq 0 ]; then + # For process 0: show live output AND save to log file using tee + echo "=== Live output from process 0 ===" + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \ + "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + --num-processes=$NUM_GPUS \ + --process-id=$i 2>&1 | tee "$LOG_FILE" & + PID=$! + PIDS+=($PID) + else + # For other processes: redirect to log files only + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + --num-processes=$NUM_GPUS \ + --process-id=$i > "$LOG_FILE" 2>&1 & + PID=$! + PIDS+=($PID) + fi + done + + # Wait for all processes to finish + wait + + # Check and print the log content from process 0 (now has log file thanks to tee) + if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then + echo "... $TEST_FILE SKIPPED" + elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then + echo "... $TEST_FILE FAILED" + HAS_FAILURE=1 + elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then + echo "... $TEST_FILE PASSED" + else + echo "... $TEST_FILE INVALID" + HAS_FAILURE=1 + fi + + # Remove the log files after processing them + wait + rm ${TEST_FILE}_gpu_*.log +done + +wait + +# Final cleanup (trap will also call cleanup on exit) +cleanup + +exit $HAS_FAILURE diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py new file mode 100644 index 000000000..df2dd5618 --- /dev/null +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -0,0 +1,214 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Collective Dense Gradient test on multi-GPU with tensor parallelism""" +import argparse +import unittest +import os + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec, NamedSharding +import flax + +from common import ( + assert_allclose, + _initialize_distributed, + _get_dp_and_tp_sizes, + _create_mesh, + DP_AXIS, + TPSP_AXIS, + PARAMS_KEY, + cgemm_parser, +) + +from transformer_engine.jax.dense import dense + +from transformer_engine.jax.quantize import fp8_autocast +from transformer_engine.jax.cpp_extensions.gemm import ( + CollectiveOp, + CollectiveOpSet, + noop_collective_op_set, +) +from transformer_engine.jax.sharding import MeshResource +import transformer_engine.jax.flax as te_flax + + +def _get_logical_axes(collective_op): + if collective_op.is_all_gather: + input_axes = (DP_AXIS, TPSP_AXIS, None) + weight_axes = (None, TPSP_AXIS) + bias_axes = (TPSP_AXIS,) + output_axes = (DP_AXIS, None, TPSP_AXIS) + else: # RS + input_axes = (DP_AXIS, None, TPSP_AXIS) + weight_axes = (TPSP_AXIS, None) + bias_axes = (None,) + output_axes = (DP_AXIS, TPSP_AXIS, None) + return input_axes, weight_axes, bias_axes, output_axes + + +def _get_operand_sharding(mesh, collective_op): + input_axes, weight_axes, bias_axes, _ = _get_logical_axes(collective_op) + x_sharding = NamedSharding(mesh, PartitionSpec(*input_axes)) + weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_axes)) + bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes)) + return x_sharding, weight_sharding, bias_sharding + + +def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): + output = dense( + x, + weight, + bias, + contracting_dims=((2,), (0,)), + input_axes=input_axes, + kernel_axes=weight_axes, + output_axes=output_axes, + collective_op_set=collective_op_set, + ) + return jnp.mean(output.astype(jnp.float32)) + + +def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): + return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))( + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set + ) + + +def run_dense_grad_tests(args, mesh=None): + """Execute Dense Gradient tests.""" + print(args) + _initialize_distributed(args) + mesh = mesh or _create_mesh(args) + + # Create test data + rng = jax.random.PRNGKey(0) + rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4) + x = jax.random.normal( + x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 + ) + weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16) + bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16) + + collective_op = ( + CollectiveOp.ALL_GATHER + if args.collective_type == "all_gather" + else CollectiveOp.REDUCE_SCATTER + ) + collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) + + with mesh, fp8_autocast( + enabled=False, + fp8_recipe=None, + mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), + ): + # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast + axis_rules = flax.linen.get_logical_axis_rules() + axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) + te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) + with flax.linen.logical_axis_rules(te_extended_axis_rules): + + x_sharding, weight_sharding, bias_sharding = _get_operand_sharding(mesh, collective_op) + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + input_axes, weight_axes, _, output_axes = _get_logical_axes(collective_op) + ref_output, ref_grads = _value_and_grad_dense( + x_sharded, + weight_sharded, + bias_sharded, + input_axes, + weight_axes, + output_axes, + noop_collective_op_set, + ) + output, sharded_grads = _value_and_grad_dense( + x_sharded, + weight_sharded, + bias_sharded, + input_axes, + weight_axes, + output_axes, + collective_op_set, + ) + jax.block_until_ready(ref_output) + jax.block_until_ready(output) + gathered_grads = [] + gathered_ref_grads = [] + for ref_grad, grad in zip(ref_grads, sharded_grads): + gathered_grads.append( + jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None))) + ) + gathered_ref_grads.append( + jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None))) + ) + jax.block_until_ready(gathered_grads) + jax.block_until_ready(gathered_ref_grads) + + if args.enable_result_check and args.process_id == 0: + assert_allclose(ref_output, output, dtype=jnp.bfloat16) + for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): + assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + + +class TestCollectiveDenseGradient(unittest.TestCase): + """Collective Dense Gradient unittests""" + + def setUp(self): + self.args = cgemm_parser( + "Collective Dense Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + self.args.coordinator_address = self.coordinator_address + self.args.num_processes = self.num_processes + self.args.process_id = self.process_id + self.args.local_device_ids = self.local_device_ids + self.args.num_devices_per_process = self.num_devices_per_process + self.args.enable_data_parallel = True + self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] + _initialize_distributed(self.args) + # Create mesh once for all tests + self.mesh = _create_mesh(self.args) + jax.sharding.set_mesh(self.mesh) + self.args.enable_result_check = True + os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1" + + def tearDown(self): + os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) + + def test_te_bf16_all_gather(self): + """Test Collective Dense Gradient with AllGather""" + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_bf16_reduce_scatter(self): + """Test Collective Dense Gradient with ReduceScatter""" + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 7: # Need at least the 3 required distributed args + print("Error: This script requires distributed initialization arguments.") + print( + "Usage: python test_dense_grad.py --coordinator-address
--num-processes " + " --process-id [--local-device-ids ] [other args]" + ) + print( + "Example: python test_dense_grad.py --coordinator-address localhost:1234" + " --num-processes 4 --process-id 0" + ) + print( + "Example: python test_dense_grad.py --coordinator-address localhost:1234" + " --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3" + ) + sys.exit(1) + + args = cgemm_parser( + "Collective Dense Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + _initialize_distributed(args) + run_dense_grad_tests(args, mesh=None) diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py new file mode 100644 index 000000000..307e4444e --- /dev/null +++ b/examples/jax/collective_gemm/test_gemm.py @@ -0,0 +1,206 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Collective GEMM test on multi-GPU with tensor parallelism + +This script uses custom distributed initialization with the following arguments: +- --coordinator-address: Coordinator address for distributed initialization +- --num-processes: Number of processes for distributed initialization +- --process-id: Process ID for distributed initialization +- --local-device-ids: Local device IDs for distributed initialization + +Example: + python test_gemm.py --coordinator-address localhost:1234 --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3 +""" +import unittest +import os +from functools import partial + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec, NamedSharding + +from common import ( + assert_allclose, + _initialize_distributed, + _get_dp_and_tp_sizes, + _create_mesh, + DP_AXIS, + TPSP_AXIS, + PARAMS_KEY, + cgemm_parser, +) + +import transformer_engine.jax.cpp_extensions as tex +from transformer_engine.jax.quantize import fp8_autocast +from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp +from transformer_engine.jax.sharding import MeshResource + + +def _get_operand_sharding(mesh, collective_op, is_with_dp): + + dp_axis = DP_AXIS if is_with_dp else None + if collective_op == CollectiveOp.ALL_GATHER: + x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None)) + weight_sharding = NamedSharding(mesh, PartitionSpec(None, TPSP_AXIS)) + bias_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS)) + output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS)) + else: # RS + x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS)) + weight_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS, None)) + bias_sharding = NamedSharding(mesh, PartitionSpec(None)) + output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None)) + + return x_sharding, weight_sharding, bias_sharding, output_sharding + + +def _get_dp_and_tp_sizes(args): + num_gpu = args.num_processes * args.num_devices_per_process + if args.tensor_parallel_size is None: + num_gpu_dp = 2 if args.enable_data_parallel else 1 + assert ( + num_gpu > 1 and num_gpu % num_gpu_dp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_tp = num_gpu // num_gpu_dp + else: + num_gpu_tp = args.tensor_parallel_size + assert ( + num_gpu > 1 and num_gpu % num_gpu_tp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_dp = num_gpu // num_gpu_tp + return num_gpu_dp, num_gpu_tp + + +@partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding")) +def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_sharding): + output = tex.gemm( + x, + weight, + bias=bias, + contracting_dims=contracting_dims, + collective_op=collective_op, + ) + if output_sharding is not None: + output = jax.lax.with_sharding_constraint(output, output_sharding) + return output + + +def run_gemm_tests(args, mesh=None): + """Execute GEMM tests.""" + print(args) + # Collective GEMM requires Shardy partitioner to be disabled + jax.config.update("jax_use_shardy_partitioner", False) + + # Initialize distributed with provided arguments + _initialize_distributed(args) + mesh = mesh or _create_mesh(args) + + # Create test data + rng = jax.random.PRNGKey(0) + rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4) + x = jax.random.normal( + x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 + ) + weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16) + bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16) + collective_op = ( + CollectiveOp.ALL_GATHER + if args.collective_type == "all_gather" + else CollectiveOp.REDUCE_SCATTER + ) + + with mesh, fp8_autocast( + enabled=False, + fp8_recipe=None, + mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), + ): + print(f"Device mesh: {mesh}") + + x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding( + mesh, collective_op, args.enable_data_parallel + ) + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + ref_output = _jitted_cgemm( + x_sharded, + weight_sharded, + bias_sharded, + contracting_dims=((2,), (0,)), + collective_op=CollectiveOp.NONE, + output_sharding=output_sharding, + ) + output = _jitted_cgemm( + x_sharded, + weight_sharded, + bias_sharded, + contracting_dims=((2,), (0,)), + collective_op=collective_op, + # CollectiveGEMM output should have a correct sharding without applying sharding constraint + output_sharding=None, + ) + assert ( + ref_output.sharding == output.sharding + ), f"ref_output.sharding={ref_output.sharding}, output.sharding={output.sharding}" + gathered_ref_output = jax.lax.with_sharding_constraint( + ref_output, NamedSharding(mesh, PartitionSpec(None)) + ) + gathered_output = jax.lax.with_sharding_constraint( + output, NamedSharding(mesh, PartitionSpec(None)) + ) + jax.block_until_ready(gathered_ref_output) + jax.block_until_ready(gathered_output) + + if args.enable_result_check and args.process_id == 0: + assert_allclose(gathered_ref_output, gathered_output) + + +class TestCollectiveGemmWithDP(unittest.TestCase): + """Collective GEMM with DP unittests""" + + def setUp(self): + self.args = cgemm_parser( + "Collective GEMM test on multi-GPU with tensor parallelism" + ).parse_args([]) + self.args.coordinator_address = self.coordinator_address + self.args.num_processes = self.num_processes + self.args.process_id = self.process_id + self.args.local_device_ids = self.local_device_ids + self.args.num_devices_per_process = self.num_devices_per_process + self.args.enable_data_parallel = True + self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] + _initialize_distributed(self.args) + self.mesh = _create_mesh(self.args) + jax.sharding.set_mesh(self.mesh) + self.args.enable_result_check = True + os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1" + + def tearDown(self): + os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) + + def test_te_bf16_all_gather_with_dp(self): + """Test Collective GEMM with AllGather""" + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_bf16_reduce_scatter_with_dp(self): + """Test Collective GEMM with ReduceScatter""" + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 5: # Need at least the 3 required distributed args + print("Error: This script requires distributed initialization arguments.") + print( + "Usage: python test_gemm.py --coordinator-address
--num-processes " + " --process-id [--local-device-ids ] [other args]" + ) + sys.exit(1) + + args = cgemm_parser("Collective GEMM test on multi-GPU with tensor parallelism").parse_args() + _initialize_distributed(args) + run_gemm_tests(args, mesh=None) diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py new file mode 100644 index 000000000..7bd6eb6a3 --- /dev/null +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -0,0 +1,272 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Collective Dense Gradient test on multi-GPU with tensor parallelism""" +import argparse +import unittest +import os + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec, NamedSharding +import flax + +from common import ( + assert_allclose, + _initialize_distributed, + _get_dp_and_tp_sizes, + _create_mesh, + DP_AXIS, + TPSP_AXIS, + PARAMS_KEY, + cgemm_parser, +) + +from transformer_engine.jax.layernorm_mlp import layernorm_mlp + +from transformer_engine.jax.quantize import fp8_autocast +from transformer_engine.jax.cpp_extensions.gemm import ( + CollectiveOpSet, + CollectiveOp, + noop_collective_op_set, +) +from transformer_engine.jax.sharding import MeshResource +import transformer_engine.jax.flax as te_flax + + +def _get_logical_axes(): + input_1_axes = (DP_AXIS, TPSP_AXIS, None) + weight_1_axes = (None, None, TPSP_AXIS) + bias_axes_1 = (None, TPSP_AXIS) + input_2_axes = (DP_AXIS, None, TPSP_AXIS) + weight_2_axes = (TPSP_AXIS, None) + bias_axes_2 = (None,) + return input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2 + + +def _get_operand_sharding(mesh): + input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2 = ( + _get_logical_axes() + ) + x_sharding = NamedSharding(mesh, PartitionSpec(*input_1_axes)) + weight_1_sharding = NamedSharding(mesh, PartitionSpec(*weight_1_axes)) + bias_1_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_1)) + weight_2_sharding = NamedSharding(mesh, PartitionSpec(*weight_2_axes)) + bias_2_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_2)) + return x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding + + +def _mean_layernorm_mlp( + x, + weight_1, + bias_1, + weight_2, + bias_2, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, +): + output = layernorm_mlp( + x, + gamma, + beta=None, + kernels=[weight_1, weight_2], + biases=[bias_1, bias_2], + norm_type="rmsnorm", + dot_1_input_axes=input_1_axes, + dot_2_input_axes=input_2_axes, + kernel_1_axes=weight_1_axes, + kernel_2_axes=weight_2_axes, + activation_type=("gelu",), + collective_op_sets=collective_op_sets, + ) + return jnp.mean(output) + + +def _value_and_grad_layernorm_mlp( + x, + weight_1, + bias_1, + weight_2, + bias_2, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, +): + return jax.jit( + jax.value_and_grad(_mean_layernorm_mlp, (0, 1, 2, 3, 4, 5)), static_argnums=(6, 7, 8, 9, 10) + )( + x, + weight_1, + bias_1, + weight_2, + bias_2, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, + ) + + +def run_layernorm_mlp_grad_tests(args, mesh=None): + """Execute Dense Gradient tests.""" + print(args) + # Collective GEMM requires Shardy partitioner to be disabled + jax.config.update("jax_use_shardy_partitioner", False) + + # Initialize distributed with provided arguments + _initialize_distributed(args) + + mesh = mesh or _create_mesh(args) + + # Create test data + rng = jax.random.PRNGKey(0) + rng, x_rng, weight_1_rng, bias_1_rng, weight_2_rng, bias_2_rng, gamma_rng = jax.random.split( + rng, 7 + ) + x = jax.random.normal( + x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 + ) + weight_1 = jax.random.normal( + weight_1_rng, (args.hidden_in, 1, args.hidden_out), dtype=jnp.bfloat16 + ) / jnp.sqrt(args.hidden_in) + bias_1 = jax.random.normal(bias_1_rng, (1, args.hidden_out), dtype=jnp.bfloat16) + weight_2 = jax.random.normal( + weight_2_rng, (args.hidden_out, args.hidden_in), dtype=jnp.bfloat16 + ) / jnp.sqrt(args.hidden_out) + bias_2 = jax.random.normal(bias_2_rng, (args.hidden_in,), dtype=jnp.bfloat16) + gamma = jax.random.normal(gamma_rng, (args.hidden_in,), dtype=jnp.bfloat16) / jnp.sqrt( + args.hidden_in + ) + collective_op_set_1 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.ALL_GATHER) + collective_op_set_2 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.REDUCE_SCATTER) + collective_op_sets = (collective_op_set_1, collective_op_set_2) + noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) + + with mesh, fp8_autocast( + enabled=False, + fp8_recipe=None, + mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), + ): + # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast + axis_rules = flax.linen.get_logical_axis_rules() + axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) + te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) + with flax.linen.logical_axis_rules(te_extended_axis_rules): + x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding = ( + _get_operand_sharding(mesh) + ) + x_sharded = jax.device_put(x, x_sharding) + weight_1_sharded = jax.device_put(weight_1, weight_1_sharding) + bias_1_sharded = jax.device_put(bias_1, bias_1_sharding) + weight_2_sharded = jax.device_put(weight_2, weight_2_sharding) + bias_2_sharded = jax.device_put(bias_2, bias_2_sharding) + + input_1_axes, weight_1_axes, _, input_2_axes, weight_2_axes, _ = _get_logical_axes() + ref_output, ref_grads = _value_and_grad_layernorm_mlp( + x_sharded, + weight_1_sharded, + bias_1_sharded, + weight_2_sharded, + bias_2_sharded, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + noop_collective_op_sets, + ) + output, sharded_grads = _value_and_grad_layernorm_mlp( + x_sharded, + weight_1_sharded, + bias_1_sharded, + weight_2_sharded, + bias_2_sharded, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, + ) + jax.block_until_ready(ref_output) + jax.block_until_ready(output) + gathered_grads = [] + gathered_ref_grads = [] + for ref_grad, grad in zip(ref_grads, sharded_grads): + gathered_grads.append( + jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None))) + ) + gathered_ref_grads.append( + jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None))) + ) + jax.block_until_ready(gathered_grads) + jax.block_until_ready(gathered_ref_grads) + + if args.enable_result_check and args.process_id == 0: + assert_allclose(ref_output, output, dtype=jnp.bfloat16) + for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): + assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + + +class TestCollectiveLayerNormMLPGradient(unittest.TestCase): + """Collective Dense Gradient unittests""" + + def setUp(self): + self.args = cgemm_parser( + "Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + self.args.coordinator_address = self.coordinator_address + self.args.num_processes = self.num_processes + self.args.process_id = self.process_id + self.args.local_device_ids = self.local_device_ids + self.args.num_devices_per_process = self.num_devices_per_process + self.args.enable_data_parallel = True + self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] + _initialize_distributed(self.args) + # Create mesh once for all tests + self.mesh = _create_mesh(self.args) + jax.sharding.set_mesh(self.mesh) + self.args.enable_result_check = True + os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1" + + def tearDown(self): + os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) + + def test_te_bf16_layernorm_mlp_grad(self): + """Test Collective Dense Gradient with AllGather""" + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 7: # Need at least the 3 required distributed args + print("Error: This script requires distributed initialization arguments.") + print( + "Usage: python test_layernorm_mlp_grad.py --coordinator-address
" + " --num-processes --process-id [--local-device-ids ] [other args]" + ) + print( + "Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234" + " --num-processes 4 --process-id 0" + ) + print( + "Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234" + " --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3" + ) + sys.exit(1) + + args = cgemm_parser( + "Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + _initialize_distributed(args) + run_layernorm_mlp_grad_tests(args, mesh=None) diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index 2a1ac0f8f..2a979e177 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -15,11 +15,37 @@ TEST_CASES=( "test_te_current_scaling_fp8_shardy" ) +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + echo echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***" HAS_FAILURE=0 # Global failure flag +PIDS=() # Array to store all process PIDs + +# Cleanup function to kill all processes +cleanup() { + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill -TERM "$pid" 2>/dev/null || true + fi + done + # Wait a bit and force kill if needed + sleep 2 + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -KILL "$pid" 2>/dev/null || true + fi + done +} + +# Set up signal handlers to cleanup on exit +trap cleanup EXIT INT TERM # Run each test case across all GPUs for TEST_CASE in "${TEST_CASES[@]}"; do echo @@ -29,25 +55,40 @@ for TEST_CASE in "${TEST_CASES[@]}"; do # Define output file for logs LOG_FILE="${TEST_CASE}_gpu_${i}.log" - # Run pytest and redirect stdout and stderr to the log file - pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ - --num-process=$NUM_GPUS \ - --process-id=$i > "$LOG_FILE" 2>&1 & - done + # For process 0: show live output AND save to log file using tee + if [ $i -eq 0 ]; then + echo "=== Live output from process 0 ===" + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs --junitxml=$XML_LOG_DIR/multiprocessing_encoder_${TEST_CASE}.xml \ + "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ + --num-process=$NUM_GPUS \ + --process-id=$i 2>&1 | tee "$LOG_FILE" & + PID=$! + PIDS+=($PID) + else + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ + --num-process=$NUM_GPUS \ + --process-id=$i > "$LOG_FILE" 2>&1 & + PID=$! + PIDS+=($PID) + fi + done # Wait for the process to finish wait - tail -n +7 "${TEST_CASE}_gpu_0.log" # Check and print the log content accordingly if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then echo "... $TEST_CASE SKIPPED" + elif grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then + echo "... $TEST_CASE FAILED" + HAS_FAILURE=1 elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then echo "... $TEST_CASE PASSED" else + echo "... $TEST_CASE INVALID" HAS_FAILURE=1 - echo "... $TEST_CASE FAILED" fi # Remove the log file after processing it @@ -56,4 +97,8 @@ for TEST_CASE in "${TEST_CASES[@]}"; do done wait + +# Final cleanup (trap will also call cleanup on exit) +cleanup + exit $HAS_FAILURE diff --git a/pyproject.toml b/pyproject.toml index c4df4aecc..80110af8a 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax", "flax>=0.7.1"] +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "nvidia-mathdx==25.1.1", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index d9c46347f..ae45f398e 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -29,6 +29,10 @@ wait python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" wait TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" +wait + +TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh" +wait if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index e4a3f4630..cb097d492 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" # Test without custom calls export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" +NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder_without_custom_call.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index b4bf0a024..7f19dda67 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -7,6 +7,8 @@ : ${TE_PATH:=/opt/transformerengine} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} : ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" # Config with the dummy feature which prevents nvinspect from being disabled. # Nvinspect will be disabled if no feature is active. @@ -20,17 +22,16 @@ pip uninstall -y nvdlfw-inspect pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git pip install pytest==8.2.1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 exit $FAIL diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 394273ca4..cdf0df888 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -31,6 +31,7 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 7f061d222..e698e997a 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" @@ -47,9 +48,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_ : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} -pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" +pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" # standard numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index 1486d5097..7fce13a3d 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -3,9 +3,11 @@ # See LICENSE for license information. -pip3 install onnxruntime==1.20.1 -pip3 install onnxruntime_extensions==0.13.0 +pip3 install onnxruntime +pip3 install onnxruntime_extensions : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" -python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 46bcf4242..c9a7a03e8 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -32,7 +32,8 @@ list(APPEND test_cuda_sources ../test_common.cu) if(USE_CUDA) list(APPEND test_cuda_sources - test_cast_float8blockwise.cu) + test_cast_float8blockwise.cu + test_cast_nvfp4_transpose.cu) else() list(APPEND test_cuda_sources test_cublaslt_gemm.cu) @@ -70,6 +71,13 @@ else() add_executable(test_operator ${test_hip_sources}) endif() +# Add profiling and debug flags for CUDA compilation +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Generate line info for device code +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g") # Add debug symbols for host code +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") # Add info about registers usage +# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping + +# Find required packages find_package(OpenMP REQUIRED) if(USE_CUDA) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index b635dc00b..f1780cbe0 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -86,6 +86,7 @@ void compute_ref(const ProcessingMethod processing_method, // Cache computations for (size_t i = i_min; i < i_max; ++i) { for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); @@ -321,7 +322,7 @@ void performTest_x1(const ProcessingMethod processing_method, std::vector mismatches_scales_indices; size_t mismatches_scales = 0; - compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), + compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, mismatches_scales_indices, mismatches_scales, scale_diff_abs_tolerance, @@ -506,7 +507,7 @@ void performTest_x2(const ProcessingMethod processing_method, std::vector mismatches_scales_indices_rowwise; size_t mismatches_scales_rowwise = 0; - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise, mismatches_scales_indices_rowwise, mismatches_scales_rowwise, @@ -516,7 +517,7 @@ void performTest_x2(const ProcessingMethod processing_method, std::vector mismatches_scales_indices_colwise; size_t mismatches_scales_colwise = 0; - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise, mismatches_scales_indices_colwise, mismatches_scales_colwise, diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 52180786d..38a9c8d29 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -274,21 +274,22 @@ void performTest_x1(const size_t rows, ? output.rowwise_cpu_scale_inv_ptr() : output.columnwise_cpu_scale_inv_ptr(); if (rowwise) { - compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + compare_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, mismatches_scales_indices, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); } else { - compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + compare_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, mismatches_scales_indices, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + } #ifdef __HIP_PLATFORM_AMD__ @@ -396,22 +397,22 @@ void performTest_x2(const size_t rows, std::vector mismatches_scales_indices_rowwise; size_t mismatches_scales_rowwise = 0; - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_indices_rowwise, mismatches_scales_rowwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise, + mismatches_scales_indices_rowwise, mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); std::vector mismatches_scales_indices_colwise; size_t mismatches_scales_colwise = 0; - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_indices_colwise, mismatches_scales_colwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise, + mismatches_scales_indices_colwise, mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); #ifdef __HIP_PLATFORM_AMD__ if (::testing::Test::HasFatalFailure()) return; diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu new file mode 100644 index 000000000..e905a0064 --- /dev/null +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -0,0 +1,741 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" +#include + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ActivationType { + Identity, + GeLU, + SiLU, + ReLU, + QGeLU, + SReLU +}; + +double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) { + const __half2_raw raw_truncated_to_fp4e2m1_pair = + __nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1); + + const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair); + const double truncated_to_fp4e2m1_x = static_cast(truncated_to_fp4e2m1_pair.x); + const double truncated_to_fp4e2m1_y = static_cast(truncated_to_fp4e2m1_pair.y); + return {truncated_to_fp4e2m1_x, truncated_to_fp4e2m1_y}; +} + +template +std::vector create_transpose(const InputType* const input, const size_t rows, size_t cols) { + std::vector input_t(cols * rows); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + const size_t idx_t = j * rows + i; + input_t[idx_t] = input[idx]; + } + } + return input_t; +} + +// Compute the global encode scale factor for a given global amax +float compute_global_encode_scaling_factor_FP4(const float global_amax) { + constexpr float fp8_max = 448.0f; // 448.0f; + constexpr float fp4_max = 6.0f; // 6.0f; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, Numeric_Traits::maxNorm); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.0f || global_encode_scale == 0.0f) { + return 1.0f; + } + return global_encode_scale; +} + +// 1D Scaling: Original implementation with 1x16 blocks +template +void quantize_nvfp4_1d(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax) { + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + + constexpr size_t block_size_X = 16; + const size_t blocks_X = divide_round_up(cols, block_size_X); + + std::array cache_buffer; + for (size_t i = 0; i < block_size_X; ++i) { + cache_buffer[i] = 0.0f; + } + + for (size_t i = 0; i < rows; ++i) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t j_min = block_X * block_size_X; + const size_t j_max = j_min + block_size_X; + + // Find block amax + float block_amax = 0.0f; + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx = j - j_min; + + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + const float elt = static_cast(static_cast(act_elt)); + cache_buffer[cache_idx] = elt; + block_amax = std::max(block_amax, std::abs(elt)); + } + + // 2. Compute E4M3 scaling factor + // Compute per-block encoding/decoding scaling factor + const float S_dec_b = block_amax / 6.0f; + + // Scale & Store per-block decoding scaling factor + const float S_dec_b_fp8 = S_dec_b * S_enc; + + // Compute "correct" per-block encoding scaling factor + const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = static_cast(S_dec_b_fp8); + const float scale_reciprocal = S_enc_b_fp8; + + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const int cache_idx_x = j - j_min; + const int cache_idx_y = cache_idx_x + 1; + const float cached_x = cache_buffer[cache_idx_x]; + const float cached_y = cache_buffer[cache_idx_y]; + const float scaled_elt_x = cached_x * scale_reciprocal; + const float scaled_elt_y = cached_y * scale_reciprocal; + const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y}; + + fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); + output[idx_pair] = casted_to_e2m1_pair; + + // const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); + } + } + } +} + +// Compute 2D mathematical scaling factors (8x8 for 128x128 input) +template +void compute_2d_mathematical_scales(float (*OP)(const float), + const InputType* const input, + const size_t rows, + const size_t cols, + const float global_amax, + std::vector>& math_scales) { + + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + math_scales.resize(blocks_Y, std::vector(blocks_X)); + + for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t i_min = block_Y * block_size_Y; + const size_t i_max = std::min(i_min + block_size_Y, rows); + const size_t j_min = block_X * block_size_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + + // Find 2D block amax over entire 16x16 region + float block_amax = 0.0f; + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + const float elt = static_cast(static_cast(act_elt)); + block_amax = std::max(block_amax, std::abs(elt)); + } + } + + // Compute E4M3 scaling factor for this 16x16 block + const float S_dec_b = block_amax / 6.0f; + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + math_scales[block_Y][block_X] = S_dec_b_fp8; + } + } +} + +// 2D Scaling: NEW implementation with proper replication +template +void quantize_nvfp4_2d(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax) { + + // Step 1: Compute mathematical 8x8 scaling factors + std::vector> math_scales; + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + // Step 2: Replicate scaling factors row-wise (128×8 storage) - only if scales is not nullptr + if (scales != nullptr) { + // Each of the 128 rows gets scaling factors from its corresponding 16×16 block + for (size_t i = 0; i < rows; ++i) { + const size_t block_Y = i / block_size_Y; + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = math_scales[block_Y][block_X]; + } + } + } + + // Step 3: Apply quantization using the mathematical scaling factors + std::array, block_size_Y> cache_buffer; + + for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t i_min = block_Y * block_size_Y; + const size_t i_max = std::min(i_min + block_size_Y, rows); + const size_t j_min = block_X * block_size_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + + // Get the scaling factor for this block + const float S_dec_b_fp8 = static_cast(math_scales[block_Y][block_X]); + const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + const float scale_reciprocal = S_enc_b_fp8; + + // Process and cache data for this 16x16 block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx_y = i - i_min; + const size_t cache_idx_x = j - j_min; + + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + const float elt = static_cast(static_cast(act_elt)); + cache_buffer[cache_idx_y][cache_idx_x] = elt; + } + } + + // Apply scaling to all elements in this 16x16 block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const size_t cache_idx_y = i - i_min; + const size_t cache_idx_x1 = j - j_min; + const size_t cache_idx_x2 = std::min(cache_idx_x1 + 1, block_size_X - 1); + + const float cached_x = cache_buffer[cache_idx_y][cache_idx_x1]; + const float cached_y = ((j + 1) < j_max && cache_idx_x2 < block_size_X) ? + cache_buffer[cache_idx_y][cache_idx_x2] : 0.0f; + + const float scaled_elt_x = cached_x * scale_reciprocal; + const float scaled_elt_y = cached_y * scale_reciprocal; + const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y}; + + fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); + output[idx_pair] = casted_to_e2m1_pair; + } + } + } + } +} + +// Wrapper function that calls appropriate implementation based on 2D flag +template +void quantize_nvfp4(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax, + const bool use_2d_quantization = false) { + if (use_2d_quantization) { + quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + } else { + quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + } +} + +template +void compute_ref(float (*OP)(const float), + const InputType* input, + fp4e2m1x2* output, + fp4e2m1x2* output_t, + fp8e4m3* scales, + fp8e4m3* scales_t, + const float global_amax, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const size_t scales_stride_t, + const bool use_2d_quantization = false) +{ + std::vector input_t = create_transpose(input, rows, cols); + + if (use_2d_quantization) { + // Step 1: Compute mathematical 8×8 scaling factors + std::vector> math_scales; + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + // Step 2: Generate scales (128×8) by replicating row-wise + for (size_t i = 0; i < rows; ++i) { + const size_t block_Y = i / block_size_Y; + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = math_scales[block_Y][block_X]; + } + } + + // Step 3: Generate scales_t (128×8) with proper transposed block mapping + for (size_t i = 0; i < cols; ++i) { // cols = 128, which becomes rows of transposed data + const size_t block_X_orig = i / block_size_X; // i was column index in original, so maps to block_X + for (size_t block_Y_new = 0; block_Y_new < blocks_Y; ++block_Y_new) { // block in transposed coordinate + const size_t scale_idx = i * scales_stride_t + block_Y_new; + scales_t[scale_idx] = math_scales[block_Y_new][block_X_orig]; + } + } + + // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d + // (This part processes the actual FP4 data using the mathematical scaling factors) + quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled + quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled + + } else { + quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization); + quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization); + } +} + +void compare_nvfp4_tensors(const std::string& name, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols, + double atol = 1e-5, double rtol = 1e-8) { + std::vector mismatch_messages; + size_t total_mismatches = 0; + + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; j += 2) { + const int idx = i * cols + j; + double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2])); + double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2])); + + for (int k = 0; k < 2; ++k) { + const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); + const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); + + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = false; + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + if (assertion) { + total_mismatches++; + std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + + std::to_string(t) + " vs " + std::to_string(r) + + " (abs_diff: " + std::to_string(fabs(t - r)) + + ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; + mismatch_messages.push_back(msg); + + // Optional: limit number of detailed messages to avoid overwhelming output + if (mismatch_messages.size() <= 100) { + std::cout << "Error in tensor " << name << ": " << msg << std::endl; + } + } + } + } + } + + // Always report summary - either success or failure + std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; + std::cout << "Total elements checked: " << (rows * cols) << std::endl; + + if (total_mismatches > 0) { + std::cout << "STATUS: FAILED for output" << std::endl; + std::cout << "Total mismatches found: " << total_mismatches << std::endl; + std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; + if (mismatch_messages.size() > 100) { + std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; + } + std::cout << "============================" << std::endl; + + GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name; + } else { + std::cout << "STATUS: PASSED for output" << std::endl; + std::cout << "All elements match within tolerance!" << std::endl; + std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl; + std::cout << "============================" << std::endl; + } +} + +// Optional: Function to dump tensor data to files for detailed analysis +void dump_nvfp4_tensor_data(const std::string& prefix, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols) { + std::string test_file = prefix + "_test.txt"; + std::string ref_file = prefix + "_ref.txt"; + std::string diff_file = prefix + "_diff.txt"; + + std::ofstream test_out(test_file); + std::ofstream ref_out(ref_file); + std::ofstream diff_out(diff_file); + + if (test_out.is_open() && ref_out.is_open() && diff_out.is_open()) { + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; j += 2) { + const int idx = i * cols + j; + double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2])); + double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2])); + + for (int k = 0; k < 2; ++k) { + const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); + const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); + const int pos = idx + k; + + test_out << "pos[" << pos << "] = " << t << std::endl; + ref_out << "pos[" << pos << "] = " << r << std::endl; + diff_out << "pos[" << pos << "] test=" << t << " ref=" << r + << " abs_diff=" << fabs(t - r) + << " rel_diff=" << (r == 0 ? 0.0 : fabs((t - r) / r)) << std::endl; + } + } + } + std::cout << "DEBUG: Dumped tensor data to files: " << test_file << ", " << ref_file << ", " << diff_file << std::endl; + } else { + std::cout << "WARNING: Could not open files for tensor data dump" << std::endl; + } +} + +void print_detailed_tensor_comparison(const std::string& name, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols) { + printf("\n=== DETAILED COMPARISON for %s (%d×%d = %d elements) ===\n", + name.c_str(), rows, cols, rows * cols); + + const int total_elements = rows * cols; + const int check_count = 128; + + printf("--- FIRST %d ELEMENTS ---\n", check_count); + printf("Index | Test_Value | Ref_Value | Match\n"); + printf("------|---------------|---------------|-------\n"); + for (int i = 0; i < std::min(check_count, total_elements); ++i) { + double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2])); + double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2])); + + double t = (i % 2 == 0) ? test_pair.x : test_pair.y; + double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y; + bool match = (fabs(t - r) < 1e-6); + + printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗"); + } + + if (total_elements > 2 * check_count) { + printf("\n--- LAST %d ELEMENTS ---\n", check_count); + printf("Index | Test_Value | Ref_Value | Match\n"); + printf("------|---------------|---------------|-------\n"); + for (int i = total_elements - check_count; i < total_elements; ++i) { + double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2])); + double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2])); + + double t = (i % 2 == 0) ? test_pair.x : test_pair.y; + double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y; + bool match = (fabs(t - r) < 1e-6); + + printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗"); + } + } + printf("==================================\n"); +} + +void compareResults_nvfp4(const Tensor &test, + const void *ref, const void *ref_t, const int rows, const int cols, + double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) { + if (if_on_gpus) test.to_cpu(); + + const fp4e2m1 *test_data = test.rowwise_cpu_dptr(); + const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr(); + const fp4e2m1 *ref_data = reinterpret_cast(ref); + const fp4e2m1 *ref_data_t = reinterpret_cast(ref_t); + + // Print detailed element-by-element comparison + // print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols); + // print_detailed_tensor_comparison("output_t", test_data_t, ref_data_t, cols, rows); + + // Optionally dump tensor data to files for detailed analysis + if (dump_data) { + dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols); + dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows); + } + + compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol); + compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); +} + +template +void performTest(float (*OP)(const float), + const std::vector& shape) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = DType::kFloat4E2M1; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + // Use get_scale_tensor_dims for NVFP4 scale tensor dimensions + // Now that CheckScaleTensorShape is fixed, this should work correctly + const std::array scale_dims = get_scale_tensor_dims(rows, cols, 1, 16); + const std::array scale_dims_t = get_scale_tensor_dims(cols, rows, 1, 16); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + const size_t unpadded_blocks_Y_t = scale_dims_t[0]; + const size_t unpadded_blocks_X_t = scale_dims_t[1]; + const size_t blocks_Y_t = scale_dims_t[2]; + const size_t blocks_X_t = scale_dims_t[3]; + const size_t scales_stride_t = blocks_X_t; + + Tensor input("input", shape, itype); + Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); + + std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); + std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); + std::unique_ptr ref_scales = std::make_unique(blocks_Y * blocks_X); + std::unique_ptr ref_scales_t = std::make_unique(blocks_Y_t * blocks_X_t); + + fillCase(&input, InputsFillCase::uniform); + + // Find global amax + float amax = 0.0f; + const InputType* input_dptr = input.rowwise_cpu_dptr(); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + amax = fmaxf(amax, static_cast(input_dptr[idx])); + } + } + // Set 2nd stage NVFP4 scaling factor + output.set_scale(amax); + + bool use_2d_quantization = false; + + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + output.scale(), + rows, + cols, + scales_stride, + scales_stride_t, + use_2d_quantization); + + QuantizationConfigWrapper quant_config; + + // Initialize stochastic rounding + Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); + rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed + rng_state.rowwise_cpu_dptr()[1] = 321; // rng_sequence + rng_state.from_cpu(); + quant_config.set_stochastic_rounding(false); + quant_config.set_rng_state(rng_state.data()); + + // Set 2D quantization based on compile-time flag + quant_config.set_nvfp4_2d_quantization(use_2d_quantization); + + // Call appropriate function based on operation type + // Activation functions take 3 parameters (input, output, stream) + // nvte_quantize_v2 takes 4 parameters (input, output, quant_config, stream) + if (OP == &gelu) { + nvte_gelu(input.data(), output.data(), 0); + } else if (OP == &silu) { + nvte_silu(input.data(), output.data(), 0); + } else if (OP == &relu) { + nvte_relu(input.data(), output.data(), 0); + } else if (OP == &qgelu) { + nvte_qgelu(input.data(), output.data(), 0); + } else if (OP == &srelu) { + nvte_srelu(input.data(), output.data(), 0); + } else { + nvte_quantize_v2(input.data(), output.data(), quant_config, 0); + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("DEBUG: CUDA error detected: %s\n", cudaGetErrorString(err)); + } + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + const double atol = 0.05; + const double rtol = 0.1; + + // Set dump_data=true to enable dumping tensor data to files for analysis + compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); + + const fp8e4m3* kernel_scales = output.rowwise_cpu_scale_inv_ptr(); + const fp8e4m3* ref_scales_ptr = ref_scales.get(); + const fp8e4m3* kernel_scales_t = output.columnwise_cpu_scale_inv_ptr(); + const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get(); + + size_t scale_mismatches_num = 0; + compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), + ref_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + scale_mismatches_num); + + compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_t.get(), + unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, + scale_mismatches_num); +} + +std::vector> tensor_dims = { + {32, 32}, + {32, 64}, + {64, 32}, + {64, 96}, + {128, 128}, + {256, 256}, + {512, 512}, + {1024, 1024}, + {2048, 2048}, + {128, 256}, + {8192, 128}, + {2048, 160}, + {8, 32, 1024}, + {16, 8, 4, 512}, + {1024, 16384}, + {4096, 13312}, +}; + +// Only GeLU activation tests are supported +std::vector Activation_types = { + ActivationType::Identity, + ActivationType::GeLU, + ActivationType::SiLU, + ActivationType::ReLU, + ActivationType::QGeLU, + ActivationType::SReLU, +}; + +} // namespace + +class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam + , + transformer_engine::DType>> {}; + +TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ActivationType Act_type = std::get<0>(GetParam()); + const auto tensor_dims = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + + // Skip tests if the input tensor is 1D + if (tensor_dims.size() < 2) { + GTEST_SKIP(); + } + + // Forward activations + auto OP = &identity; + switch (Act_type) { + case ActivationType::GeLU: OP = &gelu; break; + case ActivationType::SiLU: OP = &silu; break; + case ActivationType::ReLU: OP = &relu; break; + case ActivationType::QGeLU: OP = &qgelu; break; + case ActivationType::SReLU: OP = &srelu; break; + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + performTest(OP, tensor_dims); + ); +} + +std::string to_string(const ActivationType Act_type) { + switch (Act_type) { + case ActivationType::Identity: return "CAST_ONLY"; + case ActivationType::GeLU: return "GeLU"; + case ActivationType::SiLU: return "SiLU"; + case ActivationType::ReLU: return "ReLU"; + case ActivationType::QGeLU: return "QGeLU"; + case ActivationType::SReLU: return "SReLU"; + default: return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::ValuesIn(Activation_types), + ::testing::ValuesIn(tensor_dims), + ::testing::Values(DType::kBFloat16)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)); + const auto& shape = std::get<1>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(info.param)); + return name; + }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 9f926d07b..eebe4108a 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -109,6 +109,10 @@ size_t DIVUP(const size_t &x, const size_t &y){ return (((x) + ((y)-1)) / (y)); } +size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){ + return DIVUP(x, y) * y; +} + struct scale_inv_meta { std::vector shape; DType type; @@ -145,21 +149,71 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; - auto block_alignment = std::vector{128ul, 4ul}; - { - auto alignment = block_alignment[0]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(1)), alignment) * alignment; - alignment = block_alignment[1]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(32)), alignment) * alignment; - ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + const size_t block_size_X_rowwise = 32; + size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise}; + + const size_t block_size_Y_colwise = 32; + size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise); + size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise); + ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise}; + + ret_rowwise.type = DType::kFloat8E8M0; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + ret_colwise.type = DType::kFloat8E8M0; + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_NVFP4_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); } - { - auto alignment = block_alignment[1]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(32)), alignment) * alignment; - alignment = block_alignment[0]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(1)), alignment) * alignment; - ret_colwise.shape = {scale_dim_0, scale_dim_1}; + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + NVTE_CHECK(last_dim % 32 == 0); + NVTE_CHECK(first_dim % 32 == 0); + + scale_inv_meta ret_rowwise, ret_colwise; + + size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y, scale_dim_X}; + + size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise); + ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t}; + + ret_rowwise.type = DType::kFloat8E4M3; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3); + ret_colwise.type = DType::kFloat8E4M3; + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3); + + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + const size_t block_size_X_rowwise = 32; + size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise}; + + const size_t block_size_Y_colwise = 32; + size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise); + size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise); + ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise}; + ret_rowwise.type = DType::kFloat8E8M0; ret_colwise.type = DType::kFloat8E8M0; ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); @@ -178,13 +232,13 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; { - auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4; + size_t scale_dim_0 = DIVUP(first_dim, 128lu); + size_t scale_dim_1 = DIVUP(DIVUP(last_dim, 128lu), 4) * 4; ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { - auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4; + size_t scale_dim_0 = DIVUP(last_dim, 128lu); + size_t scale_dim_1 = DIVUP(DIVUP(first_dim, 128lu), 4) * 4; ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; @@ -204,13 +258,13 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; { - auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(first_dim, 4) * 4; + size_t scale_dim_0 = DIVUP(last_dim, 128lu); + size_t scale_dim_1 = DIVUP(first_dim, 4) * 4; ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { - auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(last_dim, 4) * 4; + size_t scale_dim_0 = DIVUP(first_dim, 128lu); + size_t scale_dim_1 = DIVUP(last_dim, 4) * 4; ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; @@ -252,14 +306,15 @@ Tensor::Tensor(const std::string& name, NVTEShape columnwise_shape = {}; std::vector columnwise_shape_vec; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING + || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { // Transpose when tensor scaling columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); for (size_t i = 0; i < shape.ndim - 1; ++i) { columnwise_shape_vec.emplace_back(shape.data[i]); } } else { - // Same shape for MX + // Same shape for MX and NVFP4 for (size_t i = 0; i < shape.ndim; ++i) { columnwise_shape_vec.emplace_back(shape.data[i]); } @@ -285,10 +340,13 @@ Tensor::Tensor(const std::string& name, std::fill_n(cpu_data_columnwise_.get(), total_size, 0); } } - tensor_.set_rowwise_data(dptr_rowwise, type, shape); - tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape); - if (isFp8Type(type)) { + const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; + const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; + tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape); + tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape); + + if (isFp8Type(type) || isFp4Type(type)) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { (void)cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) (void)cudaMemset(amax, 0, sizeof(float)); @@ -307,13 +365,19 @@ Tensor::Tensor(const std::string& name, } if (columnwise) { tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, - std::vector{1}); + std::vector{1}); columnwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); } } else { - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(normalized_shape, tensor_.scaling_mode()); + if (scaling_mode == NVTE_NVFP4_1D_SCALING) { + // Used for NVFP4 second stage scaling + cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) + cudaMemset(scale, 0, sizeof(float)); + scale_cpu_data_ = std::make_shared(0); + tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode()); auto rowwise_scale_size = rowwise_scale_meta.bytes(); auto columnwise_scale_size = colwise_scale_meta.bytes(); auto scale_shape = rowwise_scale_meta.shape; @@ -348,13 +412,16 @@ void Tensor::to_cpu() const { cudaMemcpyDeviceToHost); } if (columnwise_) { + const DType colwise_type = tensor_.dtype(); + + const size_t colwise_size = bytes(s, colwise_type); (void)cudaMemcpy(cpu_data_columnwise_.get(), - tensor_.get_columnwise_data().data_ptr, - size, - cudaMemcpyDeviceToHost); + tensor_.get_columnwise_data().data_ptr, + colwise_size, + cudaMemcpyDeviceToHost); } - if (isFp8Type(dtype())) { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) { if (tensor_.amax() != nullptr){ (void)cudaMemcpy(amax_cpu_data_.get(), tensor_.amax(), @@ -366,8 +433,7 @@ void Tensor::to_cpu() const { sizeof(float), cudaMemcpyDeviceToHost); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = rowwise_scale_meta.bytes(); (void)cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), @@ -396,15 +462,15 @@ void Tensor::from_cpu() const { (void)cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); } - if (isFp8Type(dtype())) { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) { if (tensor_.amax() != nullptr){ (void)cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } (void)cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = rowwise_scale_meta.bytes(); (void)cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, @@ -421,7 +487,7 @@ void Tensor::from_cpu() const { } void Tensor::set_scale(float scale) { - if (isFp8Type(dtype())) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { NVTE_CHECK(scale_cpu_data_); if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { *scale_cpu_data_ = scale; @@ -431,7 +497,7 @@ void Tensor::set_scale(float scale) { } void Tensor::set_scale_inv(float scale_inv) { - if (isFp8Type(dtype())) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { if (rowwise_) { NVTE_CHECK(rowwise_scale_inv_cpu_data_); } @@ -439,8 +505,7 @@ void Tensor::set_scale_inv(float scale_inv) { NVTE_CHECK(columnwise_scale_inv_cpu_data_); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(tensor_.shape(), tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); if (rowwise_) { auto num_scales = product(rowwise_scale_meta.shape); if (num_scales == 1) { @@ -473,7 +538,8 @@ void Tensor::set_scale_inv(float scale_inv) { } void Tensor::shareFP8Meta(const Tensor &other) { - if (isFp8Type(dtype()) && isFp8Type(other.dtype())) { + if ((isFp8Type(dtype()) && isFp8Type(other.dtype())) + || isFp4Type(dtype()) && isFp4Type(other.dtype())) { auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); auto my_rowwise_data = tensor_.get_rowwise_data(); new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast(my_rowwise_data.dtype), @@ -686,13 +752,30 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t } } -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - std::vector &mismatch_indices, - size_t& mismatches_num, const size_t atol, - const double abs_tolerable_mismatches_limit, - const double rel_tolerable_mismatches_limit) +template +struct CastToType; + +template <> +struct CastToType { + using type = int; +}; + +template <> +struct CastToType { + using type = float; +}; + +template +void compare_scaling_factors(const std::string &name, const T *test, const T *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + std::vector &mismatch_indices, size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit) { + using UpcastType = typename CastToType::type; + auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3); + + const size_t N = row_blocks * col_blocks; const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, std::floor(N * rel_tolerable_mismatches_limit)); @@ -701,11 +784,31 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { const int idx = i * stride + j; - const int test_val = static_cast(test[idx]); - const int ref_val = static_cast(ref[idx]); - const int abs_delta = std::abs(test_val - ref_val); + float t, r; - if (abs_delta > atol) { + bool assertion = false; + + if (std::is_same::value) { + t = static_cast(test[idx]); + r = static_cast(ref[idx]); + assertion = std::abs(t - r) > atol; + } else { + t = static_cast(*reinterpret_cast(&test[idx])); + r = static_cast(*reinterpret_cast(&ref[idx])); + const bool mismatch = (fabs(t - r) > atol_fp8e4m3) + && (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3); + if (mismatch) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + } + if (assertion) { mismatches_num++; mismatch_indices.push_back(idx); } @@ -713,8 +816,8 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, std::cout << "Error in " << name << std::endl; for (const int index : mismatch_indices) { std::cout << "Mismatch at (" << index << "):" - << static_cast(test[index]) << " vs " - << static_cast(ref[index]) << std::endl; + << static_cast(test[index]) << " vs " + << static_cast(ref[index]) << std::endl; } GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " << tolerable_mismatches_limit << "."; @@ -723,6 +826,22 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, } } +// Instantiate templates +template +void compare_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + std::vector &mismatch_indices, size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit); + +template +void compare_scaling_factors(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + std::vector &mismatch_indices, size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit); + + #ifdef __HIP_PLATFORM_AMD__ void adjust_ref_for_e8m0_scale_error(const std::string &name, @@ -932,11 +1051,14 @@ bool isFp8Type(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; } -int32_t getDeviceComputeCapability() -{ - cudaDeviceProp deviceProp; - (void)cudaGetDeviceProperties(&deviceProp, 0); - return 10 * deviceProp.major + deviceProp.minor; +bool isFp4Type(DType type) { + return type == DType::kFloat4E2M1; +} + +int32_t getDeviceComputeCapability() { + cudaDeviceProp deviceProp; + (void)cudaGetDeviceProperties(&deviceProp, 0); + return 10 * deviceProp.major + deviceProp.minor; } size_t first_dimension(const std::vector &shape) { @@ -954,7 +1076,8 @@ std::array get_scale_tensor_dims(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols) { - const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32); + const bool is_rowwise = (block_size_rows == 1) + && ((block_size_cols == 32) || (block_size_cols == 16)); const size_t alignment_Y = is_rowwise ? scale_tensor_alignment_Y_rowwise diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 3c0a387c6..95aff8cf8 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -74,6 +74,8 @@ using fp8e5m2 = te_hip_fp8_e5m2; using fp8e8m0 = uint8_t; #if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; #endif template @@ -235,7 +237,9 @@ class Tensor { float scale() const { if(scale_cpu_data_) { - NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!"); + NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING), + "Invalid scaling_mode!"); to_cpu(); return *scale_cpu_data_; } else { @@ -249,6 +253,8 @@ class Tensor { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -262,6 +268,8 @@ class Tensor { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -316,10 +324,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_MANTISSA_BITS = 23; // [128,4] rowwise and [4,128] colwise alignment requirement -constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_rowwise = 128; -constexpr size_t scale_tensor_alignment_X_colwise = 128; +constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_colwise = 4; +constexpr size_t scale_tensor_alignment_X_colwise = 128; inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; @@ -480,12 +488,13 @@ void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t N, float mismatch_rate_tol = 0.); -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - std::vector &mismatch_indices, size_t& mismatches_num, - const size_t scale_diff_abs_tolerance = 0, - const double abs_tolerable_mismatches_limit = 0, - const double rel_tolerable_mismatches_limit = 0); +template +void compare_scaling_factors(const std::string &name, const T *test, const T *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + std::vector &mismatch_indices, size_t& mismatches_num, + const size_t scale_diff_abs_tolerance = 0, + const double abs_tolerable_mismatches_limit = 0, + const double rel_tolerable_mismatches_limit = 0); #ifdef USE_ROCM void adjust_ref_for_e8m0_scale_error(const std::string &name, @@ -516,6 +525,7 @@ const std::string& caseName(InputsFillCase type); extern std::vector all_fp_types; bool isFp8Type(DType type); +bool isFp4Type(DType type); int32_t getDeviceComputeCapability(); constexpr int32_t hopperComputeCapability = 90; @@ -593,7 +603,7 @@ constexpr int32_t blackwellComputeCapability = 100; SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ default: \ printf("dtype: %d\n", static_cast(dtype)); \ - NVTE_ERROR("Invalid type MARKED TEST."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ @@ -612,7 +622,7 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 2."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \ @@ -620,7 +630,7 @@ constexpr int32_t blackwellComputeCapability = 100; using namespace transformer_engine; \ SWITCH_FP4_HANDLE(type, __VA_ARGS__) \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 3."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ @@ -645,5 +655,5 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 4."); \ + NVTE_ERROR("Invalid type."); \ } diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 6ec3c27a4..0a75053b1 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -788,9 +788,15 @@ def _test_quantize_dact_dbias( assert_allclose(te_output.data, jax_output.data) if is_dbias: - # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. precise_comparison = not ( - in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling() + # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. + (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()) + # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently. + or ( + activation_type == ("squared_relu",) + and in_dtype == jnp.bfloat16 + and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + ) ) assert_allclose( te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 03c0d1119..bf01e3216 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -78,8 +78,6 @@ def generate_collectives_count_ref( all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize ) other_bytes = 0 - if fp8_recipe == recipe.Float8CurrentScaling(): - allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction return generate_collectives_count( allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes ) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 10bb066a4..5e2898cf3 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -15,94 +15,28 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( get_cu_seqlens_on_cp_rank, ) +from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize import transformer_engine_torch as tex from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn from transformer_engine.pytorch.fp8 import fp8_autocast -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer -from transformer_engine.common.recipe import DelayedScaling +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling import warnings +from utils import ModelConfig, compare_and_assert dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} -def run_dpa_with_cp( - dtype="bf16", - model=None, - qkv_format="bshd", - kernel_backend="FlashAttention", - cp_comm_type="p2p", - fp8_mha=False, +def generate_input_shapes( + qkv_format: str, + config: ModelConfig, + world_size: int, + kernel_backend: str, ): - """Test DotProductAttention module with context parallelism""" - - # args are passed as strings - fp8_mha = fp8_mha == "True" - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "0" - if kernel_backend == "FlashAttention": - os.environ["NVTE_FLASH_ATTN"] = "1" - config = model_configs_flash_attn[model] - if kernel_backend == "FusedAttention": - os.environ["NVTE_FUSED_ATTN"] = "1" - config = model_configs_fused_attn[model] - - assert config.attn_mask_type in [ - "causal", - "no_mask", - ], f"{config.attn_mask_type} is an unsupported attention mask type!" - if qkv_format == "thd": - if "causal" in config.attn_mask_type: - config.attn_mask_type = "padding_causal" - else: - config.attn_mask_type = "padding" - - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - else: - device_count = torch.cuda.device_count() - device = rank % device_count - torch.cuda.set_device(device) - - print(f"[INFO] world_size:{world_size}, rank:{rank}") - - dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) - - # create flash attn comm group for CP - cp_comm_ranks = range(world_size) - assert rank in cp_comm_ranks - cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") - if cp_comm_type == "a2a+p2p": - assert ( - world_size % 2 == 0 - ), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!" - cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] - cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] - cp_comm_sub_groups = [] - for sub_ranks in cp_comm_sub_ranks: - sub_group = dist.new_group(sub_ranks, backend="nccl") - if rank in sub_ranks: - cp_comm_sub_groups.append(sub_group) - - if dtype == "fp8": - fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha) - - # instantiate core attn module - core_attn = DotProductAttention( - config.num_heads, - (config.head_dim_qk, config.head_dim_v), - num_gqa_groups=config.num_gqa_groups, - attention_dropout=config.dropout_p, - qkv_format=qkv_format, - attn_mask_type=config.attn_mask_type, - window_size=config.window_size, - ) - core_attn = core_attn.cuda() - - # create flash attn inputs if qkv_format == "bshd": q_input_shape = ( config.batch_size, @@ -197,35 +131,192 @@ def run_dpa_with_cp( cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv_padded = cu_seqlens_q_padded else: - assert False, f"{qkv_format} is an unsupported qkv_format!" - - q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() - k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda() - v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda() - dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() - dout_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - scale=torch.tensor([1], dtype=torch.float32).cuda(), - amax=torch.tensor([0], dtype=torch.float32).cuda(), + assert False, f"{qkv_format=} is not supported!" + + return ( + q_input_shape, + k_input_shape, + v_input_shape, + attn_output_shape, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, ) - # create flash attention bias + +def get_tols(config, dtype): + if dtype == "bf16": + if config.num_heads == config.num_gqa_groups: + atol = 2.5e-2 + rtol = 2.5e-2 + else: + atol = 3.5e-2 + rtol = 3.5e-2 + rmse_tol = 0.01 + elif dtype == "fp16": + atol = 5e-3 + rtol = 5e-3 + rmse_tol = 0.01 + elif dtype == "fp8": + atol = 5e-1 + rtol = 5e-1 + rmse_tol = 0.15 + else: + assert False, f"{dtype=} is not supported!" + + return atol, rtol, rmse_tol + + +def run_dpa_with_cp( + dtype="bf16", + model=None, + qkv_format="bshd", + kernel_backend="FlashAttention", + cp_comm_type="p2p", + fp8_bwd="True", + fp8_dpa="False", + fp8_mha="False", + scaling_mode="delayed", + f16_O="False", + log_level=logging.WARNING, +): + """Test DotProductAttention module with context parallelism""" + logging.root.setLevel(log_level) + + # set up environment variables and config + fp8_bwd = fp8_bwd == "True" and dtype == "fp8" + os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0" + fp8_dpa = fp8_dpa == "True" and dtype == "fp8" + fp8_mha = fp8_mha == "True" and dtype == "fp8" + f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True" + os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0" + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + if kernel_backend == "FlashAttention": + os.environ["NVTE_FLASH_ATTN"] = "1" + config = model_configs_flash_attn[model] + if kernel_backend == "FusedAttention": + os.environ["NVTE_FUSED_ATTN"] = "1" + config = model_configs_fused_attn[model] + assert config.attn_mask_type in [ + "causal", + "no_mask", + ], f"{config.attn_mask_type=} is not supported!" + if qkv_format == "thd": + if "causal" in config.attn_mask_type: + config.attn_mask_type = "padding_causal" + else: + config.attn_mask_type = "padding" + + # set up distributed group + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + device_count = torch.cuda.device_count() + device = rank % device_count + torch.cuda.set_device(device) + logging.info(f"[Rank {rank}] Setup: world_size {world_size}") + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + + # set up communication group for CP + cp_comm_ranks = range(world_size) + assert rank in cp_comm_ranks + cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + if cp_comm_type == "a2a+p2p": + assert world_size % 2 == 0, ( + "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size" + " = 2." + ) + cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] + cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] + cp_comm_sub_groups = [] + for sub_ranks in cp_comm_sub_ranks: + sub_group = dist.new_group(sub_ranks, backend="nccl") + if rank in sub_ranks: + cp_comm_sub_groups.append(sub_group) + + if dtype == "fp8": + if scaling_mode == "delayed": + fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "current": + fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + + # instantiate attention module + core_attn = DotProductAttention( + config.num_heads, + (config.head_dim_qk, config.head_dim_v), + num_gqa_groups=config.num_gqa_groups, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=config.attn_mask_type, + window_size=config.window_size, + softmax_type=config.softmax_type, + ).cuda() + if config.softmax_type != "vanilla": + core_attn.softmax_offset.requires_grad = True + + # generate attention inputs + ( + q_input_shape, + k_input_shape, + v_input_shape, + attn_output_shape, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) + q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + dout_orig = torch.clamp( + torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1 + ).cuda() + if scaling_mode == "delayed": + qkv_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) + dout_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) + if scaling_mode == "current": + qkv_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + dout_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + device="cuda", + ) + qkv_layout = "_".join([qkv_format] * 3) + q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] + if fp8_mha: + q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + for x in [q, k, v]: + x.requires_grad = True + if config.attn_bias_type not in ["no_bias", "alibi"]: attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() else: bias = None - # run core_attn without CP - for x in [q, k, v]: - x.requires_grad = True - + ############ run without CP ############ + logging.info(f"[Rank {rank}] Run without context parallelism") if dtype == "fp8": fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) else: fp8_context = nullcontext() - with fp8_context: + # q, k, v, out in FP8; dout in F16 out = core_attn( q, k, @@ -238,16 +329,25 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] ), + fp8_output=fp8_mha, ) - if fp8_mha: + if fp8_bwd and fp8_mha: dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) else: out.backward(dout) + dq, dk, dv = q.grad, k.grad, v.grad + d_softmax_offset = None + if config.softmax_type != "vanilla": + d_softmax_offset = core_attn.softmax_offset.grad + + ############ run with CP ############ + logging.info(f"[Rank {rank}] Run with context parallelism") - # run core_attn wit CP + # set up inputs q_, k_, v_, dout_, *rest = [ - x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias]) + x.clone().detach() + for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias]) ] bias_ = rest[0] if len(rest) else None if qkv_format == "bshd" or qkv_format == "sbhd": @@ -277,6 +377,14 @@ def run_dpa_with_cp( k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] else: assert False, f"{qkv_format} is an unsupported qkv_format!" + q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]] + if scaling_mode == "delayed": + qkv_quantizer.scale.fill_(1.0) + qkv_quantizer.amax.fill_(0.0) + dout_quantizer.scale.fill_(1.0) + dout_quantizer.amax.fill_(0.0) + if fp8_mha: + q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: bias_ = bias_.view( @@ -284,20 +392,25 @@ def run_dpa_with_cp( ) bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) + # set up environment core_attn.set_context_parallel_group( cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type, ) - + if config.softmax_type != "vanilla": + core_attn.softmax_offset.grad.zero_() if dtype == "fp8": - core_attn.reset_fp8_meta_tensors() + core_attn.fp8_initialized = False + core_attn.fp8_meta_tensors_initialized = False fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) else: fp8_context = nullcontext() + # run attention with fp8_context: + # q, k, v, out in FP8; dout in F16 out_ = core_attn( q_, k_, @@ -310,24 +423,32 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] ), + fp8_output=fp8_mha, ) - if fp8_mha: + if fp8_bwd and fp8_mha: dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) else: out_.backward(dout_) + dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad + d_softmax_offset_ = None + if config.softmax_type != "vanilla": + d_softmax_offset_ = core_attn.softmax_offset.grad.clone() + # get outputs + tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_] if fp8_mha: - assert isinstance(out, Float8Tensor) - assert isinstance(out_, Float8Tensor) - out = out.dequantize() - out_ = out_.dequantize() - - for x in [out_, q_.grad, k_.grad, v_.grad]: - assert torch.all(~torch.isnan(x)) - assert torch.all(~torch.isinf(x)) - - # compare results with and without CP + tensors_to_deq = [out, out_] if not fp8_bwd else tensors + for i, tensor in enumerate(tensors_to_deq): + tensors_to_deq[i] = tensor.dequantize() + if not fp8_bwd: + tensors[0], tensors[4] = tensors_to_deq + for tensor in tensors: + assert torch.all(~torch.isnan(tensor)) + assert torch.all(~torch.isinf(tensor)) + out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors + + ############ compare results between CP and no-CP ############ if qkv_format == "bshd" or qkv_format == "sbhd": dq, dk, dv, out = [ x.view( @@ -336,17 +457,17 @@ def run_dpa_with_cp( x.shape[seq_dim] // (2 * world_size), *x.shape[(seq_dim + 1) :], ) - for x in [q.grad, k.grad, v.grad, out] + for x in [dq, dk, dv, out] ] dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] dq_, dk_, dv_, out_ = [ x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) - for x in [q_.grad, k_.grad, v_.grad, out_] + for x in [dq_, dk_, dv_, out_] ] elif qkv_format == "thd": - dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]] - dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_] + dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] + dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] + dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True @@ -389,56 +510,70 @@ def run_dpa_with_cp( ).item() == 0 ) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" - - if dtype == "bf16": - if config.num_heads == config.num_gqa_groups: - tols = dict(atol=2.5e-2, rtol=2.5e-2) - else: - tols = dict(atol=3.5e-2, rtol=3.5e-2) - elif dtype == "fp16": - tols = dict(atol=5e-3, rtol=5e-3) - elif dtype == "fp8": - tols = dict(atol=5e-1, rtol=5e-1) - rmse_tol = 0.1 - else: - assert False, f"{dtype} is an unsupported dtype!" - - def _rmse(a, b): - return torch.sqrt((a - b).square().mean()).item() - def _error(a, b): - if dtype != "fp8": - torch.testing.assert_close(a, b, **tols) - else: - try: - torch.testing.assert_close(a, b, **tols) - except Exception as e: - logging.debug(e) - - rmse = _rmse(a, b) - rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) - assert ( - rmse < rmse_tol * rmse_range - ), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - rmse, rmse_tol * rmse_range, rmse_tol, rmse_range - ) - - if qkv_format == "bshd": - for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): - _error(a[:, 0], b[:, 0]) - _error(a[:, 1], b[:, 1]) - elif qkv_format == "sbhd": - for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): - _error(a[0], b[0]) - _error(a[1], b[1]) - elif qkv_format == "thd": - for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): - _error(a, b) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" + atol, rtol, rmse_tol = get_tols(config, dtype) + tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_] + tensors_no_cp = [out, dq, dk, dv, d_softmax_offset] + names = ["out", "dq", "dk", "dv", "d_softmax_offset"] + names_cp = [x + "_cp" for x in names] + names_no_cp = [x + "_no_cp" for x in names] + is_fp8 = dtype == "fp8" + for i, t in enumerate(tensors_no_cp): + if t is not None: + if "softmax_offset" not in names[i]: + if qkv_format == "bshd": + compare_and_assert( + t[:, 0], + tensors_cp[i][:, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, 1], + tensors_cp[i][:, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "sbhd": + compare_and_assert( + t[0], + tensors_cp[i][0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[1], + tensors_cp[i][1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "thd": + compare_and_assert( + t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + ) + else: + compare_and_assert( + t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + ) + logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") + # destroy distribution group dist.destroy_process_group() diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index a5128653e..08f3b0c6b 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -4,7 +4,6 @@ # # See LICENSE for license information. import logging -import math import os import sys import pathlib @@ -54,19 +53,21 @@ sys.path.append(str(_current_file.parent.parent)) from utils import ( reset_rng_states, + compare_and_assert, ModelConfig, dtype_tols, get_available_attention_backends, ) -# Only run FP8 tests on H100 +# Check if hardware supports FP8 fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() +# Reset RNG seed and states seed = 1234 -# Reset RNG states reset_rng_states() +# Reset FP8 global state manager @pytest.fixture(autouse=True) def reset_global_fp8_state(): yield @@ -82,9 +83,14 @@ def reset_attn_backend(): "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", "NVTE_FP8_DPA_BWD"]) yield +# Define F16 data types to test +param_types = [torch.float16] +if is_bf16_compatible(): + param_types.append(torch.bfloat16) +param_types_lean = [torch.bfloat16] model_configs_base = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "base_1_0": ModelConfig(8, 128, 16, 64), "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256), "base_2_0": ModelConfig(2, 2048, 24, 128), @@ -100,11 +106,6 @@ def reset_attn_backend(): } -param_types = [torch.float16] -if is_bf16_compatible(): # bf16 requires sm_80 or higher - param_types.append(torch.bfloat16) -param_types_lean = [torch.bfloat16] - # TODO: Enable config support in other backend(s) -- currently only the CK # backend is capable of supporting it. @pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") @@ -121,7 +122,6 @@ def test_gqa_mla_thd(): config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=True, ) if FusedAttnBackend["CK"] not in fused_attn_backends: @@ -147,7 +147,6 @@ def test_dot_product_mem_calc(): config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=pad_between_seqs, is_training=is_training, ) @@ -204,12 +203,12 @@ def test_dot_product_attention( config.window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) + # Get backends is_training = True available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=pad_between_seqs, is_training=is_training, ) @@ -220,7 +219,6 @@ def test_dot_product_attention( config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=pad_between_seqs, is_training=is_training, ) @@ -332,6 +330,7 @@ def test_dot_product_attention( share_cu_seqlens_ref, ) + # Compare results logging.info(f"[test_dot_product_attention]: is_training = {is_training}") if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") @@ -369,23 +368,102 @@ def test_dpa_checkpoint(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False, False) +model_configs_softmax = { + # test: ModelConfig(b, sq, hq, dqk) + "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), + "softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"), + "softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"), + "softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), + "softmax_2_1": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one" + ), + "softmax_2_2": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + ), + "softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"), + "softmax_3_1": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one" + ), + "softmax_3_2": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable" + ), + "softmax_4_0": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal" + ), + "softmax_4_1": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="causal", + softmax_type="off-by-one", + ), + "softmax_4_2": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="causal", + softmax_type="learnable", + ), + "softmax_5_0": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal" + ), + "softmax_5_1": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="padding_causal", + softmax_type="off-by-one", + ), + "softmax_5_2": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="padding_causal", + softmax_type="learnable", + ), +} + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_softmax]) +@pytest.mark.parametrize("model", model_configs_softmax.keys()) +def test_dpa_softmax(dtype, model_configs, model): + """Test DotProductAttention module with different softmax types""" + test_dot_product_attention( + dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False, False + ) + + model_configs_mla = { - # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend - "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0 - "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0 - "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0 - "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1 + # test: ModelConfig(b, sq, hq, dqk) + "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), + "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), + "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), + "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), "mla_2_1": ModelConfig( 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64 - ), # cross, 1 + ), "mla_2_2": ModelConfig( 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 - ), # cross, 1 - "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference - "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference - "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference - "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference - "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference + ), + "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), + "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), + "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), + "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), + "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), } @@ -399,7 +477,7 @@ def test_dpa_mla(dtype, model_configs, model): model_configs_mask = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"), "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"), @@ -454,18 +532,16 @@ def test_dpa_mask(dtype, model_configs, model): model_configs_bias = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"), "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"), "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"), - "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped - "bias_1_5": ModelConfig( - 2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi" - ), # skipped + "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), + "bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"), "bias_2_0": ModelConfig( 4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_2_1": ModelConfig( 2, 128, @@ -474,10 +550,10 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=256, attn_mask_type="padding", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_2_2": ModelConfig( 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_2_3": ModelConfig( 2, 2048, @@ -486,13 +562,11 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="post_scale_bias", - ), # skipped - "bias_2_4": ModelConfig( - 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi" - ), # skipped + ), + "bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"), "bias_2_5": ModelConfig( 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi" - ), # skipped + ), "bias_3_0": ModelConfig( 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), @@ -510,14 +584,14 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"), "bias_3_5": ModelConfig( 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi" - ), # skipped + ), "bias_4_0": ModelConfig( 4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_4_1": ModelConfig( 2, 128, @@ -526,10 +600,10 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=256, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_4_2": ModelConfig( 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_4_3": ModelConfig( 2, 2048, @@ -538,10 +612,10 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_4_4": ModelConfig( 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi" - ), # skipped + ), "bias_4_5": ModelConfig( 2, 2048, @@ -550,7 +624,7 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="padding_causal", attn_bias_type="alibi", - ), # skipped + ), } @@ -564,7 +638,7 @@ def test_dpa_bias(dtype, model_configs, model): model_configs_bias_shapes = { - # test: b, h, hg, d, sq, skv, p, + # test: ModelConfig(b, sq, hq, dqk) "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"), "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"), @@ -602,7 +676,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "swa_1_1": ModelConfig(2, 2048, 16, 64), "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4), "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096), @@ -642,7 +716,7 @@ def test_dpa_sliding_window(dtype, model_configs, model): model_configs_alibi_slopes = { - # test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type + # test: ModelConfig(b, sq, hq, dqk) "alibi_1_0": ModelConfig( 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla" ), @@ -696,7 +770,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): model_configs_layout = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "layout_0_0": ModelConfig(2, 128, 16, 64), "layout_0_1": ModelConfig( 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" @@ -744,7 +818,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"), "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), @@ -824,7 +898,6 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout, pad_between config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=pad_between_seqs, ) if FusedAttnBackend["CK"] not in fused_attn_backends: @@ -848,7 +921,6 @@ def test_dpa_qkv_layout_thd_mqa_gqa(dtype, model_configs, model, qkv_layout, pad config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=pad_between_seqs, ) if FusedAttnBackend["CK"] not in fused_attn_backends: @@ -882,7 +954,6 @@ def _run_dot_product_attention( share_cu_seqlens_ref: bool = False, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Run DotProductAttention module with one forward pass and one backward pass""" - # Set RNG and environment varables reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" @@ -1145,9 +1216,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: tp_group=None, layer_number=1, attention_type=config.attn_type, + softmax_type=config.softmax_type, ).to(dtype=dtype, device="cuda") if not is_training: block = block.eval() + if is_training and config.softmax_type != "vanilla": + block.softmax_offset.requires_grad = True cu_seqlens_q_padded = None cu_seqlens_kv_padded = None @@ -1188,12 +1262,14 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: ) if is_training: out.backward(d_out) - + d_softmax_offset = None + if is_training and config.softmax_type != "vanilla": + d_softmax_offset = block.softmax_offset.grad if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if is_training: - return out, (q.grad, k.grad, v.grad) + return out, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return out, (None, None, None) + return out, (None, None, None, d_softmax_offset) if backend == "FusedAttention": if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) @@ -1222,18 +1298,18 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 ) if is_training: - return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig) + return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset) else: - return out_orig, (None, None, None) + return out_orig, (None, None, None, d_softmax_offset) else: if is_training: - return out, (q.grad, k.grad, v.grad) + return out, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return out, (None, None, None) + return out, (None, None, None, d_softmax_offset) model_configs_te_layer = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), "te_1_1": ModelConfig( 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" @@ -1598,6 +1674,7 @@ def _run_transformer_layer( model_configs_fp8_extra_state = { + # test: ModelConfig(b, sq, hq, dqk) "large": ModelConfig(2, 128, 4, 128, num_layers=1), } @@ -1607,7 +1684,8 @@ def _run_transformer_layer( @pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") @pytest.mark.parametrize("model", ["large"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_sanity_attention_extra_state(model, dtype): +def test_dpa_fp8_extra_state(model, dtype): + """Test DotProductAttention module in FP8 with checkpointing""" config = model_configs_fp8_extra_state[model] # Test backend availability is_training = True @@ -1621,9 +1699,9 @@ def test_sanity_attention_extra_state(model, dtype): if not fused_attn_supported and not flash_attn_supported: pytest.skip("No attention backend available.") - outputs = _run_attention_extra_state(dtype, config, checkpoint=False) - outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) - outputs_checkpoint_v1_6 = _run_attention_extra_state( + outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False) + outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True) + outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state( dtype, config, mimic_v1_6=True, checkpoint=True ) @@ -1645,7 +1723,8 @@ def test_sanity_attention_extra_state(model, dtype): ) -def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): +def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): + """Run DotProductAttention module in FP8 with checkpointing""" steps = 10 path = "checkpoint.pt" fp8_enabled = True @@ -1742,7 +1821,7 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig(2, 2048, 16, 128), "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), @@ -1800,22 +1879,44 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training): +@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +def test_mha_fp8_vs_f16( + dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode +): + """Test MultiHeadAttention module in FP8""" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] # Test backend availability + if scaling_mode == "delayed": + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=True, + fp8_mha=True, + ) + elif scaling_mode == "current": + fp8_recipe = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.HYBRID, + fp8_dpa=True, + fp8_mha=True, + ) + fp8_meta = {} + fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_format.replace("hd", "h3d"), + fp8=True, + fp8_meta=fp8_meta, is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - # Skip if only unfused backend is supported - if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: - pytest.skip("Less than two backends to compare.") + if flash_attn_supported + fused_attn_supported < 1: + pytest.skip("No FP8 attention backend available.") if not fp8_dpa_bwd: available_backends, _, fused_attn_backends = get_available_attention_backends( config, @@ -1833,7 +1934,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm, RoPE, is_training + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe ) os.environ["NVTE_FLASH_ATTN"] = "0" @@ -1841,20 +1942,21 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm, RoPE, is_training + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe ) logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( - dtype, config, False, qkv_format, input_layernorm, RoPE, is_training + dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe ) atol = 5e-1 rtol = 5e-1 rmse_tol = 0.15 - logging.debug("========== {:^25s} ==========".format("forward output")) if flash_attn_supported: - _error( + logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( flash_attn_fwd_fp8, fused_attn_fwd_f16, "flash_attn_fwd_fp8", @@ -1862,8 +1964,11 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, atol, rtol, rmse_tol, + True, ) - _error( + logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( fused_attn_fwd_fp8, fused_attn_fwd_f16, "fused_attn_fwd_fp8", @@ -1871,12 +1976,13 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, atol, rtol, rmse_tol, + True, ) if is_training: for i in range(len(param_names[:1])): logging.debug("========== {:^25s} ==========".format(param_names[i])) - _error( + compare_and_assert( fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], f"fused_attn_bwd_fp8[{i}]", @@ -1884,10 +1990,14 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, atol, rtol, rmse_tol, + True, ) -def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training): +def _run_mha_fp8_vs_f16( + dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe +): + """Run MultiHeadAttention module in FP8""" reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -1896,15 +2006,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - fp8_recipe = recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - fp8_dpa=fp8_mha, - fp8_mha=fp8_mha, - ) - with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): rotary_pos_emb = None if RoPE: @@ -2014,7 +2115,9 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): +@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode): + """Test DotProductAttention module in FP8""" config = model_configs_fp8_vs_f16[model] # TODO(cyang): think of another way to verify dropout results @@ -2029,16 +2132,33 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" + os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1" # Test backend availability + if scaling_mode == "delayed": + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=True, + ) + elif scaling_mode == "current": + fp8_recipe = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.HYBRID, + fp8_dpa=True, + ) + fp8_meta = {} + fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_layout, + fp8=True, + fp8_meta=fp8_meta, is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - # Skip if only unfused backend is supported if flash_attn_supported + fused_attn_supported < 1: pytest.skip("No FP8 attention backend available.") if not fp8_dpa_bwd: @@ -2058,33 +2178,45 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)") flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout, is_training + dtype, config, True, qkv_layout, is_training, fp8_recipe + ) + + if unfused_attn_supported: + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") + unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout, is_training, fp8_recipe ) os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout, is_training + dtype, config, True, qkv_layout, is_training, fp8_recipe ) + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" if config.dropout_p == 0.0: # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - dtype, config, False, qkv_layout, is_training + dtype, config, False, qkv_layout, is_training, fp8_recipe ) atol = 5e-1 rtol = 5e-2 rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] - logging.debug("========== {:^25s} ==========".format("forward output")) if flash_attn_supported: - _error( + logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( flash_attn_fwd_fp8, fused_attn_fwd_f16, "flash_attn_fwd_fp8", @@ -2092,14 +2224,43 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): atol, rtol, rmse_tol, + True, ) + if unfused_attn_supported: + logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( + unfused_attn_fwd_fp8, + fused_attn_fwd_f16, + "unfused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, + True, + ) + if is_training: + for i, _ in enumerate(fused_attn_bwd_f16): + logging.debug("========== {:^25s} ==========".format(bwd_names[i])) + compare_and_assert( + unfused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"unfused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, + True, + ) if config.dropout_p != 0.0: # test cuDNN FP8 dropout assert torch.all( fused_attn_fwd_fp8 == 1 ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." else: - _error( + logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( fused_attn_fwd_fp8, fused_attn_fwd_f16, "fused_attn_fwd_fp8", @@ -2107,11 +2268,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): atol, rtol, rmse_tol, + True, ) if is_training: for i, _ in enumerate(fused_attn_bwd_f16): logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - _error( + compare_and_assert( fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], f"fused_attn_bwd_fp8[{i}]", @@ -2119,11 +2281,13 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): atol, rtol, rmse_tol, + True, ) + os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0" -def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): - +def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_recipe): + """Run DotProductAttention module in FP8""" reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -2132,14 +2296,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - fp8_recipe = recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - fp8_dpa=fp8_dpa, - ) - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) with fp8_model_init(enabled=fp8_dpa): dpa = DotProductAttention( @@ -2246,6 +2402,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, + fp8_output=fp8_dpa, ) if is_training: out.backward(out_grad) @@ -2256,7 +2413,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: model_configs_fp8 = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "fp8_1": ModelConfig(1, 512, 1, 64), "fp8_2": ModelConfig(4, 512, 16, 64), "fp8_3": ModelConfig(1, 2048, 1, 128), @@ -2312,7 +2469,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol = 5e-1 rtol = 5e-1 rmse_tol = 0.13 - _error( + compare_and_assert( fused_attn_fwd_fp8, unfused_attn_fwd_f16, "fused_attn_fwd_fp8", @@ -2320,8 +2477,9 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol, rtol, rmse_tol, + True, ) - _error( + compare_and_assert( fused_attn_bwd_fp8, unfused_attn_bwd_f16, "fused_attn_bwd_fp8", @@ -2329,6 +2487,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol, rtol, rmse_tol, + True, ) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index ece5a37de..c9fc86878 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -8,6 +8,7 @@ import subprocess import sys import pathlib +import logging import pytest import torch @@ -16,19 +17,27 @@ get_device_compute_capability, get_cudnn_version, ) +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, +) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) from utils import ModelConfig, get_available_attention_backends +pytest_logging_level = logging.getLevelName(logging.root.level) + # Initialize RNG state seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) +test_essential = True + model_configs_flash_attn = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA @@ -63,18 +72,31 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): return args +dtypes = ["bf16", "fp16"] +qkv_formats = ["bshd", "sbhd", "thd"] +cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] +if test_essential: + configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"] + model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs} + dtypes = ["bf16"] + qkv_formats = ["sbhd", "thd"] + + @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) -@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) +@pytest.mark.parametrize("qkv_format", qkv_formats) +@pytest.mark.parametrize("cp_comm_type", cp_comm_types) def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") config = model_configs_flash_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type + if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and qkv_format == "thd": @@ -104,12 +126,13 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_format=qkv_format, kernel_backend="FlashAttention", cp_comm_type=cp_comm_type, + log_level=pytest_logging_level, ), check=True, ) model_configs_fused_attn = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_2": ModelConfig( @@ -140,17 +163,42 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA "cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA + "cp_4_0": ModelConfig( + 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla" + ), # GQA + "cp_4_1": ModelConfig( + 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one" + ), # GQA + "cp_4_2": ModelConfig( + 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + ), # GQA } +dtypes = ["bf16", "fp16", "fp8"] +qkv_formats = ["bshd", "sbhd", "thd"] +cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] +if test_essential: + configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] + model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} + dtypes = ["bf16", "fp8"] + qkv_formats = ["sbhd", "thd"] + + @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) -@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) -@pytest.mark.parametrize("fp8_mha", [False, True]) -def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha): +@pytest.mark.parametrize("qkv_format", qkv_formats) +@pytest.mark.parametrize("cp_comm_type", cp_comm_types) +@pytest.mark.parametrize("fp8_bwd", [True, False]) +@pytest.mark.parametrize("fp8_mha", [True, False]) +@pytest.mark.parametrize("fp8_dpa", [True, False]) +@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"]) +@pytest.mark.parametrize("f16_O", [True, False]) +def test_cp_with_fused_attention( + dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O +): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") @@ -161,10 +209,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") if (not IS_HIP_EXTENSION) and dtype == "fp8" and get_device_compute_capability() < (9, 0): pytest.skip("FP8 attention is only supported on sm90+!") + if dtype == "fp8" and not fp8_dpa and fp8_mha: + pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!") + if dtype != "fp8" and fp8_bwd: + pytest.skip("Only fp8 works with fp8_bwd=True!") if IS_HIP_EXTENSION and dtype == "fp8": pytest.skip("FP8 attention has not been supported on ROCm yet!") config = model_configs_fused_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type + if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") if qkv_format == "thd" and cp_comm_type == "all_gather": @@ -192,19 +247,57 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) - if dtype != "fp8" and fp8_mha: - pytest.skip("Only fp8 works with fp8_mha=True!") + if dtype != "fp8" and (fp8_mha or fp8_dpa): + pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!") + if dtype == "fp8" and not (fp8_mha or fp8_dpa): + pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!") + if dtype != "fp8" and scaling_mode is not None: + pytest.skip("Only fp8 works with scaling_mode != None!") + if dtype == "fp8" and scaling_mode is None: + pytest.skip("fp8 only works with scaling_mode != None!") + if ( + dtype == "fp8" + and scaling_mode == "current" + and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] + ): + pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") + if f16_O and (dtype != "fp8" or scaling_mode != "current"): + pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently does not support FP8 attention!") + if dtype == "fp8" and config.softmax_type != "vanilla": + pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!") + if config.softmax_type != "vanilla" and cp_comm_type != "a2a": + pytest.skip( + "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" + ) + if config.softmax_type != "vanilla" and qkv_format == "thd": + pytest.skip( + "CP implementation does not support qkv_format=thd for non-vanilla softmax types!" + ) + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} + fp8_meta = {} + fp8_meta["recipe"] = None + fp8_meta["local_recipes"] = [] + fp8 = dtype == "fp8" and (fp8_dpa or fp8_mha) + if fp8 and scaling_mode == "delayed": + fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True) + fp8_meta["local_recipes"] = [DelayedScaling(fp8_dpa=True)] + if fp8 and scaling_mode == "current": + fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True) + fp8_meta["local_recipes"] = [ + Float8CurrentScaling(fp8_dpa=True), + DelayedScaling(fp8_dpa=True), + ] available_backends, _, fused_attn_backends = get_available_attention_backends( config, - qkv_dtype=dtypes[dtype], + qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, qkv_layout="_".join([qkv_format] * 3), - window_size=config.window_size, - context_parallel=True, + fp8=fp8, + fp8_meta=fp8_meta, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -218,7 +311,12 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha qkv_format=qkv_format, kernel_backend="FusedAttention", cp_comm_type=cp_comm_type, + fp8_bwd=fp8_bwd, + fp8_dpa=fp8_dpa, fp8_mha=fp8_mha, + scaling_mode=scaling_mode, + f16_O=f16_O, + log_level=pytest_logging_level, ), check=True, ) diff --git a/tests/pytorch/attention/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py index af71866f3..8bd8906c5 100644 --- a/tests/pytorch/attention/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -483,7 +483,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=False, is_training=False, fp8=is_fp8, diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 8a201b72d..f68ddce7f 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -11,6 +11,7 @@ import os import sys from functools import wraps +import math import transformer_engine.pytorch as te import torch @@ -23,10 +24,15 @@ DelayedScaling, Float8CurrentScaling, Float8BlockScaling, + NVFP4BlockScaling, Format, Recipe, + QParams, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE +from transformer_engine.pytorch.distributed import gather_along_first_dim from run_layer_with_overlap import _compare_tensors if IS_HIP_EXTENSION: @@ -53,6 +59,14 @@ ) +def nvfp4_vanilla(): + nvfp4_recipe = NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = QParams() + nvfp4_recipe.fp4_quant_fwd_weight = QParams() + nvfp4_recipe.fp4_quant_bwd_grad = QParams() + return nvfp4_recipe + + # Quantization recipe setup def quantization_recipe() -> Recipe: if QUANTIZATION == "fp8": @@ -65,6 +79,8 @@ def quantization_recipe() -> Recipe: return Float8CurrentScaling() if QUANTIZATION == "fp8_block_scaling": return Float8BlockScaling() + if QUANTIZATION == "nvfp4": + return nvfp4_vanilla() return te.fp8.get_default_fp8_recipe() @@ -113,10 +129,14 @@ def main(argv=None, namespace=None): # Quantization scheme QUANTIZATION = args.quantization global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE - if QUANTIZATION in ("fp8", "mxfp8"): + if QUANTIZATION in ("fp8", "mxfp8", "nvfp4"): SEQ_LEN = 32 BATCH_SIZE = 32 HIDDEN_SIZE = 128 + # For fp8 block scaling, block size is 128, + # and to make low precision TP work, input tensor + # must be 128x128 divisible to be eligible for + # low precision All-Gather when needed elif QUANTIZATION == "fp8_block_scaling": SEQ_LEN = 128 BATCH_SIZE = 128 @@ -124,6 +144,7 @@ def main(argv=None, namespace=None): test_dict = [ test_quantizer, + test_quantized_all_gather, test_linear, test_layernorm, test_layernorm_linear, @@ -193,6 +214,9 @@ def _get_tolerances(dtype): # row parallel & sequence parallel, because we do the all_gather in backward pass if QUANTIZATION == "fp8_cs": return {"rtol": 0.4, "atol": 0.25} + elif QUANTIZATION == "nvfp4": + # TODO(zhongboz): investigate why the tolerance is so large + return {"rtol": 0.125, "atol": 0.12} elif QUANTIZATION is not None: return {"rtol": 0.125, "atol": 0.0625} @@ -348,24 +372,36 @@ def _alloc_main_grad(model_single_node, model_distributed): ############################################### # Quantizer # ############################################### -def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size): +def _construct_quantizer(quantizer_class, low_precision_dtype, device, tp_group, tp_size): """ quantizer is the reference quantizer on a single GPU. quantizer_dist is the distributed quantizer to be tested on multiple GPUs. """ if quantizer_class == Float8CurrentScalingQuantizer: quantizer_dist = quantizer_class( - fp8_dtype=fp8_dtype, + fp8_dtype=low_precision_dtype, device=device, with_amax_reduction=True, amax_reduction_group=tp_group, ) quantizer = quantizer_class( - fp8_dtype=fp8_dtype, + fp8_dtype=low_precision_dtype, device=device, with_amax_reduction=False, ) return quantizer, quantizer_dist + elif quantizer_class == NVFP4Quantizer: + quantizer_dist = quantizer_class( + fp4_dtype=low_precision_dtype, + with_amax_reduction=True, + amax_reduction_group=tp_group, + ) + quantizer = quantizer_class( + fp4_dtype=low_precision_dtype, + with_amax_reduction=False, + amax_reduction_group=None, + ) + return quantizer, quantizer_dist else: raise ValueError(f"Unsupported quantizer class: {quantizer_class}") @@ -436,6 +472,194 @@ def test_quantizer(): _test_quantizer(input_dtype, fp8_dtype) +############################################ +# Quantized All-Gather # +############################################ + + +def _ref_zero_padding_scale_inv(scale_inv, unpadded_shape): + """ + Zero padding the scale_inv. + scale_inv shape is the padded shape, but not zero padded + unpadded_shape is the original shape before padding + """ + dim0, dim1 = scale_inv.shape + unpadded_dim0, unpadded_dim1 = unpadded_shape + pad_dim0 = (128 - unpadded_dim0 % 128) % 128 + pad_dim1 = (4 - unpadded_dim1 % 4) % 4 + new_dim0 = unpadded_dim0 + pad_dim0 + new_dim1 = unpadded_dim1 + pad_dim1 + + assert dim0 == new_dim0 + assert dim1 == new_dim1 + + # return input if no padding is needed + if pad_dim0 == 0 and pad_dim1 == 0: + return scale_inv + + # unpad first to remove random bits from torch empty + scale_inv = scale_inv[:unpadded_dim0, :unpadded_dim1].contiguous() + # using torch padding + new_scale_inv = torch.nn.functional.pad( + scale_inv, (0, pad_dim1, 0, pad_dim0), mode="constant", value=0 + ) + + assert new_scale_inv.shape == (new_dim0, new_dim1) + + return new_scale_inv + + +def _get_unpadded_scale_inv_shape(input_shape, quantizer_cls, columnwise): + """ + Calculate the unpadded shape of the scale_inv tensor. + """ + M, K = 1, 1 + M = math.prod(input_shape[:-1]) + K = input_shape[-1] + + if quantizer_cls == NVFP4Quantizer: + if columnwise: + outer = K + inner = math.ceil(M / NVFP4_BLOCK_SCALING_SIZE) + return (outer, inner) + else: + outer = M + inner = math.ceil(K / NVFP4_BLOCK_SCALING_SIZE) + return (outer, inner) + else: + raise ValueError(f"Unsupported quantizer class: {quantizer_cls}") + + +@run_distributed_test() +def _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls): + """Test the quantizer under distributed settings. + + Args: + input_dtype (torch.dtype): The data type of the input. + low_precision_dtype (tex.DType): The data type of the low precision, can be fp4 or fp8. + """ + + M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE // 2 + + # high precision input + x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype) + # set one element of the input to a very large value, which doesn't live in rank 0 after the split + # to test the amax reduction on purpose + # x_hp_cpu[M - 1, N - 1] = 1e4 + + # get the unpadded shapes + unpadded_rowwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, False) + unpadded_columnwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, True) + + # rank 0 takes the full copy and quantize with GPU 0 for verification + if WORLD_RANK == 0: + x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda") + x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK] + + # Create quantizers + quantizer, quantizer_dist = _construct_quantizer( + quantizer_cls, low_precision_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE + ) + + # quantize the entire input + if WORLD_RANK == 0: + x_low_precision_single = quantizer(x_hp_rank0) + + # run all-gather with a quantizer as input for quantized all-gather + x_low_precision_total, _ = gather_along_first_dim( + x_hp_local_rank, NCCL_WORLD, async_op=False, quantizer=quantizer_dist + ) + + # check the outputs + if WORLD_RANK == 0: + # assert all data and scale_inv are the same + torch.testing.assert_close( + x_low_precision_single._rowwise_data, + x_low_precision_total._rowwise_data, + rtol=0.0, + atol=0.0, + ) + # check the rowwise scale without any padding + unpad_dim0, unpad_dim1 = unpadded_rowwise_scale_inv_shape + unpadded_rowwise_scale_inv_ref = x_low_precision_single._rowwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + unpadded_rowwise_scale_inv = x_low_precision_total._rowwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + torch.testing.assert_close( + unpadded_rowwise_scale_inv_ref, + unpadded_rowwise_scale_inv, + rtol=0.0, + atol=0.0, + ) + torch.testing.assert_close( + _ref_zero_padding_scale_inv( + x_low_precision_single._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape + ), + _ref_zero_padding_scale_inv( + x_low_precision_total._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape + ), + rtol=0.0, + atol=0.0, + ) + torch.testing.assert_close( + x_low_precision_single._columnwise_data, + x_low_precision_total._columnwise_data, + rtol=0.0, + atol=0.0, + ) + unpad_dim0, unpad_dim1 = unpadded_columnwise_scale_inv_shape + unpadded_columnwise_scale_inv_ref = x_low_precision_single._columnwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + unpadded_columnwise_scale_inv = x_low_precision_total._columnwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + torch.testing.assert_close( + unpadded_columnwise_scale_inv_ref, + unpadded_columnwise_scale_inv, + rtol=0.0, + atol=0.0, + ) + torch.testing.assert_close( + _ref_zero_padding_scale_inv( + x_low_precision_single._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape + ), + _ref_zero_padding_scale_inv( + x_low_precision_total._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape + ), + rtol=0.0, + atol=0.0, + ) + + +def test_quantized_all_gather(): + """ + Run quantized all-gather tests with various configurations. + """ + # skip this test for other quantization schemes + is_nvfp4 = QUANTIZATION == "nvfp4" + # add other recipes for testing if needed + if not is_nvfp4: + return + + input_dtypes = [torch.bfloat16] + fp4_dtype = [tex.DType.kFloat4E2M1] + fp8_dtype = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + quantizer_cls_nvfp4 = [NVFP4Quantizer] + # add FP8 quantizers if needed + quantizer_cls_fp8 = [] + + low_precisio_dtypes = fp4_dtype if is_nvfp4 else fp8_dtype + quantizer_cls_list = quantizer_cls_nvfp4 if is_nvfp4 else quantizer_cls_fp8 + + for quantizer_cls in quantizer_cls_list: + for input_dtype in input_dtypes: + for low_precision_dtype in low_precisio_dtypes: + _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls) + + ############################################ # Linear # ############################################ @@ -536,10 +760,11 @@ def test_linear(): {"init_method": _constant}, {"fuse_wgrad_accumulation": True}, {"return_bias": True}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"delay_wgrad_compute": True}, {"save_original_input": True}, ] + for kwargs in kwargs_list: if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8": continue @@ -715,11 +940,12 @@ def test_layernorm_linear(): {"init_method": _constant}, {"fuse_wgrad_accumulation": True}, {"return_bias": True}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"zero_centered_gamma": False}, {"return_layernorm_output": True}, {"delay_wgrad_compute": True}, ] + for kwargs in kwargs_list: for parallel_mode in ["column"]: for sequence_parallel in [False, True]: @@ -821,7 +1047,7 @@ def test_layernorm_mlp(): {"normalization": "RMSNorm"}, {"zero_centered_gamma": True}, {"bias": False}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"activation": "relu"}, {"fuse_wgrad_accumulation": True}, {"return_bias": True}, @@ -919,7 +1145,7 @@ def test_transformer_layer(): {"fuse_qkv_params": True, "fuse_wgrad_accumulation": True}, {"qkv_weight_interleaved": False}, {"bias": False}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"fuse_qkv_params": True}, {"activation": "relu"}, ] diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py new file mode 100644 index 000000000..b1722b79a --- /dev/null +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -0,0 +1,718 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import datetime +import os +import sys +from functools import wraps +import math + +import transformer_engine.pytorch as te +import torch +from torch import nn +import torch.distributed as dist +import transformer_engine_torch as tex +from transformer_engine.common.recipe import ( + NVFP4BlockScaling, + Format, + Recipe, + QParams, +) +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE +from run_layer_with_overlap import _compare_tensors + + +BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE = 128, 256, 128 +WORLD_RANK, WORLD_SIZE = None, None +NCCL_WORLD = None +LOSS_FN = nn.MSELoss() +QUANTIZATION = None + + +def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + +# Quantization recipe setup +def quantization_recipe() -> Recipe: + if QUANTIZATION == "nvfp4": + return nvfp4_rht_and_2d_quantization() + raise ValueError(f"Unsupported quantization: {QUANTIZATION}") + + +def setup_environment_for_reference(): + if QUANTIZATION == "nvfp4": + os.environ["QAT_PARAMS"] = "9003" + else: + raise ValueError(f"Unsupported quantization for reference: {QUANTIZATION}") + + +def cleanup_environment(): + if "QAT_PARAMS" in os.environ: + del os.environ["QAT_PARAMS"] + + +def main(argv=None, namespace=None): + global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, QUANTIZATION, BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE + + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + + NCCL_WORLD = dist.new_group(backend="nccl") + + WORLD_SIZE = dist.get_world_size() + + parser = argparse.ArgumentParser() + parser.add_argument("--quantization", type=str, default=None) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--hidden-size", type=int, default=128) + parser.add_argument("--out-size", type=int, default=128) + args = parser.parse_args(argv, namespace) + + # Quantization scheme + QUANTIZATION = args.quantization + BATCH_SIZE = args.batch_size + HIDDEN_SIZE = args.hidden_size + OUT_SIZE = args.out_size + + test_dict = [ + test_linear, + test_layernorm_linear, + ] + + for test in test_dict: + test() + dist.destroy_process_group() + return 0 + + +def run_distributed_test(test_name=None): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + name = test_name if test_name is not None else func.__name__ + + dist_print(f"Starting test {name} with args {args} and {kwargs}") + torch.cuda.set_device(WORLD_RANK) + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + func(*args, **kwargs) + + dist.barrier() + dist_print(f"Passed test {name}") + + return wrapper + + return decorator + + +def dist_print(msg, src=None, end="\n", error=False): + stream = sys.stderr if error else sys.stdout + if WORLD_RANK == (0 if src is None else src): + stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n") + + +############################################ +# Linear # +############################################ +class TestDistributedLinearBase: + @staticmethod + def _prepare_data( + batch_size, hidden_size, out_size, use_bias=True, seed=0, dtype=torch.float32 + ): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda") + w = torch.randn((out_size, hidden_size), dtype=dtype, device="cuda") + bias = torch.randn((out_size), dtype=dtype, device="cuda") if use_bias else None + gradient = torch.randn((batch_size, out_size), dtype=dtype, device="cuda") + + return x, w, bias, gradient + + @staticmethod + def _shard_tensor(x, world_size, axis): + split_size = x.size()[axis] // world_size + split_tensor = torch.split(x, split_size, axis) + out = [] + for tensor in split_tensor: + out.append(tensor.detach().clone().requires_grad_(x.requires_grad)) + return out + + @staticmethod + def _gather_tensor(local, world_size, tp_group, concat_dim): + out_list = [torch.zeros_like(local) for _ in range(world_size)] + torch.distributed.all_gather(out_list, local, tp_group) + return torch.cat(out_list, dim=concat_dim) + + @staticmethod + def _all_reduce_tensor(local, world_size, tp_group): + if world_size == 1: + return local + handle = torch.distributed.all_reduce(local, group=tp_group, async_op=False) + return local + + @staticmethod + def _get_sum_abs_error(a, b): + return torch.sum(torch.abs(a - b)) + + @staticmethod + def _get_mean_abs_relative_error(a, b): + error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b)) + return torch.mean(error) + + @classmethod + def run_linear_preprocess_parallel( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_size=1, + rank=0, + ): + if tp_size > 1: + if parallel_mode == "column": + # split w in N dim, which should be axis 0 + w = cls._shard_tensor(w, tp_size, 0)[rank] + bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None + # split gradient in N dim, which should be axis 1 + gradient = cls._shard_tensor(gradient, tp_size, 1)[rank] + if sequence_parallel: + # split x in M dim, which should be axis 0 + x = cls._shard_tensor(x, tp_size, 0)[rank] + # row parallel, split x in k dim, which should be axis 1, split w in k dim, should be axis 1 + if parallel_mode == "row": + # split x in K dim, which should be axis 1 + x = cls._shard_tensor(x, tp_size, 1)[rank] + # split w in K dim, which should be axis 1 + w = cls._shard_tensor(w, tp_size, 1)[rank] + if sequence_parallel: + # split gradient in M dim, which should be axis 0 + gradient = cls._shard_tensor(gradient, tp_size, 0)[rank] + return x, w, bias, gradient + + @classmethod + def run_linear_postprocess_parallel( + cls, + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ): + if tp_size > 1: + if parallel_mode == "column": + # gather y_q in N dim, which should be axis 1 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 1) + # gather wgrad in N dim, which should be axis 0 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 0) + # gather bgrad in N dim, which should be axis 0 + bgrad = ( + cls._gather_tensor(bgrad, tp_size, tp_group, 0) if bgrad is not None else None + ) + if sequence_parallel: + # gather dgrad in M dim, which should be axis 0 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 0) + if parallel_mode == "row": + # gather dgrad in K dim, which should be axis 1 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 1) + # gather wgrad in K dim, which should be axis 1 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 1) + if sequence_parallel: + # gather y_q in M dim, which should be axis 0 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 0) + # we need to sum bias gradient when using TP + SP + bgrad = ( + cls._all_reduce_tensor(bgrad, tp_size, tp_group) + if bgrad is not None + else None + ) + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_one_step( + cls, layer, x, gradient, is_first_microbatch=None, fuse_wgrad_accumulation=False + ): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + if isinstance(layer, te.Linear): + # Kitchen Linear + y_q = layer.forward(x, is_first_microbatch=is_first_microbatch) + else: + # the default torch.nn.Linear + y_q = layer(x) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + bgrad = ( + layer._parameters["bias"].grad + if layer._parameters.get("bias", None) is not None + else None + ) + assert "weight" in layer._parameters + if fuse_wgrad_accumulation: + wgrad = layer._parameters["weight"].main_grad + assert layer._parameters["weight"].grad is None + else: + wgrad = layer._parameters["weight"].grad + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation=False, + ): + """ + Run multiple steps of linear layer and collect results. + """ + + y_q_list, dgrad_list, wgrad_list = [], [], [] + bgrad_list = [] if layer._parameters.get("bias", None) is not None else None + + for i in range(run_num_steps): + x_i = (x + i).clone().detach().requires_grad_(True) + # run_linear_one_step + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step( + layer, + x_i, + gradient, + is_first_microbatch=(i == 0) if enable_weight_cache else None, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + # Collect results + y_q_list.append(y_q.detach().clone()) + dgrad_list.append(dgrad.detach().clone()) + wgrad_list.append(wgrad.detach().clone()) + if bgrad_list is not None and bgrad is not None: + bgrad_list.append(bgrad.detach().clone()) + + # Stack the results + return ( + torch.stack(y_q_list), + torch.stack(dgrad_list), + torch.stack(wgrad_list), + torch.stack(bgrad_list) if bgrad_list is not None else None, + ) + + @classmethod + def run_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + fuse_wgrad_accumulation=False, + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = te.Linear( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + layer = layer.to("cuda") + + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + if fuse_wgrad_accumulation: + assert ( + run_num_steps > 1 + ), "Fused weight gradient accumulation requires run_num_steps > 1" + layer.weight.main_grad = torch.zeros_like(layer.weight) + + # Run one step or multiple steps + if run_num_steps == 1: + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + else: + y_q, dgrad, wgrad, bgrad = cls.run_linear_multiple_steps( + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation, + ) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, dgrad, wgrad, bgrad + + +@run_distributed_test() +def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): + """Test the linear layer with specified parallel mode and sequence parallelization. + + Args: + parallel_mode (str): 'row' or 'column' parallelism. + sequence_parallel (bool): Enable sequence parallelism if True. + kwargs (dict): Additional arguments for the linear layer. + + QUANTIZATION options: nvfp4 <=> experimental nvfp4 as a reference + """ + params_dtype = torch.bfloat16 + use_bias = kwargs.get("bias", True) + fuse_wgrad_accumulation = kwargs.get("fuse_wgrad_accumulation", False) + seed = torch.initial_seed() + recipe = quantization_recipe() + + # turn on weight quantization cache when fusing wgrad accumulation + enable_weight_cache = fuse_wgrad_accumulation + run_num_steps = 1 if not fuse_wgrad_accumulation else 5 + + x, w, bias, gradient = TestDistributedLinearBase._prepare_data( + BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype + ) + + # run the recipe under test + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + y_q, dgrad, wgrad, bgrad = TestDistributedLinearBase.run_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + run_num_steps=1 if not fuse_wgrad_accumulation else 5, + enable_weight_cache=fuse_wgrad_accumulation, + ) + + # run the reference + setup_environment_for_reference() + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = TestDistributedLinearBase.run_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + run_num_steps=run_num_steps, + enable_weight_cache=enable_weight_cache, + ) + # Clean up env + cleanup_environment() + + # compare results, zero tolerance + if WORLD_RANK == 0: + torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch") + torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch") + torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch") + if bgrad is not None and bgrad_ref is not None: + torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch") + + +def test_linear(): + """Run linear layer tests with various configurations.""" + kwargs_list = [ + {"bias": False}, + ] + + for kwargs in kwargs_list: + if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8": + continue + for parallel_mode in ["column", "row"]: + for sequence_parallel in [False, True]: + _test_linear(parallel_mode, sequence_parallel, **kwargs) + + +############################################ +# LayerNormLinear # +############################################ +class TestDistributedLayerNormLinearBase(TestDistributedLinearBase): + + @classmethod + def run_linear_one_step(cls, layer, x, gradient, is_first_microbatch=None): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + y_q, ln_out = layer.forward(x, is_first_microbatch=is_first_microbatch) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + + parameters = layer._parameters + + # bias and weight gradients + bgrad = parameters["bias"].grad if parameters.get("bias", None) is not None else None + assert "weight" in parameters + wgrad = parameters["weight"].grad + + return y_q, ln_out, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, layer, x, gradient, run_num_steps, enable_weight_cache, fuse_wgrad_accumulation=False + ): + # raise error, no test case for multiple steps for now + raise NotImplementedError("LayerNormLinear does not support test multiple steps for now") + + @classmethod + def run_layernorm_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + LayerNormLinearClass=te.LayerNormLinear, + normalization="LayerNorm", + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = LayerNormLinearClass( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + normalization=normalization, + return_layernorm_output=True, + ) + + layer = layer.to("cuda") + + # Copy weights + # kitchen_linear has different parameter names + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + # Run one step + y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, ln_out, dgrad, wgrad, bgrad + + +@run_distributed_test() +def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs): + """Test the linear layer with specified parallel mode and sequence parallelization. + + Args: + parallel_mode (str): 'column' parallelism. + sequence_parallel (bool): Enable sequence parallelism if True. + kwargs (dict): Additional arguments for the linear layer. + """ + params_dtype = torch.bfloat16 + use_bias = kwargs.get("bias", True) + seed = torch.initial_seed() + recipe = quantization_recipe() + + # run multiple steps currently not supported for LayerNormLinear + run_num_steps = 1 + + x, w, bias, gradient = TestDistributedLayerNormLinearBase._prepare_data( + BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype + ) + + # run the recipe under test + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + y_q, ln_out, dgrad, wgrad, bgrad = TestDistributedLayerNormLinearBase.run_layernorm_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + run_num_steps=run_num_steps, + enable_weight_cache=False, + ) + + # run the reference + setup_environment_for_reference() + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = ( + TestDistributedLayerNormLinearBase.run_layernorm_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + run_num_steps=run_num_steps, + enable_weight_cache=False, + ) + ) + # Clean up env + cleanup_environment() + + # compare results, zero tolerance + if WORLD_RANK == 0: + torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch") + torch.testing.assert_close(ln_out, ln_out_ref, atol=0, rtol=0, msg="LN output mismatch") + torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch") + torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch") + if bgrad is not None and bgrad_ref is not None: + torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch") + + +def test_layernorm_linear(): + kwargs_list = [ + {"bias": False}, + ] + + for kwargs in kwargs_list: + for parallel_mode in ["column"]: + for sequence_parallel in [False, True]: + _test_layernorm_linear(parallel_mode, sequence_parallel, **kwargs) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 6dc17b126..83acd6df7 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -29,6 +29,7 @@ Float8CurrentScalingQuantizer, ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.utils import is_bf16_compatible, is_fp8_fnuz import transformer_engine_torch as tex @@ -36,17 +37,20 @@ # Import utility functions _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) -from utils import dtype_tols, make_recipe +from utils import dtype_tols, make_recipe, quantization_tols # Check what quantization schemes are supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_mxfp8_available() quantization_list: list[Optional[str]] = [None] if fp8_available: quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: quantization_list.append("mxfp8") +if nvfp4_available: + quantization_list.append("nvfp4") @functools.cache @@ -117,6 +121,14 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + elif quantization == "nvfp4": + test = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") if isinstance(test, QuantizedTensor) and not test_is_quantized: @@ -439,7 +451,7 @@ def _test_basic_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -611,7 +623,7 @@ def _test_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -625,6 +637,204 @@ def _test_linear( torch.testing.assert_close(db_test, db_ref, **tols) +def _test_mlp( + *, + bias: bool = True, + hidden_size: int = 32, + local_batch_size: int = 32, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + quantization: Optional[str] = None, + quantized_weight: bool = False, + sequence_parallel: bool = False, +) -> None: + """2-layer MLP + + MLP includes GELU activation in order to test op fusions. Model + performs warmup steps in order to test inter-step logic. + + """ + + # Skip invalid configurations + quantized_compute = quantization is not None + if not quantized_compute and quantized_weight: + return + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Tensor dimensions + mlp_size = hidden_size * world_size + batch_size = local_batch_size + if sequence_parallel: + batch_size *= world_size + in_shape = (batch_size, hidden_size) + + # Random data + reset_rng() + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + w1_ref, w1_test = make_reference_and_test_tensors( + (mlp_size, hidden_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + b1_ref, b1_test = None, None + w2_ref, w2_test = make_reference_and_test_tensors( + (hidden_size, mlp_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + b2_ref, b2_test = None, None + if bias: + b1_ref, b1_test = make_reference_and_test_tensors( + (mlp_size,), + test_dtype=dtype, + test_device=device, + ) + b2_ref, b2_test = make_reference_and_test_tensors( + (world_size, hidden_size), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh") + y_ref = torch.nn.functional.linear(y_ref, w1_ref) + if bias: + y_ref += b1_ref + y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh") + y_ref = torch.nn.functional.linear(y_ref, w2_ref) + if bias: + y_ref += b2_ref.sum(dim=0) + y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh") + y_ref.backward(dy_ref) + + # Convert to distributed tensors + with torch.no_grad(): + local_mlp_size = mlp_size // world_size + local_mlp_slice = slice(rank * local_mlp_size, (rank + 1) * local_mlp_size) + dx_ref = x_ref.grad + dw1_ref = w1_ref.grad[local_mlp_slice, :] + w1_ref = w1_ref[local_mlp_slice, :] + w1_test = w1_test[local_mlp_slice, :] + dw2_ref = w2_ref.grad[:, local_mlp_slice] + w2_ref = w2_ref[:, local_mlp_slice] + w2_test = w2_test[:, local_mlp_slice] + if bias: + db1_ref = b1_ref.grad[local_mlp_slice] + b1_ref = b1_ref[local_mlp_slice] + b1_test = b1_test[local_mlp_slice] + db2_ref = b2_ref.grad[rank, :] + b2_ref = b2_ref[rank, :] + b2_test = b2_test[rank, :] + else: + db1_ref = None + db2_ref = None + if sequence_parallel: + local_batch_slice = slice( + rank * local_batch_size, + (rank + 1) * local_batch_size, + ) + x_ref = x_ref[local_batch_slice, ...] + dx_ref = dx_ref[local_batch_slice, ...] + x_test = x_test[local_batch_slice, ...].clone() + y_ref = y_ref[local_batch_slice, ...] + dy_ref = dy_ref[local_batch_slice, ...] + dy_test = dy_test[local_batch_slice, ...].clone() + x_test.requires_grad_() + + # Implementation with fusible operation + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model = te_ops.Sequential( + te_ops.GELU(), + te_ops.Linear( + hidden_size, + mlp_size, + bias=bias, + device=device, + dtype=dtype, + tensor_parallel_mode="column", + tensor_parallel_group=process_group, + sequence_parallel=sequence_parallel, + ), + te_ops.GELU(), + te_ops.Linear( + mlp_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, + tensor_parallel_mode="row", + tensor_parallel_group=process_group, + sequence_parallel=sequence_parallel, + ), + te_ops.GELU(), + ) + with torch.no_grad(): + model[1].weight.copy_(w1_test) + model[3].weight.copy_(w2_test) + if bias: + model[1].bias.copy_(b1_test) + model[3].bias.copy_(b2_test) + del w1_test, w2_test, b1_test, b2_test + + # Warmup steps + for _ in range(3): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = model(x_test) + y_test.backward(dy_test) + x_test.grad = None + model[1].weight.grad = None + model[3].weight.grad = None + if bias: + model[1].bias.grad = None + model[3].bias.grad = None + + # Forward and backward step + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = model(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = quantization_tols(quantization) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") + dw2_test = model[3].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, dx_ref, **tols) + torch.testing.assert_close(dw1_test, dw1_ref, **tols) + torch.testing.assert_close(dw2_test, dw2_ref, **tols) + if bias: + db1_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu") + db2_test = model[3].bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db1_test, db1_ref, **tols) + torch.testing.assert_close(db2_test, db2_ref, **tols) + + def _test_fp8_scale_update( *, amax_history_len: int = 31, @@ -791,16 +1001,31 @@ def run_parallel_tests() -> None: for config in itertools.product( quantization_list, ("column", "row"), + (False, True), ): if rank == 0: print(f"Running _test_linear with {config=}") - quantization, tensor_parallel_mode = config + quantization, tensor_parallel_mode, sequence_parallel = config dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 _test_linear( bias=True, # bias=False is tested in _test_basic_linear dtype=dtype, quantization=quantization, tensor_parallel_mode=tensor_parallel_mode, + sequence_parallel=sequence_parallel, + ) + + # MLP + for config in itertools.product(quantization_list, (False, True)): + if rank == 0: + print(f"Running _test_mlp with {config=}") + quantization, sequence_parallel = config + dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 + _test_mlp( + bias=True, # bias=False is tested in _test_basic_linear + dtype=dtype, + quantization=quantization, + sequence_parallel=sequence_parallel, ) # FP8 scale update diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index 1ff5aff99..d09c530cb 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -31,6 +31,7 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( FP8GlobalStateManager.is_fp8_block_scaling_available() ) +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) @@ -51,7 +52,9 @@ def _run_test(quantization): all_boolean = [True, False] -@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"]) +@pytest.mark.parametrize( + "quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"] +) def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) @@ -61,4 +64,6 @@ def test_distributed(quantization): pytest.skip(reason_for_no_mxfp8) if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) + if quantization == "nvfp4" and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) _run_test(quantization) diff --git a/tests/pytorch/distributed/test_numerics_exact.py b/tests/pytorch/distributed/test_numerics_exact.py new file mode 100644 index 000000000..890a24804 --- /dev/null +++ b/tests/pytorch/distributed/test_numerics_exact.py @@ -0,0 +1,70 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import subprocess +from pathlib import Path + +import pytest +import torch +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + +""" + Distributed numerics tests + + This numerical test aims for zero tolerance test for absolute confidence in numerics. + In the case of NVFP4, with the experimental NVFP4 quantization, we matched bitwise + result with the native silicon. For distrbuted test cases, we can do the same by thing + by comparing BF16 AG results with the low precision AG results at layer level. +""" + + +if torch.cuda.device_count() < 2: + pytest.skip("Distributed training needs at least 2 GPUs.") + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = min(4, torch.cuda.device_count()) +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def _run_test(quantization, batch_size, hidden_size, out_size): + test_path = TEST_ROOT / "run_numerics_exact.py" + test_cmd = LAUNCH_CMD + [str(test_path)] + + test_cmd += ["--quantization", quantization] + test_cmd += ["--batch-size", str(batch_size)] + test_cmd += ["--hidden-size", str(hidden_size)] + test_cmd += ["--out-size", str(out_size)] + + result = subprocess.run(test_cmd, env=os.environ, check=False) + assert result.returncode == 0 + + +all_boolean = [True, False] + + +@pytest.mark.parametrize("quantization", ["nvfp4"]) +@pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (64, 128, 128), + (128, 128, 128), + (128, 256, 256), + (512, 1024, 768), + (512, 256, 1024), + (2048, 2048, 2048), + ], +) +def test_distributed(quantization, batch_size, hidden_size, out_size): + if quantization == "nvfp4" and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + + _run_test(quantization, batch_size, hidden_size, out_size) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py new file mode 100644 index 000000000..a9e73aaf9 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -0,0 +1,243 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef +from transformer_engine.pytorch.experimental import utils + + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + + +def check_nvfp4_gemm_versus_reference( + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + M: int, + K: int, + N: int, + accumulate: bool, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, +): + te_dtype = tex.DType.kFloat4E2M1 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input tensors + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + x = torch.randn(x_shape, dtype=x_dtype, device=device) + w = torch.randn(w_shape, dtype=w_dtype, device=device) + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) + else: + out = None + + # Native TE NVFP4 quantization + x_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + w_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + # Quantize x and w + x_nvfp4_native = x_quantizer.make_empty( + x_shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_native = x_quantizer.update_quantized(x, x_nvfp4_native) + w_nvfp4_native = w_quantizer.make_empty( + w_shape, dtype=w_dtype, device=device, requires_grad=False + ) + w_nvfp4_native = w_quantizer.update_quantized(w, w_nvfp4_native) + + # Extract quantized data from native NVFP4Tensors + qx_data = ( + x_nvfp4_native._columnwise_data.view(dtype=torch.uint8) + if x_columnwise + else x_nvfp4_native._rowwise_data.view(dtype=torch.uint8) + ) + qw_data = ( + w_nvfp4_native._columnwise_data.view(dtype=torch.uint8) + if w_columnwise + else w_nvfp4_native._rowwise_data.view(dtype=torch.uint8) + ) + sx_native = ( + x_nvfp4_native._columnwise_scale_inv if x_columnwise else x_nvfp4_native._rowwise_scale_inv + ) + sw_native = ( + w_nvfp4_native._columnwise_scale_inv if w_columnwise else w_nvfp4_native._rowwise_scale_inv + ) + + # Trim quantized data to match the actual tensor dimensions (remove padding) + qx_data = qx_data[:M, :] + qw_data = qw_data[:N, :] + + # NVFP4 uses 16-element blocks, trim scales to remove padding + block_length = 16 # NVFP4 uses 16-element blocks + expected_sx_cols = expected_sw_cols = K // block_length + # Trim the scales to remove padding + sx_trimmed = sx_native[:M, :expected_sx_cols] + sw_trimmed = sw_native[:N, :expected_sw_cols] + + # Native scales are stored as uint8 but need to be interpreted as float8_e4m3fn + # for the reference GEMM to work correctly + sx_trimmed = sx_trimmed.view(torch.float8_e4m3fn) + sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn) + + # Create reference quantizer for reference GEMM + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=True, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + + # Create reference quantized tensors needed by reference GEMM + x_nvfp4_ref = ref_quantizer.quantize(x) + w_nvfp4_ref = ref_quantizer.quantize(w) + + # Reference GEMM using quantizer's qgemm method + y_ref = ref_quantizer.qgemm( + qx=qx_data, + qw=qw_data, + m_params=None, # MMParams not used in reference + out_dtype=out_dtype, + sx=sx_trimmed, + sw=sw_trimmed, + bias=None, # No bias for this test + out=out.clone() if accumulate else None, + accumulate=accumulate, + gemm_type=None, # GEMMType not used in reference + qresult_x=x_nvfp4_ref, + qresult_w=w_nvfp4_ref, + ) + + # Native TE GEMM using tex.generic_gemm (cuBLAS GEMM) + # Allocate cuBLAS workspace + workspace = torch.empty(4, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + bias = None + bias_dtype = TE_DType[torch.bfloat16] + use_gelu = False + gelu_input = None + use_grad = False + use_split_accumulator = False + + # Native cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y_native = tex.generic_gemm( + w_nvfp4_native, + transa, + x_nvfp4_native, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_input, + use_grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] + + # just in case of accumulation, make sure y_ref and y_native are not the same tensor + assert y_ref is not y_native, "y_ref and y_native should not be the same tensor" + # Reset nans to zeros because torch.assert_close does not assume nans to be equal + assert not torch.isnan(y_ref.float()).all(), "All elements are nan" + y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref) + y_native = torch.where(y_native.isnan(), torch.zeros_like(y_native), y_native) + + # Compare results with some tolerance + torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, K, N", + [ + (128, 128, 128), + (256, 128, 256), + (256, 256, 256), + (256, 1024, 256), + (1024, 1024, 1024), + (4096, 512, 3072), + (112, 128, 96), + (304, 640, 304), + (1008, 3072, 992), + (256, 64, 256), + (128, 128, 112), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize( + "is_x_columnwise, is_w_columnwise", + [ + (False, False), # Only rowwise x rowwise is supported by reference GEMM + # Note: Reference GEMM expects inputs as (M,K) x (N,K) with rowwise quantization + # Columnwise layouts are not supported by the reference implementation + ], + ids=["rowxrow"], +) +def test_nvfp4_gemm_versus_reference( + M: int, + K: int, + N: int, + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + accumulate: bool, + is_x_columnwise: bool, + is_w_columnwise: bool, +): + check_nvfp4_gemm_versus_reference( + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + M=M, + K=K, + N=N, + accumulate=accumulate, + x_columnwise=is_x_columnwise, + w_columnwise=is_w_columnwise, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py new file mode 100644 index 000000000..ae9975839 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -0,0 +1,559 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import pytest +import torch +import transformer_engine as te +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.distributed import fp8_autocast +from transformer_engine.common import recipe + + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + + +class GetRecipes: + @staticmethod + def nvfp4_vanilla(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + @staticmethod + def nvfp4_rht_only(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(random_hadamard_transform=True) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(random_hadamard_transform=False) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(random_hadamard_transform=True) + return nvfp4_recipe + + @staticmethod + def nvfp4_2d_quantization_only(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(fp4_2d_quantization=False) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(fp4_2d_quantization=False) + return nvfp4_recipe + + @staticmethod + def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + @staticmethod + def nvfp4_recipe_to_test(with_rht: bool = False, with_2d_quantization: bool = False): + if with_rht and with_2d_quantization: + return GetRecipes.nvfp4_rht_and_2d_quantization() + elif with_rht: + return GetRecipes.nvfp4_rht_only() + elif with_2d_quantization: + return GetRecipes.nvfp4_2d_quantization_only() + else: + return GetRecipes.nvfp4_vanilla() + + +def setup_environment_for_reference(with_rht: bool = False, with_2d_quantization: bool = False): + if with_rht and with_2d_quantization: + os.environ["QAT_PARAMS"] = "9003" + elif with_rht: + os.environ["QAT_PARAMS"] = "960109" + elif with_2d_quantization: + os.environ["QAT_PARAMS"] = "9002" + else: + os.environ["QAT_PARAMS"] = "6010" + + +def cleanup_environment(): + if "QAT_PARAMS" in os.environ: + del os.environ["QAT_PARAMS"] + + +def reset_rng_states(): + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +def check_nvfp4_module_versus_reference( + module_class, + in_features: int, + out_features: int, + bias: bool, + x_dtype: torch.dtype, + num_steps: int = 1, + with_rht: bool = False, + with_2d_quantization: bool = False, +): + """ + Compare native NVFP4 module against reference implementation. + + Args: + module_class: te.Linear or te.LayerNormLinear + in_features: Input feature dimension + out_features: Output feature dimension + bias: Whether to use bias + x_dtype: Input tensor dtype + num_steps: Number of forward/backward steps to test + """ + device = "cuda" + batch_size = 32 + seq_len = 128 + + # Create both modules with identical initialization + cleanup_environment() + reset_rng_states() + + # Create native module + print("\nCreate native module") + if module_class == te.pytorch.Linear: + native_module = te.pytorch.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + elif module_class == te.pytorch.LayerNormLinear: + native_module = te.pytorch.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + else: + raise ValueError(f"Unsupported module class: {module_class}") + + # Create reference module with same weights + setup_environment_for_reference(with_rht, with_2d_quantization) + reset_rng_states() + + # Create reference module + print("Create reference module") + if module_class == te.pytorch.Linear: + ref_module = te.pytorch.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + elif module_class == te.pytorch.LayerNormLinear: + ref_module = te.pytorch.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + + # Sync weights between native and reference modules + with torch.no_grad(): + # Copy main weight and bias parameters + if hasattr(native_module, "weight") and hasattr(ref_module, "weight"): + ref_module.weight.copy_(native_module.weight) + if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"): + ref_module.bias.copy_(native_module.bias) + + # Copy layer norm parameters if they exist + if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"): + ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight) + if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"): + ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias) + + nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization) + + # Training loop comparison + native_outputs = [] + ref_outputs = [] + + for step in range(num_steps): + torch.manual_seed(1234 + step) + torch.cuda.manual_seed(1234 + step) + + x_shape = (batch_size, seq_len, in_features) + x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device) + x_native = x_val.clone().detach().requires_grad_(True) + x_ref = x_native.clone().detach().requires_grad_(True) + + grad_output_shape = (batch_size, seq_len, out_features) + grad_output_val = torch.normal( + mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device + ) + grad_output = grad_output_val.clone().detach() + + # Native forward/backward + cleanup_environment() + with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): + # enable weight cache by giving is_first_microbatch + y_native = native_module(x_native, is_first_microbatch=(step == 0)) + y_native.backward(grad_output) + + # Reference forward/backward + setup_environment_for_reference(with_rht, with_2d_quantization) + with fp8_autocast( + enabled=True, fp8_recipe=nvfp4_recipe + ): # Exact recipe does not play a role here + y_ref = ref_module(x_ref) + y_ref.backward(grad_output) + + # Store results + native_outputs.append( + { + "output": y_native.detach().clone(), + "input_grad": ( + x_native.grad.detach().clone() if x_native.grad is not None else None + ), + "weight_grad": ( + native_module.weight.grad.detach().clone() + if native_module.weight.grad is not None + else None + ), + "bias_grad": ( + native_module.bias.grad.detach().clone() + if bias and native_module.bias.grad is not None + else None + ), + } + ) + + ref_outputs.append( + { + "output": y_ref.detach().clone(), + "input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None), + "weight_grad": ( + ref_module.weight.grad.detach().clone() + if ref_module.weight.grad is not None + else None + ), + "bias_grad": ( + ref_module.bias.grad.detach().clone() + if bias and ref_module.bias.grad is not None + else None + ), + } + ) + + # Compare results across all steps + for step in range(num_steps): + native_out = native_outputs[step] + ref_out = ref_outputs[step] + + # Compare outputs + torch.testing.assert_close( + native_out["output"], + ref_out["output"], + atol=1e-6, + rtol=1e-6, + msg=f"Output mismatch at step {step}", + ) + + # Compare input gradients + torch.testing.assert_close( + native_out["input_grad"], + ref_out["input_grad"], + atol=1e-6, + rtol=1e-6, + msg=( + f"Input gradient mismatch at step {step}. Native: {native_out['input_grad']}, Ref:" + f" {ref_out['input_grad']}" + ), + ) + + # Compare weight gradients + torch.testing.assert_close( + native_out["weight_grad"], + ref_out["weight_grad"], + atol=1e-6, + rtol=1e-6, + msg=( + f"Weight gradient mismatch at step {step}. Native: {native_out['weight_grad']}," + f" Ref: {ref_out['weight_grad']}" + ), + ) + + # Compare bias gradients + if bias and native_out["bias_grad"] is not None and ref_out["bias_grad"] is not None: + torch.testing.assert_close( + native_out["bias_grad"], + ref_out["bias_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Bias gradient mismatch at step {step}", + ) + + # Clean up + cleanup_environment() + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "in_features, out_features", + [ + (128, 256), + (256, 128), + (512, 512), + (768, 3072), + (1024, 4096), + ], +) +# @pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"]) +@pytest.mark.parametrize("bias", [False], ids=["no_bias"]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("num_steps", [1, 3], ids=["single_step", "multi_step"]) +@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"]) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"] +) +def test_nvfp4_linear_versus_reference( + in_features: int, + out_features: int, + bias: bool, + x_dtype: torch.dtype, + num_steps: int, + with_rht: bool, + with_2d_quantization: bool, +): + """Test NVFP4 Linear module against reference implementation.""" + if with_rht and x_dtype != torch.bfloat16: + pytest.skip("RHT is only supported for bfloat16 input") + + check_nvfp4_module_versus_reference( + module_class=te.pytorch.Linear, + in_features=in_features, + out_features=out_features, + bias=bias, + x_dtype=x_dtype, + num_steps=num_steps, + with_rht=with_rht, + with_2d_quantization=with_2d_quantization, + ) + + +def check_nvfp4_layernorm_linear_versus_reference( + in_features: int, + out_features: int, + bias: bool, + normalization: str, + x_dtype: torch.dtype, + num_steps: int = 1, + with_rht: bool = False, + with_2d_quantization: bool = False, +): + """ + Compare native NVFP4 LayerNormLinear module against reference implementation, + including ln_out. + """ + device = "cuda" + batch_size = 32 + seq_len = 128 + + # Create both modules with identical initialization + cleanup_environment() + reset_rng_states() + + # Native module + native_module = te.pytorch.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + normalization=normalization, + return_layernorm_output=True, + ) + + # Reference module + setup_environment_for_reference(with_rht, with_2d_quantization) + reset_rng_states() + ref_module = te.pytorch.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + normalization=normalization, + return_layernorm_output=True, + ) + + # Sync weights and LN params + with torch.no_grad(): + if hasattr(native_module, "weight") and hasattr(ref_module, "weight"): + ref_module.weight.copy_(native_module.weight) + if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"): + ref_module.bias.copy_(native_module.bias) + if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"): + if ( + native_module.layer_norm_weight is not None + and ref_module.layer_norm_weight is not None + ): + ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight) + if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"): + if native_module.layer_norm_bias is not None and ref_module.layer_norm_bias is not None: + ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias) + + nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization) + + native_outputs = [] + ref_outputs = [] + + for step in range(num_steps): + torch.manual_seed(1234 + step) + torch.cuda.manual_seed(1234 + step) + + x_shape = (batch_size, seq_len, in_features) + x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device) + x_native = x_val.clone().detach().requires_grad_(True) + x_ref = x_native.clone().detach().requires_grad_(True) + + grad_output_shape = (batch_size, seq_len, out_features) + grad_output_val = torch.normal( + mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device + ) + grad_output = grad_output_val.clone().detach() + + # Native forward/backward + cleanup_environment() + with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): + y_native, ln_out_native = native_module(x_native, is_first_microbatch=(step == 0)) + y_native.backward(grad_output) + + # Reference forward/backward + setup_environment_for_reference(with_rht, with_2d_quantization) + with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): + y_ref, ln_out_ref = ref_module(x_ref) + y_ref.backward(grad_output) + + native_outputs.append( + { + "output": y_native.detach().clone(), + "ln_out": ln_out_native.detach().clone(), + "input_grad": ( + x_native.grad.detach().clone() if x_native.grad is not None else None + ), + "weight_grad": ( + native_module.weight.grad.detach().clone() + if native_module.weight.grad is not None + else None + ), + "bias_grad": ( + native_module.bias.grad.detach().clone() + if bias and native_module.bias.grad is not None + else None + ), + } + ) + ref_outputs.append( + { + "output": y_ref.detach().clone(), + "ln_out": ln_out_ref.detach().clone(), + "input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None), + "weight_grad": ( + ref_module.weight.grad.detach().clone() + if ref_module.weight.grad is not None + else None + ), + "bias_grad": ( + ref_module.bias.grad.detach().clone() + if bias and ref_module.bias.grad is not None + else None + ), + } + ) + + # Compare results + for step in range(num_steps): + n = native_outputs[step] + r = ref_outputs[step] + torch.testing.assert_close( + n["output"], + r["output"], + atol=1e-6, + rtol=1e-6, + msg=f"Output mismatch at step {step}", + ) + torch.testing.assert_close( + n["ln_out"], + r["ln_out"], + atol=1e-6, + rtol=1e-6, + msg=f"LN output mismatch at step {step}", + ) + torch.testing.assert_close( + n["input_grad"], + r["input_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Input gradient mismatch at step {step}", + ) + torch.testing.assert_close( + n["weight_grad"], + r["weight_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Weight gradient mismatch at step {step}", + ) + if bias and n["bias_grad"] is not None and r["bias_grad"] is not None: + torch.testing.assert_close( + n["bias_grad"], + r["bias_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Bias gradient mismatch at step {step}", + ) + + cleanup_environment() + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "in_features, out_features", + [ + (128, 256), + (256, 128), + ], +) +@pytest.mark.parametrize("bias", [False], ids=["no_bias"]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("num_steps", [1], ids=["single_step"]) +@pytest.mark.parametrize("normalization", ["LayerNorm", "RMSNorm"], ids=["LayerNorm", "RMSNorm"]) +@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"]) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"] +) +def test_nvfp4_layernorm_linear_versus_reference( + in_features: int, + out_features: int, + bias: bool, + normalization: str, + x_dtype: torch.dtype, + num_steps: int, + with_rht: bool, + with_2d_quantization: bool, +): + if with_rht and x_dtype != torch.bfloat16: + pytest.skip("RHT is only supported for bfloat16 input") + + check_nvfp4_layernorm_linear_versus_reference( + in_features=in_features, + out_features=out_features, + bias=bias, + normalization=normalization, + x_dtype=x_dtype, + num_steps=num_steps, + with_rht=with_rht, + with_2d_quantization=with_2d_quantization, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py new file mode 100644 index 000000000..dc3c4a4e9 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -0,0 +1,495 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import NVFP4BlockScaling +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.nvfp4_tensor import ( + NVFP4Quantizer, +) +from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef +from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype + + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +def check_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + swizzled_scale: bool, + use_cpp_allocator: bool, + with_2d_quantization: bool, +) -> None: + te_dtype = tex.DType.kFloat4E2M1 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + ) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + # Extract data from NVFP4Tensor + assert x_nvfp4_sut._rowwise_data is not None + qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + # Reference quantization + quant_tile_shape = (1, 16) if not with_2d_quantization else (16, 16) + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=quant_tile_shape, + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + + # Extract data from RefNVFP4Tensor + qx_ref = ( + unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8)) + if x_nvfp4_ref.data is not None + else None + ) + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8)) + if x_nvfp4_ref.data_t is not None + else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + qx = unpack_fp4(qx) + qx_t = unpack_fp4(qx_t) if qx_t is not None else None + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + # Compare only the valid portion of scale tensors (reference may not have padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + + # Compare only the valid portion of transpose scale tensors + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (304, 304), + (320, 256), + # Some larger tiles + (2048, 2048), + (1024, 2048), + (2048, 1024), + # # largest tile + (8192, 8192), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize("swizzled_scale", [False], ids=["linear_scale"]) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] +) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + swizzled_scale: bool, + use_cpp_allocator: bool, + with_2d_quantization: bool, +) -> None: + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + return_transpose=return_transpose, + swizzled_scale=swizzled_scale, + use_cpp_allocator=use_cpp_allocator, + with_2d_quantization=with_2d_quantization, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (128, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_nvfp4_quantization_extrema_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + extrema_high: bool, + return_transpose: bool, + use_cpp_allocator: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + if extrema_high: + x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device) + else: + x = torch.zeros((M, N), dtype=x_dtype, device=device) + + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + assert x_nvfp4_sut._rowwise_data is not None + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + + qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (16, 128), + (32, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_nvfp4_quantization_boundary_values( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, +): + """ + Stress rounding/threshold behavior by placing values just below/above + many potential bin edges within each 16-element microblock. + Validates native vs reference byte-for-byte and scale parity. + """ + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 123 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Construct a single row with paired boundary values: v-eps, v+eps + # spanning a wide dynamic range to exercise clipping and multiple bins. + # Ensure even N and N is multiple of 16 for microblocks, which holds for 128. + base = torch.linspace(-12.0, 12.0, steps=N // 2, dtype=torch.float32, device=device) + eps = torch.full_like(base, 1e-3) + # Avoid zero eps for very small magnitudes + eps = torch.maximum(eps, 1e-4 * torch.ones_like(base)) + lower = base - eps + upper = base + eps + row = torch.empty(N, dtype=torch.float32, device=device) + row[0::2] = lower + row[1::2] = upper + x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype) + + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + assert x_nvfp4_sut._rowwise_data is not None + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + + qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + # Compare only valid portion of scales (trim any padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (32, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_nvfp4_quantization_noncontiguous_inputs( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 17 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Start from a contiguous tensor, then make a non-contiguous view by transpose + x_base = torch.randn((M, N), dtype=x_dtype, device=device) + x_nc = x_base.t() # shape (N, M), non-contiguous + assert not x_nc.is_contiguous() + + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x_nc) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + x_nc.shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut) + + assert x_nvfp4_sut._rowwise_data is not None + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + x_nvfp4_ref = ref_quantizer.quantize(x_nc) + + qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + # Quantized must match + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + # Compare only valid portion of scales (trim padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py new file mode 100644 index 000000000..bb542456e --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -0,0 +1,255 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py. +# Separate to make sure all the functionalities are working as expected. +# Otherwise reference implementation will get messy. + +# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality +# together with the quantization functionality. + +from typing import Tuple +import math + +import transformer_engine as te +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import NVFP4BlockScaling +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.nvfp4_tensor import ( + NVFP4Quantizer, +) +from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef +from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype + +import pytest +import torch + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +def check_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + contiguous: bool, + return_transpose: bool, + use_cpp_allocator: bool, + swizzled_scale: bool = False, + hadamard_dimension: int = 16, + with_rht: bool = True, + with_post_rht_amax: bool = True, + with_random_sign_mask: bool = True, +) -> None: + assert with_rht and with_post_rht_amax, "RHT and post-RHT amax reduction must be enabled." + + te_dtype = tex.DType.kFloat4E2M1 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + + x = x.transpose(0, 1) if not contiguous else x + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + ) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + x.shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + # Extract data from NVFP4Tensor + assert x_nvfp4_sut._rowwise_data is not None + qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + amax_rowwise = x_nvfp4_sut._amax_rowwise + amax_colwise = x_nvfp4_sut._amax_columnwise + + qx = unpack_fp4(qx) + qx_t = unpack_fp4(qx_t) if qx_t is not None else None + + # Reference quantization using NVFP4QuantizerRef with built-in RHT + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + with_rht=with_rht, + with_random_sign_mask=with_random_sign_mask, + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + # Extract data from RefNVFP4Tensor + qx_ref = ( + unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8)) + if x_nvfp4_ref.data is not None + else None + ) + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + ref_amax_rowwise = x_nvfp4_ref.global_amax_row + + if return_transpose: + assert x_nvfp4_ref.data_t is not None + assert x_nvfp4_ref.scale_t is not None + qx_t_ref = unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8)) + sx_t_ref = x_nvfp4_ref.scale_t.view(dtype=torch.uint8) + # Compute transpose amax using the same reference quantizer + x_t_for_amax = ( + ref_quantizer._apply_rht(x.t().contiguous()) if with_rht else x.t().contiguous() + ) + ref_amax_colwise_t = torch.max(torch.abs(x_t_for_amax)).to(torch.float32).view(1) + else: + qx_t_ref = None + sx_t_ref = None + ref_amax_colwise_t = None + + torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + # Compare only the valid portion of scale tensors (reference may not have padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(amax_colwise, ref_amax_colwise_t, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + + # Compare only the valid portion of transpose scale tensors + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (304, 304), + (320, 256), + # Some larger tiles + (2048, 2048), + (1024, 2048), + (2048, 1024), + # Real shapes, + (8192, 5120), + (8192, 10240), + (8192, 2560), + (8192, 11328), + (8192, 512), + (8192, 3584), + (5120, 8192), + (10240, 8192), + (2560, 8192), + (11328, 8192), + (512, 8192), + (3584, 8192), + (4096, 16384), + (14336, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +def test_rht_with_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, + with_random_sign_mask: bool, +) -> None: + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + contiguous=True, + return_transpose=return_transpose, + use_cpp_allocator=use_cpp_allocator, + with_random_sign_mask=with_random_sign_mask, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (32, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +def test_nvfp4_quantization_noncontiguous_inputs( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, + with_random_sign_mask: bool, +): + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + contiguous=False, + return_transpose=return_transpose, + use_cpp_allocator=use_cpp_allocator, + with_random_sign_mask=with_random_sign_mask, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py new file mode 100755 index 000000000..46077eb20 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -0,0 +1,238 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + +seed = 12345 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +_FP4_LUT = torch.tensor( + [ + 0.0, # 0: 0000 - zero + 0.5, # 1: 0001 - smallest positive normal + 1.0, # 2: 0010 + 1.5, # 3: 0011 + 2.0, # 4: 0100 + 3.0, # 5: 0101 + 4.0, # 6: 0110 + 6.0, # 7: 0111 - largest positive normal + -0.0, # 8: 1000 - negative zero + -0.5, # 9: 1001 - smallest negative normal + -1.0, # 10: 1010 + -1.5, # 11: 1011 + -2.0, # 12: 1100 + -3.0, # 13: 1101 + -4.0, # 14: 1110 + -6.0, # 15: 1111 - largest negative normal + ], + dtype=torch.float32, +) + + +def fp4_to_fp32(fp4: torch.Tensor) -> torch.Tensor: + # Convert FP4 indices to their corresponding floating point values + # Each index (0-15) represents a 4-bit FP4 value in E2M1 format + # Values based on the FP4 E2M1 specification + fp4_lut = _FP4_LUT.to(fp4.device) + return fp4_lut[fp4.to(torch.long)] + + +def dequantize_fp4(qx: torch.Tensor, sx: torch.Tensor, amax: torch.Tensor) -> torch.Tensor: + sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32) + dqx = fp4_to_fp32(unpack_fp4(qx)) + sf = sf[: dqx.shape[0], : dqx.shape[1]] + dequant = dqx * sf * (amax / (6.0 * 448)) + return dequant + + +def RHT(x: torch.Tensor) -> torch.Tensor: + def get_wgrad_sign_vector() -> torch.Tensor: + """Hard-coded signs for Hadamard transform""" + return torch.tensor( + [ + 1.0, + 1.0, + 1.0, + -1.0, + 1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 1.0, + -1.0, + 1.0, + -1.0, + -1.0, + ], + dtype=torch.float32, + ) + + def _build_hadamard_matrix( + size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True + ) -> torch.Tensor: + """Construct a Hadamard matrix of given power-of-two size with entries +-1. + + Uses Sylvester construction to avoid SciPy dependency. + """ + assert (size & (size - 1)) == 0, "Hadamard size must be a power of two" + h = torch.ones((1, 1), device=device, dtype=torch.float32) + while h.shape[0] < size: + h = torch.cat( + [ + torch.cat([h, h], dim=1), + torch.cat([h, -h], dim=1), + ], + dim=0, + ) + if with_random_sign_mask: + sign_mat = get_wgrad_sign_vector().to(device) * torch.eye( + size, device=device, dtype=torch.float32 + ) + h = sign_mat @ h + return h.to(dtype) + + rht_dim = 16 + # Build H and scale + H = _build_hadamard_matrix(rht_dim, x.device, x.dtype) + scale = 1.0 / float(rht_dim) ** 0.5 + + # Perform blockwise transform along the last dimension + original_shape = x.shape + x_mat = x.contiguous().view(-1, rht_dim) + # Random sign matrix is identity in this reference (no sign flipping) + transform = H * scale + out = x_mat @ transform + return out.view(original_shape) + + +def quantize_fp4( + x: torch.Tensor, use_stochastic_rounding: bool, use_2D: bool, use_RHT: bool +) -> torch.Tensor: + nvfp4_quantizer = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=use_RHT, + with_post_rht_amax=True, + stochastic_rounding=use_stochastic_rounding, + with_2d_quantization=use_2D, + ) + + x_nvfp4_sut = nvfp4_quantizer(x) + # Extract data from NVFP4Tensor + assert x_nvfp4_sut._rowwise_data is not None + qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv + assert x_nvfp4_sut._columnwise_data is not None + qx_t: torch.Tensor = x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._columnwise_scale_inv is not None + sx_t: torch.Tensor = x_nvfp4_sut._columnwise_scale_inv + + return qx, sx, qx_t, sx_t + + +def check_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool +) -> None: + device = "cuda" + torch.manual_seed(seed) + n_iters = 50 + + x = torch.randn((M, N), dtype=x_dtype, device=device) * 2 - 1 + y = x.t().contiguous() + if use_RHT: + y = RHT(y) + amax = torch.max(torch.abs(x)).float() + q_rn, s_rn, q_t_rn, s_t_rn = quantize_fp4( + x, use_stochastic_rounding=False, use_2D=use_2D, use_RHT=use_RHT + ) + dq_rn = dequantize_fp4(q_rn, s_rn, amax) + dq_t_rn = dequantize_fp4(q_t_rn, s_t_rn, amax) + error_rn = (dq_rn - x).float() + me_rn = torch.sqrt((error_rn * error_rn).mean()) + error_t_rn = (dq_t_rn - y).float() + me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean()) + sr_result = torch.zeros_like(x).float() + sr_t_result = torch.zeros_like(x).float().t().contiguous() + for i in range(n_iters): + q_sr, s_sr, q_t_sr, s_t_sr = quantize_fp4( + x, use_stochastic_rounding=True, use_2D=use_2D, use_RHT=use_RHT + ) + + dq_sr = dequantize_fp4(q_sr, s_sr, amax) + dq_t_sr = dequantize_fp4(q_t_sr, s_t_sr, amax) + + sr_result += dq_sr.float() + sr_t_result += dq_t_sr.float() + + # sr_result_tmp = sr_result / (i + 1) + # error_sr = (sr_result_tmp - x).float() + # me_sr = torch.sqrt((error_sr * error_sr).mean()) + # sr_t_result_tmp = sr_t_result / (i + 1) + # error_t_sr = (sr_t_result_tmp - y).float() + # me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean()) + # print(f"Iteration {i}: RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") + # print(f"Iteration {i}: RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") + + # Get the mean result of the stochastic rounding + # It should be more accurate than the RN result + sr_result /= n_iters + error_sr = (sr_result - x).float() + me_sr = torch.sqrt((error_sr * error_sr).mean()) + sr_t_result /= n_iters + error_t_sr = (sr_t_result - y).float() + me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean()) + + print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") + print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") + assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." + assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest." + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (8192, 8192), + (8192, 8256), # to test the nonfused RHT path + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("use_2D", [False, True], ids=str) +@pytest.mark.parametrize("use_RHT", [False, True], ids=str) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + use_2D: bool, + use_RHT: bool, + M: int, + N: int, +) -> None: + if x_dtype == torch.float32 and use_RHT: + pytest.skip("RHT is only supported with bfloat16") + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + use_2D=use_2D, + use_RHT=use_RHT, + M=M, + N=N, + ) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 90e624c94..be7a65deb 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -32,12 +32,59 @@ reset_rng_states() model_configs = { - "small": ModelConfig(32, 2, 2, 32), + "small": ModelConfig(2, 32, 2, 32), } + +def nvfp4_vanilla(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + +def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + +def check_rht_usage(recipe: recipe.Recipe) -> bool: + # if using RHT, we can only support bf16 + # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad + if recipe.nvfp4(): + if ( + recipe.fp4_quant_fwd_inp.random_hadamard_transform + or recipe.fp4_quant_fwd_weight.random_hadamard_transform + or recipe.fp4_quant_bwd_grad.random_hadamard_transform + ): + return True + return False + + +def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool: + supported_input_dtypes = [] + if recipe.nvfp4(): + supported_input_dtypes.append(torch.bfloat16) + # if not using RHT, we can add fp32 as well + if not check_rht_usage(recipe): + supported_input_dtypes.append(torch.float32) + return supported_input_dtypes + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) + fp8_recipes.append(nvfp4_rht_and_2d_quantization()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -278,7 +325,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__) def test_make_graphed_callables( *, module: str, @@ -295,8 +342,18 @@ def test_make_graphed_callables( pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") - if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op": - pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") + if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op": + pytest.skip( + f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs" + ) + if fp8 and fp8_recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe" + f" {fp8_recipe.__class__.__name__}" + ) + if fp8_params: + pytest.skip("NVFP4 params not supported") # Run model with different CUDA graph settings. model_config = model_configs[model_config] @@ -334,17 +391,19 @@ def test_make_graphed_callables( "module", _test_make_graphed_callables_with_fp8_weight_caching_modules, ) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__) def test_make_graphed_callables_with_fp8_weight_caching( *, module: str, + dtype: torch.dtype, fp8_params: bool, fp8_recipe: recipe.Recipe, ) -> None: test_make_graphed_callables( module=module, - dtype=torch.float32, + dtype=dtype, fp8_params=fp8_params, fp8_recipe=fp8_recipe, fp8_weight_caching=True, diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 281fc67a5..80802fce7 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -10,7 +10,6 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex -import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.common.recipe import Float8CurrentScaling, Format from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype @@ -273,6 +272,14 @@ def run_linear_multiple_steps( if bgrad_list is not None and bgrad is not None: bgrad_list.append(bgrad.detach().clone()) + # Stack the results + return ( + torch.stack(y_q_list), + torch.stack(dgrad_list), + torch.stack(wgrad_list), + torch.stack(bgrad_list) if bgrad_list is not None else None, + ) + @classmethod def run_linear( cls, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 500b25f58..5d197c4a6 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -37,17 +37,19 @@ Float8Quantizer, ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import get_device_compute_capability import transformer_engine_torch as tex from torch.utils.cpp_extension import IS_HIP_EXTENSION # Import utility functions -from utils import dtype_tols, make_recipe, reset_rng_states +from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states -# Check if FP8 is supported +# Check for supported quantization schemes fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() # Supported data types _dtypes: list[torch.dtype] = [torch.float32, torch.float16] @@ -63,6 +65,8 @@ _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: _quantization_list.append("mxfp8") +if nvfp4_available: + _quantization_list.append("nvfp4") def maybe_skip_quantization( @@ -70,6 +74,7 @@ def maybe_skip_quantization( *, dims: Optional[Iterable[int] | int] = None, device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, ) -> None: """Skip test case if a quantization scheme is not supported""" @@ -77,12 +82,17 @@ def maybe_skip_quantization( if quantization is None: return - # Check if quantization scheme is supported + # Check if quantization scheme is supported on device + if device is not None and torch.device(device).type != "cuda": + pytest.skip("Quantization is only supported on CUDA devices") if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if quantization == "nvfp4" and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + # Check dims if dims is not None: if not isinstance(dims, Iterable): dims = (dims,) @@ -92,10 +102,14 @@ def maybe_skip_quantization( elif quantization == "mxfp8": if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") + elif quantization == "nvfp4": + if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: + pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") - # Check if device is supported - if device is not None and torch.device(device).type != "cuda": - pytest.skip("Quantization is only supported on CUDA devices") + # Check dtype + if dtype is not None: + if quantization == "nvfp4" and dtype != torch.bfloat16: + pytest.skip("NVFP4 quantization is only supported with BF16 data") @torch.no_grad() @@ -145,6 +159,14 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + elif quantization == "nvfp4": + test = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") if isinstance(test, QuantizedTensor) and not test_is_quantized: @@ -399,12 +421,12 @@ def test_fp8_scale_update( torch.testing.assert_close( y, torch.full_like(y, y_val_ref), - **dtype_tols(tex.DType.kFloat8E4M3), + **quantization_tols("fp8_delayed_scaling"), ) torch.testing.assert_close( x.grad, torch.full_like(x.grad, dx_val_ref), - **dtype_tols(tex.DType.kFloat8E5M2), + **quantization_tols("fp8_delayed_scaling"), ) # Check that scaling factors match expected @@ -438,7 +460,8 @@ def test_dtype_cast( # Skip invalid configurations in_shape = (size, size) with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=init_dtype) + maybe_skip_quantization(quantization, dtype=final_dtype) # Random data dtype = torch.float32 @@ -506,7 +529,8 @@ def test_pyt_autocast( # Skip invalid configurations in_shape = (size, size) quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=model_dtype) + maybe_skip_quantization(quantization, dtype=autocast_dtype) # Construct operation recipe = make_recipe(quantization) @@ -562,7 +586,7 @@ def test_identity( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -628,7 +652,7 @@ def test_reshape( # Skip invalid configurations if memory_format == torch.channels_last and len(in_shape) != 4: pytest.skip("torch.channels_last only supports 4D tensors") - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, device=device, dtype=dtype) with_quantization = quantization is not None # Random data @@ -694,7 +718,7 @@ def test_bias( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -756,7 +780,7 @@ def test_quantize( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, device=device, dtype=dtype) if quantization == "mxfp8": maybe_skip_quantization(quantization, dims=in_shape) @@ -823,7 +847,7 @@ def _test_basic_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) quantization_needed = any( ( @@ -903,7 +927,7 @@ def _test_basic_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute or quantized_output or quantized_grad_input: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1024,7 +1048,7 @@ def test_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantization is None and (quantized_compute or quantized_weight): pytest.skip("Quantization scheme is not specified") @@ -1091,7 +1115,7 @@ def test_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1128,7 +1152,7 @@ def test_layer_norm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1189,7 +1213,7 @@ def test_layer_norm( # Expected numerical error tols = dtype_tols(dtype) if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1298,7 +1322,7 @@ def test_rmsnorm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1354,7 +1378,7 @@ def test_rmsnorm( # Explicit checks for quantization if quantized_compute: - tols = dtype_tols(y_test._quantizer.dtype) + tols = quantization_tols(quantization) expected_tensor_cls = { Float8Quantizer:Float8Tensor, Float8CurrentScalingQuantizer:Float8Tensor, @@ -1441,7 +1465,7 @@ def test_add_extra_input( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x1_ref, x1_test = make_reference_and_test_tensors( @@ -1480,8 +1504,11 @@ def test_add_extra_input( # Check results tols = dtype_tols(dtype) - if with_quantization: - tols = dtype_tols(x1_test._fp8_dtype) + if in_place: + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"): + tols = dtype_tols(x1_test._fp8_dtype) + elif quantization == "nvfp4": + tols = dtype_tols(x1_test._fp4_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") @@ -1510,7 +1537,7 @@ def test_make_extra_output( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1583,7 +1610,7 @@ def test_activation( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) if cache_quantized_input: maybe_skip_quantization("fp8_current_scaling", device=device) @@ -1657,8 +1684,10 @@ def test_activation( # Expected numerical error tols = dtype_tols(dtype) - if quantized_compute or cache_quantized_input: - tols = dtype_tols(tex.DType.kFloat8E4M3) + if quantized_compute: + tols = quantization_tols(quantization) + elif cache_quantized_input: + tols = quantization_tols("fp8_current_scaling") # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1689,7 +1718,7 @@ def test_swiglu( quantized_compute = quantization is not None if not quantized_compute and (quantize_forward or quantize_backward): pytest.skip("Quantization scheme has not been provided") - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1723,7 +1752,7 @@ def test_swiglu( # Expected numerical error tols = dtype_tols(dtype) if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1791,7 +1820,7 @@ def test_dropout( # Skip invalid configurations quantized_input = quantization is not None - maybe_skip_quantization(quantization, dims=shape, device=device) + maybe_skip_quantization(quantization, dims=shape, device=device, dtype=dtype) # Random data # Note: Shift values to make sure inputs are non-zero @@ -1882,7 +1911,7 @@ def test_forward_linear_bias_activation( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if dtype not in (torch.float16, torch.bfloat16): pytest.skip( @@ -1953,7 +1982,7 @@ def test_forward_linear_bias_activation( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1989,7 +2018,7 @@ def test_forward_linear_bias_add( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2064,7 +2093,7 @@ def test_forward_linear_bias_add( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2102,7 +2131,7 @@ def test_forward_linear_scale_add( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2170,7 +2199,7 @@ def test_forward_linear_scale_add( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2203,7 +2232,7 @@ def test_backward_activation_bias( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, device=device, dtype=dtype) if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0): pytest.skip("Unsupported tensor size for MXFP8") @@ -2265,7 +2294,7 @@ def test_backward_activation_bias( # Expected numerical error tols = dtype_tols(dtype) if with_quantization: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2384,7 +2413,7 @@ def test_backward_linear_add( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2452,7 +2481,7 @@ def test_backward_linear_add( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y1_test = y1_test.to(dtype=torch.float64, device="cpu") @@ -2487,7 +2516,7 @@ def test_backward_linear_scale( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2547,7 +2576,7 @@ def test_backward_linear_scale( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2588,7 +2617,7 @@ def test_linear( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) # Construct model @@ -2714,7 +2743,7 @@ def test_layernorm_mlp( ffn_shape = in_shape[:-1] + (ffn_hidden_size,) # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=ffn_shape, device=device) quantization_needed = quantized_compute or quantized_weight if quantization is None and quantization_needed: diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index c59bf376a..cfbb8094b 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -21,6 +21,7 @@ fp8_model_init, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.utils import is_fp8_fnuz from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear @@ -502,3 +503,39 @@ def test_quantizer_update(self, module_class): y = module(x, [batch_size]) else: y = module(x) + + +fp4_available, reason_for_no_fp4 = FP8GlobalStateManager.is_nvfp4_available() + + +@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (304, 304), + (320, 256), + # # largest tile + (8192, 8192), + ], +) +def test_fp4_dequantize(dtype, M, N): + q = NVFP4Quantizer() + a = torch.rand((M, N)).cuda().to(dtype=dtype) + starting_tensor = q(a) + dequantized_tensor = starting_tensor.dequantize() + new_tensor = q(dequantized_tensor) + torch.testing.assert_close( + new_tensor._rowwise_data, + starting_tensor._rowwise_data, + rtol=0, + atol=0, + ) + new_dequantized_tensor = new_tensor.dequantize() + torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index a7d762c3d..dbc1ed42f 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -90,9 +90,19 @@ def is_fp8_supported(config: ModelConfig): "large": ModelConfig(2, 128, 4, 128, num_layers=1), } + +def nvfp4_vanilla(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) + fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -382,6 +392,8 @@ def test_sanity_layernorm_linear( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -410,6 +422,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) @@ -440,6 +454,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") use_fp8 = fp8_recipe is not None with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): @@ -479,6 +495,8 @@ def test_sanity_grouped_linear( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4(): + pytest.skip("NVFP4 not supported for grouped linear") use_fp8 = fp8_recipe is not None with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): @@ -529,6 +547,8 @@ def test_sanity_layernorm_mlp( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -571,6 +591,8 @@ def test_sanity_gpt( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -632,6 +654,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization): pytest.skip(reason_for_no_fp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -686,6 +710,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization): pytest.skip(reason_for_no_fp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -737,6 +763,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -767,6 +795,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model): if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -801,6 +831,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -835,6 +867,8 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 684c15737..96f41c72e 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -24,6 +24,7 @@ get_attention_backend, AttentionParams, AttentionLogging, + check_set_window_size, ) from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend @@ -78,6 +79,8 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: # Transformer Engine dtypes if isinstance(dtype, tex.DType): + if dtype == tex.DType.kFloat4E2M1: + return dict(rtol=0.25, atol=0.125) # epsilon = 0.25 dtype = { tex.DType.kByte: torch.uint8, tex.DType.kInt32: torch.int32, @@ -100,10 +103,25 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: if dtype == torch.float8_e4m3fn or dtype == torch.float8_e4m3fnuz: return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 if dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 + return dict(rtol=0.25, atol=0.125) # epsilon = 0.125 raise ValueError(f"Unsupported dtype ({dtype})") +def quantization_tols(name: str) -> dict[str, float]: + """Estimated numerical error for a quantization scheme""" + if name in ( + "fp8", + "fp8_delayed_scaling", + "fp8_current_scaling", + "mxfp8", + "mxfp8_block_scaling", + ): + return dtype_tols(tex.DType.kFloat8E4M3) + if name == "nvfp4": + return dtype_tols(tex.DType.kFloat4E2M1) + raise ValueError(f"Unsupported quantization scheme ({name})") + + def make_recipe(name: Optional[str]) -> Optional[Recipe]: """Make recipe for quantization scheme""" if name is None: @@ -123,6 +141,12 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: ) if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling() + if name == "nvfp4": + return transformer_engine.common.recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + ) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -143,6 +167,31 @@ def reset_rng_states() -> None: torch.cuda.set_rng_state(cuda_rng_state) +def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): + if not is_fp8: + torch.testing.assert_close(a, b, atol=atol, rtol=rtol) + return + + try: + if a.dtype != b.dtype: + a = a.to(b.dtype) + torch.testing.assert_close(a, b, atol=atol, rtol=rtol) + except Exception as e: + logging.debug(e) + + rmse = torch.sqrt((a - b).square().mean()).item() + logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) + rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + assert rmse < rmse_tol * rmse_range, ( + name_a + + " vs " + + name_b + + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + rmse, rmse_tol * rmse_range, rmse_tol, rmse_range + ) + ) + + class ModelConfig: def __init__( self, @@ -153,12 +202,15 @@ def __init__( max_seqlen_kv: int = None, num_gqa_groups: int = None, head_dim_v: int = None, + softmax_type: str = "vanilla", dropout_p: float = 0.0, attn_mask_type: str = "no_mask", attn_bias_type: str = "no_bias", alibi_type: str = "none", bias_shape: str = "1hss", window_size: Tuple[int, int] = (-1, -1), + context_parallel: bool = False, + cp_comm_type: str = "p2p", total_requests: int = None, max_ctx_len: int = None, num_layers: int = 1, @@ -177,13 +229,16 @@ def __init__( self.kv_channels = (self.head_dim_qk, self.head_dim_v) self.hidden_size = self.num_heads * self.head_dim_qk self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v + self.softmax_type = softmax_type self.dropout_p = dropout_p self.attn_mask_type = attn_mask_type self.attn_bias_type = attn_bias_type self.alibi_type = alibi_type self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross" self.bias_shape = bias_shape - self.window_size = window_size + self.window_size = check_set_window_size(self.attn_mask_type, window_size) + self.context_parallel = context_parallel + self.cp_comm_type = cp_comm_type self.total_requests = total_requests self.max_ctx_len = max_ctx_len self.num_layers = num_layers @@ -221,9 +276,7 @@ def get_available_attention_backends( config: ModelConfig, qkv_dtype: torch.dtype, qkv_layout: str, - window_size: Tuple[int, int] = (-1, -1), pad_between_seqs: bool = False, - context_parallel: bool = False, deterministic: bool = False, fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, @@ -276,19 +329,21 @@ def test(): head_dim_qk=config.head_dim_qk, head_dim_v=config.head_dim_v, attn_mask_type=config.attn_mask_type, - window_size=window_size, + window_size=config.window_size, alibi_slopes_shape=alibi_slopes_shape, core_attention_bias_type=config.attn_bias_type, core_attention_bias_shape=core_attention_bias_shape, core_attention_bias_requires_grad=core_attention_bias_requires_grad, pad_between_seqs=pad_between_seqs, attention_dropout=config.dropout_p, - context_parallel=context_parallel, + context_parallel=config.context_parallel, + cp_comm_type=config.cp_comm_type, deterministic=deterministic, fp8=fp8, fp8_meta=fp8_meta, is_training=is_training, inference_params=inference_params, + softmax_type=config.softmax_type, ) ( use_flash_attention, diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cefec6d06..bd2a023bb 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -120,6 +120,30 @@ endif() # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) +if(USE_CUDA) + # NVIDIA MathDX include directory (from Python package install location) + if(NOT DEFINED MATHDX_INCLUDE_DIR) + execute_process( + COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx + OUTPUT_VARIABLE _PIP_SHOW_MATHDX + ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR + RESULT_VARIABLE _PIP_SHOW_MATHDX_RES + OUTPUT_STRIP_TRAILING_WHITESPACE) + if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0) + message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}") + endif() + string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}") + if(NOT _MATHDX_LOC_MATCH) + message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}") + endif() + set(MATHDX_LOCATION "${CMAKE_MATCH_1}") + set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include") + endif() + if(NOT EXISTS "${MATHDX_INCLUDE_DIR}") + message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.") + endif() +endif() + # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) @@ -145,6 +169,7 @@ list(APPEND transformer_engine_SOURCES fused_attn/kv_cache.cu activation/relu.cu activation/swiglu.cu + gemm/config.cpp gemm/cublaslt_gemm.cu normalization/common.cpp normalization/layernorm/ln_api.cpp @@ -178,6 +203,7 @@ list(APPEND transformer_engine_SOURCES cudnn_utils.cpp transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu + transpose/quantize_transpose_vector_blockwise_fp4.cu fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu fused_attn/fused_attn_fp8.cu @@ -185,6 +211,9 @@ list(APPEND transformer_engine_SOURCES fused_attn/utils.cu gemm/cutlass_grouped_gemm.cu util/cuda_nvml.cpp + recipe/nvfp4.cu + hadamard_transform/hadamard_transform.cu + hadamard_transform/hadamard_transform_cast_fusion.cu comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers.cu @@ -261,7 +290,8 @@ target_link_libraries(transformer_engine PUBLIC CUDNN::cudnn_all) target_include_directories(transformer_engine PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR}) target_include_directories(transformer_engine SYSTEM PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index ec29e6e12..56369db27 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -64,6 +64,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl #endif _comm_created = true; } + + initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm); +} + +void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { _use_ce = static_cast(use_ce); _num_comm_sm = num_comm_sm; _cga_size = comm_cga_size; @@ -278,6 +287,11 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm); +} + +void CommOverlapBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm) { _rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, @@ -288,7 +302,9 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); + if (_ub_comm->myrank == 0) { + printf("!!! [UB] Register UBuf %d\n", _ub_reg); + } _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); NVTE_CHECK_CUDA( @@ -640,6 +656,11 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, comm_type, aggregate); +} + +void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate) { _is_p2p = true; _is_reduce_scatter = comm_type == CommOverlapType::RS; _aggregate = aggregate; @@ -647,28 +668,28 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, // Create workspace tensor with userbuffer NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); - int buffer_chunk_bytes = buffer_bytes / tp_size; - _num_ubuf_chunks = tp_size; + int buffer_chunk_bytes = buffer_bytes / _tp_size; + _num_ubuf_chunks = _tp_size; if (_is_reduce_scatter) { // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk // outputs for reduction at the end of the pipelining. - buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1); - _num_ubuf_chunks = tp_size * 2 - 1; + buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1); + _num_ubuf_chunks = _tp_size * 2 - 1; } void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); + if (_rank == 0) printf("!!! [UBP2P] UBuf %d\n", _ub_reg); _ubuf = TensorWrapper( buffer_ptr, - std::vector{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]}, buffer_dtype); // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); for (int i = 0; i < _num_ubuf_chunks; i++) { _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), - std::vector{buffer_shape[0] / tp_size, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size, buffer_shape[1]}, buffer_dtype)); ubuf_byte_ptr += buffer_chunk_bytes; } @@ -691,7 +712,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); } - for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { + for (int i = 0; i < _stream_compute.size(); i++) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); _stream_send.push_back(std::move(stream)); @@ -711,6 +732,38 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { } } +void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, + bool local_chunk, bool rowwise) { + // Check element size + const size_t element_size = source.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), + " bytes)"); + + // Input data + const size_t source_size = source.numel(); + const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); + + // Userbuffers data + void *dst_ptr; + if (local_chunk) { + NVTE_CHECK(_ubufs[_tp_id].numel() == source_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + dst_ptr = _ubufs[_tp_id].dptr(); + } else { + NVTE_CHECK(_ubuf.numel() == source_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")"); + dst_ptr = _ubuf.dptr(); + } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, + cudaMemcpyDeviceToDevice, stream)); +} + TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, size_t chunk_id) { // Start with a chunk of the source tensor @@ -851,6 +904,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, const bool do_gelu = pre_gelu_out.numel() > 0; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", + _ubuf.numel(), " elements but got ", B_copy.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), + "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, + "-bit data type but got ", B_copy.element_size() * 8, "-bit"); + } + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); @@ -919,12 +981,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } else { @@ -972,16 +1028,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } + // Copy all-gathered B from communication buffer into auxiliary output + if (B_copy.numel() > 0) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + _ub_comm->sms = ori_sms; for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 1ce89c512..6c7bed55a 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -670,9 +670,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * reinterpret_cast(&memhndl), sizeof(cudaIpcMemHandle_t), comm->comm_intra); + // Check for NVLINK support before attempting IPC operations + if (comm->nvsize > 1) { + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + cudaDeviceProp deviceProp; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, current_device)); + bool peer_access_available = false; + for (int i = 0; i < comm->nvsize; i++) { + if (i != comm->nvrank) { + int can_access_peer; + cudaError_t peer_result = cudaDeviceCanAccessPeer(&can_access_peer, current_device, i); + if (peer_result == cudaSuccess && can_access_peer) { + peer_access_available = true; + break; + } + } + } + if (!peer_access_available) { + free(tmp); + NVTE_ERROR( + "No peer-to-peer access available between GPUs. This platform does not support the " + "GPU-to-GPU " + "communication required for multi-GPU userbuffers. Consider using single-GPU mode."); + return 1; + } + } + for (int i = 0; i < comm->nvsize; i++) { if (i != comm->nvrank) { - NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*) + NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], cudaIpcMemLazyEnablePeerAccess)); } } @@ -693,4 +720,5 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * comm->mem_ptr[hndl] = *gpubuff; return comm->free_region++; + printf("***** Returning *****\n"); } diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index e67694c38..af3a51373 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -42,6 +42,10 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { return CUDA_R_8F_E4M3; case DType::kFloat8E5M2: return CUDA_R_8F_E5M2; +#if CUDA_VERSION >= 12080 + case DType::kFloat4E2M1: + return CUDA_R_4F_E2M1; +#endif default: NVTE_ERROR("Invalid type"); } @@ -165,7 +169,9 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_num_bits) { + const uint32_t offset_elems, const size_t type_num_bits, + const CUtensorMapSwizzle swizzle) { + cuda_driver::ensure_context_exists(); // Get a function pointer to the cuTensorMapEncodeTiled driver API // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { @@ -174,6 +180,8 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, }(); // rank is the number of dimensions of the array constexpr uint32_t rank = 2; + + // Dimension for the packed data types must reflect the number of individual U# values. uint64_t size[rank] = {globalX, globalY}; // The stride is the number of bytes to traverse from the first element of one row to the next @@ -212,7 +220,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, // Swizzling can be used to avoid shared memory bank conflicts. - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + swizzle, // L2 Promotion can be used to widen the effect of a cache-policy to a wider // set of L2 cache lines. diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index ce510334b..d1208f1a0 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -54,8 +54,14 @@ inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { return mode == NVTE_DELAYED_TENSOR_SCALING; } +inline bool is_nvfp4_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; } + +inline bool is_mxfp8_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } + inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } +inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; } + inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", end, " in a vector with ", shape.size(), " entries"); @@ -114,6 +120,7 @@ struct Tensor { SimpleTensor data; SimpleTensor columnwise_data; SimpleTensor amax; + SimpleTensor columnwise_amax; SimpleTensor scale; SimpleTensor scale_inv; SimpleTensor columnwise_scale_inv; @@ -125,6 +132,7 @@ struct Tensor { : data(), columnwise_data(), amax(nullptr, {1}, DType::kFloat32), + columnwise_amax(nullptr, {1}, DType::kFloat32), scale(nullptr, {1}, DType::kFloat32), scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32), @@ -135,6 +143,7 @@ struct Tensor { data.clear(); columnwise_data.clear(); amax.clear(); + columnwise_amax.clear(); scale.clear(); scale_inv.clear(); columnwise_scale_inv.clear(); @@ -180,6 +189,7 @@ struct Tensor { * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). */ switch (scaling_mode) { + case NVTE_NVFP4_1D_SCALING: case NVTE_DELAYED_TENSOR_SCALING: if (!has_data() && has_columnwise_data()) { std::vector ret; @@ -195,7 +205,6 @@ struct Tensor { } break; case NVTE_MXFP8_1D_SCALING: - case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: if (!has_data() && has_columnwise_data()) { return columnwise_data.shape; } else { @@ -267,12 +276,18 @@ struct QuantizationConfig { NVTETensor noop_tensor = nullptr; Float8BlockScaleTensorFormat float8_block_scale_tensor_format = Float8BlockScaleTensorFormat::GEMM_READY; + NVTETensor rng_state = nullptr; + bool nvfp4_2d_quantization = false; + bool stochastic_rounding = false; static constexpr size_t attr_sizes[] = { - sizeof(bool), // force_pow_2_scales - sizeof(float), // amax_epsilon - sizeof(NVTETensor), // noop_tensor - sizeof(Float8BlockScaleTensorFormat) // float8_block_scale_tensor_format + sizeof(bool), // force_pow_2_scales + sizeof(float), // amax_epsilon + sizeof(NVTETensor), // noop_tensor + sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format + sizeof(NVTETensor), // rng_seed and offset + sizeof(bool), // nvfp4_2d_quantization + sizeof(bool) // stochastic_rounding }; }; @@ -322,6 +337,8 @@ using fp8e8m0 = __nv_fp8_e8m0; #endif // CUDA_VERSION >= 12080 #if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; #endif //FP4_TYPE_SUPPORTED #else using bf16 = hip_bfloat16; @@ -370,6 +387,7 @@ struct TypeExtrema; template <> struct TypeExtrema { static constexpr float max = 6.0f; + static constexpr float max_inverse = 1.0 / max; }; #endif @@ -377,16 +395,20 @@ template <> struct TypeExtrema { #ifndef __HIP_PLATFORM_AMD__ static constexpr float max = 448.0f; + static constexpr float max_inverse = 1.0 / max; #elif defined(__HIP_DEVICE_COMPILE__) static constexpr float maxNorm = te_fp8_fnuz() ? 240.0f : 448.0f; + static constexpr float max_inverse = 1.0 / maxNorm; #else static float maxNorm; + static constexpr float max_inverse = 1.0 / 448.0f; #endif }; template <> struct TypeExtrema { static constexpr float max = 57344.0f; + static constexpr float max_inverse = 1.0 / max; }; template <> @@ -600,6 +622,18 @@ struct TypeInfo { NVTE_ERROR("Invalid type."); \ } +// Add a pack_size argument to select the packed type for FP4 +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat4E2M1: { \ + using type = __nv_fp4x2_storage_t; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ @@ -764,10 +798,11 @@ void checkCuDriverContext(CUstream stream); CUtensorMapDataType get_CUtensorMapDataType(DType dtype); // Set up parameters to create TMA descriptor. -void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, - const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, - const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_num_bits); +void create_2D_tensor_map( + CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, + const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, + const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits, + const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); #endif //#ifdef __HIP_PLATFORM_AMD__ bool is_supported_by_CC_100(); diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 795697635..77cd8d235 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -135,9 +135,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -175,7 +176,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // TODO (cyang): add is_training to nvte_get_fused_attn_backend // sm90: fwd d<=256, bwd d=128 only // sm100: fwd d<=128, bwd d<=128 - ((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) || + ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || + (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || @@ -183,7 +185,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset && + !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000)) { if (cudnn_runtime_version >= 8900) { @@ -213,7 +215,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && - !requires_64bit_ragged_offset) { + !requires_64bit_ragged_offset && + (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) { flag_m512 = true; } if ( @@ -363,7 +366,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // check 64-bit ragged offset support (supported_ragged_offset_size) && // 9.10.0/9.10.1: known bugs with SDPA F16 - (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) { + (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001) && + // softmax type + // pre-9.13.1: vanilla + // 9.13.1+: vanilla, off-by-one, learnable + (cudnn_runtime_version >= 91301 || + (cudnn_runtime_version < 91301 && + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { @@ -405,14 +414,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with packed QKV -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - const NVTETensor rng_state, size_t max_seqlen, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, float attn_scale, + float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; @@ -421,6 +432,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); const Tensor *input_QKV = convertNVTETensorCheck(QKV); const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); @@ -447,8 +459,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -463,9 +475,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd_qkvpacked( b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O, - Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, - stream, handle); + attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -487,10 +499,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); @@ -505,6 +518,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con Tensor *input_output_dP = convertNVTETensorCheck(dP); Tensor *output_dQKV = convertNVTETensorCheck(dQKV); Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); auto ndim = input_QKV->data.shape.size(); @@ -529,8 +543,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -543,19 +557,22 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } fused_attn_arbitrary_seqlen_bwd_qkvpacked( b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO, - input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded, - input_rng_state, wkspace, stream, handle); + softmax_type, window_size_left, window_size_right, deterministic, input_QKV, input_O, + input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQKV, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, + stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -580,14 +597,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } // NVTE fused attention FWD with packed KV void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -600,6 +618,7 @@ void nvte_fused_attn_fwd_kvpacked( const Tensor *input_Q = convertNVTETensorCheck(Q); const Tensor *input_KV = convertNVTETensorCheck(KV); const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); @@ -660,8 +679,8 @@ void nvte_fused_attn_fwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -677,10 +696,11 @@ void nvte_fused_attn_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, - input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); + dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, + window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O, + Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -702,12 +722,12 @@ void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -723,6 +743,7 @@ void nvte_fused_attn_bwd_kvpacked( Tensor *output_dQ = convertNVTETensorCheck(dQ); Tensor *output_dKV = convertNVTETensorCheck(dKV); Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); size_t b = input_cu_seqlens_q->data.shape[0] - 1; @@ -755,8 +776,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -770,20 +791,23 @@ void nvte_fused_attn_bwd_kvpacked( #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } fused_attn_arbitrary_seqlen_bwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, - input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); + bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, + input_Q, input_KV, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ, + output_dKV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, + handle); #else const char *err_msg = "cuDNN 8.9.3 is required for BF16/FP16 fused attention " @@ -809,16 +833,17 @@ void nvte_fused_attn_bwd_kvpacked( } // NVTE fused attention FWD with separate Q, K and V void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -832,6 +857,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_K = convertNVTETensorCheck(K); const Tensor *input_V = convertNVTETensorCheck(V); const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); @@ -886,8 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -903,10 +929,11 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, - input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); + dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, + window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, + Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -928,14 +955,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -953,6 +981,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *output_dK = convertNVTETensorCheck(dK); Tensor *output_dV = convertNVTETensorCheck(dV); Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); auto ndim = input_Q->data.shape.size(); @@ -978,8 +1007,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -993,19 +1022,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, - input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, - output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, + deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, + input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 4e6c3c858..ba0f84578 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -54,10 +54,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, - void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, - void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, + void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrSoftmaxStats, + void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -75,6 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( is_causal = true; is_bottom_right = false; } + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (is_training && dropout_probability != 0.0f); NVTE_QKV_Format q_format = nvte_get_q_format(layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); @@ -98,8 +100,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( s_q = is_ragged_q ? max_t_q : s_q; s_kv = is_ragged_kv ? max_t_kv : s_kv; } - const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; + const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { FADescriptor_v1 descriptor{b, h, @@ -122,11 +124,14 @@ void fused_attn_arbitrary_seqlen_fwd_impl( layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, true, tensorType, - tensorType}; + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -138,6 +143,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // O std::shared_ptr, // Stats std::shared_ptr, // bias + std::shared_ptr, // softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // page_table_k @@ -168,7 +174,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - std::shared_ptr Q, K, V, attn_scale; + std::shared_ptr Q, K, V, attn_scale, softmax_offset; std::shared_ptr bias, seq_q, seq_kv; std::shared_ptr page_table_k, page_table_v; std::shared_ptr offset_q, offset_k, offset_v, offset_o, @@ -302,6 +308,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_options.set_sink_token(softmax_offset); + } + auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); std::vector o_stride(4); @@ -338,6 +353,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); auto Stats_tuple = std::make_tuple(Stats); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); + auto softmax_offset_tuple = + is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v) @@ -358,17 +375,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat( - std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, - page_table_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, page_table_tuple, offset_qo_tuple, + offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v, - offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = - get_graph(sdpa_f16_fprop_cache, descriptor); + auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, softmax_offset, seq_q, seq_kv, + page_table_k, page_table_v, offset_q, offset_o, offset_k, offset_v, offset_stats, + dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed // n.b. Care should be taken to align each of the added worksapce tensors to their type. @@ -473,6 +491,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; } + + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { NVTE_ERROR(e.what()); @@ -483,14 +506,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, - void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, - void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, - cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, + void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, + void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, + void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, + void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, + void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, + void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -506,6 +529,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( is_causal = true; is_bottom_right = false; } + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (dropout_probability != 0.0f); NVTE_QKV_Format q_format = nvte_get_q_format(layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); @@ -558,11 +582,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, deterministic, tensorType, - tensorType}; + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -579,6 +606,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr, // dV std::shared_ptr, // bias std::shared_ptr, // dBias + std::shared_ptr, // softmax_offset + std::shared_ptr, // d_softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // offset_q @@ -608,7 +637,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_compute_data_type(fe::DataType_t::FLOAT); std::shared_ptr q, k, v, o, dO, stats, attn_scale; - std::shared_ptr bias, dBias, seq_q, seq_kv; + std::shared_ptr bias, dBias, softmax_offset, d_softmax_offset, + seq_q, seq_kv; std::shared_ptr offset_q, offset_k, offset_v, offset_o, offset_stats; std::shared_ptr dropout_seed, dropout_offset; @@ -771,6 +801,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_backward_options.set_sink_token(softmax_offset); + d_softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("d_softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_backward_options.set_dsink_token(d_softmax_offset); + } + auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride); @@ -796,6 +841,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr> // dV key_tensors_tuple = std::make_tuple(q, k, v, o, dO, stats, attn_scale, dQ, dK, dV); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); + auto softmax_offset_tuple = is_softmax_offset + ? std::make_tuple(softmax_offset, d_softmax_offset) + : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto offset_qo_tuple = @@ -814,17 +862,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = - std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple, - offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); + auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, offset_qo_tuple, + offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv, - offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = - get_graph(sdpa_f16_bprop_cache, descriptor); + auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, softmax_offset, + d_softmax_offset, seq_q, seq_kv, offset_q, offset_o, offset_k, offset_v, offset_stats, + dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed // n.b. Care should be taken to align each of the added worksapce tensors to their type. @@ -938,6 +986,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[dropout_offset] = devPtrDropoutOffset; } + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { NVTE_ERROR(e.what()); @@ -949,8 +1002,9 @@ using namespace transformer_engine::fused_attn; void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -977,6 +1031,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void *devPtrO = output_O->data.dptr; void *devPtrS = nullptr; @@ -990,53 +1048,50 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( max_tokens = get_max_tokens(num_tokens); } + size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = devPtrBias; + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -1050,11 +1105,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, - devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, - handle); + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, + devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, + nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1074,9 +1129,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1122,6 +1178,12 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void *devPtrSoftmaxStats = nullptr; devPtrSoftmaxStats = output_S->data.dptr; + void *devPtrSoftmaxOffset = nullptr; + void *devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; @@ -1135,11 +1197,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK, - devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, - devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, - devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic, + devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, + devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1161,12 +1223,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1192,6 +1254,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void *devPtrO = output_O->data.dptr; void *devPtrS = nullptr; @@ -1216,53 +1282,50 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( max_tokens_kv = get_max_tokens(num_tokens_kv); } + size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = devPtrBias; + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -1277,11 +1340,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, + devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1302,10 +1365,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, + Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -1359,6 +1423,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void *devPtrSoftmaxStats = nullptr; devPtrSoftmaxStats = output_S->data.dptr; + void *devPtrSoftmaxOffset = nullptr; + void *devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; @@ -1374,9 +1444,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, - devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, + devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1401,12 +1472,13 @@ void fused_attn_arbitrary_seqlen_fwd( size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1425,6 +1497,10 @@ void fused_attn_arbitrary_seqlen_fwd( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; @@ -1446,53 +1522,50 @@ void fused_attn_arbitrary_seqlen_fwd( max_tokens_kv = get_max_tokens(num_tokens_kv); } + size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = devPtrBias; + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -1507,11 +1580,11 @@ void fused_attn_arbitrary_seqlen_fwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, + devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1532,13 +1605,14 @@ void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; @@ -1577,6 +1651,12 @@ void fused_attn_arbitrary_seqlen_bwd( void *devPtrdV = output_dV->data.dptr; void *devPtrSoftmaxStats = nullptr; devPtrSoftmaxStats = output_S->data.dptr; + void *devPtrSoftmaxOffset = nullptr; + void *devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; @@ -1592,9 +1672,10 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, - devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, + devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index e1a20274f..b9658b053 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -21,17 +21,19 @@ namespace transformer_engine { void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -41,21 +43,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, + Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -66,24 +69,26 @@ void fused_attn_arbitrary_seqlen_fwd( size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index d7f098376..21c544491 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1658,8 +1658,9 @@ void fused_attn_fp8_fwd_impl_v1( void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type, - void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + cudnn_frontend::DataType_t o_tensor_type, void* workspace, size_t* workspace_size, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); @@ -1672,6 +1673,13 @@ void fused_attn_fp8_fwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); + bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); + NVTE_CHECK(is_current_scaling || is_delayed_scaling, + "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " + "kFloat8E5M2!"); try { FADescriptor_v1 descriptor{b, @@ -1695,11 +1703,14 @@ void fused_attn_fp8_fwd_impl_v1( layout, bias_type, mask_type, + NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, 0, 0, true, - fwd_tensor_type, - fwd_tensor_type}; + qkv_tensor_type, + o_tensor_type, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -1738,7 +1749,7 @@ void fused_attn_fp8_fwd_impl_v1( // otherwise, build the op_graph and the plan. Then update cache auto mha_graph = std::make_shared(); - mha_graph->set_io_data_type(fwd_tensor_type) + mha_graph->set_io_data_type(qkv_tensor_type) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); @@ -1786,7 +1797,13 @@ void fused_attn_fp8_fwd_impl_v1( descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); + + if (is_delayed_scaling) { + scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); + } + if (is_current_scaling) { + scale_o = mha_graph->tensor(1.0f); + } fe::graph::SDPA_fp8_attributes sdpa_options; sdpa_options = fe::graph::SDPA_fp8_attributes() @@ -1838,11 +1855,12 @@ void fused_attn_fp8_fwd_impl_v1( std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); + O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); + amax_s->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -1915,13 +1933,16 @@ void fused_attn_fp8_fwd_impl_v1( {descale_v, devPtrDescaleV}, {descale_s, devPtrDescaleS}, {scale_s, devPtrScaleS}, - {scale_o, devPtrScaleO}, {attn_scale, &scaling_factor}, {O, devPtrO}, {amax_s, devPtrAmaxS}, {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; + if (is_delayed_scaling) { + variant_pack[scale_o] = devPtrScaleO; + } + /* if (is_bias) { variant_pack[bias] = devPtrBias; } */ @@ -1962,8 +1983,9 @@ void fused_attn_fp8_bwd_impl_v1( void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, - void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type, - cudnn_frontend::DataType_t bwd_tensor_type, void* workspace, size_t* workspace_size, + void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, + cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -1977,6 +1999,15 @@ void fused_attn_fp8_bwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); + bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_delayed_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); + NVTE_CHECK(is_current_scaling || is_delayed_scaling, + "FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or " + "kFloat8E5M2!"); + bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); try { FADescriptor_v1 descriptor{b, @@ -2000,11 +2031,14 @@ void fused_attn_fp8_bwd_impl_v1( layout, bias_type, mask_type, + NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, 0, 0, false, - fwd_tensor_type, - bwd_tensor_type}; + qkv_tensor_type, + o_tensor_type, + do_tensor_type, + dqkv_tensor_type}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -2057,7 +2091,7 @@ void fused_attn_fp8_bwd_impl_v1( // otherwise, build the op_graph and the plan. Then update cache auto mha_graph = std::make_shared(); - mha_graph->set_io_data_type(fwd_tensor_type) + mha_graph->set_io_data_type(qkv_tensor_type) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); @@ -2097,7 +2131,8 @@ void fused_attn_fp8_bwd_impl_v1( o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); + .set_stride(o_stride) + .set_data_type(o_tensor_type)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") .set_dim({b, h, s_q, d}) @@ -2123,14 +2158,26 @@ void fused_attn_fp8_bwd_impl_v1( descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); + if (is_O_in_F16) { + descale_o = mha_graph->tensor(1.0f); + } else { + descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); + } descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); - scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); - scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); - scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + + if (is_delayed_scaling) { + scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); + scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); + scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + } + if (is_current_scaling) { + scale_dQ = mha_graph->tensor(1.0f); + scale_dK = mha_graph->tensor(1.0f); + scale_dV = mha_graph->tensor(1.0f); + } fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; sdpa_backward_options = fe::graph::SDPA_fp8_backward_attributes() @@ -2212,10 +2259,10 @@ void fused_attn_fp8_bwd_impl_v1( .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - dO->set_data_type(bwd_tensor_type); - dQ->set_data_type(bwd_tensor_type); - dK->set_data_type(bwd_tensor_type); - dV->set_data_type(bwd_tensor_type); + dO->set_data_type(do_tensor_type); + dQ->set_data_type(dqkv_tensor_type); + dK->set_data_type(dqkv_tensor_type); + dV->set_data_type(dqkv_tensor_type); std::tuple, // q std::shared_ptr, // k @@ -2296,14 +2343,10 @@ void fused_attn_fp8_bwd_impl_v1( {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, - {descale_o, devPtrDescaleO}, {descale_dO, devPtrDescaledO}, {descale_s, devPtrDescaleS}, {descale_dP, devPtrDescaledP}, {scale_s, devPtrScaleS}, - {scale_dQ, devPtrScaledQ}, - {scale_dK, devPtrScaledK}, - {scale_dV, devPtrScaledV}, {scale_dP, devPtrScaledP}, {dQ, devPtrdQ}, {dK, devPtrdK}, @@ -2314,6 +2357,15 @@ void fused_attn_fp8_bwd_impl_v1( {amax_dP, devPtrAmaxdP}, }; + if (is_delayed_scaling) { + variant_pack[scale_dQ] = devPtrScaledQ; + variant_pack[scale_dK] = devPtrScaledK; + variant_pack[scale_dV] = devPtrScaledV; + } + if (!is_O_in_F16) { + variant_pack[descale_o] = devPtrDescaleO; + } + /* if (is_bias) { variant_pack[bias] = devPtrBias; if ((bias_b == 1) && (bias_h == h)) { @@ -2364,6 +2416,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_QKV->data.dtype; + const DType O_type = output_O->data.dtype; void* devPtrQKV = input_QKV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; @@ -2430,8 +2483,8 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, @@ -2465,6 +2518,7 @@ void fused_attn_fp8_bwd_qkvpacked( cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_QKV->data.dtype; + const DType dO_type = input_dO->data.dtype; const DType dQKV_type = output_dQKV->data.dtype; void* devPtrQKV = input_QKV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); @@ -2482,7 +2536,11 @@ void fused_attn_fp8_bwd_qkvpacked( void* devPtrDescaleV = input_QKV->scale_inv.dptr; void* devPtrO = input_O->data.dptr; - void* devPtrDescaleO = input_O->scale_inv.dptr; + const DType O_type = input_O->data.dtype; + void* devPtrDescaleO = nullptr; + if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { + devPtrDescaleO = input_O->scale_inv.dptr; + } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; @@ -2525,7 +2583,8 @@ void fused_attn_fp8_bwd_qkvpacked( devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); + get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, @@ -2563,6 +2622,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_Q->data.dtype; + const DType O_type = output_O->data.dtype; void* devPtrQ = input_Q->data.dptr; void* devPtrKV = input_KV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); @@ -2631,8 +2691,8 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, @@ -2669,6 +2729,7 @@ void fused_attn_fp8_bwd_kvpacked( cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_Q->data.dtype; + const DType dO_type = input_dO->data.dtype; const DType dQKV_type = output_dQ->data.dtype; void* devPtrQ = input_Q->data.dptr; void* devPtrKV = input_KV->data.dptr; @@ -2686,7 +2747,11 @@ void fused_attn_fp8_bwd_kvpacked( void* devPtrDescaleV = input_KV->scale_inv.dptr; void* devPtrO = input_O->data.dptr; - void* devPtrDescaleO = input_O->scale_inv.dptr; + const DType O_type = input_O->data.dtype; + void* devPtrDescaleO = nullptr; + if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { + devPtrDescaleO = input_O->scale_inv.dptr; + } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; @@ -2731,7 +2796,8 @@ void fused_attn_fp8_bwd_kvpacked( devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); + get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, @@ -2820,6 +2886,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); const DType QKV_type = input_Q->data.dtype; + const DType O_type = output_O->data.dtype; size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); @@ -2829,8 +2896,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, @@ -2876,7 +2943,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrDescaleV = input_Q->scale_inv.dptr; void* devPtrO = input_O->data.dptr; - void* devPtrDescaleO = input_O->scale_inv.dptr; + const DType O_type = input_O->data.dtype; + void* devPtrDescaleO = nullptr; + if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { + devPtrDescaleO = input_O->scale_inv.dptr; + } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; @@ -2909,6 +2980,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); const DType QKV_type = input_Q->data.dtype; + const DType dO_type = input_dO->data.dtype; const DType dQKV_type = output_dQ->data.dtype; size_t workspace_size = 0; @@ -2922,7 +2994,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); + get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 678b63691..f03774f8e 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -107,23 +107,28 @@ struct FADescriptor_v1 { NVTE_QKV_Layout layout; NVTE_Bias_Type bias_type; NVTE_Mask_Type mask_type; + NVTE_Softmax_Type softmax_type; std::int64_t window_size_left; std::int64_t window_size_right; bool deterministic; - cudnn_frontend::DataType_t fwd_tensor_type; - cudnn_frontend::DataType_t bwd_tensor_type; + cudnn_frontend::DataType_t qkv_tensor_type; + cudnn_frontend::DataType_t o_tensor_type; + cudnn_frontend::DataType_t do_tensor_type; + cudnn_frontend::DataType_t dqkv_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, - attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, - window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < + attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, + window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, + o_tensor_type, do_tensor_type, dqkv_tensor_type) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, - rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, - rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, - rhs.bwd_tensor_type); + rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, + rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, + rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, + rhs.dqkv_tensor_type); } }; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index bb5e22887..da40a223d 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -276,9 +276,10 @@ void log_fused_attn_config( // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { using namespace transformer_engine; // by default, fused attn is enabled @@ -339,14 +340,15 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // NVTE fused attention FWD with packed QKV -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; @@ -384,8 +386,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_qkvpacked( @@ -419,10 +421,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); @@ -468,8 +471,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)){ @@ -505,14 +508,17 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } // NVTE fused attention FWD with packed KV -void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { +void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -556,8 +562,8 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_kvpacked( @@ -596,12 +602,12 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -650,8 +656,8 @@ void nvte_fused_attn_bwd_kvpacked( std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -694,16 +700,16 @@ void nvte_fused_attn_bwd_kvpacked( // NVTE fused attention FWD with separate Q, K and V void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -740,8 +746,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd( @@ -780,14 +786,14 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, - size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, + float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -830,8 +836,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp new file mode 100644 index 000000000..cf211beaf --- /dev/null +++ b/transformer_engine/common/gemm/config.cpp @@ -0,0 +1,116 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "./config.h" + +#include +#include + +#include + +#include "../util/logging.h" + +NVTEMatmulConfig nvte_create_matmul_config() { return new transformer_engine::MatmulConfig; } + +void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + void *buf, size_t size_in_bytes, size_t *size_written) { + // Write attribute size + NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ", + static_cast(attr), ")"); + NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); + const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr]; + *size_written = attr_size; + + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } + + // Check buffer size + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + + // Write to buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)"); + const auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEMatmulConfigBiasTensor: + std::memcpy(buf, &config_.bias_tensor, attr_size); + break; + case kNVTEMatmulConfigDBiasTensor: + std::memcpy(buf, &config_.dbias_tensor, attr_size); + break; + case kNVTEMatmulConfigWithGELUEpilogue: + std::memcpy(buf, &config_.with_gelu_epilogue, attr_size); + break; + case kNVTEMatmulConfigWithDGELUEpilogue: + std::memcpy(buf, &config_.with_dgelu_epilogue, attr_size); + break; + case kNVTEMatmulConfigEpilogueAuxTensor: + std::memcpy(buf, &config_.epilogue_aux_tensor, attr_size); + break; + case kNVTEMatmulConfigUseSplitAccumulator: + std::memcpy(buf, &config_.use_split_accumulator, attr_size); + break; + case kNVTEMatmulConfigSMCount: + std::memcpy(buf, &config_.sm_count, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes) { + // Check attribute and buffer + NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ", + static_cast(attr), ")"); + const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr]; + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); + + // Read from buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)"); + auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEMatmulConfigBiasTensor: + std::memcpy(&config_.bias_tensor, buf, attr_size); + break; + case kNVTEMatmulConfigDBiasTensor: + std::memcpy(&config_.dbias_tensor, buf, attr_size); + break; + case kNVTEMatmulConfigWithGELUEpilogue: + std::memcpy(&config_.with_gelu_epilogue, buf, attr_size); + break; + case kNVTEMatmulConfigWithDGELUEpilogue: + std::memcpy(&config_.with_dgelu_epilogue, buf, attr_size); + break; + case kNVTEMatmulConfigEpilogueAuxTensor: + std::memcpy(&config_.epilogue_aux_tensor, buf, attr_size); + break; + case kNVTEMatmulConfigUseSplitAccumulator: + std::memcpy(&config_.use_split_accumulator, buf, attr_size); + break; + case kNVTEMatmulConfigSMCount: + std::memcpy(&config_.sm_count, buf, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_destroy_matmul_config(NVTEMatmulConfig config) { + if (config != nullptr) { + delete reinterpret_cast(config); + } +} diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h new file mode 100644 index 000000000..54ccf06a5 --- /dev/null +++ b/transformer_engine/common/gemm/config.h @@ -0,0 +1,36 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_GEMM_CONFIG_H_ +#define TRANSFORMER_ENGINE_GEMM_CONFIG_H_ + +#include + +namespace transformer_engine { + +struct MatmulConfig { + NVTETensor bias_tensor = nullptr; + NVTETensor dbias_tensor = nullptr; + bool with_gelu_epilogue = false; + bool with_dgelu_epilogue = false; + NVTETensor epilogue_aux_tensor = nullptr; + bool use_split_accumulator = false; + int sm_count = 0; + + static constexpr size_t attr_sizes[] = { + sizeof(NVTETensor), // bias_tensor + sizeof(NVTETensor), // dbias_tensor + sizeof(bool), // with_gelu_epilogue + sizeof(bool), // with_dgelu_epilogue + sizeof(NVTETensor), // epilogue_aux_tensor + sizeof(bool), // use_split_accumulator + sizeof(int) // sm_count + }; +}; + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_ diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9c2ca9b4c..c35a2960c 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -14,23 +14,58 @@ #include #include +#include #include +#include #include #include +#include #include "../common.h" +#include "../util/cuda_runtime.h" #include "../util/handle_manager.h" #include "../util/logging.h" #include "../util/multi_stream.h" -#include "common/util/cuda_runtime.h" +#include "./config.h" #ifndef __HIP_PLATFORM_AMD__ -#include "cutlass_grouped_gemm.cuh" +#include "./cutlass_grouped_gemm.cuh" #endif #ifndef __HIP_PLATFORM_AMD__ namespace { +/* Use CUDA const memory to store scalar 1 and 0 for cublas usage +*/ +__device__ __constant__ float one_device; +__device__ __constant__ float zero_device; + +inline float *GetScalarOne() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + float one = 1.0f; + NVTE_CHECK_CUDA(cudaMemcpyToSymbol(one_device, &one, sizeof(float))); + }); + // return address by cudaGetSymbolAddress + float *dev_ptr; + NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast(&dev_ptr), one_device)); + return dev_ptr; +} + +inline float *GetScalarZero() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + float zero = 0.0f; + NVTE_CHECK_CUDA(cudaMemcpyToSymbol(zero_device, &zero, sizeof(float))); + }); + // return address by cudaGetSymbolAddress + float *dev_ptr; + NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast(&dev_ptr), zero_device)); + return dev_ptr; +} + +__global__ __launch_bounds__(1) void set_float_kernel(float *ptr, float val) { *ptr = val; } + uint32_t _getAlignment(uintptr_t address) { // alignment are in bytes uint32_t alignment = 256; @@ -90,6 +125,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla bool is_A_transposed = transA == CUBLAS_OP_T; bool is_B_transposed = transB == CUBLAS_OP_T; + // Set conditions for MXFP8 and NVFP4 gemm execution. + const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode); + const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode); + // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { // Unscaled or FP8 tensor scaling @@ -110,10 +149,26 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } } - } else if (is_mxfp_scaling(A.scaling_mode)) { - // MXFP8 + } else if (nvfp4) { + // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. + + if (is_A_transposed) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); + } else { + NVTE_CHECK(is_nvfp4_scaling(A.scaling_mode), + "Input A has unsupported combination of recipe and layout"); + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); + } + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; // NVFP4 gemm is only supported in TN layout. + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = k; + } else if (mxfp8) { + // MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe. // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). + if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { @@ -169,10 +224,20 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } } - } else if (is_mxfp_scaling(B.scaling_mode)) { - // MXFP8 - // Note: Row-wise and column-wise data are scaled along different - // dimensions (with matrix interpreted in row-major order). + } else if (nvfp4) { + if (is_B_transposed) { + NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode), + "Input B has unsupported combination of recipe and layout"); + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); + } else { + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = CUBLAS_OP_N; // NVFP4 gemm is only supported in TN layout. + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; + } else if (mxfp8) { if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { @@ -231,7 +296,7 @@ namespace transformer_engine { void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, cublasOperation_t transb, bool grad, void* workspace, size_t workspaceSize, - float alpha, float beta, bool use_split_accumulator, int math_sm_count, + const void *alpha, const void *beta, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, hipStream_t stream, int compute_stream_offset = -1); #else // Use cublasLt @@ -241,7 +306,7 @@ using cublasHandleManager = detail::HandleManageramax.dptr != nullptr || inputB->amax.dptr != nullptr)) { + // Reserve some workspace for alpha scale + NVTE_CHECK(workspaceSize >= 4, + "NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has ", + workspaceSize, " bytes remaining."); + workspaceSize = (workspaceSize / 4) * 4 - 4; // Remove last 4 aligned bytes + uint8_t *workspace_ptr = reinterpret_cast(workspace); + float *new_alpha_ptr = reinterpret_cast(&workspace_ptr[workspaceSize]); + + // Update alpha scale on device + // Note: Compute NVFP4 tensor scales based on amaxes and then + // divide from alpha scale. This way we only need to apply NVFP4 + // tensor scales in matmul output, instead of in matmul inputs. + float old_alpha = *reinterpret_cast(alpha); // Assumed to be on CPU + TensorWrapper new_alpha_tensor(new_alpha_ptr, std::vector{1}, DType::kFloat32); + nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, transa, inputB->nvte_tensor, !transb, + old_alpha, new_alpha_tensor.data(), stream); + alpha = new_alpha_ptr; + + // Make sure beta scale is on device + float old_beta = *reinterpret_cast(beta); // Assumed to be on CPU + if (old_beta == 0) { + beta = GetScalarZero(); // Device constant memory + } else if (old_beta == 1) { + beta = GetScalarOne(); // Device constant memory + } else { + // Move beta to workspace + NVTE_CHECK(workspaceSize >= 4, + "NVFP4 GEMM requires at least 4 byte workspace for beta scale, but only has ", + workspaceSize, " bytes remaining."); + workspaceSize = (workspaceSize / 4) * 4 - 4; // Remove last 4 aligned bytes + float *new_beta_ptr = reinterpret_cast(&workspace_ptr[workspaceSize]); + set_float_kernel<<<1, 1, 0, stream>>>(new_beta_ptr, old_beta); + NVTE_CHECK_CUDA(cudaGetLastError()); + beta = new_beta_ptr; + } + } const cudaDataType_t A_type = get_cuda_dtype(param.Atype); const cudaDataType_t B_type = get_cuda_dtype(param.Btype); @@ -290,16 +398,23 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, "FP8 input to GEMM requires inverse of scale!"); NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr, "FP8 input to GEMM requires inverse of scale!"); + NVTE_CHECK(!is_fp4_dtype(param.Atype) || param.A_scale_inv != nullptr, + "FP4 input to GEMM requires inverse of scale!"); + NVTE_CHECK(!is_fp4_dtype(param.Btype) || param.B_scale_inv != nullptr, + "FP4 input to GEMM requires inverse of scale!"); // check consistency of arguments: // if fp8 is desired, context cannot be null // fp8 + gelu fusion + fp8 aux is unavailable right now. - if (use_fp8 && gelu) { + if ((use_fp8 || use_fp4) && gelu) { NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype), "fp8 Aux output for gemm + gelu fusion not supported!"); } - if (is_fp8_dtype(outputD->data.dtype)) { - NVTE_CHECK(beta == 0.0f, "Accumulation mode not supported with FP8 GEMM output!"); + if (is_fp4_dtype(outputD->data.dtype)) { + NVTE_ERROR("FP4 GEMM output is not supported!"); + } + if (use_fp4 && (D_type == CUDA_R_16F)) { + NVTE_ERROR("FP4 GEMM does not support FP16 output!"); } cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); @@ -339,12 +454,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, &math_sm_count, sizeof(math_sm_count))); } - // set fp8 attributes -- input and output types should already be set to fp8 as appropriate - // Note: gelu fusion isn't available right now, and we don't need + // set fp8/fp4 attributes -- input and output types should already be set to fp8/fp4 + // as appropriate. Note: gelu fusion isn't available right now, and we don't need // amax(D) either (next op is high precision). - if (use_fp8) { - // Split accumulator. - const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; + const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode); + + if (use_fp8 || use_fp4) { + // Fast accumulation is only supported for FP8. + const int8_t fastAccuMode = (use_split_accumulator) ? 0 : use_fp8; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); @@ -353,7 +470,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, cublasLtMatmulMatrixScale_t scaling_mode_a; cublasLtMatmulMatrixScale_t scaling_mode_b; #endif // CUBLAS_VERSION >= 120800 - if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { + if (is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode)) { void *A_scale_inverse = param.A_scale_inv; void *B_scale_inverse = param.B_scale_inv; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -366,7 +483,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; #endif // CUBLAS_VERSION >= 120800 - } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) { + } else if (mxfp8_gemm) { #if CUBLAS_VERSION >= 120800 NVTE_CHECK(cublas_version() >= 120800, "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); @@ -391,6 +508,34 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #else NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION); +#endif // CUBLAS_VERSION >= 120800 + } else if (use_fp4) { // NVFP4 GEMM +#if CUBLAS_VERSION >= 120800 + NVTE_CHECK(cublas_version() >= 120800, + "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); + // make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE + cublasDataType_t scale_type = CUDA_R_32F; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); + + // Set pointer mode: alpha and beta are both device pointers + // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); + + fp8e4m3 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); + fp8e4m3 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; +#else + NVTE_ERROR("FP4 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION); #endif // CUBLAS_VERSION >= 120800 } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && @@ -523,14 +668,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); -#endif -#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) +#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) NVTE_ERROR( "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", CUBLAS_VERSION); -#endif -#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \ - CUBLAS_VERSION < 130000 +#else NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ", cuda::cudart_version()); @@ -557,6 +699,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #endif } + // align the workspace to 256 B + const int required_alignment = 256; + const auto original_workspace_alignment = _getAlignment(reinterpret_cast(workspace)); + uint8_t *aligned_workspace_ptr = + reinterpret_cast(workspace) + required_alignment - original_workspace_alignment; + workspaceSize = workspaceSize - required_alignment + original_workspace_alignment; + const auto new_workspace_alignment = + _getAlignment(reinterpret_cast(aligned_workspace_ptr)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize))); @@ -564,7 +714,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const auto B_alignment = _getAlignment(reinterpret_cast(param.B)); const auto C_alignment = _getAlignment(reinterpret_cast(C)); const auto D_alignment = _getAlignment(reinterpret_cast(D)); - const auto workspace_alignment = _getAlignment(reinterpret_cast(workspace)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -573,8 +722,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); - NVTE_CHECK(workspace_alignment % 256 == 0, - "cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment); + NVTE_CHECK(new_workspace_alignment % 256 == 0, + "cuBLAS workspace pointer must be aligned to 256 bytes, got ", + new_workspace_alignment); const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, @@ -585,16 +735,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C - NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, - static_cast(&alpha), /* alpha */ - param.A, /* A */ - Adesc, param.B, /* B */ - Bdesc, static_cast(&beta), /* beta */ - C, /* C */ - Cdesc, D, /* D */ - Ddesc, &heuristicResult.algo, /* algo */ - workspace, /* workspace */ - workspaceSize, stream)); /* stream */ + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, alpha, /* alpha */ + param.A, /* A */ + Adesc, param.B, /* B */ + Bdesc, beta, /* beta */ + C, /* C */ + Cdesc, D, /* D */ + Ddesc, &heuristicResult.algo, /* algo */ + aligned_workspace_ptr, /* workspace */ + workspaceSize, stream)); /* stream */ // Update FP8 scale-inv in output tensor // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated. @@ -621,35 +770,117 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons int math_sm_count, cudaStream_t stream) { NVTE_API_CALL(nvte_cublas_gemm); using namespace transformer_engine; + + // Tensors const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); - Tensor *outputD = convertNVTETensor(D); + Tensor *outputD = convertNVTETensorCheck(D); const Tensor *biasTensor = convertNVTETensor(bias); Tensor *outputGelu = convertNVTETensor(pre_gelu_out); Tensor *wspace = convertNVTETensor(workspace); + // Scales + const float alpha = 1; + const float beta = accumulate ? 1 : 0; + + // Check for NVFP4 + // TODO Remove once alpha scale logic is moved into cublas_gemm function + if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) { + NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead."); + } + + // Launch GEMM cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, 0, 0, false, - nullptr, stream); + &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); +} + +void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A, + const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D, + NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream) { + NVTE_API_CALL(nvte_cublas_gemm_v2); + using namespace transformer_engine; + + // Data tensors + const Tensor *A_tensor = convertNVTETensorCheck(A); + const Tensor *B_tensor = convertNVTETensorCheck(B); + const Tensor *C_tensor = convertNVTETensorCheck(C); + Tensor *D_tensor = convertNVTETensorCheck(D); + NVTE_CHECK(C_tensor == D_tensor, + "Currently nvte_cublas_gemm_v2 does not support different C and D tensors."); + + // Workspace + void *workspace_ptr = nullptr; + size_t workspace_size = 0; + Tensor *workspace_tensor = convertNVTETensor(workspace); + if (workspace_tensor != nullptr) { + workspace_ptr = workspace_tensor->data.dptr; + workspace_size = + get_buffer_size_bytes(workspace_tensor->data.numel(), workspace_tensor->data.dtype); + } + + // Additional config + MatmulConfig config_; + if (config != nullptr) { + config_ = *reinterpret_cast(config); + } + + // Configure GEMM epilogue + const bool with_grad_epilogue = (config_.dbias_tensor != nullptr || config_.with_dgelu_epilogue); + if (with_grad_epilogue) { + NVTE_CHECK(config_.bias_tensor == nullptr && !config_.with_gelu_epilogue, + "Invalid epilogue (bias=", config_.bias_tensor != nullptr, + ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue, + ", dgelu=", config_.with_dgelu_epilogue, ")."); + } + Tensor dummy_tensor; + Tensor *epilogue_bias_tensor = &dummy_tensor; + if (!with_grad_epilogue && config_.bias_tensor != nullptr) { + epilogue_bias_tensor = convertNVTETensorCheck(config_.bias_tensor); + } else if (with_grad_epilogue && config_.dbias_tensor != nullptr) { + epilogue_bias_tensor = convertNVTETensorCheck(config_.dbias_tensor); + } + Tensor *epilogue_aux_tensor = &dummy_tensor; + if (config_.with_gelu_epilogue || config_.with_dgelu_epilogue) { + NVTE_CHECK(config_.epilogue_aux_tensor != nullptr, + "Requested epilogue (bias=", config_.bias_tensor != nullptr, + ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue, + ", dgelu=", config_.with_dgelu_epilogue, ") without providing aux tensor."); + epilogue_aux_tensor = convertNVTETensor(config_.epilogue_aux_tensor); + } + + // Launch GEMM + cublas_gemm(A_tensor, B_tensor, D_tensor, epilogue_bias_tensor, epilogue_aux_tensor, + transa ? CUBLAS_OP_T : CUBLAS_OP_N, transb ? CUBLAS_OP_T : CUBLAS_OP_N, + with_grad_epilogue, workspace_ptr, workspace_size, alpha, beta, + config_.use_split_accumulator, config_.sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, float alpha, float beta, bool use_split_accumulator, int math_sm_count, cudaStream_t stream) { - NVTE_API_CALL(nvte_cublas_gemm_scaled); + NVTE_API_CALL(nvte_cublas_gemm); using namespace transformer_engine; + + // Tensors const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); - Tensor *outputD = convertNVTETensor(D); + Tensor *outputD = convertNVTETensorCheck(D); const Tensor *biasTensor = convertNVTETensor(bias); Tensor *outputGelu = convertNVTETensor(pre_gelu_out); Tensor *wspace = convertNVTETensor(workspace); + // Check for NVFP4 + // TODO Remove once alpha scale logic is moved into cublas_gemm function + if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) { + NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead."); + } + + // Launch GEMM cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - alpha, beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); + &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, @@ -666,12 +897,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); -#endif -#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) +#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) NVTE_ERROR( "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", CUBLAS_VERSION); -#endif +#else NVTE_CHECK( transformer_engine::cuda::cudart_version() >= 12020 && transformer_engine::cuda::cudart_version() < 13000, @@ -691,13 +921,17 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor const Tensor *inputCounter = convertNVTETensor(counter); Tensor *wspace = convertNVTETensor(workspace); + const void *alpha_ptr = GetScalarOne(); + const void *beta_ptr = accumulate ? GetScalarOne() : GetScalarZero(); + NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, m_split, - n_split, gemm_producer, inputCounter, stream); + alpha_ptr, beta_ptr, use_split_accumulator, math_sm_count, m_split, n_split, + gemm_producer, inputCounter, stream); +#endif } void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, @@ -727,16 +961,40 @@ void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETens Tensor *outputGelu = convertNVTETensorCheck(pre_gelu_out[i]); Tensor *wspace = convertNVTETensorCheck(workspace[i % num_streams]); + const float alpha = 1; + const float beta = accumulate ? 1 : 0; + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], 1.0f, (accumulate) ? 1.0f : 0.0f, + wspace->data.dptr, wspace->data.shape[0], &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, detail::get_compute_stream(i % num_streams), i % num_streams); } #else - nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, - workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, - detail::get_compute_stream(i % num_streams)); + // Check whether GELU or dGELU epilogue is requested + Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]); + bool with_gelu_dgelu_epilogue = + (pre_gelu_tensor != nullptr && pre_gelu_tensor->data.dptr != nullptr); + + // Construct config + MatmulConfig config; + if (grad) { + config.dbias_tensor = bias[i]; + config.with_dgelu_epilogue = with_gelu_dgelu_epilogue; + } else { + config.bias_tensor = bias[i]; + config.with_gelu_epilogue = with_gelu_dgelu_epilogue; + } + config.epilogue_aux_tensor = pre_gelu_out[i]; + config.use_split_accumulator = use_split_accumulator; + config.sm_count = math_sm_count; + + // Launch GEMM + const float alpha = 1.f; + const float beta = accumulate ? 1.f : 0.f; + nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i], + workspace[i % num_streams], &config, + detail::get_compute_stream(i % num_streams)); #endif } diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index fef3966a5..010a2f9bc 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -939,7 +939,7 @@ void hipblaslt_gemm(const Tensor *inputA, bool grad, void* workspace, size_t workspaceSize, - float alpha, float beta, + const void *alpha, const void *beta, bool use_split_accumulator, int math_sm_count, hipStream_t stream, @@ -975,7 +975,7 @@ void hipblaslt_gemm(const Tensor *inputA, << " gelu=" << (outputPreGelu->data.dptr != nullptr) << " use_fp8=" << use_fp8 << " scale_mode=" << (a_tensor ? "tensor" : a_block ? "mxfp8" : "unsupported") - << " alpha=" << alpha << " beta=" << beta + << " alpha=" << *reinterpret_cast(alpha) << " beta=" << *reinterpret_cast(beta) << std::endl; } @@ -1193,10 +1193,10 @@ void hipblaslt_gemm(const Tensor *inputA, if (HIPBLAS_STATUS_SUCCESS == hipblaslt_ext::matmulIsAlgoSupported( handle, operationDesc, - static_cast(&alpha), + alpha, Adesc, Bdesc, - static_cast(&beta), + beta, Ddesc, Ddesc, algo_arr[0].algo, @@ -1273,12 +1273,12 @@ void hipblaslt_gemm(const Tensor *inputA, // Warm-up call NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&alpha), /* alpha */ + alpha, /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ Bdesc, - static_cast(&beta), /* beta */ + beta, /* beta */ C, /* C */ Cdesc, D, /* D */ @@ -1295,12 +1295,12 @@ void hipblaslt_gemm(const Tensor *inputA, { NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&alpha), /* alpha */ + alpha, /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ Bdesc, - static_cast(&beta), /* beta */ + beta, /* beta */ C, /* C */ Cdesc, D, /* D */ @@ -1356,12 +1356,12 @@ void hipblaslt_gemm(const Tensor *inputA, // D = alpha * (A * B) + beta * C NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&alpha), /* alpha */ + alpha, /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ Bdesc, - static_cast(&beta), /* beta */ + beta, /* beta */ C, /* C */ Cdesc, D, /* D */ @@ -1501,7 +1501,7 @@ void release_service_stream(hipStream_t stream, struct ServiceStreamCtl &ctl) void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, - float alpha, float beta, bool use_split_accumulator, int math_sm_count, + const void *alpha, const void *beta, bool use_split_accumulator, int math_sm_count, [[maybe_unused]] int m_split, [[maybe_unused]] int n_split, [[maybe_unused]] bool gemm_producer, [[maybe_unused]] const Tensor *inputCounter, hipStream_t stream, int compute_stream_offset) diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu new file mode 100644 index 000000000..9d4bec41d --- /dev/null +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -0,0 +1,876 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include + +#include "common/common.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +namespace transformer_engine { +namespace { + +constexpr int kThreadsPerWarp = 32; +constexpr float k16x16HadamardScale = 0.25f; + +template +__device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, + void* addr) { + auto smem_addr = static_cast(__cvta_generic_to_shared(addr)); + if constexpr (kTranspose) { + asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "r"(smem_addr)); + } else { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "r"(smem_addr)); + } +} + +template +__device__ __forceinline__ void load_matrix_16x16_from_shared(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, + void* addr, uint32_t stride) { + if constexpr (kTranspose) { + asm volatile( + "wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 " + "{%0,%1,%2,%3}, [%4], %5;\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "l"(addr), "r"(stride)); + } else { + asm volatile( + "wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 " + "{%0,%1,%2,%3}, [%4], %5;\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "l"(addr), "r"(stride)); + } +} + +template +__device__ __forceinline__ void store_matrix_16x16_to_global(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, void* addr, + uint32_t stride) { + if constexpr (kTranspose) { + asm volatile("wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n" + : + : "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride)); + } else { + asm volatile("wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n" + : + : "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride)); + } +} + +__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(uint32_t& a0) { + asm volatile( + "movmatrix.sync.aligned.m8n8.trans.b16 " + "%0, %1;\n\t" + : "=r"(a0) + : "r"(a0)); +} + +__device__ __forceinline__ void unpack_max_of_packed_bf16(uint32_t& packed_bf16, float& float_dst) { + __nv_bfloat162 bf16x2 = *reinterpret_cast<__nv_bfloat162*>(&packed_bf16); + float f_a = __bfloat162float(bf16x2.x); + float f_b = __bfloat162float(bf16x2.y); + asm volatile("max.xorsign.abs.f32 %0, %1, %2;\n\t" : "=f"(float_dst) : "f"(f_a), "f"(f_b)); + float_dst = fabsf(float_dst); +} + +template +__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc( + uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, uint32_t& b0, uint32_t& b1, + uint32_t& b2, uint32_t& b3, uint32_t& c0, uint32_t& c1, uint32_t& c2, uint32_t& c3, + uint32_t& amax_result) { + uint32_t zero = 0; + uint32_t temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7; + asm volatile( + "wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32 \n" + "{%0, %1, %2, %3, %4, %5, %6, %7}, \n" + "{%8, %9, %10, %11}, \n" + "{%12, %13, %14, %15}, \n" + "{%16, %17, %18, %19, %20, %21, %22, %23};\n\t" + : "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3), "=r"(temp4), "=r"(temp5), "=r"(temp6), + "=r"(temp7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(zero), + "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c2) : "r"(temp5), "r"(temp4)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c3) : "r"(temp7), "r"(temp6)); + if constexpr (kCalculateAmax) { + uint32_t max_even; + uint32_t max_odd; + // Reduction tree to amax(abs(result)) into bf16x2 reg outparam. + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_even) : "r"(c0), "r"(c2)); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_odd) : "r"(c1), "r"(c3)); + // N.B. mma is only called up to once per thread for identity and transpose respectively, so + // we don't have to accumulate into amax_result and can directly store into it. + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(amax_result) + : "r"(max_even), "r"(max_odd)); + } +} + +template +__device__ __forceinline__ void get_hadamard_matrix_fragment(uint32_t* had_frag_i, + uint16_t random_sign_mask, + uint32_t* had_frag_t, + uint16_t random_sign_mask_t) { + int32_t tid = threadIdx.x % 32; // Local tid + float temp_i[2]; + float temp_t[2]; +#pragma unroll + for (int i = 0; i < 2; i++) { + // i is the vertical fragment index. + // For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals. + uint32_t r = i * 8 + tid / 4; + +#pragma unroll + for (int j = 0; j < 2; j++) { +#pragma unroll + for (int k = 0; k < 2; k++) { + // k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits. + // j is the column fragment idx selecting between even and odd fragments. + // j increments 8 columns by switching fragments. + uint32_t c = j * 8 + k + tid % 4 * 2; + // 1 -> -1.0f, 0 -> 1.0f + int32_t base_sign = __popc(r & c); + if constexpr (kReturnIdentity) { + int32_t sign_i; + // Because tensor cores want the dot product dimension, + // contiguous, the regular, non-inverse hadamard swaps + // signs of columns and rows for inverse. In a simple reference, + // x.reshape(-1, 16) @ sign @ H16, this would be opposite but + // (sign @ H16) is transposed in this fragment. + if constexpr (kInverseHadamardIdentity) { + sign_i = ((random_sign_mask >> r) ^ base_sign); + } else { + sign_i = ((random_sign_mask >> c) ^ base_sign); + } + temp_i[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_i << 31)); + } + if constexpr (kReturnTransposed) { + int32_t sign_t; + if constexpr (kInverseHadamardTransposed) { + sign_t = ((random_sign_mask_t >> r) ^ base_sign); + } else { + sign_t = ((random_sign_mask_t >> c) ^ base_sign); + } + temp_t[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_t << 31)); + } + } + + if constexpr (kReturnIdentity) { + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" + : "=r"(had_frag_i[i * 2 + j]) + : "f"(temp_i[1]), "f"(temp_i[0])); + } + if constexpr (kReturnTransposed) { + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" + : "=r"(had_frag_t[i * 2 + j]) + : "f"(temp_t[1]), "f"(temp_t[0])); + } + } + } +} + +__device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx, + uint32_t gmem_col_idx) { + uint32_t smem_row_idx = gmem_row_idx; + uint32_t xor_factor = (smem_row_idx * 2) % 8; + uint32_t smem_col_idx = gmem_col_idx ^ xor_factor; + return smem_row_idx * 8 + smem_col_idx; +} + +template +__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], + IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + uint32_t& local_amax_reg, + uint32_t& local_amax_t_reg) { + uint32_t a_frag[4]; // A matrix fragment + uint32_t c_frag[4]; // Result fragment + + int warp_id = threadIdx.x / kThreadsPerWarp; + int local_rank = (threadIdx.x % kThreadsPerWarp); + + int ld_row_idx = local_rank % kHadamardDimension; + int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + + uint32_t temp_amax_reg; + uint32_t temp_amax_t_reg; + + if (kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } + + if (kReturnTransposedAmax) { + // TODO(Frank): This is not efficient, since we could directly load the + // matrix in transposed layout. + if (!kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + matrix_transpose_m8_n8_b16_inplace(a_frag[0]); + matrix_transpose_m8_n8_b16_inplace(a_frag[1]); + matrix_transpose_m8_n8_b16_inplace(a_frag[2]); + matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], + b_frag_t[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_t_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_t_reg) + : "r"(local_amax_t_reg), "r"(temp_amax_t_reg)); + } + + if (kReturnPreRhtAmax) { + if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[1])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[2]) + : "r"(a_frag[2]), "r"(a_frag[3])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[2])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_pre_rht_amax_reg) + : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); + } +} + +template +__device__ __host__ constexpr int NextPowerOf2() { + static_assert(kN > 0, "kN must be > 0"); + // Round up to the next power of 2 by counting leading zeros. + return 1 << (32 - __builtin_clz(kN - 1)); +} + +template +__device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float identity_amax, + const float transpose_amax, float* staging_for_pre_rht, + float* staging_for_identity, float* staging_for_transpose, + float* output_pre_rht_amax_ptr, + float* output_identity_amax_ptr, + float* output_transpose_amax_ptr, const int warpid) { + // intra-warp reduction + constexpr int kWarpSize = 32; + int local_rank = threadIdx.x % 32; + float warp_pre_rht_amax = kReturnPreRhtAmax ? warp_reduce_max(pre_rht_amax) : 0.0f; + float warp_identity_amax = kReturnIdentityAmax ? warp_reduce_max(identity_amax) : 0.0f; + float warp_transpose_amax = + kReturnTransposedAmax ? warp_reduce_max(transpose_amax) : 0.0f; + + // inter-warp reduction + if (threadIdx.x % 32 == 0) { + if (kReturnPreRhtAmax) { + staging_for_pre_rht[warpid] = warp_pre_rht_amax; + } + if (kReturnIdentityAmax) { + staging_for_identity[warpid] = warp_identity_amax; + } + if (kReturnTransposedAmax) { + staging_for_transpose[warpid] = warp_transpose_amax; + } + } + __syncthreads(); + constexpr int kNumWarpsPow2 = NextPowerOf2(); + if (warpid == 0) { + if (kReturnIdentityAmax) { + float identity_accum = local_rank < kNumWarps ? staging_for_identity[local_rank] : 0.0f; + identity_accum = warp_reduce_max(identity_accum); + if (local_rank == 0) { + atomicMaxFloat(output_identity_amax_ptr, identity_accum); + } + } + } + if (warpid == 1) { + if (kReturnTransposedAmax) { + float transpose_accum = local_rank < kNumWarps ? staging_for_transpose[local_rank] : 0.0f; + transpose_accum = warp_reduce_max(transpose_accum); + if (local_rank == 0) { + atomicMaxFloat(output_transpose_amax_ptr, transpose_accum); + } + } + } + if (warpid == 2) { + if (kReturnPreRhtAmax) { + float pre_rht_accum = local_rank < kNumWarps ? staging_for_pre_rht[local_rank] : 0.0f; + pre_rht_accum = warp_reduce_max(pre_rht_accum); + if (local_rank == 0) { + atomicMaxFloat(output_pre_rht_amax_ptr, pre_rht_accum); + } + } + } +} + +__launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_pre_rht_amax_ptr, + float* __restrict__ output_identity_amax_ptr, + float* __restrict__ output_transpose_amax_ptr) { + if (output_pre_rht_amax_ptr != nullptr) { + *output_pre_rht_amax_ptr = 0; + } + if (output_identity_amax_ptr != nullptr) { + *output_identity_amax_ptr = 0; + } + if (output_transpose_amax_ptr != nullptr) { + *output_transpose_amax_ptr = 0; + } +} + +template +__global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor_map_input, + float* __restrict__ output_pre_rht_amax_ptr, + float* __restrict__ output_identity_amax_ptr, + float* __restrict__ output_transpose_amax_ptr, + uint16_t random_sign_mask, uint16_t random_sign_mask_t, + uint64_t num_rows, uint64_t row_length) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y && CHUNK_DIM_Y % BUFF_DIM_Y == 0); + static_assert(CHUNK_DIM_X >= BUFF_DIM_X && CHUNK_DIM_X % BUFF_DIM_X == 0); + + constexpr size_t STAGES_Y = CHUNK_DIM_Y / BUFF_DIM_Y; + constexpr size_t STAGES_X = CHUNK_DIM_X / BUFF_DIM_X; + + constexpr int kNumWarps = (THREADS_PER_CHUNK * THREADS_PER_Y) / kThreadsPerWarp; + + const int input_block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int input_block_offset_X = blockIdx.x * CHUNK_DIM_X; + + extern __shared__ __align__(128) char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uint8_t* dshmem = reinterpret_cast((base_shmem_ptr + 127) & ~127ULL); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + constexpr size_t in_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + IType* in_sh_0 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + IType* in_sh_1 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + + IType* in_shs[2] = {in_sh_0, in_sh_1}; + + constexpr int shmem_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + + const bool is_master_thread = (threadIdx.x == 0 && threadIdx.y == 0); + + // Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + uint64_t* mbar = reinterpret_cast(dshmem); + dshmem += sizeof(uint64_t) * (STAGES_X * STAGES_Y); + + float* max_staging_identity = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_transpose = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_pre_rht = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + + initialize_barriers(mbar, + is_master_thread); + + copy_2d_to_shared(in_shs[0], reinterpret_cast(&tensor_map_input), + input_block_offset_X, input_block_offset_Y, shmem_buff_size, &mbar[0], + is_master_thread); + + uint32_t had_frag_i[4]; + uint32_t had_frag_t[4]; + get_hadamard_matrix_fragment( + had_frag_i, random_sign_mask, had_frag_t, random_sign_mask_t); + + float local_pre_rht_amax = 0.0; + float local_amax = 0.0; + float local_amax_t = 0.0; + uint32_t local_pre_rht_amax_reg = *reinterpret_cast(&local_pre_rht_amax); + uint32_t local_amax_reg = *reinterpret_cast(&local_amax); + uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { + for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { + int stage = STAGES_X * stage_y + stage_x; + + const int next_stage = stage + 1; + const int next_stage_x = stage_x + 1 == STAGES_X ? 0 : stage_x + 1; + const int next_stage_y = stage_x + 1 == STAGES_X ? stage_y + 1 : stage_y; + + if (next_stage < STAGES_X * STAGES_Y) { + const int input_global_offset_Y = input_block_offset_Y + next_stage_y * BUFF_DIM_Y; + const int input_global_offset_X = input_block_offset_X + next_stage_x * BUFF_DIM_X; + + copy_2d_to_shared(in_shs[next_stage % 2], // ping-pong + reinterpret_cast(&tensor_map_input), input_global_offset_X, + input_global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + const size_t compute_stage_x_num = + BUFF_DIM_X / (kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)); + const size_t compute_stage_y_num = BUFF_DIM_Y / (kHadamardDimension * THREADS_PER_Y); + + const size_t in_row_stride = BUFF_DIM_X; + + IType* in_sh_ptr = in_shs[stage % 2]; + +#pragma unroll + for (size_t compute_stage_y = 0; compute_stage_y < compute_stage_y_num; compute_stage_y++) { + const int row_idx_offset = (compute_stage_y * kHadamardDimension * THREADS_PER_Y + + threadIdx.y * kHadamardDimension); + const int in_row_offset = row_idx_offset * in_row_stride; + +#pragma unroll + for (size_t compute_stage_x = 0; compute_stage_x < compute_stage_x_num; compute_stage_x++) { + ComputeKernel( + had_frag_i, had_frag_t, + in_sh_ptr + in_row_offset + + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), + local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + } + + // Ensure all threads have finished their computation before new data over-writes the shared + // memory. + __syncthreads(); + } + } + } + + const int warpid = (threadIdx.x + threadIdx.y * blockDim.x) / kThreadsPerWarp; + + if constexpr (kReturnPreRhtAmax) { + unpack_max_of_packed_bf16(local_pre_rht_amax_reg, local_pre_rht_amax); + } + if constexpr (kReturnIdentityAmax) { + unpack_max_of_packed_bf16(local_amax_reg, local_amax); + } + if constexpr (kReturnTransposedAmax) { + unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t); + } + + ReduceMax( + local_pre_rht_amax, local_amax, local_amax_t, max_staging_pre_rht, max_staging_identity, + max_staging_transpose, output_pre_rht_amax_ptr, output_identity_amax_ptr, + output_transpose_amax_ptr, warpid); + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("Kernel is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restrict__ output, + T* __restrict__ output_t, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, uint64_t num_input_rows, + uint64_t num_input_cols, float* __restrict__ amax, + float* __restrict__ amax_t, bool inverse_hadamard) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + static_assert(kHadamardDimension == 16, "Currently only hadamard dimension 16 is supported."); + + // The whole threadblock will share the same smem. + extern __shared__ __align__(16) T smem[]; + + // Each 32 threads process a 16x16 matrix. There is a (y, z) grid of 16x16. + // If y = 4, z = 4, then each threadblock is processing a 4x4 grid of 16x16 matrices. + int32_t tid = threadIdx.x; + int32_t warp_id = threadIdx.y * blockDim.z + threadIdx.z; + int32_t local_bx = threadIdx.y; + int32_t local_by = threadIdx.z; + + // Define the register fragments + uint32_t a_frag[4]; // A matrix fragment + uint32_t b_frag_i[4]; // Transposed Hadamard matrix fragment, used for A @ B(col major) + uint32_t b_frag_t[4]; // Hadamard matrix fragment, used for A.T @ B.T(col major) + uint32_t c_frag[4]; // Result fragment + + // row and col for each thread. 32 threads will work together in 128 chunk to + // load the data from global memory to shared memory. + uint32_t row = tid / (kHadamardDimension * sizeof(T) / sizeof(uint4)); + uint32_t col = tid % (kHadamardDimension * sizeof(T) / sizeof(uint4)); + + uint32_t smem_index = tid; + + uint32_t input_start_col = (blockIdx.x * blockDim.y + local_bx) * kHadamardDimension; + uint32_t input_start_row = (blockIdx.y * blockDim.z + local_by) * kHadamardDimension; + + bool load = (input_start_col < num_input_cols) && (input_start_row < num_input_rows); + if (!load) { + // Out of bound, we are returning early. No thread divergence since the whole warp + // will return early. + return; + } + + uint64_t global_offset = input_start_col + input_start_row * num_input_cols; + uint64_t global_offset_t = + kOutputTrueTransposed ? (input_start_row + input_start_col * num_input_rows) : global_offset; + + T* base_smem = smem + kHadamardDimension * kHadamardDimension * warp_id; + + uint32_t* smem_b32 = reinterpret_cast(base_smem); + uint4* smem_b128 = reinterpret_cast(base_smem); + + // Asynchronously load the data from global memory to shared memory. + const uint4* input_b128 = reinterpret_cast(input + global_offset); + // Each 16x16 chunk is divided into 4 8x8 matrices, we are trying to load each + // 8x8 chunks consecutively into the smem, so we could leverage ldmatrix m8n8x4 + // to load the data in the tensor core swizzled format. + __pipeline_memcpy_async(&smem_b128[smem_index], + &input_b128[row * num_input_cols / (sizeof(uint4) / sizeof(T)) + col], + sizeof(uint4)); + __pipeline_commit(); // Commit the memcpy. Wait when we are in the computation. + + if (inverse_hadamard) { + get_hadamard_matrix_fragment(b_frag_i, random_sign_mask, + b_frag_t, random_sign_mask_t); + } else { + get_hadamard_matrix_fragment( + b_frag_i, random_sign_mask, b_frag_t, random_sign_mask_t); + } + + float local_amax = 0.0; + float local_amax_t = 0.0; + uint32_t local_amax_reg = *reinterpret_cast(&local_amax); + uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + __pipeline_wait_prior(0); + + __syncwarp(); // ensure all lanes finished their cp.async before reading smem + + // Load the A to a_frag. + if constexpr (kComputeIdentity) { + load_matrix_16x16_from_shared(a_frag[0], a_frag[1], a_frag[2], a_frag[3], smem_b32, + kHadamardDimension); + + // 16x16 @ 16x16 leveraging all threads in the warp. + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], local_amax_reg); + + // Store the result to the shared memory in non-transposed order. + if constexpr (kReturnIdentity) { + uint4* output_b128 = reinterpret_cast(output + global_offset); + store_matrix_16x16_to_global(c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_b128, + num_input_cols); + } + } + + if constexpr (kComputeTransposed) { + if (kComputeIdentity) { + matrix_transpose_m8_n8_b16_inplace(a_frag[0]); + matrix_transpose_m8_n8_b16_inplace(a_frag[1]); + matrix_transpose_m8_n8_b16_inplace(a_frag[2]); + matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + } else { + load_matrix_16x16_from_shared(a_frag[0], + a_frag[2], // NOTE: intentional index swapping + a_frag[1], // NOTE: intentional index swapping + a_frag[3], smem_b32, kHadamardDimension); + } + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], + // 2,1 is used if we are using movmatrix instruction. + // Thus loading the matrix in 2,1 order will just be normal. + // This is to be compatible with the movmatrix instruction. + a_frag[2], // NOTE: intentional index swapping for transpose purpose. + a_frag[1], // NOTE: intentional index swapping for transpose purpose. + a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], b_frag_t[3], c_frag[0], c_frag[1], + c_frag[2], c_frag[3], local_amax_t_reg); + + // Store the result to the shared memory in non-transposed order. + if constexpr (kReturnTransposed) { + uint4* output_t_b128 = reinterpret_cast(output_t + global_offset_t); + store_matrix_16x16_to_global( + c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_t_b128, + kOutputTrueTransposed ? num_input_rows : num_input_cols); + } + } + + if constexpr (kUpdateIdentityAmax) { + unpack_max_of_packed_bf16(local_amax_reg, local_amax); + local_amax = warp_reduce_max(local_amax); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + local_amax = __shfl_sync(0xFFFFFFFF, local_amax, lane_zero); + // atomic CAS to output memory. + if (tid % kThreadsPerWarp == 0) { + atomicMaxFloat(amax, local_amax); + } + } + if constexpr (kUpdateTransposeAmax) { + unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t); + local_amax_t = warp_reduce_max(local_amax_t); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + local_amax_t = __shfl_sync(0xFFFFFFFF, local_amax_t, lane_zero); + // atomic CAS to output memory. + if (tid % kThreadsPerWarp == 0) { + atomicMaxFloat(amax_t, local_amax_t); + } + } +#else + NVTE_DEVICE_ERROR("Kernel is only supported on SM 9.0+."); +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 +} + +} // namespace + +void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform); + + // Check tensors + // NOTE (frsun): This is non-intuitive, we are writing the result of + // transposed RHT to the output of rowwise. + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + NVTE_CHECK(output_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Output tensor must be simple tensor, but scaling mode is ", + to_string(output_.scaling_mode), "."); + const SimpleTensor& input = input_.data; + SimpleTensor output; + SimpleTensor& output_t = output_.data; + + // Check requested outputs + const bool return_identity = output.dptr != nullptr; + const bool return_transposed = output_t.dptr != nullptr; + if (!return_identity && !return_transposed) { // Nothing to do/ill-defined behavior. + return; + } + + checkCuDriverContext(stream); + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + num_rows *= input.shape[i]; + } + + using IType = bf16; + + constexpr int kHadamardDimension = 16; + NVTE_CHECK(row_length % kHadamardDimension == 0, + "row_length must be divisible by hadamard_dimension."); + NVTE_CHECK(num_rows % kHadamardDimension == 0, + "num_rows must be divisible by hadamard_dimension"); + + constexpr uint64_t kThreadBlockX = 4; + // Configure 4 is used for Hopper, 8 is used for Blackwell for extra memory bandwidth. + constexpr uint64_t kThreadBlockY = 4; + + uint64_t kNumWarpsPerSM = kThreadBlockX * kThreadBlockY; + + // The shared memory number of bytes required for **the whole threadblock**. + size_t shmem_bytes = kHadamardDimension * kHadamardDimension * sizeof(IType) * kNumWarpsPerSM; + + dim3 block(kThreadsPerWarp, kThreadBlockX, kThreadBlockY); + + dim3 grid(DIVUP(row_length / kHadamardDimension, kThreadBlockX), + DIVUP(num_rows / kHadamardDimension, kThreadBlockY)); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transposed, kReturnTransposed, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_identity, kReturnIdentity, + + auto kernel = + HadamardTransformKernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes); + + kernel<<>>( + reinterpret_cast(input.dptr), reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), random_sign_mask, random_sign_mask_t, + num_rows, row_length, nullptr, nullptr, false););); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Kernel that will apply the 16x16 hadamard transform the input and input.T, and then +// get the absolute max value of the result. +void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform_amax); +#if CUDA_VERSION >= 12080 + + // Check input tensor + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor& input = input_.data; + + // Check amax tensors + SimpleTensor& output_pre_rht_amax = output_.amax; + SimpleTensor output_identity_amax; + SimpleTensor& output_transpose_amax = output_.columnwise_amax; + + // Check requested outputs + const bool return_pre_rht_amax = output_pre_rht_amax.dptr != nullptr; + const bool return_identity_amax = output_identity_amax.dptr != nullptr; + const bool return_transposed_amax = output_transpose_amax.dptr != nullptr; + if (!return_identity_amax && !return_transposed_amax && + !return_pre_rht_amax) { // Nothing to do/ill-defined behavior. + return; + } + + // Zero out amaxes if needed + ZeroAmaxKernel<<<1, 1, 0, stream>>>(reinterpret_cast(output_pre_rht_amax.dptr), + reinterpret_cast(output_identity_amax.dptr), + reinterpret_cast(output_transpose_amax.dptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); + + checkCuDriverContext(stream); + + using IType = bf16; + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + num_rows *= input.shape[i]; + } + + constexpr int kHadamardDimension = 16; + NVTE_CHECK(row_length % kHadamardDimension == 0, + "row_length must be divisible by hadamard_dimension."); + NVTE_CHECK(num_rows % kHadamardDimension == 0, + "num_rows must be divisible by hadamard_dimension"); + + constexpr uint64_t kChunkBlockXSmall = 128; + constexpr uint64_t kChunkBlockYSmall = 128; + constexpr uint64_t kBuffDimX = 64; + constexpr uint64_t kBuffDimY = 64; + + alignas(64) CUtensorMap tensor_map_input{}; + + create_2D_tensor_map( + /*tensorMap=*/tensor_map_input, + /*tensor=*/input, + /*globalY=*/num_rows, + /*globalX=*/row_length, + /*shmemY=*/kBuffDimY, + /*shmemX=*/kBuffDimX, + /*stride_elems=*/row_length, + /*offset_elems=*/0, + /*type_num_bits=*/sizeof(IType) * 8, + /*swizzle=*/CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B); + + constexpr uint64_t kThreadBlockX = 4; + constexpr uint64_t kThreadBlockY = 1; + constexpr uint64_t kNumWarps = kThreadBlockX * kThreadBlockY; + + dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY); + + dim3 grid(DIVUP(row_length, kChunkBlockXSmall), DIVUP(num_rows, kChunkBlockYSmall)); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transposed_amax, kReturnTransposedAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_identity_amax, kReturnIdentityAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_pre_rht_amax, kReturnPreRhtAmax, + + // *2 for ping-pong + size_t in_sh_size = kBuffDimX * kBuffDimY * 2 * sizeof(IType); + size_t mbar_size = sizeof(uint64_t) * (kChunkBlockXSmall / kBuffDimX) * + (kChunkBlockYSmall / kBuffDimY); + size_t shmem_bytes = in_sh_size + mbar_size + kNumWarps * sizeof(float) * 3; + // Add padding in case shmem ptr is not aligned to 128 bytes. + shmem_bytes = (shmem_bytes + 128); + + auto kernel = HadamardAmaxTmaKernel< + IType, kHadamardDimension, kChunkBlockYSmall, kChunkBlockXSmall, kBuffDimY, + kBuffDimX, kThreadBlockX * kThreadsPerWarp, kThreadBlockY, kReturnPreRhtAmax, + kReturnIdentityAmax, kReturnTransposedAmax>; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem_bytes); + + kernel<<>>( + tensor_map_input, reinterpret_cast(output_pre_rht_amax.dptr), + reinterpret_cast(output_identity_amax.dptr), + reinterpret_cast(output_transpose_amax.dptr), random_sign_mask, + random_sign_mask_t, num_rows, row_length);))); + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ", + CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +} + +} // namespace transformer_engine + +void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform); + using namespace transformer_engine; + hadamard_transform(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + static_cast(random_sign_mask), + static_cast(random_sign_mask_t), stream); +} + +void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform_amax); + using namespace transformer_engine; + hadamard_transform_amax(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + static_cast(random_sign_mask), + static_cast(random_sign_mask_t), stream); +} diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu new file mode 100644 index 000000000..ce191b5ff --- /dev/null +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -0,0 +1,841 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_runtime.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "curanddx.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/numeric_conversion.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/helper_cuda.hpp" +#include "cutlass/util/print_error.hpp" + +// clang-format off + +namespace transformer_engine { +namespace detail { +namespace { + +// Define a cuRANDDx descriptor +// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10. +// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g., +// if shared memory, if needed, is enough for the described problem, usually not applicable. + +// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html +using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + curanddx::SM<800>() + curanddx::Thread()); + + +using namespace cute; +using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor + +// calculate the global encode scale factor for a given global amax. +__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { + constexpr float kFP8E4M3Max = 448.0f; + constexpr float kFP4E2M1Max = 6.0f; + // If scale is infinity, return max value of float32 + float global_encode_scale = cutlass::minimum_with_nan_propagation{}( + kFP8E4M3Max * kFP4E2M1Max / global_amax, cutlass::platform::numeric_limits::max()); + // If global amax is 0 or infinity, return 1 + return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale; +} + +template +struct SharedStorage { + static constexpr int AccumulatorPipelineStageCount = 16; + using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + MainloopPipelineStageCount, + Shape<_1,_1,_1>, + AtomThrShapeMNK>; + using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; + + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) cute::uint64_t tma_barrier[1]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + // cute::array_aligned> smem_A; + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + +}; + +CUTLASS_DEVICE +cutlass::Array +StochasticNumericConverterBase(cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + auto output_ptr = reinterpret_cast(&output); + asm volatile( \ + "{\n" \ + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \ + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \ + "}" \ + : "=h"(output_ptr[0]), + "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), + "f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]), + "r"(rbits[0]), "r"(rbits[1])); +#else + NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return output; +} + +CUTLASS_DEVICE +cutlass::Array +StochasticNumericConverter(cutlass::Array const &input, cutlass::Array const *rbits) { + using result_type = cutlass::Array; + result_type output; + cutlass::Array *result_ptr = reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = reinterpret_cast const *>(rbits); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; i++) { + result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); + } + return output; +} + +template +__global__ static +void +rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, + TA const* A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, + TB const* B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, + TC * C, CStride dC, CSmemLayout , + TSFC * SFC, + TiledMMA mma, + float const* global_amax, + const size_t* rng_state) +{ + using namespace cute; + using X = Underscore; + // static constexpr bool kApplyStochasticRounding = true; + using ElementAccumulator = float; + static constexpr int K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static constexpr uint32_t kTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(ASmemLayout{})) * cute::sizeof_bits_v); + + static constexpr int kTmaRhtTensorTransactionBytes = + cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v); + static constexpr int AccumulatorPipelineStageCount = 16; + + static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + MainloopPipelineStageCount, + Shape<_1,_1,_1>, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static constexpr int VectorSize = 16; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + // Represent the full tensors + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(16,16)); + Tensor mC = make_tensor(cute::subbyte_iterator(C), make_shape(M,N), dC); // (M,N) + + auto sfc_shape = make_shape( + M, + make_shape( make_shape(Int<16>{}, _4{}), N / 64 ) + ); + + auto sfc_stride = make_stride( + N / 16, + make_stride( make_stride(_0{}, _1{}), _4{} ) + ); + + auto sfc_layout = make_layout(sfc_shape, sfc_stride); + Tensor mSFC = make_tensor(make_gmem_ptr(SFC), sfc_layout); + + auto cluster_shape = Shape< _1, _1, _1>{}; + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + const int K_TILE_MAX = min(N, K) / 64; + uint32_t tiles_in_m = (M + size<0>(cluster_tile) - 1) / size<0>(cluster_tile); + uint32_t tiles_in_n = (N + 64 - 1) / 64; + uint32_t linear_tile_idx = blockIdx.x; + uint32_t tile_idx_m = linear_tile_idx % tiles_in_m; + uint32_t tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + + + auto mainloop_tiler = Shape<_128,_16,_64>{}; + auto epilogue_tiler = Shape<_128,_64,_64>{}; + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_,_, _), Step<_1, X,_1>{}); + Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_,_, _), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC_mn = local_tile(mC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + Tensor gSFC_mn = local_tile(mSFC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + // Allocate SMEM + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + + // + // MMA: Define C accumulators and A/B partitioning + // + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + auto mma_epilogue = make_tiled_mma(SM100_MMA_F16BF16_SS{}, + Layout>{}); + ThrMMA thr_mma_epilogue = mma_epilogue.get_slice(block_rank_in_cluster); + + + using TiledMmaEpilogue = decltype(mma_epilogue); + Tensor tCgA = thr_mma.partition_A(gA_mk); + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0,2>(ClusterTileShape{})); + auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0,2>(epilogue_tiler)); + + auto bulk_tmem_mma = TiledMMA::make_fragment_C(append(acc_shape_mma, + Int{})); + + auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C(append(acc_shape_epilogue, + Int{})); + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier(32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = tma_partition(tma_load_a, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsA), group_modes<0,3>(tCgA)); + + auto [tBgB, tBsB] = tma_partition(tma_load_b, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(tCsB), group_modes<0,3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_epilogue_warp = (warp_idx >= 4 && warp_idx <= 7); + + if (is_epilogue_warp && elect_one_sync()) { + cute::prefetch(raw_pointer_cast(global_amax)); + } + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, + mainloop_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * 128; + accumulator_pipeline_params.initializing_warp = 1; + AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + if (is_dma_warp) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_,0,0), tBsB(_,0)); + } + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + bool is_first_wave = linear_tile_idx == blockIdx.x; + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_,tile_idx_m,_); + int k_tile = 0; + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + + + CUTE_NO_UNROLL + while (k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n) { + int k_tile_idx_n = tile_idx_n + k_tile; + ++k_tile; + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + if (cute::elect_one_sync()) { + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_,k_tile_idx_n), tAsA(_,write_stage)); + } + } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } else if (is_mma_warp) { + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + CUTE_NO_UNROLL + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ) + { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_,_,_,read_stage); + auto tCrB_nk = tCrB(_,_,0,0); + CUTE_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block) + { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTE_UNROLL + for (int i = 0; i < 4; i++) { + auto accumulators = bulk_tmem_mma(_,_,_,accumulator_pipe_producer_state.index() * 4 + i); + gemm(mma, tCrA_mk(_,_,k_block * 4 + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } else if (is_epilogue_warp) { + const float global_amax_val = *global_amax; + static constexpr int FragmentSize = 256 / sizeof_bits_v; + + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int thread_idx = threadIdx.x % 128; + + Tensor tCgC = thr_mma_epilogue.partition_C(gC_mn); // (MMA,MMA_M,MMA_N) // (MMA,MMA_M,MMA_N) + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_,_,_,_0{})); + auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(thread_idx); + auto thr_r2g = tiled_r2g.get_slice(thread_idx); + + // NVFP4 non-E8 recipe constants and global scales + static constexpr float fp4_max = 6.0f; + + const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + const float global_decode_scale = 1.0f / global_encode_scale; + auto sfd_converter = cutlass::NumericConverter{}; + + do { + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) { + Tensor tCgC_mn = tCgC(_,_,_,tile_idx_m,tile_idx_n+k_tile); + + Tensor tCgSFC_mn = gSFC_mn(_,_,tile_idx_m,tile_idx_n+k_tile); + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto tCtC = bulk_tmem_epilogue(_,_,_,accumulator_pipe_consumer_state.index()); + Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + + Tensor tTR_rAcc = make_tensor(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrC = make_tensor(shape(tDgC)); + Tensor tTR_rAcc_frag = recast>(coalesce(tTR_rAcc)); + Tensor tDrC_frag = recast>(coalesce(tDrC)); + + Tensor src = thr_r2g.retile_S(tDrC); + Tensor dst = thr_r2g.retile_D(tDgC); + + Tensor tCgSFC = make_tensor(tCgSFC_mn.data(), make_layout( + make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{}) + )); + + Tensor tDgSFC = filter(thr_t2r.partition_D(tCgSFC)); + Tensor tDrSFC = make_tensor(shape(tDgSFC)); + + static constexpr int NumVecs = size(tDgC) / VectorSize; + Tensor tC_rRowSFD_frg = recast>(tDrSFC); + + cutlass::maximum_absolute_value_reduction, true> amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // TMEM_LOAD + copy(tiled_t2r, tDtC, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + + ++accumulator_pipe_consumer_state; + + // Cast data from FP32 to BF16 to FP32. + auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + + auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } + + pvscales = cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); + + tC_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); + auto acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + + // Initialize RNG for tile + const size_t rng_sequence + = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; + RNG rng(rng_seed, rng_sequence, rng_offset); + curanddx::uniform_bits dist; + uint4 random_uint4 = uint4{0, 0, 0, 0}; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scales[v], cutlass::platform::numeric_limits::max()); + // auto acc_scale = acc_scales[v]; + if constexpr (kEnableStochasticRounding) { + random_uint4 = dist.generate4(rng); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs[v], + acc_scale + ), + reinterpret_cast*>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}(cutlass::multiplies>{}(compute_frgs[v], acc_scale)); + } + } + + copy(tiled_r2g, src, dst); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFC, tDgSFC); + + } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + } +} + +// this function computes RHT-GEMM for +// A: m x n: col-major +// B: 16 x 16: row-major +// C: m x n: row-major +// SFC: m x (n/16): row-major +template +void +rht_gemm_ntt_w_sfc(int m, int n, + TA const* A, + TB const* B, + TC * C, + TSFC * SFC, + float const* global_amax, + const size_t* rng_state, + uint32_t sm_count, + cudaStream_t stream, + int k_tile_size = 2048) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = static_cast(m); + auto N = static_cast(n); + + // Define strides (mixed) + auto dA = make_stride(Int<1>{}, m); // (dM,dK) + auto dB = make_stride(Int<1>{}, 16); // (dN,dK) + auto dC = make_stride(n, Int<1>{}); // (dM,dN) + + auto cga_shape = Shape< _1, _1, _1>{}; + auto cga_tile_shape = Shape<_128,_16,_16>{}; + auto cluster_tile_mainloop = Shape<_128,_16,_64>{}; + + // Construct the MMA + auto mma = make_tiled_mma(SM100_MMA_F16BF16_SS{}, + Layout>{}); + + // MMA in CGA Layout XXX: Need to generalize synchro? {$nv-release-never} + + // Assert that the TiledMMA uses all CTAs in the CGA. + CUTE_STATIC_ASSERT_V(size(cga_shape) == size(mma)); + CUTE_STATIC_ASSERT_V(evenly_divides(cga_tile_shape, tile_shape(mma))); + + // Determine the A and B shapes + auto mma_shape_B = partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape))); + + using TiledMma = decltype(mma); + using AtomThrID = typename TiledMma::AtomThrID; + + using SmemShape_M = decltype(shape_div(shape<0>(cga_tile_shape), shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div(shape<1>(cga_tile_shape), shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(cga_tile_shape)); + + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TB, SmemShape_N, SmemShape_K>()); + + auto mma_shape_A = partition_shape_A(mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); + using SmemShape_M_A = decltype(shape_div(shape<0>(cluster_tile_mainloop), shape_div(shape<0>(cluster_tile_mainloop), size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); + using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop)); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); + + // Define the smem layouts (static) + // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory + constexpr int kBlackwellSmemSize = 232448; // 232KB in bytes + constexpr int kBytesPerStage = cute::size(mma_shape_A) * sizeof(TA) + cute::size(mma_shape_B) * sizeof(TB); + constexpr int kReservedBytes = 256; // Reserve for barriers and other uses + constexpr int kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; + auto sP = Int{}; // SMEM pipelines + auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP)); // (MMA,MMA_M,MMA_K,PIPE) + auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, append(mma_shape_B, sP)); // (MMA,MMA_N,MMA_K,PIPE) + auto sC = Layout<_1>{}; // XXX Dummy + + // Create GMEM tensors + Tensor tensorA = make_tensor(A, make_layout(make_shape(M,N), dA)); // (M,N) + Tensor tensorB = make_tensor(B, make_layout(make_shape(16,16), dB)); // (16,16) + + // Create the TiledCopy + + auto tma_load_a = make_tma_copy_A_sm100( + SM90_TMA_LOAD{}, + tensorA, + sA(_,_,_,0), + cluster_tile_mainloop, + mma); + auto tma_load_b = make_tma_copy_B_sm100( + SM90_TMA_LOAD{}, + tensorB, + sB(_,_,_,0), + cga_tile_shape, + mma); + + // Assert checks on tile sizes -- no predication + NVTE_CHECK(M % size<0>(cga_tile_shape) == 0, + "Inner dimension must be divisible by ", static_cast(size<0>(cga_tile_shape)), " but got ", M, "."); + NVTE_CHECK(N % (4 * size<1>(cga_tile_shape)) == 0, + "Outer dimension must be divisible by ", 4 * static_cast(size<1>(cga_tile_shape)), + " but got ", N, "."); + + uint32_t tiles = size(ceil_div(M, get<0>(cga_tile_shape))) * size(ceil_div(N, k_tile_size)); + + tiles = (tiles < sm_count) ? tiles : sm_count; + + dim3 dimBlock(256); + dim3 dimCluster(size<0>(cga_shape), size<1>(cga_shape), size<2>(cga_shape)); + dim3 dimGrid(tiles, 1, 1); + + int smem_size = sizeof(SharedStorage); + auto* kernel_ptr = &rht_gemm_device< + decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_tile_shape), + TA, decltype(dA), decltype(sA), decltype(tma_load_a), + TB, decltype(dB), decltype(sB), decltype(tma_load_b), + TC, decltype(dC), decltype(sC), + TSFC, + decltype(mma), + kEnableStochasticRounding>; + + bool status = cudaFuncSetAttribute(*kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (status != cudaSuccess) { + std::cerr << "Error: Failed to set Shared Memory size." << std::endl; + return; + } + (*kernel_ptr) + <<< dimGrid, dimBlock, smem_size, stream >>> + (M, N, k_tile_size, cga_tile_shape, + A, dA, sA, tma_load_a, + B, dB, sB, tma_load_b, + C, dC, sC, + SFC, + mma, global_amax, + rng_state); +} + +// this function is used to wrap the rht_gemm_ntt_w_sfc function +//to transpose the input tensor A +template +void +rht_gemm_ttt_wrapper(int m, int n, + TA const* A, + TB const* B, + TC * C, + TSFC * SFC, + float const* global_amax, + const size_t* rng_state, + uint32_t sm_count, + cudaStream_t stream, + int k_tile_size = 1024) +{ + // in addition to transpose the input tensor A + // we also need to reshape m, n to at best + // ultilize as many SMs as possible while keeping + // a relatively large contiguous dimension. + // for example, after swapping m, n for transpose purposes, + // the input / output tensor shapes for RHT-GEMM are: + // A: n x m: col-major + // B: 16 x 16: row-major + // C: n x m: row-major + // SFC: n x (m/16): row-major + rht_gemm_ntt_w_sfc( + n, m, + A, B, C, + SFC, global_amax, + rng_state, + sm_count, stream, + k_tile_size); +} + +} // namespace +} // namespace detail + +// clang-format on + +void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &output_, + const Tensor &hadamard_matrix_, + QuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform_cast_fusion_columnwise); + + // Check input and output tensors + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor &input = input_.data; + SimpleTensor &global_amax = output_.amax; + SimpleTensor &output_t = output_.data; + SimpleTensor &scale_inv_t = output_.scale_inv; + + // Stochastic rounding config + const bool use_stochastic_rounding = quant_config.stochastic_rounding; + const size_t *rng_state = nullptr; + if (quant_config.rng_state != nullptr) { + Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } + + // Template arguments + using TA = cute::bfloat16_t; + using TB = cute::bfloat16_t; + using TC = cutlass::float_e2m1_t; + using TSFC = cutlass::float_ue4m3_t; + + checkCuDriverContext(stream); + + // Check Hadamard matrix + constexpr int kHadamardDimension = 16; + NVTE_CHECK(hadamard_matrix_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Hadamard matrix must be BF16 tensor, but scaling mode is ", + to_string(hadamard_matrix_.scaling_mode), "."); + NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, + "Hadamard matrix must be BF16 tensor, but dtype is ", + to_string(hadamard_matrix_.dtype()), "."); + const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; + NVTE_CHECK( + (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), + "Hadamard matrix must have shape=", + std::vector{kHadamardDimension, kHadamardDimension}, + ", but got shape=", hadamard_matrix_.shape(), "."); + const size_t hadamard_dimension = hadamard_matrix.shape[0]; + + const size_t ndim = input.shape.size(); + const size_t n = input.shape[ndim - 1]; + size_t m = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + m *= input.shape[i]; + } + + auto sm_count = transformer_engine::cuda::sm_count(); + + NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); + + NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); + + int k_tile_size = 1024; + + if (m == 8192 && n == 5120) { + k_tile_size = 512; + } else if (m == 8192 && n == 10240) { + k_tile_size = 1024; + } else if (m == 8192 && n == 2560) { + k_tile_size = 1280; + } else if (m == 8192 && n == 11328) { + k_tile_size = 1024; + } else if (m == 8192 && n == 512) { + k_tile_size = 256; + } else if (m == 8192 && n == 3584) { + k_tile_size = 512; + } else if (m == 11328 && n == 8192) { + k_tile_size = 1024; + } else if (m == 5120 && n == 8192) { + k_tile_size = 512; + } else if (m == 10240 && n == 8192) { + k_tile_size = 1024; + } else if (m == 2560 && n == 8192) { + k_tile_size = 1280; + } else if (m == 512 && n == 8192) { + k_tile_size = 256; + } else if (m == 3584 && n == 8192) { + k_tile_size = 512; + } else if (m < 1024 || n < 1024) { + k_tile_size = 512; + } + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kUseStochasticRounding, + detail::rht_gemm_ttt_wrapper( + /*m=*/m, + /*n=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*C=*/reinterpret_cast(output_t.dptr), + /*SFC=*/reinterpret_cast(scale_inv_t.dptr), + /*global_amax=*/reinterpret_cast(global_amax.dptr), + /*rng_state=*/rng_state, + /*sm_count=*/sm_count, + /*stream=*/stream, + /*k_tile_size=*/k_tile_size);); +} + +} // namespace transformer_engine + +void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output, + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform_cast_fusion_columnwise); + using namespace transformer_engine; + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + hadamard_transform_cast_fusion_columnwise( + *convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, stream); +} diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 4d65e26ce..cffc411a0 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -67,6 +67,11 @@ class CommOverlapCore { std::vector _stream_compute; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; + private: + void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm); + public: CommOverlapCore() {} // dummy constructor for exposing type to Python @@ -78,17 +83,26 @@ class CommOverlapCore { virtual ~CommOverlapCore(); + void *get_ubuf_dptr() { return _ubuf.dptr(); } + void set_ubuf_scale_inv(float *scale_inv) { _ubuf_scale_inv = scale_inv; _ubuf_scale_inv_initialized = true; } + virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) { + NVTE_ERROR("Operation is not implemented."); + } + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, const std::vector &shape); TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, const std::vector &shape); + int get_tp_size() { return _tp_size; } + bool is_atomic_gemm() { return _atomic_gemm; } bool is_p2p_overlap() { return _is_p2p; } @@ -148,6 +162,10 @@ class CommOverlapBase : public CommOverlapCore { cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm); + public: CommOverlapBase() {} // dummy constructor for exposing type to Python @@ -224,6 +242,10 @@ class CommOverlapP2PBase : public CommOverlapCore { cudaStream_t _stream_recv; cudaEvent_t _stop_send, _stop_recv; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate); + public: CommOverlapP2PBase() {} // dummy constructor for exposing type to Python @@ -237,6 +259,9 @@ class CommOverlapP2PBase : public CommOverlapCore { virtual ~CommOverlapP2PBase(); + void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) override; + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 851032e04..8556596df 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -126,6 +126,24 @@ enum NVTE_Mask_Type { NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5, }; +/*! \enum NVTE_Softmax_Type + * \brief Attention softmax types as described in + * Efficient Streaming Language Models with Attention Sinks (https://arxiv.org/pdf/2309.17453v3). + * For a given attention score S = Q*K^T, different softmax types perform different operations on S, + * NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + * NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + * NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + * where alpha is a learnable parameter in shape [H]. + */ +enum NVTE_Softmax_Type { + /*! Vanilla softmax */ + NVTE_VANILLA_SOFTMAX = 0, + /*! Off-by-one softmax */ + NVTE_OFF_BY_ONE_SOFTMAX = 1, + /*! Learnable softmax */ + NVTE_LEARNABLE_SOFTMAX = 2, +}; + /*! \enum NVTE_Fused_Attn_Backend * \brief Fused attention backends */ @@ -191,6 +209,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] qkv_layout The layout of Tensors Q, K, V. * \param[in] bias_type The attention bias type. * \param[in] attn_mask_type The attention mask type. + * \param[in] softmax_type The attention softmax type. * \param[in] dropout The dropout probability. * \param[in] num_attn_heads The number of heads in Q. * \param[in] num_gqa_groups The number of heads in K, V. @@ -203,9 +222,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right); /*! \brief Compute dot product attention with packed QKV input. * @@ -243,6 +263,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * * \param[in] QKV The QKV tensor in packed format, H3D or 3HD. * \param[in] Bias The Bias tensor. + * \param[in] SoftmaxOffset The SoftmaxOffset tensor. * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, @@ -258,19 +279,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - const NVTETensor rng_state, size_t max_seqlen, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, + bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * @@ -309,6 +330,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * e.g. M, ZInv, rng_state. * \param[out] dQKV The gradient of the QKV tensor. * \param[out] dBias The gradient of the Bias tensor. + * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1]. * \param[in] max_seqlen Max sequence length used for computing, @@ -318,6 +340,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. @@ -327,10 +350,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream); @@ -371,6 +395,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] Q The Q tensor, in HD layouts. * \param[in] KV The KV tensor, in 2HD or H2D layouts. * \param[in] Bias The Bias tensor. + * \param[in] SoftmaxOffset The SoftmaxOffset tensor. * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, @@ -392,6 +417,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. @@ -399,13 +425,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] stream CUDA stream used for this operation. */ void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * @@ -446,6 +474,7 @@ void nvte_fused_attn_fwd_kvpacked( * \param[out] dQ The gradient of the Q tensor. * \param[out] dKV The gradient of the KV tensor. * \param[out] dBias The gradient of the Bias tensor. + * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. @@ -459,6 +488,7 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. @@ -468,12 +498,12 @@ void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream); + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with separate Q, K and V. * @@ -518,6 +548,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] K The K tensor. * \param[in] V The V tensor. * \param[in] Bias The Bias tensor. + * \param[in] SoftmaxOffset The SoftmaxOffset tensor. * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, @@ -539,22 +570,24 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] qkv_layout QKV tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream); + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -602,6 +635,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[out] dK The gradient of the K tensor. * \param[out] dV The gradient of the V tensor. * \param[out] dBias The gradient of the Bias tensor. + * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. @@ -615,6 +649,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] qkv_layout QKV tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. @@ -624,14 +659,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream); + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, + NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 58c0a1f96..dd312726a 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -17,9 +17,76 @@ #ifdef __cplusplus extern "C" { -#endif +#endif // __cplusplus -/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations. +/*! \brief Configuration for matrix multiplication. */ +typedef void *NVTEMatmulConfig; + +/*! \enum NVTEMatmulConfigAttribute + * \brief Type of option for matrix multiplication. + */ +enum NVTEMatmulConfigAttribute { + /*! Bias tensor + * + * If provided, the bias tensor is applied in the GEMM epilogue. + */ + kNVTEMatmulConfigBiasTensor = 0, + /*! Bias gradient tensor + * + * If provided, the bias gradient tensor will be filled in the GEMM epilogue. + */ + kNVTEMatmulConfigDBiasTensor = 1, + /*! Whether to compute GELU in GEMM epilogue. */ + kNVTEMatmulConfigWithGELUEpilogue = 2, + /*! Whether to compute GELU backward in GEMM epilogue. */ + kNVTEMatmulConfigWithDGELUEpilogue = 3, + /*! Auxilliary tensor for GEMM epilogue. + * + * For GELU, this will be filled with the GELU input. For GELU + * backward, this is expected to already be filled with the GELU + * input. + */ + kNVTEMatmulConfigEpilogueAuxTensor = 4, + /*! Whether to use split accumulator for FP8 GEMM. */ + kNVTEMatmulConfigUseSplitAccumulator = 5, + /*! Number of streaming multiprocessors to use in GEMM kernel. */ + kNVTEMatmulConfigSMCount = 6, + kNVTEMatmulConfigNumAttributes +}; + +/*! \brief Create a matrix multiplication configuration. */ +NVTEMatmulConfig nvte_create_matmul_config(); + +/*! \brief Query an option in matrix multiplication configuration. + * + * \param[in] config Matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to write option value. Ignored if + * NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. + */ +void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + void *buf, size_t size_in_bytes, size_t *size_written); + +/*! \brief Set an option in matrix multiplication configuration. + * + * \param[in] config Matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to read option value. + * \param[in] size_in_bytes Size of buf. + */ +void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes); + +/*! \brief Destroy a matrix multiplication configuration. */ +void nvte_destroy_matmul_config(NVTEMatmulConfig config); + +/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated). + * + * This has been deprecated in favor of nvte_cublas_gemm_v2. * * Computes: * - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors @@ -46,8 +113,31 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); +/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations. + * + * Computes: + * - `D = alpha * op(A) * op(B) + beta * C` + * + * \param[in] transa Whether to transpose A matrix. + * \param[in] transb Whether to transpose B matrix. + * \param[in] alpha Scaling factor applied to matmul output. + * \param[in] A A matrix. + * \param[in] B B matrix. + * \param[in] beta Scaling factor applied to C matrix. + * \param[in] C C matrix. + * \param[out] D Output matrix. + * \param[in] workspace Workspace tensor. + * \param[in] config Additional configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A, + const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D, + NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream); + /*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations, - * allowing for using a scaling factor for the GEMM result and the accumulation input + * allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated) + * + * This has been deprecated in favor of nvte_cublas_gemm_v2. * * Computes: * - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors @@ -135,14 +225,16 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] stream CUDA stream to wait on. */ -void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, - const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms, - bool transa, bool transb, bool grad, NVTETensor* workspace, +void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); #ifdef __cplusplus } // extern "C" -#endif +#endif // __cplusplus + +#ifdef __cplusplus /*! \namespace transformer_engine */ @@ -157,6 +249,89 @@ namespace transformer_engine { void nvte_cublas_handle_init(); #endif +/*! \struct MatmulConfigWrapper + * \brief C++ wrapper for NVTEMatmulConfig. + */ +class MatmulConfigWrapper { + public: + MatmulConfigWrapper() : config_{nvte_create_matmul_config()} {} + + MatmulConfigWrapper(const MatmulConfigWrapper &) = delete; + MatmulConfigWrapper &operator=(const MatmulConfigWrapper &) = delete; + + MatmulConfigWrapper(MatmulConfigWrapper &&other) : config_{other.config_} { + other.config_ = nullptr; + } + MatmulConfigWrapper &operator=(MatmulConfigWrapper &&other) { + if (config_ != nullptr) { + nvte_destroy_matmul_config(config_); + } + config_ = other.config_; + other.config_ = nullptr; + return *this; + } + + ~MatmulConfigWrapper() { + if (config_ != nullptr) { + nvte_destroy_matmul_config(config_); + config_ = nullptr; + } + } + + /*! \brief Get the underlying NVTEMatmulConfig. + * + * \return NVTEMatmulConfig held by this MatmulConfigWrapper. + */ + operator NVTEMatmulConfig() const noexcept { return config_; } + + /*! \brief Set bias tensor. */ + void set_bias_tensor(NVTETensor bias_tensor) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigBiasTensor, &bias_tensor, + sizeof(NVTETensor)); + } + + /*! \brief Set bias gradient tensor. */ + void set_dbias_tensor(NVTETensor dbias_tensor) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigDBiasTensor, &dbias_tensor, + sizeof(NVTETensor)); + } + + /*! \brief Set whether to compute GELU in GEMM epilogue. */ + void set_with_gelu_epilogue(bool with_gelu_epilogue) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithGELUEpilogue, + &with_gelu_epilogue, sizeof(bool)); + } + + /*! \brief Set whether to compute GELU backward in GEMM epilogue. */ + void set_with_dgelu_epilogue(bool with_dgelu_epilogue) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithDGELUEpilogue, + &with_dgelu_epilogue, sizeof(bool)); + } + + /*! \brief Set auxilliary tensor for GEMM epilogue. */ + void set_epilogue_aux_tensor(NVTETensor epilogue_aux_tensor) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigEpilogueAuxTensor, + &epilogue_aux_tensor, sizeof(NVTETensor)); + } + + /*! \brief Set whether to use split accumulator for FP8 GEMM. */ + void set_use_split_accumulator(bool use_split_accumulator) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigUseSplitAccumulator, + &use_split_accumulator, sizeof(bool)); + } + + /*! \brief Set number of streaming multiprocessors to use in GEMM kernel. */ + void set_sm_count(int sm_count) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigSMCount, &sm_count, sizeof(int)); + } + + private: + /*! \brief Wrapped NVTEMatmulConfig. */ + NVTEMatmulConfig config_ = nullptr; +}; + } // namespace transformer_engine +#endif // __cplusplus + #endif // TRANSFORMER_ENGINE_GEMM_H_ diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h new file mode 100644 index 000000000..a0dd325da --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -0,0 +1,68 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file hadamard_transform.h + * \brief Functions for Hadamard transforms. + */ + +#ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ +#define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Perform a randomized Hadamard transform on the input tensor. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] random_sign_mask 16-bit sign mask. + * \param[in] random_sign_mask_t 16-bit sign mask. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream); + +/*! \brief Perform the absolute maximum reduction on the input tensor with/without + * randomized hadamard transform. The rowwise result is the absolute maximum + * of the input tensor. The columnwise result is the absolute maximum of the + * input tensor transposed and applied randomized hadamard transformation. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] random_sign_mask 16-bit sign mask. + * \param[in] random_sign_mask_t 16-bit sign mask. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream); + +/*! \brief Perform the columnwise hadamard transform cast fusion. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] hadamard_matrix Hadamard matrix. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output, + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 89515108a..a5867276f 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -151,6 +151,10 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out, size_t start_offset, size_t block_len, const NVTEDType out_dtype, cudaStream_t stream); +void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, + const NVTETensor inpB, const bool use_rowwise_amax_B, + float alpha_in, NVTETensor alpha_out, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 70f90fa76..044e021e6 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -74,6 +74,7 @@ enum NVTETensorParam { kNVTEAmax = 3, /*!< Amax tensor */ kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */ kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ + kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */ kNVTENumTensorParams }; @@ -96,10 +97,9 @@ enum NVTEScalingMode { */ NVTE_BLOCK_SCALING_1D = 2, NVTE_BLOCK_SCALING_2D = 3, - /*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD), - and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD). - */ - NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4, + /*! Single scale per block of 16 elements consecutive in either + * rowwise or columnwise direction */ + NVTE_NVFP4_1D_SCALING = 4, NVTE_INVALID_SCALING = 100 }; @@ -338,6 +338,12 @@ enum NVTEQuantizationConfigAttribute { * likely be refactored away in the future. */ kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3, + /*! RNG state (NVTETensor with 2 elements - seed and offset */ + kNVTEQuantizationConfigRNGState = 4, + /*! Whether to use 2D block scaling for NVFP4 */ + kNVTEQuantizationConfigNVFP42DQuantization = 5, + /*! Whether to enable stochastic rounding */ + kNVTEQuantizationConfigStochasticRounding = 6, kNVTEQuantizationConfigNumAttributes }; @@ -449,6 +455,15 @@ inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; } inline bool is_fp4_dtype(const DType t) { return false; } #endif // #ifndef __HIP_PLATFORM_AMD__ +/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16) + * + * Return true if TE datatype is high precision + * \param[in] DType TE Datatype of interest + */ +inline bool is_high_precision_dtype(const DType t) { + return t == DType::kFloat32 || t == DType::kBFloat16 || t == DType::kFloat16; +} + /*! \struct TensorWrapper * \brief C++ wrapper for the NVTETensor class. */ @@ -584,6 +599,11 @@ class TensorWrapper { return set_parameter(kNVTEColumnwiseScaleInv, dptr, type, shape); } + template + TensorWrapper &set_columnwise_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEColumnwiseAmax, dptr, type, shape); + } + // Parameter getters NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { @@ -608,6 +628,10 @@ class TensorWrapper { return get_parameter(kNVTEColumnwiseScaleInv); } + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEColumnwiseAmax); + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. @@ -856,6 +880,24 @@ class QuantizationConfigWrapper { &format, sizeof(Float8BlockScaleTensorFormat)); } + /*! \brief Set stochastic rounding state */ + void set_rng_state(NVTETensor rng_state) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigRNGState, &rng_state, + sizeof(NVTETensor)); + } + + /*! \brief Set whether to use 2D block scaling for NVFP4 */ + void set_nvfp4_2d_quantization(bool nvfp4_2d_quantization) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP42DQuantization, + &nvfp4_2d_quantization, sizeof(bool)); + } + + /*! \brief Set whether to use stochastic rounding */ + void set_stochastic_rounding(bool stochastic_rounding) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigStochasticRounding, + &stochastic_rounding, sizeof(bool)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index 8c6fccfb5..1a1d3b287 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -30,7 +30,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_mxfp_scaling(z->scaling_mode)) { + !is_mxfp8_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 6c85cc432..b253e6cac 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -26,7 +26,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_mxfp_scaling(z->scaling_mode)) { + !is_mxfp8_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 466c2e605..8322238c3 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -6,7 +6,6 @@ """This module provides predefined FP8 recipes.""" from __future__ import annotations -import warnings import os from enum import Enum from typing import Optional, Union, Callable, NamedTuple @@ -40,9 +39,12 @@ class _FormatMaxVals(Enum): class Format(Enum): """ Supported FP8 formats. + Supported FP4 formats. Values ------ + E2M1 : + All FP4 tensors are in e2m1 format E4M3 : All FP8 tensors are in e4m3 format E5M2 : @@ -51,6 +53,7 @@ class Format(Enum): FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format """ + E2M1 = _FormatHelper(fwd=(6, 6), bwd=(6, 6)) E4M3 = _FormatHelper(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E4M3.value) E5M2 = _FormatHelper(fwd=_FormatMaxVals.E5M2.value, bwd=_FormatMaxVals.E5M2.value) HYBRID = _FormatHelper(fwd=E4M3.fwd, bwd=E5M2.bwd) @@ -58,9 +61,13 @@ class Format(Enum): @dataclass(frozen=True) class MMParams: - """for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator) - apply split accumulator or not, turning it on will increase accuracy but impact gemm performance, - so only turn it on for certain gemms + """Matrix multiplication options. + + Parameters + ---------- + use_split_accumulator : bool, default = `True` + Use FP8 fast accumulation on Hopper or Ada. For more details, + see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul. """ use_split_accumulator: bool = True @@ -71,10 +78,24 @@ class QParams: """Quantization parameters. power_2_scale: use power of 2 scale parameter amax_epsilon: optional minimum value of abs max + random_hadamard_transform: whether to use random hadamard transform + stochastic_rounding: whether to use stocastic rounding """ power_2_scale: bool = False amax_epsilon: float = 0.0 + random_hadamard_transform: bool = False + stochastic_rounding: bool = False + fp4_2d_quantization: bool = False + + def __repr__(self) -> str: + return ( + f"Qparams(\npower_2_scale={self.power_2_scale},\n" + f"amax_epsilon={self.amax_epsilon},\n" + f"random_hadamard_transform={self.random_hadamard_transform},\n" + f"stochastic_rounding={self.stochastic_rounding},\n" + f"fp4_2d_quantization={self.fp4_2d_quantization}\n)" + ) class Recipe: @@ -82,6 +103,10 @@ class Recipe: Base recipe class. """ + def nvfp4(self): + """Whether the given recipe is NVFP4 1D block scaling.""" + return isinstance(self, NVFP4BlockScaling) + def mxfp8(self): """Whether the given recipe is MXFP8 block scaling.""" return isinstance(self, MXFP8BlockScaling) @@ -200,6 +225,7 @@ def __repr__(self) -> str: f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " + f"reduce_amax={self.reduce_amax}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}" ) @@ -217,10 +243,11 @@ class Float8CurrentScaling(Recipe): pass. """ + use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" fp8_format: Format = Format.HYBRID - fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0) - fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0) - fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_quant_fwd_inp = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=False) fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True) fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) @@ -229,9 +256,6 @@ class Float8CurrentScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert ( - not self.fp8_dpa and not self.fp8_mha - ), "FP8 attention is not supported for Float8CurrentScaling." def __repr__(self) -> str: return ( @@ -367,3 +391,84 @@ def __repr__(self) -> str: f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}" ) + + +@dataclass() +class NVFP4BlockScaling(Recipe): + """ + Use the NVFP4 scaling strategy. + + This is a 2-level block scaling strategy. In level 1, each group of + 16 consecutive values is scaled together using their own scaling + factor. The type of the scaling factor is E4M3 (4 bits of exponent, + 3 bits of mantissa). In level 2, a global per tensor FP32 scaling + factor is used to scale the entire tensor. + + Since the scaling happens in a particular direction (either rowwise + or columnwise), in this recipe the quantized tensor and its transpose + are not numerically equivalent. Due to this, when Transformer Engine + needs both the tensor and its transpose (e.g. to calculate both + forward and backward pass), during the quantization both versions are + computed from the high precision input to avoid double quantization + errors. + + Parameters + ---------- + fp4_format : {Format.E2M1}, default = Format.E2M1 + FP4 data type. + fp8_format : {Format.E4M3}, default = Format.E4M3 + FP8 data type. Only E4M3 is supported. + fp8_dpa: bool, default = `False` + FP8 dot product attention. Not yet supported. + fp8_mha: bool, default = `False` + FP8 multi-head attention. Not yet supported. + """ + + # Configuration envvars + disable_rht: bool = os.getenv("NVTE_NVFP4_DISABLE_RHT", "0") == "1" + disable_stochastic_rounding: bool = ( + os.getenv("NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING", "0") == "1" + ) + disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" + + fp4_format: Format = Format.E2M1 + fp8_format: Format = Format.E4M3 + + # Not applying quantization to attention for now + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" + assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" + + # Quantization params + # Note: RHT is currently only applied to column-wise usage so that + # it can be used for wgrad GEMM. + self.fp4_quant_fwd_inp = QParams( + random_hadamard_transform=not self.disable_rht, + stochastic_rounding=False, + fp4_2d_quantization=False, + ) + self.fp4_quant_fwd_weight = QParams( + random_hadamard_transform=False, + stochastic_rounding=False, + fp4_2d_quantization=not self.disable_2d_quantization, + ) + self.fp4_quant_bwd_grad = QParams( + random_hadamard_transform=not self.disable_rht, + stochastic_rounding=not self.disable_stochastic_rounding, + fp4_2d_quantization=False, + ) + + def __repr__(self) -> str: + return ( + f"recipe_type={self.__class__.__name__}, " + f"fp4_format={str(self.fp4_format).split('.')[1]}, " + f"fp8_format={str(self.fp8_format).split('.')[1]}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}, " + f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " + f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " + f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " + ) diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index f3c6b7952..3f472ce81 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -28,6 +28,13 @@ using bf16__ = __hip_bfloat16; constexpr int amax_kernel_threads = 512; +__launch_bounds__(1) __global__ void zero_amax_kernel(float *amax_ptr, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + *amax_ptr = 0; +} + #ifdef __HIP_PLATFORM_AMD__ template @@ -118,7 +125,8 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, #endif const float *noop_ptr, cudaStream_t stream) { // Zero out amax so we can update with atomic max - NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); + zero_amax_kernel<<<1, 1, 0, stream>>>(amax, noop_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); // Return immediately if tensor is empty if (N == 0) { @@ -220,15 +228,17 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt // Check output tensor NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); auto &output = *convertNVTETensorCheck(output_); - NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, - "Output tensor for amax computation must be FP8 tensor with per-tensor scaling, " + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING || + output.scaling_mode == NVTE_NVFP4_1D_SCALING, + "Output tensor for amax computation must be FP8 tensor with per-tensor scaling or " + "NVFP4 1D scaling, " "but got scaling_mode=", to_string(output.scaling_mode)); NVTE_CHECK(output.amax.numel() == 1, "Output tensor for amax computation has invalid amax tensor " "(expected 1 entry, got shape=", output.amax.shape, ")"); - NVTE_CHECK(output.amax.dptr != nullptr, + NVTE_CHECK(output.amax.dptr != nullptr || output.columnwise_amax.dptr != nullptr, "Output tensor for amax computation has amax tensor without data"); NVTE_CHECK(output.amax.dtype == DType::kFloat32, "Output tensor for amax computation has invalid amax tensor " @@ -264,10 +274,12 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt } // Compute amax + float *amax_ptr = reinterpret_cast( + (output.amax.dptr != nullptr) ? output.amax.dptr : output.columnwise_amax.dptr); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel(reinterpret_cast(input.data.dptr), - reinterpret_cast(output.amax.dptr), input.data.numel(), + amax_ptr, input.data.numel(), #ifdef __HIP_PLATFORM_AMD__ block_amax, block_capacity, #endif diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu new file mode 100644 index 000000000..5ebc7ba4f --- /dev/null +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -0,0 +1,54 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include + +#include "../common.h" +#include "../utils.cuh" + +namespace transformer_engine { +namespace nvfp4_recipe { + +// constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0; +constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); + +// Kernel to compute alpha *= amax_A * amax_B / factor +__global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A, + const float *amax_B, float *alpha_out) { + // factor is defined in the enclosing namespace + *alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv; +} + +} // namespace nvfp4_recipe +} // namespace transformer_engine + +void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, + const NVTETensor inpB, const bool use_rowwise_amax_B, + float alpha_in, NVTETensor alpha_out, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_compute_per_tensor_scale); + using namespace transformer_engine; + + auto *tA = convertNVTETensor(inpA); + auto *tB = convertNVTETensor(inpB); + auto *tOut = convertNVTETensor(alpha_out); + + void *amax_A_ptr = use_rowwise_amax_A ? tA->amax.dptr : tA->columnwise_amax.dptr; + void *amax_B_ptr = use_rowwise_amax_B ? tB->amax.dptr : tB->columnwise_amax.dptr; + void *alpha_ptr = tOut->data.dptr; + + // check for not null pointers + NVTE_CHECK(amax_A_ptr != nullptr, "amax_A_ptr is null"); + NVTE_CHECK(amax_B_ptr != nullptr, "amax_B_ptr is null"); + NVTE_CHECK(alpha_ptr != nullptr, "alpha_ptr is null"); + + nvfp4_recipe::compute_nvfp4_per_tensor_scale_kernel<<<1, 1, 0, stream>>>( + alpha_in, reinterpret_cast(amax_A_ptr), + reinterpret_cast(amax_B_ptr), reinterpret_cast(alpha_ptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 499f7bcff..f090bb801 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -25,7 +25,10 @@ namespace { #endif #ifndef __HIP_PLATFORM_AMD__ -constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32; + +constexpr int MXFP8_BLOCK_SIZE = 32; +constexpr int NVFP4_BLOCK_SIZE = 16; + constexpr __device__ __host__ int TB_DIM = 32; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; @@ -37,6 +40,7 @@ constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32; // HIPCC does not support __host__ qualifier for variables // and constexpr values do not need __device__ qualifier because they are compile-time constants constexpr int MXFP8_BLOCK_SIZE = 32; +constexpr int NVFP4_BLOCK_SIZE = 16; constexpr int TB_DIM = 32; constexpr int NEW_SF_TILE_DIM_K = 16; constexpr int N_SF_PER_TD_PER_TILE = 4; @@ -333,8 +337,6 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ const int original_K = kernel_args.original_k_list[tensor_id]; constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); - constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; - constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; // Get block index in grid. Emulate 2D grid. const int num_tiles_k = K / SF_TILE_DIM_K; @@ -351,9 +353,13 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ } // namespace void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { - if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { - NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + "."); - } + NVTE_CHECK(input->scaling_mode == NVTE_MXFP8_1D_SCALING || + input->scaling_mode == NVTE_BLOCK_SCALING_1D || + input->scaling_mode == NVTE_BLOCK_SCALING_2D || + input->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ")."); + NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()), + "Input tensor has invalid dtype (", to_string(input->dtype()), ")."); // Do nothing if tensor is empty if (input->data.numel() == 0) { @@ -364,135 +370,162 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s CheckInputTensor(*output, "scaling_factor_output"); auto& scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING, + "Unsupported scaling mode for swizzling."); + + bool nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING; // 1D block scaling, row-wise or colum-wise - if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - const int m = - input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1]; - const int k = - input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0]; - - constexpr int SF_TILE_DIM_M = 128; - constexpr int SF_TILE_DIM_K = 4; - - NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); - NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); - NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); - if (output->has_data()) { - NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), - output->scale_inv.shape.end(), 1, std::multiplies()), - "Input.scale_inv size is not equal to Output.scale_inv size!"); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), - output->columnwise_scale_inv.shape.end(), 1, - std::multiplies()), - "Input.columnwise_scale_inv size is not equal to " - "Output.columnwise_scale_inv size!"); + int m, k; + if (input->has_data()) { + m = input->scale_inv.shape[0]; + k = input->scale_inv.shape[1]; + } else { + if (nvfp4) { + m = input->columnwise_scale_inv.shape[0]; + k = input->columnwise_scale_inv.shape[1]; + } else { + m = input->columnwise_scale_inv.shape[1]; + k = input->columnwise_scale_inv.shape[0]; } + } - int num_tiles_m = m / SF_TILE_DIM_M; - int num_tiles_k = k / SF_TILE_DIM_K; + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + if (output->has_data()) { + NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), + output->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), + output->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + } + + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + + // For NVFP4, the scale inverse for tranposed data needs rowwise swizzle. + const bool rowwise_swizzle = input->has_data() || nvfp4; + const bool columnwise_swizzle = input->has_columnwise_data() && !nvfp4; + + dim3 block_size(TB_DIM, TB_DIM); + if (rowwise_swizzle) { + int vec_load_size = (num_tiles_k - 1) % 4 + 1; + /* there is no int3 and misaligned if using int4/int2 */ + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + int original_M, original_K; + void *input_scale_inv_ptr, *output_scale_inv_ptr; + + if (!nvfp4 || input->has_data()) { + int block_scale_size = nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE; + original_M = input->flat_first_dim(); + original_K = input->flat_last_dim() / block_scale_size; + input_scale_inv_ptr = input->scale_inv.dptr; + output_scale_inv_ptr = output->scale_inv.dptr; + } else { + original_M = input->flat_last_dim(); + original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE; + input_scale_inv_ptr = input->columnwise_scale_inv.dptr; + output_scale_inv_ptr = output->columnwise_scale_inv.dptr; + } - dim3 block_size(TB_DIM, TB_DIM); - if (input->has_data()) { - int vec_load_size = (num_tiles_k - 1) % 4 + 1; - /* there is no int3 and misaligned if using int4/int2 */ - if (vec_load_size == 3) vec_load_size = 1; - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int original_M = input->flat_first_dim(); - const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; - switch (vec_load_size) { - case 4: + switch (vec_load_size) { + case 4: #ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); #endif - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - case 2: + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + case 2: #ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); #endif - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - case 1: + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + case 1: #ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); #endif - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; - } - NVTE_CHECK_CUDA(cudaGetLastError()); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; } - if (input->has_columnwise_data()) { - int vec_load_size = (num_tiles_m - 1) % 4 + 1; - if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int original_M = input->flat_last_dim(); - const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; - switch (vec_load_size) { - case 4: + } + if (columnwise_swizzle) { + int vec_load_size = (num_tiles_m - 1) % 4 + 1; + if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = input->flat_last_dim(); + const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; + // NVFP4 shouldn't end up here because it only needs rowwise swizzle + NVTE_CHECK(!nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle"); + + switch (vec_load_size) { + case 4: #ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); #endif - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - case 2: + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K); + break; + case 2: #ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); #endif - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - case 1: + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K); + break; + case 1: #ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); #endif - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; - } - NVTE_CHECK_CUDA(cudaGetLastError()); + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; } - - // 2D block scaling - } else { - NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); } NVTE_CHECK_CUDA(cudaGetLastError()); @@ -582,6 +615,8 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, } NVTE_CHECK_CUDA(cudaGetLastError()); } + +// TODO(nvfp4): Add NVFP4 support. void multi_tensor_swizzle_scaling_factors(const std::vector& input, std::vector& output, cudaStream_t stream) { auto num_tensors = input.size(); @@ -708,7 +743,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, * WIP (Phuong): * - Opt for bank conflicts * - Adding swizzle for 2d-block scaling. -*/ + */ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swizzle_scaling_factors); using namespace transformer_engine; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 68d1f0ec5..f26a94b88 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include "common.h" #include "common/util/cuda_runtime.h" @@ -65,8 +66,8 @@ std::string to_string(const NVTEScalingMode &mode) { return "NVTE_DELAYED_TENSOR_SCALING"; case NVTE_MXFP8_1D_SCALING: return "NVTE_MXFP8_1D_SCALING"; - case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: - return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING"; + case NVTE_NVFP4_1D_SCALING: + return "NVTE_NVFP4_1D_SCALING"; case NVTE_INVALID_SCALING: return "NVTE_INVALID_SCALING"; } @@ -96,12 +97,11 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { t.columnwise_scale_inv.shape, ")"); } } else { - if (t.scaling_mode == NVTE_MXFP8_1D_SCALING || - t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { + if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { // Need (4, 128) alignment even for e8 scaling factor auto block_alignment = std::vector{128ul, 4ul}; size_t expected_x, expected_y, alignment; - const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16; + const size_t block_size_rowwise = 32; const size_t block_size_colwise = 32; if (t.has_data()) { @@ -112,6 +112,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(block_size_rowwise)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, "\" has invalid scale_inv shape (expected ", expected, ", got ", @@ -124,11 +125,29 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { alignment; alignment = block_alignment[0]; expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(1)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", t.columnwise_scale_inv.shape, ")"); } + } else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) { + if (t.has_data()) { + const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_first_dim(), 128); + const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_last_dim(), 16lu), 4); + const auto &expected = std::vector{expected_y, expected_x}; + NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid scale_inv shape (expected ", expected, ", got ", + t.scale_inv.shape, ")"); + } + if (t.has_columnwise_data()) { + const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_last_dim(), 128); + const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_first_dim(), 16lu), 4); + const auto &expected = std::vector{expected_y, expected_x}; + NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", + t.columnwise_scale_inv.shape, ")"); + } } } } @@ -156,6 +175,26 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { "(expected Float32 or Byte, got ", to_string(t.columnwise_scale_inv.dtype), ")"); } + } else if (is_fp4_dtype(type)) { + // TODO(ksivaman): Fix this to check for amaxes and other details. + // For now only needed for swizzle. + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor input ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor input ", name, + "_scale_inverse has invalid dtype " + "(expected DType::kFloat8E4M3, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor input ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP8 scaling factor input ", + name, + "_columnwise_scale_inverse has invalid dtype " + "(expected DType::kFloat8E4M3, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name); NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name); @@ -197,10 +236,29 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt "(expected Float32 or Float8E8M0, got ", to_string(t.columnwise_scale_inv.dtype), ")"); } + } else if (is_fp4_dtype(type)) { + // FP4 output needs to have the scale_inv + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor output ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", name, + "_scale_inverse has invalid dtype " + "(expected Float8E4M3, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor output ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", + name, + "_columnwise_scale_inverse has invalid dtype " + "(expected Float8E4M3, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); - // Note: amax is supported for non-FP8 output as it can be fused into the computation - // and later used for quantization with no need to compute it separately + // Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax. + // NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name); NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name); @@ -493,6 +551,9 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, case kNVTEColumnwiseScaleInv: t->columnwise_scale_inv = *param; break; + case kNVTEColumnwiseAmax: + t->columnwise_amax = *param; + break; default: NVTE_ERROR("Unknown tensor parameter!"); } @@ -516,6 +577,8 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p return t.scale_inv; case kNVTEColumnwiseScaleInv: return t.columnwise_scale_inv; + case kNVTEColumnwiseAmax: + return t.columnwise_amax; default: NVTE_ERROR("Unknown tensor parameter!"); } @@ -631,6 +694,15 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size); break; + case kNVTEQuantizationConfigRNGState: + std::memcpy(&config_.rng_state, buf, attr_size); + break; + case kNVTEQuantizationConfigNVFP42DQuantization: + std::memcpy(&config_.nvfp4_2d_quantization, buf, attr_size); + break; + case kNVTEQuantizationConfigStochasticRounding: + std::memcpy(&config_.stochastic_rounding, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index abfa226e8..89266f4bb 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -8,6 +8,7 @@ #define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ #include "../common.h" +#include "transformer_engine/transformer_engine.h" namespace transformer_engine::detail { @@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor const bool pow_2_scale, const SimpleTensor &noop_tensor, cudaStream_t stream); +void quantize_transpose_vector_blockwise_fp4( + const SimpleTensor &input, const SimpleTensor &global_amax, SimpleTensor &scale_inv, + SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, + const bool return_identity, const bool return_transpose, const bool pow2_scale, + const bool swizzled_scale, const bool use_stochastic_rounding, + const NVTETensor rng_state_tensor, const bool use_2d_quantization, + const SimpleTensor &noop_tensor, cudaStream_t stream); + } // namespace transformer_engine::detail #endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu new file mode 100644 index 000000000..eced2c4bb --- /dev/null +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -0,0 +1,842 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/recipe/recipe_common.cuh" +#include "common/transpose/cast_transpose.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "curanddx.hpp" + +namespace transformer_engine { + +#if CUDA_VERSION >= 12080 +namespace quantize_transpose_nvfp4 { +namespace { + +using std::int32_t; +using std::uint32_t; +using std::uint8_t; + +using transformer_engine::detail::TypeExtrema; + +// Define a cuRANDDx descriptor +// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10. +// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g., +// if shared memory, if needed, is enough for the described problem, usually not applicable. +// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html +using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + + curanddx::SM<800>() + curanddx::Thread()); + +// clang-format off +/* + +Step 1: Load input to shared memory +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 8 times +* What each thread does in each loop: + * 8 elements are read from the input at a time + * 2 elements are written to the shared memory at a time, for a total of 4 times ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 1 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 7 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 8 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 2: Cast and store to output_c +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 4 times +* What each thread does in each loop: + * 2 elements are read from the shared memory at a time, for a total of 8 times + * Every 8 consecutive threads do reduction and calculate the amax of each row + * 16 elements are quantized and write to output_c at a time ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | +| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | +| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 1 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 7 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 4 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 3: Transpose, cast and store to output_t +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 2 times +* What each thread does in each loop: + * 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times + * Every 8 consecutive threads do reduction and calculate the amax of each column + * 16 elements are quantized and write to output_c at a time, for a total of 2 times ++------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+ +| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | | +| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | | +| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | | +| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | +| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | | +| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | | +| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | | +| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | | ++-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+ + +*/ +// clang-format on + +constexpr int kThreadsPerWarp = 32; + +// for fp4, we use uint8_t to store 2 fp4 numbers +constexpr int kNFP4PerContainer = 2; + +// Hyperparameters for performance tuning +constexpr int kTileDim = 128; +// constexpr int kScaleDim = 32; +constexpr int kNVecIn = 8; // The number of elements each LDG touches +constexpr int kNVecOut = 16; // The number of elements each STG touches +constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches +constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total + +// Auto-calculated constants, do not modify directly) +static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); +static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem"); +constexpr int kSMemRow = kTileDim; +constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1; +constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem; +constexpr int kNumThreadsLoad = kTileDim / kNVecIn; // 16 +constexpr int kNumThreadsStore = kTileDim / kNVecOut; // 8 +// constexpr int kNumThreadsReduce = kScaleDim / kNVecOut; +static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); +static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); + +// for 2D block scaling, we need to reduce amax in warp +static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = { + 0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080}; + +// max for every group_size elements in warp +template +__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) { + for (int offset = group_size / 2; offset > 0; offset /= 2) { + val = max(val, __shfl_down_sync(groupMask, val, offset * shfl_down_stride)); + } + return val; +} + +template +__device__ __forceinline__ ScaleType ComputeDecodeScaleFP4(const float amax, + const float global_encode_scale) { + float decode_scale = amax / TypeExtrema::max; + decode_scale = decode_scale * global_encode_scale; + decode_scale = fminf(decode_scale, TypeExtrema::max); + return static_cast(decode_scale); +} + +template +__device__ __forceinline__ float ComputeEncodeScaleFP4(ScaleType decode_scale, + const float global_decode_scale) { + return fminf(1.0f / (static_cast(decode_scale) * global_decode_scale), + TypeExtrema::max); +} + +template +__device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scale) { + return static_cast(input) * encode_scale; +} + +__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { + constexpr float fp8_max = TypeExtrema::max; + constexpr float fp4_max = TypeExtrema::max; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.f || global_encode_scale == 0.f) { + return 1.f; + } + return global_encode_scale; +} + +__device__ __forceinline__ uint32_t get_rbits(RNG& rng, uint4& random_uint4, int& rnd_idx) { + if (rnd_idx == 4) { + rnd_idx = 0; + curanddx::uniform_bits dist; + random_uint4 = dist.generate4(rng); + } + // Treat uint4 as an array of 4x uint32_t elements for indexing + const uint32_t* const rbits_arr = reinterpret_cast(&random_uint4); + const uint32_t rbits = rbits_arr[rnd_idx++]; + return rbits; +} + +template +__device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, size_t col_idx, + uint32_t col_length) { + // This function takes in indices from the scale factor matrix and returns an offset in the + // swizzled format. row_idx, col_idx are original indices from the scale factor matrix (unswizzled + // index). col_length is the column length of the scale factor matrix. tile_scales_inv is the + // pointer to the scale factor matrix. + + // https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts + // For any scale factor matrix, it's 512B base block. Each base block consists of 128 rows and 4 + // columns. Base block is divided into 4 column blocks, each column block has 32 rows and 4 + // columns. + + // NOTE: There are not a lot of good illustrations about the swizzled scale factor matrix. + // To think in high level, the swizzled scale factor matrix could be composed as: + // unswizzled_scale_factor_matrix = torch.empty((M, N // 16), dtype=torch.uint8) + // cbg_cnt = N // 16 // 4 # Assuming N is divisible by 64 + // rb_cnt = M // 128 # Assuming M is divisible by 128 + // tmp = unswizzled_scale_factor_matrix.reshape(rb_cnt, 4, 32, cbg_cnt, 4) + // tmp = torch.permute(tmp, (0, 3, 2, 1, 4)) + // swizzled_scale_factor_matrix = tmp.reshape((-1, 128, 4)) + + constexpr uint32_t kTotalRowsPerBaseBlock = 128; + constexpr uint32_t kRowsPerBaseBlockCol = 32; + constexpr uint32_t kColsPerBaseBlockCol = 4; + + const size_t rb = row_idx / kTotalRowsPerBaseBlock; + const size_t rem = row_idx % kTotalRowsPerBaseBlock; + const size_t d4 = rem / kRowsPerBaseBlockCol; + const size_t d3 = rem % kRowsPerBaseBlockCol; + const size_t cbg = col_idx / kColsPerBaseBlockCol; + const size_t d5 = col_idx % kColsPerBaseBlockCol; + + const size_t cbg_cnt = DIVUP(col_length, kColsPerBaseBlockCol); + // row-major offset in the logical shape + // (rb_cnt , cbg_cnt , 32 , 4 , 4) + // Magic number 16 below comes from the fact we have kColsPerBaseBlockCol = 4, and d4 ([0-128] / + // 32 = [0-4]) + return ((rb * cbg_cnt + cbg) * kRowsPerBaseBlockCol + d3) * 16 + d4 * kColsPerBaseBlockCol + d5; +} + +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding( + const float2 in01, const float2 in23, const uint32_t rbits) { +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + uint16_t out_4x; + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t" + "}" + : "=h"(out_4x) + : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + uint16_t dummy = 0; + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL +} + +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, + const float2 in23, + const uint32_t rbits) { +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + // NOTE: rbits unused for rn. + uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing. + asm volatile( + "{\n" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); + return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0]; +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + uint16_t dummy = 0; + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL +} + +template +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, + const uint32_t rbits) { + if constexpr (kApplyStochasticRounding) { + return cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, rbits); + } else { + return cvt_fp32_to_fp4_4x_with_rn(in01, in23, rbits); + } +} + +template +__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( + const IType* const input, const float* global_amax, OType* const output_c, + OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, + const size_t row_length, const size_t num_rows, const size_t scale_stride_x, + const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, + const size_t kScaleBlockDim, const float epsilon, const size_t* rng_state, + const float* noop_ptr) { + constexpr int kNVecContainer = kNVecOut / kNFP4PerContainer; + using SMemVec = Vec; + using OVec = Vec; + union IVec { + Vec input_type; + Vec smem_type; + }; + + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + + const size_t block_idx_x = blockIdx.x; + const size_t block_idx_y = blockIdx.y; + const size_t rng_sequence = + threadIdx.x + block_idx_x * kThreadsPerBlock + block_idx_y * gridDim.x * kThreadsPerBlock; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG rng(rng_seed, rng_sequence, rng_offset); + curanddx::uniform_bits dist; + uint4 random_uint4 = kApplyStochasticRounding ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + int rnd_idx = + 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + + extern __shared__ char smem_base[]; + SMemVec* smem = reinterpret_cast(&smem_base[0]); + + // 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode. + // Instead of static_assert, return early if these invalid modes are detected. + if constexpr (kIs2DBlockScaling && kIsE8Scaling) { + return; + } + if constexpr (kIs2DBlockScaling && !kReturnIdentity) { + return; + } + // for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4 + // use constexpr to define the size, when not using 2D, use minimal size 1x1 + constexpr int kFP4BlockScalingSize = 16; + constexpr int k2DBlockAmaxDim = kIs2DBlockScaling ? (kTileDim / kFP4BlockScalingSize) : 1; + constexpr int kNumRowsPerWarp = kThreadsPerWarp / kNumThreadsStore; // 4 + constexpr int k2DBlockAmaxReduceDim = + kIs2DBlockScaling ? (kFP4BlockScalingSize / kNumRowsPerWarp) : 1; + __shared__ CType amax_smem_red[k2DBlockAmaxDim][k2DBlockAmaxDim][k2DBlockAmaxReduceDim]; + __shared__ CType amax_smem[k2DBlockAmaxDim][k2DBlockAmaxDim]; + + // Step 1: Load input to shared memory + { + constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory + const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory + const size_t stride_g = static_cast(r_stride) * row_length; // Stride in global memory + const size_t num_ele = (c_g < row_length ? min(static_cast(kNVecIn), row_length - c_g) + : 0); // For not aligned case + const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + IVec input_vec; + // Step 1.1: Load from global memory (input) to registers + if constexpr (kAligned) { + input_vec.input_type.load_from(input_g); + } else { + if (r_g < num_rows) { + input_vec.input_type.load_from_elts(input_g, 0, num_ele); + } else { + input_vec.input_type.clear(); + } + } + // Step 1.2: Write to shared memory +#pragma unroll + for (int i = 0; i < kNVecIn / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i]; + } + // Step 1.3: Update input address, row index of shared memory, (and row index of global memory + // for not aligned case) + input_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + __syncthreads(); + + const int kNumThreadsReduce = kScaleBlockDim / kNVecOut; + const float global_encode_scale = + kIsE8Scaling ? 1.0f : ComputeGlobalEncodeScaleFP4(global_amax[0]); + const float global_decode_scale = 1.0 / global_encode_scale; + + // Step 2: Cast and store to output_c + if constexpr (kReturnIdentity) { + constexpr int r_stride = + kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory + const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory + const size_t stride_g = static_cast(r_stride) * row_length; // Stride in global memory + const size_t num_ele = + (c_g < row_length ? min(static_cast(kNVecOut / kNFP4PerContainer), + (row_length - c_g) / kNFP4PerContainer) + : 0); // For not aligned case + OType* output_g = + &output_c[(r_g * row_length + c_g) / kNFP4PerContainer]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = + (threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut / kNVecSMem]; + // Step 2.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } + // Step 2.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j])); + } + } + // Step 2.3: Reduce amax + if constexpr (kIsE8Scaling) { +#pragma unroll + for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + } + // doing shuffle sync for 2D block scaling (not applicable for E8 scaling) + if constexpr (kIs2DBlockScaling) { + // first amax shuffle sync in warp, then reduce in smem + // T0 T8 T16 T24 should do amax reduction together + constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32 + int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7 + int tid_in_warp_x = threadIdx.x % kNumThreadsStore; + int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp; + CType amax_warp_reduced = groupMax( + amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]); + // now T0 ~ T8 in each warp has the reduced amax values + int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y; + if (tid_in_warp_y == 0) { + amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] + [warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced; + } + __syncthreads(); + + if (data_row_idx % kFP4BlockScalingSize == 0) { + CType amax_2d = 0.0; + for (int i = 0; i < k2DBlockAmaxReduceDim; i++) { + amax_2d = fmaxf(amax_2d, + amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]); + } + amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d; + } + __syncthreads(); + // every thread now knows 2D amax + amax = amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x]; + } + // Step 2.4: Compute scale + ScaleType scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale); + float encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); + // Step 2.5: Write scale_inv + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g < num_rows); + write_scale_inv &= (c_g < row_length); + } + if (write_scale_inv) { + size_t row_idx = block_idx_y * kTileDim + r_s; + size_t col_idx = block_idx_x * (kNumThreadsStore / kNumThreadsReduce) + + (threadIdx.x % kNumThreadsStore) / kNumThreadsReduce; + if constexpr (kSwizzledScale) { + size_t offset = scale_factor_swizzled_offset( + row_idx, col_idx, DIVUP(row_length, kScaleBlockDim)); + tile_scales_inv_c[offset] = scale_inv; + } else { + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + } + } + // Step 2.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; i += 2) { + // Pack two elements into __nv_bfloat162 + float2 f2_a; + float2 f2_b; + f2_a.x = ComputeOutputFP4(smem_vec[i].data.elt[0], encode_scale); + f2_a.y = ComputeOutputFP4(smem_vec[i].data.elt[1], encode_scale); + f2_b.x = ComputeOutputFP4(smem_vec[i + 1].data.elt[0], encode_scale); + f2_b.y = ComputeOutputFP4(smem_vec[i + 1].data.elt[1], encode_scale); + const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + } + // Step 2.7: Store output_c + if constexpr (kAligned) { + output_vec.store_to(output_g); + } else { + if (r_g < num_rows) { + output_vec.store_to_elts(output_g, 0, num_ele); + } + } + // Step 2.8: Update output address, row index of shared memory (and row index of global memory + // for not aligned case) + output_g += stride_g / kNFP4PerContainer; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + // Step 3: Transpose, cast and store to output_t + if constexpr (kReturnTranspose) { + constexpr int c_stride = + kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory + constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); + const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory + int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory + size_t r_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Row in global memory + const size_t c_g = block_idx_y * kTileDim + r_s; // Column in global memory + const size_t stride_g = + static_cast(c_stride) * kNVecSMem * num_rows; // Stride in global memory + const size_t num_ele = (c_g < num_rows ? min(static_cast(kNVecOut / kNFP4PerContainer), + (num_rows - c_g) / kNFP4PerContainer) + : 0); // For not aligned case + OType* output_g = + &output_t[(r_g * num_rows + c_g) / kNFP4PerContainer]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = + (threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut]; + // Step 3.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + int r = r_s + i; + int c = c_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } +#pragma unroll + for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) { + // Step 3.2: Compute local amax + CType amax = 0; + if constexpr (kIs2DBlockScaling) { + // TODO(zhongbo): 2D block scaling, directly read from amax_smem + int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7 + constexpr int kNumColsPerWarp = + kThreadsPerWarp / kNumThreadsStore * kNVecSMem; // 8 elements + constexpr int kNumWarpsPerBlock = + kThreadsPerBlock / kThreadsPerWarp; // 8 warps per block + constexpr int kNumColsPerIter = kNumColsPerWarp * kNumWarpsPerBlock; + int tid_in_warp_x = (threadIdx.x / kNumThreadsStore) % kNumColsPerWarp; + int tid_in_warp_y = (threadIdx.x % kThreadsPerWarp) % kNumThreadsStore; + int data_col_idx = iter * kNumColsPerIter + warp_idx * kNumColsPerWarp + tid_in_warp_x; + amax = amax_smem[tid_in_warp_y][data_col_idx / kFP4BlockScalingSize]; + } else { +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx])); + } + } + // Step 3.3: Reduce amax + if constexpr (kIsE8Scaling) { +#pragma unroll + for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + } + // Step 3.4: Compute scale + ScaleType scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale); + float encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); + // Step 3.5: Write scale_inv_t + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g + smem_idx < row_length); + write_scale_inv &= (c_g < num_rows); + } + if (write_scale_inv) { + size_t row_idx = block_idx_x * kTileDim + c_s * kNVecSMem + smem_idx; + size_t col_idx = (block_idx_y * (kNumThreadsStore / kNumThreadsReduce) + + (threadIdx.x % kNumThreadsStore) / kNumThreadsReduce); + if constexpr (kSwizzledScale) { + size_t offset = scale_factor_swizzled_offset( + row_idx, col_idx, DIVUP(num_rows, kScaleBlockDim)); + tile_scales_inv_t[offset] = scale_inv; + } else { + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + // Step 3.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 2) { + // Pack two elements into __nv_bfloat162 + float2 f2_a; + float2 f2_b; + f2_a.x = + ComputeOutputFP4(smem_vec[2 * i].data.elt[smem_idx], encode_scale); + f2_a.y = ComputeOutputFP4(smem_vec[2 * i + 1].data.elt[smem_idx], + encode_scale); + f2_b.x = ComputeOutputFP4(smem_vec[2 * (i + 1)].data.elt[smem_idx], + encode_scale); + f2_b.y = ComputeOutputFP4(smem_vec[2 * (i + 1) + 1].data.elt[smem_idx], + encode_scale); + const uint32_t rbits = + kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + } + // Step 3.7: Store output_t + if constexpr (kAligned) { + output_vec.store_to(output_g + smem_idx * num_rows / kNFP4PerContainer); + } else { + if (r_g + smem_idx < row_length) { + output_vec.store_to_elts(output_g + smem_idx * num_rows / kNFP4PerContainer, 0, + num_ele); + } + } + } + // Step 3.8: Update output address, column index of shared memory (and row index of global + // memory for not aligned case) + output_g += stride_g / kNFP4PerContainer; + c_s += c_stride; + if constexpr (!kAligned) { + r_g += c_stride * kNVecSMem; + } + } + } +} + +} // namespace +} // namespace quantize_transpose_nvfp4 +#endif // CUDA_VERSION >= 12080 + +namespace detail { + +void quantize_transpose_vector_blockwise_fp4( + const SimpleTensor& input, const SimpleTensor& global_amax, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, + const bool return_identity, const bool return_transpose, const bool pow2_scale, + const bool swizzled_scale, const bool use_stochastic_rounding, + const NVTETensor rng_state_tensor, const bool use_2d_quantization, + const SimpleTensor& noop_tensor, cudaStream_t stream) { + NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); +#if CUDA_VERSION >= 12080 + + // pow 2 scale is for MXFP4 since it's using E8M0 scaling + // raise error if pow2_scale is true + NVTE_CHECK(!pow2_scale, "No support for pow2_scale for MXFP4 for now"); + + if (!return_identity && !return_transpose) { + return; + } + + if (use_2d_quantization && !return_identity) { + return; + } + + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; + size_t num_elements = row_length; + size_t num_rows = 1; + for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) { + num_rows *= input.shape.at(i); + num_elements *= input.shape.at(i); + } + + // Early return if the input tensor is empty + if (num_elements == 0) { + return; + } + + size_t scale_stride_x = 0; + size_t scale_stride_y = 0; + + if (return_identity) { + scale_stride_x = 1; + scale_stride_y = scale_inv.shape[1]; + } + + size_t scale_t_stride_x = 0; + size_t scale_t_stride_y = 0; + + if (return_transpose) { + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + } + + using namespace transformer_engine::quantize_transpose_nvfp4; + + const size_t num_blocks_x = DIVUP(row_length, static_cast(kTileDim)); + const size_t num_blocks_y = DIVUP(num_rows, static_cast(kTileDim)); + + // noop tensor for cuda graph + const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); + + const size_t* rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor& rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype, InputType, + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY( + output.dtype, 2, OutputType, + + dim3 grid(num_blocks_x, num_blocks_y, 1); + + using ScaleType = fp8e4m3; constexpr int kScaleBlockDim = 16; + constexpr bool kPow2Scale = false; + + const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_identity, kReturnIdentity, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transpose, kReturnTranspose, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + full_tile, kAligned, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + swizzled_scale, kSwizzledScale, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kApplyStochasticRounding, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_2d_quantization, kIs2DBlockScaling, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + auto kernel = block_scaled_1d_cast_transpose_kernel< + kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, + float, InputType, OutputType, ScaleType, kSwizzledScale, + kApplyStochasticRounding, kIs2DBlockScaling>; + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + NVTE_CHECK(err == cudaSuccess, + "Failed to set dynamic shared memory size."); + } kernel<<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(global_amax.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, + num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, + scale_t_stride_y, kScaleBlockDim, epsilon, rng_state, + noop_ptr);) // kIs2DBlockScaling + ) // kApplyStochasticRounding + ) // kSwizzledScale + ) // kAligned + ) // kReturnTranspose + ) // kReturnIdentity + ) // OutputType + ) // InputType + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +} + +} // namespace detail +} // namespace transformer_engine diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index dcb3aa42d..416d51e04 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -606,6 +606,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_DGATED) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + // const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2; const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { @@ -836,6 +837,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); } } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; @@ -955,6 +957,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu const size_t in_gate_mem = buff_size_aligned_in; const size_t out_act_mem = buff_size_aligned_out; const size_t out_gate_mem = buff_size_aligned_out; + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index b7c4cf837..e29a400d9 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -16,6 +16,7 @@ #include #ifndef __HIP_PLATFORM_AMD__ #include +#include "nvfp4_transpose.cuh" #endif //#ifndef __HIP_PLATFORM_AMD__ #include #include @@ -116,6 +117,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + // helps resolving bank conflicts in shmem const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; @@ -143,8 +146,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned IType *in_sh = reinterpret_cast(dshmem); IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); - OType *out_rowwise_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; @@ -292,7 +296,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float scaled_out = in * block_scale_inverse; const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_sh[shmem_offset_elt] = static_cast(scaled_out); + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); } } @@ -416,10 +420,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 2. Compute E8M0 scaling factor const e8m0_t biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const size_t stage_scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent; + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; @@ -447,7 +453,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_sh[shmem_offset_rowwise]); + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } } @@ -462,19 +468,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { - const size_t global_offset_Y = block_offset_Y + stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t buff_offset = buff * BUFF_DIM; + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; if constexpr (ROWWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_sh[buff_offset])); + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); } if constexpr (COLWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_sh[buff_offset])); + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); } // Create a "bulk async-group" out of the previous bulk copy operation. @@ -495,18 +501,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Added extra 1-element padding per thread_X to reduce bank conflicts float *partial_dbias_rowwise = reinterpret_cast(dshmem); - constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - const size_t shmem_thread_offset = + const int shmem_thread_offset = tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { const int j = w * PACK_SIZE + e; - const size_t shmem_elt_idx = swizzled_group_offset + e; + const int shmem_elt_idx = swizzled_group_offset + e; partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; } } @@ -514,15 +520,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int i = 0; i < THREADS_Y; ++i) { // Add extra element offset per MXFP8 scaling block [1x32] - const size_t scaling_block = threadIdx.x / SCALE_DIM_X; + const int scaling_block = threadIdx.x / SCALE_DIM_X; thread_partial_dbias += partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; } } - const size_t dbias_stride = cols; - const size_t dbias_offset_Y = blockIdx.y; - const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; - const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const int dbias_stride = cols; + const int dbias_offset_Y = blockIdx.y; + const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); if (!col_out_of_bounds_dbias) { dbias_workspace[dbias_idx] = thread_partial_dbias; @@ -544,6 +550,528 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } // namespace mxfp8_kernel +namespace nvfp4_kernel { + +using namespace ptx; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 16; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = 32; + +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16 + +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { + constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + return static_cast(block_amax * rcp_6f * S_enc); +} + +#define DIRECT_SCALING_FACTORS_STORE 1 + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0, + const float *noop, float *const amax_ptr, + const float *const nvfp4_second_stage_scale_ptr, const size_t rows, + const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool ROWWISE_SCALING = true; + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE; + + static_assert(BUFF_DIM_Y >= SCALE_DIM_Y && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); + static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + + constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size + constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X; + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + + constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE; + // static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of + // // threads to process one row in a single iteration + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const int tid_Y_colwise = 0; + const int tid_X_colwise = threadIdx.x; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t buff_size_nvfp4_scales = + CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0); + constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0); + constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0); + constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + fp8e4m3 *out_rowwise_scales_sh = + reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + e8m0_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = + (nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr); + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; + + const int buff_offset_in = buff * BUFF_IN_DIM; + const int buff_offset_out = buff * BUFF_OUT_DIM; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_IN_DIM; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM_Y]; + IType in_colwise_IType[SCALE_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(block_amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + if (colwise_scale_is_within_bounds) { + scales_colwise_e8m0[scale_idx] = biased_exponent; + } + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (int it = 0; it < ITERATIONS_ROWWISE; ++it) { + const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const int shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const int shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc); + +#if DIRECT_SCALING_FACTORS_STORE + // Check boundaries + if (rowwise_scale_is_within_bounds) { + const int scales_offset_Y = + scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = scales_offset_X_rowwise; + const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X; + scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8; + } +#else + const int shmem_scales_offset_Y = + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise; + const int shmem_scales_offset_X = tid_X_rowwise; + const int scale_idx = + shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X; + out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8; +#endif + // Compute "correct" per-block encoding scaling factor + const float block_scale_inverse = + __fdiv_rn(S_enc, static_cast(S_dec_b_fp8)); // S_enc_b_fp8 + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; // Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in01 = in_IType[w].data.elt[2 * e]; + in23 = in_IType[w].data.elt[2 * e + 1]; + } else if constexpr (IS_CACHED_ACT_OP) { + in01.x = in_cached[w].data.elt[4 * e]; + in01.y = in_cached[w].data.elt[4 * e + 1]; + in23.x = in_cached[w].data.elt[4 * e + 2]; + in23.y = in_cached[w].data.elt[4 * e + 3]; + } else { + const int j = w * PACK_SIZE + 4 * e; + in01.x = in_compute_rowwise[j]; + in01.y = in_compute_rowwise[j + 1]; + in23.x = in_compute_rowwise[j + 2]; + in23.y = in_compute_rowwise[j + 3]; + } + fp4e2m1x4 &out_quad = reinterpret_cast(out.data.elt[e]); + ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse); + } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + __builtin_assume(block_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM; + const int buff_offset_mxfp8 = buff * BUFF_IN_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset_nvfp4])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset_mxfp8])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + +#if !DIRECT_SCALING_FACTORS_STORE + // Vectorized store of scaling factors. + // Each thread stores multiple scaling factors in one store instruction. + if constexpr (ROWWISE_SCALING) { + // Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise; + const int scale_idx_global = + scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise; + const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + + if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) && + (scales_offset_X_rowwise < (cols / SCALE_DIM_X))) { + using ScalesVec_t = Vec; + const ScalesVec_t &scales = + *reinterpret_cast(&out_rowwise_scales_sh[scale_idx_shmem]); + scales.store_to(&scales_rowwise_e4m3[scale_idx_global]); + } + } +#endif + + float chunk_amax = 0.0f; + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + chunk_amax = reduce_max(thread_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, chunk_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace nvfp4_kernel + constexpr size_t FP8_CHUNK_DIM_Y = 128; constexpr size_t FP8_CHUNK_DIM_X = 128; constexpr size_t FP8_THREADS_PER_CHUNK = 128; @@ -908,7 +1436,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, #ifndef __HIP_PLATFORM_AMD__ template -static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { +void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { const size_t N = product(input.data.shape); const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); @@ -1221,6 +1749,143 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ); // NOLINT(*) } +#ifndef __HIP_PLATFORM_AMD__ +// This kernel supports only two scaling cases: +// 1. r16c0 - Rowwise NVFP4 +// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8 +template +void nvfp4_quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) { + using namespace nvfp4_kernel; + using namespace ptx; + checkCuDriverContext(stream); + + NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated."); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + bool use_colwise_scaling = output->has_columnwise_data(); + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + constexpr size_t CHUNK_DIM_Y = 128; + constexpr size_t CHUNK_DIM_X = 128; + constexpr size_t THREADS_PER_CHUNK = 128; + + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = output->scale_inv.shape[1]; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast(output->scale_inv.dptr); + e8m0_t *const scales_colwise_e8m0_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const ScalingType scaling_type = + use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE; + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const nvfp4_second_stage_scale_ptr = + reinterpret_cast(output->scale.dptr); + + // Output data type is only required for the column-wise MXFP8 scaling. + // It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work + const DType output_data_type = + use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, nvfp4_kernel::BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, + nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, 4); + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8); + } + + constexpr size_t buff_elems = nvfp4_kernel::BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = nvfp4_kernel::BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_nvfp4_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4; + const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0; + + const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales; + const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0; + + const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem + + out_rowwise_scales_mem + out_colwise_scales_mem + + TMA_SHMEM_ALIGNMENT; + + const size_t dshmem_size = in_mem + out_mem; + + switch (scaling_type) { + case ScalingType::ROWWISE: + cudaFuncSetAttribute( + cast_nvfp4_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_nvfp4_kernel + <<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + case ScalingType::BIDIMENSIONAL: + cudaFuncSetAttribute( + cast_nvfp4_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_nvfp4_kernel + <<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + }); // NOLINT(*) + ); // NOLINT(*) +} +#endif //#ifndef __HIP_PLATFORM_AMD__ + namespace detail { using Empty = transformer_engine::Empty; @@ -1457,20 +2122,33 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o auto dbias_tensor = convertNVTETensor(dbias); auto workspace_tensor = convertNVTETensor(workspace); - const QuantizationConfig *quant_config_cpp = - reinterpret_cast(quant_config); + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } - // extract noop tensor from quant_config_cpp if it's not null - const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; - const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor(); + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + // Dispatch to quantization kernel depending on data format switch (output_tensor->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { if (output_tensor->has_columnwise_data()) { NVTE_CHECK(output_tensor->has_data(), "Quantizing in only the columnwise direction not supported yet!"); if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); + cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); } else { cast_transpose_fused( *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, @@ -1478,52 +2156,91 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o } } else if (output_tensor->has_data()) { fp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); } break; } case NVTE_MXFP8_1D_SCALING: { mxfp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); break; } #ifndef __HIP_PLATFORM_AMD__ + case NVTE_NVFP4_1D_SCALING: { + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + bool use_optimized_kernel = dtype == DType::kBFloat16 && rows % 32 == 0 && cols % 32 == 0 && + output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4_quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4_quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_NVFP4_1D_SCALING for " + "2D quantization"); + quantize_transpose_vector_blockwise_fp4( + /*input=*/input_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } case NVTE_BLOCK_SCALING_2D: { // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, epsilon, /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor.data, stream); + /*noop_tensor=*/noop_tensor->data, stream); break; } case NVTE_BLOCK_SCALING_1D: { // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; if (output_tensor->has_data()) { - bool rowwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; + bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; } if (output_tensor->has_columnwise_data()) { - bool columnwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; + bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); columnwise_option = columnwise_compact ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; @@ -1531,7 +2248,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o quantize_transpose_vector_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor.data, stream); + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); break; } #endif diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index aaeb169b1..ed8a23631 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -21,6 +21,8 @@ #include #include +#include +#include #include #include "../common.h" @@ -30,6 +32,7 @@ #include "math.h" #include "ptx.cuh" #include "transformer_engine/activation.h" +#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transpose.h" #ifdef __HIP_PLATFORM_AMD__ #include "rocm_dequantize_kernels.cuh" @@ -219,7 +222,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } #endif // #ifndef __HIP_PLATFORM_AMD__ -static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); @@ -240,7 +243,7 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str ); // NOLINT(*) } -static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { bool use_rowwise_scaling = input.has_data(); bool use_colwise_scaling = input.has_columnwise_data(); #ifndef __HIP_PLATFORM_AMD__ @@ -336,6 +339,81 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s #endif NVTE_CHECK_CUDA(cudaGetLastError()); } + +#if CUDA_VERSION >= 12080 +template +__global__ void __launch_bounds__(512) + dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, + const float *const tensor_amax, const size_t N, const size_t M, + const size_t scale_stride) { + const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t x = thread_idx % M; + const size_t y = thread_idx / M; + + union fp4vec { + uint64_t vec; + fp4e2m1x4 small_vec[4]; + }; + using OVec = Vec; + const uint64_t *const input_vectorized = reinterpret_cast(input); + OVec *output_vec = reinterpret_cast(output); + + const size_t my_index = x + y * M; + const size_t my_scale_index = x + y * scale_stride; + const size_t my_output_index = (x + y * M) * 4; + fp4vec value; + value.vec = input_vectorized[my_index]; + fp8e4m3 scale = scales[my_scale_index]; + float amax = *tensor_amax; + constexpr float factor_inv = 1.0 / (6.0 * 448.0); + float final_scale = static_cast(scale) * amax * factor_inv; +#pragma unroll + for (int i = 0; i < 4; i++) { + float4 current = static_cast(value.small_vec[i]); + OVec out; + out.data.elt[0] = static_cast(current.x * final_scale); + out.data.elt[1] = static_cast(current.y * final_scale); + out.data.elt[2] = static_cast(current.z * final_scale); + out.data.elt[3] = static_cast(current.w * final_scale); + output_vec[my_output_index + i] = out; + } +} +#endif // CUDA_VERSION + +void fp4_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +#if CUDA_VERSION >= 12080 + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output"); + NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type."); + NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + constexpr int FP4_BLOCK_SIZE = 16; + const size_t N = input.flat_first_dim(); + const size_t M = input.flat_last_dim(); + + NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ", + FP4_BLOCK_SIZE, ", but got ", input.data.shape, "."); + + const size_t Mread = M / FP4_BLOCK_SIZE; + const size_t total = N * Mread; + const size_t threads = 512; + const size_t blocks = DIVUP(total, threads); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + dequantize_fp4_kernel<<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back());); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif // CUDA_VERSION >= 12080 +} + } // namespace dequantization namespace detail { @@ -344,21 +422,29 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) CheckInputTensor(input, "cast_input"); CheckOutputTensor(*output, "cast_output"); - if (is_tensor_scaling(input.scaling_mode)) { - dequantization::fp8_dequantize(input, output, stream); - } else if (is_mxfp_scaling(input.scaling_mode)) { + switch (input.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + dequantization::fp8_dequantize(input, output, stream); + break; + } + case NVTE_MXFP8_1D_SCALING: { #ifdef __HIP_PLATFORM_AMD__ - if (1) { + if (1) { #else - if (is_supported_by_CC_100()) { + if (is_supported_by_CC_100()) { #endif - dequantization::mxfp8_dequantize(input, output, stream); - } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + dequantization::mxfp8_dequantize(input, output, stream); + } else { + NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + } + break; } - } else { - // TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING - NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + case NVTE_NVFP4_1D_SCALING: { + dequantization::fp4_dequantize(input, output, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); } } diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 6ab5eb958..64f43944e 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -18,6 +18,10 @@ #endif // __HIP_PLATFORM_AMD__ #include +#ifndef __HIP_PLATFORM_AMD__ +#include "nccl.h" +#endif // !__HIP_PLATFORM_AMD__ + #ifdef NVTE_WITH_CUBLASMP #include #endif // NVTE_WITH_CUBLASMP @@ -121,4 +125,14 @@ #endif // NVTE_WITH_CUBLASMP +#ifndef __HIP_PLATFORM_AMD__ +#define NVTE_CHECK_NCCL(expr) \ + do { \ + const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \ + if (status_NVTE_CHECK_NCCL != ncclSuccess) { \ + NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \ + } \ + } while (false) +#endif // !__HIP_PLATFORM_AMD__ + #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ diff --git a/transformer_engine/common/util/nvfp4_transpose.cuh b/transformer_engine/common/util/nvfp4_transpose.cuh new file mode 100644 index 000000000..fd9f0a074 --- /dev/null +++ b/transformer_engine/common/util/nvfp4_transpose.cuh @@ -0,0 +1,1518 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file nvfp4_transpose.cuh + * \brief CUDA kernels to cast to NVFP4 and transpose. + */ + +#ifndef TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ +#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ + +#ifndef __HIP_PLATFORM_AMD__ +#include +#include +#include + +#if CUDA_VERSION > 12080 +#include +#endif // CUDA_VERSION > 12080 + +#include + +#include "../common.h" +#include "../utils.cuh" +#include "curanddx.hpp" +#include "math.h" +#include "ptx.cuh" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { + +#if CUDA_VERSION > 12080 +namespace nvfp4_transpose { + +using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + + curanddx::SM<800>() + curanddx::Thread()); + +using namespace ptx; +using nvfp4_scale_t = fp8e4m3; + +constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts) + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_NUM = 128; + +constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; + +constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM; +constexpr size_t RNG_GENS_PER_THREAD = + SCALES_PER_THREAD / 4; // Each call generates 4x uint32_t random numbers + +constexpr size_t TILE_DIM_Y = 32; +constexpr size_t TILE_DIM_X = 128; + +// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D +constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8 + +constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; +constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X; +constexpr size_t STAGES = TILES_Y * TILES_X; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = TILE_DIM_Y; +constexpr size_t BUFF_DIM_X = TILE_DIM_X; +constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; +constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; + +// Input buffer (BF16) +constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X; +constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +// Output buffer (NVFP4) +constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; +constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; + +// Output transpose buffer (NVFP4) +constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X; +constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8; +constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X; + +// Manual swizzling parameters to reduce SHMEM bank conflicts +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM / PACK_SIZE; + +constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM; +constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8 +constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16 + +constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2 +constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM; +constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE; + +static_assert(BUFF_DIM_Y >= SCALE_DIM && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); +static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16 + +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { + // constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + // NOTE: Divide by 6.0f is not elegant and not efficient. + // However, this is part of the emulation code to ensure exact match. + using namespace detail; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + const float S_dec_b = block_amax / fp4_max * S_enc; + return static_cast(fminf(S_dec_b, TypeExtrema::max)); +} + +// Compute the global encode scale factor for a given global amax +__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { + using namespace detail; + constexpr float fp8_max = TypeExtrema::max; // 448.0f; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.0f || global_encode_scale == 0.0f) { + return 1.0f; + } + return global_encode_scale; +} + +__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) { + if (rnd_idx == 4) { + rnd_idx = 0; + curanddx::uniform_bits dist; + random_uint4 = dist.generate4(rng); + } + // Treat uint4 as an array of 4x uint32_t elements for indexing + const uint32_t *const rbits_arr = reinterpret_cast(&random_uint4); + const uint32_t rbits = rbits_arr[rnd_idx++]; + return rbits; +} + +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( + const uint64_t in_4x, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + // NOTE: rbits unused for rn. + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits); + } else { + return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits); + } +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( + const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale)), "r"(rbits)); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01, + const float2 in23, + const float2 scale, + const uint32_t rbits) { + // NOTE: rbits unused for rn. + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits); + } else { + return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); + } +} + +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +template +__global__ void __launch_bounds__(THREADS_NUM) + nvfp4_transpose_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr, + const float *noop, const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG rng(rng_seed, rng_sequence, rng_offset); + curanddx::uniform_bits dist; + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + int rnd_idx = + 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + // const size_t tid_X_t = 0; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + // NOTE: This is to match with how emulation code was written. + const float S_dec_rowwise = 1.0 / S_enc_rowwise; + + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_dec_colwise = 1.0 / S_enc_colwise; + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = + (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements + fp4e2m1x4 regs[SCALE_DIM / 4]; + +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = mul_cvt_bf16_to_fp4_4x(elts, block_scale_inverse_2x, + rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + const size_t global_offset_Y_t = block_offset_Y_t; + const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // Vectorized store scaling factors through SHMEM + if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + using ScalesVec = Vec; + const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + const size_t count = // number of scales in Y dimension of this chunk + (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // Fast path: vectorized store when destination is properly aligned + scales_vec.store_to(dst); + } else { + // Safe path: element-wise store for tails or unaligned destinations + scales_vec.store_to_elts(dst, 0, count); + } + } + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void __launch_bounds__(THREADS_NUM) + nvfp4_transpose_kernel_2D(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr, + const float *noop, const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG rng(rng_seed, rng_sequence, rng_offset); + curanddx::uniform_bits dist; + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + int rnd_idx = + 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + + // NEW: 2D Block-based scaling constants + constexpr size_t BLOCK_DIM = 16; + constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2 + constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8 + constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile + constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2 + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + // NOTE: This is to match with how emulation code was written. + const float S_dec_rowwise = 1.0 / S_enc_rowwise; + + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_dec_colwise = 1.0 / S_enc_colwise; + + const size_t warp_id = threadIdx.x / 32; + const size_t lane_id = threadIdx.x % 32; + float thread_amax = 0.0f; + const size_t block_in_warp = lane_id / BLOCKS_PER_WARP; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + __shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1]; + + // Helper function for warp reduction + auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float { +#pragma unroll + for (int delta = 8; delta >= 1; delta /= 2) { + float other_amax = __shfl_xor_sync(0xffffffff, thread_amax, delta); + thread_amax = fmaxf(thread_amax, other_amax); + } + return thread_amax; + }; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + +#pragma unroll + for (size_t block_iter = 0; block_iter < ITERATIONS_BLOCK; ++block_iter) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + const size_t block_in_tile_y = block_iter; + const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; + + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + for (int elem = 0; elem < BLOCK_DIM; elem += 2) { + const size_t elem_0_row = block_iter * BLOCK_DIM + elem; + const size_t elem_1_row = elem_0_row + 1; + const size_t elem_0_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; + const size_t elem_1_col = elem_0_col; + + const size_t shmem_offset_0 = buff_offset_in + elem_0_row * BUFF_IN_DIM_X + elem_0_col; + const size_t shmem_offset_1 = buff_offset_in + elem_1_row * BUFF_IN_DIM_X + elem_1_col; + + IType2 val_2x; + val_2x.x = in_sh[shmem_offset_0]; + val_2x.y = in_sh[shmem_offset_1]; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val_2x); + } + + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else { + for (int elem = 0; elem < BLOCK_DIM; ++elem) { + const size_t elem_row = block_iter * BLOCK_DIM + elem; + const size_t elem_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; + + // Bounds checking + const bool row_out_of_bounds = (block_offset_Y + stage_offset_Y + elem_row >= rows); + const bool col_out_of_bounds = (block_offset_X + elem_col >= cols); + if (!row_out_of_bounds && !col_out_of_bounds) { + const size_t shmem_offset = buff_offset_in + elem_row * BUFF_IN_DIM_X + elem_col; + float elt = static_cast(in_sh[shmem_offset]); + + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset] = static_cast(elt); + } + + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + } + // Warp reduction to get block amax + block_amax = warp_reduce_amax(thread_amax, block_in_warp); + + if (lane_id == 0 || lane_id == 16) { + block_amax_matrix[block_in_tile_y][block_in_tile_x] = block_amax; + } + } + + // sync thread to ensure block_amax_matrix is done storing + __syncthreads(); + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t block_in_tile_y = it; + const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; + + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 3. Scale elements + + // Load data in + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + } + } else { + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + fp4e2m1x4 regs[SCALE_DIM / 4]; +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = mul_cvt_bf16_to_fp4_4x(elts, block_scale_inverse_2x, + rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t block_in_tile_y = it; + const size_t block_in_tile_x = tid_X_rowwise; + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); + } + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + const size_t global_offset_Y_t = block_offset_Y_t; + const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // Vectorized store scaling factors through SHMEM + if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + using ScalesVec = Vec; + const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + const size_t count = // number of scales in Y dimension of this chunk + (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // Fast path: vectorized store when destination is properly aligned + scales_vec.store_to(dst); + } else { + // Safe path: element-wise store for tails or unaligned destinations + scales_vec.store_to_elts(dst, 0, count); + } + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace nvfp4_transpose +#endif // CUDA_VERSION > 12080 + +// Compile-time flag to choose kernel variant +#ifndef USE_2D_NVFP4_KERNEL +#define USE_2D_NVFP4_KERNEL 0 +#endif + +template +void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, cudaStream_t stream) { +#if CUDA_VERSION > 12080 + bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + + // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to + // return the transposed data. + // TODO(Frank): Is there a better way to do this? + bool return_transpose = output->has_columnwise_data(); + + using namespace nvfp4_transpose; + using namespace ptx; + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + if (return_transpose) { + NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated"); + } + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + NVTE_CHECK(rows % 32 == 0, + "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA + NVTE_CHECK(cols % 32 == 0, + "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_NUM; + + const size_t scale_stride = output->scale_inv.shape[1]; + const size_t scale_stride_transpose = + return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + + nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); + nvfp4_scale_t *const scales_transpose_ptr = + reinterpret_cast(output->columnwise_scale_inv.dptr); + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const float *const amax_colwise_ptr = + reinterpret_cast(output->columnwise_amax.dptr); + + const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; + const size_t *rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + using IType = bf16; + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + alignas(64) CUtensorMap tensor_map_output_transpose{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + 4); + if (return_transpose) { + create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, + BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); + } + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_data_mem = buff_size_aligned_out; + constexpr size_t out_data_transpose_mem = buff_size_aligned_out; + constexpr size_t out_scales_transpose_mem = buff_size_scales; + + constexpr size_t out_mem = out_data_mem + out_data_transpose_mem; + + constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = nvfp4_transpose_kernel; + + if constexpr (use_2d_quantization) { + kernel = nvfp4_transpose_kernel_2D; + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + });); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // CUDA_VERSION > 12080 +} +} // namespace transformer_engine + +#endif // !__HIP_PLATFORM_AMD__ +#endif // TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 7c38a337b..806b4d71d 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -16,6 +16,10 @@ #include #include +#if CUDA_VERSION >= 12080 +#include +#endif // CUDA_VERSION >= 12080 + namespace transformer_engine { namespace ptx { @@ -119,12 +123,16 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) { return __int_as_float(biased_exp << FP32_MANTISSA_BITS); } +#define CUDA_ARCH_HAS_FEATURE_SM10X_ALL \ + ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ + (__CUDA_ARCH_HAS_FEATURE__(SM103_ALL))) + __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { #ifdef __HIP_PLATFORM_AMD__ #define __CUDA_ARCH_HAS_FEATURE__(x) 0 #endif -#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ - (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + uint16_t out; asm volatile( "{\n" @@ -227,18 +235,86 @@ struct alignas(2 * sizeof(T)) FPx2 { T y; }; +template +struct FPx4 { + T x1; + T x2; + T x3; + T x4; +}; + +template +struct Type2x {}; + +template <> +struct Type2x { + using type = float2; +}; + +template <> +struct Type2x { + using type = __nv_bfloat162; +}; + +template <> +struct Type2x { + using type = __half2; +}; + using floatx2 = FPx2; using bf16x2 = FPx2; using fp16x2 = FPx2; using fp8e4m3x2 = FPx2; using fp8e5m2x2 = FPx2; +using floatx4 = FPx4; +using bf16x4 = FPx4; +using fp16x4 = FPx4; +using fp8e4m3x4 = FPx4; +using fp8e5m2x4 = FPx4; + static_assert(sizeof(floatx2) == 8); static_assert(sizeof(bf16x2) == 4); static_assert(sizeof(fp16x2) == 4); static_assert(sizeof(fp8e4m3x2) == 2); static_assert(sizeof(fp8e5m2x2) == 2); +#if CUDA_VERSION >= 12080 +using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; +static_assert(sizeof(fp4e2m1x2) == 1); +static_assert(sizeof(fp4e2m1x4) == 2); +#endif // CUDA_VERSION >= 12080 + +// cvt.rn.satfinite.e2m1x2.f32 d, a, b; // Convert two FP32 values to two packed e2m1 + +// cvt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 introduced in PTX ISA version 8.6. + +// vt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 is supported on following architectures: +// sm_100a +// sm_101a +// sm_120a + +// When converting to .e2m1x2 data formats, the destination operand d has .b8 type. +// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format, +// and the converted values are packed in the destination operand d such that the value +// converted from input a is stored in the upper 4 bits of d and the value converted +// from input b is stored in the lower 4 bits of d. + +// SIMD like "Fused" cast + multiplication (x4) +#if CUDA_VERSION >= 12080 +template +__device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, const Tx2 &in23, + const float scale) { + const float x0 = static_cast(in01.x) * scale; + const float x1 = static_cast(in01.y) * scale; + const float x2 = static_cast(in23.x) * scale; + const float x3 = static_cast(in23.y) * scale; + out = fp4e2m1x4(make_float4(x0, x1, x2, x3)); +} +#endif // CUDA_VERSION >= 12080 + // SIMD like "Fused" cast + multiplication (x2) __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, const floatx2 &scale) { @@ -374,7 +450,7 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const "r"(reinterpret_cast(p2))); } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // namespace ptx diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index b243a8a0b..301748d06 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -108,7 +108,8 @@ .value("kFloat16", transformer_engine::DType::kFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ + .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ @@ -122,6 +123,10 @@ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_Softmax_Type", pybind11::module_local()) \ + .value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \ + .value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \ + .value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \ pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) \ .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 799becaee..984ab7bc5 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -52,6 +52,39 @@ constexpr uint32_t THREADS_PER_WARP = 32; //////////////////////////////////////////////////////////////////////////////////////////////////// +// Device-side error +#define NVTE_DEVICE_ERROR(message) \ + do { \ + printf("%s:%d in function %s (thread (%d,%d,%d), block (%d,%d,%d)): %s\n", __FILE__, __LINE__, \ + __func__, threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z, \ + (message)); \ + assert(0); \ + } while (false) + +// Device-side error on thread 0 +#define NVTE_DEVICE_THREAD0_ERROR(message) \ + do { \ + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && \ + threadIdx.y == 0 && threadIdx.z == 0) { \ + NVTE_DEVICE_ERROR(message); \ + } \ + } while (false) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*) + return {a.x + b.x, a.y + b.y}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void operator+=(float2 &a, const float2 &b) { // NOLINT(*) + a.x += b.x; + a.y += b.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template struct Sum { inline __device__ Sum() {} diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index ef2643359..576ca96ba 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -11,9 +11,10 @@ import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi if version.parse(jax.__version__) >= version.parse("0.5.0"): from jax.experimental.custom_partitioning import SdyShardingRule +from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec import transformer_engine_jax @@ -30,7 +31,7 @@ should_apply_1x_fused_dbias_war_for_arch_l_100, NamedSharding, ) -from .quantization import _jax_dbias, _quantize_dbias_impl +from .quantization import _jax_dbias, _quantize_dbias_impl, AmaxScope from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( @@ -40,10 +41,6 @@ ScalingMode, ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports __all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"] @@ -420,27 +417,28 @@ def shardy_sharding_rule( if version.parse(jax.__version__) < version.parse("0.5.0"): raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types - prefix = "ActLuPrimitive_" - x_rank = len(value_types[0].shape) + prefix = "ActLu_" + input_shape = value_types[0].shape + output_shape = input_shape[:-2] + input_shape[-1:] + # Here we pass len of output so that the scales are propagated correctly scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - x_rank - 1, unique_var=prefix + "x", flatten_axis=-2 + output_shape, unique_var=prefix + "x", flatten_axis=-1 ) - x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",) - out = (*x_axes[:-2], x_axes[-1]) - scale_inv = scale_rules.rowwise_rule + x_axes = scale_rules.input_spec + # Correct input spec with act dim + x_axes = x_axes[:-1] + (prefix + "_act_dim",) + x_axes[-1:] + out = scale_rules.input_spec colwise_out = (prefix + "out_colwise",) colwise_scale_inv = (prefix + "scale_inv_colwise",) if is_2x: colwise_scale_inv = scale_rules.colwise_rule if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_out = tuple( - multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2) - ) + colwise_out = multidim_transpose(out, transpose_axis=-1) else: colwise_out = out + colwise_scale_inv = scale_rules.colwise_rule - # amax is always a unit tensor. amax = (prefix + "amax",) return SdyShardingRule( @@ -448,7 +446,8 @@ def shardy_sharding_rule( x_axes, ("…1",), ), - (out, colwise_out, scale_inv, colwise_scale_inv, amax), + (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax), + **scale_rules.factor_sizes, ) @@ -895,26 +894,30 @@ def shardy_sharding_rule( if version.parse(jax.__version__) < version.parse("0.5.0"): raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types - prefix = "BaseDActLuDBiasQuantizePrimitive_" + prefix = "DActLuDBias_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 + value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 ) x_axes = scale_rules.input_spec dz_axes = (*x_axes[:-2], x_axes[-1]) out = x_axes + colwise_out = (prefix + "out_colwise",) + colwise_scale_inv = (prefix + "scale_inv_colwise",) if is_2x: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) else: colwise_out = out + colwise_scale_inv = scale_rules.colwise_rule dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) amax = (prefix + "amax",) return SdyShardingRule( (dz_axes, x_axes, ("…2",)), - (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), + (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + **scale_rules.factor_sizes, ) @@ -991,6 +994,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -999,6 +1003,7 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: If quantizer is None: @@ -1056,7 +1061,13 @@ def act_lu( activation_type=activation_type, quantizer=None, ) - out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) + out, _ = _quantize_dbias_impl( + out, + is_dbias=False, + quantizer=quantizer, + dq_dtype=x.dtype, + amax_scope=amax_scope, + ) return out if isinstance(quantizer, DelayedScaleQuantizer): diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 45d3d8b59..881281806 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -14,7 +14,7 @@ import jax import jax.numpy as jnp -from jax import dtypes, lax +from jax import dtypes, lax, ffi from jax.sharding import PartitionSpec, NamedSharding if version.parse(jax.__version__) >= version.parse("0.5.0"): from jax.experimental.custom_partitioning import SdyShardingRule @@ -53,12 +53,6 @@ ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - - __all__ = [ "FusedAttnHelper", "fused_attn_fwd", diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 92c09bb68..c210a0046 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -16,16 +16,12 @@ from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching from jax._src import dispatch +from jax import ffi from .misc import is_hip_extension import jax import transformer_engine_jax -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - class BasePrimitive(metaclass=ABCMeta): """ @@ -182,7 +178,7 @@ def shardy_sharding_rule(*args): _primitive_registry = {} -def register_primitive(cls): +def register_primitive(cls, outer_only=False): """ Register a JAX primitive and add it to the internal registry. """ @@ -195,13 +191,14 @@ def register_primitive(cls): def name_of_wrapper_p(): return cls.name + "_wrapper" - inner_p = core.Primitive(cls.name) - dispatch.prim_requires_devices_during_lowering.add(inner_p) - inner_p.multiple_results = cls.multiple_results - inner_p.def_impl(partial(xla.apply_primitive, inner_p)) - inner_p.def_abstract_eval(cls.abstract) - mlir.register_lowering(inner_p, cls.lowering, platform="rocm" if is_hip_extension() else "cuda") - cls.inner_primitive = inner_p + if not outer_only: + inner_p = core.Primitive(cls.name) + dispatch.prim_requires_devices_during_lowering.add(inner_p) + inner_p.multiple_results = cls.multiple_results + inner_p.def_impl(partial(xla.apply_primitive, inner_p)) + inner_p.def_abstract_eval(cls.abstract) + mlir.register_lowering(inner_p, cls.lowering, platform="rocm" if is_hip_extension() else "cuda") + cls.inner_primitive = inner_p outer_p = core.Primitive(name_of_wrapper_p()) dispatch.prim_requires_devices_during_lowering.add(outer_p) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4ba581c66..aaac6e751 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -8,8 +8,10 @@ import math import operator from collections.abc import Iterable -from typing import Tuple, Sequence, Union +from dataclasses import dataclass from functools import partial, reduce +from typing import Tuple, Sequence, Union +from enum import Enum import warnings import jax @@ -18,8 +20,13 @@ from jax.sharding import NamedSharding, PartitionSpec from jax.experimental.custom_partitioning import SdyShardingRule -import transformer_engine_jax as tex -from transformer_engine_jax import get_num_compute_streams +from transformer_engine_jax import ( + get_num_compute_streams, + JAXX_Collective_Op, + get_device_compute_capability, + initialize_cgemm_communicator, + get_cgemm_num_max_streams, +) from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize @@ -42,11 +49,19 @@ is_fp8_gemm_with_all_layouts_supported, apply_padding_to_scale_inv, ) -from ..sharding import global_mesh_resource -from .misc import get_padded_spec +from .misc import get_padded_spec, is_all_reduce_in_float32 +from ..sharding import ( + global_mesh_resource, + tpsp_axis_size, + dp_or_fsdp_axis_size, +) __all__ = [ + "CollectiveOp", + "CollectiveOpSet", + "collective_gemm_bootstrap", + "noop_collective_op_set", "gemm", "grouped_gemm", "gemm_uses_jax_dot", @@ -69,7 +84,7 @@ def get_cublas_workspace_size_bytes() -> None: return 67_108_864 return 33_554_432 """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if tex.get_device_compute_capability(0) >= 90: + if get_device_compute_capability(0) >= 90: return 33_554_432 return 4_194_304 @@ -165,6 +180,161 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ return lhs_q, rhs_q +def collective_gemm_bootstrap( + num_total_devices, + num_devices_per_process, + process_id, + tensor_parallel_size, + num_max_streams=3, + compute_stream_priority=0, + communication_stream_priority=0, + num_sm_for_communication=2, + use_ce=True, + aggregate_all_gather=False, +): + """Initialize NCCL communicators for Collective GEMM operations. + + This function sets up the distributed communication infrastructure needed for + tensor parallel collective GEMM operations. It supports two main scenarios: + + 1. **Multi-device per process**: TP domain = single process + - Each process manages multiple GPUs (num_devices_per_process > 1) + - TP group consists of GPUs within the same process + - Example: 2 processes × 4 GPUs each = 8 total ranks, tp_size=4 + + 2. **Single device per process**: TP domain spans multiple processes + - Each process manages one GPU (num_devices_per_process = 1) + - TP group spans across multiple processes + - Example: 8 processes × 1 GPU each = 8 total ranks, tp_size=4 + + Args: + num_total_devices (int): Total number of ranks across all processes. + Must be divisible by num_devices_per_process. + num_devices_per_process (int): Number of GPUs per process. + - For multi-device: equals tp_size (e.g., 4 GPUs per process) + - For single-device: equals 1 (1 GPU per process) + process_id (int): Process identifier (0-based). + Must be in range [0, num_total_devices // num_devices_per_process). + tensor_parallel_size (int): Size of tensor parallel groups. + Must divide num_total_devices evenly. + num_max_streams (int, optional): Maximum number of CUDA streams for overlap. + Higher values enable more parallelism but use more GPU resources. Default: 3. + compute_stream_priority (int, optional): Priority for GEMM computation streams. + Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0. + communication_stream_priority (int, optional): Priority for NCCL communication streams. + Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0. + num_sm_for_communication (int, optional): Number of streaming multiprocessors + reserved for communication operations. Default: 2. + use_ce (bool, optional): Enable CUDA copy engines for memory transfers. + Can improve performance by offloading memory operations. Default: True. + aggregate_all_gather (bool, optional): Aggregate multiple small all-gather operations + into larger ones for better efficiency. Default: False. + + Raises: + AssertionError: If num_total_devices is not divisible by num_devices_per_process, + or if process_id is out of valid range. + AssertionError: If num_devices_per_process is not 1 (Temporary: only single device per process is supported for now) + RuntimeError: If NCCL initialization fails or if configuration + is invalid (e.g., insufficient GPUs). + + Example: + # Basic initialization (single device per process) + collective_gemm_bootstrap( + num_total_devices=8, + num_devices_per_process=1, + process_id=0, + tensor_parallel_size=4 + ) + + # Advanced configuration with custom performance settings + collective_gemm_bootstrap( + num_total_devices=8, + num_devices_per_process=1, + process_id=0, + tensor_parallel_size=4, + num_max_streams=5, # More parallelism + compute_stream_priority=1, # Lower compute priority + communication_stream_priority=0, # Higher comm priority + num_sm_for_communication=4, # More SMs for communication + use_ce=True, # Enable copy engines + aggregate_all_gather=True # Aggregate small operations + ) + + Note: + This function must be called after JAX distributed initialization + and before any collective GEMM operations. Each process should call + this function with its own unique process_id. + """ + + assert ( + num_devices_per_process == 1 and jax.local_device_count() == 1 + ), "Only single device per process is supported at the moment!" + assert num_total_devices % num_devices_per_process == 0, ( + f"Invalid num_total_devices={num_total_devices}," + f" num_devices_per_process={num_devices_per_process}" + ) + assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}" + initialize_cgemm_communicator( + num_total_devices, + num_devices_per_process, + process_id, + tensor_parallel_size, + num_max_streams, + compute_stream_priority, + communication_stream_priority, + num_sm_for_communication, + use_ce, + aggregate_all_gather, + ) + + +class CollectiveOp(Enum): + "Enum for Collective Type in Collective GEMM" + + NONE = JAXX_Collective_Op.NONE + ALL_GATHER = JAXX_Collective_Op.ALL_GATHER + REDUCE_SCATTER = JAXX_Collective_Op.REDUCE_SCATTER + + @property + def is_all_gather(self) -> bool: + """Check if AllGather""" + return self == CollectiveOp.ALL_GATHER + + @property + def is_reduce_scatter(self) -> bool: + """Check if ReduceScatter""" + return self == CollectiveOp.REDUCE_SCATTER + + @property + def is_none(self) -> bool: + """Check if None""" + return self == CollectiveOp.NONE + + +@dataclass(frozen=True) +class CollectiveOpSet: + """ + A set of CollectiveOp objects that provide complementary collective GEMM configurations for the Forward and Backward passes through Dense-layers. + """ + + forward: CollectiveOp + backward: CollectiveOp + + @staticmethod + def create(forward_collective_op: CollectiveOp): + """Create a set of CollectiveOp for forward and backward passes""" + if forward_collective_op.is_all_gather: + backward_collective_op = CollectiveOp.REDUCE_SCATTER + elif forward_collective_op.is_reduce_scatter: + backward_collective_op = CollectiveOp.ALL_GATHER + else: + backward_collective_op = CollectiveOp.NONE + return CollectiveOpSet(forward=forward_collective_op, backward=backward_collective_op) + + +noop_collective_op_set = CollectiveOpSet.create(forward_collective_op=CollectiveOp.NONE) + + @partial(jax.jit, static_argnums=(1, 2)) def swizzled_scale(scale_inv, flatten_axis, is_colwise): "Swizzle scale_inv via JAX transpose ops" @@ -187,7 +357,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9, 10, 11, 12) + impl_static_args = 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 inner_primitive = None outer_primitive = None @@ -206,8 +376,12 @@ def abstract( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): - del use_split_accumulator + del use_split_accumulator, transpose_batch_sequence def _dims_are_consecutive(dims): if len(dims) <= 1: @@ -251,7 +425,7 @@ def _dims_are_consecutive(dims): ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." if ( scaling_mode != ScalingMode.MXFP8_1D_SCALING - and not tex.is_non_nt_fp8_gemm_supported() + and not is_fp8_gemm_with_all_layouts_supported() ): assert not lhs_is_transposed and rhs_is_transposed, ( "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " @@ -276,6 +450,19 @@ def _dims_are_consecutive(dims): out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape) output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + # Adjust output shape for comm+GEMM overlap + if not collective_op.is_none and not is_outer: # Inner abstract + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + overlap_out_shape = list(out_shape).copy() + if collective_op.is_all_gather: + overlap_out_shape[1] *= tpsp_axis_size() + else: # RS + overlap_out_shape[sequence_dim] = ( + overlap_out_shape[sequence_dim] // tpsp_axis_size() + ) + assert out_dtype == jnp.bfloat16, f"Unsupported out_dtype={out_dtype}" + output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype) + # Validate bias bias_shape = (0,) bias_dtype = out_dtype @@ -315,9 +502,12 @@ def _dims_are_consecutive(dims): pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) # Declare cuBLAS workspace + workspace_size = get_cublas_workspace_size_bytes() + if not collective_op.is_none: + workspace_size *= get_cgemm_num_max_streams() # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # necessarily 256 bytes aligned, we add some padding to ensure alignment. - workspace_size = get_cublas_workspace_size_bytes() + 256 + workspace_size += 256 workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) return output, bias_grad, pre_gelu_out, workspace @@ -343,8 +533,12 @@ def lowering( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): - del out_dtype + del out_dtype, transpose_batch_sequence, sequence_dim, is_outer lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) @@ -363,6 +557,7 @@ def lowering( "fuse_gelu": fuse_gelu, "grad": grad, "use_split_accumulator": use_split_accumulator, + "collective_op": int(collective_op.value), } operand_output_aliases = {} @@ -391,6 +586,10 @@ def impl( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): if scaling_mode.is_1d_block_scaling(): lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) @@ -409,7 +608,34 @@ def impl( lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) - outputs = GemmPrimitive.inner_primitive.bind( + # Alter lhs blocks so that CGEMM RS outputs correctly + if ( + collective_op.is_reduce_scatter + and not transpose_batch_sequence + and not is_outer + and not lhs.shape[0] == 1 + ): + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + original_shape = lhs.shape + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = lhs.reshape( + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + tpsp_axis_size(), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) + lhs = reordered.reshape(original_shape) + + (output, bias_grad, pre_gelu_out, _) = GemmPrimitive.inner_primitive.bind( lhs, lhs_scale_inv, rhs, @@ -423,8 +649,39 @@ def impl( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + collective_op=collective_op, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=sequence_dim, + is_outer=is_outer, ) - return outputs[:-1] # discard workspace array + # Alter output blocks for CGEMM AG + if ( + collective_op.is_all_gather + and not transpose_batch_sequence + and not is_outer + and not output.shape[0] == 1 + ): + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + original_shape = output.shape + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = output.reshape( + tpsp_axis_size(), + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim)) + output = reordered.reshape(original_shape) + + return [output, bias_grad, pre_gelu_out] @staticmethod def outer_impl( @@ -441,6 +698,10 @@ def outer_impl( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): return GemmPrimitive.impl( lhs, @@ -456,6 +717,10 @@ def outer_impl( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ) @staticmethod @@ -469,7 +734,12 @@ def batcher( fuse_gelu, grad, use_split_accumulator, + collective_op, + transpose_batch_sequence, + sequence_dim, + is_outer, ): + del transpose_batch_sequence, sequence_dim, is_outer assert GemmPrimitive.outer_primitive is not None lhs_bdims, _, rhs_bdims, *_ = batch_dims @@ -497,6 +767,10 @@ def batcher( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + collective_op=collective_op, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=sequence_dim, + is_outer=is_outer, ), (out_bdims, bias_bdims, pre_gelu_bdims), ) @@ -505,6 +779,8 @@ def batcher( def _parse_operand_output_specs( arg_infos, contracting_dims, + transpose_batch_sequence, + collective_op, ): lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) @@ -512,14 +788,12 @@ def _parse_operand_output_specs( # Ensure that tensor sequence parallelism is not used via setting tp_resource if gsr.tp_resource is not None: - for i in range(len(lhs_specs) - 1): - if lhs_specs[i] == gsr.tp_resource and lhs_specs[i + 1] == gsr.tp_resource: - warnings.warn( - "Tensor sequence parallelism is detected as" - f" tp_resource='{gsr.tp_resource}' appears twice consecutively in" - f" lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource for" - " tensor sequence parallelism to avoid potential issues." - ) + if gsr.tp_resource in lhs_specs: + warnings.warn( + "Tensor sequence parallelism is detected as tp_resource='{gsr.tp_resource}'" + " appears in lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource" + " for tensor sequence parallelism to avoid potential issues." + ) lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims) @@ -541,10 +815,43 @@ def _parse_operand_output_specs( assert reduce_spec is None, "Multiple reduce dimension is detected!" reduce_spec = l + sequence_dim = None + + # Find sequence dimension in lhs_specs if tensor sequence parallel is enabled + # We only do CollectiveGemm AG on the x or dY thus they always the LHS and have sequence dim + if collective_op.is_all_gather: + try: + tpsp_idx = lhs_specs.index(gsr.tpsp_resource) + except ValueError as exc: + raise ValueError( + f"tpsp_resource '{gsr.tpsp_resource}' is not found in lhs_specs: {lhs_specs}." + " Please check your sharding configuration." + ) from exc + sequence_dim = tpsp_idx + assert (sequence_dim == 1) ^ transpose_batch_sequence, ( + "CollectiveGEMM supports only (sequence_dim=1 and transpose_batch_sequence=False)" + " or (sequence_dim=0 and transpose_batch_sequence=True). Received:" + f" sequence_dim={sequence_dim}," + f" transpose_batch_sequence={transpose_batch_sequence}." + ) + + elif collective_op.is_reduce_scatter: + assert reduce_spec == gsr.tpsp_resource, ( + "Only CollectiveGemm RS with the Reduction over the TPSP axis is supported! Got" + f" reduce_spec={reduce_spec}, tpsp_resource={gsr.tpsp_resource}" + ) + sequence_dim = int(not transpose_batch_sequence) + if reduce_spec is not None: # Other non-reduce cdims (if exists) need to be unsharded lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs) - rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs) + # Only do AG Sequence dim if not Overlap + if collective_op.is_all_gather: + rhs_cspecs = tuple( + s if s in (reduce_spec, gsr.tpsp_resource) else None for s in rhs_cspecs + ) + else: + rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs) # Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden # No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim. @@ -564,13 +871,31 @@ def _parse_operand_output_specs( for spec in rhs_non_cspecs ) - # Non-contracting dims of LHS to be gathered along the SP axis. - # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for - # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet. - lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs) + # Only do AG Sequence dim if not Overlap + if not collective_op.is_all_gather: + # Non-contracting dims of LHS to be gathered along the SP axis. + # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for + # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet. + lhs_non_cspecs = tuple( + None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs + ) out_specs = lhs_non_cspecs + rhs_non_cspecs + # Only do AG Sequence dim if not Overlap RS + if collective_op.is_all_gather: + assert sequence_dim <= len( + lhs_non_cspecs + ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}" + out_specs = out_specs[:sequence_dim] + (None,) + out_specs[sequence_dim + 1 :] + elif collective_op.is_reduce_scatter: + assert sequence_dim <= len( + lhs_non_cspecs + ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}" + out_specs = ( + out_specs[:sequence_dim] + (gsr.tpsp_resource,) + out_specs[sequence_dim + 1 :] + ) + # specs = merge(cspecs, non_cspecs) lhs_specs, rhs_specs = map( lambda cdims, cspecs, non_cspecs: ( @@ -585,10 +910,14 @@ def _parse_operand_output_specs( bias_specs = tuple(list(rhs_non_cspecs).copy()) gelu_specs = tuple(list(out_specs).copy()) + if not collective_op.is_none: + assert sequence_dim >= 0, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + return ( (lhs_specs, rhs_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs), reduce_spec, + sequence_dim, ) @staticmethod @@ -600,6 +929,10 @@ def infer_sharding_from_operands( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, mesh, arg_infos, result_infos, @@ -608,11 +941,16 @@ def infer_sharding_from_operands( out_dtype, scaling_mode, grad, + use_split_accumulator, + result_infos, + is_outer, + sequence_dim, ) - del use_split_accumulator, result_infos - (_, (out_specs, dbias_specs, pre_gelu_specs), _) = ( - GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) + (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( + GemmPrimitive._parse_operand_output_specs( + arg_infos, contracting_dims, transpose_batch_sequence, collective_op + ) ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) @@ -637,20 +975,29 @@ def partition( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, mesh, arg_infos, result_infos, ): - del result_infos + del result_infos, is_outer, sequence_dim ( (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (out_specs, dbias_specs, pre_gelu_specs), reduce_spec, - ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) + inferred_sequence_dim, + ) = GemmPrimitive._parse_operand_output_specs( + arg_infos, + contracting_dims, + transpose_batch_sequence, + collective_op, + ) - # Assemble argument shardings - # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. + # Block scale inverses match their operands, but tensor scale inverses are unsharded. none_sharding = NamedSharding(mesh, PartitionSpec(None)) lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) @@ -699,11 +1046,19 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=inferred_sequence_dim, + is_outer=False, + collective_op=collective_op, ) - # All-Reduce GEMM output - if reduce_spec is not None: - outputs[0] = jax.lax.psum(outputs[0], reduce_spec) + if reduce_spec is not None and not collective_op.is_reduce_scatter: + if is_all_reduce_in_float32(): # For unittest only + outputs[0] = jax.lax.psum(outputs[0].astype(jnp.float32), reduce_spec).astype( + out_dtype + ) + else: + outputs[0] = jax.lax.psum(outputs[0], reduce_spec) return outputs @@ -718,14 +1073,24 @@ def shardy_sharding_rule( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, mesh, operand_types, result_types, ): del out_dtype, grad, use_split_accumulator - del mesh, result_types + del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer - prefix = "GemmPrimitive_" + if not collective_op.is_none: + raise NotImplementedError( + "CollectiveGEMM with Shardy propagation is not supported yet! Please turn off" + " Shardy by exporting env var JAX_USE_SHARDY_PARTITIONER=false" + ) + + prefix = "Gemm_" warnings.warn( "Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now," @@ -759,13 +1124,8 @@ def _generate_operand_rules(name, ndim, cdims): lhs_scale_specs = ("…1",) rhs_scale_specs = ("…2",) if scaling_mode.is_1d_block_scaling(): - # Shardy rules for MXFP8 scales cannot be related to the operands because of the - # global-unpadding and local-padding workflow. This can potentially insert expensive - # re-shards in the partition call later if the scales are not already sharded correctly. - lhs_scale_specs, rhs_scale_specs = map( - lambda specs: tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs), - (lhs_specs, rhs_specs), - ) + lhs_scale_specs = lhs_specs + rhs_scale_specs = rhs_specs lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims) rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) @@ -810,6 +1170,8 @@ def _te_gemm( fuse_gelu: bool = False, grad: bool = False, use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP, + transpose_batch_sequence: bool = False, + collective_op: CollectiveOp = CollectiveOp.NONE, ) -> Tuple[jax.Array, ...]: # Prepare non-quantized GEMM operands @@ -818,6 +1180,7 @@ def _te_gemm( lhs_scale_inv = jnp.empty(0, dtype=jnp.float32) rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) scaling_mode = ScalingMode.NO_SCALING + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) @@ -877,6 +1240,10 @@ def _te_gemm( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=-1, + is_outer=True, + collective_op=collective_op, ) @@ -1194,6 +1561,8 @@ def gemm( contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, + transpose_batch_sequence: bool = False, + collective_op: CollectiveOp = CollectiveOp.NONE, **kwargs, ) -> Tuple[jnp.ndarray, ...]: r"""General matrix multiplication with optional quantization. @@ -1227,8 +1596,11 @@ def gemm( TE's custom call to cuBLAS GEMM. use_split_accumulator: bool, default = True Enable promoting some intermediate sums to higher precision when accumulating the result in - the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only - supported with TE's custom call to cuBLAS GEMM. + the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. + transpose_batch_sequence: bool, default = False + Transpose the batch and sequence dimensions of the input tensor. + collective_op: CollectiveOp, default = CollectiveOp.NONE + Collective operation type for collective GEMM. Returns ------- @@ -1272,6 +1644,7 @@ def gemm( "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) + assert collective_op.is_none, "JAX GEMM does not support collective GEMM" return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) outputs = _te_gemm( @@ -1280,6 +1653,8 @@ def gemm( lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, contracting_dims=contracting_dims, + transpose_batch_sequence=transpose_batch_sequence, + collective_op=collective_op, **kwargs, ) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index e7464a6da..886297df6 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -299,3 +299,11 @@ def duplicate_with_new_description(self, desc: str): Create a new NamedSharding with the same mesh and spec but with a new description. """ return NamedSharding(self.mesh, self.spec, desc=desc) + + +@functools.lru_cache(maxsize=1) +def is_all_reduce_in_float32(): + """ + Check if all-reduce is in float32 + """ + return os.getenv("NVTE_JAX_ALL_REDUCE_IN_FP32", "0") == "1" diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 89731e24a..5e0ebdeec 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -13,9 +13,10 @@ import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi if version.parse(jax.__version__) >= version.parse("0.5.0"): from jax.experimental.custom_partitioning import SdyShardingRule +from jax.experimental.custom_partitioning import SdyShardingRule from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec from .misc import is_hip_extension @@ -32,7 +33,7 @@ NamedSharding, get_cudnn_version, ) -from .quantization import _quantize_dbias_impl +from .quantization import _quantize_dbias_impl, AmaxScope from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( @@ -42,11 +43,6 @@ ScalingMode, ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - __all__ = [ "layernorm_fwd", @@ -593,9 +589,9 @@ def shardy_sharding_rule( result_types, ) - prefix = "NormFwdPrimitive_" + prefix = "NormFwd_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1 + value_types[0].shape, unique_var=prefix + "x", flatten_axis=-1 ) x_axes = scale_rules.input_spec @@ -616,6 +612,7 @@ def shardy_sharding_rule( mu, rsigma, ), + **scale_rules.factor_sizes, ) @@ -892,6 +889,7 @@ def layernorm_fwd( zero_centered_gamma: bool, epsilon: float, quantizer: Optional[Quantizer], + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]: """Layer normalization forward pass with optional quantization. @@ -905,6 +903,7 @@ def layernorm_fwd( zero_centered_gamma: If True, gamma is zero-centered. epsilon: Small constant for numerical stability. quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: A tuple containing: @@ -964,7 +963,13 @@ def layernorm_fwd( epsilon=epsilon, quantizer=None, ) - out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) + out, _ = _quantize_dbias_impl( + out, + is_dbias=False, + quantizer=quantizer, + dq_dtype=x.dtype, + amax_scope=amax_scope, + ) return out, mu, rsigma is_2x2x = quantizer.is_2x2x() @@ -1094,6 +1099,7 @@ def rmsnorm_fwd( zero_centered_gamma: bool, epsilon: float, quantizer: Optional[Quantizer], + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]: """Root mean square normalization forward pass with optional quantization. @@ -1105,6 +1111,7 @@ def rmsnorm_fwd( zero_centered_gamma: If True, gamma is zero-centered. epsilon: Small constant for numerical stability. quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: A tuple containing: @@ -1165,7 +1172,11 @@ def rmsnorm_fwd( quantizer=None, ) out, _ = _quantize_dbias_impl( - out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype + out.data, + is_dbias=False, + quantizer=quantizer, + dq_dtype=x.dtype, + amax_scope=amax_scope, ) return out, rsigma @@ -1290,6 +1301,7 @@ def normalization_fwd( epsilon: float, norm_type: str, quantizer: Optional[Quantizer], + amax_scope: AmaxScope = AmaxScope.LOCAL, ): """Common wrapper for normalization forward pass. @@ -1306,6 +1318,7 @@ def normalization_fwd( - 'layernorm': Layer normalization - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: A tuple containing: @@ -1323,12 +1336,27 @@ def normalization_fwd( zero_centered_gamma is not supported if norm_type is 'rmsnorm'. """ if norm_type == "layernorm": - output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) + output, mu, rsigma = layernorm_fwd( + x, + gamma, + beta, + zero_centered_gamma, + epsilon, + quantizer, + amax_scope=amax_scope, + ) elif norm_type == "rmsnorm": assert ( not zero_centered_gamma ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" - output, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer) + output, rsigma = rmsnorm_fwd( + x, + gamma, + zero_centered_gamma, + epsilon, + quantizer, + amax_scope=amax_scope, + ) mu = None else: raise ValueError(f"{norm_type=} is not supported.") diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 78780ff9c..a7fedfeb7 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -8,13 +8,16 @@ from functools import reduce from typing import Tuple, Optional, Union import math +from enum import Enum from packaging import version + import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi if version.parse(jax.__version__) >= version.parse("0.5.0"): from jax.experimental.custom_partitioning import SdyShardingRule +from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec import transformer_engine_jax @@ -30,7 +33,12 @@ get_min_device_compute_capability, NamedSharding, ) -from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp +from ..sharding import ( + all_reduce_max_along_all_axes_except_PP, + all_reduce_sum_along_dp_fsdp, + global_mesh_resource, + lax_paral_op, +) from ..quantize import ( ScaledTensor2x, ScaledTensor, @@ -44,11 +52,6 @@ NoScaleTensor, ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] @@ -499,9 +502,9 @@ def shardy_sharding_rule( raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") del out_dtype, scale_dtype, is_outer, mesh, result_types - prefix = "BaseDBiasQuantizePrimitive_" + prefix = "DBiasQuantize_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[0].shape), + value_types[0].shape, unique_var=prefix + "x", flatten_axis=flatten_axis, ) @@ -523,6 +526,7 @@ def shardy_sharding_rule( return SdyShardingRule( (x_axes, ("…1",), amax), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + **scale_rules.factor_sizes, ) @@ -537,6 +541,126 @@ class QuantizePrimitive(BaseDBiasQuantizePrimitive): """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" +class AmaxScope(Enum): + """ + Amax Scope Enum + """ + + LOCAL = 1 + TPSP = 2 + FSDP = 3 + + +class AmaxCalculationPrimitive(BasePrimitive): + """ + Amax Calculation Primitive with custom_partitioning + """ + + name = "jax_local_amax" + multiple_results = False + impl_static_args = (1,) # amax_scope + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + *, + amax_scope, + ): + """ + amax calcuation abstract + """ + del amax_scope + + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + + out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + return out_aval + + @staticmethod + def impl( + x, + amax_scope, + ): + """ + amax calcuation implementation + """ + del amax_scope + amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) + return amax + + @staticmethod + def infer_sharding_from_operands( + amax_scope, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation infer_sharding_from_operands + """ + del (amax_scope, arg_infos, result_infos) # Unused. + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="AmaxCalculationPrimitive.out_sharding", + ) + return amax_sharding + + @staticmethod + def partition( + amax_scope, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation partition + """ + del result_infos + + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="AmaxCalculationPrimitive.out_sharding", + ) + + def sharded_impl(x): + amax = AmaxCalculationPrimitive.impl( + x, + amax_scope=amax_scope, + ) + if amax_scope is AmaxScope.TPSP: # Run AR across TP/SP + gmesh = global_mesh_resource() + amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tp_resource, mesh) + amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) + + if amax_scope is AmaxScope.FSDP: # Run AR across FSDP + gmesh = global_mesh_resource() + amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) + + return amax + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + return mesh, sharded_impl, amax_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(amax_scope, mesh, value_types, result_types): + """ + amax calcuation shardy_sharding_rule + """ + del amax_scope, mesh, result_types + prefix = "AmaxCal" + input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) + output_spec = (f"{prefix}_amax",) + return SdyShardingRule((input_spec,), (output_spec,)) + + +register_primitive(AmaxCalculationPrimitive, outer_only=True) + + def _jax_quantize( x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 ): @@ -583,6 +707,7 @@ def _quantize_dbias_impl( is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, + amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -639,7 +764,10 @@ def _quantize_dbias_impl( # until the tensor is dequantized (e.g. in the GEMM). amax = x.amax if amax is None: - amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,)) + amax = AmaxCalculationPrimitive.outer_primitive.bind( + x.data, + amax_scope=amax_scope, + ) scale = compute_scale_from_amax(amax, quantizer.q_dtype) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale @@ -711,6 +839,7 @@ def quantize( x: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, flatten_axis: int = -1, + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -721,6 +850,7 @@ def quantize( flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. is None. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: A ScaledTensor containing the quantized input tensor. @@ -729,6 +859,7 @@ def quantize( x, quantizer=quantizer, flatten_axis=flatten_axis, + amax_scope=amax_scope, ) return out @@ -738,6 +869,7 @@ def quantize_dbias( quantizer: Quantizer, is_dbias: bool = True, flatten_axis: int = -1, + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -748,6 +880,8 @@ def quantize_dbias( is_dbias: If True, compute bias gradient. Defaults to True. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + Returns: A tuple containing: @@ -761,6 +895,7 @@ def quantize_dbias( quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis, + amax_scope=amax_scope, ) diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index 43cb11a08..575a2dd3a 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -6,22 +6,16 @@ from functools import partial, reduce import operator import warnings -from packaging import version import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi from jax.sharding import PartitionSpec, NamedSharding from .base import BasePrimitive, register_primitive from .misc import get_padded_spec, check_valid_batch_dims from ..softmax import SoftmaxType -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - __all__ = [ "scaled_softmax_fwd", diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 453a4202b..edb7f14f3 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -17,6 +17,9 @@ #include #include #include +#ifndef USE_ROCM +#include +#endif #include #include @@ -36,9 +39,6 @@ #include "transformer_engine/activation.h" #include "transformer_engine/multi_stream.h" -// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace -XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); - namespace transformer_engine { namespace jax { @@ -47,16 +47,20 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D // Activation XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, JAXX_Scaling_Mode scaling_mode, bool is_2x); // Normalization +XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler); pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, @@ -125,6 +129,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( // GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); @@ -140,4 +145,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); } // namespace jax } // namespace transformer_engine +// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); + #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 17fa9906b..b2b3db52c 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -148,6 +148,30 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Attr("is_2x"), FFI_CudaGraph_Traits); +Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type amax_buf, int64_t act_enum, + JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { + return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, + colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, + act_enum, scaling_mode, is_2x_int); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Attr("act_enum") + .Attr("scaling_mode") + .Attr("is_2x")); + pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, JAXX_Scaling_Mode scaling_mode, bool is_2x) { @@ -410,5 +434,39 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("is_2x") .Attr("is_dbias"), FFI_CudaGraph_Traits); + +Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, + Buffer_Type act_input_buf, Buffer_Type scale_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type amax_buf, + Result_Type dbias_buf, Result_Type workspace_buf, + JAXX_Scaling_Mode scaling_mode, int64_t act_enum, + bool is_2x, bool is_dbias) { + return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, + act_input_buf, scale_buf, output_buf, colwise_output_buf, + scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, + workspace_buf, scaling_mode, act_enum, is_2x, is_dbias); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, + DActLuDBiasQuantizeInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // act input + .Arg() // scale + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Ret() // dbias + .Ret() // wkspace + .Attr("scaling_mode") + .Attr("act_enum") + .Attr("is_2x") + .Attr("is_dbias")); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 342953746..9ea8ba4af 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -20,10 +20,11 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right) { + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); return backend; } @@ -160,6 +161,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); + auto dummy_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); @@ -186,28 +190,30 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen"); nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, query_workspace_tensor.data(), nullptr); + qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), - dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, - kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), + ragged_offset_tensor.data(), dummy_page_table_tensor.data(), + dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, + kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, + query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported QKVLayout."); } @@ -279,10 +285,15 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + + auto dummy_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; + auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -297,12 +308,12 @@ static void FusedAttnForwardImpl( if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); - nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, is_training, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, workspace_tensor.data(), stream); + nvte_fused_attn_fwd_qkvpacked( + qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), + o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto kv_shape = @@ -310,12 +321,13 @@ static void FusedAttnForwardImpl( auto q_tensor = TensorWrapper(q, q_shape, dtype); auto kv_tensor = TensorWrapper(k, kv_shape, dtype); nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), - dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, - is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), + q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, softmax_type, window_size_left, window_size_right, + workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; @@ -324,12 +336,13 @@ static void FusedAttnForwardImpl( auto k_tensor = TensorWrapper(k, k_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype); nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), + k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -461,6 +474,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0 min_num_segments = input_batch * max_segments_per_seq; } + auto dummy_d_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) { // the last one is the largest which will be the returned workspace size auto q_cu_seqlens_tensor = @@ -470,37 +486,38 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, - deterministic, query_workspace_tensor.data(), nullptr); + nvte_fused_attn_bwd_qkvpacked( + qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, - kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, deterministic, query_workspace_tensor.data(), - nullptr); + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, deterministic, - query_workspace_tensor.data(), nullptr); + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -532,14 +549,17 @@ static void FusedAttnBackwardImpl( /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); + auto dummy_d_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); @@ -557,10 +577,11 @@ static void FusedAttnBackwardImpl( s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, - deterministic, workspace_tensor.data(), stream); + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, deterministic, + workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto kv_shape = @@ -579,10 +600,11 @@ static void FusedAttnBackwardImpl( s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, workspace_tensor.data(), stream); + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), + q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, deterministic, + workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; @@ -603,11 +625,12 @@ static void FusedAttnBackwardImpl( s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, workspace_tensor.data(), stream); + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, + kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, deterministic, + workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp new file mode 100644 index 000000000..2f226c427 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -0,0 +1,261 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + #ifndef USE_ROCM + +#include "cgemm_helper.h" + +#include "common/util/system.h" +#include "nccl.h" + +namespace transformer_engine { +namespace jax { + +ncclUniqueId CommunicatorHandler::coordinate_nccl_unique_id(const std::string &id_type) { + ncclUniqueId unique_id; + + int tp_domain_id = get_tp_domain_id(); + bool is_tp_leader = (get_local_device_id_within_tp_domain() == 0); + + pid_t pgid = getpgid(0); + + std::string base_path = getenv("NVTE_JAX_NCCL_FILE_PATH", "/tmp"); + std::string id_file = base_path + "/nccl_" + id_type + "_unique_id_pgid_" + std::to_string(pgid) + + "_" + std::to_string(num_total_devices) + "_" + std::to_string(tp_size) + + "_domain_" + std::to_string(tp_domain_id) + ".bin"; + + if (is_tp_leader) { + NVTE_CHECK_NCCL(ncclGetUniqueId(&unique_id)); + + // Write the ID to a temporary file + std::ofstream file(id_file, std::ios::binary); + NVTE_CHECK(file.is_open(), "Failed to create NCCL unique ID file: ", id_file); + file.write(reinterpret_cast(&unique_id), sizeof(ncclUniqueId)); + file.close(); + } else { + // Wait for the ID file to be created and read it + int attempts = 0; + const int max_attempts = 100; + while (attempts < max_attempts) { + std::ifstream file(id_file, std::ios::binary); + if (file.is_open()) { + file.read(reinterpret_cast(&unique_id), sizeof(ncclUniqueId)); + if (file.gcount() == sizeof(ncclUniqueId)) { + file.close(); + break; + } + file.close(); + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + attempts++; + } + NVTE_CHECK(attempts < max_attempts, + "Timeout waiting for " + id_type + " NCCL unique ID file from leader: ", id_file); + } + + if (is_tp_leader) { + _nccl_id_file_name.push_back(id_file); + } + + return unique_id; +} + +void CommunicatorHandler::init(int num_total_devices, int num_devices_per_process, int process_id, + int tp_size) { + // Validate inputs + NVTE_CHECK(num_devices_per_process == 1, + "num_devices_per_process must be == 1, got num_devices_per_process=", + num_devices_per_process); + NVTE_CHECK(num_total_devices >= 1, + "num_total_devices must be >= 1, got num_total_devices=", num_total_devices); + NVTE_CHECK( + num_total_devices % num_devices_per_process == 0, + "num_total_devices must be divisible by num_devices_per_process, got num_total_devices=", + num_total_devices, ", num_devices_per_process=", num_devices_per_process); + + // Validate TP size + NVTE_CHECK(tp_size > 0, "tp_size must be > 0, got tp_size=", tp_size); + NVTE_CHECK(num_total_devices % tp_size == 0, + "num_total_devices must be divisible by tp_size, got num_total_devices=", + num_total_devices, ", tp_size=", tp_size); + + auto &handler = get(false); + handler.num_total_devices = num_total_devices; + handler.num_devices_per_process = num_devices_per_process; + handler.process_id = process_id; + handler.num_processes = num_total_devices / num_devices_per_process; + handler.tp_size = tp_size; + handler.tp_num_domains = num_total_devices / tp_size; + + // Initialize vectors with the correct size + handler.local_device_ids_within_process.resize(num_devices_per_process); + handler.local_device_ids_within_tp_domain.resize(num_devices_per_process); + handler.tp_domain_ids.resize(num_devices_per_process); + handler.global_device_ids.resize(num_devices_per_process); + handler.tp_comms.resize(num_devices_per_process); + + NVTE_CHECK(0 <= process_id && process_id < handler.num_processes, + "Invalid process_id=", process_id, ", which is out of range [0, ", + handler.num_processes, ")"); + + // Initialize local devices and calculate their global device IDs and TP topology + for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) { + // Use the device that JAX has already assigned to this process + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + handler.local_device_ids_within_process[local_idx] = current_device; + handler.global_device_ids[local_idx] = process_id * num_devices_per_process + local_idx; + + // Calculate TP-related values for this device + int global_device_id = handler.global_device_ids[local_idx]; + if (num_devices_per_process == tp_size) { + // Scenario 1: Multi-device per process - TP domain = single process + handler.local_device_ids_within_tp_domain[local_idx] = local_idx; + handler.tp_domain_ids[local_idx] = process_id; + } else { + // Scenario 2: Single device per process - TP domain spans multiple processes + handler.local_device_ids_within_tp_domain[local_idx] = global_device_id % tp_size; + handler.tp_domain_ids[local_idx] = global_device_id / tp_size; + } + } + + ncclUniqueId tp_id = handler.coordinate_nccl_unique_id("tp"); + + NVTE_CHECK_NCCL(ncclGroupStart()); + for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) { + NVTE_CHECK_CUDA(cudaSetDevice(handler.local_device_ids_within_process[local_idx])); + int tp_local_rank = handler.local_device_ids_within_tp_domain[local_idx]; + NVTE_CHECK_NCCL( + ncclCommInitRank(&handler.tp_comms[local_idx], handler.tp_size, tp_id, tp_local_rank)); + } + NVTE_CHECK_NCCL(ncclGroupEnd()); + + // Allocate device memory for barrier operations + NVTE_CHECK_CUDA(cudaMalloc(&handler._device_barrier, sizeof(int))); + + handler._initialize = true; + + // Bootstrap UB via creating a dummy CommOverlapP2PBase object + std::vector buffer_shape{1, 1}; + auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, DType::kFloat32, + JAXX_Collective_Op::ALL_GATHER); +} + +void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, + int tp_size, int num_max_streams, int gemm_priority, + int comm_priority, int num_comm_sm, bool use_ce, + bool aggregate_ag) { + auto &config = CgemmConfig::get(false); + config.init(num_max_streams, gemm_priority, comm_priority, num_comm_sm, use_ce, aggregate_ag); + auto &handler = CommunicatorHandler::get(false); + handler.init(num_total_devices, num_devices_per_process, process_id, tp_size); +} + +int GetCgemmNumMaxStreams() { + auto &config = CgemmConfig::get(); + return config.num_max_streams; +} + +CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector buffer_shape, + DType dtype, + JAXX_Collective_Op collective_op) { + auto &comm_handler = CommunicatorHandler::get(); + auto &cgemm_config = CgemmConfig::get(); + + int device_idx = comm_handler.get_local_device_idx_for_current_device(); + int64_t plan_id = 0; + hash_combine(plan_id, buffer_shape[0], buffer_shape[1], static_cast(dtype), + static_cast(collective_op), comm_handler.tp_size, cgemm_config.num_max_streams, + cgemm_config.gemm_priority, cgemm_config.comm_priority, cgemm_config.num_comm_sm, + cgemm_config.use_ce, cgemm_config.aggregate_ag, device_idx); + + auto it = plan_map.find(plan_id); + if (it != plan_map.end()) { + return it->second.get(); + } + + if (comm_handler.num_devices_per_process == comm_handler.tp_size) { + // Multi-device per process + } else if (comm_handler.num_devices_per_process == 1) { + // Single device per process + NVTE_CHECK(comm_handler.num_total_devices % comm_handler.tp_size == 0, + "For single device per process, num_total_devices must be divisible by tp_size, " + "got num_total_devices=", + comm_handler.num_total_devices, ", tp_size=", comm_handler.tp_size); + } else { + NVTE_ERROR("Unsupported TP configuration: num_devices_per_process=", + comm_handler.num_devices_per_process, ", tp_size=", comm_handler.tp_size, + ". Supported scenarios: " + "(1) num_devices_per_process == tp_size (multi-device per process), " + "(2) num_devices_per_process == 1 (single device per process)"); + } + + std::unique_ptr executor; + executor = std::make_unique( + buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices, + comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size, + comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size, + comm_handler.allgather_func, comm_handler.barrier_func, get_nvte_collective_op(collective_op), + cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority, + cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/, + cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag); + + CommOverlapCore *executor_ptr = executor.get(); + plan_map[plan_id] = std::move(executor); + return executor_ptr; +} + +void CommunicatorHandler::nccl_device_barrier_impl(ExtComm) { + NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using barrier"); + + int device_idx = get_local_device_idx_for_current_device(); + ncclComm_t tp_comm = tp_comms[device_idx]; + + NVTE_CHECK_NCCL( + ncclAllReduce(_device_barrier, _device_barrier, 1, ncclInt, ncclSum, tp_comm, nullptr)); + cudaDeviceSynchronize(); +} + +void CommunicatorHandler::nccl_allgather_impl(void *output_buf, size_t output_bytes, + void *input_buf, size_t input_bytes, ExtComm) { + NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using allgather"); + + int device_idx = get_local_device_idx_for_current_device(); + ncclComm_t tp_comm = tp_comms[device_idx]; + + size_t expected_output_bytes = input_bytes * tp_size; + NVTE_CHECK(output_bytes == expected_output_bytes, "TP allgather buffer size mismatch: expected ", + expected_output_bytes, ", got ", output_bytes); + + NVTE_CHECK_NCCL(ncclAllGather(input_buf, output_buf, input_bytes, ncclChar, tp_comm, nullptr)); + cudaDeviceSynchronize(); +} + +CommunicatorHandler::CommunicatorHandler() : _device_barrier(nullptr) { + allgather_func = [this](void *output_buf, size_t output_bytes, void *input_buf, + size_t input_bytes, ExtComm comm) { + this->nccl_allgather_impl(output_buf, output_bytes, input_buf, input_bytes, comm); + }; + barrier_func = [this](ExtComm comm) { this->nccl_device_barrier_impl(comm); }; +} + +CommunicatorHandler::~CommunicatorHandler() { + if (_initialize && !tp_comms.empty()) { + for (auto &comm : tp_comms) { + if (comm != nullptr) { + ncclCommDestroy(comm); + } + } + } + if (_device_barrier) cudaFree(_device_barrier); + + for (const auto &file_path : _nccl_id_file_name) { + std::remove(file_path.c_str()); + } +} + +} // namespace jax +} // namespace transformer_engine +#endif diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.h b/transformer_engine/jax/csrc/extensions/cgemm_helper.h new file mode 100644 index 000000000..84b2b8154 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.h @@ -0,0 +1,189 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ +#define TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../extensions.h" +#include "common/comm_gemm_overlap/userbuffers/userbuffers.h" +#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" +#include "transformer_engine/comm_gemm_overlap.h" + +namespace transformer_engine { +namespace jax { + +// Configuration singleton for CGEMM parameters +class CgemmConfig { + public: + int num_max_streams; + int gemm_priority; + int comm_priority; + int num_comm_sm; + bool use_ce; + bool aggregate_ag; + + static void init(int _num_max_streams, int _gemm_priority, int _comm_priority, int _num_comm_sm, + bool _use_ce, bool _aggregate_ag) { + auto &config = get(false); + config._initialized = true; + config.num_max_streams = _num_max_streams; + config.gemm_priority = _gemm_priority; + config.comm_priority = _comm_priority; + config.num_comm_sm = _num_comm_sm; + config.use_ce = _use_ce; + config.aggregate_ag = _aggregate_ag; + } + + static CgemmConfig &get(bool is_initialized = true) { + static thread_local CgemmConfig instance; + NVTE_CHECK( + instance._initialized == is_initialized, + "CgemmConfig must be initialized before using it, got is_initialized=", is_initialized); + return instance; + } + + CgemmConfig(const CgemmConfig &) = delete; + CgemmConfig &operator=(const CgemmConfig &) = delete; + + private: + CgemmConfig() = default; + ~CgemmConfig() = default; + bool _initialized = false; +}; + +// Forward declaration +class CollectiveGemmPlanRegistry; + +// NCCL communicator handler for collective GEMM operations +// Support both single process single device AND single process multi device +// Two scenarios: +// 1. Single process multiple devices: TP domain = process (num_devices_per_process == tp_size) +// 2. Single process single device: TP domain spans processes (num_devices_per_process == 1) +class CommunicatorHandler { + public: + int num_total_devices = -1; + int num_devices_per_process = -1; + int process_id = -1; + int num_processes = -1; + + int tp_size = -1; + int tp_num_domains = -1; + std::vector local_device_ids_within_tp_domain; + std::vector tp_domain_ids; + std::vector tp_comms; + + std::vector local_device_ids_within_process; + std::vector global_device_ids; + + int get_global_rank() const { + int device_idx = get_local_device_idx_for_current_device(); + return global_device_ids[device_idx]; + } + + void nccl_device_barrier_impl(ExtComm); + void nccl_allgather_impl(void *output_buf, size_t output_bytes, void *input_buf, + size_t input_bytes, ExtComm); + + ncclComm_t get_comm_for_current_device() const { + int device_idx = get_local_device_idx_for_current_device(); + return tp_comms[device_idx]; + } + + int get_local_device_idx_for_current_device() const { + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + for (int i = 0; i < num_devices_per_process; i++) { + if (local_device_ids_within_process[i] == current_device) { + return i; + } + } + NVTE_ERROR("Current CUDA device ", current_device, + " not found in local_device_ids_within_process"); + } + + int get_local_device_id_within_tp_domain() const { + int device_idx = get_local_device_idx_for_current_device(); + return local_device_ids_within_tp_domain[device_idx]; + } + + int get_tp_domain_id() const { + int device_idx = get_local_device_idx_for_current_device(); + return tp_domain_ids[device_idx]; + } + + int get_tp_num_domains() const { return tp_num_domains; } + + static void init(int num_total_devices, int num_devices_per_process, int process_id, int tp_size); + + private: + ncclUniqueId coordinate_nccl_unique_id(const std::string &id_type); + + public: + static CommunicatorHandler &get(bool is_initialized = true) { + static CommunicatorHandler instance; + NVTE_CHECK(instance._initialize == is_initialized, + "CommunicatorHandler._initialize=", instance._initialize, + ", is_initialized=", is_initialized); + return instance; + } + + ExtAllgatherOp allgather_func; + ExtBarrierOp barrier_func; + + CommunicatorHandler(const CommunicatorHandler &) = delete; + CommunicatorHandler &operator=(const CommunicatorHandler &) = delete; + + private: + CommunicatorHandler(); + ~CommunicatorHandler(); + + bool _initialize = false; + int *_device_barrier = nullptr; + std::vector _nccl_id_file_name; +}; + +// Plan registry for caching collective GEMM executors +class CollectiveGemmPlanRegistry { + public: + static CollectiveGemmPlanRegistry &getInstance() { + static thread_local CollectiveGemmPlanRegistry instance; + return instance; + } + + CommOverlapCore *get_executor(std::vector buffer_shape, DType dtype, + JAXX_Collective_Op collective_op); + + private: + CollectiveGemmPlanRegistry() {} + CollectiveGemmPlanRegistry(const CollectiveGemmPlanRegistry &) = delete; + CollectiveGemmPlanRegistry &operator=(const CollectiveGemmPlanRegistry &) = delete; + + std::unordered_map> plan_map; +}; + +// Function declarations +void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, + int tp_size, int num_max_streams, int gemm_priority, + int comm_priority, int num_comm_sm, bool use_ce, + bool aggregate_ag); + +int GetCgemmNumMaxStreams(); + +} // namespace jax +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ diff --git a/transformer_engine/jax/csrc/extensions/ffi.h b/transformer_engine/jax/csrc/extensions/ffi.h index 852a67c6c..82f062a15 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.h +++ b/transformer_engine/jax/csrc/extensions/ffi.h @@ -24,6 +24,7 @@ using FFI_Stream_Type = xla::ffi::PlatformStream; using Dictionary = xla::ffi::Dictionary; constexpr auto FFI_Prepare = xla::ffi::ExecutionStage::kPrepare; +constexpr auto FFI_Initialize = xla::ffi::ExecutionStage::kInitialize; constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible}; DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type); @@ -106,5 +107,19 @@ inline static size_t te_dtype_bytes(const DType& type) { } } +template +Error_Type wrapInStreamCapture(std::function func, + cudaStream_t stream, Args... args) { + cudaGraph_t graph{}; + NVTE_CHECK_CUDA(cudaStreamBeginCapture(stream, cudaStreamCaptureModeRelaxed)); + + Error_Type error = func(stream, std::forward(args)...); + + NVTE_CHECK_CUDA(cudaStreamEndCapture(stream, &graph)); + NVTE_CHECK_CUDA(cudaGraphDestroy(graph)); + + return error; +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 7015c2f5e..98c1c4ef1 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -9,13 +9,23 @@ #include #include +#include +#include #include #include #include "../extensions.h" +#ifndef USE_ROCM +#include "cgemm_helper.h" +#endif +#include "common.h" #include "common/util/cuda_runtime.h" #include "common/util/string.h" #include "common/util/system.h" +#include "cuda_runtime.h" +#ifndef USE_ROCM +#include "nccl.h" +#endif #include "transformer_engine/swizzle.h" #include "xla/ffi/api/c_api.h" @@ -76,12 +86,77 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( return std::make_tuple(std::move(input), input_shape); } +#ifndef USE_ROCM +Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, + Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, + Result_Type pre_gelu_out, Result_Type workspace, + JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + int64_t rhs_axis_boundary, bool lhs_transposed, + bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, + bool use_split_accumulator, JAXX_Collective_Op collective_op) { + nvte_cublas_handle_init(); + + // Init UB buffer + if (collective_op != JAXX_Collective_Op::NONE) { + auto &comm_handler = CommunicatorHandler::get(); + std::vector lhs_shape = { + product(lhs.dimensions(), 0, lhs_axis_boundary), + product(lhs.dimensions(), lhs_axis_boundary, lhs.dimensions().size())}; + std::vector rhs_shape = { + product(rhs.dimensions(), 0, rhs_axis_boundary), + product(rhs.dimensions(), rhs_axis_boundary, rhs.dimensions().size())}; + + std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], + (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; + + std::vector buffer_shape{0, 0}; + DType buffer_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size; + buffer_shape[1] = lhs_shape[1]; + buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + } else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + buffer_shape[0] = out_shape[0]; + buffer_shape[1] = out_shape[1]; + } + auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, buffer_dtype, + collective_op); + } + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, + FFI::Bind() + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Ret() // output + .Ret() // bias_grad + .Ret() // pre_gelu_out + .Ret() // workspace + .Attr("scaling_mode") + .Attr("lhs_axis_boundary") + .Attr("rhs_axis_boundary") + .Attr("lhs_transposed") + .Attr("rhs_transposed") + .Attr("fuse_bias") + .Attr("fuse_gelu") + .Attr("grad") + .Attr("use_split_accumulator") + .Attr("collective_op")); +#endif + Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, - bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { + bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator, + JAXX_Collective_Op collective_op) { // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || @@ -93,16 +168,9 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); - // Output tensor std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); - auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); - NVTE_CHECK(out_.numel() == output->element_count(), - "cuBLAS GEMM output buffer size is incorrect, " - "expected ", - out_.numel(), " elements ", to_string_like(out_shape), " but got ", - output->element_count(), " elements ", to_string_like(output->dimensions())); // Bias input to forward pass or bias gradient output from backward pass void *bias_ptr = nullptr; @@ -143,9 +211,66 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); - nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), - rhs_transposed, lhs_transposed, grad, workspace_.data(), false, - use_split_accumulator, num_math_sm, stream); + + if (collective_op == JAXX_Collective_Op::NONE) { + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ", + to_string_like(out_shape), " but got ", output->element_count(), " elements ", + to_string_like(output->dimensions())); + + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), + rhs_transposed, lhs_transposed, grad, workspace_.data(), false, + use_split_accumulator, num_math_sm, stream); + } else { +#ifndef USE_ROCM + std::vector buffer_shape{0, 0}; + DType buffer_dtype = out_dtype; + auto &comm_handler = CommunicatorHandler::get(); + if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size; + buffer_shape[1] = lhs_shape[1]; + out_shape[0] = out_shape[0] * comm_handler.tp_size; + buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + } else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + buffer_shape[0] = out_shape[0]; + buffer_shape[1] = out_shape[1]; + out_shape[0] = out_shape[0] / comm_handler.tp_size; + } + auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor( + buffer_shape, buffer_dtype, collective_op); + if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + auto ubuf_out_ = TensorWrapper(executor->get_ubuf_dptr(), buffer_shape, out_dtype); + // Prepare the auxiliary buffer for the reduce-scattered GEMM output + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), + " elements ", to_string_like(out_shape), " but got ", output->element_count(), + " elements ", to_string_like(output->dimensions())); + + // Launch GEMM+RS + executor->split_overlap_rs(rhs_, rhs_transposed, lhs_, lhs_transposed, ubuf_out_, bias_, + pre_gelu_, workspace_, grad, false, use_split_accumulator, out_, + stream); + + } else if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + auto aux_out_ = TensorWrapper(nullptr, std::vector{0}, out_dtype); // Empty + + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), + " elements ", to_string_like(out_shape), " but got ", output->element_count(), + " elements ", to_string_like(output->dimensions())); + // Copy the distributed LHS operand into the local chunk of the communication buffer + executor->copy_into_buffer(stream, lhs_, true, make_lhs_rowwise); + // Launch AG+GEMM + executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, + workspace_, grad, false, use_split_accumulator, aux_out_, stream); + } +#else + NVTE_ERROR("Collective GEMM operations are not supported on ROCm"); +#endif + } return ffi_with_cuda_error_check(); } @@ -171,8 +296,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("fuse_bias") .Attr("fuse_gelu") .Attr("grad") - .Attr("use_split_accumulator"), - GemmFFI_CudaGraph_Traits); + .Attr("use_split_accumulator") + .Attr("collective_op"), + GemmFFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index af7f54feb..41b578f33 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -87,5 +87,31 @@ constexpr struct Alignment { std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); +template +void hash_combine(int64_t &seed, const T &v, Rest... rest) { + seed ^= std::hash{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + (hash_combine(seed, rest), ...); +} + +enum class JAXX_Collective_Op : int64_t { + NONE = 0, + ALL_GATHER = 1, + REDUCE_SCATTER = 2, +}; +#ifndef USE_ROCM +static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) { + switch (op) { + case JAXX_Collective_Op::ALL_GATHER: + return CommOverlapType::AG; + break; + case JAXX_Collective_Op::REDUCE_SCATTER: + return CommOverlapType::RS; + break; + default: + NVTE_ERROR("Invalid Collective Op ", static_cast(op)); + break; + } +} +#endif } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index c35bc6668..523819392 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -180,6 +180,42 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Attr("is_2x"), FFI_CudaGraph_Traits); +Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, + Buffer_Type gamma_buf, Buffer_Type beta_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type amax_buf, Result_Type mu_buf, + Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, + bool zero_centered_gamma, double epsilon, int64_t sm_margin, + JAXX_Scaling_Mode scaling_mode, bool is_2x) { + return wrapInStreamCapture( + std::function(NormForwardFFI), stream, x_buf, scale_buf, gamma_buf, beta_buf, output_buf, + colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, mu_buf, rsigma_buf, + wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, scaling_mode, is_2x); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // x + .Arg() // scale + .Arg() // gamma + .Arg() // beta + .Ret() // output + .Ret() // colwise_output + .Ret() // scale_inv + .Ret() // colwise_scale_inv + .Ret() // amax + .Ret() // mu + .Ret() // rsigma + .Ret() // wkspace + .Attr("norm_type") + .Attr("zero_centered_gamma") + .Attr("epsilon") + .Attr("sm_margin") + .Attr("scaling_mode") + .Attr("is_2x")); + pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, NVTE_Norm_Type norm_type, bool zero_centered_gamma, int sm_margin) { @@ -305,5 +341,32 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI, .Attr("sm_margin"), FFI_CudaGraph_Traits); +Error_Type NormBackwardInitializeFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, + Buffer_Type mu_buf, Buffer_Type rsigma_buf, + Buffer_Type gamma_buf, Result_Type xgrad_buf, + Result_Type wgrad_buf, Result_Type dbeta_buf, + Result_Type wkspace_buf, int64_t norm_type, + bool zero_centered_gamma, int64_t sm_margin) { + return wrapInStreamCapture(std::function(NormBackwardFFI), stream, dz_buf, x_buf, mu_buf, + rsigma_buf, gamma_buf, xgrad_buf, wgrad_buf, dbeta_buf, wkspace_buf, + norm_type, zero_centered_gamma, sm_margin); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardInitializeHandler, NormBackwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // dz + .Arg() // x + .Arg() // mu + .Arg() // rsigma + .Arg() // gamma + .Ret() // xgrad + .Ret() // wgrad + .Ret() // dbeta + .Ret() // wkspace + .Attr("norm_type") + .Attr("zero_centered_gamma") + .Attr("sm_margin")); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 563675988..793a4c59b 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -7,6 +7,10 @@ ************************************************************************/ #include "../extensions.h" +#ifndef USE_ROCM +#include "cgemm_helper.h" +#endif +#include "common/util/cuda_runtime.h" namespace transformer_engine { namespace jax { @@ -22,8 +26,12 @@ pybind11::dict Registrations() { pybind11::dict dict; // Activation - dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); - dict["te_dact_dbias_quantize_ffi"] = EncapsulateFFI(DActLuDBiasQuantizeHandler); + dict["te_act_lu_ffi"] = + pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(ActLuInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(ActLuHandler)); + dict["te_dact_dbias_quantize_ffi"] = pybind11::dict( + pybind11::arg("initialize") = EncapsulateFFI(DActLuDBiasQuantizeInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(DActLuDBiasQuantizeHandler)); // Quantization dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); @@ -45,9 +53,11 @@ pybind11::dict Registrations() { // Normalization dict["te_norm_forward_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("initialize") = EncapsulateFFI(NormForwardInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler)); dict["te_norm_backward_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("initialize") = EncapsulateFFI(NormBackwardInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler)); // Attention @@ -60,7 +70,7 @@ pybind11::dict Registrations() { // GEMM dict["te_gemm_ffi"] = - pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler), pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); // Grouped GEMM @@ -102,6 +112,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); +#ifndef USE_ROCM + m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); + m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); +#endif pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) @@ -184,6 +198,14 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE) .export_values(); + +#ifndef USE_ROCM + pybind11::enum_(m, "JAXX_Collective_Op", pybind11::module_local()) + .value("NONE", JAXX_Collective_Op::NONE) + .value("ALL_GATHER", JAXX_Collective_Op::ALL_GATHER) + .value("REDUCE_SCATTER", JAXX_Collective_Op::REDUCE_SCATTER) + .export_values(); +#endif } } // namespace jax diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 8087159a3..23df1a0ce 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -11,10 +11,12 @@ from typing import Tuple, Sequence from functools import partial +import warnings import jax import jax.numpy as jnp from . import cpp_extensions as tex +from .cpp_extensions.quantization import AmaxScope from .quantize import ( ScaledTensorFactory, ScalingMode, @@ -61,8 +63,12 @@ def dense( kernel: jnp.ndarray, bias: jnp.ndarray = None, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), + batch_sequence_transpose: bool = False, input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, + output_axes: Tuple[str, ...] = None, + using_global_amax_of_x: bool = False, + collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set, quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -76,11 +82,20 @@ def dense( kernel: Weight matrix for the dense layer transformation bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract + batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. + input_axes: Logical axes for sharding the activation input + kernel_axes: Logical axes for sharding the weight matrix + output_axes: Logical axes for sharding the output + using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. + collective_op_set: A set of CollectiveOp objects for forward and backward passes. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ + if batch_sequence_transpose: + warnings.warn("batch_sequence_transpose is not well tested, use with caution!") + if not get_quantize_config().is_fp8_enabled(): input_dtype = x.dtype kernel = kernel.astype(input_dtype) @@ -90,29 +105,30 @@ def dense( kernel, bias, contracting_dims, + batch_sequence_transpose, input_axes, kernel_axes, + output_axes, + using_global_amax_of_x, + collective_op_set, quantizer_set, ) return output -@partial( - jax.custom_vjp, - nondiff_argnums=( - 3, - 4, - 5, - ), -) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9)) def _dense( x, kernel, bias, contracting_dims, + batch_sequence_transpose, input_axes, kernel_axes, - quantizer_set, + output_axes, + using_global_amax_of_x, + collective_op_set, + quantizer_set, # need to be a diff_arg for DelayedScaling state management ): """Internal implementation of dense layer transformation with custom VJP. @@ -124,8 +140,12 @@ def _dense( kernel: Weight matrix bias: Optional bias tensor contracting_dims: Contracting dimensions specification + batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. input_axes: Logical axes for sharding the activation input + output_axes: Logical axes for sharding the output_axes kernel_axes: Logical axes for sharding the weight matrix + using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. + collective_op_set: A set of CollectiveOp objects for forward and backward passes. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: @@ -136,8 +156,12 @@ def _dense( kernel, bias, contracting_dims, + batch_sequence_transpose, input_axes, kernel_axes, + output_axes, + using_global_amax_of_x, + collective_op_set, quantizer_set, ) return output @@ -148,8 +172,12 @@ def _dense_fwd_rule( kernel, bias, contracting_dims, + batch_sequence_transpose, input_axes, kernel_axes, + output_axes, + using_global_amax_of_x, + collective_op_set, quantizer_set, ): """Forward pass rule for dense layer transformation. @@ -175,6 +203,7 @@ def _dense_fwd_rule( x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, + amax_scope=AmaxScope.TPSP if using_global_amax_of_x else AmaxScope.LOCAL, ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) @@ -182,6 +211,7 @@ def _dense_fwd_rule( kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -191,9 +221,12 @@ def _dense_fwd_rule( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + collective_op=collective_op_set.forward, ) + output = with_sharding_constraint_by_logical_axes(output, output_axes) if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape @@ -212,8 +245,16 @@ def _dense_fwd_rule( def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, ctx, grad -): # pylint: disable=unused-argument + contracting_dims, + batch_sequence_transpose, + input_axes, + kernel_axes, + output_axes, + using_global_amax_of_x, + collective_op_set, + ctx, + grad, +): """Backward pass rule for dense layer transformation. Returns: @@ -228,6 +269,7 @@ def _dense_bwd_rule( quantizer_set, flatten_axis_k, ) = ctx + grad = with_sharding_constraint_by_logical_axes(grad, output_axes) fwd_x_contracting_dims, fwd_k_contracting_dims = map( tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims @@ -238,6 +280,7 @@ def _dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, + amax_scope=AmaxScope.LOCAL if using_global_amax_of_x else AmaxScope.TPSP, ) # GEMM NT @@ -254,8 +297,9 @@ def _dense_bwd_rule( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, contracting_dims=(g_contracting_dim, k_contracting_dim), + transpose_batch_sequence=batch_sequence_transpose, + collective_op=collective_op_set.backward, ) - dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) # GEMM TN # x_non_contracting_dims @@ -267,7 +311,10 @@ def _dense_bwd_rule( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dim, g_contracting_dim), + transpose_batch_sequence=batch_sequence_transpose, ) + + dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) return dgrad, wgrad, dbias, quantizer_set diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index fb3ac7b9a..ad66684f2 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[ return drop_path_shape +# TODO(Phuong): move this function to sharding.py def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: """ Extend the given Flax logical axis rules with the predefined TransformerLayer's diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index fc957801a..cf77f8e0a 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -21,6 +21,7 @@ from jax.ad_checkpoint import checkpoint_name from . import cpp_extensions as tex +from .cpp_extensions.quantization import AmaxScope from .layernorm import canonicalize_norm_type from .quantize import ( with_sharding_constraint_by_logical_axes, @@ -40,6 +41,7 @@ def layernorm_mlp( norm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6, + batch_sequence_transpose: bool = False, norm_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None, @@ -48,6 +50,10 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + collective_op_sets: Tuple[tex.CollectiveOpSet] = ( + tex.noop_collective_op_set, + tex.noop_collective_op_set, + ), quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -71,6 +77,7 @@ def layernorm_mlp( norm_type: Type of normalization ("layernorm" or "rmsnorm") zero_centered_gamma: Whether to use zero-centered gamma for normalization epsilon: Small constant for numerical stability in normalization + batch_sequence_transpose: Whether to transpose the batch and sequence dimensions norm_input_axes: Logical axes for sharding the layernorm input dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication @@ -79,6 +86,7 @@ def layernorm_mlp( ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation + collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations Returns: @@ -121,6 +129,7 @@ def layernorm_mlp( norm_type, zero_centered_gamma, epsilon, + batch_sequence_transpose, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -129,12 +138,13 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + collective_op_sets, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -146,6 +156,7 @@ def _layernorm_mlp( norm_type: str, zero_centered_gamma: bool, epsilon: float, + batch_sequence_transpose: bool, norm_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], @@ -154,6 +165,7 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + collective_op_sets: Tuple[tex.CollectiveOpSet], quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -173,12 +185,16 @@ def _layernorm_mlp( norm_type: Type of normalization zero_centered_gamma: Whether to use zero-centered gamma epsilon: Small constant for numerical stability + batch_sequence_transpose: Whether to transpose the batch and sequence dimensions norm_input_axes: Logical axes for layernorm sharding dot_1_input_axes: Logical axes for first matrix multiplication sharding dot_2_input_axes: Logical axes for second matrix multiplication sharding + kernel_1_axes: Logical axes for first weight matrix sharding + kernel_2_axes: Logical axes for second weight matrix sharding ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing activation_type: Activation function(s) + collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations quantizer_sets: Tuple of quantizer sets Returns: @@ -195,6 +211,7 @@ def _layernorm_mlp( norm_type, zero_centered_gamma, epsilon, + batch_sequence_transpose, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -203,6 +220,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + collective_op_sets, quantizer_sets, ) return output @@ -219,6 +237,7 @@ def _layernorm_mlp_fwd_rule( norm_type, zero_centered_gamma, epsilon, + batch_sequence_transpose, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -227,6 +246,7 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + collective_op_sets, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -246,6 +266,10 @@ def _layernorm_mlp_fwd_rule( del kernel_1_axes, kernel_2_axes ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets + collective_op_set_1, collective_op_set_2 = collective_op_sets + + assert not collective_op_set_1.forward.is_reduce_scatter + assert not collective_op_set_2.forward.is_all_gather # x should be in shape of (batch..., hidden) # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate) @@ -272,13 +296,12 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, + amax_scope=AmaxScope.TPSP, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_kernel_1 = tex.quantize( - kernel_1, - flatten_axis=-2, - quantizer=ffn1_quantizer_set.kernel, + kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP ) # NN GEMM @@ -287,8 +310,10 @@ def _layernorm_mlp_fwd_rule( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, bias=bias_1 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, + collective_op=collective_op_set_1.forward, ) if use_bias_1 and tex.gemm_uses_jax_dot(): @@ -317,6 +342,7 @@ def _layernorm_mlp_fwd_rule( casted_kernel_2 = tex.quantize( kernel_2, quantizer=ffn2_quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, ) # NN GEMM @@ -325,8 +351,10 @@ def _layernorm_mlp_fwd_rule( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, + collective_op=collective_op_set_2.forward, ) if use_bias_2 and tex.gemm_uses_jax_dot(): @@ -334,6 +362,8 @@ def _layernorm_mlp_fwd_rule( bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) + # sharding of outputs should be the same as dot_1's input + dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) ctx = ( @@ -363,6 +393,7 @@ def _layernorm_mlp_bwd_rule( norm_type, zero_centered_gamma, epsilon, + batch_sequence_transpose, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -371,6 +402,7 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + collective_op_sets, ctx, grad, ): @@ -409,6 +441,10 @@ def _layernorm_mlp_bwd_rule( ) = ctx ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets + collective_op_set_1, collective_op_set_2 = collective_op_sets + + assert not collective_op_set_1.backward.is_all_gather + assert not collective_op_set_2.backward.is_reduce_scatter # Since the sharding of outputs should be the same as dot_1's input grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) @@ -417,6 +453,7 @@ def _layernorm_mlp_bwd_rule( grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, + amax_scope=AmaxScope.TPSP, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -434,6 +471,8 @@ def _layernorm_mlp_bwd_rule( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), + transpose_batch_sequence=batch_sequence_transpose, + collective_op=collective_op_set_2.backward, ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -448,6 +487,7 @@ def _layernorm_mlp_bwd_rule( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -474,6 +514,8 @@ def _layernorm_mlp_bwd_rule( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), + transpose_batch_sequence=batch_sequence_transpose, + collective_op=collective_op_set_1.backward, ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) @@ -484,6 +526,7 @@ def _layernorm_mlp_bwd_rule( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index e81a614f0..b7828e931 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -17,7 +17,7 @@ import operator import numpy as np -from jax.experimental.custom_partitioning import BATCHING +from jax.experimental.custom_partitioning import BATCHING, CompoundFactor from jax.tree_util import register_pytree_node_class import jax.numpy as jnp @@ -152,12 +152,15 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: @abstractmethod def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization. @@ -232,12 +235,15 @@ def get_grouped_scale_shape( return (n_groups,) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization. @@ -245,7 +251,7 @@ def get_shardy_sharding_rules( The Shardy rules for the scaling mode """ del flatten_axis - input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) + input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) scale_var = BATCHING + unique_var + "_scale_inv" return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) @@ -323,20 +329,23 @@ def get_grouped_scale_shape( return (n_groups,) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix - flatten_axis: Axis along which data can be flattened to 2D for quantization. + flatten_axis: Axis along which data can be flattened to 2D for quantization Returns: The Shardy rules for the scaling mode """ del flatten_axis - input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) + input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) scale_var = BATCHING + unique_var + "_scale_inv" return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) @@ -562,52 +571,55 @@ def get_grouped_scale_shape( return (n_block_x * n_block_y,) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization Returns: The Shardy rules for the scaling mode """ - del flatten_axis - input_spec = [f"{unique_var}{i}" for i in range(input_rank)] - rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)] - colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)] - - # NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors. - # Unfortunately, because Shardy rules are applied to the inner primitive, the - # only way to preserve the relationship is to lower unpadded scales to the - # underlying custom call and pad them in C++. Until that's implemented, the - # Shardy rules for block scales have to be completely disconnected from the - # Shardy rules for the tensor they belong to. - - # # We have to use two different factors in the two CompoundFactors because of Shardy - # # verifier requirements, even though they are the same. - # rowwise_var = unique_var - # colwise_var = f"{unique_var}_" - # input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") - # input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") - - # # The rowwise and colwise scale tensors should be sharded the same way as the input. - # # However, we need to adjust the dimensions where the block scaling factor applies. - # rowwise = input_spec.copy() - # rowwise[-1] = rowwise_var - - # colwise = input_spec.copy() - # colwise[flatten_axis - 1] = colwise_var - - # # This implementation needs to be updated for different block dims. - # assert self._block_dims == (1, 32) + input_rank = len(input_shape) + input_spec = [f"{unique_var}_{i}" for i in range(input_rank)] + flatten_axis = (flatten_axis + input_rank) % input_rank + + # This implementation needs to be updated for different block dims. + assert self._block_dims == (1, 32) + + # We have to use two different factors in the two CompoundFactors because of Shardy + # verifier requirements, even though they are the same. + blocksizes = {} + colwise_var = f"{unique_var}_None" + rowwise_var = f"{unique_var}_None" + if not input_shape[-1] == 32: + rowwise_var = input_spec[-1] + "_compound" + input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x") + blocksizes["blocksize_x"] = 32 + if not input_shape[flatten_axis - 1] == 32: + colwise_var = input_spec[flatten_axis - 1] + "_compound" + input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y") + blocksizes["blocksize_y"] = 32 + + # The rowwise and colwise scale tensors should be sharded the same way as the input. + # However, we need to adjust the dimensions where the block scaling factor applies. + rowwise = input_spec.copy() + rowwise[-1] = rowwise_var + + colwise = input_spec.copy() + colwise[flatten_axis - 1] = colwise_var return QuantizeShardyRules( tuple(input_spec), tuple(rowwise), tuple(colwise), - {}, # {"block_size_rowwise": 32, "block_size_colwise": 32}, + blocksizes, ) @@ -697,18 +709,22 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: return self._get_impl().get_quantize_layout(usage) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis=-1 + self, + input_shape, + unique_var, + flatten_axis=-1, ) -> Tuple[Tuple[str]]: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization. Returns: The Shardy rules for the scaling mode """ - return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis) + return self._get_impl().get_shardy_sharding_rules(input_shape, unique_var, flatten_axis) def get_grouped_scale_shape_2x( self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 339e74e2f..7a8261269 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -13,6 +13,7 @@ from dataclasses import dataclass from typing import Callable, Optional import warnings + import jax import jax.numpy as jnp from jax.interpreters import pxla @@ -364,3 +365,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes if axis != global_mesh_resource().pp_resource: x = lax_paral_op(x, jax.lax.pmax, axis, mesh) return x + + +def tpsp_axis_size(): + """ + Get the size of the tensor parallelism axis. + Return 1 if no TP axis is set. + """ + return get_mesh_axis_size(global_mesh_resource().tpsp_resource) + + +def dp_or_fsdp_axis_size(): + """ + Get the size of the data parallelism or FSDP axis. + Return 1 if no DP/FSDP axis is set. + """ + dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource) + fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource) + return dp_size if dp_size > 1 else fsdp_size diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index a823379f1..8af41ed74 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -17,16 +17,19 @@ import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION +import torch.nn.functional as F import transformer_engine_torch as tex from transformer_engine.pytorch.utils import ( - SplitAlongDim, get_device_compute_capability, - combine_tensors, split_tensor_along_dim, ) -from transformer_engine.pytorch.utils import attention_mask_func +from transformer_engine.pytorch.utils import attention_mask_func, nvtx_range_push, nvtx_range_pop +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) from transformer_engine.pytorch.tensor.quantized_tensor import ( - QuantizedTensor, + QuantizedTensorBase, prepare_for_saving, restore_from_saved, ) @@ -43,7 +46,7 @@ META_O, META_QKV, ) -from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype +from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype, FP8GlobalStateManager from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.jit import no_torch_dynamo from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( @@ -56,6 +59,9 @@ import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils as fa_utils, + combine_and_quantize, + combine_and_dequantize, + print_quantizers, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import ( AttentionLogging as attn_log, @@ -136,6 +142,58 @@ fa_utils.set_flash_attention_3_params() +# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 +_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" + + +class FP8EmulationFunc(torch.autograd.Function): + """ + Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows: + - forward : QKV (quantize+dequantize), P (pass-through), S (quantize+dequantize), O (pass-through) + - backward: dO (quantize+dequantize), dS (pass-through), dP (quantize+dequantize), dQKV (pass-through) + """ + + @staticmethod + def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): + # pylint: disable=missing-function-docstring + if quantizer_name == "QKV_quantizer": + query_layer, key_layer, value_layer = [ + x.contiguous() for x in [tensor1, tensor2, tensor3] + ] + q_fp8, k_fp8, v_fp8 = combine_and_quantize( + qkv_layout, query_layer, key_layer, value_layer, quantizer + ) + tensors = combine_and_dequantize( + qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype + ) + elif quantizer_name in ["S_quantizer", "O_quantizer"]: + t_fp8 = quantizer(tensor1) + tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + else: + tensors = (tensor1, tensor2, tensor3) + ctx.quantizer = quantizer + ctx.quantizer_name = quantizer_name + ctx.qkv_layout = qkv_layout + return tensors[0], tensors[1], tensors[2] + + @staticmethod + def backward(ctx, grad1, grad2, grad3): + # pylint: disable=missing-function-docstring + if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]: + dt_fp8 = ctx.quantizer(grad1) + tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + elif ctx.quantizer_name == "dQKV_quantizer": + query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] + dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize( + ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer + ) + tensors = combine_and_dequantize( + ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype + ) + else: + tensors = grad1, grad2, grad3 + return tensors[0], tensors[1], tensors[2], None, None, None + class UnfusedDotProductAttention(torch.nn.Module): """Parallel attention w/o QKV and Proj Gemms @@ -149,6 +207,7 @@ def __init__( attention_dropout: float = 0.0, attention_dropout_ctx: Optional[Callable] = nullcontext, layer_number: Optional[int] = None, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -156,6 +215,7 @@ def __init__( self.attention_type = attention_type self.attention_dropout_ctx = attention_dropout_ctx self.layer_number = layer_number + self.softmax_type = softmax_type def mask_func(x, y): return ( @@ -192,6 +252,11 @@ def forward( core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, inference_params: Optional[InferenceParams] = None, + softmax_offset: torch.Tensor = None, + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers=None, + fp8_output: bool = False, ) -> torch.Tensor: """Unfused attention fprop""" assert ( @@ -289,6 +354,35 @@ def forward( if apply_qk_layer_scaling: scale /= self.layer_number + if fp8: + # get quantizers from DPA; all Nones if not fp8 + QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( + dpa_utils.get_attention_quantizers(fp8, quantizers) + ) + # S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + if fp8_recipe.float8_current_scaling(): + S_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=S_quantizer.dtype, device="cuda" + ) + dP_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=dP_quantizer.dtype, device="cuda" + ) + + if "2" in qkv_layout or "3" in qkv_layout: + qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout) + qkv_layout = "_".join([qkv_format] * 3) + # quantize and dequantize QKV to emulate FP8 + query_layer, key_layer, value_layer = FP8EmulationFunc.apply( + query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout + ) + # quantize and dequantize dQKV to emulate FP8 + query_layer, key_layer, value_layer = FP8EmulationFunc.apply( + query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout + ) + # Raw attention scores. [b * np, sq, sk] if core_attention_bias_type == "no_bias": matmul_result = torch.baddbmm( @@ -333,7 +427,27 @@ def forward( dtype=query_layer.dtype ) - # attention scores and attention mask [b, np, sq, sk] + if fp8: + # quantize and dequantize dP to emulate FP8 + matmul_result, *_ = FP8EmulationFunc.apply( + matmul_result, None, None, dP_quantizer, "dP_quantizer", None + ) + + # add attention sink to the last column: [b, np, sq, sk+1] + if self.softmax_type != "vanilla": + matmul_result = torch.cat( + [ + matmul_result, + softmax_offset.to(dtype=matmul_result.dtype).expand( + matmul_result.size(0), -1, matmul_result.size(2), -1 + ), + ], + dim=-1, + ) + attention_mask = F.pad(attention_mask, (0, 1), mode="constant", value=False) + attn_mask_type = "arbitrary" + + # attention scores and attention mask softmax_scale = self.layer_number if apply_qk_layer_scaling else None attention_probs = self.scale_mask_softmax( matmul_result, attention_mask, attn_mask_type, softmax_scale @@ -344,6 +458,10 @@ def forward( if "padding" in attn_mask_type: attention_probs = attention_probs.masked_fill(attention_mask, 0) + # remove attention sink: [b, np, sq, sk] + if self.softmax_type != "vanilla": + attention_probs = attention_probs[..., :-1] + # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with self.attention_dropout_ctx(): @@ -364,6 +482,12 @@ def forward( # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + if fp8: + # quantize and dequantize S to emulate FP8 + attention_probs, *_ = FP8EmulationFunc.apply( + attention_probs, None, None, S_quantizer, "S_quantizer", None + ) + # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) @@ -398,6 +522,20 @@ def forward( # [tq, np, hn] --> [tq, hp] context_layer = context_layer.view(total_tokens, -1) + if fp8: + # quantize and dequantize O to emulate FP8 + context_layer, *_ = FP8EmulationFunc.apply( + context_layer, None, None, O_quantizer, "O_quantizer", None + ) + # quantize and dequantize dO to emulate FP8 + context_layer, *_ = FP8EmulationFunc.apply( + context_layer, None, None, dO_quantizer, "dO_quantizer", None + ) + + # quantize O + if fp8_output: + context_layer = O_quantizer(context_layer) + return context_layer @@ -496,6 +634,7 @@ def forward( quantizers=None, inference_params: Optional[InferenceParams] = None, flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), + fp8_output: bool = False, ) -> torch.Tensor: """flash-attn fprop""" @@ -701,6 +840,7 @@ def forward( quantizers=quantizers, pad_between_seqs=False, use_flash_attn_3=use_flash_attn_3, + fp8_output=fp8_output, ) else: from transformer_engine.pytorch.cpu_offload import ( @@ -800,8 +940,6 @@ def convert_to_torch_float8(tensor, dtype): ) return out - # "fp8_mha" decides outputs in fp8, while inputs are inferred from - # the real dtype assert isinstance(key_layer, query_layer.__class__) and isinstance( value_layer, query_layer.__class__ ), "q, k, and v must have the same type." @@ -848,7 +986,7 @@ def convert_to_torch_float8(tensor, dtype): if fp8: output = output.to(dtype=torch_orig_dtype) - if fp8 and fp8_meta["recipe"].fp8_mha: + if fp8 and fp8_output: O_quantizer = quantizers["scaling_fwd"][META_O] output = O_quantizer(output) @@ -876,7 +1014,7 @@ def convert_to_torch_float8(tensor, dtype): if q_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) - if fp8 and fp8_meta["recipe"].fp8_mha: + if fp8 and fp8_output: output_data = ( output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) .transpose(0, 1) @@ -900,7 +1038,7 @@ def convert_to_torch_float8(tensor, dtype): class FusedAttnFunc(torch.autograd.Function): - """Function for FusedAttention with separate Q, K, V tensors""" + """FusedAttention forward and backward implementation""" @staticmethod def forward( @@ -924,6 +1062,7 @@ def forward( qkv_layout, attn_bias_type, attn_mask_type, + softmax_type, window_size, rng_gen, fused_attention_backend, @@ -932,55 +1071,72 @@ def forward( fp8_meta, quantizers, deterministic, + softmax_offset, + fp8_output, + layer_number, ): # pylint: disable=missing-function-docstring - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False - - # FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e4m3fn - fake_dtype = q.dtype + # add NVTX range + nvtx_label = "transformer_engine.FusedAttnFunc.forward" + nvtx_range_push(f"{nvtx_label}") + + # recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; + # may be different from fp8_meta["recipe"] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + + # input types are inferred from the real data while output types are controlled by fp8_output + # fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha) + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_output + + # whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) + # whether bwd kernel in FP8: + is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + + # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + dpa_utils.get_attention_quantizers(fp8, quantizers) ) + + # get nominal data type for out + # FP16/BF16 attention: torch.float16 or torch.bfloat16 + # FP8 attention: torch.float16 or torch.bfloat16 + out_nominal_dtype = q.dtype + if fp8: fused_attention_backend = FusedAttnBackend["FP8"] - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - q_fp8, k_fp8, v_fp8 = None, None, None + # q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) - match qkv_group: - case 1: - dim = qkv_layout.find("3") - qkv = combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = QKV_quantizer(qkv) - q_fp8, k_fp8, v_fp8 = SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True) - case 2: - q_fp8 = QKV_quantizer(q) - dim = qkv_layout.split("_")[1].find("2") - kv = combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = QKV_quantizer(kv_c) - k_fp8, v_fp8 = SplitAlongDim.apply(kv_fp8, dim, [1, 1], True) - case 3: - q_fp8 = QKV_quantizer(q) - k_fp8 = QKV_quantizer(k) - v_fp8 = QKV_quantizer(v) - case _: - raise "Invalid qkv_layout " + qkv_layout - # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn - out_fp8, aux_ctx_tensors = fused_attn_fwd( + q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + + # print quantizers + print_quantizers( + "FusedAttnFunc.forward >> before: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) + + # out_: + # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -989,7 +1145,7 @@ def forward( q_fp8, k_fp8, v_fp8, - fake_dtype, + out_nominal_dtype, fused_attention_backend, attn_bias, cu_seqlens_q_padded, @@ -1004,45 +1160,59 @@ def forward( qkv_layout, attn_bias_type, attn_mask_type, + softmax_type, window_size, rng_gen, + softmax_offset, ) - if is_output_fp8: - out_ret = out_fp8 + + # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_fp8 = out_ + out = out_ + + if isinstance(out_, Float8Tensor): + if not is_output_fp8 or not is_bwd_fp8: + out = out_.dequantize().view(out_.shape) else: - out_ret = out_fp8.dequantize().view(out_fp8.shape) - # is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16 - # is_output_fp8 = True: out_save.dtype = torch.float8_e4m3fn - out_save = out_ret + if is_output_fp8 or ( + is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + ): + out_fp8 = O_quantizer(out_) + + # print quantizers + print_quantizers( + "FusedAttnFunc.forward >> after: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) - if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - # 1: qkv packed, 2: kv packed, 3: qkv separate + # return appropriate tensors + out_ret = out_fp8 if is_output_fp8 else out + + # save appropriate tensors + fp8_tensors = (None, None, None, None) + qkvo_tensors = (None, None, None, None) + if is_bwd_fp8: + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_fp8, k_fp8, v_fp8, None) + qkvo_tensors = (None, None, None, out) + else: + fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + else: if is_input_fp8: - qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) - if qkv_group == 1: - dim = qkv_layout.find("3") - qkv = combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape) - q, k, v = SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True) - if qkv_group == 2: - q = q.dequantize() - dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2") - kv = combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_no_fp8 = kv.dequantize() - k, v = SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True) - if qkv_group == 3: - q = q.dequantize() - k = k.dequantize() - v = v.dequantize() - if is_output_fp8: - out_save = out_fp8.dequantize() - - fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) + qkvo_tensors = (q, k, v, out) else: - # q, k, v, out_ret: torch.float16 or torch.bfloat16 - out_ret, aux_ctx_tensors = fused_attn_fwd( + # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -1051,7 +1221,7 @@ def forward( q, k, v, - fake_dtype, + out_nominal_dtype, fused_attention_backend, attn_bias, cu_seqlens_q_padded, @@ -1066,13 +1236,23 @@ def forward( qkv_layout, attn_bias_type, attn_mask_type, + softmax_type, window_size, rng_gen, + softmax_offset, ) - out_save = out_ret + out = out_ + out_ret = out_ fp8_tensors = (None, None, None, None) + qkvo_tensors = (q, k, v, out) - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + nvtx_range_pop(f"{nvtx_label}") + + ctx.fp8_recipe = fp8_recipe + ctx.fp8 = is_bwd_fp8 + # assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16 + # used when some tensors are base tensors and loose the "dtype" attribute + ctx.nominal_dtype = out_nominal_dtype from transformer_engine.pytorch.cpu_offload import ( CPUOffloadEnabled, @@ -1083,7 +1263,7 @@ def forward( if ctx.fp8: tensor_list = fp8_tensors else: - tensor_list = [q, k, v, out_save] + tensor_list = [q, k, v, out] qkv_layout = "sbhd_sbhd_sbhd" mark_activation_offload(*tensor_list) @@ -1091,7 +1271,6 @@ def forward( ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 - qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, *qkvo_tensors, @@ -1105,11 +1284,14 @@ def forward( ctx.tensor_objects = tensor_objects ctx.fp8_meta = fp8_meta + ctx.layer_number = layer_number + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer ctx.S_quantizer = S_quantizer - if ctx.fp8: + if ctx.fp8 and isinstance(ctx.S_quantizer, Float8Quantizer): ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer.scale = S_quantizer.scale.clone() @@ -1121,6 +1303,7 @@ def forward( ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type + ctx.softmax_type = softmax_type ctx.window_size = window_size ctx.fused_attention_backend = ( fused_attention_backend if (IS_HIP_EXTENSION or ctx.fp8) else FusedAttnBackend["F16_arbitrary_seqlen"] @@ -1133,17 +1316,15 @@ def forward( @staticmethod def backward(ctx, d_out): # pylint: disable=missing-function-docstring - if ctx.is_output_fp8: - assert isinstance( - d_out, Float8Tensor - ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." - - # FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e5m2 - fake_dtype = d_out.dtype - d_out = d_out.contiguous() + # d_out is expected to be in FP8 if is_output_fp8=True, + # but in the case it's not, convert it to FP8 before any operation + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorBase): + d_out = ctx.dO_quantizer(d_out) + if not ctx.use_FAv2_bwd: + d_out._data = d_out._data.contiguous() + elif not ctx.use_FAv2_bwd: + d_out = d_out.contiguous() ( q_fp8, k_fp8, @@ -1197,16 +1378,55 @@ def backward(ctx, d_out): dk = dk[..., : d_out.shape[-1]] dv = dv[..., : d_out.shape[-1]] else: - with torch.cuda.nvtx.range("_FusedAttn"): + with torch.cuda.nvtx.range("FusedAttnFunc.backward"): + # get nominal data type of dq, dk, dv + # FP16/BF16 attention: torch.float16 or torch.bfloat16 + # FP8 attention: torch.float16 or torch.bfloat16 + dqkv_nominal_dtype = ctx.nominal_dtype + if ctx.fp8: + # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 if ctx.is_output_fp8: d_out_fp8 = d_out else: d_out_fp8 = ctx.dO_quantizer(d_out) - dqkv_dtype = TE_DType[d_out_fp8._data.dtype] - # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn - # d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2 - dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( + + # print quantizers + print_quantizers( + "FusedAttnFunc.backward >> before: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) + + # get tex.DType for dq, dk, dv data + dqkv_te_dtype = d_out_fp8._fp8_dtype + + # q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16, + # fp8_dtype = tex.DType.kFloat8E4M3 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + # out_: + # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # + # dq_, dk_, dv_: + # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_ = ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8 + ) + dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, @@ -1214,10 +1434,10 @@ def backward(ctx, d_out): q_fp8, k_fp8, v_fp8, - out_fp8, + out_, d_out_fp8, - fake_dtype, - dqkv_dtype, + dqkv_nominal_dtype, + dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1231,44 +1451,45 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.softmax_type, ctx.window_size, ctx.deterministic, ) - # is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16 - # is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2 - if not ctx.is_input_fp8: - qkv_group = len(ctx.qkv_layout.replace("paged_kv_", "").split("_")) - if qkv_group == 1: - dim = ctx.qkv_layout.find("3") - dqkv_fp8_data = combine_tensors( - [dq_fp8._data, dk_fp8._data, dv_fp8._data], dim - ) - dqkv_fp8 = dq_fp8.make_like( - tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape - ) - dqkv = dqkv_fp8.dequantize() - dq, dk, dv = SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True) - if qkv_group == 2: - dq = dq_fp8.dequantize() - dim = ctx.qkv_layout.split("_")[1].find("2") - dkv_fp8 = combine_tensors([dk_fp8, dv_fp8], dim) - dkv_c_fp8 = dkv_fp8.view( - -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] - ) - dkv = dkv_c_fp8.dequantize() - dk, dv = SplitAlongDim.apply(dkv, dim, [1, 1], True) - if qkv_group == 3: - dq = dq_fp8.dequantize() - dk = dk_fp8.dequantize() - dv = dv_fp8.dequantize() - else: - dq, dk, dv = dq_fp8, dk_fp8, dv_fp8 + # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + dq, dk, dv = dq_, dk_, dv_ + is_float8tensor = isinstance(dq_, Float8Tensor) + if is_float8tensor and not ctx.is_input_fp8: + # return in F16 + dq, dk, dv = combine_and_dequantize( + ctx.qkv_layout, + dq_, + dk_, + dv_, + src_nominal_dtype=dq_.dtype, + ) + if not is_float8tensor and ctx.is_input_fp8: + # return in FP8 + dq, dk, dv = combine_and_quantize( + ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer + ) + + # print quantizers + print_quantizers( + "FusedAttnFunc.backward >> after: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) else: - if isinstance(d_out, QuantizedTensor): - d_out = d_out.dequantize() - dqkv_dtype = TE_DType[d_out.dtype] - # q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16 + if isinstance(d_out, QuantizedTensorBase): + d_out = d_out.dequantize(dtype=ctx.nominal_dtype) + dqkv_te_dtype = TE_DType[d_out.dtype] + # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1279,8 +1500,8 @@ def backward(ctx, d_out): v, out, d_out, - fake_dtype, - dqkv_dtype, + dqkv_nominal_dtype, + dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1294,42 +1515,17 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.softmax_type, ctx.window_size, ctx.deterministic, ) - # if no_bias or alibi, return dqkv - if ctx.attn_bias_type in ["no_bias", "alibi"]: - return ( - None, - None, - None, - None, - None, - None, - None, - None, - None, - dq, - dk, - dv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - # else, return (dqkv, dbias) + d_bias = None + if ctx.attn_bias_type not in ["no_bias", "alibi"]: + d_bias = rest[0] + d_softmax_offset = None + if ctx.softmax_type != "vanilla": + d_softmax_offset = rest[1] return ( None, None, @@ -1343,7 +1539,10 @@ def backward(ctx, d_out): dq, dk, dv, - rest[0], + d_bias, + None, + None, + None, None, None, None, @@ -1356,6 +1555,7 @@ def backward(ctx, d_out): None, None, None, + d_softmax_offset, None, None, ) @@ -1397,6 +1597,7 @@ def __init__( attention_type: str = "self", layer_number: Optional[int] = None, deterministic: bool = False, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -1409,6 +1610,7 @@ def __init__( ) == "1" and get_device_compute_capability() == (9, 0) self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic + self.softmax_type = softmax_type def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ @@ -1460,6 +1662,8 @@ def forward( quantizers=None, pad_between_seqs: bool = False, inference_params: Optional[InferenceParams] = None, + softmax_offset: torch.Tensor = None, + fp8_output: bool = False, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -1560,15 +1764,27 @@ def forward( ) if fp8: + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" " is required for FP8 attention!" ) assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!" - assert not context_parallel or fp8_meta["recipe"].reduce_amax, ( - "Amax reduction across TP+CP group is necessary when using context parallelism with" - " FP8!" - ) + if fp8_recipe.delayed(): + assert not context_parallel or fp8_recipe.reduce_amax, ( + "Amax reduction across TP+CP group is necessary when using context parallelism" + " with FP8!" + ) + if fp8_recipe.float8_current_scaling() and context_parallel: + all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers) + for q in all_quantizers: + if isinstance(q, Float8CurrentScalingQuantizer): + q.with_amax_reduction = True + q.amax_reduction_group = ( + cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group + ) if context_parallel: assert ( @@ -1610,6 +1826,10 @@ def forward( fp8_meta=fp8_meta, quantizers=quantizers, pad_between_seqs=pad_between_seqs, + softmax_type=self.softmax_type, + softmax_offset=softmax_offset, + fp8_output=fp8_output, + layer_number=self.layer_number, ) else: with self.attention_dropout_ctx(): @@ -1633,6 +1853,7 @@ def forward( qkv_layout, core_attention_bias_type, attn_mask_type, + self.softmax_type, window_size, None, # rng_gen fused_attention_backend, @@ -1641,6 +1862,9 @@ def forward( fp8_meta, quantizers, self.deterministic, + softmax_offset, + fp8_output, + self.layer_number, ) # ...hd -> ...(hd) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 7eedd688f..1779e5fc9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -12,7 +12,6 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.utils import ( - combine_tensors, get_cudnn_version, nvtx_range_pop, nvtx_range_push, @@ -23,7 +22,9 @@ fused_attn_bwd, FusedAttnBackend, ) +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.constants import ( dist_group_type, @@ -44,11 +45,18 @@ import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils as fa_utils, + combine_and_quantize, + combine_and_dequantize, + print_quantizers, ) _cu_seqlens_info_with_cp_cache = {} _seq_chunk_ids_cache_for_reordering_before_attn = {} _seq_chunk_ids_cache_for_reordering_after_attn = {} +_softmax_offset_chunk_ids_cache = {} + +# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 +_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" def flash_attn_p2p_communicate( @@ -228,11 +236,11 @@ def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device): @jit_fuser def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication before attention compute.""" - # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] - # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] + # [cp, b, s, h//cp, d] -> [b, cp, s, h//cp, d] + # or [cp, s, b, h//cp, d] -> [cp, s, b, h//cp, d] x = x.movedim(0, seq_dim).contiguous() - # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # [b, cp, s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] + # or [cp, s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) # reorder the sequence chunks x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) @@ -242,13 +250,13 @@ def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_siz @jit_fuser def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication after attention compute.""" - # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # [b, cp*2, s//2, h//cp, d] -> [cp*2, b, s//2, h//cp, d] + # or [cp*2, s//2, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.movedim(seq_dim, 0).contiguous() # reorder the sequence chunks x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) - # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] + # [cp*2, b, s//2, h//cp, d] -> [cp, 2, b, s//2, h//cp, d] + # or [cp*2, s//2, b, h//cp, d] -> [cp, 2, s//2, b, h//cp, d] x = x.view(cp_size, 2, *x.shape[1:]) return x @@ -280,16 +288,16 @@ def flash_attn_a2a_communicate( x = reorder_seq_chunks_for_a2a_before_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) - # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + # [b, cp*2, s//2, h//cp, d] -> [b, cp*s, h//cp, d] + # or [cp*2, s//2, b, h//cp, d] -> [cp*s, b, h//cp, d] a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, s, np, hn] -> [b, s, cp, np//cp, hn] - # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + # [b, s, h, d] -> [b, s, cp, h//cp, d] + # or [s, b, h, d] -> [s, b, cp, h//cp, d] x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) - # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] - # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + # [b, s, cp, h//cp, d] -> [cp, b, s, h//cp, d] + # or [s, b, cp, h//cp, d] -> [cp, s, b, h//cp, d] a2a_inputs[i] = x.movedim(-3, 0).contiguous() else: for i in range(len(a2a_inputs) + 2): @@ -300,8 +308,8 @@ def flash_attn_a2a_communicate( ) if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # [b, cp*s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] + # or [cp*s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( @@ -311,16 +319,65 @@ def flash_attn_a2a_communicate( with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] - # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + # [cp, 2, b, s//2, h//cp, d] -> [b, 2, s//2, cp, h//cp, d] + # or [cp, 2, s//2, b, h//cp, d] -> [2, s//2, b, cp, h//cp, d] x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() - # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] - # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + # [b, 2, s//2, cp, h//cp, d] -> [b*s, h, d] + # or [2, s//2, b, cp, h//cp, d] -> [s*b, h, d] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) torch.cuda.current_stream().wait_stream(cp_stream) return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs +def flash_attn_a2a_communicate_softmax_offset( + tensor: torch.Tensor, + h_dim: int, + cp_size: int, + cp_group: dist_group_type, + cp_stream: torch.cuda.Stream, + before_attn: bool, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Split/AllGather communication for softmax offset.""" + if tensor is None: + return None + + global _softmax_offset_chunk_ids_cache + device = tensor.device + if (cp_size, device) not in _softmax_offset_chunk_ids_cache: + chunk_ids = torch.arange(cp_size, dtype=torch.int32, device=device) + _softmax_offset_chunk_ids_cache[(cp_size, device)] = chunk_ids + else: + chunk_ids = _softmax_offset_chunk_ids_cache[(cp_size, device)] + + if before_attn: + # softmax_offset: split round-robin to CP ranks + # [1, h, 1, 1] -> [1, cp, h//cp, 1, 1] + shape = tensor.shape + tensor = tensor.view( + *shape[:h_dim], cp_size, shape[h_dim] // cp_size, *shape[(h_dim + 1) :] + ) + rank = get_distributed_rank(cp_group) + output = torch.index_select(tensor, dim=h_dim, index=chunk_ids[rank]) + output = output.view(*shape[:h_dim], -1, *shape[(h_dim + 1) :]) + else: + # d_softmax_offset: all-gather from all ranks to all ranks + # [1, h//cp, 1, 1] -> [1, h, 1, 1] + inp = tensor.view(-1) + output = torch.empty(cp_size * inp.shape[0], dtype=tensor.dtype, device=device) + with torch.cuda.stream(cp_stream): + torch.distributed.all_gather_into_tensor( + output, + inp, + group=cp_group, + async_op=False, + ) + torch.cuda.current_stream().wait_stream(cp_stream) + output = output.view( + *tensor.shape[:h_dim], cp_size * tensor.shape[h_dim], *tensor.shape[h_dim + 1 :] + ) + return output + + def _get_cu_seqlens_info_with_cp( batch_size: int, max_seqlen: int, @@ -420,6 +477,585 @@ def get_fa_args( ] +def cp_p2p_fwd_prepare_qkv( + q_part, + k_part, + v_part, + qkv_format, + pad_between_seqs, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + cu_seqlens_q_half, + cu_seqlens_kv_half, + rank, + step, + cp_size, + section, +): + """Prepare q, k, v and cu_seqlens for CP P2P forward""" + cu_seqlens_q_per_step = None + cu_seqlens_kv_per_step = None + if section in ["diagonal", "all"]: + if pad_between_seqs: + cu_seqlens_q_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + rank_ = rank if section == "diagonal" else (rank - step) % cp_size + cu_seqlens_kv_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank_, True, True + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step = cu_seqlens_q + cu_seqlens_kv_per_step = cu_seqlens_kv + + if qkv_format == "bshd": + # [b, 2, s//2, h, d] -> [b, s, h, d] + q_part, k_part, v_part = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q_part, k_part, v_part] + ] + elif qkv_format == "sbhd": + # [2, s//2, b, h, d] -> [s, b, h, d] + q_part, k_part, v_part = [x.view(-1, *x.shape[-3:]) for x in [q_part, k_part, v_part]] + + elif section == "lower-triangle": + if pad_between_seqs: + cu_seqlens_q_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + cu_seqlens_kv_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - step) % cp_size, + True, + False, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step = cu_seqlens_kv // (cp_size * 2) + else: + cu_seqlens_q_per_step = cu_seqlens_q + cu_seqlens_kv_per_step = cu_seqlens_kv_half + + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq, h, d] + q_part = q_part.view(q_part.shape[0], -1, *q_part.shape[-2:]) + # [b, 2, sk//2, h, d] -> [b, sk//2, h, d] + k_part = k_part[:, 0, ...] + v_part = v_part[:, 0, ...] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq, b, h, d] + q_part = q_part.view(-1, *q_part.shape[-3:]) + # [2, sk//2, b, h, d] -> [sk//2, b, h, d] + k_part = k_part[0] + v_part = v_part[0] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) + v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) + + elif section == "upper-triangle": + if pad_between_seqs: + cu_seqlens_q_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True + ) + cu_seqlens_kv_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - step) % cp_size, + True, + True, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step = cu_seqlens_q // (cp_size * 2) + cu_seqlens_kv_per_step = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step = cu_seqlens_q_half + cu_seqlens_kv_per_step = cu_seqlens_kv + + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + q_part = q_part[:, 1, ...] + # [b, 2, sk//2, h, d] -> [b, sk, h, d] + k_part, v_part = [x.view(x.shape[0], -1, *x.shape[-2:]) for x in [k_part, v_part]] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_part = q_part[1] + # [2, sk//2, b, h, d] -> [sk, b, h, d] + k_part, v_part = [x.view(-1, *x.shape[-3:]) for x in [k_part, v_part]] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + q_part = tex.thd_read_half_tensor(q_part, cu_seqlens_q_padded, 1) + + return q_part, k_part, v_part, cu_seqlens_q_per_step, cu_seqlens_kv_per_step + + +def cp_p2p_fwd_fused_attn( + attn_bias, + attn_bias_, + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + softmax_scale, + dropout_p, + qkv_layout, + attn_mask_type, + attn_bias_type, + fp8, + q_fp8, + k_fp8, + v_fp8, + fwd_nominal_dtype, + S_quantizer_per_step, + O_quantizer_per_step, + rank, + step, + cp_size, + q_part, + k_part, + v_part, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + section, +): + """Per-tile forward call of CP P2P with FusedAttention backend""" + attn_bias_inputs = None + max_seqlen_q_ = None + max_seqlen_kv_ = None + cu_seqlens_q_ = None + cu_seqlens_kv_ = None + attn_mask_type_ = None + cu_seqlens_q_padded_ = None + cu_seqlens_kv_padded_ = None + if section in ["diagonal", "all"]: + if attn_bias is not None: + idx = (rank - step) % cp_size + attn_bias_inputs = torch.cat( + ( + attn_bias[..., idx, :], + attn_bias[..., (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + attn_mask_type_ = attn_mask_type + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + elif section == "lower-triangle": + k_part = k_part.contiguous() + v_part = v_part.contiguous() + if attn_bias is not None: + idx = (rank - step) % cp_size + attn_bias_inputs = attn_bias[..., idx, :].contiguous() + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv // 2 + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = ( + cu_seqlens_kv_padded // 2 if cu_seqlens_kv_padded is not None else None + ) + elif section == "upper-triangle": + q_part = q_part.contiguous() + if attn_bias is not None: + idx = (rank - step) % cp_size + attn_bias_inputs = torch.cat( + ( + attn_bias_[..., 1, :, idx, :], + attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + max_seqlen_q_ = max_seqlen_q // 2 + max_seqlen_kv_ = max_seqlen_kv + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + cu_seqlens_q_padded_ = cu_seqlens_q_padded // 2 if cu_seqlens_q_padded is not None else None + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + + fp8_meta_kwargs = {} + if fp8: + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step + fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step + + out_per_step, aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q_, + max_seqlen_kv_, + cu_seqlens_q_, + cu_seqlens_kv_, + q_part, + k_part, + v_part, + fake_dtype=fwd_nominal_dtype, + fused_attention_backend=fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type_, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs, + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, + **fp8_meta_kwargs, + ) + + if fp8: + softmax_lse_per_step, _, rng_states = aux_ctx_tensors + else: + softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors + attn_bias = rest[0] if len(rest) > 0 else None + + return out_per_step, softmax_lse_per_step, rng_states, attn_bias + + +def cp_p2p_fwd_flash_attn( + use_flash_attn_3, + qkv_format, + fa_forward_kwargs, + flash_attn_fwd, + max_seqlen_q, + max_seqlen_kv, + q_part, + k_part, + v_part, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + section, +): + """Per-tile forward call of CP P2P with FlashAttention backend""" + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + causal_ = False + if section in ["diagonal", "all"]: + causal_ = section == "diagonal" + elif section == "lower-triangle": + max_seqlen_kv_ = max_seqlen_kv // 2 + elif section == "upper-triangle": + max_seqlen_q_ = max_seqlen_q // 2 + if section in ["lower-triangle", "upper-triangle"]: + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_forward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 + + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_, + cu_seqlens_kv=cu_seqlens_kv_, + max_seqlen_q=max_seqlen_q_, + max_seqlen_kv=max_seqlen_kv_, + ) + fa_outputs = flash_attn_fwd( + q_part, + k_part, + v_part, + *fa_forward_args_thd, + causal=causal_, + **fa_forward_kwargs, + ) + rng_states = None + if not fa_utils.v2_7_0_plus: + out_per_step = fa_outputs[4] + softmax_lse_per_step = fa_outputs[5] + if not use_flash_attn_3: + rng_states = fa_outputs[7] + else: + out_per_step = fa_outputs[0] + softmax_lse_per_step = fa_outputs[1] + if not use_flash_attn_3: + rng_states = fa_outputs[3] + + return out_per_step, softmax_lse_per_step, rng_states + + +def cp_p2p_bwd_prepare_qkv( + q_part, + k_part, + v_part, + out_part, + dout_part, + qkv_format, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + section, +): + """Prepare q, k, v and cu_seqlens for CP P2P backward""" + if section in ["diagonal", "all"]: + if qkv_format == "bshd": + # [b, 2, s//2, h, d] -> [b, s, h, d] + q_part, k_part, v_part, out_part, dout_part = [ + x.view(x.shape[0], -1, *x.shape[-2:]) + for x in [q_part, k_part, v_part, out_part, dout_part] + ] + elif qkv_format == "sbhd": + # [2, s//2, b, h, d] -> [s, b, h, d] + q_part, k_part, v_part, out_part, dout_part = [ + x.view(-1, *x.shape[-3:]) for x in [q_part, k_part, v_part, out_part, dout_part] + ] + elif section == "lower-triangle": + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq, h, d] + q_part, out_part, dout_part = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q_part, out_part, dout_part] + ] + # [b, 2, sk//2, h, d] -> [b, sk, h, d] + k_part = k_part[:, 0] + v_part = v_part[:, 0] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq, b, h, d] + q_part, out_part, dout_part = [ + x.view(-1, *x.shape[-3:]) for x in [q_part, out_part, dout_part] + ] + # [2, sk//2, b, h, d] -> [sk, b, h, d] + k_part = k_part[0] + v_part = v_part[0] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) + v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) + elif section == "upper-triangle": + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + q_part, out_part, dout_part = q_part[:, 1], out_part[:, 1], dout_part[:, 1] + # [b, 2, sk//2, h, d] -> [b, sk, h, d] + k_part, v_part = [x.view(x.shape[0], -1, *x.shape[-2:]) for x in [k_part, v_part]] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_part, out_part, dout_part = q_part[1], out_part[1], dout_part[1] + # [2, sk//2, b, h, d] -> [sk, b, h, d] + k_part, v_part = [x.view(-1, *x.shape[-3:]) for x in [k_part, v_part]] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + q_part, out_part, dout_part = [ + tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1) + for x in [q_part, out_part, dout_part] + ] + + return q_part, k_part, v_part, out_part, dout_part + + +def cp_p2p_bwd_fused_attn( + fp8, + fp8_recipe, + q_fp8, + kv_fp8, + out_fp8, + dout_fp8, + softmax_lse, + softmax_lse_, + rng_states, + attn_dbias, + attn_biases, + max_seqlen_q, + max_seqlen_kv, + step, + cp_size, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + softmax_scale, + dropout_p, + qkv_layout, + attn_mask_type, + attn_bias_type, + deterministic, + fwd_nominal_dtype, + bwd_nominal_dtype, + bwd_output_te_dtype, + S_quantizer, + dP_quantizer_per_step, + dQKV_quantizer_per_step, + q_part, + k_part, + v_part, + out_part, + dout_part, + section, +): + """Per-tile backward call of CP P2P with FusedAttention backend""" + if fp8: + aux_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] + + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + attn_mask_type_ = attn_mask_type + + if section == "lower-triangle": + k_part = k_part.contiguous() + v_part = v_part.contiguous() + max_seqlen_kv_ = max_seqlen_kv // 2 + cu_seqlens_kv_padded_ = None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + elif section == "upper-triangle": + q_part, out_part, dout_part = [x.contiguous() for x in [q_part, out_part, dout_part]] + if fp8: + aux_tensors = [ + softmax_lse_, + softmax_lse_, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] + + max_seqlen_q_ = max_seqlen_q // 2 + cu_seqlens_q_padded_ = None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + + if attn_dbias is not None: + aux_tensors += [attn_biases[cp_size - step - 1]] + + fp8_meta_kwargs = {} + if fp8: + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip( + [q_fp8, kv_fp8, kv_fp8], + [q_part, k_part, v_part], + ) + ] + if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) + dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_quantizer_per_step + + dq, dk, dv, dbias, *_ = fused_attn_bwd( + max_seqlen_q_, + max_seqlen_kv_, + cu_seqlens_q_per_step[cp_size - step - 1], + cu_seqlens_kv_per_step[cp_size - step - 1], + q_part, + k_part, + v_part, + out_part, + dout_part, + bwd_nominal_dtype, + bwd_output_te_dtype, + aux_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type_, + attn_bias_type=attn_bias_type, + deterministic=deterministic, + **fp8_meta_kwargs, + ) + + return dq, dk, dv, dbias + + +def cp_p2p_bwd_flash_attn( + use_flash_attn_3, + qkv_format, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + step, + cp_size, + fa_backward_kwargs, + flash_attn_bwd, + rng_states, + softmax_lse, + softmax_lse_, + q_part, + k_part, + v_part, + out_part, + dout_part, + section, +): + """Per-tile backward call of CP P2P with FlashAttention backend""" + dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_backward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + if not use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - step - 1] + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + softmax_lse__ = softmax_lse + causal_ = False + if section == "diagonal": + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_backward_kwargs["window_size"] = (-1, 0) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = 0 + causal_ = True + elif section == "lower-triangle": + max_seqlen_kv_ = max_seqlen_kv // 2 + elif section == "upper-triangle": + max_seqlen_q_ = max_seqlen_q // 2 + softmax_lse__ = softmax_lse_ + + fa_backward_args_thd = get_fa_args( + False, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - step - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - step - 1], + max_seqlen_q=max_seqlen_q_, + max_seqlen_kv=max_seqlen_kv_, + dq=dq, + dk=dk, + dv=dv, + ) + flash_attn_bwd( + dout_part, + q_part, + k_part, + v_part, + out_part, + softmax_lse__, + *fa_backward_args_thd, + causal=causal_, + **fa_backward_kwargs, + ) + + return dq, dk, dv + + class AttnFuncWithCPAndKVP2P(torch.autograd.Function): """ Attention implementation with context parallelism. Exchange KV between CP ranks @@ -461,30 +1097,24 @@ def forward( quantizers, pad_between_seqs, use_flash_attn_3, + fp8_output, + layer_number, ): # pylint: disable=missing-function-docstring - nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") - enable_mla = k.shape[-1] != v.shape[-1] - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + # add NVTX range + nvtx_label = "transformer_engine.AttnFuncWithCPAndKVP2P.forward" + nvtx_range_push(f"{nvtx_label}") + + # set up CP groups for cp_comm_type = {'p2p', 'a2a+p2p'} + cp_group_a2a = None + cp_size_a2a = 1 + rank_a2a = 0 if isinstance(cp_group, list): - assert ( - qkv_format != "thd" - ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!" - assert attn_bias_type == "no_bias", ( - f"{attn_bias_type} bias type is not supported with hierarchical CP implementation" - " yet!" - ) cp_group_a2a = cp_group[0] cp_size_a2a = get_distributed_world_size(cp_group_a2a) rank_a2a = get_distributed_rank(cp_group_a2a) cp_group = cp_group[1] - else: - cp_group_a2a = None - cp_size_a2a = 1 - rank_a2a = 0 - cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] @@ -494,18 +1124,19 @@ def forward( device_compute_capability < (10, 0) and cp_size == 2 ) + # set up attention args + enable_mla = k.shape[-1] != v.shape[-1] causal = "causal" in attn_mask_type - padding = "padding" in attn_mask_type + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) batch_dim = None seq_dim = None cu_seqlens_q_half, cu_seqlens_kv_half = None, None + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format if qkv_format in ["bshd", "sbhd"]: seq_dim = qkv_format.index("s") - if enable_mla: - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - else: - qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None if use_fused_attention: batch_dim = qkv_format.index("b") @@ -516,7 +1147,6 @@ def forward( q.shape[batch_dim], max_seqlen_kv, cp_size, cu_seqlens_kv ) else: - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size @@ -526,79 +1156,110 @@ def forward( cu_seqlens_kv_per_step = [None for _ in range(cp_size)] fused_attn_backend = None - qkv_dtype = q.dtype amax_per_step = None S_quantizer_per_step = [None for _ in range(cp_size)] - O_CP_quantizer_per_step = [None for _ in range(cp_size)] - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = False + O_quantizer_per_step = [None for _ in range(cp_size)] + + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + fwd_nominal_dtype = q.dtype + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_output + is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; + # may be different from fp8_meta["recipe"] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] ( QKV_quantizer, O_quantizer, - O_CP_quantizer, S_quantizer, dQKV_quantizer, - dQKV_CP_quantizer, dO_quantizer, dP_quantizer, - ) = dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True) + ) = dpa_utils.get_attention_quantizers(fp8, quantizers) + + q_f16 = None + q_fp8, k_fp8, v_fp8 = (None, None, None) + # communicate for the 'a2a' part of 'a2a+p2p' + if cp_size_a2a > 1: + if fp8 and is_input_fp8: + QKV_quantizer = q._quantizer + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = (q._data, k._data, v._data) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) + q, k, v = flash_attn_a2a_communicate( + [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True + ) + if fp8 and is_input_fp8: + q_fp8, k_fp8, v_fp8 = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q, k, v]) + ] + q, k, v = q_fp8, k_fp8, v_fp8 + # convert qkv to the right type if fp8: - if use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] + assert use_fused_attention, "FP8 is only supported with Fused Attention!" + fused_attn_backend = FusedAttnBackend["FP8"] - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha - if is_input_fp8: - QKV_quantizer = q._quantizer - q, k, v = q._data, k._data, v._data - else: - q_f16, k_f16, v_f16 = q, k, v - if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = QKV_quantizer(q_f16)._data - if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]] - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) - # partial result quantizer - for i in range(cp_size): - S_quantizer_per_step[i] = S_quantizer.copy() - S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) - O_CP_quantizer_per_step[i] = O_CP_quantizer.copy() - O_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if is_input_fp8: + # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype + # q, k, v: torch.Tensor, dtype=torch.uint8 + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] else: - assert False, "FP8 is only supported with Fused Attention!" + # q_f16: torch.Tensor, dtype=fwd_nominal_dtype + # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype + # q, k, v: torch.Tensor, dtype=torch.uint8 + q_f16 = q + q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.forward >> before: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) + + # amax_per_step[0]: amax_s x cp_size + # amax_per_step[1]: amax_o x cp_size + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; + # only used to hold temporary scale/amax values (output only, no quantization op) + for i in range(cp_size): + S_quantizer_per_step[i] = S_quantizer.copy() + S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + O_quantizer_per_step[i] = O_quantizer.copy() + O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: + # q_f16: torch.Tensor, dtype=fwd_nominal_dtype + # q, k, v: torch.Tensor, dtype=fwd_nominal_dtype q_f16 = q if use_fused_attention: fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] - if cp_size_a2a > 1: - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) - - q, k, v = flash_attn_a2a_communicate( - [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True - ) - if not fp8: - q_f16 = q - elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_f16 = q - q = QKV_quantizer(q_f16)._data - + # split qkv to two halves and prepare for load balancing assert qkv_format == "thd" or ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" if causal: if qkv_format == "bshd": - # [b, s, np, hn] -> [b, 2, s//2, np, hn] + # [b, s, h, d] -> [b, 2, s//2, h, d] q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]] elif qkv_format == "sbhd": - # [s, b, np, hn] -> [2, s//2, b, np, hn] + # [s, b, h, d] -> [2, s//2, b, h, d] q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]] + attn_bias_ = None if attn_bias is not None: assert len(attn_bias.shape) == 4, ( "Only support bias shape of [b, h, sq, sk] for forward, " @@ -607,7 +1268,7 @@ def forward( assert ( attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 ), "Sequence length does not meet divisible requirements!" - # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + # [b, h, sq, sk] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] attn_bias_ = attn_bias.view( *attn_bias.shape[:-2], 2, @@ -615,12 +1276,14 @@ def forward( 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size), ) - # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)] + # [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)] attn_bias = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) ) - assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" + # stats tensor shape: + # BHS1 before cuDNN 9.6 or flash-attention v2.6/v3 + # TH1 after cuDNN 9.6 or flash-attention v2.6/v3 softmax_lse_in_packed_format = False if qkv_format == "thd": if use_fused_attention: @@ -628,7 +1291,9 @@ def forward( else: softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3 + # set up args for FlashAttention backend flash_attn_fwd = None + fa_forward_kwargs = {} if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if use_flash_attn_3: @@ -667,11 +1332,9 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - # Flash Attn inputs + # set up inputs for forward q_inputs = [None, None] kv_inputs = [None, None] - attn_bias_inputs = [None, None] - # Flash Attn outputs out_per_step = [None for _ in range(cp_size)] softmax_lse_per_step = [None for _ in range(cp_size)] rng_states = [None for _ in range(cp_size)] @@ -683,19 +1346,15 @@ def forward( fwd_results_correction_done = torch.cuda.Event() p2p_comm_buffers = [None for _ in range(cp_size)] - if enable_mla: - # If MLA, the shape of k and v does not match, so we flatten them - # and split them after receiving them. - k_shape = k.shape - k_numel = k.numel() - v_shape = v.shape - p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) - elif qkv_format in ["bshd", "sbhd"]: - p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) - else: # qkv_format == "thd" - p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) + k_shape = k.shape + k_numel = k.numel() + v_shape = v.shape + p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) send_recv_reqs = [[], []] + # P2P communication and compute: each rank has cp_size steps + # f16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype + # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 out = None for i in range(cp_size + 1): if i < cp_size: @@ -716,634 +1375,205 @@ def forward( batch_p2p_comm, ) - if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - kv_inputs[i % 2] = p2p_comm_buffers[i] + kv_inputs[i % 2] = p2p_comm_buffers[i] + k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) + v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) + q_part = q + + prepare_inputs = [ + q_part, + k_part, + v_part, + qkv_format, + pad_between_seqs, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + cu_seqlens_q_half, + cu_seqlens_kv_half, + rank, + i, + cp_size, + ] + if use_fused_attention: + fused_attn_inputs = [ + attn_bias, + attn_bias_, + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + softmax_scale, + dropout_p, + qkv_layout, + attn_mask_type, + attn_bias_type, + fp8, + q_fp8, + k_fp8, + v_fp8, + fwd_nominal_dtype, + S_quantizer_per_step[i], + O_quantizer_per_step[i], + rank, + i, + cp_size, + ] else: - # KV exchange is in BF16/FP16, cast received KV in each step - kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data - if enable_mla: - # If MLA, k and v are flattened, so split them after receiving. - k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) - v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) + flash_attn_inputs = [ + use_flash_attn_3, + qkv_format, + fa_forward_kwargs, + flash_attn_fwd, + max_seqlen_q, + max_seqlen_kv, + ] + + # cp_size = 4: + # + # step + # section | 0 1 2 3 + # -------------------- + # G 0 | d, u, u, u, + # P 1 | l, d, u, u, + # U 2 | l, l, d, u, + # 3 | l, l, l, d, + # + # Each GPU holds a slice of Q and KV. To compute the attention of each Q slice, each GPU + # runs cp_size steps to get the partial results of its own Q and all KV slices. KV is communicated + # in a point-to-point, ring fashion. For attn_mask_type = causal, there are three attention + # patterns in the cp_size x cp_size (i.e. GPU x step) matrix, the diagonal tiles, the lower-triangle + # tiles, and the upper-triangle tiles. For attn_mask_type != causal, the pattern is all the same. if causal: if i == 0: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q - cu_seqlens_kv_per_step[i] = cu_seqlens_kv - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - if enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - if enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[2:]) - v_part = v_part.view(-1, *v_part.shape[2:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - q_inputs[i % 2] = q + section = "diagonal" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = torch.cat( - ( - attn_bias[..., idx, :], - attn_bias[..., (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() - - q_part = q_inputs[i % 2] - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - fake_dtype=qkv_dtype, - fused_attention_backend=fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - **fp8_meta_kwargs, + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + ) = cp_p2p_fwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None else: - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, - ) - fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - k_part, - v_part, - *fa_forward_args_thd, - causal=True, - **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] elif i <= rank: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, - cu_seqlens_kv_padded, - cp_size, - (rank - i) % cp_size, - True, - False, - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2) - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q - cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - if enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk//2, np, hn] - k_part = k_part[:, 0, ...] - v_part = v_part[:, 0, ...] - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - if enable_mla: - # [2, sk//2, b, np, hn] -> [sk//2, b, np, hn] - k_part = k_part[0] - v_part = v_part[0] - else: - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][0] - elif qkv_format == "thd": - q_inputs[i % 2] = q - if enable_mla: - # [t, np, hn] -> [t/2, np, hn] - k_part = tex.thd_read_half_tensor( - k_part, cu_seqlens_kv_padded, 0 - ) - v_part = tex.thd_read_half_tensor( - v_part, cu_seqlens_kv_padded, 0 - ) - else: - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_kv_padded, 0 - ) + section = "lower-triangle" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - if enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - else: - kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() - - q_part = q_inputs[i % 2] - if not enable_mla: - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv // 2, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - qkv_dtype, - fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None - if cu_seqlens_kv_padded is None - else cu_seqlens_kv_padded // 2 - ), - **fp8_meta_kwargs, + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + ) = cp_p2p_fwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None else: - if enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - else: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv // 2, - ) - if use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_forward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_forward_kwargs["window_size_left"] = -1 - fa_forward_kwargs["window_size_right"] = -1 - fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - k_part, - v_part, - *fa_forward_args_thd, - causal=False, - **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] else: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, - cu_seqlens_kv_padded, - cp_size, - (rank - i) % cp_size, - True, - True, - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q_half - cu_seqlens_kv_per_step[i] = cu_seqlens_kv - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_inputs[i % 2] = q[:, 1, ...] - if enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_inputs[i % 2] = q[1] - if enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[2:]) - v_part = v_part.view(-1, *v_part.shape[2:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor( - q, cu_seqlens_q_padded, 1 - ) + section = "upper-triangle" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - q_inputs[i % 2] = q_inputs[i % 2].contiguous() - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = torch.cat( - ( - attn_bias_[..., 1, :, idx, :], - attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() - - q_part = q_inputs[i % 2] - if not enable_mla: - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q // 2, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - qkv_dtype, - fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=( - None - if cu_seqlens_q_padded is None - else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - **fp8_meta_kwargs, + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + ) = cp_p2p_fwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None else: - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q // 2, - max_seqlen_kv=max_seqlen_kv, - ) - if use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_forward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_forward_kwargs["window_size_left"] = -1 - fa_forward_kwargs["window_size_right"] = -1 - fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - k_part, - v_part, - *fa_forward_args_thd, - causal=False, - **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] else: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, - cu_seqlens_kv_padded, - cp_size, - (rank - i) % cp_size, - True, - True, - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q - cu_seqlens_kv_per_step[i] = cu_seqlens_kv + # all tiles + section = "all" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = torch.cat( - ( - attn_bias[..., idx, :], - attn_bias[..., (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() - - q_part = q - if not enable_mla: - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - qkv_dtype, - fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - **fp8_meta_kwargs, - ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, - ) - fa_outputs = flash_attn_fwd( - q, - k_part, - v_part, - *fa_forward_args_thd, - causal=False, - **fa_forward_kwargs, + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn(*flash_attn_inputs, *prepare_outputs, section) ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] + # softmax_lse correction if i > 0: - # wait until fwd restuls correction of last step is done + # wait until fwd results correction of last step is done if i > 1: flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done) with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if use_fused_attention: - # [b, np, sq, 1] -> [b, np, sq] or - # [t, np, 1] -> [t, np] + # [b, h, sq, 1] -> [b, h, sq] or + # [t, h, 1] -> [t, np] softmax_lse_per_step[i - 1].squeeze_(-1) if softmax_lse_in_packed_format: softmax_lse_per_step[i - 1] = ( softmax_lse_per_step[i - 1].transpose(0, 1).contiguous() ) if fp8: - out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32) + # dequantize out_per_step to torch.float32 + if fp8_recipe.delayed(): + out_per_step[i - 1] = out_per_step[i - 1].dequantize( + dtype=torch.float32 + ) + if fp8_recipe.float8_current_scaling(): + out_per_step[i - 1] = out_per_step[i - 1].to(dtype=torch.float32) + if i == 1: softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": @@ -1383,6 +1613,7 @@ def forward( if causal and rank < (cp_size - 1): second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1] + # fwd output correction: out in torch.float32 for i in range(cp_size): if i <= rank or not causal: if qkv_format in ["bshd", "sbhd"]: @@ -1435,7 +1666,6 @@ def forward( softmax_lse_in_packed_format, ) - kv = p2p_comm_buffers[-1] if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) ctx.batch_size = out.shape[0] @@ -1450,39 +1680,84 @@ def forward( ) if use_fused_attention: if qkv_format == "bshd": - # [b*s, np, hn] -> [b, s, np, hn] + # [b*s, h, d] -> [b, s, h, d] out = out.view(ctx.batch_size, -1, *out.shape[-2:]) elif qkv_format == "sbhd": - # [s*b, np, hn] -> [s, b, np, hn] + # [s*b, h, d] -> [s, b, h, d] out = out.view(-1, ctx.batch_size, *out.shape[-2:]) elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) + # update FP8 quantizers: amax across cp_size steps if fp8 and use_fused_attention: amax_cp_fwd = amax_per_step.amax(dim=1) S_quantizer.amax.copy_(amax_cp_fwd[0]) - O_CP_quantizer.amax.copy_(amax_cp_fwd[1]) - - out_fp8 = None - out_f16 = out.to(qkv_dtype) + O_quantizer.amax.copy_(amax_cp_fwd[1]) - if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): - out_fp8 = O_quantizer(out_f16) # final result + if fp8: + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.forward >> after: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) + # prepare for return and ctx saves + out_fp8 = None + out_f16 = out.to(fwd_nominal_dtype) + if fp8 and ( + is_output_fp8 + or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)) + ): + out_fp8 = O_quantizer(out_f16) out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 - if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, kv_save, out_save = q, kv, out_fp8._data + ctx.layer_number = layer_number + ctx.fp8_recipe = fp8_recipe + ctx.fp8 = fp8 and is_bwd_fp8 + + kv_fp8 = None + kv = p2p_comm_buffers[-1] + if fp8: + q_fp8, kv_fp8 = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8], [q, kv]) + ] + # q, kv, out + fp8_tensors = (None, None, None) + f16_tensors = (None, None, None) + if ctx.fp8: + # fwd: fp8, bwd: fp8, save all fp8 + fp8_tensors = (q_fp8, kv_fp8, out_fp8) + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + f16_tensors = (None, None, out_f16) elif fp8 and is_input_fp8: - q_save, kv_save, out_save = q, kv, out_f16 + # fwd: fp8, bwd: f16, save all f16 + # dequantize fp8 inputs + q_f16 = q_fp8.dequantize() + kv_f16 = kv_fp8.dequantize() + f16_tensors = (q_f16, kv_f16, out_f16) + elif fp8: + # fwd: fp8, bwd: f16, save all f16 + # inputs are already in f16 + q_f16 = q_f16.view(q.shape) + kv_f16 = kv_fp8.dequantize() + f16_tensors = (q_f16, kv_f16, out_f16) else: + # fwd: f16, bwd: f16, save all f16 + # inputs and kernels are both f16 q_f16 = q_f16.view(q.shape) - q_save, kv_save, out_save = q_f16, kv, out_f16 + kv_f16 = kv + f16_tensors = (q_f16, kv_f16, out_f16) tensors_to_save, tensor_objects = prepare_for_saving( - q_save, - kv_save, - out_save, + *fp8_tensors, + *f16_tensors, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, @@ -1512,21 +1787,18 @@ def forward( ctx.use_fused_attention = use_fused_attention ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format ctx.second_half_lse_seqlen = second_half_lse_seqlen - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 ctx.use_flash_attn_3 = use_flash_attn_3 ctx.enable_mla = enable_mla - if enable_mla: - ctx.k_numel = k_numel - ctx.k_shape = k_shape - ctx.v_shape = v_shape + ctx.k_numel = k_numel + ctx.k_shape = k_shape + ctx.v_shape = v_shape - ctx.qkv_dtype = qkv_dtype + ctx.fwd_nominal_dtype = fwd_nominal_dtype ctx.dQKV_quantizer = dQKV_quantizer - ctx.dQKV_CP_quantizer = dQKV_CP_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer ctx.QKV_quantizer = QKV_quantizer @@ -1539,17 +1811,31 @@ def forward( ctx.O_quantizer.scale = O_quantizer.scale.clone() ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer.scale = S_quantizer.scale.clone() - nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") + + nvtx_range_pop(f"{nvtx_label}") return out_ret @staticmethod def backward(ctx, dout): # pylint: disable=missing-function-docstring - nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") + + # add NVTX range + nvtx_label = "transformer_engine.AttnFuncWithCPAndKVP2P.backward" + nvtx_range_push(f"{nvtx_label}") + + # dout is expected to be in FP8 if is_output_fp8=True, + # but in the case it's not, convert it to FP8 before any operation + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorBase): + dout = ctx.dO_quantizer(dout) + if ctx.use_fused_attention: + dout._data = dout._data.contiguous() + elif ctx.use_fused_attention: + dout = dout.contiguous() + + # set up CP groups for cp_comm_type = {'p2p', 'a2a+p2p'} cp_size_a2a = ctx.cp_size_a2a rank_a2a = ctx.rank_a2a - cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] @@ -1559,33 +1845,38 @@ def backward(ctx, dout): device_compute_capability < (10, 0) and cp_size == 2 ) - q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( - restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - ) + # get saved tensors + ( + q_fp8, + kv_fp8, + out_fp8, + q, + kv, + out, + softmax_lse, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *other_tensors, + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) cu_seqlens_q_per_step = other_tensors[:cp_size] cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] rng_states = other_tensors[cp_size * 2 : cp_size * 3] attn_biases = other_tensors[cp_size * 3 : cp_size * 4] + # set up attention args causal = "causal" in ctx.attn_mask_type - padding = "padding" in ctx.attn_mask_type - seq_dim = None + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") - if ctx.enable_mla: - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format - else: - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] - else: - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + # set up attention bias if attn_biases[0] is not None: - # [b, np, sq, 2*cp, sk//(2*cp)] + # [b, h, sq, 2*cp, sk//(2*cp)] attn_dbias = torch.zeros( *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device ) - # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] attn_dbias_ = attn_dbias.view( *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:] ) @@ -1593,6 +1884,7 @@ def backward(ctx, dout): attn_dbias = None attn_dbias_ = None + # set up softmax_lse softmax_lse_ = None if causal and ctx.second_half_lse_seqlen is not None: if ctx.qkv_format == "thd": @@ -1603,86 +1895,124 @@ def backward(ctx, dout): ctx.second_half_lse_seqlen, ) else: - # [b, np, sq] -> [b, np, 2, sq//2] + # [b, h, sq] -> [b, h, 2, sq//2] softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1) softmax_lse_ = softmax_lse_[..., 1, :].contiguous() if ctx.use_fused_attention: if ctx.softmax_lse_in_packed_format: softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous() - # [b, np, sq//2] -> [b, np, sq//2, 1] or - # [t//2, np] -> [t//2, np, 1] + # [b, h, sq//2] -> [b, h, sq//2, 1] or + # [t//2, np] -> [t//2, h, 1] softmax_lse_.unsqueeze_(-1) if ctx.use_fused_attention: if ctx.softmax_lse_in_packed_format: softmax_lse = softmax_lse.transpose(0, 1).contiguous() - # [b, np, sq] -> [b, np, sq, 1] or - # [t, np] -> [t, np, 1] + # [b, h, sq] -> [b, h, sq, 1] or + # [t, np] -> [t, h, 1] softmax_lse.unsqueeze_(-1) - dout = dout.contiguous() - dq = None - dout_dtype = dout.dtype + # assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16 + # used when some tensors are base tensors and loose the "dtype" attribute + bwd_nominal_dtype = ctx.fwd_nominal_dtype + + # convert out, dout to the right type fused_attn_backend = None - fused_attn_dqkv_dtype = None amax_per_step = None dP_quantizer_per_step = [None for _ in range(cp_size)] - dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)] + dQKV_quantizer_per_step = [None for _ in range(cp_size)] + buffer_dtype = torch.uint8 + dq_buffer = None + dout_fp8 = None + bwd_output_te_dtype = None + dkv_buffer = None if ctx.fp8: - if ctx.use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] + assert ctx.use_fused_attention, "FP8 is only supported with Fused Attention!" + fused_attn_backend = FusedAttnBackend["FP8"] + q, kv, out = ( + q_fp8._data, + kv_fp8._data, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8._data + ), + ) - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.dO_quantizer = dout._quantizer - else: - dout = ctx.dO_quantizer(dout) - fused_attn_dqkv_dtype = TE_DType[dout._data.dtype] - dq_fp8 = torch.empty((cp_size, *q.shape), dtype=dout._data.dtype, device=q.device) - dkv_fp8 = torch.empty( - (cp_size, *kv.shape), dtype=dout._data.dtype, device=kv.device - ) - dkv_fp8_ = torch.empty_like(dkv_fp8) - p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] - dout = dout._data - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) - for i in range(cp_size): - dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() - dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) - dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy() - dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + # dout_fp8: Float8Tensor, dtype=bwd_nominal_dtype + # dout: torch.Tensor, dtype=torch.uint8 + if ctx.is_output_fp8: + dout_fp8 = dout else: - assert False, "FP8 is only supported with Fused Attention!" + dout_fp8 = ctx.dO_quantizer(dout) + dout = dout_fp8._data + + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.backward >> before: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) + + # dout_fp8._fp8_dtype + bwd_output_te_dtype = ctx.dO_quantizer.dtype + + # create buffers for reduction in float32 + if ctx.fp8_recipe.delayed(): + dq_buffer = torch.empty( + (cp_size, *q.shape), + dtype=buffer_dtype, + device=q.device, + ) + if ctx.fp8_recipe.float8_current_scaling(): + dq_buffer = torch.empty( + q.shape, + dtype=torch.float32, + device=q.device, + ) + kv_recv_buffer = torch.empty_like(kv) + dkv_send_buffer = torch.empty( + (cp_size, *kv.shape), + dtype=buffer_dtype, + device=kv.device, + ) + dkv_recv_buffer = torch.empty_like(dkv_send_buffer) + p2p_comm_buffers = [[kv, dkv_send_buffer], [kv_recv_buffer, dkv_recv_buffer]] + if ctx.fp8_recipe.float8_current_scaling(): + dkv_buffer = torch.zeros( + kv.shape, + dtype=torch.float32, + device=kv.device, + ) + + # amax_per_step[0]: amax_dp x cp_size + # amax_per_step[1]: amax_dqkv x cp_size + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; + # only used to hold temporary scale/amax values (output only, no quantization op) + for i in range(cp_size): + dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() + dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() + dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: - if ctx.fp8_meta is not None: - if ctx.is_input_fp8: - q = ctx.QKV_quantizer.create_tensor_from_data( - q, fake_dtype=ctx.qkv_dtype, internal=True - ) - kv = ctx.QKV_quantizer.create_tensor_from_data( - kv, fake_dtype=ctx.qkv_dtype, internal=True - ) - q = q.dequantize(dtype=ctx.qkv_dtype) - kv = kv.dequantize(dtype=ctx.qkv_dtype) - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - if cp_size_a2a == 1: - dout = dout.dequantize(dtype=dout_dtype) - else: - ctx.dO_quantizer = dout._quantizer - dout = dout._data - dq = torch.empty_like(q) + if isinstance(dout, QuantizedTensorBase): + dout = dout.dequantize(dtype=bwd_nominal_dtype) + dq_buffer = torch.empty_like(q) p2p_comm_buffers = [ torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), ] p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: - fp8_meta_kwargs = {} - fused_attn_dqkv_dtype = TE_DType[dout_dtype] + bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] + # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: if not ctx.use_fused_attention: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) @@ -1699,11 +2029,6 @@ def backward(ctx, dout): ctx.cp_stream, True, ) - if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: - dout = ctx.dO_quantizer.create_tensor_from_data( - dout, fake_dtype=dout_dtype, internal=True - ) - dout = dout.dequantize(dtype=dout_dtype) if ctx.enable_mla: out = out.view(*ctx.v_shape) @@ -1712,7 +2037,6 @@ def backward(ctx, dout): # MHA or GQA out = out.view(*q.shape) dout = dout.view(*q.shape) - send_recv_reqs = [] flash_attn_bwd = None if not ctx.use_fused_attention: @@ -1747,6 +2071,7 @@ def backward(ctx, dout): if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 + send_recv_reqs = [] for i in range(cp_size): # wait until KV is received for req in send_recv_reqs: @@ -1767,8 +2092,8 @@ def backward(ctx, dout): ) else: dkv_a2a_req = torch.distributed.all_to_all_single( - dkv_fp8, - dkv_fp8_, + dkv_send_buffer, + dkv_recv_buffer, group=ctx.cp_group, async_op=True, ) @@ -1785,593 +2110,146 @@ def backward(ctx, dout): ) kv = p2p_comm_buffers[i % 2][0] - q_, kv_, out_, dout_ = None, None, None, None dq_, dk_, dv_ = None, None, None - if ctx.enable_mla: - k_part = kv[: ctx.k_numel].view(*ctx.k_shape) - v_part = kv[ctx.k_numel :].view(*ctx.v_shape) - # In reversed order of fwd - if causal: - if i == (cp_size - 1): - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_, out_, dout_ = [ - x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] - ] - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[-3:]) - v_part = v_part.view(-1, *v_part.shape[-3:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - elif ctx.qkv_format == "thd": - q_, kv_, out_, dout_ = q, kv, out, dout - if ctx.use_fused_attention: - if ctx.fp8: - aux_ctx_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - i - 1], - ] - else: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - q_part = q_ - if not ctx.enable_mla: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - out_part = out_ - dout_part = dout_ - - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=ctx.attn_mask_type, - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, - ) - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data - else: - dq_ = torch.empty_like(q_) - if ctx.enable_mla: - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - dkv_ = torch.empty_like(kv_) - dk_ = ( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ) - dv_ = ( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ) - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_kv=ctx.max_seqlen_kv, - dq=dq_, - dk=dk_, - dv=dv_, - ) - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_backward_kwargs["window_size"] = (-1, 0) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = 0 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout_, - q_, - k_part, - v_part, - out_, - softmax_lse, - *fa_backward_args_thd, - causal=True, - **fa_backward_kwargs, - ) - elif i >= (cp_size - rank - 1): - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_, out_, dout_ = [ - x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] - ] - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part[:, 0] - v_part = v_part[:, 0] - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_ = kv[:, 0] - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part[0] - v_part = v_part[0] - else: - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_ = kv[0] - elif ctx.qkv_format == "thd": - q_, out_, dout_ = q, out, dout - if ctx.enable_mla: - # [t, np, hn] -> [t/2, np, hn] - k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) - v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) - else: - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) - if ctx.use_fused_attention: - if ctx.enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - else: - kv_ = kv_.contiguous() - if ctx.fp8: - aux_ctx_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - i - 1], - ] - else: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - q_part = q_ - if not ctx.enable_mla: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - out_part = out_ - dout_part = dout_ + k_part = kv[: ctx.k_numel].view(*ctx.k_shape) + v_part = kv[ctx.k_numel :].view(*ctx.v_shape) + q_part, out_part, dout_part = q, out, dout - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( - ctx.max_seqlen_q, - ctx.max_seqlen_kv // 2, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 - ), - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, + prepare_inputs = [ + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_format, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ] + if ctx.use_fused_attention: + fused_attn_inputs = [ + ctx.fp8, + ctx.fp8_recipe, + q_fp8, + kv_fp8, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8 + ), + dout_fp8, + softmax_lse, + softmax_lse_, + rng_states, + attn_dbias, + attn_biases, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + i, + cp_size, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + ctx.softmax_scale, + ctx.dropout_p, + qkv_layout, + ctx.attn_mask_type, + ctx.attn_bias_type, + ctx.deterministic, + ctx.fwd_nominal_dtype, + bwd_nominal_dtype, + bwd_output_te_dtype, + ctx.S_quantizer, + dP_quantizer_per_step[i], + dQKV_quantizer_per_step[i], + ] + else: + flash_attn_inputs = [ + ctx.use_flash_attn_3, + ctx.qkv_format, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + i, + cp_size, + fa_backward_kwargs, + flash_attn_bwd, + rng_states, + softmax_lse, + softmax_lse_, + ] + + # Reverse the steps in forward. In the cp_size x cp_size (i.e. GPU x step) matrix, + # there are still three sections in these tiles based on their attention pattern + # for attn_mask_type = causal, and one for attn_mask_type != causal. + if causal: + if i == (cp_size - 1): + section = "diagonal" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) + if ctx.use_fused_attention: + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data else: - dq_ = torch.empty_like(q_) - if ctx.enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - dkv_ = torch.empty_like(kv_) - dk_ = ( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ) - dv_ = ( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ) - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_kv=ctx.max_seqlen_kv // 2, - dq=dq_, - dk=dk_, - dv=dv_, + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = -1 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout_, - q_, - k_part, - v_part, - out_, - softmax_lse, - *fa_backward_args_thd, - causal=False, - **fa_backward_kwargs, + elif i >= (cp_size - rank - 1): + section = "lower-triangle" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) + if ctx.use_fused_attention: + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section + ) + else: + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) else: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1] - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_, out_, dout_ = q[1], out[1], dout[1] - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[-3:]) - v_part = v_part.view(-1, *v_part.shape[-3:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - elif ctx.qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_, out_, dout_ = [ - tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1) - for x in [q, out, dout] - ] - kv_ = kv + section = "upper-triangle" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) if ctx.use_fused_attention: - q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]] - if ctx.fp8: - aux_ctx_tensors = [ - softmax_lse_, - softmax_lse_, - rng_states[cp_size - i - 1], - ] - else: - aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - - q_part = q_ - if not ctx.enable_mla: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - out_part = out_ - dout_part = dout_ - - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( - ctx.max_seqlen_q // 2, - ctx.max_seqlen_kv, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=( - None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data else: - dq_ = torch.empty_like(q_) - if ctx.enable_mla: - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - dkv_ = torch.empty_like(kv_) - dk_ = ( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ) - dv_ = ( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ) - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q // 2, - max_seqlen_kv=ctx.max_seqlen_kv, - dq=dq_, - dk=dk_, - dv=dv_, - ) - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = -1 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout_, - q_, - k_part, - v_part, - out_, - softmax_lse_, - *fa_backward_args_thd, - causal=False, - **fa_backward_kwargs, + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) else: + section = "all" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) if ctx.use_fused_attention: - if ctx.fp8: - aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]] - else: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - q_part = q - if not ctx.enable_mla: - k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] - v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] - out_part = out - dout_part = dout - - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=ctx.attn_mask_type, - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data - else: - dq_ = torch.empty_like(q) - if ctx.enable_mla: - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] - v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] - dkv_ = torch.empty_like(kv) - dk_ = dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0] - dv_ = dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1] - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_kv=ctx.max_seqlen_kv, - dq=dq_, - dk=dk_, - dv=dv_, - ) - if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): - fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = -1 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout, - q, - k_part, - v_part, - out, - softmax_lse, - *fa_backward_args_thd, - causal=False, - **fa_backward_kwargs, + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - if ctx.fp8: - dq = dq_fp8[(rank + i + 1) % cp_size] + # dq, dk, dv are reduced across steps in higher precision + # DelayedScaling: collect all results in uint8 to one tensor, dequantize to float32, then reduce + # CurrentScaling: dequantize partial results from each step to float32, then reduce + if ctx.fp8 and ctx.use_fused_attention: + if ctx.fp8_recipe.delayed(): + dq_, dk_, dv_ = [x._data for x in [dq_, dk_, dv_]] + if ctx.fp8_recipe.float8_current_scaling(): + dq_, dk_, dv_ = [x.to(torch.float32) for x in [dq_, dk_, dv_]] + + # copy dq_ into the right buffer position + # buffer is cp_size x dq_size for DelayedScaling and the same size as dq for CurrentScaling + if ctx.fp8 and ctx.fp8_recipe.delayed(): + dq = dq_buffer[(rank + i + 1) % cp_size] + else: + dq = dq_buffer if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1): - # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or - # [sq, b, np, hn] -> [2, sq//2, b, np, hn] + # [b, sq, h, d] -> [b, 2, sq//2, h, d] or + # [sq, b, h, d] -> [2, sq//2, b, h, d] dq_ = dq_.view(*dq.shape) - - if ctx.fp8: + if ctx.fp8 and ctx.fp8_recipe.delayed(): if i >= (cp_size - rank - 1) or not causal: dq.copy_(dq_) else: @@ -2381,6 +2259,8 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq[0].fill_(0) dq[1].copy_(dq_) + else: + dq.copy_(dq_) elif causal: if i > (cp_size - rank - 1): dq.add_(dq_) @@ -2416,18 +2296,19 @@ def backward(ctx, dout): else: dq.add_(dq_) + # dbias correction if attn_dbias is not None: idx = (rank + i + 1) % cp_size if i == (cp_size - 1) or not causal: - # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)] + # [b, h, sq, sk//cp] -> [b, h, sq, 2, sk//(2*cp)] dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) attn_dbias[..., idx, :].copy_(dbias_[..., 0, :]) attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) elif i >= (cp_size - rank - 1): - # [b, np, sq, sk//(2*cp)] + # [b, h, sq, sk//(2*cp)] attn_dbias[..., idx, :].copy_(dbias_) else: - # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)] + # [b, h, sq//2, sk//cp] -> [b, h, sq//2, 2, sk//(2*cp)] dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :]) attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) @@ -2436,254 +2317,159 @@ def backward(ctx, dout): for req in send_recv_reqs: req.wait() - if ctx.fp8: - if i < cp_size - 1: - dkv = dkv_fp8_[(rank + i + 1) % cp_size] - else: - dkv = dkv_fp8[(rank + i + 1) % cp_size] + # dkv correction + if ctx.fp8 and ctx.fp8_recipe.delayed(): + dkv = dkv_recv_buffer[(rank + i + 1) % cp_size] + elif ctx.fp8 and ctx.fp8_recipe.float8_current_scaling(): + dkv = dkv_buffer else: dkv = p2p_comm_buffers[(i + 1) % 2][1] - if ctx.use_fused_attention: - if ctx.enable_mla: - dkv_ = None - elif ctx.qkv_format in ["bshd", "sbhd"]: - dkv_ = combine_tensors([dk_, dv_], -2) - elif ctx.qkv_format == "thd": - dkv_ = torch.cat( - (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 - ) # pylint: disable=used-before-assignment - if not ctx.enable_mla and ctx.qkv_format in ["bshd", "sbhd"]: - # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or - # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] - # dkv is a buffer, so we do not need to transpose it, but only need to reshape it. - dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) - dkv_ = dkv_.movedim(-3, 0) - if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): - # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or - # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn] - dkv_ = dkv_.view(*dkv.shape) - - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] or - # [2, sk//2, b, np, hn] - dk = dkv[: ctx.k_numel].view(*ctx.k_shape) - dv = dkv[ctx.k_numel :].view(*ctx.v_shape) - if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): - dk_ = dk_.view(*ctx.k_shape) - dv_ = dv_.view(*ctx.v_shape) - - if ctx.fp8: - # enable_mla and fp8 - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): - if ctx.qkv_format == "bshd": - dk[:, 0, ...].copy_(dk_) - dk[:, 1, ...].fill_(0) - dv[:, 0, ...].copy_(dv_) - dv[:, 1, ...].fill_(0) - elif ctx.qkv_format == "sbhd": - dk[0].copy_(dk_) - dk[1].fill_(0) - dv[0].copy_(dv_) - dv[1].fill_(0) - else: - dk.copy_(dk_) - dv.copy_(dv_) - elif causal: - # enable_mla and not fp8 and causal - if i == (cp_size - 1): - if rank == 0: - if ctx.qkv_format == "bshd": - dk[:, 0, ...].add_(dk_[:, 0, ...]) - dk[:, 1, ...].copy_(dk_[:, 1, ...]) - dv[:, 0, ...].add_(dv_[:, 0, ...]) - dv[:, 1, ...].copy_(dv_[:, 1, ...]) - elif ctx.qkv_format == "sbhd": - dk[0, ...].add_(dk_[0, ...]) - dk[1, ...].copy_(dk_[1, ...]) - dv[0, ...].add_(dv_[0, ...]) - dv[1, ...].copy_(dv_[1, ...]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dk, dk_, cu_seqlens_kv_padded, "add", "copy" - ) - tex.thd_grad_correction( - dv, dv_, cu_seqlens_kv_padded, "add", "copy" - ) - else: - dk.add_(dk_) - dv.add_(dv_) - elif i >= (cp_size - rank - 1): - if i == 0 and rank == (cp_size - 1): - if ctx.qkv_format == "bshd": - dk[:, 0, ...].copy_(dk_) - dv[:, 0, ...].copy_(dv_) - elif ctx.qkv_format == "sbhd": - dk[0, ...].copy_(dk_) - dv[0, ...].copy_(dv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dk, dk_, cu_seqlens_kv_padded, "copy", "none" - ) - tex.thd_grad_correction( - dv, dv_, cu_seqlens_kv_padded, "copy", "none" - ) - else: - if ctx.qkv_format == "bshd": - dk[:, 0, ...].add_(dk_) - dv[:, 0, ...].add_(dv_) - elif ctx.qkv_format == "sbhd": - dk[0, ...].add_(dk_) - dv[0, ...].add_(dv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dk, dk_, cu_seqlens_kv_padded, "add", "none" - ) - tex.thd_grad_correction( - dv, dv_, cu_seqlens_kv_padded, "add", "none" - ) - elif i > 0: - dk.add_(dk_) - dv.add_(dv_) - else: # i == 0 + + # [b, 2, sk//2, h, d] or + # [2, sk//2, b, h, d] + dk = dkv[: ctx.k_numel].view(*ctx.k_shape) + dv = dkv[ctx.k_numel :].view(*ctx.v_shape) + if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): + dk_ = dk_.view(*ctx.k_shape) + dv_ = dv_.view(*ctx.v_shape) + + if ctx.fp8 and ctx.fp8_recipe.delayed(): + # fp8 + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + if ctx.qkv_format == "bshd": + dk[:, 0, ...].copy_(dk_) + dk[:, 1, ...].fill_(0) + dv[:, 0, ...].copy_(dv_) + dv[:, 1, ...].fill_(0) + elif ctx.qkv_format == "sbhd": + dk[0].copy_(dk_) + dk[1].fill_(0) + dv[0].copy_(dv_) + dv[1].fill_(0) + else: dk.copy_(dk_) dv.copy_(dv_) else: - # enable_mla and not fp8 and not causal - if i == 0: - dk.copy_(dk_) - dv.copy_(dv_) - else: # i > 0 + dk.copy_(dk_) + dv.copy_(dv_) + elif causal: + # not fp8 and causal + if i == (cp_size - 1): + if rank == 0: + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_[:, 0, ...]) + dk[:, 1, ...].copy_(dk_[:, 1, ...]) + dv[:, 0, ...].add_(dv_[:, 0, ...]) + dv[:, 1, ...].copy_(dv_[:, 1, ...]) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_[0, ...]) + dk[1, ...].copy_(dk_[1, ...]) + dv[0, ...].add_(dv_[0, ...]) + dv[1, ...].copy_(dv_[1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, "add", "copy") + tex.thd_grad_correction(dv, dv_, cu_seqlens_kv_padded, "add", "copy") + else: dk.add_(dk_) dv.add_(dv_) - else: - if ctx.fp8: - # fp8 - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].copy_(dkv_) - dkv[:, :, 1, ...].fill_(0) + dk[:, 0, ...].copy_(dk_) + dv[:, 0, ...].copy_(dv_) elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].copy_(dkv_) - dkv[:, 1, ...].fill_(0) + dk[0, ...].copy_(dk_) + dv[0, ...].copy_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, "copy", "none") + tex.thd_grad_correction(dv, dv_, cu_seqlens_kv_padded, "copy", "none") else: - dkv.copy_(dkv_) - elif causal: - # not fp8 and causal - if i == (cp_size - 1): - if rank == 0: - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) - dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_[:, 0, ...]) - dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dkv, dkv_, cu_seqlens_kv_padded, "add", "copy" - ) - else: - dkv.add_(dkv_) - elif i >= (cp_size - rank - 1): - if i == 0 and rank == (cp_size - 1): - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].copy_(dkv_) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].copy_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dkv, dkv_, cu_seqlens_kv_padded, "copy", "none" - ) - else: - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dkv, dkv_, cu_seqlens_kv_padded, "add", "none" - ) - elif i > 0: - dkv.add_(dkv_) - else: # i == 0 - dkv.copy_(dkv_) - else: - # not fp8 and not causal - if i == 0: - dkv.copy_(dkv_) - else: # i > 0 - dkv.add_(dkv_) + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_) + dv[:, 0, ...].add_(dv_) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_) + dv[0, ...].add_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, "add", "none") + tex.thd_grad_correction(dv, dv_, cu_seqlens_kv_padded, "add", "none") + elif i > 0: + dk.add_(dk_) + dv.add_(dv_) + else: # i == 0 + dk.copy_(dk_) + dv.copy_(dv_) + else: + # not fp8 and not causal + if i == 0: + dk.copy_(dk_) + dv.copy_(dv_) + else: # i > 0 + dk.add_(dk_) + dv.add_(dv_) + # sum up all cp_size for dq, dk, dv if ctx.fp8 and ctx.use_fused_attention: amax_cp_bwd = amax_per_step.amax(dim=1) ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) - ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1]) - dq = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dq_fp8, fake_dtype=torch.float32, internal=True - ) - - if ctx.enable_mla: - # [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn] - dk_fp8 = dkv_fp8[:, : ctx.k_numel].view(cp_size, *ctx.k_shape) - dv_fp8 = dkv_fp8[:, ctx.k_numel :].view(cp_size, *ctx.v_shape) - dk = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dk_fp8, fake_dtype=torch.float32, internal=True - ) - dv = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dv_fp8, fake_dtype=torch.float32, internal=True - ) - dq, dk, dv = [x.dequantize(dtype=torch.float32) for x in [dq, dk, dv]] - dq, dk, dv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dk, dv]] - else: - if ctx.qkv_format in ["bshd", "sbhd"]: - # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or - # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] - dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) - dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dkv_fp8, fake_dtype=torch.float32, internal=True + ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) + + dq = dq_buffer + if ctx.fp8_recipe.delayed(): + # [cp, b, 2, sk//2, h, d] or [cp, 2, sk//2, b, h, d] + dk = dkv_recv_buffer[:, : ctx.k_numel].view(cp_size, *ctx.k_shape) + dv = dkv_recv_buffer[:, ctx.k_numel :].view(cp_size, *ctx.v_shape) + dq, dk, dv = [ + ctx.dQKV_quantizer.create_tensor_from_data( + x, fake_dtype=bwd_nominal_dtype, internal=ctx.dQKV_quantizer.internal + ) + for x in [dq, dk, dv] + ] + dq, dk, dv = combine_and_dequantize( + qkv_layout, + dq, + dk, + dv, + src_nominal_dtype=bwd_nominal_dtype, + des_nominal_dtype=torch.float32, ) - dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]] - dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] + dq, dk, dv = [x.sum(dim=0).to(bwd_nominal_dtype) for x in [dq, dk, dv]] - if causal: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - dk = dk.view(dk.shape[0], -1, *dk.shape[-2:]) - dv = dv.view(dv.shape[0], -1, *dv.shape[-2:]) - else: - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - dq = dq.view(-1, *dq.shape[-3:]) - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - dk = dk.view(-1, *dk.shape[-3:]) - dv = dv.view(-1, *dv.shape[-3:]) - else: - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) + if ctx.fp8_recipe.float8_current_scaling(): + dk = dkv[: ctx.k_numel].view(ctx.k_shape) + dv = dkv[ctx.k_numel :].view(ctx.v_shape) + + if causal and ctx.qkv_format in ["bshd", "sbhd"]: + # [b, 2, s//2, h, d] -> [b, s, h, d] + # [2, s//2, b, h, d] -> [s, b, h, d] + dim = ctx.qkv_format.index("s") + dq, dk, dv = [x.view(*x.shape[:dim], -1, *x.shape[dim + 2 :]) for x in [dq, dk, dv]] if ctx.qkv_format == "thd" and not ctx.use_fused_attention: dq[cu_seqlens_q_padded[-1] :].fill_(0) - if ctx.enable_mla: - dk[cu_seqlens_kv_padded[-1] :].fill_(0) - dv[cu_seqlens_kv_padded[-1] :].fill_(0) - else: - dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) + dk[cu_seqlens_kv_padded[-1] :].fill_(0) + dv[cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - assert torch.uint8 not in [dq.dtype, dkv.dtype] - if ctx.enable_mla: - dq, dk, dv = [ctx.dQKV_quantizer(x)._data for x in [dq, dk, dv]] - else: - dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] - if not ctx.enable_mla: - dk, dv = dkv[0], dkv[1] + dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + + if ctx.fp8: + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.backward >> after: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) if cp_size_a2a > 1: + if ctx.fp8 and ctx.is_input_fp8: + dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv + dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], @@ -2694,20 +2480,21 @@ def backward(ctx, dout): ctx.cp_stream, False, ) + if ctx.fp8 and ctx.is_input_fp8: + dq, dk, dv = [ + Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) + for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) + ] if ctx.qkv_format == "bshd": dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] elif ctx.qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] if attn_dbias is not None: - # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] + # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, sq, sk] attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) - # converting torch.uint8 to float8tensor - if ctx.fp8 and ctx.is_input_fp8: - dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) - dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) - dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) - nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward") + + nvtx_range_pop(f"{nvtx_label}") return ( None, @@ -2736,6 +2523,8 @@ def backward(ctx, dout): None, None, None, + None, + None, ) @@ -2865,22 +2654,22 @@ def forward( else: cu_seqlens_q_padded = None - # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn] + # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) - # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn] + # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] - # [s, b, np, hn] -> [cp, s, b, np, hn] + # [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) cp_stream.wait_stream(torch.cuda.current_stream()) @@ -2900,8 +2689,8 @@ def forward( for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] q_ = q.select(seq_dim, i).contiguous() kv_seq_range_per_step[i], window_size_per_step[i] = ( get_kv_seq_info_after_all_gather( @@ -2923,7 +2712,7 @@ def forward( k.shape[1], max_seqlen_kv_, k.device ) k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd( @@ -3059,17 +2848,17 @@ def backward(ctx, dout): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() - # [s, b, np, hn] -> [cp, s, b, np, hn] + # [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) ctx.cp_stream.wait_stream(torch.cuda.current_stream()) @@ -3110,8 +2899,8 @@ def backward(ctx, dout): for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] q_ = q.select(seq_dim, i).contiguous() seq_start_idx, seq_end_idx = ( kv_seq_range_per_step[i][0], @@ -3119,13 +2908,13 @@ def backward(ctx, dout): ) max_seqlen_kv = seq_end_idx - seq_start_idx k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] out_ = out_per_step[i] dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] - dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd( + dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, cu_seqlens_q, @@ -3192,7 +2981,7 @@ def backward(ctx, dout): dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": dq[i - 1].copy_(dq_per_step[i - 1]) - # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn] + # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] dk_per_step[i - 1], dv_per_step[i - 1] = [ x.movedim(seq_dim, 0).contiguous() for x in [dk_per_step[i - 1], dv_per_step[i - 1]] @@ -3211,13 +3000,13 @@ def backward(ctx, dout): torch.cuda.current_stream().wait_stream(ctx.cp_stream) - # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn] + # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] dk = dk.view(-1, *dk.shape[-3:]) dv = dv.view(-1, *dv.shape[-3:]) dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) @@ -3286,6 +3075,9 @@ def forward( cp_stream, quantizers, use_flash_attn_3, + softmax_type, + softmax_offset, + fp8_output, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") @@ -3293,7 +3085,6 @@ def forward( softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) - qkv_dtype = q.dtype causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -3357,32 +3148,37 @@ def forward( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_output + is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; + # may be different from fp8_meta["recipe"] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + fwd_nominal_dtype = q.dtype fused_attn_backend = None - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = False QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + dpa_utils.get_attention_quantizers(fp8, quantizers) ) + + q_fp8, k_fp8, v_fp8 = (None, None, None) if fp8: if use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha if is_input_fp8: - QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_f16, k_f16, v_f16 = q, k, v - q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] + else: + q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = S_quantizer - fp8_meta_kwargs["o_quantizer"] = O_quantizer # partial result quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer else: assert False, "FP8 is only supported with Fused Attention!" else: @@ -3394,25 +3190,23 @@ def forward( q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True ) + if softmax_type != "vanilla": + softmax_offset = flash_attn_a2a_communicate_softmax_offset( + softmax_offset, 1, cp_size, cp_group, cp_stream, True + ) - if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_f16, k_f16, v_f16 = q, k, v - q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] - + out_fp8 = None + out_f16 = None batch_size = q.shape[batch_dim] + q_part, k_part, v_part = q, k, v + out_part = None if use_fused_attention: - q_part, k_part, v_part = q, k, v if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v, fake_dtype=qkv_dtype, internal=True - ) - out, aux_ctx_tensors = fused_attn_fwd( + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + out_, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -3421,7 +3215,7 @@ def forward( q_part, k_part, v_part, - qkv_dtype, + fwd_nominal_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -3433,9 +3227,27 @@ def forward( cu_seqlens_kv_padded=cu_seqlens_kv_padded, window_size=window_size, **fp8_meta_kwargs, + softmax_type=softmax_type, + softmax_offset=softmax_offset, ) - if fp8: - out = out._data + if isinstance(out_, Float8Tensor): + out_fp8 = out_ + out_ = out_._data + if is_bwd_fp8 and not ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ): + out_part = out_fp8 + else: + out_part = out_fp8.dequantize(dtype=fwd_nominal_dtype) + else: + out_f16 = out_ + out_part = out_ + if ( + fp8 + and is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + ): + out_part = O_quantizer(out_) else: fa_forward_args_thd = get_fa_args( True, @@ -3447,67 +3259,67 @@ def forward( max_seqlen_kv=max_seqlen_kv, ) fa_outputs = flash_attn_fwd( - q, - k, - v, + q_part, + k_part, + v_part, *fa_forward_args_thd, causal=causal, **fa_forward_kwargs, ) if not fa_utils.v2_7_0_plus: - out, softmax_lse = fa_outputs[4], fa_outputs[5] + out_, softmax_lse = fa_outputs[4], fa_outputs[5] rng_state = fa_outputs[7] if not use_flash_attn_3 else None else: - out, softmax_lse = fa_outputs[0], fa_outputs[1] + out_, softmax_lse = fa_outputs[0], fa_outputs[1] rng_state = fa_outputs[3] if not use_flash_attn_3 else None aux_ctx_tensors = [softmax_lse, rng_state] + out_part = out_ - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device) - out = flash_attn_a2a_communicate( - out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) + out_ = flash_attn_a2a_communicate( + out_, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False ) if use_fused_attention: if qkv_format == "bshd": - # [b*s, np, hn] -> [b, s, np, hn] - out = out.view(batch_size, -1, *out.shape[-2:]) + # [b*s, h, d] -> [b, s, h, d] + out_ = out_.view(batch_size, -1, *out_.shape[-2:]) elif qkv_format == "sbhd": - # [s*b, np, hn] -> [s, b, np, hn] - out = out.view(-1, batch_size, *out.shape[-2:]) + # [s*b, h, d] -> [s, b, h, d] + out_ = out_.view(-1, batch_size, *out_.shape[-2:]) - if fp8: - if is_output_fp8: - out_fp8 = O_quantizer.create_tensor_from_data( - out, fake_dtype=qkv_dtype, internal=False - ) - out_ret = out_fp8 - out = out_fp8._data - else: - out_fp8 = O_quantizer.create_tensor_from_data( - out, fake_dtype=qkv_dtype, internal=True - ) - out_f16 = out_fp8.dequantize(dtype=qkv_dtype) - out_ret = out_f16 + if fp8 and use_fused_attention: + if fp8_recipe.float8_current_scaling(): + out_f16 = out_ + if is_output_fp8: + out_fp8 = O_quantizer(out_) + if fp8_recipe.delayed(): + out_fp8 = Float8Tensor.make_like(out_fp8, data=out_, dtype=fwd_nominal_dtype) + if not is_output_fp8: + out_f16 = out_fp8.dequantize(dtype=fwd_nominal_dtype) else: - out_ret = out + out_f16 = out_ - if not fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, k_save, v_save, out_save = q, k, v, out - else: - if is_input_fp8: - q_save, k_save, v_save = q, k, v - else: - q_save, k_save, v_save = q_f16, k_f16, v_f16 - if is_output_fp8: - out_save = out + out_ret = out_fp8 if is_output_fp8 else out_f16 + + ctx.fp8 = fp8 and is_bwd_fp8 + fp8_tensors = (None, None, None, None) + f16_tensors = (None, None, None, None) + if ctx.fp8: + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_part, k_part, v_part, None) + f16_tensors = (None, None, None, out_part) else: - out_save = out_f16 + fp8_tensors = (q_part, k_part, v_part, out_part) + elif fp8: + q_part, k_part, v_part = combine_and_dequantize(qkv_layout, q_part, k_part, v_part) + f16_tensors = (q_part, k_part, v_part, out_part) + else: + f16_tensors = (q_part, k_part, v_part, out_part) tensors_to_save, tensor_objects = prepare_for_saving( - q_save, - k_save, - v_save, - out_save, + *fp8_tensors, + *f16_tensors, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, @@ -3516,6 +3328,7 @@ def forward( ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects + ctx.out_shape = out_ret.shape ctx.batch_size = batch_size ctx.cp_group = cp_group @@ -3530,13 +3343,14 @@ def forward( ctx.deterministic = deterministic ctx.window_size = window_size ctx.use_fused_attention = use_fused_attention - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.fp8_recipe = fp8_recipe ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.softmax_type = softmax_type - ctx.qkv_dtype = qkv_dtype ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer @@ -3560,6 +3374,10 @@ def backward(ctx, dout): cp_size = get_distributed_world_size(ctx.cp_group) ( + q_fp8, + k_fp8, + v_fp8, + out_fp8, q, k, v, @@ -3570,23 +3388,21 @@ def backward(ctx, dout): cu_seqlens_kv_padded, *aux_ctx_tensors, ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type seq_dim = ctx.qkv_format.index("s") - dout_dtype = dout.dtype + bwd_nominal_dtype = ctx.fwd_nominal_dtype + dqkv_te_dtype = None fused_attn_backend = None - fused_attn_dqkv_dtype = None + dout_fp8 = dout if ctx.fp8: if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.dO_quantizer = dout._quantizer - else: + if not isinstance(dout, QuantizedTensorBase): dout = ctx.dO_quantizer(dout) - fused_attn_dqkv_dtype = TE_DType[dout._data.dtype] + dout_fp8 = dout + dqkv_te_dtype = dout._fp8_dtype dout = dout._data fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer @@ -3596,44 +3412,23 @@ def backward(ctx, dout): else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None: - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.dO_quantizer = dout._quantizer - dout = dout._data - if ctx.is_input_fp8: - q = ctx.QKV_quantizer.create_tensor_from_data( - q, fake_dtype=ctx.qkv_dtype, internal=True - ) - k = ctx.QKV_quantizer.create_tensor_from_data( - k, fake_dtype=ctx.qkv_dtype, internal=True - ) - v = ctx.QKV_quantizer.create_tensor_from_data( - v, fake_dtype=ctx.qkv_dtype, internal=True - ) - q, k, v = [x.dequantize(dtype=ctx.qkv_dtype) for x in [q, k, v]] + if isinstance(dout, QuantizedTensorBase): + dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: fp8_meta_kwargs = {} - fused_attn_dqkv_dtype = TE_DType[dout_dtype] + dqkv_te_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] if not ctx.use_fused_attention: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(*out.shape) + dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:]) + else: + dout = dout.view(*ctx.out_shape) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, out.device) - out, dout = flash_attn_a2a_communicate( - [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device) + dout = flash_attn_a2a_communicate( + dout, chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True ) - if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: - out = ctx.O_quantizer.create_tensor_from_data( - out, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout = ctx.dO_quantizer.create_tensor_from_data( - dout, fake_dtype=dout_dtype, internal=True - ) - out = out.dequantize(dtype=ctx.qkv_dtype) - dout = dout.dequantize(dtype=dout_dtype) flash_attn_bwd = None if not ctx.use_fused_attention: @@ -3674,31 +3469,15 @@ def backward(ctx, dout): if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 + dq_fp8, dk_fp8, dv_fp8 = None, None, None if ctx.use_fused_attention: - q_part = q - k_part = k - v_part = v - out_part = out - dout_part = dout - + q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - - dq, dk, dv, _ = fused_attn_bwd( + q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + out_part = out + dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, @@ -3708,8 +3487,8 @@ def backward(ctx, dout): v_part, out_part, dout_part, - dout_dtype, - fused_attn_dqkv_dtype, + bwd_nominal_dtype, + dqkv_te_dtype, aux_ctx_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, @@ -3722,11 +3501,11 @@ def backward(ctx, dout): window_size=ctx.window_size, deterministic=ctx.deterministic, **fp8_meta_kwargs, + softmax_type=ctx.softmax_type, ) - if ctx.fp8: - dq = dq._data - dk = dk._data - dv = dv._data + if isinstance(dq, Float8Tensor): + dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv + dq, dk, dv = [x._data for x in [dq, dk, dv]] else: softmax_lse, rng_state = aux_ctx_tensors dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] @@ -3756,7 +3535,7 @@ def backward(ctx, dout): **fa_backward_kwargs, ) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, q.device) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False ) @@ -3766,18 +3545,34 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + d_bias = None + d_softmax_offset = None + if ctx.use_fused_attention: + if ctx.attn_bias_type not in ["no_bias", "alibi"]: + d_bias = rest[0] + if ctx.softmax_type != "vanilla": + d_softmax_offset = rest[1] + d_softmax_offset = flash_attn_a2a_communicate_softmax_offset( + d_softmax_offset, 1, cp_size, ctx.cp_group, ctx.cp_stream, False + ) + if ctx.fp8: - dq = ctx.dQKV_quantizer.create_tensor_from_data( - dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 - ) - dk = ctx.dQKV_quantizer.create_tensor_from_data( - dk, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 - ) - dv = ctx.dQKV_quantizer.create_tensor_from_data( - dv, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 - ) - if not ctx.is_input_fp8: - dq, dk, dv = [x.dequantize(dtype=dout_dtype) for x in [dq, dk, dv]] + if ctx.fp8_recipe.float8_current_scaling() and ctx.is_input_fp8: + dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + if ctx.fp8_recipe.delayed(): + dq, dk, dv = [ + Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) + for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) + ] + if not ctx.is_input_fp8: + dq, dk, dv = combine_and_dequantize( + qkv_layout, + dq, + dk, + dv, + src_nominal_dtype=bwd_nominal_dtype, + ) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( @@ -3796,6 +3591,7 @@ def backward(ctx, dout): None, None, None, + d_bias, None, None, None, @@ -3806,6 +3602,8 @@ def backward(ctx, dout): None, None, None, + d_softmax_offset, + None, ) @@ -3838,6 +3636,10 @@ def attn_forward_func_with_cp( quantizers=None, pad_between_seqs=False, use_flash_attn_3=False, + softmax_type="vanilla", + softmax_offset=None, + fp8_output=False, + layer_number=1, ) -> torch.Tensor: """ Attention implementation with context parallelism (CP). CP partitions tensors along the sequence @@ -3901,10 +3703,15 @@ def attn_forward_func_with_cp( """ if cp_comm_type == "a2a+p2p": - assert isinstance( - cp_group, list - ), "Hierarchical CP implementation needs multi-level CP groups!" - assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" + assert ( + isinstance(cp_group, list) and len(cp_group) == 2 + ), "CP implementation a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!" + assert ( + qkv_format != "thd" + ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!" + assert ( + attn_bias_type == "no_bias" + ), f"{attn_bias_type} bias type is not supported with hierarchical CP implementation yet!" if get_distributed_world_size(cp_group[0]) == 1: cp_group = cp_group[1] cp_comm_type = "p2p" @@ -3914,23 +3721,23 @@ def attn_forward_func_with_cp( else: assert isinstance( cp_group, dist_group_type - ), f"Unsupported process group for CP communication type {cp_comm_type}!" + ), f"cp_group must be {dist_group_type} type for {cp_comm_type=}!" assert qkv_format in [ "bshd", "sbhd", "thd", - ], f"QKV format of {qkv_format} is not supported with context parallelism!" + ], f"Context parallelism does not support {qkv_format=}!" assert ( qkv_format != "sbhd" or use_fused_attention - ), "FlashAttention does not support sbhd format!" + ), "Context parallelism does not support FlashAttention backend with qkv_format = 'sbhd'!" assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( - """Attention bias is only supported with FusedAttention and "causal" """ - """or "no_mask" mask types!""" + "Context parallelism only supports attention bias with FusedAttention backend and" + " non-padding mask types!" ) assert qkv_format != "thd" or ( cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None - ), "cu_seqlens_padded cannot be None with context parallelism + THD format!" + ), "cu_seqlens_padded can not be None for context parallelism and qkv_format = 'thd'!" sliding_window_attn = ( window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) @@ -3938,13 +3745,28 @@ def attn_forward_func_with_cp( assert not sliding_window_attn or cp_comm_type in [ "a2a", "all_gather", - ], "The context parallel running configs cannot support sliding window attetnion!" + ], "Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] assert not enable_mla or cp_comm_type in [ "p2p", "a2a+p2p", - ], "The context parallel running configs cannot support MLA!" + ], "Context parallelism does not support MLA with {cp_comm_type=}!" + + if fp8 and fp8_meta is not None: + if fp8_meta["recipe"].fp8_dpa: + assert ( + softmax_type == "vanilla" + ), "Context parallelism does not support {softmax_type=} with FP8 attention!" + assert ( + softmax_type == "vanilla" or use_fused_attention + ), "Context parallelism only supports {softmax_type=} with FusedAttention backend!" + assert ( + softmax_type == "vanilla" or cp_comm_type == "a2a" + ), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" + assert ( + softmax_type == "vanilla" or qkv_format != "thd" + ), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" args = [ is_training, @@ -3977,6 +3799,8 @@ def attn_forward_func_with_cp( quantizers, pad_between_seqs, use_flash_attn_3, + fp8_output, + layer_number, ] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": @@ -3985,7 +3809,18 @@ def attn_forward_func_with_cp( args += [window_size, cp_group, cp_stream, use_flash_attn_3] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": - args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3] + args += [ + window_size, + fp8, + fp8_meta, + cp_group, + cp_stream, + quantizers, + use_flash_attn_3, + softmax_type, + softmax_offset, + fp8_output, + ] out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index b35b87a83..a19d08ae5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -11,10 +11,25 @@ import logging import torch +from torch.nn.parameter import Parameter import transformer_engine_torch as tex +from transformer_engine.common.recipe import ( + Format, + Recipe, + DelayedScaling, + Float8CurrentScaling, +) from transformer_engine.pytorch.utils import get_cudnn_version -from transformer_engine.pytorch.fp8 import get_fp8_te_dtype +from transformer_engine.pytorch.fp8 import ( + get_fp8_te_dtype, + FP8GlobalStateManager, + RecipeState, + DelayedScalingRecipeState, + MXFP8BlockScalingRecipeState, + Float8CurrentScalingRecipeState, + Float8BlockScalingRecipeState, +) from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.export import is_in_onnx_export_mode @@ -72,6 +87,67 @@ "_alibi_bias_require_update": False, } +""" +This feature is **experimental** and subject to change. + +Some models may use different FP8 recipes for their linear layers and attention layers. To support this, +users can either use multiple, nested fp8_autocast() contexts to assign a distinct recipe for each layer, +or use a single fp8_autocast() for the non-attention layers and configure the recipe for the attention +layers as follows. + ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| Linear | Attention | Configuration | ++===================+===========+===================================================================================+ +| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to fp8_autocast(); | +| | | export NVTE_DPA_FP8_RECIPE="F16" | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8DS | FP8DS | Pass FP8DS to fp8_autocast(); | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8CS | FP8DS | Pass FP8CS to fp8_autocast(); | +| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; | +| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8DS | Pass NVFP4 to fp8_autocast(); | +| | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; | +| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8DS | FP8CS | Pass FP8DS to fp8_autocast(); | +| | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,| +| | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; | +| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8CS | FP8CS | Pass FP8CS to fp8_autocast(); | +| | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe | +| | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8CS | Pass NVFP4 to fp8_autocast(); | +| | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe | +| | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: | +| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +""" +_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") +formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2} +_dpa_fp8_format = formats[os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID")] +_dpa_fp8ds_amax_algo = os.getenv("NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent") +_dpa_fp8ds_amax_histlen = int(os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1")) +_dpa_fp8ds_reduce_amax = os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") == "1" + + __all__ = ["DotProductAttention"] @@ -168,6 +244,17 @@ class DotProductAttention(TransformerEngineBaseModule): softmax_scale: Optional[float], default = `None` softmax scale for the attention scores. If `None`, defaults to `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. + softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' + softmax type as described in this paper: + `Efficient Streaming Language Models with Attention Sinks + `_. + For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], + 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + where alpha is a learnable parameter in shape [h]. + 'off-by-one' and 'learnable' softmax types are also called sink attention + ('zero sink' and 'learnable sink'). Parallelism parameters ---------------------- @@ -223,6 +310,7 @@ def __init__( cp_stream: torch.cuda.Stream = None, cp_comm_type: str = "p2p", softmax_scale: Optional[float] = None, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -307,6 +395,20 @@ def __init__( self.attention_type = attention_type self.attention_dropout = attention_dropout + self.softmax_type = softmax_type + if self.softmax_type == "vanilla": + self.softmax_offset = None + if self.softmax_type == "off-by-one": + self.softmax_offset = torch.zeros( + self.num_attention_heads // self.tp_size, device="cuda" + ) + if self.softmax_type == "learnable": + self.register_parameter( + "softmax_offset", + Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")), + get_rng_state_tracker=get_rng_state_tracker, + ) + attn_kwargs = { "attention_dropout": attention_dropout, "attention_dropout_ctx": attention_dropout_ctx, @@ -328,6 +430,7 @@ def __init__( layer_number=layer_number, deterministic=self.deterministic, **attn_kwargs, + softmax_type=self.softmax_type, ) self.unfused_attention = UnfusedDotProductAttention( @@ -335,6 +438,7 @@ def __init__( attention_type=attention_type, **attn_kwargs, layer_number=layer_number, + softmax_type=self.softmax_type, ) def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument @@ -433,6 +537,231 @@ def set_context_parallel_group( self.cp_stream = cp_stream self.cp_comm_type = cp_comm_type + def init_fp8_metadata(self, num_gemms: int = 1) -> None: + """ + Override TransformerEngineBaseModule.init_fp8_metadata to allow for more flexible recipe support. + Initialize fp8 related metadata and tensors during fprop. + """ + _original_recipe = self.fp8_meta.get("recipe", None) + + # global recipe set in fp8_autocast() + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + + # switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to + # a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well. + # + # fp8_recipe | NVTE_DPA_FP8_RECIPE | self.fp8_meta["recipe"] | self.quantizers + # -------------------------------------------------------------------------------------------- + # DelayedScaling (DS) | unset | DS | all DS + # Float8CurrentScaling (CS) | unset | DS | CS for QKV, O, dO, dQKV; DS for S, dP + # x={DS, CS} | y | refer to row x=y | refer to row x=y + fp8_recipe_dpa = fp8_recipe + fp8_recipes = fp8_recipe + if _dpa_fp8_recipe == "F16": + # ignore the recipe from fp8_autocast, set fp8_dpa = False, fp8_mha = False + fp8_recipe.fp8_dpa = False + fp8_recipe.fp8_mha = False + elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe + fake_recipe = DelayedScaling( + fp8_format=fp8_recipe.fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "DelayedScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a DS recipe + fake_recipe = DelayedScaling( + fp8_format=_dpa_fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa + elif fp8_recipe.delayed() and _dpa_fp8_recipe == "Float8CurrentScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a CS+DS recipe + fake_recipes = [ + Float8CurrentScaling( + fp8_format=fp8_recipe.fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ), + fp8_recipe, + ] + fp8_recipe_dpa = fake_recipes[1] + fp8_recipes = fake_recipes + elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe in ( + "", + "Float8CurrentScaling", + ): + # use fp8_recipe for QKV, O, dO, dQKV, and construct a DS recipe for S, dP + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe + fake_recipe = DelayedScaling( + fp8_format=fp8_recipe.fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = [fp8_recipe, fp8_recipe_dpa] + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "Float8CurrentScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format + # construct a CS recipe for QKV, O, dO, dQKV and a DS recipe for S, dP + fake_recipes = [ + Float8CurrentScaling( + fp8_format=_dpa_fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ), + DelayedScaling( + fp8_format=_dpa_fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ), + ] + fp8_recipe_dpa = fake_recipes[1] + fp8_recipes = fake_recipes + # DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False + if not fp8_recipe_dpa.float8_per_tensor_scaling(): + assert not ( + fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha + ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" + + # reduce over TP+CP groups; expect fp8_group to be set up so + # assume attention uses the same fp8_group as GEMMs + fp8_group = FP8GlobalStateManager.get_fp8_group() + + self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + self.fp8 = FP8GlobalStateManager.is_fp8_enabled() + self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + fp8_enabled = self.fp8 or self.fp8_calibration + self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration + if self.fp8_parameters or fp8_enabled: + self.fp8_meta["global_recipe"] = fp8_recipe + self.fp8_meta["local_recipes"] = ( + fp8_recipes if isinstance(fp8_recipes, List) else [fp8_recipes] + ) + + if self.fp8_parameters or fp8_enabled: + if self.fp8_initialized and fp8_recipe_dpa == self.fp8_meta["recipe"]: + # FP8 init has already been run and recipe is the same, don't do anything. + return + self.fp8_meta["recipe"] = fp8_recipe_dpa + if fp8_recipe != fp8_recipe_dpa: + # fp8_recipe has changed, rehash the key. + autocast_key = FP8GlobalStateManager.get_unique_autocast_key( + fp8_recipe_dpa, fp8_group + ) + FP8GlobalStateManager.autocast_arguments[autocast_key] = ( + fp8_recipe_dpa, + fp8_group, + ) + else: + # If fp8 isn't enabled, turn off and return. + self.fp8_initialized = False + return + + if self.fp8_parameters and not self.fp8_initialized: + self.fp8_meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(fp8_recipes) + + if fp8_enabled: + # Set FP8 and other FP8 metadata + self.fp8_meta["num_gemms"] = num_gemms + self.fp8_meta["fp8_group"] = fp8_group + + # Set FP8_MAX per tensor according to recipe + self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd + self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + + # Allocate scales and amaxes + self.init_fp8_meta_tensors(fp8_recipes) + self.fp8_initialized = True + + self.fp8_meta["recipe"] = fp8_recipe_dpa + if fp8_recipe != fp8_recipe_dpa: + # fp8_recipe has changed, rehash the key. + autocast_key = FP8GlobalStateManager.get_unique_autocast_key( + fp8_recipe_dpa, fp8_group + ) + FP8GlobalStateManager.autocast_arguments[autocast_key] = ( + fp8_recipe_dpa, + fp8_group, + ) + + _current_recipe = self.fp8_meta["recipe"] + if _original_recipe is not None and not ( + issubclass(_current_recipe.__class__, _original_recipe.__class__) + or issubclass(_original_recipe.__class__, _current_recipe.__class__) + ): + warnings.warn( + f"Recipe type changed from {_original_recipe.__class__.__name__} " + f"to {_current_recipe.__class__.__name__}. " + "This may affect model behavior." + ) + # Clear cached workspaces as they were created with the old recipe/quantizer type + self._fp8_workspaces.clear() + + def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> None: + """Override to allow multiple recipes. Init scales and amaxes for fwd | bwd.""" + if isinstance(recipe, Recipe): + recipe = [recipe] + fp8_recipe_dpa = recipe[-1] + fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" + + # Return early if recipe state matches recipe + if self.fp8_meta_tensors_initialized: + recipe_state = self.fp8_meta[fp8_meta_tensor_key] + if fp8_recipe_dpa.delayed() and isinstance(recipe_state, DelayedScalingRecipeState): + self.adjust_amax_history_length(fp8_recipe_dpa.amax_history_len, fwd=fwd) + return + if fp8_recipe_dpa.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState): + return + if fp8_recipe_dpa.float8_current_scaling() and isinstance( + recipe_state, Float8CurrentScalingRecipeState + ): + return + if fp8_recipe_dpa.float8_block_scaling() and isinstance( + recipe_state, Float8BlockScalingRecipeState + ): + return + + # When fp8_recipe=Float8CurrentScaling, recipe=[CS, DS], and QKV/dQKV, O/dO use CS quantizers, S/dP use DS quantizers. + # See table above in init_fp8_metadata for more detail. + num_gemms = [2, 1] if len(recipe) == 2 else [3] + # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and + # 2 (grad_output and grad_input) for bwd + num_fp8_tensors = [x * 3 if fwd else x * 2 for x in num_gemms] + + # Initialize recipe state and quantizers + recipe_states = [ + RecipeState.create( + recipe[i], + mode=("forward" if fwd else "backward"), + num_quantizers=num_fp8_tensors[i], + ) + for i in range(len(recipe)) + ] + + self.fp8_meta[fp8_meta_tensor_key] = ( + recipe_states[-1] if len(recipe) == 2 else recipe_states[0] + ) + self.quantizers[fp8_meta_tensor_key] = [] + for recipe_state in recipe_states: + self.quantizers[fp8_meta_tensor_key].extend(recipe_state.make_quantizers()) + @no_torch_dynamo(recursive=False) def forward( self, @@ -456,6 +785,7 @@ def forward( fast_zero_fill: bool = True, inference_params: Optional[InferenceParams] = None, pad_between_seqs: Optional[bool] = None, + fp8_output: Optional[bool] = False, ) -> torch.Tensor: """ Dot Product Attention Layer. @@ -628,12 +958,15 @@ def forward( pad_between_seqs: Optional[bool], default = `None` If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. If true, there are padding tokens between individual sequences in a packed batch. + fp8_output: Optional[bool], default = `False` + Whether to enforce output to be in FP8 or not. """ with torch.cuda.device(query_layer.device), self.prepare_forward( query_layer, num_gemms=3, allow_non_contiguous=True, + allow_different_data_and_param_types=self.softmax_type != "vanilla", ) as query_layer: # checks for RNG if self.rng_states_tracker is not None and is_graph_capturing(): @@ -663,6 +996,8 @@ def forward( tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types.""" + else: + fp8_output = False # checks for q/k/v shapes assert ( @@ -922,6 +1257,7 @@ def forward( False ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" + # check if there is padding between sequences when qkv_format='thd' if pad_between_seqs is None: if qkv_format == "thd": pad_between_seqs = ( @@ -957,11 +1293,13 @@ def forward( pad_between_seqs=pad_between_seqs, attention_dropout=self.attention_dropout, context_parallel=context_parallel, + cp_comm_type=self.cp_comm_type, deterministic=self.deterministic, is_training=self.training, fp8=self.fp8, fp8_meta=self.fp8_meta, inference_params=inference_params, + softmax_type=self.softmax_type, ) global _attention_backends if is_in_onnx_export_mode(): @@ -1022,6 +1360,12 @@ def forward( ) # run attention + softmax_offset = ( + self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32) + if self.softmax_offset is not None + else None + ) + if use_flash_attention: if core_attention_bias_type == "alibi": alibi_slopes, _ = dpa_utils.get_alibi( @@ -1053,6 +1397,7 @@ def forward( quantizers=self.quantizers, inference_params=inference_params, flash_attention_backend=flash_attention_backend, + fp8_output=fp8_output, ) if use_fused_attention: @@ -1071,7 +1416,6 @@ def forward( bias_dtype=query_layer.dtype, bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], ) - # checkpoint_core_attention=False if checkpoint_core_attention: return self._checkpointed_attention_forward( self.fused_attention, @@ -1101,6 +1445,8 @@ def forward( quantizers=self.quantizers, pad_between_seqs=pad_between_seqs, inference_params=inference_params, + softmax_offset=softmax_offset, + fp8_output=fp8_output, ) return self.fused_attention( query_layer, @@ -1129,6 +1475,8 @@ def forward( quantizers=self.quantizers, pad_between_seqs=pad_between_seqs, inference_params=inference_params, + softmax_offset=softmax_offset, + fp8_output=fp8_output, ) from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled @@ -1140,6 +1488,7 @@ def forward( ) if use_unfused_attention: + allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, @@ -1157,6 +1506,11 @@ def forward( core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, inference_params=inference_params, + softmax_offset=softmax_offset, + fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation, + fp8_meta=self.fp8_meta, + quantizers=self.quantizers, + fp8_output=fp8_output, ) return self.unfused_attention( _alibi_cache, @@ -1173,5 +1527,10 @@ def forward( core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, inference_params=inference_params, + softmax_offset=softmax_offset, + fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation, + fp8_meta=self.fp8_meta, + quantizers=self.quantizers, + fp8_output=fp8_output, ) return None diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 1677689c1..de6fff805 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -19,6 +19,7 @@ from packaging.version import Version as PkgVersion import torch +import torch.distributed as dist import torch.nn.functional as F from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine_torch as tex @@ -27,6 +28,7 @@ QKVLayout, AttnBiasType, AttnMaskType, + SoftmaxType, FusedAttnBackend, META_QKV, META_DQKV, @@ -34,11 +36,13 @@ META_DO, META_S, META_DP, - META_O_CP, - META_DQKV_CP, ) from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -46,6 +50,8 @@ from transformer_engine.pytorch.utils import ( get_device_compute_capability, get_cudnn_version, + SplitAlongDim, + combine_tensors, ) from transformer_engine.pytorch.export import is_in_onnx_export_mode @@ -56,6 +62,9 @@ # NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 _NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) +# print quantizer info for a particular layer on a particular rank +_print_layer = int(os.getenv("NVTE_PRINT_LAYER_NUMBER", "1")) +_print_rank = int(os.getenv("NVTE_PRINT_RANK", "0")) _cu_seqlens_cache = {} @@ -209,6 +218,8 @@ class AttentionParams: Attention dropout. context_parallel: bool, default = `False` Whether context parallelism is used or not. + cp_comm_type: str, default = "p2p" + The communication type of context parallelism. deterministic: bool, default = `False` Whether to run `DotProductAttention` with determinism or not. is_training: bool, default = `True` @@ -219,6 +230,8 @@ class AttentionParams: The FP8 metadata tensor of `DotProductAttention`. inference_params: Optional[InferenceParams], default = `None` Inference-related parameters. See InferenceParams for details. + softmax_type: str, default = "vanilla" + The type of softmax operation. See DotProductAttention for details. """ qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor @@ -240,11 +253,13 @@ class AttentionParams: pad_between_seqs: bool = False attention_dropout: float = 0.0 context_parallel: bool = False + cp_comm_type: str = "p2p" deterministic: bool = False is_training: bool = True fp8: bool = False fp8_meta: Union[Dict[str, Any], None] = None inference_params: Optional[InferenceParams] = None + softmax_type: str = "vanilla" def __eq__(self, other): """ @@ -313,11 +328,13 @@ def get_attention_backend( pad_between_seqs = attention_params.pad_between_seqs attention_dropout = attention_params.attention_dropout context_parallel = attention_params.context_parallel + cp_comm_type = attention_params.cp_comm_type deterministic = attention_params.deterministic is_training = attention_params.is_training fp8 = attention_params.fp8 fp8_meta = attention_params.fp8_meta inference_params = attention_params.inference_params + softmax_type = attention_params.softmax_type # Run config logger = logging.getLogger("DotProductAttention") @@ -346,8 +363,31 @@ def get_attention_backend( field.name: getattr(attention_params, field.name) for field in fields(attention_params) } run_config.update(attention_params_dict) + # Add FP8 environment variables to config if fp8: + # all FP8 recipes: 1: (FP8 fwd, FP8 bwd), 0: (FP8 fwd, F16 bwd) run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # Float8CurrentScaling: 1: use F16 O in bwd, 0: use FP8 O in bwd + run_config["NVTE_DPA_FP8CS_O_in_F16"] = int(os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1")) + # switch recipe to "F16", "DelayedScaling", or "Float8CurrentScaling" + _dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") + run_config["NVTE_DPA_FP8_RECIPE"] = _dpa_fp8_recipe + if _dpa_fp8_recipe != "": + # config new recipe if switched + run_config["NVTE_DPA_FP8_FORMAT"] = os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID") + run_config["NVTE_DPA_FP8DS_AMAX_ALGO"] = os.getenv( + "NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent" + ) + run_config["NVTE_DPA_FP8DS_AMAX_HISTLEN"] = int( + os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1") + ) + run_config["NVTE_DPA_FP8DS_REDUCE_AMAX"] = int( + os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") + ) + # UnfusedDotProductAttention: 1: allow FP8 emulation, 0: do not allow + run_config["NVTE_UnfusedDPA_Emulate_FP8"] = int( + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") + ) logger.debug("Running with config=%s", run_config) # The following sections check if `FlashAttention` supports the provided attention params, @@ -427,8 +467,20 @@ def get_attention_backend( logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False if use_unfused_attention: - logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") - use_unfused_attention = False + allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + if not allow_emulation: + logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") + use_unfused_attention = False + fp8_recipe = fp8_meta["recipe"] + if fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + if ( + use_fused_attention + and fp8_recipe.float8_current_scaling() + and device_compute_capability < (10, 0) + ): + logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") + use_fused_attention = False # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size @@ -571,6 +623,51 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FlashAttention 3 for dropout") use_flash_attention_3 = False + # Filter: Softmax type + # context_parallel | softmax_type | supported backends + # ---------------------------------------------------------------------------------------------------- + # no | vanilla | All + # no | off-by-one | FusedAttention, UnfusedDotProductAttention + # no | learnable | FusedAttention, UnfusedDotProductAttention + # yes | vanilla | FusedAttention, FlashAttention + # yes | off-by-one | FusedAttention + # yes | learnable | FusedAttention + if softmax_type != "vanilla": + logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type) + use_flash_attention = False + if fp8 and fp8_meta["recipe"].fp8_dpa: + logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type) + use_fused_attention = False + logger.debug( + "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type + ) + use_unfused_attention = False + if qkv_format == "thd": + logger.debug( + "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type + ) + use_fused_attention = False + logger.debug( + "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", + softmax_type, + ) + use_unfused_attention = False + if context_parallel: + logger.debug( + "Disabling UnfusedDotProductAttention for context parallelism with softmax_type" + " = %s", + softmax_type, + ) + use_unfused_attention = False + if cp_comm_type != "a2a": + logger.debug( + "Disabling FusedAttention for context parallelism with softmax_type = %s and" + " cp_comm_type = %s", + softmax_type, + cp_comm_type, + ) + use_fused_attention = False + # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends # ---------------------------------------------------------------------------------------------------- @@ -814,6 +911,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt QKVLayout[qkv_layout], AttnBiasType[fu_core_attention_bias_type], AttnMaskType[attn_mask_type], + SoftmaxType[softmax_type], attention_dropout, num_heads, num_gqa_groups, @@ -1836,11 +1934,10 @@ def check_set_window_size( return window_size -def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): +def get_attention_quantizers(fp8, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: - num_of_nones = 8 if cp_specific_quantizers else 6 - return [None] * num_of_nones + return [None] * 6 QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer.internal = True QKV_quantizer.set_usage(rowwise=True, columnwise=False) @@ -1849,6 +1946,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): S_quantizer = quantizers["scaling_fwd"][META_S] S_quantizer.internal = True S_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] dQKV_quantizer.interal = True dQKV_quantizer.set_usage(rowwise=True, columnwise=False) @@ -1858,22 +1956,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): dP_quantizer = quantizers["scaling_bwd"][META_DP] dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True - dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP] - dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False) - dQKV_CP_quantizer.internal = True - O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP] - O_CP_quantizer.set_usage(rowwise=True, columnwise=False) - - if cp_specific_quantizers: - return ( + + return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer + + +def print_quantizers( + label, + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, +): + """Print the type and scale/amax of attention quantizers""" + _to_print = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL == 2 + if ( + _to_print + and _print_layer == layer_number + and ( + not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() == _print_rank) + ) + ): + names = [ + "QKV_quantizer", + "S_quantizer", + "O_quantizer", + "dO_quantizer", + "dP_quantizer", + "dQKV_quantizer", + ] + quantizers = [ QKV_quantizer, - O_quantizer, - O_CP_quantizer, S_quantizer, - dQKV_quantizer, - dQKV_CP_quantizer, + O_quantizer, dO_quantizer, dP_quantizer, - ) + dQKV_quantizer, + ] + if "forward" in label: + names = names[:3] + quantizers = quantizers[:3] + if "backward" in label: + names = names[3:] + quantizers = quantizers[3:] + for i, q in enumerate(quantizers): + type_str = "" + if q is None: + type_str = "None" + elif isinstance(q, Float8Quantizer): + type_str = "DS" + elif isinstance(q, Float8CurrentScalingQuantizer): + type_str = "CS" + print( + f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" + f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" + ) - return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer + +def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): + """Combine q,k,v based on qkv_layout and quantize them together""" + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_group = len(qkv_layout.split("_")) + src_nominal_dtype = q.dtype + match qkv_group: + case 1: + dim = qkv_layout.find("3") + qkv = combine_tensors([q, k, v], dim) + qkv_fp8 = qkv_quantizer(qkv) + q_data, k_data, v_data = SplitAlongDim.apply(qkv_fp8._data, dim, [1, 1, 1], True) + case 2: + dim = qkv_layout.split("_")[1].find("2") + kv = combine_tensors([k, v], dim) + tensors = [q, kv] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv = torch.cat([x.view(-1) for x in tensors], dim=0) + qkv_fp8 = qkv_quantizer(qkv) + q_data, kv_data = [ + qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors) + ] + k_data, v_data = SplitAlongDim.apply(kv_data, dim, [1, 1], True) + case 3: + tensors = [q, k, v] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv = torch.cat([x.view(-1) for x in tensors], dim=0) + qkv_fp8 = qkv_quantizer(qkv) + q_data, k_data, v_data = [ + qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors) + ] + case _: + raise RuntimeError("Invalid qkv_layout " + qkv_layout) + + q_fp8, k_fp8, v_fp8 = [ + Float8Tensor.make_like(qkv_fp8, data=x, dtype=src_nominal_dtype) + for x in [q_data, k_data, v_data] + ] + + return q_fp8, k_fp8, v_fp8 + + +def combine_and_dequantize( + qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=None, des_nominal_dtype=None +): + """Combine q,k,v based on qkv_layout and dequantize them together""" + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_group = len(qkv_layout.split("_")) + if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]): + src_nominal_dtype = q_fp8.dtype + else: + assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!" + if des_nominal_dtype is None: + des_nominal_dtype = src_nominal_dtype + + q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]] + match qkv_group: + case 1: + dim = qkv_layout.find("3") + qkv_data = combine_tensors([q_data, k_data, v_data], dim) + qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data) + qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype) + q, k, v = SplitAlongDim.apply(qkv, dim, [1, 1, 1], True) + case 2: + dim = qkv_layout.split("_")[1].find("2") + kv_data = combine_tensors([k_data, v_data], dim) + tensors = [q_data, kv_data] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv_data = torch.cat([x.reshape(-1) for x in tensors], dim=0) + qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype) + qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype) + q, kv = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)] + k, v = SplitAlongDim.apply(kv, dim, [1, 1], True) + case 3: + tensors = [q_data, k_data, v_data] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv_data = torch.cat([x.contiguous().reshape(-1) for x in tensors], dim=0) + qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype) + qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype) + q, k, v = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)] + case _: + raise RuntimeError("Invalid qkv_layout " + qkv_layout) + return q, k, v diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 5fd16bf1a..b2f1ff1ac 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Multi-head Attention.""" +import os import collections from typing import Callable, List, Optional, Tuple, Union import torch @@ -31,7 +32,13 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor + +# Force DotProductAttention to use a different recipe than the fp8_recipe set in fp8_autocast(). +# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling" +# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa. +_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") +_dpa_fp8_recipe_dpa = os.getenv("NVTE_DPA_FP8_RECIPE_DPA", "0") == "1" +_dpa_fp8_recipe_mha = os.getenv("NVTE_DPA_FP8_RECIPE_MHA", "0") == "1" class MultiheadAttention(torch.nn.Module): @@ -135,6 +142,17 @@ class MultiheadAttention(torch.nn.Module): For that, please use `get_qkv_layout` to gain the layout information. name: str, default = `None` name of the module, currently used for debugging purposes. + softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' + softmax type as described in this paper: + `Efficient Streaming Language Models with Attention Sinks + `_. + For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], + 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + where alpha is a learnable parameter in shape [h]. + 'off-by-one' and 'learnable' softmax types are also called sink attention + ('zero sink' and 'learnable sink'). Parallelism parameters ---------------------- @@ -245,6 +263,7 @@ def __init__( qk_norm_before_rope: bool = False, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -262,6 +281,7 @@ def __init__( self.return_bias = return_bias self.cp_size = 1 self.cp_rank = 0 + self.softmax_type = softmax_type kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) @@ -416,6 +436,7 @@ def __init__( tp_group=tp_group, layer_number=self.layer_number, attention_type=self.attention_type, + softmax_type=self.softmax_type, ) # Linear @@ -556,10 +577,12 @@ def set_context_parallel_group( self.cp_size = get_distributed_world_size(cp_group) self.cp_rank = get_distributed_rank(cp_group) elif isinstance(cp_group, list): - assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" assert ( cp_comm_type == "a2a+p2p" ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!" + assert ( + len(cp_group) == 2 + ), "cp_comm_type = a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!" cp_size_a2a = get_distributed_world_size(cp_group[0]) cp_rank_a2a = get_distributed_rank(cp_group[0]) cp_size_p2p = get_distributed_world_size(cp_group[1]) @@ -716,10 +739,22 @@ def forward( # Query, Key, and Value # ====================== - fp8_mha = ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.get_fp8_recipe().fp8_mha - ) + fp8 = FP8GlobalStateManager.is_fp8_enabled() + if _dpa_fp8_recipe == "": + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + fp8_dpa = fp8_recipe.fp8_dpa + fp8_mha = fp8_recipe.fp8_mha + float8_current_scaling = fp8_recipe.float8_current_scaling() + else: + fp8_dpa = _dpa_fp8_recipe_dpa + fp8_mha = _dpa_fp8_recipe_mha + float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling" + # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe + qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling + # DPA: always produce FP8 output when fp8=True to take advantage of the O amax + dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) + # Proj Gemm: match DPA output except for Float8CurrentScaling + proj_fp8_grad = dpa_fp8_output and not float8_current_scaling layernorm_output = None if self.attention_type == "self": @@ -728,7 +763,7 @@ def forward( layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) if self.return_layernorm_output: mixed_x_layer, layernorm_output = layernorm_qkv_outputs @@ -738,7 +773,7 @@ def forward( mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) num_queries_per_key_value = ( @@ -792,7 +827,7 @@ def forward( mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) if self.qkv_weight_interleaved: @@ -847,7 +882,7 @@ def forward( layernorm_query_outputs = self.layernorm_query( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) if self.return_layernorm_output: query_layer, layernorm_output = layernorm_query_outputs @@ -857,7 +892,7 @@ def forward( query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) # [sq, b, hp] --> [sq, b, np, hn] @@ -958,6 +993,7 @@ def forward( fast_zero_fill=fast_zero_fill, inference_params=inference_params, pad_between_seqs=pad_between_seqs, + fp8_output=dpa_fp8_output, ) # =================== @@ -966,7 +1002,7 @@ def forward( projection_output = self.proj( context_layer, is_first_microbatch=is_first_microbatch, - fp8_grad=isinstance(context_layer, QuantizedTensor), + fp8_grad=proj_fp8_grad, ) if self.return_bias: diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 48dc1ba29..f51cf63a0 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -100,3 +100,5 @@ def __missing__(self, key): dist_group_type = torch.distributed.ProcessGroup MXFP8_BLOCK_SCALING_SIZE = 32 + +NVFP4_BLOCK_SCALING_SIZE = 16 diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 16fa9f3e8..639f00de0 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -15,6 +15,7 @@ NVTE_QKV_Format, NVTE_Bias_Type, NVTE_Mask_Type, + NVTE_Softmax_Type, NVTE_Fused_Attn_Backend, ) from ..tensor.quantized_tensor import Quantizer @@ -102,6 +103,11 @@ "CK": NVTE_Fused_Attn_Backend.NVTE_CK, "No_Backend": NVTE_Fused_Attn_Backend.NVTE_No_Backend, } +SoftmaxType = { + "vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + "off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX, + "learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX, +} BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 BACKEND_F16arb_ELTS_PER_THREADS = 16 @@ -112,9 +118,6 @@ META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_DP = tex.FP8BwdTensors.GRAD_INPUT3 -# repurpose some unused amax history buffers for partial results of CP fwd and bwd -META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT -META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 def fused_attn_fwd( is_training: bool, @@ -140,8 +143,10 @@ def fused_attn_fwd( qkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", + softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), rng_gen: torch.Generator = None, + softmax_offset: torch.Tensor = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -206,6 +211,8 @@ def fused_attn_fwd( type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type: str, default = "padding" type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} + softmax_type: str, default = "vanilla" + type of the attention softmax; {"vanilla", "off-by-one", "learnable"} window_size: Tuple[int, int], default = (-1, -1) sliding window size for local attention, where query at position i attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q @@ -214,6 +221,9 @@ def fused_attn_fwd( rng_gen: torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen + softmax_offset: torch.Tensor, default = None + softmax offset tensor in shape [1, h_q, 1, 1]. + See softmax_type in DotProductAttention for details. Returns ---------- @@ -299,6 +309,7 @@ def fused_attn_fwd( QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + SoftmaxType[softmax_type], window_size, cu_seqlens_q, cu_seqlens_kv, @@ -313,6 +324,7 @@ def fused_attn_fwd( s_quantizer, o_quantizer, attn_bias, + softmax_offset, rng_gen, rng_elts_per_thread, ) @@ -346,6 +358,7 @@ def fused_attn_bwd( qkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", + softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), deterministic: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -411,6 +424,8 @@ def fused_attn_bwd( type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type: str, default = "padding" type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} + softmax_type: str, default = "vanilla" + type of the attention softmax; {"vanilla", "off-by-one", "learnable"} window_size: Tuple[int, int], default = (-1, -1) sliding window size for local attention, where query at position i attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q @@ -430,6 +445,9 @@ def fused_attn_bwd( d_bias: torch.Tensor, optional gradient tensor of Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; same data type and shape as Bias + d_softmax_offset: torch.Tensor, optional + gradient tensor of softmax offset in shape [1, h_q, 1, 1]. + See softmax_type in DotProductAttention for details. """ if attn_scale is None: d = q.size(-1) @@ -468,6 +486,7 @@ def fused_attn_bwd( QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + SoftmaxType[softmax_type], window_size, deterministic, cu_seqlens_q, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index e4f4e619f..d330e023e 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -13,6 +13,8 @@ from ..tensor.quantized_tensor import Quantizer from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..tensor.utils import is_experimental +from ..experimental.gemm import experimental_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer __all__ = [ @@ -77,6 +79,24 @@ def general_gemm( if not out.is_contiguous(): raise ValueError("Output tensor is not contiguous.") + # If A or B are experimental tensors -> dispatch to quantizers's qgemm implementation + if is_experimental(A) or is_experimental(B): + return experimental_gemm( + A, + B, + workspace, + out_dtype, + quantization_params, + gelu, + gelu_in, + accumulate, + layout, + out, + bias, + use_split_accumulator, + grad, + ) + debug_quantizer = None if isinstance(quantization_params, DebugQuantizer): debug_quantizer = quantization_params diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index aa6602401..6f6cbefd7 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -18,6 +18,20 @@ namespace transformer_engine::pytorch { +/*! convert fp4 data shape back to original shape */ +std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose) { + std::vector ret; + size_t start_idx = (transpose) ? 1 : 0; + for (size_t i = start_idx; i < shape.size() - 1; ++i) { + ret.push_back(shape[i]); + } + ret.push_back(shape.back() * 2); + if (transpose) { + ret.push_back(shape.front()); + } + return ret; +} + std::vector getTensorShape(const at::Tensor& t) { std::vector shape; for (auto s : t.sizes()) { @@ -320,4 +334,20 @@ at::Tensor allocate_amax_workspace(const TensorWrapper& input_tensor) { #endif +void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { + NVTE_SCOPED_GIL_RELEASE({ + nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val, + arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_, + at::cuda::getCurrentCUDAStream()); + }); +} + +// extract PhiloxCudaState from CUDA random number generator +at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread) { + at::PhiloxCudaState philox_args; + std::lock_guard lock(gen->mutex_); + philox_args = gen->philox_cuda_state(elts_per_thread); + return philox_args; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 07384413d..3ddf14f9d 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -37,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -204,20 +205,25 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - /*! @brief Construct a high precision tensor giving it this quantizer's amax - - Note: this member function also zeros out the amax, as it is meant to be used in conjunction with - a kernel computing the amax, which might expect the amax to be initialized to zero + /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. + * + * The amax is zeroed out. Most TE kernels that output amax expect + * amax to be initialized to zero. */ - std::pair create_hp_tensor_with_amax(const std::vector& shape, - DType dtype); + std::pair create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype, std::optional data = std::nullopt); std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; - /*! @brief Convert to a quantized data format avoiding amax computation */ + /*! @brief Quantize to FP8, skipping local amax computation + * + * The quantizer's amax pointer is assumed to already hold the local + * amax. The amax may still be reduced across the amax reduction + * group. + */ void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt); @@ -287,6 +293,63 @@ class MXFP8Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; +#ifndef USE_ROCM +class NVFP4Quantizer : public Quantizer { + public: + // fp4 dtype + DType dtype; + // amax reduction for low precision FP4 AG + bool with_amax_reduction; + c10::intrusive_ptr amax_reduction_group; + // random hadamard transform + bool with_rht; + bool with_post_rht_amax; + // 2D block scaling + bool with_2d_quantization; + bool stochastic_rounding; + + int rht_matrix_random_sign_mask_t; + at::Tensor rht_matrix; + + explicit NVFP4Quantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_NVFP4_1D_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer + * + * The amax is zeroed out. Most TE kernels that output amax expect + * amax to be initialized to zero. + */ + std::pair create_unquantized_tensor_with_amax( + TensorWrapper& quantized_tensor, DType dtype); + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; + + /*! @brief Quantize to NVFP4, skipping local amax computation + * + * The input tensor's amax pointer is assumed to already hold the + * local amax. The amax may still be reduced across the amax + * reduction group. + */ + void quantize_with_amax(TensorWrapper& input, TensorWrapper& out); + + std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + + private: + void quantize_impl(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag, bool compute_amax); +}; +#endif // USE_ROCM + + std::unique_ptr convert_quantizer(py::handle quantizer); std::vector getTensorShape(const at::Tensor& t); @@ -448,6 +511,14 @@ NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); #ifdef __HIP_PLATFORM_AMD__ at::Tensor allocate_amax_workspace(const TensorWrapper& input_tensor); #endif +std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); + +// unpack the PhiloxCudaState into CUDA tensor +void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr); + +// extract PhiloxCudaState from CUDA random number generator +at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread); + } // namespace transformer_engine::pytorch namespace std { diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 9b527b161..791939b5a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -83,28 +83,36 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T NVTE_Fused_Attn_Backend get_fused_attn_backend( bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right); + +std::pair quantizer_helper(py::handle quantizer, + const std::vector &shape, DType dtype, + bool create_hp_tensor_for_cs, + std::optional data); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const std::optional cu_seqlens_q_padded, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + const std::vector window_size, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional rng_gen, size_t rng_elts_per_thread); + const std::optional SoftmaxOffset, const std::optional rng_gen, + size_t rng_elts_per_thread); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 8b0607c9e..6fb17d3ca 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -10,179 +10,291 @@ #include "common.h" #include "pybind.h" -namespace transformer_engine::pytorch { +namespace transformer_engine { +namespace pytorch { -template -py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { +namespace { + +py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t), + const at::Tensor& input, py::handle quantizer, + int shape_divisor = 1) { init_extension(); // Input tensor auto input_tensor = input.contiguous(); - const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor); // Construct output tensor auto quantizer_cpp = convert_quantizer(quantizer); - const auto input_shape = input_cpp.shape(); + const auto input_shape = input_nvte.shape(); std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); output_shape.back() /= shape_divisor; auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); - auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); - - // Compute activation + auto [out_nvte, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); + + // Choose implementation + enum class Impl { + UNFUSED, + FULLY_FUSED, + FUSED_ACTIVATION_AMAX_FP8, +#ifndef USE_ROCM + FUSED_ACTIVATION_AMAX_NVFP4 +#endif + }; + Impl impl = Impl::UNFUSED; if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr())) { - // Compute activation directly - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // Compute activation in high-precision fused together with amax, then quantize. - - auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); - quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); - } else { - // Compute activation in high-precision, then quantize - - auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); - quantizer_cpp->quantize(temp_cpp, out_cpp); + impl = Impl::FUSED_ACTIVATION_AMAX_FP8; +#ifndef USE_ROCM + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else { + impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; + } +#endif + } + + // Perform compute + auto stream = at::cuda::getCurrentCUDAStream(); + switch (impl) { + case Impl::UNFUSED: + // Compute activation in high precision, then quantize + { + auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + quantizer_cpp->quantize(temp_nvte, out_nvte); + } + break; + case Impl::FULLY_FUSED: + // Compute activation directly + { + NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), out_nvte.data(), stream); }); + } + break; + case Impl::FUSED_ACTIVATION_AMAX_FP8: + // Compute activation and amax in high precision, then quantize to FP8 + { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + auto [temp_nvte, _] = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); + } + break; +#ifndef USE_ROCM + case Impl::FUSED_ACTIVATION_AMAX_NVFP4: + // Compute activation and amax in high precision, then quantize to NVFP4 + { + auto nvfp4_quantizer_cpp = + static_cast(quantizer_cpp.get()); // Already checked cast is valid + auto [temp_nvte, _] = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); + } + break; +#endif + default: + NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); } return out_py; } -template -py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, - py::handle quantizer) { +py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, + cudaStream_t), + const at::Tensor& grad_output, const at::Tensor& input, + py::handle quantizer) { init_extension(); // Grad output and input tensors auto grad_output_tensor = grad_output.contiguous(); auto input_tensor = input.contiguous(); - const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor); - const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + const TensorWrapper& grad_output_nvte = makeTransformerEngineTensor(grad_output_tensor); + const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor); // Construct grad input tensor auto quantizer_cpp = convert_quantizer(quantizer); - const auto input_shape_te = input_cpp.shape(); + const auto input_shape_te = input_nvte.shape(); const std::vector input_shape(input_shape_te.data, input_shape_te.data + input_shape_te.ndim); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); - auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); - - // Compute activation backward + auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); + + // Choose implementation + enum class Impl { + UNFUSED, + FULLY_FUSED, + FUSED_ACTIVATION_AMAX_FP8, +#ifndef USE_ROCM + FUSED_ACTIVATION_AMAX_NVFP4 +#endif + }; + Impl impl = Impl::UNFUSED; if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr())) { - // Compute activation backward directly - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), - at::cuda::getCurrentCUDAStream()); - }); + impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // Compute activation backward in high-precision fused together with amax, then quantize. - auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - }); - quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); - } else { - // Compute activation backward in high-precision, then quantize - auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - }); - quantizer_cpp->quantize(temp_cpp, grad_input_cpp); + impl = Impl::FUSED_ACTIVATION_AMAX_FP8; +#ifndef USE_ROCM + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else { + impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; + } +#endif + } + + // Perform compute + auto stream = at::cuda::getCurrentCUDAStream(); + switch (impl) { + case Impl::UNFUSED: + // Compute activation backward in high precision, then quantize + { + auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), + at::cuda::getCurrentCUDAStream()); + }); + quantizer_cpp->quantize(temp_nvte, grad_input_nvte); + } + break; + case Impl::FULLY_FUSED: + // Compute activation backward directly + { + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream); + }); + } + break; + case Impl::FUSED_ACTIVATION_AMAX_FP8: + // Compute activation and amax in high precision, then quantize to FP8 + { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + auto [temp_nvte, _] = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE( + { dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); }); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + } + break; +#ifndef USE_ROCM + case Impl::FUSED_ACTIVATION_AMAX_NVFP4: + // Compute activation and amax in high precision, then quantize to NVFP4 + { + auto nvfp4_quantizer_cpp = + static_cast(quantizer_cpp.get()); // Already checked cast is valid + auto [temp_nvte, _] = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype); + NVTE_SCOPED_GIL_RELEASE( + { dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); }); + nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + } + break; +#endif + default: + NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); } return grad_input_py; } -/* GELU and variants*/ +} // namespace + +/* GELU and variants */ py::object gelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_gelu, input, quantizer); } py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dgelu, grad, input, quantizer); } py::object geglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_geglu, input, quantizer, 2); } py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dgeglu, grad, input, quantizer); } py::object qgelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_qgelu, input, quantizer); } py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dqgelu, grad, input, quantizer); } py::object qgeglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_qgeglu, input, quantizer, 2); } py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dqgeglu, grad, input, quantizer); } -/* ReLU and variants*/ +/* ReLU and variants */ py::object relu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_relu, input, quantizer); } py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_drelu, grad, input, quantizer); } py::object reglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_reglu, input, quantizer, 2); } py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dreglu, grad, input, quantizer); } py::object srelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_srelu, input, quantizer); } py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dsrelu, grad, input, quantizer); } py::object sreglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_sreglu, input, quantizer, 2); } py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dsreglu, grad, input, quantizer); } -/* Silu and variants*/ +/* Silu and variants */ py::object silu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_silu, input, quantizer); } py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dsilu, grad, input, quantizer); } py::object swiglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_swiglu, input, quantizer, 2); } py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dswiglu, grad, input, quantizer); } -} // namespace transformer_engine::pytorch + +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 6d835a5c9..344bc4ab0 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -35,22 +35,6 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s { nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); }); } -void unpack(at::PhiloxCudaState arg, int64_t *rng_state_ptr) { - NVTE_SCOPED_GIL_RELEASE({ - nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val, - arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_, - at::cuda::getCurrentCUDAStream()); - }); -} - -// extract PhiloxCudaState from CUDA random number generator -at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_per_thread) { - at::PhiloxCudaState philox_args; - std::lock_guard lock(gen->mutex_); - philox_args = gen->philox_cuda_state(elts_per_thread); - return philox_args; -} - } // namespace namespace transformer_engine::pytorch { @@ -58,66 +42,95 @@ namespace transformer_engine::pytorch { // get the fused attention backend NVTE_Fused_Attn_Backend get_fused_attn_backend( bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, - max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); + bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, + max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); return fused_attention_backend; } +// helper function for S and dP quantizers +std::pair quantizer_helper(py::handle quantizer, + const std::vector &shape, DType dtype, + bool create_hp_tensor_for_cs, + std::optional data) { + std::unique_ptr T_quantizer = convert_quantizer(quantizer); + TensorWrapper te_T; + py::object py_T; + if (quantizer.is_none()) { + // high precision + auto *none_quantizer = dynamic_cast(T_quantizer.get()); + if (data.has_value()) { + std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype); + } + } else if (detail::IsFloat8Quantizers(quantizer.ptr())) { + // delayed scaling; this helps initialize scale_inv + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + std::tie(te_T, py_T) = + T_quantizer_fp8->create_tensor(shape, dtype, data, std::nullopt, std::nullopt); + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // current scaling + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + if (create_hp_tensor_for_cs) { + if (data.has_value()) { + std::tie(te_T, py_T) = + T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); + } + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); + NVTE_CHECK( + !data.has_value(), + "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); + } + } + return {std::move(te_T), std::move(py_T)}; +} + // fused attention FWD with separate Q, K and V tensors std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const std::optional cu_seqlens_q_padded, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + const std::vector window_size, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional rng_gen, size_t rng_elts_per_thread) { - TensorWrapper te_Q, te_K, te_V, te_O, te_S; - + const std::optional SoftmaxOffset, const std::optional rng_gen, + size_t rng_elts_per_thread) { auto none = py::none(); - std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); - std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); + // create QKV tensor wrappers + TensorWrapper te_Q, te_K, te_V; te_Q = makeTransformerEngineTensor(Q, none); te_K = makeTransformerEngineTensor(K, none); te_V = makeTransformerEngineTensor(V, none); - - // If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types. const DType qkv_type = te_Q.dtype(); - const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + // create S tensor + TensorWrapper te_S; + py::object py_S; + std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); + + // create O tensor + TensorWrapper te_O; + py::object py_O; + std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); std::vector q_shape = convertShape(te_Q.shape()); - std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - // create output tensor O - auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; - py::object o_python, s_python; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // Initialize FP8 tensor with scale-inverse - auto *O_quantizer_fp8 = dynamic_cast(O_quantizer.get()); - auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); - NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt, - std::nullopt, std::nullopt); - std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, - std::nullopt, std::nullopt); - } else { - std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); - } - auto o_shape_int64 = std::vector{o_shape.begin(), o_shape.end()}; + const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); // construct NVTE tensors TensorWrapper te_Bias; @@ -128,11 +141,12 @@ std::vector fused_attn_fwd( // FP8 auto h = q_shape[q_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - te_O.zero_(at::cuda::getCurrentCUDAStream()); + if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if ((h * d) % block_size == 0) { + mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + te_O.zero_(at::cuda::getCurrentCUDAStream()); + } } } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { @@ -181,12 +195,23 @@ std::vector fused_attn_fwd( DType::kInt32, nullptr, nullptr, nullptr); } + // softmax offset + TensorWrapper te_SoftmaxOffset; + if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { + auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec(); + std::vector SoftmaxOffset_shape{SoftmaxOffset_sizes.begin(), SoftmaxOffset_sizes.end()}; + te_SoftmaxOffset = + makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), SoftmaxOffset_shape, + DType::kFloat32, nullptr, nullptr, nullptr); + } + // extract rng seed and offset auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - unpack(philox_args, static_cast(rng_state.data_ptr())); + auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options); + philox_unpack(philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); // create auxiliary output tensors @@ -199,11 +224,11 @@ std::vector fused_attn_fwd( // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ nvte_fused_attn_fwd( - te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -214,52 +239,53 @@ std::vector fused_attn_fwd( // output_tensors = [O, nvte_aux_tensor_pack.tensors] std::vector output_tensors; - output_tensors.push_back(o_python); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - // allocate memory for nvte_aux_tensor_pack.tensors - at::Tensor output_tensor; - if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = - (i < nvte_aux_tensor_pack.size - 1) - ? allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false) - : rng_state; - } - } else { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); - } + output_tensors.push_back(py_O); + auto set_tensor_param = [&](size_t i, const at::Tensor &output_tensor) { output_tensors.push_back(py::cast(output_tensor)); NVTEBasicTensor temp_data = {output_tensor.data_ptr(), nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]), nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])}; nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); + }; + // allocate memory for nvte_aux_tensor_pack.tensors + // f16_max512 : S [b, h, sq, skv] + // f16_arbitrary: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] + // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] + size_t i = 0; + at::Tensor output_tensor; + // intermediate softmax tensor, S or M + output_tensor = + allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); + set_tensor_param(i++, output_tensor); + // fp8 has an additional softmax stats tensor, ZInv + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + output_tensor = + allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); + set_tensor_param(i++, output_tensor); + } + // rng_state + if (i < nvte_aux_tensor_pack.size) { + set_tensor_param(i++, rng_state); + } + // bias (optional) + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { + set_tensor_param(i++, Bias.value()); + } + // softmax_offset (optional) + if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { + set_tensor_param(i++, SoftmaxOffset.value()); } // execute the kernel NVTE_SCOPED_GIL_RELEASE({ nvte_fused_attn_fwd( - te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -274,58 +300,53 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer) { auto none = py::none(); - TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; + + // create QKV, O, dO tensor wrappers + TensorWrapper te_Q, te_K, te_V, te_O, te_dO; te_Q = makeTransformerEngineTensor(Q, none); te_K = makeTransformerEngineTensor(K, none); te_V = makeTransformerEngineTensor(V, none); te_O = makeTransformerEngineTensor(O, none); te_dO = makeTransformerEngineTensor(dO, none); - // qkv type from the te_Q - std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); - const DType qkv_type = te_Q.dtype(); - const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - - py::object s_python, dp_python; - std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); - std::unique_ptr dP_quantizer = convert_quantizer(dp_quantizer); - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); - auto *dP_quantizer_fp8 = dynamic_cast(dP_quantizer.get()); - NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, - std::nullopt, std::nullopt); - std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, - std::nullopt, std::nullopt); - } else { - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); - std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); - } + // create S and dP tensors + TensorWrapper te_S, te_dP; + py::object py_S, py_dP; + std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); + std::tie(te_dP, py_dP) = + quantizer_helper(dp_quantizer, {0}, DType::kFloat32, false, std::nullopt); + // create dQ, dK, dV tensors + TensorWrapper te_dQ, te_dK, te_dV; + py::object py_dQ, py_dK, py_dV; + std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); auto h_q = q_shape[q_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2]; auto d_qk = q_shape[q_shape.size() - 1]; - auto d_v = v_shape[v_shape.size() - 1]; - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - std::vector o_shape{q_shape.begin(), q_shape.end()}; - o_shape[o_shape.size() - 1] = d_v; + const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); at::Tensor dQ, dK, dV, dQKV, dKV; - py::object py_dQ, py_dK, py_dV; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); std::vector tmp_shape; + auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); + if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { + options = options.dtype(torch::kUInt8); + } + if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) { + options = options.dtype(fake_dtype); + } switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: @@ -398,39 +419,27 @@ std::vector fused_attn_bwd( default: NVTE_ERROR("QKV layout not supported!"); } - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - auto *fp8_quantizer = dynamic_cast(dQKV_quantizer.get()); - NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_dQ, py_dQ) = - fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt); - std::tie(te_dK, py_dK) = - fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt); - std::tie(te_dV, py_dV) = - fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt); - } else { - auto *none_quantizer = dynamic_cast(dQKV_quantizer.get()); - NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8"); - std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); - std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK); - std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV); - } + + std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ); + std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK); + std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV); // construct NVTE tensors - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && - dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQ.fill_(0); - dK.fill_(0); - dV.fill_(0); + if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && + dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) { + mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + dQ.fill_(0); + dK.fill_(0); + dV.fill_(0); + } } - - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + } else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); @@ -499,6 +508,15 @@ std::vector fused_attn_bwd( } } + // create dSoftmaxOffset in the same shape as SoftmaxOffset + at::Tensor dSoftmaxOffset; + TensorWrapper te_dSoftmaxOffset; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + options = torch::TensorOptions().dtype(at::kFloat).device(torch::kCUDA); + dSoftmaxOffset = torch::empty({1, static_cast(h_q), 1, 1}, options); + te_dSoftmaxOffset = makeTransformerEngineTensor(dSoftmaxOffset); + } + // create workspace TensorWrapper workspace; @@ -507,10 +525,10 @@ std::vector fused_attn_bwd( nvte_fused_attn_bwd( te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic, - workspace.data(), at::cuda::getCurrentCUDAStream()); + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -523,16 +541,16 @@ std::vector fused_attn_bwd( nvte_fused_attn_bwd( te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic, - workspace.data(), at::cuda::getCurrentCUDAStream()); + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - return {py_dQ, py_dK, py_dV, py::cast(dBias)}; + return {py_dQ, py_dK, py_dV, py::cast(dBias), py::cast(dSoftmaxOffset)}; } at::Tensor fa_prepare_fwd(at::Tensor qkvi) { @@ -598,7 +616,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s // Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1 int seq_dim = tensor.dim() == 3 ? 0 : 1; - int batch = cu_seqlens.size(0) - 1; int num_heads = tensor.size(seq_dim + 1); int dim_per_head = tensor.size(seq_dim + 2); int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type()); @@ -762,8 +779,6 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t NVTE_CHECK(world_size > 0); NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0); - int batch = cu_seqlens.size(0) - 1; - std::vector shape = {total_tokens / world_size}; at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int)); @@ -801,7 +816,6 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, **************************************************************************************************/ at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) { - int max_seq_len = tensor.size(1); int h = tensor.size(2); int d = tensor.size(3); std::vector shape = {t, h, d}; diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index f65614d07..89aa0f154 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -56,10 +56,25 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; } - // Unfused impl if quantizer is not supported - const bool with_fused_dbias_quantize_kernel = - detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr()); - if (!with_fused_dbias_quantize_kernel) { + // Check if fused kernel is supported + bool with_fused_kernel = false; + if (detail::IsFloat8Quantizers(quantizer.ptr())) { + auto prop = at::cuda::getCurrentDeviceProperties(); + const size_t sm_arch = 10 * prop->major + prop->minor; + if (sm_arch >= 100) { + // Fused kernel for dbias + FP8 cast on SM arch 10.0+ + with_fused_kernel = true; + } else if (quantizer_cpp->rowwise_usage && quantizer_cpp->columnwise_usage) { + // Fused kernel for dbias + FP8 cast + FP8 transpose + with_fused_kernel = true; + } + } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Fused kernel for dbias + MXFP8 quantize + with_fused_kernel = true; + } + + // Apply unfused impl if fused kernel is not supported + if (!with_fused_kernel) { at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0}); quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte); return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; @@ -124,13 +139,29 @@ std::vector dact_dbias( } // Choose implementation - enum class Impl { UNFUSED, FUSED_DACT_DBIAS_QUANTIZE, FUSED_DACT_AMAX }; + enum class Impl { + UNFUSED, + FUSED_DACT_DBIAS_QUANTIZE, + FUSED_DACT_AMAX_FP8, + FUSED_DACT_AMAX_NVFP4 + }; Impl impl = Impl::UNFUSED; if (detail::IsFloat8Quantizers(quantizer_py.ptr()) || detail::IsMXFP8Quantizers(quantizer_py.ptr())) { impl = Impl::FUSED_DACT_DBIAS_QUANTIZE; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { - impl = Impl::FUSED_DACT_AMAX; + impl = Impl::FUSED_DACT_AMAX_FP8; +#ifndef USE_ROCM + } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else { + impl = Impl::FUSED_DACT_AMAX_NVFP4; + } +#endif } // Perform compute @@ -174,22 +205,42 @@ std::vector dact_dbias( }); break; } - case Impl::FUSED_DACT_AMAX: - // Fused dact-amax kernel, unfused dbias and quantize + case Impl::FUSED_DACT_AMAX_FP8: + // Fused dact-amax kernel, unfused dbias and FP8 quantize { - auto *quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(quantizer_cpp_cs != nullptr, + auto *fp8_quantizer_cpp = + dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Invalid quantizer for fused dact-amax kernel impl"); auto [temp_nvte, temp_py] = - quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, grad_output_dtype); + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, grad_output_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); + }); + const auto temp_torch = temp_py.cast(); + at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + break; + } +#ifndef USE_ROCM + case Impl::FUSED_DACT_AMAX_NVFP4: + // Fused dact-amax kernel, unfused dbias and NVFP4 quantize + { + auto *nvfp4_quantizer_cpp = + static_cast(quantizer_cpp.get()); // Already checked cast is valid + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, + "Invalid quantizer for fused dact-amax kernel impl"); + auto [temp_nvte, temp_py] = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax( + grad_input_nvte, grad_output_dtype); NVTE_SCOPED_GIL_RELEASE({ dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); }); const auto temp_torch = temp_py.cast(); at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); - quantizer_cpp_cs->quantize_with_amax(temp_nvte, grad_input_nvte); + nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); break; } +#endif default: NVTE_ERROR("Invalid implementation"); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index c940181b0..bbb6392d0 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -39,7 +39,18 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob // Convert input tensor to C++ object auto input_contiguous = tensor.contiguous(); - const auto input_cpp = makeTransformerEngineTensor(input_contiguous); + auto input_cpp = makeTransformerEngineTensor(input_contiguous); + + // Set amax if use_existing_amax = true (only valid for CS) + bool use_existing_amax = false; + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + use_existing_amax = quantizer.attr("use_existing_amax").cast(); + if (use_existing_amax) { + const at::Tensor &amax = quantizer.attr("amax").cast(); + input_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + } // Initialize output tensor TensorWrapper output_cpp; @@ -59,7 +70,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob } // Perform quantization - quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); + if (use_existing_amax) { + auto *quantizer_cs = dynamic_cast(quantizer_cpp.get()); + quantizer_cs->quantize_with_amax(input_cpp, output_cpp, noop_flag_cpp); + } else { + quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); + } return output_py; } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index b637d49c7..cc99f3412 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -215,6 +215,19 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + // Construct GEMM config + transformer_engine::MatmulConfigWrapper config; + if (grad) { + config.set_dbias_tensor(bias_tensor.data()); + config.set_with_dgelu_epilogue(gelu); + } else { + config.set_bias_tensor(bias_tensor.data()); + config.set_with_gelu_epilogue(gelu); + } + config.set_epilogue_aux_tensor(te_pre_gelu_out.data()); + config.set_use_split_accumulator(use_split_accumulator); + config.set_sm_count(num_math_sms); + #ifndef USE_ROCM // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; @@ -286,10 +299,9 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), - bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, - te_workspace.data(), alpha, *beta, use_split_accumulator, - num_math_sms, main_stream); + nvte_cublas_gemm_v2(transa, transb, &alpha, A_tensor.data(), B_tensor.data(), &beta.value(), + out_tensor.data(), out_tensor.data(), te_workspace.data(), config, + main_stream); }); } } else { diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 728d39cbd..70fa609b4 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -68,67 +68,108 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Input and param tensors auto none = py::none(); - const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); - const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); - TensorWrapper bias_cu; + const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none); + TensorWrapper bias_nvte; if (bias.has_value()) { - bias_cu = makeTransformerEngineTensor(*bias); + bias_nvte = makeTransformerEngineTensor(*bias); } // Tensor dimensions - const size_t N = static_cast(input_cu.size(0)); - const size_t H = static_cast(input_cu.size(1)); - const std::vector size = {N, H}; + const auto shape = nvte_shape_to_vector(input_nvte.shape()); + const auto outer_size = product(shape) / shape.back(); + const auto inner_size = shape.back(); // Tensors to save for backward pass - at::Tensor mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - at::Tensor rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - TensorWrapper mu_cu = makeTransformerEngineTensor(mu); - TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma); + at::Tensor mu_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); + at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); + TensorWrapper mu_nvte = makeTransformerEngineTensor(mu_py); + TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); // Output tensor - std::unique_ptr my_quantizer = convert_quantizer(quantizer); - TensorWrapper out_cu; + auto quantizer_cpp = convert_quantizer(quantizer); + TensorWrapper out_nvte; if (out.is_none()) { - std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); + std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); } else { - out_cu = makeTransformerEngineTensor(out, quantizer); + out_nvte = makeTransformerEngineTensor(out, quantizer); } - // Determine whether to avoid fused kernel - bool force_unfused_kernel = true; - if (quantizer.is_none()) { - // No need for separate quantization step if output is unquantized - force_unfused_kernel = false; - } else if (IsFloat8Quantizers(quantizer.ptr())) { - // Always used fused kernel for FP8 delayed scaling - force_unfused_kernel = false; + // Choose implementation + enum class Impl { + // Compute norm in high precision, then quantize + UNFUSED, + // Compute norm directly + FULLY_FUSED, + // Compute norm and amax in high precision, then quantize to FP8 + FUSED_NORM_AMAX_FP8, +#ifndef USE_ROCM + // Compute norm and amax in high precision, then quantize to NVFP4 + FUSED_NORM_AMAX_NVFP4 +#endif + }; + Impl impl = Impl::UNFUSED; + if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) { + impl = Impl::FULLY_FUSED; } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - // cuDNN MXFP8 kernel requires full tile - force_unfused_kernel = N % 128 != 0 || H % 128 != 0; + if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 && + inner_size % 128 == 0) { + // cuDNN MXFP8 kernel requires full 128x128 tiles + impl = Impl::FULLY_FUSED; } + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + impl = Impl::FUSED_NORM_AMAX_FP8; +#ifndef USE_ROCM + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + // TE kernel supports amax output + impl = Impl::FUSED_NORM_AMAX_NVFP4; + } +#endif } - TensorWrapper unquantized_out_cu; + + // Construct unquantized output tensor if needed + TensorWrapper unquantized_out_nvte; py::object unquantized_out; - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - std::tie(unquantized_out_cu, unquantized_out) = - my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); - } else { + TensorWrapper *kernel_out_nvte = &out_nvte; + switch (impl) { + case Impl::UNFUSED: { NoneQuantizer q{none}; - std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; +#ifndef USE_ROCM + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; +#endif + default: { } } - TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; // Query workspace size TensorWrapper workspace; NVTE_SCOPED_GIL_RELEASE({ - nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps, + kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(), + workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); @@ -140,24 +181,33 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Launch kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps, + kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(), + workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); - // Quantize output if using unfused kernel - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); - } else { - my_quantizer->quantize(unquantized_out_cu, out_cu); + // Quantize output if needed + switch (impl) { + case Impl::UNFUSED: { + quantizer_cpp->quantize(unquantized_out_nvte, out_nvte); + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; +#ifndef USE_ROCM + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; +#endif + default: { } } - return {out, py::cast(mu), py::cast(rsigma)}; + return {out, py::cast(mu_py), py::cast(rsigma_py)}; } std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, @@ -256,61 +306,101 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Input and param tensors auto none = py::none(); - const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); - const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); + const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none); // Tensor dimensions - const size_t N = static_cast(input_cu.shape().data[0]); - const size_t H = static_cast(input_cu.shape().data[1]); - const std::vector size = {N, H}; + const auto shape = nvte_shape_to_vector(input_nvte.shape()); + const auto outer_size = product(shape) / shape.back(); + const auto inner_size = shape.back(); // Tensors to save for backward pass - auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); + at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); + TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); // Output tensor - std::unique_ptr my_quantizer = convert_quantizer(quantizer); - TensorWrapper out_cu; + auto quantizer_cpp = convert_quantizer(quantizer); + TensorWrapper out_nvte; if (out.is_none()) { - std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); + std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); } else { - out_cu = makeTransformerEngineTensor(out, quantizer); + out_nvte = makeTransformerEngineTensor(out, quantizer); } - // Determine whether to avoid fused kernel - bool force_unfused_kernel = true; - if (quantizer.is_none()) { - // No need for separate quantization step if output is unquantized - force_unfused_kernel = false; - } else if (IsFloat8Quantizers(quantizer.ptr())) { - // Always used fused kernel for FP8 delayed scaling - force_unfused_kernel = false; + // Choose implementation + enum class Impl { + // Compute norm in high precision, then quantize + UNFUSED, + // Compute norm directly + FULLY_FUSED, + // Compute norm and amax in high precision, then quantize to FP8 + FUSED_NORM_AMAX_FP8, +#ifndef USE_ROCM + // Compute norm and amax in high precision, then quantize to NVFP4 + FUSED_NORM_AMAX_NVFP4 +#endif + }; + Impl impl = Impl::UNFUSED; + if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) { + impl = Impl::FULLY_FUSED; } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - // cuDNN MXFP8 kernel requires full tile - force_unfused_kernel = N % 128 != 0 || H % 128 != 0; + if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 && + inner_size % 128 == 0) { + // cuDNN MXFP8 kernel requires full 128x128 tiles + impl = Impl::FULLY_FUSED; } + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + impl = Impl::FUSED_NORM_AMAX_FP8; +#ifndef USE_ROCM + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + // TE kernel supports amax output + impl = Impl::FUSED_NORM_AMAX_NVFP4; + } +#endif } - TensorWrapper unquantized_out_cu; + + // Construct unquantized output tensor if needed + TensorWrapper unquantized_out_nvte; py::object unquantized_out; - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - std::tie(unquantized_out_cu, unquantized_out) = - my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); - } else { + TensorWrapper *kernel_out_nvte = &out_nvte; + switch (impl) { + case Impl::UNFUSED: { NoneQuantizer q{none}; - std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; +#ifndef USE_ROCM + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; +#endif + default: { } } - TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; // Query workspace size TensorWrapper workspace; NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), - workspace.data(), + nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(), + rsigma_nvte.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); @@ -322,24 +412,32 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Launch kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), - workspace.data(), + nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(), + rsigma_nvte.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); - // Quantize output if using unfused kernel - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); - } else { - my_quantizer->quantize(unquantized_out_cu, out_cu); + // Quantize output if needed + switch (impl) { + case Impl::UNFUSED: { + quantizer_cpp->quantize(unquantized_out_nvte, out_nvte); + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; +#ifndef USE_ROCM + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; +#endif + default: { } } - return {out, py::none(), py::cast(rsigma)}; + return {out, py::none(), py::cast(rsigma_py)}; } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 55b1d179e..94dc081e5 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -34,6 +34,11 @@ PyTypeObject *MXFP8QuantizerClass = nullptr; PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; +#ifndef USE_ROCM +PyTypeObject *NVFP4TensorPythonClass = nullptr; +PyTypeObject *NVFP4TensorBasePythonClass = nullptr; +PyTypeObject *NVFP4QuantizerClass = nullptr; +#endif void init_float8_extension() { if (Float8TensorPythonClass) return; @@ -88,10 +93,30 @@ void init_float8blockwise_extension() { "Internal error: could not initialize pyTorch float8blockwise extension."); } +#ifndef USE_ROCM +void init_nvfp4_extensions() { + if (NVFP4TensorPythonClass) return; + auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor"); + NVFP4QuantizerClass = reinterpret_cast( + PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer")); + NVFP4TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Tensor")); + auto nvfp4_base_module = + py::module_::import("transformer_engine.pytorch.tensor._internal.nvfp4_tensor_base"); + NVFP4TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorBase")); + NVTE_CHECK(NVFP4TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch NVFP4 extension."); +} +#endif + void init_extension() { init_float8_extension(); init_mxfp8_extension(); init_float8blockwise_extension(); +#ifndef USE_ROCM + init_nvfp4_extensions(); +#endif } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 9fd1ae4de..cc7e23c61 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -40,13 +40,14 @@ extern PyTypeObject *MXFP8QuantizerClass; extern PyTypeObject *Float8BlockwiseQTensorPythonClass; extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; extern PyTypeObject *Float8BlockwiseQuantizerClass; +#ifndef USE_ROCM +extern PyTypeObject *NVFP4TensorPythonClass; +extern PyTypeObject *NVFP4TensorBasePythonClass; +extern PyTypeObject *NVFP4QuantizerClass; +#endif void init_extension(); -void init_float8_extension(); - -void init_mxfp8_extension(); - namespace detail { inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } @@ -69,11 +70,21 @@ inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) { return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; } +#ifndef USE_ROCM +inline bool IsNVFP4Quantizers(PyObject *obj) { return Py_TYPE(obj) == NVFP4QuantizerClass; } +#endif + inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass; } +#ifndef USE_ROCM +inline bool IsNVFP4Tensor(PyObject *obj) { + return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorBasePythonClass; +} +#endif + TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); template @@ -88,10 +99,27 @@ std::unique_ptr CreateMXFP8Params(const py::handle params); TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantization_params); +#ifndef USE_ROCM +TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer); +#endif + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } +#ifndef USE_ROCM +constexpr std::array custom_types_converters = { + std::make_tuple(IsFloat8Tensor, IsFloat8Quantizers, NVTETensorFromFloat8Tensor, + CreateQuantizer), + std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor, + CreateQuantizer), + std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, + CreateQuantizer), + std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, + NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer), + std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor, + CreateQuantizer)}; +#else constexpr std::array custom_types_converters = { std::make_tuple(IsFloat8Tensor, IsFloat8Quantizers, NVTETensorFromFloat8Tensor, CreateQuantizer), @@ -101,7 +129,7 @@ constexpr std::array custom_types_converters = { CreateQuantizer), std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer)}; - +#endif } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 37c13362c..5568f0a3b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -33,8 +33,20 @@ std::vector make_transpose_shape(const std::vector& shape) { return ret; } +/*! @brief Convert shape for FP4 data by dividing the last dimension by 2 */ +template +std::vector convert_shape_for_fp4(const std::vector& shape) { + std::vector ret; + for (size_t i = 0; i < shape.size() - 1; ++i) { + ret.push_back(shape[i]); + } + ret.push_back(shape.back() / 2); + return ret; +} + } // namespace +constexpr size_t NVFP4_BLOCK_SIZE = 16; constexpr size_t MXFP8_BLOCK_SIZE = 32; Quantizer::Quantizer(const py::handle& quantizer) { @@ -378,10 +390,15 @@ std::pair Float8CurrentScalingQuantizer::create_tenso return {std::move(out_cpp), std::move(out_py)}; } -std::pair Float8CurrentScalingQuantizer::create_hp_tensor_with_amax( - const std::vector& shape, DType dtype) { +std::pair +Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, + DType dtype, + std::optional data) { amax.zero_(); - auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); + auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) + : NoneQuantizer(py::none()).create_tensor(shape, dtype); + TensorWrapper out_cpp = std::move(out.first); + py::object out_py = std::move(out.second); out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); return {std::move(out_cpp), std::move(out_py)}; @@ -909,7 +926,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1105,7 +1122,7 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s auto last_dim = shape.back(); NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); std::vector scale_shape; @@ -1126,4 +1143,574 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s return scale_shape; } +#ifndef USE_ROCM +NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { + this->dtype = quantizer.attr("dtype").cast(); + this->with_rht = quantizer.attr("with_rht").cast(); + this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); + this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); + this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); + + // Get amax reduction group if needed for NVFP4 AG + const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); + c10::intrusive_ptr amax_reduction_group; + if (with_amax_reduction) { + auto group = quantizer.attr("_canonicalized_amax_reduction_group")(); + NVTE_CHECK(!group.is_none(), "NVFP4Quantizer could not canonicalize amax reduction group"); + amax_reduction_group = group.cast>(); + } + this->with_amax_reduction = with_amax_reduction; + this->amax_reduction_group = amax_reduction_group; + + this->rht_matrix_random_sign_mask_t = quantizer.attr("rht_matrix_random_sign_mask_t").cast(); + this->rht_matrix = quantizer.attr("rht_matrix").cast(); +} + +void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { + // set dtype for rowwise and columnwise data in tensor wrapper + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(this->dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(this->dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, + DType dtype) const { + using namespace pybind11::literals; + + // Tensor dimensions + const std::vector shape_int64(shape.begin(), shape.end()); + size_t flat_first_dim = 1; + if (shape.size() > 0) { + for (size_t i = 0; i < shape.size() - 1; ++i) { + flat_first_dim *= shape[i]; + } + } + const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", + NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); + NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, + "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, + " (got shape=", shape, ")"); + const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); + const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); + + // Allocate tensors + at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor, amax_rowwise; + at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor, amax_columnwise; + const auto bit8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto bit32_tensor_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + if (rowwise_usage) { + const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), + rowwise_scale_inv_shape.end()); + rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); + rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + amax_rowwise = at::empty({1}, bit32_tensor_opts); + } + if (columnwise_usage) { + const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), + columnwise_scale_inv_shape.end()); + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_int64_2d = {static_cast(flat_first_dim), + static_cast(flat_last_dim)}; + const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); + columnwise_data_tensor = + at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts); + columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + amax_columnwise = at::empty({1}, bit32_tensor_opts); + } + + // Convert tensors to Python + auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object { + return need_cast ? py::cast(tensor) : py::none(); + }; + auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); + auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage); + auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); + auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); + auto amax_rowwise_py = py_cast(amax_rowwise, rowwise_usage); + auto amax_columnwise_py = py_cast(amax_columnwise, columnwise_usage); + + // Construct Python NVFP4 tensor + py::object out_py; + if (internal) { + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorBasePythonClass)); + out_py = NVFP4TensorClass( + "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, + "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer); + } else { + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorPythonClass)); + out_py = NVFP4TensorClass( + "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, + "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer); + } + + // Construct C++ tensor + TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, + rowwise_scale_inv_shape); + out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector{1}); + } + if (columnwise_usage) { + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_2d = {flat_first_dim, flat_last_dim}; + auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), DType::kFloat4E2M1, + col_data_shape_fp4); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, + columnwise_scale_inv_shape); + out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair NVFP4Quantizer::create_unquantized_tensor_with_amax( + TensorWrapper& quantized_tensor, DType dtype) { + // Construct tensor + auto shape = convertShape(quantized_tensor.shape()); + auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); + + // Register amax pointer from quantized tensor + void* amax_ptr = quantized_tensor.amax(); + if (amax_ptr == nullptr) { + amax_ptr = quantized_tensor.get_columnwise_amax().data_ptr; + } + NVTE_CHECK(amax_ptr != nullptr, "Could not extract amax pointer from NVFP4 tensor."); + out_cpp.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + + // Zero out amax + NVTE_CHECK_CUDA(cudaMemsetAsync(amax_ptr, 0, sizeof(float), at::cuda::getCurrentCUDAStream())); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair NVFP4Quantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); + + // Extract buffers from Python tensor + auto get_tensor = [&tensor](const char* name) -> std::optional { + auto attr_py = tensor.attr(name); + if (attr_py.is_none()) { + return std::nullopt; + } + return attr_py.cast(); + }; + auto rowwise_data = get_tensor("_rowwise_data"); + auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); + auto columnwise_data = get_tensor("_columnwise_data"); + auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); + auto amax_rowwise = get_tensor("_amax_rowwise"); + auto amax_columnwise = get_tensor("_amax_columnwise"); + NVTE_CHECK(rowwise_data || columnwise_data, "NVFP4Tensor has no data."); + + // Tensor dimensions, shape means original shape + std::vector shape; + if (columnwise_data) { + shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); + if (rowwise_data) { + auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, + ") and column-wise data (shape=", shape, ") do not match"); + } + } else { // Already checked columnwise_data_tensor == true + shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + } + + size_t flat_first_dim = 1; + if (shape.size() > 0) { + for (size_t i = 0; i < shape.size() - 1; ++i) { + flat_first_dim *= shape[i]; + } + } + const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + + // Coerce row-wise data + if (rowwise_usage) { + if (!rowwise_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_data = at::empty(convert_shape_for_fp4(shape_int64), opts); + tensor.attr("_rowwise_data") = *rowwise_data; + } + if (!rowwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, false); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; + } + if (!amax_rowwise) { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + amax_rowwise = at::empty({1}, opts); + tensor.attr("_amax_rowwise") = *amax_rowwise; + } + } else { // rowwise_usage == false + if (rowwise_data) { + rowwise_data.reset(); + tensor.attr("_rowwise_data") = py::none(); + } + if (rowwise_scale_inv) { + rowwise_scale_inv.reset(); + tensor.attr("_rowwise_scale_inv") = py::none(); + } + if (amax_rowwise) { + amax_rowwise.reset(); + tensor.attr("_amax_rowwise") = py::none(); + } + } + + // Coerce column-wise data + if (columnwise_usage) { + if (!columnwise_data) { + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_int64_2d = {static_cast(flat_first_dim), + static_cast(flat_last_dim)}; + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); + columnwise_data = at::empty(convert_shape_for_fp4(transpose_shape_int64), opts); + tensor.attr("_columnwise_data") = *columnwise_data; + } + if (!columnwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, true); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; + } + if (!amax_columnwise) { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + amax_columnwise = at::zeros({1}, opts); + tensor.attr("_amax_columnwise") = *amax_columnwise; + } + } else { // columnwise_usage == false + if (columnwise_data) { + columnwise_data.reset(); + tensor.attr("_columnwise_data") = py::none(); + } + if (columnwise_scale_inv) { + columnwise_scale_inv.reset(); + tensor.attr("_columnwise_scale_inv") = py::none(); + } + if (amax_columnwise) { + amax_columnwise.reset(); + tensor.attr("_amax_columnwise") = py::none(); + } + } + + // Construct C++ tensor + TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), DType::kFloat4E2M1, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*rowwise_scale_inv)); + out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, std::vector{1}); + } + if (columnwise_usage) { + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_2d = {flat_first_dim, flat_last_dim}; + auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), DType::kFloat4E2M1, + col_data_shape_fp4); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*columnwise_scale_inv)); + out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag, + bool compute_amax) { + // Nothing to be done if input is empty + if (input.numel() == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); + quant_config.set_stochastic_rounding(this->stochastic_rounding); + + // We only need RHT for columnwise usage. + // flat first dim and last dim for multi dimensional input + size_t rows = 1; + for (size_t i = 0; i < input.ndim() - 1; ++i) { + rows *= input.size(i); + } + size_t cols = input.size(input.ndim() - 1); + + TensorWrapper te_rng_state; + if (this->stochastic_rounding) { + const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); + auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + auto rng_state = torch::empty({2}, opts); + philox_unpack(philox_args, static_cast(rng_state.data_ptr())); + te_rng_state = makeTransformerEngineTensor(rng_state); + quant_config.set_rng_state(te_rng_state.data()); + } + + // Restriction for the RHT cast fusion kernel. + bool eligible_for_rht_cast_fusion = + input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; + + // Compute amax. + if (this->with_rht) { + if (input.dtype() != DType::kBFloat16) { + NVTE_CHECK(false, "RHT is only supported for bfloat16 input"); + } + if (this->with_post_rht_amax) { + // We need: + // 1. Rowwise amax = amax for input + // 2. Columnwise amax = amax for RHT(input.t) + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform_amax(input.data(), out.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + } else { + // raise error since it's not supported yet + NVTE_CHECK(false, "Pre-RHT amax is not supported yet"); + } + } else { // Without RHT + if (compute_amax) { + // Amax pointers + auto rowwise_amax_ptr = out.get_amax().data_ptr; + auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; + void* amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + + // Compute amax of input tensor + out.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); + out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector{1}); + + // Make sure row-wise and column-wise amaxes match + if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + } + } + + // amax reduction + if (this->with_amax_reduction) { + std::vector amax_tensors; + // push amax tensors inside if they need to be reduced + auto make_amax_tensor = [](void* data_ptr) { + return at::from_blob( + data_ptr, std::vector{1}, + [](void*) {}, // deleter doing nothing since it doesn't own the data + at::device(at::kCUDA).dtype(torch::kFloat32)); + }; + if (rowwise_usage) { + amax_tensors.push_back(make_amax_tensor(out.get_amax().data_ptr)); + } + if (columnwise_usage) { + amax_tensors.push_back(make_amax_tensor(out.get_columnwise_amax().data_ptr)); + } + c10d::AllreduceCoalescedOptions opts; + opts.reduceOp = c10d::ReduceOp::MAX; + NVTE_SCOPED_GIL_RELEASE( + { this->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); }); + } + + if (this->with_rht) { + if (rowwise_usage) { + // For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise + TensorWrapper out_identity(out.scaling_mode()); + auto out_identity_data = out.get_rowwise_data(); + auto out_identity_scale_inv = out.get_rowwise_scale_inv(); + auto out_identity_amax = out.get_amax(); + out_identity.set_rowwise_data(out_identity_data.data_ptr, + static_cast(out_identity_data.dtype), + out_identity_data.shape); + out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, + static_cast(out_identity_scale_inv.dtype), + out_identity_scale_inv.shape); + out_identity.set_amax(out_identity_amax.data_ptr, static_cast(out_identity_amax.dtype), + out_identity_amax.shape); + + NVTE_SCOPED_GIL_RELEASE( + { nvte_quantize_v2(input.data(), out_identity.data(), quant_config, stream); }); + } + + if (columnwise_usage) { + // Get the output columnwise data, scale_inv, and amax + auto out_columnwise_data = out.get_columnwise_data(); + auto out_columnwise_scale_inv = out.get_columnwise_scale_inv(); + // NOTE: should already be populated. + auto out_columnwise_amax = out.get_columnwise_amax(); + + // Create a wrapper for the columnwise output, as the rowwise output. + // The reason is due to the input `rht_output_t` is already in the transposed layout. + // Thus, we only need a rowwise quantization to generate the columnwise output. + TensorWrapper out_transpose(out.scaling_mode()); + // Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail + // need to convert the shape to 2D here + auto colwise_data_shape = out_columnwise_data.shape; + std::vector colwise_data_shape_2d; + // shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte + // the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again + // so the multiple 2 get cancelled out + colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); + size_t last_dim = 1; + for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { + last_dim *= colwise_data_shape.data[i]; + } + colwise_data_shape_2d.push_back(last_dim); + + out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, + static_cast(out_columnwise_data.dtype), + colwise_data_shape_2d); + out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, + static_cast(out_columnwise_scale_inv.dtype), + out_columnwise_scale_inv.shape); + out_transpose.set_amax(out_columnwise_amax.data_ptr, + static_cast(out_columnwise_amax.dtype), + out_columnwise_amax.shape); + + if (!eligible_for_rht_cast_fusion) { + // Invoking fallback RHT kernel. + + // If using RHT, then amax will be computed in the RHT step + // If not using RHT, then amax will be computed based on input x + at::Tensor rht_output_t; // The RHT(x_t) output, in columnwise layout + // This wrapper is going to be passed as input to the quantization kernel. + TensorWrapper rht_output_t_cpp; // Wrapper to contain the RHT(x) and RHT(x_t) outputs + rht_output_t = + allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); + // NOTE (frsun): This is non-intuitive, we are writing the + // result of transposed RHT to the output of rowwise. + rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), + std::vector{cols, rows}); + + NVTE_SCOPED_GIL_RELEASE({ + // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. + nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + + // Quantize kernel will treat everything as rowwise input/output, which is + // intended. + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config, stream); + }); + } else { + // RHT cast fusion kernel. + NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0, + "RHT matrix is not set"); + auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix); + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform_cast_fusion_columnwise( + input.data(), out_transpose.data(), rht_matrix_nvte.data(), quant_config, stream); + }); + } + } + } else { + NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); + } +} + +void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + this->quantize_impl(input, out, noop_flag, true); +} + +void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { + // Update output tensor amaxes with input tensor amax + auto input_amax_ptr = input.amax(); + auto output_rowwise_amax_ptr = out.get_amax().data_ptr; + auto output_columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; + NVTE_CHECK(input_amax_ptr != nullptr || + (output_rowwise_amax_ptr == nullptr && output_columnwise_amax_ptr == nullptr), + "Input tensor does not have pre-computed amax"); + if (input_amax_ptr != output_rowwise_amax_ptr && input_amax_ptr != nullptr && + output_rowwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(output_rowwise_amax_ptr, input_amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream())); + } + if (input_amax_ptr != output_columnwise_amax_ptr && input_amax_ptr != nullptr && + output_columnwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(output_columnwise_amax_ptr, input_amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream())); + } + input.set_amax(nullptr, DType::kFloat32, input.defaultShape); + + // Perform quantization + this->quantize_impl(input, out, std::nullopt, false); +} + +std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, + bool columnwise) const { + size_t numel = 1; + for (auto s : shape) { + numel *= s; + } + + auto last_dim = shape.back(); + auto flat_first_dim = numel / last_dim; + + NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ", + NVFP4_BLOCK_SIZE, " (got dim=", last_dim, ")"); + NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, + "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, + " (got shape=", shape, ")"); + + std::vector scale_shape; + + bool rowwise_usage = !columnwise; + + if (rowwise_usage) { + // rowwise scaling factor shape + size_t sinv0 = roundup(flat_first_dim, 128); + size_t sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); + scale_shape = {sinv0, sinv1}; + } else { + // columnwise scaling factor shape + size_t sinv0 = roundup(last_dim, 128); + size_t sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); + scale_shape = {sinv0, sinv1}; + } + return scale_shape; +} +#endif } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index cb2121a45..368e9dcdf 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -116,6 +116,46 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer return ret; } +TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) { + const DType dtype = tensor.attr("_fp4_dtype").cast(); + + auto ret = TensorWrapper(NVTE_NVFP4_1D_SCALING); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + + NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); + + // Row-scaled data + if (rowwise_usage) { + const auto &data = tensor.attr("_rowwise_data").cast(); + const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); + const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); + ret.set_rowwise_data(data.data_ptr(), dtype, + convert_shape_back_from_fp4(getTensorShape(data), false)); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); + ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); + } + + // Column-scaled data + if (columnwise_usage) { + const auto &data = tensor.attr("_columnwise_data").cast(); + const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); + const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); + ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, + convert_shape_back_from_fp4(getTensorShape(data), false)); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, + getTensorShape(scale_inv)); + ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, + getTensorShape(amax_columnwise)); + } + + // Quantizer state + quantizer->set_quantization_params(&ret); + + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 44b636930..7eba63e80 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -18,22 +18,31 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && + input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } - NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + NVTE_CHECK(input.element_size_bits() == 4 || input.element_size_bits() == 8, + "4-bit or 8-bit input required for swizzling scaling factors."); + + const auto nvfp4 = input.scaling_mode() == NVTE_NVFP4_1D_SCALING; NVTEBasicTensor scale_inv; + NVTEShape nvte_input_shape; if (rowwise) { + nvte_input_shape = input.shape(); scale_inv = input.get_rowwise_scale_inv(); } else { + nvte_input_shape = input.get_columnwise_data().shape; scale_inv = input.get_columnwise_scale_inv(); } - auto input_shape = nvte_shape_to_vector(input.shape()); + auto input_shape = nvte_shape_to_vector(nvte_input_shape); auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); + NVTE_CHECK(input_shape.size() >= 2, "Wrong ndims for swizzle input shape."); + // Allocate memory for swizzled output. auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); std::vector scale_inv_shape_int; @@ -45,36 +54,34 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); // Reconstruct input only to avoid swizzling both directions if not needed. - // Use any 8 bit type, it's irrelevant. - transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); - transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + // The specific dtype used is irrelevant, just needs to be correct bits. + transformer_engine::TensorWrapper input_cu(input.scaling_mode()); + transformer_engine::TensorWrapper output_cu(input.scaling_mode()); + + const auto input_dtype = + (nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3; + const auto scale_inv_dtype = + (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; + if (rowwise) { - input_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); - output_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } else { - input_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); - output_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, - transformer_engine::DType::kFloat8E8M0, scale_inv_shape); + input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } // Launch kernel nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } else { - input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } return swizzled_scale_inv; diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e809528da..958ce9c7c 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -41,11 +41,14 @@ from .fp8 import FP8GlobalStateManager, fp8_autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.mxfp8_tensor import MXFP8Quantizer +from .tensor.nvfp4_tensor import NVFP4Quantizer from .tensor.float8_blockwise_tensor import Float8BlockQuantizer -from .tensor.quantized_tensor import QuantizedTensor, Quantizer +from .tensor.quantized_tensor import QuantizedTensorBase, QuantizedTensor, Quantizer from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from .tensor._internal.nvfp4_tensor_base import NVFP4TensorBase from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from .triton.pad import pad_columnwise_scale_inv from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer @@ -1208,6 +1211,245 @@ def _all_gather_fp8_blockwise( return out, handle +def _swap_first_dims(tensor: torch.Tensor, world_size: int): + """ + Swap first 2 dimensions of a tensor to fix interleaved + data format after gathering transposed data. + + For more than 2 dimensions, we squash the trailing dimensions, + instead of the first few dimensions, that's because the shape + passed in this function is already transposed. + """ + + shape = tensor.shape + assert tensor.ndim >= 2, "Wrong number of dimensions for fixing interleave." + first_dim = shape[0] + flattened_trailing = math.prod(shape[1:]) + assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave." + tensor = tensor.reshape(world_size, first_dim // world_size, flattened_trailing) + tensor = tex.swap_first_dims(tensor, out=None) + return tensor.reshape(first_dim // world_size, flattened_trailing * world_size) + + +def _post_process_nvfp4_gather( + out: NVFP4TensorBase, + columnwise_data_interleaved: torch.Tensor, + columnwise_scale_inv_interleaved: torch.Tensor, + world_size: int, + handle: Optional[torch.distributed.Work] = None, +) -> NVFP4TensorBase: + """Post-process FP8 blockwise gather.""" + if handle is not None: + handle.wait() + handle = None + + # Fix the interleaved transposed data from gathering along first dim. + out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) + out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) + + # Optionally pad the scaling inverse if needed. + out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) + + +@dataclass +class _NVFP4AllGatherAsyncHandle: + """Handle for asynchronous NVFP4 all-gather.""" + + output: NVFP4TensorBase + columnwise_data_interleaved: torch.Tensor + columnwise_scale_inv_interleaved: torch.Tensor + world_size: int + async_handle: torch.distributed.Work + _synchronized: bool = False + + def wait(self) -> None: + """Wait for the async operation to complete and post-process the tensor.""" + if self._synchronized: + return + self.async_handle.wait() + _post_process_nvfp4_gather( + self.output, + self.columnwise_data_interleaved, + self.columnwise_scale_inv_interleaved, + self.world_size, + ) + self._synchronized = True + + +def _all_gather_nvfp4( + inp: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, + quantizer: NVFP4Quantizer, + out_shape: Optional[list[int]] = None, +) -> tuple[NVFP4TensorBase, Optional[torch.distributed.Work]]: + """All-gather NVFP4 tensor along first dimension.""" + + # Input tensor attributes + in_shape: Iterable[int] = None + in_shape_t: Iterable[int] = None + device: torch.device + dtype: torch.dtype + + # Construct packed shapes for input and input_t. + if isinstance(inp, torch.Tensor) and not isinstance(inp, NVFP4TensorBase): + # High-precision tensor. + in_shape = NVFP4Quantizer.convert_shape_for_fp4(inp.size()) + in_shape_t = NVFP4Quantizer.convert_shape_for_fp4( + NVFP4Quantizer.get_columnwise_shape(inp.size()) + ) + device = inp.device + dtype = inp.dtype + elif isinstance(inp, NVFP4TensorBase): + if inp._rowwise_data is not None: + in_shape = inp._rowwise_data.size() + device = inp._rowwise_data.device + if inp._columnwise_data is not None: + in_shape_t = inp._columnwise_data.size() + device = inp._columnwise_data.device + dtype = torch.bfloat16 + else: + raise ValueError( + "Invalid type for input tensor (expected torch.Tensor or NVFP4TensorBase, " + f"found {inp.__class__.__name__})" + ) + + assert in_shape is not None or in_shape_t is not None, "No data found." + + world_size = get_distributed_world_size(process_group) + + if out_shape is None: + out_shape = [in_shape[0] * world_size] + in_shape[1:] + + # For cases where inp has dimensions that cannot be quantized, + # we gather in high precision followed by a cast to NVFP4. + if ( + not isinstance(inp, NVFP4TensorBase) + and quantizer is not None + and not quantizer.is_quantizable(inp) + ): + out = torch.empty( + out_shape, + dtype=dtype, + device=device, + memory_format=torch.contiguous_format, + ) + torch.distributed.all_gather_into_tensor(out, inp, group=process_group) + out = quantizer(out) + return out, None + + # Cast input tensor to NVFP4 with required data + if not isinstance(inp, NVFP4TensorBase): + inp = quantizer(inp) + elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( + quantizer.columnwise_usage and inp._columnwise_data is None + ): + warnings.warn( + "Input and quantizer do not have matching usages. " + "Dequantizing and requantizing to NVFP4." + ) + inp = quantizer(inp.dequantize()) + + # Construct NVFP4 output tensor + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + + # Coalesce NCCL collectives for gathering data and scale inverses. + with torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) as gather_coalescing_manager: + + # Gather NVFP4 data for row-wise usage + if quantizer.rowwise_usage: + + # Remove padding from NVFP4 scale-inverses + assert in_shape is not None, "Shape not found." + in_scale_inv = inp._rowwise_scale_inv + out_scale_inv = out._rowwise_scale_inv + flattened_in_shape0 = math.prod(in_shape[:-1]) + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] + + # Launch all-gathers + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + torch.distributed.all_gather_into_tensor( + out._rowwise_data, + inp._rowwise_data, + group=process_group, + ) + + # Transfer amax to output. + out._amax_rowwise = inp._amax_rowwise + + # Gather the transposed NVFP4 data along first dimension. Fix format later. + if quantizer.columnwise_usage: + + # Remove padding from NVFP4 scale-inverses + # For doing an all-gather on transposed scale inverses, + # we need to remove padding from both dimension. + in_scale_inv = inp._columnwise_scale_inv + # take caution that for in_shape_t, flatten in the trailing dimensions! + flattened_in_shape0 = in_shape_t[0] + flattened_in_shape1 = math.prod(in_shape_t[1:]) + + # Remove dim0 padding + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + + # Remove dim1 padding (pack first). + unpadded_dim1 = flattened_in_shape1 * 2 // 16 + if in_scale_inv.size(1) != unpadded_dim1: + in_scale_inv = in_scale_inv[:, :unpadded_dim1].contiguous() + + # Construct tensor to gather transposed scale_inv (interleaved) and launch AG. + out_scale_inv = torch.empty( + [flattened_in_shape0 * world_size] + [in_scale_inv.shape[1]], + dtype=in_scale_inv.dtype, + layout=in_scale_inv.layout, + device=in_scale_inv.device, + ) + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + + # Construct tensor to gather transposed data (interleaved) and launch AG. + out_columnwise_data = torch.empty( + [inp._columnwise_data.shape[0] * world_size] + list(inp._columnwise_data.shape[1:]), + dtype=inp._columnwise_data.dtype, + layout=inp._columnwise_data.layout, + device=inp._columnwise_data.device, + ) + torch.distributed.all_gather_into_tensor( + out_columnwise_data, + inp._columnwise_data, + group=process_group, + ) + + # Transfer amax to output. + out._amax_columnwise = inp._amax_columnwise + + handle = gather_coalescing_manager if async_op else None + + # Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed. + if async_op and quantizer.columnwise_usage: + handle = _NVFP4AllGatherAsyncHandle( + out, out_columnwise_data, out_scale_inv, world_size, handle + ) + elif quantizer.columnwise_usage: + _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) + + return out, handle + + def _all_gather_mxfp8( inp: torch.Tensor, process_group: dist_group_type, @@ -1295,7 +1537,6 @@ def _all_gather_mxfp8( flattened_in_shape0 = math.prod(in_shape[:-1]) if in_scale_inv.size(0) != flattened_in_shape0: in_scale_inv = in_scale_inv[:flattened_in_shape0] - out_scale_inv[flattened_in_shape0 * world_size :].zero_() out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] # Launch all-gathers @@ -1319,7 +1560,6 @@ def _all_gather_mxfp8( flattened_in_shape0 = math.prod(in_shape[:-1]) // 32 if in_scale_inv.size(0) != flattened_in_shape0: in_scale_inv = in_scale_inv[:flattened_in_shape0] - out_scale_inv[flattened_in_shape0 * world_size :].zero_() out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] # Launch all-gathers @@ -1351,7 +1591,7 @@ def gather_along_first_dim( # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) if world_size == 1: - if quantizer is not None and not isinstance(inp, QuantizedTensor): + if quantizer is not None and not isinstance(inp, QuantizedTensorBase): inp = quantizer(inp) return inp, None @@ -1430,13 +1670,24 @@ def gather_along_first_dim( out_shape=out_shape, ) + # NVFP4 case + if isinstance(inp, NVFP4TensorBase) or isinstance(quantizer, NVFP4Quantizer): + assert isinstance(quantizer, NVFP4Quantizer) + return _all_gather_nvfp4( + inp, + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, + ) + # High-precision communication for quantized tensors if quantizer is not None: warnings.warn( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." ) - if isinstance(inp, QuantizedTensor): + if isinstance(inp, QuantizedTensorBase): inp = inp.dequantize() # Falling back to high-precision all-gather for Float8BlockQuantizer # means that it should directly output GEMM_READY format @@ -1454,7 +1705,7 @@ def gather_along_first_dim( return out, None # Dequantize quantized tensor if not supported - if isinstance(inp, QuantizedTensor): + if isinstance(inp, QuantizedTensorBase): warnings.warn( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." diff --git a/transformer_engine/pytorch/experimental/__init__.py b/transformer_engine/pytorch/experimental/__init__.py new file mode 100644 index 000000000..11658f636 --- /dev/null +++ b/transformer_engine/pytorch/experimental/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Experimental features and APIs.""" + +from .config import set_qlinear_params, get_experimental_quantizers + + +__all__ = ["set_qlinear_params", "get_experimental_quantizers"] diff --git a/transformer_engine/pytorch/experimental/config.py b/transformer_engine/pytorch/experimental/config.py new file mode 100644 index 000000000..fec6bc938 --- /dev/null +++ b/transformer_engine/pytorch/experimental/config.py @@ -0,0 +1,201 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Config API for experimental middleware between Transformer Engine and Kitchen.""" + +import dataclasses +import enum +import os +from typing import Optional + +from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.experimental import quantization +from transformer_engine.pytorch.experimental import quantization_microblock_ref +from transformer_engine.pytorch.experimental.quantization import MMParams + + +@dataclasses.dataclass() +class QLinearParams: + """Quantization parameters of linear layer. + + Contains ready-to-use quantizers for input (x), weight (w), and gradient (g) tensors. + """ + + x_quantizer: Optional[quantization.ExperimentalQuantizer] = None + w_quantizer: Optional[quantization.ExperimentalQuantizer] = None + g_quantizer: Optional[quantization.ExperimentalQuantizer] = None + + mm_fprop: Optional[MMParams] = None + mm_dgrad: Optional[MMParams] = None + mm_wgrad: Optional[MMParams] = None + + +@enum.unique +class QuantizeRecipe(enum.Enum): + """Pre-defined quantization recipes for linear layers.""" + + NON_QUANTIZE = "non_quantize" + NVFP4_REF = "nvfp4_ref" + NVFP4_REF_RHT_ONLY = "nvfp4_ref_rht_only" + NVFP4_REF_2D_QUANTIZATION_ONLY = "nvfp4_ref_2d_quantization_only" + NVFP4_REF_RHT_AND_2D_QUANTIZATION = "nvfp4_ref_rht_and_2d_quantization" + + +def get_qlinear_params_from_predefined( + recipe: QuantizeRecipe, +) -> Optional[QLinearParams]: + """Get quantization parameters for linear layer based on recipe.""" + if recipe == QuantizeRecipe.NON_QUANTIZE: + return None + if recipe == QuantizeRecipe.NVFP4_REF: + return QLinearParams( + x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + ), + w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + ), + g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + ), + ) + if recipe == QuantizeRecipe.NVFP4_REF_RHT_ONLY: + return QLinearParams( + x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ), + w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=False, + ), + g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ), + ) + if recipe == QuantizeRecipe.NVFP4_REF_2D_QUANTIZATION_ONLY: + return QLinearParams( + x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=False, + ), + w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(16, 16), + pow_2_scales=False, + with_rht=False, + ), + g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=False, + ), + ) + if recipe == QuantizeRecipe.NVFP4_REF_RHT_AND_2D_QUANTIZATION: + return QLinearParams( + x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ), + w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(16, 16), + pow_2_scales=False, + with_rht=False, + ), + g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ), + ) + raise ValueError(f"Unsupported quantize recipe: {recipe}") + + +def get_qlinear_params_from_qat_params(qat_params_idx: int) -> Optional[QLinearParams]: + """Load quantization options from Kitchen to Transformer Engine. + + TODO(etsykunov): Confirm docstring is correct. + """ + assert qat_params_idx > 0, "QAT_PARAMS is not set." + + if qat_params_idx == 6010: + return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF) + if qat_params_idx == 960109: + return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_RHT_ONLY) + if qat_params_idx == 9002: + return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_2D_QUANTIZATION_ONLY) + if qat_params_idx == 9003: + return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_RHT_AND_2D_QUANTIZATION) + raise ValueError(f"Unsupported QAT params index: {qat_params_idx}") + + +def set_qlinear_params( + qlinear_params: Optional[QLinearParams] = None, + layer_number: Optional[int] = None, + layer_name: Optional[str] = None, +) -> Optional[QLinearParams]: + """Set quantization parameters based on configuration. + + Args: + qlinear_params: Quantization parameters. If None, loaded from environment. + layer_number: The numerical index of this layer in the model structure. + layer_name: The name for this layer. + + Returns: + QLinearParams: The finalized quantization parameters for this layer. + """ + if qlinear_params is None: + qat_params_idx = int(os.getenv("QAT_PARAMS", "0")) + if qat_params_idx == 0: + return None + return get_qlinear_params_from_qat_params(qat_params_idx) + + # Apply layer-specific overrides + if layer_number is not None: + raise NotImplementedError("Layer-specific overrides are not supported yet.") + if layer_name is not None: + raise NotImplementedError("Layer-specific overrides are not supported yet.") + + return qlinear_params + + +def get_experimental_quantizers(fp8: bool, qlinear_params: QLinearParams): + """Replacement of _get_quantizers() in TE modules.""" + if not fp8: + raise ValueError("FP8 is required to be enabled for experimental quantization.") + input_quantizer = qlinear_params.x_quantizer + weight_quantizer = qlinear_params.w_quantizer + output_quantizer = None + grad_input_quantizer = None + grad_weight_quantizer = None + grad_output_quantizer = qlinear_params.g_quantizer + + return ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) diff --git a/transformer_engine/pytorch/experimental/gemm.py b/transformer_engine/pytorch/experimental/gemm.py new file mode 100644 index 000000000..d743b577b --- /dev/null +++ b/transformer_engine/pytorch/experimental/gemm.py @@ -0,0 +1,139 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""GEMM API for experimental middleware between Transformer Engine and Kitchen.""" + +from typing import Iterable, Optional + +import torch + +from transformer_engine.pytorch.experimental.quantization import ( + MMParams, + GEMMType, + ExperimentalQuantizedTensor, +) +from transformer_engine.pytorch.tensor.quantized_tensor import Quantizer + + +def experimental_gemm( + A: ExperimentalQuantizedTensor, + B: ExperimentalQuantizedTensor, + workspace: torch.Tensor, # pylint: disable=unused-argument + out_dtype: Optional[torch.dtype] = None, + quantization_params: Optional[Quantizer] = None, # pylint: disable=unused-argument + gelu: bool = False, # pylint: disable=unused-argument + gelu_in: torch.Tensor = None, # pylint: disable=unused-argument + accumulate: bool = False, # pylint: disable=unused-argument + layout: str = "TN", + out: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + bias: Optional[torch.Tensor] = None, + use_split_accumulator: bool = False, + grad: bool = False, +) -> Iterable[Optional[torch.Tensor]]: + """Dispatch GEMM to quantizer's qgemm method.""" + assert isinstance(A, ExperimentalQuantizedTensor) and isinstance( + B, ExperimentalQuantizedTensor + ), "A and B must be ExperimentalQuantizedTensor instances" + + A, B = B, A + + # Determine GEMM type based on grad flag and layout + if not grad: + gemm_type = GEMMType.FPROP + else: + if layout == "NN": + gemm_type = GEMMType.DGRAD + elif layout == "NT": + gemm_type = GEMMType.WGRAD + else: + # Default to FPROP for other layouts + gemm_type = GEMMType.FPROP + + # Extract quantizer from QuantizedTensor to get qgemm logic + # TODO(etsykunov): make it more flexible, what if we might want to use gemm logic from B.quantizer? + quantizer = None + if hasattr(A, "quantizer") and A.quantizer is not None: + quantizer = A.quantizer + elif hasattr(B, "quantizer") and B.quantizer is not None: + quantizer = B.quantizer + else: + raise ValueError("No quantizer found in QuantizedETensor objects") + + # Create MMParams + m_params = MMParams( + out_dtype=out_dtype, + use_split_accumulator=use_split_accumulator, + ) + out_dtype = A.dtype if m_params.out_dtype is None else m_params.out_dtype + + if gemm_type == GEMMType.FPROP: + qx, sx = A.data, A.scale + qw, sw = B.data, B.scale + assert qx is not None + assert sx is not None + assert qw is not None + assert sw is not None + assert A.original_shape is not None + + # Call quantizer's qgemm method + result = quantizer.qgemm( + qx, + qw, + m_params, + out_dtype, + sx, + sw, + bias, + gemm_type=GEMMType.FPROP, + qresult_x=A, + qresult_w=B, + ) + if len(A.original_shape) > 2: + # Original input was 3D, so we need to reshape result back to 3D + batch_size = A.original_shape[0] + seq_len = A.original_shape[1] + result = result.view(batch_size, seq_len, result.shape[-1]) + elif gemm_type == GEMMType.DGRAD: + qdy, sdy = A.data, A.scale + qw_t, sw_t = B.data_t, B.scale_t + assert qdy is not None + assert sdy is not None + assert qw_t is not None + assert sw_t is not None + + result = quantizer.qgemm( + qdy, + qw_t, + m_params, + out_dtype, + sdy, + sw_t, + None, + gemm_type=GEMMType.DGRAD, + qresult_x=A, + qresult_w=B, + ) + elif gemm_type == GEMMType.WGRAD: + qdy_t, sdy_t = A.data_t, A.scale_t + qx_t, sx_t = B.data_t, B.scale_t + assert qdy_t is not None + assert sdy_t is not None + assert qx_t is not None + assert sx_t is not None + + result = quantizer.qgemm( + qdy_t, + qx_t, + m_params, + out_dtype, + sdy_t, + sx_t, + None, + gemm_type=GEMMType.WGRAD, + qresult_x=A, + qresult_w=B, + ) + + # Return in the same format as general_gemm + return result, None, None, None diff --git a/transformer_engine/pytorch/experimental/quantization.py b/transformer_engine/pytorch/experimental/quantization.py new file mode 100644 index 000000000..9adf4dabf --- /dev/null +++ b/transformer_engine/pytorch/experimental/quantization.py @@ -0,0 +1,203 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Quantization API for experimental middleware between Transformer Engine and Kitchen.""" + +from __future__ import annotations +import abc +import dataclasses +import enum +from typing import Iterable, Optional, Tuple, Union + +import torch + +from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase, Quantizer +from transformer_engine.pytorch.experimental import utils + + +@enum.unique +class GEMMType(enum.Enum): + """Type of GEMM operation being performed.""" + + FPROP = "fprop" + DGRAD = "dgrad" + WGRAD = "wgrad" + + +@dataclasses.dataclass(frozen=True) +class MMParams: + """Matrix multiplication parameters.""" + + out_dtype: torch.dtype | None = None + # Use split accumulator for more accurate FP8 GEMM + use_split_accumulator: bool = True + + +@dataclasses.dataclass +class ExperimentalQuantizedTensor(QuantizedTensorBase): + """Base class for experimental quantized tensor containers. + + An experimental container to hold quantization result, including quantized tensor, optional + transposed quantized tensor, and corresponding decoding scales. + + data: torch.Tensor + the quantized tensor. + scale: torch.Tensor + the decoding scale for the quantized tensor. Shape depends on the scaling granularity. + - if scaling type is PER_TENSOR, it should be a 1D scalar tensor. + data_t: torch.Tensor + the transposed quantized tensor (computed lazily if needed). + scale_t: torch.Tensor + the decoding scale for the transposed quantized tensor. + dtype: torch.dtype + nominal tensor datatype. + device: torch.device + device of the tensor. + quant_dtype: Union[utils.Fp4Formats, torch.dtype] + low precision tensor datatype. + original_shape: Tuple[int, ...] + original shape of the tensor. + quantizer: ExperimentalQuantizer + Builder class for quantized tensor. + """ + + data: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None + data_t: Optional[torch.Tensor] = None + scale_t: Optional[torch.Tensor] = None + global_amax_row: Optional[torch.Tensor] = None + global_amax_col: Optional[torch.Tensor] = None + + dtype: Optional[torch.dtype] = None + device: Optional[torch.device] = None + quant_dtype: Optional[Union[utils.Fp4Formats, torch.dtype]] = None + original_shape: Optional[Tuple[int, ...]] = None + quantizer: Optional[ExperimentalQuantizer] = None + + @property + def experimental(self) -> bool: + """Flag to indicate this quantizer is using experimental Kitchen middleware.""" + return True + + def get_quantizer(self) -> ExperimentalQuantizer: + """Get builder for QuantizedExperimentalTensor + + Quantizer can be used for in-place operations. + + """ + if self.quantizer is not None: + return self.quantizer + raise ValueError("Quantizer is not set") + + def prepare_for_saving( + self, + ) -> Tuple[list[Optional[torch.Tensor]], ExperimentalQuantizedTensor]: + """Prepare the quantization result for saving for backward""" + tensors = [self.data, self.data_t, self.scale, self.scale_t] + self.data = None + self.data_t = None + self.scale = None + self.scale_t = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the quantization result from the saved tensors""" + self.data = tensors[0] + self.data_t = tensors[1] + self.scale = tensors[2] + self.scale_t = tensors[3] + return tensors[4:] + + def dequantize(self, *args, **kwargs) -> torch.Tensor: + """Dequantize the quantized tensor""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement dequantize function" + ) + + # Compatibility + @property + def _data(self): + return self.data + + @_data.setter + def _data(self, value): + self.data = value + + @property + def _scale_inv(self): + return self.scale + + @_scale_inv.setter + def _scale_inv(self, value): + self.scale = value + + +class ExperimentalQuantizer(Quantizer): + """Experimental Quantizer class + + Defines the interface for experimental quantizers. + """ + + def __init__(self, *, rowwise: bool, columnwise: bool) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.internal = True + + @property + def experimental(self) -> bool: + """Flag to indicate this quantizer is using experimental Kitchen middleware""" + return True + + @abc.abstractmethod + def qgemm( + self, + qx: torch.Tensor, + qw: torch.Tensor, + m_params: MMParams, + out_dtype: torch.dtype, + sx: torch.Tensor, + sw: torch.Tensor, + bias: torch.Tensor | None = None, + out: torch.Tensor | None = None, + accumulate: bool = False, + gemm_type: GEMMType = GEMMType.FPROP, + qresult_x: ExperimentalQuantizedTensor | None = None, + qresult_w: ExperimentalQuantizedTensor | None = None, + ) -> torch.Tensor: + """Quantized GEMM interface.""" + + def dequantize(self, *args, **kwargs) -> torch.Tensor: + """Dequantize the quantized tensor""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement dequantize function" + ) + + def update_quantized(self, *args, **kwargs) -> torch.Tensor: + """Update the quantized tensor with the given tensor in-place""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement update_quantized function" + ) + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> QuantizedTensorBase: + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement make_empty function" + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement calibrate function" + ) + + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement _get_compatible_recipe function" + ) diff --git a/transformer_engine/pytorch/experimental/quantization_microblock_ref.py b/transformer_engine/pytorch/experimental/quantization_microblock_ref.py new file mode 100644 index 000000000..da749d237 --- /dev/null +++ b/transformer_engine/pytorch/experimental/quantization_microblock_ref.py @@ -0,0 +1,811 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""NVFP4 implementations for experimental middleware between Transformer Engine and Kitchen.""" + +from typing import Optional, Tuple + +import torch + +from transformer_engine.pytorch.experimental import quantization +from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.experimental.quantization import ( + ExperimentalQuantizedTensor, + ExperimentalQuantizer, +) + + +def cast_to_fp4x2(x): + """Quantize a tensor to FP4 E2M1 and store in a byte tensor""" + + result = torch.zeros_like(x, dtype=torch.uint8) + result[(x >= 0.0) & (x <= 0.25)] = 0 + result[(x > 0.25) & (x < 0.75)] = 1 + result[(x >= 0.75) & (x <= 1.25)] = 2 + result[(x > 1.25) & (x < 1.75)] = 3 + result[(x >= 1.75) & (x <= 2.5)] = 4 + result[(x > 2.5) & (x < 3.5)] = 5 + result[(x >= 3.5) & (x <= 5.0)] = 6 + result[x > 5.0] = 7 + + result[(x >= -0.25) & (x < -0.0)] = 8 + result[(x < -0.25) & (x > -0.75)] = 9 + result[(x <= -0.75) & (x >= -1.25)] = 10 + result[(x < -1.25) & (x > -1.75)] = 11 + result[(x <= -1.75) & (x >= -2.5)] = 12 + result[(x < -2.5) & (x > -3.5)] = 13 + result[(x <= -3.5) & (x >= -5.0)] = 14 + result[x < -5.0] = 15 + + return result[:, ::2] + result[:, 1::2] * 16 + + +def cast_from_fp4x2(x, dq_dtype): + """Dequantize FP4 E2M1 tensor that has been represented in a byte tensor""" + fp4_values = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + device=x.device, + dtype=dq_dtype, + ) + + # Convert to long integers for indexing + second_bit = torch.div(x, 16, rounding_mode="floor").to(torch.long) + first_bit = (x - second_bit * 16).to(torch.long) + + # Use the long integers to index fp4_values + first_bit_values = fp4_values[first_bit] + second_bit_values = fp4_values[second_bit] + + result = torch.zeros( + (first_bit_values.shape[0], first_bit_values.shape[1] * 2), + device=x.device, + dtype=dq_dtype, + ) + result[:, ::2] = first_bit_values + result[:, 1::2] = second_bit_values + + return result + + +def cast_to_e8(decode_scale): + """Cast to a value that is representable in FP8 E8M0. + + The result is in FP32, not FP8 E8M0. + """ + max_exponent = torch.tensor(127, device=decode_scale.device, dtype=torch.float32) + exponent = torch.ceil(torch.log2(decode_scale)) + exponent = torch.clamp(exponent, min=-max_exponent, max=max_exponent) + + return torch.tensor(2.0, device=decode_scale.device, dtype=torch.float32) ** exponent + + +def cast_to_e4m3(decode_scale, global_amax): + """Scale and cast to FP8 E4M3. + + decode_scale is actually the encoding scaling factor. global_amax + can be any data tensor and not just the amax. + + TODO(etsykunov): Make less unintuitive. + """ + decode_scale = decode_scale * global_amax + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=decode_scale.device, dtype=torch.float32) + decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) + return decode_scale.to(torch.float8_e4m3fn) + + +def high_precision_gemm_ref( + a: torch.Tensor, + b: torch.Tensor, + out_dtype: torch.dtype, + accumulate: bool = False, + is_a_transposed: bool = False, + is_b_transposed: bool = False, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + scale_alpha: float = 1.0, +) -> torch.Tensor: + """GEMM implementation with unquantized data""" + # Handle transpositions + mat1, mat2 = a, b + if is_a_transposed: + mat1 = a.T + if is_b_transposed: + mat2 = b.T + + # Ensure dtype compatibility for torch.addmm + mat1 = mat1.to(out_dtype) + mat2 = mat2.to(out_dtype) + + # Determine output shape + y_shape = (mat1.size(0), mat2.size(1)) + + if bias is not None: + assert not accumulate, "Bias is not supported with accumulation" + bias = bias.to(out_dtype) + # With bias case + if out_dtype == torch.float32: + y_ref = torch.addmm(bias.repeat(mat1.size(0), 1), mat1, mat2, beta=1, alpha=1) + else: + y_ref = torch.addmm(bias, mat1, mat2, beta=1, alpha=scale_alpha) + else: + # Without bias case + if accumulate and out is not None: + y_ref = out.clone().to(out_dtype) + else: + y_ref = torch.zeros(y_shape, dtype=out_dtype, device=a.device) + torch.addmm(y_ref, mat1, mat2, beta=1, alpha=scale_alpha, out=y_ref) + + return y_ref + + +class NVFP4TensorRef(ExperimentalQuantizedTensor): + """NVFP4 tensor for middleware between Transformer Engine and Kitchen""" + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"dtype={self.dtype}, " + f"device={self.device}, " + f"quant_dtype={self.quant_dtype}, " + f"data={self.dequantize(dtype=self.dtype)}, " + f"original_shape={self.original_shape}" + ")" + ) + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> ExperimentalQuantizedTensor: + """In-place update of quantized data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, ExperimentalQuantizedTensor): + return self.quantize_(tensor.dequantize(), noop_flag=noop_flag) + self.get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from quantized tensor + """ + if dtype is None: + dtype = self.dtype + + # Ignore data_t for now + assert self.data is not None, "QuantizedTensor has no valid tensor data" + assert self.scale is not None, "QuantizedTensor has no valid scale" + tensor_data = self.data + tensor_scale = self.scale + # Dispatch to the quantizer + return self.get_quantizer().dequantize(tensor_data, tensor_scale, dtype=dtype) + + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + """Generate or remove quantized data based on provided usage.""" + has_data = self.data is not None + has_data_transpose = self.data_t is not None + needs_data = has_data + needs_data_transpose = has_data_transpose + + if rowwise_usage is not None: + needs_data = rowwise_usage + if columnwise_usage is not None: + needs_data_transpose = columnwise_usage + + # Generate data that is required + if needs_data and not has_data: + raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose") + if needs_data_transpose and not has_data_transpose: + if not has_data: + raise RuntimeError("FP8 data is required to generate FP8 data transpose") + self._create_transpose() + + # Delete data that is not required + if not needs_data: + self.data = None + if not needs_data_transpose: + self.data_t = None + + def _create_transpose(self): + """Create transposed quantized tensor""" + if not self.data.is_contiguous(): + self.data = self.data.contiguous() + self.data_t = self.data.t().contiguous() + self.scale_t = self.scale + + def size(self, *args, **kwargs): # pylint: disable=unused-argument + """Return the original tensor shape, not the internal packed data shape. + + FP4 quantization packs two 4-bit values into each 8-bit value, which reduces + the second dimension by half. This method returns the logical shape that + users expect, not the internal packed storage shape. + """ + assert self.original_shape is not None + return torch.Size(self.original_shape) + + +def get_wgrad_sign_vector() -> torch.Tensor: + """Hard-coded signs for Hadamard transform""" + return torch.tensor( + [1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0], + dtype=torch.float32, + ) + + +class NVFP4QuantizerRef(ExperimentalQuantizer): + """NVFP4 quantizer for middleware between Transformer Engine and Kitchen""" + + def __init__( + self, + dtype: utils.Fp4Formats, + rowwise: bool = True, + columnwise: bool = True, + pow_2_scales: bool = False, + eps: float = 0.0, + quant_tile_shape: Tuple[int, int] = (1, 16), + with_rht: bool = False, + with_random_sign_mask: bool = True, + ): + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.dtype = dtype + self.pow_2_scales = pow_2_scales + self.eps = eps + self.quant_tile_shape = quant_tile_shape + self.with_rht = with_rht + self.with_random_sign_mask = with_random_sign_mask + + @staticmethod + def _build_hadamard_matrix( + size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True + ) -> torch.Tensor: + """Construct a Hadamard matrix of given power-of-two size with entries +-1. + + Uses Sylvester construction to avoid SciPy dependency. + """ + assert (size & (size - 1)) == 0, "Hadamard size must be a power of two" + h = torch.ones((1, 1), device=device, dtype=torch.float32) + while h.shape[0] < size: + h = torch.cat( + [ + torch.cat([h, h], dim=1), + torch.cat([h, -h], dim=1), + ], + dim=0, + ) + if with_random_sign_mask: + sign_mat = get_wgrad_sign_vector().to(device) * torch.eye( + size, device=device, dtype=torch.float32 + ) + h = sign_mat @ h + return h.to(dtype) + + def _apply_rht(self, x: torch.Tensor) -> torch.Tensor: + """Apply randomized Hadamard transform without random signs (reference path). + + This matches the reference used in tests: x_reshaped @ (H * (1/sqrt(g))). + """ + # Only apply when enabled + if not self.with_rht: + return x + + # RHT dimension equals the quantization tile length (NVFP4 uses 16) + rht_dim = self.quant_tile_shape[1] + assert ( + x.shape[-1] % rht_dim == 0 + ), f"Inner dimension {x.shape[-1]} must be divisible by hadamard dimension {rht_dim}" + + # Build H and scale + H = self._build_hadamard_matrix(rht_dim, x.device, x.dtype, self.with_random_sign_mask) + scale = 1.0 / float(rht_dim) ** 0.5 + + # Perform blockwise transform along the last dimension + original_shape = x.shape + x_mat = x.contiguous().view(-1, rht_dim) + # Random sign matrix is identity in this reference (no sign flipping) + transform = H * scale + out = x_mat @ transform + return out.view(original_shape) + + @staticmethod + def _recover_swizzled_scales( + swizzled_scale: bool, scale: torch.Tensor, m: int, n: int, block_length: int + ) -> torch.Tensor: + if not swizzled_scale: + return scale + rounded_m = utils.roundup_div(m, 128) * 128 + scale_n = utils.roundup_div(n, block_length) + rounded_n = utils.roundup_div(scale_n, 4) * 4 + # Recover swizzled scaling factor layout -> linear layout + tmp = torch.reshape(scale, (rounded_m // 128, rounded_n // 4, 32, 4, 4)) + # after permutation, the layout is [rounded_m // 128, 4, 32, rounded_n // 4, 4] + tmp = torch.permute(tmp, (0, 3, 2, 1, 4)) + result = torch.reshape(tmp, (rounded_m, rounded_n)) + return result[:m, :scale_n] + + @classmethod + def _quantize_blockwise_reference( + cls, + x: torch.Tensor, + global_amax: torch.Tensor, + tile_len_x: int, + tile_len_y: int, + *, + pow_2_scales: bool, + eps: float, # pylint: disable=unused-argument + ) -> Tuple[torch.Tensor, torch.Tensor]: + + assert x.ndim == 2 + using_2d_quantization = tile_len_x == 16 and tile_len_y == 16 + m, n = x.shape + # Compute vec_max based on the original x (before reshape) + # For 1D quantization: amax over each row chunk of 16 + # For 2D quantization: amax over each 16x16 block, but output shape is still (128, 8, 1), filled with block amax + if using_2d_quantization: + # x shape: (128, 128) + x_blocks = ( + x.unfold(0, tile_len_y, tile_len_y) + .unfold(1, tile_len_x, tile_len_x) + .to(torch.float32) + ) # (8, 8, 16, 16) + block_amax = torch.amax(torch.abs(x_blocks), dim=(-1, -2)) # (8, 8) + # Now, expand to (128, 8, 1) by repeating each block_amax for 16 rows + vec_max = block_amax.repeat_interleave(tile_len_y, dim=0).unsqueeze(-1) # (128, 8, 1) + else: + # x shape: (128, 128) + x_reshaped = x.view(m, n // tile_len_x, tile_len_x) # (128, 8, 16) + vec_max = torch.amax(torch.abs(x_reshaped), dim=-1, keepdim=True).to( + torch.float32 + ) # (128, 8, 1) + x = x.view(m, n // tile_len_x, tile_len_x) + FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + decode_scale = torch.div(vec_max, FLOAT4_E2M1_MAX) + + if pow_2_scales: + decode_scale = cast_to_e8(decode_scale) + encode_scale = torch.div( + torch.tensor(1.0, device=x.device, dtype=torch.float32), + decode_scale.to(torch.float32), + ) + else: + global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) + global_encode_scale = torch.min( + global_encode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=global_encode_scale.device, + dtype=torch.float32, + ), + ) + if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): + global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + global_decode_scale = torch.div(1.0, global_encode_scale) + + decode_scale = decode_scale * global_encode_scale + decode_scale = torch.min( + decode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) + decode_scale = decode_scale.to(torch.float8_e4m3fn) + + encode_scale = torch.min( + torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + + scaled_x = x.to(torch.float32) * encode_scale + + clipped_x = torch.clamp(scaled_x, -FLOAT4_E2M1_MAX, FLOAT4_E2M1_MAX).reshape(m, n) + + return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1) + + @staticmethod + def _pad_tensor( + tensor: torch.Tensor, row_divisor: Optional[int], col_divisor: Optional[int] + ) -> torch.Tensor: + + assert tensor.dim() == 2, "only supports 2D tensors" + M, N = tensor.shape + padding_needed_rows = 0 + padding_needed_cols = 0 + + if row_divisor is not None and M % row_divisor != 0: + padding_needed_rows = row_divisor - (M % row_divisor) + # Check and calculate column padding if col_divisor is provided + if col_divisor is not None and N % col_divisor != 0: + padding_needed_cols = col_divisor - (N % col_divisor) + + # Return original tensor if no padding is needed + if padding_needed_rows == 0 and padding_needed_cols == 0: + return tensor + + # pad the tensor + out = torch.nn.functional.pad( + tensor, + (0, padding_needed_cols, 0, padding_needed_rows), + mode="constant", + value=0.0, + ).contiguous() + + return out + + @staticmethod + def _rm_pad_tensor(tensor: torch.Tensor, original_size: tuple[int, ...]) -> torch.Tensor: + + assert tensor.dim() == 2, "only supports 2D tensors" + M, N = original_size + out = tensor[:M, :N].contiguous() + return out + + def _quantize(self, tensor: torch.Tensor) -> Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, + ]: + """ + Python implementation of microblock FP4 quantization. + + Parameters + ---------- + tensor : torch.Tensor + Input tensor to quantize (should be 2D) + + Returns + ------- + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor] + (qx, sx, qx_t, sx_t, global_amax) where: + - qx: quantized data in row-major order (if rowwise_usage), None otherwise + - sx: scale tensor for qx (if rowwise_usage), None otherwise + - qx_t: quantized data in column-major order (if columnwise_usage), None otherwise + - sx_t: scale tensor for qx_t (if columnwise_usage), None otherwise + - global_amax: global amax tensor + """ + if self.pow_2_scales: + assert self.quant_tile_shape == ( + 1, + 32, + ), "MXFP4 only supports 1x32 tile shape." + # TODO(etsykunov): Fix bug where global_amax_row and + # global_amax_col are not defined + # global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32) + else: + assert self.quant_tile_shape in ( + (1, 16), + (16, 16), + ), "NVFP4 only supports 1x16 or 16x16 tile shape." + # Prepare inputs once so we can reuse for both amax and quantization + # Row-input will always be the original input. + row_input = tensor + col_input = ( + self._apply_rht(tensor.t().contiguous()) + if self.with_rht + else tensor.t().contiguous() + ) + # Compute amax for rowwise and columnwise paths separately + global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) + global_amax_col = ( + torch.max(torch.abs(col_input)).to(torch.float32).view(1) + if self.columnwise_usage + else global_amax_row + ) + + transpose_scales = False + + M, N = tensor.shape + if self.rowwise_usage: + x_input = row_input + x_padded = self._pad_tensor( + x_input, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1] + ) + + qx, sx = self._quantize_blockwise_reference( + x_padded, + global_amax_row, + self.quant_tile_shape[1], + self.quant_tile_shape[0], + pow_2_scales=self.pow_2_scales, + eps=self.eps, + ) + if transpose_scales: + sx = sx.T + + qx = self._rm_pad_tensor(qx, (M, N // 2)) + + else: + qx = None + sx = None + + if self.columnwise_usage: + x_t = col_input + x_t_padded = self._pad_tensor( + x_t, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1] + ) + + qx_t, sx_t = self._quantize_blockwise_reference( + x_t_padded, + global_amax_col, + self.quant_tile_shape[1], + self.quant_tile_shape[0], + pow_2_scales=self.pow_2_scales, + eps=self.eps, + ) + + qx_t = self._rm_pad_tensor(qx_t, (N, M // 2)) + + if transpose_scales: + sx_t = sx_t.T + else: + qx_t = None + sx_t = None + + return qx, sx, qx_t, sx_t, global_amax_row, global_amax_col + + def quantize( + self, + tensor: torch.Tensor, + **kwargs, # pylint: disable=unused-argument + ) -> NVFP4TensorRef: + # sanity checks + assert tensor.dtype in utils.HIGH_PRECISION_FLOAT_DTYPES, "Unsupported input dtype." + + # Make it work with 3D tensors + original_shape = tensor.shape + if tensor.ndim > 2: + tensor = tensor.view(-1, tensor.shape[-1]) + + qx, sx, qx_t, sx_t, global_amax_row, global_amax_col = self._quantize(tensor) + + return NVFP4TensorRef( + data=qx, + scale=sx, + data_t=qx_t, + scale_t=sx_t, + global_amax_row=global_amax_row, + global_amax_col=global_amax_col, + dtype=tensor.dtype, + device=tensor.device, + quant_dtype=self.dtype, + quantizer=self, + original_shape=original_shape, + ) + + def update_quantized( + self, + src: torch.Tensor, + dst: ExperimentalQuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> ExperimentalQuantizedTensor: + """Update the quantized tensor with the given tensor in-place + + Parameters + ---------- + src: torch.Tensor + Source tensor to copy from + dst: ExperimentalQuantizedTensor + Destination ExperimentalQuantizedTensor to update + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + """ + # Handle noop flag + if noop_flag is not None and noop_flag.item() != 0: + return dst + + # Make sure input is in expected format + if not src.is_contiguous(): + src = src.contiguous() + + # Store the original shape and reshape for processing + original_shape = src.shape + if src.ndim > 2: + src = src.view(-1, src.shape[-1]) + + qx, sx, qx_t, sx_t, global_amax = self._quantize(src) + + # Update the destination with new data + dst.data = qx + dst.scale = sx + dst.data_t = qx_t + dst.scale_t = sx_t + dst.global_amax = global_amax + dst.dtype = src.dtype + dst.quant_dtype = self.dtype + dst.original_shape = original_shape + + return dst + + @property + def supports_allgather_fp8(self) -> bool: + """Whether the tensor data can be all-gathered with an FP8 all-gather. + + TODO(etsykunov): Confirm docstring is correct. Also, this API + seems too FP8-specific and should be reconsidered. + """ + return False + + def transpose_qresult( + self, qresult: quantization.ExperimentalQuantizedTensor + ) -> quantization.ExperimentalQuantizedTensor: + """Convert row-wise data to column-wise data (?) + + TODO(etsykunov): Confirm docstring is correct. + """ + raise NotImplementedError("Transpose qresult is not implemented for FP4.") + + @property + def supports_dequantize(self) -> bool: + """Whether quantized tensor can converted to high-precision tensor""" + return False + + @property + def is_data_t_transposed_in_memory(self) -> bool: + """Whether column-wise data is stored in transposed layout. + + TODO(etsykunov): Confirm docstring is correct. + """ + raise NotImplementedError("Not implemented yet") + + def dequantize( + self, tensor: torch.Tensor, scale: torch.Tensor, dtype: Optional[torch.dtype] = None + ) -> torch.Tensor: + """Dequantize the quantized tensor""" + raise NotImplementedError("Not implemented yet") + + def qgemm( + self, + qx: torch.Tensor, + qw: torch.Tensor, + m_params: quantization.MMParams, + out_dtype: torch.dtype, + sx: torch.Tensor, + sw: torch.Tensor, + bias: torch.Tensor | None = None, + out: torch.Tensor | None = None, + accumulate: bool = False, + gemm_type: quantization.GEMMType = quantization.GEMMType.FPROP, + qresult_x: quantization.ExperimentalQuantizedTensor | None = None, + qresult_w: quantization.ExperimentalQuantizedTensor | None = None, + ) -> torch.Tensor: + assert bias is None, "Bias is implemented for FP4 GEMM." + + high_precision_x = cast_from_fp4x2(qx, out_dtype) + high_precision_w = cast_from_fp4x2(qw, out_dtype) + + if self.pow_2_scales: + + if sx.dtype == torch.uint8: + # if scaling factor is stored in uint8 container + sx = torch.tensor(2.0, device=sx.device, dtype=torch.float32) ** ( + ( + sx.to(torch.float32) + - torch.tensor(127, device=sx.device, dtype=torch.float32) + ) + ) + sw = torch.tensor(2.0, device=sw.device, dtype=torch.float32) ** ( + ( + sw.to(torch.float32) + - torch.tensor(127, device=sw.device, dtype=torch.float32) + ) + ) + else: + # if scaling factor is torch.float8_e8m0fnu + sx = sx.to(torch.float32) + sw = sw.to(torch.float32) + + alpha = torch.tensor(1.0, device=high_precision_x.device, dtype=torch.float32) + + else: + + assert qresult_x is not None + assert qresult_w is not None + + assert qresult_x.global_amax_row is not None + assert qresult_w.global_amax_col is not None + + sx = sx.to(torch.float32) + sw = sw.to(torch.float32) + + factor = 6.0 * 6.0 * 448.0 * 448.0 + + if gemm_type == quantization.GEMMType.WGRAD: + partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col + else: + partial_alpha = qresult_x.global_amax_row * qresult_w.global_amax_row + alpha = torch.div(partial_alpha, factor).squeeze(-1) + + M, K = high_precision_x.shape + N, K_w = high_precision_w.shape + assert K == K_w, "K dimension mismatch between qx and qw" + + assert K % 32 == 0, "K dimension must be divisible by 32" + assert N % 8 == 0, "N dimension must be divisible by 8" + + block_length = 32 if self.pow_2_scales else 16 + + grid_k = K // block_length + + assert sx.shape == ( + M, + K // block_length, + ), f"sx shape mismatch: expected ({M}, {K//block_length}), got {sx.shape}" + assert sw.shape == ( + N, + K // block_length, + ), f"sw shape mismatch: expected ({N}, {K//block_length}), got {sw.shape}" + + y = torch.zeros(M, N, dtype=torch.float32, device=qx.device) + + # below implementation is to match the FP4 tensor core implementation + # Each output element (i, j) is fp32 accumulation of (K // block_length) inner products + # Each inner product is sx * sw * (1, block_length) x (block_length, 1) with precision in fp32 + # Then batch the computation in M, N dimension + for k in range(grid_k): + k_start = k * block_length + k_end = k_start + block_length + + qx_block = high_precision_x[:, k_start:k_end].clone().contiguous() + qw_block = high_precision_w[:, k_start:k_end].clone().contiguous() + + # Extract scaling factors for the current blocks + sx_block = sx[:, k] + sw_block = sw[:, k] + + y += torch.outer(sx_block, sw_block) * high_precision_gemm_ref( + qx_block, qw_block, torch.float32, is_b_transposed=True + ) + + if not self.pow_2_scales and K > 0: + # only apply global scale for NVFP4 and non-empty cases + y = alpha * y + + # accumulation happens at epilogue in float32 + if accumulate: + assert out is not None, "Output tensor must be provided for accumulation." + y += out.to(torch.float32) + else: + assert out is None, "Output tensor should be None when accumulate is False." + + y = y.to(out_dtype) + return y diff --git a/transformer_engine/pytorch/experimental/utils.py b/transformer_engine/pytorch/experimental/utils.py new file mode 100644 index 000000000..20dc6f11b --- /dev/null +++ b/transformer_engine/pytorch/experimental/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Utility functions for experimental middleware between Transformer Engine and Kitchen.""" + +import enum + +import torch + + +HIGH_PRECISION_FLOAT_DTYPES = ( + torch.float, + torch.float16, + torch.bfloat16, + torch.float32, +) + + +class Fp4Formats(enum.Enum): + """FP4 data format""" + + E2M1 = "e2m1" + + +def roundup_div(x: int, y: int) -> int: + """Round up division""" + assert x >= 0 + assert y > 0 + return (x + y - 1) // y diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 15cb88b00..ccd338433 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -24,6 +24,7 @@ MXFP8BlockScaling, Float8CurrentScaling, Float8BlockScaling, + NVFP4BlockScaling, ) from .constants import dist_group_type @@ -67,6 +68,13 @@ def check_mxfp8_support() -> Tuple[bool, str]: return False, "Device compute capability 10.0 or higher required for MXFP8 execution." +def check_nvfp4_support() -> Tuple[bool, str]: + """Return if nvfp4 support is available""" + if get_device_compute_capability() >= (10, 0): # blackwell and above + return True, "" + return False, "Device compute capability 10.0 or higher required for NVFP4 execution." + + def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" if IS_HIP_EXTENSION: @@ -128,6 +136,13 @@ def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType return tex.DType.kFloat8E5M2 +def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType: + """Get fp4 data type according to recipe and tensor""" + if fp4_recipe.fp4_format == Format.E2M1: + return tex.DType.kFloat4E2M1 + raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}") + + def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( @@ -166,6 +181,8 @@ class FP8GlobalStateManager: reason_for_no_mxfp8 = "" fp8_block_scaling_available = None reason_for_no_fp8_block_scaling = None + nvfp4_available = None + reason_for_no_nvfp4 = "" @classmethod def reset(cls) -> None: @@ -229,6 +246,13 @@ def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: ) return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling + @classmethod + def is_nvfp4_available(cls) -> Tuple[bool, str]: + """Return if NVFP4 support is available.""" + if cls.nvfp4_available is None: + cls.nvfp4_available, cls.reason_for_no_nvfp4 = check_nvfp4_support() + return cls.nvfp4_available, cls.reason_for_no_nvfp4 + @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" @@ -505,6 +529,9 @@ def fp8_autocast_enter( if isinstance(fp8_recipe, Float8BlockScaling): fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() assert fp8_block_available, reason_for_no_fp8_block + if isinstance(fp8_recipe, NVFP4BlockScaling): + nvfp4_available, reason_for_no_nvfp4 = cls.is_nvfp4_available() + assert nvfp4_available, reason_for_no_nvfp4 @classmethod def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: @@ -861,6 +888,8 @@ def create( cls = Float8CurrentScalingRecipeState elif recipe.float8_block_scaling(): cls = Float8BlockScalingRecipeState + elif recipe.nvfp4(): + cls = NVFP4BlockScalingRecipeState else: raise ValueError(f"{recipe.__class__.__name__} is not supported") return cls( @@ -965,7 +994,9 @@ def make_quantizers(self) -> list: from .tensor.float8_tensor import Float8CurrentScalingQuantizer return [ - Float8CurrentScalingQuantizer(self.dtype, device=self.device) + Float8CurrentScalingQuantizer( + self.dtype, device=self.device, force_pow_2_scales=self.recipe.use_power_2_scales + ) for i in range(self.num_quantizers) ] @@ -1108,3 +1139,79 @@ def make_quantizers(self) -> list: ] ) ) + + +class NVFP4BlockScalingRecipeState(RecipeState): + """Configuration for NVFP4 quantization. + + NVFP4 quantization does not require state. + + """ + + recipe: NVFP4BlockScaling + mode: str + dtype: tex.DType + + def __init__( + self, + recipe: NVFP4BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp4_te_dtype(recipe) + + # Allocate buffers + if device is None: + device = torch.device("cuda") + + def make_quantizers(self) -> list: + from .tensor.nvfp4_tensor import NVFP4Quantizer + + # The index convention (coming from base.py set_meta_tensor) + # is somewhat awkward. It assumes forward quantizers are + # ordered [input, weight, output, ...] and backward quantizers + # are ordered [grad_output, grad_input, ...]. This doesn't + # play nicely with fusible ops: Linear op doesn't own output + # or grad input quantizers, Quantize op only owns input and + # grad output quantizers. + + if self.mode == "forward": + + def _make_quantizer(idx: int) -> NVFP4Quantizer: + qparams = ( + self.recipe.fp4_quant_fwd_weight + if idx % 3 == 1 + else self.recipe.fp4_quant_fwd_inp + ) + return NVFP4Quantizer( + fp4_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_rht=qparams.random_hadamard_transform, + with_post_rht_amax=qparams.random_hadamard_transform, + with_2d_quantization=qparams.fp4_2d_quantization, + stochastic_rounding=qparams.stochastic_rounding, + ) + + return [_make_quantizer(idx) for idx in range(self.num_quantizers)] + + if self.mode == "backward": + return [ + NVFP4Quantizer( + fp4_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_rht=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, + with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, + with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, + stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, + ) + for _ in range(self.num_quantizers) + ] + + raise RuntimeError(f"Unexpected recipe mode ({self.mode})") diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 1f38b493c..366f4ab3a 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -8,16 +8,19 @@ import os from typing import Any, List, Optional, Tuple, Union, Callable -from dataclasses import dataclass - +import dataclasses import queue +from typing import Any, Callable, List, Optional, Tuple, Union + import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION from .. import cpp_extensions as tex +from .. import experimental from ..constants import TE_DType -from ..utils import get_default_init_method from ..export import is_in_onnx_export_mode +from ..tensor.utils import is_experimental +from ..utils import get_default_init_method if IS_HIP_EXTENSION: from ..triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton @@ -188,7 +191,33 @@ def noop_cat( return _NoopCatFunc.apply(dim, *tensors) -@dataclass +def get_module_quantizers( + module: torch.nn.Module, + fp8_output: bool, + fp8_grad: bool, + debug: bool, +): + """Return the 6-tuple of quantizers for a module in a centralized way. + + Routing policy: + - If experimental quantization is enabled via environment and module.fp8 is True, + return experimental quantizers. + - Otherwise, return the module's own quantizers (debug or regular). + """ + if getattr(module, "fp8", False) and is_experimental(): + # TODO(etsykunov): Quantizer instantiation should be better + # done in the module's constructor + qlinear_params = experimental.config.set_qlinear_params() + + if qlinear_params is not None: + return experimental.config.get_experimental_quantizers(module.fp8, qlinear_params) + + if not debug: + return module._get_quantizers(fp8_output, fp8_grad) + return module._get_debug_quantizers(fp8_output, fp8_grad) + + +@dataclasses.dataclass class _ParameterInitMeta: """ Stores essential metadata needed to support deferred parameter initialization. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b49e38544..c11ca543a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -30,6 +30,7 @@ DelayedScalingRecipeState, Float8CurrentScalingRecipeState, Float8BlockScalingRecipeState, + NVFP4BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, ) @@ -42,6 +43,7 @@ from ..constants import dist_group_type from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer +from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer if IS_HIP_EXTENSION: @@ -89,7 +91,8 @@ def get_cublas_workspace_size_bytes() -> None: return 33_554_432 """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: - return 33_554_432 + # 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales + return 32 * 1024 * 1024 + 1024 return 4_194_304 @@ -772,6 +775,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe_state, Float8BlockScalingRecipeState ): return + if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd @@ -981,12 +986,13 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: return dtype = inp.dtype - for name, param in self.named_parameters(): - if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) + if not self.allow_different_data_and_param_types: + for name, param in self.named_parameters(): + if param is not None: + assert dtype == param.dtype, ( + "Data types for parameters must match when outside of autocasted region. " + f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" + ) self.activation_dtype = dtype def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: @@ -1080,6 +1086,7 @@ def prepare_forward( inp: torch.Tensor, num_gemms: int = 1, allow_non_contiguous: bool = False, + allow_different_data_and_param_types: bool = False, ) -> Generator[torch.Tensor, None, None]: """Checks and prep for FWD. The context manager is needed because there isn't a way for a module to know @@ -1087,6 +1094,7 @@ def prepare_forward( to setup the forward aggregated amax reduction for every module just in case. The autocast exit will pick up the most recent one. """ + self.allow_different_data_and_param_types = allow_different_data_and_param_types self.forwarded_at_least_once = True # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): @@ -1235,15 +1243,13 @@ def grad_output_preprocess( ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: - if isinstance(quantizer, Float8BlockQuantizer): + # TODO(ksivaman): Re-add fusion once kernel is available. + if isinstance(quantizer, (Float8BlockQuantizer, NVFP4Quantizer)): # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) - if not isinstance( - grad_output, - (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), - ): + if not isinstance(grad_output, QuantizedTensorBase): grad_output = quantizer(grad_output) return grad_output, grad_bias diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d1aeebc9f..abe353e8b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -20,6 +20,7 @@ from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.tensor.utils import is_experimental from .base import ( fill_userbuffers_buffer_for_all_gather, get_workspace, @@ -33,6 +34,7 @@ from ..fp8 import FP8GlobalStateManager from ..utils import ( assert_dim_for_fp8_exec, + assert_dim_for_all_gather, cast_if_needed, clear_tensor_data, divide, @@ -57,7 +59,7 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ._common import apply_normalization, noop_cat, WeightGradStore +from ._common import apply_normalization, noop_cat, WeightGradStore, get_module_quantizers from ..tensor.quantized_tensor import ( QuantizedTensor, QuantizedTensorBase, @@ -147,6 +149,8 @@ def forward( if ub_name is not None: nvtx_label = f"{nvtx_label}.{ub_name}" + with_input_all_gather = parallel_mode == "column" and sequence_parallel + # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape @@ -156,6 +160,7 @@ def forward( inputmat = inp if fp8: assert_dim_for_fp8_exec(inputmat, weight) + assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer) # Cast for native AMP nvtx_range_push(f"{nvtx_label}.norm_input_cast") @@ -169,7 +174,6 @@ def forward( weight_requires_grad = weight.requires_grad backward_needs_input = is_grad_enabled and weight_requires_grad - with_input_all_gather = parallel_mode == "column" and sequence_parallel # Configure Userbuffers communication (comm+GEMM overlap) if debug: # turn off userbuffers in debug mode @@ -202,11 +206,13 @@ def forward( # Avoid quantized norm kernel if norm output will be returned # or if a gather of ln_out must be in high precision. + experimental = is_experimental(input_quantizer) with_quantized_norm = ( fp8 and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not experimental ) # ROCm does not currently support quantized norm for Float8CurrentScalingQuantizer @@ -256,7 +262,8 @@ def forward( quantizer = None if fp8 or debug: quantizer = input_quantizer - if not with_quantized_norm: + # experimental recipe doesn't need to support quantized AG + if not with_quantized_norm and not experimental: ln_out = quantizer(ln_out) quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather @@ -1474,6 +1481,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: self._customize_quantizers_float8_current_scaling(fwd, recipe) elif recipe.float8_block_scaling(): self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) # elif other recipes (mxfp8, etc) def reset_layer_norm_parameters(self) -> None: @@ -1578,11 +1587,7 @@ def forward( # Get concatenated weight and bias tensors weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad) - ) + quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug) if debug: if self.no_debug_features_active(quantizers): debug = False @@ -1817,6 +1822,28 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_linear.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + # set input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: """Get the weight tensors of the module.""" unfused_weights = [getattr(self, name) for name in self.weight_names] diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4492abe3e..eadd992c7 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -20,6 +20,7 @@ from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.tensor.utils import is_experimental from .base import ( fill_userbuffers_buffer_for_all_gather, get_workspace, @@ -43,6 +44,7 @@ init_method_constant, cast_if_needed, assert_dim_for_fp8_exec, + assert_dim_for_all_gather, clear_tensor_data, requires_grad, needs_quantized_gemm, @@ -67,6 +69,7 @@ Float8Tensor, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ._common import apply_normalization, WeightGradStore from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload @@ -121,7 +124,8 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): } # no activation fusion written yet # Per-tensor current scaling or fp8 blockwise scaling: [] - if recipe.float8_current_scaling() or recipe.float8_block_scaling(): + # TODO(ksivaman): Fuse nvfp4 act once kernel is available. + if recipe.float8_current_scaling() or recipe.float8_block_scaling() or recipe.nvfp4(): return { "gelu": (tex.gelu, tex.dgelu, None), "geglu": (tex.geglu, tex.dgeglu, None), @@ -220,6 +224,7 @@ def forward( inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) + assert_dim_for_all_gather(inputmat, sequence_parallel, fc1_input_quantizer) activation_func = _act_func( activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None @@ -267,11 +272,13 @@ def forward( # high precision layernorm output and output of the linear are returned # for debug: : layernorm output = High precision to enable processing of this norm + experimental = is_experimental(fc1_input_quantizer) with_quantized_norm = ( fp8 and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not experimental ) # ROCm does not currently support quantized norm for Float8CurrentScalingQuantizer @@ -315,7 +322,8 @@ def forward( quantizer = None if fp8 or debug: quantizer = fc1_input_quantizer - if not with_quantized_norm: + # experimental recipe doesn't need to support quantized AG + if not with_quantized_norm and not experimental: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: @@ -569,6 +577,7 @@ def forward( if not fc2_weight.requires_grad: clear_tensor_data(act_out) act_out = None + tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, @@ -698,6 +707,7 @@ def backward( mu, rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed # by the `restore_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None @@ -1042,7 +1052,10 @@ def fc2_wgrad_gemm( if ctx.fp8: # TODO float8 blockwise current scaling has no bgrad fusion for now - if isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer): + # TODO(ksivaman): Re-add fusion once kernel is available. + if isinstance( + ctx.fc1_grad_output_quantizer, (Float8BlockQuantizer, NVFP4Quantizer) + ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) else: @@ -1747,6 +1760,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: self._customize_quantizers_float8_current_scaling(fwd, recipe) elif recipe.float8_block_scaling(): self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) # elif for other recipes (mxfp8, etc.) def reset_layer_norm_parameters(self) -> None: @@ -1968,7 +1983,10 @@ def _get_quantizers(self, fp8_output): fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer.set_usage( rowwise=True, - columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), + columnwise=isinstance( + fc2_input_quantizer, + (MXFP8Quantizer, Float8BlockQuantizer, NVFP4Quantizer), + ), ) fc1_input_quantizer.internal = True if fp8_output: @@ -2173,6 +2191,28 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT2 ].amax_reduction_group = self.tp_group + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_mlp.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.set_parallel_mode: + # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.set_parallel_mode: + # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT2 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT2 + ].amax_reduction_group = self.tp_group + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: """Get the weight tensors of the module.""" return [self.fc1_weight, self.fc2_weight] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 88ed6356b..ca25fd7f9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -27,7 +27,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ._common import noop_cat, WeightGradStore +from ._common import noop_cat, WeightGradStore, get_module_quantizers from ..fp8 import FP8GlobalStateManager from ..utils import ( cast_if_needed, @@ -37,6 +37,7 @@ requires_grad, needs_quantized_gemm, assert_dim_for_fp8_exec, + assert_dim_for_all_gather, nvtx_range_pop, nvtx_range_push, ) @@ -67,6 +68,7 @@ ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.utils import is_experimental from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...debug.pytorch.debug_state import TEDebugState @@ -156,6 +158,9 @@ def forward( ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.AG + # experimental recipe check + experimental = is_experimental(input_quantizer) or is_experimental(weight_quantizer) + # ------------------------------------------------------ # Prepare input tensor # Note: Cast to expected dtype and perform tensor-parallel communication @@ -166,6 +171,7 @@ def forward( own_quantized_input = False if fp8: assert_dim_for_fp8_exec(inputmat, weight) + assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer) if save_original_input: assert not isinstance( input_quantizer, Float8Quantizer @@ -177,7 +183,7 @@ def forward( if fp8 or debug: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - if not isinstance(inputmat, QuantizedTensorBase): + if not isinstance(inputmat, QuantizedTensorBase) and not experimental: own_quantized_input = True input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) if isinstance( @@ -450,6 +456,7 @@ def forward( ctx.main_grad_func = lambda: weight.main_grad ctx.debug = debug + ctx.experimental = experimental ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = bias is not None @@ -621,7 +628,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(inputmat, QuantizedTensorBase): # Input tensor is already quantized pass - elif ctx.debug: + elif ctx.debug or ctx.experimental: # Debug quantizer will be applied immediately before wgrad GEMM pass else: @@ -710,6 +717,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # dgrad GEMM # Note: dx = dy * w + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") gemm_out, *_, reduce_scatter_out = general_gemm( weight_fp8, @@ -1363,6 +1371,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: self._customize_quantizers_float8_current_scaling(fwd, recipe) elif recipe.float8_block_scaling(): self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) # elif for other recipes (mxfp8, etc.) def reset_parameters(self, defer_init=False): @@ -1447,12 +1457,7 @@ def forward( weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad) - ) - + quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug) if debug: if self.no_debug_features_active(quantizers): debug = False @@ -1694,6 +1699,28 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + linear.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + # customize input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" if not self.fp8 and not self.fp8_calibration: diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 70c70c54d..ef125f0c6 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -322,6 +322,20 @@ def pre_first_fuser_forward(self) -> None: if self.weight.device.type == "meta": self.reset_parameters() + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + super().pre_fuser_forward(requires_grad=requires_grad) + if FP8GlobalStateManager.is_fp8_enabled(): + # Configure quantizer usages + # Note: We cache the quantized input for backward pass, + # but discard the quantized weights. + weight_requires_grad = requires_grad and self.weight.requires_grad + input_quantizer = self.get_quantizer("forward", 0) + weight_quantizer = self.get_quantizer("forward", 1) + grad_output_quantizer = self.get_quantizer("backward", 0) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) + grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) @@ -352,6 +366,35 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: and not getattr(self, "_with_quantized_weight", False) ) + # Recipe-specific configuration + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + if recipe is not None: + if recipe.float8_current_scaling(): + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon + grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon + if getattr(self, "sequence_parallel", False): + tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None) + if tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + elif tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + if recipe.nvfp4(): + if getattr(self, "sequence_parallel", False): + tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None) + if tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + elif tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + @staticmethod def _functional_forward( input: torch.Tensor, # pylint: disable=redefined-builtin @@ -731,7 +774,7 @@ def _functional_backward( if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(columnwise=True) + input_quantizer.set_usage(rowwise=False, columnwise=True) if with_x_all_gather: x, x_async = gather_along_first_dim( x_local, @@ -912,34 +955,13 @@ def op_forward( input_requires_grad = ctx.requires_grad weight_requires_grad = ctx.requires_grad and self.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) output_quantizer = next_op_input_quantizer grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - if with_quantized_compute: - # Configure quantizers - # Note: We cache the quantized input for backward pass, - # but discard the quantized weights. - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - weight_quantizer.set_usage(rowwise=True, columnwise=False) - - recipe = FP8GlobalStateManager.get_fp8_recipe() - if recipe.float8_current_scaling(): - input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - if self.sequence_parallel and self.tensor_parallel_mode == "column": - input_quantizer.with_amax_reduction = True - input_quantizer.amax_reduction_group = self.tensor_parallel_group - if self.sequence_parallel and self.tensor_parallel_mode == "row": - grad_output_quantizer.with_amax_reduction = True - grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group # Get autocast dtype if needed if torch.is_autocast_enabled(): diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 02bcfee0a..ab271e17b 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -85,7 +85,7 @@ def fuser_forward( input_requires_grad = linear_op_ctx.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = next_op_input_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 15cc081c1..4831ae407 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -79,7 +79,7 @@ def fuser_forward( input_requires_grad = linear_op_ctx.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 21190d4fc..72e17f64e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -58,7 +58,7 @@ def fuser_forward( input_requires_grad = linear_op_ctx.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = None diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index df8843649..7e5255d13 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -480,6 +480,10 @@ def __call__( # Attempt to fuse operations if neccesary self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs) + # Initialization before forward + for idx, op in enumerate(self._basic_ops): + op.pre_fuser_forward(requires_grad=idx >= self.first_op_requiring_backward) + # Fuser forward pass if is_grad_enabled: forward_func = _OperationFuserAutogradFunction.apply diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 903bc49d5..103ebf241 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -65,6 +65,13 @@ def is_fused_op(self) -> bool: def pre_first_fuser_forward(self) -> None: """Preprocessing before first fuser forward pass""" + def pre_fuser_forward( + self, + *, + requires_grad: bool, # pylint: disable=unused-argument + ) -> None: + """Preprocessing before fuser forward pass""" + def get_input_quantizer(self) -> Optional[Quantizer]: """Get builder class for quantized input tensor""" @@ -710,6 +717,10 @@ def pre_first_fuser_forward(self) -> None: for op in self.basic_ops: op.pre_first_fuser_forward() + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + for op in self.basic_ops: + op.pre_fuser_forward(requires_grad=requires_grad) + def forward( self, input: torch.Tensor, # pylint: disable=redefined-builtin diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 7fa12cc08..43846512d 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -54,6 +54,7 @@ def get_all_tensor_types(): Float8BlockwiseQTensor, Float8BlockwiseQTensorBase, ) + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Tensor, NVFP4TensorBase all_tensor_types = [ torch.Tensor, @@ -64,5 +65,7 @@ def get_all_tensor_types(): MXFP8TensorBase, Float8BlockwiseQTensor, Float8BlockwiseQTensorBase, + NVFP4Tensor, + NVFP4TensorBase, ] return all_tensor_types diff --git a/transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py new file mode 100644 index 000000000..df187d674 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py @@ -0,0 +1,348 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for NVFP4Tensor""" + +from __future__ import annotations +from collections.abc import Iterable +import functools +import math +from typing import Any, Dict, Optional, Tuple, Union +import warnings + +import torch + +# import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ..quantized_tensor import QuantizedTensorBase + +# from ...constants import TE_DType as torch_to_transformer_engine_dtype +from ..quantized_tensor import Quantizer +from ...utils import _empty_tensor + + +@functools.lru_cache(maxsize=None) +def _fp4_e2m1_vals(device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """Values representable in FP4 E2M1 format""" + return torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], + device=device, + dtype=dtype, + ) + + +class _FromNVFP4Func(torch.autograd.Function): + """Cast from NVFP4 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: NVFP4TensorBase, + dtype: torch.dtype, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + + # Dequantize row-wise data + if tensor._rowwise_data is not None: + ### TODO(tmoon): Debug dequantize kernel and remove unfused impl + # return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype]) + + # Tensor properties + shape = list(tensor._rowwise_data.size()) + shape[-1] *= 2 + device = tensor._rowwise_data.device + + # Convert FP4E2M1 values to FP32 + data = tensor._rowwise_data.view(torch.uint8).to(torch.int32) + data = torch.stack((data & 0x0F, data >> 4), dim=-1).reshape(shape) + data = _fp4_e2m1_vals(device, dtype=torch.float32)[data] + data = data.to(torch.float32).contiguous() + + # Convert FP8E4M3 block scales to FP32 + block_scales = tensor._rowwise_scale_inv + block_scales = block_scales.reshape(-1, block_scales.size(-1)) + block_scales = block_scales[: math.prod(shape[:-1]), : shape[-1] // 16] + block_scales = block_scales.view(torch.float8_e4m3fn).to(torch.float32) + + # Convert amax to FP32 tensor scale + tensor_scale = tensor._amax_rowwise / (6.0 * 448.0) # Scale by FP4E2M1 and FP8E4M3 max + + # Apply scales + block_data = data.view(-1, 16) + block_data *= tensor_scale.view(()) * block_scales.reshape(-1, 1) + + return data.to(dtype) + + if tensor._columnwise_data is not None: + raise NotImplementedError("Dequantizing column-wise NVFP4 data is not implemented yet!") + raise ValueError("Attempted to dequantize NVFP4 tensor with no data") + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class NVFP4TensorBase(QuantizedTensorBase): + """Mixin class that holds data attributes of NVFP4Tensor. + + NVFP4Tensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + """ + + _rowwise_data: Optional[torch.Tensor] + _columnwise_data: Optional[torch.Tensor] + _quantizer: Optional[Quantizer] + _rowwise_scale_inv: torch.Tensor + _columnwise_scale_inv: torch.Tensor + _fp4_dtype: TE_DType + _amax_rowwise: torch.Tensor + _amax_columnwise: torch.Tensor + + def __new__( + cls, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: torch.Tensor, + amax_rowwise: torch.Tensor, + amax_columnwise: torch.Tensor, + fp4_dtype: TE_DType, + quantizer: Optional[Quantizer], + *args, + **kwargs, + ): + + instance = super().__new__(cls, *args, **kwargs) + + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._fp4_dtype = fp4_dtype + instance._quantizer = quantizer.copy() if quantizer is not None else None + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + instance._amax_rowwise = amax_rowwise + instance._amax_columnwise = amax_columnwise + + return instance + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + for t in ( + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale_inv, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + ): + if t is not None: + t.data = _empty_tensor() + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "amax_rowwise": self._amax_rowwise, + "amax_columnwise": self._amax_columnwise, + "fp4_dtype": self._fp4_dtype, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorBase]: + """Prepare the tensor base for saving for backward""" + tensors = [ + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale_inv, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + ] + self._rowwise_data = None + self._columnwise_data = None + self._rowwise_scale_inv = None + self._columnwise_scale_inv = None + self._amax_rowwise = None + self._amax_columnwise = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + self._rowwise_scale_inv = tensors[2] + self._columnwise_scale_inv = tensors[3] + self._amax_rowwise = tensors[4] + self._amax_columnwise = tensors[5] + return tensors[6:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._rowwise_data, self._columnwise_data + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Dequantize to a higher precision.""" + return _FromNVFP4Func.forward(None, self, dtype) + + def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: + # pylint: disable=missing-function-docstring + + # Infer tensor shape + shape = None + if self._rowwise_data is not None: + byte_shape = list(self._rowwise_data.size()) + shape = byte_shape[:-1] + [byte_shape[-1] * 2] + elif self._columnwise_data is not None: + warnings.warn("Attempting to get shape of NVFP4 tensor with only column-wise data.") + byte_shape = list(self._columnwise_data.size()) + shape = byte_shape[1:-1] + [byte_shape[-1] * 2, byte_shape[0]] + if shape is None: + raise RuntimeError("Attempted to get shape of NVFP4 tensor with no data") + + # Return shape or dim + if dim is None: + return torch.Size(shape) + return shape[dim] + + def view(self, shape: torch.Size): + # pylint: disable=missing-function-docstring + + # Return input tensor if view not needed + cur_shape = self.size() + if shape is None or shape == cur_shape: + return self + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "NVFP4Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(cur_shape)} to {tuple(shape)})" + ) + + # Reshape data + new_rowwise_data = None + new_columnwise_data = None + if self._rowwise_data is not None: + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = list(shape[:-1]) + [shape[-1] // 2] + new_rowwise_data = self._rowwise_data.view(byte_shape) + if self._columnwise_data is not None: + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = self._columnwise_data.view(byte_shape) + + # Construct tensor + return NVFP4TensorBase( + rowwise_data=new_rowwise_data, + rowwise_scale_inv=self._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=self._columnwise_scale_inv, + amax_rowwise=self._amax_rowwise, + amax_columnwise=self._amax_columnwise, + quantizer=self._quantizer, + fp4_dtype=self._fp4_dtype, + ) + + def __repr__(self): + data_rowwise = self.dequantize() + + return ( + "NVFP4TensorBase(" + f"rowwise_scaled_data={data_rowwise}," + f"rowwise_scale_inv={self._rowwise_scale_inv}," + f"amax_rowwise={self._amax_rowwise}," + f"amax_columnwise={self._amax_columnwise}," + ")" + ) + + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + """ + For the NVFP4 format, columnwise scaled output is only produced by x2 + scaling kernels, so this function only disables usages. + """ + + # Default usage is based on available data + if rowwise_usage is None: + rowwise_usage = self._rowwise_data is not None + if columnwise_usage is None: + columnwise_usage = self._columnwise_data is not None + + # Update row-scaled data + if rowwise_usage: + if self._rowwise_data is None: + raise RuntimeError( + "Requested row-wise usage, but NVFP4Tensor is missing row-scaled NVFP4 data" + ) + if self._rowwise_scale_inv is None: + raise RuntimeError( + "Requested row-wise usage, but NVFP4Tensor is missing row-scaled scale-inverses" + ) + if self._amax_rowwise is None: + raise RuntimeError( + "Requested row-wise usage, but NVFP4Tensor is missing per tensor" + " row-scaled scale-inverse" + ) + else: + self._rowwise_data = None + self._rowwise_scale_inv = None + self._amax_rowwise = None + + # Update column-scaled data + if columnwise_usage: + if self._columnwise_data is None: + raise RuntimeError( + "Requested column-wise usage, but NVFP4Tensor is missing column-scaled FP8 data" + ) + if self._columnwise_scale_inv is None: + raise RuntimeError( + "Requested column-wise usage, " + "but NVFP4Tensor is missing column-scaled scale-inverses" + ) + if self._amax_columnwise is None: + raise RuntimeError( + "Requested column-wise usage, " + "but NVFP4Tensor is missing per tensor column-scaled scale-inverse" + ) + else: + self._columnwise_data = None + self._columnwise_scale_inv = None + self._amax_columnwise = None diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index a0a17d1a1..d9ea375b2 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -226,6 +226,8 @@ class Float8CurrentScalingQuantizer(Quantizer): amax: torch.Tensor """FP8 datatype""" dtype: TE_DType + """amax update options""" + use_existing_amax: bool """amax reduction options""" with_amax_reduction: bool amax_reduction_group: Optional[dist_group_type] @@ -240,6 +242,7 @@ def __init__( *, rowwise: bool = True, columnwise: bool = True, + use_existing_amax: bool = False, with_amax_reduction: bool = False, amax_reduction_group: Optional[dist_group_type] = None, force_pow_2_scales: bool = False, @@ -249,6 +252,7 @@ def __init__( self.scale = torch.empty(1, dtype=torch.float32, device=device) self.amax = torch.empty(1, dtype=torch.float32, device=device) self.dtype = fp8_dtype + self.use_existing_amax = use_existing_amax self.with_amax_reduction = with_amax_reduction self.amax_reduction_group = amax_reduction_group self.force_pow_2_scales = force_pow_2_scales diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 16b1568cb..2acce32f9 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -4,7 +4,7 @@ # # See LICENSE for license information. -"""Tensor class with FP8 data""" +"""Tensor class with MXFP8 data""" from __future__ import annotations from collections.abc import Iterable import math @@ -200,8 +200,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): Reciprocal of the scaling factor applied when casting to FP8, i.e. the scaling factor that must be applied when casting from FP8 to higher - precision. Can be inferred from fp8_meta if - provided. + precision. dtype: torch.dtype, default = torch.float32 Nominal tensor datatype. diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py new file mode 100644 index 000000000..b12e89956 --- /dev/null +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -0,0 +1,898 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with NVFP4 data""" +from __future__ import annotations +from collections.abc import Iterable +import math +from typing import Optional, Tuple, Union +import functools + +import torch +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import NVFP4BlockScaling, Recipe +from ..constants import NVFP4_BLOCK_SCALING_SIZE, dist_group_type +from ..utils import ( + canonicalize_process_group, + devices_match, + round_up_to_nearest_multiple, +) + +from ._internal.nvfp4_tensor_base import NVFP4TensorBase, _FromNVFP4Func +from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc + +aten = torch.ops.aten + + +def get_no_random_sign_vector() -> torch.Tensor: + """Non-random sign vector for Hadamard transform.""" + return torch.tensor([1], dtype=torch.float32) + + +def get_sign_from_vector(vector: torch.Tensor) -> int: + """Convert sign vector to bitmask. + + Used for random Hadamard transform. + + """ + mask = 0 + for i, v in enumerate(vector): + mask |= (v == -1) << i + return mask + + +def get_wgrad_sign_vector() -> torch.Tensor: + """Hard-coded random signs for Hadamard transform. + + https://xkcd.com/221/ + + """ + return torch.tensor( + [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1], + dtype=torch.float32, + ) + + +def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor: + """Construct a 16x16 Hadamard matrix.""" + assert hadamard_dimension == 16, "Only hadamard dimension 16 is supported." + hadamard_scale = 1 / math.sqrt(hadamard_dimension) + return ( + torch.tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1], + [1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1], + [1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1], + [1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1], + [1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1], + [1, 1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1], + [1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1], + [1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1], + [1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, -1, 1, -1, 1], + [1, 1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1], + [1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1], + [1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1], + [1, -1, 1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1], + [1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1, 1, 1, -1, -1], + [1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1], + ], + dtype=torch.float32, + ) + * hadamard_scale + ) + + +@functools.lru_cache(maxsize=None) +def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor: + """Construct matrix used in random Hadamard transform.""" + hadamard_dimension = 16 + if with_random_sign_mask: + signs = get_wgrad_sign_vector() + else: + signs = get_no_random_sign_vector() + sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32) + rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension) + return rht_matrix.to(dtype=torch.bfloat16).cuda() + + +@functools.lru_cache(maxsize=None) +def get_random_sign_mask_for_rht(with_random_sign_mask: bool) -> int: + """Sign mask for random Hadamard transform.""" + if with_random_sign_mask: + return get_sign_from_vector(get_wgrad_sign_vector()) + return 0 + + +class NVFP4Quantizer(Quantizer): + """Builder class for NVFP4 tensors with NV block scaling""" + + dtype: TE_DType + """Random Hadamard Transform""" + with_rht: bool + with_post_rht_amax: bool + """amax reduction options""" + with_amax_reduction: bool + amax_reduction_group: Optional[dist_group_type] + + """2D block scaling, only applicable for weights.""" + with_2d_quantization: bool + + """Stochastic rounding, only applicable for gradients.""" + stochastic_rounding: bool + + """RHT matrix random sign mask""" + rht_matrix_random_sign_mask_t: int + rht_matrix: torch.Tensor + + def __init__( + self, + fp4_dtype: TE_DType = tex.DType.kFloat4E2M1, + rowwise: bool = True, + columnwise: bool = True, + with_amax_reduction: bool = False, + amax_reduction_group: Optional[dist_group_type] = None, + with_rht: bool = False, + with_post_rht_amax: bool = False, + with_2d_quantization: bool = False, + stochastic_rounding: bool = False, + with_random_sign_mask: bool = True, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.dtype = fp4_dtype + self.with_rht = with_rht + self.with_post_rht_amax = with_post_rht_amax + self.with_amax_reduction = with_amax_reduction + self.amax_reduction_group = amax_reduction_group + self.with_2d_quantization = with_2d_quantization + self.stochastic_rounding = stochastic_rounding + self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht(with_random_sign_mask) + self.rht_matrix = get_rht_matrix(with_random_sign_mask) + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + + assert isinstance(dst, NVFP4Tensor), f"Cannot store quantized NVFP4 in {type(dst)} type." + + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + return dst + + def is_quantizable(self, inp: torch.Tensor) -> bool: + """Returns whether or not given inp can be quantized""" + if inp.ndim < 2: + return False + if inp.shape[-1] % NVFP4_BLOCK_SCALING_SIZE != 0: + return False + if math.prod(inp.shape[:-1]) % NVFP4_BLOCK_SCALING_SIZE != 0: + return False + return True + + def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: + """Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization. + + This method determines the shape of the scaling tensor needed for blockwise quantization, + taking into account the input tensor shape and whether columnwise scaling is used. + + Parameters + ---------- + shape : Iterable[int] + Shape of the input tensor to be quantized + columnwise : bool + Whether to use columnwise scaling (True) or rowwise scaling (False) + + Returns + ------- + Tuple[int, int] + Shape of the scaling tensor as (outer_dim, inner_dim) + For NVFP4 1D blockwise quantization, blocksize is 16 + - If columnwise: (round_to_multiple(K, 128), round_to_multiple(roundup(M / 16), 4)) + - If rowwise: (round_to_multiple(M, 128), round_to_multiple(roundup(K / 16), 4)) + Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. + CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + M, K = 1, 1 + M = math.prod(shape[:-1]) + K = shape[-1] + + if columnwise: + outer = round_up_to_nearest_multiple(K, 128) + inner = round_up_to_nearest_multiple(math.ceil(M / NVFP4_BLOCK_SCALING_SIZE), 4) + return (outer, inner) + # rowwise + outer = round_up_to_nearest_multiple(M, 128) + inner = round_up_to_nearest_multiple(math.ceil(K / NVFP4_BLOCK_SCALING_SIZE), 4) + return (outer, inner) + + @staticmethod + def get_columnwise_shape(shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of a tensor after columnwise quantization. + + For NVFP4 columnwise quantization, it's performing 16x1 quantization block scaling. + + Parameters + ---------- + shape : Iterable[int] + Original shape of the tensor + + Returns + ------- + Tuple[int, ...] + New shape with dimensions rearranged for columnwise layout. + For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1). + Returns empty tuple for empty input shape. + """ + if len(shape) == 0: + return tuple() + # and then after AG, a reorganize kernel will be called to restore the shape + colwise_shape = [shape[-1]] + for i in range(len(shape) - 1): + colwise_shape.append(shape[i]) + return tuple(colwise_shape) + + @staticmethod + def convert_shape_for_fp4(shape: Iterable[int]) -> Tuple[int, ...]: + """Convert shape for FP4 data by dividing the last dimension by 2""" + shape = list(shape) + shape[-1] = shape[-1] // 2 + return tuple(shape) + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> NVFP4Tensor: + + # Canonicalize tensor attributes + if device is None: + device = torch.device("cuda") + + assert shape[-1] % NVFP4_BLOCK_SCALING_SIZE == 0, ( + f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" + f" {NVFP4_BLOCK_SCALING_SIZE}" + ) + + flat_first_dim = math.prod(shape[:-1]) + assert flat_first_dim % NVFP4_BLOCK_SCALING_SIZE == 0, ( + f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" + f" {NVFP4_BLOCK_SCALING_SIZE}" + ) + + # Allocate FP4 data + data = None + scale_inv = None + amax_rowwise = None + if self.rowwise_usage: + data = torch.empty(self.convert_shape_for_fp4(shape), dtype=torch.uint8, device=device) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty(scale_shape, dtype=torch.uint8, device=device) + # Allocate per tensor scale inverse. FP32 format. + amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device) + + # Allocate FP8 data transpose if needed + columnwise_data = None + columnwise_scale_inv = None + amax_columnwise = None + if self.columnwise_usage: + # enforce 2D shape to avoid [S, B, H] shape and B and be 1 + # and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + shape_2d = tuple([flat_first_dim, shape[-1]]) + columnwise_data = torch.empty( + self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)), + dtype=torch.uint8, + device=device, + ) + columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) + columnwise_scale_inv = torch.empty( + columnwise_scale_shape, dtype=torch.uint8, device=device + ) + amax_columnwise = torch.zeros(1, dtype=torch.float32, device=device) + + # Construct FP8 tensor + return NVFP4Tensor( + shape=shape, + dtype=dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + fp4_dtype=self.dtype, + quantizer=self, + requires_grad=requires_grad, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + pass # Calibration is no-op + + def _canonicalized_amax_reduction_group(self) -> dist_group_type: + """Get process group for amax reduction""" + return canonicalize_process_group(self.amax_reduction_group) + + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return NVFP4BlockScaling + + +class NVFP4Tensor(NVFP4TensorBase, QuantizedTensor): + """Quantized tensor class with FP4 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP4. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + rowwise_data: torch.Tensor + Raw FP4 data in a uint8 tensor (rowwise layout). + rowwise_scale_inv: torch.Tensor + Reciprocal of the scaling factor applied when + casting to FP4, i.e. the scaling factor that must + be applied when casting from FP4 to higher + precision (rowwise). + columnwise_data: torch.Tensor, optional + Raw FP4 data in a uint8 tensor (columnwise layout). + columnwise_scale_inv: torch.Tensor, optional + Reciprocal of the scaling factor for columnwise FP4 data. + amax_rowwise: torch.Tensor, optional + Rowwise amax tracking tensor. + amax_columnwise: torch.Tensor, optional + Columnwise amax tracking tensor. + fp4_dtype: TE_DType + The FP4 data type used for quantization. + quantizer: Quantizer + The quantizer instance used for this tensor. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype, used in dequantize. + """ + + # NOTE: We reorder the *args so that we can instantiate a NVFP4TensorBase with positional args, + # which significantly reduces the Pybind11 overhead when calling the constructor from C++. + def __new__( + cls, + *args, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + amax_rowwise: Optional[torch.Tensor], + amax_columnwise: Optional[torch.Tensor], + fp4_dtype: TE_DType, + quantizer: Quantizer, + **kwargs, + ): + instance = super().__new__( + cls, + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + amax_rowwise, + amax_columnwise, + fp4_dtype, + quantizer, + *args, + **kwargs, + ) + return instance + + def __repr__(self, *, tensor_contents=None): + return f"NVFP4Tensor, data={self.dequantize(dtype=self.dtype)})" + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from NVFP4Tensor + + By default the resulting tensor's dtype is the + NVFP4Tensor's nominal dtype. + """ + # Convert PyTorch dtype to TE dtype + if dtype is None: + dtype = self.dtype + + if torch.is_grad_enabled(): + return _FromNVFP4Func.apply(self, dtype) + return _FromNVFP4Func.forward(None, self, dtype) + + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + if self._quantizer is not None: + return self._quantizer + return NVFP4Quantizer() + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> NVFP4Tensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize()) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def detach(self) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + # TODO(ksivamani): Fix the detach bug + return NVFP4Tensor.make_like(self) + + def clone(self) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + assert self._rowwise_data is not None + rowwise_data = self._rowwise_data.detach().clone() + columnwise_data = None + if self._columnwise_data is not None: + columnwise_data = self._columnwise_data.detach().clone() + return _IdentityFunc.apply( + self, + { + "rowwise_data": rowwise_data, + "columnwise_data": columnwise_data, + }, + ) + + def view(self, *shape: Tuple[int]) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> NVFP4Tensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if self._rowwise_data is not None and self._rowwise_data.is_contiguous( + memory_format=memory_format + ): + return self + if self._columnwise_data is not None and self._columnwise_data.is_contiguous( + memory_format=memory_format + ): + return self + raise ValueError("NVFP4Tensor does not support different memory formats!") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # View op + if func == aten.view.default: + if len(args) != 2: + raise RuntimeError("Unexpected args for view op (expected 2 args, got {len(args)})") + tensor = args[0] + shape = args[1] + if shape == list(tensor.size()): + return tensor.detach() + return tensor.view(shape) + + # NVFP4 dequantize not supported. Add manual support for needed funcs. + if func in (aten.empty_like.default, aten.zero_.default): + tensor = args[0] + data_init_func = torch.zeros_like if func == aten.zero_.default else torch.empty_like + scale_inv_init_func = ( + torch.ones_like if func == aten.zero_.default else torch.empty_like + ) + + if tensor._rowwise_data is not None: + rowwise_data = data_init_func(tensor._rowwise_data) + rowwise_scale_inv = scale_inv_init_func(tensor._rowwise_scale_inv) + amax_rowwise = torch.zeros_like(tensor._amax_rowwise) + else: + rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None + + if tensor._columnwise_data is not None: + columnwise_data = data_init_func(tensor._columnwise_data) + columnwise_scale_inv = scale_inv_init_func(tensor._columnwise_scale_inv) + amax_columnwise = torch.zeros_like(tensor._amax_columnwise) + else: + columnwise_data, columnwise_scale_inv, amax_columnwise = ( + None, + None, + None, + ) + + return NVFP4Tensor( + shape=tensor.shape, + dtype=tensor.dtype, + fp4_dtype=tensor._fp4_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + quantizer=tensor._quantizer, + requires_grad=tensor.requires_grad, + ) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + amax_rowwise: torch.Tensor, + amax_columnwise: torch.Tensor, + fp4_dtype: TE_DType, + dtype: torch.dtype, + quantizer: Quantizer, + ) -> NVFP4Tensor: + """Build NVFP4Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return NVFP4Tensor( + shape=shape, + dtype=dtype, + fp4_dtype=fp4_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + quantizer=quantizer, + requires_grad=False, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling""" + return ( + NVFP4Tensor._make_in_reduce_ex, + ( + self.shape, + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + self._fp4_dtype, + self.dtype, + self._quantizer, + ), + ) + + def _get_data(self) -> NVFP4Tensor: + """Get tensor data property""" + return super().data + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Just takes FP8 data if setting from a NVFP4Tensor. Otherwise + casts to FP8. + + """ + + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + if not devices_match(new_device, tensor.device): + tensor = tensor.to(device=new_device) + + # Just copy FP8 data if other tensor is NVFP4Tensor + if isinstance(tensor, NVFP4Tensor): + if ( # pylint: disable=too-many-boolean-expressions + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + NVFP4Tensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + # pylint: disable=unnecessary-dunder-call + super(NVFP4Tensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data + self._columnwise_data = tensor._columnwise_data + self._quantizer = tensor._quantizer + self._rowwise_scale_inv = tensor._rowwise_scale_inv + self._columnwise_scale_inv = tensor._columnwise_scale_inv + self._amax_rowwise = tensor._amax_rowwise + self._amax_columnwise = tensor._amax_columnwise + return + + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self._quantizer.update_quantized(tensor, self) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + # Cast to FP8 when setting NVFP4Tensor.data + data = property(_get_data, _set_data) + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the NVFP4Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: NVFP4Tensor, + shape: Optional[list[int]] = None, + ) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + cur_shape = tensor.shape + if ctx is not None: + ctx.shape = cur_shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "NVFP4Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Reshape data + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = list(shape[:-1]) + [shape[-1] // 2] + new_rowwise_data = tensor._rowwise_data.view(byte_shape) + if tensor._columnwise_data is not None: + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = tensor._columnwise_data.view(byte_shape) + + # Construct tensor + return NVFP4Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + amax_rowwise=tensor._amax_rowwise, + amax_columnwise=tensor._amax_columnwise, + quantizer=tensor._quantizer, + fp4_dtype=tensor._fp4_dtype, + requires_grad=tensor.requires_grad, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, NVFP4Tensor): + new_rowwise_data = None + new_columnwise_data = None + if grad._rowwise_data is not None: + if ctx.shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = list(ctx.shape[:-1]) + [ctx.shape[-1] // 2] + new_rowwise_data = grad._rowwise_data.view(byte_shape) + if grad._columnwise_data is not None: + columnwise_shape = (ctx.shape[-1], math.prod(ctx.shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = grad._columnwise_data.view(byte_shape) + dgrad = NVFP4Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + amax_rowwise=grad._amax_rowwise, + amax_columnwise=grad._amax_columnwise, + quantizer=grad._quantizer, + fp4_dtype=grad._fp4_dtype, + requires_grad=grad.requires_grad, + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the NVFP4Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: NVFP4Tensor, + shape: Optional[list[int]] = None, + ) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + cur_shape = tensor.shape + if ctx is not None: + ctx.shape = cur_shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "NVFP4Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Reshape data + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = list(shape[:-1]) + [shape[-1] // 2] + new_rowwise_data = tensor._rowwise_data.reshape(byte_shape) + if tensor._columnwise_data is not None: + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = tensor._columnwise_data.reshape(byte_shape) + + # Construct tensor + return NVFP4Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + amax_rowwise=tensor._amax_rowwise, + amax_columnwise=tensor._amax_columnwise, + quantizer=tensor._quantizer, + fp4_dtype=tensor._fp4_dtype, + requires_grad=tensor.requires_grad, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, NVFP4Tensor): + new_rowwise_data = None + new_columnwise_data = None + if grad._rowwise_data is not None: + if ctx.shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = list(ctx.shape[:-1]) + [ctx.shape[-1] // 2] + new_rowwise_data = grad._rowwise_data.reshape(byte_shape) + if grad._columnwise_data is not None: + columnwise_shape = (ctx.shape[-1], math.prod(ctx.shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = grad._columnwise_data.reshape(byte_shape) + dgrad = NVFP4Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + amax_rowwise=grad._amax_rowwise, + amax_columnwise=grad._amax_columnwise, + quantizer=grad._quantizer, + fp4_dtype=grad._fp4_dtype, + requires_grad=grad.requires_grad, + ) + return dgrad, None + return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 2f634f399..05c97058e 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -268,6 +268,10 @@ def supports_only_rowwise_all_gather(self) -> bool: """Returns True if the quantizer supports only rowwise all-gather""" return False + def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument + """Returns whether or not given tensor can be quantized""" + return True + class _QuantizeFunc(torch.autograd.Function): """Cast to FP8 from other dtype""" diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 23f56da5d..a4bdf5e07 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -4,11 +4,13 @@ """Helper functions for using fp8 tensors as weights""" +import os +from typing import Optional, Union import torch import transformer_engine_torch as tex from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv -from .quantized_tensor import QuantizedTensor +from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorBase from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer @@ -450,3 +452,20 @@ def _cast_master_weights_to_fp8_blockwise_scaling( tex.fp8_block_scaling_partial_cast( master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype ) + + +def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorBase]] = None) -> bool: + """Check if an environment or object is using experimental Kitchen middleware. + + Returns False if x is a torch.Tensor. + """ + # Detect if the environment is experimental + if x is None: + return int(os.getenv("QAT_PARAMS", "0")) > 0 + + # Detect if the object is experimental + if isinstance(x, torch.Tensor): + return False + if not isinstance(x, (Quantizer, QuantizedTensorBase)): + raise AssertionError("Object must be a Quantizer or QuantizedTensorBase instance") + return hasattr(x, "experimental") and x.experimental diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 89e43f845..8a032b2f5 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -191,6 +191,17 @@ class TransformerLayer(torch.nn.Module): and `DotProductAttention` modules. name: str, default = `None` name of the module, currently used for debugging purposes. + softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' + softmax type as described in this paper: + `Efficient Streaming Language Models with Attention Sinks + `_. + For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], + 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + where alpha is a learnable parameter in shape [h]. + 'off-by-one' and 'learnable' softmax types are also called sink attention + ('zero sink' and 'learnable sink'). Parallelism parameters ---------------------- @@ -306,6 +317,7 @@ def __init__( qk_norm_type: Optional[str] = None, qk_norm_eps: float = 1e-6, qk_norm_before_rope: bool = False, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -362,6 +374,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.attn_input_format = attn_input_format + self.softmax_type = softmax_type self.name = name @@ -397,6 +410,7 @@ def __init__( "qkv_format": self.attn_input_format, "seq_length": seq_length, "micro_batch_size": micro_batch_size, + "softmax_type": self.softmax_type, } self.self_attention = MultiheadAttention( diff --git a/transformer_engine/pytorch/triton/pad.py b/transformer_engine/pytorch/triton/pad.py new file mode 100644 index 000000000..29b0daf31 --- /dev/null +++ b/transformer_engine/pytorch/triton/pad.py @@ -0,0 +1,94 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""NVFP4 padding kernels + +TODO(ksivamani): Documentation + +""" + +import torch + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=1), + ], + key=["out_dim0", "out_dim1"], +) +@triton.jit +def zero_pad_kernel( + inp_ptr, + out_ptr, + in_dim0: tl.constexpr, + in_dim1: tl.constexpr, + out_dim0: tl.constexpr, + out_dim1: tl.constexpr, + in_s0, + in_s1, + out_s0, + out_s1, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """Pads a tensor assuming it's a columnwise scaling inverse.""" + + # tile over OUTPUT coordinates + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # output rows + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # output cols + om = offs_m[:, None] + on = offs_n[None, :] + + # edge masking for output + out_mask = (om < out_dim0) & (on < out_dim1) + + # valid input region is simply top-left (no offsets) + in_mask = (om < in_dim0) & (on < in_dim1) + + # load valid input, else zero (masked load touches memory only where True) + x = tl.load(inp_ptr + om * in_s0 + on * in_s1, mask=in_mask, other=0) + + # store to output (only within bounds of the output tile) + tl.store(out_ptr + om * out_s0 + on * out_s1, x, mask=out_mask) + + +def pad_columnwise_scale_inv(inp: torch.Tensor) -> torch.Tensor: + """Pads a tensor assuming it's a columnwise scaling inverse.""" + + assert inp.ndim == 2 + dim0, dim1 = inp.shape + + pad_x = (128 - dim0 % 128) % 128 + pad_y = (4 - dim1 % 4) % 4 + out_x = dim0 + pad_x + out_y = dim1 + pad_y + out = torch.empty((out_x, out_y), device=inp.device, dtype=inp.dtype) + + in_s0, in_s1 = inp.stride() + out_s0, out_s1 = out.stride() + + BLOCK_M, BLOCK_N = 128, 128 + grid = (triton.cdiv(out_x, BLOCK_M), triton.cdiv(out_y, BLOCK_N)) + + zero_pad_kernel[grid]( + inp, + out, + dim0, + dim1, + out_x, + out_y, + in_s0, + in_s1, + out_s0, + out_s1, + ) + return out diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index d124fbeaf..95ee5e214 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -14,8 +14,8 @@ import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION -import transformer_engine.pytorch.cpp_extensions as ext from . import torch_version +from .tensor.quantized_tensor import Quantizer from ..debug.pytorch.debug_quantization import DebugQuantizedTensor @@ -463,6 +463,16 @@ def is_fp8_fnuz(): get_torch_float8_e4m3_type = lambda: torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn get_torch_float8_e5m2_type = lambda: torch.float8_e5m2fnuz if is_fp8_fnuz() else torch.float8_e5m2 +def assert_dim_for_all_gather( + tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer +) -> None: + """Assert that tensor dimensions are supported for all-gather""" + if with_all_gather: + assert quantizer.is_quantizable(tensor), ( + "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ + ) + + def is_bf16_compatible() -> None: if IS_HIP_EXTENSION: # only MI200 and newer machines support bf16 @@ -492,6 +502,8 @@ def get_cudnn_version() -> Tuple[int, int, int]: # ROCm fused attn does not use cudnn, return high numbers to avoid tests filtering out if IS_HIP_EXTENSION: return (99, 0, 0) + import transformer_engine.pytorch.cpp_extensions as ext + encoded_version = ext.get_cudnn_version() major_version_magnitude = 1000 if encoded_version < 90000 else 10000 major, encoded_version = divmod(encoded_version, major_version_magnitude)