Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions turbodiffusion/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _rms_norm_fwd_fused(
Rstd,
x_stride,
y_stride,
M, # number of rows in X
N: tl.constexpr, # number of columns in X,
N2: tl.constexpr,
eps, # epsilon to avoid division by zero
Expand All @@ -110,30 +111,32 @@ def _rms_norm_fwd_fused(
pid = tl.program_id(0)
rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)
cols = tl.arange(0, N2)
mask = cols < N
col_mask = cols < N
row_mask = rows < M
mask = row_mask[:, None] & col_mask[None, :]

x_ptr = X + rows[:, None] * x_stride + cols[None, :]
y_ptr = Y + rows[:, None] * y_stride + cols[None, :]

x = tl.load(x_ptr, mask=mask[None, :], other=0.0).to(tl.float32)
x = tl.load(x_ptr, mask=mask, other=0.0).to(tl.float32)

# Compute variance
_var = x * x
var = tl.sum(_var, axis=1) / N
rstd = 1 / tl.sqrt(var + eps)

# Write mean / rstd
tl.store(Rstd + rows, rstd)
tl.store(Rstd + rows, rstd, mask=row_mask)
rstd = tl.reshape(rstd, (BLOCK_M, 1))

# Normalize and apply linear transformation
w = tl.load(W + cols)
w = tl.load(W + cols, mask=col_mask, other=0.0)
x_hat = x * rstd
y = x_hat * w

# Write output
y = y.to(Y.type.element_ty)
tl.store(y_ptr, y, mask=mask[None, :])
tl.store(y_ptr, y, mask=mask)


def rmsnorm(x, w, eps):
Expand Down Expand Up @@ -177,6 +180,7 @@ def rmsnorm(x, w, eps):
rstd, #
x.stride(0),
y.stride(0),
M,
N,
N2,
eps,
Expand All @@ -200,6 +204,7 @@ def _layer_norm_param_fwd_fused(
Rstd, # pointer to the 1/std
x_stride, # how much to increase the pointer when moving by 1 row
y_stride, # how much to increase the pointer when moving by 1 row
M, # number of rows in X
N: tl.constexpr, # number of columns in X,
N2: tl.constexpr, # number of columns in X,
eps, # epsilon to avoid division by zero
Expand All @@ -209,12 +214,14 @@ def _layer_norm_param_fwd_fused(
pid = tl.program_id(0)
rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)
cols = tl.arange(0, N2)
mask = cols < N
col_mask = cols < N
row_mask = rows < M
mask = row_mask[:, None] & col_mask[None, :]

x_ptr = X + rows[:, None] * x_stride + cols[None, :]
y_ptr = Y + rows[:, None] * y_stride + cols[None, :]

x = tl.load(x_ptr, mask=mask[None, :], other=0.0).to(tl.float32)
x = tl.load(x_ptr, mask=mask, other=0.0).to(tl.float32)

# Compute mean and Variance
mean = tl.sum(x, axis=1, keep_dims=True) / N
Expand All @@ -226,20 +233,20 @@ def _layer_norm_param_fwd_fused(
# Write mean / rstd
_mean = tl.reshape(mean, (BLOCK_M))
_rstd = tl.reshape(rstd, (BLOCK_M))
tl.store(Mean + rows, _mean)
tl.store(Rstd + rows, _rstd)
tl.store(Mean + rows, _mean, mask=row_mask)
tl.store(Rstd + rows, _rstd, mask=row_mask)

# Normalize and apply linear transformation
x_hat = (x - mean) * rstd

w = tl.load(W + cols)
b = tl.load(B + cols)
w = tl.load(W + cols, mask=col_mask, other=0.0)
b = tl.load(B + cols, mask=col_mask, other=0.0)

x_hat = x_hat * w + b

# Write output
x_hat = x_hat.to(Y.type.element_ty)
tl.store(y_ptr, x_hat, mask=mask[None, :])
tl.store(y_ptr, x_hat, mask=mask)


def layernorm_param(x, w, b, eps):
Expand Down Expand Up @@ -271,6 +278,7 @@ def layernorm_param(x, w, b, eps):
rstd, #
x.stride(0),
y.stride(0),
M,
N,
N2,
eps,
Expand Down Expand Up @@ -298,6 +306,7 @@ def _layer_norm_noparam_fwd_fused(
Rstd, # pointer to the 1/std
x_stride, # how much to increase the pointer when moving by 1 row
y_stride, # how much to increase the pointer when moving by 1 row
M, # number of rows in X
N: tl.constexpr, # number of columns in X,
N2: tl.constexpr, # number of columns in X,
eps, # epsilon to avoid division by zero
Expand All @@ -307,12 +316,14 @@ def _layer_norm_noparam_fwd_fused(
pid = tl.program_id(0)
rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)
cols = tl.arange(0, N2)
mask = cols < N
col_mask = cols < N
row_mask = rows < M
mask = row_mask[:, None] & col_mask[None, :]

x_ptr = X + rows[:, None] * x_stride + cols[None, :]
y_ptr = Y + rows[:, None] * y_stride + cols[None, :]

x = tl.load(x_ptr, mask=mask[None, :], other=0.0).to(tl.float32)
x = tl.load(x_ptr, mask=mask, other=0.0).to(tl.float32)

# Compute mean and Variance
mean = tl.sum(x, axis=1, keep_dims=True) / N
Expand All @@ -324,15 +335,15 @@ def _layer_norm_noparam_fwd_fused(
# Write mean / rstd
_mean = tl.reshape(mean, (BLOCK_M))
_rstd = tl.reshape(rstd, (BLOCK_M))
tl.store(Mean + rows, _mean)
tl.store(Rstd + rows, _rstd)
tl.store(Mean + rows, _mean, mask=row_mask)
tl.store(Rstd + rows, _rstd, mask=row_mask)

# Normalize and apply linear transformation
x_hat = (x - mean) * rstd

# Write output
x_hat = x_hat.to(Y.type.element_ty)
tl.store(y_ptr, x_hat, mask=mask[None, :])
tl.store(y_ptr, x_hat, mask=mask)


def layernorm_noparam(x, eps):
Expand Down Expand Up @@ -364,6 +375,7 @@ def layernorm_noparam(x, eps):
rstd, #
x.stride(0),
y.stride(0),
M,
N,
N2,
eps,
Expand Down