From 27cff30833ee145e97a6c84dc976b3def08eb2da Mon Sep 17 00:00:00 2001
From: Laura Wang <3700467+Laurawly@users.noreply.github.com>
Date: Tue, 6 Jan 2026 10:58:01 -0800
Subject: [PATCH 1/8] Provide a vLLM general plugin that registers
oink::rmsnorm and oink::fused_add_rms_norm backed by an SM100 CuTeDSL
RMSNorm kernel.
The ops are torch.compile-friendly (stride-preserving for padded-row inputs)
and the fused op matches vLLM's in-place residual-add RMSNorm semantics.
---
oink/README.md | 57 +
oink/pyproject.toml | 29 +
oink/src/kernelagent_oink/__init__.py | 95 +
.../kernelagent_oink/blackwell/__init__.py | 3 +
.../kernelagent_oink/blackwell/lite_quack.py | 350 +++
.../blackwell/oink_custom_ops.py | 224 ++
.../src/kernelagent_oink/blackwell/rmsnorm.py | 2660 +++++++++++++++++
7 files changed, 3418 insertions(+)
create mode 100644 oink/README.md
create mode 100644 oink/pyproject.toml
create mode 100644 oink/src/kernelagent_oink/__init__.py
create mode 100644 oink/src/kernelagent_oink/blackwell/__init__.py
create mode 100644 oink/src/kernelagent_oink/blackwell/lite_quack.py
create mode 100644 oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
create mode 100644 oink/src/kernelagent_oink/blackwell/rmsnorm.py
diff --git a/oink/README.md b/oink/README.md
new file mode 100644
index 0000000..427f69f
--- /dev/null
+++ b/oink/README.md
@@ -0,0 +1,57 @@
+# KernelAgent Oink (vLLM plugin)
+
+This subproject provides an **out-of-tree vLLM plugin** that registers
+`torch.library.custom_op` entrypoints under the `oink::` namespace:
+
+- `torch.ops.oink.rmsnorm`
+- `torch.ops.oink.fused_add_rms_norm`
+
+The implementation is backed by a CuTeDSL (CUTLASS) RMSNorm kernel tuned for
+**NVIDIA Blackwell (SM100)**.
+
+## Install (editable)
+
+From the `KernelAgent` repo root:
+
+```bash
+pip install -e ./oink
+```
+
+This plugin requires the CuTeDSL stack:
+
+```bash
+pip install nvidia-cutlass-dsl cuda-python
+```
+
+## Use with vLLM
+
+1. Enable the vLLM integration:
+
+```bash
+export VLLM_USE_OINK_RMSNORM=1
+```
+
+2. Ensure vLLM keeps `rms_norm` as a custom op when using `torch.compile` /
+CUDA graphs. In Python:
+
+```python
+from vllm import LLM
+
+llm = LLM(
+ model=...,
+ tensor_parallel_size=...,
+ enforce_eager=False,
+ compilation_config={"custom_ops": ["none", "+rms_norm"]},
+)
+```
+
+Without `+rms_norm`, Inductor may fuse RMSNorm into larger Triton kernels and
+neither vLLM's CUDA RMSNorm nor Oink will run.
+
+## Notes
+
+- This plugin is designed to be **safe to import even when disabled**; it only
+ registers ops when `VLLM_USE_OINK_RMSNORM` is truthy (`"1"` / `"true"`).
+- The ops preserve **padded-row layouts** for 2D tensors (shape `[M, N]`,
+ `stride(1) == 1`, and potentially `stride(0) > N`), which is required for
+ `torch.compile` stride verification on some models (e.g., MLA padded inputs).
diff --git a/oink/pyproject.toml b/oink/pyproject.toml
new file mode 100644
index 0000000..a9ec306
--- /dev/null
+++ b/oink/pyproject.toml
@@ -0,0 +1,29 @@
+[build-system]
+requires = ["setuptools>=61.0", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "kernelagent-oink"
+version = "0.1.0"
+description = "vLLM plugin that registers Oink Blackwell RMSNorm custom ops"
+readme = "README.md"
+requires-python = ">=3.10"
+license = {text = "Apache-2.0"}
+authors = [{name = "PyTorch Labs"}]
+
+# Keep dependencies minimal, but include the CuTeDSL stack required by the
+# Blackwell RMSNorm implementation.
+#
+# We intentionally do NOT depend on `torch` here because vLLM already pins and
+# provides a compatible PyTorch build.
+dependencies = [
+ "nvidia-cutlass-dsl",
+ "cuda-python",
+]
+
+[project.entry-points."vllm.general_plugins"]
+oink = "kernelagent_oink:register"
+
+[tool.setuptools.packages.find]
+where = ["src"]
+include = ["kernelagent_oink*"]
diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py
new file mode 100644
index 0000000..542e59e
--- /dev/null
+++ b/oink/src/kernelagent_oink/__init__.py
@@ -0,0 +1,95 @@
+from __future__ import annotations
+
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+_OPS_REGISTERED = False
+
+
+def _env_truthy(name: str) -> bool:
+ val = os.environ.get(name)
+ if val is None:
+ return False
+ return val.strip().lower() in ("1", "true", "yes", "on")
+
+
+def _infer_cuda_device_index() -> int:
+ local_rank = os.environ.get("LOCAL_RANK")
+ if local_rank is not None:
+ try:
+ return int(local_rank)
+ except ValueError:
+ pass
+ return 0
+
+
+def _compute_cutedsl_arch(major: int, minor: int) -> str:
+ # CuTeDSL uses an "a" suffix for >= Hopper.
+ suffix = "a" if major >= 9 else ""
+ # Match cutlass/base_dsl/env_manager.py: map sm_110 -> sm_101.
+ if major == 11 and minor == 0:
+ major, minor = 10, 1
+ return f"sm_{major}{minor}{suffix}"
+
+
+def register() -> None:
+ """vLLM plugin entrypoint.
+
+ This function must be safe to call multiple times and must not raise.
+ vLLM executes it in multiple processes (engine + workers).
+ """
+ global _OPS_REGISTERED
+
+ if _OPS_REGISTERED:
+ return
+
+ # Gate on the vLLM integration flag so installing the package does not
+ # change behavior unless explicitly enabled.
+ if not _env_truthy("VLLM_USE_OINK_RMSNORM"):
+ return
+
+ try:
+ import torch
+ except Exception as e: # pragma: no cover
+ logger.debug("Oink plugin: torch import failed: %s", e)
+ return
+
+ try:
+ if not torch.cuda.is_available():
+ return
+ device_index = _infer_cuda_device_index()
+ major, minor = torch.cuda.get_device_capability(device_index)
+ sm = 10 * int(major) + int(minor)
+ if sm < 100:
+ return
+
+ # Ensure required deps are importable before registering ops so that vLLM
+ # doesn't detect ops that would later fail at first use.
+ try:
+ import cutlass # noqa: F401
+ import cuda.bindings.driver as _cuda # noqa: F401
+ except Exception as e:
+ logger.warning(
+ "Oink plugin: CuTeDSL deps missing; skipping op registration. "
+ "Install `nvidia-cutlass-dsl` + `cuda-python`. Error: %s",
+ e,
+ )
+ return
+
+ # Ensure CuTeDSL sees a target arch early. If the user has already set it,
+ # respect their choice.
+ os.environ.setdefault("CUTE_DSL_ARCH", _compute_cutedsl_arch(int(major), int(minor)))
+
+ # Import registers the ops via torch.library.custom_op decorators.
+ from .blackwell import oink_custom_ops # noqa: F401
+ except Exception as e: # pragma: no cover
+ # Do not raise: vLLM plugin loader does not guard plugin execution.
+ logger.exception("Oink plugin: failed to register ops: %s", e)
+ return
+
+ _OPS_REGISTERED = True
+
+
+__all__ = ["register"]
diff --git a/oink/src/kernelagent_oink/blackwell/__init__.py b/oink/src/kernelagent_oink/blackwell/__init__.py
new file mode 100644
index 0000000..4d21ee8
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/__init__.py
@@ -0,0 +1,3 @@
+from __future__ import annotations
+
+__all__ = []
diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py
new file mode 100644
index 0000000..3c3f750
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py
@@ -0,0 +1,350 @@
+"""
+Lightweight local clone of the small subset of helpers that the SM100
+RMSNorm CuteDSL kernels depend on.
+
+This module intentionally avoids importing the `quack` package so that
+Oink Blackwell kernels can run without Quack installed, while keeping
+numerical behaviour and performance close to the original reference
+implementations.
+"""
+
+from __future__ import annotations
+
+import math
+import operator
+from typing import Callable, Optional, Tuple
+
+import cuda.bindings.driver as cuda # type: ignore
+import torch
+from torch import Tensor
+
+import cutlass
+import cutlass.cute as cute
+from cutlass import Float32, Int32, const_expr
+from cutlass.cute.runtime import from_dlpack
+from cutlass.cutlass_dsl import T, dsl_user_op
+from cutlass._mlir.dialects import llvm, nvvm, vector
+
+
+# -------------------------
+# Dtype mapping
+# -------------------------
+
+TORCH2CUTE_DTYPE = {
+ torch.float16: cutlass.Float16,
+ torch.bfloat16: cutlass.BFloat16,
+ torch.float32: cutlass.Float32,
+}
+
+
+# -------------------------
+# Tensor conversion helpers
+# -------------------------
+
+def convert_from_dlpack(
+ x: Tensor,
+ leading_dim: int,
+ alignment: int = 16,
+ divisibility: int = 1,
+) -> cute.Tensor:
+ """
+ Wrap a torch.Tensor in a CuteDSL tensor with layout metadata that
+ matches the logical leading dimension and alignment/divisibility
+ constraints expected by SM100 kernels.
+ """
+ return (
+ from_dlpack(x, assumed_align=alignment)
+ .mark_layout_dynamic(leading_dim=leading_dim)
+ .mark_compact_shape_dynamic(
+ mode=leading_dim,
+ stride_order=x.dim_order(),
+ divisibility=divisibility,
+ )
+ )
+
+
+# -------------------------
+# SM90/SM100 cluster helpers
+# -------------------------
+
+
+@dsl_user_op
+def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
+ return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
+
+
+@dsl_user_op
+def set_block_rank(
+ smem_ptr: cute.Pointer,
+ peer_cta_rank_in_cluster: cute.Int32,
+ *,
+ loc=None,
+ ip=None,
+) -> cutlass.Int32:
+ """Map the given smem pointer to the address at another CTA rank in the cluster."""
+ smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value()
+ return cutlass.Int32(
+ llvm.inline_asm(
+ T.i32(),
+ [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()],
+ "mapa.shared::cluster.u32 $0, $1, $2;",
+ "=r,r,r",
+ has_side_effects=False,
+ is_align_stack=False,
+ asm_dialect=llvm.AsmDialect.AD_ATT,
+ )
+ )
+
+
+@dsl_user_op
+def store_shared_remote(
+ val: float | Float32 | Int32 | cutlass.Int64,
+ smem_ptr: cute.Pointer,
+ mbar_ptr: cute.Pointer,
+ peer_cta_rank_in_cluster: cute.typing.Int,
+ *,
+ loc=None,
+ ip=None,
+) -> None:
+ remote_smem_ptr_i32 = set_block_rank(
+ smem_ptr,
+ peer_cta_rank_in_cluster,
+ loc=loc,
+ ip=ip,
+ ).ir_value()
+ remote_mbar_ptr_i32 = set_block_rank(
+ mbar_ptr,
+ peer_cta_rank_in_cluster,
+ loc=loc,
+ ip=ip,
+ ).ir_value()
+ if const_expr(isinstance(val, float)):
+ val = Float32(val)
+ assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
+ suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
+ constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)]
+ llvm.inline_asm(
+ None,
+ [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32],
+ f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];",
+ f"r,{constraint},r",
+ has_side_effects=True,
+ is_align_stack=False,
+ asm_dialect=llvm.AsmDialect.AD_ATT,
+ )
+
+
+@cute.jit
+def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
+ """
+ Build a predicate tensor for the K dimension only. Values beyond
+ `limit` are masked out.
+ """
+ tApA = cute.make_fragment(
+ cute.make_layout(
+ (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
+ stride=(cute.size(tAcA, mode=[2]), 0, 1),
+ ),
+ cutlass.Boolean,
+ )
+ for rest_v in cutlass.range_constexpr(tApA.shape[0]):
+ for rest_k in cutlass.range_constexpr(tApA.shape[2]):
+ tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
+ return tApA
+
+
+@dsl_user_op
+def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
+ """
+ Return a tensor whose iterator is offset by an Int64 byte offset
+ computed from `coord` and the tensor's strides.
+ """
+ flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
+ flat_stride = cute.flatten_to_tuple(tensor.stride)
+ assert len(flat_coord_i64) == len(flat_stride), (
+ "Coordinate and stride must have the same length"
+ )
+ offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride))
+ assert isinstance(tensor.iterator, cute.Pointer)
+ new_ptr = cute.make_ptr(
+ tensor.element_type,
+ tensor.iterator.toint() + offset * tensor.element_type.width // 8,
+ tensor.memspace,
+ assumed_align=tensor.iterator.max_alignment,
+ )
+ return cute.make_tensor(new_ptr, tensor.layout)
+
+
+# -------------------------
+# Reduction helpers
+# -------------------------
+
+
+@cute.jit
+def warp_reduce(
+ val: cute.TensorSSA | cute.Numeric,
+ op: Callable,
+ width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
+) -> cute.TensorSSA | cute.Numeric:
+ """
+ Warp-level reduction for either scalar values or small TensorSSA
+ fragments.
+ """
+ if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
+ res = cute.make_fragment(val.shape, val.dtype)
+ res.store(val)
+ for i in cutlass.range_constexpr(cute.size(val.shape)):
+ res[i] = warp_reduce(res[i], op, width)
+ return res.load()
+ for i in cutlass.range_constexpr(int(math.log2(width))):
+ val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
+ return val
+
+
+@cute.jit
+def block_reduce(
+ val: cute.Numeric,
+ op: Callable,
+ reduction_buffer: cute.Tensor,
+ init_val: cute.Numeric = 0.0,
+) -> cute.Numeric:
+ """Block-level reduction across warps."""
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
+ warps_per_row = cute.size(reduction_buffer.shape[1])
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
+ if lane_idx == 0:
+ reduction_buffer[row_idx, col_idx] = val
+ cute.arch.barrier()
+ block_reduce_val = init_val
+ if lane_idx < warps_per_row:
+ block_reduce_val = reduction_buffer[row_idx, lane_idx]
+ return warp_reduce(block_reduce_val, op)
+
+
+@cute.jit
+def cluster_reduce(
+ val: cute.Numeric,
+ op: Callable,
+ reduction_buffer: cute.Tensor,
+ mbar_ptr: cute.Pointer,
+ init_val: cute.Numeric = 0.0,
+ phase: Optional[cutlass.Int32] = None,
+) -> cute.Numeric:
+ """
+ Cluster-wide reduction using shared memory and mbarrier. The
+ reduction_buffer has shape (rows_per_block, (warps_per_row, cluster_n)).
+ """
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
+ if warp_idx == 0:
+ with cute.arch.elect_one():
+ num_warps = rows_per_block * warps_per_row
+ cute.arch.mbarrier_arrive_and_expect_tx(
+ mbar_ptr,
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
+ )
+ if lane_idx < cluster_n:
+ store_shared_remote(
+ val,
+ elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
+ mbar_ptr,
+ peer_cta_rank_in_cluster=lane_idx,
+ )
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
+ block_reduce_val = init_val
+ num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
+ for i in cutlass.range_constexpr(num_iter):
+ idx = lane_idx + i * cute.arch.WARP_SIZE
+ if idx < cute.size(reduction_buffer, mode=[1]):
+ block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx])
+ return warp_reduce(block_reduce_val, op)
+
+
+@cute.jit
+def block_or_cluster_reduce(
+ val: cute.Numeric,
+ op: Callable,
+ reduction_buffer: cute.Tensor,
+ mbar_ptr: Optional[cute.Pointer],
+ phase: Optional[cutlass.Int32] = None,
+ init_val: cute.Numeric = 0.0,
+) -> cute.Numeric:
+ """Dispatch between block or cluster reduction depending on mbar_ptr."""
+ if cutlass.const_expr(mbar_ptr is None):
+ return block_reduce(val, op, reduction_buffer, init_val=init_val)
+ return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase)
+
+
+@cute.jit
+def row_reduce(
+ x: cute.TensorSSA | cute.Numeric,
+ op: cute.ReductionOp,
+ threads_per_row: cutlass.Constexpr[int],
+ reduction_buffer: Optional[cute.Tensor] = None,
+ mbar_ptr: Optional[cute.Pointer] = None,
+ phase: Optional[cutlass.Int32] = None,
+ init_val: cute.Numeric = 0.0,
+ hook_fn: Optional[Callable] = None,
+) -> cute.Numeric:
+ """
+ Row-wise reduction used by RMSNorm and similar kernels.
+
+ reduction_buffer must have shape
+ (num_warps / warps_per_row, (warps_per_row, cluster_n)).
+ """
+ if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
+ val = x.reduce(op, init_val=init_val, reduction_profile=0)
+ else:
+ val = x
+ warp_op = {
+ cute.ReductionOp.ADD: operator.add,
+ cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
+ cute.ReductionOp.MIN: min,
+ cute.ReductionOp.MUL: operator.mul,
+ }[op]
+ val = warp_reduce(
+ val,
+ warp_op,
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
+ )
+ if cutlass.const_expr(hook_fn is not None):
+ hook_fn()
+ if cutlass.const_expr(reduction_buffer is not None):
+ warps_per_row, cluster_n = reduction_buffer.shape[1]
+ assert cluster_n == 1 or mbar_ptr is not None, (
+ "mbar_ptr must be provided for cluster reduction"
+ )
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
+ val = block_or_cluster_reduce(
+ val,
+ warp_op,
+ reduction_buffer,
+ mbar_ptr,
+ phase=phase,
+ init_val=init_val,
+ )
+ return val
+
+
+# -------------------------
+# SM count helper
+# -------------------------
+
+
+def get_sm_count(N: int, device: torch.device) -> int:
+ """
+ Heuristic for the number of persistent CTAs (sm_count) based on N and
+ the GPU's SM count. This mirrors the behaviour used in Quack's
+ RMSNorm kernels but lives entirely in this local module.
+ """
+ sm_count_multiple = (
+ 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
+ )
+ sm_count = torch.cuda.get_device_properties(device).multi_processor_count
+ sm_count = (
+ sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
+ )
+ return sm_count
+
diff --git a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
new file mode 100644
index 0000000..8225025
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
@@ -0,0 +1,224 @@
+from __future__ import annotations
+
+"""
+Torch custom ops wrapping Oink's Blackwell RMSNorm kernels.
+
+These ops are designed to be:
+- Architecture-aware (use CuTeDSL SM100 kernels when available, fall back
+ to a safe reference elsewhere).
+- Layout-preserving for 2D row-major inputs, including padded MLA-style
+ layouts where stride(0) > N and stride(1) == 1.
+- torch.compile-friendly via proper fake implementations that mirror
+ runtime shapes and strides.
+
+Public ops (Python signatures):
+
+ torch.ops.oink.rmsnorm(x: Tensor, weight: Tensor, eps: float) -> Tensor
+ Functional RMSNorm. Returns a new tensor with the same shape and
+ stride as x when using the fast CuTeDSL path.
+
+ torch.ops.oink.fused_add_rms_norm(
+ x: Tensor, residual: Tensor, weight: Tensor, eps: float
+ ) -> None
+ In-place fused residual-add + RMSNorm matching vLLM semantics:
+ residual = x + residual (stored into `residual`)
+ x = RMSNorm(residual, w) (stored into `x`)
+ Mutates `x` and `residual` in-place and returns None.
+"""
+
+import importlib
+import threading
+
+import torch
+from torch.library import custom_op
+
+_RMSNORM_MOD: object | None = None
+_RMSNORM_MOD_LOCK = threading.Lock()
+
+
+def _get_rmsnorm_mod():
+ """Lazy import to keep plugin registration lightweight.
+
+ Importing the CuTeDSL kernel stack can be expensive and may require a CUDA
+ context. We defer it until the first actual execution of the custom op.
+ """
+ global _RMSNORM_MOD
+
+ cached = _RMSNORM_MOD
+ if cached is not None:
+ return cached
+
+ with _RMSNORM_MOD_LOCK:
+ if _RMSNORM_MOD is None:
+ _RMSNORM_MOD = importlib.import_module("kernelagent_oink.blackwell.rmsnorm")
+ return _RMSNORM_MOD
+
+
+def _get_sm(device: torch.device | None = None) -> int:
+ """Return SM version as an int (e.g., 100 for SM100 / Blackwell)."""
+ if device is None:
+ device = torch.device("cuda")
+ major, minor = torch.cuda.get_device_capability(device)
+ return 10 * major + minor
+
+
+#
+# RMSNorm (functional)
+#
+
+@custom_op("oink::rmsnorm", mutates_args=())
+def oink_rmsnorm(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float,
+) -> torch.Tensor:
+ """
+ Functional RMSNorm entrypoint.
+
+ This op is model-agnostic. It expects a 2D [M, N] view of the input
+ where the last dimension is contiguous (stride(1) == 1). The leading
+ dimension stride(0) may be larger than N (padded-row layouts), and
+ will be preserved on the fast CuTeDSL path.
+
+ On SM100 (and newer), this dispatches to the tuned CuTeDSL Blackwell
+ RMSNorm kernel in rmsnorm.rmsnorm_forward, which in turn selects the
+ best internal schedule (including DSv3-specific stage-2 kernels where
+ applicable) and preserves the input's 2D stride when using the
+ pointer-based path.
+
+ On older architectures it falls back to a safe PyTorch reference
+ implementation for correctness.
+ """
+ assert x.is_cuda, "oink::rmsnorm requires CUDA tensors"
+ assert x.dim() == 2, "oink::rmsnorm expects a 2D [M, N] tensor view"
+ assert weight.dim() == 1, "weight must be 1D [N]"
+
+ sm = _get_sm(x.device)
+ if sm >= 100:
+ # Use the tuned CuTeDSL SM100 kernel. The public API already
+ # contains all necessary gating and layout checks internally.
+ _rms = _get_rmsnorm_mod()
+ y, _rstd, _res = _rms.rmsnorm_forward(
+ x,
+ weight=weight,
+ bias=None,
+ residual=None,
+ eps=eps,
+ store_rstd=False,
+ )
+ return y
+
+ # Fallback: reference implementation (correctness-first).
+ _rms = _get_rmsnorm_mod()
+ return _rms.rmsnorm_ref(
+ x,
+ w=weight,
+ b=None,
+ residual=None,
+ eps=eps,
+ )
+
+
+@oink_rmsnorm.register_fake
+def oink_rmsnorm_fake(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float,
+) -> torch.Tensor:
+ """
+ Fake (meta) implementation for oink::rmsnorm.
+
+ We must preserve x's logical layout (shape + stride) so that Inductor's
+ CUDA graph capture sees the same stride contract as the real kernel.
+ """
+ # x is a FakeTensor here; x.shape/x.stride()/x.device/x.dtype are defined.
+ return torch.empty_strided(
+ x.shape,
+ x.stride(),
+ device=x.device,
+ dtype=x.dtype,
+ )
+
+
+#
+# Fused residual-add + RMSNorm (in-place, vLLM semantics)
+#
+
+@custom_op("oink::fused_add_rms_norm", mutates_args=("x", "residual"))
+def oink_fused_add_rms_norm(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float,
+) -> None:
+ """
+ In-place fused residual-add + RMSNorm:
+
+ residual <- x + residual
+ x <- RMSNorm(residual, weight, eps)
+
+ Returns:
+ None (mutates `x` and `residual` in-place).
+ """
+ assert x.is_cuda and residual.is_cuda, "oink::fused_add_rms_norm requires CUDA tensors"
+ assert x.shape == residual.shape, "x and residual must have the same shape"
+ assert x.dtype == residual.dtype, "x and residual must have the same dtype"
+ assert weight.dim() == 1, "weight must be 1D [N]"
+
+ sm = _get_sm(x.device)
+ if sm >= 100:
+ _rms = _get_rmsnorm_mod()
+ # Prefer the lowest-overhead in-place entrypoint (returns None).
+ if hasattr(_rms, "fused_add_rmsnorm_inplace_"):
+ _rms.fused_add_rmsnorm_inplace_( # type: ignore[misc]
+ x,
+ residual,
+ weight,
+ eps=eps,
+ )
+ return None
+ # Backward-compatible wrapper (returns (x, residual)).
+ if hasattr(_rms, "fused_add_rmsnorm_forward_inplace"):
+ _rms.fused_add_rmsnorm_forward_inplace( # type: ignore[misc]
+ x,
+ residual,
+ weight,
+ eps=eps,
+ )
+ return None
+
+ # Extremely defensive fallback if the Oink module doesn't provide
+ # the in-place entrypoint.
+ y, z = _rms.fused_add_rmsnorm_forward(x, residual, weight, eps=eps)
+ x.copy_(y)
+ residual.copy_(z)
+ return None
+
+ # Non-SM100 fallback: keep semantics in-place (correctness-first).
+ residual.add_(x)
+ _rms = _get_rmsnorm_mod()
+ y = _rms.rmsnorm_ref(residual, w=weight, b=None, residual=None, eps=eps)
+ x.copy_(y)
+ return None
+
+
+@oink_fused_add_rms_norm.register_fake
+def oink_fused_add_rms_norm_fake(
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float,
+) -> None:
+ """
+ Fake (meta) implementation for oink::fused_add_rms_norm.
+
+ Because this op mutates its inputs in-place, the outputs alias the input
+ buffers and therefore have the same shapes and strides.
+ """
+ return None
+
+
+__all__ = [
+ "oink_rmsnorm",
+ "oink_fused_add_rms_norm",
+]
diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
new file mode 100644
index 0000000..d6c2c20
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
@@ -0,0 +1,2660 @@
+"""
+RMSNorm kernel for SM100 (Blackwell) in CuteDSL.
+
+This implementation targets Blackwell with:
+- A stride-preserving pointer path for padded-row layouts (e.g. MLA stride0> N).
+- A one-pass fused-add RMSNorm schedule for bf16/fp16 (DSv3 N=7168) that keeps
+ `x + residual` in registers (avoids re-reading gmem) and uses FP32 accumulation.
+- Optional experimental schedule knobs (env vars) to explore copy widths and
+ stage-2 cp.async variants.
+
+Note: This file expects the local CuTeDSL (cutlass) and SM100 helper modules
+to be available in the Python environment (e.g., `nvidia-cutlass-dsl` and
+`cuda-python`). It is shipped as part of the KernelAgent Oink vLLM plugin.
+"""
+
+from __future__ import annotations
+
+import ctypes
+import importlib.metadata
+import os
+import re
+import subprocess
+import sys
+import threading
+from typing import Optional, Tuple
+
+_HERE = os.path.dirname(__file__)
+
+# CuTeDSL caches generated MLIR into a tempdir under a global default
+# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across
+# `nvidia-cutlass-dsl` versions (e.g. 4.3.2 vs 4.3.4), and cross-version cache
+# sharing causes noisy "invalid section ID" warnings (and disables cache reuse).
+#
+# If the user has not pinned `CUTE_DSL_CACHE_DIR`, isolate by version so multiple
+# CuTeDSL envs can coexist on the same machine without stepping on each other.
+if "CUTE_DSL_CACHE_DIR" not in os.environ:
+ try:
+ _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl")
+ except Exception:
+ _dsl_ver = "unknown"
+ _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver)
+ _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user"
+ _tmp = os.environ.get("TMPDIR") or "/tmp"
+ os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join(
+ _tmp, _user, f"cutlass_python_cache_{_dsl_ver}"
+ )
+
+try:
+ import cutlass # type: ignore # noqa: F401
+except Exception as e:
+ raise ImportError(
+ "kernelagent_oink.blackwell.rmsnorm requires CuTeDSL's Python package "
+ "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)."
+ ) from e
+
+import torch
+from torch import Tensor
+
+import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python
+
+import cutlass
+import cutlass.cute as cute
+from cutlass import Float32, Int32, const_expr
+from cutlass.cute import runtime as rt
+
+# Simple compile cache declared early so direct execution works
+_PTR_COMPILE_CACHE = {}
+
+# Thread-local cache for the fast-launch path. We keep per-thread packed args and
+# pointer/scalar storage so concurrent callers don't race on in-place updates.
+_PTR_FAST_LAUNCH_TLS = threading.local()
+
+def _env_flag(name: str, default: bool) -> bool:
+ val = os.environ.get(name)
+ if val is None:
+ return default
+ return val.strip().lower() not in {"0", "false", "no", "off", ""}
+
+
+# Fast-launch uses a few private-ish CuTeDSL internals (packed args plumbing and
+# runtime pointer descriptors). Keep it enabled by default for our pinned CuTeDSL
+# environment, but allow disabling it via env var and auto-disable it if those
+# internals are not present in a future upgrade.
+_ENABLE_FAST_LAUNCH = _env_flag("OINK_CUTEDSL_FAST_LAUNCH", default=True)
+_FAST_LAUNCH_SUPPORTED = True
+
+# Fused-add RMSNorm schedule knobs (read once at import time; set env vars before
+# importing this module if you want to override).
+_DIRECT_GMEM_POLICY = (os.environ.get("OINK_RMSNORM_DIRECT_GMEM", "auto").strip().lower() or "auto")
+_COPY_BITS_POLICY = (os.environ.get("OINK_RMSNORM_COPY_BITS", "auto").strip().lower() or "auto")
+_ENABLE_CLUSTER_ILP = _env_flag("OINK_RMSNORM_ENABLE_CLUSTER_ILP", default=False)
+_ENABLE_CLUSTER_ILP_UNSAFE = _env_flag("OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE", default=False)
+_ENABLE_TPR256 = _env_flag("OINK_RMSNORM_ENABLE_TPR256", default=False)
+_ENABLE_STAGE2 = _env_flag("OINK_RMSNORM_ENABLE_STAGE2", default=False)
+
+# CuTeDSL stability probe for the experimental cluster_n>1 + direct-GMEM schedule.
+#
+# Some CuTeDSL builds segfault during JIT compilation when combining:
+# - cluster launches (cluster_n>1) and
+# - direct-GMEM loads/stores (no staging SMEM tiles).
+#
+# We keep the schedule gated behind `OINK_RMSNORM_ENABLE_CLUSTER_ILP=1` +
+# `OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE=1`, and additionally run a one-time
+# out-of-process compile probe so we can safely fall back to the staged SMEM
+# path instead of crashing the parent process.
+#
+# This is (currently) sensitive to the vector width: we have observed
+# reproducible segfaults for the 256b universal-copy path, while the 128b path
+# can succeed. Cache the maximum supported copy width (0 = unsupported).
+_CLUSTER_DIRECT_GMEM_MAX_COPY_BITS: Optional[int] = None
+_CLUSTER_DIRECT_GMEM_PROBE_LOCK = threading.Lock()
+_CLUSTER_DIRECT_GMEM_PROBE_WARNED = False
+
+
+def _probe_cluster_direct_gmem_max_copy_bits() -> int:
+ global _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS
+ global _CLUSTER_DIRECT_GMEM_PROBE_WARNED
+
+ override = os.environ.get("OINK_RMSNORM_CLUSTER_DIRECT_GMEM_MAX_COPY_BITS")
+ if override is not None and override.strip() != "":
+ try:
+ value = int(override)
+ except ValueError:
+ value = 0
+ value = 256 if value >= 256 else 128 if value >= 128 else 0
+ _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS = value
+ return value
+
+ if _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS is not None:
+ return _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS
+
+ with _CLUSTER_DIRECT_GMEM_PROBE_LOCK:
+ if _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS is not None:
+ return _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS
+
+ script_template = r"""
+import os
+
+os.environ["OINK_CUTEDSL_FAST_LAUNCH"] = "0"
+
+import cutlass
+import cutlass.cute as cute
+import cuda.bindings.driver as cuda
+from cutlass import Float32, Int32
+from cutlass.cute import runtime as rt
+
+from kernelagent_oink.blackwell import rmsnorm
+
+N = 7168
+dtype = cutlass.BFloat16
+
+copy_bits = int(os.environ["OINK_PROBE_COPY_BITS"])
+assumed_align = int(os.environ["OINK_PROBE_ASSUMED_ALIGN"])
+
+op = rmsnorm.RMSNormSM100(
+ N,
+ dtype,
+ stage=1,
+ copy_bits=copy_bits,
+ use_async=False,
+ direct_gmem=True,
+)
+op._cluster_n_override = 2 # 2 CTAs per row
+
+ptr_x = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
+ptr_res = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
+ptr_w = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
+
+_ = cute.compile(
+ op.launch_from_ptrs_fused_add_inplace,
+ ptr_x,
+ ptr_w,
+ ptr_res,
+ Int32(4096),
+ Int32(N),
+ Int32(N),
+ cuda.CUstream(0),
+ Float32(1e-6),
+)
+print(f"ok {copy_bits}")
+"""
+
+ env = os.environ.copy()
+ env["PYTHONNOUSERSITE"] = "1"
+
+ def run_probe(copy_bits: int, assumed_align: int):
+ probe_env = env.copy()
+ probe_env["OINK_PROBE_COPY_BITS"] = str(copy_bits)
+ probe_env["OINK_PROBE_ASSUMED_ALIGN"] = str(assumed_align)
+ return subprocess.run(
+ [sys.executable, "-c", script_template],
+ env=probe_env,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ timeout=120.0,
+ )
+
+ proc_256 = None
+ proc_128 = None
+ try:
+ proc_256 = run_probe(256, 32)
+ if proc_256.returncode == 0:
+ max_bits = 256
+ else:
+ proc_128 = run_probe(128, 16)
+ max_bits = 128 if proc_128.returncode == 0 else 0
+ except Exception:
+ max_bits = 0
+
+ if not _CLUSTER_DIRECT_GMEM_PROBE_WARNED and max_bits != 256:
+ _CLUSTER_DIRECT_GMEM_PROBE_WARNED = True
+ if max_bits == 128:
+ print(
+ "Oink: cluster_n>1 + direct_gmem 256b compile probe failed; "
+ "using 128b copies for the cluster ILP schedule.",
+ file=sys.stderr,
+ )
+ if proc_256 is not None and proc_256.stderr:
+ tail = "\n".join(proc_256.stderr.splitlines()[-12:])
+ print(f"Oink: probe stderr tail:\n{tail}", file=sys.stderr)
+ else:
+ rc = (
+ proc_128.returncode
+ if proc_128 is not None
+ else (proc_256.returncode if proc_256 is not None else "unknown")
+ )
+ print(
+ "Oink: cluster_n>1 + direct_gmem compile probe failed; "
+ f"falling back to staged SMEM path (returncode={rc}).",
+ file=sys.stderr,
+ )
+ failing_proc = proc_128 if proc_128 is not None else proc_256
+ if failing_proc is not None and failing_proc.stderr:
+ tail = "\n".join(failing_proc.stderr.splitlines()[-12:])
+ print(f"Oink: probe stderr tail:\n{tail}", file=sys.stderr)
+
+ _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS = max_bits
+ return max_bits
+
+def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
+ parts = version.split(".")
+ nums: list[int] = []
+ for part in parts[:3]:
+ match = re.match(r"^(\d+)", part)
+ nums.append(int(match.group(1)) if match is not None else 0)
+ while len(nums) < 3:
+ nums.append(0)
+ return nums[0], nums[1], nums[2]
+
+
+def _cutlass_dsl_version() -> Optional[Tuple[int, int, int]]:
+ try:
+ return _parse_version_tuple(importlib.metadata.version("nvidia-cutlass-dsl"))
+ except Exception:
+ return None
+
+
+_CUTLASS_DSL_VERSION = _cutlass_dsl_version()
+# CuTeDSL 4.3.4 tightened some kernel argument expectations (notably around
+# passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the
+# older signature for 4.3.2, but switch to a 4.3.4-compatible signature when we
+# detect 4.3.4+ (or when version detection is unavailable).
+_KERNEL_ACCEPTS_LAYOUT_ARGS = _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4)
+
+if _ENABLE_CLUSTER_ILP and not _ENABLE_CLUSTER_ILP_UNSAFE:
+ # We have observed reproducible segfaults in some CuTeDSL builds when using
+ # cluster launches for this schedule. Require an explicit UNSAFE opt-in to
+ # avoid accidental crashes.
+ _ENABLE_CLUSTER_ILP = False
+ print(
+ "Oink: OINK_RMSNORM_ENABLE_CLUSTER_ILP requested but disabled by default due to "
+ "known instability; set OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE=1 to force-enable.",
+ file=sys.stderr,
+ )
+
+
+def _fast_launch_enabled() -> bool:
+ return _ENABLE_FAST_LAUNCH and _FAST_LAUNCH_SUPPORTED
+
+
+def _direct_gmem_from_policy(*, default: bool) -> bool:
+ """Resolve the direct-GMEM schedule flag from the (import-time) policy string."""
+ if _DIRECT_GMEM_POLICY in {"0", "false", "no", "off"}:
+ return False
+ if _DIRECT_GMEM_POLICY in {"1", "true", "yes", "on"}:
+ return True
+ return default
+
+
+def _copy_bits_from_policy(*, default: int, can_use_256: bool) -> int:
+ """Resolve copy width (in bits) from the (import-time) policy string."""
+ if _COPY_BITS_POLICY in {"128"}:
+ return 128
+ if _COPY_BITS_POLICY in {"256"} and can_use_256:
+ return 256
+ return default
+
+
+class _StableI32Arg:
+ """A stable Int32 runtime arg (avoids per-call Int32().__c_pointers__ allocations)."""
+
+ def __init__(self, value: int):
+ self._c_value = ctypes.c_int32(int(value))
+ self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p)
+
+ def set(self, value: int) -> None:
+ self._c_value.value = int(value)
+
+ def __c_pointers__(self):
+ return [self._c_pointer]
+
+
+class _StableF32Arg:
+ """A stable Float32 runtime arg (avoids per-call Float32().__c_pointers__ allocations)."""
+
+ def __init__(self, value: float):
+ self._c_value = ctypes.c_float(float(value))
+ self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p)
+
+ def set(self, value: float) -> None:
+ self._c_value.value = float(value)
+
+ def __c_pointers__(self):
+ return [self._c_pointer]
+
+
+def _tls_fast_launch_cache() -> dict[tuple[object, ...], object]:
+ cache = getattr(_PTR_FAST_LAUNCH_TLS, "cache", None)
+ if cache is None:
+ cache = {}
+ _PTR_FAST_LAUNCH_TLS.cache = cache
+ return cache
+
+
+def _set_runtime_ptr(ptr: object, device_ptr: int) -> None:
+ # Runtime pointer objects cache a `ctypes.c_void_p` descriptor and pass
+ # its address to the compiled function. Updating `_desc.value` updates
+ # the device pointer without changing the address of the descriptor.
+ #
+ # This relies on internal CuTeDSL runtime pointer fields (`_desc`, `_pointer`,
+ # etc.). If these internals change in a future CuTeDSL upgrade, callers
+ # should catch AttributeError and fall back to the regular launch path.
+ device_ptr = int(device_ptr)
+ ptr._pointer = device_ptr # type: ignore[attr-defined]
+ if getattr(ptr, "_c_pointer", None) is None:
+ ptr.__c_pointers__() # type: ignore[attr-defined]
+ ptr._desc.value = device_ptr # type: ignore[attr-defined]
+
+
+class _PtrRmsnormFastLaunch:
+ def __init__(
+ self,
+ *,
+ compiled: object,
+ executor: object,
+ capi_func: object,
+ ptr_x: object,
+ ptr_w: Optional[object],
+ ptr_out: object,
+ arg_m: _StableI32Arg,
+ arg_n: _StableI32Arg,
+ arg_ld: _StableI32Arg,
+ arg_eps: _StableF32Arg,
+ stream: cuda.CUstream,
+ packed_args: object,
+ keepalive: tuple[object, ...],
+ ):
+ self._compiled = compiled
+ self._executor = executor
+ self._capi_func = capi_func
+ self._ptr_x = ptr_x
+ self._ptr_w = ptr_w
+ self._ptr_out = ptr_out
+ self._arg_m = arg_m
+ self._arg_n = arg_n
+ self._arg_ld = arg_ld
+ self._arg_eps = arg_eps
+ self._stream = stream
+ self._packed_args = packed_args
+ self._keepalive = keepalive
+
+ self._use_fast_launch = True
+
+ self._cuda_result = getattr(executor, "cuda_result", None)
+
+ self._last_x_ptr = -1
+ self._last_w_ptr = -1
+ self._last_out_ptr = -1
+ self._last_m = -1
+ self._last_ld = -1
+ self._last_eps = float("nan")
+
+ def launch(
+ self,
+ *,
+ x: Tensor,
+ weight: Optional[Tensor],
+ out: Tensor,
+ M: int,
+ N: int,
+ ld: int,
+ eps: float,
+ ) -> None:
+ if not _fast_launch_enabled() or not self._use_fast_launch:
+ self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps)
+ return
+
+ x_ptr = x.data_ptr()
+ if x_ptr != self._last_x_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_x, x_ptr)
+ self._last_x_ptr = x_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps)
+ return
+
+ if self._ptr_w is not None:
+ w_ptr = weight.data_ptr() # type: ignore[union-attr]
+ if w_ptr != self._last_w_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_w, w_ptr)
+ self._last_w_ptr = w_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps)
+ return
+
+ out_ptr = out.data_ptr()
+ if out_ptr != self._last_out_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_out, out_ptr)
+ self._last_out_ptr = out_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps)
+ return
+
+ if M != self._last_m:
+ self._arg_m.set(M)
+ self._last_m = M
+ if ld != self._last_ld:
+ self._arg_ld.set(ld)
+ self._last_ld = ld
+ if eps != self._last_eps:
+ self._arg_eps.set(eps)
+ self._last_eps = eps
+
+ # Clear the error slot before launch (mirrors JitExecutor behavior).
+ if self._cuda_result is not None:
+ self._cuda_result.value = 0
+
+ ret = self._capi_func(self._packed_args) # type: ignore[misc]
+ if ret != 0:
+ raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}")
+ if self._cuda_result is not None:
+ err = int(self._cuda_result.value)
+ if err != 0:
+ raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})")
+
+ def _disable_fast_launch(self) -> None:
+ global _FAST_LAUNCH_SUPPORTED
+ self._use_fast_launch = False
+ _FAST_LAUNCH_SUPPORTED = False
+
+ def _fallback_launch(
+ self,
+ *,
+ x: Tensor,
+ weight: Optional[Tensor],
+ out: Tensor,
+ M: int,
+ N: int,
+ ld: int,
+ eps: float,
+ ) -> None:
+ # If the packed-args or runtime pointer mutation path stops working
+ # (e.g. due to a CuTeDSL upgrade), fall back to the regular call path.
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+ ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_w = (
+ rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ if weight is not None
+ else None
+ )
+ self._compiled(
+ ptr_x,
+ ptr_w,
+ None, # ptr_b
+ None, # ptr_res
+ ptr_out,
+ None, # ptr_res_out
+ None, # ptr_rstd
+ Int32(M),
+ Int32(N),
+ Int32(ld),
+ self._stream,
+ Float32(eps),
+ )
+
+
+class _PtrFusedAddRmsnormFastLaunch:
+ def __init__(
+ self,
+ *,
+ compiled: object,
+ executor: object,
+ capi_func: object,
+ ptr_x: object,
+ ptr_w: object,
+ ptr_res: object,
+ arg_m: _StableI32Arg,
+ arg_n: _StableI32Arg,
+ arg_ld_x: _StableI32Arg,
+ arg_eps: _StableF32Arg,
+ stream: cuda.CUstream,
+ assumed_align: int,
+ packed_args: object,
+ keepalive: tuple[object, ...],
+ ):
+ self._compiled = compiled
+ self._executor = executor
+ self._capi_func = capi_func
+ self._ptr_x = ptr_x
+ self._ptr_w = ptr_w
+ self._ptr_res = ptr_res
+ self._arg_m = arg_m
+ self._arg_n = arg_n
+ self._arg_ld_x = arg_ld_x
+ self._arg_eps = arg_eps
+ self._stream = stream
+ self._assumed_align = int(assumed_align)
+ self._packed_args = packed_args
+ self._keepalive = keepalive
+
+ self._use_fast_launch = True
+
+ self._cuda_result = getattr(executor, "cuda_result", None)
+
+ self._last_x_ptr = -1
+ self._last_w_ptr = -1
+ self._last_res_ptr = -1
+ self._last_m = -1
+ self._last_ld_x = -1
+ self._last_eps = float("nan")
+
+ def launch(
+ self,
+ *,
+ x: Tensor,
+ weight: Tensor,
+ residual: Tensor,
+ M: int,
+ N: int,
+ ld_x: int,
+ eps: float,
+ ) -> None:
+ if not _fast_launch_enabled() or not self._use_fast_launch:
+ self._fallback_launch(
+ x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps
+ )
+ return
+
+ x_ptr = x.data_ptr()
+ if x_ptr != self._last_x_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_x, x_ptr)
+ self._last_x_ptr = x_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps
+ )
+ return
+
+ w_ptr = weight.data_ptr()
+ if w_ptr != self._last_w_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_w, w_ptr)
+ self._last_w_ptr = w_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps
+ )
+ return
+
+ res_ptr = residual.data_ptr()
+ if res_ptr != self._last_res_ptr:
+ try:
+ _set_runtime_ptr(self._ptr_res, res_ptr)
+ self._last_res_ptr = res_ptr
+ except AttributeError:
+ self._disable_fast_launch()
+ self._fallback_launch(
+ x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps
+ )
+ return
+
+ if M != self._last_m:
+ self._arg_m.set(M)
+ self._last_m = M
+ if ld_x != self._last_ld_x:
+ self._arg_ld_x.set(ld_x)
+ self._last_ld_x = ld_x
+ if eps != self._last_eps:
+ self._arg_eps.set(eps)
+ self._last_eps = eps
+
+ if self._cuda_result is not None:
+ self._cuda_result.value = 0
+
+ ret = self._capi_func(self._packed_args) # type: ignore[misc]
+ if ret != 0:
+ raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}")
+ if self._cuda_result is not None:
+ err = int(self._cuda_result.value)
+ if err != 0:
+ raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})")
+
+ def _disable_fast_launch(self) -> None:
+ global _FAST_LAUNCH_SUPPORTED
+ self._use_fast_launch = False
+ _FAST_LAUNCH_SUPPORTED = False
+
+ def _fallback_launch(
+ self,
+ *,
+ x: Tensor,
+ weight: Tensor,
+ residual: Tensor,
+ M: int,
+ N: int,
+ ld_x: int,
+ eps: float,
+ ) -> None:
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+ ptr_x = rt.make_ptr(
+ dtype,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align,
+ )
+ ptr_res = rt.make_ptr(
+ dtype,
+ residual.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align,
+ )
+ ptr_w = rt.make_ptr(
+ dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=self._assumed_align,
+ )
+ self._compiled(
+ ptr_x,
+ ptr_w,
+ ptr_res,
+ Int32(M),
+ Int32(N),
+ Int32(ld_x),
+ self._stream,
+ Float32(eps),
+ )
+
+
+def _get_fast_ptr_rmsnorm_launcher(
+ *,
+ compiled: object,
+ dtype: type[cutlass.Numeric],
+ N: int,
+ device_index: int,
+ stream_handle: int,
+ has_weight: bool,
+ eps: float,
+) -> Optional[_PtrRmsnormFastLaunch]:
+ if not _fast_launch_enabled():
+ return None
+ # Keyed by the compiled object identity so schedule changes (e.g. copy width,
+ # async/staged variants, etc.) never alias in the fast-launch cache.
+ key = ("ptr_fast", id(compiled), N, dtype, device_index, int(stream_handle), has_weight)
+ cache = _tls_fast_launch_cache()
+ cached = cache.get(key)
+ if cached is not None:
+ return cached # type: ignore[return-value]
+
+ # Create stable runtime args and pointer descriptors once.
+ ptr_x = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_out = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_w = (
+ rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) if has_weight else None
+ )
+
+ arg_m = _StableI32Arg(0)
+ arg_n = _StableI32Arg(N)
+ arg_ld = _StableI32Arg(N)
+ arg_eps = _StableF32Arg(eps)
+
+ stream = cuda.CUstream(int(stream_handle))
+
+ # Create an executor (loads the CUDA library once).
+ executor = compiled.to(device_index) # type: ignore[attr-defined]
+
+ # Use generate_execution_args once to build the packed args array, and keep
+ # any adapted args alive for the lifetime of the cache entry.
+ try:
+ exe_args, adapted_args = executor.generate_execution_args(
+ ptr_x,
+ ptr_w,
+ None, # ptr_b
+ None, # ptr_res
+ ptr_out,
+ None, # ptr_res_out
+ None, # ptr_rstd
+ arg_m,
+ arg_n,
+ arg_ld,
+ stream,
+ arg_eps,
+ )
+ packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined]
+ capi_func = compiled.capi_func # type: ignore[attr-defined]
+ except AttributeError:
+ global _FAST_LAUNCH_SUPPORTED
+ _FAST_LAUNCH_SUPPORTED = False
+ return None
+
+ keepalive: tuple[object, ...] = (
+ executor,
+ ptr_x,
+ ptr_w,
+ ptr_out,
+ arg_m,
+ arg_n,
+ arg_ld,
+ arg_eps,
+ stream,
+ *adapted_args,
+ )
+
+ launcher = _PtrRmsnormFastLaunch(
+ compiled=compiled,
+ executor=executor,
+ capi_func=capi_func,
+ ptr_x=ptr_x,
+ ptr_w=ptr_w,
+ ptr_out=ptr_out,
+ arg_m=arg_m,
+ arg_n=arg_n,
+ arg_ld=arg_ld,
+ arg_eps=arg_eps,
+ stream=stream,
+ packed_args=packed_args,
+ keepalive=keepalive,
+ )
+ cache[key] = launcher
+ return launcher
+
+
+def _get_fast_ptr_fused_add_rmsnorm_launcher(
+ *,
+ compiled: object,
+ dtype: type[cutlass.Numeric],
+ N: int,
+ device_index: int,
+ stream_handle: int,
+ copy_bits: int,
+ use_async: bool,
+ tpr: int,
+ direct_gmem: bool,
+ assumed_align: int,
+ eps: float,
+) -> Optional[_PtrFusedAddRmsnormFastLaunch]:
+ if not _fast_launch_enabled():
+ return None
+ key = (
+ "ptr_fused_add_fast",
+ id(compiled),
+ N,
+ dtype,
+ device_index,
+ int(stream_handle),
+ int(copy_bits),
+ bool(use_async),
+ int(tpr),
+ bool(direct_gmem),
+ int(assumed_align),
+ )
+ cache = _tls_fast_launch_cache()
+ cached = cache.get(key)
+ if cached is not None:
+ return cached # type: ignore[return-value]
+
+ ptr_x = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
+ ptr_res = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
+ ptr_w = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
+
+ arg_m = _StableI32Arg(0)
+ arg_n = _StableI32Arg(N)
+ arg_ld_x = _StableI32Arg(N)
+ arg_eps = _StableF32Arg(eps)
+
+ stream = cuda.CUstream(int(stream_handle))
+
+ executor = compiled.to(device_index) # type: ignore[attr-defined]
+
+ try:
+ exe_args, adapted_args = executor.generate_execution_args(
+ ptr_x,
+ ptr_w,
+ ptr_res,
+ arg_m,
+ arg_n,
+ arg_ld_x,
+ stream,
+ arg_eps,
+ )
+ packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined]
+ capi_func = compiled.capi_func # type: ignore[attr-defined]
+ except AttributeError:
+ global _FAST_LAUNCH_SUPPORTED
+ _FAST_LAUNCH_SUPPORTED = False
+ return None
+
+ keepalive: tuple[object, ...] = (
+ executor,
+ ptr_x,
+ ptr_w,
+ ptr_res,
+ arg_m,
+ arg_n,
+ arg_ld_x,
+ arg_eps,
+ stream,
+ *adapted_args,
+ )
+
+ launcher = _PtrFusedAddRmsnormFastLaunch(
+ compiled=compiled,
+ executor=executor,
+ capi_func=capi_func,
+ ptr_x=ptr_x,
+ ptr_w=ptr_w,
+ ptr_res=ptr_res,
+ arg_m=arg_m,
+ arg_n=arg_n,
+ arg_ld_x=arg_ld_x,
+ arg_eps=arg_eps,
+ stream=stream,
+ assumed_align=assumed_align,
+ packed_args=packed_args,
+ keepalive=keepalive,
+ )
+ cache[key] = launcher
+ return launcher
+
+
+# Local helpers for reduction, dtype mapping, and coordinate/predicate utilities.
+#
+# NOTE: Avoid `from . import ...` imports here: CuTeDSL's AST preprocessor may
+# mishandle that form (module=None in the AST). Use fully-qualified imports.
+from kernelagent_oink.blackwell import lite_quack as qutils
+from kernelagent_oink.blackwell.lite_quack import TORCH2CUTE_DTYPE, row_reduce
+
+
+# -------------------------
+# Copy helpers (allow up to 256b)
+# -------------------------
+
+@cute.jit
+def get_copy_atom_bw(
+ dtype: type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False
+) -> cute.CopyAtom:
+ # cp.async (SIMT) supports up to 128b per op; use 256b for sync when possible
+ max_bits = const_expr(128 if is_async else 256)
+ num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width))
+ from cutlass.cute.nvgpu import cpasync
+ # Prefer GLOBAL cache policy for bulk streaming reads at large M
+ copy_op = (
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL)
+ if is_async
+ else cute.nvgpu.CopyUniversalOp()
+ )
+ return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
+
+
+@cute.jit
+def copy_tiled(
+ src: cute.Tensor,
+ dst: cute.Tensor,
+ *,
+ pred: Optional[cute.Tensor] = None,
+ num_copy_elems: int = 1,
+ is_async: bool = False,
+) -> None:
+ atom = get_copy_atom_bw(src.element_type, num_copy_elems, is_async)
+ cute.copy(atom, src, dst, pred=pred)
+
+
+# -------------------------
+# RMSNorm Kernel (SM100)
+# -------------------------
+
+
+class RMSNormSM100:
+ def __init__(
+ self,
+ N: int,
+ dtype: type[cutlass.Numeric],
+ stage: Optional[int] = None,
+ *,
+ copy_bits: int = 128,
+ use_async: bool = True,
+ direct_gmem: bool = False,
+ ):
+ self.N = N
+ self.dtype = dtype
+ # Match Quack default for RMSNorm: stage = 1 unless explicitly overridden
+ self.stage = 1 if stage is None else stage
+ self.reduction_dtype = cutlass.Float32
+ self.copy_bits = int(copy_bits)
+ self.use_async = bool(use_async)
+ self.direct_gmem = bool(direct_gmem)
+
+ def _threads_per_row(self) -> int:
+ try:
+ return self._tpr_override # type: ignore[attr-defined]
+ except Exception:
+ pass
+ # Tune mid-size buckets for large-M rows.
+ N = self.N
+ # DSv3 MLA (padded/strided) hot shape. Prefer a threads-per-row that
+ # makes the tile width exactly match N with 128b vectors (bf16/fp16),
+ # avoiding the ~33% padded work from rounding 1536 -> 2048.
+ if N == 1536 and self.dtype.width == 16:
+ return 96
+ # DSv3 default hidden size (7168). Choose a threads-per-row that matches
+ # the selected vector width to avoid padded work:
+ # - 128b copies (vec=8 for bf16/fp16): 7168/8 = 896 = 7 * 128 -> tpr=128
+ # - 256b copies (vec=16 for bf16/fp16): 7168/16 = 448 = 2 * 224 -> tpr=224
+ #
+ # The fused direct-GMEM path often uses 256b copies on 32B-aligned
+ # tensors, while the non-fused path defaults to 128b copies.
+ if N == 7168 and self.dtype.width == 16:
+ return 224 if self.copy_bits >= 256 else 128
+ # For small-N, use at least one full warp per row. The kernel
+ # implementation assumes one row per CTA; returning <32 here can
+ # produce multi-row tiles (cols_per_block > 1) which is not supported.
+ if N <= 1024:
+ return 32
+ elif N <= 4096:
+ return 128
+ elif N <= 8192:
+ # Allow an override (used by 2-rows/CTA path for N≈6k/8k)
+ try:
+ return self._tpr_override # type: ignore[attr-defined]
+ except Exception:
+ return 128
+ elif N <= 16384:
+ return 256
+ else:
+ return 256
+
+ def _cluster_n(self) -> int:
+ try:
+ return self._cluster_n_override # type: ignore[attr-defined]
+ except Exception:
+ pass
+ N = self.N
+ # Default policy
+ if N <= 8192:
+ return 1
+ if const_expr(self.dtype.width == 16):
+ if N <= 16 * 1024:
+ return 2
+ elif N <= 32 * 1024:
+ return 2
+ elif N <= 64 * 1024:
+ return 4
+ elif N <= 128 * 1024:
+ return 8
+ else:
+ return 16
+ else:
+ if N <= 32 * 1024:
+ return 1
+ elif N <= 64 * 1024:
+ return 2
+ elif N <= 128 * 1024:
+ return 4
+ elif N <= 256 * 1024:
+ return 8
+ else:
+ return 16
+
+ def _num_threads(self) -> int:
+ # Favor 128 threads up to N=16k to reduce per-row partitioning overhead.
+ # This keeps cols_per_block=1 at N=8192 (bf16), which benchmarks faster for large-M.
+ try:
+ return self._nt_override # type: ignore[attr-defined]
+ except Exception:
+ if self.N == 1536 and self.dtype.width == 16:
+ return 96
+ if self.N == 7168 and self.dtype.width == 16:
+ return 224 if self.copy_bits >= 256 else 128
+ if self.N <= 1024:
+ return 32
+ return 128 if self.N <= 16384 else 256
+
+ def _tv_layout(self, num_copy_bits: int = 256) -> Tuple[cute.Shape, cute.Layout]:
+ vecsize = num_copy_bits // self.dtype.width
+ num_threads = self._num_threads()
+ assert num_threads % cute.arch.WARP_SIZE == 0
+ tpr = self._threads_per_row()
+ cluster_n = self._cluster_n()
+ # Allow tails: compute number of vector columns with ceil
+ num_cols_vec = cute.ceil_div(self.N, vecsize)
+ num_blocks_N = cute.ceil_div(num_cols_vec, tpr * cluster_n)
+ cols_per_block = num_threads // tpr
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * tpr)
+ tv_layout = cute.make_layout(
+ ((tpr, cols_per_block), (vecsize, num_blocks_N)),
+ stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * tpr)),
+ )
+ return tiler_mn, tv_layout
+
+ def _smem_bytes(self, tiler_mn, num_warps) -> int:
+ # smem for X tile (+ residual if present) + reduction buffers + mbar(s)
+ return (
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
+ + self.stage * num_warps * self._cluster_n() * (self.reduction_dtype.width // 8)
+ + self.stage * (cutlass.Int64.width // 8)
+ )
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ stream: cuda.CUstream,
+ eps: Float32 = 1e-6,
+ ):
+ # Make last dim static (N)
+ semistatic_shape = (*mX.shape[:-1], self.N)
+
+ def new_stride(t):
+ return (
+ cute.assume(t.stride[0], divby=256 // t.element_type.width),
+ t.stride[1],
+ )
+
+ mX, mRes, mO, mResO = [
+ cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
+ if const_expr(t is not None)
+ else None
+ for t in (mX, mRes, mO, mResO)
+ ]
+ assert mX.element_type == self.dtype
+ assert mO.element_type == self.dtype
+
+ copy_bits = int(self.copy_bits)
+ tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits)
+ num_threads = cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._num_threads()
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ threads_per_row = tv_layout.shape[0][0] if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._threads_per_row()
+ warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1)
+ cluster_n = self._cluster_n()
+
+ if const_expr(mW is not None):
+ mW = cute.make_tensor(
+ mW.iterator, cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
+ )
+ if const_expr(mB is not None):
+ mB = cute.make_tensor(
+ mB.iterator, cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
+ )
+ if const_expr(mRstd is not None):
+ mRstd = cute.make_tensor(
+ mRstd.iterator, cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,)))
+ )
+
+ # No SMEM reload mode switch; overlap is controlled in the K-loop path
+
+ # Compute smem usage considering staged buffers.
+ #
+ # In direct-gmem mode, we skip the gmem->smem tiles entirely and only
+ # keep the reduction buffers in shared memory.
+ stage_bufs = 2 if self.stage > 1 else 1
+ tile_bytes_x = (
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * stage_bufs
+ if const_expr(not self.direct_gmem)
+ else 0
+ )
+ tile_bytes_res = (
+ cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn)) * stage_bufs
+ if const_expr(mRes is not None and not self.direct_gmem)
+ else 0
+ )
+ red_bytes = self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8)
+ # mbarriers are only allocated/used for cluster_n>1. Some CuTeDSL builds
+ # require mbarrier state to be 16B-aligned in shared memory; account for
+ # the alignment padding when computing dynamic smem bytes.
+ smem_bytes = tile_bytes_x + tile_bytes_res + red_bytes
+ if cluster_n > 1:
+ # Align up to 16B before placing the mbarrier array.
+ smem_bytes = ((smem_bytes + 15) // 16) * 16
+ smem_bytes += self.stage * (cutlass.Int64.width // 8)
+
+ kernel = (
+ self.kernel(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ tv_layout,
+ tiler_mn,
+ const_expr(cluster_n),
+ const_expr(num_warps),
+ const_expr(warps_per_row),
+ const_expr(threads_per_row),
+ )
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ )
+ )
+ kernel.launch(
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=([1, cluster_n, 1] if cluster_n > 1 else None),
+ smem=smem_bytes,
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_x: cute.Pointer,
+ ptr_w: Optional[cute.Pointer],
+ ptr_b: Optional[cute.Pointer],
+ ptr_res: Optional[cute.Pointer],
+ ptr_out: cute.Pointer,
+ ptr_res_out: Optional[cute.Pointer],
+ ptr_rstd: Optional[cute.Pointer],
+ M: Int32,
+ N_dyn: Int32,
+ ld: Int32,
+ stream: cuda.CUstream,
+ eps: Float32 = 1e-6,
+ ):
+ """Pointer-based entrypoint to reuse the existing RMSNorm schedule.
+
+ This reconstructs cute.Tensor views from raw pointers plus sizes,
+ avoiding any DLPack conversions at the Python boundary.
+ """
+ # Use a dynamic N for the leading-dimension stride so that the
+ # subsequent cute.assume(...) in __call__ sees a dynamic expression
+ # rather than a plain Python int.
+ # The compile-time N for the kernel (self.N) is still used to
+ # specialize the schedule.
+ # Assume row-major [M, N] with an arbitrary leading-dimension stride
+ # (common for padded-row / packed-attention layouts).
+ layout_mn = cute.make_layout((M, N_dyn), stride=(ld, 1))
+ layout_n = cute.make_layout((N_dyn,), stride=(1,))
+ layout_m = cute.make_layout((M,), stride=(1,))
+
+ mX = cute.make_tensor(ptr_x, layout_mn)
+ mO = cute.make_tensor(ptr_out, layout_mn)
+
+ mRes = (
+ cute.make_tensor(ptr_res, layout_mn)
+ if const_expr(ptr_res is not None)
+ else None
+ )
+ mResO = (
+ cute.make_tensor(ptr_res_out, layout_mn)
+ if const_expr(ptr_res_out is not None)
+ else None
+ )
+ mW = (
+ cute.make_tensor(ptr_w, layout_n)
+ if const_expr(ptr_w is not None)
+ else None
+ )
+ mB = (
+ cute.make_tensor(ptr_b, layout_n)
+ if const_expr(ptr_b is not None)
+ else None
+ )
+ mRstd = (
+ cute.make_tensor(ptr_rstd, layout_m)
+ if const_expr(ptr_rstd is not None)
+ else None
+ )
+
+ # Reuse the main JIT entry to launch the scheduled kernel.
+ self.__call__(mX, mW, mB, mRes, mO, mResO, mRstd, stream, eps)
+
+ @cute.jit
+ def launch_from_ptrs_fused_add_inplace(
+ self,
+ ptr_x: cute.Pointer,
+ ptr_w: cute.Pointer,
+ ptr_res: cute.Pointer,
+ M: Int32,
+ N_dyn: Int32,
+ ld_x: Int32,
+ stream: cuda.CUstream,
+ eps: Float32 = 1e-6,
+ ):
+ """Pointer-based entrypoint for vLLM-style fused_add_rms_norm semantics.
+
+ This specialized entrypoint supports:
+ - `x` / output with an arbitrary leading-dimension stride (`ld_x`), and
+ - `residual` / residual-out as a contiguous [M, N] tensor (ld_res = N).
+
+ Both `x` and `residual` are updated in-place:
+ residual <- x + residual
+ x <- RMSNorm(residual) * weight
+ """
+ layout_x = cute.make_layout((M, N_dyn), stride=(ld_x, 1))
+ layout_res = cute.make_layout((M, N_dyn), stride=(N_dyn, 1))
+ layout_n = cute.make_layout((N_dyn,), stride=(1,))
+
+ mX = cute.make_tensor(ptr_x, layout_x)
+ mO = cute.make_tensor(ptr_x, layout_x)
+ mRes = cute.make_tensor(ptr_res, layout_res)
+ mResO = cute.make_tensor(ptr_res, layout_res)
+ mW = cute.make_tensor(ptr_w, layout_n)
+
+ self.__call__(
+ mX,
+ mW,
+ None, # bias
+ mRes,
+ mO,
+ mResO,
+ None, # rstd
+ stream,
+ eps,
+ )
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ eps: Float32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ cluster_n: cutlass.Constexpr[int],
+ num_warps: cutlass.Constexpr[int],
+ warps_per_row: cutlass.Constexpr[int],
+ threads_per_row: cutlass.Constexpr[int],
+ ):
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ if const_expr(cluster_n > 1):
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
+ else:
+ cta_rank_in_cluster = const_expr(0)
+ n_off = cta_rank_in_cluster * tiler_mn[1]
+
+ smem = cutlass.utils.SmemAllocator()
+ # Allocate one or two SMEM buffers depending on stage depth
+ sX0 = (
+ smem.allocate_tensor(
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32
+ )
+ if const_expr(not self.direct_gmem)
+ else None
+ )
+ sX1 = (
+ smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
+ )
+ if const_expr(self.stage > 1 and not self.direct_gmem)
+ else None
+ )
+ sRes0 = (
+ smem.allocate_tensor(
+ mRes.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32
+ )
+ if const_expr(mRes is not None and not self.direct_gmem)
+ else None
+ )
+ sRes1 = (
+ smem.allocate_tensor(
+ mRes.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
+ )
+ if const_expr(mRes is not None and self.stage > 1 and not self.direct_gmem)
+ else None
+ )
+
+ # Reduction buffers + mbar for cluster reduce (reused by row_reduce helper)
+ red_layout = cute.make_ordered_layout(
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
+ order=(1, 0, 2),
+ )
+ reduction_buffer = smem.allocate_tensor(self.reduction_dtype, red_layout, byte_alignment=4)
+ if const_expr(cluster_n > 1):
+ # Some CuTeDSL builds appear sensitive to the shared-memory alignment of
+ # mbarrier state. `SmemAllocator.allocate_array` does not currently
+ # expose an alignment parameter, so allocate an Int64 tensor with an
+ # explicit alignment and pass its iterator as the pointer.
+ mbar_tensor = smem.allocate_tensor(
+ cutlass.Int64,
+ cute.make_layout((self.stage,), stride=(1,)),
+ byte_alignment=16,
+ )
+ mbar_ptr = mbar_tensor.iterator
+ else:
+ mbar_ptr = None
+
+ shape = mX.shape
+ idX = cute.make_identity_tensor(shape)
+ limit_k = shape[1] - n_off
+
+ # Tiled copy setup
+ num_copy_elems_X = tv_layout.shape[1][0]
+ use_async = const_expr(self.use_async and self.N >= 1024 and not self.direct_gmem)
+ copy_atom = get_copy_atom_bw(mX.element_type, num_copy_elems_X, is_async=use_async)
+ thr_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn).get_slice(tidx)
+
+ # Tail predicate for the N dimension (when tile width > N). Reuse this
+ # for W/B loads so we never read past the end of those 1D tensors.
+ is_even_N_wb = const_expr(shape[1] == tiler_mn[1] * cluster_n)
+ if const_expr(not is_even_N_wb):
+ cX0 = cute.local_tile(idX, tiler_mn, (0, 0))
+ tXp_wb = qutils.predicate_k(thr_copy.partition_S(cX0), limit=limit_k)
+ else:
+ tXp_wb = None
+
+ # Weight/bias loads:
+ #
+ # - Direct-GMEM schedule: load weight/bias up front to hide latency.
+ # - Staged SMEM schedule: loading after the reduction reduces register
+ # pressure during the long-scoreboard reduction phase (better for large-M),
+ # but it measurably hurts small-M latency for the non-fused (no residual,
+ # no bias) case. For that specific case, prefetch weight up front as well.
+ tXrW = None
+ tXrB = None
+ prefetch_w_early = bool(
+ mW is not None and (self.direct_gmem or (mRes is None and mB is None))
+ )
+ if const_expr(prefetch_w_early):
+ gW = cute.local_tile(qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0))
+ tXgW = thr_copy.partition_S(gW)
+ tXrW = cute.make_fragment_like(tXgW)
+ if const_expr(not is_even_N_wb):
+ tXrW.fill(0)
+ cute.copy(
+ get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False),
+ tXgW,
+ tXrW,
+ pred=tXp_wb,
+ )
+ if const_expr(self.direct_gmem and mB is not None):
+ gB = cute.local_tile(qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0))
+ tXgB = thr_copy.partition_S(gB)
+ tXrB = cute.make_fragment_like(tXgB)
+ if const_expr(not is_even_N_wb):
+ tXrB.fill(0)
+ cute.copy(
+ get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False),
+ tXgB,
+ tXrB,
+ pred=tXp_wb,
+ )
+
+ # Non-persistent per-CTA execution (one tile in M)
+ self._init_cluster(tidx, mbar_ptr)
+
+ mX_i, mRes_i, mO_i, mResO_i = [
+ qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t) if t is not None else None
+ for t in (mX, mRes, mO, mResO)
+ ]
+ mX_i, mRes_i, mO_i, mResO_i = [
+ qutils.domain_offset_i64((0, n_off), t) if t is not None else None
+ for t in (mX_i, mRes_i, mO_i, mResO_i)
+ ]
+ gX_i = cute.local_tile(mX_i, tiler_mn, (0, 0))
+ gO_i = cute.local_tile(mO_i, tiler_mn, (0, 0))
+ gRes_i = (
+ cute.local_tile(mRes_i, tiler_mn, (0, 0)) if const_expr(mRes is not None) else None
+ )
+ gResO_i = (
+ cute.local_tile(mResO_i, tiler_mn, (0, 0)) if const_expr(mResO is not None) else None
+ )
+ gRstd_i = (
+ cute.local_tile(mRstd, tiler_mn, (bidx, 0)) if const_expr(mRstd is not None) else None
+ )
+ cX_i = cute.local_tile(idX, tiler_mn, (bidx, 0))
+
+ # Common identity/row index partitions reused by both default and K-loop paths
+ tXcX_i = thr_copy.partition_S(cX_i)[(0, None), None, None]
+ row_i = tXcX_i[0][0]
+ tXgRstd_i = thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None
+
+ # Stage-2 intra-row K-loop cp.async ping-pong (two tiles). This reduces
+ # per-thread fragment size and can improve memory-latency hiding for
+ # N=7168 at large M. It is enabled by setting `stage=2` when constructing
+ # the RMSNormSM100 op (see `_fused_add_rmsnorm_forward_ptr_inplace`).
+ if const_expr(
+ self.stage > 1 and not self.direct_gmem and use_async and cluster_n == 1 and shape[1] == 7168
+ ):
+ vecsize = tv_layout.shape[1][0]
+ tpr = threads_per_row
+ target_tile_n = const_expr(4096)
+ tile_factor = const_expr(target_tile_n // (vecsize * tpr))
+ if const_expr(tile_factor > 0):
+ tile_n = vecsize * tpr * tile_factor
+ num_tiles = cute.ceil_div(shape[1], tile_n)
+
+ tiler_mn_tile = (tiler_mn[0], tile_n)
+ sX0_tile = cute.local_tile(sX0, tiler_mn_tile, (0, 0))
+ sX1_tile = cute.local_tile(sX1, tiler_mn_tile, (0, 0))
+ sRes0_tile = (
+ cute.local_tile(sRes0, tiler_mn_tile, (0, 0))
+ if const_expr(mRes is not None)
+ else None
+ )
+ sRes1_tile = (
+ cute.local_tile(sRes1, tiler_mn_tile, (0, 0))
+ if const_expr(mRes is not None)
+ else None
+ )
+
+ tv_layout_tile = cute.make_layout(
+ ((tpr, tiler_mn[0]), (vecsize, tile_factor)),
+ stride=(
+ (vecsize * tiler_mn[0], 1),
+ (tiler_mn[0], tiler_mn[0] * vecsize * tpr),
+ ),
+ )
+ thr_copy_tile = cute.make_tiled_copy(copy_atom, tv_layout_tile, tiler_mn_tile).get_slice(
+ tidx
+ )
+
+ # Accumulate per-thread partial sums across tiles; reduce once.
+ sum_sq_thread = cute.Float32(0.0)
+
+ # Preload tile 0 into sX0/sRes0.
+ k_off0 = const_expr(0) * tile_n
+ gX_0 = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off0), mX_i), tiler_mn_tile, (0, 0)
+ )
+ tXgX_0 = thr_copy_tile.partition_S(gX_0)
+ tXsX_0 = thr_copy_tile.partition_D(sX0_tile)
+ cX_0 = cute.local_tile(
+ cute.domain_offset((0, k_off0), cX_i), tiler_mn_tile, (0, 0)
+ )
+ tXc_0 = thr_copy_tile.partition_S(cX_0)
+ tXp_0 = qutils.predicate_k(tXc_0, limit=limit_k)
+
+ tXp_ping = tXp_0
+ tXp_pong = tXp_0
+
+ if row_i < shape[0]:
+ copy_tiled(tXgX_0, tXsX_0, num_copy_elems=vecsize, is_async=True, pred=tXp_0)
+ if const_expr(mRes is not None):
+ gRes_0 = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off0), mRes_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXgRes_0 = thr_copy_tile.partition_S(gRes_0)
+ tXsRes_0 = thr_copy_tile.partition_D(sRes0_tile)
+ copy_tiled(
+ tXgRes_0,
+ tXsRes_0,
+ num_copy_elems=vecsize,
+ is_async=True,
+ pred=tXp_0,
+ )
+ cute.arch.cp_async_commit_group()
+
+ for t in cutlass.range_constexpr(num_tiles):
+ next_t = t + 1
+ if next_t < num_tiles:
+ k_off_n = next_t * tile_n
+ gX_n = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off_n), mX_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXgX_n = thr_copy_tile.partition_S(gX_n)
+ cX_n = cute.local_tile(
+ cute.domain_offset((0, k_off_n), cX_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXc_n = thr_copy_tile.partition_S(cX_n)
+ tXp_n = qutils.predicate_k(tXc_n, limit=limit_k)
+
+ if const_expr((t % 2) == 0):
+ tXsX_n = thr_copy_tile.partition_D(sX1_tile)
+ tXsRes_n = (
+ thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None
+ )
+ tXp_pong = tXp_n
+ else:
+ tXsX_n = thr_copy_tile.partition_D(sX0_tile)
+ tXsRes_n = (
+ thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None
+ )
+ tXp_ping = tXp_n
+
+ if row_i < shape[0]:
+ copy_tiled(
+ tXgX_n, tXsX_n, num_copy_elems=vecsize, is_async=True, pred=tXp_n
+ )
+ if const_expr(mRes is not None):
+ gRes_n = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off_n), mRes_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXgRes_n = thr_copy_tile.partition_S(gRes_n)
+ copy_tiled(
+ tXgRes_n,
+ tXsRes_n,
+ num_copy_elems=vecsize,
+ is_async=True,
+ pred=tXp_n,
+ )
+ cute.arch.cp_async_commit_group()
+
+ cute.arch.cp_async_wait_group(1 if next_t < num_tiles else 0)
+
+ # Current tile buffer (ping/pong).
+ if const_expr((t % 2) == 0):
+ tXsX_cur = thr_copy_tile.partition_D(sX0_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None
+ )
+ pred_cur = tXp_ping
+ else:
+ tXsX_cur = thr_copy_tile.partition_D(sX1_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None
+ )
+ pred_cur = tXp_pong
+
+ k_off = t * tile_n
+ gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, 0))
+ tXgX_t = thr_copy_tile.partition_S(gX_t)
+ tXrX_t = cute.make_fragment_like(tXgX_t)
+ cute.autovec_copy(tXsX_cur, tXrX_t)
+ x_t = tXrX_t.load().to(cute.Float32)
+ if const_expr(mRes is not None):
+ gRes_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, 0)
+ )
+ tXgRes_t = thr_copy_tile.partition_S(gRes_t)
+ tXrRes_t = cute.make_fragment_like(tXgRes_t)
+ cute.autovec_copy(tXsRes_cur, tXrRes_t)
+ x_t += tXrRes_t.load().to(cute.Float32)
+
+ if const_expr(mResO is not None):
+ gResO_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mResO_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
+ tXgResO_t = thr_copy_tile.partition_D(gResO_t)
+ tXrResO_t = cute.make_fragment_like(tXgResO_t)
+ tXrResO_t.store(x_t.to(tXrResO_t.element_type))
+ if row_i < shape[0]:
+ copy_tiled(
+ tXrResO_t,
+ tXgResO_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=pred_cur,
+ )
+
+ sum_sq_thread = sum_sq_thread + (x_t * x_t).reduce(
+ cute.ReductionOp.ADD,
+ init_val=0.0,
+ reduction_profile=0,
+ )
+
+ sum_sq = row_reduce(
+ sum_sq_thread,
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr,
+ init_val=0.0,
+ )
+ rstd = cute.math.rsqrt(sum_sq / shape[1] + eps, fastmath=True)
+
+ if const_expr(mRstd is not None):
+ if tXcX_i[0][1] == 0 and row_i < shape[0]:
+ tXgRstd_i[0] = rstd
+
+ for t in cutlass.range_constexpr(num_tiles):
+ k_off = t * tile_n
+ cX_t = cute.local_tile(cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, 0))
+ tXc_t = thr_copy_tile.partition_S(cX_t)
+ tXp_t = qutils.predicate_k(tXc_t, limit=limit_k)
+
+ if const_expr((t % 2) == 0):
+ tXsX_cur = thr_copy_tile.partition_D(sX0_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None
+ )
+ else:
+ tXsX_cur = thr_copy_tile.partition_D(sX1_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None
+ )
+
+ gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, 0))
+ tXgX_t = thr_copy_tile.partition_S(gX_t)
+ tXrX_t = cute.make_fragment_like(tXgX_t)
+ cute.autovec_copy(tXsX_cur, tXrX_t)
+ x_t = tXrX_t.load().to(cute.Float32)
+ if const_expr(mRes is not None):
+ gRes_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, 0)
+ )
+ tXgRes_t = thr_copy_tile.partition_S(gRes_t)
+ tXrRes_t = cute.make_fragment_like(tXgRes_t)
+ cute.autovec_copy(tXsRes_cur, tXrRes_t)
+ x_t += tXrRes_t.load().to(cute.Float32)
+
+ y_t = x_t * rstd
+ if const_expr(mW is not None):
+ gW_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mW), tiler_mn_tile, (0, 0)
+ )
+ tWgW_t = thr_copy_tile.partition_S(gW_t)
+ tWrW_t = cute.make_fragment_like(tWgW_t)
+ copy_tiled(tWgW_t, tWrW_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t)
+ y_t = y_t * tWrW_t.load().to(cute.Float32)
+ if const_expr(mB is not None):
+ gB_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mB), tiler_mn_tile, (0, 0)
+ )
+ tWgB_t = thr_copy_tile.partition_S(gB_t)
+ tWrB_t = cute.make_fragment_like(tWgB_t)
+ copy_tiled(tWgB_t, tWrB_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t)
+ y_t = y_t + tWrB_t.load().to(cute.Float32)
+
+ gO_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mO_i), tiler_mn_tile, (0, 0))
+ tXgO_t = thr_copy_tile.partition_D(gO_t)
+ tXrO_t = cute.make_fragment_like(tXgO_t)
+ tXrO_t.store(y_t.to(tXrO_t.element_type))
+ if row_i < shape[0]:
+ copy_tiled(tXrO_t, tXgO_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t)
+
+ return
+
+ # Single-stage path: one-row-per-CTA
+ tXgX_i = thr_copy.partition_S(gX_i)
+ tXgRes_i = thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None
+ tXgO_i = thr_copy.partition_D(gO_i)
+ tXgResO_i = thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None
+ # tXgRstd_i / tXcX_i / row_i prepared above
+ is_even_N_i = const_expr(shape[1] == tiler_mn[1] * cluster_n)
+ tXpX_i = (
+ qutils.predicate_k(thr_copy.partition_S(cX_i), limit=limit_k) if not is_even_N_i else None
+ )
+
+ tXrX = cute.make_fragment_like(tXgX_i)
+ tXrRes = cute.make_fragment_like(tXgRes_i) if const_expr(mRes is not None) else None
+ if const_expr(self.direct_gmem):
+ if const_expr(not is_even_N_i):
+ tXrX.fill(0)
+ if const_expr(tXrRes is not None):
+ tXrRes.fill(0)
+ if row_i < shape[0]:
+ cute.copy(copy_atom, tXgX_i, tXrX, pred=tXpX_i)
+ if const_expr(tXrRes is not None):
+ cute.copy(copy_atom, tXgRes_i, tXrRes, pred=tXpX_i)
+ else:
+ # If N is not a multiple of the tile width, the predicated gmem->smem
+ # copies leave out-of-bounds lanes uninitialized. Clear the SMEM tile
+ # so masked lanes read as 0 for reduction/output.
+ if const_expr(not is_even_N_i):
+ thr_copy.partition_D(sX0).fill(0)
+ if const_expr(mRes is not None):
+ thr_copy.partition_D(sRes0).fill(0)
+
+ if row_i < shape[0]:
+ cute.copy(copy_atom, tXgX_i, thr_copy.partition_D(sX0), pred=tXpX_i)
+ if const_expr(mRes is not None):
+ cute.copy(copy_atom, tXgRes_i, thr_copy.partition_D(sRes0), pred=tXpX_i)
+ if const_expr(use_async):
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+
+ cute.autovec_copy(thr_copy.partition_D(sX0), tXrX)
+ if const_expr(tXrRes is not None):
+ cute.autovec_copy(thr_copy.partition_D(sRes0), tXrRes)
+ x_red = tXrX.load().to(cute.Float32)
+ if const_expr(tXrRes is not None):
+ x_red += tXrRes.load().to(cute.Float32)
+
+ if const_expr(mResO is not None):
+ tXrResO = cute.make_fragment_like(tXgResO_i)
+ tXrResO.store(x_red.to(tXrResO.element_type))
+ if row_i < shape[0]:
+ cute.copy(
+ get_copy_atom_bw(tXrResO.element_type, num_copy_elems_X, is_async=False),
+ tXrResO,
+ tXgResO_i,
+ pred=tXpX_i,
+ )
+
+ sum_sq = row_reduce(
+ x_red * x_red,
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr,
+ init_val=0.0,
+ )
+ rstd = cute.math.rsqrt(sum_sq / shape[1] + eps, fastmath=True)
+
+ if const_expr(mRstd is not None):
+ if (
+ tXcX_i[0][1] == 0
+ and row_i < shape[0]
+ and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
+ ):
+ tXgRstd_i[0] = rstd
+
+ if const_expr(not self.direct_gmem and (mRes is not None or mB is not None)):
+ # Load weight/bias after the reduction so they don't inflate register
+ # pressure during the long-scoreboard reduction phase (helping occupancy
+ # when registers are the limiting factor).
+ if const_expr(mW is not None):
+ gW = cute.local_tile(qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0))
+ tXgW = thr_copy.partition_S(gW)
+ tXrW = cute.make_fragment_like(tXgW)
+ if const_expr(not is_even_N_wb):
+ tXrW.fill(0)
+ cute.copy(
+ get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False),
+ tXgW,
+ tXrW,
+ pred=tXp_wb,
+ )
+ if const_expr(mB is not None):
+ gB = cute.local_tile(qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0))
+ tXgB = thr_copy.partition_S(gB)
+ tXrB = cute.make_fragment_like(tXgB)
+ if const_expr(not is_even_N_wb):
+ tXrB.fill(0)
+ cute.copy(
+ get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False),
+ tXgB,
+ tXrB,
+ pred=tXp_wb,
+ )
+
+ # Reuse `x_red` (x + residual, in fp32) for the output path so we don't
+ # keep both `tXrX` and `tXrRes` fragments live across the reduction.
+ y = x_red * rstd
+ if const_expr(mW is not None):
+ y = y * tXrW.load().to(cute.Float32)
+ if const_expr(mB is not None):
+ y = y + tXrB.load().to(cute.Float32)
+
+ tXrO = cute.make_fragment_like(tXgO_i)
+ tXrO.store(y.to(tXrO.element_type))
+ if row_i < shape[0]:
+ cute.copy(
+ get_copy_atom_bw(tXrO.element_type, num_copy_elems_X, is_async=False),
+ tXrO,
+ tXgO_i,
+ pred=tXpX_i,
+ )
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ eps: Float32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ cluster_n: cutlass.Constexpr[int],
+ num_warps: cutlass.Constexpr[int],
+ warps_per_row: cutlass.Constexpr[int],
+ threads_per_row: cutlass.Constexpr[int],
+ ):
+ self._kernel_impl(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ tv_layout,
+ tiler_mn,
+ cluster_n,
+ num_warps,
+ warps_per_row,
+ threads_per_row,
+ )
+ else:
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ eps: Float32,
+ ):
+ copy_bits = int(self.copy_bits)
+ tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits)
+ num_threads = self._num_threads()
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ threads_per_row = self._threads_per_row()
+ warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1)
+ cluster_n = self._cluster_n()
+ self._kernel_impl(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ tv_layout,
+ tiler_mn,
+ const_expr(cluster_n),
+ const_expr(num_warps),
+ const_expr(warps_per_row),
+ const_expr(threads_per_row),
+ )
+
+ @cute.jit
+ def _init_cluster(self, tidx: cutlass.Int32, mbar_ptr: Optional[cute.Pointer]):
+ if const_expr(mbar_ptr is not None):
+ if tidx < self.stage:
+ cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
+ cute.arch.mbarrier_init_fence()
+ cute.arch.cluster_arrive_relaxed()
+
+
+def _can_use_ptr_path(
+ x: Tensor,
+ weight: Optional[Tensor],
+ bias: Optional[Tensor],
+ residual: Optional[Tensor],
+) -> bool:
+ """Fast path precondition for the pointer-based CuTeDSL entry.
+
+ We require a row-major 2D layout where the last dimension is
+ contiguous (stride(1) == 1). The leading dimension (stride(0))
+ may be larger than N (padded-row / packed-attention layouts),
+ and is passed to the kernel as `ld`.
+ """
+ if x.stride(1) != 1:
+ return False
+ # All participating tensors are interpreted as the same element type
+ # (derived from x.dtype) in the pointer-based path. If dtypes differ,
+ # we'd read the wrong bit patterns and silently produce incorrect output.
+ if residual is not None and residual.dtype != x.dtype:
+ return False
+ if weight is not None and weight.dtype != x.dtype:
+ return False
+ if bias is not None and bias.dtype != x.dtype:
+ return False
+ # The kernel assumes `ld` satisfies a divisibility constraint used by
+ # cute.assume(..., divby=...) for vectorization.
+ elem_bits = TORCH2CUTE_DTYPE[x.dtype].width
+ divby = 256 // elem_bits
+ if (x.stride(0) % divby) != 0:
+ return False
+ # The kernel uses 128-bit vectorized copies (16B). Require at least 16B
+ # alignment on all participating tensors to avoid misaligned global loads.
+ if (x.data_ptr() % 16) != 0:
+ return False
+ if residual is not None and residual.stride(1) != 1:
+ return False
+ if residual is not None and residual.stride(0) != x.stride(0):
+ return False
+ if residual is not None and (residual.data_ptr() % 16) != 0:
+ return False
+ if weight is not None and not weight.is_contiguous():
+ return False
+ if bias is not None and not bias.is_contiguous():
+ return False
+ if weight is not None and (weight.data_ptr() % 16) != 0:
+ return False
+ if bias is not None and (bias.data_ptr() % 16) != 0:
+ return False
+ return True
+
+
+def _can_use_ptr_path_fused_add_inplace(
+ x: Tensor,
+ weight: Tensor,
+ residual: Tensor,
+) -> bool:
+ """Fast-path precondition for fused_add_rmsnorm_forward_inplace.
+
+ We allow the common vLLM layout where:
+ - `x` is strided/padded row-major (stride(1) == 1, stride(0) >= N)
+ - `residual` is contiguous row-major (stride(0) == N)
+ """
+ if x.stride(1) != 1:
+ return False
+ if residual.dtype != x.dtype:
+ return False
+ if weight.dtype != x.dtype:
+ return False
+ if residual.stride(1) != 1:
+ return False
+ if not residual.is_contiguous():
+ return False
+ if not weight.is_contiguous():
+ return False
+
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+ divby = 256 // dtype.width
+ if (x.stride(0) % divby) != 0:
+ return False
+ if (residual.stride(0) % divby) != 0:
+ return False
+
+ if (x.data_ptr() % 16) != 0:
+ return False
+ if (residual.data_ptr() % 16) != 0:
+ return False
+ if (weight.data_ptr() % 16) != 0:
+ return False
+ return True
+
+
+def _rmsnorm_forward_ptr(
+ x: Tensor,
+ weight: Optional[Tensor],
+ bias: Optional[Tensor],
+ residual: Optional[Tensor],
+ eps: float,
+ store_rstd: bool,
+) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ """Pointer-based RMSNorm forward that bypasses DLPack entirely.
+
+ This path reconstructs cute.Tensor views from raw device pointers
+ and explicit layouts inside the JIT graph, avoiding any runtime
+ DLPack conversions while reusing the tuned RMSNormSM100 schedule.
+ """
+ assert x.is_cuda
+ assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand."
+ M, N = x.shape
+
+ # Preserve the input's 2D stride so downstream users that rely on
+ # padded-row layouts (stride0 > N) continue to see the expected layout.
+ out = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
+ residual_out: Optional[Tensor] = None
+ rstd: Optional[Tensor] = None
+
+ if residual is not None:
+ residual_out = torch.empty_strided(
+ residual.shape, residual.stride(), device=residual.device, dtype=residual.dtype
+ )
+ if store_rstd:
+ rstd = torch.empty(M, device=x.device, dtype=torch.float32)
+
+ _rmsnorm_forward_ptr_into(
+ x=x,
+ weight=weight,
+ bias=bias,
+ residual=residual,
+ out=out,
+ residual_out=residual_out,
+ rstd=rstd,
+ eps=eps,
+ )
+ return out, rstd, residual_out
+
+
+def _rmsnorm_forward_ptr_into(
+ x: Tensor,
+ weight: Optional[Tensor],
+ bias: Optional[Tensor],
+ residual: Optional[Tensor],
+ out: Tensor,
+ residual_out: Optional[Tensor],
+ rstd: Optional[Tensor],
+ eps: float,
+) -> None:
+ """Internal helper that launches the pointer-based kernel into preallocated outputs.
+
+ This enables integration into frameworks like vLLM that manage their
+ own buffers and prefer in-place or out-parameter semantics.
+ """
+ assert x.is_cuda
+ assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand."
+ M, N = x.shape
+ device_index = x.get_device()
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+
+ if bias is None and residual is None and residual_out is None and rstd is None:
+ # Fast-launch path: cache packed args and update pointers/scalars in-place to
+ # avoid Python-side argument marshalling overhead that dominates small-batch cases.
+ #
+ # If fast-launch is disabled (or CuTeDSL internals changed), we fall back
+ # to calling the compiled function directly.
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ has_weight = weight is not None
+
+ stage = 1
+ compiled_key = (
+ "ptr",
+ N,
+ dtype,
+ False, # residual
+ has_weight,
+ False, # bias
+ False, # residual_out
+ False, # rstd
+ stage,
+ device_index,
+ )
+ compiled = _PTR_COMPILE_CACHE.get(compiled_key)
+ if compiled is None:
+ op = RMSNormSM100(N, dtype, stage=stage)
+ ld_val = int(x.stride(0))
+ ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_w = (
+ rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ if has_weight
+ else None
+ )
+ stream = cuda.CUstream(stream_handle)
+ ld = Int32(ld_val)
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_x,
+ ptr_w,
+ None, # ptr_b
+ None, # ptr_res
+ ptr_out,
+ None, # ptr_res_out
+ None, # ptr_rstd
+ Int32(M),
+ Int32(N),
+ ld,
+ stream,
+ Float32(eps),
+ )
+ _PTR_COMPILE_CACHE[compiled_key] = compiled
+
+ launcher = _get_fast_ptr_rmsnorm_launcher(
+ compiled=compiled,
+ dtype=dtype,
+ N=N,
+ device_index=device_index,
+ stream_handle=stream_handle,
+ has_weight=has_weight,
+ eps=eps,
+ )
+ ld_val = int(x.stride(0))
+ if launcher is not None:
+ launcher.launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld_val, eps=eps)
+ return
+
+ ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_w = (
+ rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ if has_weight
+ else None
+ )
+ stream = cuda.CUstream(stream_handle)
+ ld = Int32(ld_val)
+ compiled(
+ ptr_x,
+ ptr_w,
+ None, # ptr_b
+ None, # ptr_res
+ ptr_out,
+ None, # ptr_res_out
+ None, # ptr_rstd
+ Int32(M),
+ Int32(N),
+ ld,
+ stream,
+ Float32(eps),
+ )
+ return
+
+ # Fallback: general path (supports bias/residual/rstd, but is slower to launch).
+ stage = 1
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ key = (
+ "ptr",
+ N,
+ dtype,
+ residual is not None,
+ weight is not None,
+ bias is not None,
+ residual_out is not None,
+ rstd is not None,
+ stage,
+ device_index,
+ )
+ compiled = _PTR_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = RMSNormSM100(N, dtype, stage=stage)
+ ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_res = (
+ rt.make_ptr(dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ if residual is not None
+ else None
+ )
+ ptr_res_out = (
+ rt.make_ptr(
+ dtype, residual_out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ if residual_out is not None
+ else None
+ )
+ ptr_w = (
+ rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ if weight is not None
+ else None
+ )
+ ptr_b = (
+ rt.make_ptr(dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ if bias is not None
+ else None
+ )
+ ptr_rstd = (
+ rt.make_ptr(
+ cutlass.Float32, rstd.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=4
+ )
+ if rstd is not None
+ else None
+ )
+ stream = cuda.CUstream(stream_handle)
+ ld = Int32(int(x.stride(0)))
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_x,
+ ptr_w,
+ ptr_b,
+ ptr_res,
+ ptr_out,
+ ptr_res_out,
+ ptr_rstd,
+ Int32(M),
+ Int32(N),
+ ld,
+ stream,
+ Float32(eps),
+ )
+ _PTR_COMPILE_CACHE[key] = compiled
+ ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_res = (
+ rt.make_ptr(dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ if residual is not None
+ else None
+ )
+ ptr_res_out = (
+ rt.make_ptr(dtype, residual_out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ if residual_out is not None
+ else None
+ )
+ ptr_w = (
+ rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ if weight is not None
+ else None
+ )
+ ptr_b = (
+ rt.make_ptr(dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ if bias is not None
+ else None
+ )
+ ptr_rstd = (
+ rt.make_ptr(cutlass.Float32, rstd.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=4)
+ if rstd is not None
+ else None
+ )
+ stream = cuda.CUstream(stream_handle)
+ ld = Int32(int(x.stride(0)))
+ compiled(
+ ptr_x,
+ ptr_w,
+ ptr_b,
+ ptr_res,
+ ptr_out,
+ ptr_res_out,
+ ptr_rstd,
+ Int32(M),
+ Int32(N),
+ ld,
+ stream,
+ Float32(eps),
+ )
+
+
+def _fused_add_rmsnorm_forward_ptr_inplace(
+ x: Tensor,
+ residual: Tensor,
+ weight: Tensor,
+ eps: float,
+) -> None:
+ """Pointer-based fused_add_rmsnorm that updates `x` and `residual` in-place."""
+ assert x.is_cuda
+ assert x.dim() == 2
+ assert residual.is_cuda
+ assert residual.dim() == 2
+ assert x.shape == residual.shape
+
+ M, N = x.shape
+ device_index = x.get_device()
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+ stage = 1
+
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+
+ # Latency-optimized schedule for small-M cases: avoid the gmem->smem
+ # staging path (large dynamic smem + extra barriers) and load directly
+ # from gmem into registers.
+ copy_bits = 128
+ # Use a direct-GMEM schedule (no staging SMEM tiles) for DSv3 hidden size
+ # (7168, bf16/fp16). This improves both:
+ # - small-M latency (fewer barriers + less dynamic shared memory), and
+ # - large-M bandwidth (lower overhead, better vectorization when 32B-aligned).
+ #
+ # This is a policy decision: it is tuned for DSv3's N=7168. If you want to
+ # benchmark other models/shapes, you can override it with:
+ # - OINK_RMSNORM_DIRECT_GMEM=0 (force staging/cp.async path)
+ # - OINK_RMSNORM_DIRECT_GMEM=1 (force direct-gmem path)
+ direct_gmem = _direct_gmem_from_policy(default=bool(dtype.width == 16 and N == 7168))
+ use_async = not direct_gmem
+ tpr_override: Optional[int] = None
+ nt_override: Optional[int] = None
+ cluster_n_override: Optional[int] = None
+ direct_gmem_max_copy_bits: Optional[int] = None
+
+ # Experimental stage-2 cp.async path (2-tile ping-pong) for N=7168. This is
+ # primarily about improving memory-latency hiding / reducing long-scoreboard
+ # stalls for large-M workloads.
+ if _ENABLE_STAGE2 and dtype.width == 16 and N == 7168 and M >= 4096:
+ stage = 2
+ direct_gmem = False
+ use_async = True
+
+ # Experimental ILP variant (clusters): split each row across 2 CTAs.
+ #
+ # NOTE: This is currently opt-in because some CuTeDSL builds exhibit
+ # instability with cluster launches for this specific schedule. To reduce
+ # the chance of accidental crashes, we require an additional explicit
+ # opt-in via `OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE=1`.
+ if _ENABLE_CLUSTER_ILP and not _ENABLE_STAGE2:
+ if dtype.width == 16 and N == 7168 and M >= 4096:
+ cluster_n_override = 2
+ if direct_gmem:
+ # Cluster launches + direct-GMEM has exhibited reproducible compiler
+ # instability (segfaults) in some CuTeDSL builds, especially for the
+ # 256b vector path. Probe it out-of-process once so we can safely
+ # select a working copy width (or fall back to the staged SMEM path)
+ # instead of crashing the parent process.
+ max_bits = _probe_cluster_direct_gmem_max_copy_bits()
+ if max_bits == 0:
+ direct_gmem = False
+ use_async = True
+ else:
+ direct_gmem_max_copy_bits = max_bits
+
+ # Experimental per-row partitioning: use 256 threads/row for N=7168 to
+ # increase concurrency/ILP (accepts a small tail-predicate region).
+ if _ENABLE_TPR256 and cluster_n_override is None and not _ENABLE_STAGE2:
+ if dtype.width == 16 and N == 7168 and M >= 4096:
+ tpr_override = 256
+ nt_override = 256
+
+
+ can_use_256 = bool(
+ direct_gmem
+ and (
+ direct_gmem_max_copy_bits is None
+ or direct_gmem_max_copy_bits >= 256
+ )
+ and dtype.width == 16
+ and (x.data_ptr() % 32) == 0
+ and (residual.data_ptr() % 32) == 0
+ and (weight.data_ptr() % 32) == 0
+ )
+ assumed_align = 32 if can_use_256 else 16
+ if can_use_256:
+ copy_bits = 256
+
+ copy_bits = _copy_bits_from_policy(default=copy_bits, can_use_256=can_use_256)
+ if copy_bits == 128:
+ assumed_align = 16
+ elif copy_bits == 256 and can_use_256:
+ assumed_align = 32
+ else:
+ copy_bits = 128
+ assumed_align = 16
+
+ key = (
+ "ptr_fused_add_inplace",
+ N,
+ dtype,
+ stage,
+ device_index,
+ copy_bits,
+ use_async,
+ tpr_override,
+ nt_override,
+ direct_gmem,
+ cluster_n_override,
+ )
+ compiled = _PTR_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = RMSNormSM100(
+ N,
+ dtype,
+ stage=stage,
+ copy_bits=copy_bits,
+ use_async=use_async,
+ direct_gmem=direct_gmem,
+ )
+ if tpr_override is not None:
+ op._tpr_override = tpr_override # type: ignore[attr-defined]
+ if nt_override is not None:
+ op._nt_override = nt_override # type: ignore[attr-defined]
+ if cluster_n_override is not None:
+ op._cluster_n_override = cluster_n_override # type: ignore[attr-defined]
+ ptr_x = rt.make_ptr(
+ dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ ptr_res = rt.make_ptr(
+ dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ ptr_w = rt.make_ptr(
+ dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ stream = cuda.CUstream(stream_handle)
+ ld_x = Int32(int(x.stride(0)))
+ compiled = cute.compile(
+ op.launch_from_ptrs_fused_add_inplace,
+ ptr_x,
+ ptr_w,
+ ptr_res,
+ Int32(M),
+ Int32(N),
+ ld_x,
+ stream,
+ Float32(eps),
+ )
+ _PTR_COMPILE_CACHE[key] = compiled
+ launcher = _get_fast_ptr_fused_add_rmsnorm_launcher(
+ compiled=compiled,
+ dtype=dtype,
+ N=N,
+ device_index=device_index,
+ stream_handle=stream_handle,
+ copy_bits=copy_bits,
+ use_async=use_async,
+ tpr=tpr_override or 0,
+ direct_gmem=direct_gmem,
+ assumed_align=assumed_align,
+ eps=eps,
+ )
+ if launcher is not None:
+ launcher.launch(
+ x=x,
+ weight=weight,
+ residual=residual,
+ M=M,
+ N=N,
+ ld_x=int(x.stride(0)),
+ eps=eps,
+ )
+ return
+
+ # Fast-launch is disabled/unavailable (or CuTeDSL internals changed). Fall back
+ # to calling the compiled function directly.
+ ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
+ ptr_res = rt.make_ptr(
+ dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ ptr_w = rt.make_ptr(
+ dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ stream = cuda.CUstream(stream_handle)
+ ld_x = Int32(int(x.stride(0)))
+ compiled(ptr_x, ptr_w, ptr_res, Int32(M), Int32(N), ld_x, stream, Float32(eps))
+
+
+# -------------------------
+# Public API (forward + verify)
+# -------------------------
+
+
+def rmsnorm_forward(
+ x: Tensor,
+ weight: Optional[Tensor] = None,
+ bias: Optional[Tensor] = None,
+ residual: Optional[Tensor] = None,
+ eps: float = 1e-6,
+ store_rstd: bool = False,
+) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ assert x.is_cuda
+ assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand."
+ M, N = x.shape
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+
+ # For DSv3 big-M outliers on SM100, keep using the dedicated
+ # stage-2 K-loop implementation, which is already tuned and
+ # parity-checked against the reference.
+ use_stage2_big_dsv3 = bool(
+ M >= 65536 and N in (6144, 8192) and x.dtype in (torch.float16, torch.bfloat16)
+ )
+ if use_stage2_big_dsv3:
+ try:
+ import rmsnorm_with_stage2 as rms2 # type: ignore[import-not-found]
+ except Exception:
+ rms2 = None # type: ignore[assignment]
+ if rms2 is not None:
+ y, rstd, residual_out = rms2.rmsnorm_forward_with_stage2(
+ x, weight=weight, bias=bias, residual=residual, eps=eps, store_rstd=store_rstd
+ )
+ # Preserve stride contracts for torch.compile consistency, even
+ # when using the optional stage-2 implementation.
+ if y.stride() != x.stride():
+ y_strided = torch.empty_strided(
+ x.shape, x.stride(), device=x.device, dtype=x.dtype
+ )
+ y_strided.copy_(y)
+ y = y_strided
+ if residual is not None and residual_out is not None:
+ if residual_out.stride() != residual.stride():
+ residual_out_strided = torch.empty_strided(
+ residual.shape,
+ residual.stride(),
+ device=residual.device,
+ dtype=residual.dtype,
+ )
+ residual_out_strided.copy_(residual_out)
+ residual_out = residual_out_strided
+ return y, rstd, residual_out
+
+ # Default: use the pointer-based entry whenever we can represent the
+ # inputs as a row-major [M, N] view with stride(1) == 1. For rare layouts
+ # we can't safely express without DLPack, fall back to a torch reference.
+ if _can_use_ptr_path(x, weight, bias, residual):
+ return _rmsnorm_forward_ptr(x, weight, bias, residual, eps, store_rstd)
+
+ # Safe fallback (correctness-first). This is expected to be rare in vLLM.
+ y = rmsnorm_ref(x, weight, bias, residual, eps)
+ # Preserve the input stride contract even on the fallback path so
+ # torch.compile sees a consistent output layout across all branches.
+ if y.stride() != x.stride():
+ y_strided = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
+ y_strided.copy_(y)
+ y = y_strided
+ rstd = None
+ if store_rstd:
+ xf = x.float()
+ if residual is not None:
+ xf = xf + residual.float()
+ rstd = torch.rsqrt(xf.square().mean(dim=-1) + eps).to(torch.float32)
+ residual_out = None
+ if residual is not None:
+ residual_out = (x.float() + residual.float()).to(x.dtype)
+ if residual_out.stride() != residual.stride():
+ residual_out_strided = torch.empty_strided(
+ residual.shape,
+ residual.stride(),
+ device=residual.device,
+ dtype=residual.dtype,
+ )
+ residual_out_strided.copy_(residual_out)
+ residual_out = residual_out_strided
+ return y, rstd, residual_out
+
+
+def rmsnorm_ref(
+ x: Tensor,
+ w: Optional[Tensor] = None,
+ b: Optional[Tensor] = None,
+ residual: Optional[Tensor] = None,
+ eps: float = 1e-6,
+) -> Tensor:
+ xf = x.float()
+ if residual is not None:
+ xf = xf + residual.float()
+ rstd = torch.rsqrt(xf.square().mean(dim=-1, keepdim=True) + eps)
+ y = xf * rstd
+ if w is not None:
+ y = y * w.float()
+ if b is not None:
+ y = y + b.float()
+ return y.to(x.dtype)
+
+
+def fused_add_rmsnorm_forward(
+ x: Tensor,
+ residual: Tensor,
+ weight: Tensor,
+ eps: float = 1e-6,
+) -> Tuple[Tensor, Tensor]:
+ """Fused residual-add + RMSNorm for SM100 in CuteDSL.
+
+ This is a convenience wrapper around ``rmsnorm_forward`` that matches the
+ semantics of vLLM's ``fused_add_rms_norm``:
+
+ z = x + residual
+ y = RMSNorm(z, weight, eps)
+
+ It returns ``(y, z)`` where ``z`` has the same dtype/shape as the inputs.
+ """
+ assert x.is_cuda and residual.is_cuda
+ assert x.shape == residual.shape
+ assert x.dtype == residual.dtype
+
+ orig_shape = x.shape
+ N = orig_shape[-1]
+
+ x_2d = x.view(-1, N)
+ res_2d = residual.view(-1, N)
+
+ y_2d, _rstd, z_2d = rmsnorm_forward(
+ x_2d,
+ weight=weight,
+ bias=None,
+ residual=res_2d,
+ eps=eps,
+ store_rstd=False,
+ )
+
+ y = y_2d.view(orig_shape)
+ z = z_2d.view(orig_shape)
+ return y, z
+
+
+def fused_add_rmsnorm_forward_inplace(
+ x: Tensor,
+ residual: Tensor,
+ weight: Tensor,
+ eps: float = 1e-6,
+) -> Tuple[Tensor, Tensor]:
+ """In-place fused residual-add + RMSNorm matching vLLM semantics.
+
+ This variant writes:
+
+ z = x + residual (stored into ``residual``)
+ y = RMSNorm(z, w) (stored into ``x``)
+
+ i.e., it uses ``x`` as the normalized output buffer and ``residual`` as
+ the residual-out buffer, mirroring vLLM's fused_add_rms_norm kernel.
+ """
+ fused_add_rmsnorm_inplace_(x, residual, weight, eps=eps)
+ return x, residual
+
+
+def fused_add_rmsnorm_inplace_(
+ x: Tensor,
+ residual: Tensor,
+ weight: Tensor,
+ eps: float = 1e-6,
+) -> None:
+ """In-place fused residual-add + RMSNorm matching vLLM semantics.
+
+ This is the lowest-overhead Python entrypoint (returns `None`) intended
+ for performance-critical call sites like `torch.ops.oink.fused_add_rms_norm`.
+ """
+ assert x.is_cuda and residual.is_cuda
+ assert x.shape == residual.shape
+ assert x.dtype == residual.dtype
+
+ N = x.shape[-1]
+ x_2d = x if x.dim() == 2 else x.view(-1, N)
+ res_2d = residual if residual.dim() == 2 else residual.view(-1, N)
+
+ # Fast path: vLLM-compatible layout where x may be strided/padded but
+ # residual is contiguous. This updates both tensors in-place without
+ # additional allocations.
+ if _can_use_ptr_path_fused_add_inplace(x_2d, weight, res_2d):
+ _fused_add_rmsnorm_forward_ptr_inplace(x_2d, res_2d, weight, eps)
+ return None
+
+ # Fallback: allocate via the regular fused path, then copy results into
+ # the user-provided buffers so that semantics remain identical.
+ y, z = fused_add_rmsnorm_forward(x, residual, weight, eps)
+ x.copy_(y)
+ residual.copy_(z)
+ return None
+
+
+if __name__ == "__main__":
+ # Minimal ad-hoc test (functionality only). For performance comparisons, use the benchmark harness.
+ if not torch.cuda.is_available():
+ print("CUDA not available; functional test skipped.")
+ sys.exit(0)
+ M, N = 1024, 8192
+ dtype = torch.bfloat16
+ x = torch.randn(M, N, device="cuda", dtype=dtype)
+ w = torch.randn(N, device="cuda", dtype=dtype)
+ y_ref = rmsnorm_ref(x, w)
+ y, _, _ = rmsnorm_forward(x, w)
+ torch.testing.assert_close(y, y_ref, rtol=1e-3, atol=1e-3)
+ print("RMSNormSM100 correctness check passed.")
+
+# (compile cache moved to top)
From 3003c1398f700cf13095eb40f1e5bab84de013f7 Mon Sep 17 00:00:00 2001
From: Laura Wang <3700467+Laurawly@users.noreply.github.com>
Date: Tue, 6 Jan 2026 12:25:31 -0800
Subject: [PATCH 2/8] Fix oink ruff lint and add license headers
---
oink/src/kernelagent_oink/__init__.py | 14 ++++++++
.../kernelagent_oink/blackwell/__init__.py | 14 ++++++++
.../kernelagent_oink/blackwell/lite_quack.py | 20 ++++++++---
.../blackwell/oink_custom_ops.py | 16 ++++++++-
.../src/kernelagent_oink/blackwell/rmsnorm.py | 33 +++++++++++++------
5 files changed, 82 insertions(+), 15 deletions(-)
diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py
index 542e59e..d9f25d0 100644
--- a/oink/src/kernelagent_oink/__init__.py
+++ b/oink/src/kernelagent_oink/__init__.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from __future__ import annotations
import logging
diff --git a/oink/src/kernelagent_oink/blackwell/__init__.py b/oink/src/kernelagent_oink/blackwell/__init__.py
index 4d21ee8..a92109a 100644
--- a/oink/src/kernelagent_oink/blackwell/__init__.py
+++ b/oink/src/kernelagent_oink/blackwell/__init__.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from __future__ import annotations
__all__ = []
diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py
index 3c3f750..8c05b47 100644
--- a/oink/src/kernelagent_oink/blackwell/lite_quack.py
+++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
Lightweight local clone of the small subset of helpers that the SM100
RMSNorm CuteDSL kernels depend on.
@@ -12,9 +26,8 @@
import math
import operator
-from typing import Callable, Optional, Tuple
+from typing import Callable, Optional
-import cuda.bindings.driver as cuda # type: ignore
import torch
from torch import Tensor
@@ -23,7 +36,7 @@
from cutlass import Float32, Int32, const_expr
from cutlass.cute.runtime import from_dlpack
from cutlass.cutlass_dsl import T, dsl_user_op
-from cutlass._mlir.dialects import llvm, nvvm, vector
+from cutlass._mlir.dialects import llvm
# -------------------------
@@ -347,4 +360,3 @@ def get_sm_count(N: int, device: torch.device) -> int:
sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
)
return sm_count
-
diff --git a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
index 8225025..92423d9 100644
--- a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
+++ b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
@@ -1,4 +1,16 @@
-from __future__ import annotations
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
"""
Torch custom ops wrapping Oink's Blackwell RMSNorm kernels.
@@ -26,6 +38,8 @@
Mutates `x` and `residual` in-place and returns None.
"""
+from __future__ import annotations
+
import importlib
import threading
diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
index d6c2c20..a77938c 100644
--- a/oink/src/kernelagent_oink/blackwell/rmsnorm.py
+++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
RMSNorm kernel for SM100 (Blackwell) in CuteDSL.
@@ -53,15 +67,15 @@
"(`cutlass`, typically provided by `nvidia-cutlass-dsl`)."
) from e
-import torch
-from torch import Tensor
+import torch # noqa: E402
+from torch import Tensor # noqa: E402
-import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python
+import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python # noqa: E402
-import cutlass
-import cutlass.cute as cute
-from cutlass import Float32, Int32, const_expr
-from cutlass.cute import runtime as rt
+import cutlass # noqa: E402
+import cutlass.cute as cute # noqa: E402
+from cutlass import Float32, Int32, const_expr # noqa: E402
+from cutlass.cute import runtime as rt # noqa: E402
# Simple compile cache declared early so direct execution works
_PTR_COMPILE_CACHE = {}
@@ -862,8 +876,8 @@ def _get_fast_ptr_fused_add_rmsnorm_launcher(
#
# NOTE: Avoid `from . import ...` imports here: CuTeDSL's AST preprocessor may
# mishandle that form (module=None in the AST). Use fully-qualified imports.
-from kernelagent_oink.blackwell import lite_quack as qutils
-from kernelagent_oink.blackwell.lite_quack import TORCH2CUTE_DTYPE, row_reduce
+from kernelagent_oink.blackwell import lite_quack as qutils # noqa: E402
+from kernelagent_oink.blackwell.lite_quack import TORCH2CUTE_DTYPE, row_reduce # noqa: E402
# -------------------------
@@ -2458,7 +2472,6 @@ def rmsnorm_forward(
assert x.is_cuda
assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand."
M, N = x.shape
- dtype = TORCH2CUTE_DTYPE[x.dtype]
# For DSv3 big-M outliers on SM100, keep using the dedicated
# stage-2 K-loop implementation, which is already tuned and
From 1468088ecab42c675d8e4d9e0bf465bd5f68644a Mon Sep 17 00:00:00 2001
From: Laura Wang <3700467+Laurawly@users.noreply.github.com>
Date: Tue, 6 Jan 2026 12:27:54 -0800
Subject: [PATCH 3/8] Format oink with ruff
---
oink/src/kernelagent_oink/__init__.py | 4 +-
.../kernelagent_oink/blackwell/lite_quack.py | 41 +-
.../blackwell/oink_custom_ops.py | 6 +-
.../src/kernelagent_oink/blackwell/rmsnorm.py | 472 ++++++++++++++----
4 files changed, 403 insertions(+), 120 deletions(-)
diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py
index d9f25d0..bbbd7c1 100644
--- a/oink/src/kernelagent_oink/__init__.py
+++ b/oink/src/kernelagent_oink/__init__.py
@@ -94,7 +94,9 @@ def register() -> None:
# Ensure CuTeDSL sees a target arch early. If the user has already set it,
# respect their choice.
- os.environ.setdefault("CUTE_DSL_ARCH", _compute_cutedsl_arch(int(major), int(minor)))
+ os.environ.setdefault(
+ "CUTE_DSL_ARCH", _compute_cutedsl_arch(int(major), int(minor))
+ )
# Import registers the ops via torch.library.custom_op decorators.
from .blackwell import oink_custom_ops # noqa: F401
diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py
index 8c05b47..14ae723 100644
--- a/oink/src/kernelagent_oink/blackwell/lite_quack.py
+++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py
@@ -54,6 +54,7 @@
# Tensor conversion helpers
# -------------------------
+
def convert_from_dlpack(
x: Tensor,
leading_dim: int,
@@ -82,7 +83,9 @@ def convert_from_dlpack(
@dsl_user_op
-def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
+def elem_pointer(
+ x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None
+) -> cute.Pointer:
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
@@ -133,7 +136,9 @@ def store_shared_remote(
).ir_value()
if const_expr(isinstance(val, float)):
val = Float32(val)
- assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
+ assert isinstance(val, (Float32, Int32, cutlass.Int64)), (
+ "val must be Float32, Int32, or Int64"
+ )
suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)]
llvm.inline_asm(
@@ -155,19 +160,27 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
"""
tApA = cute.make_fragment(
cute.make_layout(
- (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
+ (
+ cute.size(tAcA, mode=[0, 1]),
+ cute.size(tAcA, mode=[1]),
+ cute.size(tAcA, mode=[2]),
+ ),
stride=(cute.size(tAcA, mode=[2]), 0, 1),
),
cutlass.Boolean,
)
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
- tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
+ tApA[rest_v, 0, rest_k] = cute.elem_less(
+ tAcA[(0, rest_v), 0, rest_k][1], limit
+ )
return tApA
@dsl_user_op
-def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
+def domain_offset_i64(
+ coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
+) -> cute.Tensor:
"""
Return a tensor whose iterator is offset by an Int64 byte offset
computed from `coord` and the tensor's strides.
@@ -287,7 +300,9 @@ def block_or_cluster_reduce(
"""Dispatch between block or cluster reduction depending on mbar_ptr."""
if cutlass.const_expr(mbar_ptr is None):
return block_reduce(val, op, reduction_buffer, init_val=init_val)
- return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase)
+ return cluster_reduce(
+ val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase
+ )
@cute.jit
@@ -313,7 +328,9 @@ def row_reduce(
val = x
warp_op = {
cute.ReductionOp.ADD: operator.add,
- cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
+ cute.ReductionOp.MAX: cute.arch.fmax
+ if cutlass.const_expr(x.dtype == Float32)
+ else max,
cute.ReductionOp.MIN: min,
cute.ReductionOp.MUL: operator.mul,
}[op]
@@ -353,10 +370,16 @@ def get_sm_count(N: int, device: torch.device) -> int:
RMSNorm kernels but lives entirely in this local module.
"""
sm_count_multiple = (
- 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
+ 16
+ if N <= 256
+ else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
)
sm_count = torch.cuda.get_device_properties(device).multi_processor_count
sm_count = (
- sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2
+ sm_count * sm_count_multiple
+ if N <= 8192
+ else sm_count // 2
+ if N <= 16384
+ else sm_count * 2
)
return sm_count
diff --git a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
index 92423d9..a96a4c7 100644
--- a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
+++ b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py
@@ -80,6 +80,7 @@ def _get_sm(device: torch.device | None = None) -> int:
# RMSNorm (functional)
#
+
@custom_op("oink::rmsnorm", mutates_args=())
def oink_rmsnorm(
x: torch.Tensor,
@@ -158,6 +159,7 @@ def oink_rmsnorm_fake(
# Fused residual-add + RMSNorm (in-place, vLLM semantics)
#
+
@custom_op("oink::fused_add_rms_norm", mutates_args=("x", "residual"))
def oink_fused_add_rms_norm(
x: torch.Tensor,
@@ -174,7 +176,9 @@ def oink_fused_add_rms_norm(
Returns:
None (mutates `x` and `residual` in-place).
"""
- assert x.is_cuda and residual.is_cuda, "oink::fused_add_rms_norm requires CUDA tensors"
+ assert x.is_cuda and residual.is_cuda, (
+ "oink::fused_add_rms_norm requires CUDA tensors"
+ )
assert x.shape == residual.shape, "x and residual must have the same shape"
assert x.dtype == residual.dtype, "x and residual must have the same dtype"
assert weight.dim() == 1, "weight must be 1D [N]"
diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
index a77938c..c7fc1b3 100644
--- a/oink/src/kernelagent_oink/blackwell/rmsnorm.py
+++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
@@ -84,6 +84,7 @@
# pointer/scalar storage so concurrent callers don't race on in-place updates.
_PTR_FAST_LAUNCH_TLS = threading.local()
+
def _env_flag(name: str, default: bool) -> bool:
val = os.environ.get(name)
if val is None:
@@ -100,10 +101,16 @@ def _env_flag(name: str, default: bool) -> bool:
# Fused-add RMSNorm schedule knobs (read once at import time; set env vars before
# importing this module if you want to override).
-_DIRECT_GMEM_POLICY = (os.environ.get("OINK_RMSNORM_DIRECT_GMEM", "auto").strip().lower() or "auto")
-_COPY_BITS_POLICY = (os.environ.get("OINK_RMSNORM_COPY_BITS", "auto").strip().lower() or "auto")
+_DIRECT_GMEM_POLICY = (
+ os.environ.get("OINK_RMSNORM_DIRECT_GMEM", "auto").strip().lower() or "auto"
+)
+_COPY_BITS_POLICY = (
+ os.environ.get("OINK_RMSNORM_COPY_BITS", "auto").strip().lower() or "auto"
+)
_ENABLE_CLUSTER_ILP = _env_flag("OINK_RMSNORM_ENABLE_CLUSTER_ILP", default=False)
-_ENABLE_CLUSTER_ILP_UNSAFE = _env_flag("OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE", default=False)
+_ENABLE_CLUSTER_ILP_UNSAFE = _env_flag(
+ "OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE", default=False
+)
_ENABLE_TPR256 = _env_flag("OINK_RMSNORM_ENABLE_TPR256", default=False)
_ENABLE_STAGE2 = _env_flag("OINK_RMSNORM_ENABLE_STAGE2", default=False)
@@ -252,6 +259,7 @@ def run_probe(copy_bits: int, assumed_align: int):
_CLUSTER_DIRECT_GMEM_MAX_COPY_BITS = max_bits
return max_bits
+
def _parse_version_tuple(version: str) -> Tuple[int, int, int]:
parts = version.split(".")
nums: list[int] = []
@@ -275,7 +283,9 @@ def _cutlass_dsl_version() -> Optional[Tuple[int, int, int]]:
# passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the
# older signature for 4.3.2, but switch to a 4.3.4-compatible signature when we
# detect 4.3.4+ (or when version detection is unavailable).
-_KERNEL_ACCEPTS_LAYOUT_ARGS = _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4)
+_KERNEL_ACCEPTS_LAYOUT_ARGS = (
+ _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4)
+)
if _ENABLE_CLUSTER_ILP and not _ENABLE_CLUSTER_ILP_UNSAFE:
# We have observed reproducible segfaults in some CuTeDSL builds when using
@@ -427,7 +437,9 @@ def launch(
self._last_x_ptr = x_ptr
except AttributeError:
self._disable_fast_launch()
- self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps)
+ self._fallback_launch(
+ x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps
+ )
return
if self._ptr_w is not None:
@@ -438,7 +450,9 @@ def launch(
self._last_w_ptr = w_ptr
except AttributeError:
self._disable_fast_launch()
- self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps)
+ self._fallback_launch(
+ x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps
+ )
return
out_ptr = out.data_ptr()
@@ -448,7 +462,9 @@ def launch(
self._last_out_ptr = out_ptr
except AttributeError:
self._disable_fast_launch()
- self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps)
+ self._fallback_launch(
+ x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps
+ )
return
if M != self._last_m:
@@ -492,10 +508,19 @@ def _fallback_launch(
# If the packed-args or runtime pointer mutation path stops working
# (e.g. due to a CuTeDSL upgrade), fall back to the regular call path.
dtype = TORCH2CUTE_DTYPE[x.dtype]
- ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
- ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_x = rt.make_ptr(
+ dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_out = rt.make_ptr(
+ dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
ptr_w = (
- rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ rt.make_ptr(
+ dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
if weight is not None
else None
)
@@ -695,7 +720,15 @@ def _get_fast_ptr_rmsnorm_launcher(
return None
# Keyed by the compiled object identity so schedule changes (e.g. copy width,
# async/staged variants, etc.) never alias in the fast-launch cache.
- key = ("ptr_fast", id(compiled), N, dtype, device_index, int(stream_handle), has_weight)
+ key = (
+ "ptr_fast",
+ id(compiled),
+ N,
+ dtype,
+ device_index,
+ int(stream_handle),
+ has_weight,
+ )
cache = _tls_fast_launch_cache()
cached = cache.get(key)
if cached is not None:
@@ -705,7 +738,9 @@ def _get_fast_ptr_rmsnorm_launcher(
ptr_x = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16)
ptr_out = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16)
ptr_w = (
- rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) if has_weight else None
+ rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ if has_weight
+ else None
)
arg_m = _StableI32Arg(0)
@@ -808,9 +843,15 @@ def _get_fast_ptr_fused_add_rmsnorm_launcher(
if cached is not None:
return cached # type: ignore[return-value]
- ptr_x = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
- ptr_res = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
- ptr_w = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
+ ptr_x = rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ ptr_res = rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
+ ptr_w = rt.make_ptr(
+ dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
arg_m = _StableI32Arg(0)
arg_n = _StableI32Arg(N)
@@ -884,6 +925,7 @@ def _get_fast_ptr_fused_add_rmsnorm_launcher(
# Copy helpers (allow up to 256b)
# -------------------------
+
@cute.jit
def get_copy_atom_bw(
dtype: type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False
@@ -892,6 +934,7 @@ def get_copy_atom_bw(
max_bits = const_expr(128 if is_async else 256)
num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width))
from cutlass.cute.nvgpu import cpasync
+
# Prefer GLOBAL cache policy for bulk streaming reads at large M
copy_op = (
cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL)
@@ -1037,7 +1080,10 @@ def _tv_layout(self, num_copy_bits: int = 256) -> Tuple[cute.Shape, cute.Layout]
tiler_mn = (cols_per_block, vecsize * num_blocks_N * tpr)
tv_layout = cute.make_layout(
((tpr, cols_per_block), (vecsize, num_blocks_N)),
- stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * tpr)),
+ stride=(
+ (vecsize * cols_per_block, 1),
+ (cols_per_block, cols_per_block * vecsize * tpr),
+ ),
)
return tiler_mn, tv_layout
@@ -1045,7 +1091,10 @@ def _smem_bytes(self, tiler_mn, num_warps) -> int:
# smem for X tile (+ residual if present) + reduction buffers + mbar(s)
return (
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
- + self.stage * num_warps * self._cluster_n() * (self.reduction_dtype.width // 8)
+ + self.stage
+ * num_warps
+ * self._cluster_n()
+ * (self.reduction_dtype.width // 8)
+ self.stage * (cutlass.Int64.width // 8)
)
@@ -1072,7 +1121,9 @@ def new_stride(t):
)
mX, mRes, mO, mResO = [
- cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
+ cute.make_tensor(
+ t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))
+ )
if const_expr(t is not None)
else None
for t in (mX, mRes, mO, mResO)
@@ -1082,23 +1133,34 @@ def new_stride(t):
copy_bits = int(self.copy_bits)
tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits)
- num_threads = cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._num_threads()
+ num_threads = (
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._num_threads()
+ )
num_warps = num_threads // cute.arch.WARP_SIZE
- threads_per_row = tv_layout.shape[0][0] if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._threads_per_row()
+ threads_per_row = (
+ tv_layout.shape[0][0]
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._threads_per_row()
+ )
warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1)
cluster_n = self._cluster_n()
if const_expr(mW is not None):
mW = cute.make_tensor(
- mW.iterator, cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
+ mW.iterator,
+ cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))),
)
if const_expr(mB is not None):
mB = cute.make_tensor(
- mB.iterator, cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
+ mB.iterator,
+ cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))),
)
if const_expr(mRstd is not None):
mRstd = cute.make_tensor(
- mRstd.iterator, cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,)))
+ mRstd.iterator,
+ cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))),
)
# No SMEM reload mode switch; overlap is controlled in the K-loop path
@@ -1114,11 +1176,14 @@ def new_stride(t):
else 0
)
tile_bytes_res = (
- cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn)) * stage_bufs
+ cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn))
+ * stage_bufs
if const_expr(mRes is not None and not self.direct_gmem)
else 0
)
- red_bytes = self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8)
+ red_bytes = (
+ self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8)
+ )
# mbarriers are only allocated/used for cluster_n>1. Some CuTeDSL builds
# require mbarrier state to be 16B-aligned in shared memory; account for
# the alignment padding when computing dynamic smem bytes.
@@ -1211,14 +1276,10 @@ def launch_from_ptrs(
else None
)
mW = (
- cute.make_tensor(ptr_w, layout_n)
- if const_expr(ptr_w is not None)
- else None
+ cute.make_tensor(ptr_w, layout_n) if const_expr(ptr_w is not None) else None
)
mB = (
- cute.make_tensor(ptr_b, layout_n)
- if const_expr(ptr_b is not None)
- else None
+ cute.make_tensor(ptr_b, layout_n) if const_expr(ptr_b is not None) else None
)
mRstd = (
cute.make_tensor(ptr_rstd, layout_m)
@@ -1303,7 +1364,9 @@ def _kernel_impl(
# Allocate one or two SMEM buffers depending on stage depth
sX0 = (
smem.allocate_tensor(
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
)
if const_expr(not self.direct_gmem)
else None
@@ -1319,7 +1382,9 @@ def _kernel_impl(
)
sRes0 = (
smem.allocate_tensor(
- mRes.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32
+ mRes.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
)
if const_expr(mRes is not None and not self.direct_gmem)
else None
@@ -1339,7 +1404,9 @@ def _kernel_impl(
(num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
order=(1, 0, 2),
)
- reduction_buffer = smem.allocate_tensor(self.reduction_dtype, red_layout, byte_alignment=4)
+ reduction_buffer = smem.allocate_tensor(
+ self.reduction_dtype, red_layout, byte_alignment=4
+ )
if const_expr(cluster_n > 1):
# Some CuTeDSL builds appear sensitive to the shared-memory alignment of
# mbarrier state. `SmemAllocator.allocate_array` does not currently
@@ -1360,8 +1427,12 @@ def _kernel_impl(
# Tiled copy setup
num_copy_elems_X = tv_layout.shape[1][0]
- use_async = const_expr(self.use_async and self.N >= 1024 and not self.direct_gmem)
- copy_atom = get_copy_atom_bw(mX.element_type, num_copy_elems_X, is_async=use_async)
+ use_async = const_expr(
+ self.use_async and self.N >= 1024 and not self.direct_gmem
+ )
+ copy_atom = get_copy_atom_bw(
+ mX.element_type, num_copy_elems_X, is_async=use_async
+ )
thr_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn).get_slice(tidx)
# Tail predicate for the N dimension (when tile width > N). Reuse this
@@ -1386,7 +1457,9 @@ def _kernel_impl(
mW is not None and (self.direct_gmem or (mRes is None and mB is None))
)
if const_expr(prefetch_w_early):
- gW = cute.local_tile(qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0))
+ gW = cute.local_tile(
+ qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0)
+ )
tXgW = thr_copy.partition_S(gW)
tXrW = cute.make_fragment_like(tXgW)
if const_expr(not is_even_N_wb):
@@ -1398,7 +1471,9 @@ def _kernel_impl(
pred=tXp_wb,
)
if const_expr(self.direct_gmem and mB is not None):
- gB = cute.local_tile(qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0))
+ gB = cute.local_tile(
+ qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0)
+ )
tXgB = thr_copy.partition_S(gB)
tXrB = cute.make_fragment_like(tXgB)
if const_expr(not is_even_N_wb):
@@ -1414,7 +1489,9 @@ def _kernel_impl(
self._init_cluster(tidx, mbar_ptr)
mX_i, mRes_i, mO_i, mResO_i = [
- qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t) if t is not None else None
+ qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t)
+ if t is not None
+ else None
for t in (mX, mRes, mO, mResO)
]
mX_i, mRes_i, mO_i, mResO_i = [
@@ -1424,27 +1501,39 @@ def _kernel_impl(
gX_i = cute.local_tile(mX_i, tiler_mn, (0, 0))
gO_i = cute.local_tile(mO_i, tiler_mn, (0, 0))
gRes_i = (
- cute.local_tile(mRes_i, tiler_mn, (0, 0)) if const_expr(mRes is not None) else None
+ cute.local_tile(mRes_i, tiler_mn, (0, 0))
+ if const_expr(mRes is not None)
+ else None
)
gResO_i = (
- cute.local_tile(mResO_i, tiler_mn, (0, 0)) if const_expr(mResO is not None) else None
+ cute.local_tile(mResO_i, tiler_mn, (0, 0))
+ if const_expr(mResO is not None)
+ else None
)
gRstd_i = (
- cute.local_tile(mRstd, tiler_mn, (bidx, 0)) if const_expr(mRstd is not None) else None
+ cute.local_tile(mRstd, tiler_mn, (bidx, 0))
+ if const_expr(mRstd is not None)
+ else None
)
cX_i = cute.local_tile(idX, tiler_mn, (bidx, 0))
# Common identity/row index partitions reused by both default and K-loop paths
tXcX_i = thr_copy.partition_S(cX_i)[(0, None), None, None]
row_i = tXcX_i[0][0]
- tXgRstd_i = thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None
+ tXgRstd_i = (
+ thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None
+ )
# Stage-2 intra-row K-loop cp.async ping-pong (two tiles). This reduces
# per-thread fragment size and can improve memory-latency hiding for
# N=7168 at large M. It is enabled by setting `stage=2` when constructing
# the RMSNormSM100 op (see `_fused_add_rmsnorm_forward_ptr_inplace`).
if const_expr(
- self.stage > 1 and not self.direct_gmem and use_async and cluster_n == 1 and shape[1] == 7168
+ self.stage > 1
+ and not self.direct_gmem
+ and use_async
+ and cluster_n == 1
+ and shape[1] == 7168
):
vecsize = tv_layout.shape[1][0]
tpr = threads_per_row
@@ -1475,9 +1564,9 @@ def _kernel_impl(
(tiler_mn[0], tiler_mn[0] * vecsize * tpr),
),
)
- thr_copy_tile = cute.make_tiled_copy(copy_atom, tv_layout_tile, tiler_mn_tile).get_slice(
- tidx
- )
+ thr_copy_tile = cute.make_tiled_copy(
+ copy_atom, tv_layout_tile, tiler_mn_tile
+ ).get_slice(tidx)
# Accumulate per-thread partial sums across tiles; reduce once.
sum_sq_thread = cute.Float32(0.0)
@@ -1499,7 +1588,13 @@ def _kernel_impl(
tXp_pong = tXp_0
if row_i < shape[0]:
- copy_tiled(tXgX_0, tXsX_0, num_copy_elems=vecsize, is_async=True, pred=tXp_0)
+ copy_tiled(
+ tXgX_0,
+ tXsX_0,
+ num_copy_elems=vecsize,
+ is_async=True,
+ pred=tXp_0,
+ )
if const_expr(mRes is not None):
gRes_0 = cute.local_tile(
qutils.domain_offset_i64((0, k_off0), mRes_i),
@@ -1538,19 +1633,27 @@ def _kernel_impl(
if const_expr((t % 2) == 0):
tXsX_n = thr_copy_tile.partition_D(sX1_tile)
tXsRes_n = (
- thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None
+ thr_copy_tile.partition_D(sRes1_tile)
+ if const_expr(mRes is not None)
+ else None
)
tXp_pong = tXp_n
else:
tXsX_n = thr_copy_tile.partition_D(sX0_tile)
tXsRes_n = (
- thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None
+ thr_copy_tile.partition_D(sRes0_tile)
+ if const_expr(mRes is not None)
+ else None
)
tXp_ping = tXp_n
if row_i < shape[0]:
copy_tiled(
- tXgX_n, tXsX_n, num_copy_elems=vecsize, is_async=True, pred=tXp_n
+ tXgX_n,
+ tXsX_n,
+ num_copy_elems=vecsize,
+ is_async=True,
+ pred=tXp_n,
)
if const_expr(mRes is not None):
gRes_n = cute.local_tile(
@@ -1574,25 +1677,35 @@ def _kernel_impl(
if const_expr((t % 2) == 0):
tXsX_cur = thr_copy_tile.partition_D(sX0_tile)
tXsRes_cur = (
- thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None
+ thr_copy_tile.partition_D(sRes0_tile)
+ if const_expr(mRes is not None)
+ else None
)
pred_cur = tXp_ping
else:
tXsX_cur = thr_copy_tile.partition_D(sX1_tile)
tXsRes_cur = (
- thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None
+ thr_copy_tile.partition_D(sRes1_tile)
+ if const_expr(mRes is not None)
+ else None
)
pred_cur = tXp_pong
k_off = t * tile_n
- gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, 0))
+ gX_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mX_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
tXgX_t = thr_copy_tile.partition_S(gX_t)
tXrX_t = cute.make_fragment_like(tXgX_t)
cute.autovec_copy(tXsX_cur, tXrX_t)
x_t = tXrX_t.load().to(cute.Float32)
if const_expr(mRes is not None):
gRes_t = cute.local_tile(
- qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, 0)
+ qutils.domain_offset_i64((0, k_off), mRes_i),
+ tiler_mn_tile,
+ (0, 0),
)
tXgRes_t = thr_copy_tile.partition_S(gRes_t)
tXrRes_t = cute.make_fragment_like(tXgRes_t)
@@ -1639,29 +1752,41 @@ def _kernel_impl(
for t in cutlass.range_constexpr(num_tiles):
k_off = t * tile_n
- cX_t = cute.local_tile(cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, 0))
+ cX_t = cute.local_tile(
+ cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, 0)
+ )
tXc_t = thr_copy_tile.partition_S(cX_t)
tXp_t = qutils.predicate_k(tXc_t, limit=limit_k)
if const_expr((t % 2) == 0):
tXsX_cur = thr_copy_tile.partition_D(sX0_tile)
tXsRes_cur = (
- thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None
+ thr_copy_tile.partition_D(sRes0_tile)
+ if const_expr(mRes is not None)
+ else None
)
else:
tXsX_cur = thr_copy_tile.partition_D(sX1_tile)
tXsRes_cur = (
- thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None
+ thr_copy_tile.partition_D(sRes1_tile)
+ if const_expr(mRes is not None)
+ else None
)
- gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, 0))
+ gX_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mX_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
tXgX_t = thr_copy_tile.partition_S(gX_t)
tXrX_t = cute.make_fragment_like(tXgX_t)
cute.autovec_copy(tXsX_cur, tXrX_t)
x_t = tXrX_t.load().to(cute.Float32)
if const_expr(mRes is not None):
gRes_t = cute.local_tile(
- qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, 0)
+ qutils.domain_offset_i64((0, k_off), mRes_i),
+ tiler_mn_tile,
+ (0, 0),
)
tXgRes_t = thr_copy_tile.partition_S(gRes_t)
tXrRes_t = cute.make_fragment_like(tXgRes_t)
@@ -1671,43 +1796,77 @@ def _kernel_impl(
y_t = x_t * rstd
if const_expr(mW is not None):
gW_t = cute.local_tile(
- qutils.domain_offset_i64((0, k_off), mW), tiler_mn_tile, (0, 0)
+ qutils.domain_offset_i64((0, k_off), mW),
+ tiler_mn_tile,
+ (0, 0),
)
tWgW_t = thr_copy_tile.partition_S(gW_t)
tWrW_t = cute.make_fragment_like(tWgW_t)
- copy_tiled(tWgW_t, tWrW_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t)
+ copy_tiled(
+ tWgW_t,
+ tWrW_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=tXp_t,
+ )
y_t = y_t * tWrW_t.load().to(cute.Float32)
if const_expr(mB is not None):
gB_t = cute.local_tile(
- qutils.domain_offset_i64((0, k_off), mB), tiler_mn_tile, (0, 0)
+ qutils.domain_offset_i64((0, k_off), mB),
+ tiler_mn_tile,
+ (0, 0),
)
tWgB_t = thr_copy_tile.partition_S(gB_t)
tWrB_t = cute.make_fragment_like(tWgB_t)
- copy_tiled(tWgB_t, tWrB_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t)
+ copy_tiled(
+ tWgB_t,
+ tWrB_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=tXp_t,
+ )
y_t = y_t + tWrB_t.load().to(cute.Float32)
- gO_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mO_i), tiler_mn_tile, (0, 0))
+ gO_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mO_i),
+ tiler_mn_tile,
+ (0, 0),
+ )
tXgO_t = thr_copy_tile.partition_D(gO_t)
tXrO_t = cute.make_fragment_like(tXgO_t)
tXrO_t.store(y_t.to(tXrO_t.element_type))
if row_i < shape[0]:
- copy_tiled(tXrO_t, tXgO_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t)
+ copy_tiled(
+ tXrO_t,
+ tXgO_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=tXp_t,
+ )
return
# Single-stage path: one-row-per-CTA
tXgX_i = thr_copy.partition_S(gX_i)
- tXgRes_i = thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None
+ tXgRes_i = (
+ thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None
+ )
tXgO_i = thr_copy.partition_D(gO_i)
- tXgResO_i = thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None
+ tXgResO_i = (
+ thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None
+ )
# tXgRstd_i / tXcX_i / row_i prepared above
is_even_N_i = const_expr(shape[1] == tiler_mn[1] * cluster_n)
tXpX_i = (
- qutils.predicate_k(thr_copy.partition_S(cX_i), limit=limit_k) if not is_even_N_i else None
+ qutils.predicate_k(thr_copy.partition_S(cX_i), limit=limit_k)
+ if not is_even_N_i
+ else None
)
tXrX = cute.make_fragment_like(tXgX_i)
- tXrRes = cute.make_fragment_like(tXgRes_i) if const_expr(mRes is not None) else None
+ tXrRes = (
+ cute.make_fragment_like(tXgRes_i) if const_expr(mRes is not None) else None
+ )
if const_expr(self.direct_gmem):
if const_expr(not is_even_N_i):
tXrX.fill(0)
@@ -1729,7 +1888,9 @@ def _kernel_impl(
if row_i < shape[0]:
cute.copy(copy_atom, tXgX_i, thr_copy.partition_D(sX0), pred=tXpX_i)
if const_expr(mRes is not None):
- cute.copy(copy_atom, tXgRes_i, thr_copy.partition_D(sRes0), pred=tXpX_i)
+ cute.copy(
+ copy_atom, tXgRes_i, thr_copy.partition_D(sRes0), pred=tXpX_i
+ )
if const_expr(use_async):
cute.arch.cp_async_commit_group()
cute.arch.cp_async_wait_group(0)
@@ -1746,7 +1907,9 @@ def _kernel_impl(
tXrResO.store(x_red.to(tXrResO.element_type))
if row_i < shape[0]:
cute.copy(
- get_copy_atom_bw(tXrResO.element_type, num_copy_elems_X, is_async=False),
+ get_copy_atom_bw(
+ tXrResO.element_type, num_copy_elems_X, is_async=False
+ ),
tXrResO,
tXgResO_i,
pred=tXpX_i,
@@ -1775,7 +1938,9 @@ def _kernel_impl(
# pressure during the long-scoreboard reduction phase (helping occupancy
# when registers are the limiting factor).
if const_expr(mW is not None):
- gW = cute.local_tile(qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0))
+ gW = cute.local_tile(
+ qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0)
+ )
tXgW = thr_copy.partition_S(gW)
tXrW = cute.make_fragment_like(tXgW)
if const_expr(not is_even_N_wb):
@@ -1787,7 +1952,9 @@ def _kernel_impl(
pred=tXp_wb,
)
if const_expr(mB is not None):
- gB = cute.local_tile(qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0))
+ gB = cute.local_tile(
+ qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0)
+ )
tXgB = thr_copy.partition_S(gB)
tXrB = cute.make_fragment_like(tXgB)
if const_expr(not is_even_N_wb):
@@ -1818,6 +1985,7 @@ def _kernel_impl(
)
if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
@cute.kernel
def kernel(
self,
@@ -1853,6 +2021,7 @@ def kernel(
threads_per_row,
)
else:
+
@cute.kernel
def kernel(
self,
@@ -2015,7 +2184,10 @@ def _rmsnorm_forward_ptr(
if residual is not None:
residual_out = torch.empty_strided(
- residual.shape, residual.stride(), device=residual.device, dtype=residual.dtype
+ residual.shape,
+ residual.stride(),
+ device=residual.device,
+ dtype=residual.dtype,
)
if store_rstd:
rstd = torch.empty(M, device=x.device, dtype=torch.float32)
@@ -2082,10 +2254,19 @@ def _rmsnorm_forward_ptr_into(
if compiled is None:
op = RMSNormSM100(N, dtype, stage=stage)
ld_val = int(x.stride(0))
- ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
- ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_x = rt.make_ptr(
+ dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_out = rt.make_ptr(
+ dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
ptr_w = (
- rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ rt.make_ptr(
+ dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
if has_weight
else None
)
@@ -2122,10 +2303,19 @@ def _rmsnorm_forward_ptr_into(
launcher.launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld_val, eps=eps)
return
- ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
- ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_x = rt.make_ptr(
+ dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_out = rt.make_ptr(
+ dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
ptr_w = (
- rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ rt.make_ptr(
+ dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
if has_weight
else None
)
@@ -2167,33 +2357,55 @@ def _rmsnorm_forward_ptr_into(
compiled = _PTR_COMPILE_CACHE.get(key)
if compiled is None:
op = RMSNormSM100(N, dtype, stage=stage)
- ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
- ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_x = rt.make_ptr(
+ dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_out = rt.make_ptr(
+ dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
ptr_res = (
- rt.make_ptr(dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ rt.make_ptr(
+ dtype,
+ residual.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
if residual is not None
else None
)
ptr_res_out = (
rt.make_ptr(
- dtype, residual_out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ dtype,
+ residual_out.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
)
if residual_out is not None
else None
)
ptr_w = (
- rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ rt.make_ptr(
+ dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
if weight is not None
else None
)
ptr_b = (
- rt.make_ptr(dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ rt.make_ptr(
+ dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
if bias is not None
else None
)
ptr_rstd = (
rt.make_ptr(
- cutlass.Float32, rstd.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=4
+ cutlass.Float32,
+ rstd.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
)
if rstd is not None
else None
@@ -2216,30 +2428,50 @@ def _rmsnorm_forward_ptr_into(
Float32(eps),
)
_PTR_COMPILE_CACHE[key] = compiled
- ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
- ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_x = rt.make_ptr(
+ dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_out = rt.make_ptr(
+ dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
ptr_res = (
- rt.make_ptr(dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ rt.make_ptr(
+ dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
if residual is not None
else None
)
ptr_res_out = (
- rt.make_ptr(dtype, residual_out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ rt.make_ptr(
+ dtype,
+ residual_out.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
if residual_out is not None
else None
)
ptr_w = (
- rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ rt.make_ptr(
+ dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
if weight is not None
else None
)
ptr_b = (
- rt.make_ptr(dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ rt.make_ptr(
+ dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
if bias is not None
else None
)
ptr_rstd = (
- rt.make_ptr(cutlass.Float32, rstd.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=4)
+ rt.make_ptr(
+ cutlass.Float32,
+ rstd.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
if rstd is not None
else None
)
@@ -2296,7 +2528,9 @@ def _fused_add_rmsnorm_forward_ptr_inplace(
# benchmark other models/shapes, you can override it with:
# - OINK_RMSNORM_DIRECT_GMEM=0 (force staging/cp.async path)
# - OINK_RMSNORM_DIRECT_GMEM=1 (force direct-gmem path)
- direct_gmem = _direct_gmem_from_policy(default=bool(dtype.width == 16 and N == 7168))
+ direct_gmem = _direct_gmem_from_policy(
+ default=bool(dtype.width == 16 and N == 7168)
+ )
use_async = not direct_gmem
tpr_override: Optional[int] = None
nt_override: Optional[int] = None
@@ -2340,13 +2574,9 @@ def _fused_add_rmsnorm_forward_ptr_inplace(
tpr_override = 256
nt_override = 256
-
can_use_256 = bool(
direct_gmem
- and (
- direct_gmem_max_copy_bits is None
- or direct_gmem_max_copy_bits >= 256
- )
+ and (direct_gmem_max_copy_bits is None or direct_gmem_max_copy_bits >= 256)
and dtype.width == 16
and (x.data_ptr() % 32) == 0
and (residual.data_ptr() % 32) == 0
@@ -2395,13 +2625,22 @@ def _fused_add_rmsnorm_forward_ptr_inplace(
if cluster_n_override is not None:
op._cluster_n_override = cluster_n_override # type: ignore[attr-defined]
ptr_x = rt.make_ptr(
- dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ dtype,
+ x.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
)
ptr_res = rt.make_ptr(
- dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ dtype,
+ residual.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
)
ptr_w = rt.make_ptr(
- dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
)
stream = cuda.CUstream(stream_handle)
ld_x = Int32(int(x.stride(0)))
@@ -2444,12 +2683,20 @@ def _fused_add_rmsnorm_forward_ptr_inplace(
# Fast-launch is disabled/unavailable (or CuTeDSL internals changed). Fall back
# to calling the compiled function directly.
- ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align)
+ ptr_x = rt.make_ptr(
+ dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ )
ptr_res = rt.make_ptr(
- dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ dtype,
+ residual.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
)
ptr_w = rt.make_ptr(
- dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align
+ dtype,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=assumed_align,
)
stream = cuda.CUstream(stream_handle)
ld_x = Int32(int(x.stride(0)))
@@ -2486,7 +2733,12 @@ def rmsnorm_forward(
rms2 = None # type: ignore[assignment]
if rms2 is not None:
y, rstd, residual_out = rms2.rmsnorm_forward_with_stage2(
- x, weight=weight, bias=bias, residual=residual, eps=eps, store_rstd=store_rstd
+ x,
+ weight=weight,
+ bias=bias,
+ residual=residual,
+ eps=eps,
+ store_rstd=store_rstd,
)
# Preserve stride contracts for torch.compile consistency, even
# when using the optional stage-2 implementation.
@@ -2519,7 +2771,9 @@ def rmsnorm_forward(
# Preserve the input stride contract even on the fallback path so
# torch.compile sees a consistent output layout across all branches.
if y.stride() != x.stride():
- y_strided = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
+ y_strided = torch.empty_strided(
+ x.shape, x.stride(), device=x.device, dtype=x.dtype
+ )
y_strided.copy_(y)
y = y_strided
rstd = None
From 4c9a826cb2b1a3d48fbffe156f10943b41ec8f80 Mon Sep 17 00:00:00 2001
From: Laura Wang <3700467+Laurawly@users.noreply.github.com>
Date: Wed, 21 Jan 2026 11:06:37 -0800
Subject: [PATCH 4/8] oink: SM100 suite refresh (strict parity + quack-style
benches)
- Switch correctness gate to PyTorch ref + record err stats\n- Tighten Softmax/LayerNorm tolerances (Quack-like)\n- Quack-style benchmark suite layout + SVG plots\n- Packaging/README polish for publishability
---
oink/README.md | 94 +-
oink/benchmarks/README.md | 146 +
oink/benchmarks/benchmark/bench_utils.py | 255 ++
.../benchmark_cross_entropy_sm100.py | 426 +++
.../benchmark_fused_add_rmsnorm_sm100.py | 296 ++
.../benchmark/benchmark_hbm_roofline_sm100.py | 226 ++
.../benchmark/benchmark_layernorm_sm100.py | 393 +++
.../benchmark/benchmark_rmsnorm_bwd_sm100.py | 434 +++
.../benchmark/benchmark_rmsnorm_sm100.py | 337 ++
.../benchmark/benchmark_softmax_sm100.py | 292 ++
.../media/sm100_bf16_oink_vs_quack.svg | 2259 +++++++++++++
.../media/sm100_bf16_oink_vs_quack_dsv3.svg | 2600 +++++++++++++++
.../sm100_bf16_oink_vs_quack_dsv3_all.svg | 2936 ++++++++++++++++
..._bf16_oink_vs_quack_dsv3_cross_entropy.svg | 1687 ++++++++++
...bf16_oink_vs_quack_dsv3_with_layernorm.svg | 2720 +++++++++++++++
...m100_bf16_oink_vs_quack_with_layernorm.svg | 2580 ++++++++++++++
.../media/sm100_fp16_oink_vs_quack.svg | 2280 +++++++++++++
.../media/sm100_fp16_oink_vs_quack_dsv3.svg | 2621 +++++++++++++++
.../sm100_fp16_oink_vs_quack_dsv3_all.svg | 2957 +++++++++++++++++
..._fp16_oink_vs_quack_dsv3_cross_entropy.svg | 1708 ++++++++++
...fp16_oink_vs_quack_dsv3_with_layernorm.svg | 2741 +++++++++++++++
...m100_fp16_oink_vs_quack_with_layernorm.svg | 2601 +++++++++++++++
.../benchmarks/readme/plot_quack_style_svg.py | 431 +++
oink/benchmarks/readme/run_sm100_suite.py | 302 ++
oink/benchmarks/readme/summarize_results.py | 205 ++
oink/pyproject.toml | 24 +-
oink/src/kernelagent_oink/__init__.py | 28 +-
.../blackwell/cross_entropy.py | 1209 +++++++
.../kernelagent_oink/blackwell/layernorm.py | 1368 ++++++++
.../kernelagent_oink/blackwell/lite_quack.py | 1001 +++++-
.../src/kernelagent_oink/blackwell/rmsnorm.py | 467 ++-
.../blackwell/rmsnorm_with_stage2.py | 805 +++++
.../src/kernelagent_oink/blackwell/softmax.py | 749 +++++
33 files changed, 39026 insertions(+), 152 deletions(-)
create mode 100644 oink/benchmarks/README.md
create mode 100644 oink/benchmarks/benchmark/bench_utils.py
create mode 100644 oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
create mode 100644 oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
create mode 100644 oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
create mode 100644 oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
create mode 100644 oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
create mode 100644 oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
create mode 100644 oink/benchmarks/benchmark/benchmark_softmax_sm100.py
create mode 100644 oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg
create mode 100644 oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg
create mode 100644 oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg
create mode 100644 oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg
create mode 100644 oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg
create mode 100644 oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg
create mode 100644 oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg
create mode 100644 oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg
create mode 100644 oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg
create mode 100644 oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg
create mode 100644 oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg
create mode 100644 oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg
create mode 100644 oink/benchmarks/readme/plot_quack_style_svg.py
create mode 100644 oink/benchmarks/readme/run_sm100_suite.py
create mode 100644 oink/benchmarks/readme/summarize_results.py
create mode 100644 oink/src/kernelagent_oink/blackwell/cross_entropy.py
create mode 100644 oink/src/kernelagent_oink/blackwell/layernorm.py
create mode 100644 oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py
create mode 100644 oink/src/kernelagent_oink/blackwell/softmax.py
diff --git a/oink/README.md b/oink/README.md
index 427f69f..aeb0c09 100644
--- a/oink/README.md
+++ b/oink/README.md
@@ -1,13 +1,33 @@
-# KernelAgent Oink (vLLM plugin)
+# KernelAgent-Oink
-This subproject provides an **out-of-tree vLLM plugin** that registers
-`torch.library.custom_op` entrypoints under the `oink::` namespace:
+KernelAgent-Oink is a small **CuTeDSL (CUTLASS DSL) kernel library** for
+**NVIDIA Blackwell (SM100 / GB200 / B200-class)**, bundled as a lightweight
+Python package that can be used standalone or as a **vLLM general plugin**.
-- `torch.ops.oink.rmsnorm`
-- `torch.ops.oink.fused_add_rms_norm`
+At the moment, the vLLM integration exposes the following `torch.library.custom_op`
+entrypoints under the `oink::` namespace:
-The implementation is backed by a CuTeDSL (CUTLASS) RMSNorm kernel tuned for
-**NVIDIA Blackwell (SM100)**.
+- `torch.ops.oink.rmsnorm(x, weight, eps) -> Tensor`
+- `torch.ops.oink.fused_add_rms_norm(x, residual, weight, eps) -> None` (in-place)
+
+The package also includes additional SM100 kernels used by the benchmark suite:
+LayerNorm, Softmax (fwd+bwd), and CrossEntropy (fwd+bwd).
+
+## Requirements
+
+- GPU: **SM100** for the fast CuTeDSL paths. On other GPUs, Oink falls back to
+ reference PyTorch implementations for correctness.
+- Python dependencies:
+ - `nvidia-cutlass-dsl` (CuTeDSL)
+ - `cuda-python`
+ - `torch` (provided by your environment / vLLM)
+
+Recommended env vars:
+
+```bash
+export CUTE_DSL_ARCH=sm_100a
+export PYTORCH_ALLOC_CONF=expandable_segments:True
+```
## Install (editable)
@@ -17,22 +37,23 @@ From the `KernelAgent` repo root:
pip install -e ./oink
```
-This plugin requires the CuTeDSL stack:
+For running the in-repo benchmark suite / plots:
```bash
-pip install nvidia-cutlass-dsl cuda-python
+pip install -e "./oink[bench]"
```
-## Use with vLLM
+## Usage
+
+### vLLM (general plugin)
-1. Enable the vLLM integration:
+1) Enable the plugin:
```bash
export VLLM_USE_OINK_RMSNORM=1
```
-2. Ensure vLLM keeps `rms_norm` as a custom op when using `torch.compile` /
-CUDA graphs. In Python:
+2) Ensure vLLM keeps `rms_norm` as a custom op when using `torch.compile` / CUDA graphs:
```python
from vllm import LLM
@@ -45,13 +66,44 @@ llm = LLM(
)
```
-Without `+rms_norm`, Inductor may fuse RMSNorm into larger Triton kernels and
-neither vLLM's CUDA RMSNorm nor Oink will run.
+Without `+rms_norm`, Inductor may fuse RMSNorm into larger kernels and neither
+vLLM’s CUDA RMSNorm nor Oink will run.
+
+### Direct PyTorch usage (manual op registration)
+
+For standalone use (outside vLLM), register the custom ops once:
+
+```python
+import kernelagent_oink
+import torch
+
+kernelagent_oink.register(force=True)
+
+x = torch.randn(1024, 4096, device="cuda", dtype=torch.bfloat16)
+w = torch.randn(4096, device="cuda", dtype=torch.bfloat16)
+y = torch.ops.oink.rmsnorm(x, w, 1e-6)
+```
+
+## Benchmarks
+
+The repo includes a Quack-style benchmark suite (tables + SVG plots) to compare
+Oink against Quack on SM100 and to reproduce the reported speedups.
+
+- How to run + methodology: `oink/benchmarks/README.md`
+- Pre-generated plots: `oink/benchmarks/media/`
+
+
+

+
+
+
+

+
-## Notes
+## Links
-- This plugin is designed to be **safe to import even when disabled**; it only
- registers ops when `VLLM_USE_OINK_RMSNORM` is truthy (`"1"` / `"true"`).
-- The ops preserve **padded-row layouts** for 2D tensors (shape `[M, N]`,
- `stride(1) == 1`, and potentially `stride(0) > N`), which is required for
- `torch.compile` stride verification on some models (e.g., MLA padded inputs).
+| What | Link |
+|---|---|
+| Quack (expert baseline) | https://github.com/Dao-AILab/quack |
+| KernelAgent (agentic framework) | https://github.com/meta-pytorch/KernelAgent |
+| vLLM PR (Oink RMSNorm integration) | https://github.com/vllm-project/vllm/pull/31828 |
diff --git a/oink/benchmarks/README.md b/oink/benchmarks/README.md
new file mode 100644
index 0000000..ceb7932
--- /dev/null
+++ b/oink/benchmarks/README.md
@@ -0,0 +1,146 @@
+# SM100 Benchmarks (KernelAgent-Oink vs Quack)
+
+This folder contains SM100 (GB200 / Blackwell) microbenchmarks for the Oink
+CuTeDSL kernels vendored into KernelAgent, comparing against Quack’s SM100
+kernels where Quack provides an equivalent API.
+
+## Prereqs
+
+- GPU: **SM100** (`torch.cuda.get_device_capability() == (10, 0)`).
+- Python deps in your environment:
+ - `torch`
+ - `nvidia-cutlass-dsl` (CuTeDSL)
+ - `cuda-python`
+ - `triton` (only for `triton.testing.do_bench`)
+ - `quack` (optional; only needed for Oink-vs-Quack comparisons)
+
+Recommended env vars:
+
+```bash
+export PYTORCH_ALLOC_CONF=expandable_segments:True
+export CUTE_DSL_ARCH=sm_100a
+```
+
+## Shape suites
+
+- **Quack-suite**: `(batch, seq) ∈ {1,4,8,16,32} × {8192,16384,32768,65536,131072}`,
+ with `hidden = 4096` so `M = batch * seq`, `N = 4096`.
+- **DeepSeek-V3-like (DSv3)**
+ - RMSNorm / LayerNorm / Softmax: `M ∈ {4096, 16384, 65536}`, `N ∈ {6144, 7168, 8192}`
+ - Cross-entropy: `M ∈ {4096, 16384, 65536}`, `N ∈ {3072, 6144, 8192, 12288}`
+
+## Correctness gates
+
+By default, each script runs a per-shape `torch.testing.assert_close` check
+vs a **pure-PyTorch reference** **before** emitting timing numbers. When Quack
+is available for that op/path, the script also validates Quack vs the *same*
+reference (so speedups can’t come from looser numerics).
+
+Disable with `--skip-verify` only for quick smoke tests.
+
+## Running benchmarks
+
+All scripts support:
+
+- `--quack-suite` or `--dsv3` (or `--configs MxN,...`)
+- `--dtype {bf16,fp16,fp32}`
+- `--iters ` and `--warmup-ms ` for kernel-only timing
+- `--json ` and/or `--csv ` outputs (meta + rows)
+
+### One-command suite
+
+Run the full Quack-suite + DSv3 set (Oink vs Quack) and write all JSON artifacts
+to a timestamped directory:
+
+```bash
+python oink/benchmarks/readme/run_sm100_suite.py --dtype bf16
+```
+
+Turn the JSON artifacts into Markdown tables (with geomean speedups):
+
+```bash
+python oink/benchmarks/readme/summarize_results.py --in-dir /tmp/kernelagent_oink_sm100_suite_ \
+ --out /tmp/kernelagent_oink_sm100_suite_summary.md
+```
+
+### Measured HBM roofline (STREAM-like)
+
+To contextualize the `*_tbps` numbers as a fraction of a *measured* bandwidth
+ceiling (rather than a theoretical spec), run:
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op both --gb 2 \
+ --json /tmp/hbm_roofline_sm100_bf16.json
+```
+
+### RMSNorm forward
+
+```bash
+python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 200 --warmup-ms 25 \
+ --json /tmp/oink_rmsnorm_fwd_quack_suite.json
+
+python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 200 --warmup-ms 25 \
+ --json /tmp/oink_rmsnorm_fwd_dsv3.json
+
+# vLLM-style inference weights (weight dtype == activation dtype)
+python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype same --quack-suite --iters 200 --warmup-ms 25 \
+ --json /tmp/oink_rmsnorm_fwd_quack_suite_wsame.json
+```
+
+### Fused Add + RMSNorm (vLLM-style, in-place)
+
+This is a good "roofline case study" kernel (heavy read/write traffic, very little extra math):
+
+```bash
+CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --M 65536 --N 4096 \
+ --json /tmp/fused_add_rmsnorm_sm100_bf16.json
+```
+
+### RMSNorm backward
+
+```bash
+python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 100 --warmup-ms 25 \
+ --csv /tmp/oink_rmsnorm_bwd_quack_suite.csv
+
+python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 100 --warmup-ms 25 \
+ --csv /tmp/oink_rmsnorm_bwd_dsv3.csv
+```
+
+### Softmax (forward + backward)
+
+```bash
+python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \
+ --json /tmp/oink_softmax_fwd_bwd_quack_suite.json
+
+python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \
+ --json /tmp/oink_softmax_fwd_bwd_dsv3.json
+```
+
+### Cross-entropy (forward + backward)
+
+```bash
+python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \
+ --json /tmp/oink_cross_entropy_fwd_bwd_quack_suite.json
+
+python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \
+ --json /tmp/oink_cross_entropy_fwd_bwd_dsv3.json
+```
+
+### LayerNorm forward
+
+```bash
+python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --quack-suite --iters 200 --warmup-ms 25 \
+ --json /tmp/oink_layernorm_fwd_quack_suite.json
+
+python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --dsv3 --iters 200 --warmup-ms 25 \
+ --json /tmp/oink_layernorm_fwd_dsv3.json
+```
+
+## Notes
+
+- These scripts intentionally avoid importing any external Oink checkout so the
+ results reflect the in-tree KernelAgent Oink kernels.
+- For RMSNorm, the `rmsnorm_with_stage2` implementation is a **fallback** that
+ is only used when the pointer-based fast path cannot be used (e.g. when
+ `weight.dtype != x.dtype`, or when layouts/alignments are incompatible). You
+ can force it for A/B testing via `KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2=1`.
diff --git a/oink/benchmarks/benchmark/bench_utils.py b/oink/benchmarks/benchmark/bench_utils.py
new file mode 100644
index 0000000..0abb005
--- /dev/null
+++ b/oink/benchmarks/benchmark/bench_utils.py
@@ -0,0 +1,255 @@
+from __future__ import annotations
+
+import csv
+import json
+import math
+import os
+import subprocess
+import sys
+from dataclasses import asdict, dataclass
+from datetime import datetime
+from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple
+
+import torch
+from triton.testing import do_bench as triton_do_bench
+
+
+@dataclass(frozen=True)
+class DeviceMeta:
+ device: str
+ capability: Tuple[int, int]
+ torch: str
+ cuda: str
+ cute_dsl_arch: str
+ git_sha: str
+ timestamp: str
+
+
+def _try_git_sha() -> str:
+ here = os.path.dirname(os.path.abspath(__file__))
+ repo_root = os.path.abspath(os.path.join(here, "..", ".."))
+ try:
+ out = subprocess.check_output(
+ ["git", "rev-parse", "HEAD"],
+ cwd=repo_root,
+ stderr=subprocess.DEVNULL,
+ text=True,
+ )
+ return out.strip()
+ except Exception:
+ return ""
+
+
+def collect_device_meta(device: Optional[torch.device] = None) -> DeviceMeta:
+ if device is None:
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ timestamp = datetime.now().isoformat(timespec="seconds")
+ return DeviceMeta(
+ device=str(props.name),
+ capability=(int(props.major), int(props.minor)),
+ torch=str(torch.__version__),
+ cuda=str(getattr(torch.version, "cuda", "unknown")),
+ cute_dsl_arch=os.environ.get("CUTE_DSL_ARCH", ""),
+ git_sha=_try_git_sha(),
+ timestamp=timestamp,
+ )
+
+
+def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float:
+ """Approximate HBM peak bandwidth in GB/s for roofline fractions."""
+ if device is None:
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ if sm >= 100:
+ return 8000.0
+ return 2000.0
+
+
+def do_bench_triton(fn: Callable[[], Any], *, warmup_ms: int = 25, rep_ms: int = 100) -> float:
+ """Kernel-only timing consistent with the Oink benchmark harnesses."""
+ return float(triton_do_bench(fn, warmup=warmup_ms, rep=rep_ms, return_mode="mean"))
+
+
+def parse_dtype(s: str) -> torch.dtype:
+ s = s.lower()
+ if s == "bf16":
+ return torch.bfloat16
+ if s == "fp16":
+ return torch.float16
+ if s == "fp32":
+ return torch.float32
+ raise ValueError(f"Unsupported dtype: {s}")
+
+
+def parse_configs(s: str) -> List[Tuple[int, int]]:
+ out: List[Tuple[int, int]] = []
+ for part in s.split(","):
+ m, n = part.lower().split("x")
+ out.append((int(m), int(n)))
+ return out
+
+
+def quack_suite_configs() -> List[Tuple[int, int, int]]:
+ """Return (batch, seq, hidden) triples following Quack's common grid (hidden=4096)."""
+ batch_sizes = [1, 4, 8, 16, 32]
+ seq_lengths = [8192, 16384, 32768, 65536, 131072]
+ hidden = 4096
+ cfgs: List[Tuple[int, int, int]] = []
+ for bs in batch_sizes:
+ for sl in seq_lengths:
+ M = bs * sl
+ if M * hidden > (2**31):
+ continue
+ cfgs.append((bs, sl, hidden))
+ return cfgs
+
+
+def ensure_oink_src_on_path() -> None:
+ """Make the in-repo KernelAgent Oink package importable without an editable install."""
+ here = os.path.dirname(os.path.abspath(__file__))
+ oink_src = os.path.abspath(os.path.join(here, "..", "..", "src"))
+ if oink_src not in sys.path:
+ sys.path.insert(0, oink_src)
+
+
+def write_csv(path: str, rows: Sequence[Dict[str, Any]]) -> None:
+ if not rows:
+ return
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
+ file_exists = os.path.exists(path)
+ with open(path, "a", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=sorted(rows[0].keys()))
+ if not file_exists:
+ writer.writeheader()
+ for row in rows:
+ writer.writerow(row)
+
+
+def write_json(path: str, meta: DeviceMeta, rows: Sequence[Dict[str, Any]], *, extra: Dict[str, Any] | None = None) -> None:
+ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
+ payload: Dict[str, Any] = {
+ "meta": {**asdict(meta), **(extra or {})},
+ "rows": list(rows),
+ }
+ with open(path, "w") as f:
+ json.dump(payload, f, indent=2)
+
+
+def iter_row_blocks(M: int, block_rows: int) -> Iterable[Tuple[int, int]]:
+ """Yield (start, end) row index ranges for a 2D (M, N) matrix.
+
+ The intent is to make correctness references for large tensors tractable
+ without materializing full float32 intermediates.
+ """
+ if M < 0:
+ raise ValueError(f"M must be non-negative, got {M}")
+ if block_rows <= 0:
+ raise ValueError(f"block_rows must be > 0, got {block_rows}")
+ for start in range(0, M, block_rows):
+ yield start, min(M, start + block_rows)
+
+
+@dataclass
+class ErrorStats:
+ """Numerical error stats between an output and a reference.
+
+ Notes:
+ - `max_abs` and `rel_l2` are computed exactly (streamed).
+ - `p99_abs` is computed over a deterministic strided sample of abs error
+ values (to keep very large tensors tractable).
+ """
+
+ max_abs: float
+ p99_abs: float
+ rel_l2: float
+ p99_sample_elems: int
+ p99_sample_stride: int
+
+
+class ErrorStatsAccumulator:
+ """Stream error stats over (output_block, ref_block) pairs.
+
+ This is intended for large 2D tensors where we compute reference results
+ block-by-block to avoid materializing full float32 intermediates.
+ """
+
+ def __init__(self, *, total_elems: int, p99_target_samples: int = 1_000_000):
+ if total_elems <= 0:
+ raise ValueError(f"total_elems must be > 0, got {total_elems}")
+ if p99_target_samples <= 0:
+ raise ValueError(f"p99_target_samples must be > 0, got {p99_target_samples}")
+ self.total_elems = int(total_elems)
+ self.p99_target_samples = int(p99_target_samples)
+ # Deterministic strided sampling across the flattened tensor order.
+ self.sample_stride = max(1, self.total_elems // self.p99_target_samples)
+ self._global_offset = 0
+
+ self._max_abs = 0.0
+ self._err_sq_sum = 0.0
+ self._ref_sq_sum = 0.0
+ self._abs_err_samples: List[torch.Tensor] = []
+
+ def update(self, out: torch.Tensor, ref: torch.Tensor) -> None:
+ if out.shape != ref.shape:
+ raise ValueError(f"shape mismatch: out={tuple(out.shape)} ref={tuple(ref.shape)}")
+
+ # Compute error in float32 for stable reductions.
+ err_f32 = (out - ref).to(torch.float32)
+ abs_err = err_f32.abs()
+
+ # Exact reductions.
+ self._max_abs = max(self._max_abs, float(abs_err.max().item()))
+ self._err_sq_sum += float((err_f32 * err_f32).sum(dtype=torch.float64).item())
+ ref_f32 = ref.to(torch.float32)
+ self._ref_sq_sum += float((ref_f32 * ref_f32).sum(dtype=torch.float64).item())
+
+ # Deterministic strided sample for p99_abs.
+ flat = abs_err.flatten()
+ block_elems = int(flat.numel())
+ if block_elems <= 0:
+ return
+
+ stride = int(self.sample_stride)
+ first = (-int(self._global_offset)) % stride
+ if first < block_elems:
+ idx = torch.arange(first, block_elems, step=stride, device=flat.device, dtype=torch.int64)
+ # Gather a modest number of values (≈ block_elems/stride).
+ vals = flat.index_select(0, idx).detach().to(device="cpu", dtype=torch.float32)
+ self._abs_err_samples.append(vals)
+
+ self._global_offset += block_elems
+
+ def finalize(self) -> ErrorStats:
+ if self._abs_err_samples:
+ samples = torch.cat(self._abs_err_samples, dim=0)
+ if samples.numel() > self.p99_target_samples:
+ samples = samples[: self.p99_target_samples]
+ p99 = float(torch.quantile(samples, 0.99).item()) if samples.numel() > 0 else 0.0
+ sample_elems = int(samples.numel())
+ else:
+ p99 = 0.0
+ sample_elems = 0
+
+ denom = math.sqrt(self._ref_sq_sum) if self._ref_sq_sum > 0 else 0.0
+ rel_l2 = (math.sqrt(self._err_sq_sum) / denom) if denom > 0 else 0.0
+
+ return ErrorStats(
+ max_abs=float(self._max_abs),
+ p99_abs=float(p99),
+ rel_l2=float(rel_l2),
+ p99_sample_elems=int(sample_elems),
+ p99_sample_stride=int(self.sample_stride),
+ )
+
+
+def error_stats_to_row(prefix: str, stats: ErrorStats) -> Dict[str, Any]:
+ """Flatten ErrorStats into JSON-friendly row fields."""
+ return {
+ f"{prefix}_max_abs": float(stats.max_abs),
+ f"{prefix}_p99_abs": float(stats.p99_abs),
+ f"{prefix}_rel_l2": float(stats.rel_l2),
+ f"{prefix}_p99_sample_elems": int(stats.p99_sample_elems),
+ f"{prefix}_p99_sample_stride": int(stats.p99_sample_stride),
+ }
diff --git a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
new file mode 100644
index 0000000..8bcac15
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
@@ -0,0 +1,426 @@
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
+os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+from bench_utils import ( # noqa: E402
+ ErrorStatsAccumulator,
+ collect_device_meta,
+ detect_hbm_peak_gbps,
+ do_bench_triton,
+ error_stats_to_row,
+ ensure_oink_src_on_path,
+ iter_row_blocks,
+ parse_configs,
+ parse_dtype,
+ quack_suite_configs,
+ write_csv,
+ write_json,
+)
+
+ensure_oink_src_on_path()
+
+from kernelagent_oink.blackwell import cross_entropy as oink_ce # noqa: E402
+
+try:
+ from quack.cross_entropy import cross_entropy_bwd as quack_ce_bwd # type: ignore
+ from quack.cross_entropy import cross_entropy_fwd as quack_ce_fwd # type: ignore
+except Exception:
+ quack_ce_fwd = None
+ quack_ce_bwd = None
+
+
+# Match Quack's unit-test defaults (tests/test_cross_entropy.py).
+_VERIFY_TOL_LOSS = dict(atol=5e-5, rtol=1e-5) # float32 outputs (loss/lse)
+_VERIFY_TOL_DX = {
+ torch.float32: dict(atol=5e-5, rtol=1e-5),
+ # FP16 `dx` is low-precision; allow ~1 ulp at typical magnitudes.
+ torch.float16: dict(atol=1e-3, rtol=1e-3),
+ # BF16 `dx` is low-precision; allow ~1 ulp at typical magnitudes.
+ torch.bfloat16: dict(atol=1e-2, rtol=1e-2),
+}
+
+
+def bytes_io_model_ce(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ *,
+ target_dtype: torch.dtype = torch.int64,
+ mode: str,
+) -> int:
+ elem = torch.tensor(0, dtype=dtype).element_size()
+ t_elem = torch.tensor(0, dtype=target_dtype).element_size()
+ # Forward:
+ # read logits (M*N) + read target (M) + write loss (M fp32) + write lse (M fp32)
+ fwd = M * N * elem + M * t_elem + 2 * M * 4
+ # Backward (reduction="none" path):
+ # read logits (M*N) + read target (M) + read dloss (M fp32) + read lse (M fp32) + write dx (M*N)
+ bwd = 2 * M * N * elem + M * t_elem + 2 * M * 4
+
+ if mode == "fwd":
+ return int(fwd)
+ if mode == "bwd":
+ return int(bwd)
+ if mode == "fwd_bwd":
+ # Logical IO for dx given (logits, target, dloss): read logits + read target
+ # + read dloss + write dx. (Intermediate lse/loss are implementation details.)
+ return int(2 * M * N * elem + M * t_elem + M * 4)
+ raise ValueError(f"Unsupported mode: {mode}")
+
+
+def dsv3_configs() -> List[Tuple[int, int]]:
+ Ms = [4096, 16384, 65536]
+ Ns = [3072, 6144, 8192, 12288]
+ return [(m, n) for m in Ms for n in Ns]
+
+
+def _verify_parity(logits: torch.Tensor, target: torch.Tensor, *, ignore_index: int) -> dict[str, object]:
+ dtype = logits.dtype
+ ref_block_rows = 512
+ dloss = torch.randn(logits.size(0), device=logits.device, dtype=torch.float32) # upstream grad
+
+ with torch.no_grad():
+ loss_o, lse_o = oink_ce.cross_entropy_forward(
+ logits, target, ignore_index=ignore_index, reduction="none"
+ )
+ dx_o = oink_ce.cross_entropy_backward(dloss, logits, target, lse_o, ignore_index=ignore_index)
+ dx_fused_o = oink_ce.cross_entropy_fwd_bwd(
+ dloss,
+ logits,
+ target,
+ ignore_index=ignore_index,
+ )
+
+ loss_q = None
+ lse_q = None
+ dx_q = None
+ if quack_ce_fwd is not None and quack_ce_bwd is not None:
+ loss_q, lse_q = quack_ce_fwd(
+ logits,
+ target,
+ target_logit=None,
+ ignore_index=ignore_index,
+ return_lse=True,
+ return_dx=False,
+ inplace_backward=False,
+ )
+ dx_q = quack_ce_bwd(
+ logits,
+ target,
+ dloss,
+ lse_q,
+ ignore_index=ignore_index,
+ inplace_backward=False,
+ )
+
+ M = int(logits.shape[0])
+ N = int(logits.shape[1])
+ loss_acc_ours = ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000))
+ lse_acc_ours = ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000))
+ dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ dx_fused_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ loss_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000))
+ if (quack_ce_fwd is not None and quack_ce_bwd is not None)
+ else None
+ )
+ lse_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000))
+ if (quack_ce_fwd is not None and quack_ce_bwd is not None)
+ else None
+ )
+ dx_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if (quack_ce_fwd is not None and quack_ce_bwd is not None)
+ else None
+ )
+
+ # Match Quack tests: compare to a PyTorch reference computed on float32 logits.
+ # Chunk over rows so we don't materialize a full (M, N) float32 tensor.
+ for start, end in iter_row_blocks(M, ref_block_rows):
+ logits_f32 = logits[start:end].float().requires_grad_(True)
+ target_blk = target[start:end]
+ dloss_blk = dloss[start:end]
+
+ loss_ref = torch.nn.functional.cross_entropy(
+ logits_f32,
+ target_blk,
+ reduction="none",
+ ignore_index=ignore_index,
+ )
+ lse_ref = torch.logsumexp(logits_f32, dim=-1)
+ (dx_ref_f32,) = torch.autograd.grad(loss_ref, logits_f32, grad_outputs=dloss_blk)
+ dx_ref = dx_ref_f32.to(dtype)
+
+ torch.testing.assert_close(loss_o[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS)
+ torch.testing.assert_close(lse_o[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS)
+ torch.testing.assert_close(dx_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype])
+ torch.testing.assert_close(dx_fused_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype])
+ loss_acc_ours.update(loss_o[start:end], loss_ref.detach())
+ lse_acc_ours.update(lse_o[start:end], lse_ref.detach())
+ dx_acc_ours.update(dx_o[start:end], dx_ref)
+ dx_fused_acc_ours.update(dx_fused_o[start:end], dx_ref)
+
+ if loss_q is not None and lse_q is not None and dx_q is not None:
+ torch.testing.assert_close(loss_q[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS)
+ torch.testing.assert_close(lse_q[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS)
+ torch.testing.assert_close(dx_q[start:end], dx_ref, **_VERIFY_TOL_DX[dtype])
+ assert loss_acc_quack is not None and lse_acc_quack is not None and dx_acc_quack is not None
+ loss_acc_quack.update(loss_q[start:end], loss_ref.detach())
+ lse_acc_quack.update(lse_q[start:end], lse_ref.detach())
+ dx_acc_quack.update(dx_q[start:end], dx_ref)
+
+ stats: dict[str, object] = {}
+ stats.update(error_stats_to_row("ours_err_loss", loss_acc_ours.finalize()))
+ stats.update(error_stats_to_row("ours_err_lse", lse_acc_ours.finalize()))
+ stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize()))
+ stats.update(error_stats_to_row("ours_err_dx_fused", dx_fused_acc_ours.finalize()))
+ if loss_acc_quack is not None and lse_acc_quack is not None and dx_acc_quack is not None:
+ stats.update(error_stats_to_row("quack_err_loss", loss_acc_quack.finalize()))
+ stats.update(error_stats_to_row("quack_err_lse", lse_acc_quack.finalize()))
+ stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize()))
+ return stats
+
+
+def bench_single(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ *,
+ warmup_ms: int,
+ iters_ms: int,
+ mode: str,
+ verify: bool,
+ ignore_index: int,
+) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]:
+ device = torch.device("cuda")
+ logits = 0.1 * torch.randn(M, N, device=device, dtype=dtype)
+ target = torch.randint(0, N, (M,), device=device, dtype=torch.int64)
+ # Sprinkle some ignore_index entries for robustness (and to match reduction semantics).
+ if ignore_index is not None:
+ mask = torch.rand(M, device=device) < 0.01
+ target[mask] = int(ignore_index)
+ dloss = torch.randn(M, device=device, dtype=torch.float32)
+
+ stats: dict[str, object] = {}
+ if verify:
+ stats = _verify_parity(logits, target, ignore_index=int(ignore_index))
+
+ bytes_io = bytes_io_model_ce(M, N, dtype, target_dtype=target.dtype, mode=mode)
+
+ if mode == "fwd":
+ fn_oink = lambda: oink_ce.cross_entropy_forward(
+ logits, target, ignore_index=int(ignore_index), reduction="none"
+ )
+ fn_quack = (
+ None
+ if quack_ce_fwd is None
+ else (
+ lambda: quack_ce_fwd(
+ logits,
+ target,
+ target_logit=None,
+ ignore_index=int(ignore_index),
+ return_lse=True,
+ return_dx=False,
+ inplace_backward=False,
+ )
+ )
+ )
+ elif mode == "bwd":
+ with torch.no_grad():
+ _loss_o, lse_o = oink_ce.cross_entropy_forward(
+ logits, target, ignore_index=int(ignore_index), reduction="none"
+ )
+ if quack_ce_fwd is not None:
+ _loss_q, lse_q = quack_ce_fwd(
+ logits,
+ target,
+ target_logit=None,
+ ignore_index=int(ignore_index),
+ return_lse=True,
+ return_dx=False,
+ inplace_backward=False,
+ )
+ else:
+ lse_q = None
+ fn_oink = lambda: oink_ce.cross_entropy_backward(
+ dloss, logits, target, lse_o, ignore_index=int(ignore_index)
+ )
+ fn_quack = (
+ None
+ if (quack_ce_bwd is None or lse_q is None)
+ else (
+ lambda: quack_ce_bwd(
+ logits,
+ target,
+ dloss,
+ lse_q,
+ ignore_index=int(ignore_index),
+ inplace_backward=False,
+ )
+ )
+ )
+ elif mode == "fwd_bwd":
+ fn_oink = lambda: oink_ce.cross_entropy_fwd_bwd(
+ dloss,
+ logits,
+ target,
+ ignore_index=int(ignore_index),
+ )
+ fn_quack = (
+ None
+ if (quack_ce_fwd is None or quack_ce_bwd is None)
+ else (
+ lambda: quack_ce_bwd(
+ logits,
+ target,
+ dloss,
+ quack_ce_fwd(
+ logits,
+ target,
+ target_logit=None,
+ ignore_index=int(ignore_index),
+ return_lse=True,
+ return_dx=False,
+ inplace_backward=False,
+ )[1],
+ ignore_index=int(ignore_index),
+ inplace_backward=False,
+ )
+ )
+ )
+ else:
+ raise ValueError(f"Unsupported mode: {mode}")
+
+ ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9
+
+ if fn_quack is None:
+ return (ms_oink, gbps_oink), None, stats
+
+ ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9
+ return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats
+
+
+def main() -> None:
+ if not torch.cuda.is_available():
+ raise SystemExit("CUDA not available")
+
+ torch.cuda.set_device(0)
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
+
+ p = argparse.ArgumentParser()
+ p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"])
+ p.add_argument("--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"])
+ p.add_argument("--ignore-index", type=int, default=-100)
+ p.add_argument("--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only).")
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows")
+ p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)")
+ p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
+ p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid (vocab=4096)")
+ p.add_argument("--dsv3", action="store_true", help="Run DSv3 set: M in {4096,16384,65536}, N in {3072,6144,8192,12288}")
+ p.add_argument(
+ "--skip-verify",
+ action="store_true",
+ help="Skip correctness checks (Oink/Quack vs PyTorch float32-logits cross entropy)",
+ )
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+
+ if args.quack_suite:
+ cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()]
+ elif args.dsv3:
+ cfgs = dsv3_configs()
+ else:
+ cfgs = parse_configs(args.configs)
+
+ hbm_peak = detect_hbm_peak_gbps(device)
+ meta = collect_device_meta(device)
+
+ rows_out: List[Dict[str, Any]] = []
+ for (M, N) in cfgs:
+ print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...", flush=True)
+ (ms_oink, gbps_oink), quack, stats = bench_single(
+ M=M,
+ N=N,
+ dtype=dtype,
+ warmup_ms=int(args.warmup_ms),
+ iters_ms=int(args.iters),
+ mode=str(args.mode),
+ verify=not args.skip_verify,
+ ignore_index=int(args.ignore_index),
+ )
+ row: Dict[str, Any] = {
+ "M": M,
+ "N": N,
+ "dtype": args.dtype,
+ "mode": args.mode,
+ "ignore_index": int(args.ignore_index),
+ "ours_ms": ms_oink,
+ "ours_gbps": gbps_oink,
+ "ours_tbps": gbps_oink / 1000.0,
+ "ours_hbm_frac": gbps_oink / hbm_peak,
+ }
+ if quack is not None:
+ ms_q, gbps_q = quack
+ row.update(
+ {
+ "quack_ms": ms_q,
+ "quack_gbps": gbps_q,
+ "quack_tbps": gbps_q / 1000.0,
+ "speedup_vs_quack": ms_q / ms_oink,
+ }
+ )
+ row.update(stats)
+ rows_out.append(row)
+
+ if args.csv is not None:
+ write_csv(args.csv, rows_out)
+ if args.json is not None:
+ write_json(
+ args.json,
+ meta,
+ rows_out,
+ extra={
+ "method": "triton.testing.do_bench(mean)",
+ "warmup_ms": int(args.warmup_ms),
+ "rep_ms": int(args.iters),
+ "io_model_bytes": "mode-dependent; see bytes_io_model_ce in script",
+ },
+ )
+
+ headers = ["M", "N", "mode", "ours_ms", "ours_tbps"]
+ if quack_ce_fwd is not None and quack_ce_bwd is not None:
+ headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"]
+ print("\nSummary:")
+ print(" ".join(h.rjust(14) for h in headers))
+ for r in rows_out:
+ parts: List[str] = []
+ for h in headers:
+ v = r.get(h)
+ if isinstance(v, float):
+ parts.append(f"{v:14.4f}")
+ else:
+ parts.append(f"{str(v):>14}")
+ print(" ".join(parts))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
new file mode 100644
index 0000000..b75f892
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
@@ -0,0 +1,296 @@
+from __future__ import annotations
+
+"""
+Benchmark fused_add_rmsnorm (in-place) on SM100.
+
+This matches vLLM's fused_add_rms_norm semantics:
+ z = x + residual (stored into residual)
+ y = RMSNorm(z, w) (stored into x)
+
+Why this exists:
+- It is a common inference hot path (vLLM).
+- It is strongly memory-bound (reads/writes two MxN tensors), making it a good
+ roofline case study for Blackwell.
+
+Example:
+ CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --M 65536 --N 4096 \\
+ --json /tmp/fused_add_rmsnorm_sm100_bf16.json
+
+DSv3 suite (Oink vs Quack, multi-shape):
+ CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --dsv3 \\
+ --json /tmp/kernelagent_oink_sm100_suite_bf16/fused_add_rmsnorm_dsv3.json
+"""
+
+import argparse
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
+os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+from bench_utils import ( # noqa: E402
+ ErrorStatsAccumulator,
+ collect_device_meta,
+ detect_hbm_peak_gbps,
+ do_bench_triton,
+ error_stats_to_row,
+ ensure_oink_src_on_path,
+ iter_row_blocks,
+ parse_dtype,
+ write_json,
+)
+
+ensure_oink_src_on_path()
+
+from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402
+
+_VERIFY_TOL = {
+ # Align with Quack's RMSNorm unit-test defaults (tests/test_rmsnorm.py).
+ torch.float32: dict(atol=1e-4, rtol=1e-3),
+ torch.float16: dict(atol=1e-2, rtol=1e-3),
+ torch.bfloat16: dict(atol=1e-1, rtol=1e-2),
+}
+
+try:
+ # Use the low-level mutating custom op to avoid per-iteration allocations
+ # (critical for fair comparisons on small/medium M).
+ from quack.rmsnorm import _rmsnorm_fwd as quack_rmsnorm_fwd_mut # type: ignore
+except Exception:
+ quack_rmsnorm_fwd_mut = None
+
+
+def dsv3_configs() -> List[Tuple[int, int]]:
+ Ms = [4096, 16384, 65536]
+ Ns = [6144, 7168, 8192]
+ return [(m, n) for m in Ms for n in Ns]
+
+
+def bytes_io_model_fused_add_rmsnorm_inplace(M: int, N: int, dtype: torch.dtype) -> int:
+ elem = torch.tensor(0, dtype=dtype).element_size()
+ # Read x + read residual + write x + write residual + read weight
+ return int((4 * M * N + N) * elem)
+
+
+def _verify_parity(
+ *,
+ x: torch.Tensor,
+ residual: torch.Tensor,
+ w: torch.Tensor,
+ eps: float,
+) -> dict[str, object]:
+ tol = _VERIFY_TOL[x.dtype]
+ ref_block_rows = 4096
+ M = int(x.shape[0])
+ N = int(x.shape[1])
+
+ y_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ z_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ y_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_fwd_mut is not None else None
+ z_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_fwd_mut is not None else None
+
+ x_o = x.clone()
+ r_o = residual.clone()
+ out_q = None
+ res_out_q = None
+ with torch.no_grad():
+ oink_rmsnorm.fused_add_rmsnorm_inplace_(x_o, r_o, w, eps=eps)
+
+ if quack_rmsnorm_fwd_mut is not None:
+ out_q = torch.empty_like(x)
+ res_out_q = torch.empty_like(residual)
+ quack_rmsnorm_fwd_mut(
+ x,
+ w,
+ out_q,
+ None, # bias
+ None, # rstd
+ None, # mean
+ residual,
+ res_out_q,
+ eps,
+ False, # is_layernorm
+ )
+
+ # Pure-PyTorch reference (float32 accumulation), chunked over rows.
+ M = int(x.shape[0])
+ w_f32 = w.float()
+ for start, end in iter_row_blocks(M, ref_block_rows):
+ z = x[start:end] + residual[start:end]
+ zf = z.float()
+ rstd = torch.rsqrt(zf.square().mean(dim=-1, keepdim=True) + eps)
+ y_ref = ((zf * rstd) * w_f32).to(x.dtype)
+
+ torch.testing.assert_close(x_o[start:end], y_ref, **tol)
+ torch.testing.assert_close(r_o[start:end], z, **tol)
+ y_acc_ours.update(x_o[start:end], y_ref)
+ z_acc_ours.update(r_o[start:end], z)
+ if out_q is not None and res_out_q is not None:
+ torch.testing.assert_close(out_q[start:end], y_ref, **tol)
+ torch.testing.assert_close(res_out_q[start:end], z, **tol)
+ assert y_acc_quack is not None and z_acc_quack is not None
+ y_acc_quack.update(out_q[start:end], y_ref)
+ z_acc_quack.update(res_out_q[start:end], z)
+
+ stats: dict[str, object] = {}
+ stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize()))
+ stats.update(error_stats_to_row("ours_err_residual_out", z_acc_ours.finalize()))
+ if y_acc_quack is not None and z_acc_quack is not None:
+ stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize()))
+ stats.update(error_stats_to_row("quack_err_residual_out", z_acc_quack.finalize()))
+ return stats
+
+
+def bench_one(
+ *,
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ warmup_ms: int,
+ iters_ms: int,
+ verify: bool,
+) -> Dict[str, Any]:
+ device = torch.device("cuda")
+ x = torch.randn((M, N), device=device, dtype=dtype)
+ residual = torch.randn_like(x)
+ w = torch.randn((N,), device=device, dtype=dtype)
+
+ stats: dict[str, object] = {}
+ if verify:
+ stats = _verify_parity(x=x, residual=residual, w=w, eps=1e-6)
+
+ bytes_io = bytes_io_model_fused_add_rmsnorm_inplace(M, N, dtype)
+
+ fn = lambda: oink_rmsnorm.fused_add_rmsnorm_inplace_(x, residual, w, eps=1e-6)
+ ms = do_bench_triton(fn, warmup_ms=warmup_ms, rep_ms=iters_ms)
+
+ gbps = bytes_io / (ms * 1e-3) / 1e9
+ tbps = gbps / 1000.0
+ hbm_frac = gbps / detect_hbm_peak_gbps(device)
+
+ row: Dict[str, Any] = dict(
+ M=int(M),
+ N=int(N),
+ dtype="bf16" if dtype is torch.bfloat16 else ("fp16" if dtype is torch.float16 else "fp32"),
+ ours_ms=float(ms),
+ ours_gbps=float(gbps),
+ ours_tbps=float(tbps),
+ ours_hbm_frac=float(hbm_frac),
+ )
+ row.update(stats)
+
+ if quack_rmsnorm_fwd_mut is not None:
+ out_q = torch.empty_like(x)
+ res_out_q = torch.empty_like(residual)
+
+ fn_q = lambda: quack_rmsnorm_fwd_mut(
+ x,
+ w,
+ out_q,
+ None, # bias
+ None, # rstd
+ None, # mean
+ residual,
+ res_out_q,
+ 1e-6,
+ False, # is_layernorm
+ )
+ ms_q = do_bench_triton(fn_q, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_q = bytes_io / (ms_q * 1e-3) / 1e9
+ row.update(
+ dict(
+ quack_ms=float(ms_q),
+ quack_gbps=float(gbps_q),
+ quack_tbps=float(gbps_q / 1000.0),
+ speedup_vs_quack=float(ms_q / ms),
+ )
+ )
+
+ return row
+
+
+def _dtype_label(dtype: torch.dtype) -> str:
+ if dtype is torch.bfloat16:
+ return "bf16"
+ if dtype is torch.float16:
+ return "fp16"
+ return "fp32"
+
+
+def _print_table(rows: List[Dict[str, Any]]) -> None:
+ if not rows:
+ return
+ headers = ["M", "N", "ours_ms", "ours_tbps"]
+ has_quack = any("quack_ms" in r for r in rows)
+ if has_quack:
+ headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"]
+ print("\nSummary:")
+ print(" ".join(h.rjust(14) for h in headers))
+ for r in rows:
+ parts: List[str] = []
+ for h in headers:
+ v = r.get(h)
+ if isinstance(v, float):
+ parts.append(f"{v:14.4f}")
+ else:
+ parts.append(f"{str(v):>14}")
+ print(" ".join(parts))
+
+
+def main() -> None:
+ p = argparse.ArgumentParser()
+ p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
+ p.add_argument("--M", type=int, default=65536)
+ p.add_argument("--N", type=int, default=4096)
+ p.add_argument(
+ "--dsv3",
+ action="store_true",
+ help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
+ )
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument("--iters", type=int, default=200, help="rep_ms for do_bench (default: 200)")
+ p.add_argument("--skip-verify", action="store_true")
+ p.add_argument("--json", type=str, default=None)
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+ meta = collect_device_meta(torch.device("cuda"))
+
+ cfgs = dsv3_configs() if bool(args.dsv3) else [(int(args.M), int(args.N))]
+ rows: List[Dict[str, Any]] = []
+ for (M, N) in cfgs:
+ print(f"bench M={M:<8d} N={N:<6d} dtype={_dtype_label(dtype)} fused_add_rmsnorm ...", flush=True)
+ rows.append(
+ bench_one(
+ M=int(M),
+ N=int(N),
+ dtype=dtype,
+ warmup_ms=int(args.warmup_ms),
+ iters_ms=int(args.iters),
+ verify=not bool(args.skip_verify),
+ )
+ )
+
+ _print_table(rows)
+
+ if args.json:
+ write_json(
+ args.json,
+ meta,
+ rows,
+ extra=dict(
+ io_model_bytes="(4*M*N + N)*elem_size",
+ warmup_ms=int(args.warmup_ms),
+ rep_ms=int(args.iters),
+ method="triton.testing.do_bench(mean)",
+ note="Oink fused_add_rmsnorm_inplace_ vs Quack quack::_rmsnorm_fwd(residual=..., residual_out=...) when available",
+ ),
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
new file mode 100644
index 0000000..971a03c
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
@@ -0,0 +1,226 @@
+from __future__ import annotations
+
+"""
+HBM roofline microbenchmark for SM100 (GB200 / Blackwell).
+
+This script measures a STREAM-like bandwidth ceiling using a simple Triton kernel
+that performs a large contiguous copy (read + write) and/or triad (read + read + write)
+over a large buffer.
+
+Why this exists:
+- The benchmark harnesses for Oink ops report an "ours_tbps" derived from an IO model.
+- For roofline discussions, comparing against a *measured* device bandwidth ceiling
+ is often more meaningful than quoting a marketing/theoretical spec.
+
+Example:
+ CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op copy --gb 2
+ CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype fp16 --op triad --gb 2
+"""
+
+import argparse
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import triton
+import triton.language as tl
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+from bench_utils import ( # noqa: E402
+ collect_device_meta,
+ do_bench_triton,
+ parse_dtype,
+ write_json,
+)
+
+
+@triton.jit
+def _copy_kernel(
+ x_ptr,
+ y_ptr,
+ n_elements,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offsets = pid * BLOCK + tl.arange(0, BLOCK)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask, other=0)
+ tl.store(y_ptr + offsets, x, mask=mask)
+
+
+@triton.jit
+def _triad_kernel(
+ x_ptr,
+ y_ptr,
+ n_elements,
+ BLOCK: tl.constexpr,
+):
+ pid = tl.program_id(0)
+ offsets = pid * BLOCK + tl.arange(0, BLOCK)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask, other=0)
+ y = tl.load(y_ptr + offsets, mask=mask, other=0)
+ tl.store(y_ptr + offsets, x + y, mask=mask)
+
+
+def _bytes_moved(n_elements: int, elem_size: int, *, op: str) -> int:
+ if op == "copy":
+ return int(2 * n_elements * elem_size) # read x + write y
+ if op == "triad":
+ return int(3 * n_elements * elem_size) # read x + read y + write y
+ raise ValueError(f"Unsupported op: {op}")
+
+
+def bench_one(
+ *,
+ n_elements: int,
+ dtype: torch.dtype,
+ op: str,
+ block: int,
+ num_warps: int,
+ warmup_ms: int,
+ iters_ms: int,
+) -> Tuple[float, float]:
+ device = torch.device("cuda")
+ x = torch.empty((n_elements,), device=device, dtype=dtype)
+ y = torch.empty_like(x)
+ # Avoid pathological compression-friendly patterns (e.g. all-zeros) that can
+ # artificially inflate apparent bandwidth on some GPUs. Random-ish data is
+ # a closer match to ML workloads.
+ x.uniform_(-1, 1)
+ y.uniform_(-1, 1)
+
+ grid = (triton.cdiv(n_elements, block),)
+
+ if op == "copy":
+ launch = lambda: _copy_kernel[grid](
+ x,
+ y,
+ n_elements,
+ BLOCK=block,
+ num_warps=num_warps,
+ num_stages=4,
+ )
+ elif op == "triad":
+ launch = lambda: _triad_kernel[grid](
+ x,
+ y,
+ n_elements,
+ BLOCK=block,
+ num_warps=num_warps,
+ num_stages=4,
+ )
+ else:
+ raise ValueError(f"Unsupported op: {op}")
+
+ # Force compilation out of the timed region.
+ launch()
+ torch.cuda.synchronize()
+
+ ms = do_bench_triton(launch, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ moved = _bytes_moved(n_elements, x.element_size(), op=op)
+ tbps = moved / (ms * 1e-3) / 1e12
+ return ms, tbps
+
+
+def _print_summary(rows: List[Dict[str, Any]]) -> None:
+ if not rows:
+ return
+ best = max(rows, key=lambda r: float(r["tbps"]))
+ print("\nSummary (STREAM-like):")
+ print(f"- best_tbps: {best['tbps']:.3f} TB/s ({best['op']}, BLOCK={best['block']}, warps={best['num_warps']})")
+
+
+def main() -> None:
+ p = argparse.ArgumentParser()
+ p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
+ p.add_argument("--op", type=str, default="copy", choices=["copy", "triad", "both"])
+ p.add_argument("--gb", type=float, default=2.0, help="Size per tensor in GB (default: 2)")
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument("--iters", type=int, default=100, help="rep_ms for do_bench (default: 100)")
+ p.add_argument("--json", type=str, default=None, help="Write JSON results to this path")
+ p.add_argument("--no-sweep", action="store_true", help="Disable tuning sweep; run a single config")
+ p.add_argument("--block", type=int, default=2048, help="BLOCK size when --no-sweep is set")
+ p.add_argument("--warps", type=int, default=8, help="num_warps when --no-sweep is set")
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ cap = (int(props.major), int(props.minor))
+ if cap != (10, 0):
+ raise RuntimeError(f"Expected SM100 (10,0), got {cap} ({props.name})")
+
+ elem_size = torch.tensor(0, dtype=dtype).element_size()
+ bytes_per_tensor = int(args.gb * (1024**3))
+ n_elements = max(1, bytes_per_tensor // elem_size)
+
+ ops: List[str]
+ if args.op == "both":
+ ops = ["copy", "triad"]
+ else:
+ ops = [args.op]
+
+ if args.no_sweep:
+ sweep: List[Tuple[int, int]] = [(int(args.block), int(args.warps))]
+ else:
+ # A tiny hand-tuned sweep that keeps compile overhead reasonable.
+ sweep = [
+ (1024, 4),
+ (1024, 8),
+ (2048, 4),
+ (2048, 8),
+ (4096, 8),
+ ]
+
+ print(f"Running on {props.name} (SM{props.major}{props.minor})")
+ print(f"- dtype: {args.dtype} (elem={elem_size}B)")
+ print(f"- n_elements: {n_elements:,} (~{(n_elements * elem_size) / (1024**3):.2f} GiB per tensor)")
+ print(f"- ops: {ops}")
+ print(f"- sweep: {sweep}")
+
+ meta = collect_device_meta(device)
+ rows: List[Dict[str, Any]] = []
+ for op in ops:
+ for block, warps in sweep:
+ ms, tbps = bench_one(
+ n_elements=n_elements,
+ dtype=dtype,
+ op=op,
+ block=block,
+ num_warps=warps,
+ warmup_ms=int(args.warmup_ms),
+ iters_ms=int(args.iters),
+ )
+ rows.append(
+ dict(
+ op=op,
+ dtype=str(args.dtype),
+ n_elements=int(n_elements),
+ elem_size_B=int(elem_size),
+ block=int(block),
+ num_warps=int(warps),
+ warmup_ms=int(args.warmup_ms),
+ rep_ms=int(args.iters),
+ ms=float(ms),
+ tbps=float(tbps),
+ )
+ )
+ print(f"- {op:5s} BLOCK={block:4d} warps={warps}: {tbps:.3f} TB/s ({ms:.4f} ms)")
+
+ _print_summary(rows)
+
+ if args.json:
+ # Write meta + detailed rows for reproducibility.
+ extra = dict(
+ bytes_model="copy:2*N*elem, triad:3*N*elem",
+ bytes_per_tensor=int(bytes_per_tensor),
+ gb_per_tensor=float(args.gb),
+ )
+ write_json(args.json, meta, rows, extra=extra)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
new file mode 100644
index 0000000..778e3e2
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
@@ -0,0 +1,393 @@
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
+os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+from bench_utils import ( # noqa: E402
+ ErrorStatsAccumulator,
+ collect_device_meta,
+ detect_hbm_peak_gbps,
+ do_bench_triton,
+ error_stats_to_row,
+ ensure_oink_src_on_path,
+ iter_row_blocks,
+ parse_configs,
+ parse_dtype,
+ quack_suite_configs,
+ write_csv,
+ write_json,
+)
+
+ensure_oink_src_on_path()
+
+from kernelagent_oink.blackwell import layernorm as oink_ln # noqa: E402
+
+try:
+ # Quack exposes LayerNorm through the RMSNorm module (is_layernorm=True path).
+ from quack.rmsnorm import layernorm_fwd as quack_layernorm # type: ignore
+except Exception:
+ quack_layernorm = None
+
+_VERIFY_TOL_Y = {
+ # Match Quack's unit-test defaults (tests/test_layernorm.py).
+ torch.float32: dict(atol=1e-4, rtol=1e-4),
+ torch.float16: dict(atol=1e-3, rtol=1e-3),
+ torch.bfloat16: dict(atol=1e-2, rtol=1e-2),
+}
+
+# Quack checks rstd/mean (fp32) with a tighter fixed tolerance.
+_VERIFY_TOL_STATS = dict(atol=6e-4, rtol=6e-4)
+
+
+def bytes_io_model_layernorm(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ *,
+ has_bias: bool,
+ return_rstd: bool,
+ return_mean: bool,
+ weight_dtype: torch.dtype = torch.float32,
+) -> int:
+ elem = torch.tensor(0, dtype=dtype).element_size()
+ w_elem = torch.tensor(0, dtype=weight_dtype).element_size()
+ total = 0
+ # Read x + write y
+ total += 2 * M * N * elem
+ # Read weight (+ optional bias) along feature dim
+ total += N * w_elem
+ if has_bias:
+ total += N * w_elem
+ # Optional per-row stats (fp32)
+ if return_rstd:
+ total += M * 4
+ if return_mean:
+ total += M * 4
+ return int(total)
+
+
+def dsv3_configs() -> List[Tuple[int, int]]:
+ Ms = [4096, 16384, 65536]
+ Ns = [6144, 7168, 8192]
+ return [(m, n) for m in Ms for n in Ns]
+
+
+def _verify_parity(
+ x: torch.Tensor,
+ w: torch.Tensor,
+ b: torch.Tensor | None,
+ *,
+ eps: float,
+ return_rstd: bool,
+ return_mean: bool,
+) -> dict[str, object]:
+ tol_y = _VERIFY_TOL_Y[x.dtype]
+ ref_block_rows = 4096
+ M = int(x.shape[0])
+ N = int(x.shape[1])
+
+ y_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ y_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N) if (quack_layernorm is not None and b is None) else None
+ )
+ with torch.no_grad():
+ ours = oink_ln.layernorm(
+ x,
+ w,
+ bias=b,
+ eps=eps,
+ return_rstd=return_rstd,
+ return_mean=return_mean,
+ )
+ quack = None
+ if quack_layernorm is not None and b is None:
+ quack = quack_layernorm(
+ x,
+ w,
+ eps=eps,
+ return_rstd=return_rstd,
+ return_mean=return_mean,
+ )
+ torch.cuda.synchronize()
+
+ def _unpack(out):
+ if return_rstd and return_mean:
+ y, rstd, mean = out
+ elif return_rstd and not return_mean:
+ y, rstd = out
+ mean = None
+ elif return_mean and not return_rstd:
+ y, mean = out
+ rstd = None
+ else:
+ y, rstd, mean = out, None, None
+ return y, rstd, mean
+
+ y_o, rstd_o, mean_o = _unpack(ours)
+ y_q, rstd_q, mean_q = _unpack(quack) if quack is not None else (None, None, None)
+
+ # Pure-PyTorch reference (float32 accumulation), matching Quack's unit tests:
+ # - compute ref output via F.layer_norm on float32
+ # - compute mean/rstd from float32 input
+ rstd_ref_all = torch.empty((M,), device=x.device, dtype=torch.float32) if return_rstd else None
+ mean_ref_all = torch.empty((M,), device=x.device, dtype=torch.float32) if return_mean else None
+
+ for start, end in iter_row_blocks(M, ref_block_rows):
+ x_f32 = x[start:end].float()
+ y_ref_f32 = torch.nn.functional.layer_norm(x_f32, w.shape, w, b, eps)
+ y_ref = y_ref_f32.to(x.dtype)
+ torch.testing.assert_close(y_o[start:end], y_ref, **tol_y)
+ y_acc_ours.update(y_o[start:end], y_ref)
+ if y_q is not None:
+ torch.testing.assert_close(y_q[start:end], y_ref, **tol_y)
+ assert y_acc_quack is not None
+ y_acc_quack.update(y_q[start:end], y_ref)
+
+ # Per-row stats in fp32, as in Quack's tests.
+ if return_rstd or return_mean:
+ mean_f32 = x_f32.mean(dim=-1)
+ if return_mean:
+ assert mean_ref_all is not None
+ mean_ref_all[start:end] = mean_f32
+ if return_rstd:
+ var_f32 = ((x_f32 - mean_f32.unsqueeze(1)) ** 2).mean(dim=-1)
+ rstd_ref = 1.0 / torch.sqrt(var_f32 + eps)
+ assert rstd_ref_all is not None
+ rstd_ref_all[start:end] = rstd_ref
+
+ assert rstd_o is not None
+ torch.testing.assert_close(rstd_o[start:end], rstd_ref, **_VERIFY_TOL_STATS)
+ if rstd_q is not None:
+ torch.testing.assert_close(rstd_q[start:end], rstd_ref, **_VERIFY_TOL_STATS)
+
+ if return_mean:
+ mean_ref = mean_f32
+ assert mean_o is not None
+ torch.testing.assert_close(mean_o[start:end], mean_ref, **_VERIFY_TOL_STATS)
+ if mean_q is not None:
+ torch.testing.assert_close(mean_q[start:end], mean_ref, **_VERIFY_TOL_STATS)
+
+ stats: dict[str, object] = {}
+ stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize()))
+ if y_acc_quack is not None:
+ stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize()))
+
+ if return_rstd:
+ assert rstd_o is not None and rstd_ref_all is not None
+ rstd_acc_ours = ErrorStatsAccumulator(
+ total_elems=int(rstd_ref_all.numel()), p99_target_samples=int(rstd_ref_all.numel())
+ )
+ rstd_acc_ours.update(rstd_o, rstd_ref_all)
+ stats.update(error_stats_to_row("ours_err_rstd", rstd_acc_ours.finalize()))
+ if rstd_q is not None:
+ rstd_acc_quack = ErrorStatsAccumulator(
+ total_elems=int(rstd_ref_all.numel()), p99_target_samples=int(rstd_ref_all.numel())
+ )
+ rstd_acc_quack.update(rstd_q, rstd_ref_all)
+ stats.update(error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize()))
+
+ if return_mean:
+ assert mean_o is not None and mean_ref_all is not None
+ mean_acc_ours = ErrorStatsAccumulator(
+ total_elems=int(mean_ref_all.numel()), p99_target_samples=int(mean_ref_all.numel())
+ )
+ mean_acc_ours.update(mean_o, mean_ref_all)
+ stats.update(error_stats_to_row("ours_err_mean", mean_acc_ours.finalize()))
+ if mean_q is not None:
+ mean_acc_quack = ErrorStatsAccumulator(
+ total_elems=int(mean_ref_all.numel()), p99_target_samples=int(mean_ref_all.numel())
+ )
+ mean_acc_quack.update(mean_q, mean_ref_all)
+ stats.update(error_stats_to_row("quack_err_mean", mean_acc_quack.finalize()))
+
+ return stats
+
+
+def bench_single(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ *,
+ eps: float,
+ warmup_ms: int,
+ iters_ms: int,
+ verify: bool,
+ return_rstd: bool,
+ return_mean: bool,
+ has_bias: bool,
+) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]:
+ device = torch.device("cuda")
+ x = torch.randn(M, N, device=device, dtype=dtype)
+ w = torch.randn(N, device=device, dtype=torch.float32)
+ b = torch.randn(N, device=device, dtype=torch.float32) if has_bias else None
+
+ stats: dict[str, object] = {}
+ if verify:
+ stats = _verify_parity(x, w, b, eps=eps, return_rstd=return_rstd, return_mean=return_mean)
+
+ bytes_io = bytes_io_model_layernorm(
+ M,
+ N,
+ dtype,
+ has_bias=has_bias,
+ return_rstd=return_rstd,
+ return_mean=return_mean,
+ weight_dtype=w.dtype,
+ )
+
+ fn_oink = lambda: oink_ln.layernorm(
+ x,
+ w,
+ bias=b,
+ eps=eps,
+ return_rstd=return_rstd,
+ return_mean=return_mean,
+ )
+ ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9
+
+ if quack_layernorm is None or has_bias:
+ return (ms_oink, gbps_oink), None, stats
+
+ fn_quack = lambda: quack_layernorm(
+ x,
+ w,
+ eps=eps,
+ return_rstd=return_rstd,
+ return_mean=return_mean,
+ )
+ ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9
+ return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats
+
+
+def main() -> None:
+ if not torch.cuda.is_available():
+ raise SystemExit("CUDA not available")
+
+ torch.cuda.set_device(0)
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
+
+ p = argparse.ArgumentParser()
+ p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"])
+ p.add_argument("--eps", type=float, default=1e-6)
+ p.add_argument("--return-rstd", action="store_true")
+ p.add_argument("--return-mean", action="store_true")
+ p.add_argument("--with-bias", action="store_true", help="Benchmark bias path (Quack compare skipped)")
+ p.add_argument("--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only).")
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows")
+ p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)")
+ p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
+ p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid (hidden=4096)")
+ p.add_argument(
+ "--dsv3",
+ action="store_true",
+ help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
+ )
+ p.add_argument(
+ "--skip-verify",
+ action="store_true",
+ help="Skip correctness checks (Oink/Quack vs a pure-PyTorch reference; Quack compare skipped when bias is enabled)",
+ )
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+ eps = float(args.eps)
+
+ if args.quack_suite:
+ cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()]
+ elif args.dsv3:
+ cfgs = dsv3_configs()
+ else:
+ cfgs = parse_configs(args.configs)
+
+ hbm_peak = detect_hbm_peak_gbps(device)
+ meta = collect_device_meta(device)
+
+ rows_out: List[Dict[str, Any]] = []
+ for (M, N) in cfgs:
+ print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True)
+ (ms_oink, gbps_oink), quack, stats = bench_single(
+ M=M,
+ N=N,
+ dtype=dtype,
+ eps=eps,
+ warmup_ms=int(args.warmup_ms),
+ iters_ms=int(args.iters),
+ verify=not args.skip_verify,
+ return_rstd=bool(args.return_rstd),
+ return_mean=bool(args.return_mean),
+ has_bias=bool(args.with_bias),
+ )
+ row: Dict[str, Any] = {
+ "M": M,
+ "N": N,
+ "dtype": args.dtype,
+ "eps": eps,
+ "return_rstd": bool(args.return_rstd),
+ "return_mean": bool(args.return_mean),
+ "with_bias": bool(args.with_bias),
+ "ours_ms": ms_oink,
+ "ours_gbps": gbps_oink,
+ "ours_tbps": gbps_oink / 1000.0,
+ "ours_hbm_frac": gbps_oink / hbm_peak,
+ }
+ if quack is not None:
+ ms_q, gbps_q = quack
+ row.update(
+ {
+ "quack_ms": ms_q,
+ "quack_gbps": gbps_q,
+ "quack_tbps": gbps_q / 1000.0,
+ "speedup_vs_quack": ms_q / ms_oink,
+ }
+ )
+ row.update(stats)
+ rows_out.append(row)
+
+ if args.csv is not None:
+ write_csv(args.csv, rows_out)
+ if args.json is not None:
+ write_json(
+ args.json,
+ meta,
+ rows_out,
+ extra={
+ "method": "triton.testing.do_bench(mean)",
+ "warmup_ms": int(args.warmup_ms),
+ "rep_ms": int(args.iters),
+ "io_model_bytes": "see bytes_io_model_layernorm in script",
+ },
+ )
+
+ headers = ["M", "N", "ours_ms", "ours_tbps"]
+ if quack_layernorm is not None and (not args.with_bias):
+ headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"]
+ print("\nSummary:")
+ print(" ".join(h.rjust(14) for h in headers))
+ for r in rows_out:
+ parts: List[str] = []
+ for h in headers:
+ v = r.get(h)
+ if isinstance(v, float):
+ parts.append(f"{v:14.4f}")
+ else:
+ parts.append(f"{str(v):>14}")
+ print(" ".join(parts))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
new file mode 100644
index 0000000..01c390d
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
@@ -0,0 +1,434 @@
+from __future__ import annotations
+
+import argparse
+import csv
+import os
+import sys
+from dataclasses import dataclass
+from typing import List, Optional, Tuple
+
+import torch
+from triton.testing import do_bench as triton_do_bench
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
+os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+# Make the in-repo KernelAgent Oink package importable without an editable install.
+_HERE = os.path.dirname(os.path.abspath(__file__))
+_OINK_SRC = os.path.abspath(os.path.join(_HERE, "..", "src"))
+if _OINK_SRC not in sys.path:
+ sys.path.insert(0, _OINK_SRC)
+
+from bench_utils import ( # noqa: E402
+ ErrorStatsAccumulator,
+ collect_device_meta,
+ error_stats_to_row,
+ iter_row_blocks,
+ write_json,
+)
+from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402
+
+try:
+ from quack.rmsnorm import rmsnorm_bwd as quack_rmsnorm_bwd # type: ignore
+except Exception:
+ quack_rmsnorm_bwd = None
+
+_VERIFY_TOL_DX = {
+ # Match Quack's unit-test defaults (tests/test_rmsnorm.py).
+ torch.float32: dict(atol=1e-4, rtol=1e-3),
+ torch.float16: dict(atol=1e-2, rtol=1e-3),
+ torch.bfloat16: dict(atol=1e-1, rtol=1e-2),
+}
+
+
+def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float:
+ """Approximate HBM peak bandwidth in GB/s for roofline fractions."""
+ if device is None:
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ if sm >= 100:
+ return 8000.0
+ return 2000.0
+
+
+@dataclass
+class Result:
+ ms: float
+ gbps: float
+
+
+def do_bench_triton(fn, warmup_ms: int = 25, rep_ms: int = 100) -> float:
+ # Kernel-only timing consistent with the existing Oink forward harness.
+ return float(triton_do_bench(fn, warmup=warmup_ms, rep=rep_ms, return_mode="mean"))
+
+
+def bytes_io_model_bwd(
+ M: int, N: int, dtype: torch.dtype, *, weight_dtype: torch.dtype = torch.float32
+) -> int:
+ """A simple IO model for RMSNorm backward.
+
+ This intentionally ignores partial-reduction scratch buffers (`dw_partial` /
+ `db_partial`) since those are highly implementation-specific and depend on
+ sm_count; we still report speedups and times regardless.
+ """
+ elem = torch.tensor(0, dtype=dtype).element_size()
+ w_elem = torch.tensor(0, dtype=weight_dtype).element_size()
+ # Read x + dout + write dx
+ total = 3 * M * N * elem
+ # Read weight + write dw
+ total += 2 * N * w_elem
+ # Read rstd (fp32 per row)
+ total += M * 4
+ return int(total)
+
+
+def parse_dtype(s: str) -> torch.dtype:
+ s = s.lower()
+ if s == "bf16":
+ return torch.bfloat16
+ if s == "fp16":
+ return torch.float16
+ if s == "fp32":
+ return torch.float32
+ raise ValueError(f"Unsupported dtype: {s}")
+
+
+def parse_configs(s: str) -> List[Tuple[int, int]]:
+ out: List[Tuple[int, int]] = []
+ for part in s.split(","):
+ m, n = part.lower().split("x")
+ out.append((int(m), int(n)))
+ return out
+
+
+def quack_suite_configs() -> List[Tuple[int, int, int]]:
+ """Return (batch, seq, hidden) triples following Quack's grid (hidden=4096)."""
+ batch_sizes = [1, 4, 8, 16, 32]
+ seq_lengths = [8192, 16384, 32768, 65536, 131072]
+ hidden = 4096
+ cfgs: List[Tuple[int, int, int]] = []
+ for bs in batch_sizes:
+ for sl in seq_lengths:
+ M = bs * sl
+ if M * hidden > (2**31):
+ continue
+ cfgs.append((bs, sl, hidden))
+ return cfgs
+
+
+def dsv3_configs() -> List[Tuple[int, int]]:
+ Ms = [4096, 16384, 65536]
+ Ns = [6144, 7168, 8192]
+ return [(m, n) for m in Ms for n in Ns]
+
+
+def _verify_parity(
+ x: torch.Tensor,
+ w: torch.Tensor,
+ dout: torch.Tensor,
+ rstd: torch.Tensor,
+ *,
+ has_bias: bool,
+ has_residual: bool,
+) -> dict[str, object]:
+ tol_dx = _VERIFY_TOL_DX[x.dtype]
+ ref_block_rows = 1024
+ M, N = int(x.shape[0]), int(x.shape[1])
+
+ dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ dx_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_bwd is not None else None
+
+ with torch.no_grad():
+ dx_oink, dw_oink, db_oink, dres_oink = oink_rmsnorm.rmsnorm_backward(
+ x,
+ w,
+ dout,
+ rstd,
+ dresidual_out=None,
+ has_bias=has_bias,
+ has_residual=has_residual,
+ )
+
+ dx_quack = None
+ dw_quack = None
+ db_quack = None
+ dres_quack = None
+ if quack_rmsnorm_bwd is not None:
+ dx_quack, dw_quack, db_quack, dres_quack = quack_rmsnorm_bwd(
+ x,
+ w,
+ dout,
+ rstd,
+ dresidual_out=None,
+ has_bias=has_bias,
+ has_residual=has_residual,
+ )
+ torch.cuda.synchronize()
+
+ # Pure-PyTorch reference, matching Quack's rmsnorm_bwd_ref (float32 math for x_hat).
+ # Chunk over rows to avoid materializing an (M, N) float32 tensor for large shapes.
+ dw_accum = torch.zeros((N,), device=x.device, dtype=torch.float32)
+ w_f32 = w.float()
+ for start, end in iter_row_blocks(M, ref_block_rows):
+ x_f32 = x[start:end].float()
+ rstd_blk = rstd[start:end]
+ x_hat = x_f32 * rstd_blk.unsqueeze(1)
+ # Match Quack/PyTorch reference behavior: gradient math uses float32
+ # intermediates even when (x, w, dout) are bf16/fp16.
+ dout_f32 = dout[start:end].float()
+ wdy = dout_f32 * w_f32
+ c1 = (x_hat * wdy).mean(dim=-1, keepdim=True)
+ dx_ref = ((wdy - x_hat * c1) * rstd_blk.unsqueeze(1)).to(x.dtype)
+
+ torch.testing.assert_close(dx_oink[start:end], dx_ref, **tol_dx)
+ dx_acc_ours.update(dx_oink[start:end], dx_ref)
+ if dx_quack is not None:
+ torch.testing.assert_close(dx_quack[start:end], dx_ref, **tol_dx)
+ assert dx_acc_quack is not None
+ dx_acc_quack.update(dx_quack[start:end], dx_ref)
+
+ if dw_oink is not None:
+ dw_accum += (dout_f32 * x_hat).sum(dim=0)
+
+ stats: dict[str, object] = {}
+ stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize()))
+ if dx_acc_quack is not None:
+ stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize()))
+
+ if dw_oink is not None:
+ dw_ref = dw_accum.to(w.dtype)
+ if w.dtype == torch.float32:
+ # Weight grad is sensitive to reduction order; use a slightly larger
+ # absolute tolerance in the suite harness (Quack's unit tests use
+ # smaller M, where dw is typically tighter).
+ dw_tol = dict(atol=2e-3, rtol=1e-3)
+ else:
+ # For fp16/bf16 weights, `dw` is low-precision and grows with M; use an
+ # ulp/magnitude-aware tolerance rather than a fixed epsilon.
+ dw_ref_f32 = dw_ref.to(torch.float32)
+ dw_oink_f32 = dw_oink.to(torch.float32)
+ scale = float(dw_ref_f32.abs().max().item())
+ dw_atol = max(2.0 * torch.finfo(w.dtype).eps * scale, 1e-3)
+ dw_tol = dict(atol=dw_atol, rtol=1e-3)
+ torch.testing.assert_close(dw_oink_f32, dw_ref_f32, **dw_tol)
+ if dw_quack is not None:
+ torch.testing.assert_close(dw_quack.to(torch.float32), dw_ref_f32, **dw_tol)
+ dw_tol = None # handled above
+ if dw_tol is not None:
+ torch.testing.assert_close(dw_oink, dw_ref, **dw_tol)
+ if dw_quack is not None:
+ torch.testing.assert_close(dw_quack, dw_ref, **dw_tol)
+
+ # Record weight-grad error stats (small, so exact p99 over the full vector).
+ dw_acc_ours = ErrorStatsAccumulator(total_elems=int(dw_ref.numel()), p99_target_samples=int(dw_ref.numel()))
+ dw_acc_ours.update(dw_oink, dw_ref)
+ stats.update(error_stats_to_row("ours_err_dw", dw_acc_ours.finalize()))
+ if dw_quack is not None:
+ dw_acc_quack = ErrorStatsAccumulator(
+ total_elems=int(dw_ref.numel()), p99_target_samples=int(dw_ref.numel())
+ )
+ dw_acc_quack.update(dw_quack, dw_ref)
+ stats.update(error_stats_to_row("quack_err_dw", dw_acc_quack.finalize()))
+
+ assert db_oink is None and db_quack is None
+ assert dres_oink is None and dres_quack is None
+ return stats
+
+
+def bench_single(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ weight_dtype: torch.dtype,
+ iters_ms: int,
+ eps: float,
+ warmup_ms: int,
+ verify: bool,
+) -> Tuple[Result, Result | None, dict[str, object]]:
+ device = torch.device("cuda")
+ x = torch.randn(M, N, device=device, dtype=dtype)
+ w = torch.randn(N, device=device, dtype=weight_dtype)
+ dout = torch.randn(M, N, device=device, dtype=dtype)
+ # rstd is fp32 per row; compute once outside the timed region.
+ with torch.no_grad():
+ xf = x.float()
+ rstd = torch.rsqrt(xf.square().mean(dim=-1) + eps).to(torch.float32)
+
+ stats: dict[str, object] = {}
+ if verify:
+ stats = _verify_parity(x, w, dout, rstd, has_bias=False, has_residual=False)
+
+ fn_oink = lambda: oink_rmsnorm.rmsnorm_backward(
+ x,
+ w,
+ dout,
+ rstd,
+ dresidual_out=None,
+ has_bias=False,
+ has_residual=False,
+ )
+
+ ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ bytes_io = bytes_io_model_bwd(M, N, dtype, weight_dtype=w.dtype)
+ gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9
+ ours = Result(ms=ms_oink, gbps=gbps_oink)
+
+ if quack_rmsnorm_bwd is None:
+ return ours, None, stats
+
+ fn_quack = lambda: quack_rmsnorm_bwd(
+ x,
+ w,
+ dout,
+ rstd,
+ dresidual_out=None,
+ has_bias=False,
+ has_residual=False,
+ )
+ ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9
+ return ours, Result(ms=ms_quack, gbps=gbps_quack), stats
+
+
+def main() -> None:
+ if not torch.cuda.is_available():
+ raise SystemExit("CUDA not available")
+
+ torch.cuda.set_device(0)
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
+
+ p = argparse.ArgumentParser()
+ p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"])
+ p.add_argument(
+ "--weight-dtype",
+ type=str,
+ default="fp32",
+ choices=["same", "fp16", "bf16", "fp32"],
+ help="RMSNorm weight dtype. `same` matches activation dtype.",
+ )
+ p.add_argument("--eps", type=float, default=1e-6)
+ p.add_argument(
+ "--iters",
+ type=int,
+ default=100,
+ help="Triton do_bench rep_ms (kernel-only).",
+ )
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows")
+ p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)")
+ p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
+ p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid")
+ p.add_argument(
+ "--dsv3",
+ action="store_true",
+ help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
+ )
+ p.add_argument(
+ "--skip-verify",
+ action="store_true",
+ help="Skip correctness checks (Oink/Quack vs a pure-PyTorch RMSNorm backward reference)",
+ )
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+ if args.weight_dtype == "same":
+ weight_dtype = dtype
+ else:
+ weight_dtype = parse_dtype(args.weight_dtype)
+ eps = float(args.eps)
+
+ if args.quack_suite:
+ cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()]
+ elif args.dsv3:
+ cfgs = dsv3_configs()
+ else:
+ cfgs = parse_configs(args.configs)
+
+ hbm_peak = detect_hbm_peak_gbps(device)
+
+ rows_out: list[dict[str, object]] = []
+
+ for (M, N) in cfgs:
+ print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True)
+ ours, quack, stats = bench_single(
+ M=M,
+ N=N,
+ dtype=dtype,
+ weight_dtype=weight_dtype,
+ iters_ms=int(args.iters),
+ eps=eps,
+ warmup_ms=int(args.warmup_ms),
+ verify=not args.skip_verify,
+ )
+
+ row: dict[str, object] = {
+ "M": M,
+ "N": N,
+ "dtype": args.dtype,
+ "weight_dtype": args.weight_dtype,
+ "ours_ms": ours.ms,
+ "ours_gbps": ours.gbps,
+ "ours_tbps": ours.gbps / 1000.0,
+ "ours_hbm_frac": ours.gbps / hbm_peak,
+ }
+ if quack is not None:
+ row.update(
+ {
+ "quack_ms": quack.ms,
+ "quack_gbps": quack.gbps,
+ "quack_tbps": quack.gbps / 1000.0,
+ "speedup_vs_quack": quack.ms / ours.ms,
+ }
+ )
+ row.update(stats)
+ rows_out.append(row)
+
+ if args.csv is not None:
+ file_exists = os.path.exists(args.csv)
+ with open(args.csv, "a", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=sorted(row.keys()))
+ if not file_exists:
+ writer.writeheader()
+ writer.writerow(row)
+
+ if args.json is not None:
+ meta = collect_device_meta(device)
+ write_json(
+ args.json,
+ meta,
+ rows_out,
+ extra={
+ "method": "triton.testing.do_bench(mean)",
+ "warmup_ms": int(args.warmup_ms),
+ "rep_ms": int(args.iters),
+ "io_model_bytes": "see bytes_io_model_bwd in script",
+ "weight_dtype": str(args.weight_dtype),
+ },
+ )
+
+ # Print a small summary table.
+ headers = ["M", "N", "dtype", "ours_ms", "ours_tbps", "ours_hbm_frac"]
+ if quack_rmsnorm_bwd is not None:
+ headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"]
+ print("\nSummary:")
+ print(" ".join(h.rjust(14) for h in headers))
+ for r in rows_out:
+ parts: list[str] = []
+ for h in headers:
+ v = r.get(h)
+ if isinstance(v, float):
+ parts.append(f"{v:14.4f}")
+ else:
+ parts.append(f"{str(v):>14}")
+ print(" ".join(parts))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
new file mode 100644
index 0000000..e55e9ff
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
@@ -0,0 +1,337 @@
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
+os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+from bench_utils import ( # noqa: E402
+ ErrorStatsAccumulator,
+ collect_device_meta,
+ detect_hbm_peak_gbps,
+ do_bench_triton,
+ error_stats_to_row,
+ ensure_oink_src_on_path,
+ iter_row_blocks,
+ parse_configs,
+ parse_dtype,
+ quack_suite_configs,
+ write_csv,
+ write_json,
+)
+
+ensure_oink_src_on_path()
+
+from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402
+
+try:
+ from quack.rmsnorm import rmsnorm_fwd as quack_rmsnorm_fwd # type: ignore
+except Exception:
+ quack_rmsnorm_fwd = None
+
+_VERIFY_TOL_Y = {
+ # Match Quack's unit-test defaults (tests/test_rmsnorm.py).
+ torch.float32: dict(atol=1e-4, rtol=1e-3),
+ torch.float16: dict(atol=1e-2, rtol=1e-3),
+ # NOTE: bf16 ulp grows with magnitude; a slightly larger rtol is more robust
+ # for the large-M suite shapes (and fused paths that can see larger values).
+ torch.bfloat16: dict(atol=1e-1, rtol=1e-2),
+}
+
+_VERIFY_TOL_RSTD = {
+ torch.float32: dict(atol=1e-5, rtol=1e-5),
+ torch.float16: dict(atol=1e-3, rtol=1e-3),
+ torch.bfloat16: dict(atol=1e-3, rtol=1e-3),
+}
+
+
+def bytes_io_model_fwd(
+ M: int, N: int, dtype: torch.dtype, *, weight_dtype: torch.dtype = torch.float32
+) -> int:
+ elem = torch.tensor(0, dtype=dtype).element_size()
+ w_elem = torch.tensor(0, dtype=weight_dtype).element_size()
+ # Read x + write y
+ total = 2 * M * N * elem
+ # Read weight
+ total += N * w_elem
+ return int(total)
+
+
+def dsv3_configs() -> List[Tuple[int, int]]:
+ # DSv3-ish hidden sizes used throughout the Oink/Quack SM100 suite tables.
+ Ms = [4096, 16384, 65536]
+ Ns = [6144, 7168, 8192]
+ return [(m, n) for m in Ms for n in Ns]
+
+
+def _verify_parity(
+ x: torch.Tensor,
+ w: torch.Tensor,
+ *,
+ eps: float,
+ store_rstd: bool,
+) -> dict[str, object]:
+ tol_y = _VERIFY_TOL_Y[x.dtype]
+ tol_rstd = _VERIFY_TOL_RSTD[x.dtype]
+ ref_block_rows = 4096
+ M = int(x.shape[0])
+ N = int(x.shape[1])
+
+ y_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ y_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_fwd is not None else None
+
+ with torch.no_grad():
+ y_o, rstd_o, res_o = oink_rmsnorm.rmsnorm_forward(
+ x,
+ weight=w,
+ bias=None,
+ residual=None,
+ eps=eps,
+ store_rstd=store_rstd,
+ )
+ y_q = None
+ rstd_q = None
+ if quack_rmsnorm_fwd is not None:
+ # Quack returns (out, residual_out, rstd).
+ y_q, res_q, rstd_q = quack_rmsnorm_fwd(
+ x,
+ w,
+ bias=None,
+ residual=None,
+ out_dtype=None,
+ residual_dtype=None,
+ eps=eps,
+ store_rstd=store_rstd,
+ )
+
+ # Pure-PyTorch reference (float32 accumulation), chunked over rows to avoid
+ # materializing an (M, N) float32 tensor for large Quack-suite shapes.
+ w_f32 = w.float()
+ rstd_ref = torch.empty((M,), device=x.device, dtype=torch.float32)
+ for start, end in iter_row_blocks(M, ref_block_rows):
+ x_f32 = x[start:end].float()
+ rstd_blk = torch.rsqrt(x_f32.square().mean(dim=-1) + eps)
+ rstd_ref[start:end] = rstd_blk
+
+ y_ref_blk_f32 = (x_f32 * rstd_blk.unsqueeze(1)) * w_f32
+ y_ref_blk = y_ref_blk_f32.to(x.dtype)
+ torch.testing.assert_close(y_o[start:end], y_ref_blk, **tol_y)
+ y_acc_ours.update(y_o[start:end], y_ref_blk)
+ if y_q is not None:
+ torch.testing.assert_close(y_q[start:end], y_ref_blk, **tol_y)
+ assert y_acc_quack is not None
+ y_acc_quack.update(y_q[start:end], y_ref_blk)
+
+ stats: dict[str, object] = {}
+ stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize()))
+ if y_acc_quack is not None:
+ stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize()))
+
+ if store_rstd:
+ assert rstd_o is not None
+ torch.testing.assert_close(rstd_o, rstd_ref, **tol_rstd)
+ if y_q is not None:
+ assert rstd_q is not None
+ torch.testing.assert_close(rstd_q, rstd_ref, **tol_rstd)
+ # Stats for rstd are cheap (M elements); compute exact p99 over all rows.
+ rstd_acc_ours = ErrorStatsAccumulator(total_elems=int(rstd_ref.numel()), p99_target_samples=int(rstd_ref.numel()))
+ rstd_acc_ours.update(rstd_o, rstd_ref)
+ stats.update(error_stats_to_row("ours_err_rstd", rstd_acc_ours.finalize()))
+ if rstd_q is not None:
+ rstd_acc_quack = ErrorStatsAccumulator(
+ total_elems=int(rstd_ref.numel()), p99_target_samples=int(rstd_ref.numel())
+ )
+ rstd_acc_quack.update(rstd_q, rstd_ref)
+ stats.update(error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize()))
+ # Residual output semantics differ slightly across implementations:
+ # - Oink returns `None` when residual is None.
+ # - Quack returns `x` as a safe alias in that case.
+ #
+ # For parity we focus on `y` (and optional `rstd`) for the residual=None path.
+ assert res_o is None
+ if quack_rmsnorm_fwd is not None:
+ assert res_q is x
+ return stats
+
+
+def bench_single(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ *,
+ weight_dtype: torch.dtype,
+ eps: float,
+ warmup_ms: int,
+ iters_ms: int,
+ verify: bool,
+ store_rstd: bool,
+) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]:
+ device = torch.device("cuda")
+ x = torch.randn(M, N, device=device, dtype=dtype)
+ w = torch.randn(N, device=device, dtype=weight_dtype)
+
+ stats: dict[str, object] = {}
+ if verify:
+ stats = _verify_parity(x, w, eps=eps, store_rstd=store_rstd)
+
+ bytes_io = bytes_io_model_fwd(M, N, dtype, weight_dtype=w.dtype)
+
+ fn_oink = lambda: oink_rmsnorm.rmsnorm_forward(
+ x,
+ weight=w,
+ bias=None,
+ residual=None,
+ eps=eps,
+ store_rstd=store_rstd,
+ )
+ ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9
+
+ if quack_rmsnorm_fwd is None:
+ return (ms_oink, gbps_oink), None, stats
+
+ fn_quack = lambda: quack_rmsnorm_fwd(
+ x,
+ w,
+ bias=None,
+ residual=None,
+ out_dtype=None,
+ residual_dtype=None,
+ eps=eps,
+ store_rstd=store_rstd,
+ )
+ ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9
+ return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats
+
+
+def main() -> None:
+ if not torch.cuda.is_available():
+ raise SystemExit("CUDA not available")
+
+ torch.cuda.set_device(0)
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
+
+ p = argparse.ArgumentParser()
+ p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"])
+ p.add_argument(
+ "--weight-dtype",
+ type=str,
+ default="fp32",
+ choices=["same", "fp16", "bf16", "fp32"],
+ help="RMSNorm weight dtype. `same` matches activation dtype (vLLM-style inference).",
+ )
+ p.add_argument("--eps", type=float, default=1e-6)
+ p.add_argument("--store-rstd", action="store_true", help="Also write rstd (fp32 per row)")
+ p.add_argument("--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only).")
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows")
+ p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)")
+ p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
+ p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid")
+ p.add_argument("--dsv3", action="store_true", help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}")
+ p.add_argument("--skip-verify", action="store_true", help="Skip correctness checks (Oink/Quack vs a pure-PyTorch reference)")
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+ if args.weight_dtype == "same":
+ weight_dtype = dtype
+ else:
+ weight_dtype = parse_dtype(args.weight_dtype)
+ eps = float(args.eps)
+
+ if args.quack_suite:
+ cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()]
+ elif args.dsv3:
+ cfgs = dsv3_configs()
+ else:
+ cfgs = parse_configs(args.configs)
+
+ hbm_peak = detect_hbm_peak_gbps(device)
+ meta = collect_device_meta(device)
+
+ rows_out: List[Dict[str, Any]] = []
+ for (M, N) in cfgs:
+ print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True)
+ (ms_oink, gbps_oink), quack, stats = bench_single(
+ M=M,
+ N=N,
+ dtype=dtype,
+ weight_dtype=weight_dtype,
+ eps=eps,
+ warmup_ms=int(args.warmup_ms),
+ iters_ms=int(args.iters),
+ verify=not args.skip_verify,
+ store_rstd=bool(args.store_rstd),
+ )
+ row: Dict[str, Any] = {
+ "M": M,
+ "N": N,
+ "dtype": args.dtype,
+ "weight_dtype": args.weight_dtype,
+ "eps": eps,
+ "store_rstd": bool(args.store_rstd),
+ "ours_ms": ms_oink,
+ "ours_gbps": gbps_oink,
+ "ours_tbps": gbps_oink / 1000.0,
+ "ours_hbm_frac": gbps_oink / hbm_peak,
+ }
+ if quack is not None:
+ ms_q, gbps_q = quack
+ row.update(
+ {
+ "quack_ms": ms_q,
+ "quack_gbps": gbps_q,
+ "quack_tbps": gbps_q / 1000.0,
+ "speedup_vs_quack": ms_q / ms_oink,
+ }
+ )
+ row.update(stats)
+ rows_out.append(row)
+
+ if args.csv is not None:
+ write_csv(args.csv, rows_out)
+ if args.json is not None:
+ write_json(
+ args.json,
+ meta,
+ rows_out,
+ extra={
+ "method": "triton.testing.do_bench(mean)",
+ "warmup_ms": int(args.warmup_ms),
+ "rep_ms": int(args.iters),
+ "io_model_bytes": "(2*M*N)*elem_size + N*weight_elem_size",
+ "store_rstd": bool(args.store_rstd),
+ "weight_dtype": str(args.weight_dtype),
+ },
+ )
+
+ # Print a compact summary table.
+ headers = ["M", "N", "ours_ms", "ours_tbps"]
+ if quack_rmsnorm_fwd is not None:
+ headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"]
+ print("\nSummary:")
+ print(" ".join(h.rjust(14) for h in headers))
+ for r in rows_out:
+ parts: List[str] = []
+ for h in headers:
+ v = r.get(h)
+ if isinstance(v, float):
+ parts.append(f"{v:14.4f}")
+ else:
+ parts.append(f"{str(v):>14}")
+ print(" ".join(parts))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py
new file mode 100644
index 0000000..93c5af3
--- /dev/null
+++ b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py
@@ -0,0 +1,292 @@
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+# Reduce fragmentation pressure on busy GPUs.
+os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+
+# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
+os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+from bench_utils import ( # noqa: E402
+ ErrorStatsAccumulator,
+ collect_device_meta,
+ detect_hbm_peak_gbps,
+ do_bench_triton,
+ error_stats_to_row,
+ ensure_oink_src_on_path,
+ iter_row_blocks,
+ parse_configs,
+ parse_dtype,
+ quack_suite_configs,
+ write_csv,
+ write_json,
+)
+
+ensure_oink_src_on_path()
+
+from kernelagent_oink.blackwell import softmax as oink_softmax # noqa: E402
+
+try:
+ from quack.softmax import softmax_bwd as quack_softmax_bwd # type: ignore
+ from quack.softmax import softmax_fwd as quack_softmax_fwd # type: ignore
+except Exception:
+ quack_softmax_fwd = None
+ quack_softmax_bwd = None
+
+_VERIFY_TOL = {
+ # Match Quack's unit-test defaults (tests/test_softmax.py).
+ torch.float32: dict(atol=1e-4, rtol=1e-4),
+ torch.float16: dict(atol=1e-3, rtol=1e-3),
+ torch.bfloat16: dict(atol=1e-2, rtol=1e-2),
+}
+
+
+def bytes_io_model_softmax(M: int, N: int, dtype: torch.dtype, *, mode: str) -> int:
+ elem = torch.tensor(0, dtype=dtype).element_size()
+ if mode == "fwd":
+ return int(2 * M * N * elem) # read x + write y
+ if mode == "bwd":
+ return int(3 * M * N * elem) # read dy + read y + write dx
+ if mode == "fwd_bwd":
+ # Logical IO for dx given (x, dy): read x + read dy + write dx.
+ # (The intermediate y=softmax(x) is an implementation detail and is
+ # intentionally not counted here.)
+ return int(3 * M * N * elem)
+ raise ValueError(f"Unsupported mode: {mode}")
+
+
+def dsv3_configs() -> List[Tuple[int, int]]:
+ Ms = [4096, 16384, 65536]
+ Ns = [6144, 7168, 8192]
+ return [(m, n) for m in Ms for n in Ns]
+
+
+def _verify_parity(x: torch.Tensor) -> dict[str, object]:
+ tol = _VERIFY_TOL[x.dtype]
+ ref_block_rows = 4096
+ dy = torch.randn_like(x) # upstream grad
+
+ with torch.no_grad():
+ y_o = oink_softmax.softmax_forward(x)
+ dx_o = oink_softmax.softmax_backward(dy, y_o)
+ dx_fused_o = oink_softmax.softmax_fwd_bwd(dy, x)
+
+ y_q = None
+ dx_q = None
+ if quack_softmax_fwd is not None and quack_softmax_bwd is not None:
+ y_q = quack_softmax_fwd(x)
+ dx_q = quack_softmax_bwd(dy, y_q)
+
+ M = int(x.shape[0])
+ N = int(x.shape[1])
+ y_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ dx_fused_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
+ y_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if (quack_softmax_fwd is not None and quack_softmax_bwd is not None)
+ else None
+ )
+ dx_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if (quack_softmax_fwd is not None and quack_softmax_bwd is not None)
+ else None
+ )
+
+ # Match Quack tests: compare to PyTorch softmax refs (fwd+bwd), chunked.
+ for start, end in iter_row_blocks(M, ref_block_rows):
+ x_blk = x[start:end]
+ dy_blk = dy[start:end]
+ y_ref_blk = torch.softmax(x_blk, dim=-1)
+ dot = torch.sum(dy_blk * y_ref_blk, dim=-1, keepdim=True, dtype=torch.float32)
+ dx_ref_blk = (dy_blk - dot.to(dy_blk.dtype)) * y_ref_blk
+
+ torch.testing.assert_close(y_o[start:end], y_ref_blk, **tol)
+ torch.testing.assert_close(dx_o[start:end], dx_ref_blk, **tol)
+ torch.testing.assert_close(dx_fused_o[start:end], dx_ref_blk, **tol)
+ y_acc_ours.update(y_o[start:end], y_ref_blk)
+ dx_acc_ours.update(dx_o[start:end], dx_ref_blk)
+ dx_fused_acc_ours.update(dx_fused_o[start:end], dx_ref_blk)
+ if y_q is not None and dx_q is not None:
+ torch.testing.assert_close(y_q[start:end], y_ref_blk, **tol)
+ torch.testing.assert_close(dx_q[start:end], dx_ref_blk, **tol)
+ assert y_acc_quack is not None and dx_acc_quack is not None
+ y_acc_quack.update(y_q[start:end], y_ref_blk)
+ dx_acc_quack.update(dx_q[start:end], dx_ref_blk)
+
+ stats: dict[str, object] = {}
+ stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize()))
+ stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize()))
+ stats.update(error_stats_to_row("ours_err_dx_fused", dx_fused_acc_ours.finalize()))
+ if y_acc_quack is not None and dx_acc_quack is not None:
+ stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize()))
+ stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize()))
+ return stats
+
+
+def bench_single(
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ *,
+ warmup_ms: int,
+ iters_ms: int,
+ mode: str,
+ verify: bool,
+) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]:
+ device = torch.device("cuda")
+ x = torch.randn(M, N, device=device, dtype=dtype)
+ dy = torch.randn_like(x)
+
+ stats: dict[str, object] = {}
+ if verify:
+ stats = _verify_parity(x)
+
+ bytes_io = bytes_io_model_softmax(M, N, dtype, mode=mode)
+
+ if mode == "fwd":
+ fn_oink = lambda: oink_softmax.softmax_forward(x)
+ fn_quack = None if quack_softmax_fwd is None else (lambda: quack_softmax_fwd(x))
+ elif mode == "bwd":
+ with torch.no_grad():
+ y_o = oink_softmax.softmax_forward(x)
+ y_q = quack_softmax_fwd(x) if quack_softmax_fwd is not None else None
+ fn_oink = lambda: oink_softmax.softmax_backward(dy, y_o)
+ fn_quack = (
+ None
+ if (quack_softmax_bwd is None or y_q is None)
+ else (lambda: quack_softmax_bwd(dy, y_q))
+ )
+ elif mode == "fwd_bwd":
+ fn_oink = lambda: oink_softmax.softmax_fwd_bwd(dy, x)
+ fn_quack = (
+ None
+ if (quack_softmax_fwd is None or quack_softmax_bwd is None)
+ else (lambda: quack_softmax_bwd(dy, quack_softmax_fwd(x)))
+ )
+ else:
+ raise ValueError(f"Unsupported mode: {mode}")
+
+ ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9
+
+ if fn_quack is None:
+ return (ms_oink, gbps_oink), None, stats
+
+ ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms)
+ gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9
+ return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats
+
+
+def main() -> None:
+ if not torch.cuda.is_available():
+ raise SystemExit("CUDA not available")
+
+ torch.cuda.set_device(0)
+ device = torch.device("cuda")
+ props = torch.cuda.get_device_properties(device)
+ sm = props.major * 10 + props.minor
+ print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
+
+ p = argparse.ArgumentParser()
+ p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"])
+ p.add_argument("--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"])
+ p.add_argument("--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only).")
+ p.add_argument("--warmup-ms", type=int, default=25)
+ p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows")
+ p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)")
+ p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
+ p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid")
+ p.add_argument(
+ "--dsv3",
+ action="store_true",
+ help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
+ )
+ p.add_argument("--skip-verify", action="store_true", help="Skip correctness checks (Oink/Quack vs PyTorch softmax)")
+ args = p.parse_args()
+
+ dtype = parse_dtype(args.dtype)
+
+ if args.quack_suite:
+ cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()]
+ elif args.dsv3:
+ cfgs = dsv3_configs()
+ else:
+ cfgs = parse_configs(args.configs)
+
+ hbm_peak = detect_hbm_peak_gbps(device)
+ meta = collect_device_meta(device)
+
+ rows_out: List[Dict[str, Any]] = []
+ for (M, N) in cfgs:
+ print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...", flush=True)
+ (ms_oink, gbps_oink), quack, stats = bench_single(
+ M=M,
+ N=N,
+ dtype=dtype,
+ warmup_ms=int(args.warmup_ms),
+ iters_ms=int(args.iters),
+ mode=str(args.mode),
+ verify=not args.skip_verify,
+ )
+ row: Dict[str, Any] = {
+ "M": M,
+ "N": N,
+ "dtype": args.dtype,
+ "mode": args.mode,
+ "ours_ms": ms_oink,
+ "ours_gbps": gbps_oink,
+ "ours_tbps": gbps_oink / 1000.0,
+ "ours_hbm_frac": gbps_oink / hbm_peak,
+ }
+ if quack is not None:
+ ms_q, gbps_q = quack
+ row.update(
+ {
+ "quack_ms": ms_q,
+ "quack_gbps": gbps_q,
+ "quack_tbps": gbps_q / 1000.0,
+ "speedup_vs_quack": ms_q / ms_oink,
+ }
+ )
+ row.update(stats)
+ rows_out.append(row)
+
+ if args.csv is not None:
+ write_csv(args.csv, rows_out)
+ if args.json is not None:
+ write_json(
+ args.json,
+ meta,
+ rows_out,
+ extra={
+ "method": "triton.testing.do_bench(mean)",
+ "warmup_ms": int(args.warmup_ms),
+ "rep_ms": int(args.iters),
+ "io_model_bytes": "mode-dependent: fwd=2*M*N, bwd=3*M*N, fwd_bwd=3*M*N (all * elem_size; fwd_bwd counts logical x+dy+dx)",
+ },
+ )
+
+ headers = ["M", "N", "mode", "ours_ms", "ours_tbps"]
+ if quack_softmax_fwd is not None and quack_softmax_bwd is not None:
+ headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"]
+ print("\nSummary:")
+ print(" ".join(h.rjust(14) for h in headers))
+ for r in rows_out:
+ parts: List[str] = []
+ for h in headers:
+ v = r.get(h)
+ if isinstance(v, float):
+ parts.append(f"{v:14.4f}")
+ else:
+ parts.append(f"{str(v):>14}")
+ print(" ".join(parts))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg
new file mode 100644
index 0000000..e32e3a7
--- /dev/null
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg
@@ -0,0 +1,2259 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg
new file mode 100644
index 0000000..b70ba9b
--- /dev/null
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg
@@ -0,0 +1,2600 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg
new file mode 100644
index 0000000..f5cd53c
--- /dev/null
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg
@@ -0,0 +1,2936 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg
new file mode 100644
index 0000000..db39e3c
--- /dev/null
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg
@@ -0,0 +1,1687 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg
new file mode 100644
index 0000000..e8d4cc6
--- /dev/null
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg
@@ -0,0 +1,2720 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg
new file mode 100644
index 0000000..a5670bd
--- /dev/null
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg
@@ -0,0 +1,2580 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg
new file mode 100644
index 0000000..0a021e9
--- /dev/null
+++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg
@@ -0,0 +1,2280 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg
new file mode 100644
index 0000000..9a58fde
--- /dev/null
+++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg
@@ -0,0 +1,2621 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg
new file mode 100644
index 0000000..bff56d0
--- /dev/null
+++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg
@@ -0,0 +1,2957 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg
new file mode 100644
index 0000000..6a16fe8
--- /dev/null
+++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg
@@ -0,0 +1,1708 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg
new file mode 100644
index 0000000..242d013
--- /dev/null
+++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg
@@ -0,0 +1,2741 @@
+
+
+
diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg
new file mode 100644
index 0000000..dac54ac
--- /dev/null
+++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg
@@ -0,0 +1,2601 @@
+
+
+
diff --git a/oink/benchmarks/readme/plot_quack_style_svg.py b/oink/benchmarks/readme/plot_quack_style_svg.py
new file mode 100644
index 0000000..1799f2e
--- /dev/null
+++ b/oink/benchmarks/readme/plot_quack_style_svg.py
@@ -0,0 +1,431 @@
+from __future__ import annotations
+
+"""
+Generate Quack-style SVG performance plots (Oink vs Quack) from the SM100 suite
+JSON artifacts under `/tmp/kernelagent_oink_sm100_suite_{bf16,fp16}`.
+
+The intent is to match Quack's README visual style:
+ - 3 horizontal panels (suite-dependent):
+ - Quack-suite: RMSNorm / Softmax / CrossEntropy
+ - DSv3 (hidden-size): Fused Add+RMSNorm / Softmax / LayerNorm
+ - DSv3 (all ops, 4-panel): Fused Add+RMSNorm / Softmax / LayerNorm / CrossEntropy
+ - DSv3 CrossEntropy: CrossEntropy-only (single panel)
+ - y-axis: model memory bandwidth (GB/s) derived from an IO model
+ - x-axis: a small set of labeled (M, N) shape points
+ - thick lines + markers, dashed y-grid, compact legend
+ - optional horizontal roofline line (measured STREAM-like HBM peak)
+
+Example:
+ python oink/benchmarks/readme/plot_quack_style_svg.py \\
+ --in-dir /tmp/kernelagent_oink_sm100_suite_bf16 \\
+ --suite quack_suite \\
+ --roofline-json /tmp/hbm_roofline_sm100_bf16.json \\
+ --out oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg
+
+For completeness, we can also include LayerNorm as an extra panel (Quack's
+own README plot does not include LayerNorm):
+ python oink/benchmarks/readme/plot_quack_style_svg.py \\
+ --in-dir /tmp/kernelagent_oink_sm100_suite_bf16 \\
+ --suite quack_suite \\
+ --include-layernorm \\
+ --roofline-json /tmp/hbm_roofline_sm100_bf16.json \\
+ --out oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg
+
+Note on DSv3 suite:
+- The DSv3 plot intentionally covers only the hidden-size ops (fused Add+RMSNorm,
+ Softmax, LayerNorm) which share the same `(M, N)` sweep.
+- CrossEntropy in DSv3 uses a vocab-size-like `N` sweep and is plotted separately
+ via `--suite dsv3_cross_entropy` to avoid a mixed x-axis with gaps.
+- For README embedding convenience, `--suite dsv3_all` renders a 4-panel
+ single-row figure where the CrossEntropy panel uses its own x-axis.
+- The RMSNorm panel uses the real block primitive (fused residual-add + RMSNorm)
+ when available: `fused_add_rmsnorm_dsv3.json`.
+"""
+
+import argparse
+import json
+import math
+import os
+from collections import defaultdict
+from statistics import median
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
+
+
+def _load_json(path: str) -> Dict[str, Any]:
+ with open(path) as f:
+ return json.load(f)
+
+
+def _fmt_k(v: int) -> str:
+ # Match Quack's x-axis labels: "32K" means 32768 (1024-based).
+ if v % 1024 == 0:
+ return f"{v // 1024}K"
+ return str(v)
+
+
+def _shape_label(m: int, n: int) -> str:
+ return f"({_fmt_k(m)}, {_fmt_k(n)})"
+
+
+def _gbps_from_row(prefix: str, row: Mapping[str, Any]) -> Optional[float]:
+ # Prefer GB/s in the JSON if present; otherwise fall back to TB/s.
+ gbps_key = f"{prefix}_gbps"
+ tbps_key = f"{prefix}_tbps"
+ if gbps_key in row and row[gbps_key] is not None:
+ return float(row[gbps_key])
+ if tbps_key in row and row[tbps_key] is not None:
+ return float(row[tbps_key]) * 1000.0
+ return None
+
+
+def _aggregate_by_shape(rows: Sequence[Mapping[str, Any]]) -> Dict[Tuple[int, int], Dict[str, float]]:
+ """Aggregate duplicate (M, N) rows using median (more robust than mean)."""
+ buckets: dict[tuple[int, int], dict[str, list[float]]] = defaultdict(
+ lambda: defaultdict(list)
+ )
+ for r in rows:
+ m = int(r["M"])
+ n = int(r["N"])
+ ours = _gbps_from_row("ours", r)
+ quack = _gbps_from_row("quack", r)
+ if ours is not None:
+ buckets[(m, n)]["ours"].append(ours)
+ if quack is not None:
+ buckets[(m, n)]["quack"].append(quack)
+
+ out: Dict[Tuple[int, int], Dict[str, float]] = {}
+ for k, vs in buckets.items():
+ if not vs["ours"] or not vs["quack"]:
+ continue
+ out[k] = dict(ours=float(median(vs["ours"])), quack=float(median(vs["quack"])))
+ return out
+
+
+def _sort_shapes(shapes: Iterable[Tuple[int, int]]) -> List[Tuple[int, int]]:
+ # Sort by N then M to keep the x-axis stable across panels.
+ return sorted(set(shapes), key=lambda x: (x[1], x[0]))
+
+
+def _read_roofline_gbps(path: str) -> float:
+ payload = _load_json(path)
+ rows = payload.get("rows", [])
+ best_tbps = max(float(r["tbps"]) for r in rows)
+ return best_tbps * 1000.0
+
+
+def _ensure_matplotlib():
+ try:
+ import matplotlib as mpl # noqa: F401
+ import matplotlib.pyplot as plt # noqa: F401
+ except Exception as e: # pragma: no cover
+ raise SystemExit(
+ "matplotlib is required to generate SVG plots.\n"
+ "Install with: `python -m pip install matplotlib`"
+ ) from e
+
+
+def _plot(
+ *,
+ panels: Sequence[Tuple[str, Dict[Tuple[int, int], Dict[str, float]]]],
+ roofline_gbps: Optional[float],
+ out_path: str,
+ title: str,
+ shape_policy: str,
+ per_panel_x: bool,
+) -> None:
+ _ensure_matplotlib()
+ import matplotlib as mpl
+ import matplotlib.pyplot as plt
+
+ mpl.rcParams.update(
+ {
+ # Quack-style: embed glyphs as paths for consistent rendering.
+ "svg.fonttype": "path",
+ "font.family": "DejaVu Sans",
+ "axes.titlesize": 18,
+ "axes.labelsize": 16,
+ "xtick.labelsize": 10,
+ "ytick.labelsize": 12,
+ }
+ )
+
+ # Colors roughly matching Quack's SVG palette.
+ COLOR_OINK = "#5ba3f5"
+ COLOR_QUACK = "#ff4444"
+ COLOR_ROOF = "#4d4d4d"
+
+ fig, axes = plt.subplots(
+ nrows=1,
+ ncols=len(panels),
+ figsize=(6.0 * len(panels), 5.6),
+ constrained_layout=False,
+ sharey=True,
+ )
+ if len(panels) == 1:
+ axes = [axes]
+
+ max_y = 0.0
+ for ax, (panel_title, data) in zip(axes, panels):
+ if per_panel_x:
+ shapes = _sort_shapes(data.keys())
+ else:
+ # Quack-style plots use a single shared x-axis across panels. Prefer
+ # the intersection so every panel has a value at every x tick
+ # (cleaner than rendering gaps), and fall back to the union if the
+ # intersection is empty.
+ shape_sets = [set(d.keys()) for _n, d in panels]
+ if shape_policy in {"first", "primary"}:
+ shapes = _sort_shapes(shape_sets[0]) if shape_sets else []
+ elif shape_policy == "intersection" and shape_sets:
+ common = set.intersection(*shape_sets)
+ shapes = _sort_shapes(common) if common else []
+ elif shape_policy == "union":
+ shapes = _sort_shapes(s for _n, d in panels for s in d.keys())
+ else:
+ raise ValueError(f"Unsupported shape_policy: {shape_policy}")
+ if not shapes:
+ shapes = _sort_shapes(s for _n, d in panels for s in d.keys())
+
+ x = list(range(len(shapes)))
+ x_labels = [_shape_label(m, n) for (m, n) in shapes]
+
+ ours_y: List[float] = []
+ quack_y: List[float] = []
+ for s in shapes:
+ rec = data.get(s)
+ if rec is None: # only possible in shared-x mode with union
+ ours_y.append(math.nan)
+ quack_y.append(math.nan)
+ continue
+ ours_y.append(float(rec["ours"]))
+ quack_y.append(float(rec["quack"]))
+ max_y = max(max_y, *(v for v in ours_y if math.isfinite(v)), *(v for v in quack_y if math.isfinite(v)))
+
+ ax.plot(
+ x,
+ ours_y,
+ marker="o",
+ linewidth=5,
+ markersize=7,
+ color=COLOR_OINK,
+ label="KernelAgent-Oink (ours)",
+ )
+ ax.plot(
+ x,
+ quack_y,
+ marker="o",
+ linewidth=5,
+ markersize=7,
+ color=COLOR_QUACK,
+ label="Quack",
+ )
+ if roofline_gbps is not None:
+ ax.axhline(
+ roofline_gbps,
+ color=COLOR_ROOF,
+ linewidth=3,
+ linestyle=(0, (4, 6)),
+ label="HBM peak (measured)" if ax is axes[0] else None,
+ )
+ max_y = max(max_y, float(roofline_gbps))
+
+ ax.set_title(panel_title)
+ ax.set_xticks(x)
+ ax.set_xticklabels(x_labels, rotation=-45, ha="left")
+ if per_panel_x:
+ # DSv3 "all ops" figure: each panel has its own x-axis. Make the
+ # semantics explicit so readers don't assume the same `N` meaning
+ # across panels (CrossEntropy uses a classes/vocab-shard-like axis).
+ if "cross" in panel_title.lower():
+ ax.set_xlabel("Shape (M, C classes)")
+ else:
+ ax.set_xlabel("Shape (M, N hidden)")
+
+ # Quack-like dashed y-grid.
+ ax.grid(axis="y", linestyle=(0, (4, 7.2)), linewidth=0.8, color="#b0b0b0")
+ ax.set_axisbelow(True)
+
+ # Light spines (Quack SVG uses a light gray frame).
+ for spine in ax.spines.values():
+ spine.set_color("#d3d3d3")
+ spine.set_linewidth(1.5)
+
+ axes[0].set_ylabel("Memory Bandwidth (GB/s)")
+
+ # A little headroom above the tallest curve/roofline.
+ ymax = max_y * 1.08 if max_y > 0 else 1.0
+ for ax in axes:
+ ax.set_ylim(0.0, ymax)
+
+ # Tight layout for the axes area, reserving headroom for the suptitle and a
+ # shared legend. In some matplotlib versions, figure-level legends can
+ # overlap the middle panel title unless we reserve a slightly taller header
+ # band.
+ fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.70))
+
+ # Single shared legend across the top (like Quack), but keep it inside the
+ # reserved header band so it doesn't overlap the middle panel title.
+ handles, labels = axes[0].get_legend_handles_labels()
+ # Quack's legend fits nicely in one row because their plots are 3-panel and
+ # therefore wide. For single-panel figures, a 3-column legend can overflow
+ # the canvas and get clipped in the SVG, so we stack it vertically.
+ legend_ncol = min(3, len(labels))
+ legend_fontsize = 13
+ if len(panels) == 1:
+ legend_ncol = 1
+ legend_fontsize = 12
+ fig.legend(
+ handles,
+ labels,
+ loc="upper center",
+ ncol=legend_ncol,
+ frameon=False,
+ bbox_to_anchor=(0.5, 0.91),
+ fontsize=legend_fontsize,
+ handlelength=2.5,
+ )
+ # Single-panel figures (e.g. DSv3 CrossEntropy) are much narrower than the
+ # Quack-style 3-panel plots; use a slightly smaller suptitle font to avoid
+ # clipping in the exported SVG.
+ suptitle_fs = 22 if len(panels) > 1 else 18
+ fig.suptitle(title, y=0.98, fontsize=suptitle_fs)
+
+ out_path = os.path.abspath(out_path)
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ # Use a tight bounding box so rotated x tick labels and the figure-level
+ # legend don't get clipped in SVG exports (matplotlib can be fragile here
+ # across versions).
+ fig.savefig(out_path, format="svg", bbox_inches="tight", pad_inches=0.02)
+ plt.close(fig)
+
+
+def _panel_files_for_suite(suite: str) -> List[Tuple[str, str]]:
+ if suite == "quack_suite":
+ return [
+ ("RMSNorm (fp32 weight)", "rmsnorm_fwd_quack_suite_wfp32.json"),
+ ("Softmax (fwd+bwd)", "softmax_fwd_bwd_quack_suite.json"),
+ ("Cross-Entropy (fwd+bwd)", "cross_entropy_fwd_bwd_quack_suite.json"),
+ ]
+ if suite == "dsv3":
+ return [
+ ("Fused Add+RMSNorm (fwd)", "fused_add_rmsnorm_dsv3.json"),
+ ("Softmax (fwd+bwd)", "softmax_fwd_bwd_dsv3.json"),
+ ("LayerNorm (fwd)", "layernorm_fwd_dsv3.json"),
+ ]
+ if suite == "dsv3_all":
+ return [
+ ("Fused Add+RMSNorm (fwd)", "fused_add_rmsnorm_dsv3.json"),
+ ("Softmax (fwd+bwd)", "softmax_fwd_bwd_dsv3.json"),
+ ("LayerNorm (fwd)", "layernorm_fwd_dsv3.json"),
+ ("Cross-Entropy (fwd+bwd)", "cross_entropy_fwd_bwd_dsv3.json"),
+ ]
+ if suite == "dsv3_cross_entropy":
+ return [
+ ("Cross-Entropy (fwd+bwd)", "cross_entropy_fwd_bwd_dsv3.json"),
+ ]
+ raise ValueError(f"Unsupported suite: {suite}")
+
+
+def _layernorm_file_for_suite(suite: str) -> str:
+ if suite == "quack_suite":
+ return "layernorm_fwd_quack_suite.json"
+ raise ValueError(f"Unsupported suite: {suite}")
+
+
+def main() -> None:
+ p = argparse.ArgumentParser(
+ description="Generate Quack-style SVG plots from KernelAgent-Oink suite JSONs."
+ )
+ p.add_argument(
+ "--in-dir", type=str, required=True, help="Directory containing suite JSON outputs"
+ )
+ p.add_argument(
+ "--suite",
+ type=str,
+ default="quack_suite",
+ choices=["quack_suite", "dsv3", "dsv3_all", "dsv3_cross_entropy"],
+ )
+ p.add_argument(
+ "--include-layernorm",
+ action="store_true",
+ help="Add a LayerNorm (fwd) panel (only meaningful for `--suite quack_suite`).",
+ )
+ p.add_argument(
+ "--shape-policy",
+ type=str,
+ default="intersection",
+ choices=["intersection", "union", "first"],
+ help=(
+ "How to pick x-axis shapes across panels. "
+ "`intersection` matches Quack-style (only shapes common to every panel). "
+ "`first` uses the first panel's shapes (keeps DSv3 N=7168 visible). "
+ "`union` includes every shape across panels (may create gaps)."
+ ),
+ )
+ p.add_argument("--roofline-json", type=str, default=None, help="Optional /tmp/hbm_roofline_sm100_*.json path")
+ p.add_argument("--out", type=str, required=True, help="Output SVG path")
+ p.add_argument("--title", type=str, default=None, help="Optional figure title override")
+ args = p.parse_args()
+
+ in_dir = os.path.abspath(args.in_dir)
+ if not os.path.isdir(in_dir):
+ raise SystemExit(f"--in-dir is not a directory: {in_dir}")
+
+ roofline_gbps = _read_roofline_gbps(args.roofline_json) if args.roofline_json else None
+
+ panel_files = list(_panel_files_for_suite(str(args.suite)))
+ if args.include_layernorm:
+ if args.suite != "quack_suite":
+ raise SystemExit("--include-layernorm is only supported for `--suite quack_suite`.")
+ panel_files.append(("LayerNorm (fwd)", _layernorm_file_for_suite(str(args.suite))))
+
+ panels: List[Tuple[str, Dict[Tuple[int, int], Dict[str, float]]]] = []
+ for panel_title, filename in panel_files:
+ path = os.path.join(in_dir, filename)
+ if not os.path.exists(path):
+ raise SystemExit(f"Missing required JSON: {path}")
+ payload = _load_json(path)
+ rows = payload.get("rows", [])
+ if not isinstance(rows, list):
+ rows = []
+ panels.append((panel_title, _aggregate_by_shape(rows)))
+
+ if args.title is not None:
+ title = str(args.title)
+ else:
+ # Try to infer dtype from the first panel's JSON.
+ first_json = os.path.join(in_dir, panel_files[0][1])
+ payload = _load_json(first_json)
+ rows = payload.get("rows", [])
+ dtype = rows[0].get("dtype", "") if rows else ""
+ if args.suite == "quack_suite":
+ suite_name = "Quack-suite"
+ elif args.suite == "dsv3":
+ suite_name = "DSv3 (hidden-size ops)"
+ elif args.suite == "dsv3_all":
+ suite_name = "DSv3 (4 ops)"
+ elif args.suite == "dsv3_cross_entropy":
+ # Keep this short: this suite is rendered as a single panel, so the
+ # figure is much narrower than the 3-panel plots.
+ suite_name = "DSv3 CrossEntropy"
+ else:
+ suite_name = str(args.suite)
+ suffix = " (+LayerNorm)" if (args.suite == "quack_suite" and args.include_layernorm) else ""
+ if args.suite == "dsv3_cross_entropy":
+ title = f"SM100 {dtype.upper()} — {suite_name}{suffix}"
+ else:
+ title = f"SM100 {dtype.upper()} Kernel Benchmarks (Oink vs Quack) — {suite_name}{suffix}"
+
+ _plot(
+ panels=panels,
+ roofline_gbps=roofline_gbps,
+ out_path=str(args.out),
+ title=title,
+ shape_policy=str(args.shape_policy),
+ per_panel_x=(str(args.suite) == "dsv3_all"),
+ )
+ print(f"Wrote: {os.path.abspath(args.out)}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/readme/run_sm100_suite.py b/oink/benchmarks/readme/run_sm100_suite.py
new file mode 100644
index 0000000..5ac1091
--- /dev/null
+++ b/oink/benchmarks/readme/run_sm100_suite.py
@@ -0,0 +1,302 @@
+from __future__ import annotations
+
+import argparse
+import os
+import subprocess
+import sys
+from datetime import datetime
+from typing import List, Tuple
+
+
+def _ts() -> str:
+ return datetime.now().strftime("%Y%m%d_%H%M%S")
+
+
+def _run(cmd: List[str], *, dry_run: bool) -> None:
+ print("+", " ".join(cmd), flush=True)
+ if dry_run:
+ return
+ subprocess.run(cmd, check=True)
+
+
+def main() -> None:
+ p = argparse.ArgumentParser()
+ p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"])
+ p.add_argument(
+ "--out-dir",
+ type=str,
+ default=None,
+ help="Directory to write JSON outputs (default: /tmp/kernelagent_oink_sm100_suite_)",
+ )
+ p.add_argument(
+ "--skip-verify",
+ action="store_true",
+ help="Skip correctness checks (Oink/Quack vs PyTorch / pure-PyTorch references)",
+ )
+ p.add_argument("--dry-run", action="store_true", help="Print commands without executing them")
+ args = p.parse_args()
+
+ # Standardize env for standalone runs outside the vLLM plugin.
+ os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
+ os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
+
+ out_dir = args.out_dir or f"/tmp/kernelagent_oink_sm100_suite_{_ts()}"
+ os.makedirs(out_dir, exist_ok=True)
+
+ here = os.path.dirname(os.path.abspath(__file__))
+ bench_dir = os.path.abspath(os.path.join(here, "..", "benchmark"))
+ py = sys.executable
+
+ def script(name: str) -> str:
+ return os.path.join(bench_dir, name)
+
+ common = ["--dtype", args.dtype]
+ if args.skip_verify:
+ common = [*common, "--skip-verify"]
+
+ runs: List[Tuple[str, List[str]]] = [
+ (
+ "rmsnorm_fwd_quack_suite_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--quack-suite",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wfp32.json"),
+ ],
+ ),
+ (
+ "rmsnorm_fwd_dsv3_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--dsv3",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_dsv3_wfp32.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_quack_suite_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--quack-suite",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wfp32.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_dsv3_wfp32",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "fp32",
+ "--dsv3",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_dsv3_wfp32.json"),
+ ],
+ ),
+ # vLLM inference-style RMSNorm (weight dtype == activation dtype).
+ (
+ "rmsnorm_fwd_quack_suite_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--quack-suite",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wsame.json"),
+ ],
+ ),
+ (
+ "rmsnorm_fwd_dsv3_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--dsv3",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_fwd_dsv3_wsame.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_quack_suite_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--quack-suite",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wsame.json"),
+ ],
+ ),
+ (
+ "rmsnorm_bwd_dsv3_wsame",
+ [
+ py,
+ script("benchmark_rmsnorm_bwd_sm100.py"),
+ *common,
+ "--weight-dtype",
+ "same",
+ "--dsv3",
+ "--iters",
+ "100",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "rmsnorm_bwd_dsv3_wsame.json"),
+ ],
+ ),
+ (
+ "softmax_fwd_bwd_quack_suite",
+ [
+ py,
+ script("benchmark_softmax_sm100.py"),
+ *common,
+ "--mode",
+ "fwd_bwd",
+ "--quack-suite",
+ "--iters",
+ "50",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "softmax_fwd_bwd_quack_suite.json"),
+ ],
+ ),
+ (
+ "softmax_fwd_bwd_dsv3",
+ [
+ py,
+ script("benchmark_softmax_sm100.py"),
+ *common,
+ "--mode",
+ "fwd_bwd",
+ "--dsv3",
+ "--iters",
+ "50",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "softmax_fwd_bwd_dsv3.json"),
+ ],
+ ),
+ (
+ "cross_entropy_fwd_bwd_quack_suite",
+ [
+ py,
+ script("benchmark_cross_entropy_sm100.py"),
+ *common,
+ "--mode",
+ "fwd_bwd",
+ "--quack-suite",
+ "--iters",
+ "50",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "cross_entropy_fwd_bwd_quack_suite.json"),
+ ],
+ ),
+ (
+ "cross_entropy_fwd_bwd_dsv3",
+ [
+ py,
+ script("benchmark_cross_entropy_sm100.py"),
+ *common,
+ "--mode",
+ "fwd_bwd",
+ "--dsv3",
+ "--iters",
+ "50",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "cross_entropy_fwd_bwd_dsv3.json"),
+ ],
+ ),
+ (
+ "layernorm_fwd_quack_suite",
+ [
+ py,
+ script("benchmark_layernorm_sm100.py"),
+ *common,
+ "--quack-suite",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "layernorm_fwd_quack_suite.json"),
+ ],
+ ),
+ (
+ "layernorm_fwd_dsv3",
+ [
+ py,
+ script("benchmark_layernorm_sm100.py"),
+ *common,
+ "--dsv3",
+ "--iters",
+ "200",
+ "--warmup-ms",
+ "25",
+ "--json",
+ os.path.join(out_dir, "layernorm_fwd_dsv3.json"),
+ ],
+ ),
+ ]
+
+ print(f"Writing results to: {out_dir}", flush=True)
+ for name, cmd in runs:
+ print(f"\n== {name} ==", flush=True)
+ _run(cmd, dry_run=bool(args.dry_run))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/benchmarks/readme/summarize_results.py b/oink/benchmarks/readme/summarize_results.py
new file mode 100644
index 0000000..70782dd
--- /dev/null
+++ b/oink/benchmarks/readme/summarize_results.py
@@ -0,0 +1,205 @@
+from __future__ import annotations
+
+import argparse
+import json
+import math
+import os
+from typing import Any, Dict, Iterable, List, Optional, Sequence
+
+
+def _load_json(path: str) -> Dict[str, Any]:
+ with open(path) as f:
+ return json.load(f)
+
+
+def _fmt_cell(v: object) -> str:
+ if v is None:
+ return ""
+ if isinstance(v, float):
+ if math.isfinite(v):
+ av = abs(v)
+ # Use scientific notation for very small values so we don't render
+ # meaningful error stats as "0.0000".
+ if av != 0.0 and av < 1e-3:
+ return f"{v:.2e}"
+ return f"{v:.4f}"
+ return str(v)
+ return str(v)
+
+
+def _md_table(rows: Sequence[Dict[str, Any]], columns: Sequence[str]) -> str:
+ header = "| " + " | ".join(columns) + " |"
+ sep = "|" + "|".join(["---"] * len(columns)) + "|"
+ lines = [header, sep]
+ for r in rows:
+ lines.append("| " + " | ".join(_fmt_cell(r.get(c)) for c in columns) + " |")
+ return "\n".join(lines)
+
+
+def _pick_columns(rows: Sequence[Dict[str, Any]]) -> List[str]:
+ preferred = [
+ "M",
+ "N",
+ "dtype",
+ "weight_dtype",
+ "mode",
+ "eps",
+ "store_rstd",
+ "return_rstd",
+ "return_mean",
+ "ignore_index",
+ "ours_ms",
+ "ours_tbps",
+ "ours_hbm_frac",
+ "quack_ms",
+ "quack_tbps",
+ "speedup_vs_quack",
+ ]
+ present = set().union(*(r.keys() for r in rows)) if rows else set()
+ cols = [c for c in preferred if c in present]
+ # Fall back to a stable sorted view if we missed everything (shouldn't happen).
+ return cols or sorted(present)
+
+
+def _geomean(values: Iterable[float]) -> Optional[float]:
+ logs: List[float] = []
+ for v in values:
+ if v <= 0 or not math.isfinite(v):
+ continue
+ logs.append(math.log(v))
+ if not logs:
+ return None
+ return math.exp(sum(logs) / len(logs))
+
+
+def _collect_error_prefixes(rows: Sequence[Dict[str, Any]]) -> List[str]:
+ """Infer error-stat prefixes like `ours_err_dx` from row keys."""
+ prefixes: set[str] = set()
+ for r in rows:
+ for k in r.keys():
+ if not isinstance(k, str):
+ continue
+ if not k.endswith("_max_abs"):
+ continue
+ if "err_" not in k:
+ continue
+ prefixes.add(k[: -len("_max_abs")])
+ return sorted(prefixes)
+
+
+def _summarize_error_stats(rows: Sequence[Dict[str, Any]]) -> str:
+ prefixes = _collect_error_prefixes(rows)
+ if not prefixes:
+ return ""
+
+ out_rows: List[Dict[str, Any]] = []
+ for pfx in prefixes:
+ # Per-prefix worst-case across rows.
+ max_abs_vals = [float(r[pfx + "_max_abs"]) for r in rows if (pfx + "_max_abs") in r]
+ p99_abs_vals = [float(r[pfx + "_p99_abs"]) for r in rows if (pfx + "_p99_abs") in r]
+ rel_l2_vals = [float(r[pfx + "_rel_l2"]) for r in rows if (pfx + "_rel_l2") in r]
+ if not max_abs_vals and not p99_abs_vals and not rel_l2_vals:
+ continue
+ out_rows.append(
+ {
+ "metric": pfx,
+ "max_abs (max over shapes)": max(max_abs_vals) if max_abs_vals else None,
+ "p99_abs (max over shapes)": max(p99_abs_vals) if p99_abs_vals else None,
+ "rel_l2 (max over shapes)": max(rel_l2_vals) if rel_l2_vals else None,
+ }
+ )
+
+ if not out_rows:
+ return ""
+
+ cols = ["metric", "max_abs (max over shapes)", "p99_abs (max over shapes)", "rel_l2 (max over shapes)"]
+ return "\n".join(["", "### Error Stats (vs PyTorch ref)", "", _md_table(out_rows, cols), ""])
+
+
+def summarize_one(path: str) -> str:
+ payload = _load_json(path)
+ meta = payload.get("meta", {})
+ rows = payload.get("rows", [])
+ if not isinstance(rows, list):
+ rows = []
+
+ cols = _pick_columns(rows)
+ parts: List[str] = []
+
+ base = os.path.basename(path)
+ parts.append(f"## `{base}`")
+ if meta:
+ device = meta.get("device")
+ cap = meta.get("capability")
+ torch_ver = meta.get("torch")
+ cuda_ver = meta.get("cuda")
+ git_sha = meta.get("git_sha")
+ ts = meta.get("timestamp")
+ parts.append("")
+ parts.append(
+ f"- device: `{device}` | capability: `{cap}` | torch: `{torch_ver}` | cuda: `{cuda_ver}` | git_sha: `{git_sha}` | timestamp: `{ts}`"
+ )
+ method = meta.get("method")
+ if method is not None:
+ parts.append(f"- method: `{method}`")
+ if meta.get("warmup_ms") is not None and meta.get("rep_ms") is not None:
+ parts.append(f"- warmup_ms: `{meta.get('warmup_ms')}` | rep_ms: `{meta.get('rep_ms')}`")
+
+ if rows:
+ parts.append("")
+ parts.append(_md_table(rows, cols))
+
+ speeds = [float(r["speedup_vs_quack"]) for r in rows if "speedup_vs_quack" in r]
+ gm = _geomean(speeds)
+ if gm is not None:
+ parts.append("")
+ parts.append(f"- geomean speedup vs Quack: `{gm:.3f}x` (over {len(speeds)} shapes)")
+
+ err_block = _summarize_error_stats(rows)
+ if err_block:
+ parts.append(err_block.rstrip())
+ else:
+ parts.append("")
+ parts.append("_No rows found in JSON._")
+
+ parts.append("")
+ return "\n".join(parts)
+
+
+def main() -> None:
+ p = argparse.ArgumentParser(description="Summarize KernelAgent-Oink benchmark JSONs into Markdown tables.")
+ p.add_argument("--in-dir", type=str, required=True, help="Directory containing benchmark JSON files")
+ p.add_argument("--out", type=str, default=None, help="Optional output markdown path (default: stdout)")
+ args = p.parse_args()
+
+ in_dir = os.path.abspath(args.in_dir)
+ if not os.path.isdir(in_dir):
+ raise SystemExit(f"--in-dir is not a directory: {in_dir}")
+
+ json_paths = sorted(
+ os.path.join(in_dir, name) for name in os.listdir(in_dir) if name.endswith(".json")
+ )
+ if not json_paths:
+ raise SystemExit(f"No .json files found under: {in_dir}")
+
+ out_parts: List[str] = []
+ out_parts.append("# KernelAgent-Oink SM100 Benchmark Summary")
+ out_parts.append("")
+ out_parts.append(f"Input directory: `{in_dir}`")
+ out_parts.append("")
+ for path in json_paths:
+ out_parts.append(summarize_one(path))
+
+ text = "\n".join(out_parts).rstrip() + "\n"
+ if args.out is None:
+ print(text, end="")
+ return
+
+ out_path = os.path.abspath(args.out)
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ with open(out_path, "w") as f:
+ f.write(text)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/oink/pyproject.toml b/oink/pyproject.toml
index a9ec306..0d19d6e 100644
--- a/oink/pyproject.toml
+++ b/oink/pyproject.toml
@@ -5,11 +5,26 @@ build-backend = "setuptools.build_meta"
[project]
name = "kernelagent-oink"
version = "0.1.0"
-description = "vLLM plugin that registers Oink Blackwell RMSNorm custom ops"
+description = "CuTeDSL kernels for Blackwell (SM100), shipped as a vLLM plugin"
readme = "README.md"
requires-python = ">=3.10"
license = {text = "Apache-2.0"}
authors = [{name = "PyTorch Labs"}]
+keywords = ["cuda", "cutlass", "cute", "cutedsl", "blackwell", "sm100", "vllm"]
+classifiers = [
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3 :: Only",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+]
+
+[project.urls]
+Repository = "https://github.com/meta-pytorch/KernelAgent"
+Documentation = "https://github.com/meta-pytorch/KernelAgent/tree/main/oink"
+Issues = "https://github.com/meta-pytorch/KernelAgent/issues"
# Keep dependencies minimal, but include the CuTeDSL stack required by the
# Blackwell RMSNorm implementation.
@@ -21,6 +36,13 @@ dependencies = [
"cuda-python",
]
+[project.optional-dependencies]
+# Optional extras for running the in-repo benchmark suite (not needed for vLLM integration).
+bench = [
+ "matplotlib",
+ "triton",
+]
+
[project.entry-points."vllm.general_plugins"]
oink = "kernelagent_oink:register"
diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py
index bbbd7c1..f5c36a6 100644
--- a/oink/src/kernelagent_oink/__init__.py
+++ b/oink/src/kernelagent_oink/__init__.py
@@ -12,6 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+KernelAgent-Oink: SM100 CuTeDSL kernels + optional vLLM plugin.
+
+This package can be loaded as a vLLM "general plugin" (entrypoint group
+`vllm.general_plugins`). In that mode it registers Oink custom ops only when
+explicitly enabled via an environment variable (so installing the package does
+not change behavior by default).
+
+For standalone usage (outside vLLM), call `kernelagent_oink.register(force=True)`
+to register the custom ops explicitly.
+"""
+
from __future__ import annotations
import logging
@@ -48,11 +60,14 @@ def _compute_cutedsl_arch(major: int, minor: int) -> str:
return f"sm_{major}{minor}{suffix}"
-def register() -> None:
- """vLLM plugin entrypoint.
+def register(*, force: bool = False) -> None:
+ """Register Oink torch custom ops.
+
+ - vLLM plugin mode (default): no-op unless `VLLM_USE_OINK_RMSNORM` is truthy.
+ - Standalone mode: pass `force=True` to register explicitly.
- This function must be safe to call multiple times and must not raise.
- vLLM executes it in multiple processes (engine + workers).
+ This function must be safe to call multiple times and must not raise. vLLM
+ executes it in multiple processes (engine + workers).
"""
global _OPS_REGISTERED
@@ -60,8 +75,9 @@ def register() -> None:
return
# Gate on the vLLM integration flag so installing the package does not
- # change behavior unless explicitly enabled.
- if not _env_truthy("VLLM_USE_OINK_RMSNORM"):
+ # change behavior unless explicitly enabled. For standalone usage (outside
+ # vLLM), callers can pass force=True to register the ops explicitly.
+ if not force and not _env_truthy("VLLM_USE_OINK_RMSNORM"):
return
try:
diff --git a/oink/src/kernelagent_oink/blackwell/cross_entropy.py b/oink/src/kernelagent_oink/blackwell/cross_entropy.py
new file mode 100644
index 0000000..94f052f
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/cross_entropy.py
@@ -0,0 +1,1209 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Cross-entropy forward + backward kernels for SM100 (Blackwell) in CuteDSL.
+
+This module implements numerically stable cross-entropy over the last
+dimension of 2D logits tensors `(M, N)` together with its backward pass,
+targeting SM100 with Quack-style tiling, cp.async pipelines, and (for the
+forward pass) optional cluster-wide online softmax reductions, but without
+depending on the external `quack` package at runtime.
+
+Public APIs:
+
+- ``cross_entropy_forward(logits, target, ignore_index=-100, reduction="none")``
+ returns ``(loss, lse)`` where ``loss`` follows the requested reduction and
+ ``lse`` is always per-example log-sum-exp (shape ``(M,)``).
+- ``cross_entropy_backward(dloss, logits, target, lse, ignore_index=-100)``
+ returns per-logit gradients ``dlogits`` matching PyTorch /
+ ``quack.cross_entropy_bwd`` semantics for ``reduction="none"``.
+- ``cross_entropy(logits, target, ignore_index=-100, reduction="mean"|"sum"|"none")``
+ is a convenience wrapper that mirrors ``torch.nn.functional.cross_entropy``
+ reductions using the SM100 CuteDSL kernels for the forward pass.
+
+The kernels are self-contained and use only local helpers in
+`kernelagent_oink.blackwell.lite_quack` plus CuTeDSL/CUTLASS.
+"""
+
+from __future__ import annotations
+
+import importlib.metadata
+import math
+import os
+import re
+from typing import Literal, Optional, Type
+
+import torch
+from torch import Tensor
+
+import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python
+
+# CuTeDSL caches generated MLIR into a tempdir under a global default
+# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across
+# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy
+# warnings (and disables cache reuse).
+if "CUTE_DSL_CACHE_DIR" not in os.environ:
+ try:
+ _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl")
+ except Exception:
+ _dsl_ver = "unknown"
+ _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver)
+ _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user"
+ _tmp = os.environ.get("TMPDIR") or "/tmp"
+ os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join(
+ _tmp, _user, f"cutlass_python_cache_{_dsl_ver}"
+ )
+
+try:
+ import cutlass # type: ignore # noqa: F401
+except Exception as e:
+ raise ImportError(
+ "kernelagent_oink.blackwell.cross_entropy requires CuTeDSL's Python package "
+ "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)."
+ ) from e
+
+import cutlass.cute as cute
+from cutlass import Boolean, Float32, Int32, const_expr
+from cutlass.cute import runtime as rt
+from cutlass.cute.runtime import from_dlpack
+
+from kernelagent_oink.blackwell.lite_quack import (
+ _KERNEL_ACCEPTS_LAYOUT_ARGS,
+ TORCH2CUTE_DTYPE,
+ ReductionBase,
+ domain_offset_i64,
+ fill_oob,
+ online_softmax_reduce,
+ predicate_k,
+)
+
+_FWD_COMPILE_CACHE: dict[tuple[type[cutlass.Numeric], int], cute.Kernel] = {}
+_BWD_COMPILE_CACHE: dict[tuple[type[cutlass.Numeric], int], cute.Kernel] = {}
+_PTR_FWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+_PTR_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+
+
+def _convert_logits_2d(x: Tensor) -> cute.Tensor:
+ """Convert a 2D logits tensor (M, N) into a CuTe tensor.
+
+ We assume 16-byte alignment and mark the layout compact and row-major
+ in the last dimension, matching the conventions used in the SM100
+ softmax and RMSNorm kernels.
+ """
+ assert x.dim() == 2, "Input logits must be 2D (M, N)"
+ return (
+ from_dlpack(x.detach(), assumed_align=16)
+ .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
+ )
+
+
+def _convert_1d(t: Tensor, assumed_align: int) -> cute.Tensor:
+ """Convert a 1D tensor with a fully dynamic layout."""
+ assert t.dim() == 1, "Expected a 1D tensor"
+ return from_dlpack(t.detach(), assumed_align=assumed_align).mark_layout_dynamic()
+
+
+class CrossEntropyFwdSM100(ReductionBase):
+ """SM100-tuned cross-entropy forward kernel.
+
+ This mirrors the structure of ``quack.cross_entropy.CrossEntropy`` but
+ is simplified to always use the single-pass online softmax reduction and
+ never computes gradients inside the forward kernel.
+ """
+
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
+ # Use one stage with an Int64 reduction buffer packing (max, sum_exp)
+ # pairs via lite_quack.online_softmax_reduce.
+ super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64)
+
+ def _calculate_threads_per_row(self) -> int:
+ N = self.N
+ return (
+ 8
+ if N <= 64
+ else (
+ 16
+ if N <= 128
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
+ )
+ )
+
+ def _set_cluster_n(self) -> None:
+ # Match Quack's cluster_n growth policy while keeping it explicit so
+ # we can tune SM100-specific shapes later if needed.
+ N = self.N
+ if const_expr(self.dtype.width == 16):
+ cluster_n = (
+ 1
+ if N <= 16 * 1024
+ else (
+ 2
+ if N <= 32 * 1024
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
+ )
+ )
+ else: # fp32
+ cluster_n = (
+ 1
+ if N <= 16 * 1024
+ else (
+ 2
+ if N <= 64 * 1024
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
+ )
+ )
+ self.cluster_n = cluster_n
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mLoss: cute.Tensor, # (M,)
+ mLSE: Optional[cute.Tensor], # (M,)
+ ignore_index: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ assert mX.element_type == self.dtype
+ self._set_cluster_n()
+ # If N is not divisible by the full 128-bit vector width, step down
+ # to the largest compatible vector size as in Quack.
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
+ num_threads = (
+ cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ kernel = (
+ self.kernel(
+ mX,
+ mTarget,
+ mLoss,
+ mLSE,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(
+ mX,
+ mTarget,
+ mLoss,
+ mLSE,
+ ignore_index,
+ )
+ )
+ kernel.launch(
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_logits: cute.Pointer,
+ ptr_target: cute.Pointer,
+ ptr_loss: cute.Pointer,
+ ptr_lse: cute.Pointer,
+ M: Int32,
+ ld: Int32,
+ ignore_index: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ """Pointer-based entrypoint that bypasses DLPack conversions."""
+ ld_assumed = cute.assume(ld, divby=128 // self.dtype.width)
+ layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1))
+ layout_m = cute.make_layout((M,), stride=(1,))
+ mX = cute.make_tensor(ptr_logits, layout_mn)
+ mTarget = cute.make_tensor(ptr_target, layout_m)
+ mLoss = cute.make_tensor(ptr_loss, layout_m)
+ mLSE = cute.make_tensor(ptr_lse, layout_m)
+ self.__call__(mX, mTarget, mLoss, mLSE, ignore_index, stream)
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mLoss: cute.Tensor, # (M,)
+ mLSE: Optional[cute.Tensor], # (M,)
+ ignore_index: Int32, # Index to ignore in loss computation
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ if const_expr(self.cluster_n > 1):
+ cluster_y = cute.arch.block_idx()[1]
+ else:
+ cluster_y = const_expr(0)
+
+ shape: cute.Shape = mX.shape
+ idX = cute.make_identity_tensor(shape)
+
+ # Slice per-CTA region; use 64-bit indexing for large tensors.
+ mX_off = domain_offset_i64((bidx * tiler_mn[0], 0), mX)
+ gX = cute.local_tile(mX_off, tiler_mn, (0, cluster_y))
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
+
+ smem = cutlass.utils.SmemAllocator()
+ sX = smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
+
+ # Copy setup: gmem -> smem via cp.async, 128-bit or narrower as needed.
+ num_copy_elems_X = tv_layout.shape[1][0]
+ num_copy_bits_X = mX.element_type.width * num_copy_elems_X
+ copy_atom_load_X = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ gX.element_type,
+ num_bits_per_copy=num_copy_bits_X,
+ )
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
+
+ tXgX = thr_copy_X.partition_S(gX)
+ tXsX = thr_copy_X.partition_D(sX)
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
+ tXrX = cute.make_fragment_like(tXgX)
+
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
+
+ row = tXcX[0][0]
+ target = Int32.zero
+ if row < shape[0]:
+ target = Int32(mTarget[row])
+
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
+ tXpX = (
+ predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+ if row < shape[0]:
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+
+ # Fill out-of-bounds values with -inf so they are ignored in max/sum.
+ if const_expr(not is_even_N):
+ fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
+
+ cute.autovec_copy(tXsX, tXrX)
+ x = tXrX.load().to(Float32)
+
+ should_ignore = Boolean(target == ignore_index)
+
+ # Load the target logit if this row is not ignored. Use Int64 indexing
+ # to safely handle very large tensors.
+ target_logit = Float32.zero
+ if row < shape[0] and tXcX[0][1] == 0 and not should_ignore:
+ mX_row = domain_offset_i64((row, 0), mX)
+ target_logit = Float32(mX_row[0, target])
+
+ threads_per_row = tv_layout.shape[0][0]
+ max_x, denom, _ = online_softmax_reduce(
+ x,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr,
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
+ phase=None,
+ return_exp_x=False,
+ )
+
+ # Write loss and lse to gmem. Only one CTA in the cluster writes to
+ # avoid duplicate stores.
+ if (
+ tXcX[0][1] == 0
+ and row < shape[0]
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
+ ):
+ lse = max_x + cute.math.log(denom, fastmath=True)
+ loss_val = (lse - target_logit) if not should_ignore else Float32.zero
+ mLoss[row] = mLoss.element_type(loss_val)
+ if const_expr(mLSE is not None):
+ mLSE[row] = lse
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mLoss: cute.Tensor, # (M,)
+ mLSE: Optional[cute.Tensor], # (M,)
+ ignore_index: Int32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ self._kernel_impl(
+ mX,
+ mTarget,
+ mLoss,
+ mLSE,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mLoss: cute.Tensor, # (M,)
+ mLSE: Optional[cute.Tensor], # (M,)
+ ignore_index: Int32,
+ ) -> None:
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
+ self._kernel_impl(
+ mX,
+ mTarget,
+ mLoss,
+ mLSE,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+
+
+class CrossEntropyBackwardSM100:
+ """SM100-tuned cross-entropy backward kernel.
+
+ This is a direct port of ``quack.cross_entropy.CrossEntropyBackward`` to
+ the local lite_quack helpers, using cp.async tiling over the (M, N)
+ logits and broadcasting ``dloss`` / ``lse`` across the row dimension.
+ """
+
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
+ self.dtype = dtype
+ self.N = N
+
+ def _get_num_threads(self) -> int:
+ # Keep in sync with _get_tv_layout() (we tile N in 16k blocks).
+ N = min(self.N, 16384)
+ return 128 if N <= 16384 else 256
+
+ def _calculate_threads_per_row(self) -> int:
+ N = min(self.N, 16384) # We split by blocks of 16k in N.
+ return (
+ 8
+ if N <= 64
+ else (
+ 16
+ if N <= 128
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
+ )
+ )
+
+ def _get_tv_layout(self, num_copy_bits: int = 128) -> tuple[cute.Shape, cute.Layout]:
+ vecsize = num_copy_bits // self.dtype.width
+ assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
+ N = min(self.N, 16384)
+ num_threads = 128 if N <= 16384 else 256
+ threads_per_row = self._calculate_threads_per_row()
+ cols_per_block = num_threads // threads_per_row
+ num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row)
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
+ tv_layout = cute.make_layout(
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
+ stride=(
+ (vecsize * cols_per_block, 1),
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
+ ),
+ )
+ return tiler_mn, tv_layout
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor,
+ mTarget: cute.Tensor,
+ mDLoss: cute.Tensor,
+ mdX: cute.Tensor,
+ mLSE: cute.Tensor,
+ ignore_index: Int32, # Index to ignore in gradient computation
+ stream: cuda.CUstream,
+ ) -> None:
+ assert mX.element_type == self.dtype
+ assert mdX.element_type == self.dtype
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
+ num_threads = (
+ cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ )
+ # Broadcast (M,) tensors along the N dimension with stride 0.
+ mDLoss, mTarget, mLSE = [
+ cute.make_tensor(
+ X.iterator,
+ cute.append(X.layout, cute.make_layout((self.N,), stride=(0,))),
+ )
+ for X in (mDLoss, mTarget, mLSE)
+ ]
+ smem_size = cute.size_in_bytes(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ )
+ kernel = (
+ self.kernel(
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ mLSE,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ mLSE,
+ ignore_index,
+ )
+ )
+ kernel.launch(
+ grid=[
+ cute.ceil_div(mX.shape[0], tiler_mn[0]),
+ cute.ceil_div(mX.shape[1], tiler_mn[1]),
+ 1,
+ ],
+ block=[num_threads, 1, 1],
+ smem=smem_size,
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_logits: cute.Pointer,
+ ptr_target: cute.Pointer,
+ ptr_dloss: cute.Pointer,
+ ptr_dx: cute.Pointer,
+ ptr_lse: cute.Pointer,
+ M: Int32,
+ ld: Int32,
+ ignore_index: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ """Pointer-based entrypoint that bypasses DLPack conversions."""
+ ld_assumed = cute.assume(ld, divby=128 // self.dtype.width)
+ layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1))
+ layout_m = cute.make_layout((M,), stride=(1,))
+ mX = cute.make_tensor(ptr_logits, layout_mn)
+ mdX = cute.make_tensor(ptr_dx, layout_mn)
+ mTarget = cute.make_tensor(ptr_target, layout_m)
+ mDLoss = cute.make_tensor(ptr_dloss, layout_m)
+ mLSE = cute.make_tensor(ptr_lse, layout_m)
+ self.__call__(mX, mTarget, mDLoss, mdX, mLSE, ignore_index, stream)
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mDLoss: cute.Tensor, # (M,)
+ mdX: cute.Tensor, # (M, N)
+ mLSE: cute.Tensor, # (M,)
+ ignore_index: Int32, # Index to ignore in gradient computation
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, bidy, _ = cute.arch.block_idx()
+ shape = mX.shape
+
+ smem = cutlass.utils.SmemAllocator()
+ sX = smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+
+ idX = cute.make_identity_tensor(shape)
+ mX_off, mdX_off = [
+ domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mdX)
+ ]
+ gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX_off, mdX_off)]
+ cX = cute.local_tile(idX, tiler_mn, (bidx, bidy))
+
+ num_copy_elems_X = tv_layout.shape[1][0]
+ num_copy_bits_X = mX.element_type.width * num_copy_elems_X
+ copy_atom_load_X = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ gX.element_type,
+ num_bits_per_copy=num_copy_bits_X,
+ )
+ copy_atom_store_dX = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ gdX.element_type,
+ num_bits_per_copy=num_copy_bits_X,
+ )
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
+ thr_copy_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
+
+ tXgX = thr_copy_X.partition_S(gX)
+ tXsX = thr_copy_X.partition_D(sX)
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
+ tXcFull = thr_copy_X.partition_S(cX)
+ tXgdX = thr_copy_dX.partition_D(gdX)
+
+ tXrX, tXrdX = [cute.make_fragment_like(thr) for thr in (tXgX, tXgdX)]
+
+ is_even_N = const_expr(shape[1] % tiler_mn[1] == 0)
+ row = tXcX[0][0]
+ tXpX = (
+ predicate_k(thr_copy_X.partition_S(cX), limit=shape[1])
+ if not is_even_N
+ else None
+ )
+ if row < shape[0]:
+ cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX)
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+ if const_expr(not is_even_N):
+ fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
+ cute.autovec_copy(tXsX, tXrX)
+ x = tXrX.load().to(Float32)
+
+ target = Int32.zero
+ dloss = Float32.zero
+ lse = Float32.zero
+ if row < shape[0]:
+ target = Int32(mTarget[row])
+ should_ignore = Boolean(target == ignore_index)
+ dloss = Float32(mDLoss[row]) if not should_ignore else Float32.zero
+ lse = Float32(mLSE[row])
+
+ log2_e = math.log2(math.e)
+ probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True)
+ prob_shifted = probs - 1.0
+ mask = cute.make_fragment_like(tXrX, cutlass.Boolean)
+ for i in cutlass.range(cute.size(tXcFull), unroll_full=True):
+ mask[i] = tXcFull[i][1] == target
+ grad = cute.where(mask.load(), prob_shifted, probs)
+ grad = grad * dloss
+
+ tXrdX.store(grad.to(tXrdX.element_type))
+ tXpdX = (
+ predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1])
+ if not is_even_N
+ else None
+ )
+ if row < shape[0]:
+ cute.copy(copy_atom_store_dX, tXrdX, tXgdX, pred=tXpdX)
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mDLoss: cute.Tensor, # (M,)
+ mdX: cute.Tensor, # (M, N)
+ mLSE: cute.Tensor, # (M,)
+ ignore_index: Int32, # Index to ignore in gradient computation
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ self._kernel_impl(
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ mLSE,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor, # (M, N)
+ mTarget: cute.Tensor, # (M,)
+ mDLoss: cute.Tensor, # (M,)
+ mdX: cute.Tensor, # (M, N)
+ mLSE: cute.Tensor, # (M,)
+ ignore_index: Int32, # Index to ignore in gradient computation
+ ) -> None:
+ num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
+ tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
+ self._kernel_impl(
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ mLSE,
+ ignore_index,
+ tv_layout,
+ tiler_mn,
+ )
+
+
+def cross_entropy_forward(
+ logits: Tensor,
+ target: Tensor,
+ ignore_index: int = -100,
+ reduction: Literal["none", "mean", "sum"] = "none",
+) -> tuple[Tensor, Tensor]:
+ """SM100 CuteDSL cross-entropy forward pass.
+
+ Args:
+ logits: Tensor of shape ``(M, N)`` on CUDA.
+ target: Tensor of shape ``(M,)`` with integer class indices.
+ ignore_index: Target value to ignore when computing the loss.
+ reduction: One of ``"none"``, ``"mean"``, or ``"sum"`` following
+ ``torch.nn.functional.cross_entropy`` semantics.
+
+ Returns:
+ A tuple ``(loss, lse)`` where:
+ - ``loss`` has shape ``(M,)`` if ``reduction="none"`` or is a scalar
+ otherwise.
+ - ``lse`` is the per-example log-sum-exp with shape ``(M,)``.
+ """
+ assert logits.dim() == 2, "logits must be 2D (M, N)"
+ assert target.dim() == 1, "target must be 1D (M,)"
+ assert logits.shape[0] == target.shape[0], "Batch dimensions must match"
+ assert logits.is_cuda and target.is_cuda, "logits and target must be on CUDA device"
+ assert logits.dtype in TORCH2CUTE_DTYPE, "Unsupported logits dtype"
+ assert target.dtype in (torch.int32, torch.int64), "target must be int32 or int64"
+
+ M, N = logits.shape
+ device = logits.device
+ dtype_cute = TORCH2CUTE_DTYPE[logits.dtype]
+
+ loss = torch.empty(M, device=device, dtype=torch.float32)
+ lse = torch.empty(M, device=device, dtype=torch.float32)
+
+ if _can_use_ptr_path_logits(logits) and _can_use_ptr_path_target(target):
+ _cross_entropy_forward_ptr_into(
+ logits=logits,
+ target=target,
+ loss=loss,
+ lse=lse,
+ ignore_index=int(ignore_index),
+ )
+ if reduction == "none":
+ return loss, lse
+ with torch.no_grad():
+ mask = target != ignore_index
+ if reduction == "sum":
+ reduced = loss.sum()
+ elif reduction == "mean":
+ valid = mask.sum()
+ if valid > 0:
+ reduced = loss[mask].sum() / valid.to(loss.dtype)
+ else:
+ reduced = loss.sum() * 0.0
+ else:
+ raise ValueError(
+ f"Invalid reduction mode: {reduction}. Expected 'none', 'mean', or 'sum'."
+ )
+ return reduced, lse
+
+ mX = _convert_logits_2d(logits)
+ mTarget = _convert_1d(target.to(torch.int64), assumed_align=8)
+ mLoss = _convert_1d(loss, assumed_align=4)
+ mLSE = _convert_1d(lse, assumed_align=4)
+
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+
+ compile_key = (dtype_cute, N)
+ kernel = _FWD_COMPILE_CACHE.get(compile_key)
+ if kernel is None:
+ op = CrossEntropyFwdSM100(dtype_cute, N)
+ kernel = cute.compile(
+ op,
+ mX,
+ mTarget,
+ mLoss,
+ mLSE,
+ Int32(ignore_index),
+ current_stream,
+ )
+ _FWD_COMPILE_CACHE[compile_key] = kernel
+
+ kernel(mX, mTarget, mLoss, mLSE, Int32(ignore_index), current_stream)
+
+ if reduction == "none":
+ return loss, lse
+
+ with torch.no_grad():
+ mask = target != ignore_index
+ if reduction == "sum":
+ reduced = loss.sum()
+ elif reduction == "mean":
+ valid = mask.sum()
+ if valid > 0:
+ reduced = loss[mask].sum() / valid.to(loss.dtype)
+ else:
+ reduced = loss.sum() * 0.0
+ else:
+ raise ValueError(
+ f"Invalid reduction mode: {reduction}. Expected 'none', 'mean', or 'sum'."
+ )
+ return reduced, lse
+
+
+def _cross_entropy_backward_sm100(
+ logits: Tensor,
+ target: Tensor,
+ dloss: Tensor,
+ lse: Tensor,
+ dx: Tensor,
+ ignore_index: int = -100,
+) -> None:
+ """Internal SM100 cross-entropy backward dispatch using CuteDSL."""
+ assert logits.dim() == 2, "logits must be 2D (M, N)"
+ assert target.dim() == 1, "target must be 1D (M,)"
+ assert dloss.dim() == 1, "dloss must be 1D (M,)"
+ assert lse.dim() == 1, "lse must be 1D (M,)"
+ assert logits.shape[0] == target.shape[0] == dloss.shape[0] == lse.shape[0], (
+ "Batch dimensions must match"
+ )
+ assert logits.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda, (
+ "All tensors must be on CUDA device"
+ )
+ assert logits.dtype in TORCH2CUTE_DTYPE, "Unsupported logits dtype"
+ assert target.dtype in (torch.int32, torch.int64), "target must be int32 or int64"
+
+ M, N = logits.shape
+ dtype_cute = TORCH2CUTE_DTYPE[logits.dtype]
+
+ if (
+ _can_use_ptr_path_logits(logits)
+ and _can_use_ptr_path_logits(dx)
+ and _can_use_ptr_path_target(target)
+ and _can_use_ptr_path_f32_1d(dloss)
+ and _can_use_ptr_path_f32_1d(lse)
+ and logits.stride() == dx.stride()
+ ):
+ _cross_entropy_backward_ptr_into(
+ logits=logits,
+ target=target,
+ dloss=dloss,
+ lse=lse,
+ dx=dx,
+ ignore_index=int(ignore_index),
+ )
+ return
+
+ mX = _convert_logits_2d(logits)
+ mdX = _convert_logits_2d(dx)
+ mTarget = _convert_1d(target.to(torch.int64), assumed_align=8)
+ mDLoss = _convert_1d(dloss.to(torch.float32), assumed_align=4)
+ mLSE = _convert_1d(lse.to(torch.float32), assumed_align=4)
+
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+
+ compile_key = (dtype_cute, N)
+ kernel = _BWD_COMPILE_CACHE.get(compile_key)
+ if kernel is None:
+ op = CrossEntropyBackwardSM100(dtype_cute, N)
+ kernel = cute.compile(
+ op,
+ mX,
+ mTarget,
+ mDLoss,
+ mdX,
+ mLSE,
+ Int32(ignore_index),
+ current_stream,
+ )
+ _BWD_COMPILE_CACHE[compile_key] = kernel
+
+ kernel(mX, mTarget, mDLoss, mdX, mLSE, Int32(ignore_index), current_stream)
+
+
+def _can_use_ptr_path_logits(x: Tensor) -> bool:
+ if not x.is_cuda or x.dim() != 2:
+ return False
+ if x.dtype not in TORCH2CUTE_DTYPE:
+ return False
+ if x.stride(1) != 1:
+ return False
+ if (x.data_ptr() % 16) != 0:
+ return False
+ dtype_x = TORCH2CUTE_DTYPE[x.dtype]
+ divby = 128 // dtype_x.width
+ if (x.stride(0) % divby) != 0:
+ return False
+ return True
+
+
+def _can_use_ptr_path_target(t: Tensor) -> bool:
+ if not t.is_cuda or t.dim() != 1:
+ return False
+ if t.dtype is not torch.int64:
+ return False
+ if not t.is_contiguous():
+ return False
+ if t.stride(0) != 1:
+ return False
+ if (t.data_ptr() % 8) != 0:
+ return False
+ return True
+
+
+def _can_use_ptr_path_f32_1d(t: Tensor) -> bool:
+ if not t.is_cuda or t.dim() != 1:
+ return False
+ if t.dtype is not torch.float32:
+ return False
+ if not t.is_contiguous():
+ return False
+ if t.stride(0) != 1:
+ return False
+ if (t.data_ptr() % 4) != 0:
+ return False
+ return True
+
+
+def _cross_entropy_forward_ptr_into(
+ *,
+ logits: Tensor,
+ target: Tensor,
+ loss: Tensor,
+ lse: Tensor,
+ ignore_index: int,
+) -> None:
+ assert logits.is_cuda and logits.dim() == 2
+ assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0]
+ assert target.dtype is torch.int64
+ assert loss.is_cuda and loss.shape == (logits.shape[0],) and loss.dtype is torch.float32
+ assert lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32
+
+ M, N = logits.shape
+ device_index = logits.get_device()
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream))
+
+ dtype_x = TORCH2CUTE_DTYPE[logits.dtype]
+ key = ("ptr_fwd", int(N), dtype_x, int(device_index))
+ compiled = _PTR_FWD_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = CrossEntropyFwdSM100(dtype_x, int(N))
+ ptr_logits = rt.make_ptr(
+ dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_target = rt.make_ptr(
+ cutlass.Int64,
+ target.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=8,
+ )
+ ptr_loss = rt.make_ptr(
+ cutlass.Float32,
+ loss.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ ptr_lse = rt.make_ptr(
+ cutlass.Float32,
+ lse.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_logits,
+ ptr_target,
+ ptr_loss,
+ ptr_lse,
+ Int32(int(M)),
+ Int32(int(logits.stride(0))),
+ Int32(int(ignore_index)),
+ stream,
+ )
+ _PTR_FWD_COMPILE_CACHE[key] = compiled
+
+ ptr_logits = rt.make_ptr(
+ dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_target = rt.make_ptr(
+ cutlass.Int64,
+ target.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=8,
+ )
+ ptr_loss = rt.make_ptr(
+ cutlass.Float32,
+ loss.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ ptr_lse = rt.make_ptr(
+ cutlass.Float32,
+ lse.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ compiled(
+ ptr_logits,
+ ptr_target,
+ ptr_loss,
+ ptr_lse,
+ Int32(int(M)),
+ Int32(int(logits.stride(0))),
+ Int32(int(ignore_index)),
+ stream,
+ )
+
+
+def _cross_entropy_backward_ptr_into(
+ *,
+ logits: Tensor,
+ target: Tensor,
+ dloss: Tensor,
+ lse: Tensor,
+ dx: Tensor,
+ ignore_index: int,
+) -> None:
+ assert logits.is_cuda and logits.dim() == 2
+ assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0]
+ assert target.dtype is torch.int64
+ assert dloss.is_cuda and dloss.shape == (logits.shape[0],) and dloss.dtype is torch.float32
+ assert lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32
+ assert dx.is_cuda and dx.shape == logits.shape and dx.dtype == logits.dtype
+ assert dx.stride() == logits.stride(), "Pointer path expects dx to match logits strides"
+
+ M, N = logits.shape
+ device_index = logits.get_device()
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream))
+
+ dtype_x = TORCH2CUTE_DTYPE[logits.dtype]
+ key = ("ptr_bwd", int(N), dtype_x, int(device_index))
+ compiled = _PTR_BWD_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = CrossEntropyBackwardSM100(dtype_x, int(N))
+ ptr_logits = rt.make_ptr(
+ dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_target = rt.make_ptr(
+ cutlass.Int64,
+ target.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=8,
+ )
+ ptr_dloss = rt.make_ptr(
+ cutlass.Float32,
+ dloss.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ ptr_dx = rt.make_ptr(
+ dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_lse = rt.make_ptr(
+ cutlass.Float32,
+ lse.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_logits,
+ ptr_target,
+ ptr_dloss,
+ ptr_dx,
+ ptr_lse,
+ Int32(int(M)),
+ Int32(int(logits.stride(0))),
+ Int32(int(ignore_index)),
+ stream,
+ )
+ _PTR_BWD_COMPILE_CACHE[key] = compiled
+
+ ptr_logits = rt.make_ptr(
+ dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_target = rt.make_ptr(
+ cutlass.Int64,
+ target.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=8,
+ )
+ ptr_dloss = rt.make_ptr(
+ cutlass.Float32,
+ dloss.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ ptr_dx = rt.make_ptr(dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_lse = rt.make_ptr(
+ cutlass.Float32,
+ lse.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ compiled(
+ ptr_logits,
+ ptr_target,
+ ptr_dloss,
+ ptr_dx,
+ ptr_lse,
+ Int32(int(M)),
+ Int32(int(logits.stride(0))),
+ Int32(int(ignore_index)),
+ stream,
+ )
+
+
+def cross_entropy_backward(
+ dloss: Tensor,
+ logits: Tensor,
+ target: Tensor,
+ lse: Tensor,
+ ignore_index: int = -100,
+) -> Tensor:
+ """SM100 CuteDSL cross-entropy backward pass.
+
+ Args:
+ dloss: Upstream gradient of shape ``(M,)`` corresponding to
+ ``reduction="none"``.
+ logits: Input logits tensor of shape ``(M, N)``.
+ target: Integer class indices of shape ``(M,)``.
+ lse: Per-example log-sum-exp tensor of shape ``(M,)`` as returned
+ by :func:`cross_entropy_forward`.
+ ignore_index: Target value to ignore in gradient computation.
+
+ Returns:
+ ``dlogits`` of shape ``(M, N)`` with the same dtype as ``logits``.
+ """
+ assert logits.dim() == 2, "logits must be 2D (M, N)"
+ assert dloss.dim() == 1, "dloss must be 1D (M,)"
+ assert logits.size(0) == dloss.size(0), "Batch dimensions must match"
+ assert logits.is_cuda and dloss.is_cuda, "logits and dloss must be on CUDA device"
+
+ dx = torch.empty_like(logits)
+ _cross_entropy_backward_sm100(
+ logits,
+ target,
+ dloss,
+ lse,
+ dx,
+ ignore_index=ignore_index,
+ )
+ return dx
+
+
+def cross_entropy(
+ logits: Tensor,
+ target: Tensor,
+ ignore_index: int = -100,
+ reduction: Literal["none", "mean", "sum"] = "mean",
+) -> Tensor:
+ """Convenience wrapper mirroring ``torch.nn.functional.cross_entropy`` reductions.
+
+ This uses :func:`cross_entropy_forward` under the hood but returns only
+ the reduced loss tensor.
+ """
+ loss, _lse = cross_entropy_forward(
+ logits,
+ target,
+ ignore_index=ignore_index,
+ reduction="none",
+ )
+ if reduction == "none":
+ return loss
+ mask = target != ignore_index
+ if reduction == "sum":
+ return loss.sum()
+ if reduction == "mean":
+ valid = mask.sum()
+ if valid > 0:
+ return loss[mask].sum() / valid.to(loss.dtype)
+ return loss.sum() * 0.0
+ raise ValueError(
+ f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'."
+ )
+
+
+def verify_cross_entropy_parity(
+ M: int,
+ N: int,
+ dtype: torch.dtype = torch.bfloat16,
+ ignore_index: int = -100,
+) -> None:
+ """Compare SM100 CuteDSL cross-entropy against PyTorch for a single shape."""
+ device = torch.device("cuda")
+ torch.manual_seed(0)
+
+ logits = 0.1 * torch.randn(M, N, device=device, dtype=dtype)
+ logits.requires_grad_(True)
+ target = torch.randint(0, N, (M,), device=device, dtype=torch.int64)
+
+ # Optionally sprinkle some ignore_index entries for robustness.
+ if ignore_index != -100:
+ mask = torch.rand(M, device=device) < 0.1
+ target[mask] = ignore_index
+
+ loss, lse = cross_entropy_forward(logits, target, ignore_index=ignore_index, reduction="none")
+
+ logits_ref = logits.detach().clone().requires_grad_()
+ target_ref = target.detach().clone()
+ loss_ref = torch.nn.functional.cross_entropy(
+ logits_ref.float(),
+ target_ref,
+ ignore_index=ignore_index,
+ reduction="none",
+ )
+
+ # Forward parity
+ if dtype in (torch.float16, torch.bfloat16):
+ atol = 5e-2
+ rtol = 5e-2
+ else:
+ atol = 1e-4
+ rtol = 1e-4
+ torch.testing.assert_close(loss, loss_ref, atol=atol, rtol=rtol)
+
+ # Backward parity
+ dloss = torch.randn_like(loss_ref)
+ (dx_ref,) = torch.autograd.grad(loss_ref, logits_ref, grad_outputs=dloss)
+ dx = cross_entropy_backward(dloss, logits, target, lse, ignore_index=ignore_index)
+ torch.testing.assert_close(dx, dx_ref.to(logits.dtype), atol=atol, rtol=rtol)
+
+
+if __name__ == "__main__":
+ # Minimal functional check when executed directly. For performance
+ # comparisons and detailed tuning, use the dedicated benchmark harness.
+ if not torch.cuda.is_available():
+ print("CUDA not available; cross-entropy parity check skipped.")
+ raise SystemExit(0)
+
+ M, N = 1024, 8192
+ dtype = torch.bfloat16
+ verify_cross_entropy_parity(M, N, dtype=dtype, ignore_index=-100)
+ print("SM100 cross-entropy CuteDSL parity check passed.")
diff --git a/oink/src/kernelagent_oink/blackwell/layernorm.py b/oink/src/kernelagent_oink/blackwell/layernorm.py
new file mode 100644
index 0000000..05b11de
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/layernorm.py
@@ -0,0 +1,1368 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+LayerNorm kernel for SM100 (Blackwell) in CuteDSL.
+
+This implementation:
+- Mirrors Quack's LayerNorm tiling / cluster policy / cp.async pipeline
+ but uses only local helpers so that it does not depend on the external
+ `quack` package at runtime.
+- Supports fp16 / bf16 / fp32 inputs with fp32 accumulation.
+- Optionally writes out per-row `rstd` and `mean` buffers for reuse in
+ backward or fused kernels.
+
+Backward is implemented with dedicated CuteDSL kernels for input and
+parameter gradients (dx, dweight, dbias), avoiding PyTorch autograd
+while matching `torch.nn.functional.layer_norm`'s gradients numerically.
+"""
+
+from __future__ import annotations
+
+import importlib.metadata
+import os
+import re
+import operator
+from typing import Optional, Tuple, Type
+
+import torch
+from torch import Tensor
+
+import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python
+
+# CuTeDSL caches generated MLIR into a tempdir under a global default
+# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across
+# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy
+# warnings (and disables cache reuse).
+if "CUTE_DSL_CACHE_DIR" not in os.environ:
+ try:
+ _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl")
+ except Exception:
+ _dsl_ver = "unknown"
+ _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver)
+ _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user"
+ _tmp = os.environ.get("TMPDIR") or "/tmp"
+ os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join(
+ _tmp, _user, f"cutlass_python_cache_{_dsl_ver}"
+ )
+
+try:
+ import cutlass # type: ignore # noqa: F401
+except Exception as e:
+ raise ImportError(
+ "kernelagent_oink.blackwell.layernorm requires CuTeDSL's Python package "
+ "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)."
+ ) from e
+
+import cutlass.cute as cute
+from cutlass import Float32, Int32, const_expr
+from cutlass.cute import runtime as rt
+from cutlass.cute.runtime import from_dlpack
+
+# Simple compile cache for the forward kernel
+_COMPILE_CACHE: dict[Tuple[int, type[cutlass.Numeric], bool, bool, bool], object] = {}
+_PTR_COMPILE_CACHE: dict[Tuple[object, ...], object] = {}
+
+# Backward compile caches: one for dx, one for parameter gradients.
+_BWD_DX_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric]], object] = {}
+_BWD_PARAM_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric], bool], object] = {}
+
+# Local helpers cloned from Quack via lite_quack so that this kernel does
+# not depend on `quack` at runtime.
+from kernelagent_oink.blackwell.lite_quack import (
+ _KERNEL_ACCEPTS_LAYOUT_ARGS,
+ TORCH2CUTE_DTYPE,
+ ReductionBase as _ReductionBase,
+ convert_from_dlpack as convert_from_dlpack_cute,
+ domain_offset_i64,
+ get_sm_count,
+ predicate_k,
+ row_reduce,
+ warp_reduce,
+)
+
+
+def _convert_row_major(t: Tensor) -> cute.Tensor:
+ """
+ Convert a 2D row-major torch.Tensor to a CuTeDSL tensor with a compact,
+ dynamic layout on the leading dimension.
+ """
+ return from_dlpack(t.detach(), assumed_align=16).mark_compact_shape_dynamic(
+ mode=0,
+ stride_order=(0, 1),
+ )
+
+
+class LayerNormSM100(_ReductionBase):
+ """
+ SM100 LayerNorm forward kernel.
+
+ This mirrors `quack.layernorm.LayerNorm`'s schedule:
+ - Stage=2 pipeline: first pass computes mean, second pass computes
+ variance / rstd and normalization.
+ - Threads-per-row and cluster_n policy follow Quack's LayerNorm
+ heuristics to keep tensor-core friendly tiles across N.
+ - Optional `reload_from` hint enables reloading X from SMEM for large-N
+ shapes to shorten register lifetimes.
+
+ Differences vs Quack:
+ - Bias is optional and supported directly in the kernel.
+ - Dtype mapping and reduction helpers come from `lite_quack`.
+ """
+
+ def __init__(self, dtype: type[cutlass.Numeric], N: int):
+ super().__init__(dtype, N, stage=2) # 2 stages for mean and var
+ # Default reload policy mirrors Quack: use SMEM reload only for
+ # very large hidden sizes. We keep this conservative for LayerNorm
+ # and tune primarily via threads-per-block / cluster_n.
+ self.reload_from: Optional[str] = None if N <= 16384 else "smem"
+ self.delay_w_load: bool = False
+
+ def _calculate_threads_per_row(self) -> int:
+ # Match Quack's LayerNorm threads-per-row buckets.
+ N = self.N
+ return (
+ 8
+ if N <= 64
+ else (
+ 16
+ if N <= 128
+ else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
+ )
+ )
+
+ def _set_cluster_n(self) -> None:
+ # Cluster_n policy mirrors quack.layernorm.LayerNorm._set_cluster_n.
+ N = self.N
+ if const_expr(self.dtype.width == 16):
+ cluster_n = (
+ 1
+ if N <= 16 * 1024
+ else (
+ 2
+ if N <= 32 * 1024
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
+ )
+ )
+ else:
+ cluster_n = (
+ 1
+ if N <= 32 * 1024
+ else (
+ 2
+ if N <= 64 * 1024
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
+ )
+ )
+ self.cluster_n = cluster_n
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor,
+ mW: cute.Tensor,
+ mB: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mRstd: Optional[cute.Tensor],
+ mMean: Optional[cute.Tensor],
+ stream: cuda.CUstream,
+ eps: Float32 = 1e-6,
+ ):
+ assert mX.element_type == self.dtype
+ assert mO.element_type == self.dtype
+
+ # Tiling and cluster policy (mirrors Quack LayerNorm).
+ self._set_cluster_n()
+ tiler_mn, tv_layout = self._get_tv_layout()
+ num_threads = (
+ cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+
+ # Expand weight / bias to match tiler_mn[0] rows per CTA.
+ mW = cute.make_tensor(
+ mW.iterator,
+ cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))),
+ )
+ if const_expr(mB is not None):
+ mB = cute.make_tensor(
+ mB.iterator,
+ cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))),
+ )
+ if const_expr(mRstd is not None):
+ mRstd = cute.make_tensor(
+ mRstd.iterator,
+ cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))),
+ )
+ if const_expr(mMean is not None):
+ mMean = cute.make_tensor(
+ mMean.iterator,
+ cute.append(mMean.layout, cute.make_layout((self.N,), stride=(0,))),
+ )
+
+ kernel = (
+ self.kernel(
+ mX,
+ mW,
+ mB,
+ mO,
+ mRstd,
+ mMean,
+ eps,
+ tv_layout,
+ tiler_mn,
+ const_expr(self.reload_from),
+ const_expr(self.delay_w_load),
+ )
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(
+ mX,
+ mW,
+ mB,
+ mO,
+ mRstd,
+ mMean,
+ eps,
+ )
+ )
+ kernel.launch(
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[
+ 1,
+ self.cluster_n,
+ 1,
+ ]
+ if const_expr(self.cluster_n > 1)
+ else None,
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_x: cute.Pointer,
+ ptr_w: cute.Pointer,
+ ptr_b: Optional[cute.Pointer],
+ ptr_out: cute.Pointer,
+ ptr_rstd: Optional[cute.Pointer],
+ ptr_mean: Optional[cute.Pointer],
+ M: Int32,
+ ld: Int32,
+ stream: cuda.CUstream,
+ eps: Float32 = 1e-6,
+ ) -> None:
+ """Pointer-based entrypoint that bypasses DLPack conversions.
+
+ This reconstructs cute.Tensor views from raw device pointers + explicit
+ layouts inside the JIT graph, reusing the tuned LayerNormSM100 schedule.
+ """
+ # The kernel uses 128-bit vectorized copies for X. Mirror Quack's
+ # `divisibility=128 // dtype.width` contract so the compiler can
+ # prove alignment for cp.async.
+ ld_assumed = cute.assume(ld, divby=128 // self.dtype.width)
+ # Match `mark_compact_shape_dynamic(mode=0, ...)`: M is dynamic, N is static.
+ layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1))
+ layout_n = cute.make_layout((self.N,), stride=(1,))
+ layout_m = cute.make_layout((M,), stride=(1,))
+
+ mX = cute.make_tensor(ptr_x, layout_mn)
+ mO = cute.make_tensor(ptr_out, layout_mn)
+ mW = cute.make_tensor(ptr_w, layout_n)
+ mB = cute.make_tensor(ptr_b, layout_n) if const_expr(ptr_b is not None) else None
+ mRstd = (
+ cute.make_tensor(ptr_rstd, layout_m)
+ if const_expr(ptr_rstd is not None)
+ else None
+ )
+ mMean = (
+ cute.make_tensor(ptr_mean, layout_m)
+ if const_expr(ptr_mean is not None)
+ else None
+ )
+
+ self.__call__(mX, mW, mB, mO, mRstd, mMean, stream, eps)
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor,
+ mW: cute.Tensor,
+ mB: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mRstd: Optional[cute.Tensor],
+ mMean: Optional[cute.Tensor],
+ eps: Float32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ reload_from: cutlass.Constexpr,
+ delay_w_load: cutlass.Constexpr,
+ ):
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ if const_expr(self.cluster_n > 1):
+ cluster_y = cute.arch.block_idx()[1]
+ else:
+ cluster_y = const_expr(0)
+
+ smem = cutlass.utils.SmemAllocator()
+ sX = smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
+
+ shape = mX.shape
+ idX = cute.make_identity_tensor(shape)
+
+ # Slice for CTAs: use domain_offset_i64 to handle >2^31 elements.
+ mX, mO = [
+ domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)
+ ]
+ gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
+ gB = (
+ cute.local_tile(mB, tiler_mn, (0, cluster_y))
+ if const_expr(mB is not None)
+ else None
+ )
+ gRstd = (
+ cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
+ if const_expr(mRstd is not None)
+ else None
+ )
+ gMean = (
+ cute.local_tile(mMean, tiler_mn, (bidx, cluster_y))
+ if const_expr(mMean is not None)
+ else None
+ )
+
+ # Copy atoms for X / W / B / O.
+ copy_atom_load_X = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ mX.element_type,
+ num_bits_per_copy=128,
+ )
+ copy_atom_load_X_async = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ mX.element_type,
+ num_bits_per_copy=128,
+ )
+ copy_atom_load_WB = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ mW.element_type,
+ num_bits_per_copy=128,
+ )
+ copy_atom_store_O = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ mO.element_type,
+ num_bits_per_copy=128,
+ )
+
+ thr_copy_X = cute.make_tiled_copy(
+ copy_atom_load_X_async,
+ tv_layout,
+ tiler_mn,
+ ).get_slice(tidx)
+ thr_copy_WB = cute.make_tiled_copy(
+ copy_atom_load_WB,
+ tv_layout,
+ tiler_mn,
+ ).get_slice(tidx)
+ thr_copy_O = cute.make_tiled_copy(
+ copy_atom_store_O,
+ tv_layout,
+ tiler_mn,
+ ).get_slice(tidx)
+
+ tWgW = thr_copy_WB.partition_S(gW)
+ tBgB = (
+ thr_copy_WB.partition_S(gB)
+ if const_expr(gB is not None)
+ else None
+ )
+ tXgX = thr_copy_X.partition_S(gX)
+ tXsX = thr_copy_X.partition_D(sX)
+ tXgO = thr_copy_O.partition_D(gO)
+ tXrRstd = (
+ thr_copy_O.partition_D(gRstd)
+ if const_expr(mRstd is not None)
+ else None
+ )
+ tXrMean = (
+ thr_copy_O.partition_D(gMean)
+ if const_expr(mMean is not None)
+ else None
+ )
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
+
+ # Fragments for gmem->rmem.
+ tWrW = cute.make_fragment_like(tWgW)
+ tBrB = (
+ cute.make_fragment_like(tBgB)
+ if const_expr(mB is not None)
+ else None
+ )
+ tXrW = thr_copy_X.retile(tWrW)
+ tXrB = (
+ thr_copy_X.retile(tBrB)
+ if const_expr(mB is not None)
+ else None
+ )
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
+
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
+ self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=False)
+
+ tXpX = predicate_k(
+ thr_copy_X.partition_S(cX),
+ limit=shape[1],
+ )
+ row = tXcX[0][0]
+ if row < shape[0]:
+ cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX)
+ cute.arch.cp_async_commit_group()
+
+ tWpW = predicate_k(
+ thr_copy_WB.partition_S(cX),
+ limit=shape[1],
+ )
+ if const_expr(not delay_w_load):
+ cute.copy(copy_atom_load_WB, tWgW, tWrW, pred=tWpW)
+ if const_expr(mB is not None):
+ cute.copy(copy_atom_load_WB, tBgB, tBrB, pred=tWpW)
+
+ cute.arch.cp_async_wait_group(0)
+ cute.autovec_copy(tXsX, tXrX)
+ x = tXrX.load().to(Float32)
+ threads_per_row = tv_layout.shape[0][0]
+ sum_x = row_reduce(
+ x,
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
+ init_val=0.0,
+ hook_fn=(
+ cute.arch.cluster_wait
+ if const_expr(self.cluster_n > 1)
+ else None
+ ),
+ )
+ mean = sum_x / shape[1]
+
+ if const_expr(reload_from == "smem"):
+ cute.autovec_copy(tXsX, tXrX)
+ x = tXrX.load().to(Float32)
+ elif const_expr(reload_from == "gmem"):
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
+ x = tXrX.load().to(Float32)
+
+ sum_sq_x_sub_mean = row_reduce(
+ (x - mean) * (x - mean),
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 1],
+ mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None,
+ init_val=0.0,
+ )
+ rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True)
+
+ if const_expr(mRstd is not None):
+ if (
+ tXcX[0][1] == 0
+ and row < shape[0]
+ and (
+ self.cluster_n == 1
+ or cute.arch.block_idx_in_cluster() == 0
+ )
+ ):
+ tXrRstd[0] = rstd
+
+ if const_expr(mMean is not None):
+ if (
+ tXcX[0][1] == 0
+ and row < shape[0]
+ and (
+ self.cluster_n == 1
+ or cute.arch.block_idx_in_cluster() == 0
+ )
+ ):
+ tXrMean[0] = mean
+
+ if const_expr(delay_w_load):
+ cute.copy(copy_atom_load_WB, tWgW, tWrW, pred=tWpW)
+ if const_expr(mB is not None):
+ cute.copy(copy_atom_load_WB, tBgB, tBrB, pred=tWpW)
+
+ if const_expr(reload_from == "smem"):
+ cute.autovec_copy(tXsX, tXrX)
+ x = tXrX.load().to(Float32)
+ elif const_expr(reload_from == "gmem"):
+ cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX)
+ x = tXrX.load().to(Float32)
+
+ x_hat = (x - mean) * rstd
+ w = tXrW.load().to(Float32)
+ y = x_hat * w
+ if const_expr(mB is not None):
+ b = tXrB.load().to(Float32)
+ y = y + b
+
+ tXrO.store(y.to(tXrO.element_type))
+ tOpO = predicate_k(
+ thr_copy_O.partition_S(cX),
+ limit=shape[1],
+ )
+ if row < shape[0]:
+ cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO)
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: cute.Tensor,
+ mB: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mRstd: Optional[cute.Tensor],
+ mMean: Optional[cute.Tensor],
+ eps: Float32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ reload_from: cutlass.Constexpr,
+ delay_w_load: cutlass.Constexpr,
+ ):
+ self._kernel_impl(
+ mX,
+ mW,
+ mB,
+ mO,
+ mRstd,
+ mMean,
+ eps,
+ tv_layout,
+ tiler_mn,
+ reload_from,
+ delay_w_load,
+ )
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: cute.Tensor,
+ mB: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mRstd: Optional[cute.Tensor],
+ mMean: Optional[cute.Tensor],
+ eps: Float32,
+ ):
+ tiler_mn, tv_layout = self._get_tv_layout()
+ self._kernel_impl(
+ mX,
+ mW,
+ mB,
+ mO,
+ mRstd,
+ mMean,
+ eps,
+ tv_layout,
+ tiler_mn,
+ const_expr(self.reload_from),
+ const_expr(self.delay_w_load),
+ )
+
+
+# -----------------------------------------------------------------------------
+# Public Python API
+# -----------------------------------------------------------------------------
+
+
+def layernorm(
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor] = None,
+ eps: float = 1e-6,
+ return_rstd: bool = False,
+ return_mean: bool = False,
+):
+ """
+ LayerNorm forward pass using the SM100 CuteDSL kernel.
+
+ Args:
+ x: Input tensor of shape (M, N).
+ weight: Scale parameter of shape (N,), typically fp32.
+ bias: Optional bias parameter of shape (N,).
+ eps: Small value for numerical stability.
+ return_rstd: Whether to return per-row reciprocal std (shape (M,)).
+ return_mean: Whether to return per-row mean (shape (M,)).
+ """
+ assert x.is_cuda and weight.is_cuda, "x and weight must be CUDA tensors"
+ assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand."
+ assert weight.dim() == 1, "weight must be 1D"
+ assert x.shape[1] == weight.shape[0], "Last dim of x must match weight.size(0)"
+ if bias is not None:
+ assert bias.is_cuda, "bias must be on CUDA"
+ assert bias.dim() == 1 and bias.shape[0] == weight.shape[0], (
+ "bias must be 1D and match weight"
+ )
+
+ M, N = x.shape
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+
+ rstd = torch.empty(M, device=x.device, dtype=torch.float32) if return_rstd else None
+ mean = torch.empty(M, device=x.device, dtype=torch.float32) if return_mean else None
+
+ # Fast path: bypass DLPack conversions when the inputs are in the common
+ # contiguous row-major layout and weights/bias are fp32 (Quack-style).
+ if _can_use_ptr_path(x, weight, bias):
+ out = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
+ _layernorm_forward_ptr_into(
+ x=x,
+ weight=weight,
+ bias=bias,
+ out=out,
+ rstd=rstd,
+ mean=mean,
+ eps=eps,
+ )
+ if return_mean and return_rstd:
+ return out, rstd, mean
+ if return_rstd and not return_mean:
+ return out, rstd
+ if return_mean and not return_rstd:
+ return out, mean
+ return out
+
+ out = torch.empty_like(x)
+ mX = _convert_row_major(x)
+ mO = _convert_row_major(out)
+
+ # Weight/bias live in feature dimension (N).
+ mW = convert_from_dlpack_cute(
+ weight.detach(),
+ leading_dim=0,
+ alignment=16,
+ divisibility=128 // cutlass.Float32.width,
+ )
+ mB = (
+ convert_from_dlpack_cute(
+ bias.detach(),
+ leading_dim=0,
+ alignment=16,
+ divisibility=128 // cutlass.Float32.width,
+ )
+ if bias is not None
+ else None
+ )
+
+ mRstd = (
+ from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+ if rstd is not None
+ else None
+ )
+ mMean = (
+ from_dlpack(mean.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+ if mean is not None
+ else None
+ )
+
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+ key = (N, dtype, mB is not None, mRstd is not None, mMean is not None)
+ compiled = _COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = LayerNormSM100(dtype, N)
+ compiled = cute.compile(
+ op,
+ mX,
+ mW,
+ mB,
+ mO,
+ mRstd,
+ mMean,
+ stream,
+ Float32(eps),
+ )
+ _COMPILE_CACHE[key] = compiled
+
+ compiled(
+ mX,
+ mW,
+ mB,
+ mO,
+ mRstd,
+ mMean,
+ stream,
+ Float32(eps),
+ )
+
+ if return_mean and return_rstd:
+ return out, rstd, mean
+ if return_rstd and not return_mean:
+ return out, rstd
+ if return_mean and not return_rstd:
+ return out, mean
+ return out
+
+
+def _can_use_ptr_path(x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> bool:
+ """Return True if we can safely use the pointer-based fast path.
+
+ This is intentionally conservative: we target the common inference-like
+ layout (2D row-major with stride(1)==1) and Quack-style fp32 weights.
+ """
+ if not x.is_cuda or x.dim() != 2:
+ return False
+ if x.stride(1) != 1:
+ return False
+ if not weight.is_cuda or weight.dim() != 1:
+ return False
+ if weight.dtype != torch.float32:
+ return False
+ if not weight.is_contiguous():
+ return False
+ if bias is not None:
+ if not bias.is_cuda or bias.dim() != 1:
+ return False
+ if bias.dtype != torch.float32:
+ return False
+ if not bias.is_contiguous():
+ return False
+ # Require 16B alignment for 128-bit vector copies (matches Quack's assumed_align=16).
+ if (x.data_ptr() % 16) != 0:
+ return False
+ if (weight.data_ptr() % 16) != 0:
+ return False
+ if bias is not None and (bias.data_ptr() % 16) != 0:
+ return False
+ # The kernel uses 128-bit vectorized loads; require the leading dimension
+ # to preserve 16B alignment for every row start.
+ dtype_x = TORCH2CUTE_DTYPE[x.dtype]
+ divby = 128 // dtype_x.width
+ if (x.stride(0) % divby) != 0:
+ return False
+ return True
+
+
+def _layernorm_forward_ptr_into(
+ *,
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor],
+ out: Tensor,
+ rstd: Optional[Tensor],
+ mean: Optional[Tensor],
+ eps: float,
+) -> None:
+ """Launch the pointer-based LayerNorm kernel into preallocated outputs."""
+ assert x.is_cuda and x.dim() == 2
+ M, N = x.shape
+ assert weight.is_cuda and weight.dim() == 1 and weight.shape[0] == N
+ if bias is not None:
+ assert bias.is_cuda and bias.dim() == 1 and bias.shape[0] == N
+ assert out.is_cuda and out.shape == x.shape and out.dtype == x.dtype
+ assert out.stride() == x.stride(), "Pointer path expects out to match x strides"
+ if rstd is not None:
+ assert rstd.is_cuda and rstd.shape == (M,) and rstd.dtype == torch.float32
+ if mean is not None:
+ assert mean.is_cuda and mean.shape == (M,) and mean.dtype == torch.float32
+
+ device_index = x.get_device()
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream_handle = int(torch.cuda.current_stream().cuda_stream)
+ stream = cuda.CUstream(stream_handle)
+
+ dtype_x = TORCH2CUTE_DTYPE[x.dtype]
+ key = (
+ "ptr",
+ int(N),
+ dtype_x,
+ bias is not None,
+ rstd is not None,
+ mean is not None,
+ int(device_index),
+ )
+ compiled = _PTR_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = LayerNormSM100(dtype_x, int(N))
+ ptr_x = rt.make_ptr(
+ dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_out = rt.make_ptr(
+ dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_w = rt.make_ptr(
+ cutlass.Float32,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
+ ptr_b = (
+ rt.make_ptr(
+ cutlass.Float32,
+ bias.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
+ if bias is not None
+ else None
+ )
+ ptr_rstd = (
+ rt.make_ptr(
+ cutlass.Float32,
+ rstd.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ if rstd is not None
+ else None
+ )
+ ptr_mean = (
+ rt.make_ptr(
+ cutlass.Float32,
+ mean.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ if mean is not None
+ else None
+ )
+ ld = Int32(int(x.stride(0)))
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_x,
+ ptr_w,
+ ptr_b,
+ ptr_out,
+ ptr_rstd,
+ ptr_mean,
+ Int32(int(M)),
+ ld,
+ stream,
+ Float32(float(eps)),
+ )
+ _PTR_COMPILE_CACHE[key] = compiled
+
+ ptr_x = rt.make_ptr(dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_out = rt.make_ptr(
+ dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_w = rt.make_ptr(
+ cutlass.Float32,
+ weight.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
+ ptr_b = (
+ rt.make_ptr(
+ cutlass.Float32,
+ bias.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=16,
+ )
+ if bias is not None
+ else None
+ )
+ ptr_rstd = (
+ rt.make_ptr(
+ cutlass.Float32,
+ rstd.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ if rstd is not None
+ else None
+ )
+ ptr_mean = (
+ rt.make_ptr(
+ cutlass.Float32,
+ mean.data_ptr(),
+ mem_space=rt.AddressSpace.gmem,
+ assumed_align=4,
+ )
+ if mean is not None
+ else None
+ )
+ ld = Int32(int(x.stride(0)))
+ compiled(
+ ptr_x,
+ ptr_w,
+ ptr_b,
+ ptr_out,
+ ptr_rstd,
+ ptr_mean,
+ Int32(int(M)),
+ ld,
+ stream,
+ Float32(float(eps)),
+ )
+
+
+def layernorm_ref(
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor] = None,
+ eps: float = 1e-6,
+) -> Tensor:
+ """
+ Reference LayerNorm implemented via torch.nn.functional.layer_norm.
+ """
+ x_f32 = x.float()
+ w = weight.float()
+ b = bias.float() if bias is not None else None
+ y = torch.nn.functional.layer_norm(x_f32, (x.shape[-1],), w, b, eps)
+ return y.to(x.dtype)
+
+
+def _as_2d(x: Tensor) -> Tuple[Tensor, Tuple[int, ...]]:
+ if x.dim() == 2:
+ return x, x.shape
+ original_shape = x.shape
+ M = int(torch.prod(torch.tensor(original_shape[:-1])).item())
+ N = original_shape[-1]
+ return x.reshape(M, N), original_shape
+
+
+def _restore_shape(x: Tensor, shape: Tuple[int, ...]) -> Tensor:
+ return x.reshape(shape)
+
+
+@cute.kernel
+def _layernorm_backward_dx_kernel(
+ mX: cute.Tensor,
+ mW: cute.Tensor,
+ mdO: cute.Tensor,
+ mRstd: cute.Tensor,
+ mMean: cute.Tensor,
+ mdX: cute.Tensor,
+):
+ """
+ Simple CTA-per-row LayerNorm backward kernel for dx only.
+
+ Each block processes one row of shape (N,), using block_threads threads.
+ It performs two passes over the row:
+ 1) Compute mean_wdy and mean_xhat_wdy in fp32.
+ 2) Compute dx using the standard LayerNorm backward formula:
+ dx = rstd * (wdy - mean_wdy - x_hat * mean_xhat_wdy),
+ where wdy = dy * gamma and x_hat = (x - mean) * rstd.
+ """
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+
+ block_threads = const_expr(256)
+ shape = mX.shape
+ M = shape[0]
+ N = shape[1]
+
+ row = bidx
+ if row < M:
+ # Shared buffers for warp-level reductions across the block.
+ smem = cutlass.utils.SmemAllocator()
+ num_warps = const_expr(block_threads // cute.arch.WARP_SIZE)
+ warp_sums_layout = cute.make_layout((num_warps,), stride=(1,))
+ warp_sums_wdy = smem.allocate_tensor(Float32, warp_sums_layout, byte_alignment=4)
+ warp_sums_xhatwdy = smem.allocate_tensor(Float32, warp_sums_layout, byte_alignment=4)
+
+ lane = cute.arch.lane_idx()
+ warp_idx = cute.arch.warp_idx()
+
+ rstd_val = mRstd[row].to(Float32)
+ mean_val = mMean[row].to(Float32)
+
+ # Pass 1: compute local partial sums of wdy and x_hat*wdy.
+ local_wdy = Float32(0.0)
+ local_xhatwdy = Float32(0.0)
+ for col in cutlass.range(tidx, N, block_threads):
+ x_val = mX[row, col].to(Float32)
+ dy_val = mdO[row, col].to(Float32)
+ gamma = mW[col].to(Float32)
+ x_mu = x_val - mean_val
+ x_hat = x_mu * rstd_val
+ wdy = dy_val * gamma
+ local_wdy += wdy
+ local_xhatwdy += x_hat * wdy
+
+ # Warp-level reduction, then block-level reduction via shared memory.
+ red_op = operator.add # type: ignore[assignment]
+ local_wdy = warp_reduce(local_wdy, red_op)
+ local_xhatwdy = warp_reduce(local_xhatwdy, red_op)
+
+ if lane == 0:
+ warp_sums_wdy[warp_idx] = local_wdy
+ warp_sums_xhatwdy[warp_idx] = local_xhatwdy
+
+ cute.arch.barrier()
+
+ total_wdy = Float32(0.0)
+ total_xhatwdy = Float32(0.0)
+ if warp_idx == 0 and lane == 0:
+ for wi in cutlass.range_constexpr(num_warps):
+ total_wdy += warp_sums_wdy[wi]
+ total_xhatwdy += warp_sums_xhatwdy[wi]
+ # Store totals back into first slots for broadcast.
+ warp_sums_wdy[0] = total_wdy
+ warp_sums_xhatwdy[0] = total_xhatwdy
+
+ cute.arch.barrier()
+
+ total_wdy = warp_sums_wdy[0]
+ total_xhatwdy = warp_sums_xhatwdy[0]
+ inv_N = Float32(1.0 / float(N))
+ mean_wdy = total_wdy * inv_N
+ mean_xhatwdy = total_xhatwdy * inv_N
+
+ # Pass 2: compute dx and write back.
+ for col in cutlass.range(tidx, N, block_threads):
+ x_val = mX[row, col].to(Float32)
+ dy_val = mdO[row, col].to(Float32)
+ gamma = mW[col].to(Float32)
+ x_mu = x_val - mean_val
+ x_hat = x_mu * rstd_val
+ wdy = dy_val * gamma
+ dx_val = (wdy - mean_wdy - x_hat * mean_xhatwdy) * rstd_val
+ mdX[row, col] = dx_val.to(mdX.element_type)
+
+
+@cute.jit
+def _layernorm_backward_dx(
+ mX: cute.Tensor,
+ mW: cute.Tensor,
+ mdO: cute.Tensor,
+ mRstd: cute.Tensor,
+ mMean: cute.Tensor,
+ mdX: cute.Tensor,
+ stream: cuda.CUstream,
+) -> None:
+ """
+ JIT wrapper that launches the dx-only LayerNorm backward kernel.
+ One CTA processes one row of length N with 256 threads.
+ """
+ M = mX.shape[0]
+ _layernorm_backward_dx_kernel(
+ mX,
+ mW,
+ mdO,
+ mRstd,
+ mMean,
+ mdX,
+ ).launch(
+ grid=[M, 1, 1],
+ block=[256, 1, 1],
+ stream=stream,
+ )
+
+
+@cute.kernel
+def _layernorm_backward_param_kernel(
+ mX: cute.Tensor,
+ mdO: cute.Tensor,
+ mRstd: cute.Tensor,
+ mMean: cute.Tensor,
+ mdW_partial: Optional[cute.Tensor],
+ mdB_partial: Optional[cute.Tensor],
+ num_blocks: Int32,
+) -> None:
+ """
+ Parameter-gradient kernel for LayerNorm.
+
+ Each CTA accumulates partial dweight/dbias over a stripe of rows:
+ - Grid dim X: num_blocks (sm_count-style persistent CTAs).
+ - Threads in a CTA partition the N dimension.
+ - For each assigned column, a thread streams over rows
+ row = blockIdx.x, blockIdx.x + num_blocks, ...
+
+ This mirrors the persistent-CTA pattern used by RMSNorm backward,
+ but uses a simpler per-thread accumulation since columns are
+ independent.
+ """
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+
+ block_threads = const_expr(256)
+ M = mX.shape[0]
+ N = mX.shape[1]
+
+ if bidx < num_blocks:
+ for col in cutlass.range(tidx, N, block_threads):
+ dw_local = Float32(0.0)
+ db_local = Float32(0.0)
+ for row in cutlass.range(bidx, M, num_blocks):
+ x_val = mX[row, col].to(Float32)
+ dy_val = mdO[row, col].to(Float32)
+ rstd_val = mRstd[row].to(Float32)
+ mean_val = mMean[row].to(Float32)
+ x_mu = x_val - mean_val
+ x_hat = x_mu * rstd_val
+ dw_local += dy_val * x_hat
+ db_local += dy_val
+
+ if const_expr(mdW_partial is not None):
+ mdW_partial[bidx, col] = dw_local
+ if const_expr(mdB_partial is not None):
+ mdB_partial[bidx, col] = db_local
+
+
+@cute.jit
+def _layernorm_backward_param(
+ mX: cute.Tensor,
+ mdO: cute.Tensor,
+ mRstd: cute.Tensor,
+ mMean: cute.Tensor,
+ mdW_partial: Optional[cute.Tensor],
+ mdB_partial: Optional[cute.Tensor],
+ num_blocks: Int32,
+ stream: cuda.CUstream,
+) -> None:
+ """
+ JIT wrapper that launches the parameter-gradient kernel.
+ """
+ _layernorm_backward_param_kernel(
+ mX,
+ mdO,
+ mRstd,
+ mMean,
+ mdW_partial,
+ mdB_partial,
+ num_blocks,
+ ).launch(
+ grid=[num_blocks, 1, 1],
+ block=[256, 1, 1],
+ stream=stream,
+ )
+
+
+def _layernorm_backward_dx_sm100(
+ dout_2d: Tensor,
+ x_2d: Tensor,
+ weight: Tensor,
+ rstd_1d: Tensor,
+ mean_1d: Tensor,
+ dx_2d: Tensor,
+) -> None:
+ """
+ Host-side helper to run the dx-only LayerNorm backward kernel.
+ """
+ M, N = x_2d.shape
+ assert dout_2d.shape == (M, N)
+ assert rstd_1d.numel() == M
+ assert mean_1d.numel() == M
+
+ dtype = TORCH2CUTE_DTYPE[x_2d.dtype]
+
+ mX = _convert_row_major(x_2d)
+ mdO = _convert_row_major(dout_2d)
+ mdX = _convert_row_major(dx_2d)
+
+ mW = convert_from_dlpack_cute(
+ weight.detach(),
+ leading_dim=0,
+ alignment=16,
+ divisibility=128 // cutlass.Float32.width,
+ )
+ mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+ mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+ key = (N, dtype)
+ compiled = _BWD_DX_COMPILE_CACHE.get(key)
+ if compiled is None:
+ compiled = cute.compile(
+ _layernorm_backward_dx,
+ mX,
+ mW,
+ mdO,
+ mRstd,
+ mMean,
+ mdX,
+ stream,
+ )
+ _BWD_DX_COMPILE_CACHE[key] = compiled
+
+ compiled(
+ mX,
+ mW,
+ mdO,
+ mRstd,
+ mMean,
+ mdX,
+ stream,
+ )
+
+
+def _layernorm_backward_params_sm100(
+ dout_2d: Tensor,
+ x_2d: Tensor,
+ rstd_1d: Tensor,
+ mean_1d: Tensor,
+ dw_partial: Optional[Tensor],
+ db_partial: Optional[Tensor],
+ sm_count: int,
+) -> None:
+ """
+ Host-side helper to run the parameter-gradient kernel that populates
+ dw_partial / db_partial of shape (sm_count, N).
+ """
+ M, N = x_2d.shape
+ assert dout_2d.shape == (M, N)
+ assert rstd_1d.numel() == M
+ assert mean_1d.numel() == M
+ if dw_partial is None and db_partial is None:
+ return
+
+ dtype = TORCH2CUTE_DTYPE[x_2d.dtype]
+
+ mX = _convert_row_major(x_2d)
+ mdO = _convert_row_major(dout_2d)
+ mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+ mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+
+ mdW_partial = (
+ from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
+ if dw_partial is not None
+ else None
+ )
+ mdB_partial = (
+ from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
+ if db_partial is not None
+ else None
+ )
+
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+ has_bias = db_partial is not None
+ key = (N, dtype, has_bias)
+ compiled = _BWD_PARAM_COMPILE_CACHE.get(key)
+ if compiled is None:
+ compiled = cute.compile(
+ _layernorm_backward_param,
+ mX,
+ mdO,
+ mRstd,
+ mMean,
+ mdW_partial,
+ mdB_partial,
+ Int32(sm_count),
+ stream,
+ )
+ _BWD_PARAM_COMPILE_CACHE[key] = compiled
+
+ compiled(
+ mX,
+ mdO,
+ mRstd,
+ mMean,
+ mdW_partial,
+ mdB_partial,
+ Int32(sm_count),
+ stream,
+ )
+
+
+def layernorm_backward(
+ dout: Tensor,
+ x: Tensor,
+ weight: Tensor,
+ rstd: Tensor,
+ mean: Tensor,
+ bias: Optional[Tensor] = None,
+) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ """
+ LayerNorm backward implemented in CuteDSL / CUTLASS.
+
+ Computes gradients w.r.t. input, weight, and optional bias using
+ two kernels:
+ - A dx kernel (CTA-per-row) that streams over N.
+ - A parameter-gradient kernel that accumulates dw/db over a
+ persistent grid of CTAs across the M dimension.
+ """
+ assert x.shape == dout.shape, "x and dout must have the same shape"
+ assert x.is_cuda and dout.is_cuda, "x and dout must be CUDA tensors"
+ assert weight.dim() == 1, "weight must be 1D"
+ if bias is not None:
+ assert bias.dim() == 1, "bias must be 1D"
+
+ x_2d, orig_shape = _as_2d(x)
+ dout_2d, _ = _as_2d(dout)
+ M, N = x_2d.shape
+
+ # Flatten to 2D for the kernels.
+ mean_flat = mean.view(M)
+ rstd_flat = rstd.view(M)
+
+ dx_2d = torch.empty_like(x_2d)
+ _layernorm_backward_dx_sm100(
+ dout_2d,
+ x_2d,
+ weight,
+ rstd_flat,
+ mean_flat,
+ dx_2d,
+ )
+
+ device = x.device
+ sm_count = get_sm_count(N, device, M=M, dtype=x.dtype)
+
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
+ db_partial = (
+ torch.empty(sm_count, N, device=device, dtype=torch.float32)
+ if bias is not None
+ else None
+ )
+
+ _layernorm_backward_params_sm100(
+ dout_2d,
+ x_2d,
+ rstd_flat,
+ mean_flat,
+ dw_partial,
+ db_partial,
+ sm_count,
+ )
+
+ dweight = dw_partial.sum(dim=0).to(weight.dtype)
+ dbias = db_partial.sum(dim=0).to(bias.dtype) if bias is not None else None
+
+ dx = _restore_shape(dx_2d, orig_shape)
+ return dx, dweight, dbias
+
+
+if __name__ == "__main__":
+ # Allow direct execution for a quick functional check.
+ if not torch.cuda.is_available():
+ print("CUDA not available; LayerNormSM100 test skipped.")
+ raise SystemExit(0)
+
+ device = "cuda"
+ M, N = 2048, 4096
+ dtype = torch.bfloat16
+ x = torch.randn(M, N, device=device, dtype=dtype)
+ w = torch.randn(N, device=device, dtype=torch.float32)
+ b = torch.randn(N, device=device, dtype=torch.float32)
+
+ y_ref = layernorm_ref(x, w, b)
+ y, rstd, mean = layernorm(x, w, b, return_rstd=True, return_mean=True)
+ torch.testing.assert_close(
+ y,
+ y_ref,
+ atol=5e-2 if dtype != torch.float32 else 1e-5,
+ rtol=5e-2 if dtype != torch.float32 else 1e-5,
+ )
+
+ print("LayerNormSM100 forward correctness check passed.")
diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py
index 14ae723..1bc15b1 100644
--- a/oink/src/kernelagent_oink/blackwell/lite_quack.py
+++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py
@@ -13,21 +13,24 @@
# limitations under the License.
"""
-Lightweight local clone of the small subset of helpers that the SM100
+Lightweight local clone of the small subset of Quack helpers that the SM100
RMSNorm CuteDSL kernels depend on.
This module intentionally avoids importing the `quack` package so that
-Oink Blackwell kernels can run without Quack installed, while keeping
-numerical behaviour and performance close to the original reference
-implementations.
+KernelAgent Oink SM100 kernels can run without Quack installed, while keeping
+numerical behaviour and performance identical to the reference kernels.
"""
from __future__ import annotations
import math
import operator
-from typing import Callable, Optional
+import importlib.metadata
+import re
+from functools import partial
+from typing import Callable, Optional, Tuple, Type
+import cuda.bindings.driver as cuda # type: ignore
import torch
from torch import Tensor
@@ -36,11 +39,39 @@
from cutlass import Float32, Int32, const_expr
from cutlass.cute.runtime import from_dlpack
from cutlass.cutlass_dsl import T, dsl_user_op
-from cutlass._mlir.dialects import llvm
+from cutlass._mlir.dialects import llvm, nvvm, vector
+
+
+def _parse_version_tuple(version: str) -> tuple[int, int, int]:
+ parts = version.split(".")
+ nums: list[int] = []
+ for part in parts[:3]:
+ match = re.match(r"^(\d+)", part)
+ nums.append(int(match.group(1)) if match is not None else 0)
+ while len(nums) < 3:
+ nums.append(0)
+ return nums[0], nums[1], nums[2]
+
+
+def _cutlass_dsl_version() -> Optional[tuple[int, int, int]]:
+ try:
+ return _parse_version_tuple(importlib.metadata.version("nvidia-cutlass-dsl"))
+ except Exception:
+ return None
+
+
+_CUTLASS_DSL_VERSION = _cutlass_dsl_version()
+# CuTeDSL 4.3.4 tightened some kernel argument expectations (notably around
+# passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the
+# older signature for <4.3.4, but switch to a 4.3.4+ compatible signature when
+# we detect 4.3.4+ (or when version detection is unavailable).
+_KERNEL_ACCEPTS_LAYOUT_ARGS = (
+ _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4)
+)
# -------------------------
-# Dtype mapping
+# Dtype mapping (from quack.cute_dsl_utils)
# -------------------------
TORCH2CUTE_DTYPE = {
@@ -51,21 +82,15 @@
# -------------------------
-# Tensor conversion helpers
+# Tensor conversion helpers (from quack.utils)
# -------------------------
-
def convert_from_dlpack(
x: Tensor,
leading_dim: int,
alignment: int = 16,
divisibility: int = 1,
) -> cute.Tensor:
- """
- Wrap a torch.Tensor in a CuteDSL tensor with layout metadata that
- matches the logical leading dimension and alignment/divisibility
- constraints expected by SM100 kernels.
- """
return (
from_dlpack(x, assumed_align=alignment)
.mark_layout_dynamic(leading_dim=leading_dim)
@@ -78,14 +103,12 @@ def convert_from_dlpack(
# -------------------------
-# SM90/SM100 cluster helpers
+# SM90/SM100 cluster helpers (from quack.utils)
# -------------------------
@dsl_user_op
-def elem_pointer(
- x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None
-) -> cute.Pointer:
+def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
@@ -136,9 +159,7 @@ def store_shared_remote(
).ir_value()
if const_expr(isinstance(val, float)):
val = Float32(val)
- assert isinstance(val, (Float32, Int32, cutlass.Int64)), (
- "val must be Float32, Int32, or Int64"
- )
+ assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)]
llvm.inline_asm(
@@ -154,37 +175,22 @@ def store_shared_remote(
@cute.jit
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
- """
- Build a predicate tensor for the K dimension only. Values beyond
- `limit` are masked out.
- """
+ # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if".
tApA = cute.make_fragment(
cute.make_layout(
- (
- cute.size(tAcA, mode=[0, 1]),
- cute.size(tAcA, mode=[1]),
- cute.size(tAcA, mode=[2]),
- ),
+ (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
stride=(cute.size(tAcA, mode=[2]), 0, 1),
),
cutlass.Boolean,
)
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
- tApA[rest_v, 0, rest_k] = cute.elem_less(
- tAcA[(0, rest_v), 0, rest_k][1], limit
- )
+ tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
return tApA
@dsl_user_op
-def domain_offset_i64(
- coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
-) -> cute.Tensor:
- """
- Return a tensor whose iterator is offset by an Int64 byte offset
- computed from `coord` and the tensor's strides.
- """
+def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
flat_stride = cute.flatten_to_tuple(tensor.stride)
assert len(flat_coord_i64) == len(flat_stride), (
@@ -201,8 +207,81 @@ def domain_offset_i64(
return cute.make_tensor(new_ptr, tensor.layout)
+@dsl_user_op
+def coord_offset_i64(
+ idx: cute.typing.Int,
+ tensor: cute.Tensor,
+ dim: int,
+ *,
+ loc=None,
+ ip=None,
+) -> cute.Tensor:
+ offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim])
+ assert isinstance(tensor.iterator, cute.Pointer)
+ new_ptr = cute.make_ptr(
+ tensor.element_type,
+ tensor.iterator.toint() + offset * tensor.element_type.width // 8,
+ tensor.memspace,
+ assumed_align=tensor.iterator.max_alignment,
+ )
+ return cute.make_tensor(new_ptr, tensor.layout)
+
+
+@cute.jit
+def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cutlass.Numeric) -> None:
+ """Fill out-of-bounds values in shared memory tensor."""
+ tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0])
+ tXrX_fill.fill(fill_value)
+ for rest_v in cutlass.range_constexpr(const_expr(tXsX.shape[0][1])):
+ for rest_k in cutlass.range_constexpr(const_expr(tXsX.shape[2])):
+ if const_expr(tXpX is not None):
+ if not tXpX[rest_v, 0, rest_k]:
+ cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
+ else:
+ cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k])
+
+
+@dsl_user_op
+def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64:
+ """Pack two f32 values into a single i64.
+
+ This mirrors quack.utils.f32x2_to_i64 and is used by online_softmax_reduce
+ to store (max, sum_exp) pairs in an Int64 reduction buffer.
+ """
+ vec_f32x2 = vector.from_elements(
+ T.vector(2, T.f32()),
+ (a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)),
+ loc=loc,
+ ip=ip,
+ )
+ vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2, loc=loc, ip=ip)
+ res = cutlass.Int64(
+ vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
+ )
+ return res
+
+
+@dsl_user_op
+def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
+ """Unpack a single i64 into two f32 values, inverse of f32x2_to_i64."""
+ vec_i64x1 = vector.from_elements(
+ T.vector(1, T.i64()),
+ (c.ir_value(loc=loc, ip=ip),),
+ loc=loc,
+ ip=ip,
+ )
+ vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1, loc=loc, ip=ip)
+ res0 = Float32(
+ vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
+ )
+ res1 = Float32(
+ vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
+ )
+ return res0, res1
+
+
# -------------------------
-# Reduction helpers
+# Reduction helpers (from quack.reduce)
# -------------------------
@@ -212,10 +291,6 @@ def warp_reduce(
op: Callable,
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
) -> cute.TensorSSA | cute.Numeric:
- """
- Warp-level reduction for either scalar values or small TensorSSA
- fragments.
- """
if cutlass.const_expr(isinstance(val, cute.TensorSSA)):
res = cute.make_fragment(val.shape, val.dtype)
res.store(val)
@@ -234,7 +309,7 @@ def block_reduce(
reduction_buffer: cute.Tensor,
init_val: cute.Numeric = 0.0,
) -> cute.Numeric:
- """Block-level reduction across warps."""
+ """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)."""
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
warps_per_row = cute.size(reduction_buffer.shape[1])
row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
@@ -256,10 +331,7 @@ def cluster_reduce(
init_val: cute.Numeric = 0.0,
phase: Optional[cutlass.Int32] = None,
) -> cute.Numeric:
- """
- Cluster-wide reduction using shared memory and mbarrier. The
- reduction_buffer has shape (rows_per_block, (warps_per_row, cluster_n)).
- """
+ """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))."""
cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
@@ -297,12 +369,10 @@ def block_or_cluster_reduce(
phase: Optional[cutlass.Int32] = None,
init_val: cute.Numeric = 0.0,
) -> cute.Numeric:
- """Dispatch between block or cluster reduction depending on mbar_ptr."""
+ """Perform either block or cluster reduction based on whether mbar_ptr is provided."""
if cutlass.const_expr(mbar_ptr is None):
return block_reduce(val, op, reduction_buffer, init_val=init_val)
- return cluster_reduce(
- val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase
- )
+ return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase)
@cute.jit
@@ -316,21 +386,14 @@ def row_reduce(
init_val: cute.Numeric = 0.0,
hook_fn: Optional[Callable] = None,
) -> cute.Numeric:
- """
- Row-wise reduction used by RMSNorm and similar kernels.
-
- reduction_buffer must have shape
- (num_warps / warps_per_row, (warps_per_row, cluster_n)).
- """
+ """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))."""
if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
val = x.reduce(op, init_val=init_val, reduction_profile=0)
else:
val = x
warp_op = {
cute.ReductionOp.ADD: operator.add,
- cute.ReductionOp.MAX: cute.arch.fmax
- if cutlass.const_expr(x.dtype == Float32)
- else max,
+ cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
cute.ReductionOp.MIN: min,
cute.ReductionOp.MUL: operator.mul,
}[op]
@@ -358,28 +421,808 @@ def row_reduce(
return val
+@cute.jit
+def row_reduce_add(
+ x: cute.TensorSSA | cute.Numeric,
+ threads_per_row: cutlass.Constexpr[int],
+ reduction_buffer: Optional[cute.Tensor] = None,
+ mbar_ptr: Optional[cute.Pointer] = None,
+ phase: Optional[cutlass.Int32] = None,
+ init_val: cute.Numeric = 0.0,
+ hook_fn: Optional[Callable] = None,
+) -> cute.Numeric:
+ """Specialized row_reduce for ADD reductions.
+
+ This mirrors row_reduce but hardcodes the ADD operation so we avoid
+ dynamic dispatch on the reduction op. It is used by bandwidth-bound
+ kernels like RMSNorm backward where the reduction is always ADD in
+ Float32.
+ """
+ if cutlass.const_expr(isinstance(x, cute.TensorSSA)):
+ val = x.reduce(cute.ReductionOp.ADD, init_val=init_val, reduction_profile=0)
+ else:
+ val = x
+ val = warp_reduce(
+ val,
+ operator.add,
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
+ )
+ if cutlass.const_expr(hook_fn is not None):
+ hook_fn()
+ if cutlass.const_expr(reduction_buffer is not None):
+ warps_per_row, cluster_n = reduction_buffer.shape[1]
+ assert cluster_n == 1 or mbar_ptr is not None, (
+ "mbar_ptr must be provided for cluster reduction"
+ )
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
+ val = block_or_cluster_reduce(
+ val,
+ operator.add,
+ reduction_buffer,
+ mbar_ptr,
+ phase=phase,
+ init_val=init_val,
+ )
+ return val
+
+
+@cute.jit
+def online_softmax_reduce(
+ x: cute.TensorSSA,
+ threads_per_row: cutlass.Constexpr[int],
+ reduction_buffer: Optional[cute.Tensor] = None,
+ mbar_ptr: Optional[cute.Pointer] = None,
+ hook_fn: Optional[Callable] = None,
+ phase: Optional[cutlass.Int32] = None,
+ return_exp_x: bool = False,
+) -> tuple[Float32, Float32, Optional[cute.TensorSSA]]:
+ """Online softmax reduction over a row.
+
+ This mirrors quack.reduce.online_softmax_reduce and computes:
+ - max_x: row-wise maximum of x
+ - sum_exp_x: row-wise sum of exp(x - max_x)
+ - exp_x (optional): per-element exp(x - max_x_final) if return_exp_x is True
+ """
+ assert x.dtype == Float32, "x must be of type Float32"
+ # reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)
+ max_x = warp_reduce(
+ x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
+ cute.arch.fmax,
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
+ )
+ log2_e = math.log2(math.e)
+ exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
+ sum_exp_x = warp_reduce(
+ exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
+ operator.add,
+ width=min(threads_per_row, cute.arch.WARP_SIZE),
+ )
+ if cutlass.const_expr(hook_fn is not None):
+ hook_fn()
+ if cutlass.const_expr(reduction_buffer is not None):
+ rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape
+ assert cluster_n == 1 or mbar_ptr is not None, (
+ "mbar_ptr must be provided for cluster reduction"
+ )
+ if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1):
+ assert reduction_buffer.element_type == cutlass.Int64, (
+ "reduction_buffer must be of type Int64"
+ )
+ lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx()
+ row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row
+ if cutlass.const_expr(mbar_ptr is None):
+ if lane_idx == 0:
+ reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x)
+ cute.arch.barrier()
+ max_x_single_warp = -Float32.inf
+ sum_exp_x = 0.0
+ if lane_idx < warps_per_row:
+ max_x_single_warp, sum_exp_x = i64_to_f32x2(
+ reduction_buffer[row_idx, lane_idx]
+ )
+ max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
+ sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True)
+ sum_exp_x = warp_reduce(sum_exp_x, operator.add)
+ if cutlass.const_expr(return_exp_x):
+ exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
+ max_x = max_x_final
+ else:
+ cta_rank_in_cluster = cute.arch.block_idx_in_cluster()
+ if warp_idx == 0:
+ with cute.arch.elect_one():
+ num_warps = rows_per_block * warps_per_row
+ cute.arch.mbarrier_arrive_and_expect_tx(
+ mbar_ptr,
+ num_warps * cluster_n * reduction_buffer.element_type.width // 8,
+ )
+ if lane_idx < cluster_n:
+ store_shared_remote(
+ f32x2_to_i64(max_x, sum_exp_x),
+ elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
+ mbar_ptr,
+ peer_cta_rank_in_cluster=lane_idx,
+ )
+ cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
+ num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
+ max_x_single_warp = cute.make_fragment(num_iter, Float32)
+ max_x_single_warp.fill(-Float32.inf)
+ sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32)
+ sum_exp_x_single_warp.fill(0.0)
+ for i in cutlass.range_constexpr(num_iter):
+ idx = lane_idx + i * cute.arch.WARP_SIZE
+ if idx < cute.size(reduction_buffer, mode=[1]):
+ max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2(
+ reduction_buffer[row_idx, idx]
+ )
+ max_x_final = max_x_single_warp.load().reduce(
+ cute.ReductionOp.MAX,
+ init_val=-Float32.inf,
+ reduction_profile=0,
+ )
+ max_x_final = warp_reduce(max_x_final, cute.arch.fmax)
+ sum_exp_x = 0.0
+ for i in cutlass.range_constexpr(num_iter):
+ sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp(
+ max_x_single_warp[i] - max_x_final,
+ fastmath=True,
+ )
+ sum_exp_x = warp_reduce(sum_exp_x, operator.add)
+ if cutlass.const_expr(return_exp_x):
+ exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
+ max_x = max_x_final
+ return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None)
+
+
+# -------------------------
+# Copy helpers (minimal subset of quack.copy_utils)
+# -------------------------
+
+
+@dsl_user_op
+def get_copy_atom(
+ dtype: Type[cutlass.Numeric],
+ num_copy_elems: int,
+ is_async: bool = False,
+ *,
+ loc=None,
+ ip=None,
+) -> cute.CopyAtom:
+ from cutlass.cute.nvgpu import cpasync
+
+ num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
+ copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
+ return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits, loc=loc, ip=ip)
+
+
+@dsl_user_op
+def copy(
+ src: cute.Tensor,
+ dst: cute.Tensor,
+ *,
+ pred: Optional[cute.Tensor] = None,
+ num_copy_elems: int = 1,
+ is_async: bool = False,
+ loc=None,
+ ip=None,
+ **kwargs,
+) -> None:
+ copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async, loc=loc, ip=ip)
+ cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
+
+
+# -------------------------
+# Reduction base (from quack.reduction_base)
+# -------------------------
+
+
+class ReductionBase:
+ def __init__(
+ self,
+ dtype: Type[cutlass.Numeric],
+ N: int,
+ stage: int,
+ reduction_dtype: Type[cutlass.Numeric] = cutlass.Float32,
+ ):
+ self.dtype = dtype
+ self.N = N
+ self.stage = stage
+ self.reduction_dtype = reduction_dtype
+
+ def _calculate_threads_per_row(self) -> int:
+ raise NotImplementedError()
+
+ def _set_cluster_n(self) -> None:
+ self.cluster_n = 1
+
+ def _get_num_threads(self) -> int:
+ return 128 if self.N <= 16384 else 256
+
+ def _get_tv_layout(self, num_copy_bits: int = 128) -> Tuple[cute.Shape, cute.Layout]:
+ vecsize = num_copy_bits // self.dtype.width
+ assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
+ num_threads = self._get_num_threads()
+ assert num_threads % cute.arch.WARP_SIZE == 0
+
+ threads_per_row = self._calculate_threads_per_row()
+ self._set_cluster_n()
+ num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
+ cols_per_block = num_threads // threads_per_row
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
+ tv_layout = cute.make_layout(
+ ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)),
+ stride=(
+ (vecsize * cols_per_block, 1),
+ (cols_per_block, cols_per_block * vecsize * threads_per_row),
+ ),
+ )
+ return tiler_mn, tv_layout
+
+ def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int:
+ return (
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
+ + self.stage * (cutlass.Int64.width // 8)
+ )
+
+ def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int) -> cute.Layout:
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
+ warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
+ return cute.make_ordered_layout(
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
+ order=(1, 0, 2),
+ )
+
+ def _allocate_reduction_buffer_and_mbar(
+ self,
+ smem: cutlass.utils.SmemAllocator,
+ tv_layout: cute.Layout,
+ is_persistent: bool = False,
+ ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
+ reduction_buffer = smem.allocate_tensor(
+ self.reduction_dtype,
+ self._get_reduction_buffer_layout(tv_layout, self.cluster_n),
+ byte_alignment=4,
+ )
+ if cutlass.const_expr(self.cluster_n > 1):
+ mbar_ptr = smem.allocate_array(
+ cutlass.Int64,
+ num_elems=self.stage if not is_persistent else self.stage * 2,
+ )
+ else:
+ mbar_ptr = None
+ return reduction_buffer, mbar_ptr
+
+ @cute.jit
+ def _initialize_cluster(
+ self,
+ tidx: cutlass.Int32,
+ mbar_ptr: Optional[cute.Pointer],
+ num_warps: int,
+ is_persistent: bool = False,
+ ) -> None:
+ if cutlass.const_expr(self.cluster_n > 1 and mbar_ptr is not None):
+ if tidx < self.stage:
+ cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
+ if cutlass.const_expr(is_persistent):
+ cute.arch.mbarrier_init(
+ mbar_ptr + self.stage + tidx,
+ num_warps * self.cluster_n,
+ )
+ cute.arch.mbarrier_init_fence()
+ cute.arch.cluster_arrive_relaxed()
+
+
# -------------------------
-# SM count helper
+# RMSNorm backward base (from quack.rmsnorm.RMSNormBackward)
# -------------------------
-def get_sm_count(N: int, device: torch.device) -> int:
+class RMSNormBackward(ReductionBase):
+ def __init__(self, dtype: cutlass.Numeric, N: int):
+ # 2 stages for double buffering when computing mean of x_hat * wdy
+ super().__init__(dtype, N, stage=2, reduction_dtype=Float32)
+ self.reload_wdy = None if N <= 16 * 1024 else "smem"
+ if self.N > 128 * 1024 and self.dtype.width >= 32:
+ raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
+
+ def _get_num_threads(self) -> int:
+ return 128 if self.N <= 4096 else 256
+
+ def _calculate_threads_per_row(self) -> int:
+ N = self.N
+ return (
+ 8
+ if N <= 64
+ else (
+ 16
+ if N <= 128
+ else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256)))
+ )
+ )
+
+ def _set_cluster_n(self) -> None:
+ N = self.N
+ cluster_n = (
+ 1
+ if N <= 8 * 1024
+ else (2 if N <= 16 * 1024 else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16)))
+ )
+ self.cluster_n = cluster_n
+
+ def _smem_size_in_bytes(self, tiler_mn, num_warps: int, do_dtype=None) -> int:
+ if do_dtype is None:
+ do_dtype = self.dtype
+ return (
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
+ + cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
+ + self.stage * (cutlass.Int64.width // 8) * 2
+ )
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mdO: cute.Tensor,
+ mdResO: Optional[cute.Tensor],
+ mRstd: cute.Tensor,
+ mdX: cute.Tensor,
+ mdW: Optional[cute.Tensor],
+ mdRes: Optional[cute.Tensor],
+ mdB: Optional[cute.Tensor],
+ sm_count: Int32,
+ stream: cuda.CUstream,
+ ):
+ semistatic_shape = (*mX.shape[:-1], self.N)
+
+ def new_stride(t):
+ return (
+ cute.assume(t.stride[0], divby=128 // t.element_type.width),
+ t.stride[1],
+ )
+
+ mX, mdO, mdResO, mdX, mdRes = [
+ cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
+ if const_expr(t is not None)
+ else None
+ for t in (mX, mdO, mdResO, mdX, mdRes)
+ ]
+ self._set_cluster_n()
+ largest_dtype_width = const_expr(
+ max(
+ mX.element_type.width,
+ mdO.element_type.width,
+ mdX.element_type.width,
+ mdResO.element_type.width if mdResO is not None else 0,
+ mdRes.element_type.width if mdRes is not None else 0,
+ )
+ )
+ tiler_mn, tv_layout = self._get_tv_layout(
+ num_copy_bits=128 // largest_dtype_width * mX.element_type.width
+ )
+ num_threads = (
+ cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ if const_expr(mW is not None):
+ mW_expanded_layout = cute.prepend(
+ mW.layout,
+ cute.make_layout((tiler_mn[0],), stride=(0,)),
+ )
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
+
+ num_blocks = sm_count
+ kernel = (
+ self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn)
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes)
+ )
+ kernel.launch(
+ grid=[num_blocks, self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type),
+ stream=stream,
+ )
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mdO: cute.Tensor,
+ mdResO: Optional[cute.Tensor],
+ mRstd: cute.Tensor,
+ mdX: cute.Tensor,
+ mdW: Optional[cute.Tensor],
+ mdB: Optional[cute.Tensor],
+ mdRes: Optional[cute.Tensor],
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ):
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx_start, _, _ = cute.arch.block_idx()
+ gdim, _, _ = cute.arch.grid_dim()
+ if const_expr(self.cluster_n > 1):
+ cluster_y = cute.arch.block_idx()[1]
+ else:
+ cluster_y = const_expr(0)
+
+ shape = mX.shape
+ M = shape[0]
+ is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n)
+
+ idX = cute.make_identity_tensor(shape)
+
+ smem = cutlass.utils.SmemAllocator()
+ smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2))
+ sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16)
+ sdO = smem.allocate_tensor(mdO.element_type, smem_layout, byte_alignment=16)
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
+ smem,
+ tv_layout,
+ is_persistent=True,
+ )
+ if const_expr(mbar_ptr is not None):
+ mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2
+ else:
+ mbar_full_ptr, mbar_empty_ptr = None, None
+
+ num_copy_elems_X = tv_layout.shape[1][0]
+ copy_atom_load_X = get_copy_atom(mX.element_type, num_copy_elems_X, is_async=False)
+ thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
+ copy_fn = partial(copy, num_copy_elems=num_copy_elems_X)
+
+ gX, gdO, gdResO, gdX, gdRes, cX = [
+ cute.local_tile(mT, tiler_mn, (None, cluster_y)) if mT is not None else None
+ for mT in (mX, mdO, mdResO, mdX, mdRes, idX)
+ ]
+ gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) if mW is not None else None
+ gdW, gdB = [
+ cute.local_tile(mT, (1, tiler_mn[1]), (bidx_start, cluster_y))
+ if const_expr(mT is not None)
+ else None
+ for mT in (mdW, mdB)
+ ]
+
+ tXgX = thr_copy_X.partition_S(gX)
+ tXsX = thr_copy_X.partition_D(sX)
+ tXgdO = thr_copy_X.partition_S(gdO)
+ tXsdO = thr_copy_X.partition_D(sdO)
+ tXgdX = thr_copy_X.partition_D(gdX)
+ if const_expr(mdResO is not None):
+ tXgdResO = thr_copy_X.partition_S(gdResO)
+ if const_expr(mdRes is not None):
+ tXgdRes = thr_copy_X.partition_D(gdRes)
+ tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
+
+ tXrX, tXrdO, tXrdX = [
+ cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
+ ]
+ tXrdResO = None
+ if const_expr(mdResO is not None):
+ tXrdResO = cute.make_fragment_like(tXgdResO[None, None, None, 0])
+ tXrdRes = None
+ if const_expr(mdRes is not None):
+ tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0])
+
+ tXpX = (
+ predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1])
+ if not is_even_N
+ else None
+ )
+
+ tXgdW, tXrdW = None, None
+ tXgdB, tXrdB = None, None
+ if const_expr(mdW is not None):
+ tXgdW = thr_copy_X.partition_S(gdW)
+ tXrdW = cute.make_fragment_like(tXgdW, Float32)
+ if const_expr(mdB is not None):
+ tXgdB = thr_copy_X.partition_S(gdB)
+ tXrdB = cute.make_fragment_like(tXgdB, Float32)
+
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
+
+ self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True)
+
+ tXrW = None
+ if const_expr(mW is not None):
+ tXgW = thr_copy_X.partition_S(gW)
+ tXrW = cute.make_fragment_like(tXgW)
+ if not is_even_N:
+ tXrW.fill(0.0)
+ copy_fn(tXgW, tXrW, pred=tXpX)
+
+ row = tXcX[None, None, None, bidx_start][0][0]
+ if row < M:
+ tXgX_cur = coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0]
+ tXgdO_cur = coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0]
+ copy_fn(tXgX_cur, tXsX[None, None, None, 0], pred=tXpX, is_async=True)
+ copy_fn(tXgdO_cur, tXsdO[None, None, None, 0], pred=tXpX, is_async=True)
+ elif tiler_mn[0] > 1:
+ fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero)
+ fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero)
+ cute.arch.cp_async_commit_group()
+
+ if const_expr(self.cluster_n > 1):
+ cute.arch.cluster_wait()
+
+ threads_per_row = tv_layout.shape[0][0]
+ if const_expr(mdW is not None):
+ tXrdW.fill(0.0)
+ if const_expr(mdB is not None):
+ tXrdB.fill(0.0)
+ stage = Int32(0)
+ producer_phase = Int32(1)
+ consumer_phase = Int32(0)
+ for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
+ row = tXcX[None, None, None, bidx][0][0]
+ if row + gdim * tiler_mn[0] < M:
+ tXgX_cur = coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
+ tXgdO_cur = coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0]
+ copy_fn(tXgX_cur, tXsX[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
+ copy_fn(tXgdO_cur, tXsdO[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
+ elif tiler_mn[0] > 1:
+ fill_oob(
+ tXsX[None, None, None, stage ^ 1],
+ None,
+ fill_value=mX.element_type.zero,
+ )
+ fill_oob(
+ tXsdO[None, None, None, stage ^ 1],
+ None,
+ fill_value=mdO.element_type.zero,
+ )
+ cute.arch.cp_async_commit_group()
+ rstd_val = cutlass.Float.zero
+ if row < M or tiler_mn[0] == 1:
+ rstd_val = mRstd[row]
+ if const_expr(mdResO is not None):
+ tXgdResO_cur = coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
+ if row < M or tiler_mn[0] == 1:
+ copy_fn(tXgdResO_cur, tXrdResO, pred=tXpX)
+ elif tiler_mn[0] > 1:
+ tXrdResO.fill(0.0)
+ cute.arch.cp_async_wait_group(1)
+ cute.autovec_copy(tXsX[None, None, None, stage], tXrX)
+ x = tXrX.load().to(cute.Float32)
+ cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
+ dout = tXrdO.load().to(cute.Float32)
+ x_hat = x * rstd_val
+ wdy = dout
+ if const_expr(mW is not None):
+ wdy *= tXrW.load().to(Float32)
+ if const_expr(self.cluster_n > 1):
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
+ mean_xhat_wdy = (
+ row_reduce_add(
+ x_hat * wdy,
+ threads_per_row,
+ reduction_buffer[None, None, stage],
+ (mbar_full_ptr + stage if const_expr(self.cluster_n > 1) else None),
+ phase=consumer_phase,
+ init_val=0.0,
+ )
+ / shape[1]
+ )
+
+ if const_expr(self.cluster_n > 1):
+ cute.arch.fence_proxy(
+ cute.arch.ProxyKind.async_shared,
+ space=cute.arch.SharedSpace.shared_cta,
+ )
+ cute.arch.sync_warp()
+ lane_idx = cute.arch.lane_idx()
+ if lane_idx < self.cluster_n:
+ cute.arch.mbarrier_arrive(
+ mbar_empty_ptr + stage,
+ peer_cta_rank_in_cluster=lane_idx,
+ )
+
+ if const_expr(self.reload_wdy == "smem"):
+ cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO)
+ dout = tXrdO.load().to(cute.Float32)
+ wdy = dout
+ if const_expr(mW is not None):
+ wdy *= tXrW.load().to(Float32)
+
+ dx = (wdy - x_hat * mean_xhat_wdy) * rstd_val
+ if const_expr(mdResO is not None):
+ dx += tXrdResO.load().to(cute.Float32)
+ tXrdX.store(dx.to(tXrdX.element_type))
+ if row < M or tiler_mn[0] == 1:
+ tXgdX_cur = coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0]
+ copy_fn(tXrdX, tXgdX_cur, pred=tXpX)
+ if const_expr(mdRes is not None):
+ tXrdRes.store(dx.to(tXrdRes.element_type))
+ tXgdRes_cur = coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
+ if row < M or tiler_mn[0] == 1:
+ copy_fn(tXrdRes, tXgdRes_cur, pred=tXpX)
+ if const_expr(mdW is not None):
+ tXrdW.store(tXrdW.load() + dout * x_hat)
+ if const_expr(mdB is not None):
+ tXrdB.store(tXrdB.load() + dout)
+
+ stage ^= 1
+ if stage == 0:
+ consumer_phase ^= 1
+ producer_phase ^= 1
+
+ if const_expr(tiler_mn[0] > 1):
+ if const_expr(mdW is not None):
+ sdW = cute.make_tensor(
+ cute.recast_ptr(sX.iterator, dtype=cute.Float32),
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ )
+ tXsdW = thr_copy_X.partition_D(sdW)
+ cute.arch.barrier()
+ row0 = tXcX[None, None, None, 0][0][0]
+ if row0 > 0:
+ cute.autovec_copy(tXrdW, tXsdW)
+ cute.arch.barrier()
+ if row0 == 0:
+ for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
+ tXrdW_other = cute.make_fragment_like(tXrdW)
+ tXsdW_other = cute.make_tensor(
+ tXsdW.iterator + i * sdW.stride[0],
+ tXsdW.layout,
+ )
+ cute.autovec_copy(tXsdW_other, tXrdW_other)
+ tXrdW.store(tXrdW.load() + tXrdW_other.load())
+ copy_fn(tXrdW, tXgdW, pred=tXpX)
+ cute.arch.barrier()
+ if const_expr(mdB is not None):
+ sdB = cute.make_tensor(
+ cute.recast_ptr(sX.iterator, dtype=cute.Float32),
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ )
+ tXsdB = thr_copy_X.partition_D(sdB)
+ cute.arch.barrier()
+ row0 = tXcX[None, None, None, 0][0][0]
+ if row0 > 0:
+ cute.autovec_copy(tXrdB, tXsdB)
+ cute.arch.barrier()
+ if row0 == 0:
+ for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])):
+ tXrdB_other = cute.make_fragment_like(tXrdB)
+ tXsdB_other = cute.make_tensor(
+ tXsdB.iterator + i * sdB.stride[0],
+ tXsdB.layout,
+ )
+ cute.autovec_copy(tXsdB_other, tXrdB_other)
+ tXrdB.store(tXrdB.load() + tXrdB_other.load())
+ copy_fn(tXrdB, tXgdB, pred=tXpX)
+ else:
+ if const_expr(mdW is not None):
+ copy_fn(tXrdW, tXgdW, pred=tXpX)
+ if const_expr(mdB is not None):
+ copy_fn(tXrdB, tXgdB, pred=tXpX)
+
+ if const_expr(self.cluster_n > 1):
+ stage ^= 1
+ if stage == 0:
+ producer_phase ^= 1
+ cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase)
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mdO: cute.Tensor,
+ mdResO: Optional[cute.Tensor],
+ mRstd: cute.Tensor,
+ mdX: cute.Tensor,
+ mdW: Optional[cute.Tensor],
+ mdB: Optional[cute.Tensor],
+ mdRes: Optional[cute.Tensor],
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ):
+ self._kernel_impl(
+ mX,
+ mW,
+ mdO,
+ mdResO,
+ mRstd,
+ mdX,
+ mdW,
+ mdB,
+ mdRes,
+ tv_layout,
+ tiler_mn,
+ )
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mdO: cute.Tensor,
+ mdResO: Optional[cute.Tensor],
+ mRstd: cute.Tensor,
+ mdX: cute.Tensor,
+ mdW: Optional[cute.Tensor],
+ mdB: Optional[cute.Tensor],
+ mdRes: Optional[cute.Tensor],
+ ):
+ largest_dtype_width = const_expr(
+ max(
+ mX.element_type.width,
+ mdO.element_type.width,
+ mdX.element_type.width,
+ mdResO.element_type.width if mdResO is not None else 0,
+ mdRes.element_type.width if mdRes is not None else 0,
+ )
+ )
+ tiler_mn, tv_layout = self._get_tv_layout(
+ num_copy_bits=128 // largest_dtype_width * mX.element_type.width
+ )
+ self._kernel_impl(
+ mX,
+ mW,
+ mdO,
+ mdResO,
+ mRstd,
+ mdX,
+ mdW,
+ mdB,
+ mdRes,
+ tv_layout,
+ tiler_mn,
+ )
+
+
+# -------------------------
+# SM count helper (from quack.rmsnorm._get_sm_count)
+# -------------------------
+
+
+def get_sm_count(
+ N: int,
+ device: torch.device,
+ M: Optional[int] = None,
+ dtype: Optional[torch.dtype] = None,
+) -> int:
"""
- Heuristic for the number of persistent CTAs (sm_count) based on N and
- the GPU's SM count. This mirrors the behaviour used in Quack's
- RMSNorm kernels but lives entirely in this local module.
+ SM count heuristic for reduction-style kernels.
+
+ This starts from Quack's _get_sm_count policy and layers on SM100 /
+ DSv3-specific tuning so that:
+ - For DSv3-style shapes (large-M, N in {6144, 8192}, fp16/bf16),
+ sm_count is reduced for very large M to cut down the number of
+ dw_partial/db_partial rows that ever hit HBM.
+ - For Quack-suite hidden=4096, small-M shapes, sm_count is modestly
+ increased to improve SM occupancy, matching the existing SM100
+ tuning used by both RMSNorm and LayerNorm.
"""
+ props = torch.cuda.get_device_properties(device)
+ num_sms = props.multi_processor_count
+
sm_count_multiple = (
- 16
- if N <= 256
- else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
- )
- sm_count = torch.cuda.get_device_properties(device).multi_processor_count
- sm_count = (
- sm_count * sm_count_multiple
- if N <= 8192
- else sm_count // 2
- if N <= 16384
- else sm_count * 2
+ 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
)
+ sm_count = num_sms
+ if N <= 8192:
+ sm_count = sm_count * sm_count_multiple
+ elif N <= 16384:
+ sm_count = sm_count // 2
+ else:
+ sm_count = sm_count * 2
+
+ # Quack-suite tuning: for small-M, hidden=4096 shapes (M<=8192) and
+ # 16-bit dtypes, increase sm_count to improve occupancy. This mirrors
+ # the existing SM100 RMSNorm/LayerNorm heuristics.
+ if (
+ dtype in (torch.float16, torch.bfloat16)
+ and M is not None
+ and M <= 8192
+ and N == 4096
+ ):
+ sm_count = min(sm_count * 2, num_sms * 4)
+
return sm_count
diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
index c7fc1b3..1e080a3 100644
--- a/oink/src/kernelagent_oink/blackwell/rmsnorm.py
+++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
@@ -76,6 +76,7 @@
import cutlass.cute as cute # noqa: E402
from cutlass import Float32, Int32, const_expr # noqa: E402
from cutlass.cute import runtime as rt # noqa: E402
+from cutlass.cute.runtime import from_dlpack # noqa: E402
# Simple compile cache declared early so direct execution works
_PTR_COMPILE_CACHE = {}
@@ -114,6 +115,14 @@ def _env_flag(name: str, default: bool) -> bool:
_ENABLE_TPR256 = _env_flag("OINK_RMSNORM_ENABLE_TPR256", default=False)
_ENABLE_STAGE2 = _env_flag("OINK_RMSNORM_ENABLE_STAGE2", default=False)
+# Forward dispatch control:
+# - Default behavior: use the pointer-based path when safe, otherwise fall back
+# to the stage-2 module (then the torch reference).
+# - If you want to force stage-2 even when the pointer path is available (for
+# experimentation / A-B testing), set this env var **before** importing this
+# module.
+_FORCE_RMSNORM_STAGE2_FWD = _env_flag("KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2", default=False)
+
# CuTeDSL stability probe for the experimental cluster_n>1 + direct-GMEM schedule.
#
# Some CuTeDSL builds segfault during JIT compilation when combining:
@@ -918,7 +927,13 @@ def _get_fast_ptr_fused_add_rmsnorm_launcher(
# NOTE: Avoid `from . import ...` imports here: CuTeDSL's AST preprocessor may
# mishandle that form (module=None in the AST). Use fully-qualified imports.
from kernelagent_oink.blackwell import lite_quack as qutils # noqa: E402
-from kernelagent_oink.blackwell.lite_quack import TORCH2CUTE_DTYPE, row_reduce # noqa: E402
+from kernelagent_oink.blackwell.lite_quack import ( # noqa: E402
+ TORCH2CUTE_DTYPE,
+ RMSNormBackward as BaseRMSNormBackward,
+ convert_from_dlpack as convert_from_dlpack_cute,
+ get_sm_count,
+ row_reduce,
+)
# -------------------------
@@ -2720,52 +2735,57 @@ def rmsnorm_forward(
assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand."
M, N = x.shape
- # For DSv3 big-M outliers on SM100, keep using the dedicated
- # stage-2 K-loop implementation, which is already tuned and
- # parity-checked against the reference.
- use_stage2_big_dsv3 = bool(
- M >= 65536 and N in (6144, 8192) and x.dtype in (torch.float16, torch.bfloat16)
- )
- if use_stage2_big_dsv3:
- try:
- import rmsnorm_with_stage2 as rms2 # type: ignore[import-not-found]
- except Exception:
- rms2 = None # type: ignore[assignment]
- if rms2 is not None:
- y, rstd, residual_out = rms2.rmsnorm_forward_with_stage2(
- x,
- weight=weight,
- bias=bias,
- residual=residual,
- eps=eps,
- store_rstd=store_rstd,
- )
- # Preserve stride contracts for torch.compile consistency, even
- # when using the optional stage-2 implementation.
- if y.stride() != x.stride():
- y_strided = torch.empty_strided(
- x.shape, x.stride(), device=x.device, dtype=x.dtype
- )
- y_strided.copy_(y)
- y = y_strided
- if residual is not None and residual_out is not None:
- if residual_out.stride() != residual.stride():
- residual_out_strided = torch.empty_strided(
- residual.shape,
- residual.stride(),
- device=residual.device,
- dtype=residual.dtype,
- )
- residual_out_strided.copy_(residual_out)
- residual_out = residual_out_strided
- return y, rstd, residual_out
-
- # Default: use the pointer-based entry whenever we can represent the
- # inputs as a row-major [M, N] view with stride(1) == 1. For rare layouts
- # we can't safely express without DLPack, fall back to a torch reference.
- if _can_use_ptr_path(x, weight, bias, residual):
+ # Fast path: use the pointer-based entry whenever we can represent the
+ # inputs as a row-major [M, N] view with stride(1) == 1 and dtype contracts
+ # are satisfied (vLLM uses this in inference).
+ #
+ # When the pointer path can't be used (e.g. float32 weights for Quack-style
+ # APIs, or non-standard layouts), fall back to the CuTeDSL stage-2 module
+ # before using the slow torch reference implementation.
+ force_stage2 = _FORCE_RMSNORM_STAGE2_FWD
+
+ if not force_stage2 and _can_use_ptr_path(x, weight, bias, residual):
return _rmsnorm_forward_ptr(x, weight, bias, residual, eps, store_rstd)
+ # CuTeDSL fallback for cases that aren't safe for the pointer path.
+ # Import lazily to keep vLLM plugin startup and common inference fast paths
+ # lightweight.
+ try:
+ import importlib
+
+ rms2 = importlib.import_module(
+ ".rmsnorm_with_stage2",
+ package=__package__ or "kernelagent_oink.blackwell",
+ )
+ except Exception:
+ rms2 = None # type: ignore[assignment]
+ if rms2 is not None:
+ y, rstd, residual_out = rms2.rmsnorm_forward_with_stage2(
+ x,
+ weight=weight,
+ bias=bias,
+ residual=residual,
+ eps=eps,
+ store_rstd=store_rstd,
+ )
+ # Preserve stride contracts for torch.compile consistency, even
+ # when using the optional stage-2 implementation.
+ if y.stride() != x.stride():
+ y_strided = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
+ y_strided.copy_(y)
+ y = y_strided
+ if residual is not None and residual_out is not None:
+ if residual_out.stride() != residual.stride():
+ residual_out_strided = torch.empty_strided(
+ residual.shape,
+ residual.stride(),
+ device=residual.device,
+ dtype=residual.dtype,
+ )
+ residual_out_strided.copy_(residual_out)
+ residual_out = residual_out_strided
+ return y, rstd, residual_out
+
# Safe fallback (correctness-first). This is expected to be rare in vLLM.
y = rmsnorm_ref(x, weight, bias, residual, eps)
# Preserve the input stride contract even on the fallback path so
@@ -2910,6 +2930,363 @@ def fused_add_rmsnorm_inplace_(
return None
+# -------------------------
+# Backward kernel (SM100)
+# -------------------------
+
+
+class RMSNormBackwardSM100(BaseRMSNormBackward):
+ """SM100-tuned RMSNorm backward.
+
+ This is a thin wrapper around the generic `lite_quack.RMSNormBackward`
+ base implementation, with SM100-friendly tiling heuristics that mirror
+ the forward policy used by Oink.
+ """
+
+ def __init__(self, dtype: cutlass.Numeric, N: int):
+ super().__init__(dtype, N)
+
+ def _get_num_threads(self) -> int:
+ # Keep 128 threads only up to N=4k; use 256 for larger rows to ensure
+ # threads_per_row <= num_threads across buckets.
+ try:
+ return self._nt_override # type: ignore[attr-defined]
+ except Exception:
+ return 128 if self.N <= 4096 else 256
+
+ def _calculate_threads_per_row(self) -> int:
+ # Mirror RMSNormSM100 forward's tiling.
+ N = self.N
+ if N <= 64:
+ return 8
+ if N <= 128:
+ return 16
+ if N <= 1024:
+ return 32
+ if N <= 4096:
+ return 128
+ if N <= 8192:
+ try:
+ return self._tpr_override # type: ignore[attr-defined]
+ except Exception:
+ return 128
+ if N <= 16384:
+ return 256
+ return 256
+
+ def _set_cluster_n(self) -> None:
+ # Reuse the SM100 forward cluster growth policy so large-N shapes can
+ # fan out across multiple CTAs in the same row.
+ try:
+ self.cluster_n = self._cluster_n_override # type: ignore[attr-defined]
+ return
+ except Exception:
+ pass
+
+ N = self.N
+ if N <= 8192:
+ cluster_n = 1
+ elif self.dtype.width == 16:
+ if N <= 16 * 1024:
+ cluster_n = 2
+ elif N <= 32 * 1024:
+ cluster_n = 2
+ elif N <= 64 * 1024:
+ cluster_n = 4
+ elif N <= 128 * 1024:
+ cluster_n = 8
+ else:
+ cluster_n = 16
+ else:
+ if N <= 32 * 1024:
+ cluster_n = 1
+ elif N <= 64 * 1024:
+ cluster_n = 2
+ elif N <= 128 * 1024:
+ cluster_n = 4
+ elif N <= 256 * 1024:
+ cluster_n = 8
+ else:
+ cluster_n = 16
+ self.cluster_n = cluster_n
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mdO: cute.Tensor,
+ mdResO: Optional[cute.Tensor],
+ mRstd: cute.Tensor,
+ mdX: cute.Tensor,
+ mdW: Optional[cute.Tensor],
+ mdRes: Optional[cute.Tensor],
+ mdB: Optional[cute.Tensor],
+ sm_count: Int32,
+ stream: cuda.CUstream,
+ ):
+ # Match forward's 32B alignment on the leading dimension to unlock
+ # wider vectorization when legal.
+ semistatic_shape = (*mX.shape[:-1], self.N)
+
+ def new_stride(t):
+ return (
+ cute.assume(t.stride[0], divby=256 // t.element_type.width),
+ t.stride[1],
+ )
+
+ mX, mdO, mdResO, mdX, mdRes = [
+ cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
+ if const_expr(t is not None)
+ else None
+ for t in (mX, mdO, mdResO, mdX, mdRes)
+ ]
+
+ self._set_cluster_n()
+ largest_dtype_width = const_expr(
+ max(
+ mX.element_type.width,
+ mdO.element_type.width,
+ mdX.element_type.width,
+ mdResO.element_type.width if mdResO is not None else 0,
+ mdRes.element_type.width if mdRes is not None else 0,
+ )
+ )
+ tiler_mn, tv_layout = self._get_tv_layout(
+ num_copy_bits=128 // largest_dtype_width * mX.element_type.width
+ )
+ num_threads = (
+ cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ if const_expr(mW is not None):
+ mW_expanded_layout = cute.prepend(
+ mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))
+ )
+ mW = cute.make_tensor(mW.iterator, mW_expanded_layout)
+
+ num_blocks = sm_count
+ kernel = (
+ self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn)
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes)
+ )
+ kernel.launch(
+ grid=[num_blocks, self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type),
+ stream=stream,
+ )
+
+
+_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+
+
+def _rmsnorm_bwd_sm100(
+ x: Tensor,
+ weight: Optional[Tensor],
+ dout: Tensor,
+ rstd: Tensor,
+ dx: Tensor,
+ dw_partial: Optional[Tensor],
+ db_partial: Optional[Tensor] = None,
+ dresidual_out: Optional[Tensor] = None,
+ dresidual: Optional[Tensor] = None,
+ sm_count: Optional[int] = None,
+) -> None:
+ """SM100-specific RMSNorm backward dispatch.
+
+ Mirrors Quack's `quack.rmsnorm._rmsnorm_bwd`, but instantiates
+ `RMSNormBackwardSM100` (SM100-tuned heuristics).
+ """
+ assert x.dim() == 2, "Input must be 2D"
+ assert x.is_cuda, "Input tensor must be on CUDA device"
+ assert x.dtype in (torch.float16, torch.bfloat16, torch.float32)
+
+ if weight is not None:
+ assert weight.dim() == 1
+ assert x.shape[-1] == weight.shape[0]
+ assert weight.is_cuda
+ assert weight.dtype in (torch.float32, torch.bfloat16, torch.float16)
+ if dresidual_out is not None:
+ assert dresidual_out.shape == x.shape
+ assert dresidual_out.is_cuda
+ assert dresidual_out.dtype in (torch.float16, torch.bfloat16, torch.float32)
+ if dresidual is not None:
+ assert dresidual.shape == x.shape
+ assert dresidual.is_cuda
+ assert dresidual.dtype in (torch.float16, torch.bfloat16, torch.float32)
+
+ M, N = x.size(0), x.size(1)
+ device = x.device
+ if dw_partial is None and db_partial is None:
+ assert sm_count is not None
+ else:
+ sm_count = (
+ dw_partial.shape[0] if dw_partial is not None else db_partial.shape[0]
+ )
+
+ # Match Quack's conversion strategy for activations/gradients: keep the
+ # (M, N) layout dynamic without enforcing additional compact-shape
+ # constraints. This reduces per-call Python overhead for small-M shapes.
+ convert_from_dlpack = lambda t: from_dlpack( # type: ignore[assignment]
+ t.detach(),
+ assumed_align=16,
+ ).mark_layout_dynamic(leading_dim=1)
+ x_tensor, dout_tensor, dres_out_tensor, dx_tensor, dres_tensor = [
+ convert_from_dlpack(t) if t is not None else None
+ for t in (x, dout, dresidual_out, dx, dresidual)
+ ]
+
+ if weight is not None:
+ weight_dtype = TORCH2CUTE_DTYPE[weight.dtype]
+ weight_tensor = convert_from_dlpack_cute(
+ weight.detach(),
+ leading_dim=0,
+ divisibility=128 // weight_dtype.width,
+ )
+ else:
+ weight_tensor = None
+
+ dw_partial_tensor = (
+ from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
+ if dw_partial is not None
+ else None
+ )
+ db_partial_tensor = (
+ from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
+ if db_partial is not None
+ else None
+ )
+ rstd_tensor = (
+ from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+ )
+
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+
+ compile_key = (
+ M,
+ N,
+ x_tensor.element_type,
+ weight_tensor.element_type if weight is not None else None,
+ db_partial.dtype if db_partial is not None else None,
+ dresidual.dtype if dresidual is not None else None,
+ dresidual_out.dtype if dresidual_out is not None else None,
+ )
+ kernel = _BWD_COMPILE_CACHE.get(compile_key)
+ if kernel is None:
+ op = RMSNormBackwardSM100(x_tensor.element_type, N)
+
+ # Shape-specific tuning overrides for DSv3-style N=8192 rows.
+ if isinstance(op, RMSNormBackwardSM100) and N == 8192:
+ if M >= 65536:
+ op._tpr_override = 256 # type: ignore[attr-defined]
+ op._nt_override = 256 # type: ignore[attr-defined]
+ elif M >= 16384:
+ op._tpr_override = 256 # type: ignore[attr-defined]
+
+ kernel = cute.compile(
+ op,
+ x_tensor,
+ weight_tensor,
+ dout_tensor,
+ dres_out_tensor,
+ rstd_tensor,
+ dx_tensor,
+ dw_partial_tensor,
+ dres_tensor,
+ db_partial_tensor,
+ Int32(sm_count if sm_count is not None else 0),
+ current_stream,
+ )
+ _BWD_COMPILE_CACHE[compile_key] = kernel
+
+ kernel(
+ x_tensor,
+ weight_tensor,
+ dout_tensor,
+ dres_out_tensor,
+ rstd_tensor,
+ dx_tensor,
+ dw_partial_tensor,
+ dres_tensor,
+ db_partial_tensor,
+ Int32(sm_count if sm_count is not None else 0),
+ current_stream,
+ )
+
+
+def rmsnorm_backward(
+ x: Tensor,
+ weight: Optional[Tensor],
+ dout: Tensor,
+ rstd: Tensor,
+ dresidual_out: Optional[Tensor] = None,
+ has_bias: bool = False,
+ has_residual: bool = False,
+) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
+ """Public SM100 RMSNorm backward entry point.
+
+ Signature mirrors `quack.rmsnorm.rmsnorm_bwd` for easy comparisons.
+ """
+ device = x.device
+ M, N = x.size(0), x.size(1)
+ dx = torch.empty_like(x)
+ if dresidual_out is not None and dresidual_out.dtype != dx.dtype:
+ dresidual = torch.empty_like(x, dtype=dresidual_out.dtype)
+ else:
+ dresidual = None
+
+ # Shared SM100 tuning policy (used by both RMSNorm and LayerNorm).
+ sm_count = get_sm_count(N, device, M=M, dtype=x.dtype)
+
+ # Quack-suite smallest case (M=8192, N=4096) is extremely sensitive to
+ # Python/allocator overhead because the kernel itself is very fast.
+ #
+ # The default `lite_quack.get_sm_count` adds a small-M occupancy boost for
+ # N=4096, which increases `dw_partial` size and can amplify allocator
+ # pressure in benchmark/verify loops. Clamp to Quack's baseline policy
+ # (`sm_count = num_sms * 2` for N=4096) for this regime.
+ if N == 4096 and M <= 8192 and x.dtype in (torch.float16, torch.bfloat16):
+ try:
+ num_sms = torch.cuda.get_device_properties(device).multi_processor_count
+ sm_count = min(int(sm_count), int(num_sms) * 2)
+ except Exception:
+ pass
+
+ if weight is not None:
+ dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32)
+ else:
+ dw_partial = None
+ db_partial = (
+ torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None
+ )
+
+ _rmsnorm_bwd_sm100(
+ x,
+ weight,
+ dout,
+ rstd,
+ dx,
+ dw_partial,
+ db_partial,
+ dresidual_out,
+ dresidual,
+ sm_count,
+ )
+
+ dw = dw_partial.sum(dim=0).to(weight.dtype) if weight is not None else None
+ db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None
+ if has_residual and dresidual is None:
+ dresidual = dx
+ return dx, dw, db, dresidual
+
+
+# Quack-style alias for benchmarks
+rmsnorm_bwd = rmsnorm_backward
+
+
if __name__ == "__main__":
# Minimal ad-hoc test (functionality only). For performance comparisons, use the benchmark harness.
if not torch.cuda.is_available():
diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py
new file mode 100644
index 0000000..b53da12
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py
@@ -0,0 +1,805 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+RMSNorm kernel for SM100 (Blackwell) in CuteDSL, with the experimental
+stage-2 cp.async ping-pong path preserved for N≈6k/8k.
+
+This file is a fork of rmsnorm.py that keeps the K-loop cp.async path
+behind `self.stage > 1` while the main implementation has been simplified
+to a single-stage schedule.
+"""
+
+from __future__ import annotations
+
+import importlib.metadata
+import re
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor
+
+import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python
+
+import cutlass
+import cutlass.cute as cute
+from cutlass import Float32, const_expr
+from cutlass.cute.runtime import from_dlpack
+
+from kernelagent_oink.blackwell import lite_quack as qutils
+from kernelagent_oink.blackwell.lite_quack import TORCH2CUTE_DTYPE, row_reduce
+
+_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+
+
+def _parse_version_tuple(version: str) -> tuple[int, int, int]:
+ parts = version.split(".")
+ nums: list[int] = []
+ for part in parts[:3]:
+ match = re.match(r"^(\d+)", part)
+ nums.append(int(match.group(1)) if match is not None else 0)
+ while len(nums) < 3:
+ nums.append(0)
+ return nums[0], nums[1], nums[2]
+
+
+def _cutlass_dsl_version() -> Optional[tuple[int, int, int]]:
+ try:
+ return _parse_version_tuple(importlib.metadata.version("nvidia-cutlass-dsl"))
+ except Exception:
+ return None
+
+
+_CUTLASS_DSL_VERSION = _cutlass_dsl_version()
+# CuTeDSL 4.3.4 tightened some kernel argument expectations (notably around
+# passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the
+# older signature for <4.3.4, but switch to a 4.3.4+ compatible signature when
+# we detect 4.3.4+ (or when version detection is unavailable).
+_KERNEL_ACCEPTS_LAYOUT_ARGS = (
+ _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4)
+)
+
+
+@cute.jit
+def get_copy_atom_bw(
+ dtype: type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False
+) -> cute.CopyAtom:
+ max_bits = const_expr(128 if is_async else 256)
+ num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width))
+ from cutlass.cute.nvgpu import cpasync
+
+ copy_op = (
+ cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL)
+ if is_async
+ else cute.nvgpu.CopyUniversalOp()
+ )
+ return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits)
+
+
+@cute.jit
+def copy_tiled(
+ src: cute.Tensor,
+ dst: cute.Tensor,
+ *,
+ pred: Optional[cute.Tensor] = None,
+ num_copy_elems: int = 1,
+ is_async: bool = False,
+) -> None:
+ atom = get_copy_atom_bw(src.element_type, num_copy_elems, is_async)
+ cute.copy(atom, src, dst, pred=pred)
+
+
+class RMSNormSM100WithStage2:
+ def __init__(self, N: int, dtype: type[cutlass.Numeric], stage: Optional[int] = None):
+ self.N = N
+ self.dtype = dtype
+ self.stage = 1 if stage is None else stage
+ self.reduction_dtype = cutlass.Float32
+
+ def _threads_per_row(self) -> int:
+ N = self.N
+ if N <= 64:
+ return 8
+ elif N <= 128:
+ return 16
+ elif N <= 1024:
+ return 32
+ elif N <= 4096:
+ return 128
+ elif N <= 8192:
+ try:
+ return self._tpr_override # type: ignore[attr-defined]
+ except Exception:
+ return 128
+ elif N <= 16384:
+ return 256
+ else:
+ return 256
+
+ def _cluster_n(self) -> int:
+ N = self.N
+ if N <= 8192:
+ return 1
+ if const_expr(self.dtype.width == 16):
+ if N <= 16 * 1024:
+ return 2
+ elif N <= 32 * 1024:
+ return 2
+ elif N <= 64 * 1024:
+ return 4
+ elif N <= 128 * 1024:
+ return 8
+ else:
+ return 16
+ else:
+ if N <= 32 * 1024:
+ return 1
+ elif N <= 64 * 1024:
+ return 2
+ elif N <= 128 * 1024:
+ return 4
+ elif N <= 256 * 1024:
+ return 8
+ else:
+ return 16
+
+ def _num_threads(self) -> int:
+ try:
+ return self._nt_override # type: ignore[attr-defined]
+ except Exception:
+ return 128 if self.N <= 16384 else 256
+
+ def _tv_layout(self, num_copy_bits: int = 256) -> Tuple[cute.Shape, cute.Layout]:
+ vecsize = num_copy_bits // self.dtype.width
+ num_threads = self._num_threads()
+ assert num_threads % cute.arch.WARP_SIZE == 0
+ tpr = self._threads_per_row()
+ cluster_n = self._cluster_n()
+ num_cols_vec = cute.ceil_div(self.N, vecsize)
+ num_blocks_N = cute.ceil_div(num_cols_vec, tpr * cluster_n)
+ cols_per_block = num_threads // tpr
+ tiler_mn = (cols_per_block, vecsize * num_blocks_N * tpr)
+ tv_layout = cute.make_layout(
+ ((tpr, cols_per_block), (vecsize, num_blocks_N)),
+ stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * tpr)),
+ )
+ return tiler_mn, tv_layout
+
+ @cute.jit
+ def __call__(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ stream: cuda.CUstream,
+ eps: Float32 = 1e-6,
+ ):
+ semistatic_shape = (*mX.shape[:-1], self.N)
+
+ def new_stride(t):
+ return (
+ cute.assume(t.stride[0], divby=256 // t.element_type.width),
+ t.stride[1],
+ )
+
+ mX, mRes, mO, mResO = [
+ cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
+ if const_expr(t is not None)
+ else None
+ for t in (mX, mRes, mO, mResO)
+ ]
+ assert mX.element_type == self.dtype
+ assert mO.element_type == self.dtype
+
+ copy_bits = const_expr(128)
+ tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits)
+ num_threads = (
+ cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ threads_per_row = (
+ tv_layout.shape[0][0] if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._threads_per_row()
+ )
+ warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1)
+ cluster_n = self._cluster_n()
+
+ if const_expr(mW is not None):
+ mW = cute.make_tensor(
+ mW.iterator, cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
+ )
+ if const_expr(mB is not None):
+ mB = cute.make_tensor(
+ mB.iterator, cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
+ )
+ if const_expr(mRstd is not None):
+ mRstd = cute.make_tensor(
+ mRstd.iterator, cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,)))
+ )
+
+ stage_bufs = 2 if self.stage > 1 else 1
+ tile_bytes_x = cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * stage_bufs
+ tile_bytes_res = (
+ cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn)) * stage_bufs
+ if const_expr(mRes is not None)
+ else 0
+ )
+ red_bytes = self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8)
+ mbar_bytes = self.stage * (cutlass.Int64.width // 8)
+ smem_bytes = tile_bytes_x + tile_bytes_res + red_bytes + mbar_bytes
+
+ kernel = (
+ self.kernel(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ tv_layout,
+ tiler_mn,
+ const_expr(num_warps),
+ const_expr(warps_per_row),
+ const_expr(threads_per_row),
+ )
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ )
+ )
+
+ kernel.launch(
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=([1, cluster_n, 1] if const_expr(cluster_n > 1) else None),
+ smem=smem_bytes,
+ stream=stream,
+ )
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ eps: Float32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ num_warps: cutlass.Constexpr[int],
+ warps_per_row: cutlass.Constexpr[int],
+ threads_per_row: cutlass.Constexpr[int],
+ ):
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ cluster_n = self._cluster_n()
+ cluster_y = const_expr(0) if const_expr(cluster_n == 1) else cute.arch.block_idx()[1]
+
+ smem = cutlass.utils.SmemAllocator()
+ sX0 = smem.allocate_tensor(
+ mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32
+ )
+ sX1 = (
+ smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
+ )
+ if const_expr(self.stage > 1)
+ else None
+ )
+ sRes0 = (
+ smem.allocate_tensor(
+ mRes.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32
+ )
+ if const_expr(mRes is not None)
+ else None
+ )
+ sRes1 = (
+ smem.allocate_tensor(
+ mRes.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
+ )
+ if const_expr(mRes is not None and self.stage > 1)
+ else None
+ )
+
+ reduction_buffer, mbar_ptr = self._alloc_reduction_and_mbar(smem, num_warps, warps_per_row)
+
+ shape = mX.shape
+ idX = cute.make_identity_tensor(shape)
+
+ num_copy_elems_X = tv_layout.shape[1][0]
+ use_async = const_expr(self.N >= 1024)
+ copy_atom = get_copy_atom_bw(mX.element_type, num_copy_elems_X, is_async=use_async)
+ thr_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn).get_slice(tidx)
+
+ gW, gB = [
+ cute.local_tile(t, tiler_mn, (0, cluster_y)) if const_expr(t is not None) else None
+ for t in (mW, mB)
+ ]
+ tXgW = thr_copy.partition_S(gW) if const_expr(mW is not None) else None
+ tXgB = thr_copy.partition_S(gB) if const_expr(mB is not None) else None
+ tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None
+ tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
+ if const_expr(mW is not None):
+ cute.copy(get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False), tXgW, tXrW)
+ if const_expr(mB is not None):
+ cute.copy(get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False), tXgB, tXrB)
+
+ self._init_cluster(tidx, mbar_ptr)
+
+ mX_i, mRes_i, mO_i, mResO_i = [
+ qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t) if t is not None else None
+ for t in (mX, mRes, mO, mResO)
+ ]
+ gX_i = cute.local_tile(mX_i, tiler_mn, (0, cluster_y))
+ gO_i = cute.local_tile(mO_i, tiler_mn, (0, cluster_y))
+ gRes_i = (
+ cute.local_tile(mRes_i, tiler_mn, (0, cluster_y)) if const_expr(mRes is not None) else None
+ )
+ gResO_i = (
+ cute.local_tile(mResO_i, tiler_mn, (0, cluster_y)) if const_expr(mResO is not None) else None
+ )
+ gRstd_i = (
+ cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y)) if const_expr(mRstd is not None) else None
+ )
+ cX_i = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
+
+ tXcX_i = thr_copy.partition_S(cX_i)[(0, None), None, None]
+ row_i = tXcX_i[0][0]
+ tXgRstd_i = thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None
+
+ # Intra-row K-loop cp.async ping-pong (two-pass) for N≈6k/8k (stage=2)
+ if const_expr(self.stage > 1 and (shape[1] == 6144 or shape[1] == 8192)):
+ vecsize = tv_layout.shape[1][0]
+ tpr = threads_per_row
+ target_tile_n = const_expr(4096 if shape[1] == 6144 else 8192)
+ tile_factor = const_expr(target_tile_n // (vecsize * tpr))
+ tile_n = vecsize * tpr * tile_factor
+ num_tiles = cute.ceil_div(shape[1], tile_n)
+
+ tiler_mn_tile = (tiler_mn[0], tile_n)
+ sX0_tile = cute.local_tile(sX0, tiler_mn_tile, (0, 0))
+ sX1_tile = cute.local_tile(sX1, tiler_mn_tile, (0, 0)) if const_expr(self.stage > 1) else None
+ sRes0_tile = (
+ cute.local_tile(sRes0, tiler_mn_tile, (0, 0)) if const_expr(mRes is not None) else None
+ )
+ sRes1_tile = (
+ cute.local_tile(sRes1, tiler_mn_tile, (0, 0)) if const_expr(mRes is not None and self.stage > 1) else None
+ )
+
+ tv_layout_tile = cute.make_layout(
+ ((tpr, tiler_mn[0]), (vecsize, tile_factor)),
+ stride=((vecsize * tiler_mn[0], 1), (tiler_mn[0], tiler_mn[0] * vecsize * tpr)),
+ )
+ thr_copy_tile = cute.make_tiled_copy(copy_atom, tv_layout_tile, tiler_mn_tile).get_slice(tidx)
+
+ sum_sq_acc = cute.Float32(0.0)
+ k_off0 = const_expr(0) * tile_n
+ gX_0 = cute.local_tile(qutils.domain_offset_i64((0, k_off0), mX_i), tiler_mn_tile, (0, cluster_y))
+ tXgX_0 = thr_copy_tile.partition_S(gX_0)
+ tXsX_0 = thr_copy_tile.partition_D(sX0_tile)
+ cX_0 = cute.local_tile(cute.domain_offset((0, k_off0), cX_i), tiler_mn_tile, (0, cluster_y))
+ tXc_0 = thr_copy_tile.partition_S(cX_0)
+ tXp_0 = qutils.predicate_k(tXc_0, limit=shape[1])
+ tXp_ping = tXp_0
+ tXp_pong = tXp_0
+ if row_i < shape[0]:
+ copy_tiled(tXgX_0, tXsX_0, num_copy_elems=vecsize, is_async=use_async, pred=tXp_0)
+ if const_expr(mRes is not None):
+ gRes_0 = cute.local_tile(qutils.domain_offset_i64((0, k_off0), mRes_i), tiler_mn_tile, (0, cluster_y))
+ tXgRes_0 = thr_copy_tile.partition_S(gRes_0)
+ tXsRes_0 = thr_copy_tile.partition_D(sRes0_tile)
+ copy_tiled(tXgRes_0, tXsRes_0, num_copy_elems=vecsize, is_async=use_async, pred=tXp_0)
+ if const_expr(use_async):
+ cute.arch.cp_async_commit_group()
+
+ for t in cutlass.range_constexpr(num_tiles):
+ next_t = t + 1
+ if next_t < num_tiles:
+ k_off_n = next_t * tile_n
+ gX_n = cute.local_tile(qutils.domain_offset_i64((0, k_off_n), mX_i), tiler_mn_tile, (0, cluster_y))
+ tXgX_n = thr_copy_tile.partition_S(gX_n)
+ cX_n = cute.local_tile(cute.domain_offset((0, k_off_n), cX_i), tiler_mn_tile, (0, cluster_y))
+ tXc_n = thr_copy_tile.partition_S(cX_n)
+ tXp_n = qutils.predicate_k(tXc_n, limit=shape[1])
+ if const_expr((t % 2) == 0):
+ tXsX_n = thr_copy_tile.partition_D(sX1_tile)
+ tXsRes_n = (
+ thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None
+ )
+ tXp_pong = tXp_n
+ else:
+ tXsX_n = thr_copy_tile.partition_D(sX0_tile)
+ tXsRes_n = (
+ thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None
+ )
+ tXp_ping = tXp_n
+ if row_i < shape[0]:
+ copy_tiled(tXgX_n, tXsX_n, num_copy_elems=vecsize, is_async=use_async, pred=tXp_n)
+ if const_expr(mRes is not None):
+ gRes_n = cute.local_tile(qutils.domain_offset_i64((0, k_off_n), mRes_i), tiler_mn_tile, (0, cluster_y))
+ tXgRes_n = thr_copy_tile.partition_S(gRes_n)
+ copy_tiled(tXgRes_n, tXsRes_n, num_copy_elems=vecsize, is_async=use_async, pred=tXp_n)
+ if const_expr(use_async):
+ cute.arch.cp_async_commit_group()
+ if const_expr(use_async):
+ cute.arch.cp_async_wait_group(1 if next_t < num_tiles else 0)
+
+ if const_expr((t % 2) == 0):
+ tXsX_cur = thr_copy_tile.partition_D(sX0_tile)
+ tXsRes_cur = thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None
+ pred_cur = tXp_ping
+ else:
+ tXsX_cur = thr_copy_tile.partition_D(sX1_tile)
+ tXsRes_cur = thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None
+ pred_cur = tXp_pong
+ qutils.fill_oob(tXsX_cur, pred_cur, mX.element_type.zero)
+ if const_expr(mRes is not None):
+ qutils.fill_oob(tXsRes_cur, pred_cur, mRes.element_type.zero)
+
+ k_off = t * tile_n
+ gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, cluster_y))
+ tXgX_t = thr_copy_tile.partition_S(gX_t)
+ tXrX = cute.make_fragment_like(tXgX_t)
+ cute.autovec_copy(tXsX_cur, tXrX)
+ x = tXrX.load().to(cute.Float32)
+ if const_expr(mRes is not None):
+ gRes_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, cluster_y))
+ tXgRes_t = thr_copy_tile.partition_S(gRes_t)
+ tXrRes = cute.make_fragment_like(tXgRes_t)
+ cute.autovec_copy(tXsRes_cur, tXrRes)
+ x += tXrRes.load().to(cute.Float32)
+
+ if const_expr(mResO is not None):
+ gResO_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mResO_i), tiler_mn_tile, (0, cluster_y))
+ tXgResO_t = thr_copy_tile.partition_D(gResO_t)
+ tXrResO = cute.make_fragment_like(tXgResO_t)
+ tXrResO.store(x.to(tXrResO.element_type))
+ if row_i < shape[0]:
+ copy_tiled(tXrResO, tXgResO_t, num_copy_elems=vecsize, is_async=False, pred=pred_cur)
+
+ sum_sq_tile = row_reduce(
+ x * x,
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr,
+ init_val=0.0,
+ hook_fn=(cute.arch.cluster_wait if const_expr(cluster_n > 1) else None),
+ )
+ sum_sq_acc = sum_sq_acc + sum_sq_tile
+
+ rstd = cute.math.rsqrt(sum_sq_acc / shape[1] + eps, fastmath=True)
+ if const_expr(mRstd is not None):
+ if (
+ tXcX_i[0][1] == 0
+ and row_i < shape[0]
+ and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
+ ):
+ tXgRstd_i[0] = rstd
+
+ for t in cutlass.range_constexpr(num_tiles):
+ k_off = t * tile_n
+ cX_t = cute.local_tile(cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, cluster_y))
+ tXc_t = thr_copy_tile.partition_S(cX_t)
+ tXp_t = qutils.predicate_k(tXc_t, limit=shape[1])
+
+ if const_expr((t % 2) == 0):
+ tXsX_cur = thr_copy_tile.partition_D(sX0_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None
+ )
+ else:
+ tXsX_cur = thr_copy_tile.partition_D(sX1_tile)
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None
+ )
+
+ qutils.fill_oob(tXsX_cur, tXp_t, mX.element_type.zero)
+ if const_expr(mRes is not None):
+ qutils.fill_oob(tXsRes_cur, tXp_t, mRes.element_type.zero)
+
+ gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, cluster_y))
+ tXgX_t = thr_copy_tile.partition_S(gX_t)
+ tXrX = cute.make_fragment_like(tXgX_t)
+ cute.autovec_copy(tXsX_cur, tXrX)
+ x = tXrX.load().to(cute.Float32)
+ if const_expr(mRes is not None):
+ gRes_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, cluster_y))
+ tXgRes_t = thr_copy_tile.partition_S(gRes_t)
+ tXrRes = cute.make_fragment_like(tXgRes_t)
+ cute.autovec_copy(tXsRes_cur, tXrRes)
+ x += tXrRes.load().to(cute.Float32)
+
+ y = x * rstd
+ if const_expr(mW is not None):
+ gW_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mW), tiler_mn_tile, (0, cluster_y))
+ tWgW_t = thr_copy_tile.partition_S(gW_t)
+ tWrW_t = cute.make_fragment_like(tWgW_t)
+ copy_tiled(tWgW_t, tWrW_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t)
+ y = y * tWrW_t.load().to(cute.Float32)
+ if const_expr(mB is not None):
+ gB_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mB), tiler_mn_tile, (0, cluster_y))
+ tWgB_t = thr_copy_tile.partition_S(gB_t)
+ tWrB_t = cute.make_fragment_like(tWgB_t)
+ copy_tiled(tWgB_t, tWrB_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t)
+ y = y + tWrB_t.load().to(cute.Float32)
+
+ gO_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mO_i), tiler_mn_tile, (0, cluster_y))
+ tXgO_t = thr_copy_tile.partition_D(gO_t)
+ tXrO = cute.make_fragment_like(tXgO_t)
+ tXrO.store(y.to(tXrO.element_type))
+ if row_i < shape[0]:
+ copy_tiled(tXrO, tXgO_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t)
+
+ return
+
+ # Fallback: single-stage path identical to current rmsnorm.py
+ tXgX_i = thr_copy.partition_S(gX_i)
+ tXgRes_i = thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None
+ tXgO_i = thr_copy.partition_D(gO_i)
+ tXgResO_i = thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None
+ is_even_N_i = const_expr(shape[1] == tiler_mn[1] * cluster_n)
+ tXpX_i = (
+ qutils.predicate_k(thr_copy.partition_S(cX_i), limit=shape[1]) if not is_even_N_i else None
+ )
+
+ if row_i < shape[0]:
+ cute.copy(copy_atom, tXgX_i, thr_copy.partition_D(sX0), pred=tXpX_i)
+ if const_expr(mRes is not None):
+ cute.copy(copy_atom, tXgRes_i, thr_copy.partition_D(sRes0), pred=tXpX_i)
+ if const_expr(use_async):
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+
+ tXrX = cute.make_fragment_like(tXgX_i)
+ cute.autovec_copy(thr_copy.partition_D(sX0), tXrX)
+ x = tXrX.load().to(cute.Float32)
+ if const_expr(mRes is not None):
+ tXrRes = cute.make_fragment_like(tXgRes_i)
+ cute.autovec_copy(thr_copy.partition_D(sRes0), tXrRes)
+ x += tXrRes.load().to(cute.Float32)
+
+ if const_expr(mResO is not None):
+ tXrResO = cute.make_fragment_like(tXgResO_i)
+ tXrResO.store(x.to(tXrResO.element_type))
+ if row_i < shape[0]:
+ cute.copy(
+ get_copy_atom_bw(tXrResO.element_type, num_copy_elems_X, is_async=False),
+ tXrResO,
+ tXgResO_i,
+ )
+
+ sum_sq = row_reduce(
+ x * x,
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr,
+ init_val=0.0,
+ hook_fn=(cute.arch.cluster_wait if const_expr(cluster_n > 1) else None),
+ )
+ rstd = cute.math.rsqrt(sum_sq / shape[1] + eps, fastmath=True)
+
+ if const_expr(mRstd is not None):
+ if (
+ tXcX_i[0][1] == 0
+ and row_i < shape[0]
+ and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
+ ):
+ tXgRstd_i[0] = rstd
+
+ y = x * rstd
+ if const_expr(mW is not None):
+ y = y * tXrW.load().to(cute.Float32)
+ if const_expr(mB is not None):
+ y = y + tXrB.load().to(cute.Float32)
+
+ tXrO = cute.make_fragment_like(tXgO_i)
+ tXrO.store(y.to(tXrO.element_type))
+ if row_i < shape[0]:
+ cute.copy(
+ get_copy_atom_bw(tXrO.element_type, num_copy_elems_X, is_async=False),
+ tXrO,
+ tXgO_i,
+ )
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ eps: Float32,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ num_warps: cutlass.Constexpr[int],
+ warps_per_row: cutlass.Constexpr[int],
+ threads_per_row: cutlass.Constexpr[int],
+ ):
+ self._kernel_impl(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ tv_layout,
+ tiler_mn,
+ num_warps,
+ warps_per_row,
+ threads_per_row,
+ )
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mW: Optional[cute.Tensor],
+ mB: Optional[cute.Tensor],
+ mRes: Optional[cute.Tensor],
+ mO: cute.Tensor,
+ mResO: Optional[cute.Tensor],
+ mRstd: Optional[cute.Tensor],
+ eps: Float32,
+ ):
+ copy_bits = const_expr(128)
+ tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits)
+ num_threads = self._num_threads()
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ threads_per_row = self._threads_per_row()
+ warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1)
+ self._kernel_impl(
+ mX,
+ mW,
+ mB,
+ mRes,
+ mO,
+ mResO,
+ mRstd,
+ eps,
+ tv_layout,
+ tiler_mn,
+ const_expr(num_warps),
+ const_expr(warps_per_row),
+ const_expr(threads_per_row),
+ )
+
+ @cute.jit
+ def _alloc_reduction_and_mbar(
+ self,
+ smem: cutlass.utils.SmemAllocator,
+ num_warps: cutlass.Constexpr[int],
+ warps_per_row: cutlass.Constexpr[int],
+ ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]:
+ cluster_n = self._cluster_n()
+ red_layout = cute.make_ordered_layout(
+ (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
+ order=(1, 0, 2),
+ )
+ reduction_buffer = smem.allocate_tensor(self.reduction_dtype, red_layout, byte_alignment=4)
+ if const_expr(cluster_n > 1):
+ mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
+ else:
+ mbar_ptr = None
+ return reduction_buffer, mbar_ptr
+
+ @cute.jit
+ def _init_cluster(self, tidx: cutlass.Int32, mbar_ptr: Optional[cute.Pointer]):
+ if const_expr(mbar_ptr is not None):
+ if tidx < self.stage:
+ cute.arch.mbarrier_init(mbar_ptr + tidx, 1)
+ cute.arch.mbarrier_init_fence()
+ cute.arch.cluster_arrive_relaxed()
+
+
+def rmsnorm_forward_with_stage2(
+ x: Tensor,
+ weight: Optional[Tensor] = None,
+ bias: Optional[Tensor] = None,
+ residual: Optional[Tensor] = None,
+ eps: float = 1e-6,
+ store_rstd: bool = False,
+) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
+ assert x.is_cuda
+ assert x.dim() == 2
+ M, N = x.shape
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+
+ convert_x = lambda t: from_dlpack(t.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1)
+ mX = convert_x(x)
+ mRes = convert_x(residual) if residual is not None else None
+ out = torch.empty_like(x, dtype=x.dtype)
+ mO = from_dlpack(out.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1)
+
+ mW = (
+ from_dlpack(weight.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=0)
+ if weight is not None
+ else None
+ )
+ mB = (
+ from_dlpack(bias.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=0)
+ if bias is not None
+ else None
+ )
+ if store_rstd:
+ rstd = torch.empty(M, device=x.device, dtype=torch.float32)
+ mRstd = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+ else:
+ rstd = None
+ mRstd = None
+
+ residual_out = None
+ mResO = None
+ if residual is not None:
+ residual_out = torch.empty_like(residual)
+ mResO = from_dlpack(residual_out.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1)
+
+ # Enable the intra-row cp.async K-loop only for DSv3-style large-N rows
+ # with very large M, where there is enough work per row to amortize the
+ # pipeline start-up cost. Mid-size M shapes are better served by the
+ # simpler single-stage schedule.
+ use_kloop = bool(M >= 65536 and N in (6144, 8192))
+ stage = 2 if use_kloop else 1
+ op = RMSNormSM100WithStage2(N, dtype, stage=stage)
+ if use_kloop:
+ op._tpr_override = 128 # type: ignore[attr-defined]
+ # Prefer 1 row/CTA at N=6144; keep 2 rows/CTA at N=8192 to match
+ # the original tuning there.
+ op._nt_override = (128 if N == 6144 else 256) # type: ignore[attr-defined]
+
+ stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+ key = (
+ N,
+ dtype,
+ mRes is not None,
+ mW is not None,
+ mB is not None,
+ mResO is not None,
+ mRstd is not None,
+ stage,
+ )
+ compiled = _COMPILE_CACHE.get(key)
+ if compiled is None:
+ compiled = cute.compile(op, mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps))
+ _COMPILE_CACHE[key] = compiled
+ compiled(mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps))
+ return out, rstd, residual_out
diff --git a/oink/src/kernelagent_oink/blackwell/softmax.py b/oink/src/kernelagent_oink/blackwell/softmax.py
new file mode 100644
index 0000000..a2f2581
--- /dev/null
+++ b/oink/src/kernelagent_oink/blackwell/softmax.py
@@ -0,0 +1,749 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Softmax forward + backward kernels for SM100 (Blackwell) in CuteDSL.
+
+This module implements numerically stable softmax over the last dimension of
+2D tensors (M, N) and its backward pass, targeting SM100 with Quack-style
+tiling, cp.async pipelines, and cluster reductions, but without depending on
+the `quack` package at runtime.
+
+The kernels are self-contained and use only local helpers in
+`kernelagent_oink.blackwell.lite_quack` plus CuTeDSL/CUTLASS.
+"""
+
+from __future__ import annotations
+
+import importlib.metadata
+import math
+import os
+import re
+from typing import Optional, Type
+
+import torch
+from torch import Tensor
+
+import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python
+
+# CuTeDSL caches generated MLIR into a tempdir under a global default
+# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across
+# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy
+# warnings (and disables cache reuse).
+if "CUTE_DSL_CACHE_DIR" not in os.environ:
+ try:
+ _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl")
+ except Exception:
+ _dsl_ver = "unknown"
+ _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver)
+ _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user"
+ _tmp = os.environ.get("TMPDIR") or "/tmp"
+ os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join(
+ _tmp, _user, f"cutlass_python_cache_{_dsl_ver}"
+ )
+
+try:
+ import cutlass # type: ignore # noqa: F401
+except Exception as e:
+ raise ImportError(
+ "kernelagent_oink.blackwell.softmax requires CuTeDSL's Python package "
+ "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)."
+ ) from e
+
+import cutlass.cute as cute
+from cutlass import Float32, Int32, const_expr
+from cutlass.cute import runtime as rt
+from cutlass.cute.runtime import from_dlpack
+
+from kernelagent_oink.blackwell.lite_quack import (
+ _KERNEL_ACCEPTS_LAYOUT_ARGS,
+ TORCH2CUTE_DTYPE,
+ ReductionBase,
+ domain_offset_i64,
+ fill_oob,
+ online_softmax_reduce,
+ predicate_k,
+ row_reduce,
+)
+
+_FWD_COMPILE_CACHE: dict[tuple[Type[cutlass.Numeric], int], object] = {}
+_BWD_COMPILE_CACHE: dict[tuple[Type[cutlass.Numeric], int], object] = {}
+_PTR_FWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+_PTR_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {}
+
+
+class SoftmaxFwdSM100(ReductionBase):
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
+ # One-stage online reduction: pack (max, sum_exp) into Int64 reduction buffer.
+ super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64)
+
+ def _calculate_threads_per_row(self) -> int:
+ # Match Quack's bucketed policy for Softmax.
+ N = self.N
+ if N <= 64:
+ return 8
+ if N <= 128:
+ return 16
+ if N <= 3072:
+ return 32
+ if N <= 6144:
+ return 64
+ if N <= 16384:
+ return 128
+ return 256
+
+ def _set_cluster_n(self) -> None:
+ # Quack-style growth of cluster_n with N and dtype.
+ N = self.N
+ if const_expr(self.dtype.width == 16):
+ cluster_n = (
+ 1
+ if N <= 16 * 1024
+ else (
+ 2
+ if N <= 32 * 1024
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
+ )
+ )
+ else:
+ cluster_n = (
+ 1
+ if N <= 32 * 1024
+ else (
+ 2
+ if N <= 64 * 1024
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
+ )
+ )
+ self.cluster_n = cluster_n
+
+ @cute.jit
+ def __call__(self, mX: cute.Tensor, mO: cute.Tensor, stream: cuda.CUstream) -> None:
+ assert mX.element_type == self.dtype
+ assert mO.element_type == self.dtype
+ # Use the generic ReductionBase tiling with 128-bit vectorization.
+ tiler_mn, tv_layout = self._get_tv_layout()
+ num_threads = (
+ cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ kernel = (
+ self.kernel(mX, mO, tv_layout, tiler_mn)
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(mX, mO)
+ )
+ kernel.launch(
+ grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_x: cute.Pointer,
+ ptr_out: cute.Pointer,
+ M: Int32,
+ ld: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ """Pointer-based entrypoint that bypasses DLPack conversions.
+
+ Reconstructs cute.Tensor views from raw pointers + explicit layouts
+ inside the JIT graph, matching the existing SM100 schedule.
+ """
+ # Mirror Quack/LayerNorm contracts: assume 16B alignment and an LD that
+ # preserves 128-bit vectorized copies for every row start.
+ ld_assumed = cute.assume(ld, divby=128 // self.dtype.width)
+ layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1))
+ mX = cute.make_tensor(ptr_x, layout_mn)
+ mO = cute.make_tensor(ptr_out, layout_mn)
+ self.__call__(mX, mO, stream)
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mX: cute.Tensor,
+ mO: cute.Tensor,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ if const_expr(self.cluster_n > 1):
+ cluster_y = cute.arch.block_idx()[1]
+ else:
+ cluster_y = const_expr(0)
+
+ shape = mX.shape
+ idX = cute.make_identity_tensor(shape)
+
+ # Slice per-CTA region; use 64-bit indexing for large tensors.
+ mX, mO = [domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
+ gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
+
+ smem = cutlass.utils.SmemAllocator()
+ sX = smem.allocate_tensor(
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
+
+ # Copy atoms for gmem <-> smem and smem <-> gmem.
+ # Use 128-bit cp.async for global->shared and 128-bit vectorized stores.
+ copy_atom_load = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ mX.element_type,
+ num_bits_per_copy=128,
+ )
+ copy_atom_store = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ gO.element_type,
+ num_bits_per_copy=128,
+ )
+
+ thr_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn).get_slice(tidx)
+ thr_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
+
+ tXgX = thr_copy_load.partition_S(gX)
+ tXsX = thr_copy_load.partition_D(sX)
+ tXgO = thr_copy_store.partition_D(gO)
+ tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None]
+
+ # Register fragments.
+ tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
+
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
+
+ # Predicate and cp.async pipeline for potential tail tiles.
+ is_even_N = const_expr(self.N == tiler_mn[1] * self.cluster_n)
+ tXpX = (
+ predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+
+ if tXcX[0][0] < shape[0]:
+ cute.copy(copy_atom_load, tXgX, tXsX, pred=tXpX)
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+
+ if const_expr(not is_even_N):
+ fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
+
+ cute.autovec_copy(tXsX, tXrX)
+ x = tXrX.load().to(Float32)
+ threads_per_row = tv_layout.shape[0][0]
+
+ # Online softmax reduction: compute max and sum_exp in a single pass, with
+ # optional cluster-wide aggregation via an Int64 reduction buffer.
+ max_x, denom, exp_x = online_softmax_reduce(
+ x,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr,
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
+ phase=None,
+ return_exp_x=True,
+ )
+
+ y = exp_x * cute.arch.rcp_approx(denom)
+ tXrO.store(y.to(tXrO.element_type))
+
+ tOpO = (
+ predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+
+ if tXcX[0][0] < shape[0]:
+ cute.copy(copy_atom_store, tXrO, tXgO, pred=tOpO)
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mO: cute.Tensor,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ self._kernel_impl(mX, mO, tv_layout, tiler_mn)
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mX: cute.Tensor,
+ mO: cute.Tensor,
+ ) -> None:
+ tiler_mn, tv_layout = self._get_tv_layout()
+ self._kernel_impl(mX, mO, tv_layout, tiler_mn)
+
+
+class SoftmaxBwdSM100(ReductionBase):
+ def __init__(self, dtype: Type[cutlass.Numeric], N: int):
+ # One stage for dot(dy, y) per row.
+ super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32)
+
+ def _calculate_threads_per_row(self) -> int:
+ # Match Quack backward softmax buckets.
+ N = self.N
+ if N <= 64:
+ return 8
+ if N <= 128:
+ return 16
+ if N <= 3072:
+ return 32
+ if N <= 6144:
+ return 64
+ if N <= 8192:
+ return 128
+ return 256
+
+ def _set_cluster_n(self) -> None:
+ N = self.N
+ if const_expr(self.dtype.width == 16):
+ cluster_n = (
+ 1
+ if N <= 16 * 1024
+ else (
+ 2
+ if N <= 32 * 1024
+ else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16))
+ )
+ )
+ else:
+ cluster_n = (
+ 1
+ if N <= 32 * 1024
+ else (
+ 2
+ if N <= 64 * 1024
+ else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16))
+ )
+ )
+ self.cluster_n = cluster_n
+
+ def _get_num_threads(self) -> int:
+ # Slightly more aggressive threading for large N than the base class.
+ return 128 if self.N <= 8192 else 256
+
+ def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int:
+ # Store both y and dy tiles plus reduction buffers and mbarriers.
+ return (
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
+ + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
+ + self.stage * (cutlass.Int64.width // 8)
+ )
+
+ @cute.jit
+ def __call__(
+ self,
+ mdY: cute.Tensor,
+ mY: cute.Tensor,
+ mdX: cute.Tensor,
+ stream: cuda.CUstream,
+ ) -> None:
+ assert mdY.element_type == self.dtype
+ assert mY.element_type == self.dtype
+ assert mdX.element_type == self.dtype
+ # Use the generic ReductionBase tiling with 128-bit vectorization.
+ tiler_mn, tv_layout = self._get_tv_layout()
+ num_threads = (
+ cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ )
+ num_warps = num_threads // cute.arch.WARP_SIZE
+ kernel = (
+ self.kernel(mdY, mY, mdX, tv_layout, tiler_mn)
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self.kernel(mdY, mY, mdX)
+ )
+ kernel.launch(
+ grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1],
+ block=[num_threads, 1, 1],
+ cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None,
+ smem=self._smem_size_in_bytes(tiler_mn, num_warps),
+ stream=stream,
+ )
+
+ @cute.jit
+ def launch_from_ptrs(
+ self,
+ ptr_dy: cute.Pointer,
+ ptr_y: cute.Pointer,
+ ptr_dx: cute.Pointer,
+ M: Int32,
+ ld: Int32,
+ stream: cuda.CUstream,
+ ) -> None:
+ """Pointer-based entrypoint that bypasses DLPack conversions."""
+ ld_assumed = cute.assume(ld, divby=128 // self.dtype.width)
+ layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1))
+ mdY = cute.make_tensor(ptr_dy, layout_mn)
+ mY = cute.make_tensor(ptr_y, layout_mn)
+ mdX = cute.make_tensor(ptr_dx, layout_mn)
+ self.__call__(mdY, mY, mdX, stream)
+
+ @cute.jit
+ def _kernel_impl(
+ self,
+ mdY: cute.Tensor,
+ mY: cute.Tensor,
+ mdX: cute.Tensor,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ tidx, _, _ = cute.arch.thread_idx()
+ bidx, _, _ = cute.arch.block_idx()
+ if const_expr(self.cluster_n > 1):
+ cluster_y = cute.arch.block_idx()[1]
+ else:
+ cluster_y = const_expr(0)
+
+ shape = mdY.shape
+ idX = cute.make_identity_tensor(shape)
+
+ mdY, mY, mdX = [
+ domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mdY, mY, mdX)
+ ]
+ gdY, gY, gdX = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX)]
+ cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
+
+ smem = cutlass.utils.SmemAllocator()
+ sdY = smem.allocate_tensor(
+ mdY.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+ sY = smem.allocate_tensor(
+ mY.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=16,
+ )
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
+
+ copy_atom_load = cute.make_copy_atom(
+ cute.nvgpu.cpasync.CopyG2SOp(),
+ mdY.element_type,
+ num_bits_per_copy=128,
+ )
+ copy_atom_store = cute.make_copy_atom(
+ cute.nvgpu.CopyUniversalOp(),
+ gdX.element_type,
+ num_bits_per_copy=128,
+ )
+
+ thr_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn).get_slice(tidx)
+ thr_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
+
+ tdYgdY = thr_copy_load.partition_S(gdY)
+ tdYsdY = thr_copy_load.partition_D(sdY)
+ tYgY = thr_copy_load.partition_S(gY)
+ tYsY = thr_copy_load.partition_D(sY)
+ tdXgdX = thr_copy_store.partition_D(gdX)
+ tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None]
+
+ tdYrdY, tYrY, tdXrdX = [cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX)]
+
+ num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
+ self._initialize_cluster(tidx, mbar_ptr, num_warps)
+
+ is_even_N = const_expr(self.N == tiler_mn[1] * self.cluster_n)
+ tdYpdY = (
+ predicate_k(thr_copy_load.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+
+ if tXcX[0][0] < shape[0]:
+ cute.copy(copy_atom_load, tdYgdY, tdYsdY, pred=tdYpdY)
+ cute.copy(copy_atom_load, tYgY, tYsY, pred=tdYpdY)
+ cute.arch.cp_async_commit_group()
+ cute.arch.cp_async_wait_group(0)
+
+ cute.autovec_copy(tdYsdY, tdYrdY)
+ cute.autovec_copy(tYsY, tYrY)
+ dy = tdYrdY.load().to(Float32)
+ y = tYrY.load().to(Float32)
+
+ threads_per_row = tv_layout.shape[0][0]
+ dot = row_reduce(
+ dy * y,
+ cute.ReductionOp.ADD,
+ threads_per_row,
+ reduction_buffer[None, None, 0],
+ mbar_ptr if const_expr(self.cluster_n > 1) else None,
+ init_val=0.0,
+ hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None,
+ )
+
+ dx = y * (dy - dot)
+ tdXrdX.store(dx.to(tdXrdX.element_type))
+
+ tdXpdX = (
+ predicate_k(thr_copy_store.partition_S(cX), limit=shape[1])
+ if const_expr(not is_even_N)
+ else None
+ )
+ if tXcX[0][0] < shape[0]:
+ cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tdXpdX)
+
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mdY: cute.Tensor,
+ mY: cute.Tensor,
+ mdX: cute.Tensor,
+ tv_layout: cute.Layout,
+ tiler_mn: cute.Shape,
+ ) -> None:
+ self._kernel_impl(mdY, mY, mdX, tv_layout, tiler_mn)
+ else:
+
+ @cute.kernel
+ def kernel(
+ self,
+ mdY: cute.Tensor,
+ mY: cute.Tensor,
+ mdX: cute.Tensor,
+ ) -> None:
+ tiler_mn, tv_layout = self._get_tv_layout()
+ self._kernel_impl(mdY, mY, mdX, tv_layout, tiler_mn)
+
+
+def _convert_2d_tensor(x: Tensor) -> cute.Tensor:
+ # Match Quack's Softmax conversion exactly: assume 16B alignment and mark
+ # the shape compact with row-major stride order (0, 1), with mode=0 (batch).
+ # We intentionally do not call mark_layout_dynamic here to avoid the
+ # leading_dim stride==1 constraint used in RMSNorm.
+ return (
+ from_dlpack(x.detach(), assumed_align=16)
+ .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
+ )
+
+
+def _can_use_ptr_path_2d(x: Tensor) -> bool:
+ """Conservative guard for the pointer-based fast path."""
+ if not x.is_cuda or x.dim() != 2:
+ return False
+ if x.dtype not in TORCH2CUTE_DTYPE:
+ return False
+ # Require row-major last-dim contiguous.
+ if x.stride(1) != 1:
+ return False
+ # Require 16B alignment (matches from_dlpack(..., assumed_align=16)).
+ if (x.data_ptr() % 16) != 0:
+ return False
+ dtype_x = TORCH2CUTE_DTYPE[x.dtype]
+ divby = 128 // dtype_x.width
+ # Softmax uses ReductionBase default num_copy_bits=128, so N must be divisible.
+ if (x.shape[1] % divby) != 0:
+ return False
+ # Ensure each row start remains aligned for 128-bit vectorized copies.
+ if (x.stride(0) % divby) != 0:
+ return False
+ return True
+
+
+def _softmax_forward_ptr_into(*, x: Tensor, out: Tensor) -> None:
+ """Launch the pointer-based Softmax forward kernel into preallocated `out`."""
+ assert x.is_cuda and x.dim() == 2
+ assert out.is_cuda and out.shape == x.shape and out.dtype == x.dtype
+ assert out.stride() == x.stride(), "Pointer path expects out to match x strides"
+
+ M, N = x.shape
+ device_index = x.get_device()
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream))
+
+ dtype_x = TORCH2CUTE_DTYPE[x.dtype]
+ key = ("ptr_fwd", int(N), dtype_x, int(device_index))
+ compiled = _PTR_FWD_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = SoftmaxFwdSM100(dtype_x, int(N))
+ ptr_x = rt.make_ptr(dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_out = rt.make_ptr(
+ dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ld = Int32(int(x.stride(0)))
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_x,
+ ptr_out,
+ Int32(int(M)),
+ ld,
+ stream,
+ )
+ _PTR_FWD_COMPILE_CACHE[key] = compiled
+
+ ptr_x = rt.make_ptr(dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_out = rt.make_ptr(dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ compiled(ptr_x, ptr_out, Int32(int(M)), Int32(int(x.stride(0))), stream)
+
+
+def _softmax_backward_ptr_into(*, dy: Tensor, y: Tensor, dx: Tensor) -> None:
+ """Launch the pointer-based Softmax backward kernel into preallocated `dx`."""
+ assert dy.is_cuda and dy.dim() == 2
+ assert y.is_cuda and y.shape == dy.shape and y.dtype == dy.dtype
+ assert dx.is_cuda and dx.shape == dy.shape and dx.dtype == dy.dtype
+ assert dy.stride() == y.stride() == dx.stride(), "Pointer path expects matching strides"
+
+ M, N = dy.shape
+ device_index = dy.get_device()
+ if torch.cuda.current_device() != device_index:
+ torch.cuda.set_device(device_index)
+ stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream))
+
+ dtype_x = TORCH2CUTE_DTYPE[dy.dtype]
+ key = ("ptr_bwd", int(N), dtype_x, int(device_index))
+ compiled = _PTR_BWD_COMPILE_CACHE.get(key)
+ if compiled is None:
+ op = SoftmaxBwdSM100(dtype_x, int(N))
+ ptr_dy = rt.make_ptr(dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_y = rt.make_ptr(dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_dx = rt.make_ptr(dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ld = Int32(int(dy.stride(0)))
+ compiled = cute.compile(
+ op.launch_from_ptrs,
+ ptr_dy,
+ ptr_y,
+ ptr_dx,
+ Int32(int(M)),
+ ld,
+ stream,
+ )
+ _PTR_BWD_COMPILE_CACHE[key] = compiled
+
+ ptr_dy = rt.make_ptr(dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_y = rt.make_ptr(dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_dx = rt.make_ptr(dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ compiled(ptr_dy, ptr_y, ptr_dx, Int32(int(M)), Int32(int(dy.stride(0))), stream)
+
+
+def softmax_forward(x: Tensor) -> Tensor:
+ """SM100 CuteDSL softmax forward pass: y = softmax(x, dim=-1)."""
+ assert x.dim() == 2, "Input must be 2D (M, N)"
+ assert x.is_cuda, "Input must be on CUDA device"
+ assert x.dtype in TORCH2CUTE_DTYPE, "Unsupported dtype"
+
+ N = x.size(1)
+ dtype = TORCH2CUTE_DTYPE[x.dtype]
+ if _can_use_ptr_path_2d(x):
+ out = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
+ _softmax_forward_ptr_into(x=x, out=out)
+ return out
+
+ out = torch.empty_like(x)
+
+ x_tensor = _convert_2d_tensor(x)
+ out_tensor = _convert_2d_tensor(out)
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+
+ compile_key = (dtype, N)
+ kernel = _FWD_COMPILE_CACHE.get(compile_key)
+ if kernel is None:
+ op = SoftmaxFwdSM100(dtype, N)
+ kernel = cute.compile(op, x_tensor, out_tensor, current_stream)
+ _FWD_COMPILE_CACHE[compile_key] = kernel
+ kernel(x_tensor, out_tensor, current_stream)
+ return out
+
+
+def softmax_backward(dy: Tensor, y: Tensor) -> Tensor:
+ """SM100 CuteDSL softmax backward pass."""
+ assert dy.dim() == 2 and y.dim() == 2, "dy and y must be 2D (M, N)"
+ assert dy.shape == y.shape, "dy and y must have the same shape"
+ assert dy.is_cuda and y.is_cuda, "dy and y must be on CUDA device"
+ assert dy.dtype in TORCH2CUTE_DTYPE, "Unsupported dtype"
+ assert y.dtype == dy.dtype, "dy and y must have the same dtype"
+
+ N = dy.size(1)
+ dtype = TORCH2CUTE_DTYPE[dy.dtype]
+ if _can_use_ptr_path_2d(dy) and _can_use_ptr_path_2d(y) and dy.stride() == y.stride():
+ dx = torch.empty_strided(dy.shape, dy.stride(), device=dy.device, dtype=dy.dtype)
+ _softmax_backward_ptr_into(dy=dy, y=y, dx=dx)
+ return dx
+
+ dx = torch.empty_like(dy)
+
+ dy_tensor = _convert_2d_tensor(dy)
+ y_tensor = _convert_2d_tensor(y)
+ dx_tensor = _convert_2d_tensor(dx)
+ current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
+
+ compile_key = (dtype, N)
+ kernel = _BWD_COMPILE_CACHE.get(compile_key)
+ if kernel is None:
+ op = SoftmaxBwdSM100(dtype, N)
+ kernel = cute.compile(op, dy_tensor, y_tensor, dx_tensor, current_stream)
+ _BWD_COMPILE_CACHE[compile_key] = kernel
+ kernel(dy_tensor, y_tensor, dx_tensor, current_stream)
+ return dx
+
+
+class SoftmaxFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x: Tensor) -> Tensor:
+ y = softmax_forward(x)
+ ctx.save_for_backward(y)
+ return y
+
+ @staticmethod
+ def backward(ctx, dy: Tensor) -> tuple[Tensor]:
+ (y,) = ctx.saved_tensors
+ dx = softmax_backward(dy, y)
+ return dx
+
+
+def softmax(x: Tensor) -> Tensor:
+ """Autograd-friendly softmax using the SM100 CuteDSL kernel."""
+ return SoftmaxFunction.apply(x)
+
+
+def _torch_softmax_reference(x: Tensor) -> Tensor:
+ return torch.nn.functional.softmax(x, dim=-1)
+
+
+def verify_softmax_parity(
+ M: int,
+ N: int,
+ dtype: torch.dtype = torch.bfloat16,
+ atol: float = 5e-2,
+ rtol: float = 5e-2,
+) -> None:
+ """Compare SM100 CuteDSL softmax against PyTorch for a single shape."""
+ device = torch.device("cuda")
+ x = torch.randn(M, N, device=device, dtype=dtype)
+ x.requires_grad_(True)
+
+ # Forward parity
+ y_ref = _torch_softmax_reference(x)
+ y = softmax(x)
+ torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol)
+
+ # Backward parity
+ dy = torch.randn_like(y)
+ (dx_ref,) = torch.autograd.grad(y_ref, x, dy, retain_graph=False)
+ dx = softmax_backward(dy, y)
+ torch.testing.assert_close(dx, dx_ref, atol=atol, rtol=rtol)
From 9b29732bb333a98ddbb5f750c1ff407d9050c5e7 Mon Sep 17 00:00:00 2001
From: Laura Wang <3700467+Laurawly@users.noreply.github.com>
Date: Wed, 21 Jan 2026 20:06:57 -0800
Subject: [PATCH 5/8] oink: fix ruff lint
---
.../benchmark_cross_entropy_sm100.py | 89 ++++++++++---------
.../benchmark_fused_add_rmsnorm_sm100.py | 36 ++++----
.../benchmark/benchmark_hbm_roofline_sm100.py | 42 +++++----
.../benchmark/benchmark_layernorm_sm100.py | 34 +++----
.../benchmark/benchmark_rmsnorm_bwd_sm100.py | 39 ++++----
.../benchmark/benchmark_rmsnorm_sm100.py | 40 +++++----
.../benchmark/benchmark_softmax_sm100.py | 42 ++++++---
.../benchmarks/readme/plot_quack_style_svg.py | 4 +-
.../kernelagent_oink/blackwell/layernorm.py | 2 +-
.../kernelagent_oink/blackwell/lite_quack.py | 2 +-
.../src/kernelagent_oink/blackwell/rmsnorm.py | 13 +--
.../blackwell/rmsnorm_with_stage2.py | 10 ++-
.../src/kernelagent_oink/blackwell/softmax.py | 3 +-
13 files changed, 199 insertions(+), 157 deletions(-)
diff --git a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
index 8bcac15..18399c7 100644
--- a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
@@ -219,14 +219,16 @@ def bench_single(
bytes_io = bytes_io_model_ce(M, N, dtype, target_dtype=target.dtype, mode=mode)
if mode == "fwd":
- fn_oink = lambda: oink_ce.cross_entropy_forward(
- logits, target, ignore_index=int(ignore_index), reduction="none"
- )
- fn_quack = (
- None
- if quack_ce_fwd is None
- else (
- lambda: quack_ce_fwd(
+ def fn_oink():
+ return oink_ce.cross_entropy_forward(
+ logits, target, ignore_index=int(ignore_index), reduction="none"
+ )
+
+ fn_quack = None
+ if quack_ce_fwd is not None:
+
+ def fn_quack():
+ return quack_ce_fwd(
logits,
target,
target_logit=None,
@@ -235,8 +237,7 @@ def bench_single(
return_dx=False,
inplace_backward=False,
)
- )
- )
+
elif mode == "bwd":
with torch.no_grad():
_loss_o, lse_o = oink_ce.cross_entropy_forward(
@@ -254,14 +255,17 @@ def bench_single(
)
else:
lse_q = None
- fn_oink = lambda: oink_ce.cross_entropy_backward(
- dloss, logits, target, lse_o, ignore_index=int(ignore_index)
- )
- fn_quack = (
- None
- if (quack_ce_bwd is None or lse_q is None)
- else (
- lambda: quack_ce_bwd(
+
+ def fn_oink():
+ return oink_ce.cross_entropy_backward(
+ dloss, logits, target, lse_o, ignore_index=int(ignore_index)
+ )
+
+ fn_quack = None
+ if quack_ce_bwd is not None and lse_q is not None:
+
+ def fn_quack():
+ return quack_ce_bwd(
logits,
target,
dloss,
@@ -269,37 +273,38 @@ def bench_single(
ignore_index=int(ignore_index),
inplace_backward=False,
)
- )
- )
+
elif mode == "fwd_bwd":
- fn_oink = lambda: oink_ce.cross_entropy_fwd_bwd(
- dloss,
- logits,
- target,
- ignore_index=int(ignore_index),
- )
- fn_quack = (
- None
- if (quack_ce_fwd is None or quack_ce_bwd is None)
- else (
- lambda: quack_ce_bwd(
+ def fn_oink():
+ return oink_ce.cross_entropy_fwd_bwd(
+ dloss,
+ logits,
+ target,
+ ignore_index=int(ignore_index),
+ )
+
+ fn_quack = None
+ if quack_ce_fwd is not None and quack_ce_bwd is not None:
+
+ def fn_quack():
+ _loss_q, lse_q = quack_ce_fwd(
+ logits,
+ target,
+ target_logit=None,
+ ignore_index=int(ignore_index),
+ return_lse=True,
+ return_dx=False,
+ inplace_backward=False,
+ )
+ return quack_ce_bwd(
logits,
target,
dloss,
- quack_ce_fwd(
- logits,
- target,
- target_logit=None,
- ignore_index=int(ignore_index),
- return_lse=True,
- return_dx=False,
- inplace_backward=False,
- )[1],
+ lse_q,
ignore_index=int(ignore_index),
inplace_backward=False,
)
- )
- )
+
else:
raise ValueError(f"Unsupported mode: {mode}")
diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
index b75f892..6418e61 100644
--- a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
"""
Benchmark fused_add_rmsnorm (in-place) on SM100.
@@ -21,9 +19,11 @@
--json /tmp/kernelagent_oink_sm100_suite_bf16/fused_add_rmsnorm_dsv3.json
"""
+from __future__ import annotations
+
import argparse
import os
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Tuple
import torch
@@ -165,7 +165,9 @@ def bench_one(
bytes_io = bytes_io_model_fused_add_rmsnorm_inplace(M, N, dtype)
- fn = lambda: oink_rmsnorm.fused_add_rmsnorm_inplace_(x, residual, w, eps=1e-6)
+ def fn():
+ oink_rmsnorm.fused_add_rmsnorm_inplace_(x, residual, w, eps=1e-6)
+
ms = do_bench_triton(fn, warmup_ms=warmup_ms, rep_ms=iters_ms)
gbps = bytes_io / (ms * 1e-3) / 1e9
@@ -187,18 +189,20 @@ def bench_one(
out_q = torch.empty_like(x)
res_out_q = torch.empty_like(residual)
- fn_q = lambda: quack_rmsnorm_fwd_mut(
- x,
- w,
- out_q,
- None, # bias
- None, # rstd
- None, # mean
- residual,
- res_out_q,
- 1e-6,
- False, # is_layernorm
- )
+ def fn_q():
+ quack_rmsnorm_fwd_mut(
+ x,
+ w,
+ out_q,
+ None, # bias
+ None, # rstd
+ None, # mean
+ residual,
+ res_out_q,
+ 1e-6,
+ False, # is_layernorm
+ )
+
ms_q = do_bench_triton(fn_q, warmup_ms=warmup_ms, rep_ms=iters_ms)
gbps_q = bytes_io / (ms_q * 1e-3) / 1e9
row.update(
diff --git a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
index 971a03c..8ec4bfd 100644
--- a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
"""
HBM roofline microbenchmark for SM100 (GB200 / Blackwell).
@@ -17,9 +15,11 @@
CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype fp16 --op triad --gb 2
"""
+from __future__ import annotations
+
import argparse
import os
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Tuple
import torch
import triton
@@ -95,23 +95,27 @@ def bench_one(
grid = (triton.cdiv(n_elements, block),)
if op == "copy":
- launch = lambda: _copy_kernel[grid](
- x,
- y,
- n_elements,
- BLOCK=block,
- num_warps=num_warps,
- num_stages=4,
- )
+ def launch():
+ _copy_kernel[grid](
+ x,
+ y,
+ n_elements,
+ BLOCK=block,
+ num_warps=num_warps,
+ num_stages=4,
+ )
+
elif op == "triad":
- launch = lambda: _triad_kernel[grid](
- x,
- y,
- n_elements,
- BLOCK=block,
- num_warps=num_warps,
- num_stages=4,
- )
+ def launch():
+ _triad_kernel[grid](
+ x,
+ y,
+ n_elements,
+ BLOCK=block,
+ num_warps=num_warps,
+ num_stages=4,
+ )
+
else:
raise ValueError(f"Unsupported op: {op}")
diff --git a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
index 778e3e2..a9865d1 100644
--- a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
@@ -244,27 +244,31 @@ def bench_single(
weight_dtype=w.dtype,
)
- fn_oink = lambda: oink_ln.layernorm(
- x,
- w,
- bias=b,
- eps=eps,
- return_rstd=return_rstd,
- return_mean=return_mean,
- )
+ def fn_oink():
+ return oink_ln.layernorm(
+ x,
+ w,
+ bias=b,
+ eps=eps,
+ return_rstd=return_rstd,
+ return_mean=return_mean,
+ )
+
ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms)
gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9
if quack_layernorm is None or has_bias:
return (ms_oink, gbps_oink), None, stats
- fn_quack = lambda: quack_layernorm(
- x,
- w,
- eps=eps,
- return_rstd=return_rstd,
- return_mean=return_mean,
- )
+ def fn_quack():
+ return quack_layernorm(
+ x,
+ w,
+ eps=eps,
+ return_rstd=return_rstd,
+ return_mean=return_mean,
+ )
+
ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms)
gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9
return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
index 01c390d..4ba1c47 100644
--- a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
@@ -262,15 +262,16 @@ def bench_single(
if verify:
stats = _verify_parity(x, w, dout, rstd, has_bias=False, has_residual=False)
- fn_oink = lambda: oink_rmsnorm.rmsnorm_backward(
- x,
- w,
- dout,
- rstd,
- dresidual_out=None,
- has_bias=False,
- has_residual=False,
- )
+ def fn_oink():
+ return oink_rmsnorm.rmsnorm_backward(
+ x,
+ w,
+ dout,
+ rstd,
+ dresidual_out=None,
+ has_bias=False,
+ has_residual=False,
+ )
ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms)
bytes_io = bytes_io_model_bwd(M, N, dtype, weight_dtype=w.dtype)
@@ -280,15 +281,17 @@ def bench_single(
if quack_rmsnorm_bwd is None:
return ours, None, stats
- fn_quack = lambda: quack_rmsnorm_bwd(
- x,
- w,
- dout,
- rstd,
- dresidual_out=None,
- has_bias=False,
- has_residual=False,
- )
+ def fn_quack():
+ return quack_rmsnorm_bwd(
+ x,
+ w,
+ dout,
+ rstd,
+ dresidual_out=None,
+ has_bias=False,
+ has_residual=False,
+ )
+
ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms)
gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9
return ours, Result(ms=ms_quack, gbps=gbps_quack), stats
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
index e55e9ff..20ed8ac 100644
--- a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
@@ -183,30 +183,34 @@ def bench_single(
bytes_io = bytes_io_model_fwd(M, N, dtype, weight_dtype=w.dtype)
- fn_oink = lambda: oink_rmsnorm.rmsnorm_forward(
- x,
- weight=w,
- bias=None,
- residual=None,
- eps=eps,
- store_rstd=store_rstd,
- )
+ def fn_oink():
+ return oink_rmsnorm.rmsnorm_forward(
+ x,
+ weight=w,
+ bias=None,
+ residual=None,
+ eps=eps,
+ store_rstd=store_rstd,
+ )
+
ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms)
gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9
if quack_rmsnorm_fwd is None:
return (ms_oink, gbps_oink), None, stats
- fn_quack = lambda: quack_rmsnorm_fwd(
- x,
- w,
- bias=None,
- residual=None,
- out_dtype=None,
- residual_dtype=None,
- eps=eps,
- store_rstd=store_rstd,
- )
+ def fn_quack():
+ return quack_rmsnorm_fwd(
+ x,
+ w,
+ bias=None,
+ residual=None,
+ out_dtype=None,
+ residual_dtype=None,
+ eps=eps,
+ store_rstd=store_rstd,
+ )
+
ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms)
gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9
return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats
diff --git a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py
index 93c5af3..7826efc 100644
--- a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py
@@ -150,25 +150,39 @@ def bench_single(
bytes_io = bytes_io_model_softmax(M, N, dtype, mode=mode)
if mode == "fwd":
- fn_oink = lambda: oink_softmax.softmax_forward(x)
- fn_quack = None if quack_softmax_fwd is None else (lambda: quack_softmax_fwd(x))
+ def fn_oink():
+ return oink_softmax.softmax_forward(x)
+
+ fn_quack = None
+ if quack_softmax_fwd is not None:
+
+ def fn_quack():
+ return quack_softmax_fwd(x)
+
elif mode == "bwd":
with torch.no_grad():
y_o = oink_softmax.softmax_forward(x)
y_q = quack_softmax_fwd(x) if quack_softmax_fwd is not None else None
- fn_oink = lambda: oink_softmax.softmax_backward(dy, y_o)
- fn_quack = (
- None
- if (quack_softmax_bwd is None or y_q is None)
- else (lambda: quack_softmax_bwd(dy, y_q))
- )
+
+ def fn_oink():
+ return oink_softmax.softmax_backward(dy, y_o)
+
+ fn_quack = None
+ if quack_softmax_bwd is not None and y_q is not None:
+
+ def fn_quack():
+ return quack_softmax_bwd(dy, y_q)
+
elif mode == "fwd_bwd":
- fn_oink = lambda: oink_softmax.softmax_fwd_bwd(dy, x)
- fn_quack = (
- None
- if (quack_softmax_fwd is None or quack_softmax_bwd is None)
- else (lambda: quack_softmax_bwd(dy, quack_softmax_fwd(x)))
- )
+ def fn_oink():
+ return oink_softmax.softmax_fwd_bwd(dy, x)
+
+ fn_quack = None
+ if quack_softmax_fwd is not None and quack_softmax_bwd is not None:
+
+ def fn_quack():
+ return quack_softmax_bwd(dy, quack_softmax_fwd(x))
+
else:
raise ValueError(f"Unsupported mode: {mode}")
diff --git a/oink/benchmarks/readme/plot_quack_style_svg.py b/oink/benchmarks/readme/plot_quack_style_svg.py
index 1799f2e..c089b2b 100644
--- a/oink/benchmarks/readme/plot_quack_style_svg.py
+++ b/oink/benchmarks/readme/plot_quack_style_svg.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
"""
Generate Quack-style SVG performance plots (Oink vs Quack) from the SM100 suite
JSON artifacts under `/tmp/kernelagent_oink_sm100_suite_{bf16,fp16}`.
@@ -42,6 +40,8 @@
when available: `fused_add_rmsnorm_dsv3.json`.
"""
+from __future__ import annotations
+
import argparse
import json
import math
diff --git a/oink/src/kernelagent_oink/blackwell/layernorm.py b/oink/src/kernelagent_oink/blackwell/layernorm.py
index 05b11de..0e4d640 100644
--- a/oink/src/kernelagent_oink/blackwell/layernorm.py
+++ b/oink/src/kernelagent_oink/blackwell/layernorm.py
@@ -80,7 +80,7 @@
# Local helpers cloned from Quack via lite_quack so that this kernel does
# not depend on `quack` at runtime.
-from kernelagent_oink.blackwell.lite_quack import (
+from kernelagent_oink.blackwell.lite_quack import ( # noqa: E402
_KERNEL_ACCEPTS_LAYOUT_ARGS,
TORCH2CUTE_DTYPE,
ReductionBase as _ReductionBase,
diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py
index 1bc15b1..590d773 100644
--- a/oink/src/kernelagent_oink/blackwell/lite_quack.py
+++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py
@@ -39,7 +39,7 @@
from cutlass import Float32, Int32, const_expr
from cutlass.cute.runtime import from_dlpack
from cutlass.cutlass_dsl import T, dsl_user_op
-from cutlass._mlir.dialects import llvm, nvvm, vector
+from cutlass._mlir.dialects import llvm, vector
def _parse_version_tuple(version: str) -> tuple[int, int, int]:
diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
index 1e080a3..9df9f16 100644
--- a/oink/src/kernelagent_oink/blackwell/rmsnorm.py
+++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
@@ -3119,7 +3119,6 @@ def _rmsnorm_bwd_sm100(
assert dresidual.dtype in (torch.float16, torch.bfloat16, torch.float32)
M, N = x.size(0), x.size(1)
- device = x.device
if dw_partial is None and db_partial is None:
assert sm_count is not None
else:
@@ -3130,12 +3129,14 @@ def _rmsnorm_bwd_sm100(
# Match Quack's conversion strategy for activations/gradients: keep the
# (M, N) layout dynamic without enforcing additional compact-shape
# constraints. This reduces per-call Python overhead for small-M shapes.
- convert_from_dlpack = lambda t: from_dlpack( # type: ignore[assignment]
- t.detach(),
- assumed_align=16,
- ).mark_layout_dynamic(leading_dim=1)
+ def _convert_mx(t: Tensor) -> cute.Tensor:
+ return from_dlpack(
+ t.detach(),
+ assumed_align=16,
+ ).mark_layout_dynamic(leading_dim=1)
+
x_tensor, dout_tensor, dres_out_tensor, dx_tensor, dres_tensor = [
- convert_from_dlpack(t) if t is not None else None
+ _convert_mx(t) if t is not None else None
for t in (x, dout, dresidual_out, dx, dresidual)
]
diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py
index b53da12..fec5bf4 100644
--- a/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py
+++ b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py
@@ -744,9 +744,13 @@ def rmsnorm_forward_with_stage2(
M, N = x.shape
dtype = TORCH2CUTE_DTYPE[x.dtype]
- convert_x = lambda t: from_dlpack(t.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1)
- mX = convert_x(x)
- mRes = convert_x(residual) if residual is not None else None
+ def _convert_x(t: Tensor) -> cute.Tensor:
+ return from_dlpack(
+ t.detach(), assumed_align=32
+ ).mark_layout_dynamic(leading_dim=1)
+
+ mX = _convert_x(x)
+ mRes = _convert_x(residual) if residual is not None else None
out = torch.empty_like(x, dtype=x.dtype)
mO = from_dlpack(out.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1)
diff --git a/oink/src/kernelagent_oink/blackwell/softmax.py b/oink/src/kernelagent_oink/blackwell/softmax.py
index a2f2581..a8a2791 100644
--- a/oink/src/kernelagent_oink/blackwell/softmax.py
+++ b/oink/src/kernelagent_oink/blackwell/softmax.py
@@ -27,10 +27,9 @@
from __future__ import annotations
import importlib.metadata
-import math
import os
import re
-from typing import Optional, Type
+from typing import Type
import torch
from torch import Tensor
From 0543b6f6948c8c110a72d7273ec343f8f897baaa Mon Sep 17 00:00:00 2001
From: Laura Wang <3700467+Laurawly@users.noreply.github.com>
Date: Wed, 21 Jan 2026 20:08:22 -0800
Subject: [PATCH 6/8] oink: ruff format
---
oink/benchmarks/benchmark/bench_utils.py | 34 +-
.../benchmark_cross_entropy_sm100.py | 97 +++--
.../benchmark_fused_add_rmsnorm_sm100.py | 35 +-
.../benchmark/benchmark_hbm_roofline_sm100.py | 44 ++-
.../benchmark/benchmark_layernorm_sm100.py | 82 +++--
.../benchmark/benchmark_rmsnorm_bwd_sm100.py | 32 +-
.../benchmark/benchmark_rmsnorm_sm100.py | 55 ++-
.../benchmark/benchmark_softmax_sm100.py | 39 +-
.../benchmarks/readme/plot_quack_style_svg.py | 44 ++-
oink/benchmarks/readme/run_sm100_suite.py | 8 +-
oink/benchmarks/readme/summarize_results.py | 61 +++-
.../blackwell/cross_entropy.py | 81 ++++-
.../kernelagent_oink/blackwell/layernorm.py | 90 +++--
.../kernelagent_oink/blackwell/lite_quack.py | 179 ++++++---
.../src/kernelagent_oink/blackwell/rmsnorm.py | 32 +-
.../blackwell/rmsnorm_with_stage2.py | 342 ++++++++++++++----
.../src/kernelagent_oink/blackwell/softmax.py | 100 +++--
17 files changed, 1019 insertions(+), 336 deletions(-)
diff --git a/oink/benchmarks/benchmark/bench_utils.py b/oink/benchmarks/benchmark/bench_utils.py
index 0abb005..0a9ae4b 100644
--- a/oink/benchmarks/benchmark/bench_utils.py
+++ b/oink/benchmarks/benchmark/bench_utils.py
@@ -67,7 +67,9 @@ def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float:
return 2000.0
-def do_bench_triton(fn: Callable[[], Any], *, warmup_ms: int = 25, rep_ms: int = 100) -> float:
+def do_bench_triton(
+ fn: Callable[[], Any], *, warmup_ms: int = 25, rep_ms: int = 100
+) -> float:
"""Kernel-only timing consistent with the Oink benchmark harnesses."""
return float(triton_do_bench(fn, warmup=warmup_ms, rep=rep_ms, return_mode="mean"))
@@ -127,7 +129,13 @@ def write_csv(path: str, rows: Sequence[Dict[str, Any]]) -> None:
writer.writerow(row)
-def write_json(path: str, meta: DeviceMeta, rows: Sequence[Dict[str, Any]], *, extra: Dict[str, Any] | None = None) -> None:
+def write_json(
+ path: str,
+ meta: DeviceMeta,
+ rows: Sequence[Dict[str, Any]],
+ *,
+ extra: Dict[str, Any] | None = None,
+) -> None:
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
payload: Dict[str, Any] = {
"meta": {**asdict(meta), **(extra or {})},
@@ -179,7 +187,9 @@ def __init__(self, *, total_elems: int, p99_target_samples: int = 1_000_000):
if total_elems <= 0:
raise ValueError(f"total_elems must be > 0, got {total_elems}")
if p99_target_samples <= 0:
- raise ValueError(f"p99_target_samples must be > 0, got {p99_target_samples}")
+ raise ValueError(
+ f"p99_target_samples must be > 0, got {p99_target_samples}"
+ )
self.total_elems = int(total_elems)
self.p99_target_samples = int(p99_target_samples)
# Deterministic strided sampling across the flattened tensor order.
@@ -193,7 +203,9 @@ def __init__(self, *, total_elems: int, p99_target_samples: int = 1_000_000):
def update(self, out: torch.Tensor, ref: torch.Tensor) -> None:
if out.shape != ref.shape:
- raise ValueError(f"shape mismatch: out={tuple(out.shape)} ref={tuple(ref.shape)}")
+ raise ValueError(
+ f"shape mismatch: out={tuple(out.shape)} ref={tuple(ref.shape)}"
+ )
# Compute error in float32 for stable reductions.
err_f32 = (out - ref).to(torch.float32)
@@ -214,9 +226,13 @@ def update(self, out: torch.Tensor, ref: torch.Tensor) -> None:
stride = int(self.sample_stride)
first = (-int(self._global_offset)) % stride
if first < block_elems:
- idx = torch.arange(first, block_elems, step=stride, device=flat.device, dtype=torch.int64)
+ idx = torch.arange(
+ first, block_elems, step=stride, device=flat.device, dtype=torch.int64
+ )
# Gather a modest number of values (≈ block_elems/stride).
- vals = flat.index_select(0, idx).detach().to(device="cpu", dtype=torch.float32)
+ vals = (
+ flat.index_select(0, idx).detach().to(device="cpu", dtype=torch.float32)
+ )
self._abs_err_samples.append(vals)
self._global_offset += block_elems
@@ -226,7 +242,11 @@ def finalize(self) -> ErrorStats:
samples = torch.cat(self._abs_err_samples, dim=0)
if samples.numel() > self.p99_target_samples:
samples = samples[: self.p99_target_samples]
- p99 = float(torch.quantile(samples, 0.99).item()) if samples.numel() > 0 else 0.0
+ p99 = (
+ float(torch.quantile(samples, 0.99).item())
+ if samples.numel() > 0
+ else 0.0
+ )
sample_elems = int(samples.numel())
else:
p99 = 0.0
diff --git a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
index 18399c7..ff1a99b 100644
--- a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
@@ -84,16 +84,22 @@ def dsv3_configs() -> List[Tuple[int, int]]:
return [(m, n) for m in Ms for n in Ns]
-def _verify_parity(logits: torch.Tensor, target: torch.Tensor, *, ignore_index: int) -> dict[str, object]:
+def _verify_parity(
+ logits: torch.Tensor, target: torch.Tensor, *, ignore_index: int
+) -> dict[str, object]:
dtype = logits.dtype
ref_block_rows = 512
- dloss = torch.randn(logits.size(0), device=logits.device, dtype=torch.float32) # upstream grad
+ dloss = torch.randn(
+ logits.size(0), device=logits.device, dtype=torch.float32
+ ) # upstream grad
with torch.no_grad():
loss_o, lse_o = oink_ce.cross_entropy_forward(
logits, target, ignore_index=ignore_index, reduction="none"
)
- dx_o = oink_ce.cross_entropy_backward(dloss, logits, target, lse_o, ignore_index=ignore_index)
+ dx_o = oink_ce.cross_entropy_backward(
+ dloss, logits, target, lse_o, ignore_index=ignore_index
+ )
dx_fused_o = oink_ce.cross_entropy_fwd_bwd(
dloss,
logits,
@@ -125,8 +131,12 @@ def _verify_parity(logits: torch.Tensor, target: torch.Tensor, *, ignore_index:
M = int(logits.shape[0])
N = int(logits.shape[1])
- loss_acc_ours = ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000))
- lse_acc_ours = ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000))
+ loss_acc_ours = ErrorStatsAccumulator(
+ total_elems=M, p99_target_samples=min(M, 1_000_000)
+ )
+ lse_acc_ours = ErrorStatsAccumulator(
+ total_elems=M, p99_target_samples=min(M, 1_000_000)
+ )
dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
dx_fused_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
loss_acc_quack = (
@@ -159,23 +169,39 @@ def _verify_parity(logits: torch.Tensor, target: torch.Tensor, *, ignore_index:
ignore_index=ignore_index,
)
lse_ref = torch.logsumexp(logits_f32, dim=-1)
- (dx_ref_f32,) = torch.autograd.grad(loss_ref, logits_f32, grad_outputs=dloss_blk)
+ (dx_ref_f32,) = torch.autograd.grad(
+ loss_ref, logits_f32, grad_outputs=dloss_blk
+ )
dx_ref = dx_ref_f32.to(dtype)
- torch.testing.assert_close(loss_o[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS)
- torch.testing.assert_close(lse_o[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS)
+ torch.testing.assert_close(
+ loss_o[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS
+ )
+ torch.testing.assert_close(
+ lse_o[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS
+ )
torch.testing.assert_close(dx_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype])
- torch.testing.assert_close(dx_fused_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype])
+ torch.testing.assert_close(
+ dx_fused_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype]
+ )
loss_acc_ours.update(loss_o[start:end], loss_ref.detach())
lse_acc_ours.update(lse_o[start:end], lse_ref.detach())
dx_acc_ours.update(dx_o[start:end], dx_ref)
dx_fused_acc_ours.update(dx_fused_o[start:end], dx_ref)
if loss_q is not None and lse_q is not None and dx_q is not None:
- torch.testing.assert_close(loss_q[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS)
- torch.testing.assert_close(lse_q[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS)
+ torch.testing.assert_close(
+ loss_q[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS
+ )
+ torch.testing.assert_close(
+ lse_q[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS
+ )
torch.testing.assert_close(dx_q[start:end], dx_ref, **_VERIFY_TOL_DX[dtype])
- assert loss_acc_quack is not None and lse_acc_quack is not None and dx_acc_quack is not None
+ assert (
+ loss_acc_quack is not None
+ and lse_acc_quack is not None
+ and dx_acc_quack is not None
+ )
loss_acc_quack.update(loss_q[start:end], loss_ref.detach())
lse_acc_quack.update(lse_q[start:end], lse_ref.detach())
dx_acc_quack.update(dx_q[start:end], dx_ref)
@@ -185,7 +211,11 @@ def _verify_parity(logits: torch.Tensor, target: torch.Tensor, *, ignore_index:
stats.update(error_stats_to_row("ours_err_lse", lse_acc_ours.finalize()))
stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize()))
stats.update(error_stats_to_row("ours_err_dx_fused", dx_fused_acc_ours.finalize()))
- if loss_acc_quack is not None and lse_acc_quack is not None and dx_acc_quack is not None:
+ if (
+ loss_acc_quack is not None
+ and lse_acc_quack is not None
+ and dx_acc_quack is not None
+ ):
stats.update(error_stats_to_row("quack_err_loss", loss_acc_quack.finalize()))
stats.update(error_stats_to_row("quack_err_lse", lse_acc_quack.finalize()))
stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize()))
@@ -219,6 +249,7 @@ def bench_single(
bytes_io = bytes_io_model_ce(M, N, dtype, target_dtype=target.dtype, mode=mode)
if mode == "fwd":
+
def fn_oink():
return oink_ce.cross_entropy_forward(
logits, target, ignore_index=int(ignore_index), reduction="none"
@@ -275,6 +306,7 @@ def fn_quack():
)
elif mode == "fwd_bwd":
+
def fn_oink():
return oink_ce.cross_entropy_fwd_bwd(
dloss,
@@ -330,16 +362,34 @@ def main() -> None:
print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
p = argparse.ArgumentParser()
- p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"])
- p.add_argument("--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"])
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]
+ )
+ p.add_argument(
+ "--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"]
+ )
p.add_argument("--ignore-index", type=int, default=-100)
- p.add_argument("--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only).")
+ p.add_argument(
+ "--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only)."
+ )
p.add_argument("--warmup-ms", type=int, default=25)
- p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows")
- p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)")
+ p.add_argument(
+ "--csv", type=str, default=None, help="Optional CSV output path; appends rows"
+ )
+ p.add_argument(
+ "--json", type=str, default=None, help="Optional JSON output path (meta + rows)"
+ )
p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
- p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid (vocab=4096)")
- p.add_argument("--dsv3", action="store_true", help="Run DSv3 set: M in {4096,16384,65536}, N in {3072,6144,8192,12288}")
+ p.add_argument(
+ "--quack-suite",
+ action="store_true",
+ help="Run Quack-style batch/seq grid (vocab=4096)",
+ )
+ p.add_argument(
+ "--dsv3",
+ action="store_true",
+ help="Run DSv3 set: M in {4096,16384,65536}, N in {3072,6144,8192,12288}",
+ )
p.add_argument(
"--skip-verify",
action="store_true",
@@ -360,8 +410,11 @@ def main() -> None:
meta = collect_device_meta(device)
rows_out: List[Dict[str, Any]] = []
- for (M, N) in cfgs:
- print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...", flush=True)
+ for M, N in cfgs:
+ print(
+ f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...",
+ flush=True,
+ )
(ms_oink, gbps_oink), quack, stats = bench_single(
M=M,
N=N,
diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
index 6418e61..863712d 100644
--- a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
@@ -90,8 +90,16 @@ def _verify_parity(
y_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
z_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
- y_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_fwd_mut is not None else None
- z_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_fwd_mut is not None else None
+ y_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if quack_rmsnorm_fwd_mut is not None
+ else None
+ )
+ z_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if quack_rmsnorm_fwd_mut is not None
+ else None
+ )
x_o = x.clone()
r_o = residual.clone()
@@ -141,7 +149,9 @@ def _verify_parity(
stats.update(error_stats_to_row("ours_err_residual_out", z_acc_ours.finalize()))
if y_acc_quack is not None and z_acc_quack is not None:
stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize()))
- stats.update(error_stats_to_row("quack_err_residual_out", z_acc_quack.finalize()))
+ stats.update(
+ error_stats_to_row("quack_err_residual_out", z_acc_quack.finalize())
+ )
return stats
@@ -177,7 +187,9 @@ def fn():
row: Dict[str, Any] = dict(
M=int(M),
N=int(N),
- dtype="bf16" if dtype is torch.bfloat16 else ("fp16" if dtype is torch.float16 else "fp32"),
+ dtype="bf16"
+ if dtype is torch.bfloat16
+ else ("fp16" if dtype is torch.float16 else "fp32"),
ours_ms=float(ms),
ours_gbps=float(gbps),
ours_tbps=float(tbps),
@@ -247,7 +259,9 @@ def _print_table(rows: List[Dict[str, Any]]) -> None:
def main() -> None:
p = argparse.ArgumentParser()
- p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]
+ )
p.add_argument("--M", type=int, default=65536)
p.add_argument("--N", type=int, default=4096)
p.add_argument(
@@ -256,7 +270,9 @@ def main() -> None:
help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
)
p.add_argument("--warmup-ms", type=int, default=25)
- p.add_argument("--iters", type=int, default=200, help="rep_ms for do_bench (default: 200)")
+ p.add_argument(
+ "--iters", type=int, default=200, help="rep_ms for do_bench (default: 200)"
+ )
p.add_argument("--skip-verify", action="store_true")
p.add_argument("--json", type=str, default=None)
args = p.parse_args()
@@ -266,8 +282,11 @@ def main() -> None:
cfgs = dsv3_configs() if bool(args.dsv3) else [(int(args.M), int(args.N))]
rows: List[Dict[str, Any]] = []
- for (M, N) in cfgs:
- print(f"bench M={M:<8d} N={N:<6d} dtype={_dtype_label(dtype)} fused_add_rmsnorm ...", flush=True)
+ for M, N in cfgs:
+ print(
+ f"bench M={M:<8d} N={N:<6d} dtype={_dtype_label(dtype)} fused_add_rmsnorm ...",
+ flush=True,
+ )
rows.append(
bench_one(
M=int(M),
diff --git a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
index 8ec4bfd..c22294e 100644
--- a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
@@ -95,6 +95,7 @@ def bench_one(
grid = (triton.cdiv(n_elements, block),)
if op == "copy":
+
def launch():
_copy_kernel[grid](
x,
@@ -106,6 +107,7 @@ def launch():
)
elif op == "triad":
+
def launch():
_triad_kernel[grid](
x,
@@ -134,20 +136,38 @@ def _print_summary(rows: List[Dict[str, Any]]) -> None:
return
best = max(rows, key=lambda r: float(r["tbps"]))
print("\nSummary (STREAM-like):")
- print(f"- best_tbps: {best['tbps']:.3f} TB/s ({best['op']}, BLOCK={best['block']}, warps={best['num_warps']})")
+ print(
+ f"- best_tbps: {best['tbps']:.3f} TB/s ({best['op']}, BLOCK={best['block']}, warps={best['num_warps']})"
+ )
def main() -> None:
p = argparse.ArgumentParser()
- p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]
+ )
p.add_argument("--op", type=str, default="copy", choices=["copy", "triad", "both"])
- p.add_argument("--gb", type=float, default=2.0, help="Size per tensor in GB (default: 2)")
+ p.add_argument(
+ "--gb", type=float, default=2.0, help="Size per tensor in GB (default: 2)"
+ )
p.add_argument("--warmup-ms", type=int, default=25)
- p.add_argument("--iters", type=int, default=100, help="rep_ms for do_bench (default: 100)")
- p.add_argument("--json", type=str, default=None, help="Write JSON results to this path")
- p.add_argument("--no-sweep", action="store_true", help="Disable tuning sweep; run a single config")
- p.add_argument("--block", type=int, default=2048, help="BLOCK size when --no-sweep is set")
- p.add_argument("--warps", type=int, default=8, help="num_warps when --no-sweep is set")
+ p.add_argument(
+ "--iters", type=int, default=100, help="rep_ms for do_bench (default: 100)"
+ )
+ p.add_argument(
+ "--json", type=str, default=None, help="Write JSON results to this path"
+ )
+ p.add_argument(
+ "--no-sweep",
+ action="store_true",
+ help="Disable tuning sweep; run a single config",
+ )
+ p.add_argument(
+ "--block", type=int, default=2048, help="BLOCK size when --no-sweep is set"
+ )
+ p.add_argument(
+ "--warps", type=int, default=8, help="num_warps when --no-sweep is set"
+ )
args = p.parse_args()
dtype = parse_dtype(args.dtype)
@@ -181,7 +201,9 @@ def main() -> None:
print(f"Running on {props.name} (SM{props.major}{props.minor})")
print(f"- dtype: {args.dtype} (elem={elem_size}B)")
- print(f"- n_elements: {n_elements:,} (~{(n_elements * elem_size) / (1024**3):.2f} GiB per tensor)")
+ print(
+ f"- n_elements: {n_elements:,} (~{(n_elements * elem_size) / (1024**3):.2f} GiB per tensor)"
+ )
print(f"- ops: {ops}")
print(f"- sweep: {sweep}")
@@ -212,7 +234,9 @@ def main() -> None:
tbps=float(tbps),
)
)
- print(f"- {op:5s} BLOCK={block:4d} warps={warps}: {tbps:.3f} TB/s ({ms:.4f} ms)")
+ print(
+ f"- {op:5s} BLOCK={block:4d} warps={warps}: {tbps:.3f} TB/s ({ms:.4f} ms)"
+ )
_print_summary(rows)
diff --git a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
index a9865d1..3c0e37d 100644
--- a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
@@ -97,7 +97,9 @@ def _verify_parity(
y_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
y_acc_quack = (
- ErrorStatsAccumulator(total_elems=M * N) if (quack_layernorm is not None and b is None) else None
+ ErrorStatsAccumulator(total_elems=M * N)
+ if (quack_layernorm is not None and b is None)
+ else None
)
with torch.no_grad():
ours = oink_ln.layernorm(
@@ -138,8 +140,12 @@ def _unpack(out):
# Pure-PyTorch reference (float32 accumulation), matching Quack's unit tests:
# - compute ref output via F.layer_norm on float32
# - compute mean/rstd from float32 input
- rstd_ref_all = torch.empty((M,), device=x.device, dtype=torch.float32) if return_rstd else None
- mean_ref_all = torch.empty((M,), device=x.device, dtype=torch.float32) if return_mean else None
+ rstd_ref_all = (
+ torch.empty((M,), device=x.device, dtype=torch.float32) if return_rstd else None
+ )
+ mean_ref_all = (
+ torch.empty((M,), device=x.device, dtype=torch.float32) if return_mean else None
+ )
for start, end in iter_row_blocks(M, ref_block_rows):
x_f32 = x[start:end].float()
@@ -165,16 +171,24 @@ def _unpack(out):
rstd_ref_all[start:end] = rstd_ref
assert rstd_o is not None
- torch.testing.assert_close(rstd_o[start:end], rstd_ref, **_VERIFY_TOL_STATS)
+ torch.testing.assert_close(
+ rstd_o[start:end], rstd_ref, **_VERIFY_TOL_STATS
+ )
if rstd_q is not None:
- torch.testing.assert_close(rstd_q[start:end], rstd_ref, **_VERIFY_TOL_STATS)
+ torch.testing.assert_close(
+ rstd_q[start:end], rstd_ref, **_VERIFY_TOL_STATS
+ )
if return_mean:
mean_ref = mean_f32
assert mean_o is not None
- torch.testing.assert_close(mean_o[start:end], mean_ref, **_VERIFY_TOL_STATS)
+ torch.testing.assert_close(
+ mean_o[start:end], mean_ref, **_VERIFY_TOL_STATS
+ )
if mean_q is not None:
- torch.testing.assert_close(mean_q[start:end], mean_ref, **_VERIFY_TOL_STATS)
+ torch.testing.assert_close(
+ mean_q[start:end], mean_ref, **_VERIFY_TOL_STATS
+ )
stats: dict[str, object] = {}
stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize()))
@@ -184,30 +198,38 @@ def _unpack(out):
if return_rstd:
assert rstd_o is not None and rstd_ref_all is not None
rstd_acc_ours = ErrorStatsAccumulator(
- total_elems=int(rstd_ref_all.numel()), p99_target_samples=int(rstd_ref_all.numel())
+ total_elems=int(rstd_ref_all.numel()),
+ p99_target_samples=int(rstd_ref_all.numel()),
)
rstd_acc_ours.update(rstd_o, rstd_ref_all)
stats.update(error_stats_to_row("ours_err_rstd", rstd_acc_ours.finalize()))
if rstd_q is not None:
rstd_acc_quack = ErrorStatsAccumulator(
- total_elems=int(rstd_ref_all.numel()), p99_target_samples=int(rstd_ref_all.numel())
+ total_elems=int(rstd_ref_all.numel()),
+ p99_target_samples=int(rstd_ref_all.numel()),
)
rstd_acc_quack.update(rstd_q, rstd_ref_all)
- stats.update(error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize()))
+ stats.update(
+ error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize())
+ )
if return_mean:
assert mean_o is not None and mean_ref_all is not None
mean_acc_ours = ErrorStatsAccumulator(
- total_elems=int(mean_ref_all.numel()), p99_target_samples=int(mean_ref_all.numel())
+ total_elems=int(mean_ref_all.numel()),
+ p99_target_samples=int(mean_ref_all.numel()),
)
mean_acc_ours.update(mean_o, mean_ref_all)
stats.update(error_stats_to_row("ours_err_mean", mean_acc_ours.finalize()))
if mean_q is not None:
mean_acc_quack = ErrorStatsAccumulator(
- total_elems=int(mean_ref_all.numel()), p99_target_samples=int(mean_ref_all.numel())
+ total_elems=int(mean_ref_all.numel()),
+ p99_target_samples=int(mean_ref_all.numel()),
)
mean_acc_quack.update(mean_q, mean_ref_all)
- stats.update(error_stats_to_row("quack_err_mean", mean_acc_quack.finalize()))
+ stats.update(
+ error_stats_to_row("quack_err_mean", mean_acc_quack.finalize())
+ )
return stats
@@ -232,7 +254,9 @@ def bench_single(
stats: dict[str, object] = {}
if verify:
- stats = _verify_parity(x, w, b, eps=eps, return_rstd=return_rstd, return_mean=return_mean)
+ stats = _verify_parity(
+ x, w, b, eps=eps, return_rstd=return_rstd, return_mean=return_mean
+ )
bytes_io = bytes_io_model_layernorm(
M,
@@ -285,17 +309,33 @@ def main() -> None:
print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
p = argparse.ArgumentParser()
- p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"])
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]
+ )
p.add_argument("--eps", type=float, default=1e-6)
p.add_argument("--return-rstd", action="store_true")
p.add_argument("--return-mean", action="store_true")
- p.add_argument("--with-bias", action="store_true", help="Benchmark bias path (Quack compare skipped)")
- p.add_argument("--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only).")
+ p.add_argument(
+ "--with-bias",
+ action="store_true",
+ help="Benchmark bias path (Quack compare skipped)",
+ )
+ p.add_argument(
+ "--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only)."
+ )
p.add_argument("--warmup-ms", type=int, default=25)
- p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows")
- p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)")
+ p.add_argument(
+ "--csv", type=str, default=None, help="Optional CSV output path; appends rows"
+ )
+ p.add_argument(
+ "--json", type=str, default=None, help="Optional JSON output path (meta + rows)"
+ )
p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
- p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid (hidden=4096)")
+ p.add_argument(
+ "--quack-suite",
+ action="store_true",
+ help="Run Quack-style batch/seq grid (hidden=4096)",
+ )
p.add_argument(
"--dsv3",
action="store_true",
@@ -322,7 +362,7 @@ def main() -> None:
meta = collect_device_meta(device)
rows_out: List[Dict[str, Any]] = []
- for (M, N) in cfgs:
+ for M, N in cfgs:
print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True)
(ms_oink, gbps_oink), quack, stats = bench_single(
M=M,
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
index 4ba1c47..b9909e7 100644
--- a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
@@ -140,7 +140,11 @@ def _verify_parity(
M, N = int(x.shape[0]), int(x.shape[1])
dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
- dx_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_bwd is not None else None
+ dx_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if quack_rmsnorm_bwd is not None
+ else None
+ )
with torch.no_grad():
dx_oink, dw_oink, db_oink, dres_oink = oink_rmsnorm.rmsnorm_backward(
@@ -216,7 +220,9 @@ def _verify_parity(
dw_tol = dict(atol=dw_atol, rtol=1e-3)
torch.testing.assert_close(dw_oink_f32, dw_ref_f32, **dw_tol)
if dw_quack is not None:
- torch.testing.assert_close(dw_quack.to(torch.float32), dw_ref_f32, **dw_tol)
+ torch.testing.assert_close(
+ dw_quack.to(torch.float32), dw_ref_f32, **dw_tol
+ )
dw_tol = None # handled above
if dw_tol is not None:
torch.testing.assert_close(dw_oink, dw_ref, **dw_tol)
@@ -224,7 +230,9 @@ def _verify_parity(
torch.testing.assert_close(dw_quack, dw_ref, **dw_tol)
# Record weight-grad error stats (small, so exact p99 over the full vector).
- dw_acc_ours = ErrorStatsAccumulator(total_elems=int(dw_ref.numel()), p99_target_samples=int(dw_ref.numel()))
+ dw_acc_ours = ErrorStatsAccumulator(
+ total_elems=int(dw_ref.numel()), p99_target_samples=int(dw_ref.numel())
+ )
dw_acc_ours.update(dw_oink, dw_ref)
stats.update(error_stats_to_row("ours_err_dw", dw_acc_ours.finalize()))
if dw_quack is not None:
@@ -308,7 +316,9 @@ def main() -> None:
print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
p = argparse.ArgumentParser()
- p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"])
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]
+ )
p.add_argument(
"--weight-dtype",
type=str,
@@ -324,10 +334,16 @@ def main() -> None:
help="Triton do_bench rep_ms (kernel-only).",
)
p.add_argument("--warmup-ms", type=int, default=25)
- p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows")
- p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)")
+ p.add_argument(
+ "--csv", type=str, default=None, help="Optional CSV output path; appends rows"
+ )
+ p.add_argument(
+ "--json", type=str, default=None, help="Optional JSON output path (meta + rows)"
+ )
p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
- p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid")
+ p.add_argument(
+ "--quack-suite", action="store_true", help="Run Quack-style batch/seq grid"
+ )
p.add_argument(
"--dsv3",
action="store_true",
@@ -358,7 +374,7 @@ def main() -> None:
rows_out: list[dict[str, object]] = []
- for (M, N) in cfgs:
+ for M, N in cfgs:
print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True)
ours, quack, stats = bench_single(
M=M,
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
index 20ed8ac..f4c8a5f 100644
--- a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
@@ -85,7 +85,11 @@ def _verify_parity(
N = int(x.shape[1])
y_acc_ours = ErrorStatsAccumulator(total_elems=M * N)
- y_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_fwd is not None else None
+ y_acc_quack = (
+ ErrorStatsAccumulator(total_elems=M * N)
+ if quack_rmsnorm_fwd is not None
+ else None
+ )
with torch.no_grad():
y_o, rstd_o, res_o = oink_rmsnorm.rmsnorm_forward(
@@ -141,15 +145,20 @@ def _verify_parity(
assert rstd_q is not None
torch.testing.assert_close(rstd_q, rstd_ref, **tol_rstd)
# Stats for rstd are cheap (M elements); compute exact p99 over all rows.
- rstd_acc_ours = ErrorStatsAccumulator(total_elems=int(rstd_ref.numel()), p99_target_samples=int(rstd_ref.numel()))
+ rstd_acc_ours = ErrorStatsAccumulator(
+ total_elems=int(rstd_ref.numel()), p99_target_samples=int(rstd_ref.numel())
+ )
rstd_acc_ours.update(rstd_o, rstd_ref)
stats.update(error_stats_to_row("ours_err_rstd", rstd_acc_ours.finalize()))
if rstd_q is not None:
rstd_acc_quack = ErrorStatsAccumulator(
- total_elems=int(rstd_ref.numel()), p99_target_samples=int(rstd_ref.numel())
+ total_elems=int(rstd_ref.numel()),
+ p99_target_samples=int(rstd_ref.numel()),
)
rstd_acc_quack.update(rstd_q, rstd_ref)
- stats.update(error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize()))
+ stats.update(
+ error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize())
+ )
# Residual output semantics differ slightly across implementations:
# - Oink returns `None` when residual is None.
# - Quack returns `x` as a safe alias in that case.
@@ -227,7 +236,9 @@ def main() -> None:
print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
p = argparse.ArgumentParser()
- p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"])
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]
+ )
p.add_argument(
"--weight-dtype",
type=str,
@@ -236,15 +247,33 @@ def main() -> None:
help="RMSNorm weight dtype. `same` matches activation dtype (vLLM-style inference).",
)
p.add_argument("--eps", type=float, default=1e-6)
- p.add_argument("--store-rstd", action="store_true", help="Also write rstd (fp32 per row)")
- p.add_argument("--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only).")
+ p.add_argument(
+ "--store-rstd", action="store_true", help="Also write rstd (fp32 per row)"
+ )
+ p.add_argument(
+ "--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only)."
+ )
p.add_argument("--warmup-ms", type=int, default=25)
- p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows")
- p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)")
+ p.add_argument(
+ "--csv", type=str, default=None, help="Optional CSV output path; appends rows"
+ )
+ p.add_argument(
+ "--json", type=str, default=None, help="Optional JSON output path (meta + rows)"
+ )
p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
- p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid")
- p.add_argument("--dsv3", action="store_true", help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}")
- p.add_argument("--skip-verify", action="store_true", help="Skip correctness checks (Oink/Quack vs a pure-PyTorch reference)")
+ p.add_argument(
+ "--quack-suite", action="store_true", help="Run Quack-style batch/seq grid"
+ )
+ p.add_argument(
+ "--dsv3",
+ action="store_true",
+ help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
+ )
+ p.add_argument(
+ "--skip-verify",
+ action="store_true",
+ help="Skip correctness checks (Oink/Quack vs a pure-PyTorch reference)",
+ )
args = p.parse_args()
dtype = parse_dtype(args.dtype)
@@ -265,7 +294,7 @@ def main() -> None:
meta = collect_device_meta(device)
rows_out: List[Dict[str, Any]] = []
- for (M, N) in cfgs:
+ for M, N in cfgs:
print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True)
(ms_oink, gbps_oink), quack, stats = bench_single(
M=M,
diff --git a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py
index 7826efc..995b09f 100644
--- a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py
@@ -150,6 +150,7 @@ def bench_single(
bytes_io = bytes_io_model_softmax(M, N, dtype, mode=mode)
if mode == "fwd":
+
def fn_oink():
return oink_softmax.softmax_forward(x)
@@ -174,6 +175,7 @@ def fn_quack():
return quack_softmax_bwd(dy, y_q)
elif mode == "fwd_bwd":
+
def fn_oink():
return oink_softmax.softmax_fwd_bwd(dy, x)
@@ -208,20 +210,36 @@ def main() -> None:
print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})")
p = argparse.ArgumentParser()
- p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"])
- p.add_argument("--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"])
- p.add_argument("--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only).")
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]
+ )
+ p.add_argument(
+ "--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"]
+ )
+ p.add_argument(
+ "--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only)."
+ )
p.add_argument("--warmup-ms", type=int, default=25)
- p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows")
- p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)")
+ p.add_argument(
+ "--csv", type=str, default=None, help="Optional CSV output path; appends rows"
+ )
+ p.add_argument(
+ "--json", type=str, default=None, help="Optional JSON output path (meta + rows)"
+ )
p.add_argument("--configs", type=str, default="1024x4096,8192x4096")
- p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid")
+ p.add_argument(
+ "--quack-suite", action="store_true", help="Run Quack-style batch/seq grid"
+ )
p.add_argument(
"--dsv3",
action="store_true",
help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}",
)
- p.add_argument("--skip-verify", action="store_true", help="Skip correctness checks (Oink/Quack vs PyTorch softmax)")
+ p.add_argument(
+ "--skip-verify",
+ action="store_true",
+ help="Skip correctness checks (Oink/Quack vs PyTorch softmax)",
+ )
args = p.parse_args()
dtype = parse_dtype(args.dtype)
@@ -237,8 +255,11 @@ def main() -> None:
meta = collect_device_meta(device)
rows_out: List[Dict[str, Any]] = []
- for (M, N) in cfgs:
- print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...", flush=True)
+ for M, N in cfgs:
+ print(
+ f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...",
+ flush=True,
+ )
(ms_oink, gbps_oink), quack, stats = bench_single(
M=M,
N=N,
diff --git a/oink/benchmarks/readme/plot_quack_style_svg.py b/oink/benchmarks/readme/plot_quack_style_svg.py
index c089b2b..af76832 100644
--- a/oink/benchmarks/readme/plot_quack_style_svg.py
+++ b/oink/benchmarks/readme/plot_quack_style_svg.py
@@ -78,7 +78,9 @@ def _gbps_from_row(prefix: str, row: Mapping[str, Any]) -> Optional[float]:
return None
-def _aggregate_by_shape(rows: Sequence[Mapping[str, Any]]) -> Dict[Tuple[int, int], Dict[str, float]]:
+def _aggregate_by_shape(
+ rows: Sequence[Mapping[str, Any]],
+) -> Dict[Tuple[int, int], Dict[str, float]]:
"""Aggregate duplicate (M, N) rows using median (more robust than mean)."""
buckets: dict[tuple[int, int], dict[str, list[float]]] = defaultdict(
lambda: defaultdict(list)
@@ -199,7 +201,11 @@ def _plot(
continue
ours_y.append(float(rec["ours"]))
quack_y.append(float(rec["quack"]))
- max_y = max(max_y, *(v for v in ours_y if math.isfinite(v)), *(v for v in quack_y if math.isfinite(v)))
+ max_y = max(
+ max_y,
+ *(v for v in ours_y if math.isfinite(v)),
+ *(v for v in quack_y if math.isfinite(v)),
+ )
ax.plot(
x,
@@ -337,7 +343,10 @@ def main() -> None:
description="Generate Quack-style SVG plots from KernelAgent-Oink suite JSONs."
)
p.add_argument(
- "--in-dir", type=str, required=True, help="Directory containing suite JSON outputs"
+ "--in-dir",
+ type=str,
+ required=True,
+ help="Directory containing suite JSON outputs",
)
p.add_argument(
"--suite",
@@ -362,22 +371,35 @@ def main() -> None:
"`union` includes every shape across panels (may create gaps)."
),
)
- p.add_argument("--roofline-json", type=str, default=None, help="Optional /tmp/hbm_roofline_sm100_*.json path")
+ p.add_argument(
+ "--roofline-json",
+ type=str,
+ default=None,
+ help="Optional /tmp/hbm_roofline_sm100_*.json path",
+ )
p.add_argument("--out", type=str, required=True, help="Output SVG path")
- p.add_argument("--title", type=str, default=None, help="Optional figure title override")
+ p.add_argument(
+ "--title", type=str, default=None, help="Optional figure title override"
+ )
args = p.parse_args()
in_dir = os.path.abspath(args.in_dir)
if not os.path.isdir(in_dir):
raise SystemExit(f"--in-dir is not a directory: {in_dir}")
- roofline_gbps = _read_roofline_gbps(args.roofline_json) if args.roofline_json else None
+ roofline_gbps = (
+ _read_roofline_gbps(args.roofline_json) if args.roofline_json else None
+ )
panel_files = list(_panel_files_for_suite(str(args.suite)))
if args.include_layernorm:
if args.suite != "quack_suite":
- raise SystemExit("--include-layernorm is only supported for `--suite quack_suite`.")
- panel_files.append(("LayerNorm (fwd)", _layernorm_file_for_suite(str(args.suite))))
+ raise SystemExit(
+ "--include-layernorm is only supported for `--suite quack_suite`."
+ )
+ panel_files.append(
+ ("LayerNorm (fwd)", _layernorm_file_for_suite(str(args.suite)))
+ )
panels: List[Tuple[str, Dict[Tuple[int, int], Dict[str, float]]]] = []
for panel_title, filename in panel_files:
@@ -410,7 +432,11 @@ def main() -> None:
suite_name = "DSv3 CrossEntropy"
else:
suite_name = str(args.suite)
- suffix = " (+LayerNorm)" if (args.suite == "quack_suite" and args.include_layernorm) else ""
+ suffix = (
+ " (+LayerNorm)"
+ if (args.suite == "quack_suite" and args.include_layernorm)
+ else ""
+ )
if args.suite == "dsv3_cross_entropy":
title = f"SM100 {dtype.upper()} — {suite_name}{suffix}"
else:
diff --git a/oink/benchmarks/readme/run_sm100_suite.py b/oink/benchmarks/readme/run_sm100_suite.py
index 5ac1091..c31d4b5 100644
--- a/oink/benchmarks/readme/run_sm100_suite.py
+++ b/oink/benchmarks/readme/run_sm100_suite.py
@@ -21,7 +21,9 @@ def _run(cmd: List[str], *, dry_run: bool) -> None:
def main() -> None:
p = argparse.ArgumentParser()
- p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"])
+ p.add_argument(
+ "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]
+ )
p.add_argument(
"--out-dir",
type=str,
@@ -33,7 +35,9 @@ def main() -> None:
action="store_true",
help="Skip correctness checks (Oink/Quack vs PyTorch / pure-PyTorch references)",
)
- p.add_argument("--dry-run", action="store_true", help="Print commands without executing them")
+ p.add_argument(
+ "--dry-run", action="store_true", help="Print commands without executing them"
+ )
args = p.parse_args()
# Standardize env for standalone runs outside the vLLM plugin.
diff --git a/oink/benchmarks/readme/summarize_results.py b/oink/benchmarks/readme/summarize_results.py
index 70782dd..29b768e 100644
--- a/oink/benchmarks/readme/summarize_results.py
+++ b/oink/benchmarks/readme/summarize_results.py
@@ -95,16 +95,26 @@ def _summarize_error_stats(rows: Sequence[Dict[str, Any]]) -> str:
out_rows: List[Dict[str, Any]] = []
for pfx in prefixes:
# Per-prefix worst-case across rows.
- max_abs_vals = [float(r[pfx + "_max_abs"]) for r in rows if (pfx + "_max_abs") in r]
- p99_abs_vals = [float(r[pfx + "_p99_abs"]) for r in rows if (pfx + "_p99_abs") in r]
- rel_l2_vals = [float(r[pfx + "_rel_l2"]) for r in rows if (pfx + "_rel_l2") in r]
+ max_abs_vals = [
+ float(r[pfx + "_max_abs"]) for r in rows if (pfx + "_max_abs") in r
+ ]
+ p99_abs_vals = [
+ float(r[pfx + "_p99_abs"]) for r in rows if (pfx + "_p99_abs") in r
+ ]
+ rel_l2_vals = [
+ float(r[pfx + "_rel_l2"]) for r in rows if (pfx + "_rel_l2") in r
+ ]
if not max_abs_vals and not p99_abs_vals and not rel_l2_vals:
continue
out_rows.append(
{
"metric": pfx,
- "max_abs (max over shapes)": max(max_abs_vals) if max_abs_vals else None,
- "p99_abs (max over shapes)": max(p99_abs_vals) if p99_abs_vals else None,
+ "max_abs (max over shapes)": max(max_abs_vals)
+ if max_abs_vals
+ else None,
+ "p99_abs (max over shapes)": max(p99_abs_vals)
+ if p99_abs_vals
+ else None,
"rel_l2 (max over shapes)": max(rel_l2_vals) if rel_l2_vals else None,
}
)
@@ -112,8 +122,15 @@ def _summarize_error_stats(rows: Sequence[Dict[str, Any]]) -> str:
if not out_rows:
return ""
- cols = ["metric", "max_abs (max over shapes)", "p99_abs (max over shapes)", "rel_l2 (max over shapes)"]
- return "\n".join(["", "### Error Stats (vs PyTorch ref)", "", _md_table(out_rows, cols), ""])
+ cols = [
+ "metric",
+ "max_abs (max over shapes)",
+ "p99_abs (max over shapes)",
+ "rel_l2 (max over shapes)",
+ ]
+ return "\n".join(
+ ["", "### Error Stats (vs PyTorch ref)", "", _md_table(out_rows, cols), ""]
+ )
def summarize_one(path: str) -> str:
@@ -143,7 +160,9 @@ def summarize_one(path: str) -> str:
if method is not None:
parts.append(f"- method: `{method}`")
if meta.get("warmup_ms") is not None and meta.get("rep_ms") is not None:
- parts.append(f"- warmup_ms: `{meta.get('warmup_ms')}` | rep_ms: `{meta.get('rep_ms')}`")
+ parts.append(
+ f"- warmup_ms: `{meta.get('warmup_ms')}` | rep_ms: `{meta.get('rep_ms')}`"
+ )
if rows:
parts.append("")
@@ -153,7 +172,9 @@ def summarize_one(path: str) -> str:
gm = _geomean(speeds)
if gm is not None:
parts.append("")
- parts.append(f"- geomean speedup vs Quack: `{gm:.3f}x` (over {len(speeds)} shapes)")
+ parts.append(
+ f"- geomean speedup vs Quack: `{gm:.3f}x` (over {len(speeds)} shapes)"
+ )
err_block = _summarize_error_stats(rows)
if err_block:
@@ -167,9 +188,21 @@ def summarize_one(path: str) -> str:
def main() -> None:
- p = argparse.ArgumentParser(description="Summarize KernelAgent-Oink benchmark JSONs into Markdown tables.")
- p.add_argument("--in-dir", type=str, required=True, help="Directory containing benchmark JSON files")
- p.add_argument("--out", type=str, default=None, help="Optional output markdown path (default: stdout)")
+ p = argparse.ArgumentParser(
+ description="Summarize KernelAgent-Oink benchmark JSONs into Markdown tables."
+ )
+ p.add_argument(
+ "--in-dir",
+ type=str,
+ required=True,
+ help="Directory containing benchmark JSON files",
+ )
+ p.add_argument(
+ "--out",
+ type=str,
+ default=None,
+ help="Optional output markdown path (default: stdout)",
+ )
args = p.parse_args()
in_dir = os.path.abspath(args.in_dir)
@@ -177,7 +210,9 @@ def main() -> None:
raise SystemExit(f"--in-dir is not a directory: {in_dir}")
json_paths = sorted(
- os.path.join(in_dir, name) for name in os.listdir(in_dir) if name.endswith(".json")
+ os.path.join(in_dir, name)
+ for name in os.listdir(in_dir)
+ if name.endswith(".json")
)
if not json_paths:
raise SystemExit(f"No .json files found under: {in_dir}")
diff --git a/oink/src/kernelagent_oink/blackwell/cross_entropy.py b/oink/src/kernelagent_oink/blackwell/cross_entropy.py
index 94f052f..3e6eef1 100644
--- a/oink/src/kernelagent_oink/blackwell/cross_entropy.py
+++ b/oink/src/kernelagent_oink/blackwell/cross_entropy.py
@@ -103,9 +103,8 @@ def _convert_logits_2d(x: Tensor) -> cute.Tensor:
softmax and RMSNorm kernels.
"""
assert x.dim() == 2, "Input logits must be 2D (M, N)"
- return (
- from_dlpack(x.detach(), assumed_align=16)
- .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
+ return from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
+ mode=0, stride_order=(0, 1)
)
@@ -136,7 +135,11 @@ def _calculate_threads_per_row(self) -> int:
else (
16
if N <= 128
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
+ else (
+ 32
+ if N <= 3072
+ else (64 if N <= 6144 else (128 if N <= 16384 else 256))
+ )
)
)
@@ -183,7 +186,9 @@ def __call__(
num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
num_threads = (
- cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
)
num_warps = num_threads // cute.arch.WARP_SIZE
kernel = (
@@ -267,7 +272,9 @@ def _kernel_impl(
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
byte_alignment=16,
)
- reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
+ smem, tv_layout
+ )
# Copy setup: gmem -> smem via cp.async, 128-bit or narrower as needed.
num_copy_elems_X = tv_layout.shape[1][0]
@@ -277,7 +284,9 @@ def _kernel_impl(
gX.element_type,
num_bits_per_copy=num_copy_bits_X,
)
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
+ thr_copy_X = cute.make_tiled_copy(
+ copy_atom_load_X, tv_layout, tiler_mn
+ ).get_slice(tidx)
tXgX = thr_copy_X.partition_S(gX)
tXsX = thr_copy_X.partition_D(sX)
@@ -414,13 +423,21 @@ def _calculate_threads_per_row(self) -> int:
else (
16
if N <= 128
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
+ else (
+ 32
+ if N <= 3072
+ else (64 if N <= 6144 else (128 if N <= 16384 else 256))
+ )
)
)
- def _get_tv_layout(self, num_copy_bits: int = 128) -> tuple[cute.Shape, cute.Layout]:
+ def _get_tv_layout(
+ self, num_copy_bits: int = 128
+ ) -> tuple[cute.Shape, cute.Layout]:
vecsize = num_copy_bits // self.dtype.width
- assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
+ assert self.N % vecsize == 0, (
+ f"Input N {self.N} is not divisible by vector size {vecsize}"
+ )
N = min(self.N, 16384)
num_threads = 128 if N <= 16384 else 256
threads_per_row = self._calculate_threads_per_row()
@@ -452,7 +469,9 @@ def __call__(
num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width
tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits)
num_threads = (
- cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
)
# Broadcast (M,) tensors along the N dimension with stride 0.
mDLoss, mTarget, mLSE = [
@@ -564,8 +583,12 @@ def _kernel_impl(
gdX.element_type,
num_bits_per_copy=num_copy_bits_X,
)
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
- thr_copy_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx)
+ thr_copy_X = cute.make_tiled_copy(
+ copy_atom_load_X, tv_layout, tiler_mn
+ ).get_slice(tidx)
+ thr_copy_dX = cute.make_tiled_copy(
+ copy_atom_store_dX, tv_layout, tiler_mn
+ ).get_slice(tidx)
tXgX = thr_copy_X.partition_S(gX)
tXsX = thr_copy_X.partition_D(sX)
@@ -898,8 +921,14 @@ def _cross_entropy_forward_ptr_into(
assert logits.is_cuda and logits.dim() == 2
assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0]
assert target.dtype is torch.int64
- assert loss.is_cuda and loss.shape == (logits.shape[0],) and loss.dtype is torch.float32
- assert lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32
+ assert (
+ loss.is_cuda
+ and loss.shape == (logits.shape[0],)
+ and loss.dtype is torch.float32
+ )
+ assert (
+ lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32
+ )
M, N = logits.shape
device_index = logits.get_device()
@@ -991,10 +1020,18 @@ def _cross_entropy_backward_ptr_into(
assert logits.is_cuda and logits.dim() == 2
assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0]
assert target.dtype is torch.int64
- assert dloss.is_cuda and dloss.shape == (logits.shape[0],) and dloss.dtype is torch.float32
- assert lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32
+ assert (
+ dloss.is_cuda
+ and dloss.shape == (logits.shape[0],)
+ and dloss.dtype is torch.float32
+ )
+ assert (
+ lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32
+ )
assert dx.is_cuda and dx.shape == logits.shape and dx.dtype == logits.dtype
- assert dx.stride() == logits.stride(), "Pointer path expects dx to match logits strides"
+ assert dx.stride() == logits.stride(), (
+ "Pointer path expects dx to match logits strides"
+ )
M, N = logits.shape
device_index = logits.get_device()
@@ -1060,7 +1097,9 @@ def _cross_entropy_backward_ptr_into(
mem_space=rt.AddressSpace.gmem,
assumed_align=4,
)
- ptr_dx = rt.make_ptr(dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_dx = rt.make_ptr(
+ dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
ptr_lse = rt.make_ptr(
cutlass.Float32,
lse.data_ptr(),
@@ -1169,7 +1208,9 @@ def verify_cross_entropy_parity(
mask = torch.rand(M, device=device) < 0.1
target[mask] = ignore_index
- loss, lse = cross_entropy_forward(logits, target, ignore_index=ignore_index, reduction="none")
+ loss, lse = cross_entropy_forward(
+ logits, target, ignore_index=ignore_index, reduction="none"
+ )
logits_ref = logits.detach().clone().requires_grad_()
target_ref = target.detach().clone()
diff --git a/oink/src/kernelagent_oink/blackwell/layernorm.py b/oink/src/kernelagent_oink/blackwell/layernorm.py
index 0e4d640..67f67ce 100644
--- a/oink/src/kernelagent_oink/blackwell/layernorm.py
+++ b/oink/src/kernelagent_oink/blackwell/layernorm.py
@@ -138,7 +138,11 @@ def _calculate_threads_per_row(self) -> int:
else (
16
if N <= 128
- else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256)))
+ else (
+ 32
+ if N <= 3072
+ else (64 if N <= 6144 else (128 if N <= 16384 else 256))
+ )
)
)
@@ -186,7 +190,9 @@ def __call__(
self._set_cluster_n()
tiler_mn, tv_layout = self._get_tv_layout()
num_threads = (
- cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
)
num_warps = num_threads // cute.arch.WARP_SIZE
@@ -281,7 +287,9 @@ def launch_from_ptrs(
mX = cute.make_tensor(ptr_x, layout_mn)
mO = cute.make_tensor(ptr_out, layout_mn)
mW = cute.make_tensor(ptr_w, layout_n)
- mB = cute.make_tensor(ptr_b, layout_n) if const_expr(ptr_b is not None) else None
+ mB = (
+ cute.make_tensor(ptr_b, layout_n) if const_expr(ptr_b is not None) else None
+ )
mRstd = (
cute.make_tensor(ptr_rstd, layout_m)
if const_expr(ptr_rstd is not None)
@@ -323,15 +331,15 @@ def _kernel_impl(
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
byte_alignment=16,
)
- reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
+ smem, tv_layout
+ )
shape = mX.shape
idX = cute.make_identity_tensor(shape)
# Slice for CTAs: use domain_offset_i64 to handle >2^31 elements.
- mX, mO = [
- domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)
- ]
+ mX, mO = [domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)]
gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)]
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
gW = cute.local_tile(mW, tiler_mn, (0, cluster_y))
@@ -390,39 +398,23 @@ def _kernel_impl(
).get_slice(tidx)
tWgW = thr_copy_WB.partition_S(gW)
- tBgB = (
- thr_copy_WB.partition_S(gB)
- if const_expr(gB is not None)
- else None
- )
+ tBgB = thr_copy_WB.partition_S(gB) if const_expr(gB is not None) else None
tXgX = thr_copy_X.partition_S(gX)
tXsX = thr_copy_X.partition_D(sX)
tXgO = thr_copy_O.partition_D(gO)
tXrRstd = (
- thr_copy_O.partition_D(gRstd)
- if const_expr(mRstd is not None)
- else None
+ thr_copy_O.partition_D(gRstd) if const_expr(mRstd is not None) else None
)
tXrMean = (
- thr_copy_O.partition_D(gMean)
- if const_expr(mMean is not None)
- else None
+ thr_copy_O.partition_D(gMean) if const_expr(mMean is not None) else None
)
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None]
# Fragments for gmem->rmem.
tWrW = cute.make_fragment_like(tWgW)
- tBrB = (
- cute.make_fragment_like(tBgB)
- if const_expr(mB is not None)
- else None
- )
+ tBrB = cute.make_fragment_like(tBgB) if const_expr(mB is not None) else None
tXrW = thr_copy_X.retile(tWrW)
- tXrB = (
- thr_copy_X.retile(tBrB)
- if const_expr(mB is not None)
- else None
- )
+ tXrB = thr_copy_X.retile(tBrB) if const_expr(mB is not None) else None
tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)]
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
@@ -458,9 +450,7 @@ def _kernel_impl(
mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None,
init_val=0.0,
hook_fn=(
- cute.arch.cluster_wait
- if const_expr(self.cluster_n > 1)
- else None
+ cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None
),
)
mean = sum_x / shape[1]
@@ -486,10 +476,7 @@ def _kernel_impl(
if (
tXcX[0][1] == 0
and row < shape[0]
- and (
- self.cluster_n == 1
- or cute.arch.block_idx_in_cluster() == 0
- )
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
):
tXrRstd[0] = rstd
@@ -497,10 +484,7 @@ def _kernel_impl(
if (
tXcX[0][1] == 0
and row < shape[0]
- and (
- self.cluster_n == 1
- or cute.arch.block_idx_in_cluster() == 0
- )
+ and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0)
):
tXrMean[0] = mean
@@ -861,7 +845,9 @@ def _layernorm_forward_ptr_into(
)
_PTR_COMPILE_CACHE[key] = compiled
- ptr_x = rt.make_ptr(dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_x = rt.make_ptr(
+ dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
ptr_out = rt.make_ptr(
dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
)
@@ -978,8 +964,12 @@ def _layernorm_backward_dx_kernel(
smem = cutlass.utils.SmemAllocator()
num_warps = const_expr(block_threads // cute.arch.WARP_SIZE)
warp_sums_layout = cute.make_layout((num_warps,), stride=(1,))
- warp_sums_wdy = smem.allocate_tensor(Float32, warp_sums_layout, byte_alignment=4)
- warp_sums_xhatwdy = smem.allocate_tensor(Float32, warp_sums_layout, byte_alignment=4)
+ warp_sums_wdy = smem.allocate_tensor(
+ Float32, warp_sums_layout, byte_alignment=4
+ )
+ warp_sums_xhatwdy = smem.allocate_tensor(
+ Float32, warp_sums_layout, byte_alignment=4
+ )
lane = cute.arch.lane_idx()
warp_idx = cute.arch.warp_idx()
@@ -1177,8 +1167,12 @@ def _layernorm_backward_dx_sm100(
alignment=16,
divisibility=128 // cutlass.Float32.width,
)
- mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
- mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+ mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic(
+ leading_dim=0
+ )
+ mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic(
+ leading_dim=0
+ )
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
key = (N, dtype)
@@ -1231,8 +1225,12 @@ def _layernorm_backward_params_sm100(
mX = _convert_row_major(x_2d)
mdO = _convert_row_major(dout_2d)
- mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
- mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+ mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic(
+ leading_dim=0
+ )
+ mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic(
+ leading_dim=0
+ )
mdW_partial = (
from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0)
diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py
index 590d773..e8ce93a 100644
--- a/oink/src/kernelagent_oink/blackwell/lite_quack.py
+++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py
@@ -85,6 +85,7 @@ def _cutlass_dsl_version() -> Optional[tuple[int, int, int]]:
# Tensor conversion helpers (from quack.utils)
# -------------------------
+
def convert_from_dlpack(
x: Tensor,
leading_dim: int,
@@ -108,7 +109,9 @@ def convert_from_dlpack(
@dsl_user_op
-def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
+def elem_pointer(
+ x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None
+) -> cute.Pointer:
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
@@ -159,7 +162,9 @@ def store_shared_remote(
).ir_value()
if const_expr(isinstance(val, float)):
val = Float32(val)
- assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64"
+ assert isinstance(val, (Float32, Int32, cutlass.Int64)), (
+ "val must be Float32, Int32, or Int64"
+ )
suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)]
constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)]
llvm.inline_asm(
@@ -178,19 +183,27 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if".
tApA = cute.make_fragment(
cute.make_layout(
- (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
+ (
+ cute.size(tAcA, mode=[0, 1]),
+ cute.size(tAcA, mode=[1]),
+ cute.size(tAcA, mode=[2]),
+ ),
stride=(cute.size(tAcA, mode=[2]), 0, 1),
),
cutlass.Boolean,
)
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
- tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
+ tApA[rest_v, 0, rest_k] = cute.elem_less(
+ tAcA[(0, rest_v), 0, rest_k][1], limit
+ )
return tApA
@dsl_user_op
-def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor:
+def domain_offset_i64(
+ coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
+) -> cute.Tensor:
flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord))
flat_stride = cute.flatten_to_tuple(tensor.stride)
assert len(flat_coord_i64) == len(flat_stride), (
@@ -228,7 +241,9 @@ def coord_offset_i64(
@cute.jit
-def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cutlass.Numeric) -> None:
+def fill_oob(
+ tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cutlass.Numeric
+) -> None:
"""Fill out-of-bounds values in shared memory tensor."""
tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0])
tXrX_fill.fill(fill_value)
@@ -256,7 +271,9 @@ def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64:
)
vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2, loc=loc, ip=ip)
res = cutlass.Int64(
- vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
+ vector.extract(
+ vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip
+ )
)
return res
@@ -272,10 +289,14 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float
)
vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1, loc=loc, ip=ip)
res0 = Float32(
- vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)
+ vector.extract(
+ vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip
+ )
)
res1 = Float32(
- vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)
+ vector.extract(
+ vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip
+ )
)
return res0, res1
@@ -372,7 +393,9 @@ def block_or_cluster_reduce(
"""Perform either block or cluster reduction based on whether mbar_ptr is provided."""
if cutlass.const_expr(mbar_ptr is None):
return block_reduce(val, op, reduction_buffer, init_val=init_val)
- return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase)
+ return cluster_reduce(
+ val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase
+ )
@cute.jit
@@ -393,7 +416,9 @@ def row_reduce(
val = x
warp_op = {
cute.ReductionOp.ADD: operator.add,
- cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max,
+ cute.ReductionOp.MAX: cute.arch.fmax
+ if cutlass.const_expr(x.dtype == Float32)
+ else max,
cute.ReductionOp.MIN: min,
cute.ReductionOp.MUL: operator.mul,
}[op]
@@ -521,7 +546,9 @@ def online_softmax_reduce(
reduction_buffer[row_idx, lane_idx]
)
max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax)
- sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True)
+ sum_exp_x *= cute.math.exp(
+ max_x_single_warp - max_x_final, fastmath=True
+ )
sum_exp_x = warp_reduce(sum_exp_x, operator.add)
if cutlass.const_expr(return_exp_x):
exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True)
@@ -533,16 +560,23 @@ def online_softmax_reduce(
num_warps = rows_per_block * warps_per_row
cute.arch.mbarrier_arrive_and_expect_tx(
mbar_ptr,
- num_warps * cluster_n * reduction_buffer.element_type.width // 8,
+ num_warps
+ * cluster_n
+ * reduction_buffer.element_type.width
+ // 8,
)
if lane_idx < cluster_n:
store_shared_remote(
f32x2_to_i64(max_x, sum_exp_x),
- elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))),
+ elem_pointer(
+ reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))
+ ),
mbar_ptr,
peer_cta_rank_in_cluster=lane_idx,
)
- cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0)
+ cute.arch.mbarrier_wait(
+ mbar_ptr, phase=phase if phase is not None else 0
+ )
num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE)
max_x_single_warp = cute.make_fragment(num_iter, Float32)
max_x_single_warp.fill(-Float32.inf)
@@ -591,7 +625,9 @@ def get_copy_atom(
num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width))
copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp()
- return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits, loc=loc, ip=ip)
+ return cute.make_copy_atom(
+ copy_op, dtype, num_bits_per_copy=num_copy_bits, loc=loc, ip=ip
+ )
@dsl_user_op
@@ -606,7 +642,9 @@ def copy(
ip=None,
**kwargs,
) -> None:
- copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async, loc=loc, ip=ip)
+ copy_atom = get_copy_atom(
+ src.element_type, num_copy_elems, is_async, loc=loc, ip=ip
+ )
cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs)
@@ -637,15 +675,21 @@ def _set_cluster_n(self) -> None:
def _get_num_threads(self) -> int:
return 128 if self.N <= 16384 else 256
- def _get_tv_layout(self, num_copy_bits: int = 128) -> Tuple[cute.Shape, cute.Layout]:
+ def _get_tv_layout(
+ self, num_copy_bits: int = 128
+ ) -> Tuple[cute.Shape, cute.Layout]:
vecsize = num_copy_bits // self.dtype.width
- assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}"
+ assert self.N % vecsize == 0, (
+ f"Input N {self.N} is not divisible by vector size {vecsize}"
+ )
num_threads = self._get_num_threads()
assert num_threads % cute.arch.WARP_SIZE == 0
threads_per_row = self._calculate_threads_per_row()
self._set_cluster_n()
- num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n)
+ num_blocks_N = cute.ceil_div(
+ self.N // vecsize, threads_per_row * self.cluster_n
+ )
cols_per_block = num_threads // threads_per_row
tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row)
tv_layout = cute.make_layout(
@@ -660,11 +704,16 @@ def _get_tv_layout(self, num_copy_bits: int = 128) -> Tuple[cute.Shape, cute.Lay
def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int:
return (
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))
- + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
+ + self.stage
+ * num_warps
+ * self.cluster_n
+ * (self.reduction_dtype.width // 8)
+ self.stage * (cutlass.Int64.width // 8)
)
- def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int) -> cute.Layout:
+ def _get_reduction_buffer_layout(
+ self, tv_layout: cute.Layout, cluster_n: int
+ ) -> cute.Layout:
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1)
return cute.make_ordered_layout(
@@ -723,7 +772,9 @@ def __init__(self, dtype: cutlass.Numeric, N: int):
super().__init__(dtype, N, stage=2, reduction_dtype=Float32)
self.reload_wdy = None if N <= 16 * 1024 else "smem"
if self.N > 128 * 1024 and self.dtype.width >= 32:
- raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits")
+ raise ValueError(
+ "RMSNormBackward does not support N > 128k with dtype >= 32 bits"
+ )
def _get_num_threads(self) -> int:
return 128 if self.N <= 4096 else 256
@@ -736,7 +787,11 @@ def _calculate_threads_per_row(self) -> int:
else (
16
if N <= 128
- else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256)))
+ else (
+ 32
+ if N <= 256
+ else (64 if N <= 512 else (128 if N <= 4096 else 256))
+ )
)
)
@@ -745,7 +800,11 @@ def _set_cluster_n(self) -> None:
cluster_n = (
1
if N <= 8 * 1024
- else (2 if N <= 16 * 1024 else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16)))
+ else (
+ 2
+ if N <= 16 * 1024
+ else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16))
+ )
)
self.cluster_n = cluster_n
@@ -755,7 +814,10 @@ def _smem_size_in_bytes(self, tiler_mn, num_warps: int, do_dtype=None) -> int:
return (
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
+ cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2
- + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
+ + self.stage
+ * num_warps
+ * self.cluster_n
+ * (self.reduction_dtype.width // 8)
+ self.stage * (cutlass.Int64.width // 8) * 2
)
@@ -783,7 +845,9 @@ def new_stride(t):
)
mX, mdO, mdResO, mdX, mdRes = [
- cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
+ cute.make_tensor(
+ t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))
+ )
if const_expr(t is not None)
else None
for t in (mX, mdO, mdResO, mdX, mdRes)
@@ -802,7 +866,9 @@ def new_stride(t):
num_copy_bits=128 // largest_dtype_width * mX.element_type.width
)
num_threads = (
- cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
)
num_warps = num_threads // cute.arch.WARP_SIZE
if const_expr(mW is not None):
@@ -814,7 +880,9 @@ def new_stride(t):
num_blocks = sm_count
kernel = (
- self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn)
+ self.kernel(
+ mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn
+ )
if _KERNEL_ACCEPTS_LAYOUT_ARGS
else self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes)
)
@@ -822,7 +890,9 @@ def new_stride(t):
grid=[num_blocks, self.cluster_n, 1],
block=[num_threads, 1, 1],
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
- smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type),
+ smem=self._smem_size_in_bytes(
+ tiler_mn, num_warps, do_dtype=mdO.element_type
+ ),
stream=stream,
)
@@ -856,7 +926,9 @@ def _kernel_impl(
idX = cute.make_identity_tensor(shape)
smem = cutlass.utils.SmemAllocator()
- smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2))
+ smem_layout = cute.make_ordered_layout(
+ (tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2)
+ )
sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16)
sdO = smem.allocate_tensor(mdO.element_type, smem_layout, byte_alignment=16)
reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
@@ -870,8 +942,12 @@ def _kernel_impl(
mbar_full_ptr, mbar_empty_ptr = None, None
num_copy_elems_X = tv_layout.shape[1][0]
- copy_atom_load_X = get_copy_atom(mX.element_type, num_copy_elems_X, is_async=False)
- thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx)
+ copy_atom_load_X = get_copy_atom(
+ mX.element_type, num_copy_elems_X, is_async=False
+ )
+ thr_copy_X = cute.make_tiled_copy(
+ copy_atom_load_X, tv_layout, tiler_mn
+ ).get_slice(tidx)
copy_fn = partial(copy, num_copy_elems=num_copy_elems_X)
gX, gdO, gdResO, gdX, gdRes, cX = [
@@ -898,7 +974,8 @@ def _kernel_impl(
tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None]
tXrX, tXrdO, tXrdX = [
- cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX)
+ cute.make_fragment_like(thr[None, None, None, 0])
+ for thr in (tXgX, tXgdO, tXgdX)
]
tXrdResO = None
if const_expr(mdResO is not None):
@@ -959,10 +1036,24 @@ def _kernel_impl(
for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim):
row = tXcX[None, None, None, bidx][0][0]
if row + gdim * tiler_mn[0] < M:
- tXgX_cur = coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0]
- tXgdO_cur = coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0]
- copy_fn(tXgX_cur, tXsX[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
- copy_fn(tXgdO_cur, tXsdO[None, None, None, stage ^ 1], pred=tXpX, is_async=True)
+ tXgX_cur = coord_offset_i64(bidx + gdim, tXgX, dim=3)[
+ None, None, None, 0
+ ]
+ tXgdO_cur = coord_offset_i64(bidx + gdim, tXgdO, dim=3)[
+ None, None, None, 0
+ ]
+ copy_fn(
+ tXgX_cur,
+ tXsX[None, None, None, stage ^ 1],
+ pred=tXpX,
+ is_async=True,
+ )
+ copy_fn(
+ tXgdO_cur,
+ tXsdO[None, None, None, stage ^ 1],
+ pred=tXpX,
+ is_async=True,
+ )
elif tiler_mn[0] > 1:
fill_oob(
tXsX[None, None, None, stage ^ 1],
@@ -979,7 +1070,9 @@ def _kernel_impl(
if row < M or tiler_mn[0] == 1:
rstd_val = mRstd[row]
if const_expr(mdResO is not None):
- tXgdResO_cur = coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0]
+ tXgdResO_cur = coord_offset_i64(bidx, tXgdResO, dim=3)[
+ None, None, None, 0
+ ]
if row < M or tiler_mn[0] == 1:
copy_fn(tXgdResO_cur, tXrdResO, pred=tXpX)
elif tiler_mn[0] > 1:
@@ -1036,7 +1129,9 @@ def _kernel_impl(
copy_fn(tXrdX, tXgdX_cur, pred=tXpX)
if const_expr(mdRes is not None):
tXrdRes.store(dx.to(tXrdRes.element_type))
- tXgdRes_cur = coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0]
+ tXgdRes_cur = coord_offset_i64(bidx, tXgdRes, dim=3)[
+ None, None, None, 0
+ ]
if row < M or tiler_mn[0] == 1:
copy_fn(tXrdRes, tXgdRes_cur, pred=tXpX)
if const_expr(mdW is not None):
@@ -1204,7 +1299,9 @@ def get_sm_count(
num_sms = props.multi_processor_count
sm_count_multiple = (
- 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
+ 16
+ if N <= 256
+ else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1)))
)
sm_count = num_sms
if N <= 8192:
diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
index 9df9f16..e921947 100644
--- a/oink/src/kernelagent_oink/blackwell/rmsnorm.py
+++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py
@@ -121,7 +121,9 @@ def _env_flag(name: str, default: bool) -> bool:
# - If you want to force stage-2 even when the pointer path is available (for
# experimentation / A-B testing), set this env var **before** importing this
# module.
-_FORCE_RMSNORM_STAGE2_FWD = _env_flag("KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2", default=False)
+_FORCE_RMSNORM_STAGE2_FWD = _env_flag(
+ "KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2", default=False
+)
# CuTeDSL stability probe for the experimental cluster_n>1 + direct-GMEM schedule.
#
@@ -2771,7 +2773,9 @@ def rmsnorm_forward(
# Preserve stride contracts for torch.compile consistency, even
# when using the optional stage-2 implementation.
if y.stride() != x.stride():
- y_strided = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype)
+ y_strided = torch.empty_strided(
+ x.shape, x.stride(), device=x.device, dtype=x.dtype
+ )
y_strided.copy_(y)
y = y_strided
if residual is not None and residual_out is not None:
@@ -3036,7 +3040,9 @@ def new_stride(t):
)
mX, mdO, mdResO, mdX, mdRes = [
- cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
+ cute.make_tensor(
+ t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))
+ )
if const_expr(t is not None)
else None
for t in (mX, mdO, mdResO, mdX, mdRes)
@@ -3056,7 +3062,9 @@ def new_stride(t):
num_copy_bits=128 // largest_dtype_width * mX.element_type.width
)
num_threads = (
- cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
)
num_warps = num_threads // cute.arch.WARP_SIZE
if const_expr(mW is not None):
@@ -3067,7 +3075,9 @@ def new_stride(t):
num_blocks = sm_count
kernel = (
- self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn)
+ self.kernel(
+ mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn
+ )
if _KERNEL_ACCEPTS_LAYOUT_ARGS
else self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes)
)
@@ -3075,7 +3085,9 @@ def new_stride(t):
grid=[num_blocks, self.cluster_n, 1],
block=[num_threads, 1, 1],
cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None,
- smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type),
+ smem=self._smem_size_in_bytes(
+ tiler_mn, num_warps, do_dtype=mdO.element_type
+ ),
stream=stream,
)
@@ -3160,8 +3172,8 @@ def _convert_mx(t: Tensor) -> cute.Tensor:
if db_partial is not None
else None
)
- rstd_tensor = (
- from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+ rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(
+ leading_dim=0
)
current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
@@ -3261,7 +3273,9 @@ def rmsnorm_backward(
else:
dw_partial = None
db_partial = (
- torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None
+ torch.empty(sm_count, N, device=device, dtype=torch.float32)
+ if has_bias
+ else None
)
_rmsnorm_bwd_sm100(
diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py
index fec5bf4..2b5b36d 100644
--- a/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py
+++ b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py
@@ -101,7 +101,9 @@ def copy_tiled(
class RMSNormSM100WithStage2:
- def __init__(self, N: int, dtype: type[cutlass.Numeric], stage: Optional[int] = None):
+ def __init__(
+ self, N: int, dtype: type[cutlass.Numeric], stage: Optional[int] = None
+ ):
self.N = N
self.dtype = dtype
self.stage = 1 if stage is None else stage
@@ -172,7 +174,10 @@ def _tv_layout(self, num_copy_bits: int = 256) -> Tuple[cute.Shape, cute.Layout]
tiler_mn = (cols_per_block, vecsize * num_blocks_N * tpr)
tv_layout = cute.make_layout(
((tpr, cols_per_block), (vecsize, num_blocks_N)),
- stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * tpr)),
+ stride=(
+ (vecsize * cols_per_block, 1),
+ (cols_per_block, cols_per_block * vecsize * tpr),
+ ),
)
return tiler_mn, tv_layout
@@ -198,7 +203,9 @@ def new_stride(t):
)
mX, mRes, mO, mResO = [
- cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)))
+ cute.make_tensor(
+ t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))
+ )
if const_expr(t is not None)
else None
for t in (mX, mRes, mO, mResO)
@@ -209,36 +216,48 @@ def new_stride(t):
copy_bits = const_expr(128)
tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits)
num_threads = (
- cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._num_threads()
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._num_threads()
)
num_warps = num_threads // cute.arch.WARP_SIZE
threads_per_row = (
- tv_layout.shape[0][0] if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._threads_per_row()
+ tv_layout.shape[0][0]
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._threads_per_row()
)
warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1)
cluster_n = self._cluster_n()
if const_expr(mW is not None):
mW = cute.make_tensor(
- mW.iterator, cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
+ mW.iterator,
+ cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))),
)
if const_expr(mB is not None):
mB = cute.make_tensor(
- mB.iterator, cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,)))
+ mB.iterator,
+ cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))),
)
if const_expr(mRstd is not None):
mRstd = cute.make_tensor(
- mRstd.iterator, cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,)))
+ mRstd.iterator,
+ cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))),
)
stage_bufs = 2 if self.stage > 1 else 1
- tile_bytes_x = cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * stage_bufs
+ tile_bytes_x = (
+ cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * stage_bufs
+ )
tile_bytes_res = (
- cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn)) * stage_bufs
+ cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn))
+ * stage_bufs
if const_expr(mRes is not None)
else 0
)
- red_bytes = self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8)
+ red_bytes = (
+ self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8)
+ )
mbar_bytes = self.stage * (cutlass.Int64.width // 8)
smem_bytes = tile_bytes_x + tile_bytes_res + red_bytes + mbar_bytes
@@ -299,11 +318,15 @@ def _kernel_impl(
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
cluster_n = self._cluster_n()
- cluster_y = const_expr(0) if const_expr(cluster_n == 1) else cute.arch.block_idx()[1]
+ cluster_y = (
+ const_expr(0) if const_expr(cluster_n == 1) else cute.arch.block_idx()[1]
+ )
smem = cutlass.utils.SmemAllocator()
sX0 = smem.allocate_tensor(
- mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32
+ mX.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
)
sX1 = (
smem.allocate_tensor(
@@ -316,7 +339,9 @@ def _kernel_impl(
)
sRes0 = (
smem.allocate_tensor(
- mRes.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32
+ mRes.element_type,
+ cute.make_ordered_layout(tiler_mn, order=(1, 0)),
+ byte_alignment=32,
)
if const_expr(mRes is not None)
else None
@@ -331,18 +356,24 @@ def _kernel_impl(
else None
)
- reduction_buffer, mbar_ptr = self._alloc_reduction_and_mbar(smem, num_warps, warps_per_row)
+ reduction_buffer, mbar_ptr = self._alloc_reduction_and_mbar(
+ smem, num_warps, warps_per_row
+ )
shape = mX.shape
idX = cute.make_identity_tensor(shape)
num_copy_elems_X = tv_layout.shape[1][0]
use_async = const_expr(self.N >= 1024)
- copy_atom = get_copy_atom_bw(mX.element_type, num_copy_elems_X, is_async=use_async)
+ copy_atom = get_copy_atom_bw(
+ mX.element_type, num_copy_elems_X, is_async=use_async
+ )
thr_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn).get_slice(tidx)
gW, gB = [
- cute.local_tile(t, tiler_mn, (0, cluster_y)) if const_expr(t is not None) else None
+ cute.local_tile(t, tiler_mn, (0, cluster_y))
+ if const_expr(t is not None)
+ else None
for t in (mW, mB)
]
tXgW = thr_copy.partition_S(gW) if const_expr(mW is not None) else None
@@ -350,32 +381,50 @@ def _kernel_impl(
tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None
tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None
if const_expr(mW is not None):
- cute.copy(get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False), tXgW, tXrW)
+ cute.copy(
+ get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False),
+ tXgW,
+ tXrW,
+ )
if const_expr(mB is not None):
- cute.copy(get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False), tXgB, tXrB)
+ cute.copy(
+ get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False),
+ tXgB,
+ tXrB,
+ )
self._init_cluster(tidx, mbar_ptr)
mX_i, mRes_i, mO_i, mResO_i = [
- qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t) if t is not None else None
+ qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t)
+ if t is not None
+ else None
for t in (mX, mRes, mO, mResO)
]
gX_i = cute.local_tile(mX_i, tiler_mn, (0, cluster_y))
gO_i = cute.local_tile(mO_i, tiler_mn, (0, cluster_y))
gRes_i = (
- cute.local_tile(mRes_i, tiler_mn, (0, cluster_y)) if const_expr(mRes is not None) else None
+ cute.local_tile(mRes_i, tiler_mn, (0, cluster_y))
+ if const_expr(mRes is not None)
+ else None
)
gResO_i = (
- cute.local_tile(mResO_i, tiler_mn, (0, cluster_y)) if const_expr(mResO is not None) else None
+ cute.local_tile(mResO_i, tiler_mn, (0, cluster_y))
+ if const_expr(mResO is not None)
+ else None
)
gRstd_i = (
- cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y)) if const_expr(mRstd is not None) else None
+ cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y))
+ if const_expr(mRstd is not None)
+ else None
)
cX_i = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
tXcX_i = thr_copy.partition_S(cX_i)[(0, None), None, None]
row_i = tXcX_i[0][0]
- tXgRstd_i = thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None
+ tXgRstd_i = (
+ thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None
+ )
# Intra-row K-loop cp.async ping-pong (two-pass) for N≈6k/8k (stage=2)
if const_expr(self.stage > 1 and (shape[1] == 6144 or shape[1] == 8192)):
@@ -388,37 +437,72 @@ def _kernel_impl(
tiler_mn_tile = (tiler_mn[0], tile_n)
sX0_tile = cute.local_tile(sX0, tiler_mn_tile, (0, 0))
- sX1_tile = cute.local_tile(sX1, tiler_mn_tile, (0, 0)) if const_expr(self.stage > 1) else None
+ sX1_tile = (
+ cute.local_tile(sX1, tiler_mn_tile, (0, 0))
+ if const_expr(self.stage > 1)
+ else None
+ )
sRes0_tile = (
- cute.local_tile(sRes0, tiler_mn_tile, (0, 0)) if const_expr(mRes is not None) else None
+ cute.local_tile(sRes0, tiler_mn_tile, (0, 0))
+ if const_expr(mRes is not None)
+ else None
)
sRes1_tile = (
- cute.local_tile(sRes1, tiler_mn_tile, (0, 0)) if const_expr(mRes is not None and self.stage > 1) else None
+ cute.local_tile(sRes1, tiler_mn_tile, (0, 0))
+ if const_expr(mRes is not None and self.stage > 1)
+ else None
)
tv_layout_tile = cute.make_layout(
((tpr, tiler_mn[0]), (vecsize, tile_factor)),
- stride=((vecsize * tiler_mn[0], 1), (tiler_mn[0], tiler_mn[0] * vecsize * tpr)),
+ stride=(
+ (vecsize * tiler_mn[0], 1),
+ (tiler_mn[0], tiler_mn[0] * vecsize * tpr),
+ ),
)
- thr_copy_tile = cute.make_tiled_copy(copy_atom, tv_layout_tile, tiler_mn_tile).get_slice(tidx)
+ thr_copy_tile = cute.make_tiled_copy(
+ copy_atom, tv_layout_tile, tiler_mn_tile
+ ).get_slice(tidx)
sum_sq_acc = cute.Float32(0.0)
k_off0 = const_expr(0) * tile_n
- gX_0 = cute.local_tile(qutils.domain_offset_i64((0, k_off0), mX_i), tiler_mn_tile, (0, cluster_y))
+ gX_0 = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off0), mX_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
tXgX_0 = thr_copy_tile.partition_S(gX_0)
tXsX_0 = thr_copy_tile.partition_D(sX0_tile)
- cX_0 = cute.local_tile(cute.domain_offset((0, k_off0), cX_i), tiler_mn_tile, (0, cluster_y))
+ cX_0 = cute.local_tile(
+ cute.domain_offset((0, k_off0), cX_i), tiler_mn_tile, (0, cluster_y)
+ )
tXc_0 = thr_copy_tile.partition_S(cX_0)
tXp_0 = qutils.predicate_k(tXc_0, limit=shape[1])
tXp_ping = tXp_0
tXp_pong = tXp_0
if row_i < shape[0]:
- copy_tiled(tXgX_0, tXsX_0, num_copy_elems=vecsize, is_async=use_async, pred=tXp_0)
+ copy_tiled(
+ tXgX_0,
+ tXsX_0,
+ num_copy_elems=vecsize,
+ is_async=use_async,
+ pred=tXp_0,
+ )
if const_expr(mRes is not None):
- gRes_0 = cute.local_tile(qutils.domain_offset_i64((0, k_off0), mRes_i), tiler_mn_tile, (0, cluster_y))
+ gRes_0 = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off0), mRes_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
tXgRes_0 = thr_copy_tile.partition_S(gRes_0)
tXsRes_0 = thr_copy_tile.partition_D(sRes0_tile)
- copy_tiled(tXgRes_0, tXsRes_0, num_copy_elems=vecsize, is_async=use_async, pred=tXp_0)
+ copy_tiled(
+ tXgRes_0,
+ tXsRes_0,
+ num_copy_elems=vecsize,
+ is_async=use_async,
+ pred=tXp_0,
+ )
if const_expr(use_async):
cute.arch.cp_async_commit_group()
@@ -426,29 +510,57 @@ def _kernel_impl(
next_t = t + 1
if next_t < num_tiles:
k_off_n = next_t * tile_n
- gX_n = cute.local_tile(qutils.domain_offset_i64((0, k_off_n), mX_i), tiler_mn_tile, (0, cluster_y))
+ gX_n = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off_n), mX_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
tXgX_n = thr_copy_tile.partition_S(gX_n)
- cX_n = cute.local_tile(cute.domain_offset((0, k_off_n), cX_i), tiler_mn_tile, (0, cluster_y))
+ cX_n = cute.local_tile(
+ cute.domain_offset((0, k_off_n), cX_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
tXc_n = thr_copy_tile.partition_S(cX_n)
tXp_n = qutils.predicate_k(tXc_n, limit=shape[1])
if const_expr((t % 2) == 0):
tXsX_n = thr_copy_tile.partition_D(sX1_tile)
tXsRes_n = (
- thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None
+ thr_copy_tile.partition_D(sRes1_tile)
+ if const_expr(mRes is not None)
+ else None
)
tXp_pong = tXp_n
else:
tXsX_n = thr_copy_tile.partition_D(sX0_tile)
tXsRes_n = (
- thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None
+ thr_copy_tile.partition_D(sRes0_tile)
+ if const_expr(mRes is not None)
+ else None
)
tXp_ping = tXp_n
if row_i < shape[0]:
- copy_tiled(tXgX_n, tXsX_n, num_copy_elems=vecsize, is_async=use_async, pred=tXp_n)
+ copy_tiled(
+ tXgX_n,
+ tXsX_n,
+ num_copy_elems=vecsize,
+ is_async=use_async,
+ pred=tXp_n,
+ )
if const_expr(mRes is not None):
- gRes_n = cute.local_tile(qutils.domain_offset_i64((0, k_off_n), mRes_i), tiler_mn_tile, (0, cluster_y))
+ gRes_n = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off_n), mRes_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
tXgRes_n = thr_copy_tile.partition_S(gRes_n)
- copy_tiled(tXgRes_n, tXsRes_n, num_copy_elems=vecsize, is_async=use_async, pred=tXp_n)
+ copy_tiled(
+ tXgRes_n,
+ tXsRes_n,
+ num_copy_elems=vecsize,
+ is_async=use_async,
+ pred=tXp_n,
+ )
if const_expr(use_async):
cute.arch.cp_async_commit_group()
if const_expr(use_async):
@@ -456,36 +568,62 @@ def _kernel_impl(
if const_expr((t % 2) == 0):
tXsX_cur = thr_copy_tile.partition_D(sX0_tile)
- tXsRes_cur = thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes0_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
pred_cur = tXp_ping
else:
tXsX_cur = thr_copy_tile.partition_D(sX1_tile)
- tXsRes_cur = thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None
+ tXsRes_cur = (
+ thr_copy_tile.partition_D(sRes1_tile)
+ if const_expr(mRes is not None)
+ else None
+ )
pred_cur = tXp_pong
qutils.fill_oob(tXsX_cur, pred_cur, mX.element_type.zero)
if const_expr(mRes is not None):
qutils.fill_oob(tXsRes_cur, pred_cur, mRes.element_type.zero)
k_off = t * tile_n
- gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, cluster_y))
+ gX_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mX_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
tXgX_t = thr_copy_tile.partition_S(gX_t)
tXrX = cute.make_fragment_like(tXgX_t)
cute.autovec_copy(tXsX_cur, tXrX)
x = tXrX.load().to(cute.Float32)
if const_expr(mRes is not None):
- gRes_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, cluster_y))
+ gRes_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mRes_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
tXgRes_t = thr_copy_tile.partition_S(gRes_t)
tXrRes = cute.make_fragment_like(tXgRes_t)
cute.autovec_copy(tXsRes_cur, tXrRes)
x += tXrRes.load().to(cute.Float32)
if const_expr(mResO is not None):
- gResO_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mResO_i), tiler_mn_tile, (0, cluster_y))
+ gResO_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mResO_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
tXgResO_t = thr_copy_tile.partition_D(gResO_t)
tXrResO = cute.make_fragment_like(tXgResO_t)
tXrResO.store(x.to(tXrResO.element_type))
if row_i < shape[0]:
- copy_tiled(tXrResO, tXgResO_t, num_copy_elems=vecsize, is_async=False, pred=pred_cur)
+ copy_tiled(
+ tXrResO,
+ tXgResO_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=pred_cur,
+ )
sum_sq_tile = row_reduce(
x * x,
@@ -494,7 +632,9 @@ def _kernel_impl(
reduction_buffer[None, None, 0],
mbar_ptr,
init_val=0.0,
- hook_fn=(cute.arch.cluster_wait if const_expr(cluster_n > 1) else None),
+ hook_fn=(
+ cute.arch.cluster_wait if const_expr(cluster_n > 1) else None
+ ),
)
sum_sq_acc = sum_sq_acc + sum_sq_tile
@@ -509,32 +649,46 @@ def _kernel_impl(
for t in cutlass.range_constexpr(num_tiles):
k_off = t * tile_n
- cX_t = cute.local_tile(cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, cluster_y))
+ cX_t = cute.local_tile(
+ cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, cluster_y)
+ )
tXc_t = thr_copy_tile.partition_S(cX_t)
tXp_t = qutils.predicate_k(tXc_t, limit=shape[1])
if const_expr((t % 2) == 0):
tXsX_cur = thr_copy_tile.partition_D(sX0_tile)
tXsRes_cur = (
- thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None
+ thr_copy_tile.partition_D(sRes0_tile)
+ if const_expr(mRes is not None)
+ else None
)
else:
tXsX_cur = thr_copy_tile.partition_D(sX1_tile)
tXsRes_cur = (
- thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None
+ thr_copy_tile.partition_D(sRes1_tile)
+ if const_expr(mRes is not None)
+ else None
)
qutils.fill_oob(tXsX_cur, tXp_t, mX.element_type.zero)
if const_expr(mRes is not None):
qutils.fill_oob(tXsRes_cur, tXp_t, mRes.element_type.zero)
- gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, cluster_y))
+ gX_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mX_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
tXgX_t = thr_copy_tile.partition_S(gX_t)
tXrX = cute.make_fragment_like(tXgX_t)
cute.autovec_copy(tXsX_cur, tXrX)
x = tXrX.load().to(cute.Float32)
if const_expr(mRes is not None):
- gRes_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, cluster_y))
+ gRes_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mRes_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
tXgRes_t = thr_copy_tile.partition_S(gRes_t)
tXrRes = cute.make_fragment_like(tXgRes_t)
cute.autovec_copy(tXsRes_cur, tXrRes)
@@ -542,35 +696,67 @@ def _kernel_impl(
y = x * rstd
if const_expr(mW is not None):
- gW_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mW), tiler_mn_tile, (0, cluster_y))
+ gW_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mW),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
tWgW_t = thr_copy_tile.partition_S(gW_t)
tWrW_t = cute.make_fragment_like(tWgW_t)
- copy_tiled(tWgW_t, tWrW_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t)
+ copy_tiled(
+ tWgW_t,
+ tWrW_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=tXp_t,
+ )
y = y * tWrW_t.load().to(cute.Float32)
if const_expr(mB is not None):
- gB_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mB), tiler_mn_tile, (0, cluster_y))
+ gB_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mB),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
tWgB_t = thr_copy_tile.partition_S(gB_t)
tWrB_t = cute.make_fragment_like(tWgB_t)
- copy_tiled(tWgB_t, tWrB_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t)
+ copy_tiled(
+ tWgB_t,
+ tWrB_t,
+ num_copy_elems=vecsize,
+ is_async=False,
+ pred=tXp_t,
+ )
y = y + tWrB_t.load().to(cute.Float32)
- gO_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mO_i), tiler_mn_tile, (0, cluster_y))
+ gO_t = cute.local_tile(
+ qutils.domain_offset_i64((0, k_off), mO_i),
+ tiler_mn_tile,
+ (0, cluster_y),
+ )
tXgO_t = thr_copy_tile.partition_D(gO_t)
tXrO = cute.make_fragment_like(tXgO_t)
tXrO.store(y.to(tXrO.element_type))
if row_i < shape[0]:
- copy_tiled(tXrO, tXgO_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t)
+ copy_tiled(
+ tXrO, tXgO_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t
+ )
return
# Fallback: single-stage path identical to current rmsnorm.py
tXgX_i = thr_copy.partition_S(gX_i)
- tXgRes_i = thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None
+ tXgRes_i = (
+ thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None
+ )
tXgO_i = thr_copy.partition_D(gO_i)
- tXgResO_i = thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None
+ tXgResO_i = (
+ thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None
+ )
is_even_N_i = const_expr(shape[1] == tiler_mn[1] * cluster_n)
tXpX_i = (
- qutils.predicate_k(thr_copy.partition_S(cX_i), limit=shape[1]) if not is_even_N_i else None
+ qutils.predicate_k(thr_copy.partition_S(cX_i), limit=shape[1])
+ if not is_even_N_i
+ else None
)
if row_i < shape[0]:
@@ -594,7 +780,9 @@ def _kernel_impl(
tXrResO.store(x.to(tXrResO.element_type))
if row_i < shape[0]:
cute.copy(
- get_copy_atom_bw(tXrResO.element_type, num_copy_elems_X, is_async=False),
+ get_copy_atom_bw(
+ tXrResO.element_type, num_copy_elems_X, is_async=False
+ ),
tXrResO,
tXgResO_i,
)
@@ -715,7 +903,9 @@ def _alloc_reduction_and_mbar(
(num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage),
order=(1, 0, 2),
)
- reduction_buffer = smem.allocate_tensor(self.reduction_dtype, red_layout, byte_alignment=4)
+ reduction_buffer = smem.allocate_tensor(
+ self.reduction_dtype, red_layout, byte_alignment=4
+ )
if const_expr(cluster_n > 1):
mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage)
else:
@@ -745,9 +935,9 @@ def rmsnorm_forward_with_stage2(
dtype = TORCH2CUTE_DTYPE[x.dtype]
def _convert_x(t: Tensor) -> cute.Tensor:
- return from_dlpack(
- t.detach(), assumed_align=32
- ).mark_layout_dynamic(leading_dim=1)
+ return from_dlpack(t.detach(), assumed_align=32).mark_layout_dynamic(
+ leading_dim=1
+ )
mX = _convert_x(x)
mRes = _convert_x(residual) if residual is not None else None
@@ -755,7 +945,9 @@ def _convert_x(t: Tensor) -> cute.Tensor:
mO = from_dlpack(out.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1)
mW = (
- from_dlpack(weight.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=0)
+ from_dlpack(weight.detach(), assumed_align=32).mark_layout_dynamic(
+ leading_dim=0
+ )
if weight is not None
else None
)
@@ -766,7 +958,9 @@ def _convert_x(t: Tensor) -> cute.Tensor:
)
if store_rstd:
rstd = torch.empty(M, device=x.device, dtype=torch.float32)
- mRstd = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0)
+ mRstd = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(
+ leading_dim=0
+ )
else:
rstd = None
mRstd = None
@@ -775,7 +969,9 @@ def _convert_x(t: Tensor) -> cute.Tensor:
mResO = None
if residual is not None:
residual_out = torch.empty_like(residual)
- mResO = from_dlpack(residual_out.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1)
+ mResO = from_dlpack(
+ residual_out.detach(), assumed_align=32
+ ).mark_layout_dynamic(leading_dim=1)
# Enable the intra-row cp.async K-loop only for DSv3-style large-N rows
# with very large M, where there is enough work per row to amortize the
@@ -788,7 +984,7 @@ def _convert_x(t: Tensor) -> cute.Tensor:
op._tpr_override = 128 # type: ignore[attr-defined]
# Prefer 1 row/CTA at N=6144; keep 2 rows/CTA at N=8192 to match
# the original tuning there.
- op._nt_override = (128 if N == 6144 else 256) # type: ignore[attr-defined]
+ op._nt_override = 128 if N == 6144 else 256 # type: ignore[attr-defined]
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
key = (
@@ -803,7 +999,9 @@ def _convert_x(t: Tensor) -> cute.Tensor:
)
compiled = _COMPILE_CACHE.get(key)
if compiled is None:
- compiled = cute.compile(op, mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps))
+ compiled = cute.compile(
+ op, mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps)
+ )
_COMPILE_CACHE[key] = compiled
compiled(mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps))
return out, rstd, residual_out
diff --git a/oink/src/kernelagent_oink/blackwell/softmax.py b/oink/src/kernelagent_oink/blackwell/softmax.py
index a8a2791..6a7eb54 100644
--- a/oink/src/kernelagent_oink/blackwell/softmax.py
+++ b/oink/src/kernelagent_oink/blackwell/softmax.py
@@ -134,7 +134,9 @@ def __call__(self, mX: cute.Tensor, mO: cute.Tensor, stream: cuda.CUstream) -> N
# Use the generic ReductionBase tiling with 128-bit vectorization.
tiler_mn, tv_layout = self._get_tv_layout()
num_threads = (
- cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
)
num_warps = num_threads // cute.arch.WARP_SIZE
kernel = (
@@ -201,7 +203,9 @@ def _kernel_impl(
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
byte_alignment=16,
)
- reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
+ smem, tv_layout
+ )
# Copy atoms for gmem <-> smem and smem <-> gmem.
# Use 128-bit cp.async for global->shared and 128-bit vectorized stores.
@@ -216,8 +220,12 @@ def _kernel_impl(
num_bits_per_copy=128,
)
- thr_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn).get_slice(tidx)
- thr_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
+ thr_copy_load = cute.make_tiled_copy(
+ copy_atom_load, tv_layout, tiler_mn
+ ).get_slice(tidx)
+ thr_copy_store = cute.make_tiled_copy(
+ copy_atom_store, tv_layout, tiler_mn
+ ).get_slice(tidx)
tXgX = thr_copy_load.partition_S(gX)
tXsX = thr_copy_load.partition_D(sX)
@@ -349,7 +357,10 @@ def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int:
# Store both y and dy tiles plus reduction buffers and mbarriers.
return (
cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2
- + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8)
+ + self.stage
+ * num_warps
+ * self.cluster_n
+ * (self.reduction_dtype.width // 8)
+ self.stage * (cutlass.Int64.width // 8)
)
@@ -367,7 +378,9 @@ def __call__(
# Use the generic ReductionBase tiling with 128-bit vectorization.
tiler_mn, tv_layout = self._get_tv_layout()
num_threads = (
- cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads()
+ cute.size(tv_layout, mode=[0])
+ if _KERNEL_ACCEPTS_LAYOUT_ARGS
+ else self._get_num_threads()
)
num_warps = num_threads // cute.arch.WARP_SIZE
kernel = (
@@ -423,7 +436,9 @@ def _kernel_impl(
mdY, mY, mdX = [
domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mdY, mY, mdX)
]
- gdY, gY, gdX = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX)]
+ gdY, gY, gdX = [
+ cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX)
+ ]
cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y))
smem = cutlass.utils.SmemAllocator()
@@ -437,7 +452,9 @@ def _kernel_impl(
cute.make_ordered_layout(tiler_mn, order=(1, 0)),
byte_alignment=16,
)
- reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout)
+ reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(
+ smem, tv_layout
+ )
copy_atom_load = cute.make_copy_atom(
cute.nvgpu.cpasync.CopyG2SOp(),
@@ -450,8 +467,12 @@ def _kernel_impl(
num_bits_per_copy=128,
)
- thr_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn).get_slice(tidx)
- thr_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx)
+ thr_copy_load = cute.make_tiled_copy(
+ copy_atom_load, tv_layout, tiler_mn
+ ).get_slice(tidx)
+ thr_copy_store = cute.make_tiled_copy(
+ copy_atom_store, tv_layout, tiler_mn
+ ).get_slice(tidx)
tdYgdY = thr_copy_load.partition_S(gdY)
tdYsdY = thr_copy_load.partition_D(sdY)
@@ -460,7 +481,9 @@ def _kernel_impl(
tdXgdX = thr_copy_store.partition_D(gdX)
tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None]
- tdYrdY, tYrY, tdXrdX = [cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX)]
+ tdYrdY, tYrY, tdXrdX = [
+ cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX)
+ ]
num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE
self._initialize_cluster(tidx, mbar_ptr, num_warps)
@@ -535,9 +558,8 @@ def _convert_2d_tensor(x: Tensor) -> cute.Tensor:
# the shape compact with row-major stride order (0, 1), with mode=0 (batch).
# We intentionally do not call mark_layout_dynamic here to avoid the
# leading_dim stride==1 constraint used in RMSNorm.
- return (
- from_dlpack(x.detach(), assumed_align=16)
- .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1))
+ return from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic(
+ mode=0, stride_order=(0, 1)
)
@@ -581,7 +603,9 @@ def _softmax_forward_ptr_into(*, x: Tensor, out: Tensor) -> None:
compiled = _PTR_FWD_COMPILE_CACHE.get(key)
if compiled is None:
op = SoftmaxFwdSM100(dtype_x, int(N))
- ptr_x = rt.make_ptr(dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_x = rt.make_ptr(
+ dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
ptr_out = rt.make_ptr(
dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
)
@@ -596,8 +620,12 @@ def _softmax_forward_ptr_into(*, x: Tensor, out: Tensor) -> None:
)
_PTR_FWD_COMPILE_CACHE[key] = compiled
- ptr_x = rt.make_ptr(dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
- ptr_out = rt.make_ptr(dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_x = rt.make_ptr(
+ dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_out = rt.make_ptr(
+ dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
compiled(ptr_x, ptr_out, Int32(int(M)), Int32(int(x.stride(0))), stream)
@@ -606,7 +634,9 @@ def _softmax_backward_ptr_into(*, dy: Tensor, y: Tensor, dx: Tensor) -> None:
assert dy.is_cuda and dy.dim() == 2
assert y.is_cuda and y.shape == dy.shape and y.dtype == dy.dtype
assert dx.is_cuda and dx.shape == dy.shape and dx.dtype == dy.dtype
- assert dy.stride() == y.stride() == dx.stride(), "Pointer path expects matching strides"
+ assert dy.stride() == y.stride() == dx.stride(), (
+ "Pointer path expects matching strides"
+ )
M, N = dy.shape
device_index = dy.get_device()
@@ -619,9 +649,15 @@ def _softmax_backward_ptr_into(*, dy: Tensor, y: Tensor, dx: Tensor) -> None:
compiled = _PTR_BWD_COMPILE_CACHE.get(key)
if compiled is None:
op = SoftmaxBwdSM100(dtype_x, int(N))
- ptr_dy = rt.make_ptr(dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
- ptr_y = rt.make_ptr(dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
- ptr_dx = rt.make_ptr(dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_dy = rt.make_ptr(
+ dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_y = rt.make_ptr(
+ dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_dx = rt.make_ptr(
+ dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
ld = Int32(int(dy.stride(0)))
compiled = cute.compile(
op.launch_from_ptrs,
@@ -634,9 +670,15 @@ def _softmax_backward_ptr_into(*, dy: Tensor, y: Tensor, dx: Tensor) -> None:
)
_PTR_BWD_COMPILE_CACHE[key] = compiled
- ptr_dy = rt.make_ptr(dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
- ptr_y = rt.make_ptr(dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
- ptr_dx = rt.make_ptr(dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16)
+ ptr_dy = rt.make_ptr(
+ dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_y = rt.make_ptr(
+ dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
+ ptr_dx = rt.make_ptr(
+ dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16
+ )
compiled(ptr_dy, ptr_y, ptr_dx, Int32(int(M)), Int32(int(dy.stride(0))), stream)
@@ -679,8 +721,14 @@ def softmax_backward(dy: Tensor, y: Tensor) -> Tensor:
N = dy.size(1)
dtype = TORCH2CUTE_DTYPE[dy.dtype]
- if _can_use_ptr_path_2d(dy) and _can_use_ptr_path_2d(y) and dy.stride() == y.stride():
- dx = torch.empty_strided(dy.shape, dy.stride(), device=dy.device, dtype=dy.dtype)
+ if (
+ _can_use_ptr_path_2d(dy)
+ and _can_use_ptr_path_2d(y)
+ and dy.stride() == y.stride()
+ ):
+ dx = torch.empty_strided(
+ dy.shape, dy.stride(), device=dy.device, dtype=dy.dtype
+ )
_softmax_backward_ptr_into(dy=dy, y=y, dx=dx)
return dx
From 7e818eecf3448f9f8173ac4ad5154177fea830e8 Mon Sep 17 00:00:00 2001
From: Laura Wang <3700467+Laurawly@users.noreply.github.com>
Date: Wed, 21 Jan 2026 20:11:32 -0800
Subject: [PATCH 7/8] oink: add license headers to benchmarks
---
oink/benchmarks/benchmark/bench_utils.py | 14 ++++++++++++++
.../benchmark/benchmark_cross_entropy_sm100.py | 14 ++++++++++++++
.../benchmark/benchmark_fused_add_rmsnorm_sm100.py | 14 ++++++++++++++
.../benchmark/benchmark_hbm_roofline_sm100.py | 14 ++++++++++++++
.../benchmark/benchmark_layernorm_sm100.py | 14 ++++++++++++++
.../benchmark/benchmark_rmsnorm_bwd_sm100.py | 14 ++++++++++++++
.../benchmark/benchmark_rmsnorm_sm100.py | 14 ++++++++++++++
.../benchmark/benchmark_softmax_sm100.py | 14 ++++++++++++++
oink/benchmarks/readme/plot_quack_style_svg.py | 14 ++++++++++++++
oink/benchmarks/readme/run_sm100_suite.py | 14 ++++++++++++++
oink/benchmarks/readme/summarize_results.py | 14 ++++++++++++++
11 files changed, 154 insertions(+)
diff --git a/oink/benchmarks/benchmark/bench_utils.py b/oink/benchmarks/benchmark/bench_utils.py
index 0a9ae4b..ef996ec 100644
--- a/oink/benchmarks/benchmark/bench_utils.py
+++ b/oink/benchmarks/benchmark/bench_utils.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from __future__ import annotations
import csv
diff --git a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
index ff1a99b..3c8bf44 100644
--- a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from __future__ import annotations
import argparse
diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
index 863712d..8a0227b 100644
--- a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
Benchmark fused_add_rmsnorm (in-place) on SM100.
diff --git a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
index c22294e..22fb48d 100644
--- a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
HBM roofline microbenchmark for SM100 (GB200 / Blackwell).
diff --git a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
index 3c0e37d..20895b7 100644
--- a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from __future__ import annotations
import argparse
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
index b9909e7..31b335b 100644
--- a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from __future__ import annotations
import argparse
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
index f4c8a5f..39e6cd7 100644
--- a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from __future__ import annotations
import argparse
diff --git a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py
index 995b09f..a5b2b3c 100644
--- a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from __future__ import annotations
import argparse
diff --git a/oink/benchmarks/readme/plot_quack_style_svg.py b/oink/benchmarks/readme/plot_quack_style_svg.py
index af76832..88eebdf 100644
--- a/oink/benchmarks/readme/plot_quack_style_svg.py
+++ b/oink/benchmarks/readme/plot_quack_style_svg.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
"""
Generate Quack-style SVG performance plots (Oink vs Quack) from the SM100 suite
JSON artifacts under `/tmp/kernelagent_oink_sm100_suite_{bf16,fp16}`.
diff --git a/oink/benchmarks/readme/run_sm100_suite.py b/oink/benchmarks/readme/run_sm100_suite.py
index c31d4b5..fb9d603 100644
--- a/oink/benchmarks/readme/run_sm100_suite.py
+++ b/oink/benchmarks/readme/run_sm100_suite.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from __future__ import annotations
import argparse
diff --git a/oink/benchmarks/readme/summarize_results.py b/oink/benchmarks/readme/summarize_results.py
index 29b768e..684694d 100644
--- a/oink/benchmarks/readme/summarize_results.py
+++ b/oink/benchmarks/readme/summarize_results.py
@@ -1,3 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from __future__ import annotations
import argparse
From 5d195d6c2d80b9415d8a5433c54494ccac769eac Mon Sep 17 00:00:00 2001
From: Laura Wang <3700467+Laurawly@users.noreply.github.com>
Date: Thu, 22 Jan 2026 09:59:37 -0800
Subject: [PATCH 8/8] update
---
oink/benchmarks/README.md | 6 +
.../benchmark_fused_add_rmsnorm_sm100.py | 55 +-
.../benchmark/benchmark_rmsnorm_bwd_sm100.py | 11 +-
.../media/sm100_bf16_oink_vs_quack.svg | 350 ++--
.../media/sm100_bf16_oink_vs_quack_dsv3.svg | 714 +++----
.../sm100_bf16_oink_vs_quack_dsv3_all.svg | 570 +++---
..._bf16_oink_vs_quack_dsv3_cross_entropy.svg | 180 +-
...bf16_oink_vs_quack_dsv3_with_layernorm.svg | 1742 ++++++++---------
...m100_bf16_oink_vs_quack_with_layernorm.svg | 460 ++---
.../media/sm100_fp16_oink_vs_quack.svg | 350 ++--
.../media/sm100_fp16_oink_vs_quack_dsv3.svg | 714 +++----
.../sm100_fp16_oink_vs_quack_dsv3_all.svg | 570 +++---
..._fp16_oink_vs_quack_dsv3_cross_entropy.svg | 180 +-
...fp16_oink_vs_quack_dsv3_with_layernorm.svg | 1742 ++++++++---------
...m100_fp16_oink_vs_quack_with_layernorm.svg | 460 ++---
oink/benchmarks/readme/run_sm100_suite.py | 17 +
.../blackwell/cross_entropy.py | 1045 +++++++++-
.../kernelagent_oink/blackwell/fast_launch.py | 115 ++
.../kernelagent_oink/blackwell/layernorm.py | 845 ++++++--
.../kernelagent_oink/blackwell/lite_quack.py | 181 +-
.../src/kernelagent_oink/blackwell/rmsnorm.py | 1287 +++++++++++-
.../src/kernelagent_oink/blackwell/softmax.py | 834 +++++++-
22 files changed, 7996 insertions(+), 4432 deletions(-)
create mode 100644 oink/src/kernelagent_oink/blackwell/fast_launch.py
diff --git a/oink/benchmarks/README.md b/oink/benchmarks/README.md
index ceb7932..a5c4676 100644
--- a/oink/benchmarks/README.md
+++ b/oink/benchmarks/README.md
@@ -96,6 +96,12 @@ CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsn
--json /tmp/fused_add_rmsnorm_sm100_bf16.json
```
+Note on the Quack baseline: Oink exposes an **in-place** fused op (updates `x` and `residual`).
+Quack’s fused kernel produces `out` and `residual_out` out-of-place, so by default the benchmark
+times `quack::_rmsnorm_fwd` **plus** two explicit copies (`x.copy_(out)`, `residual.copy_(residual_out)`)
+to match the in-place semantics (integration-realistic). Use `--quack-baseline kernel` to time only
+the Quack fused kernel with preallocated outputs.
+
### RMSNorm backward
```bash
diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
index 8a0227b..1787d7d 100644
--- a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py
@@ -31,6 +31,15 @@
DSv3 suite (Oink vs Quack, multi-shape):
CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --dsv3 \\
--json /tmp/kernelagent_oink_sm100_suite_bf16/fused_add_rmsnorm_dsv3.json
+
+Quack baseline note:
+- Oink exposes an **in-place** fused op (writes `x` and `residual` in-place).
+- Quack provides an equivalent fused kernel, but typically returns `out` and
+ `residual_out` (out-of-place) and does not expose a public "update my input
+ buffers in-place" API.
+- For integration realism (vLLM-style semantics) we default to timing:
+ Quack fused kernel + 2 explicit copies to apply the in-place updates
+ so the benchmark covers the full semantic cost.
"""
from __future__ import annotations
@@ -177,6 +186,7 @@ def bench_one(
warmup_ms: int,
iters_ms: int,
verify: bool,
+ quack_baseline: str,
) -> Dict[str, Any]:
device = torch.device("cuda")
x = torch.randn((M, N), device=device, dtype=dtype)
@@ -212,23 +222,40 @@ def fn():
row.update(stats)
if quack_rmsnorm_fwd_mut is not None:
- out_q = torch.empty_like(x)
- res_out_q = torch.empty_like(residual)
+ x_q = x.clone()
+ residual_q = residual.clone()
+ out_q = torch.empty_like(x_q)
+ res_out_q = torch.empty_like(residual_q)
- def fn_q():
+ def fn_q_kernel():
quack_rmsnorm_fwd_mut(
- x,
+ x_q,
w,
out_q,
None, # bias
None, # rstd
None, # mean
- residual,
+ residual_q,
res_out_q,
1e-6,
False, # is_layernorm
)
+ if quack_baseline == "kernel":
+ fn_q = fn_q_kernel
+ elif quack_baseline == "kernel_inplace":
+
+ def fn_q():
+ fn_q_kernel()
+ # Apply the same in-place semantics as vLLM expects:
+ # - x is overwritten with y
+ # - residual is overwritten with z = x + residual
+ x_q.copy_(out_q)
+ residual_q.copy_(res_out_q)
+
+ else:
+ raise ValueError(f"Unknown quack_baseline: {quack_baseline}")
+
ms_q = do_bench_triton(fn_q, warmup_ms=warmup_ms, rep_ms=iters_ms)
gbps_q = bytes_io / (ms_q * 1e-3) / 1e9
row.update(
@@ -287,6 +314,18 @@ def main() -> None:
p.add_argument(
"--iters", type=int, default=200, help="rep_ms for do_bench (default: 200)"
)
+ p.add_argument(
+ "--quack-baseline",
+ type=str,
+ default="kernel_inplace",
+ choices=["kernel", "kernel_inplace"],
+ help=(
+ "How to time Quack for the in-place fused op.\n"
+ "- kernel: Quack fused kernel only (preallocated out/residual_out).\n"
+ "- kernel_inplace: Quack fused kernel + 2 explicit copies to apply "
+ "in-place semantics (integration-realistic)."
+ ),
+ )
p.add_argument("--skip-verify", action="store_true")
p.add_argument("--json", type=str, default=None)
args = p.parse_args()
@@ -309,6 +348,7 @@ def main() -> None:
warmup_ms=int(args.warmup_ms),
iters_ms=int(args.iters),
verify=not bool(args.skip_verify),
+ quack_baseline=str(args.quack_baseline),
)
)
@@ -324,7 +364,10 @@ def main() -> None:
warmup_ms=int(args.warmup_ms),
rep_ms=int(args.iters),
method="triton.testing.do_bench(mean)",
- note="Oink fused_add_rmsnorm_inplace_ vs Quack quack::_rmsnorm_fwd(residual=..., residual_out=...) when available",
+ note=(
+ "Oink fused_add_rmsnorm_inplace_ vs Quack baseline "
+ f"({args.quack_baseline}) when available"
+ ),
),
)
diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
index 31b335b..50ecb2e 100644
--- a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
+++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py
@@ -17,7 +17,6 @@
import argparse
import csv
import os
-import sys
from dataclasses import dataclass
from typing import List, Optional, Tuple
@@ -30,19 +29,17 @@
# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM.
os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a")
-# Make the in-repo KernelAgent Oink package importable without an editable install.
-_HERE = os.path.dirname(os.path.abspath(__file__))
-_OINK_SRC = os.path.abspath(os.path.join(_HERE, "..", "src"))
-if _OINK_SRC not in sys.path:
- sys.path.insert(0, _OINK_SRC)
-
from bench_utils import ( # noqa: E402
ErrorStatsAccumulator,
collect_device_meta,
+ ensure_oink_src_on_path,
error_stats_to_row,
iter_row_blocks,
write_json,
)
+
+ensure_oink_src_on_path()
+
from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402
try:
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg
index e32e3a7..96b5b83 100644
--- a/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg
@@ -6,7 +6,7 @@
- 2026-01-12T23:31:37.117906
+ 2026-01-22T03:16:57.722815
image/svg+xml
@@ -41,12 +41,12 @@ z
-
-
+
@@ -176,7 +176,7 @@ z
-
+
@@ -243,7 +243,7 @@ z
-
+
@@ -322,7 +322,7 @@ z
-
+
@@ -343,7 +343,7 @@ z
-
+
@@ -365,7 +365,7 @@ z
-
+
@@ -414,7 +414,7 @@ z
-
+
@@ -439,16 +439,16 @@ z
+" clip-path="url(#p68502969ea)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/>
-
-
+
@@ -483,18 +483,18 @@ z
-
+
-
+
-
+
@@ -504,18 +504,18 @@ L 424.416918 286.945749
-
+
-
+
-
+
@@ -525,18 +525,18 @@ L 424.416918 239.345028
-
+
-
+
-
+
@@ -944,16 +944,16 @@ z
-
+
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
-
+
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+" clip-path="url(#p68502969ea)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/>
-
+
@@ -1247,7 +1247,7 @@ z
-
+
@@ -1268,7 +1268,7 @@ z
-
+
@@ -1289,7 +1289,7 @@ z
-
+
@@ -1310,7 +1310,7 @@ z
-
+
@@ -1332,7 +1332,7 @@ z
-
+
@@ -1354,7 +1354,7 @@ z
-
+
@@ -1379,93 +1379,93 @@ z
+" clip-path="url(#p8e212f52e1)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/>
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+" clip-path="url(#p8e212f52e1)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/>
-
+
@@ -1601,7 +1601,7 @@ z
-
+
@@ -1622,7 +1622,7 @@ z
-
+
@@ -1643,7 +1643,7 @@ z
-
+
@@ -1664,7 +1664,7 @@ z
-
+
@@ -1686,7 +1686,7 @@ z
-
+
@@ -1708,7 +1708,7 @@ z
-
+
@@ -1733,93 +1733,93 @@ z
+" clip-path="url(#p9bab140156)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/>
-
+
-
+
-
+
-
+
-
+
-
+
-
+
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+" clip-path="url(#p9bab140156)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/>
-
+
@@ -2183,7 +2183,7 @@ L 619.955625 46.691969
L 636.205625 46.691969
" style="fill: none; stroke: #ff4444; stroke-width: 5; stroke-linecap: square"/>
-
+
@@ -2246,13 +2246,13 @@ z
-
+
-
+
-
+
diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg
index b70ba9b..254623e 100644
--- a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg
+++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg
@@ -1,12 +1,12 @@
-