From d4cfc5fe95803be3eb81a4c5ee1677841190e4a4 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 9 Dec 2025 06:23:04 +0000 Subject: [PATCH 01/34] [CI] Skipped test_gpt_full_activation_recompute tests for gfx950 --- tests/pytorch/test_numerics.py | 9 ++++++++- transformer_engine/pytorch/utils.py | 6 ++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 1787ab191..16c603422 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -28,7 +28,7 @@ is_bf16_compatible, ) if IS_HIP_EXTENSION: - from transformer_engine.pytorch.utils import is_mi200, is_mi308 + from transformer_engine.pytorch.utils import is_mi200, is_mi308, is_mi350 from transformer_engine.pytorch import ( DotProductAttention, @@ -757,6 +757,13 @@ def test_gpt_full_activation_recompute( pytest.skip("FP8 parameters are not supported in debug mode.") if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) + if IS_HIP_EXTENSION and is_mi350(): + if (dtype == torch.bfloat16 + and not fp8 + and not use_reentrant + and recipe.float8_per_tensor_scaling() + ): + pytest.skip("hipBLASLt does not provide suitable algorithms on MI350 for this config.") config = model_configs[model] torch.compiler.reset() # avoid cache size limit overflow diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 9d0d71fdc..f49b98cb4 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -456,6 +456,12 @@ def is_mi308(): import re return (re.search('AMD Instinct MI308', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) + @functools.lru_cache(maxsize=None) + def is_mi350(): + """check whether this machine is mi35x""" + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + return (props.major, props.minor) == (9, 5) + @functools.lru_cache(maxsize=None) def is_fp8_fnuz(): return IS_HIP_EXTENSION and get_device_compute_capability() == (9, 4) From 7b2d2301b969b3f1ad922949671c886b8329b6b9 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 9 Dec 2025 06:26:43 +0000 Subject: [PATCH 02/34] [CI] Skipped unsupported test_basic_linear_quantized tests on gfx950 --- tests/pytorch/test_fusible_ops.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 78894d97d..d8044e53b 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -36,6 +36,9 @@ from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex +from torch.utils.cpp_extension import IS_HIP_EXTENSION +if IS_HIP_EXTENSION: + from transformer_engine.pytorch.utils import is_mi350 # Import utility functions _current_file = pathlib.Path(__file__).resolve() @@ -918,6 +921,16 @@ def test_basic_linear_quantized( quantized_grad_input: bool, ) -> None: """GEMM with FP8 inputs and outputs""" + if IS_HIP_EXTENSION and is_mi350(): + if ( + quantization + and quantization.startswith("fp8") + and quantized_compute + and (quantized_grad_input or quantized_output) + ): + pytest.skip( + "hipBLASLt does not provide suitable algorithms on gfx950 for this config." + ) if quantization is None: pytest.skip("Skipping case without quantization") self._test_basic_linear( From 0ce9fbef50b4fb56a5f3fa0df0f5b98187c22deb Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Wed, 10 Dec 2025 08:34:33 +0000 Subject: [PATCH 03/34] [CI] Fixed test_numerics, test_norms, test_fused_optimizer failures for gfx950 ci enablement --- tests/pytorch/test_fused_optimizer.py | 7 ++ tests/pytorch/test_numerics.py | 22 ++++- .../pytorch/triton_kernels/rmsnorm.py | 92 +++++++++++++++---- 3 files changed, 101 insertions(+), 20 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index e04f0477b..47e8820ed 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -9,6 +9,7 @@ import pytest import torch from torch import nn +from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.testing._internal.common_device_type import largeTensorTest import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling @@ -18,6 +19,9 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import gpu_autocast_ctx +if IS_HIP_EXTENSION: + from transformer_engine.pytorch.utils import is_mi350 + # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -378,6 +382,7 @@ def test_bf16_exp_avg(self): @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg(self): + model_tol = 3e-2 if IS_HIP_EXTENSION and is_mi350() else None self.gen_precision_aware_test( use_fp8_params=False, param_dtype=torch.bfloat16, @@ -388,6 +393,8 @@ def test_fp8_exp_avg(self): exp_avg_sq_dtype=torch.float32, master_rtol=1e-2, master_atol=1e-2, + model_rtol=model_tol, + model_atol=model_tol, ) @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 16c603422..7662858d1 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2813,10 +2813,24 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): max_seqlen_kv=config.seq_len, ) - torch.testing.assert_close( - y_bshd, - y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), - ) + if IS_HIP_EXTENSION: + tols_thd = dtype_tols(dtype) + if dtype in (torch.float16, torch.bfloat16): + # ROCm fused attention (CK) on THD can produce slightly larger error + tols_thd["atol"] = 2e-3 + _, use_aotriton, use_ck = rocm_attn_backend() + if use_aotriton and not use_ck: + tols_thd["rtol"] = tols_thd["rtol"] * 3 + torch.testing.assert_close( + y_bshd, + y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), + **tols_thd, + ) + else: + torch.testing.assert_close( + y_bshd, + y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), + ) @pytest.mark.parametrize("dtype", param_types) diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index c48a2a9b2..9f152582e 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -46,6 +46,8 @@ def _rmsnorm_fwd_triton_impl( IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, MAKE_TRANSPOSE: tl.constexpr, + INPUT_ALIGNED_16: tl.constexpr, + OUTPUT_ALIGNED_16: tl.constexpr, ): # Enable the transpose cache only in FP8 mode. @@ -78,7 +80,8 @@ def _rmsnorm_fwd_triton_impl( for blk_idx in tl.range(0, n_cols_blks, num_stages=2): cols = blk_idx * BLOCK_SIZE + col_offsets input_ptrs = row_input_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) x = tl.load(input_ptrs).to(tl.float32) sum_squares += tl.sum(x * x, axis=0) @@ -86,7 +89,8 @@ def _rmsnorm_fwd_triton_impl( cols = n_cols_blks * BLOCK_SIZE + col_offsets mask = cols < n_cols input_ptrs = row_input_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) sum_squares += tl.sum(x * x, axis=0) @@ -101,7 +105,8 @@ def _rmsnorm_fwd_triton_impl( for blk_idx in tl.range(0, n_cols_blks, num_stages=2): cols = blk_idx * BLOCK_SIZE + col_offsets input_ptrs = row_input_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) x = tl.load(input_ptrs).to(tl.float32) g_ptrs = g_ptr + cols g = tl.load(g_ptrs).to(tl.float32) @@ -109,6 +114,8 @@ def _rmsnorm_fwd_triton_impl( g += 1 rms_norm = x * norm_factor * g output_ptrs = row_output_ptr + cols + if OUTPUT_ALIGNED_16: + output_ptrs = tl.multiple_of(output_ptrs, (16, )) if IS_FP8: amax_temp = tl.max(tl.abs(rms_norm), axis=-1) amax = tl.maximum(amax, amax_temp) @@ -123,6 +130,8 @@ def _rmsnorm_fwd_triton_impl( cols = n_cols_blks * BLOCK_SIZE + col_offsets mask = cols < n_cols input_ptrs = row_input_ptr + cols + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) g_ptrs = g_ptr + cols g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32) @@ -130,6 +139,8 @@ def _rmsnorm_fwd_triton_impl( g += 1 rms_norm = x * norm_factor * g output_ptrs = row_output_ptr + cols + if OUTPUT_ALIGNED_16: + output_ptrs = tl.multiple_of(output_ptrs, (16, )) if IS_FP8: amax_temp = tl.max(tl.abs(rms_norm), axis=-1) amax = tl.maximum(amax, amax_temp) @@ -144,7 +155,8 @@ def _rmsnorm_fwd_triton_impl( mask = col_offsets < n_cols for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets - input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) row_norm = row * row @@ -160,7 +172,8 @@ def _rmsnorm_fwd_triton_impl( rms_norm = row * norm_factor * g output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets - output_ptrs = tl.multiple_of(output_ptrs, (16, )) + if OUTPUT_ALIGNED_16: + output_ptrs = tl.multiple_of(output_ptrs, (16, )) if IS_FP8: amax_temp = tl.max(tl.abs(rms_norm), axis=-1) amax = tl.maximum(amax, amax_temp) @@ -184,7 +197,9 @@ def _rmsnorm_fwd_triton_impl( @triton.jit def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr, input_row_stride, output_row_stride, n_rows, n_cols, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, - USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr): + USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr, + INPUT_ALIGNED_16: tl.constexpr, GRAD_OUTPUT_ALIGNED_16: tl.constexpr, + DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr): row_start = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) # tl.assume(input_row_stride >= 0) @@ -209,8 +224,10 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d input_ptrs = row_input_ptr + cols grad_output_ptrs = row_grad_output_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if GRAD_OUTPUT_ALIGNED_16: + grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) x = tl.load(input_ptrs).to(tl.float32) grad_output = tl.load(grad_output_ptrs).to(tl.float32) @@ -241,8 +258,10 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d input_ptrs = row_input_ptr + cols grad_output_ptrs = row_grad_output_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if GRAD_OUTPUT_ALIGNED_16: + grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) x = tl.load(input_ptrs).to(tl.float32) grad_output = tl.load(grad_output_ptrs).to(tl.float32) @@ -255,10 +274,14 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d n_cols) dx_ptrs = row_dx_ptr + cols + if DX_ALIGNED_16: + dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty)) dg = grad_output * x * norm_factor dg_ptrs = row_dg_ptr + cols + if DG_ALIGNED_16: + dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) tl.store(dg_ptrs, dg.to(tl.float32)) # Handle remainder @@ -277,10 +300,14 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d n_cols) dx_ptrs = row_dx_ptr + cols + if DX_ALIGNED_16: + dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) dg = grad_output * x * norm_factor dg_ptrs = row_dg_ptr + cols + if DG_ALIGNED_16: + dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) tl.store(dg_ptrs, dg.to(tl.float32), mask=mask) else: @@ -292,9 +319,12 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d grad_output_ptrs = grad_output_ptr + row_idx * output_row_stride + col_offsets dx_ptrs = dx_ptr + row_idx * input_row_stride + col_offsets - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) - dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if GRAD_OUTPUT_ALIGNED_16: + grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) + if DX_ALIGNED_16: + dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) @@ -352,9 +382,32 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): dg_tmp = torch.empty(dg_tmp_rows(x_, sm_margin), N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None grid_bwd = lambda meta: (NUM_PRGMS, ) - _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, - x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, - USE_BLOCKED, NUM_PRGMS, num_warps=8) + input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(-1) % 16 == 0) + grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(-1) % 16 == 0) + dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(-1) % 16 == 0) + dg_target = dg_tmp if need_reduction else dgamma + dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(-1) % 16 == 0) + _rmsnorm_bwd_triton[grid_bwd]( + dz_, + x_, + gamma_, + rsigma_, + dx, + dg_target, + x_.stride(0), + dz_.stride(0), + M, + N, + zero_centered_gamma, + blk_size, + USE_BLOCKED, + NUM_PRGMS, + input_aligned_16, + grad_output_aligned_16, + dx_aligned_16, + dg_aligned_16, + num_warps=8, + ) if need_reduction: grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] @@ -439,6 +492,11 @@ def te_rmsnorm_fwd_triton( grid_fwd = lambda meta: (NUM_PRGMS, ) # TODO(micky774) Implement fused MXFP8 quantization within the kernel kernel = _rmsnorm_fwd_triton if autotune else _rmsnorm_fwd_triton_impl + input_aligned_16 = (input.data_ptr() % 16 == 0) and (input.stride(-1) % 16 == 0) + out_alignment_tensor = out._data if hasattr(out, "_data") else out + output_aligned_16 = (out_alignment_tensor.data_ptr() % 16 == 0) and ( + out_alignment_tensor.stride(-1) % 16 == 0 + ) kernel[grid_fwd]( out_ptr, input, @@ -460,6 +518,8 @@ def te_rmsnorm_fwd_triton( IS_FP8, FP8_MAX, MAKE_TRANSPOSE, + input_aligned_16, + output_aligned_16, ) if IS_MXFP8 or IS_FP8_CURRENT_SCALING: out = quantizer.quantize(out, out=ln_out) From 53a27bd78792b07b269d1f3be4a19c0e70f61f78 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 12 Dec 2025 09:41:21 +0000 Subject: [PATCH 04/34] [CI] Disabled gfx950 support until FP8 GEMM layout coverage is verified with hipblaslt --- transformer_engine/jax/quantize/device_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/quantize/device_utils.py b/transformer_engine/jax/quantize/device_utils.py index ca90ba9fb..5fc0f0473 100644 --- a/transformer_engine/jax/quantize/device_utils.py +++ b/transformer_engine/jax/quantize/device_utils.py @@ -35,7 +35,8 @@ def get_device_compute_capability(gpu_id: int = 0) -> int: def is_fp8_gemm_with_all_layouts_supported() -> bool: """Return True if using Blackwell architecture, False otherwise.""" compute_capability = get_device_compute_capability() - if is_hip_extension(): + # Enable once FP8 GEMM layout coverage is validated with hipblaslt. + # if is_hip_extension(): # gfx950 --> NV blackwell - return compute_capability == 95 + # return compute_capability == 95 return 100 <= compute_capability < 120 From a602c3e7436a6db8cb784ce113f5f926a08019b7 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 12 Dec 2025 22:17:53 +0000 Subject: [PATCH 05/34] [CI] [gfx950] Disable cudaGraph for gemmm and grouped-gemm --- transformer_engine/jax/csrc/extensions/gemm.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index ba2d65e3e..ec61f2a97 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include "transformer_engine/gemm.h" +#include #include #include #include @@ -21,6 +22,13 @@ namespace transformer_engine { namespace jax { +#ifdef USE_ROCM +// hipblaslt GEMM is not graph-capture safe on ROCm. +constexpr auto GemmFFI_CudaGraph_Traits = std::initializer_list{}; +#else +constexpr auto GemmFFI_CudaGraph_Traits = FFI_CudaGraph_Traits; +#endif + static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { // Move the pointer to the next 256B aligned address return reinterpret_cast((reinterpret_cast(ptr) + 255) & @@ -200,7 +208,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("fuse_gelu") .Attr("grad") .Attr("use_split_accumulator"), - FFI_CudaGraph_Traits); + GemmFFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, @@ -593,7 +601,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("scaling_mode") .Attr("has_bias") .Attr("is_grouped_dense_wgrad"), - FFI_CudaGraph_Traits); + GemmFFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine From bda7a47da1ba1bc6e249bee60a9977a10858fdb4 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Mon, 15 Dec 2025 23:14:37 +0000 Subject: [PATCH 06/34] Addressed reviews --- tests/pytorch/test_fused_optimizer.py | 6 ++---- tests/pytorch/test_fusible_ops.py | 5 ++--- tests/pytorch/test_numerics.py | 4 ++-- transformer_engine/jax/csrc/extensions/gemm.cpp | 2 ++ transformer_engine/jax/quantize/device_utils.py | 3 ++- transformer_engine/pytorch/utils.py | 6 ------ 6 files changed, 10 insertions(+), 16 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 47e8820ed..32abea1de 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -18,9 +18,7 @@ from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import gpu_autocast_ctx - -if IS_HIP_EXTENSION: - from transformer_engine.pytorch.utils import is_mi350 +from transformer_engine.pytorch.utils import get_device_compute_capability # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -382,7 +380,7 @@ def test_bf16_exp_avg(self): @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg(self): - model_tol = 3e-2 if IS_HIP_EXTENSION and is_mi350() else None + model_tol = 3e-2 if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) else None self.gen_precision_aware_test( use_fp8_params=False, param_dtype=torch.bfloat16, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index d8044e53b..1db81ec23 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -35,10 +35,9 @@ ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible +from transformer_engine.pytorch.utils import get_device_compute_capability import transformer_engine_torch as tex from torch.utils.cpp_extension import IS_HIP_EXTENSION -if IS_HIP_EXTENSION: - from transformer_engine.pytorch.utils import is_mi350 # Import utility functions _current_file = pathlib.Path(__file__).resolve() @@ -921,7 +920,7 @@ def test_basic_linear_quantized( quantized_grad_input: bool, ) -> None: """GEMM with FP8 inputs and outputs""" - if IS_HIP_EXTENSION and is_mi350(): + if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): if ( quantization and quantization.startswith("fp8") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 7662858d1..22f0ecb69 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -28,7 +28,7 @@ is_bf16_compatible, ) if IS_HIP_EXTENSION: - from transformer_engine.pytorch.utils import is_mi200, is_mi308, is_mi350 + from transformer_engine.pytorch.utils import is_mi200, is_mi308 from transformer_engine.pytorch import ( DotProductAttention, @@ -757,7 +757,7 @@ def test_gpt_full_activation_recompute( pytest.skip("FP8 parameters are not supported in debug mode.") if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) - if IS_HIP_EXTENSION and is_mi350(): + if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): if (dtype == torch.bfloat16 and not fp8 and not use_reentrant diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index ec61f2a97..3e0842ef5 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/jax/quantize/device_utils.py b/transformer_engine/jax/quantize/device_utils.py index 5fc0f0473..3f04d674c 100644 --- a/transformer_engine/jax/quantize/device_utils.py +++ b/transformer_engine/jax/quantize/device_utils.py @@ -36,7 +36,8 @@ def is_fp8_gemm_with_all_layouts_supported() -> bool: """Return True if using Blackwell architecture, False otherwise.""" compute_capability = get_device_compute_capability() # Enable once FP8 GEMM layout coverage is validated with hipblaslt. - # if is_hip_extension(): + if is_hip_extension(): # gfx950 --> NV blackwell # return compute_capability == 95 + return False return 100 <= compute_capability < 120 diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index f49b98cb4..9d0d71fdc 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -456,12 +456,6 @@ def is_mi308(): import re return (re.search('AMD Instinct MI308', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) - @functools.lru_cache(maxsize=None) - def is_mi350(): - """check whether this machine is mi35x""" - props = torch.cuda.get_device_properties(torch.cuda.current_device()) - return (props.major, props.minor) == (9, 5) - @functools.lru_cache(maxsize=None) def is_fp8_fnuz(): return IS_HIP_EXTENSION and get_device_compute_capability() == (9, 4) From e8d543c187e66ef498327c4d8ce3f8114d77c7fe Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 16 Dec 2025 16:18:22 +0000 Subject: [PATCH 07/34] [CI] Add MI355 nodes to github actions workflow --- .github/workflows/rocm-ci.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 67af6dc9f..cbd8206f7 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -40,9 +40,12 @@ concurrency: jobs: build_and_test: - name: Build and Test on GPU + name: Build and Test on GPU (${{ matrix.runner }}) timeout-minutes: 720 - runs-on: linux-mi325-8 + runs-on: ${{ matrix.runner }} + strategy: + matrix: + runner: [linux-mi325-8, linux-mi355-8] steps: - name: Checkout repository uses: actions/checkout@v4 From 660d8a0187de072ef10f4be26067d09d14b64fa9 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 16 Dec 2025 16:23:07 +0000 Subject: [PATCH 08/34] [CI] Update docker image --- ci/ci_config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/ci_config.json b/ci/ci_config.json index 9ef4d03a2..cb5135817 100644 --- a/ci/ci_config.json +++ b/ci/ci_config.json @@ -1,6 +1,6 @@ { "docker_images": { - "default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.0.2_ubuntu22.04_py3.10_pytorch_release-2.7_9015dfdf_jax_v0.6.0_fa-v2.8.0", + "default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.1.1_ubuntu22.04_py3.11_pytorch_release_2.8_63e525b2_jax_0.6.0_fa-2.8.0", "release_v1.13": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273", "release_v1.14": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273" } From c0fd7915c975435df194ab7381f34d3a84dac389 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 16 Dec 2025 16:46:36 +0000 Subject: [PATCH 09/34] [CI] add MI355 runner matrix and keep matrix legs independent --- .github/workflows/rocm-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index cbd8206f7..560a3ef53 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -44,6 +44,7 @@ jobs: timeout-minutes: 720 runs-on: ${{ matrix.runner }} strategy: + fail-fast: false matrix: runner: [linux-mi325-8, linux-mi355-8] steps: From ed1363883a541c3749a60f58a733c5ccead865bd Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Thu, 18 Dec 2025 06:09:32 +0000 Subject: [PATCH 10/34] Skip unstable Gemm tests on gfx950 --- tests/cpp/operator/test_cublaslt_gemm.cu | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 071470bdf..fdecefbd2 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -228,6 +228,14 @@ void performTest(const TestParams& params) { #ifdef __HIP_PLATFORM_AMD__ + // Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable. + // Re-enable after ROCm 7.2 once hipBLASLt fixes land. + if (prop.major == 9 && prop.minor == 5 && + params.transa && !params.transb && + params.m == 2304 && params.k == 768 && params.n == 4096) { + GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 until ROCm 7.2"; + } + // Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0. // hipBLASLt currently supports this config only bool fp8_gelu_fusion_config = false; @@ -450,6 +458,14 @@ void performDqTest(const TestParams ¶ms) { cudaDeviceProp prop; (void)cudaGetDeviceProperties(&prop, 0); + // Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable. + // Re-enable after ROCm 7.2 once hipBLASLt fixes land. + if (prop.major == 9 && prop.minor == 5 && + params.transa && !params.transb && + params.m == 2304 && params.k == 768 && params.n == 4096) { + GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 until ROCm 7.2"; + } + bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5); if (!mxfp8_supported) { GTEST_SKIP() << "MXFP8 is not supported in current config"; From 6bde4cbf675f29e2d0fa239a9c9111f843bc7be7 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Thu, 18 Dec 2025 17:42:50 +0000 Subject: [PATCH 11/34] Addressed reviews --- ci/ci_config.json | 2 +- tests/pytorch/test_fused_optimizer.py | 2 ++ tests/pytorch/test_numerics.py | 11 ++++------- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/ci/ci_config.json b/ci/ci_config.json index cb5135817..a7b3d5d6c 100644 --- a/ci/ci_config.json +++ b/ci/ci_config.json @@ -1,6 +1,6 @@ { "docker_images": { - "default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.1.1_ubuntu22.04_py3.11_pytorch_release_2.8_63e525b2_jax_0.6.0_fa-2.8.0", + "default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.1.1_ubuntu22.04_py3.11_pytorch_release_2.8_63e525b2_jax_0.7.1_fa-2.8.0", "release_v1.13": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273", "release_v1.14": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273" } diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 32abea1de..3527bb9a6 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 22f0ecb69..b2885d677 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2813,14 +2813,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): max_seqlen_kv=config.seq_len, ) - if IS_HIP_EXTENSION: + if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): tols_thd = dtype_tols(dtype) - if dtype in (torch.float16, torch.bfloat16): - # ROCm fused attention (CK) on THD can produce slightly larger error - tols_thd["atol"] = 2e-3 - _, use_aotriton, use_ck = rocm_attn_backend() - if use_aotriton and not use_ck: - tols_thd["rtol"] = tols_thd["rtol"] * 3 + # On gfx950 the results for THD are different + # that results in lower final result precision + tols_thd["atol"] = 2e-3 torch.testing.assert_close( y_bshd, y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), From 8baf254c3c8123e47e9bd4adf4fc39a1103d2a19 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Sat, 20 Dec 2025 00:22:52 +0000 Subject: [PATCH 12/34] Guard gfx950 TN skip by ROCm version and adjust MXFP8 Dq test size --- tests/cpp/operator/test_cublaslt_gemm.cu | 26 ++++++++---------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index fdecefbd2..91e586ee3 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -29,8 +29,8 @@ std::vector> test_case_sizes = { }; std::vector> test_case_sizes_mxfp8 = { - {2304, 768, 4096}, -}; + {768, 3072, 4096}, +}; // A, B, Bias, Gelu, D // Bias type choose as bf16 in use_fp8, D_type otherwise @@ -228,13 +228,13 @@ void performTest(const TestParams& params) { #ifdef __HIP_PLATFORM_AMD__ - // Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable. - // Re-enable after ROCm 7.2 once hipBLASLt fixes land. - if (prop.major == 9 && prop.minor == 5 && - params.transa && !params.transb && - params.m == 2304 && params.k == 768 && params.n == 4096) { - GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 until ROCm 7.2"; - } + #if HIP_VERSION < 70200000 + if (prop.major == 9 && prop.minor == 5 && + params.transa && !params.transb && + params.m == 2304 && params.k == 768 && params.n == 4096) { + GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 for ROCm < 7.2"; + } + #endif // Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0. // hipBLASLt currently supports this config only @@ -458,14 +458,6 @@ void performDqTest(const TestParams ¶ms) { cudaDeviceProp prop; (void)cudaGetDeviceProperties(&prop, 0); - // Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable. - // Re-enable after ROCm 7.2 once hipBLASLt fixes land. - if (prop.major == 9 && prop.minor == 5 && - params.transa && !params.transb && - params.m == 2304 && params.k == 768 && params.n == 4096) { - GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 until ROCm 7.2"; - } - bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5); if (!mxfp8_supported) { GTEST_SKIP() << "MXFP8 is not supported in current config"; From 6cddbec174c543b9806c69aed890920e5f848587 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 6 Jan 2026 16:43:10 +0000 Subject: [PATCH 13/34] Removed ROCM7.2 guards --- tests/cpp/operator/test_cublaslt_gemm.cu | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 91e586ee3..1ebe9df96 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -228,14 +228,6 @@ void performTest(const TestParams& params) { #ifdef __HIP_PLATFORM_AMD__ - #if HIP_VERSION < 70200000 - if (prop.major == 9 && prop.minor == 5 && - params.transa && !params.transb && - params.m == 2304 && params.k == 768 && params.n == 4096) { - GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 for ROCm < 7.2"; - } - #endif - // Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0. // hipBLASLt currently supports this config only bool fp8_gelu_fusion_config = false; From 5a832950ecb56eed59ff1d4580585d534d3fed51 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Wed, 7 Jan 2026 16:21:10 +0000 Subject: [PATCH 14/34] Reverted ROCM7.2 guards --- tests/cpp/operator/test_cublaslt_gemm.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 1ebe9df96..91e586ee3 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -228,6 +228,14 @@ void performTest(const TestParams& params) { #ifdef __HIP_PLATFORM_AMD__ + #if HIP_VERSION < 70200000 + if (prop.major == 9 && prop.minor == 5 && + params.transa && !params.transb && + params.m == 2304 && params.k == 768 && params.n == 4096) { + GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 for ROCm < 7.2"; + } + #endif + // Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0. // hipBLASLt currently supports this config only bool fp8_gelu_fusion_config = false; From e4d2d21d98fd4913361e1e0c0de862de65a104ed Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 12 Jan 2026 15:41:30 -0600 Subject: [PATCH 15/34] Corrected Normalization scale_inv padding removal --- .../jax/cpp_extensions/normalization.py | 12 +++++------- transformer_engine/jax/quantize/helper.py | 2 ++ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 8885ae2ea..b2cf3b236 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -321,14 +321,12 @@ def impl( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(x.shape, is_padded=False) - # slice out padding for mxfp8, noop for DelayedScaling - scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( - rowwise_scale_inv_shape - ) + # Slice out the padding for mxfp8 - the ROCm kernel writes to strided + # 2D positions, not contiguous. + # For 1D MXFP8: allocated [padded_rows, padded_cols], kernel writes [:actual_rows, :actual_cols] + scale_inv = scale_inv[tuple(slice(0, dim) for dim in rowwise_scale_inv_shape)] if is_2x: - colwise_scale_inv = colwise_scale_inv.flatten()[ - : reduce(operator.mul, colwise_scale_inv_shape, 1) - ].reshape(colwise_scale_inv_shape) + colwise_scale_inv = colwise_scale_inv[tuple(slice(0, dim) for dim in colwise_scale_inv_shape)] return ( out, colwise_out, diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 0b9659a46..1941877a8 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -85,6 +85,8 @@ def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: A tuple of (bool, str) indicating support and any error message """ if is_hip_extension(): + if gpu_arch >= 95: + return True, "" return False, "FP8 block scaled gemm not yet supported for ROCm" if gpu_arch >= 100: # blackwell and above return True, "" From 93a69a97c5a6c3b03c74ebd92a3a32d155fbcbad Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 16 Jan 2026 14:59:01 -0600 Subject: [PATCH 16/34] Updated cast behavior, added safeguards around MXFP8 GEMM --- tests/jax/test_custom_call_compute.py | 39 +++++++++++++++---- tests/jax/test_distributed_layernorm_mlp.py | 14 +++++++ transformer_engine/common/gemm/rocm_gemm.cu | 6 +++ .../common/normalization/common.h | 2 +- .../common/util/rocm_cast_kernels.cuh | 6 +-- .../jax/cpp_extensions/normalization.py | 32 +-------------- .../jax/csrc/extensions/gemm.cpp | 2 + 7 files changed, 60 insertions(+), 41 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 20a8037eb..9d727ffb0 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -11,6 +11,7 @@ from functools import reduce from typing import Union import operator +from packaging import version from utils import ( assert_allclose, @@ -96,7 +97,7 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype) elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING: # Compare MXFP8 scales as uint8 - assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8)) + assert_allclose(a.scale_inv.view(jnp.uint8), b.scale_inv.view(jnp.uint8)) else: raise ValueError(f"Unsupported scaling mode {a.scaling_mode}") assert_allclose(a.data, b.data) @@ -874,6 +875,22 @@ def _use_jax_fp8_gemm(enabled=False): elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") +def _check_mxfp8_gemm_support( + with_jax_gemm, + m, n, k, + x_qtype=jnp_float8_e4m3_type, + w_qtype=jnp_float8_e4m3_type +): + if not with_jax_gemm: + if jnp_float8_e5m2_type in (x_qtype, w_qtype): + pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") + if (m % 16 != 0) or (n % 16 != 0) or (k % 128 != 0): + pytest.skip( + f"Input shape {(m, k)} x {(k, n)} is not supported by MXFP8 GEMM." + ) + else: + if version.parse(jax.__version__) < version.parse("0.8.0"): + pytest.skip("MXFP8 not supported by JAX GEMM yet.") class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): @@ -919,12 +936,8 @@ def test_gemm_bf16(self, m, n, k, data_layout): @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): - if ( - not with_jax_gemm - and scaling_mode.is_1d_block_scaling() - and jnp_float8_e5m2_type in (x_qtype, w_qtype) - ): - pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") + if scaling_mode.is_1d_block_scaling(): + _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, x_qtype, w_qtype) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( @@ -981,6 +994,8 @@ def ref_func(x, w, data_layout): def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + if scaling_mode.is_1d_block_scaling(): + _check_mxfp8_gemm_support(with_jax_gemm, m, n, k) key = jax.random.PRNGKey(1) bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) @@ -1054,6 +1069,8 @@ def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_g """ Test layernorm_dense VJP Rule """ + if scaling_mode.is_1d_block_scaling(): + _check_mxfp8_gemm_support(with_jax_gemm, m, n, k) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 @@ -1137,6 +1154,8 @@ def test_layernorm_mlp_grad( """ Test layernorm_mlp VJP Rule """ + if scaling_mode.is_1d_block_scaling(): + _check_mxfp8_gemm_support(with_jax_gemm, m, n, k) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 @@ -1344,6 +1363,9 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("layout", ["NN"]) def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): + if scaling_mode.is_1d_block_scaling(): + pytest.skip("MXFP8 grouped GEMM is not fully supported yet in ROCm.") + fwd_dtype, bwd_dtype = fwd_bwd_dtype quantizer_set = QuantizerFactory.create_set( scaling_mode=scaling_mode, @@ -1429,6 +1451,9 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): ) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): + if scaling_mode.is_1d_block_scaling(): + pytest.skip("MXFP8 grouped GEMM is not fully supported yet in ROCm.") + fwd_dtype, bwd_dtype = fwd_bwd_dtype dtype = jnp.bfloat16 x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 694610978..1b3b4c41d 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -4,6 +4,8 @@ # # See LICENSE for license information. from typing import Callable, Sequence, Union, Optional +from packaging import version + import pytest import jax @@ -163,6 +165,12 @@ def _test_layernorm_mlp_grad( use_shardy, with_jax_gemm, ): + if ( + with_jax_gemm + and version.parse(jax.__version__) < version.parse("0.8.0") + and isinstance(fp8_recipe, recipe.MXFP8BlockScaling) + ): + pytest.skip("MXFP8 not supported by JAX GEMM yet.") jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config layernorm_type = "rmsnorm" @@ -327,6 +335,12 @@ def _test_layernorm_mlp( use_shardy, with_jax_gemm, ): + if ( + with_jax_gemm + and version.parse(jax.__version__) < version.parse("0.8.0") + and isinstance(fp8_recipe, recipe.MXFP8BlockScaling) + ): + pytest.skip("MXFP8 not supported by JAX GEMM yet.") jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape layernorm_type = "rmsnorm" diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 94f1bbfbd..9e5f686e3 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1528,6 +1528,12 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK((transb ? B0 : B1) == k, "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, ")"); + // Check that K is a multiple of 128, and M/N are multiples of 16 for MXFP8 GEMM + if (inputA->scaling_mode == NVTE_MXFP8_1D_SCALING || inputB->scaling_mode == NVTE_MXFP8_1D_SCALING) { + NVTE_CHECK((k % 128) == 0, "GEMM K dimension must be multiple of 128 for MXFP8 scaling (got K=", k, ")"); + NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")"); + NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")"); + } const int lda = transa ? k : m; const int ldb = transb ? n : k; diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 31d2c0b74..c4a0a6f8f 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -464,7 +464,7 @@ void rocm_norm_mxfp8_quantize(LaunchParams &launch_params) TRANSFORMER_ENGINE_SWITCH_CONDITION( !(cols % (32 * sizeof(compute_t))), IS_ALIGNED, cast_mxfp8_2D_kernel<<>>( + SCALE_DIM_Y, scale_dim_X_rowwise, IS_ALIGNED><<>>( reinterpret_cast(launch_params.params.z), nullptr, reinterpret_cast(launch_params.z_tensor->data.dptr), diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index d62350e0a..ac0ce2174 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -45,7 +45,7 @@ constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // template + size_t SCALE_DIM_X, bool IS_ALIGNED> __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) cast_mxfp8_2D_kernel(const IType *input_ptr, const IType *act_input_ptr, @@ -221,7 +221,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); const e8m0_t biased_exponent = - float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation + float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); // Only single thread writes the computed scaling factor if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { @@ -278,7 +278,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __builtin_assume(amax >= 0); block_amax = fmaxf(block_amax, amax); - const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation + const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index b2cf3b236..06ff2f0f2 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -324,9 +324,9 @@ def impl( # Slice out the padding for mxfp8 - the ROCm kernel writes to strided # 2D positions, not contiguous. # For 1D MXFP8: allocated [padded_rows, padded_cols], kernel writes [:actual_rows, :actual_cols] - scale_inv = scale_inv[tuple(slice(0, dim) for dim in rowwise_scale_inv_shape)] + scale_inv = jax.lax.slice(scale_inv, [0] * scale_inv.ndim, rowwise_scale_inv_shape) if is_2x: - colwise_scale_inv = colwise_scale_inv[tuple(slice(0, dim) for dim in colwise_scale_inv_shape)] + colwise_scale_inv = jax.lax.slice(colwise_scale_inv, [0] * colwise_scale_inv.ndim, colwise_scale_inv_shape) return ( out, colwise_out, @@ -994,20 +994,6 @@ def layernorm_fwd( ) colwise_scale_inv = rowwise_scale_inv - # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. - # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. - # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: - rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( - x.shape, is_padded=False - ) - rowwise_scale_inv = rowwise_scale_inv.flatten()[ - : reduce(operator.mul, rowwise_unpadded_shape) - ].reshape(rowwise_unpadded_shape) - colwise_scale_inv = colwise_scale_inv.flatten()[ - : reduce(operator.mul, colwise_unpadded_shape) - ].reshape(colwise_unpadded_shape) - scaled_tensor = ScaledTensorFactory.create( data=rowwise_casted_output, scale_inv=rowwise_scale_inv, @@ -1194,20 +1180,6 @@ def rmsnorm_fwd( ) colwise_scale_inv = rowwise_scale_inv - # cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs. - # So here we need to slice out the zero tail and reshape it to the unpadded scale shape. - # The ScaledTensorFactory takes care of padding when creating the ScaledTensor - if quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING: - rowwise_unpadded_shape, colwise_unpadded_shape = quantizer.get_scale_shapes( - x.shape, is_padded=False - ) - rowwise_scale_inv = rowwise_scale_inv.flatten()[ - : reduce(operator.mul, rowwise_unpadded_shape) - ].reshape(rowwise_unpadded_shape) - colwise_scale_inv = colwise_scale_inv.flatten()[ - : reduce(operator.mul, colwise_unpadded_shape) - ].reshape(colwise_unpadded_shape) - scaled_tensor = ScaledTensorFactory.create( data=rowwise_casted_output, scale_inv=rowwise_scale_inv, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 3e0842ef5..8164c5bdd 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -360,11 +360,13 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto bias_shape = std::vector{has_bias ? n : 0}; const int arch = cuda::sm_arch(); + #ifndef __HIP_PLATFORM_AMD__ if (arch < 100 && is_fp8_gemm) { NVTE_CHECK(!lhs_is_trans && rhs_is_trans, "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); } + #endif // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; From 3940cf5498db240a717bace6666bcdaa492b1ad2 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 19 Jan 2026 10:55:46 -0600 Subject: [PATCH 17/34] Add partial progress --- tests/cpp/operator/test_cublaslt_gemm.cu | 9 ++++ tests/jax/test_custom_call_compute.py | 42 ++++++++++++------- tests/jax/test_distributed_layernorm_mlp.py | 4 +- .../jax/csrc/extensions/gemm.cpp | 23 +++++----- 4 files changed, 52 insertions(+), 26 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 61ca86a1e..5ba8411c6 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -30,6 +30,15 @@ std::vector> test_case_sizes = { std::vector> test_case_sizes_mxfp8 = { {768, 3072, 4096}, + {32, 128, 32}, + {32, 128, 64}, + {64, 128, 32}, + {128, 128, 64}, + {128, 128, 128}, + {256, 128, 128}, + {128, 256, 128}, + {128, 128, 256}, + {256, 256, 256}, }; // A, B, Bias, Gelu, D diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9d727ffb0..1e882dcc0 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -57,7 +57,7 @@ (2048, 2048, 1024), (2048, 1024, 1024), ] - +TEST_SHAPES = [(64, 32, 64), (128, 64, 128), (128, 256, 256)] jnp_float8_e4m3_type = get_jnp_float8_e4m3_type() jnp_float8_e5m2_type = get_jnp_float8_e5m2_type() @@ -882,8 +882,8 @@ def _check_mxfp8_gemm_support( w_qtype=jnp_float8_e4m3_type ): if not with_jax_gemm: - if jnp_float8_e5m2_type in (x_qtype, w_qtype): - pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") + # if jnp_float8_e5m2_type in (x_qtype, w_qtype): + # pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") if (m % 16 != 0) or (n % 16 != 0) or (k % 128 != 0): pytest.skip( f"Input shape {(m, k)} x {(k, n)} is not supported by MXFP8 GEMM." @@ -930,7 +930,7 @@ def test_gemm_bf16(self, m, n, k, data_layout): assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) + @pytest_parametrize_wrapper("m,n,k", TEST_SHAPES) @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @@ -988,17 +988,21 @@ def ref_func(x, w, data_layout): assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) + @pytest_parametrize_wrapper("m,n,k", TEST_SHAPES) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + key = jax.random.PRNGKey(1) if scaling_mode.is_1d_block_scaling(): + # Check for first GEMM _check_mxfp8_gemm_support(with_jax_gemm, m, n, k) - - key = jax.random.PRNGKey(1) - bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) + # Check for second GEMM + _check_mxfp8_gemm_support(with_jax_gemm, m, k, n) + bias = None + else: + bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) def primitive_func(x, w, bias, contracting_dims, quantizer_set): primitive_out = dense( @@ -1007,9 +1011,10 @@ def primitive_func(x, w, bias, contracting_dims, quantizer_set): return jnp.mean(primitive_out) def ref_func(x, w, bias, data_layout): - return jnp.mean( - self._ref_gemm_with_jnp_dot(x, w, data_layout) + jnp.expand_dims(bias, axis=0) - ) + out = self._ref_gemm_with_jnp_dot(x, w, data_layout) + if bias is not None: + out = out + jnp.expand_dims(bias, axis=0) + return jnp.mean(out) value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) @@ -1035,7 +1040,8 @@ def ref_func(x, w, bias, data_layout): assert_allclose(primitive_out, ref_out, dtype=jnp_float8_e4m3_type) assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp_float8_e5m2_type) + if bias is not None: + assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp_float8_e5m2_type) @pytest.fixture(name="random_inputs") @@ -1061,7 +1067,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) + @pytest.mark.parametrize("m,n,k", TEST_SHAPES) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) @@ -1070,7 +1076,10 @@ def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_g Test layernorm_dense VJP Rule """ if scaling_mode.is_1d_block_scaling(): + # Check for fwd GEMM _check_mxfp8_gemm_support(with_jax_gemm, m, n, k) + # Check for bwd GEMM + _check_mxfp8_gemm_support(with_jax_gemm, m, k, n) # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 @@ -1142,7 +1151,7 @@ def ref_func(x, w, gamma, beta): assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp_float8_e5m2_type) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) + @pytest.mark.parametrize("m,n,k", TEST_SHAPES) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @@ -1155,7 +1164,12 @@ def test_layernorm_mlp_grad( Test layernorm_mlp VJP Rule """ if scaling_mode.is_1d_block_scaling(): + # Check for first GEMM _check_mxfp8_gemm_support(with_jax_gemm, m, n, k) + # Check for second GEMM + _check_mxfp8_gemm_support(with_jax_gemm, m, k, n) + if use_bias: + pytest.skip("Bias is not supported for MXFP8 GEMM.") # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 1b3b4c41d..c23942a0b 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -66,7 +66,9 @@ LN_BIAS_AXES = (W_NO_SHARD_AXES,) BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES) BIAS_2_AXES = (W_NO_SHARD_AXES,) -INTERMEDIATE = 64 +# We set to 128 to ensure compatibility with MXFP8 GEMM which requires the +# reduction dim to be multiple of 128 +INTERMEDIATE = 128 # Only test with FSDP and TP as DP is not used diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 8164c5bdd..96ef18951 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -74,6 +74,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( // Swizzle scaling factors for MXFP8 if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + return std::make_tuple(std::move(input), input_shape); // Get the swizzle buffer NVTE_CHECK(swizzled_scale_inv->element_count() > 0, "Missing swizzled inverse scale buffer in the JAX primitive."); @@ -555,17 +556,17 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t num_non_empty_gemms = lhs_list.size(); - if (is_mxfp8_scaling) { - for (int i = 0; i < num_non_empty_gemms; i++) { - // The i-th GEMM will use the (i % num_streams)-th stream to compute, - // use the same stream to swizzle the scaling factors to make sure that - // the swizzling is done before the GEMM computation starts. - int stream_id = i % num_streams; - cudaStream_t stream_i = nvte_get_compute_stream(stream_id); - nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); - nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); - } - } + // if (is_mxfp8_scaling) { + // for (int i = 0; i < num_non_empty_gemms; i++) { + // // The i-th GEMM will use the (i % num_streams)-th stream to compute, + // // use the same stream to swizzle the scaling factors to make sure that + // // the swizzling is done before the GEMM computation starts. + // int stream_id = i % num_streams; + // cudaStream_t stream_i = nvte_get_compute_stream(stream_id); + // nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); + // nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); + // } + // } // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM size_t num_zero_outs = zero_out_dptr_list.size(); From 496744fff3600a0382cdbffd50ff36e503daa1ef Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 19 Jan 2026 17:00:45 -0600 Subject: [PATCH 18/34] Improved guards on LayerNormMLP tests --- tests/jax/test_custom_call_compute.py | 20 +--- tests/jax/test_distributed_layernorm_mlp.py | 125 +++++++++++++++++--- tests/jax/utils.py | 18 +++ 3 files changed, 126 insertions(+), 37 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 1e882dcc0..7641afa16 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -11,12 +11,12 @@ from functools import reduce from typing import Union import operator -from packaging import version from utils import ( assert_allclose, pytest_parametrize_wrapper, use_jax_gemm, + _check_mxfp8_gemm_support, ) from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.layernorm_mlp import layernorm_mlp @@ -875,22 +875,6 @@ def _use_jax_fp8_gemm(enabled=False): elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") -def _check_mxfp8_gemm_support( - with_jax_gemm, - m, n, k, - x_qtype=jnp_float8_e4m3_type, - w_qtype=jnp_float8_e4m3_type -): - if not with_jax_gemm: - # if jnp_float8_e5m2_type in (x_qtype, w_qtype): - # pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") - if (m % 16 != 0) or (n % 16 != 0) or (k % 128 != 0): - pytest.skip( - f"Input shape {(m, k)} x {(k, n)} is not supported by MXFP8 GEMM." - ) - else: - if version.parse(jax.__version__) < version.parse("0.8.0"): - pytest.skip("MXFP8 not supported by JAX GEMM yet.") class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): @@ -937,7 +921,7 @@ def test_gemm_bf16(self, m, n, k, data_layout): @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): if scaling_mode.is_1d_block_scaling(): - _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, x_qtype, w_qtype) + _check_mxfp8_gemm_support(with_jax_gemm, m, n, k) x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) quantizer_set = QuantizerFactory.create_set( diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index c23942a0b..a6474d1bb 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -18,6 +18,7 @@ is_devices_enough, pytest_parametrize_wrapper, use_jax_gemm, + _check_mxfp8_gemm_support, ) from transformer_engine.common import recipe @@ -55,7 +56,7 @@ SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) DTYPES = [jnp.bfloat16, jnp.float16] -INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in] +INPUT_SHAPE = [[4, 64, 256]] # [batch, seqlen, hidden_in] LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES) DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES) @@ -66,9 +67,9 @@ LN_BIAS_AXES = (W_NO_SHARD_AXES,) BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES) BIAS_2_AXES = (W_NO_SHARD_AXES,) -# We set to 128 to ensure compatibility with MXFP8 GEMM which requires the -# reduction dim to be multiple of 128 -INTERMEDIATE = 128 +# We set to 256 to ensure compatibility with MXFP8 GEMM which requires the +# reduction dim to be multiple of 128 after sharding. +INTERMEDIATE = 128 * 2 # Only test with FSDP and TP as DP is not used @@ -156,6 +157,77 @@ def layernorm_fp8_mlp_prim_func( ) ) + def _check_mxfp8_layernorm_mlp_support( + self, + batch_size, + intermediate_size, + activation_size, + hidden_in, + hidden_out, + mesh_config, + use_bias, + with_jax_gemm + ): + # Check input shape compatibility with MXFP8 GEMMs + # FWD 1 + m = batch_size + k = hidden_in // mesh_config[1][1] # Account for TP sharding + n = activation_size + _check_mxfp8_gemm_support( + with_jax_gemm, + m, n, k, + use_bias + ) + # FWD 2 + k = intermediate_size // mesh_config[1][1] # Account for TP sharding + n = hidden_out + _check_mxfp8_gemm_support( + with_jax_gemm, + m, n, k, + use_bias + ) + + def _check_mxfp8_layernorm_mlp_grad_support( + self, + batch_size, + intermediate_size, + activation_size, + hidden_in, + hidden_out, + mesh_config, + use_bias, + with_jax_gemm + ): + # Check forwards + self._check_mxfp8_layernorm_mlp_support( + batch_size, + intermediate_size, + activation_size, + hidden_in, + hidden_out, + mesh_config, + use_bias, + with_jax_gemm, + ) + # BWD 1 + m = batch_size + k = hidden_out // mesh_config[1][1] # Account for TP sharding + n = intermediate_size + _check_mxfp8_gemm_support( + with_jax_gemm, + m, n, k, + use_bias + ) + # BWD 2 + m = intermediate_size + k = batch_size // mesh_config[1][1] # Account for TP sharding + n = hidden_out + _check_mxfp8_gemm_support( + with_jax_gemm, + m, n, k, + use_bias + ) + def _test_layernorm_mlp_grad( self, mesh_config, @@ -167,19 +239,29 @@ def _test_layernorm_mlp_grad( use_shardy, with_jax_gemm, ): - if ( - with_jax_gemm - and version.parse(jax.__version__) < version.parse("0.8.0") - and isinstance(fp8_recipe, recipe.MXFP8BlockScaling) - ): - pytest.skip("MXFP8 not supported by JAX GEMM yet.") + inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs( + input_shape, activation_type, use_bias, dtype + ) + if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + batch_size = x.shape[0]*x.shape[1] + intermediate_size = k2.shape[0] + activation_size = k1.shape[1]*k1.shape[2] + hidden_in = x.shape[2] + hidden_out = hidden_in + self._check_mxfp8_layernorm_mlp_grad_support( + batch_size, + intermediate_size, + activation_size, + hidden_in, + hidden_in, + mesh_config, + use_bias, + with_jax_gemm + ) jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config layernorm_type = "rmsnorm" - inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs( - input_shape, activation_type, use_bias, dtype - ) static_inputs = [layernorm_type, activation_type] with use_jax_gemm(enabled=with_jax_gemm): @@ -337,12 +419,17 @@ def _test_layernorm_mlp( use_shardy, with_jax_gemm, ): - if ( - with_jax_gemm - and version.parse(jax.__version__) < version.parse("0.8.0") - and isinstance(fp8_recipe, recipe.MXFP8BlockScaling) - ): - pytest.skip("MXFP8 not supported by JAX GEMM yet.") + if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + self._check_mxfp8_layernorm_mlp_support( + input_shape[0]*input_shape[1], + INTERMEDIATE, + 2*INTERMEDIATE, + input_shape[2], + input_shape[2], + mesh_config, + use_bias, + with_jax_gemm + ) jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape layernorm_type = "rmsnorm" diff --git a/tests/jax/utils.py b/tests/jax/utils.py index f34fb5448..f5e5b8479 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -11,6 +11,7 @@ import operator from typing import Any, Callable, Dict, Tuple, Sequence, Union, Iterable, Optional, NewType from contextlib import contextmanager +from packaging import version import jax import jax.numpy as jnp @@ -49,6 +50,23 @@ def is_devices_enough(required): return len(jax.devices()) >= required +def _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias=False): + if not with_jax_gemm: + if (m % 16 != 0) or (n % 16 != 0) or (k % 128 != 0): + pytest.skip( + f"Input shape {(m, k)} x {(k, n)} is not supported by hipblaslt MXFP8 GEMM." + ) + if use_bias: + pytest.skip("hipblaslt GEMM does not yet support MXFP8 with bias.") + else: + jax_version = version.parse(jax.__version__) + if jax_version < version.parse("0.8.0"): + pytest.skip( + "MXFP8 support for JAX GEMM is added in version 0.8.0, " + f"but the current detected version is {jax_version}." + ) + + def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]: # Generate broadcast dims for drop_path. drop_path_shape = list(range(0, len(shape))) From 74a07243368ce18e55444568ba5dd30eb4b4ac28 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 20 Jan 2026 12:52:40 -0600 Subject: [PATCH 19/34] Remove swizzle in JAX GEMM primitive --- transformer_engine/jax/cpp_extensions/gemm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4ba581c66..28a42020f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -406,8 +406,9 @@ def impl( rhs_scale_inv = apply_padding_to_scale_inv( rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis ) - lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) - rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) + if not is_hip_extension(): + lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) + rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) outputs = GemmPrimitive.inner_primitive.bind( lhs, From 12181f53de97f6da3d1b240791681a15a16ab902 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 20 Jan 2026 14:34:26 -0600 Subject: [PATCH 20/34] Added unique factors for sharding scales, added xfail to test --- tests/jax/test_distributed_layernorm_mlp.py | 24 ++++++++++--------- .../jax/cpp_extensions/activation.py | 5 +++- .../jax/cpp_extensions/normalization.py | 5 +++- transformer_engine/jax/layernorm_mlp.py | 2 +- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 09ed1f4c6..46594b451 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -246,18 +246,20 @@ def _test_layernorm_mlp_grad( inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs( input_shape, activation_type, use_bias, dtype ) + if ( + (not with_jax_gemm) + and use_bias + and fp8_recipe is None + and dtype == jnp.bfloat16 + ): + pytest.xfail("Skip known failure case.") if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - batch_size = x.shape[0]*x.shape[1] - intermediate_size = k2.shape[0] - activation_size = k1.shape[1]*k1.shape[2] - hidden_in = x.shape[2] - hidden_out = hidden_in self._check_mxfp8_layernorm_mlp_grad_support( - batch_size, - intermediate_size, - activation_size, - hidden_in, - hidden_in, + input_shape[0]*input_shape[1], + INTERMEDIATE, + len(activation_type)*INTERMEDIATE, + input_shape[2], + input_shape[2], mesh_config, use_bias, with_jax_gemm @@ -435,7 +437,7 @@ def _test_layernorm_mlp( self._check_mxfp8_layernorm_mlp_support( input_shape[0]*input_shape[1], INTERMEDIATE, - 2*INTERMEDIATE, + len(activation_type)*INTERMEDIATE, input_shape[2], input_shape[2], mesh_config, diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index ef2643359..6121ccf20 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -912,9 +912,12 @@ def shardy_sharding_rule( dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) amax = (prefix + "amax",) + # When is_2x==False, colwise_scale_inv needs a different factor + colwise_scale_inv_rule = scale_rules.colwise_rule if is_2x else (prefix + "x_colwise_scale_inv",) + return SdyShardingRule( (dz_axes, x_axes, ("…2",)), - (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), + (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv_rule, amax, dbias), ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index d7b72e4d7..3f9fd5cf4 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -603,13 +603,16 @@ def shardy_sharding_rule( mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma amax = (prefix + "amax",) + # When is_2x==False, colwise_scale_inv needs a different factor + colwise_scale_inv_rule = scale_rules.colwise_rule if is_2x else (prefix + "x_colwise_scale_inv",) + return SdyShardingRule( (x_axes, ("…1",), ("…2",), ("…3",)), ( out, colwise_out, scale_rules.rowwise_rule, - scale_rules.colwise_rule, + colwise_scale_inv_rule, amax, mu, rsigma, diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index fc957801a..264d9fc2b 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -260,7 +260,7 @@ def _layernorm_mlp_fwd_rule( assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] use_bias_1 = bias_1 is not None - use_bias_2 = bias_1 is not None + use_bias_2 = bias_2 is not None x = with_sharding_constraint_by_logical_axes(x, norm_input_axes) From ace4712974a788c1eaa6983c1d751ebbc990b15d Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 20 Jan 2026 14:37:23 -0600 Subject: [PATCH 21/34] Removed old code --- transformer_engine/jax/csrc/extensions/gemm.cpp | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f5ef148d2..cf396a46d 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -516,18 +516,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t num_non_empty_gemms = lhs_list.size(); - // if (is_mxfp8_scaling) { - // for (int i = 0; i < num_non_empty_gemms; i++) { - // // The i-th GEMM will use the (i % num_streams)-th stream to compute, - // // use the same stream to swizzle the scaling factors to make sure that - // // the swizzling is done before the GEMM computation starts. - // int stream_id = i % num_streams; - // cudaStream_t stream_i = nvte_get_compute_stream(stream_id); - // nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); - // nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); - // } - // } - // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM size_t num_zero_outs = zero_out_dptr_list.size(); for (int i = 0; i < num_zero_outs; i++) { From 82bbb7eb17eb557bd42de2bc06a3903972c8b08b Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Tue, 20 Jan 2026 14:45:53 -0600 Subject: [PATCH 22/34] Added bias parameterization --- tests/jax/test_custom_call_compute.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 62bc243b5..d692a0c11 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -979,18 +979,18 @@ def ref_func(x, w, data_layout): @pytest_parametrize_wrapper("m,n,k", TEST_SHAPES) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): + @pytest_parametrize_wrapper("use_bias", [False, True]) + def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm, use_bias): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) key = jax.random.PRNGKey(1) if scaling_mode.is_1d_block_scaling(): # Check for first GEMM - _check_mxfp8_gemm_support(with_jax_gemm, m, n, k) + _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias) # Check for second GEMM - _check_mxfp8_gemm_support(with_jax_gemm, m, k, n) - bias = None - else: - bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) + _check_mxfp8_gemm_support(with_jax_gemm, m, k, n, use_bias) + + bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) if use_bias else None def primitive_func(x, w, bias, contracting_dims, quantizer_set): primitive_out = dense( From 4a016b9d439b304525edfd1e9a45a24882db92f9 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 21 Jan 2026 10:00:51 -0600 Subject: [PATCH 23/34] Refactored test guard, added bias guard in hipblasgemm --- tests/jax/test_distributed_layernorm_mlp.py | 81 ++------------------- tests/jax/test_layer.py | 26 ++++++- tests/jax/utils.py | 68 +++++++++++++++++ transformer_engine/common/gemm/rocm_gemm.cu | 1 + 4 files changed, 99 insertions(+), 77 deletions(-) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 46594b451..128a59280 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -18,7 +18,8 @@ is_devices_enough, pytest_parametrize_wrapper, use_jax_gemm, - _check_mxfp8_gemm_support, + _check_mxfp8_layernorm_mlp_grad_support, + _check_mxfp8_layernorm_mlp_support, ) from transformer_engine.common import recipe @@ -161,76 +162,6 @@ def layernorm_fp8_mlp_prim_func( ) ) - def _check_mxfp8_layernorm_mlp_support( - self, - batch_size, - intermediate_size, - activation_size, - hidden_in, - hidden_out, - mesh_config, - use_bias, - with_jax_gemm - ): - # Check input shape compatibility with MXFP8 GEMMs - # FWD 1 - m = batch_size - k = hidden_in // mesh_config[1][1] # Account for TP sharding - n = activation_size - _check_mxfp8_gemm_support( - with_jax_gemm, - m, n, k, - use_bias - ) - # FWD 2 - k = intermediate_size // mesh_config[1][1] # Account for TP sharding - n = hidden_out - _check_mxfp8_gemm_support( - with_jax_gemm, - m, n, k, - use_bias - ) - - def _check_mxfp8_layernorm_mlp_grad_support( - self, - batch_size, - intermediate_size, - activation_size, - hidden_in, - hidden_out, - mesh_config, - use_bias, - with_jax_gemm - ): - # Check forwards - self._check_mxfp8_layernorm_mlp_support( - batch_size, - intermediate_size, - activation_size, - hidden_in, - hidden_out, - mesh_config, - use_bias, - with_jax_gemm, - ) - # BWD 1 - m = batch_size - k = hidden_out // mesh_config[1][1] # Account for TP sharding - n = intermediate_size - _check_mxfp8_gemm_support( - with_jax_gemm, - m, n, k, - use_bias - ) - # BWD 2 - m = intermediate_size - k = batch_size // mesh_config[1][1] # Account for TP sharding - n = hidden_out - _check_mxfp8_gemm_support( - with_jax_gemm, - m, n, k, - use_bias - ) def _test_layernorm_mlp_grad( self, @@ -254,13 +185,13 @@ def _test_layernorm_mlp_grad( ): pytest.xfail("Skip known failure case.") if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - self._check_mxfp8_layernorm_mlp_grad_support( + _check_mxfp8_layernorm_mlp_grad_support( input_shape[0]*input_shape[1], INTERMEDIATE, len(activation_type)*INTERMEDIATE, input_shape[2], input_shape[2], - mesh_config, + mesh_config[1][1], use_bias, with_jax_gemm ) @@ -434,13 +365,13 @@ def _test_layernorm_mlp( with_jax_gemm, ): if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - self._check_mxfp8_layernorm_mlp_support( + _check_mxfp8_layernorm_mlp_support( input_shape[0]*input_shape[1], INTERMEDIATE, len(activation_type)*INTERMEDIATE, input_shape[2], input_shape[2], - mesh_config, + mesh_config[1][1], use_bias, with_jax_gemm ) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 6f672ade7..d38ca3381 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -17,8 +17,12 @@ dtype_tols, sync_params_values, ) -from utils import DecoderLayer as RefDecoderLayer -from utils import EncoderLayer as RefEncoderLayer +from utils import ( + DecoderLayer as RefDecoderLayer, + EncoderLayer as RefEncoderLayer, + _check_mxfp8_layernorm_mlp_grad_support, + _check_mxfp8_layernorm_mlp_support, + ) from transformer_engine.common import recipe from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType @@ -521,6 +525,15 @@ def test_backward(self, data_shape, dtype, attrs): @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test forward with fp8 enabled""" + if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + _check_mxfp8_layernorm_mlp_support( + data_shape[0]*data_shape[1], + 2048, + 2048, + data_shape[2], + data_shape[2], + use_bias=attrs.get(_KEY_OF_USE_BIAS, False), + ) # Empty MeshResource is used as we are running on a single device with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) @@ -529,6 +542,15 @@ def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test backward with fp8 enabled""" + if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + _check_mxfp8_layernorm_mlp_grad_support( + data_shape[0]*data_shape[1], + 2048, + 2048, + data_shape[2], + data_shape[2], + use_bias=attrs.get(_KEY_OF_USE_BIAS, False), + ) # Empty MeshResource is used as we are running on a single device with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 43a78d07b..5cde22ffb 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -66,6 +66,74 @@ def _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias=False): f"but the current detected version is {jax_version}." ) +def _check_mxfp8_layernorm_mlp_support( + batch_size, + intermediate_size, + activation_size, + hidden_in, + hidden_out, + n_tp_shards=1, + use_bias=False, + with_jax_gemm=False, +): + # Check input shape compatibility with MXFP8 GEMMs + # FWD 1 + m = batch_size + k = hidden_in // n_tp_shards # Account for TP sharding + n = activation_size + _check_mxfp8_gemm_support( + with_jax_gemm, + m, n, k, + use_bias + ) + # FWD 2 + k = intermediate_size // n_tp_shards # Account for TP sharding + n = hidden_out + _check_mxfp8_gemm_support( + with_jax_gemm, + m, n, k, + use_bias + ) + +def _check_mxfp8_layernorm_mlp_grad_support( + batch_size, + intermediate_size, + activation_size, + hidden_in, + hidden_out, + n_tp_shards=1, + use_bias=False, + with_jax_gemm=False, +): + # Check forwards + _check_mxfp8_layernorm_mlp_support( + batch_size, + intermediate_size, + activation_size, + hidden_in, + hidden_out, + n_tp_shards, + use_bias, + with_jax_gemm, + ) + # BWD 1 + m = batch_size + k = hidden_out // n_tp_shards # Account for TP sharding + n = intermediate_size + _check_mxfp8_gemm_support( + with_jax_gemm, + m, n, k, + use_bias + ) + # BWD 2 + m = intermediate_size + k = batch_size // n_tp_shards # Account for TP sharding + n = hidden_out + _check_mxfp8_gemm_support( + with_jax_gemm, + m, n, k, + use_bias + ) def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[int]: # Generate broadcast dims for drop_path. diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index d71099dfc..c2c4c502a 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1524,6 +1524,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ")"); // Check that K is a multiple of 128, and M/N are multiples of 16 for MXFP8 GEMM if (inputA->scaling_mode == NVTE_MXFP8_1D_SCALING || inputB->scaling_mode == NVTE_MXFP8_1D_SCALING) { + NVTE_CHECK(inputBias->data.dptr == nullptr, "MXFP8 GEMM does not yet support bias."); NVTE_CHECK((k % 128) == 0, "GEMM K dimension must be multiple of 128 for MXFP8 scaling (got K=", k, ")"); NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")"); NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")"); From 0eed117a65d30dc6253647a122d5a67e1080db27 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 23 Jan 2026 13:29:07 -0600 Subject: [PATCH 24/34] PR comments adressed --- tests/jax/test_custom_call_compute.py | 4 +++- tests/jax/test_distributed_layernorm_mlp.py | 13 ++++++++---- tests/jax/utils.py | 4 ++++ .../jax/cpp_extensions/normalization.py | 2 +- .../jax/csrc/extensions/gemm.cpp | 20 ++++++++++++++++--- 5 files changed, 34 insertions(+), 9 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index d692a0c11..086d8b2d9 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -58,7 +58,9 @@ (2048, 2048, 1024), (2048, 1024, 1024), ] -TEST_SHAPES = [(64, 32, 64), (128, 64, 128), (128, 256, 256)] +TEST_SHAPES = [(64, 32, 64)] +if is_hip_extension(): + TEST_SHAPES += [(128, 64, 128), (128, 256, 256)] jnp_float8_e4m3_type = get_jnp_float8_e4m3_type() jnp_float8_e5m2_type = get_jnp_float8_e5m2_type() diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 128a59280..33c4089f1 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -61,7 +61,7 @@ SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) DTYPES = [jnp.bfloat16, jnp.float16] -INPUT_SHAPE = [[4, 64, 256]] # [batch, seqlen, hidden_in] +INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in] LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES) DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES) @@ -72,9 +72,14 @@ LN_BIAS_AXES = (W_NO_SHARD_AXES,) BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES) BIAS_2_AXES = (W_NO_SHARD_AXES,) -# We set to 256 to ensure compatibility with MXFP8 GEMM which requires the -# reduction dim to be multiple of 128 after sharding. -INTERMEDIATE = 128 * 2 + +INTERMEDIATE = 128 + +# We set to 256 to ensure compatibility with hipblaslt MXFP8 GEMM which +# requires the reduction dim to be multiple of 128 after sharding. +if is_hip_extension(): + INPUT_SHAPE = [[4, 64, 256]] + INTERMEDIATE = INTERMEDIATE * 2 # Only test with FSDP and TPSP as DP is not used diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 5cde22ffb..08c866a73 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -29,6 +29,7 @@ ) from transformer_engine.jax.quantize.helper import DType as TEDType from transformer_engine.jax.util import get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type +from transformer_engine.jax.cpp_extensions.misc import is_hip_extension PRNGKey = Any Shape = Tuple[int, ...] @@ -51,6 +52,9 @@ def is_devices_enough(required): def _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias=False): + if not is_hip_extension(): + return + if not with_jax_gemm: if (m % 16 != 0) or (n % 16 != 0) or (k % 128 != 0): pytest.skip( diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 3f9fd5cf4..c0af0a060 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -321,7 +321,7 @@ def impl( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(x.shape, is_padded=False) - # Slice out the padding for mxfp8 - the ROCm kernel writes to strided + # Slice out the padding for mxfp8 -- the kernel writes to strided # 2D positions, not contiguous. # For 1D MXFP8: allocated [padded_rows, padded_cols], kernel writes [:actual_rows, :actual_cols] scale_inv = jax.lax.slice(scale_inv, [0] * scale_inv.ndim, rowwise_scale_inv_shape) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index cf396a46d..5b18e4e0d 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -321,13 +321,13 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto bias_shape = std::vector{has_bias ? n : 0}; const int arch = cuda::sm_arch(); - #ifndef __HIP_PLATFORM_AMD__ +#ifndef __HIP_PLATFORM_AMD__ if (arch < 100 && is_fp8_gemm) { NVTE_CHECK(!lhs_is_trans && rhs_is_trans, "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); } - #endif +#endif // These lists are to keep the TensorWrapper objects alive std::vector lhs_wrapper_list; @@ -516,7 +516,21 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t num_non_empty_gemms = lhs_list.size(); - // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM +#ifndef __HIP_PLATFORM_AMD__ + if (is_mxfp8_scaling) { + for (int i = 0; i < num_non_empty_gemms; i++) { + // The i-th GEMM will use the (i % num_streams)-th stream to compute, + // use the same stream to swizzle the scaling factors to make sure that + // the swizzling is done before the GEMM computation starts. + int stream_id = i % num_streams; + cudaStream_t stream_i = nvte_get_compute_stream(stream_id); + nvte_swizzle_scaling_factors(lhs_swizzle_list[i], lhs_list[i], stream_i); + nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); + } + } +#endif + +// Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM size_t num_zero_outs = zero_out_dptr_list.size(); for (int i = 0; i < num_zero_outs; i++) { int stream_id = i % num_streams; From 49d13759916c0679e432e722cdde37ee01782b27 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 23 Jan 2026 13:32:11 -0600 Subject: [PATCH 25/34] Minor typo --- tests/jax/test_distributed_layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 33c4089f1..1127964a5 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -78,7 +78,7 @@ # We set to 256 to ensure compatibility with hipblaslt MXFP8 GEMM which # requires the reduction dim to be multiple of 128 after sharding. if is_hip_extension(): - INPUT_SHAPE = [[4, 64, 256]] + INPUT_SHAPE += [[4, 64, 256]] INTERMEDIATE = INTERMEDIATE * 2 From bae9cf7e846621720818630a119a45031367f922 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 23 Jan 2026 13:40:48 -0600 Subject: [PATCH 26/34] Updated test per PR comments --- tests/jax/test_custom_call_compute.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 086d8b2d9..b29e92d67 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -926,6 +926,13 @@ def test_gemm_bf16(self, m, n, k, data_layout): @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): + if ( + not with_jax_gemm + and scaling_mode.is_1d_block_scaling() + and jnp_float8_e5m2_type in (x_qtype, w_qtype) + and not is_hip_extension() + ): + pytest.skip("Float8E5M2 is not recommended for MXFP8 GEMM.") if scaling_mode.is_1d_block_scaling(): _check_mxfp8_gemm_support(with_jax_gemm, m, n, k) @@ -981,7 +988,7 @@ def ref_func(x, w, data_layout): @pytest_parametrize_wrapper("m,n,k", TEST_SHAPES) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - @pytest_parametrize_wrapper("use_bias", [False, True]) + @pytest_parametrize_wrapper("use_bias", [False, True] if is_hip_extension() else [True]) def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm, use_bias): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -1154,11 +1161,10 @@ def test_layernorm_mlp_grad( """ if scaling_mode.is_1d_block_scaling(): # Check for first GEMM - _check_mxfp8_gemm_support(with_jax_gemm, m, n, k) + _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias) # Check for second GEMM - _check_mxfp8_gemm_support(with_jax_gemm, m, k, n) - if use_bias: - pytest.skip("Bias is not supported for MXFP8 GEMM.") + _check_mxfp8_gemm_support(with_jax_gemm, m, k, n, use_bias) + # zero_centered_gamma is already tested in TestNorm zero_centered_gamma = False eps = 1e-6 From 5090bb8372dd9671939c895156c4e1b249e19fc6 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Fri, 23 Jan 2026 16:53:29 -0600 Subject: [PATCH 27/34] Improve test to cover padded scale_inv for mxfp8 gemm --- tests/jax/test_custom_call_compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index b29e92d67..5745035f1 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -60,7 +60,7 @@ ] TEST_SHAPES = [(64, 32, 64)] if is_hip_extension(): - TEST_SHAPES += [(128, 64, 128), (128, 256, 256)] + TEST_SHAPES += [(64, 64, 128), (128, 256, 256)] jnp_float8_e4m3_type = get_jnp_float8_e4m3_type() jnp_float8_e5m2_type = get_jnp_float8_e5m2_type() From 65ed84be03cea7f3673c0a1575a886343c4654e9 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 11:01:48 -0600 Subject: [PATCH 28/34] Address PR review comments --- tests/jax/test_custom_call_compute.py | 4 ++-- tests/jax/test_distributed_layernorm_mlp.py | 20 +++++++++---------- .../jax/cpp_extensions/activation.py | 1 + .../jax/cpp_extensions/normalization.py | 11 +++++----- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 5745035f1..4555936f7 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1370,7 +1370,7 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("layout", ["NN"]) def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): - if scaling_mode.is_1d_block_scaling(): + if is_hip_extension() and scaling_mode.is_1d_block_scaling(): pytest.skip("MXFP8 grouped GEMM is not fully supported yet in ROCm.") fwd_dtype, bwd_dtype = fwd_bwd_dtype @@ -1453,7 +1453,7 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): ) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): - if scaling_mode.is_1d_block_scaling(): + if is_hip_extension() and scaling_mode.is_1d_block_scaling(): pytest.skip("MXFP8 grouped GEMM is not fully supported yet in ROCm.") fwd_dtype, bwd_dtype = fwd_bwd_dtype diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 1127964a5..2097afde1 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -4,7 +4,6 @@ # # See LICENSE for license information. from typing import Callable, Sequence, Union, Optional -from packaging import version import pytest @@ -73,13 +72,14 @@ BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES) BIAS_2_AXES = (W_NO_SHARD_AXES,) -INTERMEDIATE = 128 +INTERMEDIATE = 64 # We set to 256 to ensure compatibility with hipblaslt MXFP8 GEMM which # requires the reduction dim to be multiple of 128 after sharding. if is_hip_extension(): INPUT_SHAPE += [[4, 64, 256]] - INTERMEDIATE = INTERMEDIATE * 2 + # TODO: Calculate intermediate size dynamically based on mesh config tpsp axis + INTERMEDIATE = 128 * 2 # Only test with FSDP and TPSP as DP is not used @@ -167,7 +167,6 @@ def layernorm_fp8_mlp_prim_func( ) ) - def _test_layernorm_mlp_grad( self, mesh_config, @@ -179,14 +178,12 @@ def _test_layernorm_mlp_grad( use_shardy, with_jax_gemm, ): - inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs( - input_shape, activation_type, use_bias, dtype - ) if ( - (not with_jax_gemm) + is_hip_extension() + and (not with_jax_gemm) and use_bias - and fp8_recipe is None - and dtype == jnp.bfloat16 + and (fp8_recipe is None) + and (dtype == jnp.bfloat16) ): pytest.xfail("Skip known failure case.") if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): @@ -204,6 +201,9 @@ def _test_layernorm_mlp_grad( device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config layernorm_type = "rmsnorm" + inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs( + input_shape, activation_type, use_bias, dtype + ) static_inputs = [layernorm_type, activation_type] with use_jax_gemm(enabled=with_jax_gemm): diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 6121ccf20..f1f7935ea 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -912,6 +912,7 @@ def shardy_sharding_rule( dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) amax = (prefix + "amax",) + # TODO(micky774): Investigate the necessity of separate colwise_scale_inv rule. # When is_2x==False, colwise_scale_inv needs a different factor colwise_scale_inv_rule = scale_rules.colwise_rule if is_2x else (prefix + "x_colwise_scale_inv",) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index c0af0a060..d12858bdb 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -321,8 +321,7 @@ def impl( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(x.shape, is_padded=False) - # Slice out the padding for mxfp8 -- the kernel writes to strided - # 2D positions, not contiguous. + # Slice out the padding for mxfp8 -- the kernel writes to strided 2D positions, not contiguous. # For 1D MXFP8: allocated [padded_rows, padded_cols], kernel writes [:actual_rows, :actual_cols] scale_inv = jax.lax.slice(scale_inv, [0] * scale_inv.ndim, rowwise_scale_inv_shape) if is_2x: @@ -581,6 +580,7 @@ def shardy_sharding_rule( ): if version.parse(jax.__version__) < version.parse("0.5.0"): raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") + print(f"DEBUG RESULT SHAPES *** rowwise_scale_inv={tuple(result_types[2].shape)}, colwise_scale_inv={tuple(result_types[3].shape)}") del ( zero_centered_gamma, epsilon, @@ -596,7 +596,6 @@ def shardy_sharding_rule( len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1 ) x_axes = scale_rules.input_spec - out = x_axes colwise_out = out if is_2x else (prefix + "out_colwise",) rsigma = x_axes[:-1] @@ -604,15 +603,15 @@ def shardy_sharding_rule( amax = (prefix + "amax",) # When is_2x==False, colwise_scale_inv needs a different factor - colwise_scale_inv_rule = scale_rules.colwise_rule if is_2x else (prefix + "x_colwise_scale_inv",) - + # colwise_scale_inv_rule = scale_rules.colwise_rule if is_2x else (prefix + "x_colwise_scale_inv",) + print(f"DEBUG *** {scale_rules.rowwise_rule=} | {scale_rules.colwise_rule=}") return SdyShardingRule( (x_axes, ("…1",), ("…2",), ("…3",)), ( out, colwise_out, scale_rules.rowwise_rule, - colwise_scale_inv_rule, + scale_rules.colwise_rule, amax, mu, rsigma, From c2afecc0487f58eafcd1e734bd1b871743438a11 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 11:05:26 -0600 Subject: [PATCH 29/34] Added arch specific guard for FP8 GEMM config --- transformer_engine/jax/csrc/extensions/gemm.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 5b18e4e0d..fba9bb916 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -327,6 +327,12 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); } +#else + if (arch < 95 && is_fp8_gemm) { + NVTE_CHECK(!lhs_is_trans && rhs_is_trans, + "For FP8 input on gfx942, only NT (row-major) GEMM is supported, ", + "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); + } #endif // These lists are to keep the TensorWrapper objects alive From d815a26c1daa89d3a2c08cd42394edd251251694 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 11:07:40 -0600 Subject: [PATCH 30/34] Reformat inline comment --- transformer_engine/jax/cpp_extensions/normalization.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index d12858bdb..2f23c3711 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -321,8 +321,9 @@ def impl( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(x.shape, is_padded=False) - # Slice out the padding for mxfp8 -- the kernel writes to strided 2D positions, not contiguous. - # For 1D MXFP8: allocated [padded_rows, padded_cols], kernel writes [:actual_rows, :actual_cols] + # Slice out the padding for mxfp8 -- the kernel writes to strided 2D + # positions, not contiguous. For 1D MXFP8: allocated [padded_rows, + # padded_cols], kernel writes [:actual_rows, :actual_cols] scale_inv = jax.lax.slice(scale_inv, [0] * scale_inv.ndim, rowwise_scale_inv_shape) if is_2x: colwise_scale_inv = jax.lax.slice(colwise_scale_inv, [0] * colwise_scale_inv.ndim, colwise_scale_inv_shape) From 3cd8e5e824d5df32f47b9000e15cfdfe51b18775 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 11:10:06 -0600 Subject: [PATCH 31/34] Minor code reformat --- .../jax/cpp_extensions/normalization.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 2f23c3711..098c5d96f 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -324,9 +324,17 @@ def impl( # Slice out the padding for mxfp8 -- the kernel writes to strided 2D # positions, not contiguous. For 1D MXFP8: allocated [padded_rows, # padded_cols], kernel writes [:actual_rows, :actual_cols] - scale_inv = jax.lax.slice(scale_inv, [0] * scale_inv.ndim, rowwise_scale_inv_shape) + scale_inv = jax.lax.slice( + scale_inv, + [0] * scale_inv.ndim, + rowwise_scale_inv_shape + ) if is_2x: - colwise_scale_inv = jax.lax.slice(colwise_scale_inv, [0] * colwise_scale_inv.ndim, colwise_scale_inv_shape) + colwise_scale_inv = jax.lax.slice( + colwise_scale_inv, + [0] * colwise_scale_inv.ndim, + colwise_scale_inv_shape + ) return ( out, colwise_out, From 02a9dbc2debb4e3f31872348c547e57244a66930 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 11:13:35 -0600 Subject: [PATCH 32/34] Remove debug statement, reformat code --- transformer_engine/jax/cpp_extensions/activation.py | 5 ++++- transformer_engine/jax/cpp_extensions/normalization.py | 8 ++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index f1f7935ea..785519cef 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -914,7 +914,10 @@ def shardy_sharding_rule( # TODO(micky774): Investigate the necessity of separate colwise_scale_inv rule. # When is_2x==False, colwise_scale_inv needs a different factor - colwise_scale_inv_rule = scale_rules.colwise_rule if is_2x else (prefix + "x_colwise_scale_inv",) + colwise_scale_inv_rule = ( + scale_rules.colwise_rule if is_2x + else (prefix + "x_colwise_scale_inv",) + ) return SdyShardingRule( (dz_axes, x_axes, ("…2",)), diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 098c5d96f..bedf24123 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -611,9 +611,13 @@ def shardy_sharding_rule( mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma amax = (prefix + "amax",) + # TODO(micky774): Investigate the necessity of separate + # colwise_scale_inv rule. # When is_2x==False, colwise_scale_inv needs a different factor - # colwise_scale_inv_rule = scale_rules.colwise_rule if is_2x else (prefix + "x_colwise_scale_inv",) - print(f"DEBUG *** {scale_rules.rowwise_rule=} | {scale_rules.colwise_rule=}") + colwise_scale_inv_rule = ( + scale_rules.colwise_rule if is_2x + else (prefix + "x_colwise_scale_inv",) + ) return SdyShardingRule( (x_axes, ("…1",), ("…2",), ("…3",)), ( From 31f5737ce8e5aa9986121bc6ab946e13c72b3762 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 29 Jan 2026 09:51:01 -0600 Subject: [PATCH 33/34] Formatting --- tests/jax/test_custom_call_compute.py | 4 +++- tests/jax/test_distributed_layernorm_mlp.py | 2 -- transformer_engine/jax/cpp_extensions/normalization.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 4555936f7..f1f6fb739 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -992,14 +992,16 @@ def ref_func(x, w, data_layout): def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm, use_bias): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + key = jax.random.PRNGKey(1) + bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) if use_bias else None + if scaling_mode.is_1d_block_scaling(): # Check for first GEMM _check_mxfp8_gemm_support(with_jax_gemm, m, n, k, use_bias) # Check for second GEMM _check_mxfp8_gemm_support(with_jax_gemm, m, k, n, use_bias) - bias = jax.random.uniform(key, n, dtype=jnp.bfloat16) if use_bias else None def primitive_func(x, w, bias, contracting_dims, quantizer_set): primitive_out = dense( diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 2097afde1..074dffcc5 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -4,7 +4,6 @@ # # See LICENSE for license information. from typing import Callable, Sequence, Union, Optional - import pytest import jax @@ -71,7 +70,6 @@ LN_BIAS_AXES = (W_NO_SHARD_AXES,) BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES) BIAS_2_AXES = (W_NO_SHARD_AXES,) - INTERMEDIATE = 64 # We set to 256 to ensure compatibility with hipblaslt MXFP8 GEMM which diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index bedf24123..268fb657c 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -589,7 +589,6 @@ def shardy_sharding_rule( ): if version.parse(jax.__version__) < version.parse("0.5.0"): raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") - print(f"DEBUG RESULT SHAPES *** rowwise_scale_inv={tuple(result_types[2].shape)}, colwise_scale_inv={tuple(result_types[3].shape)}") del ( zero_centered_gamma, epsilon, @@ -605,6 +604,7 @@ def shardy_sharding_rule( len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1 ) x_axes = scale_rules.input_spec + out = x_axes colwise_out = out if is_2x else (prefix + "out_colwise",) rsigma = x_axes[:-1] From 8588544121ebef2d00a431b2322f362b5d255c64 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 29 Jan 2026 13:42:24 -0600 Subject: [PATCH 34/34] Reverted unnecessary shardy changes --- transformer_engine/jax/cpp_extensions/activation.py | 9 +-------- transformer_engine/jax/cpp_extensions/normalization.py | 7 ------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 785519cef..ef2643359 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -912,16 +912,9 @@ def shardy_sharding_rule( dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) amax = (prefix + "amax",) - # TODO(micky774): Investigate the necessity of separate colwise_scale_inv rule. - # When is_2x==False, colwise_scale_inv needs a different factor - colwise_scale_inv_rule = ( - scale_rules.colwise_rule if is_2x - else (prefix + "x_colwise_scale_inv",) - ) - return SdyShardingRule( (dz_axes, x_axes, ("…2",)), - (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv_rule, amax, dbias), + (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 268fb657c..ecf9aa5a8 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -611,13 +611,6 @@ def shardy_sharding_rule( mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma amax = (prefix + "amax",) - # TODO(micky774): Investigate the necessity of separate - # colwise_scale_inv rule. - # When is_2x==False, colwise_scale_inv needs a different factor - colwise_scale_inv_rule = ( - scale_rules.colwise_rule if is_2x - else (prefix + "x_colwise_scale_inv",) - ) return SdyShardingRule( (x_axes, ("…1",), ("…2",), ("…3",)), (