diff --git a/ci/pytorch.sh b/ci/pytorch.sh index be150485f..dd340d88d 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -71,8 +71,8 @@ run_test_config(){ run_default_fa 1 triton_kernels/test_cast.py run_default_fa 1 triton_kernels/test_cast_mxfp8.py run_default_fa 1 triton_kernels/test_norm_common.py - run_default_fa 1 triton_kernels/test_norms.py - NVTE_TEST_TRITON_AUTOTUNE=1 run_default_fa_lbl "autotune" 3 triton_kernels/test_norms.py + NVTE_ROCM_ENABLE_MXFP8=1 run_default_fa 1 triton_kernels/test_norms.py + NVTE_ROCM_ENABLE_MXFP8=1 NVTE_TEST_TRITON_AUTOTUNE=1 run_default_fa_lbl "autotune" 3 triton_kernels/test_norms.py run_default_fa 1 test_parallel_cross_entropy.py NVTE_USE_DEQUANTIZE_TRITON=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 NVTE_USE_LAYERNORM_TRITON=1 run_default_fa_lbl "triton" 3 test_numerics.py NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_fusible_ops.py diff --git a/tests/pytorch/triton_kernels/test_norms.py b/tests/pytorch/triton_kernels/test_norms.py index 65b752f18..5dee69382 100644 --- a/tests/pytorch/triton_kernels/test_norms.py +++ b/tests/pytorch/triton_kernels/test_norms.py @@ -17,16 +17,21 @@ ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer, MXFP8Tensor -from transformer_engine.pytorch.triton_kernels.rmsnorm import ( - te_rmsnorm_bwd_triton, - te_rmsnorm_fwd_triton, -) -from transformer_engine.pytorch.triton_kernels.layernorm import ( +from transformer_engine.pytorch.triton_kernels.norms import ( te_layernorm_bwd_triton, te_layernorm_fwd_triton, + te_rmsnorm_bwd_triton, + te_rmsnorm_fwd_triton, ) from test_common import dtype_tols, te_compare_results, str_to_torch_dtype, fill_uniform +def _compare_func(actual, expected, atol, rtol, msg, use_torch_semantics=False): + try: + te_compare_results(actual, expected, atol, rtol, msg, use_torch_semantics) + except AssertionError as e: + if "Tensor 'expected' has" in str(e) and "NaN" in str(e): + pytest.skip("HIP reference tensor contains NaN values.") + # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() @@ -406,11 +411,11 @@ def _compare_output_tensors( quantization, fp8_dtype ): tols = dtype_tols(out_triton.dtype if quantization is None else fp8_dtype) - _compare_func = partial(te_compare_results, **tols, use_torch_semantics=True) + compare_func = partial(_compare_func, **tols, use_torch_semantics=True) dq_out_triton = out_triton.dequantize() dq_out_hip = out_hip.dequantize() - _compare_func( + compare_func( actual=dq_out_triton, expected=dq_out_hip, msg=lambda msg: f"Output does not match triton <-> hip\n\n{msg}\n", @@ -428,7 +433,7 @@ def _compare_output_tensors( if not out_hip._transpose_invalid: # The transpose data are generally uint8 so we must convert # them for floating point comparison. - _compare_func( + compare_func( actual=out_triton._transpose.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)).to(torch.float32), expected=out_hip._transpose.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)).to(torch.float32), msg=lambda msg: f"Output transpose does not match triton <-> hip\n\n{msg}\n", @@ -438,11 +443,8 @@ def _compare_output_tensors( if not isinstance(out_triton, MXFP8Tensor): raise ValueError(f"Expected a MXFP8Tensor but got {type(out_triton)} instead.") - # TODO(micky774): Figure out if we need to apply the same view - # trick to MXFP8 data as we do to FP8 transpose data. - # I suspect not. if out_hip._rowwise_data is not None: - _compare_func( + compare_func( actual=out_triton, expected=out_hip, msg=lambda msg: f"Output rowwise data does not match triton <-> hip\n\n{msg}\n", @@ -452,9 +454,9 @@ def _compare_output_tensors( assert out_triton._rowwise_data is None, "Expected no rowwise data." # We use higher precision for the scales - _compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True) + compare_func = partial(_compare_func, atol=1e-6, rtol=5e-5, use_torch_semantics=True) if quantization == "fp8": - _compare_func( + compare_func( actual=out_triton._scale_inv, expected=out_hip._scale_inv, msg=lambda msg: f"Output scale inverse does not match triton <-> hip\n\n{msg}\n", @@ -469,9 +471,9 @@ def _compare_output_tensors( msg += "be None." raise ValueError(msg) if has_rscale_triton: - _compare_func( - actual=out_triton._rowwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), - expected=out_hip._rowwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), + compare_func( + actual=out_triton._rowwise_scale_inv.view(torch.uint8), + expected=out_hip._rowwise_scale_inv.view(torch.uint8), msg=lambda msg: f"Output rowwise scale inverse does not match triton <-> hip\n\n{msg}\n", ) @@ -484,9 +486,9 @@ def _compare_output_tensors( msg += "be None." raise ValueError(msg) if has_cscale_triton: - _compare_func( - actual=out_triton._columnwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), - expected=out_hip._columnwise_scale_inv.view(te_dtype_to_torch_dtype(out_triton._fp8_dtype)), + compare_func( + actual=out_triton._columnwise_scale_inv.view(torch.uint8), + expected=out_hip._columnwise_scale_inv.view(torch.uint8), msg=lambda msg: f"Output columnwise scale inverse does not match triton <-> hip\n\n{msg}\n", ) @@ -497,7 +499,7 @@ def _compare_quantizers( quantization ): if quantization is None: return - _compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True) + compare_func = partial(_compare_func, atol=1e-6, rtol=5e-5, use_torch_semantics=True) if quantizer_triton.dtype != quantizer_hip.dtype: raise ValueError("Expected matching quantizer dtypes, but got " @@ -511,12 +513,12 @@ def _compare_quantizers( raise ValueError(f"Expected matching quantizer {usage} but got {qt_usage=} != {qh_usage=}") if quantization == "fp8": - _compare_func( + compare_func( actual=quantizer_triton.scale, expected=quantizer_hip.scale, msg=lambda msg: f"Quantizer scale does not match triton <-> hip\n\n{msg}\n", ) - _compare_func( + compare_func( actual=quantizer_triton.amax, expected=quantizer_hip.amax, msg=lambda msg: f"Quantizer amax does not match triton <-> hip\n\n{msg}\n", @@ -529,15 +531,15 @@ def _compare_stat_tensors( norm ): # We use higher precision for the remaining outputs - _compare_func = partial(te_compare_results, atol=1e-6, rtol=5e-5, use_torch_semantics=True) + compare_func = partial(_compare_func, atol=1e-6, rtol=5e-5, use_torch_semantics=True) - _compare_func( + compare_func( actual=rsigma_triton, expected=rsigma_hip, msg=lambda msg: f"rsigma does not match triton <-> hip\n\n{msg}\n", ) if norm == "layer": - _compare_func( + compare_func( actual=mu_triton, expected=mu_hip, msg=lambda msg: f"mu does not match triton <-> hip\n\n{msg}\n", @@ -579,20 +581,20 @@ def _compare_bwd_tensors( dbeta_triton, dbeta_hip, norm ): - _compare_func = partial(te_compare_results, atol=1.5e-4, rtol=1e-4, use_torch_semantics=True) + compare_func = partial(_compare_func, atol=1.5e-4, rtol=1e-4, use_torch_semantics=True) - _compare_func( + compare_func( actual=dx_triton, expected=dx_hip, msg=lambda msg: f"dx does not match triton <-> hip\n\n{msg}\n", ) - _compare_func( + compare_func( actual=dgamma_triton, expected=dgamma_hip, msg=lambda msg: f"dgamma does not match triton <-> hip\n\n{msg}\n", ) if norm == "layer": - _compare_func( + compare_func( actual=dbeta_triton, expected=dbeta_hip, msg=lambda msg: f"dbeta does not match triton <-> hip\n\n{msg}\n", diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 1f38b493c..a7af8f4db 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -20,8 +20,12 @@ from ..export import is_in_onnx_export_mode if IS_HIP_EXTENSION: - from ..triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton - from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton + from ..triton_kernels.norms import ( + te_layernorm_fwd_triton, + te_layernorm_bwd_triton, + te_rmsnorm_fwd_triton, + te_rmsnorm_bwd_triton + ) def _get_normalization_func(normalization: str, forward: bool): use_rmsnorm_triton = bool( int(os.environ.get('NVTE_USE_RMSNORM_TRITON', '0')) ) and IS_HIP_EXTENSION diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d1aeebc9f..17b6a4d17 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -79,8 +79,7 @@ ) if IS_HIP_EXTENSION: - from ..triton_kernels.layernorm import te_layernorm_bwd_triton - from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton + from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton __all__ = ["LayerNormLinear"] diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4492abe3e..29aa3feb2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -83,8 +83,7 @@ from ...debug.pytorch.debug_state import TEDebugState if IS_HIP_EXTENSION: - from ..triton_kernels.layernorm import te_layernorm_bwd_triton - from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton + from ..triton_kernels.norms import te_layernorm_bwd_triton, te_rmsnorm_bwd_triton __all__ = ["LayerNormMLP"] diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index d429c4fa4..53d372c11 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -17,7 +17,9 @@ from transformer_engine_torch import layernorm_bwd, layernorm_fwd from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: - from ...triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton + from ...triton_kernels.norms import te_layernorm_fwd_triton, te_layernorm_bwd_triton +from ...fp8 import FP8GlobalStateManager +from ...tensor import QuantizedTensor from ...constants import TE_DType from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...export import is_in_onnx_export_mode diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index bbe805fe9..62fa88494 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -17,7 +17,7 @@ from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: - from ...triton_kernels.rmsnorm import ( + from ...triton_kernels.norms import ( te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton ) diff --git a/transformer_engine/pytorch/triton_kernels/layernorm.py b/transformer_engine/pytorch/triton_kernels/layernorm.py index 3baa64697..27e279f40 100644 --- a/transformer_engine/pytorch/triton_kernels/layernorm.py +++ b/transformer_engine/pytorch/triton_kernels/layernorm.py @@ -3,25 +3,9 @@ from itertools import product -import os -import torch - -from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer -from ..constants import TE_DType -from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor.quantized_tensor import Quantizer -from ..triton_kernels.cast import te_quantize_triton import triton import triton.language as tl -import warnings -import transformer_engine_torch as tex -from .common import ( - get_fp8_max, - te_dtype_to_torch_dtype, - te_dtype_to_triton_dtype, -) -from .norm_common import make_ln_out def get_autotune_config(full_tuning_space=False): if full_tuning_space: @@ -36,20 +20,20 @@ def get_autotune_config(full_tuning_space=False): @triton.jit def _layernorm_fwd_triton_impl( - x_ptr, - y_ptr, - w_ptr, + input_ptr, + output_ptr, + g_ptr, b_ptr, mean_ptr, - rstd_ptr, - scale_ptr, - amax_ptr, - scale_inv_ptr, - x_row_stride, - y_row_stride, + rsigma_ptr, + input_row_stride, + output_row_stride, n_rows, n_cols, - eps, + epsilon, + q_amax_ptr, + q_scale_ptr, + scale_inv_ptr, out_transpose_ptr, out_transpose_stride, ZERO_CENTERED_GAMMA: tl.constexpr, @@ -81,12 +65,12 @@ def _layernorm_fwd_triton_impl( start_row = pid if IS_FP8: - scale = tl.load(scale_ptr) + scale = tl.load(q_scale_ptr) amax = 0.0 for row_idx in range(start_row, start_row + rows_per_tile): - x_ptr_start = x_ptr + (row_idx * x_row_stride) - y_ptr_start = y_ptr + (row_idx * y_row_stride) + x_ptr_start = input_ptr + (row_idx * input_row_stride) + y_ptr_start = output_ptr + (row_idx * output_row_stride) n_cols_blks = tl.cdiv(n_cols, BLOCK_SIZE) - 1 @@ -118,16 +102,16 @@ def _layernorm_fwd_triton_impl( _var += x_block * x_block var = tl.sum(_var, axis=0) / n_cols - rstd = tl.rsqrt(var + eps) + rstd = tl.rsqrt(var + epsilon) # Write mean / rstd tl.store(mean_ptr + row_idx, mean) - tl.store(rstd_ptr + row_idx, rstd) + tl.store(rsigma_ptr + row_idx, rstd) # Normalize and store for blk_idx in range(0, n_cols_blks): cols = blk_idx * BLOCK_SIZE + col_offsets - w_block = tl.load(w_ptr + cols).to(tl.float32) + w_block = tl.load(g_ptr + cols).to(tl.float32) b_block = tl.load(b_ptr + cols).to(tl.float32) x_block = tl.load(x_ptr_start + cols).to(tl.float32) if ZERO_CENTERED_GAMMA: @@ -139,7 +123,7 @@ def _layernorm_fwd_triton_impl( amax = amax_temp if amax_temp > amax else amax y_block = y_block * scale y_block = tl.clamp(y_block, -FP8_MAX, FP8_MAX) - y_block = y_block.to(y_ptr.type.element_ty) + y_block = y_block.to(output_ptr.type.element_ty) tl.store(y_ptr_start + cols, y_block) if MAKE_TRANSPOSE: output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx @@ -148,7 +132,7 @@ def _layernorm_fwd_triton_impl( # For last iteration, do masked load and store cols = n_cols_blks * BLOCK_SIZE + col_offsets mask = cols < n_cols - w_block = tl.load(w_ptr + cols, mask=mask, other=0.0).to(tl.float32) + w_block = tl.load(g_ptr + cols, mask=mask, other=0.0).to(tl.float32) b_block = tl.load(b_ptr + cols, mask=mask, other=0.0).to(tl.float32) x_block = tl.load(x_ptr_start + cols, mask=mask, other=0.0).to(tl.float32) if ZERO_CENTERED_GAMMA: @@ -160,20 +144,20 @@ def _layernorm_fwd_triton_impl( amax = amax_temp if amax_temp > amax else amax y_block = y_block * scale y_block = tl.clamp(y_block, -FP8_MAX, FP8_MAX) - y_block = y_block.to(y_ptr.type.element_ty) + y_block = y_block.to(output_ptr.type.element_ty) tl.store(y_ptr_start + cols, y_block, mask=mask) if MAKE_TRANSPOSE: output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx tl.store(output_t_ptrs, y_block, mask=mask) if IS_FP8: + if pid == 0: + scale_inv = tl.fdiv(1.0, scale) + tl.store(scale_inv_ptr, scale_inv) if APPLY_ATOMIC: - if pid == 0: - scale_inv = tl.fdiv(1.0, scale) - tl.store(scale_inv_ptr, scale_inv) - tl.atomic_max(amax_ptr, amax, sem="relaxed") + tl.atomic_max(q_amax_ptr, amax, sem="relaxed") else: - tl.store(amax_ptr + pid, amax) + tl.store(q_amax_ptr + pid, amax) autotune_dec = triton.autotune(configs=get_autotune_config(), key=["n_rows", "n_cols"], use_cuda_graph=True) _layernorm_fwd_triton = autotune_dec(_layernorm_fwd_triton_impl) @@ -182,8 +166,6 @@ def _layernorm_fwd_triton_impl( def _layernorm_fwd_reduce_triton( amax_input_ptr, amax_output_ptr, - scale_ptr, - scale_inv_ptr, n_rows, BLOCK_SIZE: tl.constexpr, ): @@ -200,12 +182,6 @@ def _layernorm_fwd_reduce_triton( tl.atomic_max(amax_output_ptr, amax, sem="relaxed") - if pid == 0: - scale = tl.load(scale_ptr) - scale_inv = tl.fdiv(1.0, scale) - tl.store(scale_inv_ptr, scale_inv) - - @triton.jit def _layernorm_bwd_dx_fused_triton( DX, # pointer to the input gradient @@ -455,215 +431,3 @@ def _layernorm_bwd_dwdb_triton_v2( sum_db = tl.sum(db, axis=0) tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.type.element_ty), mask=cols < N) tl.store(FINAL_DB + cols, sum_db.to(FINAL_DB.type.element_ty), mask=cols < N) - -def te_layernorm_fwd_triton(input: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float, - ln_out: torch.Tensor, - quantizer: Quantizer, - otype: tex.DType, - sm_margin: int, - zero_centered_gamma: bool, - autotune: bool = True,): - if sm_margin is not None and sm_margin > 0: - warnings.warn( - '"sm_margin" is not supported in the Triton based forward layer-norm kernel. ' - + f"sm_margin={sm_margin} will be ignored." - ) - device = input.device - M, N = input.shape - - IS_MXFP8 = isinstance(quantizer, MXFP8Quantizer) - MAKE_TRANSPOSE = False - - # Create empty tensors for mu and rsigma - mu = torch.empty((M,), dtype=torch.float32, device=device) - rsigma = torch.empty((M,), dtype=torch.float32, device=device) - torch_out_dtype = ( - otype if isinstance(otype, torch.dtype) - else te_dtype_to_torch_dtype(otype) - ) - # Create ln_out - ln_out = make_ln_out(ln_out, quantizer=quantizer, input_shape=input.shape, out_dtype=torch_out_dtype) - # To update the amax ptr directly with atomic max - APPLY_ATOMIC = M < 512 - - # MXFP8/fp8_current_scaling requires unfused quantization. - IS_FP8 = isinstance(quantizer, Float8Quantizer) - IS_FP8_CURRENT_SCALING = isinstance(quantizer, Float8CurrentScalingQuantizer) - - amax_temp = torch.empty((M,), dtype=torch.float32, device=device) if IS_FP8 else None - - max_fused_size = 16384 // input.element_size() - BLOCK_SIZE = min(max_fused_size, triton.next_power_of_2(N)) - - out_transpose_ptr = None - out_transpose_stride = None - - # Create necessary values for fp8 if needed - if IS_FP8: - scale = quantizer.scale - amax_out = quantizer.amax - scale_inv = ln_out._scale_inv - cast_out = ln_out._data - MAKE_TRANSPOSE = quantizer.columnwise_usage - if MAKE_TRANSPOSE: - tl_dtype = te_dtype_to_triton_dtype(quantizer.dtype) - if ln_out._transpose_invalid: - ln_out._transpose = torch.empty((ln_out._data.shape[1], ln_out._data.shape[0]), dtype=ln_out._data.dtype, device=device) - ln_out._transpose_invalid = False - out_transpose_ptr = triton.reinterpret(ln_out._transpose, tl_dtype) - out_transpose_stride = ln_out._transpose.stride(0) - else: - scale = None - amax_out = None - scale_inv = None - cast_out = ln_out - - kernel = _layernorm_fwd_triton if autotune else _layernorm_fwd_triton_impl - kernel[(M,)]( - input, - triton.reinterpret(cast_out, te_dtype_to_triton_dtype(ln_out._fp8_dtype)) if IS_FP8 else cast_out, - weight, - bias, - mu, - rsigma, - scale, - amax_out if APPLY_ATOMIC else amax_temp, - scale_inv, - input.stride(0), - cast_out.stride(0), - M, - N, - eps, - out_transpose_ptr, - out_transpose_stride, - ZERO_CENTERED_GAMMA=zero_centered_gamma, - BLOCK_SIZE=BLOCK_SIZE, - IS_FP8=IS_FP8, - APPLY_ATOMIC=APPLY_ATOMIC, - # TODO: Improve performance with persistent kernel - # Persistent kernel currently lags behind non persistent version - # It also lags behind TE implementation in a few cases - PERSISTENT=False, - FP8_MAX=get_fp8_max(quantizer.dtype) if IS_FP8 else None, - MAKE_TRANSPOSE=MAKE_TRANSPOSE - ) - - # For MXFP8, we do regular layernorm and then quantize it separately - if IS_MXFP8 or IS_FP8_CURRENT_SCALING: - ln_out = te_quantize_triton(ln_out, quantizer) - - # Reduce and find amax if "not APPLY_ATOMIC" is True. - if IS_FP8 and not APPLY_ATOMIC: - _layernorm_fwd_reduce_triton[(triton.cdiv(M, 256),)]( - amax_temp, - amax_out, - scale, - scale_inv, - M, - 256, - ) - return ln_out, mu, rsigma - -# drop in replacement for transformer_engine::pytorch::layernorm_bwd -# TODO: Add support for `sm_margin > 0`. -def te_layernorm_bwd_triton( - dz: torch.Tensor, - x: torch.Tensor, - mu: torch.Tensor, - rsigma: torch.Tensor, - gamma: torch.Tensor, - sm_margin: int, - zero_centered_gamma: bool -): - if sm_margin is not None and sm_margin > 0: - warnings.warn( - '"sm_margin" is not supported in the Triton based backward layer-norm kernel. ' - + f"sm_margin={sm_margin} will be ignored." - ) - M, N = x.shape - # calculate dw and db separately when M is small - IGNORE_DW_DB_IN_FUSED = M <= 512 - tile_num = max(min(256, M // 4), 1) - if M <= 512 and M * N < 64 * 1024 * 1024: - tile_num = M - elif M >= 8192: - tile_num = 2048 - max_fused_size = 32768 // x.element_size() - next_power = triton.next_power_of_2(N) - BLOCK_SIZE = min(max_fused_size, next_power) - # For cases with small M and large N, decrease block size to help with occupancy and register spill - if tile_num == M: - if tile_num > 256: - BLOCK_SIZE = min(BLOCK_SIZE, 2048) - else: - BLOCK_SIZE = min(BLOCK_SIZE, 4096) - USE_BLOCKED = N > BLOCK_SIZE - num_warps = min(max(BLOCK_SIZE // 256, 1), 8) - - dx = torch.empty_like(x) - if not IGNORE_DW_DB_IN_FUSED: - _dgamma = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device) - _dbeta = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device) - else: - _dgamma = None - _dbeta = None - dgamma = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device) - dbeta = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device) - grid_bwd = (tile_num,) - _layernorm_bwd_dx_fused_triton[grid_bwd]( - dx, - dz, - _dgamma, - _dbeta, - x, - gamma, - mu, - rsigma, - x.stride(0), - N, - ZERO_CENTERED_GAMMA=zero_centered_gamma, - NUM_ROWS=M, - BLOCK_SIZE_N=BLOCK_SIZE, - USE_BLOCKED=USE_BLOCKED, - num_warps=num_warps, - IGNORE_DW_DB=IGNORE_DW_DB_IN_FUSED, - ) - grid_reduce = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE_N"]),) - if not IGNORE_DW_DB_IN_FUSED: - dwdb_block_n = max(16, N // 256) - dwdb_block_n = triton.next_power_of_2(dwdb_block_n) - dwdb_block_m = (64 * 128) // dwdb_block_n - dwdb_block_m = min(triton.next_power_of_2(tile_num), dwdb_block_m) - _layernorm_bwd_dwdb_triton[grid_reduce]( - _dgamma, - _dbeta, - dgamma, - dbeta, - min(tile_num, M), - N, - BLOCK_SIZE_M=dwdb_block_m, - BLOCK_SIZE_N=dwdb_block_n, - ) - else: - dwdb_block_n = max(16, N // 256) - dwdb_block_n = triton.next_power_of_2(dwdb_block_n) - dwdb_block_m = (64 * 128) // dwdb_block_n - dwdb_block_m = min(triton.next_power_of_2(M), dwdb_block_m) - _layernorm_bwd_dwdb_triton_v2[grid_reduce]( - x, - dz, - mu, - rsigma, - x.stride(0), - dgamma, - dbeta, - M, - N, - BLOCK_SIZE_M=dwdb_block_m, - BLOCK_SIZE_N=dwdb_block_n, - ) - - return dx, dgamma, dbeta diff --git a/transformer_engine/pytorch/triton_kernels/norms.py b/transformer_engine/pytorch/triton_kernels/norms.py new file mode 100644 index 000000000..eee8a6e1d --- /dev/null +++ b/transformer_engine/pytorch/triton_kernels/norms.py @@ -0,0 +1,367 @@ +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. +# License for AMD contributions = MIT. See LICENSE for more information + +import torch +import triton +import warnings +import transformer_engine_torch as tex + +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.triton_kernels.common import ( + te_dtype_to_torch_dtype, + te_dtype_to_triton_dtype, +) +from ..tensor.quantized_tensor import Quantizer +from .norm_common import num_programs, block_size, use_blocked, make_ln_out +from .common import get_fp8_max +from .rmsnorm import ( + _rmsnorm_fwd_triton, + _rmsnorm_fwd_triton_impl, + _rmsnorm_bwd_triton, + _rmsnorm_bwd_dg_reduce_triton, +) +from .layernorm import ( + _layernorm_fwd_triton, + _layernorm_fwd_triton_impl, + _layernorm_fwd_reduce_triton, + _layernorm_bwd_dwdb_triton, + _layernorm_bwd_dwdb_triton_v2, + _layernorm_bwd_dx_fused_triton, +) + +_norm_kernels={ + "rms":{ + True: _rmsnorm_fwd_triton, + False: _rmsnorm_fwd_triton_impl, + }, + "layer":{ + True: _layernorm_fwd_triton, + False: _layernorm_fwd_triton_impl, + } +} +# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd +def te_rmsnorm_fwd_triton( + input: torch.Tensor, + weight: torch.Tensor, + eps: float, + ln_out: torch.Tensor, + quantizer: Quantizer, + otype: tex.DType, + sm_margin: int, + zero_centered_gamma: bool, + autotune: bool = True, +): + return _te_norm_fwd_triton( + kernel='rms', + input_tensor=input, + weight=weight, + bias=None, + eps=eps, + ln_out=ln_out, + quantizer=quantizer, + otype=otype, + sm_margin=sm_margin, + zero_centered_gamma=zero_centered_gamma, + autotune=autotune, + ) + +# triton drop-in replacement for transformer_engine::pytorch::layernorm_fwd +def te_layernorm_fwd_triton( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + ln_out: torch.Tensor, + quantizer: Quantizer, + otype: tex.DType, + sm_margin: int, + zero_centered_gamma: bool, + autotune: bool = True, +): + return _te_norm_fwd_triton( + kernel='layer', + input_tensor=input, + weight=weight, + bias=bias, + eps=eps, + ln_out=ln_out, + quantizer=quantizer, + otype=otype, + sm_margin=sm_margin, + zero_centered_gamma=zero_centered_gamma, + autotune=autotune, + ) + +def _te_norm_fwd_triton( + kernel: str, + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + ln_out: torch.Tensor, + quantizer: Quantizer, + otype: tex.DType, + sm_margin: int, + zero_centered_gamma: bool, + autotune: bool = True, +): + if kernel not in {'rms', 'layer'}: + raise ValueError(f"Expected `kernel` in ('rms', 'layer') but got {kernel=} instead.") + if eps < 0: + raise ValueError(f"`eps` must be non-negative, but a value of {eps} was passed") + if len(input_tensor.shape) != 2: + raise ValueError( + f"The input must be a 2-dimensional matrix, but an input with {input_tensor.ndim} was passed.") + + device = input_tensor.device + N, H = input_tensor.shape + if weight.shape[0] != H: + raise ValueError( + f"The shape of `weight` must be feature-aligned, " + f"but {weight.shape[0]=} while {input_tensor.shape[1]=}" + ) + IS_FP8 = isinstance(quantizer, Float8Quantizer) + IS_MXFP8 = isinstance(quantizer, MXFP8Quantizer) + IS_FP8_CURRENT_SCALING = isinstance(quantizer, Float8CurrentScalingQuantizer) + BLOCK_SIZE = block_size(input_tensor) + USE_BLOCKED = use_blocked(input_tensor) + NUM_PRGMS = N if kernel=='layer' else num_programs(input_tensor, sm_margin) + MAKE_TRANSPOSE = False + APPLY_ATOMIC = N < 512 or kernel == 'rms' + ATOMIC_REDUCTION_BLOCK_SIZE=256 + + mu = torch.empty((N,), dtype=torch.float32, device=device) if kernel == 'layer' else None + rsigma = torch.empty((N,), dtype=torch.float32, device=device) + torch_out_dtype = ( + otype if isinstance(otype, torch.dtype) + else te_dtype_to_torch_dtype(otype) + ) + out = make_ln_out( + ln_out, + quantizer=quantizer, + input_shape=input_tensor.shape, + out_dtype=torch_out_dtype + ) + amax = None + tl_dtype = None + scale_inv_ptr = None + q_scale = None + out_ptr = out + out_transpose_ptr = None + out_transpose_stride = None + FP8_MAX = None + if IS_FP8: + MAKE_TRANSPOSE = quantizer.columnwise_usage + amax = ( + quantizer.amax if APPLY_ATOMIC else + torch.empty((NUM_PRGMS,), dtype=torch.float32, device=device) + ) + tl_dtype = te_dtype_to_triton_dtype(quantizer.dtype) + scale_inv_ptr = out._scale_inv + q_scale = quantizer.scale + out_ptr = triton.reinterpret(out._data, tl_dtype) + FP8_MAX = get_fp8_max(quantizer.dtype) + if MAKE_TRANSPOSE: + if out._transpose_invalid: + out._transpose = torch.empty( + (out._data.shape[1], out._data.shape[0]), + dtype=out._data.dtype, device=device + ) + out._transpose_invalid = False + out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) + out_transpose_stride = out._transpose.stride(0) + + grid_fwd = lambda meta: (NUM_PRGMS,) + kernel_func = _norm_kernels[kernel][autotune] + input_row_stride = input_tensor.stride(0) + output_row_stride=out_ptr.stride(0) + kwargs = dict( + input_ptr=input_tensor, + output_ptr=out_ptr, + g_ptr=weight, + rsigma_ptr=rsigma, + input_row_stride=input_row_stride, + output_row_stride=output_row_stride, + n_rows=N, n_cols=H, + epsilon=eps, + q_amax_ptr=amax, + q_scale_ptr=q_scale, + scale_inv_ptr=scale_inv_ptr, + out_transpose_ptr=out_transpose_ptr, + out_transpose_stride=out_transpose_stride, + ZERO_CENTERED_GAMMA=zero_centered_gamma, + BLOCK_SIZE=BLOCK_SIZE, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + MAKE_TRANSPOSE=MAKE_TRANSPOSE, + ) + if kernel == 'layer': + kwargs["APPLY_ATOMIC"]=APPLY_ATOMIC + kwargs["PERSISTENT"]=False # TODO: Improve persistent algo performance + kwargs["b_ptr"]=bias + kwargs["mean_ptr"]=mu + elif kernel == "rms": + kwargs["USE_BLOCKED"]=USE_BLOCKED + kwargs["NUM_PRGMS"]=NUM_PRGMS + kwargs["INPUT_ALIGNED_16"]=( + input_tensor.data_ptr() % 16 == 0 and + input_row_stride * getattr(input_tensor.dtype, 'itemsize', 1) % 16 == 0 + ) + kwargs["OUTPUT_ALIGNED_16"]=( + out_ptr.data_ptr() % 16 == 0 and + output_row_stride * getattr(out_ptr.dtype, 'itemsize', 1) % 16 == 0 + ) + + kernel_func[grid_fwd](**kwargs) + + # Reduce and find amax if "not APPLY_ATOMIC" is True for layernorm. + if IS_FP8 and not APPLY_ATOMIC: + _layernorm_fwd_reduce_triton[(triton.cdiv(N, ATOMIC_REDUCTION_BLOCK_SIZE),)]( + amax, + quantizer.amax, + N, ATOMIC_REDUCTION_BLOCK_SIZE, + ) + elif IS_MXFP8 or IS_FP8_CURRENT_SCALING: + out = quantizer.quantize(out, out=ln_out) + + return out, mu, rsigma + + +# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_bwd +def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): + # may take non-contiguous inputs + dz_ = dz.contiguous() + x_ = x.contiguous() + rsigma_ = rsigma.contiguous() + gamma_ = gamma.contiguous() + + dx = torch.empty_like(x_) + dgamma = torch.empty_like(gamma_) + + M, N = x_.shape + blk_size = block_size(x_) + USE_BLOCKED = use_blocked(x_) + NUM_PRGMS = num_programs(x_, sm_margin) + need_reduction = N > 1 + dg_tmp_rows = x_.shape[0] if use_blocked(x_) else num_programs(x_, sm_margin) + dg_tmp = torch.empty(dg_tmp_rows, N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None + + input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) * x_.dtype.itemsize % 16 == 0) + grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(0) * dz_.dtype.itemsize % 16 == 0) + dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(0) * dx.dtype.itemsize % 16 == 0) + dg_target = dg_tmp if need_reduction else dgamma + dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(0) * dg_target.dtype.itemsize % 16 == 0) + grid_bwd = lambda meta: (NUM_PRGMS, ) + _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, + x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, + USE_BLOCKED, NUM_PRGMS, input_aligned_16, grad_output_aligned_16, + dx_aligned_16, dg_aligned_16, num_warps=8) + + if need_reduction: + grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] + _rmsnorm_bwd_dg_reduce_triton[grid_reduce](dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], + BLOCK_SIZE_M=128, BLOCK_SIZE_N=64) + + return dx, dgamma + +# drop in replacement for transformer_engine::pytorch::layernorm_bwd +# TODO: Add support for `sm_margin > 0`. +def te_layernorm_bwd_triton( + dz: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool +): + if sm_margin is not None and sm_margin > 0: + warnings.warn( + '"sm_margin" is not supported in the Triton based backward layer-norm kernel. ' + + f"sm_margin={sm_margin} will be ignored." + ) + M, N = x.shape + # calculate dw and db separately when M is small + IGNORE_DW_DB_IN_FUSED = M <= 512 + tile_num = max(min(256, M // 4), 1) + if M <= 512 and M * N < 64 * 1024 * 1024: + tile_num = M + elif M >= 8192: + tile_num = 2048 + max_fused_size = 32768 // x.element_size() + next_power = triton.next_power_of_2(N) + BLOCK_SIZE = min(max_fused_size, next_power) + # For cases with small M and large N, decrease block size to help with occupancy and register spill + if tile_num == M: + if tile_num > 256: + BLOCK_SIZE = min(BLOCK_SIZE, 2048) + else: + BLOCK_SIZE = min(BLOCK_SIZE, 4096) + USE_BLOCKED = N > BLOCK_SIZE + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + + dx = torch.empty_like(x) + if not IGNORE_DW_DB_IN_FUSED: + _dgamma = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device) + _dbeta = torch.zeros((tile_num, N), dtype=torch.float32, device=gamma.device) + else: + _dgamma = None + _dbeta = None + dgamma = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device) + dbeta = torch.zeros((N,), dtype=gamma.dtype, device=gamma.device) + grid_bwd = (tile_num,) + _layernorm_bwd_dx_fused_triton[grid_bwd]( + dx, + dz, + _dgamma, + _dbeta, + x, + gamma, + mu, + rsigma, + x.stride(0), + N, + ZERO_CENTERED_GAMMA=zero_centered_gamma, + NUM_ROWS=M, + BLOCK_SIZE_N=BLOCK_SIZE, + USE_BLOCKED=USE_BLOCKED, + num_warps=num_warps, + IGNORE_DW_DB=IGNORE_DW_DB_IN_FUSED, + ) + grid_reduce = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE_N"]),) + if not IGNORE_DW_DB_IN_FUSED: + dwdb_block_n = max(16, N // 256) + dwdb_block_n = triton.next_power_of_2(dwdb_block_n) + dwdb_block_m = (64 * 128) // dwdb_block_n + dwdb_block_m = min(triton.next_power_of_2(tile_num), dwdb_block_m) + _layernorm_bwd_dwdb_triton[grid_reduce]( + _dgamma, + _dbeta, + dgamma, + dbeta, + min(tile_num, M), + N, + BLOCK_SIZE_M=dwdb_block_m, + BLOCK_SIZE_N=dwdb_block_n, + ) + else: + dwdb_block_n = max(16, N // 256) + dwdb_block_n = triton.next_power_of_2(dwdb_block_n) + dwdb_block_m = (64 * 128) // dwdb_block_n + dwdb_block_m = min(triton.next_power_of_2(M), dwdb_block_m) + _layernorm_bwd_dwdb_triton_v2[grid_reduce]( + x, + dz, + mu, + rsigma, + x.stride(0), + dgamma, + dbeta, + M, + N, + BLOCK_SIZE_M=dwdb_block_m, + BLOCK_SIZE_N=dwdb_block_n, + ) + + return dx, dgamma, dbeta diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 9f152582e..983559ab7 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -5,40 +5,27 @@ import triton import triton.language as tl from itertools import product -from .norm_common import num_programs, block_size, use_blocked, make_ln_out -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer -from transformer_engine.pytorch.triton_kernels.common import ( - te_dtype_to_torch_dtype, - te_dtype_to_triton_dtype, -) -from .common import get_fp8_max -from ..tensor.quantized_tensor import Quantizer -import transformer_engine_torch as tex - -def dg_tmp_rows(x, sm_margin=None): - return x.shape[0] if use_blocked(x) else num_programs(x, sm_margin) - def get_autotune_config(): return [triton.Config({'waves_per_eu': we}, num_warps=nw) for (we, nw) in product([0, 1, 2, 4], [4, 8, 16])] +# TODO(micky774) Implement fused MXFP8 quantization within the kernel @triton.jit def _rmsnorm_fwd_triton_impl( - output_ptr, input_ptr, - g_ptr, rsigma_ptr, + output_ptr, + g_ptr, + rsigma_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon, - amax_ptr, q_amax_ptr, q_scale_ptr, scale_inv_ptr, out_transpose_ptr, - transpose_row_stride, + out_transpose_stride, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, USE_BLOCKED: tl.constexpr, @@ -122,7 +109,7 @@ def _rmsnorm_fwd_triton_impl( rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) if MAKE_TRANSPOSE: - output_t_ptrs = out_transpose_ptr + cols * transpose_row_stride + row_idx + output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx tl.store(output_t_ptrs, rms_norm.to(output_type)) tl.store(output_ptrs, rms_norm.to(output_type)) @@ -147,7 +134,7 @@ def _rmsnorm_fwd_triton_impl( rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) if MAKE_TRANSPOSE: - output_t_ptrs = out_transpose_ptr + cols * transpose_row_stride + row_idx + output_t_ptrs = out_transpose_ptr + cols * out_transpose_stride + row_idx tl.store(output_t_ptrs, rms_norm.to(output_type), mask=mask) tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) @@ -180,11 +167,10 @@ def _rmsnorm_fwd_triton_impl( rms_norm = rms_norm * scale rms_norm = tl.clamp(rms_norm, -FP8_MAX, FP8_MAX) if MAKE_TRANSPOSE: - output_t_ptrs = out_transpose_ptr + col_offsets * transpose_row_stride + row_idx + output_t_ptrs = out_transpose_ptr + col_offsets * out_transpose_stride + row_idx tl.store(output_t_ptrs, rms_norm.to(output_type), mask=mask) tl.store(output_ptrs, rms_norm.to(output_type), mask=mask) if IS_FP8: - tl.store(amax_ptr + row_start, amax) tl.atomic_max(q_amax_ptr, amax, sem="relaxed") if row_start == 0: scale = tl.load(q_scale_ptr) @@ -363,165 +349,3 @@ def _rmsnorm_bwd_dg_reduce_triton(dg_in_ptr, dg_out_ptr, dg_in_stride, n_rows, n sum_dg = tl.sum(acc, axis=0) tl.store(dg_out_ptr + cols, sum_dg.to(dg_out_ptr.type.element_ty), mask=cols < n_cols) -# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_bwd -def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): - # may take non-contiguous inputs - dz_ = dz.contiguous() - x_ = x.contiguous() - rsigma_ = rsigma.contiguous() - gamma_ = gamma.contiguous() - - dx = torch.empty_like(x_) - dgamma = torch.empty_like(gamma_) - - M, N = x_.shape - blk_size = block_size(x_) - USE_BLOCKED = use_blocked(x_) - NUM_PRGMS = num_programs(x_, sm_margin) - need_reduction = N > 1 - dg_tmp = torch.empty(dg_tmp_rows(x_, sm_margin), N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None - - grid_bwd = lambda meta: (NUM_PRGMS, ) - input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(-1) % 16 == 0) - grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(-1) % 16 == 0) - dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(-1) % 16 == 0) - dg_target = dg_tmp if need_reduction else dgamma - dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(-1) % 16 == 0) - _rmsnorm_bwd_triton[grid_bwd]( - dz_, - x_, - gamma_, - rsigma_, - dx, - dg_target, - x_.stride(0), - dz_.stride(0), - M, - N, - zero_centered_gamma, - blk_size, - USE_BLOCKED, - NUM_PRGMS, - input_aligned_16, - grad_output_aligned_16, - dx_aligned_16, - dg_aligned_16, - num_warps=8, - ) - - if need_reduction: - grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] - _rmsnorm_bwd_dg_reduce_triton[grid_reduce](dg_tmp, dgamma, dg_tmp.stride(0), dg_tmp.shape[0], dg_tmp.shape[1], - BLOCK_SIZE_M=128, BLOCK_SIZE_N=64) - - return dx, dgamma - -# triton drop-in replacement for transformer_engine::pytorch::rmsnorm_fwd -def te_rmsnorm_fwd_triton( - input: torch.Tensor, - weight: torch.Tensor, - eps: float, - ln_out: torch.Tensor, - quantizer: Quantizer, - otype: tex.DType, - sm_margin: int, - zero_centered_gamma: bool, - autotune: bool = True, -): - if eps < 0: - raise ValueError(f"`eps` must be non-negative, but a value of {eps} was passed") - if len(input.shape) != 2: - raise ValueError( - f"The input must be a 2-dimensional matrix, but an input with {input.ndim} was passed.") - - device = input.device - N, H = input.shape - if weight.shape[0] != H: - raise ValueError( - f"The shape of `weight` must be feature-aligned, " - f"but {weight.shape[0]=} while {input.shape[1]=}" - ) - IS_FP8 = isinstance(quantizer, Float8Quantizer) - IS_FP8_CURRENT_SCALING = isinstance(quantizer, Float8CurrentScalingQuantizer) - IS_MXFP8 = isinstance(quantizer, MXFP8Quantizer) - BLOCK_SIZE = block_size(input) - USE_BLOCKED = use_blocked(input) - NUM_PRGMS = num_programs(input, sm_margin) - MAKE_TRANSPOSE = False - - rsigma = torch.empty((N,), dtype=torch.float32, device=device) - torch_out_dtype = ( - otype if isinstance(otype, torch.dtype) - else te_dtype_to_torch_dtype(otype) - ) - out = make_ln_out( - ln_out, - quantizer=quantizer, - input_shape=input.shape, - out_dtype=torch_out_dtype - ) - if IS_FP8: - MAKE_TRANSPOSE = quantizer.columnwise_usage - amax = torch.empty((NUM_PRGMS,), dtype=torch.float32, device=device) - tl_dtype = te_dtype_to_triton_dtype(quantizer.dtype) - scale_inv_ptr = out._scale_inv - q_scale = quantizer.scale - q_amax = quantizer.amax - out_ptr = triton.reinterpret(out._data, tl_dtype) - FP8_MAX = get_fp8_max(quantizer.dtype) - if MAKE_TRANSPOSE: - if out._transpose_invalid: - out._transpose = torch.empty((out._data.shape[1], out._data.shape[0]), dtype=out._data.dtype, device=device) - out._transpose_invalid = False - out_transpose_ptr = triton.reinterpret(out._transpose, tl_dtype) - out_transpose_stride = out._transpose.stride(0) - else: - out_transpose_ptr = None - out_transpose_stride = None - else: - amax = None - tl_dtype = None - scale_inv_ptr = None - q_scale = None - q_amax = None - out_ptr = out - out_transpose_ptr = None - out_transpose_stride = None - FP8_MAX = None - - grid_fwd = lambda meta: (NUM_PRGMS, ) - # TODO(micky774) Implement fused MXFP8 quantization within the kernel - kernel = _rmsnorm_fwd_triton if autotune else _rmsnorm_fwd_triton_impl - input_aligned_16 = (input.data_ptr() % 16 == 0) and (input.stride(-1) % 16 == 0) - out_alignment_tensor = out._data if hasattr(out, "_data") else out - output_aligned_16 = (out_alignment_tensor.data_ptr() % 16 == 0) and ( - out_alignment_tensor.stride(-1) % 16 == 0 - ) - kernel[grid_fwd]( - out_ptr, - input, - weight, - rsigma, - input.stride(0), - out_ptr.stride(0), - N, H, eps, - amax, - q_amax, - q_scale, - scale_inv_ptr, - out_transpose_ptr, - out_transpose_stride, - zero_centered_gamma, - BLOCK_SIZE, - USE_BLOCKED, - NUM_PRGMS, - IS_FP8, - FP8_MAX, - MAKE_TRANSPOSE, - input_aligned_16, - output_aligned_16, - ) - if IS_MXFP8 or IS_FP8_CURRENT_SCALING: - out = quantizer.quantize(out, out=ln_out) - - return out, None, rsigma