diff --git a/tests/pytorch/triton_kernels/test_cast_mxfp4.py b/tests/pytorch/triton_kernels/test_cast_mxfp4.py new file mode 100644 index 000000000..4d285e886 --- /dev/null +++ b/tests/pytorch/triton_kernels/test_cast_mxfp4.py @@ -0,0 +1,178 @@ +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# License for AMD contributions = MIT. See LICENSE for more information + +import math +import pytest +import torch +import numpy as np +import os + +os.environ["USE_TRITON_FUSED_CAST_TRANSPOSE"] = "1" + +from transformer_engine.pytorch.tensor.mxfp4_tensor import MXFP4Quantizer, MXFP4_BLOCK_SCALING_SIZE +from transformer_engine.pytorch.triton_kernels.cast import te_quantize_triton +from test_common import te_compare_results, fill_uniform + + +def mxfp4_quantize_cpu(input_tensor, axis='row'): + """CPU reference for MXFP4 quantization matching Triton kernel behavior with shuffle.""" + original_shape = input_tensor.shape + if input_tensor.dim() > 2: + input_tensor = input_tensor.view(-1, input_tensor.shape[-1]) + + M, N = input_tensor.shape + + if axis == 'col': + input_tensor = input_tensor.t().contiguous() + M, N = N, M + + data = input_tensor.cpu().float().numpy() + + BLOCK_SIZE = 32 + assert N % BLOCK_SIZE == 0, f"N={N} must be divisible by {BLOCK_SIZE}" + + num_blocks = N // BLOCK_SIZE + + # E2M1 FP4 lookup table + fp4_values = np.array([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) + + # Reshape to blocks: [M, num_blocks, BLOCK_SIZE] + data_blocks = data.reshape(M, num_blocks, BLOCK_SIZE) + amax_blocks = np.max(np.abs(data_blocks), axis=2) + + # Triton's amax rounding: (amax + 0x200000) & 0xFF800000 + amax_int = amax_blocks.astype(np.float32).view(np.uint32) + amax_int = ((amax_int + 0x200000) & 0xFF800000).astype(np.uint32) + amax_rounded = amax_int.view(np.float32) + + # E8M0 scale computation: floor(log2(amax)) - 2 + 127 + scale_unbiased = np.floor(np.log2(np.maximum(amax_rounded, 1e-45))) - 2 + scale_unbiased = np.clip(scale_unbiased, -127, 127) + scales = (scale_unbiased + 127).astype(np.uint8) + scales = np.where(amax_blocks == 0, 0, scales) + + # Scale values for quantization + scale_vals = np.where(scales[:, :, None] > 0, + 2.0 ** (-(scales[:, :, None] - 127)), + 1.0) + + scaled_blocks = data_blocks * scale_vals + + # Quantize to FP4 + signs = (scaled_blocks < 0).astype(np.uint8) + abs_vals = np.abs(scaled_blocks) + diffs = np.abs(abs_vals[:, :, :, None] - fp4_values[None, None, None, :]) + indices = np.argmin(diffs, axis=3).astype(np.uint8) + fp4_encoded = (signs << 3) | indices + + fp4_flat = fp4_encoded.reshape(M, N) + + # Pack: (odd_col << 4) | even_col + fp4_even = fp4_flat[:, 0::2] + fp4_odd = fp4_flat[:, 1::2] + fp4_packed = ((fp4_odd << 4) | fp4_even).astype(np.uint8) + + def cdiv(a, b): return (a + b - 1) // b + + scale_M_pad = cdiv(M, 256) * 256 + scale_N_pad = cdiv(num_blocks, 8) * 8 + scales_padded = np.full((scale_M_pad, scale_N_pad), 127, dtype=np.uint8) + + # Copy scales directly (no data shuffle support in Triton kernel) + scales_padded[:M, :num_blocks] = scales + + fp4_packed_torch = torch.from_numpy(fp4_packed).to(input_tensor.device) + scales_torch = torch.from_numpy(scales_padded).to(input_tensor.device) + + return fp4_packed_torch, scales_torch + + +@pytest.mark.parametrize("shape", [ + (128, 128), + (256, 256), + (256, 1024), + (2048, 6144), + (16384, 128), + (32768, 160), + (4096, 1632), + (8, 32, 1024), + (16, 8, 4, 512), +]) +@pytest.mark.parametrize("in_dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize(("rowwise", "columnwise"), [ + (True, True), + (False, True), + (True, False) +]) +@pytest.mark.parametrize("shuffle_B_matrix", [False, True]) +def test_quantize_mxfp4(shape, in_dtype, rowwise, columnwise, shuffle_B_matrix): + """Test MXFP4 quantization for rowwise/columnwise modes with/without FP4 shuffle. + + Note: FP4 data shuffle (shuffle_B_matrix_for_aiter) is not yet supported in Triton kernel. + """ + if shuffle_B_matrix: + pytest.skip("FP4 data shuffle not yet supported in Triton kernel") + + input_tensor = fill_uniform(shape, dtype=in_dtype) + + quantizer = MXFP4Quantizer( + rowwise=rowwise, + columnwise=columnwise, + shuffle_B_matrix_for_aiter=shuffle_B_matrix + ) + out = quantizer.make_empty(input_tensor.shape, dtype=in_dtype) + quantized_out = te_quantize_triton(input_tensor, quantizer=quantizer, output=out) + + # Tolerance: allow 1 nibble diff for rare edge cases near FP4 boundaries + data_atol = 20.0 if in_dtype != torch.float32 else 16.0 + scale_atol = 2.0 if in_dtype != torch.float32 else 1.0 + + if rowwise: + ref_data, ref_scale = mxfp4_quantize_cpu(input_tensor, axis='row') + M = math.prod(input_tensor.shape[:-1]) + K = input_tensor.shape[-1] + num_blocks = K // MXFP4_BLOCK_SCALING_SIZE + + te_compare_results( + quantized_out._rowwise_data.view(torch.uint8), + ref_data, + atol=data_atol, + rtol=0.0, + msg="rowwise FP4 data mismatch", + use_torch_semantics=True + ) + + # Compare only valid (non-padded) region - no shuffle extraction needed + te_compare_results( + quantized_out._rowwise_scale.view(torch.uint8)[:M, :num_blocks], + ref_scale[:M, :num_blocks], + atol=scale_atol, + rtol=0.0, + msg="rowwise E8M0 scales mismatch", + use_torch_semantics=True + ) + + if columnwise: + ref_data, ref_scale = mxfp4_quantize_cpu(input_tensor, axis='col') + M = math.prod(input_tensor.shape[:-1]) + K = input_tensor.shape[-1] + num_blocks = M // MXFP4_BLOCK_SCALING_SIZE + + te_compare_results( + quantized_out._columnwise_data.view(torch.uint8), + ref_data, + atol=data_atol, + rtol=0.0, + msg="columnwise FP4 data mismatch", + use_torch_semantics=True + ) + + # Compare only valid (non-padded) region - no shuffle extraction needed + te_compare_results( + quantized_out._columnwise_scale.view(torch.uint8)[:K, :num_blocks], + ref_scale[:K, :num_blocks], + atol=scale_atol, + rtol=0.0, + msg="columnwise E8M0 scales mismatch", + use_torch_semantics=True + ) diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index b243a8a0b..c0622ca48 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -108,7 +108,8 @@ .value("kFloat16", transformer_engine::DType::kFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ + .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp4_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp4_tensor_base.py new file mode 100644 index 000000000..8e3bf857b --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/mxfp4_tensor_base.py @@ -0,0 +1,216 @@ +# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# See LICENSE for license information. + +"""Mixin class holding data specific for MXFP4Tensor""" + +from __future__ import annotations +from typing import Optional, Dict, Any, Tuple +import torch + +from transformer_engine_torch import DType as TE_DType + +from ..quantized_tensor import QuantizedTensorBase +from ..quantized_tensor import Quantizer +from ...utils import _empty_tensor + + +class _FromMXFP4Func(torch.autograd.Function): + """Cast from MXFP4 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: MXFP4TensorBase, + dtype: torch.dtype, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + # For MXFP4, return cached high-precision data if available + # Full dequantization from FP4 will be implemented later with AITER support + if hasattr(tensor, '_data') and tensor._data is not None: + # Return cached high-precision data (used during model initialization/teardown) + return tensor._data.to(dtype) if tensor._data.dtype != dtype else tensor._data + + # If no cached data, we would need to dequantize from rowwise FP4 data + # This path should not be hit in forward-only MXFP4 training + + # TODO: Implement MXFP4 dequantization from packed FP4 using AITER kernels + raise NotImplementedError( + "MXFP4 dequantization from packed FP4 not yet implemented. " + "This should only be called during model teardown with cached high-precision data." + ) + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class MXFP4TensorBase(QuantizedTensorBase): + """Mixin class that holds data attributes of MXFP4Tensor. + + MXFP4Tensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + FP4 data format: + - Data: [M, K/2] uint8 tensor (2 FP4 values packed per byte) + - Scale: [M, K/32] uint8 tensor (E8M0 format, one scale per 32-element block) + + """ + + _rowwise_data: Optional[torch.Tensor] # [M, K/2] uint8 + _columnwise_data: Optional[torch.Tensor] # [K, M/2] uint8 (transposed) + _quantizer: Optional[Quantizer] + _fp4_dtype: TE_DType + _rowwise_scale: torch.Tensor # [M, K/32] uint8 E8M0 + _columnwise_scale: torch.Tensor # [K, M/32] uint8 E8M0 + _original_shape: Optional[Tuple[int, ...]] # Original shape before reshape (for 3D inputs) + + def __new__( + cls, + *args, + rowwise_data: Optional[torch.Tensor], + rowwise_scale: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale: torch.Tensor, + fp4_dtype: TE_DType, + quantizer: Optional[Quantizer] = None, + original_shape: Optional[Tuple[int, ...]] = None, + **kwargs, + ): + instance = super().__new__(cls, *args, **kwargs) + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._quantizer = quantizer + instance._fp4_dtype = fp4_dtype + instance._rowwise_scale = rowwise_scale + instance._columnwise_scale = columnwise_scale + instance._original_shape = original_shape + + return instance + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + for t in ( + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale, + self._columnwise_scale, + ): + if t is not None: + t.data = _empty_tensor() + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale": self._rowwise_scale, + "columnwise_data": self._columnwise_data, + "columnwise_scale": self._columnwise_scale, + "fp4_dtype": self._fp4_dtype, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP4TensorBase]: + """Prepare the tensor base for saving for backward""" + tensors = [ + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale, + self._columnwise_scale, + ] + self._rowwise_data = None + self._columnwise_data = None + self._rowwise_scale = None + self._columnwise_scale = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + self._rowwise_scale = tensors[2] + self._columnwise_scale = tensors[3] + return tensors[4:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._rowwise_data, self._columnwise_data + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Dequantize to a higher precision.""" + return _FromMXFP4Func.forward(None, self, dtype) + + def size(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + if self._rowwise_data is not None: + # Note: Rowwise data is [M, K/2] but we report logical size [M, K] + shape = list(self._rowwise_data.size(*args, **kwargs)) + if len(shape) > 0: + shape[-1] = shape[-1] * 2 # Unpacked size + return torch.Size(shape) if not args and not kwargs else shape + # Similar logic for columnwise data + shape = list(self._columnwise_data.size(*args, **kwargs)) + if len(shape) > 0: + shape[-1] = shape[-1] * 2 # Unpacked size + return torch.Size(shape) if not args and not kwargs else shape + + def __repr__(self): + return ( + "MXFP4TensorBase(" + f"fp4_dtype={self._fp4_dtype}, " + f"rowwise_data_shape={self._rowwise_data.shape if self._rowwise_data is not None else None}, " + f"rowwise_scale_shape={self._rowwise_scale.shape if self._rowwise_scale is not None else None}" + ")" + ) + + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + """ + Update the usage of the MXFP4TensorBase. + + """ + + # Default usage is based on available data + if rowwise_usage is None: + rowwise_usage = self._rowwise_data is not None + if columnwise_usage is None: + columnwise_usage = self._columnwise_data is not None + + # Update row-scaled data + if rowwise_usage: + if self._rowwise_data is None: + raise RuntimeError( + "Requested row-wise usage, but MXFP4Tensor is missing row-scaled FP4 data" + ) + if self._rowwise_scale is None: + raise RuntimeError( + "Requested row-wise usage, but MXFP4Tensor is missing row-scaled scales" + ) + else: + self._rowwise_data = None + self._rowwise_scale = None + + # Update column-scaled data + if columnwise_usage: + if self._columnwise_data is None: + raise RuntimeError( + "Requested column-wise usage, but MXFP4Tensor is missing column-scaled FP4 data" + ) + if self._columnwise_scale is None: + raise RuntimeError( + "Requested column-wise usage, but MXFP4Tensor is missing column-scaled scales" + ) + else: + self._columnwise_data = None + self._columnwise_scale = None diff --git a/transformer_engine/pytorch/tensor/mxfp4_tensor.py b/transformer_engine/pytorch/tensor/mxfp4_tensor.py new file mode 100644 index 000000000..084f8275b --- /dev/null +++ b/transformer_engine/pytorch/tensor/mxfp4_tensor.py @@ -0,0 +1,292 @@ +# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# See LICENSE for license information. + +"""Tensor class with MXFP4 data""" + +from __future__ import annotations +from collections.abc import Iterable +import math +from typing import Optional + +import torch +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ..constants import MXFP8_BLOCK_SCALING_SIZE # MXFP4 uses same block size +from ..utils import devices_match + +from ._internal.mxfp4_tensor_base import MXFP4TensorBase, _FromMXFP4Func +from .quantized_tensor import QuantizedTensor, Quantizer + +MXFP4_BLOCK_SCALING_SIZE = MXFP8_BLOCK_SCALING_SIZE + +aten = torch.ops.aten + + +class MXFP4Quantizer(Quantizer): + """Builder class for FP4 tensors with MX block scaling + + High-precision tensors (e.g. in FP32 or BF16) are quantized to FP4 by + dividing them into groups of 32 elements, each scaled and cast + separately using AITER's per_1x32_f4_quant_hip kernel. + + The quantization produces: + - FP4 data: [M, K/2] uint8 (2 FP4 values packed per byte) + - E8M0 scales: [M, K/32] uint8 (one scale per 32-element block) + + """ + + dtype: TE_DType + + def __init__( + self, + fp4_dtype: TE_DType = tex.DType.kFloat4E2M1, + *, + rowwise: bool = True, + columnwise: bool = True, + shuffle_B_matrix_for_aiter: bool = False, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.dtype = fp4_dtype + self.shuffle_B_matrix_for_aiter = shuffle_B_matrix_for_aiter + assert self.dtype == tex.DType.kFloat4E2M1, "Only E2M1 format supported for MXFP4" + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + + assert isinstance(dst, MXFP4Tensor), f"Cannot store quantized MXFP4 in {type(dst)} type." + + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + original_shape = src.shape + if src.dim() > 2: + src = src.view(-1, src.shape[-1]) + + if src.dim() != 2: + raise ValueError( + f"MXFP4 quantization requires 2D tensors for AITER gemm_a4w4, " + f"but got tensor with shape {original_shape} (dim={len(original_shape)}). " + f"Biases and other 1D tensors should not be quantized with MXFP4." + ) + + + + with torch._C._DisableTorchDispatch(): + rowwise_fp4_uint8 = dst._rowwise_data.view(torch.uint8) if dst._rowwise_data is not None else None + rowwise_scale_uint8 = dst._rowwise_scale.view(torch.uint8) if dst._rowwise_scale is not None else None + colwise_fp4_uint8 = dst._columnwise_data.view(torch.uint8) if dst._columnwise_data is not None else None + colwise_scale_uint8 = dst._columnwise_scale.view(torch.uint8) if dst._columnwise_scale is not None else None + + # Triton kernel path - API aligned with HIP + from ..triton_kernels.cast_transpose import te_cast_transpose_mxfp4_triton + + te_cast_transpose_mxfp4_triton( + src, + rowwise_fp4_out=rowwise_fp4_uint8, + rowwise_scale_out=rowwise_scale_uint8, + colwise_fp4_out=colwise_fp4_uint8, + colwise_scale_out=colwise_scale_uint8, + shuffle_rowwise_scale=False, # Not yet supported + shuffle_colwise_scale=False, # Not yet supported + shuffle_rowwise_fp4=self.shuffle_B_matrix_for_aiter, + shuffle_colwise_fp4=self.shuffle_B_matrix_for_aiter, + use_hadamard=False, # Not yet supported + ) + + + # Update FP4 dtype + dst._fp4_dtype = self.dtype + + return dst + + def is_quantizable(self, inp: torch.Tensor) -> bool: + """Returns whether or not given inp can be quantized""" + if inp.ndim < 2: + return False + if inp.shape[-1] % MXFP4_BLOCK_SCALING_SIZE != 0: + return False + if math.prod(inp.shape[:-1]) % MXFP4_BLOCK_SCALING_SIZE != 0: + return False + return True + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> MXFP4Tensor: + + # Canonicalize tensor attributes + if device is None: + device = torch.device("cuda") + + shape = tuple(shape) + assert ( + shape[-1] % MXFP4_BLOCK_SCALING_SIZE == 0 + and math.prod(shape[:-1]) % MXFP4_BLOCK_SCALING_SIZE == 0 + ), ( + f"Incorrect shape {shape} for MXFP4. Tensor dims must be divisible by" + f" {MXFP4_BLOCK_SCALING_SIZE}" + ) + + M = math.prod(shape[:-1]) + K = shape[-1] + + def cdiv(a, b): + return (a + b - 1) // b + + # Allocate FP4 data: [M, K/2] + rowwise_data = torch.empty(M, K // 2, dtype=torch.float4_e2m1fn_x2, device=device) + + # Allocate PADDED scale tensors for shuffle compatibility + rowwise_scale_N = K // MXFP4_BLOCK_SCALING_SIZE + rowwise_scale_M_pad = cdiv(M, 256) * 256 + rowwise_scale_N_pad = cdiv(rowwise_scale_N, 8) * 8 + rowwise_scale = torch.empty( + rowwise_scale_M_pad, rowwise_scale_N_pad, + dtype=torch.float8_e8m0fnu, device=device + ) + + # Allocate FP4 data transpose if needed + columnwise_data = None + columnwise_scale = None + if self.columnwise_usage: + columnwise_data = torch.empty(K, M // 2, dtype=torch.float4_e2m1fn_x2, device=device) + colwise_scale_N = M // MXFP4_BLOCK_SCALING_SIZE + colwise_scale_M_pad = cdiv(K, 256) * 256 + colwise_scale_N_pad = cdiv(colwise_scale_N, 8) * 8 + columnwise_scale = torch.empty( + colwise_scale_M_pad, colwise_scale_N_pad, + dtype=torch.float8_e8m0fnu, device=device + ) + + # Construct FP4 tensor + return MXFP4Tensor( + shape=shape, + dtype=dtype, + fp4_dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale=rowwise_scale, + columnwise_data=columnwise_data, + columnwise_scale=columnwise_scale, + quantizer=self, + original_shape=None, # Will be set during update_quantized if needed + requires_grad=requires_grad, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + # No calibration needed for MXFP4 (uses per-block current scaling) + pass + + def _get_compatible_recipe(self): + """Returns recipe class that is compatible with this quantizer. + + MXFP4 doesn't have a dedicated recipe yet, return None for now. + """ + return None + + +class MXFP4Tensor(MXFP4TensorBase, QuantizedTensor): + """Experimental tensor class with FP4 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP4. The FP4 data is packed with + 2 FP4 values per byte. + + For MXFP4 forward-only training: + - Forward pass: Uses FP4 quantized data with AITER gemm_a4w4 + - Backward pass: Uses high-precision (BF16) gradients + + Parameters + ---------- + data: torch.Tensor + Raw FP4 data in a uint8 tensor [M, K/2] + fp4_dtype: transformer_engine_torch.DType, default = kFloat4E2M1 + FP4 format (E2M1: 2 bits exponent, 1 bit mantissa) + fp4_scale: torch.Tensor + E8M0 scaling factors [M, K/32], one per 32-element block + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype + + """ + + def __repr__(self, *, tensor_contents=None): + return ( + f"MXFP4Tensor(fp4_dtype={self._fp4_dtype}, " + f"shape={self.shape}, " + f"rowwise_data_shape={self._rowwise_data.shape if self._rowwise_data is not None else None})" + ) + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from MXFP4Tensor + + By default the resulting tensor's dtype is the MXFP4Tensor's nominal dtype. + + Note: For MXFP4 forward-only training, this is typically not needed as + backward pass uses high-precision activations. + """ + # Convert PyTorch dtype to TE dtype + if dtype is None: + dtype = self.dtype + + if torch.is_grad_enabled(): + return _FromMXFP4Func.apply(self, dtype) + return _FromMXFP4Func.forward(None, self, dtype) + + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + if self._quantizer is not None: + return self._quantizer + return MXFP4Quantizer( + fp4_dtype=self._fp4_dtype, + ) + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> MXFP4Tensor: + """Quantize a tensor and store result in this tensor + + This updates the FP4 data and scales in-place. + + """ + quantizer = self._get_quantizer() + return quantizer.update_quantized(tensor, self, noop_flag=noop_flag) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> MXFP4Tensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if self._rowwise_data is not None and self._rowwise_data.is_contiguous( + memory_format=memory_format + ): + return self + if self._columnwise_data is not None and self._columnwise_data.is_contiguous( + memory_format=memory_format + ): + return self + raise ValueError("MXFP4Tensor does not support different memory formats!") + diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 2f634f399..a10cadf78 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -86,6 +86,61 @@ def update_quantizer(self, quantizer: Quantizer): self._quantizer = quantizer +class QuantizedTensorBase: + r"""Base class for all *TensorBase classes. + + This class (and its subclasses) are optimization for when + the full QuantizedTensor is not needed (when it is fully + contained inside torch.autograd function and not visible to + PyTorch's autograd). + + When creating a new tensor type X one should create both + XTensorBase class inheriting from QuantizedTensorBase and + XTensor inheriting from XTensorBase and QuantizedTensor. + XTensorBase should contain all data members needed to + implement the functionality of the tensor, while + XTensor should only implement the functionality needed + to behave like regular torch.Tensor (liek __torch_dispatch__).""" + + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + r""" + Generate or remove quantized data based on provided usage. + + Parameters + ---------- + rowwise_usage : Optional[bool[, default = `None` + Whether to create or keep the data needed for using the tensor + in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None` + preserves the original value in the tensor. + columnwise_usage : Optional[bool], default = `None` + Whether to create or keep the data needed for using the tensor + in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as + `None` preserves the original value in the tensor. + + """ + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement update_usage function" + ) + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]: + """Prepare the tensor base for saving for backward""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement prepare_for_saving function" + ) + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement restore_from_saved function" + ) + + def prepare_for_saving( *tensors: Union[torch.Tensor, QuantizedTensorBase], ) -> Tuple[ @@ -537,4 +592,4 @@ def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: The new tensor has the same underlying data. """ - return self.__class__.make_like(self, dtype=dtype) + return self.__class__.make_like(self, dtype=dtype) \ No newline at end of file diff --git a/transformer_engine/pytorch/triton_kernels/cast.py b/transformer_engine/pytorch/triton_kernels/cast.py index b6a7270a3..b9b975118 100644 --- a/transformer_engine/pytorch/triton_kernels/cast.py +++ b/transformer_engine/pytorch/triton_kernels/cast.py @@ -15,6 +15,7 @@ import transformer_engine_torch as tex from ..tensor.quantized_tensor import QuantizedTensor, Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor._internal.mxfp4_tensor_base import MXFP4TensorBase @functools.lru_cache(maxsize=None) def _empty_tensor() -> torch.Tensor: @@ -119,6 +120,9 @@ def te_quantize_triton( out = tex.quantize(input_tensor, quantizer, out, noop_flag) elif isinstance(out, MXFP8TensorBase): te_cast_transpose_mxfp8_triton(input_tensor, out) + elif isinstance(out, MXFP4TensorBase): + # MXFP4 uses AITER quantization directly - bypass C++ tex.quantize + out = quantizer.update_quantized(input_tensor, out, noop_flag=noop_flag) else: raise NotImplementedError(f"Not implemented for tensor type: '{type(out).__name__}'") @@ -131,4 +135,3 @@ def te_dequantize_triton(input, dtype: tex.DType): return tex.dequantize(input, dtype) else: raise NotImplementedError(f"Not implemented for tensor type: '{type(input).__name__}'") - diff --git a/transformer_engine/pytorch/triton_kernels/cast_transpose.py b/transformer_engine/pytorch/triton_kernels/cast_transpose.py index 178a19093..16d693f17 100644 --- a/transformer_engine/pytorch/triton_kernels/cast_transpose.py +++ b/transformer_engine/pytorch/triton_kernels/cast_transpose.py @@ -2,9 +2,28 @@ # License for AMD contributions = MIT. See LICENSE for more information import torch +from typing import Optional + +try: + from ..constants import MXFP8_BLOCK_SCALING_SIZE + from .common import ( + te_dtype_to_triton_dtype, + te_dtype_to_torch_dtype, + get_fp8_max, + ) +except Exception: # pragma: no cover - fallback for standalone benchmarking + MXFP8_BLOCK_SCALING_SIZE = 32 + + def _missing(*args, **kwargs): + raise ImportError( + "transformer_engine dependencies not available. " + "Ensure transformer_engine_torch is installed." + ) + + te_dtype_to_triton_dtype = _missing # type: ignore + te_dtype_to_torch_dtype = _missing # type: ignore + get_fp8_max = _missing # type: ignore -from ..constants import MXFP8_BLOCK_SCALING_SIZE -import transformer_engine_torch as tex import triton import triton.language as tl from .common import ( @@ -398,6 +417,318 @@ def _cast_transpose_triton_mxfp8( colwise_y_ptr_current_chunk = colwise_y_ptr + offsets_Y[:, None] * stride_rowwise_row + offsets_X[None, :] * stride_rowwise_col tl.store(colwise_y_ptr_current_chunk, y_chunk_colwise_scaled.to(colwise_y_ptr.type.element_ty), mask=mask) +########################################## +#### cast_transpose_mxfp4 +########################################## + +@triton.jit +def _cast_transpose_triton_mxfp4( + x_ptr, + rowwise_fp4_ptr, + rowwise_scale_ptr, + colwise_fp4_ptr, + colwise_scale_ptr, + stride_x_m, + stride_x_n, + stride_rowwise_fp4_m, + stride_rowwise_fp4_n, + stride_rowwise_scale_m, + stride_rowwise_scale_n, + stride_colwise_fp4_m, + stride_colwise_fp4_n, + stride_colwise_scale_m, + stride_colwise_scale_n, + M: tl.constexpr, + N: tl.constexpr, + rowwise_scale_N: tl.constexpr, + rowwise_scale_M_pad: tl.constexpr, + rowwise_scale_N_pad: tl.constexpr, + colwise_scale_M: tl.constexpr, + colwise_scale_N: tl.constexpr, + colwise_scale_M_pad: tl.constexpr, + colwise_scale_N_pad: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + MXFP4_BLOCK_SIZE: tl.constexpr, + USE_ROWWISE: tl.constexpr, + USE_COLWISE: tl.constexpr, + SHUFFLE_ROWWISE: tl.constexpr, + SHUFFLE_COLWISE: tl.constexpr, +): + """ + MXFP4 cast + transpose (rowwise + columnwise) following the MXFP8 fused pattern. + + Example to keep in mind: + Input (M, N) = (4096, 6144) bf16 + Rowwise output = (M, N/2) uint8 (two FP4 packed per byte) + Colwise output = (N, M/2) uint8 + + Grid layout: + BLOCK_M x BLOCK_N tile (default 128 x 128). + Inside the tile we iterate over 32 x 32 MXFP4 blocks. + + Strides: + - stride_x_m / stride_x_n point into the source matrix. + - stride_rowwise_fp4_* index rowwise packed bytes. + - stride_colwise_fp4_* index columnwise packed bytes. + """ + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + stride_x_m = tl.cast(stride_x_m, tl.int64) + stride_x_n = tl.cast(stride_x_n, tl.int64) + stride_rowwise_fp4_m = tl.cast(stride_rowwise_fp4_m, tl.int64) + stride_rowwise_fp4_n = tl.cast(stride_rowwise_fp4_n, tl.int64) + stride_colwise_fp4_m = tl.cast(stride_colwise_fp4_m, tl.int64) + stride_colwise_fp4_n = tl.cast(stride_colwise_fp4_n, tl.int64) + + num_chunks_m = BLOCK_M // MXFP4_BLOCK_SIZE + num_chunks_n = BLOCK_N // MXFP4_BLOCK_SIZE + + base_m = pid_m * BLOCK_M + base_n = pid_n * BLOCK_N + + # Each BLOCK_M covers BLOCK_M / 32 MXFP4 row blocks. + row_block_base = (base_m // MXFP4_BLOCK_SIZE) + + E8_BIAS = tl.constexpr(127) + E2_BIAS = tl.constexpr(1) + + for chunk_m in range(num_chunks_m): + offs_m = base_m + chunk_m * MXFP4_BLOCK_SIZE + tl.arange(0, MXFP4_BLOCK_SIZE) + row_mask = offs_m < M + + for chunk_n in range(num_chunks_n): + offs_n = base_n + chunk_n * MXFP4_BLOCK_SIZE + tl.arange(0, MXFP4_BLOCK_SIZE) + col_mask = offs_n < N + + mask = row_mask[:, None] & col_mask[None, :] + + # Load a 32x32 bf16 tile (promoted to fp32) so both row/col passes reuse the same data. TODO @saraora to double check if this is necessary. + # offs_m = 128*k + [0..31] + # offs_n = 128*l + [0..31] + # This chunk is reused for both rowwise and columnwise passes. + x_chunk = tl.load( + x_ptr + offs_m[:, None] * stride_x_m + offs_n[None, :] * stride_x_n, + mask=mask, + other=0.0, + ).to(tl.float32) + + # ---------- Rowwise path ---------- + if USE_ROWWISE: + # For each row in the current tile (base_m + row0), process the elements in [base_n : base_n + 31]. + # Compute one E8M0 scale per row (32 elements). + amax_row = tl.max(tl.abs(x_chunk), axis=1, keep_dims=True) + amax_row = amax_row.to(tl.int32, bitcast=True) + amax_row = (amax_row + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax_row = amax_row.to(tl.float32, bitcast=True) + scale_unbiased_row = tl.log2(amax_row).floor() - 2 + scale_unbiased_row = tl.clamp(scale_unbiased_row, min=-127, max=127) + quant_scale_row = tl.exp2(-scale_unbiased_row) + + qx_row = x_chunk * quant_scale_row + bs_row = scale_unbiased_row.to(tl.uint8) + 127 + + qx_row_u32 = qx_row.to(tl.uint32, bitcast=True) + s_row = qx_row_u32 & 0x80000000 + e_row = (qx_row_u32 >> 23) & 0xFF + m_row = qx_row_u32 & 0x7FFFFF + + adjusted_row = tl.core.sub(E8_BIAS, e_row + 1, sanitize_overflow=False) + m_row = tl.where(e_row < E8_BIAS, (0x400000 | (m_row >> 1)) >> adjusted_row, m_row) + e_row = tl.maximum(e_row, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS) + + e2m1_row = tl.minimum((((e_row << 2) | (m_row >> 21)) + 1) >> 1, 0x7) + e2m1_row = ((s_row >> 28) | e2m1_row).to(tl.uint8) + + # Pack columns (C0,C1) -> byte0, (C2,C3) -> byte1, etc. + row_pairs = tl.reshape( + e2m1_row, [MXFP4_BLOCK_SIZE, MXFP4_BLOCK_SIZE // 2, 2] + ) + vals_even, vals_odd = tl.split(row_pairs) + packed_row = vals_even | (vals_odd << 4) + + row_out_rows = offs_m + row_out_cols = ( + (pid_n * BLOCK_N) // 2 + + chunk_n * (MXFP4_BLOCK_SIZE // 2) + + tl.arange(0, MXFP4_BLOCK_SIZE // 2) + ) + row_store_mask = ( + (row_out_rows < M)[:, None] + & (row_out_cols < (N // 2))[None, :] + ) + + tl.store( + rowwise_fp4_ptr + + row_out_rows[:, None] * stride_rowwise_fp4_m + + row_out_cols[None, :] * stride_rowwise_fp4_n, + packed_row, + mask=row_store_mask, + ) + + scale_offset_x = (pid_n * num_chunks_n) + chunk_n + scale_rows = offs_m + + if SHUFFLE_ROWWISE: + # Rowwise shuffle matches AITER's e8m0_shuffle: + # view(sm//32, 2, 16, sn//8, 2, 4) -> permute(0, 3, 5, 2, 4, 1) -> view(sm, sn) + # where sm = M (rows), sn = N/32 (scale columns) + # + # For input (row=scale_rows, col=scale_offset_x): + # i0 = row // 32 + # i1 = (row % 32) // 16 + # i2 = row % 16 + # i3 = col // 8 + # i4 = (col % 8) // 4 + # i5 = col % 4 + # Output linear = i0*(sn//8*256) + i3*256 + i5*64 + i2*4 + i4*2 + i1 + i0 = scale_rows[:, None] // 32 + i1 = (scale_rows[:, None] % 32) // 16 + i2 = scale_rows[:, None] % 16 + i3 = scale_offset_x // 8 + i4 = (scale_offset_x % 8) // 4 + i5 = scale_offset_x % 4 + + # rowwise_scale_N_pad is already (N/32) rounded up to multiple of 8 + bs_offs = ( + i0 * (rowwise_scale_N_pad // 8 * 256) + + i3 * 256 + + i5 * 64 + + i2 * 4 + + i4 * 2 + + i1 + ) + mask_valid = (scale_rows < M)[:, None] & ( + scale_offset_x < rowwise_scale_N + ) + mask_pad = (scale_rows < rowwise_scale_M_pad)[:, None] & ( + scale_offset_x < rowwise_scale_N_pad + ) + vals = tl.where(mask_valid, bs_row, 127) + tl.store(rowwise_scale_ptr + bs_offs, vals, mask=mask_pad) + else: + scale_mask = (scale_rows < M)[:, None] & ( + scale_offset_x < rowwise_scale_N + ) + tl.store( + rowwise_scale_ptr + + scale_rows[:, None] * stride_rowwise_scale_m + + scale_offset_x * stride_rowwise_scale_n, + bs_row, + mask=scale_mask, + ) + + # ---------- Columnwise path ---------- + if USE_COLWISE: + # Treat columns as rows by transposing to reuse the same per-row logic. + # Instead of manually transposing indices, view the tile transposed. + x_col = tl.trans(x_chunk) + amax_col = tl.max(tl.abs(x_col), axis=1, keep_dims=True) + amax_col = amax_col.to(tl.int32, bitcast=True) + amax_col = (amax_col + 0x200000).to(tl.uint32, bitcast=True) & 0xFF800000 + amax_col = amax_col.to(tl.float32, bitcast=True) + scale_unbiased_col = tl.log2(amax_col).floor() - 2 + scale_unbiased_col = tl.clamp(scale_unbiased_col, min=-127, max=127) + quant_scale_col = tl.exp2(-scale_unbiased_col) + + qx_col = x_col * quant_scale_col + bs_col = scale_unbiased_col.to(tl.uint8) + 127 + + qx_col_u32 = qx_col.to(tl.uint32, bitcast=True) + s_col = qx_col_u32 & 0x80000000 + e_col = (qx_col_u32 >> 23) & 0xFF + m_col = qx_col_u32 & 0x7FFFFF + + adjusted_col = tl.core.sub(E8_BIAS, e_col + 1, sanitize_overflow=False) + m_col = tl.where(e_col < E8_BIAS, (0x400000 | (m_col >> 1)) >> adjusted_col, m_col) + e_col = tl.maximum(e_col, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS) + + e2m1_col = tl.minimum((((e_col << 2) | (m_col >> 21)) + 1) >> 1, 0x7) + e2m1_col = ((s_col >> 28) | e2m1_col).to(tl.uint8) + + # After transpose, each row in x_col is one column from the original tile. + col_pairs = tl.reshape( + e2m1_col, [MXFP4_BLOCK_SIZE, MXFP4_BLOCK_SIZE // 2, 2] + ) + vals_even, vals_odd = tl.split(col_pairs) + packed_col = vals_even | (vals_odd << 4) # [cols, row_pairs] + + col_indices = ( + base_n + chunk_n * MXFP4_BLOCK_SIZE + tl.arange(0, MXFP4_BLOCK_SIZE) + ) + row_pairs = tl.arange(0, MXFP4_BLOCK_SIZE // 2) + rowpair_base = (base_m // 2) + chunk_m * (MXFP4_BLOCK_SIZE // 2) + rowpair_indices = rowpair_base + row_pairs + + # col_indices: [base_n + chunk_n*MXFP4_BLOCK_SIZE + i for i in range(MXFP4_BLOCK_SIZE)] + # rowpair_indices: [base_m // 2 + chunk_m*(MXFP4_BLOCK_SIZE//2) + j for j in range(MXFP4_BLOCK_SIZE//2)] + col_fp4_mask = (col_indices < N)[:, None] & ( + rowpair_indices < (M // 2) + )[None, :] + + # Store directly into the [N, M/2] layout expected by columnwise tensors. + tl.store( + colwise_fp4_ptr + + col_indices[:, None] * stride_colwise_fp4_m + + rowpair_indices[None, :] * stride_colwise_fp4_n, + packed_col, + mask=col_fp4_mask, + ) + + scale_chunk = (pid_m * num_chunks_m) + chunk_m + + if SHUFFLE_COLWISE: + # Columnwise shuffle matches AITER's e8m0_shuffle: + # view(sm//32, 2, 16, sn//8, 2, 4) -> permute(0, 3, 5, 2, 4, 1) -> view(sm, sn) + # where sm = colwise_scale_M (N), sn = colwise_scale_N (M/32) + # + # For input (row=col_indices, col=scale_chunk): + # i0 = row // 32 + # i1 = (row % 32) // 16 + # i2 = row % 16 + # i3 = col // 8 + # i4 = (col % 8) // 4 + # i5 = col % 4 + # Output linear = i0*(sn//8*256) + i3*256 + i5*64 + i2*4 + i4*2 + i1 + bs_col_1d = tl.reshape(bs_col, [MXFP4_BLOCK_SIZE]) + i0 = col_indices // 32 + i1 = (col_indices % 32) // 16 + i2 = col_indices % 16 + i3 = scale_chunk // 8 + i4 = (scale_chunk % 8) // 4 + i5 = scale_chunk % 4 + + # colwise_scale_N_pad is already (M/32) rounded up to multiple of 8 + bs_offs = ( + i0 * (colwise_scale_N_pad // 8 * 256) + + i3 * 256 + + i5 * 64 + + i2 * 4 + + i4 * 2 + + i1 + ) + mask_valid = (col_indices < colwise_scale_M) & ( + scale_chunk < colwise_scale_N + ) + mask_pad = (col_indices < colwise_scale_M_pad) & ( + scale_chunk < colwise_scale_N_pad + ) + vals = tl.where(mask_valid, bs_col_1d, 127) + tl.store(colwise_scale_ptr + bs_offs, vals, mask=mask_pad) + else: + # Simple row-major layout: each column has scale_chunk entries along the N-dimension. + scale_mask = (col_indices < colwise_scale_M)[:, None] & ( + scale_chunk < colwise_scale_N + ) + tl.store( + colwise_scale_ptr + + col_indices[:, None] * stride_colwise_scale_m + + scale_chunk * stride_colwise_scale_n, + bs_col, + mask=scale_mask, + ) + @triton.jit def _dequantize_mxfp8_triton( x_ptr, y_ptr, @@ -580,6 +911,188 @@ def te_cast_transpose_mxfp8_triton(input, out, noop_flag=None): colwise_scale_M, colwise_scale_N, max_fp8, BLOCK_X, BLOCK_Y, GROUP_Y, MXFP8_BLOCK_SCALING_SIZE, USE_ROWWISE_SCALING, USE_COLWISE_SCALING) +def te_cast_transpose_mxfp4_triton( + input: torch.Tensor, + rowwise_fp4_out: Optional[torch.Tensor] = None, + rowwise_scale_out: Optional[torch.Tensor] = None, + colwise_fp4_out: Optional[torch.Tensor] = None, + colwise_scale_out: Optional[torch.Tensor] = None, + shuffle_rowwise_scale: bool = False, + shuffle_colwise_scale: bool = False, + shuffle_rowwise_fp4: bool = False, + shuffle_colwise_fp4: bool = False, + use_hadamard: bool = False, +) -> tuple: + """ + Fused MXFP4 quantization with optional transpose + + Performs quantization for both rowwise and columnwise layouts + + Args: + input: Input tensor [M, N] in BF16/FP16 + rowwise_fp4_out: Optional pre-allocated rowwise FP4 output [M, N/2] + rowwise_scale_out: Optional pre-allocated rowwise E8M0 scales + colwise_fp4_out: Optional pre-allocated colwise FP4 output [N, M/2] + colwise_scale_out: Optional pre-allocated colwise E8M0 scales + shuffle_rowwise_scale: Whether to apply shuffle permutation to rowwise scales + shuffle_colwise_scale: Whether to apply shuffle permutation to colwise scales + shuffle_rowwise_fp4: Whether to apply shuffle permutation to rowwise FP4 data + shuffle_colwise_fp4: Whether to apply shuffle permutation to colwise FP4 data + use_hadamard: Whether to apply Hadamard transform before quantization + + Returns: + (rowwise_fp4, rowwise_scale, colwise_fp4, colwise_scale) + """ + # Check for unsupported features + if use_hadamard: + raise NotImplementedError("Fused Hadamard transform is not supported in Triton MXFP4 kernel") + if shuffle_rowwise_fp4 or shuffle_colwise_fp4: + raise NotImplementedError("FP4 data shuffle is not supported in Triton MXFP4 kernel") + + # Reshape input to 2D + original_shape = input.shape + if input.dim() > 2: + input = input.view(-1, input.shape[-1]) + if input.dim() != 2: + raise ValueError(f"Input must be 2D or reshapeable to 2D, got shape {original_shape}") + + M, N = input.shape + MXFP4_BLOCK_SIZE = 32 + BLOCK_M = 128 + BLOCK_N = 128 + + # Validate dimensions + assert N % MXFP4_BLOCK_SIZE == 0, f"N={N} must be divisible by {MXFP4_BLOCK_SIZE}" + + device = input.device + USE_ROWWISE = rowwise_fp4_out is not None or colwise_fp4_out is None + USE_COLWISE = colwise_fp4_out is not None + + # Allocate rowwise outputs (matching AITER layout) + if USE_ROWWISE: + if rowwise_fp4_out is None: + rowwise_fp4_out = torch.empty(M, N // 2, dtype=torch.uint8, device=device) + + scaleN_row = triton.cdiv(N, MXFP4_BLOCK_SIZE) + if rowwise_scale_out is None: + if shuffle_rowwise_scale: + # AITER shuffled layout + scaleM = triton.cdiv(M, 32) * 32 + scaleN = triton.cdiv(scaleN_row, 8) * 8 + rowwise_scale_out = torch.empty( + triton.cdiv(M, 256) * 256, scaleN, + dtype=torch.uint8, device=device + ) + else: + # Non-shuffled layout + rowwise_scale_out = torch.empty(M, scaleN_row, dtype=torch.uint8, device=device) + + scaleM_pad = triton.cdiv(M, 32) * 32 + scaleN_pad = triton.cdiv(scaleN_row, 8) * 8 + else: + scaleN_row = 1 + scaleM_pad = scaleN_pad = 1 + + colwise_scale_tmp = None + kernel_colwise_scale = None + kernel_colwise_scale_M = kernel_colwise_scale_N = 1 + kernel_colwise_scale_M_pad = kernel_colwise_scale_N_pad = 1 + + # Allocate columnwise outputs (transposed) + if USE_COLWISE: + if colwise_fp4_out is None: + colwise_fp4_out = torch.empty(N, M // 2, dtype=torch.uint8, device=device) + + scaleN_colwise_valid = triton.cdiv(M, MXFP4_BLOCK_SIZE) + if colwise_scale_out is None: + if shuffle_colwise_scale: + # AITER shuffled layout for colwise + scaleM_colwise_pad = triton.cdiv(N, 32) * 32 + scaleN_colwise_pad = triton.cdiv(scaleN_colwise_valid, 8) * 8 + colwise_scale_out = torch.empty( + triton.cdiv(N, 256) * 256, scaleN_colwise_pad, + dtype=torch.uint8, device=device + ) + else: + # Non-shuffled layout + colwise_scale_out = torch.empty(N, scaleN_colwise_valid, dtype=torch.uint8, device=device) + + if shuffle_colwise_scale: + scaleM_colwise_pad = triton.cdiv(N, 256) * 256 + scaleN_colwise_pad = triton.cdiv(scaleN_colwise_valid, 8) * 8 + else: + scaleM_colwise_pad = N + scaleN_colwise_pad = scaleN_colwise_valid + + if shuffle_colwise_scale: + # Allocate padded temporary tensor for shuffled output + colwise_scale_tmp = torch.empty( + scaleM_colwise_pad, + scaleN_colwise_pad, + dtype=torch.uint8, + device=device, + ) + kernel_colwise_scale = colwise_scale_tmp + kernel_colwise_scale_M = N # Valid (non-padded) dimension + kernel_colwise_scale_N = scaleN_colwise_valid # Valid (non-padded) dimension + kernel_colwise_scale_M_pad = scaleM_colwise_pad + kernel_colwise_scale_N_pad = scaleN_colwise_pad + else: + kernel_colwise_scale = colwise_scale_out + kernel_colwise_scale_M = colwise_scale_out.shape[0] + kernel_colwise_scale_N = colwise_scale_out.shape[1] + kernel_colwise_scale_M_pad = scaleM_colwise_pad + kernel_colwise_scale_N_pad = scaleN_colwise_pad + else: + scaleM_colwise_pad = scaleN_colwise_pad = 1 + kernel_colwise_scale = colwise_scale_out + + # Ensure input is contiguous + if not input.is_contiguous(): + input = input.contiguous() + + # Launch kernel with (M_blocks, N_blocks) + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + + _cast_transpose_triton_mxfp4[grid]( + input, + rowwise_fp4_out if USE_ROWWISE else None, + rowwise_scale_out if USE_ROWWISE else None, + colwise_fp4_out if USE_COLWISE else None, + kernel_colwise_scale if USE_COLWISE else None, + input.stride(0), input.stride(1), + rowwise_fp4_out.stride(0) if USE_ROWWISE else 1, + rowwise_fp4_out.stride(1) if USE_ROWWISE else 1, + rowwise_scale_out.stride(0) if USE_ROWWISE else 1, + rowwise_scale_out.stride(1) if USE_ROWWISE else 1, + colwise_fp4_out.stride(0) if USE_COLWISE else 1, + colwise_fp4_out.stride(1) if USE_COLWISE else 1, + kernel_colwise_scale.stride(0) if USE_COLWISE else 1, + kernel_colwise_scale.stride(1) if USE_COLWISE else 1, + M=M, + N=N, + rowwise_scale_N=scaleN_row, + rowwise_scale_M_pad=scaleM_pad, + rowwise_scale_N_pad=scaleN_pad, + colwise_scale_M=kernel_colwise_scale_M, + colwise_scale_N=kernel_colwise_scale_N, + colwise_scale_M_pad=kernel_colwise_scale_M_pad, + colwise_scale_N_pad=kernel_colwise_scale_N_pad, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + MXFP4_BLOCK_SIZE=MXFP4_BLOCK_SIZE, + USE_ROWWISE=USE_ROWWISE, + USE_COLWISE=USE_COLWISE, + SHUFFLE_ROWWISE=shuffle_rowwise_scale, + SHUFFLE_COLWISE=shuffle_colwise_scale, + ) + + # Copy shuffled columnwise scales to output tensor (trim padding) + if USE_COLWISE and shuffle_colwise_scale: + colwise_scale_out[:N, :scaleN_colwise_valid] = kernel_colwise_scale[:N, :scaleN_colwise_valid] + + return rowwise_fp4_out, rowwise_scale_out, colwise_fp4_out, colwise_scale_out + def te_dequantize_mxfp8_triton(input, dtype): input_metadata = input.get_metadata() use_rowwise_scaling = input_metadata["rowwise_data"] is not None diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index d124fbeaf..d3c65b667 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -26,6 +26,11 @@ def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: return True return False +@functools.lru_cache(maxsize=None) +def _empty_tensor() -> torch.Tensor: + """Get tensor with no entries and no data""" + return torch.Tensor().cuda() + @functools.lru_cache(maxsize=None) def _empty_tensor() -> torch.Tensor: