From 1e363163eb33e2f9b607f2ab4196c53fa27febdb Mon Sep 17 00:00:00 2001 From: kerryi Date: Tue, 20 Jan 2026 19:04:27 +0800 Subject: [PATCH] fix: add row boundary checks to Triton normalization kernels Fixed illegal memory access in _rms_norm_fwd_fused, _layer_norm_param_fwd_fused, and _layer_norm_noparam_fwd_fused kernels. The kernels were missing row boundary checks (rows < M) which caused out-of-bounds memory access when the number of rows M is not divisible by BLOCK_M (32). Changes: - Added M parameter to all three kernels - Added row_mask = rows < M - Changed mask from 1D (cols only) to 2D (rows & cols) - Applied proper masking to all tl.load and tl.store operations --- turbodiffusion/ops/core.py | 46 ++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/turbodiffusion/ops/core.py b/turbodiffusion/ops/core.py index 6d74fe5..6e26e00 100644 --- a/turbodiffusion/ops/core.py +++ b/turbodiffusion/ops/core.py @@ -101,6 +101,7 @@ def _rms_norm_fwd_fused( Rstd, x_stride, y_stride, + M, # number of rows in X N: tl.constexpr, # number of columns in X, N2: tl.constexpr, eps, # epsilon to avoid division by zero @@ -110,12 +111,14 @@ def _rms_norm_fwd_fused( pid = tl.program_id(0) rows = pid * BLOCK_M + tl.arange(0, BLOCK_M) cols = tl.arange(0, N2) - mask = cols < N + col_mask = cols < N + row_mask = rows < M + mask = row_mask[:, None] & col_mask[None, :] x_ptr = X + rows[:, None] * x_stride + cols[None, :] y_ptr = Y + rows[:, None] * y_stride + cols[None, :] - x = tl.load(x_ptr, mask=mask[None, :], other=0.0).to(tl.float32) + x = tl.load(x_ptr, mask=mask, other=0.0).to(tl.float32) # Compute variance _var = x * x @@ -123,17 +126,17 @@ def _rms_norm_fwd_fused( rstd = 1 / tl.sqrt(var + eps) # Write mean / rstd - tl.store(Rstd + rows, rstd) + tl.store(Rstd + rows, rstd, mask=row_mask) rstd = tl.reshape(rstd, (BLOCK_M, 1)) # Normalize and apply linear transformation - w = tl.load(W + cols) + w = tl.load(W + cols, mask=col_mask, other=0.0) x_hat = x * rstd y = x_hat * w # Write output y = y.to(Y.type.element_ty) - tl.store(y_ptr, y, mask=mask[None, :]) + tl.store(y_ptr, y, mask=mask) def rmsnorm(x, w, eps): @@ -177,6 +180,7 @@ def rmsnorm(x, w, eps): rstd, # x.stride(0), y.stride(0), + M, N, N2, eps, @@ -200,6 +204,7 @@ def _layer_norm_param_fwd_fused( Rstd, # pointer to the 1/std x_stride, # how much to increase the pointer when moving by 1 row y_stride, # how much to increase the pointer when moving by 1 row + M, # number of rows in X N: tl.constexpr, # number of columns in X, N2: tl.constexpr, # number of columns in X, eps, # epsilon to avoid division by zero @@ -209,12 +214,14 @@ def _layer_norm_param_fwd_fused( pid = tl.program_id(0) rows = pid * BLOCK_M + tl.arange(0, BLOCK_M) cols = tl.arange(0, N2) - mask = cols < N + col_mask = cols < N + row_mask = rows < M + mask = row_mask[:, None] & col_mask[None, :] x_ptr = X + rows[:, None] * x_stride + cols[None, :] y_ptr = Y + rows[:, None] * y_stride + cols[None, :] - x = tl.load(x_ptr, mask=mask[None, :], other=0.0).to(tl.float32) + x = tl.load(x_ptr, mask=mask, other=0.0).to(tl.float32) # Compute mean and Variance mean = tl.sum(x, axis=1, keep_dims=True) / N @@ -226,20 +233,20 @@ def _layer_norm_param_fwd_fused( # Write mean / rstd _mean = tl.reshape(mean, (BLOCK_M)) _rstd = tl.reshape(rstd, (BLOCK_M)) - tl.store(Mean + rows, _mean) - tl.store(Rstd + rows, _rstd) + tl.store(Mean + rows, _mean, mask=row_mask) + tl.store(Rstd + rows, _rstd, mask=row_mask) # Normalize and apply linear transformation x_hat = (x - mean) * rstd - w = tl.load(W + cols) - b = tl.load(B + cols) + w = tl.load(W + cols, mask=col_mask, other=0.0) + b = tl.load(B + cols, mask=col_mask, other=0.0) x_hat = x_hat * w + b # Write output x_hat = x_hat.to(Y.type.element_ty) - tl.store(y_ptr, x_hat, mask=mask[None, :]) + tl.store(y_ptr, x_hat, mask=mask) def layernorm_param(x, w, b, eps): @@ -271,6 +278,7 @@ def layernorm_param(x, w, b, eps): rstd, # x.stride(0), y.stride(0), + M, N, N2, eps, @@ -298,6 +306,7 @@ def _layer_norm_noparam_fwd_fused( Rstd, # pointer to the 1/std x_stride, # how much to increase the pointer when moving by 1 row y_stride, # how much to increase the pointer when moving by 1 row + M, # number of rows in X N: tl.constexpr, # number of columns in X, N2: tl.constexpr, # number of columns in X, eps, # epsilon to avoid division by zero @@ -307,12 +316,14 @@ def _layer_norm_noparam_fwd_fused( pid = tl.program_id(0) rows = pid * BLOCK_M + tl.arange(0, BLOCK_M) cols = tl.arange(0, N2) - mask = cols < N + col_mask = cols < N + row_mask = rows < M + mask = row_mask[:, None] & col_mask[None, :] x_ptr = X + rows[:, None] * x_stride + cols[None, :] y_ptr = Y + rows[:, None] * y_stride + cols[None, :] - x = tl.load(x_ptr, mask=mask[None, :], other=0.0).to(tl.float32) + x = tl.load(x_ptr, mask=mask, other=0.0).to(tl.float32) # Compute mean and Variance mean = tl.sum(x, axis=1, keep_dims=True) / N @@ -324,15 +335,15 @@ def _layer_norm_noparam_fwd_fused( # Write mean / rstd _mean = tl.reshape(mean, (BLOCK_M)) _rstd = tl.reshape(rstd, (BLOCK_M)) - tl.store(Mean + rows, _mean) - tl.store(Rstd + rows, _rstd) + tl.store(Mean + rows, _mean, mask=row_mask) + tl.store(Rstd + rows, _rstd, mask=row_mask) # Normalize and apply linear transformation x_hat = (x - mean) * rstd # Write output x_hat = x_hat.to(Y.type.element_ty) - tl.store(y_ptr, x_hat, mask=mask[None, :]) + tl.store(y_ptr, x_hat, mask=mask) def layernorm_noparam(x, eps): @@ -364,6 +375,7 @@ def layernorm_noparam(x, eps): rstd, # x.stride(0), y.stride(0), + M, N, N2, eps,