From 7c72ee05ee6df6943e469e839a449fab778e9feb Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 5 Sep 2025 16:26:48 -0500 Subject: [PATCH 01/16] Initial refactor --- tests/pytorch/triton_kernels/test_norms.py | 8 +- transformer_engine/pytorch/module/_common.py | 8 +- .../pytorch/module/layernorm_linear.py | 3 +- .../pytorch/module/layernorm_mlp.py | 3 +- .../pytorch/ops/basic/layer_norm.py | 2 +- .../pytorch/ops/basic/rmsnorm.py | 2 +- .../pytorch/triton_kernels/layernorm.py | 225 +----------- .../pytorch/triton_kernels/norms.py | 342 ++++++++++++++++++ .../pytorch/triton_kernels/rmsnorm.py | 111 +----- 9 files changed, 364 insertions(+), 340 deletions(-) create mode 100644 transformer_engine/pytorch/triton_kernels/norms.py diff --git a/tests/pytorch/triton_kernels/test_norms.py b/tests/pytorch/triton_kernels/test_norms.py index 44c481e29..b9b652600 100644 --- a/tests/pytorch/triton_kernels/test_norms.py +++ b/tests/pytorch/triton_kernels/test_norms.py @@ -17,13 +17,11 @@ ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor -from transformer_engine.pytorch.triton_kernels.rmsnorm import ( - te_rmsnorm_bwd_triton, - te_rmsnorm_fwd_triton, -) -from transformer_engine.pytorch.triton_kernels.layernorm import ( +from transformer_engine.pytorch.triton_kernels.norms import ( te_layernorm_bwd_triton, te_layernorm_fwd_triton, + te_rmsnorm_bwd_triton, + te_rmsnorm_fwd_triton, ) from test_common import dtype_tols, te_compare_results, str_to_torch_dtype, fill_uniform diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index c9a823fe3..42efc3e0f 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -24,8 +24,12 @@ _use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0"))) if IS_HIP_EXTENSION: - from ..triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton - from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton + from ..triton_kernels.norms import ( + te_layernorm_fwd_triton, + te_layernorm_bwd_triton, + te_rmsnorm_fwd_triton, + te_rmsnorm_bwd_triton + ) def _get_normalization_func(normalization: str, forward: bool): use_rmsnorm_triton = bool( int(os.environ.get('NVTE_USE_RMSNORM_TRITON', '0')) ) and IS_HIP_EXTENSION diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index e5770d9e7..01d5c99df 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -68,8 +68,7 @@ ) if IS_HIP_EXTENSION: - from ..triton_kernels.layernorm import te_layernorm_bwd_triton - from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton + from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton from ..rocm_utils import create_fp8_weight_transpose_cache, clear_fp8_weight_transpose_cache diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 31780d4e0..ac8d9c626 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -75,8 +75,7 @@ ) if IS_HIP_EXTENSION: - from ..triton_kernels.layernorm import te_layernorm_bwd_triton - from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton + from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton from ..rocm_utils import create_fp8_weight_transpose_cache, clear_fp8_weight_transpose_cache diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index c94459bc3..d6294ee89 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -17,7 +17,7 @@ from transformer_engine_torch import layernorm_bwd, layernorm_fwd from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: - from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton + from ...triton_kernels.norms import te_layernorm_fwd_triton, te_layernorm_bwd_triton from ...fp8 import FP8GlobalStateManager from ...tensor import QuantizedTensor from ...constants import TE_DType diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index e945d25fc..a7bbe5d7f 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -17,7 +17,7 @@ from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: - from ...triton_kernels.rmsnorm import ( + from ...triton_kernels.norms import ( te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton ) diff --git a/transformer_engine/pytorch/triton_kernels/layernorm.py b/transformer_engine/pytorch/triton_kernels/layernorm.py index 265093b73..cfdd5c335 100644 --- a/transformer_engine/pytorch/triton_kernels/layernorm.py +++ b/transformer_engine/pytorch/triton_kernels/layernorm.py @@ -167,10 +167,10 @@ def _layernorm_fwd_triton_impl( tl.store(output_t_ptrs, y_block, mask=mask) if IS_FP8: + if pid == 0: + scale_inv = tl.fdiv(1.0, scale) + tl.store(scale_inv_ptr, scale_inv) if APPLY_ATOMIC: - if pid == 0: - scale_inv = tl.fdiv(1.0, scale) - tl.store(scale_inv_ptr, scale_inv) tl.atomic_max(amax_ptr, amax, sem="relaxed") else: tl.store(amax_ptr + pid, amax) @@ -182,8 +182,6 @@ def _layernorm_fwd_triton_impl( def _layernorm_fwd_reduce_triton( amax_input_ptr, amax_output_ptr, - scale_ptr, - scale_inv_ptr, n_rows, BLOCK_SIZE: tl.constexpr, ): @@ -200,12 +198,6 @@ def _layernorm_fwd_reduce_triton( tl.atomic_max(amax_output_ptr, amax, sem="relaxed") - if pid == 0: - scale = tl.load(scale_ptr) - scale_inv = tl.fdiv(1.0, scale) - tl.store(scale_inv_ptr, scale_inv) - - @triton.jit def _layernorm_bwd_dx_fused_triton( DX, # pointer to the input gradient @@ -455,214 +447,3 @@ def _layernorm_bwd_dwdb_triton_v2( sum_db = tl.sum(db, axis=0) tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.type.element_ty), mask=cols < N) tl.store(FINAL_DB + cols, sum_db.to(FINAL_DB.type.element_ty), mask=cols < N) - -def te_layernorm_fwd_triton(input: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float, - ln_out: torch.Tensor, - quantizer: Quantizer, - otype: tex.DType, - sm_margin: int, - zero_centered_gamma: bool, - autotune: bool = True,): - if sm_margin is not None and sm_margin > 0: - warnings.warn( - '"sm_margin" is not supported in the Triton based forward layer-norm kernel. ' - + f"sm_margin={sm_margin} will be ignored." - ) - device = input.device - M, N = input.shape - - IS_MXFP8 = isinstance(quantizer, MXFP8Quantizer) - MAKE_TRANSPOSE = False - - # Create empty tensors for mu and rsigma - mu = torch.empty((M,), dtype=torch.float32, device=device) - rsigma = torch.empty((M,), dtype=torch.float32, device=device) - torch_out_dtype = ( - otype if isinstance(otype, torch.dtype) - else te_dtype_to_torch_dtype(otype) - ) - # Create ln_out - ln_out = make_ln_out(ln_out, quantizer=quantizer, input_shape=input.shape, out_dtype=torch_out_dtype) - # To update the amax ptr directly with atomic max - APPLY_ATOMIC = M < 512 - - # MXFP8 is handled regularly, hence quantizer of Float8Quantizer is considered FP8 - IS_FP8 = isinstance(quantizer, Float8Quantizer) - - amax_temp = torch.empty((M,), dtype=torch.float32, device=device) if IS_FP8 else None - - max_fused_size = 16384 // input.element_size() - BLOCK_SIZE = min(max_fused_size, triton.next_power_of_2(N)) - - out_transpose_ptr = None - out_transpose_stride = None - - # Create necessary values for fp8 if needed - if IS_FP8: - scale = quantizer.scale - amax_out = quantizer.amax - scale_inv = ln_out._scale_inv - cast_out = ln_out._data - MAKE_TRANSPOSE = quantizer.columnwise_usage - if MAKE_TRANSPOSE: - tl_dtype = te_dtype_to_triton_dtype(quantizer.dtype) - if ln_out._transpose_invalid: - ln_out._transpose = torch.empty((ln_out._data.shape[1], ln_out._data.shape[0]), dtype=ln_out._data.dtype, device=device) - ln_out._transpose_invalid = False - out_transpose_ptr = triton.reinterpret(ln_out._transpose, tl_dtype) - out_transpose_stride = ln_out._transpose.stride(0) - else: - scale = None - amax_out = None - scale_inv = None - cast_out = ln_out - - kernel = _layernorm_fwd_triton if autotune else _layernorm_fwd_triton_impl - kernel[(M,)]( - input, - triton.reinterpret(cast_out, te_dtype_to_triton_dtype(ln_out._fp8_dtype)) if IS_FP8 else cast_out, - weight, - bias, - mu, - rsigma, - scale, - amax_out if APPLY_ATOMIC else amax_temp, - scale_inv, - input.stride(0), - cast_out.stride(0), - M, - N, - eps, - out_transpose_ptr, - out_transpose_stride, - ZERO_CENTERED_GAMMA=zero_centered_gamma, - BLOCK_SIZE=BLOCK_SIZE, - IS_FP8=IS_FP8, - APPLY_ATOMIC=APPLY_ATOMIC, - # TODO: Improve performance with persistent kernel - # Persistent kernel currently lags behind non persistent version - # It also lags behind TE implementation in a few cases - PERSISTENT=False, - FP8_MAX=get_fp8_max(quantizer.dtype) if IS_FP8 else None, - MAKE_TRANSPOSE=MAKE_TRANSPOSE - ) - - # For MXFP8, we do regular layernorm and then quantize it separately - if IS_MXFP8: - ln_out = te_quantize_triton(ln_out, quantizer) - - # Reduce and find amax if "not APPLY_ATOMIC" is True. - if IS_FP8 and not APPLY_ATOMIC: - _layernorm_fwd_reduce_triton[(triton.cdiv(M, 256),)]( - amax_temp, - amax_out, - scale, - scale_inv, - M, - 256, - ) - return ln_out, mu, rsigma - -# drop in replacement for transformer_engine::pytorch::layernorm_bwd -# TODO: Add support for `sm_margin > 0`. -def te_layernorm_bwd_triton( - dz: torch.Tensor, - x: torch.Tensor, - mu: torch.Tensor, - rsigma: torch.Tensor, - gamma: torch.Tensor, - sm_margin: int, - zero_centered_gamma: bool -): - if sm_margin is not None and sm_margin > 0: - warnings.warn( - '"sm_margin" is not supported in the Triton based backward layer-norm kernel. ' - + f"sm_margin={sm_margin} will be ignored." - ) - M, N = x.shape - # calculate dw and db separately when M is small - IGNORE_DW_DB_IN_FUSED = M <= 512 - tile_num = max(min(256, M // 4), 1) - if M <= 512 and M * N < 64 * 1024 * 1024: - tile_num = M - elif M >= 8192: - tile_num = 2048 - max_fused_size = 32768 // x.element_size() - next_power = triton.next_power_of_2(N) - BLOCK_SIZE = min(max_fused_size, next_power) - # For cases with small M and large N, decrease block size to help with occupancy and register spill - if tile_num == M: - if tile_num > 256: - BLOCK_SIZE = min(BLOCK_SIZE, 2048) - else: - BLOCK_SIZE = min(BLOCK_SIZE, 4096) - USE_BLOCKED = N > BLOCK_SIZE - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - - dx = torch.empty_like(x) - if not IGNORE_DW_DB_IN_FUSED: - _dgamma = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device) - _dbeta = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device) - else: - _dgamma = None - _dbeta = None - dgamma = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device) - dbeta = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device) - grid_bwd = (tile_num,) - _layernorm_bwd_dx_fused_triton[grid_bwd]( - dx, - dz, - _dgamma, - _dbeta, - x, - gamma, - mu, - rsigma, - x.stride(0), - N, - ZERO_CENTERED_GAMMA=zero_centered_gamma, - NUM_ROWS=M, - BLOCK_SIZE_N=BLOCK_SIZE, - USE_BLOCKED=USE_BLOCKED, - num_warps=num_warps, - IGNORE_DW_DB=IGNORE_DW_DB_IN_FUSED, - ) - grid_reduce = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE_N"]),) - if not IGNORE_DW_DB_IN_FUSED: - dwdb_block_n = max(16, N // 256) - dwdb_block_n = triton.next_power_of_2(dwdb_block_n) - dwdb_block_m = (64 * 128) // dwdb_block_n - dwdb_block_m = min(triton.next_power_of_2(tile_num), dwdb_block_m) - _layernorm_bwd_dwdb_triton[grid_reduce]( - _dgamma, - _dbeta, - dgamma, - dbeta, - min(tile_num, M), - N, - BLOCK_SIZE_M=dwdb_block_m, - BLOCK_SIZE_N=dwdb_block_n, - ) - else: - dwdb_block_n = max(16, N // 256) - dwdb_block_n = triton.next_power_of_2(dwdb_block_n) - dwdb_block_m = (64 * 128) // dwdb_block_n - dwdb_block_m = min(triton.next_power_of_2(M), dwdb_block_m) - _layernorm_bwd_dwdb_triton_v2[grid_reduce]( - x, - dz, - mu, - rsigma, - x.stride(0), - dgamma, - dbeta, - M, - N, - BLOCK_SIZE_M=dwdb_block_m, - BLOCK_SIZE_N=dwdb_block_n, - ) - - return dx, dgamma, dbeta diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py new file mode 100644 index 000000000..ebfc5d741 --- /dev/null +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -0,0 +1,342 @@ +import torch +import triton +import warnings +import transformer_engine_torch as tex + +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.triton_kernels.common import ( + te_dtype_to_torch_dtype, + te_dtype_to_triton_dtype, +) +from ..tensor.quantized_tensor import Quantizer +from .norm_common import num_programs, block_size, use_blocked, make_ln_out +from .common import get_fp8_max +from .rmsnorm import ( + _rmsnorm_fwd_triton, + _rmsnorm_fwd_triton_impl, + _rmsnorm_bwd_triton, + _rmsnorm_bwd_dg_reduce_triton, +) +from .layernorm import ( + _layernorm_fwd_triton, + _layernorm_fwd_triton_impl, + _layernorm_fwd_reduce_triton, + _layernorm_bwd_dwdb_triton, + _layernorm_bwd_dwdb_triton_v2, + _layernorm_bwd_dx_fused_triton, +) + +_norm_kernels={ + "rms":{ + True: _rmsnorm_fwd_triton, + False: _rmsnorm_fwd_triton_impl, + }, + "layer":{ + True: _layernorm_fwd_triton, + False: _layernorm_fwd_triton_impl, + } +} +# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd +def te_rmsnorm_fwd_triton( + input: torch.Tensor, + weight: torch.Tensor, + eps: float, + ln_out: torch.Tensor, + quantizer: Quantizer, + otype: tex.DType, + sm_margin: int, + zero_centered_gamma: bool, + autotune: bool = True, +): + return te_norm_fwd_triton( + kernel='rms', + input=input, + weight=weight, + bias=None, + eps=eps, + ln_out=ln_out, + quantizer=quantizer, + otype=otype, + sm_margin=sm_margin, + zero_centered_gamma=zero_centered_gamma, + autotune=autotune, + ) + +# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd +def te_layernorm_fwd_triton( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + ln_out: torch.Tensor, + quantizer: Quantizer, + otype: tex.DType, + sm_margin: int, + zero_centered_gamma: bool, + autotune: bool = True, +): + return te_norm_fwd_triton( + kernel='layer', + input=input, + weight=weight, + bias=bias, + eps=eps, + ln_out=ln_out, + quantizer=quantizer, + otype=otype, + sm_margin=sm_margin, + zero_centered_gamma=zero_centered_gamma, + autotune=autotune, + ) + +# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd +def te_norm_fwd_triton( + kernel: str, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + ln_out: torch.Tensor, + quantizer: Quantizer, + otype: tex.DType, + sm_margin: int, + zero_centered_gamma: bool, + autotune: bool = True, +): + if kernel not in {'rms', 'layer'}: + raise ValueError(f"Expected `kernel` in ('rms', 'layer') but got {kernel=} instead.") + if eps < 0: + raise ValueError(f"`eps` must be non-negative, but a value of {eps} was passed") + if len(input.shape) != 2: + raise ValueError( + f"The input must be a 2-dimensional matrix, but an input with {input.ndim} was passed.") + + device = input.device + N, H = input.shape + if weight.shape[0] != H: + raise ValueError( + f"The shape of `weight` must be feature-aligned, " + f"but {weight.shape[0]=} while {input.shape[1]=}" + ) + IS_FP8 = isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) + IS_MXFP8 = isinstance(quantizer, MXFP8Quantizer) + BLOCK_SIZE = block_size(input) + USE_BLOCKED = use_blocked(input) + NUM_PRGMS = num_programs(input, sm_margin) + MAKE_TRANSPOSE = False + APPLY_ATOMIC = N < 512 and kernel == 'layer' + ATOMIC_REDUCTION_BLOCK_SIZE=256 + + mu = torch.empty((N,), dtype=torch.float32, device=device) if kernel == 'layer' else None + rsigma = torch.empty((N,), dtype=torch.float32, device=device) + torch_out_dtype = ( + otype if isinstance(otype, torch.dtype) + else te_dtype_to_torch_dtype(otype) + ) + out = make_ln_out( + ln_out, + quantizer=quantizer, + input_shape=input.shape, + out_dtype=torch_out_dtype + ) + amax = None + tl_dtype = None + scale_inv_ptr = None + q_scale = None + out_ptr = out + out_transpose_ptr = None + out_transpose_stride = None + FP8_MAX = None + if IS_FP8: + MAKE_TRANSPOSE = quantizer.columnwise_usage + amax = ( + torch.empty((NUM_PRGMS,), dtype=torch.float32, device=device) + if APPLY_ATOMIC else quantizer.amax + ) + tl_dtype = te_dtype_to_triton_dtype(quantizer.dtype) + scale_inv_ptr = out._scale_inv + q_scale = quantizer.scale + out_ptr = triton.reinterpret(out._data, tl_dtype) + FP8_MAX = get_fp8_max(quantizer.dtype) + if MAKE_TRANSPOSE: + if out._transpose_invalid: + out._transpose = torch.empty( + (out._data.shape[1], out._data.shape[0]), + dtype=out._data.dtype, device=device + ) + out._transpose_invalid = False + out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) + out_transpose_stride = out._transpose.stride(0) + + grid_fwd = lambda meta: (NUM_PRGMS, ) + kernel = _norm_kernels[kernel][autotune] + kernel[grid_fwd]( + out_ptr, + input, + weight, + mu, + bias, + rsigma, + input.stride(0), + out_ptr.stride(0), + N, H, eps, + amax, + q_scale, + scale_inv_ptr, + out_transpose_ptr, + out_transpose_stride, + zero_centered_gamma, + BLOCK_SIZE, + USE_BLOCKED, + NUM_PRGMS, + IS_FP8, + APPLY_ATOMIC, + FP8_MAX, + MAKE_TRANSPOSE, + ) + if IS_MXFP8: + out = quantizer.quantize(out, out=ln_out) + + # Reduce and find amax if "not APPLY_ATOMIC" is True for layernorm. + if IS_FP8 and not APPLY_ATOMIC: + _layernorm_fwd_reduce_triton[(triton.cdiv(N, ATOMIC_REDUCTION_BLOCK_SIZE),)]( + amax, + quantizer.amax, + N, ATOMIC_REDUCTION_BLOCK_SIZE, + ) + + return out, mu, rsigma + + +# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_bwd +def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): + # may take non-contiguous inputs + dz_ = dz.contiguous() + x_ = x.contiguous() + rsigma_ = rsigma.contiguous() + gamma_ = gamma.contiguous() + + dx = torch.empty_like(x_) + dgamma = torch.empty_like(gamma_) + + M, N = x_.shape + blk_size = block_size(x_) + USE_BLOCKED = use_blocked(x_) + NUM_PRGMS = num_programs(x_, sm_margin) + need_reduction = N > 1 + dg_tmp_rows = x_.shape[0] if use_blocked(x_) else num_programs(x_, sm_margin) + dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None + + grid_bwd = lambda meta: (NUM_PRGMS, ) + _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, + x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, + USE_BLOCKED, NUM_PRGMS, num_warps=8) + + if need_reduction: + grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] + _rmsnorm_bwd_dg_reduce_triton[grid_reduce](dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], + BLOCK_SIZE_M=128, BLOCK_SIZE_N=64) + + return dx, dgamma + +# drop in replacement for transformer_engine::pytorch::layernorm_bwd +# TODO: Add support for `sm_margin > 0`. +def te_layernorm_bwd_triton( + dz: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool +): + if sm_margin is not None and sm_margin > 0: + warnings.warn( + '"sm_margin" is not supported in the Triton based backward layer-norm kernel. ' + + f"sm_margin={sm_margin} will be ignored." + ) + M, N = x.shape + # calculate dw and db separately when M is small + IGNORE_DW_DB_IN_FUSED = M <= 512 + tile_num = max(min(256, M // 4), 1) + if M <= 512 and M * N < 64 * 1024 * 1024: + tile_num = M + elif M >= 8192: + tile_num = 2048 + max_fused_size = 32768 // x.element_size() + next_power = triton.next_power_of_2(N) + BLOCK_SIZE = min(max_fused_size, next_power) + # For cases with small M and large N, decrease block size to help with occupancy and register spill + if tile_num == M: + if tile_num > 256: + BLOCK_SIZE = min(BLOCK_SIZE, 2048) + else: + BLOCK_SIZE = min(BLOCK_SIZE, 4096) + USE_BLOCKED = N > BLOCK_SIZE + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + + dx = torch.empty_like(x) + if not IGNORE_DW_DB_IN_FUSED: + _dgamma = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device) + _dbeta = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device) + else: + _dgamma = None + _dbeta = None + dgamma = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device) + dbeta = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device) + grid_bwd = (tile_num,) + _layernorm_bwd_dx_fused_triton[grid_bwd]( + dx, + dz, + _dgamma, + _dbeta, + x, + gamma, + mu, + rsigma, + x.stride(0), + N, + ZERO_CENTERED_GAMMA=zero_centered_gamma, + NUM_ROWS=M, + BLOCK_SIZE_N=BLOCK_SIZE, + USE_BLOCKED=USE_BLOCKED, + num_warps=num_warps, + IGNORE_DW_DB=IGNORE_DW_DB_IN_FUSED, + ) + grid_reduce = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE_N"]),) + if not IGNORE_DW_DB_IN_FUSED: + dwdb_block_n = max(16, N // 256) + dwdb_block_n = triton.next_power_of_2(dwdb_block_n) + dwdb_block_m = (64 * 128) // dwdb_block_n + dwdb_block_m = min(triton.next_power_of_2(tile_num), dwdb_block_m) + _layernorm_bwd_dwdb_triton[grid_reduce]( + _dgamma, + _dbeta, + dgamma, + dbeta, + min(tile_num, M), + N, + BLOCK_SIZE_M=dwdb_block_m, + BLOCK_SIZE_N=dwdb_block_n, + ) + else: + dwdb_block_n = max(16, N // 256) + dwdb_block_n = triton.next_power_of_2(dwdb_block_n) + dwdb_block_m = (64 * 128) // dwdb_block_n + dwdb_block_m = min(triton.next_power_of_2(M), dwdb_block_m) + _layernorm_bwd_dwdb_triton_v2[grid_reduce]( + x, + dz, + mu, + rsigma, + x.stride(0), + dgamma, + dbeta, + M, + N, + BLOCK_SIZE_M=dwdb_block_m, + BLOCK_SIZE_N=dwdb_block_n, + ) + + return dx, dgamma, dbeta diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index b62d61ced..9133c0188 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -24,16 +24,19 @@ def get_autotune_config(): return [triton.Config({'waves_per_eu': we}, num_warps=nw) for (we, nw) in product([0, 1, 2, 4], [4, 8, 16])] +# TODO(micky774) Implement fused MXFP8 quantization within the kernel @triton.jit def _rmsnorm_fwd_triton_impl( output_ptr, input_ptr, - g_ptr, rsigma_ptr, + g_ptr, + bias_ptr, # Unused, for API purposes only + mu_ptr, # Unused, for API purposes only + rsigma_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon, - amax_ptr, q_amax_ptr, q_scale_ptr, scale_inv_ptr, @@ -44,6 +47,7 @@ def _rmsnorm_fwd_triton_impl( USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr, IS_FP8: tl.constexpr, + APPLY_ATOMIC: tl.constexpr, # Unused, for API purposes only FP8_MAX: tl.constexpr, MAKE_TRANSPOSE: tl.constexpr, ): @@ -171,7 +175,6 @@ def _rmsnorm_fwd_triton_impl( tl.store(output_t_ptrs, rms_norm.to(output_type), mask=mask) tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) if IS_FP8: - tl.store(amax_ptr + row_start, amax) tl.atomic_max(q_amax_ptr, amax, sem="relaxed") if row_start == 0: scale = tl.load(q_scale_ptr) @@ -362,105 +365,3 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): BLOCK_SIZE_M=128, BLOCK_SIZE_N=64) return dx, dgamma - -# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd -def te_rmsnorm_fwd_triton( - input: torch.Tensor, - weight: torch.Tensor, - eps: float, - ln_out: torch.Tensor, - quantizer: Quantizer, - otype: tex.DType, - sm_margin: int, - zero_centered_gamma: bool, - autotune: bool = True, -): - if eps < 0: - raise ValueError(f"`eps` must be non-negative, but a value of {eps} was passed") - if len(input.shape) != 2: - raise ValueError( - f"The input must be a 2-dimensional matrix, but an input with {input.ndim} was passed.") - - device = input.device - N, H = input.shape - if weight.shape[0] != H: - raise ValueError( - f"The shape of `weight` must be feature-aligned, " - f"but {weight.shape[0]=} while {input.shape[1]=}" - ) - IS_FP8 = isinstance(quantizer, Float8Quantizer) - IS_MXFP8 = isinstance(quantizer, MXFP8Quantizer) - BLOCK_SIZE = block_size(input) - USE_BLOCKED = use_blocked(input) - NUM_PRGMS = num_programs(input, sm_margin) - MAKE_TRANSPOSE = False - - rsigma = torch.empty((N,), dtype=torch.float32, device=device) - torch_out_dtype = ( - otype if isinstance(otype, torch.dtype) - else te_dtype_to_torch_dtype(otype) - ) - out = make_ln_out( - ln_out, - quantizer=quantizer, - input_shape=input.shape, - out_dtype=torch_out_dtype - ) - if IS_FP8: - MAKE_TRANSPOSE = quantizer.columnwise_usage - amax = torch.empty((NUM_PRGMS,), dtype=torch.float32, device=device) - tl_dtype = te_dtype_to_triton_dtype(quantizer.dtype) - scale_inv_ptr = out._scale_inv - q_scale = quantizer.scale - q_amax = quantizer.amax - out_ptr = triton.reinterpret(out._data, tl_dtype) - FP8_MAX = get_fp8_max(quantizer.dtype) - if MAKE_TRANSPOSE: - if out._transpose_invalid: - out._transpose = torch.empty((out._data.shape[1], out._data.shape[0]), dtype=out._data.dtype, device=device) - out._transpose_invalid = False - out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) - out_transpose_stride = out._transpose.stride(0) - else: - out_transpose_ptr = None - out_transpose_stride = None - else: - amax = None - tl_dtype = None - scale_inv_ptr = None - q_scale = None - q_amax = None - out_ptr = out - out_transpose_ptr = None - out_transpose_stride = None - FP8_MAX = None - - grid_fwd = lambda meta: (NUM_PRGMS, ) - # TODO(micky774) Implement fused MXFP8 quantization within the kernel - kernel = _rmsnorm_fwd_triton if autotune else _rmsnorm_fwd_triton_impl - kernel[grid_fwd]( - out_ptr, - input, - weight, - rsigma, - input.stride(0), - out_ptr.stride(0), - N, H, eps, - amax, - q_amax, - q_scale, - scale_inv_ptr, - out_transpose_ptr, - out_transpose_stride, - zero_centered_gamma, - BLOCK_SIZE, - USE_BLOCKED, - NUM_PRGMS, - IS_FP8, - FP8_MAX, - MAKE_TRANSPOSE, - ) - if IS_MXFP8: - out = quantizer.quantize(out, out=ln_out) - - return out, None, rsigma From 5b2ea1c3117eb968a8ac7d9568ee73bc571b65a4 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 5 Sep 2025 17:03:47 -0500 Subject: [PATCH 02/16] Minor API correction --- .../pytorch/triton_kernels/layernorm.py | 16 ------- .../pytorch/triton_kernels/norms.py | 30 ++++++------- .../pytorch/triton_kernels/rmsnorm.py | 45 +------------------ 3 files changed, 16 insertions(+), 75 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/layernorm.py b/transformer_engine/pytorch/triton_kernels/layernorm.py index cfdd5c335..8d9d9336f 100644 --- a/transformer_engine/pytorch/triton_kernels/layernorm.py +++ b/transformer_engine/pytorch/triton_kernels/layernorm.py @@ -3,25 +3,9 @@ from itertools import product -import os -import torch - -from ..tensor.float8_tensor import Float8Quantizer -from ..constants import TE_DType -from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor.quantized_tensor import Quantizer -from ..triton_kernels.cast import te_quantize_triton import triton import triton.language as tl -import warnings -import transformer_engine_torch as tex -from .common import ( - get_fp8_max, - te_dtype_to_torch_dtype, - te_dtype_to_triton_dtype, -) -from .norm_common import make_ln_out def get_autotune_config(full_tuning_space=False): if full_tuning_space: diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index ebfc5d741..443f2c0fb 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -51,7 +51,7 @@ def te_rmsnorm_fwd_triton( ): return te_norm_fwd_triton( kernel='rms', - input=input, + input_tensor=input, weight=weight, bias=None, eps=eps, @@ -78,7 +78,7 @@ def te_layernorm_fwd_triton( ): return te_norm_fwd_triton( kernel='layer', - input=input, + input_tensor=input, weight=weight, bias=bias, eps=eps, @@ -93,7 +93,7 @@ def te_layernorm_fwd_triton( # triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd def te_norm_fwd_triton( kernel: str, - input: torch.Tensor, + input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, @@ -108,22 +108,22 @@ def te_norm_fwd_triton( raise ValueError(f"Expected `kernel` in ('rms', 'layer') but got {kernel=} instead.") if eps < 0: raise ValueError(f"`eps` must be non-negative, but a value of {eps} was passed") - if len(input.shape) != 2: + if len(input_tensor.shape) != 2: raise ValueError( - f"The input must be a 2-dimensional matrix, but an input with {input.ndim} was passed.") + f"The input must be a 2-dimensional matrix, but an input with {input_tensor.ndim} was passed.") - device = input.device - N, H = input.shape + device = input_tensor.device + N, H = input_tensor.shape if weight.shape[0] != H: raise ValueError( f"The shape of `weight` must be feature-aligned, " - f"but {weight.shape[0]=} while {input.shape[1]=}" + f"but {weight.shape[0]=} while {input_tensor.shape[1]=}" ) IS_FP8 = isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) IS_MXFP8 = isinstance(quantizer, MXFP8Quantizer) - BLOCK_SIZE = block_size(input) - USE_BLOCKED = use_blocked(input) - NUM_PRGMS = num_programs(input, sm_margin) + BLOCK_SIZE = block_size(input_tensor) + USE_BLOCKED = use_blocked(input_tensor) + NUM_PRGMS = num_programs(input_tensor, sm_margin) MAKE_TRANSPOSE = False APPLY_ATOMIC = N < 512 and kernel == 'layer' ATOMIC_REDUCTION_BLOCK_SIZE=256 @@ -137,7 +137,7 @@ def te_norm_fwd_triton( out = make_ln_out( ln_out, quantizer=quantizer, - input_shape=input.shape, + input_shape=input_tensor.shape, out_dtype=torch_out_dtype ) amax = None @@ -172,13 +172,13 @@ def te_norm_fwd_triton( grid_fwd = lambda meta: (NUM_PRGMS, ) kernel = _norm_kernels[kernel][autotune] kernel[grid_fwd]( + input_tensor, out_ptr, - input, weight, - mu, bias, + mu, rsigma, - input.stride(0), + input_tensor.stride(0), out_ptr.stride(0), N, H, eps, amax, diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 9133c0188..237561341 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -5,20 +5,6 @@ import triton import triton.language as tl from itertools import product -from .norm_common import num_programs, block_size, use_blocked, make_ln_out -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer -from transformer_engine.pytorch.triton_kernels.common import ( - te_dtype_to_torch_dtype, - te_dtype_to_triton_dtype, -) -from .common import get_fp8_max -from ..tensor.quantized_tensor import Quantizer -import transformer_engine_torch as tex - -def dg_tmp_rows(x, sm_margin=None): - return x.shape[0] if use_blocked(x) else num_programs(x, sm_margin) - def get_autotune_config(): return [triton.Config({'waves_per_eu': we}, num_warps=nw) for (we, nw) in product([0, 1, 2, 4], [4, 8, 16])] @@ -27,8 +13,8 @@ def get_autotune_config(): # TODO(micky774) Implement fused MXFP8 quantization within the kernel @triton.jit def _rmsnorm_fwd_triton_impl( - output_ptr, input_ptr, + output_ptr, g_ptr, bias_ptr, # Unused, for API purposes only mu_ptr, # Unused, for API purposes only @@ -336,32 +322,3 @@ def _rmsnorm_bwd_dg_reduce_triton(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n sum_dg = tl.sum(acc, axis=0) tl.store(dg_out_ptr + cols, sum_dg.to(dg_out_ptr.type.element_ty), mask=cols < n_cols) -# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_bwd -def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): - # may take non-contiguous inputs - dz_ = dz.contiguous() - x_ = x.contiguous() - rsigma_ = rsigma.contiguous() - gamma_ = gamma.contiguous() - - dx = torch.empty_like(x_) - dgamma = torch.empty_like(gamma_) - - M, N = x_.shape - blk_size = block_size(x_) - USE_BLOCKED = use_blocked(x_) - NUM_PRGMS = num_programs(x_, sm_margin) - need_reduction = N > 1 - dg_tmp = torch.empty(dg_tmp_rows(x_, sm_margin), N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None - - grid_bwd = lambda meta: (NUM_PRGMS, ) - _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, - x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, - USE_BLOCKED, NUM_PRGMS, num_warps=8) - - if need_reduction: - grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] - _rmsnorm_bwd_dg_reduce_triton[grid_reduce](dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], - BLOCK_SIZE_M=128, BLOCK_SIZE_N=64) - - return dx, dgamma From 077b8fc0767f4c359c0cc5878f0762515625189e Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 5 Sep 2025 17:09:34 -0500 Subject: [PATCH 03/16] Corrected atomic behaivor --- transformer_engine/pytorch/triton_kernels/norms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index 443f2c0fb..69a86fefd 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -125,7 +125,7 @@ def te_norm_fwd_triton( USE_BLOCKED = use_blocked(input_tensor) NUM_PRGMS = num_programs(input_tensor, sm_margin) MAKE_TRANSPOSE = False - APPLY_ATOMIC = N < 512 and kernel == 'layer' + APPLY_ATOMIC = N < 512 or kernel == 'rms' ATOMIC_REDUCTION_BLOCK_SIZE=256 mu = torch.empty((N,), dtype=torch.float32, device=device) if kernel == 'layer' else None From 0011e5f3415db9c38c034fa4937612f7738793b7 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 5 Sep 2025 18:04:41 -0500 Subject: [PATCH 04/16] API update --- .../pytorch/triton_kernels/layernorm.py | 42 ++++++------ .../pytorch/triton_kernels/norms.py | 65 ++++++++++--------- .../pytorch/triton_kernels/rmsnorm.py | 11 ++-- 3 files changed, 61 insertions(+), 57 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/layernorm.py b/transformer_engine/pytorch/triton_kernels/layernorm.py index 8d9d9336f..27e279f40 100644 --- a/transformer_engine/pytorch/triton_kernels/layernorm.py +++ b/transformer_engine/pytorch/triton_kernels/layernorm.py @@ -20,20 +20,20 @@ def get_autotune_config(full_tuning_space=False): @triton.jit def _layernorm_fwd_triton_impl( - x_ptr, - y_ptr, - w_ptr, + input_ptr, + output_ptr, + g_ptr, b_ptr, mean_ptr, - rstd_ptr, - scale_ptr, - amax_ptr, - scale_inv_ptr, - x_row_stride, - y_row_stride, + rsigma_ptr, + input_row_stride, + output_row_stride, n_rows, n_cols, - eps, + epsilon, + q_amax_ptr, + q_scale_ptr, + scale_inv_ptr, out_transpose_ptr, out_transpose_stride, ZERO_CENTERED_GAMMA: tl.constexpr, @@ -65,12 +65,12 @@ def _layernorm_fwd_triton_impl( start_row = pid if IS_FP8: - scale = tl.load(scale_ptr) + scale = tl.load(q_scale_ptr) amax = 0.0 for row_idx in range(start_row, start_row + rows_per_tile): - x_ptr_start = x_ptr + (row_idx * x_row_stride) - y_ptr_start = y_ptr + (row_idx * y_row_stride) + x_ptr_start = input_ptr + (row_idx * input_row_stride) + y_ptr_start = output_ptr + (row_idx * output_row_stride) n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1 @@ -102,16 +102,16 @@ def _layernorm_fwd_triton_impl( _var += x_block * x_block var = tl.sum(_var, axis=0) / n_cols - rstd = tl.rsqrt(var + eps) + rstd = tl.rsqrt(var + epsilon) # Write mean / rstd tl.store(mean_ptr + row_idx, mean) - tl.store(rstd_ptr + row_idx, rstd) + tl.store(rsigma_ptr + row_idx, rstd) # Normalize and store for blk_idx in range(0, n_cols_blks): cols = blk_idx * BLOCK_SIZE + col_offsets - w_block = tl.load(w_ptr + cols).to(tl.float32) + w_block = tl.load(g_ptr + cols).to(tl.float32) b_block = tl.load(b_ptr + cols).to(tl.float32) x_block = tl.load(x_ptr_start + cols).to(tl.float32) if ZERO_CENTERED_GAMMA: @@ -123,7 +123,7 @@ def _layernorm_fwd_triton_impl( amax = amax_temp if amax_temp > amax else amax y_block = y_block * scale y_block = tl.clamp(y_block, -FP8_MAX, FP8_MAX) - y_block = y_block.to(y_ptr.type.element_ty) + y_block = y_block.to(output_ptr.type.element_ty) tl.store(y_ptr_start + cols, y_block) if MAKE_TRANSPOSE: output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx @@ -132,7 +132,7 @@ def _layernorm_fwd_triton_impl( # For last iteration, do masked load and store cols = n_cols_blks * BLOCK_SIZE + col_offsets mask = cols < n_cols - w_block = tl.load(w_ptr + cols, mask=mask, other=0.0).to(tl.float32) + w_block = tl.load(g_ptr + cols, mask=mask, other=0.0).to(tl.float32) b_block = tl.load(b_ptr + cols, mask=mask, other=0.0).to(tl.float32) x_block = tl.load(x_ptr_start + cols, mask=mask, other=0.0).to(tl.float32) if ZERO_CENTERED_GAMMA: @@ -144,7 +144,7 @@ def _layernorm_fwd_triton_impl( amax = amax_temp if amax_temp > amax else amax y_block = y_block * scale y_block = tl.clamp(y_block, -FP8_MAX, FP8_MAX) - y_block = y_block.to(y_ptr.type.element_ty) + y_block = y_block.to(output_ptr.type.element_ty) tl.store(y_ptr_start + cols, y_block, mask=mask) if MAKE_TRANSPOSE: output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx @@ -155,9 +155,9 @@ def _layernorm_fwd_triton_impl( scale_inv = tl.fdiv(1.0, scale) tl.store(scale_inv_ptr, scale_inv) if APPLY_ATOMIC: - tl.atomic_max(amax_ptr, amax, sem="relaxed") + tl.atomic_max(q_amax_ptr, amax, sem="relaxed") else: - tl.store(amax_ptr + pid, amax) + tl.store(q_amax_ptr + pid, amax) autotune_dec = triton.autotune(configs=get_autotune_config(), key=["n_rows", "n_cols"], use_cuda_graph=True) _layernorm_fwd_triton = autotune_dec(_layernorm_fwd_triton_impl) diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index 69a86fefd..ce6ea0af7 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -151,8 +151,8 @@ def te_norm_fwd_triton( if IS_FP8: MAKE_TRANSPOSE = quantizer.columnwise_usage amax = ( + quantizer.amax if APPLY_ATOMIC else torch.empty((NUM_PRGMS,), dtype=torch.float32, device=device) - if APPLY_ATOMIC else quantizer.amax ) tl_dtype = te_dtype_to_triton_dtype(quantizer.dtype) scale_inv_ptr = out._scale_inv @@ -169,35 +169,40 @@ def te_norm_fwd_triton( out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) out_transpose_stride = out._transpose.stride(0) - grid_fwd = lambda meta: (NUM_PRGMS, ) - kernel = _norm_kernels[kernel][autotune] - kernel[grid_fwd]( - input_tensor, - out_ptr, - weight, - bias, - mu, - rsigma, - input_tensor.stride(0), - out_ptr.stride(0), - N, H, eps, - amax, - q_scale, - scale_inv_ptr, - out_transpose_ptr, - out_transpose_stride, - zero_centered_gamma, - BLOCK_SIZE, - USE_BLOCKED, - NUM_PRGMS, - IS_FP8, - APPLY_ATOMIC, - FP8_MAX, - MAKE_TRANSPOSE, + grid_fwd = lambda meta: (N if kernel=='layer' else NUM_PRGMS,) + kernel_func = _norm_kernels[kernel][autotune] + kwargs = dict( + input_ptr=input_tensor, + output_ptr=out_ptr, + g_ptr=weight, + rsigma_ptr=rsigma, + input_row_stride=input_tensor.stride(0), + output_row_stride=out_ptr.stride(0), + n_rows=N, n_cols=H, + epsilon=eps, + q_amax_ptr=amax, + q_scale_ptr=q_scale, + scale_inv_ptr=scale_inv_ptr, + out_transpose_ptr=out_transpose_ptr, + out_transpose_stride=out_transpose_stride, + ZERO_CENTERED_GAMMA=zero_centered_gamma, + BLOCK_SIZE=BLOCK_SIZE, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + MAKE_TRANSPOSE=MAKE_TRANSPOSE, + ) + if kernel == 'layer': + kwargs["APPLY_ATOMIC"]=APPLY_ATOMIC + kwargs["PERSISTENT"]=False # TODO: Improve persistent algo performance + kwargs["b_ptr"]=bias + kwargs["mean_ptr"]=mu + elif kernel == "rms": + kwargs["USE_BLOCKED"]=USE_BLOCKED + kwargs["NUM_PRGMS"]=NUM_PRGMS + + kernel_func[grid_fwd]( + **kwargs, ) - if IS_MXFP8: - out = quantizer.quantize(out, out=ln_out) - # Reduce and find amax if "not APPLY_ATOMIC" is True for layernorm. if IS_FP8 and not APPLY_ATOMIC: _layernorm_fwd_reduce_triton[(triton.cdiv(N, ATOMIC_REDUCTION_BLOCK_SIZE),)]( @@ -205,6 +210,8 @@ def te_norm_fwd_triton( quantizer.amax, N, ATOMIC_REDUCTION_BLOCK_SIZE, ) + elif IS_MXFP8: + out = quantizer.quantize(out, out=ln_out) return out, mu, rsigma diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 237561341..0394ccf73 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -16,8 +16,6 @@ def _rmsnorm_fwd_triton_impl( input_ptr, output_ptr, g_ptr, - bias_ptr, # Unused, for API purposes only - mu_ptr, # Unused, for API purposes only rsigma_ptr, input_row_stride, output_row_stride, @@ -27,13 +25,12 @@ def _rmsnorm_fwd_triton_impl( q_scale_ptr, scale_inv_ptr, out_transpose_ptr, - transpose_row_stride, + out_transpose_stride, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr, IS_FP8: tl.constexpr, - APPLY_ATOMIC: tl.constexpr, # Unused, for API purposes only FP8_MAX: tl.constexpr, MAKE_TRANSPOSE: tl.constexpr, ): @@ -105,7 +102,7 @@ def _rmsnorm_fwd_triton_impl( rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) if MAKE_TRANSPOSE: - output_t_ptrs = out_transpose_ptr + cols * transpose_row_stride + row_idx + output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx tl.store(output_t_ptrs, rms_norm.to(output_type)) tl.store(output_ptrs, rms_norm.to(output_type)) @@ -126,7 +123,7 @@ def _rmsnorm_fwd_triton_impl( rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) if MAKE_TRANSPOSE: - output_t_ptrs = out_transpose_ptr + cols * transpose_row_stride + row_idx + output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx tl.store(output_t_ptrs, rms_norm.to(output_type), mask=mask) tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) @@ -157,7 +154,7 @@ def _rmsnorm_fwd_triton_impl( rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) if MAKE_TRANSPOSE: - output_t_ptrs = out_transpose_ptr + col_offsets * transpose_row_stride + row_idx + output_t_ptrs = out_transpose_ptr + col_offsets * out_transpose_stride + row_idx tl.store(output_t_ptrs, rms_norm.to(output_type), mask=mask) tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) if IS_FP8: From 26298c8eda594f51a8bb077b6e795f6fdd0870ce Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 8 Sep 2025 09:48:27 -0500 Subject: [PATCH 05/16] Formatting --- .../pytorch/triton_kernels/norms.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index ce6ea0af7..efb9e5cbf 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -189,7 +189,7 @@ def te_norm_fwd_triton( BLOCK_SIZE=BLOCK_SIZE, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, - MAKE_TRANSPOSE=MAKE_TRANSPOSE, + MAKE_TRANSPOSE=MAKE_TRANSPOSE, ) if kernel == 'layer': kwargs["APPLY_ATOMIC"]=APPLY_ATOMIC @@ -199,10 +199,9 @@ def te_norm_fwd_triton( elif kernel == "rms": kwargs["USE_BLOCKED"]=USE_BLOCKED kwargs["NUM_PRGMS"]=NUM_PRGMS - - kernel_func[grid_fwd]( - **kwargs, - ) + + kernel_func[grid_fwd](**kwargs) + # Reduce and find amax if "not APPLY_ATOMIC" is True for layernorm. if IS_FP8 and not APPLY_ATOMIC: _layernorm_fwd_reduce_triton[(triton.cdiv(N, ATOMIC_REDUCTION_BLOCK_SIZE),)]( @@ -250,12 +249,12 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): # drop in replacement for transformer_engine::pytorch::layernorm_bwd # TODO: Add support for `sm_margin > 0`. def te_layernorm_bwd_triton( - dz: torch.Tensor, - x: torch.Tensor, - mu: torch.Tensor, - rsigma: torch.Tensor, - gamma: torch.Tensor, - sm_margin: int, + dz: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, zero_centered_gamma: bool ): if sm_margin is not None and sm_margin > 0: From 18eb6e77451f2d0d52b2f55162be8cac05710434 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 12:37:05 -0600 Subject: [PATCH 06/16] Added skip for failing HIP kernels --- tests/pytorch/triton_kernels/test_norms.py | 43 +++++++++++++--------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/tests/pytorch/triton_kernels/test_norms.py b/tests/pytorch/triton_kernels/test_norms.py index 039532bd7..07874a841 100644 --- a/tests/pytorch/triton_kernels/test_norms.py +++ b/tests/pytorch/triton_kernels/test_norms.py @@ -25,6 +25,13 @@ ) from test_common import dtype_tols, te_compare_results, str_to_torch_dtype, fill_uniform +def _compare_func(actual, expected, atol, rtol, msg, use_torch_semantics=False): + try: + te_compare_results(actual, expected, atol, rtol, msg, use_torch_semantics) + except AssertionError as e: + if "Tensor 'expected' has" in str(e) and "NaN" in str(e): + pytest.skip("HIP reference tensor contains NaN values.") + # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -404,11 +411,11 @@ def _compare_output_tensors( quantization, fp8_dtype ): tols = dtype_tols(out_triton.dtype if quantization is None else fp8_dtype) - _compare_func = partial(te_compare_results, **tols, use_torch_semantics=True) + compare_func = partial(_compare_func, **tols, use_torch_semantics=True) dq_out_triton = out_triton.dequantize() dq_out_hip = out_hip.dequantize() - _compare_func( + compare_func( actual=dq_out_triton, expected=dq_out_hip, msg=lambda msg: f"Output does not match triton <-> hip\n\n{msg}\n", @@ -426,7 +433,7 @@ def _compare_output_tensors( if not out_hip._transpose_invalid: # The transpose data are generally uint8 so we must convert # them for floating point comparison. - _compare_func( + compare_func( actual=out_triton._transpose.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)).to(torch.float32), expected=out_hip._transpose.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)).to(torch.float32), msg=lambda msg: f"Output transpose does not match triton <-> hip\n\n{msg}\n", @@ -440,7 +447,7 @@ def _compare_output_tensors( # trick to MXFP8 data as we do to FP8 transpose data. # I suspect not. if out_hip._rowwise_data is not None: - _compare_func( + compare_func( actual=out_triton, expected=out_hip, msg=lambda msg: f"Output rowwise data does not match triton <-> hip\n\n{msg}\n", @@ -450,9 +457,9 @@ def _compare_output_tensors( assert out_triton._rowwise_data is None, "Expected no rowwise data." # We use higher precision for the scales - _compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True) + compare_func = partial(_compare_func, atol=1e-6, rtol=5e-5, use_torch_semantics=True) if quantization == "fp8": - _compare_func( + compare_func( actual=out_triton._scale_inv, expected=out_hip._scale_inv, msg=lambda msg: f"Output scale inverse does not match triton <-> hip\n\n{msg}\n", @@ -467,7 +474,7 @@ def _compare_output_tensors( msg += "be None." raise ValueError(msg) if has_rscale_triton: - _compare_func( + compare_func( actual=out_triton._rowwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), expected=out_hip._rowwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), msg=lambda msg: f"Output rowwise scale inverse does not match triton <-> hip\n\n{msg}\n", @@ -482,7 +489,7 @@ def _compare_output_tensors( msg += "be None." raise ValueError(msg) if has_cscale_triton: - _compare_func( + compare_func( actual=out_triton._columnwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), expected=out_hip._columnwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), msg=lambda msg: f"Output columnwise scale inverse does not match triton <-> hip\n\n{msg}\n", @@ -495,7 +502,7 @@ def _compare_quantizers( quantization ): if quantization is None: return - _compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True) + compare_func = partial(_compare_func, atol=1e-6, rtol=5e-5, use_torch_semantics=True) if quantizer_triton.dtype != quantizer_hip.dtype: raise ValueError("Expected matching quantizer dtypes, but got " @@ -509,12 +516,12 @@ def _compare_quantizers( raise ValueError(f"Expected matching quantizer {usage} but got {qt_usage=} != {qh_usage=}") if quantization == "fp8": - _compare_func( + compare_func( actual=quantizer_triton.scale, expected=quantizer_hip.scale, msg=lambda msg: f"Quantizer scale does not match triton <-> hip\n\n{msg}\n", ) - _compare_func( + compare_func( actual=quantizer_triton.amax, expected=quantizer_hip.amax, msg=lambda msg: f"Quantizer amax does not match triton <-> hip\n\n{msg}\n", @@ -527,15 +534,15 @@ def _compare_stat_tensors( norm ): # We use higher precision for the remaining outputs - _compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True) + compare_func = partial(_compare_func, atol=1e-6, rtol=5e-5, use_torch_semantics=True) - _compare_func( + compare_func( actual=rsigma_triton, expected=rsigma_hip, msg=lambda msg: f"rsigma does not match triton <-> hip\n\n{msg}\n", ) if norm == "layer": - _compare_func( + compare_func( actual=mu_triton, expected=mu_hip, msg=lambda msg: f"mu does not match triton <-> hip\n\n{msg}\n", @@ -577,20 +584,20 @@ def _compare_bwd_tensors( dbeta_triton, dbeta_hip, norm ): - _compare_func = partial(te_compare_results, atol=1.5e-4, rtol=1e-4, use_torch_semantics=True) + compare_func = partial(_compare_func, atol=1.5e-4, rtol=1e-4, use_torch_semantics=True) - _compare_func( + compare_func( actual=dx_triton, expected=dx_hip, msg=lambda msg: f"dx does not match triton <-> hip\n\n{msg}\n", ) - _compare_func( + compare_func( actual=dgamma_triton, expected=dgamma_hip, msg=lambda msg: f"dgamma does not match triton <-> hip\n\n{msg}\n", ) if norm == "layer": - _compare_func( + compare_func( actual=dbeta_triton, expected=dbeta_hip, msg=lambda msg: f"dbeta does not match triton <-> hip\n\n{msg}\n", From a72b507eea56050d301de88770e3c09ca24fb80c Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 13:23:18 -0600 Subject: [PATCH 07/16] Updated to account for alignment args --- .../pytorch/triton_kernels/norms.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index efb9e5cbf..b3bde6423 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -48,6 +48,8 @@ def te_rmsnorm_fwd_triton( sm_margin: int, zero_centered_gamma: bool, autotune: bool = True, + INPUT_ALIGNED_16: bool = False, + OUTPUT_ALIGNED_16: bool = False, ): return te_norm_fwd_triton( kernel='rms', @@ -61,6 +63,8 @@ def te_rmsnorm_fwd_triton( sm_margin=sm_margin, zero_centered_gamma=zero_centered_gamma, autotune=autotune, + INPUT_ALIGNED_16=INPUT_ALIGNED_16, + OUTPUT_ALIGNED_16=OUTPUT_ALIGNED_16, ) # triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd @@ -103,6 +107,8 @@ def te_norm_fwd_triton( sm_margin: int, zero_centered_gamma: bool, autotune: bool = True, + INPUT_ALIGNED_16: bool = False, + OUTPUT_ALIGNED_16: bool = False, ): if kernel not in {'rms', 'layer'}: raise ValueError(f"Expected `kernel` in ('rms', 'layer') but got {kernel=} instead.") @@ -199,6 +205,8 @@ def te_norm_fwd_triton( elif kernel == "rms": kwargs["USE_BLOCKED"]=USE_BLOCKED kwargs["NUM_PRGMS"]=NUM_PRGMS + kwargs["INPUT_ALIGNED_16"]=INPUT_ALIGNED_16 + kwargs["OUTPUT_ALIGNED_16"]=OUTPUT_ALIGNED_16 kernel_func[grid_fwd](**kwargs) @@ -216,7 +224,16 @@ def te_norm_fwd_triton( # triton drop-in replacement for transformer_engine::pytorch::rmsnorm_bwd -def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): +def te_rmsnorm_bwd_triton( + dz, x, + rsigma, gamma, + sm_margin, + zero_centered_gamma, + INPUT_ALIGNED_16=False, + GRAD_OUTPUT_ALIGNED_16=False, + DX_ALIGNED_16=False, + DG_ALIGNED_16=False, +): # may take non-contiguous inputs dz_ = dz.contiguous() x_ = x.contiguous() @@ -237,7 +254,8 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): grid_bwd = lambda meta: (NUM_PRGMS, ) _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, - USE_BLOCKED, NUM_PRGMS, num_warps=8) + USE_BLOCKED, NUM_PRGMS, INPUT_ALIGNED_16, GRAD_OUTPUT_ALIGNED_16, + DX_ALIGNED_16, DG_ALIGNED_16, num_warps=8) if need_reduction: grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] From a64b5f1bd0ef6d9019717cc9ddc69b217750f208 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 13:43:03 -0600 Subject: [PATCH 08/16] Updated CI script for MI350 runs, minor code cleaning --- ci/pytorch.sh | 4 ++-- transformer_engine/pytorch/triton_kernels/norms.py | 9 ++++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/ci/pytorch.sh b/ci/pytorch.sh index be150485f..dd340d88d 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -71,8 +71,8 @@ run_test_config(){ run_default_fa 1 triton_kernels/test_cast.py run_default_fa 1 triton_kernels/test_cast_mxfp8.py run_default_fa 1 triton_kernels/test_norm_common.py - run_default_fa 1 triton_kernels/test_norms.py - NVTE_TEST_TRITON_AUTOTUNE=1 run_default_fa_lbl "autotune" 3 triton_kernels/test_norms.py + NVTE_ROCM_ENABLE_MXFP8=1 run_default_fa 1 triton_kernels/test_norms.py + NVTE_ROCM_ENABLE_MXFP8=1 NVTE_TEST_TRITON_AUTOTUNE=1 run_default_fa_lbl "autotune" 3 triton_kernels/test_norms.py run_default_fa 1 test_parallel_cross_entropy.py NVTE_USE_DEQUANTIZE_TRITON=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 NVTE_USE_LAYERNORM_TRITON=1 run_default_fa_lbl "triton" 3 test_numerics.py NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_fusible_ops.py diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index b3bde6423..be7f907d8 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -51,7 +51,7 @@ def te_rmsnorm_fwd_triton( INPUT_ALIGNED_16: bool = False, OUTPUT_ALIGNED_16: bool = False, ): - return te_norm_fwd_triton( + return _te_norm_fwd_triton( kernel='rms', input_tensor=input, weight=weight, @@ -67,7 +67,7 @@ def te_rmsnorm_fwd_triton( OUTPUT_ALIGNED_16=OUTPUT_ALIGNED_16, ) -# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd +# triton drop-in replacement for transformer_engine::pytorch::layernorm_fwd def te_layernorm_fwd_triton( input: torch.Tensor, weight: torch.Tensor, @@ -80,7 +80,7 @@ def te_layernorm_fwd_triton( zero_centered_gamma: bool, autotune: bool = True, ): - return te_norm_fwd_triton( + return _te_norm_fwd_triton( kernel='layer', input_tensor=input, weight=weight, @@ -94,8 +94,7 @@ def te_layernorm_fwd_triton( autotune=autotune, ) -# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd -def te_norm_fwd_triton( +def _te_norm_fwd_triton( kernel: str, input_tensor: torch.Tensor, weight: torch.Tensor, From 1d2554caf964a3c73bfdbab4df6152f60102976c Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 13:50:30 -0600 Subject: [PATCH 09/16] Streamlined implementation --- .../pytorch/triton_kernels/norms.py | 31 +++++++------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index be7f907d8..79c570e0a 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -48,8 +48,6 @@ def te_rmsnorm_fwd_triton( sm_margin: int, zero_centered_gamma: bool, autotune: bool = True, - INPUT_ALIGNED_16: bool = False, - OUTPUT_ALIGNED_16: bool = False, ): return _te_norm_fwd_triton( kernel='rms', @@ -63,8 +61,6 @@ def te_rmsnorm_fwd_triton( sm_margin=sm_margin, zero_centered_gamma=zero_centered_gamma, autotune=autotune, - INPUT_ALIGNED_16=INPUT_ALIGNED_16, - OUTPUT_ALIGNED_16=OUTPUT_ALIGNED_16, ) # triton drop-in replacement for transformer_engine::pytorch::layernorm_fwd @@ -106,8 +102,6 @@ def _te_norm_fwd_triton( sm_margin: int, zero_centered_gamma: bool, autotune: bool = True, - INPUT_ALIGNED_16: bool = False, - OUTPUT_ALIGNED_16: bool = False, ): if kernel not in {'rms', 'layer'}: raise ValueError(f"Expected `kernel` in ('rms', 'layer') but got {kernel=} instead.") @@ -204,8 +198,8 @@ def _te_norm_fwd_triton( elif kernel == "rms": kwargs["USE_BLOCKED"]=USE_BLOCKED kwargs["NUM_PRGMS"]=NUM_PRGMS - kwargs["INPUT_ALIGNED_16"]=INPUT_ALIGNED_16 - kwargs["OUTPUT_ALIGNED_16"]=OUTPUT_ALIGNED_16 + kwargs["INPUT_ALIGNED_16"]=(input_tensor.data_ptr() % 16 == 0) and (input_tensor.stride(-1) % 16 == 0) + kwargs["OUTPUT_ALIGNED_16"]=(out_ptr.data_ptr() % 16 == 0) and (out_ptr.stride(-1) % 16 == 0) kernel_func[grid_fwd](**kwargs) @@ -223,16 +217,7 @@ def _te_norm_fwd_triton( # triton drop-in replacement for transformer_engine::pytorch::rmsnorm_bwd -def te_rmsnorm_bwd_triton( - dz, x, - rsigma, gamma, - sm_margin, - zero_centered_gamma, - INPUT_ALIGNED_16=False, - GRAD_OUTPUT_ALIGNED_16=False, - DX_ALIGNED_16=False, - DG_ALIGNED_16=False, -): +def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): # may take non-contiguous inputs dz_ = dz.contiguous() x_ = x.contiguous() @@ -250,11 +235,17 @@ def te_rmsnorm_bwd_triton( dg_tmp_rows = x_.shape[0] if use_blocked(x_) else num_programs(x_, sm_margin) dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None + input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(-1) % 16 == 0) + grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(-1) % 16 == 0) + dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(-1) % 16 == 0) + dg_target = dg_tmp if need_reduction else dgamma + dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(-1) % 16 == 0) + grid_bwd = lambda meta: (NUM_PRGMS, ) _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, - USE_BLOCKED, NUM_PRGMS, INPUT_ALIGNED_16, GRAD_OUTPUT_ALIGNED_16, - DX_ALIGNED_16, DG_ALIGNED_16, num_warps=8) + USE_BLOCKED, NUM_PRGMS, input_aligned_16, grad_output_aligned_16, + dx_aligned_16, dg_aligned_16, num_warps=8) if need_reduction: grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] From fd59057e1d1f0c875ae9474cef3c1c5a76f94eb1 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 14:10:30 -0600 Subject: [PATCH 10/16] Corrected alignment calculation --- .../pytorch/triton_kernels/norms.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index 79c570e0a..62ba92594 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -170,13 +170,15 @@ def _te_norm_fwd_triton( grid_fwd = lambda meta: (N if kernel=='layer' else NUM_PRGMS,) kernel_func = _norm_kernels[kernel][autotune] + input_row_stride = input_tensor.stride(0) + output_row_stride=out_ptr.stride(0) kwargs = dict( input_ptr=input_tensor, output_ptr=out_ptr, g_ptr=weight, rsigma_ptr=rsigma, - input_row_stride=input_tensor.stride(0), - output_row_stride=out_ptr.stride(0), + input_row_stride=input_row_stride, + output_row_stride=output_row_stride, n_rows=N, n_cols=H, epsilon=eps, q_amax_ptr=amax, @@ -198,8 +200,8 @@ def _te_norm_fwd_triton( elif kernel == "rms": kwargs["USE_BLOCKED"]=USE_BLOCKED kwargs["NUM_PRGMS"]=NUM_PRGMS - kwargs["INPUT_ALIGNED_16"]=(input_tensor.data_ptr() % 16 == 0) and (input_tensor.stride(-1) % 16 == 0) - kwargs["OUTPUT_ALIGNED_16"]=(out_ptr.data_ptr() % 16 == 0) and (out_ptr.stride(-1) % 16 == 0) + kwargs["INPUT_ALIGNED_16"]=(input_tensor.data_ptr() % 16 == 0) and (input_row_stride % 16 == 0) + kwargs["OUTPUT_ALIGNED_16"]=(out_ptr.data_ptr() % 16 == 0) and (output_row_stride % 16 == 0) kernel_func[grid_fwd](**kwargs) @@ -235,12 +237,11 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): dg_tmp_rows = x_.shape[0] if use_blocked(x_) else num_programs(x_, sm_margin) dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None - input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(-1) % 16 == 0) - grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(-1) % 16 == 0) - dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(-1) % 16 == 0) + input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) % 16 == 0) + grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(0) % 16 == 0) + dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(0) % 16 == 0) dg_target = dg_tmp if need_reduction else dgamma - dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(-1) % 16 == 0) - + dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(0) % 16 == 0) grid_bwd = lambda meta: (NUM_PRGMS, ) _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, From bbd4240c008dea2c41ce77cbe0e02fd8b63b0637 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 14:14:11 -0600 Subject: [PATCH 11/16] Add copyright --- transformer_engine/pytorch/triton_kernels/norms.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index 62ba92594..034fdd957 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -1,3 +1,6 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# License for AMD contributions = MIT. See LICENSE for more information + import torch import triton import warnings From 92aecf2791d9836ff519dbff26e4efdf25137d83 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 14:57:52 -0600 Subject: [PATCH 12/16] Updated alignment calculation --- transformer_engine/pytorch/triton_kernels/norms.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index 034fdd957..535b8e105 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -203,8 +203,8 @@ def _te_norm_fwd_triton( elif kernel == "rms": kwargs["USE_BLOCKED"]=USE_BLOCKED kwargs["NUM_PRGMS"]=NUM_PRGMS - kwargs["INPUT_ALIGNED_16"]=(input_tensor.data_ptr() % 16 == 0) and (input_row_stride % 16 == 0) - kwargs["OUTPUT_ALIGNED_16"]=(out_ptr.data_ptr() % 16 == 0) and (output_row_stride % 16 == 0) + kwargs["INPUT_ALIGNED_16"]=(input_tensor.data_ptr() % 16 == 0) and (input_row_stride * input_tensor.dtype.itemsize % 16 == 0) + kwargs["OUTPUT_ALIGNED_16"]=(out_ptr.data_ptr() % 16 == 0) and (output_row_stride * out_ptr.dtype.itemsize % 16 == 0) kernel_func[grid_fwd](**kwargs) @@ -240,11 +240,11 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): dg_tmp_rows = x_.shape[0] if use_blocked(x_) else num_programs(x_, sm_margin) dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None - input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) % 16 == 0) - grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(0) % 16 == 0) - dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(0) % 16 == 0) + input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) * x_.dtype.itemsize % 16 == 0) + grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(0) * dz_.dtype.itemsize % 16 == 0) + dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(0) * dx.dtype.itemsize % 16 == 0) dg_target = dg_tmp if need_reduction else dgamma - dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(0) % 16 == 0) + dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(0) * dg_target.dtype.itemsize % 16 == 0) grid_bwd = lambda meta: (NUM_PRGMS, ) _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, From 12bb15609b96e6723b200cb355519d32a4b628e8 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 29 Jan 2026 10:45:26 -0600 Subject: [PATCH 13/16] Corrected FP8_CS handling --- transformer_engine/pytorch/triton_kernels/norms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index 535b8e105..9f9f234a7 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -121,8 +121,9 @@ def _te_norm_fwd_triton( f"The shape of `weight` must be feature-aligned, " f"but {weight.shape[0]=} while {input_tensor.shape[1]=}" ) - IS_FP8 = isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) + IS_FP8 = isinstance(quantizer, Float8Quantizer) IS_MXFP8 = isinstance(quantizer, MXFP8Quantizer) + IS_FP8_CURRENT_SCALING = isinstance(quantizer, Float8CurrentScalingQuantizer) BLOCK_SIZE = block_size(input_tensor) USE_BLOCKED = use_blocked(input_tensor) NUM_PRGMS = num_programs(input_tensor, sm_margin) @@ -215,7 +216,7 @@ def _te_norm_fwd_triton( quantizer.amax, N, ATOMIC_REDUCTION_BLOCK_SIZE, ) - elif IS_MXFP8: + elif IS_MXFP8 or IS_FP8_CURRENT_SCALING: out = quantizer.quantize(out, out=ln_out) return out, mu, rsigma From 19250390550ff69eb5b07db3ab27e05686cea03d Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 30 Jan 2026 13:38:00 -0600 Subject: [PATCH 14/16] Corrected layernorm memory access bug --- tests/pytorch/triton_kernels/test_norms.py | 11 ++++------- transformer_engine/pytorch/triton_kernels/norms.py | 12 +++++++++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/pytorch/triton_kernels/test_norms.py b/tests/pytorch/triton_kernels/test_norms.py index 07874a841..5dee69382 100644 --- a/tests/pytorch/triton_kernels/test_norms.py +++ b/tests/pytorch/triton_kernels/test_norms.py @@ -443,9 +443,6 @@ def _compare_output_tensors( if not isinstance(out_triton, MXFP8Tensor): raise ValueError(f"Expected a MXFP8Tensor but got {type(out_triton)} instead.") - # TODO(micky774): Figure out if we need to apply the same view - # trick to MXFP8 data as we do to FP8 transpose data. - # I suspect not. if out_hip._rowwise_data is not None: compare_func( actual=out_triton, @@ -475,8 +472,8 @@ def _compare_output_tensors( raise ValueError(msg) if has_rscale_triton: compare_func( - actual=out_triton._rowwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), - expected=out_hip._rowwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), + actual=out_triton._rowwise_scale_inv.view(torch.uint8), + expected=out_hip._rowwise_scale_inv.view(torch.uint8), msg=lambda msg: f"Output rowwise scale inverse does not match triton <-> hip\n\n{msg}\n", ) @@ -490,8 +487,8 @@ def _compare_output_tensors( raise ValueError(msg) if has_cscale_triton: compare_func( - actual=out_triton._columnwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), - expected=out_hip._columnwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), + actual=out_triton._columnwise_scale_inv.view(torch.uint8), + expected=out_hip._columnwise_scale_inv.view(torch.uint8), msg=lambda msg: f"Output columnwise scale inverse does not match triton <-> hip\n\n{msg}\n", ) diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index 9f9f234a7..3421b365f 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -155,7 +155,7 @@ def _te_norm_fwd_triton( MAKE_TRANSPOSE = quantizer.columnwise_usage amax = ( quantizer.amax if APPLY_ATOMIC else - torch.empty((NUM_PRGMS,), dtype=torch.float32, device=device) + torch.empty((N,), dtype=torch.float32, device=device) ) tl_dtype = te_dtype_to_triton_dtype(quantizer.dtype) scale_inv_ptr = out._scale_inv @@ -204,8 +204,14 @@ def _te_norm_fwd_triton( elif kernel == "rms": kwargs["USE_BLOCKED"]=USE_BLOCKED kwargs["NUM_PRGMS"]=NUM_PRGMS - kwargs["INPUT_ALIGNED_16"]=(input_tensor.data_ptr() % 16 == 0) and (input_row_stride * input_tensor.dtype.itemsize % 16 == 0) - kwargs["OUTPUT_ALIGNED_16"]=(out_ptr.data_ptr() % 16 == 0) and (output_row_stride * out_ptr.dtype.itemsize % 16 == 0) + kwargs["INPUT_ALIGNED_16"]=( + input_tensor.data_ptr() % 16 == 0 and + input_row_stride * getattr(input_tensor.dtype, 'itemsize', 1) % 16 == 0 + ) + kwargs["OUTPUT_ALIGNED_16"]=( + out_ptr.data_ptr() % 16 == 0 and + output_row_stride * getattr(out_ptr.dtype, 'itemsize', 1) % 16 == 0 + ) kernel_func[grid_fwd](**kwargs) From 6f9b6c5f066e4eae053bf157c4898f60d685d579 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 30 Jan 2026 14:34:22 -0600 Subject: [PATCH 15/16] Corrected amax dims --- transformer_engine/pytorch/triton_kernels/norms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index 3421b365f..333ef1e9a 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -126,7 +126,7 @@ def _te_norm_fwd_triton( IS_FP8_CURRENT_SCALING = isinstance(quantizer, Float8CurrentScalingQuantizer) BLOCK_SIZE = block_size(input_tensor) USE_BLOCKED = use_blocked(input_tensor) - NUM_PRGMS = num_programs(input_tensor, sm_margin) + NUM_PRGMS = N if kernel=='layer' else num_programs(input_tensor, sm_margin) MAKE_TRANSPOSE = False APPLY_ATOMIC = N < 512 or kernel == 'rms' ATOMIC_REDUCTION_BLOCK_SIZE=256 @@ -172,7 +172,7 @@ def _te_norm_fwd_triton( out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) out_transpose_stride = out._transpose.stride(0) - grid_fwd = lambda meta: (N if kernel=='layer' else NUM_PRGMS,) + grid_fwd = lambda meta: (NUM_PRGMS,) kernel_func = _norm_kernels[kernel][autotune] input_row_stride = input_tensor.stride(0) output_row_stride=out_ptr.stride(0) From dc3ed87c817865ef180053f8735cceb292518bac Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 30 Jan 2026 14:35:59 -0600 Subject: [PATCH 16/16] Adjusted amax init --- transformer_engine/pytorch/triton_kernels/norms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py index 333ef1e9a..eee8a6e1d 100644 --- a/transformer_engine/pytorch/triton_kernels/norms.py +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -155,7 +155,7 @@ def _te_norm_fwd_triton( MAKE_TRANSPOSE = quantizer.columnwise_usage amax = ( quantizer.amax if APPLY_ATOMIC else - torch.empty((N,), dtype=torch.float32, device=device) + torch.empty((NUM_PRGMS,), dtype=torch.float32, device=device) ) tl_dtype = te_dtype_to_triton_dtype(quantizer.dtype) scale_inv_ptr = out._scale_inv