diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index a4cc95459..09e05366f 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -365,8 +365,8 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): dg_tmp = torch.empty(dg_tmp_rows(x_, sm_margin), N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None grid_bwd = lambda meta: (NUM_PRGMS, ) - input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(-1) % 16 == 0) - grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(-1) % 16 == 0) + input_aligned_16 = (x_.data_ptr() % 16 == 0) and (((x_.stride(0) * x_.dtype.itemsize) % 16) == 0) + grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (((dz_.stride(0) * dz_.dtype.itemsize) % 16) == 0) _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, USE_BLOCKED, NUM_PRGMS, input_aligned_16, grad_output_aligned_16, num_warps=8) @@ -453,10 +453,14 @@ def te_rmsnorm_fwd_triton( grid_fwd = lambda meta: (NUM_PRGMS, ) # TODO(micky774) Implement fused MXFP8 quantization within the kernel kernel = _rmsnorm_fwd_triton if autotune else _rmsnorm_fwd_triton_impl - input_aligned_16 = (input.data_ptr() % 16 == 0) and (input.stride(-1) % 16 == 0) + input_aligned_16 = ( + (input.data_ptr() % 16 == 0) and + (((input.stride(0) * input.dtype.itemsize) % 16) == 0) + ) out_alignment_tensor = out._data if hasattr(out, "_data") else out - output_aligned_16 = (out_alignment_tensor.data_ptr() % 16 == 0) and ( - out_alignment_tensor.stride(-1) % 16 == 0 + output_aligned_16 = ( + (out_alignment_tensor.data_ptr() % 16 == 0) and + (((out_alignment_tensor.stride(0) * out_alignment_tensor.dtype.itemsize) % 16) == 0) ) kernel[grid_fwd]( out_ptr,