From 8461f7e66f7452ab3c4e20e755820aa5920f35f2 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 14:23:04 -0600 Subject: [PATCH 1/2] Added hotfix --- transformer_engine/pytorch/triton_kernels/rmsnorm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index a4cc95459..e43c46280 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -365,8 +365,8 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): dg_tmp = torch.empty(dg_tmp_rows(x_, sm_margin), N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None grid_bwd = lambda meta: (NUM_PRGMS, ) - input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(-1) % 16 == 0) - grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(-1) % 16 == 0) + input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) % 16 == 0) + grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(0) % 16 == 0) _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, USE_BLOCKED, NUM_PRGMS, input_aligned_16, grad_output_aligned_16, num_warps=8) @@ -453,10 +453,10 @@ def te_rmsnorm_fwd_triton( grid_fwd = lambda meta: (NUM_PRGMS, ) # TODO(micky774) Implement fused MXFP8 quantization within the kernel kernel = _rmsnorm_fwd_triton if autotune else _rmsnorm_fwd_triton_impl - input_aligned_16 = (input.data_ptr() % 16 == 0) and (input.stride(-1) % 16 == 0) + input_aligned_16 = (input.data_ptr() % 16 == 0) and (input.stride(0) % 16 == 0) out_alignment_tensor = out._data if hasattr(out, "_data") else out output_aligned_16 = (out_alignment_tensor.data_ptr() % 16 == 0) and ( - out_alignment_tensor.stride(-1) % 16 == 0 + out_alignment_tensor.stride(0) % 16 == 0 ) kernel[grid_fwd]( out_ptr, From d9aa3092eaebc4ed427d18b81723d95aa035b419 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 14:39:49 -0600 Subject: [PATCH 2/2] Updated to account for element size --- .../pytorch/triton_kernels/rmsnorm.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index e43c46280..09e05366f 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -365,8 +365,8 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): dg_tmp = torch.empty(dg_tmp_rows(x_, sm_margin), N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None grid_bwd = lambda meta: (NUM_PRGMS, ) - input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(0) % 16 == 0) - grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(0) % 16 == 0) + input_aligned_16 = (x_.data_ptr() % 16 == 0) and (((x_.stride(0) * x_.dtype.itemsize) % 16) == 0) + grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (((dz_.stride(0) * dz_.dtype.itemsize) % 16) == 0) _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, USE_BLOCKED, NUM_PRGMS, input_aligned_16, grad_output_aligned_16, num_warps=8) @@ -453,10 +453,14 @@ def te_rmsnorm_fwd_triton( grid_fwd = lambda meta: (NUM_PRGMS, ) # TODO(micky774) Implement fused MXFP8 quantization within the kernel kernel = _rmsnorm_fwd_triton if autotune else _rmsnorm_fwd_triton_impl - input_aligned_16 = (input.data_ptr() % 16 == 0) and (input.stride(0) % 16 == 0) + input_aligned_16 = ( + (input.data_ptr() % 16 == 0) and + (((input.stride(0) * input.dtype.itemsize) % 16) == 0) + ) out_alignment_tensor = out._data if hasattr(out, "_data") else out - output_aligned_16 = (out_alignment_tensor.data_ptr() % 16 == 0) and ( - out_alignment_tensor.stride(0) % 16 == 0 + output_aligned_16 = ( + (out_alignment_tensor.data_ptr() % 16 == 0) and + (((out_alignment_tensor.stride(0) * out_alignment_tensor.dtype.itemsize) % 16) == 0) ) kernel[grid_fwd]( out_ptr,