From 27cff30833ee145e97a6c84dc976b3def08eb2da Mon Sep 17 00:00:00 2001 From: Laura Wang <3700467+Laurawly@users.noreply.github.com> Date: Tue, 6 Jan 2026 10:58:01 -0800 Subject: [PATCH 1/8] Provide a vLLM general plugin that registers oink::rmsnorm and oink::fused_add_rms_norm backed by an SM100 CuTeDSL RMSNorm kernel. The ops are torch.compile-friendly (stride-preserving for padded-row inputs) and the fused op matches vLLM's in-place residual-add RMSNorm semantics. --- oink/README.md | 57 + oink/pyproject.toml | 29 + oink/src/kernelagent_oink/__init__.py | 95 + .../kernelagent_oink/blackwell/__init__.py | 3 + .../kernelagent_oink/blackwell/lite_quack.py | 350 +++ .../blackwell/oink_custom_ops.py | 224 ++ .../src/kernelagent_oink/blackwell/rmsnorm.py | 2660 +++++++++++++++++ 7 files changed, 3418 insertions(+) create mode 100644 oink/README.md create mode 100644 oink/pyproject.toml create mode 100644 oink/src/kernelagent_oink/__init__.py create mode 100644 oink/src/kernelagent_oink/blackwell/__init__.py create mode 100644 oink/src/kernelagent_oink/blackwell/lite_quack.py create mode 100644 oink/src/kernelagent_oink/blackwell/oink_custom_ops.py create mode 100644 oink/src/kernelagent_oink/blackwell/rmsnorm.py diff --git a/oink/README.md b/oink/README.md new file mode 100644 index 0000000..427f69f --- /dev/null +++ b/oink/README.md @@ -0,0 +1,57 @@ +# KernelAgent Oink (vLLM plugin) + +This subproject provides an **out-of-tree vLLM plugin** that registers +`torch.library.custom_op` entrypoints under the `oink::` namespace: + +- `torch.ops.oink.rmsnorm` +- `torch.ops.oink.fused_add_rms_norm` + +The implementation is backed by a CuTeDSL (CUTLASS) RMSNorm kernel tuned for +**NVIDIA Blackwell (SM100)**. + +## Install (editable) + +From the `KernelAgent` repo root: + +```bash +pip install -e ./oink +``` + +This plugin requires the CuTeDSL stack: + +```bash +pip install nvidia-cutlass-dsl cuda-python +``` + +## Use with vLLM + +1. Enable the vLLM integration: + +```bash +export VLLM_USE_OINK_RMSNORM=1 +``` + +2. Ensure vLLM keeps `rms_norm` as a custom op when using `torch.compile` / +CUDA graphs. In Python: + +```python +from vllm import LLM + +llm = LLM( + model=..., + tensor_parallel_size=..., + enforce_eager=False, + compilation_config={"custom_ops": ["none", "+rms_norm"]}, +) +``` + +Without `+rms_norm`, Inductor may fuse RMSNorm into larger Triton kernels and +neither vLLM's CUDA RMSNorm nor Oink will run. + +## Notes + +- This plugin is designed to be **safe to import even when disabled**; it only + registers ops when `VLLM_USE_OINK_RMSNORM` is truthy (`"1"` / `"true"`). +- The ops preserve **padded-row layouts** for 2D tensors (shape `[M, N]`, + `stride(1) == 1`, and potentially `stride(0) > N`), which is required for + `torch.compile` stride verification on some models (e.g., MLA padded inputs). diff --git a/oink/pyproject.toml b/oink/pyproject.toml new file mode 100644 index 0000000..a9ec306 --- /dev/null +++ b/oink/pyproject.toml @@ -0,0 +1,29 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "kernelagent-oink" +version = "0.1.0" +description = "vLLM plugin that registers Oink Blackwell RMSNorm custom ops" +readme = "README.md" +requires-python = ">=3.10" +license = {text = "Apache-2.0"} +authors = [{name = "PyTorch Labs"}] + +# Keep dependencies minimal, but include the CuTeDSL stack required by the +# Blackwell RMSNorm implementation. +# +# We intentionally do NOT depend on `torch` here because vLLM already pins and +# provides a compatible PyTorch build. +dependencies = [ + "nvidia-cutlass-dsl", + "cuda-python", +] + +[project.entry-points."vllm.general_plugins"] +oink = "kernelagent_oink:register" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["kernelagent_oink*"] diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py new file mode 100644 index 0000000..542e59e --- /dev/null +++ b/oink/src/kernelagent_oink/__init__.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import logging +import os + +logger = logging.getLogger(__name__) + +_OPS_REGISTERED = False + + +def _env_truthy(name: str) -> bool: + val = os.environ.get(name) + if val is None: + return False + return val.strip().lower() in ("1", "true", "yes", "on") + + +def _infer_cuda_device_index() -> int: + local_rank = os.environ.get("LOCAL_RANK") + if local_rank is not None: + try: + return int(local_rank) + except ValueError: + pass + return 0 + + +def _compute_cutedsl_arch(major: int, minor: int) -> str: + # CuTeDSL uses an "a" suffix for >= Hopper. + suffix = "a" if major >= 9 else "" + # Match cutlass/base_dsl/env_manager.py: map sm_110 -> sm_101. + if major == 11 and minor == 0: + major, minor = 10, 1 + return f"sm_{major}{minor}{suffix}" + + +def register() -> None: + """vLLM plugin entrypoint. + + This function must be safe to call multiple times and must not raise. + vLLM executes it in multiple processes (engine + workers). + """ + global _OPS_REGISTERED + + if _OPS_REGISTERED: + return + + # Gate on the vLLM integration flag so installing the package does not + # change behavior unless explicitly enabled. + if not _env_truthy("VLLM_USE_OINK_RMSNORM"): + return + + try: + import torch + except Exception as e: # pragma: no cover + logger.debug("Oink plugin: torch import failed: %s", e) + return + + try: + if not torch.cuda.is_available(): + return + device_index = _infer_cuda_device_index() + major, minor = torch.cuda.get_device_capability(device_index) + sm = 10 * int(major) + int(minor) + if sm < 100: + return + + # Ensure required deps are importable before registering ops so that vLLM + # doesn't detect ops that would later fail at first use. + try: + import cutlass # noqa: F401 + import cuda.bindings.driver as _cuda # noqa: F401 + except Exception as e: + logger.warning( + "Oink plugin: CuTeDSL deps missing; skipping op registration. " + "Install `nvidia-cutlass-dsl` + `cuda-python`. Error: %s", + e, + ) + return + + # Ensure CuTeDSL sees a target arch early. If the user has already set it, + # respect their choice. + os.environ.setdefault("CUTE_DSL_ARCH", _compute_cutedsl_arch(int(major), int(minor))) + + # Import registers the ops via torch.library.custom_op decorators. + from .blackwell import oink_custom_ops # noqa: F401 + except Exception as e: # pragma: no cover + # Do not raise: vLLM plugin loader does not guard plugin execution. + logger.exception("Oink plugin: failed to register ops: %s", e) + return + + _OPS_REGISTERED = True + + +__all__ = ["register"] diff --git a/oink/src/kernelagent_oink/blackwell/__init__.py b/oink/src/kernelagent_oink/blackwell/__init__.py new file mode 100644 index 0000000..4d21ee8 --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/__init__.py @@ -0,0 +1,3 @@ +from __future__ import annotations + +__all__ = [] diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py new file mode 100644 index 0000000..3c3f750 --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py @@ -0,0 +1,350 @@ +""" +Lightweight local clone of the small subset of helpers that the SM100 +RMSNorm CuteDSL kernels depend on. + +This module intentionally avoids importing the `quack` package so that +Oink Blackwell kernels can run without Quack installed, while keeping +numerical behaviour and performance close to the original reference +implementations. +""" + +from __future__ import annotations + +import math +import operator +from typing import Callable, Optional, Tuple + +import cuda.bindings.driver as cuda # type: ignore +import torch +from torch import Tensor + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute.runtime import from_dlpack +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm, nvvm, vector + + +# ------------------------- +# Dtype mapping +# ------------------------- + +TORCH2CUTE_DTYPE = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, +} + + +# ------------------------- +# Tensor conversion helpers +# ------------------------- + +def convert_from_dlpack( + x: Tensor, + leading_dim: int, + alignment: int = 16, + divisibility: int = 1, +) -> cute.Tensor: + """ + Wrap a torch.Tensor in a CuteDSL tensor with layout metadata that + matches the logical leading dimension and alignment/divisibility + constraints expected by SM100 kernels. + """ + return ( + from_dlpack(x, assumed_align=alignment) + .mark_layout_dynamic(leading_dim=leading_dim) + .mark_compact_shape_dynamic( + mode=leading_dim, + stride_order=x.dim_order(), + divisibility=divisibility, + ) + ) + + +# ------------------------- +# SM90/SM100 cluster helpers +# ------------------------- + + +@dsl_user_op +def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + +@dsl_user_op +def set_block_rank( + smem_ptr: cute.Pointer, + peer_cta_rank_in_cluster: cute.Int32, + *, + loc=None, + ip=None, +) -> cutlass.Int32: + """Map the given smem pointer to the address at another CTA rank in the cluster.""" + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + return cutlass.Int32( + llvm.inline_asm( + T.i32(), + [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def store_shared_remote( + val: float | Float32 | Int32 | cutlass.Int64, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: cute.typing.Int, + *, + loc=None, + ip=None, +) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, + peer_cta_rank_in_cluster, + loc=loc, + ip=ip, + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, + peer_cta_rank_in_cluster, + loc=loc, + ip=ip, + ).ir_value() + if const_expr(isinstance(val, float)): + val = Float32(val) + assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64" + suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)] + constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)] + llvm.inline_asm( + None, + [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32], + f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];", + f"r,{constraint},r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: + """ + Build a predicate tensor for the K dimension only. Values beyond + `limit` are masked out. + """ + tApA = cute.make_fragment( + cute.make_layout( + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + return tApA + + +@dsl_user_op +def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + """ + Return a tensor whose iterator is offset by an Int64 byte offset + computed from `coord` and the tensor's strides. + """ + flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) + flat_stride = cute.flatten_to_tuple(tensor.stride) + assert len(flat_coord_i64) == len(flat_stride), ( + "Coordinate and stride must have the same length" + ) + offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) + assert isinstance(tensor.iterator, cute.Pointer) + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +# ------------------------- +# Reduction helpers +# ------------------------- + + +@cute.jit +def warp_reduce( + val: cute.TensorSSA | cute.Numeric, + op: Callable, + width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.TensorSSA | cute.Numeric: + """ + Warp-level reduction for either scalar values or small TensorSSA + fragments. + """ + if cutlass.const_expr(isinstance(val, cute.TensorSSA)): + res = cute.make_fragment(val.shape, val.dtype) + res.store(val) + for i in cutlass.range_constexpr(cute.size(val.shape)): + res[i] = warp_reduce(res[i], op, width) + return res.load() + for i in cutlass.range_constexpr(int(math.log2(width))): + val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) + return val + + +@cute.jit +def block_reduce( + val: cute.Numeric, + op: Callable, + reduction_buffer: cute.Tensor, + init_val: cute.Numeric = 0.0, +) -> cute.Numeric: + """Block-level reduction across warps.""" + lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() + warps_per_row = cute.size(reduction_buffer.shape[1]) + row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row + if lane_idx == 0: + reduction_buffer[row_idx, col_idx] = val + cute.arch.barrier() + block_reduce_val = init_val + if lane_idx < warps_per_row: + block_reduce_val = reduction_buffer[row_idx, lane_idx] + return warp_reduce(block_reduce_val, op) + + +@cute.jit +def cluster_reduce( + val: cute.Numeric, + op: Callable, + reduction_buffer: cute.Tensor, + mbar_ptr: cute.Pointer, + init_val: cute.Numeric = 0.0, + phase: Optional[cutlass.Int32] = None, +) -> cute.Numeric: + """ + Cluster-wide reduction using shared memory and mbarrier. The + reduction_buffer has shape (rows_per_block, (warps_per_row, cluster_n)). + """ + cta_rank_in_cluster = cute.arch.block_idx_in_cluster() + lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() + rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape + row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row + if warp_idx == 0: + with cute.arch.elect_one(): + num_warps = rows_per_block * warps_per_row + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_ptr, + num_warps * cluster_n * reduction_buffer.element_type.width // 8, + ) + if lane_idx < cluster_n: + store_shared_remote( + val, + elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))), + mbar_ptr, + peer_cta_rank_in_cluster=lane_idx, + ) + cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0) + block_reduce_val = init_val + num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE) + for i in cutlass.range_constexpr(num_iter): + idx = lane_idx + i * cute.arch.WARP_SIZE + if idx < cute.size(reduction_buffer, mode=[1]): + block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx]) + return warp_reduce(block_reduce_val, op) + + +@cute.jit +def block_or_cluster_reduce( + val: cute.Numeric, + op: Callable, + reduction_buffer: cute.Tensor, + mbar_ptr: Optional[cute.Pointer], + phase: Optional[cutlass.Int32] = None, + init_val: cute.Numeric = 0.0, +) -> cute.Numeric: + """Dispatch between block or cluster reduction depending on mbar_ptr.""" + if cutlass.const_expr(mbar_ptr is None): + return block_reduce(val, op, reduction_buffer, init_val=init_val) + return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase) + + +@cute.jit +def row_reduce( + x: cute.TensorSSA | cute.Numeric, + op: cute.ReductionOp, + threads_per_row: cutlass.Constexpr[int], + reduction_buffer: Optional[cute.Tensor] = None, + mbar_ptr: Optional[cute.Pointer] = None, + phase: Optional[cutlass.Int32] = None, + init_val: cute.Numeric = 0.0, + hook_fn: Optional[Callable] = None, +) -> cute.Numeric: + """ + Row-wise reduction used by RMSNorm and similar kernels. + + reduction_buffer must have shape + (num_warps / warps_per_row, (warps_per_row, cluster_n)). + """ + if cutlass.const_expr(isinstance(x, cute.TensorSSA)): + val = x.reduce(op, init_val=init_val, reduction_profile=0) + else: + val = x + warp_op = { + cute.ReductionOp.ADD: operator.add, + cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max, + cute.ReductionOp.MIN: min, + cute.ReductionOp.MUL: operator.mul, + }[op] + val = warp_reduce( + val, + warp_op, + width=min(threads_per_row, cute.arch.WARP_SIZE), + ) + if cutlass.const_expr(hook_fn is not None): + hook_fn() + if cutlass.const_expr(reduction_buffer is not None): + warps_per_row, cluster_n = reduction_buffer.shape[1] + assert cluster_n == 1 or mbar_ptr is not None, ( + "mbar_ptr must be provided for cluster reduction" + ) + if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1): + val = block_or_cluster_reduce( + val, + warp_op, + reduction_buffer, + mbar_ptr, + phase=phase, + init_val=init_val, + ) + return val + + +# ------------------------- +# SM count helper +# ------------------------- + + +def get_sm_count(N: int, device: torch.device) -> int: + """ + Heuristic for the number of persistent CTAs (sm_count) based on N and + the GPU's SM count. This mirrors the behaviour used in Quack's + RMSNorm kernels but lives entirely in this local module. + """ + sm_count_multiple = ( + 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1))) + ) + sm_count = torch.cuda.get_device_properties(device).multi_processor_count + sm_count = ( + sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2 + ) + return sm_count + diff --git a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py new file mode 100644 index 0000000..8225025 --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +""" +Torch custom ops wrapping Oink's Blackwell RMSNorm kernels. + +These ops are designed to be: +- Architecture-aware (use CuTeDSL SM100 kernels when available, fall back + to a safe reference elsewhere). +- Layout-preserving for 2D row-major inputs, including padded MLA-style + layouts where stride(0) > N and stride(1) == 1. +- torch.compile-friendly via proper fake implementations that mirror + runtime shapes and strides. + +Public ops (Python signatures): + + torch.ops.oink.rmsnorm(x: Tensor, weight: Tensor, eps: float) -> Tensor + Functional RMSNorm. Returns a new tensor with the same shape and + stride as x when using the fast CuTeDSL path. + + torch.ops.oink.fused_add_rms_norm( + x: Tensor, residual: Tensor, weight: Tensor, eps: float + ) -> None + In-place fused residual-add + RMSNorm matching vLLM semantics: + residual = x + residual (stored into `residual`) + x = RMSNorm(residual, w) (stored into `x`) + Mutates `x` and `residual` in-place and returns None. +""" + +import importlib +import threading + +import torch +from torch.library import custom_op + +_RMSNORM_MOD: object | None = None +_RMSNORM_MOD_LOCK = threading.Lock() + + +def _get_rmsnorm_mod(): + """Lazy import to keep plugin registration lightweight. + + Importing the CuTeDSL kernel stack can be expensive and may require a CUDA + context. We defer it until the first actual execution of the custom op. + """ + global _RMSNORM_MOD + + cached = _RMSNORM_MOD + if cached is not None: + return cached + + with _RMSNORM_MOD_LOCK: + if _RMSNORM_MOD is None: + _RMSNORM_MOD = importlib.import_module("kernelagent_oink.blackwell.rmsnorm") + return _RMSNORM_MOD + + +def _get_sm(device: torch.device | None = None) -> int: + """Return SM version as an int (e.g., 100 for SM100 / Blackwell).""" + if device is None: + device = torch.device("cuda") + major, minor = torch.cuda.get_device_capability(device) + return 10 * major + minor + + +# +# RMSNorm (functional) +# + +@custom_op("oink::rmsnorm", mutates_args=()) +def oink_rmsnorm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, +) -> torch.Tensor: + """ + Functional RMSNorm entrypoint. + + This op is model-agnostic. It expects a 2D [M, N] view of the input + where the last dimension is contiguous (stride(1) == 1). The leading + dimension stride(0) may be larger than N (padded-row layouts), and + will be preserved on the fast CuTeDSL path. + + On SM100 (and newer), this dispatches to the tuned CuTeDSL Blackwell + RMSNorm kernel in rmsnorm.rmsnorm_forward, which in turn selects the + best internal schedule (including DSv3-specific stage-2 kernels where + applicable) and preserves the input's 2D stride when using the + pointer-based path. + + On older architectures it falls back to a safe PyTorch reference + implementation for correctness. + """ + assert x.is_cuda, "oink::rmsnorm requires CUDA tensors" + assert x.dim() == 2, "oink::rmsnorm expects a 2D [M, N] tensor view" + assert weight.dim() == 1, "weight must be 1D [N]" + + sm = _get_sm(x.device) + if sm >= 100: + # Use the tuned CuTeDSL SM100 kernel. The public API already + # contains all necessary gating and layout checks internally. + _rms = _get_rmsnorm_mod() + y, _rstd, _res = _rms.rmsnorm_forward( + x, + weight=weight, + bias=None, + residual=None, + eps=eps, + store_rstd=False, + ) + return y + + # Fallback: reference implementation (correctness-first). + _rms = _get_rmsnorm_mod() + return _rms.rmsnorm_ref( + x, + w=weight, + b=None, + residual=None, + eps=eps, + ) + + +@oink_rmsnorm.register_fake +def oink_rmsnorm_fake( + x: torch.Tensor, + weight: torch.Tensor, + eps: float, +) -> torch.Tensor: + """ + Fake (meta) implementation for oink::rmsnorm. + + We must preserve x's logical layout (shape + stride) so that Inductor's + CUDA graph capture sees the same stride contract as the real kernel. + """ + # x is a FakeTensor here; x.shape/x.stride()/x.device/x.dtype are defined. + return torch.empty_strided( + x.shape, + x.stride(), + device=x.device, + dtype=x.dtype, + ) + + +# +# Fused residual-add + RMSNorm (in-place, vLLM semantics) +# + +@custom_op("oink::fused_add_rms_norm", mutates_args=("x", "residual")) +def oink_fused_add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, +) -> None: + """ + In-place fused residual-add + RMSNorm: + + residual <- x + residual + x <- RMSNorm(residual, weight, eps) + + Returns: + None (mutates `x` and `residual` in-place). + """ + assert x.is_cuda and residual.is_cuda, "oink::fused_add_rms_norm requires CUDA tensors" + assert x.shape == residual.shape, "x and residual must have the same shape" + assert x.dtype == residual.dtype, "x and residual must have the same dtype" + assert weight.dim() == 1, "weight must be 1D [N]" + + sm = _get_sm(x.device) + if sm >= 100: + _rms = _get_rmsnorm_mod() + # Prefer the lowest-overhead in-place entrypoint (returns None). + if hasattr(_rms, "fused_add_rmsnorm_inplace_"): + _rms.fused_add_rmsnorm_inplace_( # type: ignore[misc] + x, + residual, + weight, + eps=eps, + ) + return None + # Backward-compatible wrapper (returns (x, residual)). + if hasattr(_rms, "fused_add_rmsnorm_forward_inplace"): + _rms.fused_add_rmsnorm_forward_inplace( # type: ignore[misc] + x, + residual, + weight, + eps=eps, + ) + return None + + # Extremely defensive fallback if the Oink module doesn't provide + # the in-place entrypoint. + y, z = _rms.fused_add_rmsnorm_forward(x, residual, weight, eps=eps) + x.copy_(y) + residual.copy_(z) + return None + + # Non-SM100 fallback: keep semantics in-place (correctness-first). + residual.add_(x) + _rms = _get_rmsnorm_mod() + y = _rms.rmsnorm_ref(residual, w=weight, b=None, residual=None, eps=eps) + x.copy_(y) + return None + + +@oink_fused_add_rms_norm.register_fake +def oink_fused_add_rms_norm_fake( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, +) -> None: + """ + Fake (meta) implementation for oink::fused_add_rms_norm. + + Because this op mutates its inputs in-place, the outputs alias the input + buffers and therefore have the same shapes and strides. + """ + return None + + +__all__ = [ + "oink_rmsnorm", + "oink_fused_add_rms_norm", +] diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py new file mode 100644 index 0000000..d6c2c20 --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py @@ -0,0 +1,2660 @@ +""" +RMSNorm kernel for SM100 (Blackwell) in CuteDSL. + +This implementation targets Blackwell with: +- A stride-preserving pointer path for padded-row layouts (e.g. MLA stride0> N). +- A one-pass fused-add RMSNorm schedule for bf16/fp16 (DSv3 N=7168) that keeps + `x + residual` in registers (avoids re-reading gmem) and uses FP32 accumulation. +- Optional experimental schedule knobs (env vars) to explore copy widths and + stage-2 cp.async variants. + +Note: This file expects the local CuTeDSL (cutlass) and SM100 helper modules +to be available in the Python environment (e.g., `nvidia-cutlass-dsl` and +`cuda-python`). It is shipped as part of the KernelAgent Oink vLLM plugin. +""" + +from __future__ import annotations + +import ctypes +import importlib.metadata +import os +import re +import subprocess +import sys +import threading +from typing import Optional, Tuple + +_HERE = os.path.dirname(__file__) + +# CuTeDSL caches generated MLIR into a tempdir under a global default +# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across +# `nvidia-cutlass-dsl` versions (e.g. 4.3.2 vs 4.3.4), and cross-version cache +# sharing causes noisy "invalid section ID" warnings (and disables cache reuse). +# +# If the user has not pinned `CUTE_DSL_CACHE_DIR`, isolate by version so multiple +# CuTeDSL envs can coexist on the same machine without stepping on each other. +if "CUTE_DSL_CACHE_DIR" not in os.environ: + try: + _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl") + except Exception: + _dsl_ver = "unknown" + _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver) + _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user" + _tmp = os.environ.get("TMPDIR") or "/tmp" + os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join( + _tmp, _user, f"cutlass_python_cache_{_dsl_ver}" + ) + +try: + import cutlass # type: ignore # noqa: F401 +except Exception as e: + raise ImportError( + "kernelagent_oink.blackwell.rmsnorm requires CuTeDSL's Python package " + "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)." + ) from e + +import torch +from torch import Tensor + +import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute import runtime as rt + +# Simple compile cache declared early so direct execution works +_PTR_COMPILE_CACHE = {} + +# Thread-local cache for the fast-launch path. We keep per-thread packed args and +# pointer/scalar storage so concurrent callers don't race on in-place updates. +_PTR_FAST_LAUNCH_TLS = threading.local() + +def _env_flag(name: str, default: bool) -> bool: + val = os.environ.get(name) + if val is None: + return default + return val.strip().lower() not in {"0", "false", "no", "off", ""} + + +# Fast-launch uses a few private-ish CuTeDSL internals (packed args plumbing and +# runtime pointer descriptors). Keep it enabled by default for our pinned CuTeDSL +# environment, but allow disabling it via env var and auto-disable it if those +# internals are not present in a future upgrade. +_ENABLE_FAST_LAUNCH = _env_flag("OINK_CUTEDSL_FAST_LAUNCH", default=True) +_FAST_LAUNCH_SUPPORTED = True + +# Fused-add RMSNorm schedule knobs (read once at import time; set env vars before +# importing this module if you want to override). +_DIRECT_GMEM_POLICY = (os.environ.get("OINK_RMSNORM_DIRECT_GMEM", "auto").strip().lower() or "auto") +_COPY_BITS_POLICY = (os.environ.get("OINK_RMSNORM_COPY_BITS", "auto").strip().lower() or "auto") +_ENABLE_CLUSTER_ILP = _env_flag("OINK_RMSNORM_ENABLE_CLUSTER_ILP", default=False) +_ENABLE_CLUSTER_ILP_UNSAFE = _env_flag("OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE", default=False) +_ENABLE_TPR256 = _env_flag("OINK_RMSNORM_ENABLE_TPR256", default=False) +_ENABLE_STAGE2 = _env_flag("OINK_RMSNORM_ENABLE_STAGE2", default=False) + +# CuTeDSL stability probe for the experimental cluster_n>1 + direct-GMEM schedule. +# +# Some CuTeDSL builds segfault during JIT compilation when combining: +# - cluster launches (cluster_n>1) and +# - direct-GMEM loads/stores (no staging SMEM tiles). +# +# We keep the schedule gated behind `OINK_RMSNORM_ENABLE_CLUSTER_ILP=1` + +# `OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE=1`, and additionally run a one-time +# out-of-process compile probe so we can safely fall back to the staged SMEM +# path instead of crashing the parent process. +# +# This is (currently) sensitive to the vector width: we have observed +# reproducible segfaults for the 256b universal-copy path, while the 128b path +# can succeed. Cache the maximum supported copy width (0 = unsupported). +_CLUSTER_DIRECT_GMEM_MAX_COPY_BITS: Optional[int] = None +_CLUSTER_DIRECT_GMEM_PROBE_LOCK = threading.Lock() +_CLUSTER_DIRECT_GMEM_PROBE_WARNED = False + + +def _probe_cluster_direct_gmem_max_copy_bits() -> int: + global _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS + global _CLUSTER_DIRECT_GMEM_PROBE_WARNED + + override = os.environ.get("OINK_RMSNORM_CLUSTER_DIRECT_GMEM_MAX_COPY_BITS") + if override is not None and override.strip() != "": + try: + value = int(override) + except ValueError: + value = 0 + value = 256 if value >= 256 else 128 if value >= 128 else 0 + _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS = value + return value + + if _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS is not None: + return _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS + + with _CLUSTER_DIRECT_GMEM_PROBE_LOCK: + if _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS is not None: + return _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS + + script_template = r""" +import os + +os.environ["OINK_CUTEDSL_FAST_LAUNCH"] = "0" + +import cutlass +import cutlass.cute as cute +import cuda.bindings.driver as cuda +from cutlass import Float32, Int32 +from cutlass.cute import runtime as rt + +from kernelagent_oink.blackwell import rmsnorm + +N = 7168 +dtype = cutlass.BFloat16 + +copy_bits = int(os.environ["OINK_PROBE_COPY_BITS"]) +assumed_align = int(os.environ["OINK_PROBE_ASSUMED_ALIGN"]) + +op = rmsnorm.RMSNormSM100( + N, + dtype, + stage=1, + copy_bits=copy_bits, + use_async=False, + direct_gmem=True, +) +op._cluster_n_override = 2 # 2 CTAs per row + +ptr_x = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) +ptr_res = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) +ptr_w = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) + +_ = cute.compile( + op.launch_from_ptrs_fused_add_inplace, + ptr_x, + ptr_w, + ptr_res, + Int32(4096), + Int32(N), + Int32(N), + cuda.CUstream(0), + Float32(1e-6), +) +print(f"ok {copy_bits}") +""" + + env = os.environ.copy() + env["PYTHONNOUSERSITE"] = "1" + + def run_probe(copy_bits: int, assumed_align: int): + probe_env = env.copy() + probe_env["OINK_PROBE_COPY_BITS"] = str(copy_bits) + probe_env["OINK_PROBE_ASSUMED_ALIGN"] = str(assumed_align) + return subprocess.run( + [sys.executable, "-c", script_template], + env=probe_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=120.0, + ) + + proc_256 = None + proc_128 = None + try: + proc_256 = run_probe(256, 32) + if proc_256.returncode == 0: + max_bits = 256 + else: + proc_128 = run_probe(128, 16) + max_bits = 128 if proc_128.returncode == 0 else 0 + except Exception: + max_bits = 0 + + if not _CLUSTER_DIRECT_GMEM_PROBE_WARNED and max_bits != 256: + _CLUSTER_DIRECT_GMEM_PROBE_WARNED = True + if max_bits == 128: + print( + "Oink: cluster_n>1 + direct_gmem 256b compile probe failed; " + "using 128b copies for the cluster ILP schedule.", + file=sys.stderr, + ) + if proc_256 is not None and proc_256.stderr: + tail = "\n".join(proc_256.stderr.splitlines()[-12:]) + print(f"Oink: probe stderr tail:\n{tail}", file=sys.stderr) + else: + rc = ( + proc_128.returncode + if proc_128 is not None + else (proc_256.returncode if proc_256 is not None else "unknown") + ) + print( + "Oink: cluster_n>1 + direct_gmem compile probe failed; " + f"falling back to staged SMEM path (returncode={rc}).", + file=sys.stderr, + ) + failing_proc = proc_128 if proc_128 is not None else proc_256 + if failing_proc is not None and failing_proc.stderr: + tail = "\n".join(failing_proc.stderr.splitlines()[-12:]) + print(f"Oink: probe stderr tail:\n{tail}", file=sys.stderr) + + _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS = max_bits + return max_bits + +def _parse_version_tuple(version: str) -> Tuple[int, int, int]: + parts = version.split(".") + nums: list[int] = [] + for part in parts[:3]: + match = re.match(r"^(\d+)", part) + nums.append(int(match.group(1)) if match is not None else 0) + while len(nums) < 3: + nums.append(0) + return nums[0], nums[1], nums[2] + + +def _cutlass_dsl_version() -> Optional[Tuple[int, int, int]]: + try: + return _parse_version_tuple(importlib.metadata.version("nvidia-cutlass-dsl")) + except Exception: + return None + + +_CUTLASS_DSL_VERSION = _cutlass_dsl_version() +# CuTeDSL 4.3.4 tightened some kernel argument expectations (notably around +# passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the +# older signature for 4.3.2, but switch to a 4.3.4-compatible signature when we +# detect 4.3.4+ (or when version detection is unavailable). +_KERNEL_ACCEPTS_LAYOUT_ARGS = _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4) + +if _ENABLE_CLUSTER_ILP and not _ENABLE_CLUSTER_ILP_UNSAFE: + # We have observed reproducible segfaults in some CuTeDSL builds when using + # cluster launches for this schedule. Require an explicit UNSAFE opt-in to + # avoid accidental crashes. + _ENABLE_CLUSTER_ILP = False + print( + "Oink: OINK_RMSNORM_ENABLE_CLUSTER_ILP requested but disabled by default due to " + "known instability; set OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE=1 to force-enable.", + file=sys.stderr, + ) + + +def _fast_launch_enabled() -> bool: + return _ENABLE_FAST_LAUNCH and _FAST_LAUNCH_SUPPORTED + + +def _direct_gmem_from_policy(*, default: bool) -> bool: + """Resolve the direct-GMEM schedule flag from the (import-time) policy string.""" + if _DIRECT_GMEM_POLICY in {"0", "false", "no", "off"}: + return False + if _DIRECT_GMEM_POLICY in {"1", "true", "yes", "on"}: + return True + return default + + +def _copy_bits_from_policy(*, default: int, can_use_256: bool) -> int: + """Resolve copy width (in bits) from the (import-time) policy string.""" + if _COPY_BITS_POLICY in {"128"}: + return 128 + if _COPY_BITS_POLICY in {"256"} and can_use_256: + return 256 + return default + + +class _StableI32Arg: + """A stable Int32 runtime arg (avoids per-call Int32().__c_pointers__ allocations).""" + + def __init__(self, value: int): + self._c_value = ctypes.c_int32(int(value)) + self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p) + + def set(self, value: int) -> None: + self._c_value.value = int(value) + + def __c_pointers__(self): + return [self._c_pointer] + + +class _StableF32Arg: + """A stable Float32 runtime arg (avoids per-call Float32().__c_pointers__ allocations).""" + + def __init__(self, value: float): + self._c_value = ctypes.c_float(float(value)) + self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p) + + def set(self, value: float) -> None: + self._c_value.value = float(value) + + def __c_pointers__(self): + return [self._c_pointer] + + +def _tls_fast_launch_cache() -> dict[tuple[object, ...], object]: + cache = getattr(_PTR_FAST_LAUNCH_TLS, "cache", None) + if cache is None: + cache = {} + _PTR_FAST_LAUNCH_TLS.cache = cache + return cache + + +def _set_runtime_ptr(ptr: object, device_ptr: int) -> None: + # Runtime pointer objects cache a `ctypes.c_void_p` descriptor and pass + # its address to the compiled function. Updating `_desc.value` updates + # the device pointer without changing the address of the descriptor. + # + # This relies on internal CuTeDSL runtime pointer fields (`_desc`, `_pointer`, + # etc.). If these internals change in a future CuTeDSL upgrade, callers + # should catch AttributeError and fall back to the regular launch path. + device_ptr = int(device_ptr) + ptr._pointer = device_ptr # type: ignore[attr-defined] + if getattr(ptr, "_c_pointer", None) is None: + ptr.__c_pointers__() # type: ignore[attr-defined] + ptr._desc.value = device_ptr # type: ignore[attr-defined] + + +class _PtrRmsnormFastLaunch: + def __init__( + self, + *, + compiled: object, + executor: object, + capi_func: object, + ptr_x: object, + ptr_w: Optional[object], + ptr_out: object, + arg_m: _StableI32Arg, + arg_n: _StableI32Arg, + arg_ld: _StableI32Arg, + arg_eps: _StableF32Arg, + stream: cuda.CUstream, + packed_args: object, + keepalive: tuple[object, ...], + ): + self._compiled = compiled + self._executor = executor + self._capi_func = capi_func + self._ptr_x = ptr_x + self._ptr_w = ptr_w + self._ptr_out = ptr_out + self._arg_m = arg_m + self._arg_n = arg_n + self._arg_ld = arg_ld + self._arg_eps = arg_eps + self._stream = stream + self._packed_args = packed_args + self._keepalive = keepalive + + self._use_fast_launch = True + + self._cuda_result = getattr(executor, "cuda_result", None) + + self._last_x_ptr = -1 + self._last_w_ptr = -1 + self._last_out_ptr = -1 + self._last_m = -1 + self._last_ld = -1 + self._last_eps = float("nan") + + def launch( + self, + *, + x: Tensor, + weight: Optional[Tensor], + out: Tensor, + M: int, + N: int, + ld: int, + eps: float, + ) -> None: + if not _fast_launch_enabled() or not self._use_fast_launch: + self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps) + return + + x_ptr = x.data_ptr() + if x_ptr != self._last_x_ptr: + try: + _set_runtime_ptr(self._ptr_x, x_ptr) + self._last_x_ptr = x_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps) + return + + if self._ptr_w is not None: + w_ptr = weight.data_ptr() # type: ignore[union-attr] + if w_ptr != self._last_w_ptr: + try: + _set_runtime_ptr(self._ptr_w, w_ptr) + self._last_w_ptr = w_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps) + return + + out_ptr = out.data_ptr() + if out_ptr != self._last_out_ptr: + try: + _set_runtime_ptr(self._ptr_out, out_ptr) + self._last_out_ptr = out_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps) + return + + if M != self._last_m: + self._arg_m.set(M) + self._last_m = M + if ld != self._last_ld: + self._arg_ld.set(ld) + self._last_ld = ld + if eps != self._last_eps: + self._arg_eps.set(eps) + self._last_eps = eps + + # Clear the error slot before launch (mirrors JitExecutor behavior). + if self._cuda_result is not None: + self._cuda_result.value = 0 + + ret = self._capi_func(self._packed_args) # type: ignore[misc] + if ret != 0: + raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}") + if self._cuda_result is not None: + err = int(self._cuda_result.value) + if err != 0: + raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})") + + def _disable_fast_launch(self) -> None: + global _FAST_LAUNCH_SUPPORTED + self._use_fast_launch = False + _FAST_LAUNCH_SUPPORTED = False + + def _fallback_launch( + self, + *, + x: Tensor, + weight: Optional[Tensor], + out: Tensor, + M: int, + N: int, + ld: int, + eps: float, + ) -> None: + # If the packed-args or runtime pointer mutation path stops working + # (e.g. due to a CuTeDSL upgrade), fall back to the regular call path. + dtype = TORCH2CUTE_DTYPE[x.dtype] + ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_w = ( + rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + if weight is not None + else None + ) + self._compiled( + ptr_x, + ptr_w, + None, # ptr_b + None, # ptr_res + ptr_out, + None, # ptr_res_out + None, # ptr_rstd + Int32(M), + Int32(N), + Int32(ld), + self._stream, + Float32(eps), + ) + + +class _PtrFusedAddRmsnormFastLaunch: + def __init__( + self, + *, + compiled: object, + executor: object, + capi_func: object, + ptr_x: object, + ptr_w: object, + ptr_res: object, + arg_m: _StableI32Arg, + arg_n: _StableI32Arg, + arg_ld_x: _StableI32Arg, + arg_eps: _StableF32Arg, + stream: cuda.CUstream, + assumed_align: int, + packed_args: object, + keepalive: tuple[object, ...], + ): + self._compiled = compiled + self._executor = executor + self._capi_func = capi_func + self._ptr_x = ptr_x + self._ptr_w = ptr_w + self._ptr_res = ptr_res + self._arg_m = arg_m + self._arg_n = arg_n + self._arg_ld_x = arg_ld_x + self._arg_eps = arg_eps + self._stream = stream + self._assumed_align = int(assumed_align) + self._packed_args = packed_args + self._keepalive = keepalive + + self._use_fast_launch = True + + self._cuda_result = getattr(executor, "cuda_result", None) + + self._last_x_ptr = -1 + self._last_w_ptr = -1 + self._last_res_ptr = -1 + self._last_m = -1 + self._last_ld_x = -1 + self._last_eps = float("nan") + + def launch( + self, + *, + x: Tensor, + weight: Tensor, + residual: Tensor, + M: int, + N: int, + ld_x: int, + eps: float, + ) -> None: + if not _fast_launch_enabled() or not self._use_fast_launch: + self._fallback_launch( + x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps + ) + return + + x_ptr = x.data_ptr() + if x_ptr != self._last_x_ptr: + try: + _set_runtime_ptr(self._ptr_x, x_ptr) + self._last_x_ptr = x_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps + ) + return + + w_ptr = weight.data_ptr() + if w_ptr != self._last_w_ptr: + try: + _set_runtime_ptr(self._ptr_w, w_ptr) + self._last_w_ptr = w_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps + ) + return + + res_ptr = residual.data_ptr() + if res_ptr != self._last_res_ptr: + try: + _set_runtime_ptr(self._ptr_res, res_ptr) + self._last_res_ptr = res_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, weight=weight, residual=residual, M=M, N=N, ld_x=ld_x, eps=eps + ) + return + + if M != self._last_m: + self._arg_m.set(M) + self._last_m = M + if ld_x != self._last_ld_x: + self._arg_ld_x.set(ld_x) + self._last_ld_x = ld_x + if eps != self._last_eps: + self._arg_eps.set(eps) + self._last_eps = eps + + if self._cuda_result is not None: + self._cuda_result.value = 0 + + ret = self._capi_func(self._packed_args) # type: ignore[misc] + if ret != 0: + raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}") + if self._cuda_result is not None: + err = int(self._cuda_result.value) + if err != 0: + raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})") + + def _disable_fast_launch(self) -> None: + global _FAST_LAUNCH_SUPPORTED + self._use_fast_launch = False + _FAST_LAUNCH_SUPPORTED = False + + def _fallback_launch( + self, + *, + x: Tensor, + weight: Tensor, + residual: Tensor, + M: int, + N: int, + ld_x: int, + eps: float, + ) -> None: + dtype = TORCH2CUTE_DTYPE[x.dtype] + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + ptr_res = rt.make_ptr( + dtype, + residual.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + ptr_w = rt.make_ptr( + dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + self._compiled( + ptr_x, + ptr_w, + ptr_res, + Int32(M), + Int32(N), + Int32(ld_x), + self._stream, + Float32(eps), + ) + + +def _get_fast_ptr_rmsnorm_launcher( + *, + compiled: object, + dtype: type[cutlass.Numeric], + N: int, + device_index: int, + stream_handle: int, + has_weight: bool, + eps: float, +) -> Optional[_PtrRmsnormFastLaunch]: + if not _fast_launch_enabled(): + return None + # Keyed by the compiled object identity so schedule changes (e.g. copy width, + # async/staged variants, etc.) never alias in the fast-launch cache. + key = ("ptr_fast", id(compiled), N, dtype, device_index, int(stream_handle), has_weight) + cache = _tls_fast_launch_cache() + cached = cache.get(key) + if cached is not None: + return cached # type: ignore[return-value] + + # Create stable runtime args and pointer descriptors once. + ptr_x = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_out = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_w = ( + rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) if has_weight else None + ) + + arg_m = _StableI32Arg(0) + arg_n = _StableI32Arg(N) + arg_ld = _StableI32Arg(N) + arg_eps = _StableF32Arg(eps) + + stream = cuda.CUstream(int(stream_handle)) + + # Create an executor (loads the CUDA library once). + executor = compiled.to(device_index) # type: ignore[attr-defined] + + # Use generate_execution_args once to build the packed args array, and keep + # any adapted args alive for the lifetime of the cache entry. + try: + exe_args, adapted_args = executor.generate_execution_args( + ptr_x, + ptr_w, + None, # ptr_b + None, # ptr_res + ptr_out, + None, # ptr_res_out + None, # ptr_rstd + arg_m, + arg_n, + arg_ld, + stream, + arg_eps, + ) + packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined] + capi_func = compiled.capi_func # type: ignore[attr-defined] + except AttributeError: + global _FAST_LAUNCH_SUPPORTED + _FAST_LAUNCH_SUPPORTED = False + return None + + keepalive: tuple[object, ...] = ( + executor, + ptr_x, + ptr_w, + ptr_out, + arg_m, + arg_n, + arg_ld, + arg_eps, + stream, + *adapted_args, + ) + + launcher = _PtrRmsnormFastLaunch( + compiled=compiled, + executor=executor, + capi_func=capi_func, + ptr_x=ptr_x, + ptr_w=ptr_w, + ptr_out=ptr_out, + arg_m=arg_m, + arg_n=arg_n, + arg_ld=arg_ld, + arg_eps=arg_eps, + stream=stream, + packed_args=packed_args, + keepalive=keepalive, + ) + cache[key] = launcher + return launcher + + +def _get_fast_ptr_fused_add_rmsnorm_launcher( + *, + compiled: object, + dtype: type[cutlass.Numeric], + N: int, + device_index: int, + stream_handle: int, + copy_bits: int, + use_async: bool, + tpr: int, + direct_gmem: bool, + assumed_align: int, + eps: float, +) -> Optional[_PtrFusedAddRmsnormFastLaunch]: + if not _fast_launch_enabled(): + return None + key = ( + "ptr_fused_add_fast", + id(compiled), + N, + dtype, + device_index, + int(stream_handle), + int(copy_bits), + bool(use_async), + int(tpr), + bool(direct_gmem), + int(assumed_align), + ) + cache = _tls_fast_launch_cache() + cached = cache.get(key) + if cached is not None: + return cached # type: ignore[return-value] + + ptr_x = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) + ptr_res = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) + ptr_w = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) + + arg_m = _StableI32Arg(0) + arg_n = _StableI32Arg(N) + arg_ld_x = _StableI32Arg(N) + arg_eps = _StableF32Arg(eps) + + stream = cuda.CUstream(int(stream_handle)) + + executor = compiled.to(device_index) # type: ignore[attr-defined] + + try: + exe_args, adapted_args = executor.generate_execution_args( + ptr_x, + ptr_w, + ptr_res, + arg_m, + arg_n, + arg_ld_x, + stream, + arg_eps, + ) + packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined] + capi_func = compiled.capi_func # type: ignore[attr-defined] + except AttributeError: + global _FAST_LAUNCH_SUPPORTED + _FAST_LAUNCH_SUPPORTED = False + return None + + keepalive: tuple[object, ...] = ( + executor, + ptr_x, + ptr_w, + ptr_res, + arg_m, + arg_n, + arg_ld_x, + arg_eps, + stream, + *adapted_args, + ) + + launcher = _PtrFusedAddRmsnormFastLaunch( + compiled=compiled, + executor=executor, + capi_func=capi_func, + ptr_x=ptr_x, + ptr_w=ptr_w, + ptr_res=ptr_res, + arg_m=arg_m, + arg_n=arg_n, + arg_ld_x=arg_ld_x, + arg_eps=arg_eps, + stream=stream, + assumed_align=assumed_align, + packed_args=packed_args, + keepalive=keepalive, + ) + cache[key] = launcher + return launcher + + +# Local helpers for reduction, dtype mapping, and coordinate/predicate utilities. +# +# NOTE: Avoid `from . import ...` imports here: CuTeDSL's AST preprocessor may +# mishandle that form (module=None in the AST). Use fully-qualified imports. +from kernelagent_oink.blackwell import lite_quack as qutils +from kernelagent_oink.blackwell.lite_quack import TORCH2CUTE_DTYPE, row_reduce + + +# ------------------------- +# Copy helpers (allow up to 256b) +# ------------------------- + +@cute.jit +def get_copy_atom_bw( + dtype: type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False +) -> cute.CopyAtom: + # cp.async (SIMT) supports up to 128b per op; use 256b for sync when possible + max_bits = const_expr(128 if is_async else 256) + num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width)) + from cutlass.cute.nvgpu import cpasync + # Prefer GLOBAL cache policy for bulk streaming reads at large M + copy_op = ( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL) + if is_async + else cute.nvgpu.CopyUniversalOp() + ) + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@cute.jit +def copy_tiled( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + num_copy_elems: int = 1, + is_async: bool = False, +) -> None: + atom = get_copy_atom_bw(src.element_type, num_copy_elems, is_async) + cute.copy(atom, src, dst, pred=pred) + + +# ------------------------- +# RMSNorm Kernel (SM100) +# ------------------------- + + +class RMSNormSM100: + def __init__( + self, + N: int, + dtype: type[cutlass.Numeric], + stage: Optional[int] = None, + *, + copy_bits: int = 128, + use_async: bool = True, + direct_gmem: bool = False, + ): + self.N = N + self.dtype = dtype + # Match Quack default for RMSNorm: stage = 1 unless explicitly overridden + self.stage = 1 if stage is None else stage + self.reduction_dtype = cutlass.Float32 + self.copy_bits = int(copy_bits) + self.use_async = bool(use_async) + self.direct_gmem = bool(direct_gmem) + + def _threads_per_row(self) -> int: + try: + return self._tpr_override # type: ignore[attr-defined] + except Exception: + pass + # Tune mid-size buckets for large-M rows. + N = self.N + # DSv3 MLA (padded/strided) hot shape. Prefer a threads-per-row that + # makes the tile width exactly match N with 128b vectors (bf16/fp16), + # avoiding the ~33% padded work from rounding 1536 -> 2048. + if N == 1536 and self.dtype.width == 16: + return 96 + # DSv3 default hidden size (7168). Choose a threads-per-row that matches + # the selected vector width to avoid padded work: + # - 128b copies (vec=8 for bf16/fp16): 7168/8 = 896 = 7 * 128 -> tpr=128 + # - 256b copies (vec=16 for bf16/fp16): 7168/16 = 448 = 2 * 224 -> tpr=224 + # + # The fused direct-GMEM path often uses 256b copies on 32B-aligned + # tensors, while the non-fused path defaults to 128b copies. + if N == 7168 and self.dtype.width == 16: + return 224 if self.copy_bits >= 256 else 128 + # For small-N, use at least one full warp per row. The kernel + # implementation assumes one row per CTA; returning <32 here can + # produce multi-row tiles (cols_per_block > 1) which is not supported. + if N <= 1024: + return 32 + elif N <= 4096: + return 128 + elif N <= 8192: + # Allow an override (used by 2-rows/CTA path for N≈6k/8k) + try: + return self._tpr_override # type: ignore[attr-defined] + except Exception: + return 128 + elif N <= 16384: + return 256 + else: + return 256 + + def _cluster_n(self) -> int: + try: + return self._cluster_n_override # type: ignore[attr-defined] + except Exception: + pass + N = self.N + # Default policy + if N <= 8192: + return 1 + if const_expr(self.dtype.width == 16): + if N <= 16 * 1024: + return 2 + elif N <= 32 * 1024: + return 2 + elif N <= 64 * 1024: + return 4 + elif N <= 128 * 1024: + return 8 + else: + return 16 + else: + if N <= 32 * 1024: + return 1 + elif N <= 64 * 1024: + return 2 + elif N <= 128 * 1024: + return 4 + elif N <= 256 * 1024: + return 8 + else: + return 16 + + def _num_threads(self) -> int: + # Favor 128 threads up to N=16k to reduce per-row partitioning overhead. + # This keeps cols_per_block=1 at N=8192 (bf16), which benchmarks faster for large-M. + try: + return self._nt_override # type: ignore[attr-defined] + except Exception: + if self.N == 1536 and self.dtype.width == 16: + return 96 + if self.N == 7168 and self.dtype.width == 16: + return 224 if self.copy_bits >= 256 else 128 + if self.N <= 1024: + return 32 + return 128 if self.N <= 16384 else 256 + + def _tv_layout(self, num_copy_bits: int = 256) -> Tuple[cute.Shape, cute.Layout]: + vecsize = num_copy_bits // self.dtype.width + num_threads = self._num_threads() + assert num_threads % cute.arch.WARP_SIZE == 0 + tpr = self._threads_per_row() + cluster_n = self._cluster_n() + # Allow tails: compute number of vector columns with ceil + num_cols_vec = cute.ceil_div(self.N, vecsize) + num_blocks_N = cute.ceil_div(num_cols_vec, tpr * cluster_n) + cols_per_block = num_threads // tpr + tiler_mn = (cols_per_block, vecsize * num_blocks_N * tpr) + tv_layout = cute.make_layout( + ((tpr, cols_per_block), (vecsize, num_blocks_N)), + stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * tpr)), + ) + return tiler_mn, tv_layout + + def _smem_bytes(self, tiler_mn, num_warps) -> int: + # smem for X tile (+ residual if present) + reduction buffers + mbar(s) + return ( + cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) + + self.stage * num_warps * self._cluster_n() * (self.reduction_dtype.width // 8) + + self.stage * (cutlass.Int64.width // 8) + ) + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ): + # Make last dim static (N) + semistatic_shape = (*mX.shape[:-1], self.N) + + def new_stride(t): + return ( + cute.assume(t.stride[0], divby=256 // t.element_type.width), + t.stride[1], + ) + + mX, mRes, mO, mResO = [ + cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))) + if const_expr(t is not None) + else None + for t in (mX, mRes, mO, mResO) + ] + assert mX.element_type == self.dtype + assert mO.element_type == self.dtype + + copy_bits = int(self.copy_bits) + tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits) + num_threads = cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._num_threads() + num_warps = num_threads // cute.arch.WARP_SIZE + threads_per_row = tv_layout.shape[0][0] if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._threads_per_row() + warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1) + cluster_n = self._cluster_n() + + if const_expr(mW is not None): + mW = cute.make_tensor( + mW.iterator, cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))) + ) + if const_expr(mB is not None): + mB = cute.make_tensor( + mB.iterator, cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))) + ) + if const_expr(mRstd is not None): + mRstd = cute.make_tensor( + mRstd.iterator, cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))) + ) + + # No SMEM reload mode switch; overlap is controlled in the K-loop path + + # Compute smem usage considering staged buffers. + # + # In direct-gmem mode, we skip the gmem->smem tiles entirely and only + # keep the reduction buffers in shared memory. + stage_bufs = 2 if self.stage > 1 else 1 + tile_bytes_x = ( + cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * stage_bufs + if const_expr(not self.direct_gmem) + else 0 + ) + tile_bytes_res = ( + cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn)) * stage_bufs + if const_expr(mRes is not None and not self.direct_gmem) + else 0 + ) + red_bytes = self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8) + # mbarriers are only allocated/used for cluster_n>1. Some CuTeDSL builds + # require mbarrier state to be 16B-aligned in shared memory; account for + # the alignment padding when computing dynamic smem bytes. + smem_bytes = tile_bytes_x + tile_bytes_res + red_bytes + if cluster_n > 1: + # Align up to 16B before placing the mbarrier array. + smem_bytes = ((smem_bytes + 15) // 16) * 16 + smem_bytes += self.stage * (cutlass.Int64.width // 8) + + kernel = ( + self.kernel( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + tv_layout, + tiler_mn, + const_expr(cluster_n), + const_expr(num_warps), + const_expr(warps_per_row), + const_expr(threads_per_row), + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + ) + ) + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1], + block=[num_threads, 1, 1], + cluster=([1, cluster_n, 1] if cluster_n > 1 else None), + smem=smem_bytes, + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_x: cute.Pointer, + ptr_w: Optional[cute.Pointer], + ptr_b: Optional[cute.Pointer], + ptr_res: Optional[cute.Pointer], + ptr_out: cute.Pointer, + ptr_res_out: Optional[cute.Pointer], + ptr_rstd: Optional[cute.Pointer], + M: Int32, + N_dyn: Int32, + ld: Int32, + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ): + """Pointer-based entrypoint to reuse the existing RMSNorm schedule. + + This reconstructs cute.Tensor views from raw pointers plus sizes, + avoiding any DLPack conversions at the Python boundary. + """ + # Use a dynamic N for the leading-dimension stride so that the + # subsequent cute.assume(...) in __call__ sees a dynamic expression + # rather than a plain Python int. + # The compile-time N for the kernel (self.N) is still used to + # specialize the schedule. + # Assume row-major [M, N] with an arbitrary leading-dimension stride + # (common for padded-row / packed-attention layouts). + layout_mn = cute.make_layout((M, N_dyn), stride=(ld, 1)) + layout_n = cute.make_layout((N_dyn,), stride=(1,)) + layout_m = cute.make_layout((M,), stride=(1,)) + + mX = cute.make_tensor(ptr_x, layout_mn) + mO = cute.make_tensor(ptr_out, layout_mn) + + mRes = ( + cute.make_tensor(ptr_res, layout_mn) + if const_expr(ptr_res is not None) + else None + ) + mResO = ( + cute.make_tensor(ptr_res_out, layout_mn) + if const_expr(ptr_res_out is not None) + else None + ) + mW = ( + cute.make_tensor(ptr_w, layout_n) + if const_expr(ptr_w is not None) + else None + ) + mB = ( + cute.make_tensor(ptr_b, layout_n) + if const_expr(ptr_b is not None) + else None + ) + mRstd = ( + cute.make_tensor(ptr_rstd, layout_m) + if const_expr(ptr_rstd is not None) + else None + ) + + # Reuse the main JIT entry to launch the scheduled kernel. + self.__call__(mX, mW, mB, mRes, mO, mResO, mRstd, stream, eps) + + @cute.jit + def launch_from_ptrs_fused_add_inplace( + self, + ptr_x: cute.Pointer, + ptr_w: cute.Pointer, + ptr_res: cute.Pointer, + M: Int32, + N_dyn: Int32, + ld_x: Int32, + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ): + """Pointer-based entrypoint for vLLM-style fused_add_rms_norm semantics. + + This specialized entrypoint supports: + - `x` / output with an arbitrary leading-dimension stride (`ld_x`), and + - `residual` / residual-out as a contiguous [M, N] tensor (ld_res = N). + + Both `x` and `residual` are updated in-place: + residual <- x + residual + x <- RMSNorm(residual) * weight + """ + layout_x = cute.make_layout((M, N_dyn), stride=(ld_x, 1)) + layout_res = cute.make_layout((M, N_dyn), stride=(N_dyn, 1)) + layout_n = cute.make_layout((N_dyn,), stride=(1,)) + + mX = cute.make_tensor(ptr_x, layout_x) + mO = cute.make_tensor(ptr_x, layout_x) + mRes = cute.make_tensor(ptr_res, layout_res) + mResO = cute.make_tensor(ptr_res, layout_res) + mW = cute.make_tensor(ptr_w, layout_n) + + self.__call__( + mX, + mW, + None, # bias + mRes, + mO, + mResO, + None, # rstd + stream, + eps, + ) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + eps: Float32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + cluster_n: cutlass.Constexpr[int], + num_warps: cutlass.Constexpr[int], + warps_per_row: cutlass.Constexpr[int], + threads_per_row: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + if const_expr(cluster_n > 1): + cta_rank_in_cluster = cute.arch.block_idx_in_cluster() + else: + cta_rank_in_cluster = const_expr(0) + n_off = cta_rank_in_cluster * tiler_mn[1] + + smem = cutlass.utils.SmemAllocator() + # Allocate one or two SMEM buffers depending on stage depth + sX0 = ( + smem.allocate_tensor( + mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32 + ) + if const_expr(not self.direct_gmem) + else None + ) + sX1 = ( + smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, + ) + if const_expr(self.stage > 1 and not self.direct_gmem) + else None + ) + sRes0 = ( + smem.allocate_tensor( + mRes.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32 + ) + if const_expr(mRes is not None and not self.direct_gmem) + else None + ) + sRes1 = ( + smem.allocate_tensor( + mRes.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, + ) + if const_expr(mRes is not None and self.stage > 1 and not self.direct_gmem) + else None + ) + + # Reduction buffers + mbar for cluster reduce (reused by row_reduce helper) + red_layout = cute.make_ordered_layout( + (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage), + order=(1, 0, 2), + ) + reduction_buffer = smem.allocate_tensor(self.reduction_dtype, red_layout, byte_alignment=4) + if const_expr(cluster_n > 1): + # Some CuTeDSL builds appear sensitive to the shared-memory alignment of + # mbarrier state. `SmemAllocator.allocate_array` does not currently + # expose an alignment parameter, so allocate an Int64 tensor with an + # explicit alignment and pass its iterator as the pointer. + mbar_tensor = smem.allocate_tensor( + cutlass.Int64, + cute.make_layout((self.stage,), stride=(1,)), + byte_alignment=16, + ) + mbar_ptr = mbar_tensor.iterator + else: + mbar_ptr = None + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + limit_k = shape[1] - n_off + + # Tiled copy setup + num_copy_elems_X = tv_layout.shape[1][0] + use_async = const_expr(self.use_async and self.N >= 1024 and not self.direct_gmem) + copy_atom = get_copy_atom_bw(mX.element_type, num_copy_elems_X, is_async=use_async) + thr_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn).get_slice(tidx) + + # Tail predicate for the N dimension (when tile width > N). Reuse this + # for W/B loads so we never read past the end of those 1D tensors. + is_even_N_wb = const_expr(shape[1] == tiler_mn[1] * cluster_n) + if const_expr(not is_even_N_wb): + cX0 = cute.local_tile(idX, tiler_mn, (0, 0)) + tXp_wb = qutils.predicate_k(thr_copy.partition_S(cX0), limit=limit_k) + else: + tXp_wb = None + + # Weight/bias loads: + # + # - Direct-GMEM schedule: load weight/bias up front to hide latency. + # - Staged SMEM schedule: loading after the reduction reduces register + # pressure during the long-scoreboard reduction phase (better for large-M), + # but it measurably hurts small-M latency for the non-fused (no residual, + # no bias) case. For that specific case, prefetch weight up front as well. + tXrW = None + tXrB = None + prefetch_w_early = bool( + mW is not None and (self.direct_gmem or (mRes is None and mB is None)) + ) + if const_expr(prefetch_w_early): + gW = cute.local_tile(qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0)) + tXgW = thr_copy.partition_S(gW) + tXrW = cute.make_fragment_like(tXgW) + if const_expr(not is_even_N_wb): + tXrW.fill(0) + cute.copy( + get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False), + tXgW, + tXrW, + pred=tXp_wb, + ) + if const_expr(self.direct_gmem and mB is not None): + gB = cute.local_tile(qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0)) + tXgB = thr_copy.partition_S(gB) + tXrB = cute.make_fragment_like(tXgB) + if const_expr(not is_even_N_wb): + tXrB.fill(0) + cute.copy( + get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False), + tXgB, + tXrB, + pred=tXp_wb, + ) + + # Non-persistent per-CTA execution (one tile in M) + self._init_cluster(tidx, mbar_ptr) + + mX_i, mRes_i, mO_i, mResO_i = [ + qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t) if t is not None else None + for t in (mX, mRes, mO, mResO) + ] + mX_i, mRes_i, mO_i, mResO_i = [ + qutils.domain_offset_i64((0, n_off), t) if t is not None else None + for t in (mX_i, mRes_i, mO_i, mResO_i) + ] + gX_i = cute.local_tile(mX_i, tiler_mn, (0, 0)) + gO_i = cute.local_tile(mO_i, tiler_mn, (0, 0)) + gRes_i = ( + cute.local_tile(mRes_i, tiler_mn, (0, 0)) if const_expr(mRes is not None) else None + ) + gResO_i = ( + cute.local_tile(mResO_i, tiler_mn, (0, 0)) if const_expr(mResO is not None) else None + ) + gRstd_i = ( + cute.local_tile(mRstd, tiler_mn, (bidx, 0)) if const_expr(mRstd is not None) else None + ) + cX_i = cute.local_tile(idX, tiler_mn, (bidx, 0)) + + # Common identity/row index partitions reused by both default and K-loop paths + tXcX_i = thr_copy.partition_S(cX_i)[(0, None), None, None] + row_i = tXcX_i[0][0] + tXgRstd_i = thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None + + # Stage-2 intra-row K-loop cp.async ping-pong (two tiles). This reduces + # per-thread fragment size and can improve memory-latency hiding for + # N=7168 at large M. It is enabled by setting `stage=2` when constructing + # the RMSNormSM100 op (see `_fused_add_rmsnorm_forward_ptr_inplace`). + if const_expr( + self.stage > 1 and not self.direct_gmem and use_async and cluster_n == 1 and shape[1] == 7168 + ): + vecsize = tv_layout.shape[1][0] + tpr = threads_per_row + target_tile_n = const_expr(4096) + tile_factor = const_expr(target_tile_n // (vecsize * tpr)) + if const_expr(tile_factor > 0): + tile_n = vecsize * tpr * tile_factor + num_tiles = cute.ceil_div(shape[1], tile_n) + + tiler_mn_tile = (tiler_mn[0], tile_n) + sX0_tile = cute.local_tile(sX0, tiler_mn_tile, (0, 0)) + sX1_tile = cute.local_tile(sX1, tiler_mn_tile, (0, 0)) + sRes0_tile = ( + cute.local_tile(sRes0, tiler_mn_tile, (0, 0)) + if const_expr(mRes is not None) + else None + ) + sRes1_tile = ( + cute.local_tile(sRes1, tiler_mn_tile, (0, 0)) + if const_expr(mRes is not None) + else None + ) + + tv_layout_tile = cute.make_layout( + ((tpr, tiler_mn[0]), (vecsize, tile_factor)), + stride=( + (vecsize * tiler_mn[0], 1), + (tiler_mn[0], tiler_mn[0] * vecsize * tpr), + ), + ) + thr_copy_tile = cute.make_tiled_copy(copy_atom, tv_layout_tile, tiler_mn_tile).get_slice( + tidx + ) + + # Accumulate per-thread partial sums across tiles; reduce once. + sum_sq_thread = cute.Float32(0.0) + + # Preload tile 0 into sX0/sRes0. + k_off0 = const_expr(0) * tile_n + gX_0 = cute.local_tile( + qutils.domain_offset_i64((0, k_off0), mX_i), tiler_mn_tile, (0, 0) + ) + tXgX_0 = thr_copy_tile.partition_S(gX_0) + tXsX_0 = thr_copy_tile.partition_D(sX0_tile) + cX_0 = cute.local_tile( + cute.domain_offset((0, k_off0), cX_i), tiler_mn_tile, (0, 0) + ) + tXc_0 = thr_copy_tile.partition_S(cX_0) + tXp_0 = qutils.predicate_k(tXc_0, limit=limit_k) + + tXp_ping = tXp_0 + tXp_pong = tXp_0 + + if row_i < shape[0]: + copy_tiled(tXgX_0, tXsX_0, num_copy_elems=vecsize, is_async=True, pred=tXp_0) + if const_expr(mRes is not None): + gRes_0 = cute.local_tile( + qutils.domain_offset_i64((0, k_off0), mRes_i), + tiler_mn_tile, + (0, 0), + ) + tXgRes_0 = thr_copy_tile.partition_S(gRes_0) + tXsRes_0 = thr_copy_tile.partition_D(sRes0_tile) + copy_tiled( + tXgRes_0, + tXsRes_0, + num_copy_elems=vecsize, + is_async=True, + pred=tXp_0, + ) + cute.arch.cp_async_commit_group() + + for t in cutlass.range_constexpr(num_tiles): + next_t = t + 1 + if next_t < num_tiles: + k_off_n = next_t * tile_n + gX_n = cute.local_tile( + qutils.domain_offset_i64((0, k_off_n), mX_i), + tiler_mn_tile, + (0, 0), + ) + tXgX_n = thr_copy_tile.partition_S(gX_n) + cX_n = cute.local_tile( + cute.domain_offset((0, k_off_n), cX_i), + tiler_mn_tile, + (0, 0), + ) + tXc_n = thr_copy_tile.partition_S(cX_n) + tXp_n = qutils.predicate_k(tXc_n, limit=limit_k) + + if const_expr((t % 2) == 0): + tXsX_n = thr_copy_tile.partition_D(sX1_tile) + tXsRes_n = ( + thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None + ) + tXp_pong = tXp_n + else: + tXsX_n = thr_copy_tile.partition_D(sX0_tile) + tXsRes_n = ( + thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None + ) + tXp_ping = tXp_n + + if row_i < shape[0]: + copy_tiled( + tXgX_n, tXsX_n, num_copy_elems=vecsize, is_async=True, pred=tXp_n + ) + if const_expr(mRes is not None): + gRes_n = cute.local_tile( + qutils.domain_offset_i64((0, k_off_n), mRes_i), + tiler_mn_tile, + (0, 0), + ) + tXgRes_n = thr_copy_tile.partition_S(gRes_n) + copy_tiled( + tXgRes_n, + tXsRes_n, + num_copy_elems=vecsize, + is_async=True, + pred=tXp_n, + ) + cute.arch.cp_async_commit_group() + + cute.arch.cp_async_wait_group(1 if next_t < num_tiles else 0) + + # Current tile buffer (ping/pong). + if const_expr((t % 2) == 0): + tXsX_cur = thr_copy_tile.partition_D(sX0_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None + ) + pred_cur = tXp_ping + else: + tXsX_cur = thr_copy_tile.partition_D(sX1_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None + ) + pred_cur = tXp_pong + + k_off = t * tile_n + gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, 0)) + tXgX_t = thr_copy_tile.partition_S(gX_t) + tXrX_t = cute.make_fragment_like(tXgX_t) + cute.autovec_copy(tXsX_cur, tXrX_t) + x_t = tXrX_t.load().to(cute.Float32) + if const_expr(mRes is not None): + gRes_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, 0) + ) + tXgRes_t = thr_copy_tile.partition_S(gRes_t) + tXrRes_t = cute.make_fragment_like(tXgRes_t) + cute.autovec_copy(tXsRes_cur, tXrRes_t) + x_t += tXrRes_t.load().to(cute.Float32) + + if const_expr(mResO is not None): + gResO_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mResO_i), + tiler_mn_tile, + (0, 0), + ) + tXgResO_t = thr_copy_tile.partition_D(gResO_t) + tXrResO_t = cute.make_fragment_like(tXgResO_t) + tXrResO_t.store(x_t.to(tXrResO_t.element_type)) + if row_i < shape[0]: + copy_tiled( + tXrResO_t, + tXgResO_t, + num_copy_elems=vecsize, + is_async=False, + pred=pred_cur, + ) + + sum_sq_thread = sum_sq_thread + (x_t * x_t).reduce( + cute.ReductionOp.ADD, + init_val=0.0, + reduction_profile=0, + ) + + sum_sq = row_reduce( + sum_sq_thread, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + init_val=0.0, + ) + rstd = cute.math.rsqrt(sum_sq / shape[1] + eps, fastmath=True) + + if const_expr(mRstd is not None): + if tXcX_i[0][1] == 0 and row_i < shape[0]: + tXgRstd_i[0] = rstd + + for t in cutlass.range_constexpr(num_tiles): + k_off = t * tile_n + cX_t = cute.local_tile(cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, 0)) + tXc_t = thr_copy_tile.partition_S(cX_t) + tXp_t = qutils.predicate_k(tXc_t, limit=limit_k) + + if const_expr((t % 2) == 0): + tXsX_cur = thr_copy_tile.partition_D(sX0_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None + ) + else: + tXsX_cur = thr_copy_tile.partition_D(sX1_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None + ) + + gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, 0)) + tXgX_t = thr_copy_tile.partition_S(gX_t) + tXrX_t = cute.make_fragment_like(tXgX_t) + cute.autovec_copy(tXsX_cur, tXrX_t) + x_t = tXrX_t.load().to(cute.Float32) + if const_expr(mRes is not None): + gRes_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, 0) + ) + tXgRes_t = thr_copy_tile.partition_S(gRes_t) + tXrRes_t = cute.make_fragment_like(tXgRes_t) + cute.autovec_copy(tXsRes_cur, tXrRes_t) + x_t += tXrRes_t.load().to(cute.Float32) + + y_t = x_t * rstd + if const_expr(mW is not None): + gW_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mW), tiler_mn_tile, (0, 0) + ) + tWgW_t = thr_copy_tile.partition_S(gW_t) + tWrW_t = cute.make_fragment_like(tWgW_t) + copy_tiled(tWgW_t, tWrW_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t) + y_t = y_t * tWrW_t.load().to(cute.Float32) + if const_expr(mB is not None): + gB_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mB), tiler_mn_tile, (0, 0) + ) + tWgB_t = thr_copy_tile.partition_S(gB_t) + tWrB_t = cute.make_fragment_like(tWgB_t) + copy_tiled(tWgB_t, tWrB_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t) + y_t = y_t + tWrB_t.load().to(cute.Float32) + + gO_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mO_i), tiler_mn_tile, (0, 0)) + tXgO_t = thr_copy_tile.partition_D(gO_t) + tXrO_t = cute.make_fragment_like(tXgO_t) + tXrO_t.store(y_t.to(tXrO_t.element_type)) + if row_i < shape[0]: + copy_tiled(tXrO_t, tXgO_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t) + + return + + # Single-stage path: one-row-per-CTA + tXgX_i = thr_copy.partition_S(gX_i) + tXgRes_i = thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None + tXgO_i = thr_copy.partition_D(gO_i) + tXgResO_i = thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None + # tXgRstd_i / tXcX_i / row_i prepared above + is_even_N_i = const_expr(shape[1] == tiler_mn[1] * cluster_n) + tXpX_i = ( + qutils.predicate_k(thr_copy.partition_S(cX_i), limit=limit_k) if not is_even_N_i else None + ) + + tXrX = cute.make_fragment_like(tXgX_i) + tXrRes = cute.make_fragment_like(tXgRes_i) if const_expr(mRes is not None) else None + if const_expr(self.direct_gmem): + if const_expr(not is_even_N_i): + tXrX.fill(0) + if const_expr(tXrRes is not None): + tXrRes.fill(0) + if row_i < shape[0]: + cute.copy(copy_atom, tXgX_i, tXrX, pred=tXpX_i) + if const_expr(tXrRes is not None): + cute.copy(copy_atom, tXgRes_i, tXrRes, pred=tXpX_i) + else: + # If N is not a multiple of the tile width, the predicated gmem->smem + # copies leave out-of-bounds lanes uninitialized. Clear the SMEM tile + # so masked lanes read as 0 for reduction/output. + if const_expr(not is_even_N_i): + thr_copy.partition_D(sX0).fill(0) + if const_expr(mRes is not None): + thr_copy.partition_D(sRes0).fill(0) + + if row_i < shape[0]: + cute.copy(copy_atom, tXgX_i, thr_copy.partition_D(sX0), pred=tXpX_i) + if const_expr(mRes is not None): + cute.copy(copy_atom, tXgRes_i, thr_copy.partition_D(sRes0), pred=tXpX_i) + if const_expr(use_async): + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + cute.autovec_copy(thr_copy.partition_D(sX0), tXrX) + if const_expr(tXrRes is not None): + cute.autovec_copy(thr_copy.partition_D(sRes0), tXrRes) + x_red = tXrX.load().to(cute.Float32) + if const_expr(tXrRes is not None): + x_red += tXrRes.load().to(cute.Float32) + + if const_expr(mResO is not None): + tXrResO = cute.make_fragment_like(tXgResO_i) + tXrResO.store(x_red.to(tXrResO.element_type)) + if row_i < shape[0]: + cute.copy( + get_copy_atom_bw(tXrResO.element_type, num_copy_elems_X, is_async=False), + tXrResO, + tXgResO_i, + pred=tXpX_i, + ) + + sum_sq = row_reduce( + x_red * x_red, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + init_val=0.0, + ) + rstd = cute.math.rsqrt(sum_sq / shape[1] + eps, fastmath=True) + + if const_expr(mRstd is not None): + if ( + tXcX_i[0][1] == 0 + and row_i < shape[0] + and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) + ): + tXgRstd_i[0] = rstd + + if const_expr(not self.direct_gmem and (mRes is not None or mB is not None)): + # Load weight/bias after the reduction so they don't inflate register + # pressure during the long-scoreboard reduction phase (helping occupancy + # when registers are the limiting factor). + if const_expr(mW is not None): + gW = cute.local_tile(qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0)) + tXgW = thr_copy.partition_S(gW) + tXrW = cute.make_fragment_like(tXgW) + if const_expr(not is_even_N_wb): + tXrW.fill(0) + cute.copy( + get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False), + tXgW, + tXrW, + pred=tXp_wb, + ) + if const_expr(mB is not None): + gB = cute.local_tile(qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0)) + tXgB = thr_copy.partition_S(gB) + tXrB = cute.make_fragment_like(tXgB) + if const_expr(not is_even_N_wb): + tXrB.fill(0) + cute.copy( + get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False), + tXgB, + tXrB, + pred=tXp_wb, + ) + + # Reuse `x_red` (x + residual, in fp32) for the output path so we don't + # keep both `tXrX` and `tXrRes` fragments live across the reduction. + y = x_red * rstd + if const_expr(mW is not None): + y = y * tXrW.load().to(cute.Float32) + if const_expr(mB is not None): + y = y + tXrB.load().to(cute.Float32) + + tXrO = cute.make_fragment_like(tXgO_i) + tXrO.store(y.to(tXrO.element_type)) + if row_i < shape[0]: + cute.copy( + get_copy_atom_bw(tXrO.element_type, num_copy_elems_X, is_async=False), + tXrO, + tXgO_i, + pred=tXpX_i, + ) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + eps: Float32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + cluster_n: cutlass.Constexpr[int], + num_warps: cutlass.Constexpr[int], + warps_per_row: cutlass.Constexpr[int], + threads_per_row: cutlass.Constexpr[int], + ): + self._kernel_impl( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + tv_layout, + tiler_mn, + cluster_n, + num_warps, + warps_per_row, + threads_per_row, + ) + else: + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + eps: Float32, + ): + copy_bits = int(self.copy_bits) + tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits) + num_threads = self._num_threads() + num_warps = num_threads // cute.arch.WARP_SIZE + threads_per_row = self._threads_per_row() + warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1) + cluster_n = self._cluster_n() + self._kernel_impl( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + tv_layout, + tiler_mn, + const_expr(cluster_n), + const_expr(num_warps), + const_expr(warps_per_row), + const_expr(threads_per_row), + ) + + @cute.jit + def _init_cluster(self, tidx: cutlass.Int32, mbar_ptr: Optional[cute.Pointer]): + if const_expr(mbar_ptr is not None): + if tidx < self.stage: + cute.arch.mbarrier_init(mbar_ptr + tidx, 1) + cute.arch.mbarrier_init_fence() + cute.arch.cluster_arrive_relaxed() + + +def _can_use_ptr_path( + x: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + residual: Optional[Tensor], +) -> bool: + """Fast path precondition for the pointer-based CuTeDSL entry. + + We require a row-major 2D layout where the last dimension is + contiguous (stride(1) == 1). The leading dimension (stride(0)) + may be larger than N (padded-row / packed-attention layouts), + and is passed to the kernel as `ld`. + """ + if x.stride(1) != 1: + return False + # All participating tensors are interpreted as the same element type + # (derived from x.dtype) in the pointer-based path. If dtypes differ, + # we'd read the wrong bit patterns and silently produce incorrect output. + if residual is not None and residual.dtype != x.dtype: + return False + if weight is not None and weight.dtype != x.dtype: + return False + if bias is not None and bias.dtype != x.dtype: + return False + # The kernel assumes `ld` satisfies a divisibility constraint used by + # cute.assume(..., divby=...) for vectorization. + elem_bits = TORCH2CUTE_DTYPE[x.dtype].width + divby = 256 // elem_bits + if (x.stride(0) % divby) != 0: + return False + # The kernel uses 128-bit vectorized copies (16B). Require at least 16B + # alignment on all participating tensors to avoid misaligned global loads. + if (x.data_ptr() % 16) != 0: + return False + if residual is not None and residual.stride(1) != 1: + return False + if residual is not None and residual.stride(0) != x.stride(0): + return False + if residual is not None and (residual.data_ptr() % 16) != 0: + return False + if weight is not None and not weight.is_contiguous(): + return False + if bias is not None and not bias.is_contiguous(): + return False + if weight is not None and (weight.data_ptr() % 16) != 0: + return False + if bias is not None and (bias.data_ptr() % 16) != 0: + return False + return True + + +def _can_use_ptr_path_fused_add_inplace( + x: Tensor, + weight: Tensor, + residual: Tensor, +) -> bool: + """Fast-path precondition for fused_add_rmsnorm_forward_inplace. + + We allow the common vLLM layout where: + - `x` is strided/padded row-major (stride(1) == 1, stride(0) >= N) + - `residual` is contiguous row-major (stride(0) == N) + """ + if x.stride(1) != 1: + return False + if residual.dtype != x.dtype: + return False + if weight.dtype != x.dtype: + return False + if residual.stride(1) != 1: + return False + if not residual.is_contiguous(): + return False + if not weight.is_contiguous(): + return False + + dtype = TORCH2CUTE_DTYPE[x.dtype] + divby = 256 // dtype.width + if (x.stride(0) % divby) != 0: + return False + if (residual.stride(0) % divby) != 0: + return False + + if (x.data_ptr() % 16) != 0: + return False + if (residual.data_ptr() % 16) != 0: + return False + if (weight.data_ptr() % 16) != 0: + return False + return True + + +def _rmsnorm_forward_ptr( + x: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + residual: Optional[Tensor], + eps: float, + store_rstd: bool, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Pointer-based RMSNorm forward that bypasses DLPack entirely. + + This path reconstructs cute.Tensor views from raw device pointers + and explicit layouts inside the JIT graph, avoiding any runtime + DLPack conversions while reusing the tuned RMSNormSM100 schedule. + """ + assert x.is_cuda + assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand." + M, N = x.shape + + # Preserve the input's 2D stride so downstream users that rely on + # padded-row layouts (stride0 > N) continue to see the expected layout. + out = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) + residual_out: Optional[Tensor] = None + rstd: Optional[Tensor] = None + + if residual is not None: + residual_out = torch.empty_strided( + residual.shape, residual.stride(), device=residual.device, dtype=residual.dtype + ) + if store_rstd: + rstd = torch.empty(M, device=x.device, dtype=torch.float32) + + _rmsnorm_forward_ptr_into( + x=x, + weight=weight, + bias=bias, + residual=residual, + out=out, + residual_out=residual_out, + rstd=rstd, + eps=eps, + ) + return out, rstd, residual_out + + +def _rmsnorm_forward_ptr_into( + x: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + residual: Optional[Tensor], + out: Tensor, + residual_out: Optional[Tensor], + rstd: Optional[Tensor], + eps: float, +) -> None: + """Internal helper that launches the pointer-based kernel into preallocated outputs. + + This enables integration into frameworks like vLLM that manage their + own buffers and prefer in-place or out-parameter semantics. + """ + assert x.is_cuda + assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand." + M, N = x.shape + device_index = x.get_device() + dtype = TORCH2CUTE_DTYPE[x.dtype] + + if bias is None and residual is None and residual_out is None and rstd is None: + # Fast-launch path: cache packed args and update pointers/scalars in-place to + # avoid Python-side argument marshalling overhead that dominates small-batch cases. + # + # If fast-launch is disabled (or CuTeDSL internals changed), we fall back + # to calling the compiled function directly. + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + has_weight = weight is not None + + stage = 1 + compiled_key = ( + "ptr", + N, + dtype, + False, # residual + has_weight, + False, # bias + False, # residual_out + False, # rstd + stage, + device_index, + ) + compiled = _PTR_COMPILE_CACHE.get(compiled_key) + if compiled is None: + op = RMSNormSM100(N, dtype, stage=stage) + ld_val = int(x.stride(0)) + ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_w = ( + rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + if has_weight + else None + ) + stream = cuda.CUstream(stream_handle) + ld = Int32(ld_val) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_x, + ptr_w, + None, # ptr_b + None, # ptr_res + ptr_out, + None, # ptr_res_out + None, # ptr_rstd + Int32(M), + Int32(N), + ld, + stream, + Float32(eps), + ) + _PTR_COMPILE_CACHE[compiled_key] = compiled + + launcher = _get_fast_ptr_rmsnorm_launcher( + compiled=compiled, + dtype=dtype, + N=N, + device_index=device_index, + stream_handle=stream_handle, + has_weight=has_weight, + eps=eps, + ) + ld_val = int(x.stride(0)) + if launcher is not None: + launcher.launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld_val, eps=eps) + return + + ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_w = ( + rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + if has_weight + else None + ) + stream = cuda.CUstream(stream_handle) + ld = Int32(ld_val) + compiled( + ptr_x, + ptr_w, + None, # ptr_b + None, # ptr_res + ptr_out, + None, # ptr_res_out + None, # ptr_rstd + Int32(M), + Int32(N), + ld, + stream, + Float32(eps), + ) + return + + # Fallback: general path (supports bias/residual/rstd, but is slower to launch). + stage = 1 + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + key = ( + "ptr", + N, + dtype, + residual is not None, + weight is not None, + bias is not None, + residual_out is not None, + rstd is not None, + stage, + device_index, + ) + compiled = _PTR_COMPILE_CACHE.get(key) + if compiled is None: + op = RMSNormSM100(N, dtype, stage=stage) + ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_res = ( + rt.make_ptr(dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + if residual is not None + else None + ) + ptr_res_out = ( + rt.make_ptr( + dtype, residual_out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + if residual_out is not None + else None + ) + ptr_w = ( + rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + if weight is not None + else None + ) + ptr_b = ( + rt.make_ptr(dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + if bias is not None + else None + ) + ptr_rstd = ( + rt.make_ptr( + cutlass.Float32, rstd.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=4 + ) + if rstd is not None + else None + ) + stream = cuda.CUstream(stream_handle) + ld = Int32(int(x.stride(0))) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_x, + ptr_w, + ptr_b, + ptr_res, + ptr_out, + ptr_res_out, + ptr_rstd, + Int32(M), + Int32(N), + ld, + stream, + Float32(eps), + ) + _PTR_COMPILE_CACHE[key] = compiled + ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_res = ( + rt.make_ptr(dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + if residual is not None + else None + ) + ptr_res_out = ( + rt.make_ptr(dtype, residual_out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + if residual_out is not None + else None + ) + ptr_w = ( + rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + if weight is not None + else None + ) + ptr_b = ( + rt.make_ptr(dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + if bias is not None + else None + ) + ptr_rstd = ( + rt.make_ptr(cutlass.Float32, rstd.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=4) + if rstd is not None + else None + ) + stream = cuda.CUstream(stream_handle) + ld = Int32(int(x.stride(0))) + compiled( + ptr_x, + ptr_w, + ptr_b, + ptr_res, + ptr_out, + ptr_res_out, + ptr_rstd, + Int32(M), + Int32(N), + ld, + stream, + Float32(eps), + ) + + +def _fused_add_rmsnorm_forward_ptr_inplace( + x: Tensor, + residual: Tensor, + weight: Tensor, + eps: float, +) -> None: + """Pointer-based fused_add_rmsnorm that updates `x` and `residual` in-place.""" + assert x.is_cuda + assert x.dim() == 2 + assert residual.is_cuda + assert residual.dim() == 2 + assert x.shape == residual.shape + + M, N = x.shape + device_index = x.get_device() + dtype = TORCH2CUTE_DTYPE[x.dtype] + stage = 1 + + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + + # Latency-optimized schedule for small-M cases: avoid the gmem->smem + # staging path (large dynamic smem + extra barriers) and load directly + # from gmem into registers. + copy_bits = 128 + # Use a direct-GMEM schedule (no staging SMEM tiles) for DSv3 hidden size + # (7168, bf16/fp16). This improves both: + # - small-M latency (fewer barriers + less dynamic shared memory), and + # - large-M bandwidth (lower overhead, better vectorization when 32B-aligned). + # + # This is a policy decision: it is tuned for DSv3's N=7168. If you want to + # benchmark other models/shapes, you can override it with: + # - OINK_RMSNORM_DIRECT_GMEM=0 (force staging/cp.async path) + # - OINK_RMSNORM_DIRECT_GMEM=1 (force direct-gmem path) + direct_gmem = _direct_gmem_from_policy(default=bool(dtype.width == 16 and N == 7168)) + use_async = not direct_gmem + tpr_override: Optional[int] = None + nt_override: Optional[int] = None + cluster_n_override: Optional[int] = None + direct_gmem_max_copy_bits: Optional[int] = None + + # Experimental stage-2 cp.async path (2-tile ping-pong) for N=7168. This is + # primarily about improving memory-latency hiding / reducing long-scoreboard + # stalls for large-M workloads. + if _ENABLE_STAGE2 and dtype.width == 16 and N == 7168 and M >= 4096: + stage = 2 + direct_gmem = False + use_async = True + + # Experimental ILP variant (clusters): split each row across 2 CTAs. + # + # NOTE: This is currently opt-in because some CuTeDSL builds exhibit + # instability with cluster launches for this specific schedule. To reduce + # the chance of accidental crashes, we require an additional explicit + # opt-in via `OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE=1`. + if _ENABLE_CLUSTER_ILP and not _ENABLE_STAGE2: + if dtype.width == 16 and N == 7168 and M >= 4096: + cluster_n_override = 2 + if direct_gmem: + # Cluster launches + direct-GMEM has exhibited reproducible compiler + # instability (segfaults) in some CuTeDSL builds, especially for the + # 256b vector path. Probe it out-of-process once so we can safely + # select a working copy width (or fall back to the staged SMEM path) + # instead of crashing the parent process. + max_bits = _probe_cluster_direct_gmem_max_copy_bits() + if max_bits == 0: + direct_gmem = False + use_async = True + else: + direct_gmem_max_copy_bits = max_bits + + # Experimental per-row partitioning: use 256 threads/row for N=7168 to + # increase concurrency/ILP (accepts a small tail-predicate region). + if _ENABLE_TPR256 and cluster_n_override is None and not _ENABLE_STAGE2: + if dtype.width == 16 and N == 7168 and M >= 4096: + tpr_override = 256 + nt_override = 256 + + + can_use_256 = bool( + direct_gmem + and ( + direct_gmem_max_copy_bits is None + or direct_gmem_max_copy_bits >= 256 + ) + and dtype.width == 16 + and (x.data_ptr() % 32) == 0 + and (residual.data_ptr() % 32) == 0 + and (weight.data_ptr() % 32) == 0 + ) + assumed_align = 32 if can_use_256 else 16 + if can_use_256: + copy_bits = 256 + + copy_bits = _copy_bits_from_policy(default=copy_bits, can_use_256=can_use_256) + if copy_bits == 128: + assumed_align = 16 + elif copy_bits == 256 and can_use_256: + assumed_align = 32 + else: + copy_bits = 128 + assumed_align = 16 + + key = ( + "ptr_fused_add_inplace", + N, + dtype, + stage, + device_index, + copy_bits, + use_async, + tpr_override, + nt_override, + direct_gmem, + cluster_n_override, + ) + compiled = _PTR_COMPILE_CACHE.get(key) + if compiled is None: + op = RMSNormSM100( + N, + dtype, + stage=stage, + copy_bits=copy_bits, + use_async=use_async, + direct_gmem=direct_gmem, + ) + if tpr_override is not None: + op._tpr_override = tpr_override # type: ignore[attr-defined] + if nt_override is not None: + op._nt_override = nt_override # type: ignore[attr-defined] + if cluster_n_override is not None: + op._cluster_n_override = cluster_n_override # type: ignore[attr-defined] + ptr_x = rt.make_ptr( + dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_res = rt.make_ptr( + dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_w = rt.make_ptr( + dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + stream = cuda.CUstream(stream_handle) + ld_x = Int32(int(x.stride(0))) + compiled = cute.compile( + op.launch_from_ptrs_fused_add_inplace, + ptr_x, + ptr_w, + ptr_res, + Int32(M), + Int32(N), + ld_x, + stream, + Float32(eps), + ) + _PTR_COMPILE_CACHE[key] = compiled + launcher = _get_fast_ptr_fused_add_rmsnorm_launcher( + compiled=compiled, + dtype=dtype, + N=N, + device_index=device_index, + stream_handle=stream_handle, + copy_bits=copy_bits, + use_async=use_async, + tpr=tpr_override or 0, + direct_gmem=direct_gmem, + assumed_align=assumed_align, + eps=eps, + ) + if launcher is not None: + launcher.launch( + x=x, + weight=weight, + residual=residual, + M=M, + N=N, + ld_x=int(x.stride(0)), + eps=eps, + ) + return + + # Fast-launch is disabled/unavailable (or CuTeDSL internals changed). Fall back + # to calling the compiled function directly. + ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) + ptr_res = rt.make_ptr( + dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_w = rt.make_ptr( + dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + stream = cuda.CUstream(stream_handle) + ld_x = Int32(int(x.stride(0))) + compiled(ptr_x, ptr_w, ptr_res, Int32(M), Int32(N), ld_x, stream, Float32(eps)) + + +# ------------------------- +# Public API (forward + verify) +# ------------------------- + + +def rmsnorm_forward( + x: Tensor, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + residual: Optional[Tensor] = None, + eps: float = 1e-6, + store_rstd: bool = False, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + assert x.is_cuda + assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand." + M, N = x.shape + dtype = TORCH2CUTE_DTYPE[x.dtype] + + # For DSv3 big-M outliers on SM100, keep using the dedicated + # stage-2 K-loop implementation, which is already tuned and + # parity-checked against the reference. + use_stage2_big_dsv3 = bool( + M >= 65536 and N in (6144, 8192) and x.dtype in (torch.float16, torch.bfloat16) + ) + if use_stage2_big_dsv3: + try: + import rmsnorm_with_stage2 as rms2 # type: ignore[import-not-found] + except Exception: + rms2 = None # type: ignore[assignment] + if rms2 is not None: + y, rstd, residual_out = rms2.rmsnorm_forward_with_stage2( + x, weight=weight, bias=bias, residual=residual, eps=eps, store_rstd=store_rstd + ) + # Preserve stride contracts for torch.compile consistency, even + # when using the optional stage-2 implementation. + if y.stride() != x.stride(): + y_strided = torch.empty_strided( + x.shape, x.stride(), device=x.device, dtype=x.dtype + ) + y_strided.copy_(y) + y = y_strided + if residual is not None and residual_out is not None: + if residual_out.stride() != residual.stride(): + residual_out_strided = torch.empty_strided( + residual.shape, + residual.stride(), + device=residual.device, + dtype=residual.dtype, + ) + residual_out_strided.copy_(residual_out) + residual_out = residual_out_strided + return y, rstd, residual_out + + # Default: use the pointer-based entry whenever we can represent the + # inputs as a row-major [M, N] view with stride(1) == 1. For rare layouts + # we can't safely express without DLPack, fall back to a torch reference. + if _can_use_ptr_path(x, weight, bias, residual): + return _rmsnorm_forward_ptr(x, weight, bias, residual, eps, store_rstd) + + # Safe fallback (correctness-first). This is expected to be rare in vLLM. + y = rmsnorm_ref(x, weight, bias, residual, eps) + # Preserve the input stride contract even on the fallback path so + # torch.compile sees a consistent output layout across all branches. + if y.stride() != x.stride(): + y_strided = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) + y_strided.copy_(y) + y = y_strided + rstd = None + if store_rstd: + xf = x.float() + if residual is not None: + xf = xf + residual.float() + rstd = torch.rsqrt(xf.square().mean(dim=-1) + eps).to(torch.float32) + residual_out = None + if residual is not None: + residual_out = (x.float() + residual.float()).to(x.dtype) + if residual_out.stride() != residual.stride(): + residual_out_strided = torch.empty_strided( + residual.shape, + residual.stride(), + device=residual.device, + dtype=residual.dtype, + ) + residual_out_strided.copy_(residual_out) + residual_out = residual_out_strided + return y, rstd, residual_out + + +def rmsnorm_ref( + x: Tensor, + w: Optional[Tensor] = None, + b: Optional[Tensor] = None, + residual: Optional[Tensor] = None, + eps: float = 1e-6, +) -> Tensor: + xf = x.float() + if residual is not None: + xf = xf + residual.float() + rstd = torch.rsqrt(xf.square().mean(dim=-1, keepdim=True) + eps) + y = xf * rstd + if w is not None: + y = y * w.float() + if b is not None: + y = y + b.float() + return y.to(x.dtype) + + +def fused_add_rmsnorm_forward( + x: Tensor, + residual: Tensor, + weight: Tensor, + eps: float = 1e-6, +) -> Tuple[Tensor, Tensor]: + """Fused residual-add + RMSNorm for SM100 in CuteDSL. + + This is a convenience wrapper around ``rmsnorm_forward`` that matches the + semantics of vLLM's ``fused_add_rms_norm``: + + z = x + residual + y = RMSNorm(z, weight, eps) + + It returns ``(y, z)`` where ``z`` has the same dtype/shape as the inputs. + """ + assert x.is_cuda and residual.is_cuda + assert x.shape == residual.shape + assert x.dtype == residual.dtype + + orig_shape = x.shape + N = orig_shape[-1] + + x_2d = x.view(-1, N) + res_2d = residual.view(-1, N) + + y_2d, _rstd, z_2d = rmsnorm_forward( + x_2d, + weight=weight, + bias=None, + residual=res_2d, + eps=eps, + store_rstd=False, + ) + + y = y_2d.view(orig_shape) + z = z_2d.view(orig_shape) + return y, z + + +def fused_add_rmsnorm_forward_inplace( + x: Tensor, + residual: Tensor, + weight: Tensor, + eps: float = 1e-6, +) -> Tuple[Tensor, Tensor]: + """In-place fused residual-add + RMSNorm matching vLLM semantics. + + This variant writes: + + z = x + residual (stored into ``residual``) + y = RMSNorm(z, w) (stored into ``x``) + + i.e., it uses ``x`` as the normalized output buffer and ``residual`` as + the residual-out buffer, mirroring vLLM's fused_add_rms_norm kernel. + """ + fused_add_rmsnorm_inplace_(x, residual, weight, eps=eps) + return x, residual + + +def fused_add_rmsnorm_inplace_( + x: Tensor, + residual: Tensor, + weight: Tensor, + eps: float = 1e-6, +) -> None: + """In-place fused residual-add + RMSNorm matching vLLM semantics. + + This is the lowest-overhead Python entrypoint (returns `None`) intended + for performance-critical call sites like `torch.ops.oink.fused_add_rms_norm`. + """ + assert x.is_cuda and residual.is_cuda + assert x.shape == residual.shape + assert x.dtype == residual.dtype + + N = x.shape[-1] + x_2d = x if x.dim() == 2 else x.view(-1, N) + res_2d = residual if residual.dim() == 2 else residual.view(-1, N) + + # Fast path: vLLM-compatible layout where x may be strided/padded but + # residual is contiguous. This updates both tensors in-place without + # additional allocations. + if _can_use_ptr_path_fused_add_inplace(x_2d, weight, res_2d): + _fused_add_rmsnorm_forward_ptr_inplace(x_2d, res_2d, weight, eps) + return None + + # Fallback: allocate via the regular fused path, then copy results into + # the user-provided buffers so that semantics remain identical. + y, z = fused_add_rmsnorm_forward(x, residual, weight, eps) + x.copy_(y) + residual.copy_(z) + return None + + +if __name__ == "__main__": + # Minimal ad-hoc test (functionality only). For performance comparisons, use the benchmark harness. + if not torch.cuda.is_available(): + print("CUDA not available; functional test skipped.") + sys.exit(0) + M, N = 1024, 8192 + dtype = torch.bfloat16 + x = torch.randn(M, N, device="cuda", dtype=dtype) + w = torch.randn(N, device="cuda", dtype=dtype) + y_ref = rmsnorm_ref(x, w) + y, _, _ = rmsnorm_forward(x, w) + torch.testing.assert_close(y, y_ref, rtol=1e-3, atol=1e-3) + print("RMSNormSM100 correctness check passed.") + +# (compile cache moved to top) From 3003c1398f700cf13095eb40f1e5bab84de013f7 Mon Sep 17 00:00:00 2001 From: Laura Wang <3700467+Laurawly@users.noreply.github.com> Date: Tue, 6 Jan 2026 12:25:31 -0800 Subject: [PATCH 2/8] Fix oink ruff lint and add license headers --- oink/src/kernelagent_oink/__init__.py | 14 ++++++++ .../kernelagent_oink/blackwell/__init__.py | 14 ++++++++ .../kernelagent_oink/blackwell/lite_quack.py | 20 ++++++++--- .../blackwell/oink_custom_ops.py | 16 ++++++++- .../src/kernelagent_oink/blackwell/rmsnorm.py | 33 +++++++++++++------ 5 files changed, 82 insertions(+), 15 deletions(-) diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py index 542e59e..d9f25d0 100644 --- a/oink/src/kernelagent_oink/__init__.py +++ b/oink/src/kernelagent_oink/__init__.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import logging diff --git a/oink/src/kernelagent_oink/blackwell/__init__.py b/oink/src/kernelagent_oink/blackwell/__init__.py index 4d21ee8..a92109a 100644 --- a/oink/src/kernelagent_oink/blackwell/__init__.py +++ b/oink/src/kernelagent_oink/blackwell/__init__.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations __all__ = [] diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py index 3c3f750..8c05b47 100644 --- a/oink/src/kernelagent_oink/blackwell/lite_quack.py +++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Lightweight local clone of the small subset of helpers that the SM100 RMSNorm CuteDSL kernels depend on. @@ -12,9 +26,8 @@ import math import operator -from typing import Callable, Optional, Tuple +from typing import Callable, Optional -import cuda.bindings.driver as cuda # type: ignore import torch from torch import Tensor @@ -23,7 +36,7 @@ from cutlass import Float32, Int32, const_expr from cutlass.cute.runtime import from_dlpack from cutlass.cutlass_dsl import T, dsl_user_op -from cutlass._mlir.dialects import llvm, nvvm, vector +from cutlass._mlir.dialects import llvm # ------------------------- @@ -347,4 +360,3 @@ def get_sm_count(N: int, device: torch.device) -> int: sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2 ) return sm_count - diff --git a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py index 8225025..92423d9 100644 --- a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py +++ b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py @@ -1,4 +1,16 @@ -from __future__ import annotations +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Torch custom ops wrapping Oink's Blackwell RMSNorm kernels. @@ -26,6 +38,8 @@ Mutates `x` and `residual` in-place and returns None. """ +from __future__ import annotations + import importlib import threading diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py index d6c2c20..a77938c 100644 --- a/oink/src/kernelagent_oink/blackwell/rmsnorm.py +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ RMSNorm kernel for SM100 (Blackwell) in CuteDSL. @@ -53,15 +67,15 @@ "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)." ) from e -import torch -from torch import Tensor +import torch # noqa: E402 +from torch import Tensor # noqa: E402 -import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python +import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python # noqa: E402 -import cutlass -import cutlass.cute as cute -from cutlass import Float32, Int32, const_expr -from cutlass.cute import runtime as rt +import cutlass # noqa: E402 +import cutlass.cute as cute # noqa: E402 +from cutlass import Float32, Int32, const_expr # noqa: E402 +from cutlass.cute import runtime as rt # noqa: E402 # Simple compile cache declared early so direct execution works _PTR_COMPILE_CACHE = {} @@ -862,8 +876,8 @@ def _get_fast_ptr_fused_add_rmsnorm_launcher( # # NOTE: Avoid `from . import ...` imports here: CuTeDSL's AST preprocessor may # mishandle that form (module=None in the AST). Use fully-qualified imports. -from kernelagent_oink.blackwell import lite_quack as qutils -from kernelagent_oink.blackwell.lite_quack import TORCH2CUTE_DTYPE, row_reduce +from kernelagent_oink.blackwell import lite_quack as qutils # noqa: E402 +from kernelagent_oink.blackwell.lite_quack import TORCH2CUTE_DTYPE, row_reduce # noqa: E402 # ------------------------- @@ -2458,7 +2472,6 @@ def rmsnorm_forward( assert x.is_cuda assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand." M, N = x.shape - dtype = TORCH2CUTE_DTYPE[x.dtype] # For DSv3 big-M outliers on SM100, keep using the dedicated # stage-2 K-loop implementation, which is already tuned and From 1468088ecab42c675d8e4d9e0bf465bd5f68644a Mon Sep 17 00:00:00 2001 From: Laura Wang <3700467+Laurawly@users.noreply.github.com> Date: Tue, 6 Jan 2026 12:27:54 -0800 Subject: [PATCH 3/8] Format oink with ruff --- oink/src/kernelagent_oink/__init__.py | 4 +- .../kernelagent_oink/blackwell/lite_quack.py | 41 +- .../blackwell/oink_custom_ops.py | 6 +- .../src/kernelagent_oink/blackwell/rmsnorm.py | 472 ++++++++++++++---- 4 files changed, 403 insertions(+), 120 deletions(-) diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py index d9f25d0..bbbd7c1 100644 --- a/oink/src/kernelagent_oink/__init__.py +++ b/oink/src/kernelagent_oink/__init__.py @@ -94,7 +94,9 @@ def register() -> None: # Ensure CuTeDSL sees a target arch early. If the user has already set it, # respect their choice. - os.environ.setdefault("CUTE_DSL_ARCH", _compute_cutedsl_arch(int(major), int(minor))) + os.environ.setdefault( + "CUTE_DSL_ARCH", _compute_cutedsl_arch(int(major), int(minor)) + ) # Import registers the ops via torch.library.custom_op decorators. from .blackwell import oink_custom_ops # noqa: F401 diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py index 8c05b47..14ae723 100644 --- a/oink/src/kernelagent_oink/blackwell/lite_quack.py +++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py @@ -54,6 +54,7 @@ # Tensor conversion helpers # ------------------------- + def convert_from_dlpack( x: Tensor, leading_dim: int, @@ -82,7 +83,9 @@ def convert_from_dlpack( @dsl_user_op -def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: +def elem_pointer( + x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None +) -> cute.Pointer: return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) @@ -133,7 +136,9 @@ def store_shared_remote( ).ir_value() if const_expr(isinstance(val, float)): val = Float32(val) - assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64" + assert isinstance(val, (Float32, Int32, cutlass.Int64)), ( + "val must be Float32, Int32, or Int64" + ) suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)] constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)] llvm.inline_asm( @@ -155,19 +160,27 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: """ tApA = cute.make_fragment( cute.make_layout( - (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + ( + cute.size(tAcA, mode=[0, 1]), + cute.size(tAcA, mode=[1]), + cute.size(tAcA, mode=[2]), + ), stride=(cute.size(tAcA, mode=[2]), 0, 1), ), cutlass.Boolean, ) for rest_v in cutlass.range_constexpr(tApA.shape[0]): for rest_k in cutlass.range_constexpr(tApA.shape[2]): - tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + tApA[rest_v, 0, rest_k] = cute.elem_less( + tAcA[(0, rest_v), 0, rest_k][1], limit + ) return tApA @dsl_user_op -def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: +def domain_offset_i64( + coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None +) -> cute.Tensor: """ Return a tensor whose iterator is offset by an Int64 byte offset computed from `coord` and the tensor's strides. @@ -287,7 +300,9 @@ def block_or_cluster_reduce( """Dispatch between block or cluster reduction depending on mbar_ptr.""" if cutlass.const_expr(mbar_ptr is None): return block_reduce(val, op, reduction_buffer, init_val=init_val) - return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase) + return cluster_reduce( + val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase + ) @cute.jit @@ -313,7 +328,9 @@ def row_reduce( val = x warp_op = { cute.ReductionOp.ADD: operator.add, - cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max, + cute.ReductionOp.MAX: cute.arch.fmax + if cutlass.const_expr(x.dtype == Float32) + else max, cute.ReductionOp.MIN: min, cute.ReductionOp.MUL: operator.mul, }[op] @@ -353,10 +370,16 @@ def get_sm_count(N: int, device: torch.device) -> int: RMSNorm kernels but lives entirely in this local module. """ sm_count_multiple = ( - 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1))) + 16 + if N <= 256 + else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1))) ) sm_count = torch.cuda.get_device_properties(device).multi_processor_count sm_count = ( - sm_count * sm_count_multiple if N <= 8192 else sm_count // 2 if N <= 16384 else sm_count * 2 + sm_count * sm_count_multiple + if N <= 8192 + else sm_count // 2 + if N <= 16384 + else sm_count * 2 ) return sm_count diff --git a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py index 92423d9..a96a4c7 100644 --- a/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py +++ b/oink/src/kernelagent_oink/blackwell/oink_custom_ops.py @@ -80,6 +80,7 @@ def _get_sm(device: torch.device | None = None) -> int: # RMSNorm (functional) # + @custom_op("oink::rmsnorm", mutates_args=()) def oink_rmsnorm( x: torch.Tensor, @@ -158,6 +159,7 @@ def oink_rmsnorm_fake( # Fused residual-add + RMSNorm (in-place, vLLM semantics) # + @custom_op("oink::fused_add_rms_norm", mutates_args=("x", "residual")) def oink_fused_add_rms_norm( x: torch.Tensor, @@ -174,7 +176,9 @@ def oink_fused_add_rms_norm( Returns: None (mutates `x` and `residual` in-place). """ - assert x.is_cuda and residual.is_cuda, "oink::fused_add_rms_norm requires CUDA tensors" + assert x.is_cuda and residual.is_cuda, ( + "oink::fused_add_rms_norm requires CUDA tensors" + ) assert x.shape == residual.shape, "x and residual must have the same shape" assert x.dtype == residual.dtype, "x and residual must have the same dtype" assert weight.dim() == 1, "weight must be 1D [N]" diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py index a77938c..c7fc1b3 100644 --- a/oink/src/kernelagent_oink/blackwell/rmsnorm.py +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py @@ -84,6 +84,7 @@ # pointer/scalar storage so concurrent callers don't race on in-place updates. _PTR_FAST_LAUNCH_TLS = threading.local() + def _env_flag(name: str, default: bool) -> bool: val = os.environ.get(name) if val is None: @@ -100,10 +101,16 @@ def _env_flag(name: str, default: bool) -> bool: # Fused-add RMSNorm schedule knobs (read once at import time; set env vars before # importing this module if you want to override). -_DIRECT_GMEM_POLICY = (os.environ.get("OINK_RMSNORM_DIRECT_GMEM", "auto").strip().lower() or "auto") -_COPY_BITS_POLICY = (os.environ.get("OINK_RMSNORM_COPY_BITS", "auto").strip().lower() or "auto") +_DIRECT_GMEM_POLICY = ( + os.environ.get("OINK_RMSNORM_DIRECT_GMEM", "auto").strip().lower() or "auto" +) +_COPY_BITS_POLICY = ( + os.environ.get("OINK_RMSNORM_COPY_BITS", "auto").strip().lower() or "auto" +) _ENABLE_CLUSTER_ILP = _env_flag("OINK_RMSNORM_ENABLE_CLUSTER_ILP", default=False) -_ENABLE_CLUSTER_ILP_UNSAFE = _env_flag("OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE", default=False) +_ENABLE_CLUSTER_ILP_UNSAFE = _env_flag( + "OINK_RMSNORM_ENABLE_CLUSTER_ILP_UNSAFE", default=False +) _ENABLE_TPR256 = _env_flag("OINK_RMSNORM_ENABLE_TPR256", default=False) _ENABLE_STAGE2 = _env_flag("OINK_RMSNORM_ENABLE_STAGE2", default=False) @@ -252,6 +259,7 @@ def run_probe(copy_bits: int, assumed_align: int): _CLUSTER_DIRECT_GMEM_MAX_COPY_BITS = max_bits return max_bits + def _parse_version_tuple(version: str) -> Tuple[int, int, int]: parts = version.split(".") nums: list[int] = [] @@ -275,7 +283,9 @@ def _cutlass_dsl_version() -> Optional[Tuple[int, int, int]]: # passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the # older signature for 4.3.2, but switch to a 4.3.4-compatible signature when we # detect 4.3.4+ (or when version detection is unavailable). -_KERNEL_ACCEPTS_LAYOUT_ARGS = _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4) +_KERNEL_ACCEPTS_LAYOUT_ARGS = ( + _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4) +) if _ENABLE_CLUSTER_ILP and not _ENABLE_CLUSTER_ILP_UNSAFE: # We have observed reproducible segfaults in some CuTeDSL builds when using @@ -427,7 +437,9 @@ def launch( self._last_x_ptr = x_ptr except AttributeError: self._disable_fast_launch() - self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps) + self._fallback_launch( + x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps + ) return if self._ptr_w is not None: @@ -438,7 +450,9 @@ def launch( self._last_w_ptr = w_ptr except AttributeError: self._disable_fast_launch() - self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps) + self._fallback_launch( + x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps + ) return out_ptr = out.data_ptr() @@ -448,7 +462,9 @@ def launch( self._last_out_ptr = out_ptr except AttributeError: self._disable_fast_launch() - self._fallback_launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps) + self._fallback_launch( + x=x, weight=weight, out=out, M=M, N=N, ld=ld, eps=eps + ) return if M != self._last_m: @@ -492,10 +508,19 @@ def _fallback_launch( # If the packed-args or runtime pointer mutation path stops working # (e.g. due to a CuTeDSL upgrade), fall back to the regular call path. dtype = TORCH2CUTE_DTYPE[x.dtype] - ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) - ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_x = rt.make_ptr( + dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_out = rt.make_ptr( + dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) ptr_w = ( - rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + rt.make_ptr( + dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) if weight is not None else None ) @@ -695,7 +720,15 @@ def _get_fast_ptr_rmsnorm_launcher( return None # Keyed by the compiled object identity so schedule changes (e.g. copy width, # async/staged variants, etc.) never alias in the fast-launch cache. - key = ("ptr_fast", id(compiled), N, dtype, device_index, int(stream_handle), has_weight) + key = ( + "ptr_fast", + id(compiled), + N, + dtype, + device_index, + int(stream_handle), + has_weight, + ) cache = _tls_fast_launch_cache() cached = cache.get(key) if cached is not None: @@ -705,7 +738,9 @@ def _get_fast_ptr_rmsnorm_launcher( ptr_x = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) ptr_out = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) ptr_w = ( - rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) if has_weight else None + rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) + if has_weight + else None ) arg_m = _StableI32Arg(0) @@ -808,9 +843,15 @@ def _get_fast_ptr_fused_add_rmsnorm_launcher( if cached is not None: return cached # type: ignore[return-value] - ptr_x = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) - ptr_res = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) - ptr_w = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) + ptr_x = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_res = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_w = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) arg_m = _StableI32Arg(0) arg_n = _StableI32Arg(N) @@ -884,6 +925,7 @@ def _get_fast_ptr_fused_add_rmsnorm_launcher( # Copy helpers (allow up to 256b) # ------------------------- + @cute.jit def get_copy_atom_bw( dtype: type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False @@ -892,6 +934,7 @@ def get_copy_atom_bw( max_bits = const_expr(128 if is_async else 256) num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width)) from cutlass.cute.nvgpu import cpasync + # Prefer GLOBAL cache policy for bulk streaming reads at large M copy_op = ( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL) @@ -1037,7 +1080,10 @@ def _tv_layout(self, num_copy_bits: int = 256) -> Tuple[cute.Shape, cute.Layout] tiler_mn = (cols_per_block, vecsize * num_blocks_N * tpr) tv_layout = cute.make_layout( ((tpr, cols_per_block), (vecsize, num_blocks_N)), - stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * tpr)), + stride=( + (vecsize * cols_per_block, 1), + (cols_per_block, cols_per_block * vecsize * tpr), + ), ) return tiler_mn, tv_layout @@ -1045,7 +1091,10 @@ def _smem_bytes(self, tiler_mn, num_warps) -> int: # smem for X tile (+ residual if present) + reduction buffers + mbar(s) return ( cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) - + self.stage * num_warps * self._cluster_n() * (self.reduction_dtype.width // 8) + + self.stage + * num_warps + * self._cluster_n() + * (self.reduction_dtype.width // 8) + self.stage * (cutlass.Int64.width // 8) ) @@ -1072,7 +1121,9 @@ def new_stride(t): ) mX, mRes, mO, mResO = [ - cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))) + cute.make_tensor( + t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)) + ) if const_expr(t is not None) else None for t in (mX, mRes, mO, mResO) @@ -1082,23 +1133,34 @@ def new_stride(t): copy_bits = int(self.copy_bits) tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits) - num_threads = cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._num_threads() + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._num_threads() + ) num_warps = num_threads // cute.arch.WARP_SIZE - threads_per_row = tv_layout.shape[0][0] if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._threads_per_row() + threads_per_row = ( + tv_layout.shape[0][0] + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._threads_per_row() + ) warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1) cluster_n = self._cluster_n() if const_expr(mW is not None): mW = cute.make_tensor( - mW.iterator, cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))) + mW.iterator, + cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), ) if const_expr(mB is not None): mB = cute.make_tensor( - mB.iterator, cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))) + mB.iterator, + cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), ) if const_expr(mRstd is not None): mRstd = cute.make_tensor( - mRstd.iterator, cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))) + mRstd.iterator, + cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))), ) # No SMEM reload mode switch; overlap is controlled in the K-loop path @@ -1114,11 +1176,14 @@ def new_stride(t): else 0 ) tile_bytes_res = ( - cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn)) * stage_bufs + cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn)) + * stage_bufs if const_expr(mRes is not None and not self.direct_gmem) else 0 ) - red_bytes = self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8) + red_bytes = ( + self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8) + ) # mbarriers are only allocated/used for cluster_n>1. Some CuTeDSL builds # require mbarrier state to be 16B-aligned in shared memory; account for # the alignment padding when computing dynamic smem bytes. @@ -1211,14 +1276,10 @@ def launch_from_ptrs( else None ) mW = ( - cute.make_tensor(ptr_w, layout_n) - if const_expr(ptr_w is not None) - else None + cute.make_tensor(ptr_w, layout_n) if const_expr(ptr_w is not None) else None ) mB = ( - cute.make_tensor(ptr_b, layout_n) - if const_expr(ptr_b is not None) - else None + cute.make_tensor(ptr_b, layout_n) if const_expr(ptr_b is not None) else None ) mRstd = ( cute.make_tensor(ptr_rstd, layout_m) @@ -1303,7 +1364,9 @@ def _kernel_impl( # Allocate one or two SMEM buffers depending on stage depth sX0 = ( smem.allocate_tensor( - mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32 + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, ) if const_expr(not self.direct_gmem) else None @@ -1319,7 +1382,9 @@ def _kernel_impl( ) sRes0 = ( smem.allocate_tensor( - mRes.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32 + mRes.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, ) if const_expr(mRes is not None and not self.direct_gmem) else None @@ -1339,7 +1404,9 @@ def _kernel_impl( (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage), order=(1, 0, 2), ) - reduction_buffer = smem.allocate_tensor(self.reduction_dtype, red_layout, byte_alignment=4) + reduction_buffer = smem.allocate_tensor( + self.reduction_dtype, red_layout, byte_alignment=4 + ) if const_expr(cluster_n > 1): # Some CuTeDSL builds appear sensitive to the shared-memory alignment of # mbarrier state. `SmemAllocator.allocate_array` does not currently @@ -1360,8 +1427,12 @@ def _kernel_impl( # Tiled copy setup num_copy_elems_X = tv_layout.shape[1][0] - use_async = const_expr(self.use_async and self.N >= 1024 and not self.direct_gmem) - copy_atom = get_copy_atom_bw(mX.element_type, num_copy_elems_X, is_async=use_async) + use_async = const_expr( + self.use_async and self.N >= 1024 and not self.direct_gmem + ) + copy_atom = get_copy_atom_bw( + mX.element_type, num_copy_elems_X, is_async=use_async + ) thr_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn).get_slice(tidx) # Tail predicate for the N dimension (when tile width > N). Reuse this @@ -1386,7 +1457,9 @@ def _kernel_impl( mW is not None and (self.direct_gmem or (mRes is None and mB is None)) ) if const_expr(prefetch_w_early): - gW = cute.local_tile(qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0)) + gW = cute.local_tile( + qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0) + ) tXgW = thr_copy.partition_S(gW) tXrW = cute.make_fragment_like(tXgW) if const_expr(not is_even_N_wb): @@ -1398,7 +1471,9 @@ def _kernel_impl( pred=tXp_wb, ) if const_expr(self.direct_gmem and mB is not None): - gB = cute.local_tile(qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0)) + gB = cute.local_tile( + qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0) + ) tXgB = thr_copy.partition_S(gB) tXrB = cute.make_fragment_like(tXgB) if const_expr(not is_even_N_wb): @@ -1414,7 +1489,9 @@ def _kernel_impl( self._init_cluster(tidx, mbar_ptr) mX_i, mRes_i, mO_i, mResO_i = [ - qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t) if t is not None else None + qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t) + if t is not None + else None for t in (mX, mRes, mO, mResO) ] mX_i, mRes_i, mO_i, mResO_i = [ @@ -1424,27 +1501,39 @@ def _kernel_impl( gX_i = cute.local_tile(mX_i, tiler_mn, (0, 0)) gO_i = cute.local_tile(mO_i, tiler_mn, (0, 0)) gRes_i = ( - cute.local_tile(mRes_i, tiler_mn, (0, 0)) if const_expr(mRes is not None) else None + cute.local_tile(mRes_i, tiler_mn, (0, 0)) + if const_expr(mRes is not None) + else None ) gResO_i = ( - cute.local_tile(mResO_i, tiler_mn, (0, 0)) if const_expr(mResO is not None) else None + cute.local_tile(mResO_i, tiler_mn, (0, 0)) + if const_expr(mResO is not None) + else None ) gRstd_i = ( - cute.local_tile(mRstd, tiler_mn, (bidx, 0)) if const_expr(mRstd is not None) else None + cute.local_tile(mRstd, tiler_mn, (bidx, 0)) + if const_expr(mRstd is not None) + else None ) cX_i = cute.local_tile(idX, tiler_mn, (bidx, 0)) # Common identity/row index partitions reused by both default and K-loop paths tXcX_i = thr_copy.partition_S(cX_i)[(0, None), None, None] row_i = tXcX_i[0][0] - tXgRstd_i = thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None + tXgRstd_i = ( + thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None + ) # Stage-2 intra-row K-loop cp.async ping-pong (two tiles). This reduces # per-thread fragment size and can improve memory-latency hiding for # N=7168 at large M. It is enabled by setting `stage=2` when constructing # the RMSNormSM100 op (see `_fused_add_rmsnorm_forward_ptr_inplace`). if const_expr( - self.stage > 1 and not self.direct_gmem and use_async and cluster_n == 1 and shape[1] == 7168 + self.stage > 1 + and not self.direct_gmem + and use_async + and cluster_n == 1 + and shape[1] == 7168 ): vecsize = tv_layout.shape[1][0] tpr = threads_per_row @@ -1475,9 +1564,9 @@ def _kernel_impl( (tiler_mn[0], tiler_mn[0] * vecsize * tpr), ), ) - thr_copy_tile = cute.make_tiled_copy(copy_atom, tv_layout_tile, tiler_mn_tile).get_slice( - tidx - ) + thr_copy_tile = cute.make_tiled_copy( + copy_atom, tv_layout_tile, tiler_mn_tile + ).get_slice(tidx) # Accumulate per-thread partial sums across tiles; reduce once. sum_sq_thread = cute.Float32(0.0) @@ -1499,7 +1588,13 @@ def _kernel_impl( tXp_pong = tXp_0 if row_i < shape[0]: - copy_tiled(tXgX_0, tXsX_0, num_copy_elems=vecsize, is_async=True, pred=tXp_0) + copy_tiled( + tXgX_0, + tXsX_0, + num_copy_elems=vecsize, + is_async=True, + pred=tXp_0, + ) if const_expr(mRes is not None): gRes_0 = cute.local_tile( qutils.domain_offset_i64((0, k_off0), mRes_i), @@ -1538,19 +1633,27 @@ def _kernel_impl( if const_expr((t % 2) == 0): tXsX_n = thr_copy_tile.partition_D(sX1_tile) tXsRes_n = ( - thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None + thr_copy_tile.partition_D(sRes1_tile) + if const_expr(mRes is not None) + else None ) tXp_pong = tXp_n else: tXsX_n = thr_copy_tile.partition_D(sX0_tile) tXsRes_n = ( - thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None + thr_copy_tile.partition_D(sRes0_tile) + if const_expr(mRes is not None) + else None ) tXp_ping = tXp_n if row_i < shape[0]: copy_tiled( - tXgX_n, tXsX_n, num_copy_elems=vecsize, is_async=True, pred=tXp_n + tXgX_n, + tXsX_n, + num_copy_elems=vecsize, + is_async=True, + pred=tXp_n, ) if const_expr(mRes is not None): gRes_n = cute.local_tile( @@ -1574,25 +1677,35 @@ def _kernel_impl( if const_expr((t % 2) == 0): tXsX_cur = thr_copy_tile.partition_D(sX0_tile) tXsRes_cur = ( - thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None + thr_copy_tile.partition_D(sRes0_tile) + if const_expr(mRes is not None) + else None ) pred_cur = tXp_ping else: tXsX_cur = thr_copy_tile.partition_D(sX1_tile) tXsRes_cur = ( - thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None + thr_copy_tile.partition_D(sRes1_tile) + if const_expr(mRes is not None) + else None ) pred_cur = tXp_pong k_off = t * tile_n - gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, 0)) + gX_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mX_i), + tiler_mn_tile, + (0, 0), + ) tXgX_t = thr_copy_tile.partition_S(gX_t) tXrX_t = cute.make_fragment_like(tXgX_t) cute.autovec_copy(tXsX_cur, tXrX_t) x_t = tXrX_t.load().to(cute.Float32) if const_expr(mRes is not None): gRes_t = cute.local_tile( - qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, 0) + qutils.domain_offset_i64((0, k_off), mRes_i), + tiler_mn_tile, + (0, 0), ) tXgRes_t = thr_copy_tile.partition_S(gRes_t) tXrRes_t = cute.make_fragment_like(tXgRes_t) @@ -1639,29 +1752,41 @@ def _kernel_impl( for t in cutlass.range_constexpr(num_tiles): k_off = t * tile_n - cX_t = cute.local_tile(cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, 0)) + cX_t = cute.local_tile( + cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, 0) + ) tXc_t = thr_copy_tile.partition_S(cX_t) tXp_t = qutils.predicate_k(tXc_t, limit=limit_k) if const_expr((t % 2) == 0): tXsX_cur = thr_copy_tile.partition_D(sX0_tile) tXsRes_cur = ( - thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None + thr_copy_tile.partition_D(sRes0_tile) + if const_expr(mRes is not None) + else None ) else: tXsX_cur = thr_copy_tile.partition_D(sX1_tile) tXsRes_cur = ( - thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None + thr_copy_tile.partition_D(sRes1_tile) + if const_expr(mRes is not None) + else None ) - gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, 0)) + gX_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mX_i), + tiler_mn_tile, + (0, 0), + ) tXgX_t = thr_copy_tile.partition_S(gX_t) tXrX_t = cute.make_fragment_like(tXgX_t) cute.autovec_copy(tXsX_cur, tXrX_t) x_t = tXrX_t.load().to(cute.Float32) if const_expr(mRes is not None): gRes_t = cute.local_tile( - qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, 0) + qutils.domain_offset_i64((0, k_off), mRes_i), + tiler_mn_tile, + (0, 0), ) tXgRes_t = thr_copy_tile.partition_S(gRes_t) tXrRes_t = cute.make_fragment_like(tXgRes_t) @@ -1671,43 +1796,77 @@ def _kernel_impl( y_t = x_t * rstd if const_expr(mW is not None): gW_t = cute.local_tile( - qutils.domain_offset_i64((0, k_off), mW), tiler_mn_tile, (0, 0) + qutils.domain_offset_i64((0, k_off), mW), + tiler_mn_tile, + (0, 0), ) tWgW_t = thr_copy_tile.partition_S(gW_t) tWrW_t = cute.make_fragment_like(tWgW_t) - copy_tiled(tWgW_t, tWrW_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t) + copy_tiled( + tWgW_t, + tWrW_t, + num_copy_elems=vecsize, + is_async=False, + pred=tXp_t, + ) y_t = y_t * tWrW_t.load().to(cute.Float32) if const_expr(mB is not None): gB_t = cute.local_tile( - qutils.domain_offset_i64((0, k_off), mB), tiler_mn_tile, (0, 0) + qutils.domain_offset_i64((0, k_off), mB), + tiler_mn_tile, + (0, 0), ) tWgB_t = thr_copy_tile.partition_S(gB_t) tWrB_t = cute.make_fragment_like(tWgB_t) - copy_tiled(tWgB_t, tWrB_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t) + copy_tiled( + tWgB_t, + tWrB_t, + num_copy_elems=vecsize, + is_async=False, + pred=tXp_t, + ) y_t = y_t + tWrB_t.load().to(cute.Float32) - gO_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mO_i), tiler_mn_tile, (0, 0)) + gO_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mO_i), + tiler_mn_tile, + (0, 0), + ) tXgO_t = thr_copy_tile.partition_D(gO_t) tXrO_t = cute.make_fragment_like(tXgO_t) tXrO_t.store(y_t.to(tXrO_t.element_type)) if row_i < shape[0]: - copy_tiled(tXrO_t, tXgO_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t) + copy_tiled( + tXrO_t, + tXgO_t, + num_copy_elems=vecsize, + is_async=False, + pred=tXp_t, + ) return # Single-stage path: one-row-per-CTA tXgX_i = thr_copy.partition_S(gX_i) - tXgRes_i = thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None + tXgRes_i = ( + thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None + ) tXgO_i = thr_copy.partition_D(gO_i) - tXgResO_i = thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None + tXgResO_i = ( + thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None + ) # tXgRstd_i / tXcX_i / row_i prepared above is_even_N_i = const_expr(shape[1] == tiler_mn[1] * cluster_n) tXpX_i = ( - qutils.predicate_k(thr_copy.partition_S(cX_i), limit=limit_k) if not is_even_N_i else None + qutils.predicate_k(thr_copy.partition_S(cX_i), limit=limit_k) + if not is_even_N_i + else None ) tXrX = cute.make_fragment_like(tXgX_i) - tXrRes = cute.make_fragment_like(tXgRes_i) if const_expr(mRes is not None) else None + tXrRes = ( + cute.make_fragment_like(tXgRes_i) if const_expr(mRes is not None) else None + ) if const_expr(self.direct_gmem): if const_expr(not is_even_N_i): tXrX.fill(0) @@ -1729,7 +1888,9 @@ def _kernel_impl( if row_i < shape[0]: cute.copy(copy_atom, tXgX_i, thr_copy.partition_D(sX0), pred=tXpX_i) if const_expr(mRes is not None): - cute.copy(copy_atom, tXgRes_i, thr_copy.partition_D(sRes0), pred=tXpX_i) + cute.copy( + copy_atom, tXgRes_i, thr_copy.partition_D(sRes0), pred=tXpX_i + ) if const_expr(use_async): cute.arch.cp_async_commit_group() cute.arch.cp_async_wait_group(0) @@ -1746,7 +1907,9 @@ def _kernel_impl( tXrResO.store(x_red.to(tXrResO.element_type)) if row_i < shape[0]: cute.copy( - get_copy_atom_bw(tXrResO.element_type, num_copy_elems_X, is_async=False), + get_copy_atom_bw( + tXrResO.element_type, num_copy_elems_X, is_async=False + ), tXrResO, tXgResO_i, pred=tXpX_i, @@ -1775,7 +1938,9 @@ def _kernel_impl( # pressure during the long-scoreboard reduction phase (helping occupancy # when registers are the limiting factor). if const_expr(mW is not None): - gW = cute.local_tile(qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0)) + gW = cute.local_tile( + qutils.domain_offset_i64((0, n_off), mW), tiler_mn, (0, 0) + ) tXgW = thr_copy.partition_S(gW) tXrW = cute.make_fragment_like(tXgW) if const_expr(not is_even_N_wb): @@ -1787,7 +1952,9 @@ def _kernel_impl( pred=tXp_wb, ) if const_expr(mB is not None): - gB = cute.local_tile(qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0)) + gB = cute.local_tile( + qutils.domain_offset_i64((0, n_off), mB), tiler_mn, (0, 0) + ) tXgB = thr_copy.partition_S(gB) tXrB = cute.make_fragment_like(tXgB) if const_expr(not is_even_N_wb): @@ -1818,6 +1985,7 @@ def _kernel_impl( ) if _KERNEL_ACCEPTS_LAYOUT_ARGS: + @cute.kernel def kernel( self, @@ -1853,6 +2021,7 @@ def kernel( threads_per_row, ) else: + @cute.kernel def kernel( self, @@ -2015,7 +2184,10 @@ def _rmsnorm_forward_ptr( if residual is not None: residual_out = torch.empty_strided( - residual.shape, residual.stride(), device=residual.device, dtype=residual.dtype + residual.shape, + residual.stride(), + device=residual.device, + dtype=residual.dtype, ) if store_rstd: rstd = torch.empty(M, device=x.device, dtype=torch.float32) @@ -2082,10 +2254,19 @@ def _rmsnorm_forward_ptr_into( if compiled is None: op = RMSNormSM100(N, dtype, stage=stage) ld_val = int(x.stride(0)) - ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) - ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_x = rt.make_ptr( + dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_out = rt.make_ptr( + dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) ptr_w = ( - rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + rt.make_ptr( + dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) if has_weight else None ) @@ -2122,10 +2303,19 @@ def _rmsnorm_forward_ptr_into( launcher.launch(x=x, weight=weight, out=out, M=M, N=N, ld=ld_val, eps=eps) return - ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) - ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_x = rt.make_ptr( + dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_out = rt.make_ptr( + dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) ptr_w = ( - rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + rt.make_ptr( + dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) if has_weight else None ) @@ -2167,33 +2357,55 @@ def _rmsnorm_forward_ptr_into( compiled = _PTR_COMPILE_CACHE.get(key) if compiled is None: op = RMSNormSM100(N, dtype, stage=stage) - ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) - ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_x = rt.make_ptr( + dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_out = rt.make_ptr( + dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) ptr_res = ( - rt.make_ptr(dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + rt.make_ptr( + dtype, + residual.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) if residual is not None else None ) ptr_res_out = ( rt.make_ptr( - dtype, residual_out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, + residual_out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, ) if residual_out is not None else None ) ptr_w = ( - rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + rt.make_ptr( + dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) if weight is not None else None ) ptr_b = ( - rt.make_ptr(dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + rt.make_ptr( + dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) if bias is not None else None ) ptr_rstd = ( rt.make_ptr( - cutlass.Float32, rstd.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=4 + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, ) if rstd is not None else None @@ -2216,30 +2428,50 @@ def _rmsnorm_forward_ptr_into( Float32(eps), ) _PTR_COMPILE_CACHE[key] = compiled - ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) - ptr_out = rt.make_ptr(dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_x = rt.make_ptr( + dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_out = rt.make_ptr( + dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) ptr_res = ( - rt.make_ptr(dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + rt.make_ptr( + dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) if residual is not None else None ) ptr_res_out = ( - rt.make_ptr(dtype, residual_out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + rt.make_ptr( + dtype, + residual_out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) if residual_out is not None else None ) ptr_w = ( - rt.make_ptr(dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + rt.make_ptr( + dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) if weight is not None else None ) ptr_b = ( - rt.make_ptr(dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + rt.make_ptr( + dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) if bias is not None else None ) ptr_rstd = ( - rt.make_ptr(cutlass.Float32, rstd.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=4) + rt.make_ptr( + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) if rstd is not None else None ) @@ -2296,7 +2528,9 @@ def _fused_add_rmsnorm_forward_ptr_inplace( # benchmark other models/shapes, you can override it with: # - OINK_RMSNORM_DIRECT_GMEM=0 (force staging/cp.async path) # - OINK_RMSNORM_DIRECT_GMEM=1 (force direct-gmem path) - direct_gmem = _direct_gmem_from_policy(default=bool(dtype.width == 16 and N == 7168)) + direct_gmem = _direct_gmem_from_policy( + default=bool(dtype.width == 16 and N == 7168) + ) use_async = not direct_gmem tpr_override: Optional[int] = None nt_override: Optional[int] = None @@ -2340,13 +2574,9 @@ def _fused_add_rmsnorm_forward_ptr_inplace( tpr_override = 256 nt_override = 256 - can_use_256 = bool( direct_gmem - and ( - direct_gmem_max_copy_bits is None - or direct_gmem_max_copy_bits >= 256 - ) + and (direct_gmem_max_copy_bits is None or direct_gmem_max_copy_bits >= 256) and dtype.width == 16 and (x.data_ptr() % 32) == 0 and (residual.data_ptr() % 32) == 0 @@ -2395,13 +2625,22 @@ def _fused_add_rmsnorm_forward_ptr_inplace( if cluster_n_override is not None: op._cluster_n_override = cluster_n_override # type: ignore[attr-defined] ptr_x = rt.make_ptr( - dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) ptr_res = rt.make_ptr( - dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + dtype, + residual.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) ptr_w = rt.make_ptr( - dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) stream = cuda.CUstream(stream_handle) ld_x = Int32(int(x.stride(0))) @@ -2444,12 +2683,20 @@ def _fused_add_rmsnorm_forward_ptr_inplace( # Fast-launch is disabled/unavailable (or CuTeDSL internals changed). Fall back # to calling the compiled function directly. - ptr_x = rt.make_ptr(dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align) + ptr_x = rt.make_ptr( + dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) ptr_res = rt.make_ptr( - dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + dtype, + residual.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) ptr_w = rt.make_ptr( - dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) stream = cuda.CUstream(stream_handle) ld_x = Int32(int(x.stride(0))) @@ -2486,7 +2733,12 @@ def rmsnorm_forward( rms2 = None # type: ignore[assignment] if rms2 is not None: y, rstd, residual_out = rms2.rmsnorm_forward_with_stage2( - x, weight=weight, bias=bias, residual=residual, eps=eps, store_rstd=store_rstd + x, + weight=weight, + bias=bias, + residual=residual, + eps=eps, + store_rstd=store_rstd, ) # Preserve stride contracts for torch.compile consistency, even # when using the optional stage-2 implementation. @@ -2519,7 +2771,9 @@ def rmsnorm_forward( # Preserve the input stride contract even on the fallback path so # torch.compile sees a consistent output layout across all branches. if y.stride() != x.stride(): - y_strided = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) + y_strided = torch.empty_strided( + x.shape, x.stride(), device=x.device, dtype=x.dtype + ) y_strided.copy_(y) y = y_strided rstd = None From 4c9a826cb2b1a3d48fbffe156f10943b41ec8f80 Mon Sep 17 00:00:00 2001 From: Laura Wang <3700467+Laurawly@users.noreply.github.com> Date: Wed, 21 Jan 2026 11:06:37 -0800 Subject: [PATCH 4/8] oink: SM100 suite refresh (strict parity + quack-style benches) - Switch correctness gate to PyTorch ref + record err stats\n- Tighten Softmax/LayerNorm tolerances (Quack-like)\n- Quack-style benchmark suite layout + SVG plots\n- Packaging/README polish for publishability --- oink/README.md | 94 +- oink/benchmarks/README.md | 146 + oink/benchmarks/benchmark/bench_utils.py | 255 ++ .../benchmark_cross_entropy_sm100.py | 426 +++ .../benchmark_fused_add_rmsnorm_sm100.py | 296 ++ .../benchmark/benchmark_hbm_roofline_sm100.py | 226 ++ .../benchmark/benchmark_layernorm_sm100.py | 393 +++ .../benchmark/benchmark_rmsnorm_bwd_sm100.py | 434 +++ .../benchmark/benchmark_rmsnorm_sm100.py | 337 ++ .../benchmark/benchmark_softmax_sm100.py | 292 ++ .../media/sm100_bf16_oink_vs_quack.svg | 2259 +++++++++++++ .../media/sm100_bf16_oink_vs_quack_dsv3.svg | 2600 +++++++++++++++ .../sm100_bf16_oink_vs_quack_dsv3_all.svg | 2936 ++++++++++++++++ ..._bf16_oink_vs_quack_dsv3_cross_entropy.svg | 1687 ++++++++++ ...bf16_oink_vs_quack_dsv3_with_layernorm.svg | 2720 +++++++++++++++ ...m100_bf16_oink_vs_quack_with_layernorm.svg | 2580 ++++++++++++++ .../media/sm100_fp16_oink_vs_quack.svg | 2280 +++++++++++++ .../media/sm100_fp16_oink_vs_quack_dsv3.svg | 2621 +++++++++++++++ .../sm100_fp16_oink_vs_quack_dsv3_all.svg | 2957 +++++++++++++++++ ..._fp16_oink_vs_quack_dsv3_cross_entropy.svg | 1708 ++++++++++ ...fp16_oink_vs_quack_dsv3_with_layernorm.svg | 2741 +++++++++++++++ ...m100_fp16_oink_vs_quack_with_layernorm.svg | 2601 +++++++++++++++ .../benchmarks/readme/plot_quack_style_svg.py | 431 +++ oink/benchmarks/readme/run_sm100_suite.py | 302 ++ oink/benchmarks/readme/summarize_results.py | 205 ++ oink/pyproject.toml | 24 +- oink/src/kernelagent_oink/__init__.py | 28 +- .../blackwell/cross_entropy.py | 1209 +++++++ .../kernelagent_oink/blackwell/layernorm.py | 1368 ++++++++ .../kernelagent_oink/blackwell/lite_quack.py | 1001 +++++- .../src/kernelagent_oink/blackwell/rmsnorm.py | 467 ++- .../blackwell/rmsnorm_with_stage2.py | 805 +++++ .../src/kernelagent_oink/blackwell/softmax.py | 749 +++++ 33 files changed, 39026 insertions(+), 152 deletions(-) create mode 100644 oink/benchmarks/README.md create mode 100644 oink/benchmarks/benchmark/bench_utils.py create mode 100644 oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py create mode 100644 oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py create mode 100644 oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py create mode 100644 oink/benchmarks/benchmark/benchmark_layernorm_sm100.py create mode 100644 oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py create mode 100644 oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py create mode 100644 oink/benchmarks/benchmark/benchmark_softmax_sm100.py create mode 100644 oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg create mode 100644 oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg create mode 100644 oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg create mode 100644 oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg create mode 100644 oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg create mode 100644 oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg create mode 100644 oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg create mode 100644 oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg create mode 100644 oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg create mode 100644 oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg create mode 100644 oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg create mode 100644 oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg create mode 100644 oink/benchmarks/readme/plot_quack_style_svg.py create mode 100644 oink/benchmarks/readme/run_sm100_suite.py create mode 100644 oink/benchmarks/readme/summarize_results.py create mode 100644 oink/src/kernelagent_oink/blackwell/cross_entropy.py create mode 100644 oink/src/kernelagent_oink/blackwell/layernorm.py create mode 100644 oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py create mode 100644 oink/src/kernelagent_oink/blackwell/softmax.py diff --git a/oink/README.md b/oink/README.md index 427f69f..aeb0c09 100644 --- a/oink/README.md +++ b/oink/README.md @@ -1,13 +1,33 @@ -# KernelAgent Oink (vLLM plugin) +# KernelAgent-Oink -This subproject provides an **out-of-tree vLLM plugin** that registers -`torch.library.custom_op` entrypoints under the `oink::` namespace: +KernelAgent-Oink is a small **CuTeDSL (CUTLASS DSL) kernel library** for +**NVIDIA Blackwell (SM100 / GB200 / B200-class)**, bundled as a lightweight +Python package that can be used standalone or as a **vLLM general plugin**. -- `torch.ops.oink.rmsnorm` -- `torch.ops.oink.fused_add_rms_norm` +At the moment, the vLLM integration exposes the following `torch.library.custom_op` +entrypoints under the `oink::` namespace: -The implementation is backed by a CuTeDSL (CUTLASS) RMSNorm kernel tuned for -**NVIDIA Blackwell (SM100)**. +- `torch.ops.oink.rmsnorm(x, weight, eps) -> Tensor` +- `torch.ops.oink.fused_add_rms_norm(x, residual, weight, eps) -> None` (in-place) + +The package also includes additional SM100 kernels used by the benchmark suite: +LayerNorm, Softmax (fwd+bwd), and CrossEntropy (fwd+bwd). + +## Requirements + +- GPU: **SM100** for the fast CuTeDSL paths. On other GPUs, Oink falls back to + reference PyTorch implementations for correctness. +- Python dependencies: + - `nvidia-cutlass-dsl` (CuTeDSL) + - `cuda-python` + - `torch` (provided by your environment / vLLM) + +Recommended env vars: + +```bash +export CUTE_DSL_ARCH=sm_100a +export PYTORCH_ALLOC_CONF=expandable_segments:True +``` ## Install (editable) @@ -17,22 +37,23 @@ From the `KernelAgent` repo root: pip install -e ./oink ``` -This plugin requires the CuTeDSL stack: +For running the in-repo benchmark suite / plots: ```bash -pip install nvidia-cutlass-dsl cuda-python +pip install -e "./oink[bench]" ``` -## Use with vLLM +## Usage + +### vLLM (general plugin) -1. Enable the vLLM integration: +1) Enable the plugin: ```bash export VLLM_USE_OINK_RMSNORM=1 ``` -2. Ensure vLLM keeps `rms_norm` as a custom op when using `torch.compile` / -CUDA graphs. In Python: +2) Ensure vLLM keeps `rms_norm` as a custom op when using `torch.compile` / CUDA graphs: ```python from vllm import LLM @@ -45,13 +66,44 @@ llm = LLM( ) ``` -Without `+rms_norm`, Inductor may fuse RMSNorm into larger Triton kernels and -neither vLLM's CUDA RMSNorm nor Oink will run. +Without `+rms_norm`, Inductor may fuse RMSNorm into larger kernels and neither +vLLM’s CUDA RMSNorm nor Oink will run. + +### Direct PyTorch usage (manual op registration) + +For standalone use (outside vLLM), register the custom ops once: + +```python +import kernelagent_oink +import torch + +kernelagent_oink.register(force=True) + +x = torch.randn(1024, 4096, device="cuda", dtype=torch.bfloat16) +w = torch.randn(4096, device="cuda", dtype=torch.bfloat16) +y = torch.ops.oink.rmsnorm(x, w, 1e-6) +``` + +## Benchmarks + +The repo includes a Quack-style benchmark suite (tables + SVG plots) to compare +Oink against Quack on SM100 and to reproduce the reported speedups. + +- How to run + methodology: `oink/benchmarks/README.md` +- Pre-generated plots: `oink/benchmarks/media/` + +
+ SM100 BF16: Oink vs Quack (Quack-suite) +
+ +
+ SM100 BF16: Oink vs Quack (DSv3-like shapes) +
-## Notes +## Links -- This plugin is designed to be **safe to import even when disabled**; it only - registers ops when `VLLM_USE_OINK_RMSNORM` is truthy (`"1"` / `"true"`). -- The ops preserve **padded-row layouts** for 2D tensors (shape `[M, N]`, - `stride(1) == 1`, and potentially `stride(0) > N`), which is required for - `torch.compile` stride verification on some models (e.g., MLA padded inputs). +| What | Link | +|---|---| +| Quack (expert baseline) | https://github.com/Dao-AILab/quack | +| KernelAgent (agentic framework) | https://github.com/meta-pytorch/KernelAgent | +| vLLM PR (Oink RMSNorm integration) | https://github.com/vllm-project/vllm/pull/31828 | diff --git a/oink/benchmarks/README.md b/oink/benchmarks/README.md new file mode 100644 index 0000000..ceb7932 --- /dev/null +++ b/oink/benchmarks/README.md @@ -0,0 +1,146 @@ +# SM100 Benchmarks (KernelAgent-Oink vs Quack) + +This folder contains SM100 (GB200 / Blackwell) microbenchmarks for the Oink +CuTeDSL kernels vendored into KernelAgent, comparing against Quack’s SM100 +kernels where Quack provides an equivalent API. + +## Prereqs + +- GPU: **SM100** (`torch.cuda.get_device_capability() == (10, 0)`). +- Python deps in your environment: + - `torch` + - `nvidia-cutlass-dsl` (CuTeDSL) + - `cuda-python` + - `triton` (only for `triton.testing.do_bench`) + - `quack` (optional; only needed for Oink-vs-Quack comparisons) + +Recommended env vars: + +```bash +export PYTORCH_ALLOC_CONF=expandable_segments:True +export CUTE_DSL_ARCH=sm_100a +``` + +## Shape suites + +- **Quack-suite**: `(batch, seq) ∈ {1,4,8,16,32} × {8192,16384,32768,65536,131072}`, + with `hidden = 4096` so `M = batch * seq`, `N = 4096`. +- **DeepSeek-V3-like (DSv3)** + - RMSNorm / LayerNorm / Softmax: `M ∈ {4096, 16384, 65536}`, `N ∈ {6144, 7168, 8192}` + - Cross-entropy: `M ∈ {4096, 16384, 65536}`, `N ∈ {3072, 6144, 8192, 12288}` + +## Correctness gates + +By default, each script runs a per-shape `torch.testing.assert_close` check +vs a **pure-PyTorch reference** **before** emitting timing numbers. When Quack +is available for that op/path, the script also validates Quack vs the *same* +reference (so speedups can’t come from looser numerics). + +Disable with `--skip-verify` only for quick smoke tests. + +## Running benchmarks + +All scripts support: + +- `--quack-suite` or `--dsv3` (or `--configs MxN,...`) +- `--dtype {bf16,fp16,fp32}` +- `--iters ` and `--warmup-ms ` for kernel-only timing +- `--json ` and/or `--csv ` outputs (meta + rows) + +### One-command suite + +Run the full Quack-suite + DSv3 set (Oink vs Quack) and write all JSON artifacts +to a timestamped directory: + +```bash +python oink/benchmarks/readme/run_sm100_suite.py --dtype bf16 +``` + +Turn the JSON artifacts into Markdown tables (with geomean speedups): + +```bash +python oink/benchmarks/readme/summarize_results.py --in-dir /tmp/kernelagent_oink_sm100_suite_ \ + --out /tmp/kernelagent_oink_sm100_suite_summary.md +``` + +### Measured HBM roofline (STREAM-like) + +To contextualize the `*_tbps` numbers as a fraction of a *measured* bandwidth +ceiling (rather than a theoretical spec), run: + +```bash +CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op both --gb 2 \ + --json /tmp/hbm_roofline_sm100_bf16.json +``` + +### RMSNorm forward + +```bash +python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 200 --warmup-ms 25 \ + --json /tmp/oink_rmsnorm_fwd_quack_suite.json + +python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 200 --warmup-ms 25 \ + --json /tmp/oink_rmsnorm_fwd_dsv3.json + +# vLLM-style inference weights (weight dtype == activation dtype) +python oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py --dtype bf16 --weight-dtype same --quack-suite --iters 200 --warmup-ms 25 \ + --json /tmp/oink_rmsnorm_fwd_quack_suite_wsame.json +``` + +### Fused Add + RMSNorm (vLLM-style, in-place) + +This is a good "roofline case study" kernel (heavy read/write traffic, very little extra math): + +```bash +CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --M 65536 --N 4096 \ + --json /tmp/fused_add_rmsnorm_sm100_bf16.json +``` + +### RMSNorm backward + +```bash +python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --quack-suite --iters 100 --warmup-ms 25 \ + --csv /tmp/oink_rmsnorm_bwd_quack_suite.csv + +python oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py --dtype bf16 --weight-dtype fp32 --dsv3 --iters 100 --warmup-ms 25 \ + --csv /tmp/oink_rmsnorm_bwd_dsv3.csv +``` + +### Softmax (forward + backward) + +```bash +python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \ + --json /tmp/oink_softmax_fwd_bwd_quack_suite.json + +python oink/benchmarks/benchmark/benchmark_softmax_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \ + --json /tmp/oink_softmax_fwd_bwd_dsv3.json +``` + +### Cross-entropy (forward + backward) + +```bash +python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --quack-suite --iters 50 --warmup-ms 25 \ + --json /tmp/oink_cross_entropy_fwd_bwd_quack_suite.json + +python oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py --dtype bf16 --mode fwd_bwd --dsv3 --iters 50 --warmup-ms 25 \ + --json /tmp/oink_cross_entropy_fwd_bwd_dsv3.json +``` + +### LayerNorm forward + +```bash +python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --quack-suite --iters 200 --warmup-ms 25 \ + --json /tmp/oink_layernorm_fwd_quack_suite.json + +python oink/benchmarks/benchmark/benchmark_layernorm_sm100.py --dtype bf16 --dsv3 --iters 200 --warmup-ms 25 \ + --json /tmp/oink_layernorm_fwd_dsv3.json +``` + +## Notes + +- These scripts intentionally avoid importing any external Oink checkout so the + results reflect the in-tree KernelAgent Oink kernels. +- For RMSNorm, the `rmsnorm_with_stage2` implementation is a **fallback** that + is only used when the pointer-based fast path cannot be used (e.g. when + `weight.dtype != x.dtype`, or when layouts/alignments are incompatible). You + can force it for A/B testing via `KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2=1`. diff --git a/oink/benchmarks/benchmark/bench_utils.py b/oink/benchmarks/benchmark/bench_utils.py new file mode 100644 index 0000000..0abb005 --- /dev/null +++ b/oink/benchmarks/benchmark/bench_utils.py @@ -0,0 +1,255 @@ +from __future__ import annotations + +import csv +import json +import math +import os +import subprocess +import sys +from dataclasses import asdict, dataclass +from datetime import datetime +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple + +import torch +from triton.testing import do_bench as triton_do_bench + + +@dataclass(frozen=True) +class DeviceMeta: + device: str + capability: Tuple[int, int] + torch: str + cuda: str + cute_dsl_arch: str + git_sha: str + timestamp: str + + +def _try_git_sha() -> str: + here = os.path.dirname(os.path.abspath(__file__)) + repo_root = os.path.abspath(os.path.join(here, "..", "..")) + try: + out = subprocess.check_output( + ["git", "rev-parse", "HEAD"], + cwd=repo_root, + stderr=subprocess.DEVNULL, + text=True, + ) + return out.strip() + except Exception: + return "" + + +def collect_device_meta(device: Optional[torch.device] = None) -> DeviceMeta: + if device is None: + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + timestamp = datetime.now().isoformat(timespec="seconds") + return DeviceMeta( + device=str(props.name), + capability=(int(props.major), int(props.minor)), + torch=str(torch.__version__), + cuda=str(getattr(torch.version, "cuda", "unknown")), + cute_dsl_arch=os.environ.get("CUTE_DSL_ARCH", ""), + git_sha=_try_git_sha(), + timestamp=timestamp, + ) + + +def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float: + """Approximate HBM peak bandwidth in GB/s for roofline fractions.""" + if device is None: + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + if sm >= 100: + return 8000.0 + return 2000.0 + + +def do_bench_triton(fn: Callable[[], Any], *, warmup_ms: int = 25, rep_ms: int = 100) -> float: + """Kernel-only timing consistent with the Oink benchmark harnesses.""" + return float(triton_do_bench(fn, warmup=warmup_ms, rep=rep_ms, return_mode="mean")) + + +def parse_dtype(s: str) -> torch.dtype: + s = s.lower() + if s == "bf16": + return torch.bfloat16 + if s == "fp16": + return torch.float16 + if s == "fp32": + return torch.float32 + raise ValueError(f"Unsupported dtype: {s}") + + +def parse_configs(s: str) -> List[Tuple[int, int]]: + out: List[Tuple[int, int]] = [] + for part in s.split(","): + m, n = part.lower().split("x") + out.append((int(m), int(n))) + return out + + +def quack_suite_configs() -> List[Tuple[int, int, int]]: + """Return (batch, seq, hidden) triples following Quack's common grid (hidden=4096).""" + batch_sizes = [1, 4, 8, 16, 32] + seq_lengths = [8192, 16384, 32768, 65536, 131072] + hidden = 4096 + cfgs: List[Tuple[int, int, int]] = [] + for bs in batch_sizes: + for sl in seq_lengths: + M = bs * sl + if M * hidden > (2**31): + continue + cfgs.append((bs, sl, hidden)) + return cfgs + + +def ensure_oink_src_on_path() -> None: + """Make the in-repo KernelAgent Oink package importable without an editable install.""" + here = os.path.dirname(os.path.abspath(__file__)) + oink_src = os.path.abspath(os.path.join(here, "..", "..", "src")) + if oink_src not in sys.path: + sys.path.insert(0, oink_src) + + +def write_csv(path: str, rows: Sequence[Dict[str, Any]]) -> None: + if not rows: + return + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + file_exists = os.path.exists(path) + with open(path, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=sorted(rows[0].keys())) + if not file_exists: + writer.writeheader() + for row in rows: + writer.writerow(row) + + +def write_json(path: str, meta: DeviceMeta, rows: Sequence[Dict[str, Any]], *, extra: Dict[str, Any] | None = None) -> None: + os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) + payload: Dict[str, Any] = { + "meta": {**asdict(meta), **(extra or {})}, + "rows": list(rows), + } + with open(path, "w") as f: + json.dump(payload, f, indent=2) + + +def iter_row_blocks(M: int, block_rows: int) -> Iterable[Tuple[int, int]]: + """Yield (start, end) row index ranges for a 2D (M, N) matrix. + + The intent is to make correctness references for large tensors tractable + without materializing full float32 intermediates. + """ + if M < 0: + raise ValueError(f"M must be non-negative, got {M}") + if block_rows <= 0: + raise ValueError(f"block_rows must be > 0, got {block_rows}") + for start in range(0, M, block_rows): + yield start, min(M, start + block_rows) + + +@dataclass +class ErrorStats: + """Numerical error stats between an output and a reference. + + Notes: + - `max_abs` and `rel_l2` are computed exactly (streamed). + - `p99_abs` is computed over a deterministic strided sample of abs error + values (to keep very large tensors tractable). + """ + + max_abs: float + p99_abs: float + rel_l2: float + p99_sample_elems: int + p99_sample_stride: int + + +class ErrorStatsAccumulator: + """Stream error stats over (output_block, ref_block) pairs. + + This is intended for large 2D tensors where we compute reference results + block-by-block to avoid materializing full float32 intermediates. + """ + + def __init__(self, *, total_elems: int, p99_target_samples: int = 1_000_000): + if total_elems <= 0: + raise ValueError(f"total_elems must be > 0, got {total_elems}") + if p99_target_samples <= 0: + raise ValueError(f"p99_target_samples must be > 0, got {p99_target_samples}") + self.total_elems = int(total_elems) + self.p99_target_samples = int(p99_target_samples) + # Deterministic strided sampling across the flattened tensor order. + self.sample_stride = max(1, self.total_elems // self.p99_target_samples) + self._global_offset = 0 + + self._max_abs = 0.0 + self._err_sq_sum = 0.0 + self._ref_sq_sum = 0.0 + self._abs_err_samples: List[torch.Tensor] = [] + + def update(self, out: torch.Tensor, ref: torch.Tensor) -> None: + if out.shape != ref.shape: + raise ValueError(f"shape mismatch: out={tuple(out.shape)} ref={tuple(ref.shape)}") + + # Compute error in float32 for stable reductions. + err_f32 = (out - ref).to(torch.float32) + abs_err = err_f32.abs() + + # Exact reductions. + self._max_abs = max(self._max_abs, float(abs_err.max().item())) + self._err_sq_sum += float((err_f32 * err_f32).sum(dtype=torch.float64).item()) + ref_f32 = ref.to(torch.float32) + self._ref_sq_sum += float((ref_f32 * ref_f32).sum(dtype=torch.float64).item()) + + # Deterministic strided sample for p99_abs. + flat = abs_err.flatten() + block_elems = int(flat.numel()) + if block_elems <= 0: + return + + stride = int(self.sample_stride) + first = (-int(self._global_offset)) % stride + if first < block_elems: + idx = torch.arange(first, block_elems, step=stride, device=flat.device, dtype=torch.int64) + # Gather a modest number of values (≈ block_elems/stride). + vals = flat.index_select(0, idx).detach().to(device="cpu", dtype=torch.float32) + self._abs_err_samples.append(vals) + + self._global_offset += block_elems + + def finalize(self) -> ErrorStats: + if self._abs_err_samples: + samples = torch.cat(self._abs_err_samples, dim=0) + if samples.numel() > self.p99_target_samples: + samples = samples[: self.p99_target_samples] + p99 = float(torch.quantile(samples, 0.99).item()) if samples.numel() > 0 else 0.0 + sample_elems = int(samples.numel()) + else: + p99 = 0.0 + sample_elems = 0 + + denom = math.sqrt(self._ref_sq_sum) if self._ref_sq_sum > 0 else 0.0 + rel_l2 = (math.sqrt(self._err_sq_sum) / denom) if denom > 0 else 0.0 + + return ErrorStats( + max_abs=float(self._max_abs), + p99_abs=float(p99), + rel_l2=float(rel_l2), + p99_sample_elems=int(sample_elems), + p99_sample_stride=int(self.sample_stride), + ) + + +def error_stats_to_row(prefix: str, stats: ErrorStats) -> Dict[str, Any]: + """Flatten ErrorStats into JSON-friendly row fields.""" + return { + f"{prefix}_max_abs": float(stats.max_abs), + f"{prefix}_p99_abs": float(stats.p99_abs), + f"{prefix}_rel_l2": float(stats.rel_l2), + f"{prefix}_p99_sample_elems": int(stats.p99_sample_elems), + f"{prefix}_p99_sample_stride": int(stats.p99_sample_stride), + } diff --git a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py new file mode 100644 index 0000000..8bcac15 --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py @@ -0,0 +1,426 @@ +from __future__ import annotations + +import argparse +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM. +os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + +from bench_utils import ( # noqa: E402 + ErrorStatsAccumulator, + collect_device_meta, + detect_hbm_peak_gbps, + do_bench_triton, + error_stats_to_row, + ensure_oink_src_on_path, + iter_row_blocks, + parse_configs, + parse_dtype, + quack_suite_configs, + write_csv, + write_json, +) + +ensure_oink_src_on_path() + +from kernelagent_oink.blackwell import cross_entropy as oink_ce # noqa: E402 + +try: + from quack.cross_entropy import cross_entropy_bwd as quack_ce_bwd # type: ignore + from quack.cross_entropy import cross_entropy_fwd as quack_ce_fwd # type: ignore +except Exception: + quack_ce_fwd = None + quack_ce_bwd = None + + +# Match Quack's unit-test defaults (tests/test_cross_entropy.py). +_VERIFY_TOL_LOSS = dict(atol=5e-5, rtol=1e-5) # float32 outputs (loss/lse) +_VERIFY_TOL_DX = { + torch.float32: dict(atol=5e-5, rtol=1e-5), + # FP16 `dx` is low-precision; allow ~1 ulp at typical magnitudes. + torch.float16: dict(atol=1e-3, rtol=1e-3), + # BF16 `dx` is low-precision; allow ~1 ulp at typical magnitudes. + torch.bfloat16: dict(atol=1e-2, rtol=1e-2), +} + + +def bytes_io_model_ce( + M: int, + N: int, + dtype: torch.dtype, + *, + target_dtype: torch.dtype = torch.int64, + mode: str, +) -> int: + elem = torch.tensor(0, dtype=dtype).element_size() + t_elem = torch.tensor(0, dtype=target_dtype).element_size() + # Forward: + # read logits (M*N) + read target (M) + write loss (M fp32) + write lse (M fp32) + fwd = M * N * elem + M * t_elem + 2 * M * 4 + # Backward (reduction="none" path): + # read logits (M*N) + read target (M) + read dloss (M fp32) + read lse (M fp32) + write dx (M*N) + bwd = 2 * M * N * elem + M * t_elem + 2 * M * 4 + + if mode == "fwd": + return int(fwd) + if mode == "bwd": + return int(bwd) + if mode == "fwd_bwd": + # Logical IO for dx given (logits, target, dloss): read logits + read target + # + read dloss + write dx. (Intermediate lse/loss are implementation details.) + return int(2 * M * N * elem + M * t_elem + M * 4) + raise ValueError(f"Unsupported mode: {mode}") + + +def dsv3_configs() -> List[Tuple[int, int]]: + Ms = [4096, 16384, 65536] + Ns = [3072, 6144, 8192, 12288] + return [(m, n) for m in Ms for n in Ns] + + +def _verify_parity(logits: torch.Tensor, target: torch.Tensor, *, ignore_index: int) -> dict[str, object]: + dtype = logits.dtype + ref_block_rows = 512 + dloss = torch.randn(logits.size(0), device=logits.device, dtype=torch.float32) # upstream grad + + with torch.no_grad(): + loss_o, lse_o = oink_ce.cross_entropy_forward( + logits, target, ignore_index=ignore_index, reduction="none" + ) + dx_o = oink_ce.cross_entropy_backward(dloss, logits, target, lse_o, ignore_index=ignore_index) + dx_fused_o = oink_ce.cross_entropy_fwd_bwd( + dloss, + logits, + target, + ignore_index=ignore_index, + ) + + loss_q = None + lse_q = None + dx_q = None + if quack_ce_fwd is not None and quack_ce_bwd is not None: + loss_q, lse_q = quack_ce_fwd( + logits, + target, + target_logit=None, + ignore_index=ignore_index, + return_lse=True, + return_dx=False, + inplace_backward=False, + ) + dx_q = quack_ce_bwd( + logits, + target, + dloss, + lse_q, + ignore_index=ignore_index, + inplace_backward=False, + ) + + M = int(logits.shape[0]) + N = int(logits.shape[1]) + loss_acc_ours = ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000)) + lse_acc_ours = ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000)) + dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + dx_fused_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + loss_acc_quack = ( + ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000)) + if (quack_ce_fwd is not None and quack_ce_bwd is not None) + else None + ) + lse_acc_quack = ( + ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000)) + if (quack_ce_fwd is not None and quack_ce_bwd is not None) + else None + ) + dx_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if (quack_ce_fwd is not None and quack_ce_bwd is not None) + else None + ) + + # Match Quack tests: compare to a PyTorch reference computed on float32 logits. + # Chunk over rows so we don't materialize a full (M, N) float32 tensor. + for start, end in iter_row_blocks(M, ref_block_rows): + logits_f32 = logits[start:end].float().requires_grad_(True) + target_blk = target[start:end] + dloss_blk = dloss[start:end] + + loss_ref = torch.nn.functional.cross_entropy( + logits_f32, + target_blk, + reduction="none", + ignore_index=ignore_index, + ) + lse_ref = torch.logsumexp(logits_f32, dim=-1) + (dx_ref_f32,) = torch.autograd.grad(loss_ref, logits_f32, grad_outputs=dloss_blk) + dx_ref = dx_ref_f32.to(dtype) + + torch.testing.assert_close(loss_o[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS) + torch.testing.assert_close(lse_o[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS) + torch.testing.assert_close(dx_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype]) + torch.testing.assert_close(dx_fused_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype]) + loss_acc_ours.update(loss_o[start:end], loss_ref.detach()) + lse_acc_ours.update(lse_o[start:end], lse_ref.detach()) + dx_acc_ours.update(dx_o[start:end], dx_ref) + dx_fused_acc_ours.update(dx_fused_o[start:end], dx_ref) + + if loss_q is not None and lse_q is not None and dx_q is not None: + torch.testing.assert_close(loss_q[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS) + torch.testing.assert_close(lse_q[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS) + torch.testing.assert_close(dx_q[start:end], dx_ref, **_VERIFY_TOL_DX[dtype]) + assert loss_acc_quack is not None and lse_acc_quack is not None and dx_acc_quack is not None + loss_acc_quack.update(loss_q[start:end], loss_ref.detach()) + lse_acc_quack.update(lse_q[start:end], lse_ref.detach()) + dx_acc_quack.update(dx_q[start:end], dx_ref) + + stats: dict[str, object] = {} + stats.update(error_stats_to_row("ours_err_loss", loss_acc_ours.finalize())) + stats.update(error_stats_to_row("ours_err_lse", lse_acc_ours.finalize())) + stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize())) + stats.update(error_stats_to_row("ours_err_dx_fused", dx_fused_acc_ours.finalize())) + if loss_acc_quack is not None and lse_acc_quack is not None and dx_acc_quack is not None: + stats.update(error_stats_to_row("quack_err_loss", loss_acc_quack.finalize())) + stats.update(error_stats_to_row("quack_err_lse", lse_acc_quack.finalize())) + stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize())) + return stats + + +def bench_single( + M: int, + N: int, + dtype: torch.dtype, + *, + warmup_ms: int, + iters_ms: int, + mode: str, + verify: bool, + ignore_index: int, +) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]: + device = torch.device("cuda") + logits = 0.1 * torch.randn(M, N, device=device, dtype=dtype) + target = torch.randint(0, N, (M,), device=device, dtype=torch.int64) + # Sprinkle some ignore_index entries for robustness (and to match reduction semantics). + if ignore_index is not None: + mask = torch.rand(M, device=device) < 0.01 + target[mask] = int(ignore_index) + dloss = torch.randn(M, device=device, dtype=torch.float32) + + stats: dict[str, object] = {} + if verify: + stats = _verify_parity(logits, target, ignore_index=int(ignore_index)) + + bytes_io = bytes_io_model_ce(M, N, dtype, target_dtype=target.dtype, mode=mode) + + if mode == "fwd": + fn_oink = lambda: oink_ce.cross_entropy_forward( + logits, target, ignore_index=int(ignore_index), reduction="none" + ) + fn_quack = ( + None + if quack_ce_fwd is None + else ( + lambda: quack_ce_fwd( + logits, + target, + target_logit=None, + ignore_index=int(ignore_index), + return_lse=True, + return_dx=False, + inplace_backward=False, + ) + ) + ) + elif mode == "bwd": + with torch.no_grad(): + _loss_o, lse_o = oink_ce.cross_entropy_forward( + logits, target, ignore_index=int(ignore_index), reduction="none" + ) + if quack_ce_fwd is not None: + _loss_q, lse_q = quack_ce_fwd( + logits, + target, + target_logit=None, + ignore_index=int(ignore_index), + return_lse=True, + return_dx=False, + inplace_backward=False, + ) + else: + lse_q = None + fn_oink = lambda: oink_ce.cross_entropy_backward( + dloss, logits, target, lse_o, ignore_index=int(ignore_index) + ) + fn_quack = ( + None + if (quack_ce_bwd is None or lse_q is None) + else ( + lambda: quack_ce_bwd( + logits, + target, + dloss, + lse_q, + ignore_index=int(ignore_index), + inplace_backward=False, + ) + ) + ) + elif mode == "fwd_bwd": + fn_oink = lambda: oink_ce.cross_entropy_fwd_bwd( + dloss, + logits, + target, + ignore_index=int(ignore_index), + ) + fn_quack = ( + None + if (quack_ce_fwd is None or quack_ce_bwd is None) + else ( + lambda: quack_ce_bwd( + logits, + target, + dloss, + quack_ce_fwd( + logits, + target, + target_logit=None, + ignore_index=int(ignore_index), + return_lse=True, + return_dx=False, + inplace_backward=False, + )[1], + ignore_index=int(ignore_index), + inplace_backward=False, + ) + ) + ) + else: + raise ValueError(f"Unsupported mode: {mode}") + + ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9 + + if fn_quack is None: + return (ms_oink, gbps_oink), None, stats + + ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9 + return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats + + +def main() -> None: + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + torch.cuda.set_device(0) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") + + p = argparse.ArgumentParser() + p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]) + p.add_argument("--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"]) + p.add_argument("--ignore-index", type=int, default=-100) + p.add_argument("--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only).") + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows") + p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)") + p.add_argument("--configs", type=str, default="1024x4096,8192x4096") + p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid (vocab=4096)") + p.add_argument("--dsv3", action="store_true", help="Run DSv3 set: M in {4096,16384,65536}, N in {3072,6144,8192,12288}") + p.add_argument( + "--skip-verify", + action="store_true", + help="Skip correctness checks (Oink/Quack vs PyTorch float32-logits cross entropy)", + ) + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + + if args.quack_suite: + cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] + elif args.dsv3: + cfgs = dsv3_configs() + else: + cfgs = parse_configs(args.configs) + + hbm_peak = detect_hbm_peak_gbps(device) + meta = collect_device_meta(device) + + rows_out: List[Dict[str, Any]] = [] + for (M, N) in cfgs: + print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...", flush=True) + (ms_oink, gbps_oink), quack, stats = bench_single( + M=M, + N=N, + dtype=dtype, + warmup_ms=int(args.warmup_ms), + iters_ms=int(args.iters), + mode=str(args.mode), + verify=not args.skip_verify, + ignore_index=int(args.ignore_index), + ) + row: Dict[str, Any] = { + "M": M, + "N": N, + "dtype": args.dtype, + "mode": args.mode, + "ignore_index": int(args.ignore_index), + "ours_ms": ms_oink, + "ours_gbps": gbps_oink, + "ours_tbps": gbps_oink / 1000.0, + "ours_hbm_frac": gbps_oink / hbm_peak, + } + if quack is not None: + ms_q, gbps_q = quack + row.update( + { + "quack_ms": ms_q, + "quack_gbps": gbps_q, + "quack_tbps": gbps_q / 1000.0, + "speedup_vs_quack": ms_q / ms_oink, + } + ) + row.update(stats) + rows_out.append(row) + + if args.csv is not None: + write_csv(args.csv, rows_out) + if args.json is not None: + write_json( + args.json, + meta, + rows_out, + extra={ + "method": "triton.testing.do_bench(mean)", + "warmup_ms": int(args.warmup_ms), + "rep_ms": int(args.iters), + "io_model_bytes": "mode-dependent; see bytes_io_model_ce in script", + }, + ) + + headers = ["M", "N", "mode", "ours_ms", "ours_tbps"] + if quack_ce_fwd is not None and quack_ce_bwd is not None: + headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"] + print("\nSummary:") + print(" ".join(h.rjust(14) for h in headers)) + for r in rows_out: + parts: List[str] = [] + for h in headers: + v = r.get(h) + if isinstance(v, float): + parts.append(f"{v:14.4f}") + else: + parts.append(f"{str(v):>14}") + print(" ".join(parts)) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py new file mode 100644 index 0000000..b75f892 --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +""" +Benchmark fused_add_rmsnorm (in-place) on SM100. + +This matches vLLM's fused_add_rms_norm semantics: + z = x + residual (stored into residual) + y = RMSNorm(z, w) (stored into x) + +Why this exists: +- It is a common inference hot path (vLLM). +- It is strongly memory-bound (reads/writes two MxN tensors), making it a good + roofline case study for Blackwell. + +Example: + CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --M 65536 --N 4096 \\ + --json /tmp/fused_add_rmsnorm_sm100_bf16.json + +DSv3 suite (Oink vs Quack, multi-shape): + CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --dsv3 \\ + --json /tmp/kernelagent_oink_sm100_suite_bf16/fused_add_rmsnorm_dsv3.json +""" + +import argparse +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM. +os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + +from bench_utils import ( # noqa: E402 + ErrorStatsAccumulator, + collect_device_meta, + detect_hbm_peak_gbps, + do_bench_triton, + error_stats_to_row, + ensure_oink_src_on_path, + iter_row_blocks, + parse_dtype, + write_json, +) + +ensure_oink_src_on_path() + +from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402 + +_VERIFY_TOL = { + # Align with Quack's RMSNorm unit-test defaults (tests/test_rmsnorm.py). + torch.float32: dict(atol=1e-4, rtol=1e-3), + torch.float16: dict(atol=1e-2, rtol=1e-3), + torch.bfloat16: dict(atol=1e-1, rtol=1e-2), +} + +try: + # Use the low-level mutating custom op to avoid per-iteration allocations + # (critical for fair comparisons on small/medium M). + from quack.rmsnorm import _rmsnorm_fwd as quack_rmsnorm_fwd_mut # type: ignore +except Exception: + quack_rmsnorm_fwd_mut = None + + +def dsv3_configs() -> List[Tuple[int, int]]: + Ms = [4096, 16384, 65536] + Ns = [6144, 7168, 8192] + return [(m, n) for m in Ms for n in Ns] + + +def bytes_io_model_fused_add_rmsnorm_inplace(M: int, N: int, dtype: torch.dtype) -> int: + elem = torch.tensor(0, dtype=dtype).element_size() + # Read x + read residual + write x + write residual + read weight + return int((4 * M * N + N) * elem) + + +def _verify_parity( + *, + x: torch.Tensor, + residual: torch.Tensor, + w: torch.Tensor, + eps: float, +) -> dict[str, object]: + tol = _VERIFY_TOL[x.dtype] + ref_block_rows = 4096 + M = int(x.shape[0]) + N = int(x.shape[1]) + + y_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + z_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + y_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_fwd_mut is not None else None + z_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_fwd_mut is not None else None + + x_o = x.clone() + r_o = residual.clone() + out_q = None + res_out_q = None + with torch.no_grad(): + oink_rmsnorm.fused_add_rmsnorm_inplace_(x_o, r_o, w, eps=eps) + + if quack_rmsnorm_fwd_mut is not None: + out_q = torch.empty_like(x) + res_out_q = torch.empty_like(residual) + quack_rmsnorm_fwd_mut( + x, + w, + out_q, + None, # bias + None, # rstd + None, # mean + residual, + res_out_q, + eps, + False, # is_layernorm + ) + + # Pure-PyTorch reference (float32 accumulation), chunked over rows. + M = int(x.shape[0]) + w_f32 = w.float() + for start, end in iter_row_blocks(M, ref_block_rows): + z = x[start:end] + residual[start:end] + zf = z.float() + rstd = torch.rsqrt(zf.square().mean(dim=-1, keepdim=True) + eps) + y_ref = ((zf * rstd) * w_f32).to(x.dtype) + + torch.testing.assert_close(x_o[start:end], y_ref, **tol) + torch.testing.assert_close(r_o[start:end], z, **tol) + y_acc_ours.update(x_o[start:end], y_ref) + z_acc_ours.update(r_o[start:end], z) + if out_q is not None and res_out_q is not None: + torch.testing.assert_close(out_q[start:end], y_ref, **tol) + torch.testing.assert_close(res_out_q[start:end], z, **tol) + assert y_acc_quack is not None and z_acc_quack is not None + y_acc_quack.update(out_q[start:end], y_ref) + z_acc_quack.update(res_out_q[start:end], z) + + stats: dict[str, object] = {} + stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize())) + stats.update(error_stats_to_row("ours_err_residual_out", z_acc_ours.finalize())) + if y_acc_quack is not None and z_acc_quack is not None: + stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize())) + stats.update(error_stats_to_row("quack_err_residual_out", z_acc_quack.finalize())) + return stats + + +def bench_one( + *, + M: int, + N: int, + dtype: torch.dtype, + warmup_ms: int, + iters_ms: int, + verify: bool, +) -> Dict[str, Any]: + device = torch.device("cuda") + x = torch.randn((M, N), device=device, dtype=dtype) + residual = torch.randn_like(x) + w = torch.randn((N,), device=device, dtype=dtype) + + stats: dict[str, object] = {} + if verify: + stats = _verify_parity(x=x, residual=residual, w=w, eps=1e-6) + + bytes_io = bytes_io_model_fused_add_rmsnorm_inplace(M, N, dtype) + + fn = lambda: oink_rmsnorm.fused_add_rmsnorm_inplace_(x, residual, w, eps=1e-6) + ms = do_bench_triton(fn, warmup_ms=warmup_ms, rep_ms=iters_ms) + + gbps = bytes_io / (ms * 1e-3) / 1e9 + tbps = gbps / 1000.0 + hbm_frac = gbps / detect_hbm_peak_gbps(device) + + row: Dict[str, Any] = dict( + M=int(M), + N=int(N), + dtype="bf16" if dtype is torch.bfloat16 else ("fp16" if dtype is torch.float16 else "fp32"), + ours_ms=float(ms), + ours_gbps=float(gbps), + ours_tbps=float(tbps), + ours_hbm_frac=float(hbm_frac), + ) + row.update(stats) + + if quack_rmsnorm_fwd_mut is not None: + out_q = torch.empty_like(x) + res_out_q = torch.empty_like(residual) + + fn_q = lambda: quack_rmsnorm_fwd_mut( + x, + w, + out_q, + None, # bias + None, # rstd + None, # mean + residual, + res_out_q, + 1e-6, + False, # is_layernorm + ) + ms_q = do_bench_triton(fn_q, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_q = bytes_io / (ms_q * 1e-3) / 1e9 + row.update( + dict( + quack_ms=float(ms_q), + quack_gbps=float(gbps_q), + quack_tbps=float(gbps_q / 1000.0), + speedup_vs_quack=float(ms_q / ms), + ) + ) + + return row + + +def _dtype_label(dtype: torch.dtype) -> str: + if dtype is torch.bfloat16: + return "bf16" + if dtype is torch.float16: + return "fp16" + return "fp32" + + +def _print_table(rows: List[Dict[str, Any]]) -> None: + if not rows: + return + headers = ["M", "N", "ours_ms", "ours_tbps"] + has_quack = any("quack_ms" in r for r in rows) + if has_quack: + headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"] + print("\nSummary:") + print(" ".join(h.rjust(14) for h in headers)) + for r in rows: + parts: List[str] = [] + for h in headers: + v = r.get(h) + if isinstance(v, float): + parts.append(f"{v:14.4f}") + else: + parts.append(f"{str(v):>14}") + print(" ".join(parts)) + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) + p.add_argument("--M", type=int, default=65536) + p.add_argument("--N", type=int, default=4096) + p.add_argument( + "--dsv3", + action="store_true", + help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", + ) + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument("--iters", type=int, default=200, help="rep_ms for do_bench (default: 200)") + p.add_argument("--skip-verify", action="store_true") + p.add_argument("--json", type=str, default=None) + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + meta = collect_device_meta(torch.device("cuda")) + + cfgs = dsv3_configs() if bool(args.dsv3) else [(int(args.M), int(args.N))] + rows: List[Dict[str, Any]] = [] + for (M, N) in cfgs: + print(f"bench M={M:<8d} N={N:<6d} dtype={_dtype_label(dtype)} fused_add_rmsnorm ...", flush=True) + rows.append( + bench_one( + M=int(M), + N=int(N), + dtype=dtype, + warmup_ms=int(args.warmup_ms), + iters_ms=int(args.iters), + verify=not bool(args.skip_verify), + ) + ) + + _print_table(rows) + + if args.json: + write_json( + args.json, + meta, + rows, + extra=dict( + io_model_bytes="(4*M*N + N)*elem_size", + warmup_ms=int(args.warmup_ms), + rep_ms=int(args.iters), + method="triton.testing.do_bench(mean)", + note="Oink fused_add_rmsnorm_inplace_ vs Quack quack::_rmsnorm_fwd(residual=..., residual_out=...) when available", + ), + ) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py new file mode 100644 index 0000000..971a03c --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +""" +HBM roofline microbenchmark for SM100 (GB200 / Blackwell). + +This script measures a STREAM-like bandwidth ceiling using a simple Triton kernel +that performs a large contiguous copy (read + write) and/or triad (read + read + write) +over a large buffer. + +Why this exists: +- The benchmark harnesses for Oink ops report an "ours_tbps" derived from an IO model. +- For roofline discussions, comparing against a *measured* device bandwidth ceiling + is often more meaningful than quoting a marketing/theoretical spec. + +Example: + CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype bf16 --op copy --gb 2 + CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype fp16 --op triad --gb 2 +""" + +import argparse +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +import triton +import triton.language as tl + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +from bench_utils import ( # noqa: E402 + collect_device_meta, + do_bench_triton, + parse_dtype, + write_json, +) + + +@triton.jit +def _copy_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK + tl.arange(0, BLOCK) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask, other=0) + tl.store(y_ptr + offsets, x, mask=mask) + + +@triton.jit +def _triad_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK + tl.arange(0, BLOCK) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask, other=0) + y = tl.load(y_ptr + offsets, mask=mask, other=0) + tl.store(y_ptr + offsets, x + y, mask=mask) + + +def _bytes_moved(n_elements: int, elem_size: int, *, op: str) -> int: + if op == "copy": + return int(2 * n_elements * elem_size) # read x + write y + if op == "triad": + return int(3 * n_elements * elem_size) # read x + read y + write y + raise ValueError(f"Unsupported op: {op}") + + +def bench_one( + *, + n_elements: int, + dtype: torch.dtype, + op: str, + block: int, + num_warps: int, + warmup_ms: int, + iters_ms: int, +) -> Tuple[float, float]: + device = torch.device("cuda") + x = torch.empty((n_elements,), device=device, dtype=dtype) + y = torch.empty_like(x) + # Avoid pathological compression-friendly patterns (e.g. all-zeros) that can + # artificially inflate apparent bandwidth on some GPUs. Random-ish data is + # a closer match to ML workloads. + x.uniform_(-1, 1) + y.uniform_(-1, 1) + + grid = (triton.cdiv(n_elements, block),) + + if op == "copy": + launch = lambda: _copy_kernel[grid]( + x, + y, + n_elements, + BLOCK=block, + num_warps=num_warps, + num_stages=4, + ) + elif op == "triad": + launch = lambda: _triad_kernel[grid]( + x, + y, + n_elements, + BLOCK=block, + num_warps=num_warps, + num_stages=4, + ) + else: + raise ValueError(f"Unsupported op: {op}") + + # Force compilation out of the timed region. + launch() + torch.cuda.synchronize() + + ms = do_bench_triton(launch, warmup_ms=warmup_ms, rep_ms=iters_ms) + moved = _bytes_moved(n_elements, x.element_size(), op=op) + tbps = moved / (ms * 1e-3) / 1e12 + return ms, tbps + + +def _print_summary(rows: List[Dict[str, Any]]) -> None: + if not rows: + return + best = max(rows, key=lambda r: float(r["tbps"])) + print("\nSummary (STREAM-like):") + print(f"- best_tbps: {best['tbps']:.3f} TB/s ({best['op']}, BLOCK={best['block']}, warps={best['num_warps']})") + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) + p.add_argument("--op", type=str, default="copy", choices=["copy", "triad", "both"]) + p.add_argument("--gb", type=float, default=2.0, help="Size per tensor in GB (default: 2)") + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument("--iters", type=int, default=100, help="rep_ms for do_bench (default: 100)") + p.add_argument("--json", type=str, default=None, help="Write JSON results to this path") + p.add_argument("--no-sweep", action="store_true", help="Disable tuning sweep; run a single config") + p.add_argument("--block", type=int, default=2048, help="BLOCK size when --no-sweep is set") + p.add_argument("--warps", type=int, default=8, help="num_warps when --no-sweep is set") + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + cap = (int(props.major), int(props.minor)) + if cap != (10, 0): + raise RuntimeError(f"Expected SM100 (10,0), got {cap} ({props.name})") + + elem_size = torch.tensor(0, dtype=dtype).element_size() + bytes_per_tensor = int(args.gb * (1024**3)) + n_elements = max(1, bytes_per_tensor // elem_size) + + ops: List[str] + if args.op == "both": + ops = ["copy", "triad"] + else: + ops = [args.op] + + if args.no_sweep: + sweep: List[Tuple[int, int]] = [(int(args.block), int(args.warps))] + else: + # A tiny hand-tuned sweep that keeps compile overhead reasonable. + sweep = [ + (1024, 4), + (1024, 8), + (2048, 4), + (2048, 8), + (4096, 8), + ] + + print(f"Running on {props.name} (SM{props.major}{props.minor})") + print(f"- dtype: {args.dtype} (elem={elem_size}B)") + print(f"- n_elements: {n_elements:,} (~{(n_elements * elem_size) / (1024**3):.2f} GiB per tensor)") + print(f"- ops: {ops}") + print(f"- sweep: {sweep}") + + meta = collect_device_meta(device) + rows: List[Dict[str, Any]] = [] + for op in ops: + for block, warps in sweep: + ms, tbps = bench_one( + n_elements=n_elements, + dtype=dtype, + op=op, + block=block, + num_warps=warps, + warmup_ms=int(args.warmup_ms), + iters_ms=int(args.iters), + ) + rows.append( + dict( + op=op, + dtype=str(args.dtype), + n_elements=int(n_elements), + elem_size_B=int(elem_size), + block=int(block), + num_warps=int(warps), + warmup_ms=int(args.warmup_ms), + rep_ms=int(args.iters), + ms=float(ms), + tbps=float(tbps), + ) + ) + print(f"- {op:5s} BLOCK={block:4d} warps={warps}: {tbps:.3f} TB/s ({ms:.4f} ms)") + + _print_summary(rows) + + if args.json: + # Write meta + detailed rows for reproducibility. + extra = dict( + bytes_model="copy:2*N*elem, triad:3*N*elem", + bytes_per_tensor=int(bytes_per_tensor), + gb_per_tensor=float(args.gb), + ) + write_json(args.json, meta, rows, extra=extra) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py new file mode 100644 index 0000000..778e3e2 --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py @@ -0,0 +1,393 @@ +from __future__ import annotations + +import argparse +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM. +os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + +from bench_utils import ( # noqa: E402 + ErrorStatsAccumulator, + collect_device_meta, + detect_hbm_peak_gbps, + do_bench_triton, + error_stats_to_row, + ensure_oink_src_on_path, + iter_row_blocks, + parse_configs, + parse_dtype, + quack_suite_configs, + write_csv, + write_json, +) + +ensure_oink_src_on_path() + +from kernelagent_oink.blackwell import layernorm as oink_ln # noqa: E402 + +try: + # Quack exposes LayerNorm through the RMSNorm module (is_layernorm=True path). + from quack.rmsnorm import layernorm_fwd as quack_layernorm # type: ignore +except Exception: + quack_layernorm = None + +_VERIFY_TOL_Y = { + # Match Quack's unit-test defaults (tests/test_layernorm.py). + torch.float32: dict(atol=1e-4, rtol=1e-4), + torch.float16: dict(atol=1e-3, rtol=1e-3), + torch.bfloat16: dict(atol=1e-2, rtol=1e-2), +} + +# Quack checks rstd/mean (fp32) with a tighter fixed tolerance. +_VERIFY_TOL_STATS = dict(atol=6e-4, rtol=6e-4) + + +def bytes_io_model_layernorm( + M: int, + N: int, + dtype: torch.dtype, + *, + has_bias: bool, + return_rstd: bool, + return_mean: bool, + weight_dtype: torch.dtype = torch.float32, +) -> int: + elem = torch.tensor(0, dtype=dtype).element_size() + w_elem = torch.tensor(0, dtype=weight_dtype).element_size() + total = 0 + # Read x + write y + total += 2 * M * N * elem + # Read weight (+ optional bias) along feature dim + total += N * w_elem + if has_bias: + total += N * w_elem + # Optional per-row stats (fp32) + if return_rstd: + total += M * 4 + if return_mean: + total += M * 4 + return int(total) + + +def dsv3_configs() -> List[Tuple[int, int]]: + Ms = [4096, 16384, 65536] + Ns = [6144, 7168, 8192] + return [(m, n) for m in Ms for n in Ns] + + +def _verify_parity( + x: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor | None, + *, + eps: float, + return_rstd: bool, + return_mean: bool, +) -> dict[str, object]: + tol_y = _VERIFY_TOL_Y[x.dtype] + ref_block_rows = 4096 + M = int(x.shape[0]) + N = int(x.shape[1]) + + y_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + y_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) if (quack_layernorm is not None and b is None) else None + ) + with torch.no_grad(): + ours = oink_ln.layernorm( + x, + w, + bias=b, + eps=eps, + return_rstd=return_rstd, + return_mean=return_mean, + ) + quack = None + if quack_layernorm is not None and b is None: + quack = quack_layernorm( + x, + w, + eps=eps, + return_rstd=return_rstd, + return_mean=return_mean, + ) + torch.cuda.synchronize() + + def _unpack(out): + if return_rstd and return_mean: + y, rstd, mean = out + elif return_rstd and not return_mean: + y, rstd = out + mean = None + elif return_mean and not return_rstd: + y, mean = out + rstd = None + else: + y, rstd, mean = out, None, None + return y, rstd, mean + + y_o, rstd_o, mean_o = _unpack(ours) + y_q, rstd_q, mean_q = _unpack(quack) if quack is not None else (None, None, None) + + # Pure-PyTorch reference (float32 accumulation), matching Quack's unit tests: + # - compute ref output via F.layer_norm on float32 + # - compute mean/rstd from float32 input + rstd_ref_all = torch.empty((M,), device=x.device, dtype=torch.float32) if return_rstd else None + mean_ref_all = torch.empty((M,), device=x.device, dtype=torch.float32) if return_mean else None + + for start, end in iter_row_blocks(M, ref_block_rows): + x_f32 = x[start:end].float() + y_ref_f32 = torch.nn.functional.layer_norm(x_f32, w.shape, w, b, eps) + y_ref = y_ref_f32.to(x.dtype) + torch.testing.assert_close(y_o[start:end], y_ref, **tol_y) + y_acc_ours.update(y_o[start:end], y_ref) + if y_q is not None: + torch.testing.assert_close(y_q[start:end], y_ref, **tol_y) + assert y_acc_quack is not None + y_acc_quack.update(y_q[start:end], y_ref) + + # Per-row stats in fp32, as in Quack's tests. + if return_rstd or return_mean: + mean_f32 = x_f32.mean(dim=-1) + if return_mean: + assert mean_ref_all is not None + mean_ref_all[start:end] = mean_f32 + if return_rstd: + var_f32 = ((x_f32 - mean_f32.unsqueeze(1)) ** 2).mean(dim=-1) + rstd_ref = 1.0 / torch.sqrt(var_f32 + eps) + assert rstd_ref_all is not None + rstd_ref_all[start:end] = rstd_ref + + assert rstd_o is not None + torch.testing.assert_close(rstd_o[start:end], rstd_ref, **_VERIFY_TOL_STATS) + if rstd_q is not None: + torch.testing.assert_close(rstd_q[start:end], rstd_ref, **_VERIFY_TOL_STATS) + + if return_mean: + mean_ref = mean_f32 + assert mean_o is not None + torch.testing.assert_close(mean_o[start:end], mean_ref, **_VERIFY_TOL_STATS) + if mean_q is not None: + torch.testing.assert_close(mean_q[start:end], mean_ref, **_VERIFY_TOL_STATS) + + stats: dict[str, object] = {} + stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize())) + if y_acc_quack is not None: + stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize())) + + if return_rstd: + assert rstd_o is not None and rstd_ref_all is not None + rstd_acc_ours = ErrorStatsAccumulator( + total_elems=int(rstd_ref_all.numel()), p99_target_samples=int(rstd_ref_all.numel()) + ) + rstd_acc_ours.update(rstd_o, rstd_ref_all) + stats.update(error_stats_to_row("ours_err_rstd", rstd_acc_ours.finalize())) + if rstd_q is not None: + rstd_acc_quack = ErrorStatsAccumulator( + total_elems=int(rstd_ref_all.numel()), p99_target_samples=int(rstd_ref_all.numel()) + ) + rstd_acc_quack.update(rstd_q, rstd_ref_all) + stats.update(error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize())) + + if return_mean: + assert mean_o is not None and mean_ref_all is not None + mean_acc_ours = ErrorStatsAccumulator( + total_elems=int(mean_ref_all.numel()), p99_target_samples=int(mean_ref_all.numel()) + ) + mean_acc_ours.update(mean_o, mean_ref_all) + stats.update(error_stats_to_row("ours_err_mean", mean_acc_ours.finalize())) + if mean_q is not None: + mean_acc_quack = ErrorStatsAccumulator( + total_elems=int(mean_ref_all.numel()), p99_target_samples=int(mean_ref_all.numel()) + ) + mean_acc_quack.update(mean_q, mean_ref_all) + stats.update(error_stats_to_row("quack_err_mean", mean_acc_quack.finalize())) + + return stats + + +def bench_single( + M: int, + N: int, + dtype: torch.dtype, + *, + eps: float, + warmup_ms: int, + iters_ms: int, + verify: bool, + return_rstd: bool, + return_mean: bool, + has_bias: bool, +) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]: + device = torch.device("cuda") + x = torch.randn(M, N, device=device, dtype=dtype) + w = torch.randn(N, device=device, dtype=torch.float32) + b = torch.randn(N, device=device, dtype=torch.float32) if has_bias else None + + stats: dict[str, object] = {} + if verify: + stats = _verify_parity(x, w, b, eps=eps, return_rstd=return_rstd, return_mean=return_mean) + + bytes_io = bytes_io_model_layernorm( + M, + N, + dtype, + has_bias=has_bias, + return_rstd=return_rstd, + return_mean=return_mean, + weight_dtype=w.dtype, + ) + + fn_oink = lambda: oink_ln.layernorm( + x, + w, + bias=b, + eps=eps, + return_rstd=return_rstd, + return_mean=return_mean, + ) + ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9 + + if quack_layernorm is None or has_bias: + return (ms_oink, gbps_oink), None, stats + + fn_quack = lambda: quack_layernorm( + x, + w, + eps=eps, + return_rstd=return_rstd, + return_mean=return_mean, + ) + ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9 + return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats + + +def main() -> None: + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + torch.cuda.set_device(0) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") + + p = argparse.ArgumentParser() + p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]) + p.add_argument("--eps", type=float, default=1e-6) + p.add_argument("--return-rstd", action="store_true") + p.add_argument("--return-mean", action="store_true") + p.add_argument("--with-bias", action="store_true", help="Benchmark bias path (Quack compare skipped)") + p.add_argument("--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only).") + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows") + p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)") + p.add_argument("--configs", type=str, default="1024x4096,8192x4096") + p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid (hidden=4096)") + p.add_argument( + "--dsv3", + action="store_true", + help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", + ) + p.add_argument( + "--skip-verify", + action="store_true", + help="Skip correctness checks (Oink/Quack vs a pure-PyTorch reference; Quack compare skipped when bias is enabled)", + ) + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + eps = float(args.eps) + + if args.quack_suite: + cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] + elif args.dsv3: + cfgs = dsv3_configs() + else: + cfgs = parse_configs(args.configs) + + hbm_peak = detect_hbm_peak_gbps(device) + meta = collect_device_meta(device) + + rows_out: List[Dict[str, Any]] = [] + for (M, N) in cfgs: + print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True) + (ms_oink, gbps_oink), quack, stats = bench_single( + M=M, + N=N, + dtype=dtype, + eps=eps, + warmup_ms=int(args.warmup_ms), + iters_ms=int(args.iters), + verify=not args.skip_verify, + return_rstd=bool(args.return_rstd), + return_mean=bool(args.return_mean), + has_bias=bool(args.with_bias), + ) + row: Dict[str, Any] = { + "M": M, + "N": N, + "dtype": args.dtype, + "eps": eps, + "return_rstd": bool(args.return_rstd), + "return_mean": bool(args.return_mean), + "with_bias": bool(args.with_bias), + "ours_ms": ms_oink, + "ours_gbps": gbps_oink, + "ours_tbps": gbps_oink / 1000.0, + "ours_hbm_frac": gbps_oink / hbm_peak, + } + if quack is not None: + ms_q, gbps_q = quack + row.update( + { + "quack_ms": ms_q, + "quack_gbps": gbps_q, + "quack_tbps": gbps_q / 1000.0, + "speedup_vs_quack": ms_q / ms_oink, + } + ) + row.update(stats) + rows_out.append(row) + + if args.csv is not None: + write_csv(args.csv, rows_out) + if args.json is not None: + write_json( + args.json, + meta, + rows_out, + extra={ + "method": "triton.testing.do_bench(mean)", + "warmup_ms": int(args.warmup_ms), + "rep_ms": int(args.iters), + "io_model_bytes": "see bytes_io_model_layernorm in script", + }, + ) + + headers = ["M", "N", "ours_ms", "ours_tbps"] + if quack_layernorm is not None and (not args.with_bias): + headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"] + print("\nSummary:") + print(" ".join(h.rjust(14) for h in headers)) + for r in rows_out: + parts: List[str] = [] + for h in headers: + v = r.get(h) + if isinstance(v, float): + parts.append(f"{v:14.4f}") + else: + parts.append(f"{str(v):>14}") + print(" ".join(parts)) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py new file mode 100644 index 0000000..01c390d --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py @@ -0,0 +1,434 @@ +from __future__ import annotations + +import argparse +import csv +import os +import sys +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +from triton.testing import do_bench as triton_do_bench + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM. +os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + +# Make the in-repo KernelAgent Oink package importable without an editable install. +_HERE = os.path.dirname(os.path.abspath(__file__)) +_OINK_SRC = os.path.abspath(os.path.join(_HERE, "..", "src")) +if _OINK_SRC not in sys.path: + sys.path.insert(0, _OINK_SRC) + +from bench_utils import ( # noqa: E402 + ErrorStatsAccumulator, + collect_device_meta, + error_stats_to_row, + iter_row_blocks, + write_json, +) +from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402 + +try: + from quack.rmsnorm import rmsnorm_bwd as quack_rmsnorm_bwd # type: ignore +except Exception: + quack_rmsnorm_bwd = None + +_VERIFY_TOL_DX = { + # Match Quack's unit-test defaults (tests/test_rmsnorm.py). + torch.float32: dict(atol=1e-4, rtol=1e-3), + torch.float16: dict(atol=1e-2, rtol=1e-3), + torch.bfloat16: dict(atol=1e-1, rtol=1e-2), +} + + +def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float: + """Approximate HBM peak bandwidth in GB/s for roofline fractions.""" + if device is None: + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + if sm >= 100: + return 8000.0 + return 2000.0 + + +@dataclass +class Result: + ms: float + gbps: float + + +def do_bench_triton(fn, warmup_ms: int = 25, rep_ms: int = 100) -> float: + # Kernel-only timing consistent with the existing Oink forward harness. + return float(triton_do_bench(fn, warmup=warmup_ms, rep=rep_ms, return_mode="mean")) + + +def bytes_io_model_bwd( + M: int, N: int, dtype: torch.dtype, *, weight_dtype: torch.dtype = torch.float32 +) -> int: + """A simple IO model for RMSNorm backward. + + This intentionally ignores partial-reduction scratch buffers (`dw_partial` / + `db_partial`) since those are highly implementation-specific and depend on + sm_count; we still report speedups and times regardless. + """ + elem = torch.tensor(0, dtype=dtype).element_size() + w_elem = torch.tensor(0, dtype=weight_dtype).element_size() + # Read x + dout + write dx + total = 3 * M * N * elem + # Read weight + write dw + total += 2 * N * w_elem + # Read rstd (fp32 per row) + total += M * 4 + return int(total) + + +def parse_dtype(s: str) -> torch.dtype: + s = s.lower() + if s == "bf16": + return torch.bfloat16 + if s == "fp16": + return torch.float16 + if s == "fp32": + return torch.float32 + raise ValueError(f"Unsupported dtype: {s}") + + +def parse_configs(s: str) -> List[Tuple[int, int]]: + out: List[Tuple[int, int]] = [] + for part in s.split(","): + m, n = part.lower().split("x") + out.append((int(m), int(n))) + return out + + +def quack_suite_configs() -> List[Tuple[int, int, int]]: + """Return (batch, seq, hidden) triples following Quack's grid (hidden=4096).""" + batch_sizes = [1, 4, 8, 16, 32] + seq_lengths = [8192, 16384, 32768, 65536, 131072] + hidden = 4096 + cfgs: List[Tuple[int, int, int]] = [] + for bs in batch_sizes: + for sl in seq_lengths: + M = bs * sl + if M * hidden > (2**31): + continue + cfgs.append((bs, sl, hidden)) + return cfgs + + +def dsv3_configs() -> List[Tuple[int, int]]: + Ms = [4096, 16384, 65536] + Ns = [6144, 7168, 8192] + return [(m, n) for m in Ms for n in Ns] + + +def _verify_parity( + x: torch.Tensor, + w: torch.Tensor, + dout: torch.Tensor, + rstd: torch.Tensor, + *, + has_bias: bool, + has_residual: bool, +) -> dict[str, object]: + tol_dx = _VERIFY_TOL_DX[x.dtype] + ref_block_rows = 1024 + M, N = int(x.shape[0]), int(x.shape[1]) + + dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + dx_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_bwd is not None else None + + with torch.no_grad(): + dx_oink, dw_oink, db_oink, dres_oink = oink_rmsnorm.rmsnorm_backward( + x, + w, + dout, + rstd, + dresidual_out=None, + has_bias=has_bias, + has_residual=has_residual, + ) + + dx_quack = None + dw_quack = None + db_quack = None + dres_quack = None + if quack_rmsnorm_bwd is not None: + dx_quack, dw_quack, db_quack, dres_quack = quack_rmsnorm_bwd( + x, + w, + dout, + rstd, + dresidual_out=None, + has_bias=has_bias, + has_residual=has_residual, + ) + torch.cuda.synchronize() + + # Pure-PyTorch reference, matching Quack's rmsnorm_bwd_ref (float32 math for x_hat). + # Chunk over rows to avoid materializing an (M, N) float32 tensor for large shapes. + dw_accum = torch.zeros((N,), device=x.device, dtype=torch.float32) + w_f32 = w.float() + for start, end in iter_row_blocks(M, ref_block_rows): + x_f32 = x[start:end].float() + rstd_blk = rstd[start:end] + x_hat = x_f32 * rstd_blk.unsqueeze(1) + # Match Quack/PyTorch reference behavior: gradient math uses float32 + # intermediates even when (x, w, dout) are bf16/fp16. + dout_f32 = dout[start:end].float() + wdy = dout_f32 * w_f32 + c1 = (x_hat * wdy).mean(dim=-1, keepdim=True) + dx_ref = ((wdy - x_hat * c1) * rstd_blk.unsqueeze(1)).to(x.dtype) + + torch.testing.assert_close(dx_oink[start:end], dx_ref, **tol_dx) + dx_acc_ours.update(dx_oink[start:end], dx_ref) + if dx_quack is not None: + torch.testing.assert_close(dx_quack[start:end], dx_ref, **tol_dx) + assert dx_acc_quack is not None + dx_acc_quack.update(dx_quack[start:end], dx_ref) + + if dw_oink is not None: + dw_accum += (dout_f32 * x_hat).sum(dim=0) + + stats: dict[str, object] = {} + stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize())) + if dx_acc_quack is not None: + stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize())) + + if dw_oink is not None: + dw_ref = dw_accum.to(w.dtype) + if w.dtype == torch.float32: + # Weight grad is sensitive to reduction order; use a slightly larger + # absolute tolerance in the suite harness (Quack's unit tests use + # smaller M, where dw is typically tighter). + dw_tol = dict(atol=2e-3, rtol=1e-3) + else: + # For fp16/bf16 weights, `dw` is low-precision and grows with M; use an + # ulp/magnitude-aware tolerance rather than a fixed epsilon. + dw_ref_f32 = dw_ref.to(torch.float32) + dw_oink_f32 = dw_oink.to(torch.float32) + scale = float(dw_ref_f32.abs().max().item()) + dw_atol = max(2.0 * torch.finfo(w.dtype).eps * scale, 1e-3) + dw_tol = dict(atol=dw_atol, rtol=1e-3) + torch.testing.assert_close(dw_oink_f32, dw_ref_f32, **dw_tol) + if dw_quack is not None: + torch.testing.assert_close(dw_quack.to(torch.float32), dw_ref_f32, **dw_tol) + dw_tol = None # handled above + if dw_tol is not None: + torch.testing.assert_close(dw_oink, dw_ref, **dw_tol) + if dw_quack is not None: + torch.testing.assert_close(dw_quack, dw_ref, **dw_tol) + + # Record weight-grad error stats (small, so exact p99 over the full vector). + dw_acc_ours = ErrorStatsAccumulator(total_elems=int(dw_ref.numel()), p99_target_samples=int(dw_ref.numel())) + dw_acc_ours.update(dw_oink, dw_ref) + stats.update(error_stats_to_row("ours_err_dw", dw_acc_ours.finalize())) + if dw_quack is not None: + dw_acc_quack = ErrorStatsAccumulator( + total_elems=int(dw_ref.numel()), p99_target_samples=int(dw_ref.numel()) + ) + dw_acc_quack.update(dw_quack, dw_ref) + stats.update(error_stats_to_row("quack_err_dw", dw_acc_quack.finalize())) + + assert db_oink is None and db_quack is None + assert dres_oink is None and dres_quack is None + return stats + + +def bench_single( + M: int, + N: int, + dtype: torch.dtype, + weight_dtype: torch.dtype, + iters_ms: int, + eps: float, + warmup_ms: int, + verify: bool, +) -> Tuple[Result, Result | None, dict[str, object]]: + device = torch.device("cuda") + x = torch.randn(M, N, device=device, dtype=dtype) + w = torch.randn(N, device=device, dtype=weight_dtype) + dout = torch.randn(M, N, device=device, dtype=dtype) + # rstd is fp32 per row; compute once outside the timed region. + with torch.no_grad(): + xf = x.float() + rstd = torch.rsqrt(xf.square().mean(dim=-1) + eps).to(torch.float32) + + stats: dict[str, object] = {} + if verify: + stats = _verify_parity(x, w, dout, rstd, has_bias=False, has_residual=False) + + fn_oink = lambda: oink_rmsnorm.rmsnorm_backward( + x, + w, + dout, + rstd, + dresidual_out=None, + has_bias=False, + has_residual=False, + ) + + ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) + bytes_io = bytes_io_model_bwd(M, N, dtype, weight_dtype=w.dtype) + gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9 + ours = Result(ms=ms_oink, gbps=gbps_oink) + + if quack_rmsnorm_bwd is None: + return ours, None, stats + + fn_quack = lambda: quack_rmsnorm_bwd( + x, + w, + dout, + rstd, + dresidual_out=None, + has_bias=False, + has_residual=False, + ) + ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9 + return ours, Result(ms=ms_quack, gbps=gbps_quack), stats + + +def main() -> None: + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + torch.cuda.set_device(0) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") + + p = argparse.ArgumentParser() + p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]) + p.add_argument( + "--weight-dtype", + type=str, + default="fp32", + choices=["same", "fp16", "bf16", "fp32"], + help="RMSNorm weight dtype. `same` matches activation dtype.", + ) + p.add_argument("--eps", type=float, default=1e-6) + p.add_argument( + "--iters", + type=int, + default=100, + help="Triton do_bench rep_ms (kernel-only).", + ) + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows") + p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)") + p.add_argument("--configs", type=str, default="1024x4096,8192x4096") + p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid") + p.add_argument( + "--dsv3", + action="store_true", + help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", + ) + p.add_argument( + "--skip-verify", + action="store_true", + help="Skip correctness checks (Oink/Quack vs a pure-PyTorch RMSNorm backward reference)", + ) + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + if args.weight_dtype == "same": + weight_dtype = dtype + else: + weight_dtype = parse_dtype(args.weight_dtype) + eps = float(args.eps) + + if args.quack_suite: + cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] + elif args.dsv3: + cfgs = dsv3_configs() + else: + cfgs = parse_configs(args.configs) + + hbm_peak = detect_hbm_peak_gbps(device) + + rows_out: list[dict[str, object]] = [] + + for (M, N) in cfgs: + print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True) + ours, quack, stats = bench_single( + M=M, + N=N, + dtype=dtype, + weight_dtype=weight_dtype, + iters_ms=int(args.iters), + eps=eps, + warmup_ms=int(args.warmup_ms), + verify=not args.skip_verify, + ) + + row: dict[str, object] = { + "M": M, + "N": N, + "dtype": args.dtype, + "weight_dtype": args.weight_dtype, + "ours_ms": ours.ms, + "ours_gbps": ours.gbps, + "ours_tbps": ours.gbps / 1000.0, + "ours_hbm_frac": ours.gbps / hbm_peak, + } + if quack is not None: + row.update( + { + "quack_ms": quack.ms, + "quack_gbps": quack.gbps, + "quack_tbps": quack.gbps / 1000.0, + "speedup_vs_quack": quack.ms / ours.ms, + } + ) + row.update(stats) + rows_out.append(row) + + if args.csv is not None: + file_exists = os.path.exists(args.csv) + with open(args.csv, "a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=sorted(row.keys())) + if not file_exists: + writer.writeheader() + writer.writerow(row) + + if args.json is not None: + meta = collect_device_meta(device) + write_json( + args.json, + meta, + rows_out, + extra={ + "method": "triton.testing.do_bench(mean)", + "warmup_ms": int(args.warmup_ms), + "rep_ms": int(args.iters), + "io_model_bytes": "see bytes_io_model_bwd in script", + "weight_dtype": str(args.weight_dtype), + }, + ) + + # Print a small summary table. + headers = ["M", "N", "dtype", "ours_ms", "ours_tbps", "ours_hbm_frac"] + if quack_rmsnorm_bwd is not None: + headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"] + print("\nSummary:") + print(" ".join(h.rjust(14) for h in headers)) + for r in rows_out: + parts: list[str] = [] + for h in headers: + v = r.get(h) + if isinstance(v, float): + parts.append(f"{v:14.4f}") + else: + parts.append(f"{str(v):>14}") + print(" ".join(parts)) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py new file mode 100644 index 0000000..e55e9ff --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py @@ -0,0 +1,337 @@ +from __future__ import annotations + +import argparse +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM. +os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + +from bench_utils import ( # noqa: E402 + ErrorStatsAccumulator, + collect_device_meta, + detect_hbm_peak_gbps, + do_bench_triton, + error_stats_to_row, + ensure_oink_src_on_path, + iter_row_blocks, + parse_configs, + parse_dtype, + quack_suite_configs, + write_csv, + write_json, +) + +ensure_oink_src_on_path() + +from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402 + +try: + from quack.rmsnorm import rmsnorm_fwd as quack_rmsnorm_fwd # type: ignore +except Exception: + quack_rmsnorm_fwd = None + +_VERIFY_TOL_Y = { + # Match Quack's unit-test defaults (tests/test_rmsnorm.py). + torch.float32: dict(atol=1e-4, rtol=1e-3), + torch.float16: dict(atol=1e-2, rtol=1e-3), + # NOTE: bf16 ulp grows with magnitude; a slightly larger rtol is more robust + # for the large-M suite shapes (and fused paths that can see larger values). + torch.bfloat16: dict(atol=1e-1, rtol=1e-2), +} + +_VERIFY_TOL_RSTD = { + torch.float32: dict(atol=1e-5, rtol=1e-5), + torch.float16: dict(atol=1e-3, rtol=1e-3), + torch.bfloat16: dict(atol=1e-3, rtol=1e-3), +} + + +def bytes_io_model_fwd( + M: int, N: int, dtype: torch.dtype, *, weight_dtype: torch.dtype = torch.float32 +) -> int: + elem = torch.tensor(0, dtype=dtype).element_size() + w_elem = torch.tensor(0, dtype=weight_dtype).element_size() + # Read x + write y + total = 2 * M * N * elem + # Read weight + total += N * w_elem + return int(total) + + +def dsv3_configs() -> List[Tuple[int, int]]: + # DSv3-ish hidden sizes used throughout the Oink/Quack SM100 suite tables. + Ms = [4096, 16384, 65536] + Ns = [6144, 7168, 8192] + return [(m, n) for m in Ms for n in Ns] + + +def _verify_parity( + x: torch.Tensor, + w: torch.Tensor, + *, + eps: float, + store_rstd: bool, +) -> dict[str, object]: + tol_y = _VERIFY_TOL_Y[x.dtype] + tol_rstd = _VERIFY_TOL_RSTD[x.dtype] + ref_block_rows = 4096 + M = int(x.shape[0]) + N = int(x.shape[1]) + + y_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + y_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_fwd is not None else None + + with torch.no_grad(): + y_o, rstd_o, res_o = oink_rmsnorm.rmsnorm_forward( + x, + weight=w, + bias=None, + residual=None, + eps=eps, + store_rstd=store_rstd, + ) + y_q = None + rstd_q = None + if quack_rmsnorm_fwd is not None: + # Quack returns (out, residual_out, rstd). + y_q, res_q, rstd_q = quack_rmsnorm_fwd( + x, + w, + bias=None, + residual=None, + out_dtype=None, + residual_dtype=None, + eps=eps, + store_rstd=store_rstd, + ) + + # Pure-PyTorch reference (float32 accumulation), chunked over rows to avoid + # materializing an (M, N) float32 tensor for large Quack-suite shapes. + w_f32 = w.float() + rstd_ref = torch.empty((M,), device=x.device, dtype=torch.float32) + for start, end in iter_row_blocks(M, ref_block_rows): + x_f32 = x[start:end].float() + rstd_blk = torch.rsqrt(x_f32.square().mean(dim=-1) + eps) + rstd_ref[start:end] = rstd_blk + + y_ref_blk_f32 = (x_f32 * rstd_blk.unsqueeze(1)) * w_f32 + y_ref_blk = y_ref_blk_f32.to(x.dtype) + torch.testing.assert_close(y_o[start:end], y_ref_blk, **tol_y) + y_acc_ours.update(y_o[start:end], y_ref_blk) + if y_q is not None: + torch.testing.assert_close(y_q[start:end], y_ref_blk, **tol_y) + assert y_acc_quack is not None + y_acc_quack.update(y_q[start:end], y_ref_blk) + + stats: dict[str, object] = {} + stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize())) + if y_acc_quack is not None: + stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize())) + + if store_rstd: + assert rstd_o is not None + torch.testing.assert_close(rstd_o, rstd_ref, **tol_rstd) + if y_q is not None: + assert rstd_q is not None + torch.testing.assert_close(rstd_q, rstd_ref, **tol_rstd) + # Stats for rstd are cheap (M elements); compute exact p99 over all rows. + rstd_acc_ours = ErrorStatsAccumulator(total_elems=int(rstd_ref.numel()), p99_target_samples=int(rstd_ref.numel())) + rstd_acc_ours.update(rstd_o, rstd_ref) + stats.update(error_stats_to_row("ours_err_rstd", rstd_acc_ours.finalize())) + if rstd_q is not None: + rstd_acc_quack = ErrorStatsAccumulator( + total_elems=int(rstd_ref.numel()), p99_target_samples=int(rstd_ref.numel()) + ) + rstd_acc_quack.update(rstd_q, rstd_ref) + stats.update(error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize())) + # Residual output semantics differ slightly across implementations: + # - Oink returns `None` when residual is None. + # - Quack returns `x` as a safe alias in that case. + # + # For parity we focus on `y` (and optional `rstd`) for the residual=None path. + assert res_o is None + if quack_rmsnorm_fwd is not None: + assert res_q is x + return stats + + +def bench_single( + M: int, + N: int, + dtype: torch.dtype, + *, + weight_dtype: torch.dtype, + eps: float, + warmup_ms: int, + iters_ms: int, + verify: bool, + store_rstd: bool, +) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]: + device = torch.device("cuda") + x = torch.randn(M, N, device=device, dtype=dtype) + w = torch.randn(N, device=device, dtype=weight_dtype) + + stats: dict[str, object] = {} + if verify: + stats = _verify_parity(x, w, eps=eps, store_rstd=store_rstd) + + bytes_io = bytes_io_model_fwd(M, N, dtype, weight_dtype=w.dtype) + + fn_oink = lambda: oink_rmsnorm.rmsnorm_forward( + x, + weight=w, + bias=None, + residual=None, + eps=eps, + store_rstd=store_rstd, + ) + ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9 + + if quack_rmsnorm_fwd is None: + return (ms_oink, gbps_oink), None, stats + + fn_quack = lambda: quack_rmsnorm_fwd( + x, + w, + bias=None, + residual=None, + out_dtype=None, + residual_dtype=None, + eps=eps, + store_rstd=store_rstd, + ) + ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9 + return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats + + +def main() -> None: + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + torch.cuda.set_device(0) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") + + p = argparse.ArgumentParser() + p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]) + p.add_argument( + "--weight-dtype", + type=str, + default="fp32", + choices=["same", "fp16", "bf16", "fp32"], + help="RMSNorm weight dtype. `same` matches activation dtype (vLLM-style inference).", + ) + p.add_argument("--eps", type=float, default=1e-6) + p.add_argument("--store-rstd", action="store_true", help="Also write rstd (fp32 per row)") + p.add_argument("--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only).") + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows") + p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)") + p.add_argument("--configs", type=str, default="1024x4096,8192x4096") + p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid") + p.add_argument("--dsv3", action="store_true", help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}") + p.add_argument("--skip-verify", action="store_true", help="Skip correctness checks (Oink/Quack vs a pure-PyTorch reference)") + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + if args.weight_dtype == "same": + weight_dtype = dtype + else: + weight_dtype = parse_dtype(args.weight_dtype) + eps = float(args.eps) + + if args.quack_suite: + cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] + elif args.dsv3: + cfgs = dsv3_configs() + else: + cfgs = parse_configs(args.configs) + + hbm_peak = detect_hbm_peak_gbps(device) + meta = collect_device_meta(device) + + rows_out: List[Dict[str, Any]] = [] + for (M, N) in cfgs: + print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True) + (ms_oink, gbps_oink), quack, stats = bench_single( + M=M, + N=N, + dtype=dtype, + weight_dtype=weight_dtype, + eps=eps, + warmup_ms=int(args.warmup_ms), + iters_ms=int(args.iters), + verify=not args.skip_verify, + store_rstd=bool(args.store_rstd), + ) + row: Dict[str, Any] = { + "M": M, + "N": N, + "dtype": args.dtype, + "weight_dtype": args.weight_dtype, + "eps": eps, + "store_rstd": bool(args.store_rstd), + "ours_ms": ms_oink, + "ours_gbps": gbps_oink, + "ours_tbps": gbps_oink / 1000.0, + "ours_hbm_frac": gbps_oink / hbm_peak, + } + if quack is not None: + ms_q, gbps_q = quack + row.update( + { + "quack_ms": ms_q, + "quack_gbps": gbps_q, + "quack_tbps": gbps_q / 1000.0, + "speedup_vs_quack": ms_q / ms_oink, + } + ) + row.update(stats) + rows_out.append(row) + + if args.csv is not None: + write_csv(args.csv, rows_out) + if args.json is not None: + write_json( + args.json, + meta, + rows_out, + extra={ + "method": "triton.testing.do_bench(mean)", + "warmup_ms": int(args.warmup_ms), + "rep_ms": int(args.iters), + "io_model_bytes": "(2*M*N)*elem_size + N*weight_elem_size", + "store_rstd": bool(args.store_rstd), + "weight_dtype": str(args.weight_dtype), + }, + ) + + # Print a compact summary table. + headers = ["M", "N", "ours_ms", "ours_tbps"] + if quack_rmsnorm_fwd is not None: + headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"] + print("\nSummary:") + print(" ".join(h.rjust(14) for h in headers)) + for r in rows_out: + parts: List[str] = [] + for h in headers: + v = r.get(h) + if isinstance(v, float): + parts.append(f"{v:14.4f}") + else: + parts.append(f"{str(v):>14}") + print(" ".join(parts)) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py new file mode 100644 index 0000000..93c5af3 --- /dev/null +++ b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import argparse +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch + +# Reduce fragmentation pressure on busy GPUs. +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + +# Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM. +os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + +from bench_utils import ( # noqa: E402 + ErrorStatsAccumulator, + collect_device_meta, + detect_hbm_peak_gbps, + do_bench_triton, + error_stats_to_row, + ensure_oink_src_on_path, + iter_row_blocks, + parse_configs, + parse_dtype, + quack_suite_configs, + write_csv, + write_json, +) + +ensure_oink_src_on_path() + +from kernelagent_oink.blackwell import softmax as oink_softmax # noqa: E402 + +try: + from quack.softmax import softmax_bwd as quack_softmax_bwd # type: ignore + from quack.softmax import softmax_fwd as quack_softmax_fwd # type: ignore +except Exception: + quack_softmax_fwd = None + quack_softmax_bwd = None + +_VERIFY_TOL = { + # Match Quack's unit-test defaults (tests/test_softmax.py). + torch.float32: dict(atol=1e-4, rtol=1e-4), + torch.float16: dict(atol=1e-3, rtol=1e-3), + torch.bfloat16: dict(atol=1e-2, rtol=1e-2), +} + + +def bytes_io_model_softmax(M: int, N: int, dtype: torch.dtype, *, mode: str) -> int: + elem = torch.tensor(0, dtype=dtype).element_size() + if mode == "fwd": + return int(2 * M * N * elem) # read x + write y + if mode == "bwd": + return int(3 * M * N * elem) # read dy + read y + write dx + if mode == "fwd_bwd": + # Logical IO for dx given (x, dy): read x + read dy + write dx. + # (The intermediate y=softmax(x) is an implementation detail and is + # intentionally not counted here.) + return int(3 * M * N * elem) + raise ValueError(f"Unsupported mode: {mode}") + + +def dsv3_configs() -> List[Tuple[int, int]]: + Ms = [4096, 16384, 65536] + Ns = [6144, 7168, 8192] + return [(m, n) for m in Ms for n in Ns] + + +def _verify_parity(x: torch.Tensor) -> dict[str, object]: + tol = _VERIFY_TOL[x.dtype] + ref_block_rows = 4096 + dy = torch.randn_like(x) # upstream grad + + with torch.no_grad(): + y_o = oink_softmax.softmax_forward(x) + dx_o = oink_softmax.softmax_backward(dy, y_o) + dx_fused_o = oink_softmax.softmax_fwd_bwd(dy, x) + + y_q = None + dx_q = None + if quack_softmax_fwd is not None and quack_softmax_bwd is not None: + y_q = quack_softmax_fwd(x) + dx_q = quack_softmax_bwd(dy, y_q) + + M = int(x.shape[0]) + N = int(x.shape[1]) + y_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + dx_fused_acc_ours = ErrorStatsAccumulator(total_elems=M * N) + y_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if (quack_softmax_fwd is not None and quack_softmax_bwd is not None) + else None + ) + dx_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if (quack_softmax_fwd is not None and quack_softmax_bwd is not None) + else None + ) + + # Match Quack tests: compare to PyTorch softmax refs (fwd+bwd), chunked. + for start, end in iter_row_blocks(M, ref_block_rows): + x_blk = x[start:end] + dy_blk = dy[start:end] + y_ref_blk = torch.softmax(x_blk, dim=-1) + dot = torch.sum(dy_blk * y_ref_blk, dim=-1, keepdim=True, dtype=torch.float32) + dx_ref_blk = (dy_blk - dot.to(dy_blk.dtype)) * y_ref_blk + + torch.testing.assert_close(y_o[start:end], y_ref_blk, **tol) + torch.testing.assert_close(dx_o[start:end], dx_ref_blk, **tol) + torch.testing.assert_close(dx_fused_o[start:end], dx_ref_blk, **tol) + y_acc_ours.update(y_o[start:end], y_ref_blk) + dx_acc_ours.update(dx_o[start:end], dx_ref_blk) + dx_fused_acc_ours.update(dx_fused_o[start:end], dx_ref_blk) + if y_q is not None and dx_q is not None: + torch.testing.assert_close(y_q[start:end], y_ref_blk, **tol) + torch.testing.assert_close(dx_q[start:end], dx_ref_blk, **tol) + assert y_acc_quack is not None and dx_acc_quack is not None + y_acc_quack.update(y_q[start:end], y_ref_blk) + dx_acc_quack.update(dx_q[start:end], dx_ref_blk) + + stats: dict[str, object] = {} + stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize())) + stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize())) + stats.update(error_stats_to_row("ours_err_dx_fused", dx_fused_acc_ours.finalize())) + if y_acc_quack is not None and dx_acc_quack is not None: + stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize())) + stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize())) + return stats + + +def bench_single( + M: int, + N: int, + dtype: torch.dtype, + *, + warmup_ms: int, + iters_ms: int, + mode: str, + verify: bool, +) -> Tuple[Tuple[float, float], Optional[Tuple[float, float]], dict[str, object]]: + device = torch.device("cuda") + x = torch.randn(M, N, device=device, dtype=dtype) + dy = torch.randn_like(x) + + stats: dict[str, object] = {} + if verify: + stats = _verify_parity(x) + + bytes_io = bytes_io_model_softmax(M, N, dtype, mode=mode) + + if mode == "fwd": + fn_oink = lambda: oink_softmax.softmax_forward(x) + fn_quack = None if quack_softmax_fwd is None else (lambda: quack_softmax_fwd(x)) + elif mode == "bwd": + with torch.no_grad(): + y_o = oink_softmax.softmax_forward(x) + y_q = quack_softmax_fwd(x) if quack_softmax_fwd is not None else None + fn_oink = lambda: oink_softmax.softmax_backward(dy, y_o) + fn_quack = ( + None + if (quack_softmax_bwd is None or y_q is None) + else (lambda: quack_softmax_bwd(dy, y_q)) + ) + elif mode == "fwd_bwd": + fn_oink = lambda: oink_softmax.softmax_fwd_bwd(dy, x) + fn_quack = ( + None + if (quack_softmax_fwd is None or quack_softmax_bwd is None) + else (lambda: quack_softmax_bwd(dy, quack_softmax_fwd(x))) + ) + else: + raise ValueError(f"Unsupported mode: {mode}") + + ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9 + + if fn_quack is None: + return (ms_oink, gbps_oink), None, stats + + ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms) + gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9 + return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats + + +def main() -> None: + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available") + + torch.cuda.set_device(0) + device = torch.device("cuda") + props = torch.cuda.get_device_properties(device) + sm = props.major * 10 + props.minor + print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") + + p = argparse.ArgumentParser() + p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]) + p.add_argument("--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"]) + p.add_argument("--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only).") + p.add_argument("--warmup-ms", type=int, default=25) + p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows") + p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)") + p.add_argument("--configs", type=str, default="1024x4096,8192x4096") + p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid") + p.add_argument( + "--dsv3", + action="store_true", + help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", + ) + p.add_argument("--skip-verify", action="store_true", help="Skip correctness checks (Oink/Quack vs PyTorch softmax)") + args = p.parse_args() + + dtype = parse_dtype(args.dtype) + + if args.quack_suite: + cfgs = [(bs * sl, hidden) for (bs, sl, hidden) in quack_suite_configs()] + elif args.dsv3: + cfgs = dsv3_configs() + else: + cfgs = parse_configs(args.configs) + + hbm_peak = detect_hbm_peak_gbps(device) + meta = collect_device_meta(device) + + rows_out: List[Dict[str, Any]] = [] + for (M, N) in cfgs: + print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...", flush=True) + (ms_oink, gbps_oink), quack, stats = bench_single( + M=M, + N=N, + dtype=dtype, + warmup_ms=int(args.warmup_ms), + iters_ms=int(args.iters), + mode=str(args.mode), + verify=not args.skip_verify, + ) + row: Dict[str, Any] = { + "M": M, + "N": N, + "dtype": args.dtype, + "mode": args.mode, + "ours_ms": ms_oink, + "ours_gbps": gbps_oink, + "ours_tbps": gbps_oink / 1000.0, + "ours_hbm_frac": gbps_oink / hbm_peak, + } + if quack is not None: + ms_q, gbps_q = quack + row.update( + { + "quack_ms": ms_q, + "quack_gbps": gbps_q, + "quack_tbps": gbps_q / 1000.0, + "speedup_vs_quack": ms_q / ms_oink, + } + ) + row.update(stats) + rows_out.append(row) + + if args.csv is not None: + write_csv(args.csv, rows_out) + if args.json is not None: + write_json( + args.json, + meta, + rows_out, + extra={ + "method": "triton.testing.do_bench(mean)", + "warmup_ms": int(args.warmup_ms), + "rep_ms": int(args.iters), + "io_model_bytes": "mode-dependent: fwd=2*M*N, bwd=3*M*N, fwd_bwd=3*M*N (all * elem_size; fwd_bwd counts logical x+dy+dx)", + }, + ) + + headers = ["M", "N", "mode", "ours_ms", "ours_tbps"] + if quack_softmax_fwd is not None and quack_softmax_bwd is not None: + headers += ["quack_ms", "quack_tbps", "speedup_vs_quack"] + print("\nSummary:") + print(" ".join(h.rjust(14) for h in headers)) + for r in rows_out: + parts: List[str] = [] + for h in headers: + v = r.get(h) + if isinstance(v, float): + parts.append(f"{v:14.4f}") + else: + parts.append(f"{str(v):>14}") + print(" ".join(parts)) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg new file mode 100644 index 0000000..e32e3a7 --- /dev/null +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg @@ -0,0 +1,2259 @@ + + + + + + + + 2026-01-12T23:31:37.117906 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg new file mode 100644 index 0000000..b70ba9b --- /dev/null +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg @@ -0,0 +1,2600 @@ + + + + + + + + 2026-01-12T20:27:29.562089 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg new file mode 100644 index 0000000..f5cd53c --- /dev/null +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg @@ -0,0 +1,2936 @@ + + + + + + + + 2026-01-12T23:50:09.117981 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg new file mode 100644 index 0000000..db39e3c --- /dev/null +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg @@ -0,0 +1,1687 @@ + + + + + + + + 2026-01-12T23:31:44.506589 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg new file mode 100644 index 0000000..e8d4cc6 --- /dev/null +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg @@ -0,0 +1,2720 @@ + + + + + + + + 2026-01-08T16:35:17.144819 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg new file mode 100644 index 0000000..a5670bd --- /dev/null +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg @@ -0,0 +1,2580 @@ + + + + + + + + 2026-01-12T23:31:33.339254 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg new file mode 100644 index 0000000..0a021e9 --- /dev/null +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg @@ -0,0 +1,2280 @@ + + + + + + + + 2026-01-12T23:31:38.919062 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg new file mode 100644 index 0000000..9a58fde --- /dev/null +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg @@ -0,0 +1,2621 @@ + + + + + + + + 2026-01-12T20:27:30.111404 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg new file mode 100644 index 0000000..bff56d0 --- /dev/null +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg @@ -0,0 +1,2957 @@ + + + + + + + + 2026-01-12T23:50:13.556455 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg new file mode 100644 index 0000000..6a16fe8 --- /dev/null +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg @@ -0,0 +1,1708 @@ + + + + + + + + 2026-01-12T23:31:46.294935 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg new file mode 100644 index 0000000..242d013 --- /dev/null +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg @@ -0,0 +1,2741 @@ + + + + + + + + 2026-01-08T16:35:17.806957 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg new file mode 100644 index 0000000..dac54ac --- /dev/null +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg @@ -0,0 +1,2601 @@ + + + + + + + + 2026-01-12T23:31:35.225900 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/oink/benchmarks/readme/plot_quack_style_svg.py b/oink/benchmarks/readme/plot_quack_style_svg.py new file mode 100644 index 0000000..1799f2e --- /dev/null +++ b/oink/benchmarks/readme/plot_quack_style_svg.py @@ -0,0 +1,431 @@ +from __future__ import annotations + +""" +Generate Quack-style SVG performance plots (Oink vs Quack) from the SM100 suite +JSON artifacts under `/tmp/kernelagent_oink_sm100_suite_{bf16,fp16}`. + +The intent is to match Quack's README visual style: + - 3 horizontal panels (suite-dependent): + - Quack-suite: RMSNorm / Softmax / CrossEntropy + - DSv3 (hidden-size): Fused Add+RMSNorm / Softmax / LayerNorm + - DSv3 (all ops, 4-panel): Fused Add+RMSNorm / Softmax / LayerNorm / CrossEntropy + - DSv3 CrossEntropy: CrossEntropy-only (single panel) + - y-axis: model memory bandwidth (GB/s) derived from an IO model + - x-axis: a small set of labeled (M, N) shape points + - thick lines + markers, dashed y-grid, compact legend + - optional horizontal roofline line (measured STREAM-like HBM peak) + +Example: + python oink/benchmarks/readme/plot_quack_style_svg.py \\ + --in-dir /tmp/kernelagent_oink_sm100_suite_bf16 \\ + --suite quack_suite \\ + --roofline-json /tmp/hbm_roofline_sm100_bf16.json \\ + --out oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg + +For completeness, we can also include LayerNorm as an extra panel (Quack's +own README plot does not include LayerNorm): + python oink/benchmarks/readme/plot_quack_style_svg.py \\ + --in-dir /tmp/kernelagent_oink_sm100_suite_bf16 \\ + --suite quack_suite \\ + --include-layernorm \\ + --roofline-json /tmp/hbm_roofline_sm100_bf16.json \\ + --out oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg + +Note on DSv3 suite: +- The DSv3 plot intentionally covers only the hidden-size ops (fused Add+RMSNorm, + Softmax, LayerNorm) which share the same `(M, N)` sweep. +- CrossEntropy in DSv3 uses a vocab-size-like `N` sweep and is plotted separately + via `--suite dsv3_cross_entropy` to avoid a mixed x-axis with gaps. +- For README embedding convenience, `--suite dsv3_all` renders a 4-panel + single-row figure where the CrossEntropy panel uses its own x-axis. +- The RMSNorm panel uses the real block primitive (fused residual-add + RMSNorm) + when available: `fused_add_rmsnorm_dsv3.json`. +""" + +import argparse +import json +import math +import os +from collections import defaultdict +from statistics import median +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple + + +def _load_json(path: str) -> Dict[str, Any]: + with open(path) as f: + return json.load(f) + + +def _fmt_k(v: int) -> str: + # Match Quack's x-axis labels: "32K" means 32768 (1024-based). + if v % 1024 == 0: + return f"{v // 1024}K" + return str(v) + + +def _shape_label(m: int, n: int) -> str: + return f"({_fmt_k(m)}, {_fmt_k(n)})" + + +def _gbps_from_row(prefix: str, row: Mapping[str, Any]) -> Optional[float]: + # Prefer GB/s in the JSON if present; otherwise fall back to TB/s. + gbps_key = f"{prefix}_gbps" + tbps_key = f"{prefix}_tbps" + if gbps_key in row and row[gbps_key] is not None: + return float(row[gbps_key]) + if tbps_key in row and row[tbps_key] is not None: + return float(row[tbps_key]) * 1000.0 + return None + + +def _aggregate_by_shape(rows: Sequence[Mapping[str, Any]]) -> Dict[Tuple[int, int], Dict[str, float]]: + """Aggregate duplicate (M, N) rows using median (more robust than mean).""" + buckets: dict[tuple[int, int], dict[str, list[float]]] = defaultdict( + lambda: defaultdict(list) + ) + for r in rows: + m = int(r["M"]) + n = int(r["N"]) + ours = _gbps_from_row("ours", r) + quack = _gbps_from_row("quack", r) + if ours is not None: + buckets[(m, n)]["ours"].append(ours) + if quack is not None: + buckets[(m, n)]["quack"].append(quack) + + out: Dict[Tuple[int, int], Dict[str, float]] = {} + for k, vs in buckets.items(): + if not vs["ours"] or not vs["quack"]: + continue + out[k] = dict(ours=float(median(vs["ours"])), quack=float(median(vs["quack"]))) + return out + + +def _sort_shapes(shapes: Iterable[Tuple[int, int]]) -> List[Tuple[int, int]]: + # Sort by N then M to keep the x-axis stable across panels. + return sorted(set(shapes), key=lambda x: (x[1], x[0])) + + +def _read_roofline_gbps(path: str) -> float: + payload = _load_json(path) + rows = payload.get("rows", []) + best_tbps = max(float(r["tbps"]) for r in rows) + return best_tbps * 1000.0 + + +def _ensure_matplotlib(): + try: + import matplotlib as mpl # noqa: F401 + import matplotlib.pyplot as plt # noqa: F401 + except Exception as e: # pragma: no cover + raise SystemExit( + "matplotlib is required to generate SVG plots.\n" + "Install with: `python -m pip install matplotlib`" + ) from e + + +def _plot( + *, + panels: Sequence[Tuple[str, Dict[Tuple[int, int], Dict[str, float]]]], + roofline_gbps: Optional[float], + out_path: str, + title: str, + shape_policy: str, + per_panel_x: bool, +) -> None: + _ensure_matplotlib() + import matplotlib as mpl + import matplotlib.pyplot as plt + + mpl.rcParams.update( + { + # Quack-style: embed glyphs as paths for consistent rendering. + "svg.fonttype": "path", + "font.family": "DejaVu Sans", + "axes.titlesize": 18, + "axes.labelsize": 16, + "xtick.labelsize": 10, + "ytick.labelsize": 12, + } + ) + + # Colors roughly matching Quack's SVG palette. + COLOR_OINK = "#5ba3f5" + COLOR_QUACK = "#ff4444" + COLOR_ROOF = "#4d4d4d" + + fig, axes = plt.subplots( + nrows=1, + ncols=len(panels), + figsize=(6.0 * len(panels), 5.6), + constrained_layout=False, + sharey=True, + ) + if len(panels) == 1: + axes = [axes] + + max_y = 0.0 + for ax, (panel_title, data) in zip(axes, panels): + if per_panel_x: + shapes = _sort_shapes(data.keys()) + else: + # Quack-style plots use a single shared x-axis across panels. Prefer + # the intersection so every panel has a value at every x tick + # (cleaner than rendering gaps), and fall back to the union if the + # intersection is empty. + shape_sets = [set(d.keys()) for _n, d in panels] + if shape_policy in {"first", "primary"}: + shapes = _sort_shapes(shape_sets[0]) if shape_sets else [] + elif shape_policy == "intersection" and shape_sets: + common = set.intersection(*shape_sets) + shapes = _sort_shapes(common) if common else [] + elif shape_policy == "union": + shapes = _sort_shapes(s for _n, d in panels for s in d.keys()) + else: + raise ValueError(f"Unsupported shape_policy: {shape_policy}") + if not shapes: + shapes = _sort_shapes(s for _n, d in panels for s in d.keys()) + + x = list(range(len(shapes))) + x_labels = [_shape_label(m, n) for (m, n) in shapes] + + ours_y: List[float] = [] + quack_y: List[float] = [] + for s in shapes: + rec = data.get(s) + if rec is None: # only possible in shared-x mode with union + ours_y.append(math.nan) + quack_y.append(math.nan) + continue + ours_y.append(float(rec["ours"])) + quack_y.append(float(rec["quack"])) + max_y = max(max_y, *(v for v in ours_y if math.isfinite(v)), *(v for v in quack_y if math.isfinite(v))) + + ax.plot( + x, + ours_y, + marker="o", + linewidth=5, + markersize=7, + color=COLOR_OINK, + label="KernelAgent-Oink (ours)", + ) + ax.plot( + x, + quack_y, + marker="o", + linewidth=5, + markersize=7, + color=COLOR_QUACK, + label="Quack", + ) + if roofline_gbps is not None: + ax.axhline( + roofline_gbps, + color=COLOR_ROOF, + linewidth=3, + linestyle=(0, (4, 6)), + label="HBM peak (measured)" if ax is axes[0] else None, + ) + max_y = max(max_y, float(roofline_gbps)) + + ax.set_title(panel_title) + ax.set_xticks(x) + ax.set_xticklabels(x_labels, rotation=-45, ha="left") + if per_panel_x: + # DSv3 "all ops" figure: each panel has its own x-axis. Make the + # semantics explicit so readers don't assume the same `N` meaning + # across panels (CrossEntropy uses a classes/vocab-shard-like axis). + if "cross" in panel_title.lower(): + ax.set_xlabel("Shape (M, C classes)") + else: + ax.set_xlabel("Shape (M, N hidden)") + + # Quack-like dashed y-grid. + ax.grid(axis="y", linestyle=(0, (4, 7.2)), linewidth=0.8, color="#b0b0b0") + ax.set_axisbelow(True) + + # Light spines (Quack SVG uses a light gray frame). + for spine in ax.spines.values(): + spine.set_color("#d3d3d3") + spine.set_linewidth(1.5) + + axes[0].set_ylabel("Memory Bandwidth (GB/s)") + + # A little headroom above the tallest curve/roofline. + ymax = max_y * 1.08 if max_y > 0 else 1.0 + for ax in axes: + ax.set_ylim(0.0, ymax) + + # Tight layout for the axes area, reserving headroom for the suptitle and a + # shared legend. In some matplotlib versions, figure-level legends can + # overlap the middle panel title unless we reserve a slightly taller header + # band. + fig.tight_layout(rect=(0.0, 0.0, 1.0, 0.70)) + + # Single shared legend across the top (like Quack), but keep it inside the + # reserved header band so it doesn't overlap the middle panel title. + handles, labels = axes[0].get_legend_handles_labels() + # Quack's legend fits nicely in one row because their plots are 3-panel and + # therefore wide. For single-panel figures, a 3-column legend can overflow + # the canvas and get clipped in the SVG, so we stack it vertically. + legend_ncol = min(3, len(labels)) + legend_fontsize = 13 + if len(panels) == 1: + legend_ncol = 1 + legend_fontsize = 12 + fig.legend( + handles, + labels, + loc="upper center", + ncol=legend_ncol, + frameon=False, + bbox_to_anchor=(0.5, 0.91), + fontsize=legend_fontsize, + handlelength=2.5, + ) + # Single-panel figures (e.g. DSv3 CrossEntropy) are much narrower than the + # Quack-style 3-panel plots; use a slightly smaller suptitle font to avoid + # clipping in the exported SVG. + suptitle_fs = 22 if len(panels) > 1 else 18 + fig.suptitle(title, y=0.98, fontsize=suptitle_fs) + + out_path = os.path.abspath(out_path) + os.makedirs(os.path.dirname(out_path), exist_ok=True) + # Use a tight bounding box so rotated x tick labels and the figure-level + # legend don't get clipped in SVG exports (matplotlib can be fragile here + # across versions). + fig.savefig(out_path, format="svg", bbox_inches="tight", pad_inches=0.02) + plt.close(fig) + + +def _panel_files_for_suite(suite: str) -> List[Tuple[str, str]]: + if suite == "quack_suite": + return [ + ("RMSNorm (fp32 weight)", "rmsnorm_fwd_quack_suite_wfp32.json"), + ("Softmax (fwd+bwd)", "softmax_fwd_bwd_quack_suite.json"), + ("Cross-Entropy (fwd+bwd)", "cross_entropy_fwd_bwd_quack_suite.json"), + ] + if suite == "dsv3": + return [ + ("Fused Add+RMSNorm (fwd)", "fused_add_rmsnorm_dsv3.json"), + ("Softmax (fwd+bwd)", "softmax_fwd_bwd_dsv3.json"), + ("LayerNorm (fwd)", "layernorm_fwd_dsv3.json"), + ] + if suite == "dsv3_all": + return [ + ("Fused Add+RMSNorm (fwd)", "fused_add_rmsnorm_dsv3.json"), + ("Softmax (fwd+bwd)", "softmax_fwd_bwd_dsv3.json"), + ("LayerNorm (fwd)", "layernorm_fwd_dsv3.json"), + ("Cross-Entropy (fwd+bwd)", "cross_entropy_fwd_bwd_dsv3.json"), + ] + if suite == "dsv3_cross_entropy": + return [ + ("Cross-Entropy (fwd+bwd)", "cross_entropy_fwd_bwd_dsv3.json"), + ] + raise ValueError(f"Unsupported suite: {suite}") + + +def _layernorm_file_for_suite(suite: str) -> str: + if suite == "quack_suite": + return "layernorm_fwd_quack_suite.json" + raise ValueError(f"Unsupported suite: {suite}") + + +def main() -> None: + p = argparse.ArgumentParser( + description="Generate Quack-style SVG plots from KernelAgent-Oink suite JSONs." + ) + p.add_argument( + "--in-dir", type=str, required=True, help="Directory containing suite JSON outputs" + ) + p.add_argument( + "--suite", + type=str, + default="quack_suite", + choices=["quack_suite", "dsv3", "dsv3_all", "dsv3_cross_entropy"], + ) + p.add_argument( + "--include-layernorm", + action="store_true", + help="Add a LayerNorm (fwd) panel (only meaningful for `--suite quack_suite`).", + ) + p.add_argument( + "--shape-policy", + type=str, + default="intersection", + choices=["intersection", "union", "first"], + help=( + "How to pick x-axis shapes across panels. " + "`intersection` matches Quack-style (only shapes common to every panel). " + "`first` uses the first panel's shapes (keeps DSv3 N=7168 visible). " + "`union` includes every shape across panels (may create gaps)." + ), + ) + p.add_argument("--roofline-json", type=str, default=None, help="Optional /tmp/hbm_roofline_sm100_*.json path") + p.add_argument("--out", type=str, required=True, help="Output SVG path") + p.add_argument("--title", type=str, default=None, help="Optional figure title override") + args = p.parse_args() + + in_dir = os.path.abspath(args.in_dir) + if not os.path.isdir(in_dir): + raise SystemExit(f"--in-dir is not a directory: {in_dir}") + + roofline_gbps = _read_roofline_gbps(args.roofline_json) if args.roofline_json else None + + panel_files = list(_panel_files_for_suite(str(args.suite))) + if args.include_layernorm: + if args.suite != "quack_suite": + raise SystemExit("--include-layernorm is only supported for `--suite quack_suite`.") + panel_files.append(("LayerNorm (fwd)", _layernorm_file_for_suite(str(args.suite)))) + + panels: List[Tuple[str, Dict[Tuple[int, int], Dict[str, float]]]] = [] + for panel_title, filename in panel_files: + path = os.path.join(in_dir, filename) + if not os.path.exists(path): + raise SystemExit(f"Missing required JSON: {path}") + payload = _load_json(path) + rows = payload.get("rows", []) + if not isinstance(rows, list): + rows = [] + panels.append((panel_title, _aggregate_by_shape(rows))) + + if args.title is not None: + title = str(args.title) + else: + # Try to infer dtype from the first panel's JSON. + first_json = os.path.join(in_dir, panel_files[0][1]) + payload = _load_json(first_json) + rows = payload.get("rows", []) + dtype = rows[0].get("dtype", "") if rows else "" + if args.suite == "quack_suite": + suite_name = "Quack-suite" + elif args.suite == "dsv3": + suite_name = "DSv3 (hidden-size ops)" + elif args.suite == "dsv3_all": + suite_name = "DSv3 (4 ops)" + elif args.suite == "dsv3_cross_entropy": + # Keep this short: this suite is rendered as a single panel, so the + # figure is much narrower than the 3-panel plots. + suite_name = "DSv3 CrossEntropy" + else: + suite_name = str(args.suite) + suffix = " (+LayerNorm)" if (args.suite == "quack_suite" and args.include_layernorm) else "" + if args.suite == "dsv3_cross_entropy": + title = f"SM100 {dtype.upper()} — {suite_name}{suffix}" + else: + title = f"SM100 {dtype.upper()} Kernel Benchmarks (Oink vs Quack) — {suite_name}{suffix}" + + _plot( + panels=panels, + roofline_gbps=roofline_gbps, + out_path=str(args.out), + title=title, + shape_policy=str(args.shape_policy), + per_panel_x=(str(args.suite) == "dsv3_all"), + ) + print(f"Wrote: {os.path.abspath(args.out)}") + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/readme/run_sm100_suite.py b/oink/benchmarks/readme/run_sm100_suite.py new file mode 100644 index 0000000..5ac1091 --- /dev/null +++ b/oink/benchmarks/readme/run_sm100_suite.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +import argparse +import os +import subprocess +import sys +from datetime import datetime +from typing import List, Tuple + + +def _ts() -> str: + return datetime.now().strftime("%Y%m%d_%H%M%S") + + +def _run(cmd: List[str], *, dry_run: bool) -> None: + print("+", " ".join(cmd), flush=True) + if dry_run: + return + subprocess.run(cmd, check=True) + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]) + p.add_argument( + "--out-dir", + type=str, + default=None, + help="Directory to write JSON outputs (default: /tmp/kernelagent_oink_sm100_suite_)", + ) + p.add_argument( + "--skip-verify", + action="store_true", + help="Skip correctness checks (Oink/Quack vs PyTorch / pure-PyTorch references)", + ) + p.add_argument("--dry-run", action="store_true", help="Print commands without executing them") + args = p.parse_args() + + # Standardize env for standalone runs outside the vLLM plugin. + os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") + os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") + + out_dir = args.out_dir or f"/tmp/kernelagent_oink_sm100_suite_{_ts()}" + os.makedirs(out_dir, exist_ok=True) + + here = os.path.dirname(os.path.abspath(__file__)) + bench_dir = os.path.abspath(os.path.join(here, "..", "benchmark")) + py = sys.executable + + def script(name: str) -> str: + return os.path.join(bench_dir, name) + + common = ["--dtype", args.dtype] + if args.skip_verify: + common = [*common, "--skip-verify"] + + runs: List[Tuple[str, List[str]]] = [ + ( + "rmsnorm_fwd_quack_suite_wfp32", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--quack-suite", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wfp32.json"), + ], + ), + ( + "rmsnorm_fwd_dsv3_wfp32", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--dsv3", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_dsv3_wfp32.json"), + ], + ), + ( + "rmsnorm_bwd_quack_suite_wfp32", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--quack-suite", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wfp32.json"), + ], + ), + ( + "rmsnorm_bwd_dsv3_wfp32", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "fp32", + "--dsv3", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_dsv3_wfp32.json"), + ], + ), + # vLLM inference-style RMSNorm (weight dtype == activation dtype). + ( + "rmsnorm_fwd_quack_suite_wsame", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "same", + "--quack-suite", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_quack_suite_wsame.json"), + ], + ), + ( + "rmsnorm_fwd_dsv3_wsame", + [ + py, + script("benchmark_rmsnorm_sm100.py"), + *common, + "--weight-dtype", + "same", + "--dsv3", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_fwd_dsv3_wsame.json"), + ], + ), + ( + "rmsnorm_bwd_quack_suite_wsame", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "same", + "--quack-suite", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_quack_suite_wsame.json"), + ], + ), + ( + "rmsnorm_bwd_dsv3_wsame", + [ + py, + script("benchmark_rmsnorm_bwd_sm100.py"), + *common, + "--weight-dtype", + "same", + "--dsv3", + "--iters", + "100", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "rmsnorm_bwd_dsv3_wsame.json"), + ], + ), + ( + "softmax_fwd_bwd_quack_suite", + [ + py, + script("benchmark_softmax_sm100.py"), + *common, + "--mode", + "fwd_bwd", + "--quack-suite", + "--iters", + "50", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "softmax_fwd_bwd_quack_suite.json"), + ], + ), + ( + "softmax_fwd_bwd_dsv3", + [ + py, + script("benchmark_softmax_sm100.py"), + *common, + "--mode", + "fwd_bwd", + "--dsv3", + "--iters", + "50", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "softmax_fwd_bwd_dsv3.json"), + ], + ), + ( + "cross_entropy_fwd_bwd_quack_suite", + [ + py, + script("benchmark_cross_entropy_sm100.py"), + *common, + "--mode", + "fwd_bwd", + "--quack-suite", + "--iters", + "50", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "cross_entropy_fwd_bwd_quack_suite.json"), + ], + ), + ( + "cross_entropy_fwd_bwd_dsv3", + [ + py, + script("benchmark_cross_entropy_sm100.py"), + *common, + "--mode", + "fwd_bwd", + "--dsv3", + "--iters", + "50", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "cross_entropy_fwd_bwd_dsv3.json"), + ], + ), + ( + "layernorm_fwd_quack_suite", + [ + py, + script("benchmark_layernorm_sm100.py"), + *common, + "--quack-suite", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "layernorm_fwd_quack_suite.json"), + ], + ), + ( + "layernorm_fwd_dsv3", + [ + py, + script("benchmark_layernorm_sm100.py"), + *common, + "--dsv3", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "layernorm_fwd_dsv3.json"), + ], + ), + ] + + print(f"Writing results to: {out_dir}", flush=True) + for name, cmd in runs: + print(f"\n== {name} ==", flush=True) + _run(cmd, dry_run=bool(args.dry_run)) + + +if __name__ == "__main__": + main() diff --git a/oink/benchmarks/readme/summarize_results.py b/oink/benchmarks/readme/summarize_results.py new file mode 100644 index 0000000..70782dd --- /dev/null +++ b/oink/benchmarks/readme/summarize_results.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import argparse +import json +import math +import os +from typing import Any, Dict, Iterable, List, Optional, Sequence + + +def _load_json(path: str) -> Dict[str, Any]: + with open(path) as f: + return json.load(f) + + +def _fmt_cell(v: object) -> str: + if v is None: + return "" + if isinstance(v, float): + if math.isfinite(v): + av = abs(v) + # Use scientific notation for very small values so we don't render + # meaningful error stats as "0.0000". + if av != 0.0 and av < 1e-3: + return f"{v:.2e}" + return f"{v:.4f}" + return str(v) + return str(v) + + +def _md_table(rows: Sequence[Dict[str, Any]], columns: Sequence[str]) -> str: + header = "| " + " | ".join(columns) + " |" + sep = "|" + "|".join(["---"] * len(columns)) + "|" + lines = [header, sep] + for r in rows: + lines.append("| " + " | ".join(_fmt_cell(r.get(c)) for c in columns) + " |") + return "\n".join(lines) + + +def _pick_columns(rows: Sequence[Dict[str, Any]]) -> List[str]: + preferred = [ + "M", + "N", + "dtype", + "weight_dtype", + "mode", + "eps", + "store_rstd", + "return_rstd", + "return_mean", + "ignore_index", + "ours_ms", + "ours_tbps", + "ours_hbm_frac", + "quack_ms", + "quack_tbps", + "speedup_vs_quack", + ] + present = set().union(*(r.keys() for r in rows)) if rows else set() + cols = [c for c in preferred if c in present] + # Fall back to a stable sorted view if we missed everything (shouldn't happen). + return cols or sorted(present) + + +def _geomean(values: Iterable[float]) -> Optional[float]: + logs: List[float] = [] + for v in values: + if v <= 0 or not math.isfinite(v): + continue + logs.append(math.log(v)) + if not logs: + return None + return math.exp(sum(logs) / len(logs)) + + +def _collect_error_prefixes(rows: Sequence[Dict[str, Any]]) -> List[str]: + """Infer error-stat prefixes like `ours_err_dx` from row keys.""" + prefixes: set[str] = set() + for r in rows: + for k in r.keys(): + if not isinstance(k, str): + continue + if not k.endswith("_max_abs"): + continue + if "err_" not in k: + continue + prefixes.add(k[: -len("_max_abs")]) + return sorted(prefixes) + + +def _summarize_error_stats(rows: Sequence[Dict[str, Any]]) -> str: + prefixes = _collect_error_prefixes(rows) + if not prefixes: + return "" + + out_rows: List[Dict[str, Any]] = [] + for pfx in prefixes: + # Per-prefix worst-case across rows. + max_abs_vals = [float(r[pfx + "_max_abs"]) for r in rows if (pfx + "_max_abs") in r] + p99_abs_vals = [float(r[pfx + "_p99_abs"]) for r in rows if (pfx + "_p99_abs") in r] + rel_l2_vals = [float(r[pfx + "_rel_l2"]) for r in rows if (pfx + "_rel_l2") in r] + if not max_abs_vals and not p99_abs_vals and not rel_l2_vals: + continue + out_rows.append( + { + "metric": pfx, + "max_abs (max over shapes)": max(max_abs_vals) if max_abs_vals else None, + "p99_abs (max over shapes)": max(p99_abs_vals) if p99_abs_vals else None, + "rel_l2 (max over shapes)": max(rel_l2_vals) if rel_l2_vals else None, + } + ) + + if not out_rows: + return "" + + cols = ["metric", "max_abs (max over shapes)", "p99_abs (max over shapes)", "rel_l2 (max over shapes)"] + return "\n".join(["", "### Error Stats (vs PyTorch ref)", "", _md_table(out_rows, cols), ""]) + + +def summarize_one(path: str) -> str: + payload = _load_json(path) + meta = payload.get("meta", {}) + rows = payload.get("rows", []) + if not isinstance(rows, list): + rows = [] + + cols = _pick_columns(rows) + parts: List[str] = [] + + base = os.path.basename(path) + parts.append(f"## `{base}`") + if meta: + device = meta.get("device") + cap = meta.get("capability") + torch_ver = meta.get("torch") + cuda_ver = meta.get("cuda") + git_sha = meta.get("git_sha") + ts = meta.get("timestamp") + parts.append("") + parts.append( + f"- device: `{device}` | capability: `{cap}` | torch: `{torch_ver}` | cuda: `{cuda_ver}` | git_sha: `{git_sha}` | timestamp: `{ts}`" + ) + method = meta.get("method") + if method is not None: + parts.append(f"- method: `{method}`") + if meta.get("warmup_ms") is not None and meta.get("rep_ms") is not None: + parts.append(f"- warmup_ms: `{meta.get('warmup_ms')}` | rep_ms: `{meta.get('rep_ms')}`") + + if rows: + parts.append("") + parts.append(_md_table(rows, cols)) + + speeds = [float(r["speedup_vs_quack"]) for r in rows if "speedup_vs_quack" in r] + gm = _geomean(speeds) + if gm is not None: + parts.append("") + parts.append(f"- geomean speedup vs Quack: `{gm:.3f}x` (over {len(speeds)} shapes)") + + err_block = _summarize_error_stats(rows) + if err_block: + parts.append(err_block.rstrip()) + else: + parts.append("") + parts.append("_No rows found in JSON._") + + parts.append("") + return "\n".join(parts) + + +def main() -> None: + p = argparse.ArgumentParser(description="Summarize KernelAgent-Oink benchmark JSONs into Markdown tables.") + p.add_argument("--in-dir", type=str, required=True, help="Directory containing benchmark JSON files") + p.add_argument("--out", type=str, default=None, help="Optional output markdown path (default: stdout)") + args = p.parse_args() + + in_dir = os.path.abspath(args.in_dir) + if not os.path.isdir(in_dir): + raise SystemExit(f"--in-dir is not a directory: {in_dir}") + + json_paths = sorted( + os.path.join(in_dir, name) for name in os.listdir(in_dir) if name.endswith(".json") + ) + if not json_paths: + raise SystemExit(f"No .json files found under: {in_dir}") + + out_parts: List[str] = [] + out_parts.append("# KernelAgent-Oink SM100 Benchmark Summary") + out_parts.append("") + out_parts.append(f"Input directory: `{in_dir}`") + out_parts.append("") + for path in json_paths: + out_parts.append(summarize_one(path)) + + text = "\n".join(out_parts).rstrip() + "\n" + if args.out is None: + print(text, end="") + return + + out_path = os.path.abspath(args.out) + os.makedirs(os.path.dirname(out_path), exist_ok=True) + with open(out_path, "w") as f: + f.write(text) + + +if __name__ == "__main__": + main() diff --git a/oink/pyproject.toml b/oink/pyproject.toml index a9ec306..0d19d6e 100644 --- a/oink/pyproject.toml +++ b/oink/pyproject.toml @@ -5,11 +5,26 @@ build-backend = "setuptools.build_meta" [project] name = "kernelagent-oink" version = "0.1.0" -description = "vLLM plugin that registers Oink Blackwell RMSNorm custom ops" +description = "CuTeDSL kernels for Blackwell (SM100), shipped as a vLLM plugin" readme = "README.md" requires-python = ">=3.10" license = {text = "Apache-2.0"} authors = [{name = "PyTorch Labs"}] +keywords = ["cuda", "cutlass", "cute", "cutedsl", "blackwell", "sm100", "vllm"] +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] + +[project.urls] +Repository = "https://github.com/meta-pytorch/KernelAgent" +Documentation = "https://github.com/meta-pytorch/KernelAgent/tree/main/oink" +Issues = "https://github.com/meta-pytorch/KernelAgent/issues" # Keep dependencies minimal, but include the CuTeDSL stack required by the # Blackwell RMSNorm implementation. @@ -21,6 +36,13 @@ dependencies = [ "cuda-python", ] +[project.optional-dependencies] +# Optional extras for running the in-repo benchmark suite (not needed for vLLM integration). +bench = [ + "matplotlib", + "triton", +] + [project.entry-points."vllm.general_plugins"] oink = "kernelagent_oink:register" diff --git a/oink/src/kernelagent_oink/__init__.py b/oink/src/kernelagent_oink/__init__.py index bbbd7c1..f5c36a6 100644 --- a/oink/src/kernelagent_oink/__init__.py +++ b/oink/src/kernelagent_oink/__init__.py @@ -12,6 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +KernelAgent-Oink: SM100 CuTeDSL kernels + optional vLLM plugin. + +This package can be loaded as a vLLM "general plugin" (entrypoint group +`vllm.general_plugins`). In that mode it registers Oink custom ops only when +explicitly enabled via an environment variable (so installing the package does +not change behavior by default). + +For standalone usage (outside vLLM), call `kernelagent_oink.register(force=True)` +to register the custom ops explicitly. +""" + from __future__ import annotations import logging @@ -48,11 +60,14 @@ def _compute_cutedsl_arch(major: int, minor: int) -> str: return f"sm_{major}{minor}{suffix}" -def register() -> None: - """vLLM plugin entrypoint. +def register(*, force: bool = False) -> None: + """Register Oink torch custom ops. + + - vLLM plugin mode (default): no-op unless `VLLM_USE_OINK_RMSNORM` is truthy. + - Standalone mode: pass `force=True` to register explicitly. - This function must be safe to call multiple times and must not raise. - vLLM executes it in multiple processes (engine + workers). + This function must be safe to call multiple times and must not raise. vLLM + executes it in multiple processes (engine + workers). """ global _OPS_REGISTERED @@ -60,8 +75,9 @@ def register() -> None: return # Gate on the vLLM integration flag so installing the package does not - # change behavior unless explicitly enabled. - if not _env_truthy("VLLM_USE_OINK_RMSNORM"): + # change behavior unless explicitly enabled. For standalone usage (outside + # vLLM), callers can pass force=True to register the ops explicitly. + if not force and not _env_truthy("VLLM_USE_OINK_RMSNORM"): return try: diff --git a/oink/src/kernelagent_oink/blackwell/cross_entropy.py b/oink/src/kernelagent_oink/blackwell/cross_entropy.py new file mode 100644 index 0000000..94f052f --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/cross_entropy.py @@ -0,0 +1,1209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cross-entropy forward + backward kernels for SM100 (Blackwell) in CuteDSL. + +This module implements numerically stable cross-entropy over the last +dimension of 2D logits tensors `(M, N)` together with its backward pass, +targeting SM100 with Quack-style tiling, cp.async pipelines, and (for the +forward pass) optional cluster-wide online softmax reductions, but without +depending on the external `quack` package at runtime. + +Public APIs: + +- ``cross_entropy_forward(logits, target, ignore_index=-100, reduction="none")`` + returns ``(loss, lse)`` where ``loss`` follows the requested reduction and + ``lse`` is always per-example log-sum-exp (shape ``(M,)``). +- ``cross_entropy_backward(dloss, logits, target, lse, ignore_index=-100)`` + returns per-logit gradients ``dlogits`` matching PyTorch / + ``quack.cross_entropy_bwd`` semantics for ``reduction="none"``. +- ``cross_entropy(logits, target, ignore_index=-100, reduction="mean"|"sum"|"none")`` + is a convenience wrapper that mirrors ``torch.nn.functional.cross_entropy`` + reductions using the SM100 CuteDSL kernels for the forward pass. + +The kernels are self-contained and use only local helpers in +`kernelagent_oink.blackwell.lite_quack` plus CuTeDSL/CUTLASS. +""" + +from __future__ import annotations + +import importlib.metadata +import math +import os +import re +from typing import Literal, Optional, Type + +import torch +from torch import Tensor + +import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python + +# CuTeDSL caches generated MLIR into a tempdir under a global default +# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across +# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy +# warnings (and disables cache reuse). +if "CUTE_DSL_CACHE_DIR" not in os.environ: + try: + _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl") + except Exception: + _dsl_ver = "unknown" + _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver) + _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user" + _tmp = os.environ.get("TMPDIR") or "/tmp" + os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join( + _tmp, _user, f"cutlass_python_cache_{_dsl_ver}" + ) + +try: + import cutlass # type: ignore # noqa: F401 +except Exception as e: + raise ImportError( + "kernelagent_oink.blackwell.cross_entropy requires CuTeDSL's Python package " + "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)." + ) from e + +import cutlass.cute as cute +from cutlass import Boolean, Float32, Int32, const_expr +from cutlass.cute import runtime as rt +from cutlass.cute.runtime import from_dlpack + +from kernelagent_oink.blackwell.lite_quack import ( + _KERNEL_ACCEPTS_LAYOUT_ARGS, + TORCH2CUTE_DTYPE, + ReductionBase, + domain_offset_i64, + fill_oob, + online_softmax_reduce, + predicate_k, +) + +_FWD_COMPILE_CACHE: dict[tuple[type[cutlass.Numeric], int], cute.Kernel] = {} +_BWD_COMPILE_CACHE: dict[tuple[type[cutlass.Numeric], int], cute.Kernel] = {} +_PTR_FWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} +_PTR_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +def _convert_logits_2d(x: Tensor) -> cute.Tensor: + """Convert a 2D logits tensor (M, N) into a CuTe tensor. + + We assume 16-byte alignment and mark the layout compact and row-major + in the last dimension, matching the conventions used in the SM100 + softmax and RMSNorm kernels. + """ + assert x.dim() == 2, "Input logits must be 2D (M, N)" + return ( + from_dlpack(x.detach(), assumed_align=16) + .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1)) + ) + + +def _convert_1d(t: Tensor, assumed_align: int) -> cute.Tensor: + """Convert a 1D tensor with a fully dynamic layout.""" + assert t.dim() == 1, "Expected a 1D tensor" + return from_dlpack(t.detach(), assumed_align=assumed_align).mark_layout_dynamic() + + +class CrossEntropyFwdSM100(ReductionBase): + """SM100-tuned cross-entropy forward kernel. + + This mirrors the structure of ``quack.cross_entropy.CrossEntropy`` but + is simplified to always use the single-pass online softmax reduction and + never computes gradients inside the forward kernel. + """ + + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + # Use one stage with an Int64 reduction buffer packing (max, sum_exp) + # pairs via lite_quack.online_softmax_reduce. + super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64) + + def _calculate_threads_per_row(self) -> int: + N = self.N + return ( + 8 + if N <= 64 + else ( + 16 + if N <= 128 + else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))) + ) + ) + + def _set_cluster_n(self) -> None: + # Match Quack's cluster_n growth policy while keeping it explicit so + # we can tune SM100-specific shapes later if needed. + N = self.N + if const_expr(self.dtype.width == 16): + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: # fp32 + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + @cute.jit + def __call__( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mLoss: cute.Tensor, # (M,) + mLSE: Optional[cute.Tensor], # (M,) + ignore_index: Int32, + stream: cuda.CUstream, + ) -> None: + assert mX.element_type == self.dtype + self._set_cluster_n() + # If N is not divisible by the full 128-bit vector width, step down + # to the largest compatible vector size as in Quack. + num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) + num_threads = ( + cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + kernel = ( + self.kernel( + mX, + mTarget, + mLoss, + mLSE, + ignore_index, + tv_layout, + tiler_mn, + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel( + mX, + mTarget, + mLoss, + mLSE, + ignore_index, + ) + ) + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_logits: cute.Pointer, + ptr_target: cute.Pointer, + ptr_loss: cute.Pointer, + ptr_lse: cute.Pointer, + M: Int32, + ld: Int32, + ignore_index: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions.""" + ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + layout_m = cute.make_layout((M,), stride=(1,)) + mX = cute.make_tensor(ptr_logits, layout_mn) + mTarget = cute.make_tensor(ptr_target, layout_m) + mLoss = cute.make_tensor(ptr_loss, layout_m) + mLSE = cute.make_tensor(ptr_lse, layout_m) + self.__call__(mX, mTarget, mLoss, mLSE, ignore_index, stream) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mLoss: cute.Tensor, # (M,) + mLSE: Optional[cute.Tensor], # (M,) + ignore_index: Int32, # Index to ignore in loss computation + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + if const_expr(self.cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = const_expr(0) + + shape: cute.Shape = mX.shape + idX = cute.make_identity_tensor(shape) + + # Slice per-CTA region; use 64-bit indexing for large tensors. + mX_off = domain_offset_i64((bidx * tiler_mn[0], 0), mX) + gX = cute.local_tile(mX_off, tiler_mn, (0, cluster_y)) + cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + + # Copy setup: gmem -> smem via cp.async, 128-bit or narrower as needed. + num_copy_elems_X = tv_layout.shape[1][0] + num_copy_bits_X = mX.element_type.width * num_copy_elems_X + copy_atom_load_X = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + gX.element_type, + num_bits_per_copy=num_copy_bits_X, + ) + thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx) + + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + tXrX = cute.make_fragment_like(tXgX) + + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps) + + row = tXcX[0][0] + target = Int32.zero + if row < shape[0]: + target = Int32(mTarget[row]) + + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + tXpX = ( + predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + if row < shape[0]: + cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + # Fill out-of-bounds values with -inf so they are ignored in max/sum. + if const_expr(not is_even_N): + fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + + should_ignore = Boolean(target == ignore_index) + + # Load the target logit if this row is not ignored. Use Int64 indexing + # to safely handle very large tensors. + target_logit = Float32.zero + if row < shape[0] and tXcX[0][1] == 0 and not should_ignore: + mX_row = domain_offset_i64((row, 0), mX) + target_logit = Float32(mX_row[0, target]) + + threads_per_row = tv_layout.shape[0][0] + max_x, denom, _ = online_softmax_reduce( + x, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + phase=None, + return_exp_x=False, + ) + + # Write loss and lse to gmem. Only one CTA in the cluster writes to + # avoid duplicate stores. + if ( + tXcX[0][1] == 0 + and row < shape[0] + and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) + ): + lse = max_x + cute.math.log(denom, fastmath=True) + loss_val = (lse - target_logit) if not should_ignore else Float32.zero + mLoss[row] = mLoss.element_type(loss_val) + if const_expr(mLSE is not None): + mLSE[row] = lse + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mLoss: cute.Tensor, # (M,) + mLSE: Optional[cute.Tensor], # (M,) + ignore_index: Int32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + self._kernel_impl( + mX, + mTarget, + mLoss, + mLSE, + ignore_index, + tv_layout, + tiler_mn, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mLoss: cute.Tensor, # (M,) + mLSE: Optional[cute.Tensor], # (M,) + ignore_index: Int32, + ) -> None: + num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) + self._kernel_impl( + mX, + mTarget, + mLoss, + mLSE, + ignore_index, + tv_layout, + tiler_mn, + ) + + +class CrossEntropyBackwardSM100: + """SM100-tuned cross-entropy backward kernel. + + This is a direct port of ``quack.cross_entropy.CrossEntropyBackward`` to + the local lite_quack helpers, using cp.async tiling over the (M, N) + logits and broadcasting ``dloss`` / ``lse`` across the row dimension. + """ + + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + self.dtype = dtype + self.N = N + + def _get_num_threads(self) -> int: + # Keep in sync with _get_tv_layout() (we tile N in 16k blocks). + N = min(self.N, 16384) + return 128 if N <= 16384 else 256 + + def _calculate_threads_per_row(self) -> int: + N = min(self.N, 16384) # We split by blocks of 16k in N. + return ( + 8 + if N <= 64 + else ( + 16 + if N <= 128 + else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))) + ) + ) + + def _get_tv_layout(self, num_copy_bits: int = 128) -> tuple[cute.Shape, cute.Layout]: + vecsize = num_copy_bits // self.dtype.width + assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}" + N = min(self.N, 16384) + num_threads = 128 if N <= 16384 else 256 + threads_per_row = self._calculate_threads_per_row() + cols_per_block = num_threads // threads_per_row + num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row) + tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) + tv_layout = cute.make_layout( + ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)), + stride=( + (vecsize * cols_per_block, 1), + (cols_per_block, cols_per_block * vecsize * threads_per_row), + ), + ) + return tiler_mn, tv_layout + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mTarget: cute.Tensor, + mDLoss: cute.Tensor, + mdX: cute.Tensor, + mLSE: cute.Tensor, + ignore_index: Int32, # Index to ignore in gradient computation + stream: cuda.CUstream, + ) -> None: + assert mX.element_type == self.dtype + assert mdX.element_type == self.dtype + num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) + num_threads = ( + cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + ) + # Broadcast (M,) tensors along the N dimension with stride 0. + mDLoss, mTarget, mLSE = [ + cute.make_tensor( + X.iterator, + cute.append(X.layout, cute.make_layout((self.N,), stride=(0,))), + ) + for X in (mDLoss, mTarget, mLSE) + ] + smem_size = cute.size_in_bytes( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + ) + kernel = ( + self.kernel( + mX, + mTarget, + mDLoss, + mdX, + mLSE, + ignore_index, + tv_layout, + tiler_mn, + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel( + mX, + mTarget, + mDLoss, + mdX, + mLSE, + ignore_index, + ) + ) + kernel.launch( + grid=[ + cute.ceil_div(mX.shape[0], tiler_mn[0]), + cute.ceil_div(mX.shape[1], tiler_mn[1]), + 1, + ], + block=[num_threads, 1, 1], + smem=smem_size, + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_logits: cute.Pointer, + ptr_target: cute.Pointer, + ptr_dloss: cute.Pointer, + ptr_dx: cute.Pointer, + ptr_lse: cute.Pointer, + M: Int32, + ld: Int32, + ignore_index: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions.""" + ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + layout_m = cute.make_layout((M,), stride=(1,)) + mX = cute.make_tensor(ptr_logits, layout_mn) + mdX = cute.make_tensor(ptr_dx, layout_mn) + mTarget = cute.make_tensor(ptr_target, layout_m) + mDLoss = cute.make_tensor(ptr_dloss, layout_m) + mLSE = cute.make_tensor(ptr_lse, layout_m) + self.__call__(mX, mTarget, mDLoss, mdX, mLSE, ignore_index, stream) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + mLSE: cute.Tensor, # (M,) + ignore_index: Int32, # Index to ignore in gradient computation + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, _ = cute.arch.block_idx() + shape = mX.shape + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + + idX = cute.make_identity_tensor(shape) + mX_off, mdX_off = [ + domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mdX) + ] + gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX_off, mdX_off)] + cX = cute.local_tile(idX, tiler_mn, (bidx, bidy)) + + num_copy_elems_X = tv_layout.shape[1][0] + num_copy_bits_X = mX.element_type.width * num_copy_elems_X + copy_atom_load_X = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + gX.element_type, + num_bits_per_copy=num_copy_bits_X, + ) + copy_atom_store_dX = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gdX.element_type, + num_bits_per_copy=num_copy_bits_X, + ) + thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx) + thr_copy_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx) + + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + tXcFull = thr_copy_X.partition_S(cX) + tXgdX = thr_copy_dX.partition_D(gdX) + + tXrX, tXrdX = [cute.make_fragment_like(thr) for thr in (tXgX, tXgdX)] + + is_even_N = const_expr(shape[1] % tiler_mn[1] == 0) + row = tXcX[0][0] + tXpX = ( + predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) + if not is_even_N + else None + ) + if row < shape[0]: + cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + if const_expr(not is_even_N): + fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + + target = Int32.zero + dloss = Float32.zero + lse = Float32.zero + if row < shape[0]: + target = Int32(mTarget[row]) + should_ignore = Boolean(target == ignore_index) + dloss = Float32(mDLoss[row]) if not should_ignore else Float32.zero + lse = Float32(mLSE[row]) + + log2_e = math.log2(math.e) + probs = cute.math.exp2(x * log2_e - (lse * log2_e), fastmath=True) + prob_shifted = probs - 1.0 + mask = cute.make_fragment_like(tXrX, cutlass.Boolean) + for i in cutlass.range(cute.size(tXcFull), unroll_full=True): + mask[i] = tXcFull[i][1] == target + grad = cute.where(mask.load(), prob_shifted, probs) + grad = grad * dloss + + tXrdX.store(grad.to(tXrdX.element_type)) + tXpdX = ( + predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1]) + if not is_even_N + else None + ) + if row < shape[0]: + cute.copy(copy_atom_store_dX, tXrdX, tXgdX, pred=tXpdX) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + mLSE: cute.Tensor, # (M,) + ignore_index: Int32, # Index to ignore in gradient computation + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + self._kernel_impl( + mX, + mTarget, + mDLoss, + mdX, + mLSE, + ignore_index, + tv_layout, + tiler_mn, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + mLSE: cute.Tensor, # (M,) + ignore_index: Int32, # Index to ignore in gradient computation + ) -> None: + num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) + self._kernel_impl( + mX, + mTarget, + mDLoss, + mdX, + mLSE, + ignore_index, + tv_layout, + tiler_mn, + ) + + +def cross_entropy_forward( + logits: Tensor, + target: Tensor, + ignore_index: int = -100, + reduction: Literal["none", "mean", "sum"] = "none", +) -> tuple[Tensor, Tensor]: + """SM100 CuteDSL cross-entropy forward pass. + + Args: + logits: Tensor of shape ``(M, N)`` on CUDA. + target: Tensor of shape ``(M,)`` with integer class indices. + ignore_index: Target value to ignore when computing the loss. + reduction: One of ``"none"``, ``"mean"``, or ``"sum"`` following + ``torch.nn.functional.cross_entropy`` semantics. + + Returns: + A tuple ``(loss, lse)`` where: + - ``loss`` has shape ``(M,)`` if ``reduction="none"`` or is a scalar + otherwise. + - ``lse`` is the per-example log-sum-exp with shape ``(M,)``. + """ + assert logits.dim() == 2, "logits must be 2D (M, N)" + assert target.dim() == 1, "target must be 1D (M,)" + assert logits.shape[0] == target.shape[0], "Batch dimensions must match" + assert logits.is_cuda and target.is_cuda, "logits and target must be on CUDA device" + assert logits.dtype in TORCH2CUTE_DTYPE, "Unsupported logits dtype" + assert target.dtype in (torch.int32, torch.int64), "target must be int32 or int64" + + M, N = logits.shape + device = logits.device + dtype_cute = TORCH2CUTE_DTYPE[logits.dtype] + + loss = torch.empty(M, device=device, dtype=torch.float32) + lse = torch.empty(M, device=device, dtype=torch.float32) + + if _can_use_ptr_path_logits(logits) and _can_use_ptr_path_target(target): + _cross_entropy_forward_ptr_into( + logits=logits, + target=target, + loss=loss, + lse=lse, + ignore_index=int(ignore_index), + ) + if reduction == "none": + return loss, lse + with torch.no_grad(): + mask = target != ignore_index + if reduction == "sum": + reduced = loss.sum() + elif reduction == "mean": + valid = mask.sum() + if valid > 0: + reduced = loss[mask].sum() / valid.to(loss.dtype) + else: + reduced = loss.sum() * 0.0 + else: + raise ValueError( + f"Invalid reduction mode: {reduction}. Expected 'none', 'mean', or 'sum'." + ) + return reduced, lse + + mX = _convert_logits_2d(logits) + mTarget = _convert_1d(target.to(torch.int64), assumed_align=8) + mLoss = _convert_1d(loss, assumed_align=4) + mLSE = _convert_1d(lse, assumed_align=4) + + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compile_key = (dtype_cute, N) + kernel = _FWD_COMPILE_CACHE.get(compile_key) + if kernel is None: + op = CrossEntropyFwdSM100(dtype_cute, N) + kernel = cute.compile( + op, + mX, + mTarget, + mLoss, + mLSE, + Int32(ignore_index), + current_stream, + ) + _FWD_COMPILE_CACHE[compile_key] = kernel + + kernel(mX, mTarget, mLoss, mLSE, Int32(ignore_index), current_stream) + + if reduction == "none": + return loss, lse + + with torch.no_grad(): + mask = target != ignore_index + if reduction == "sum": + reduced = loss.sum() + elif reduction == "mean": + valid = mask.sum() + if valid > 0: + reduced = loss[mask].sum() / valid.to(loss.dtype) + else: + reduced = loss.sum() * 0.0 + else: + raise ValueError( + f"Invalid reduction mode: {reduction}. Expected 'none', 'mean', or 'sum'." + ) + return reduced, lse + + +def _cross_entropy_backward_sm100( + logits: Tensor, + target: Tensor, + dloss: Tensor, + lse: Tensor, + dx: Tensor, + ignore_index: int = -100, +) -> None: + """Internal SM100 cross-entropy backward dispatch using CuteDSL.""" + assert logits.dim() == 2, "logits must be 2D (M, N)" + assert target.dim() == 1, "target must be 1D (M,)" + assert dloss.dim() == 1, "dloss must be 1D (M,)" + assert lse.dim() == 1, "lse must be 1D (M,)" + assert logits.shape[0] == target.shape[0] == dloss.shape[0] == lse.shape[0], ( + "Batch dimensions must match" + ) + assert logits.is_cuda and target.is_cuda and dloss.is_cuda and lse.is_cuda, ( + "All tensors must be on CUDA device" + ) + assert logits.dtype in TORCH2CUTE_DTYPE, "Unsupported logits dtype" + assert target.dtype in (torch.int32, torch.int64), "target must be int32 or int64" + + M, N = logits.shape + dtype_cute = TORCH2CUTE_DTYPE[logits.dtype] + + if ( + _can_use_ptr_path_logits(logits) + and _can_use_ptr_path_logits(dx) + and _can_use_ptr_path_target(target) + and _can_use_ptr_path_f32_1d(dloss) + and _can_use_ptr_path_f32_1d(lse) + and logits.stride() == dx.stride() + ): + _cross_entropy_backward_ptr_into( + logits=logits, + target=target, + dloss=dloss, + lse=lse, + dx=dx, + ignore_index=int(ignore_index), + ) + return + + mX = _convert_logits_2d(logits) + mdX = _convert_logits_2d(dx) + mTarget = _convert_1d(target.to(torch.int64), assumed_align=8) + mDLoss = _convert_1d(dloss.to(torch.float32), assumed_align=4) + mLSE = _convert_1d(lse.to(torch.float32), assumed_align=4) + + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compile_key = (dtype_cute, N) + kernel = _BWD_COMPILE_CACHE.get(compile_key) + if kernel is None: + op = CrossEntropyBackwardSM100(dtype_cute, N) + kernel = cute.compile( + op, + mX, + mTarget, + mDLoss, + mdX, + mLSE, + Int32(ignore_index), + current_stream, + ) + _BWD_COMPILE_CACHE[compile_key] = kernel + + kernel(mX, mTarget, mDLoss, mdX, mLSE, Int32(ignore_index), current_stream) + + +def _can_use_ptr_path_logits(x: Tensor) -> bool: + if not x.is_cuda or x.dim() != 2: + return False + if x.dtype not in TORCH2CUTE_DTYPE: + return False + if x.stride(1) != 1: + return False + if (x.data_ptr() % 16) != 0: + return False + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + divby = 128 // dtype_x.width + if (x.stride(0) % divby) != 0: + return False + return True + + +def _can_use_ptr_path_target(t: Tensor) -> bool: + if not t.is_cuda or t.dim() != 1: + return False + if t.dtype is not torch.int64: + return False + if not t.is_contiguous(): + return False + if t.stride(0) != 1: + return False + if (t.data_ptr() % 8) != 0: + return False + return True + + +def _can_use_ptr_path_f32_1d(t: Tensor) -> bool: + if not t.is_cuda or t.dim() != 1: + return False + if t.dtype is not torch.float32: + return False + if not t.is_contiguous(): + return False + if t.stride(0) != 1: + return False + if (t.data_ptr() % 4) != 0: + return False + return True + + +def _cross_entropy_forward_ptr_into( + *, + logits: Tensor, + target: Tensor, + loss: Tensor, + lse: Tensor, + ignore_index: int, +) -> None: + assert logits.is_cuda and logits.dim() == 2 + assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0] + assert target.dtype is torch.int64 + assert loss.is_cuda and loss.shape == (logits.shape[0],) and loss.dtype is torch.float32 + assert lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32 + + M, N = logits.shape + device_index = logits.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + + dtype_x = TORCH2CUTE_DTYPE[logits.dtype] + key = ("ptr_fwd", int(N), dtype_x, int(device_index)) + compiled = _PTR_FWD_COMPILE_CACHE.get(key) + if compiled is None: + op = CrossEntropyFwdSM100(dtype_x, int(N)) + ptr_logits = rt.make_ptr( + dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + target.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=8, + ) + ptr_loss = rt.make_ptr( + cutlass.Float32, + loss.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_lse = rt.make_ptr( + cutlass.Float32, + lse.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_logits, + ptr_target, + ptr_loss, + ptr_lse, + Int32(int(M)), + Int32(int(logits.stride(0))), + Int32(int(ignore_index)), + stream, + ) + _PTR_FWD_COMPILE_CACHE[key] = compiled + + ptr_logits = rt.make_ptr( + dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + target.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=8, + ) + ptr_loss = rt.make_ptr( + cutlass.Float32, + loss.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_lse = rt.make_ptr( + cutlass.Float32, + lse.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + compiled( + ptr_logits, + ptr_target, + ptr_loss, + ptr_lse, + Int32(int(M)), + Int32(int(logits.stride(0))), + Int32(int(ignore_index)), + stream, + ) + + +def _cross_entropy_backward_ptr_into( + *, + logits: Tensor, + target: Tensor, + dloss: Tensor, + lse: Tensor, + dx: Tensor, + ignore_index: int, +) -> None: + assert logits.is_cuda and logits.dim() == 2 + assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0] + assert target.dtype is torch.int64 + assert dloss.is_cuda and dloss.shape == (logits.shape[0],) and dloss.dtype is torch.float32 + assert lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32 + assert dx.is_cuda and dx.shape == logits.shape and dx.dtype == logits.dtype + assert dx.stride() == logits.stride(), "Pointer path expects dx to match logits strides" + + M, N = logits.shape + device_index = logits.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + + dtype_x = TORCH2CUTE_DTYPE[logits.dtype] + key = ("ptr_bwd", int(N), dtype_x, int(device_index)) + compiled = _PTR_BWD_COMPILE_CACHE.get(key) + if compiled is None: + op = CrossEntropyBackwardSM100(dtype_x, int(N)) + ptr_logits = rt.make_ptr( + dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + target.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=8, + ) + ptr_dloss = rt.make_ptr( + cutlass.Float32, + dloss.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_lse = rt.make_ptr( + cutlass.Float32, + lse.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_logits, + ptr_target, + ptr_dloss, + ptr_dx, + ptr_lse, + Int32(int(M)), + Int32(int(logits.stride(0))), + Int32(int(ignore_index)), + stream, + ) + _PTR_BWD_COMPILE_CACHE[key] = compiled + + ptr_logits = rt.make_ptr( + dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + target.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=8, + ) + ptr_dloss = rt.make_ptr( + cutlass.Float32, + dloss.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr(dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_lse = rt.make_ptr( + cutlass.Float32, + lse.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + compiled( + ptr_logits, + ptr_target, + ptr_dloss, + ptr_dx, + ptr_lse, + Int32(int(M)), + Int32(int(logits.stride(0))), + Int32(int(ignore_index)), + stream, + ) + + +def cross_entropy_backward( + dloss: Tensor, + logits: Tensor, + target: Tensor, + lse: Tensor, + ignore_index: int = -100, +) -> Tensor: + """SM100 CuteDSL cross-entropy backward pass. + + Args: + dloss: Upstream gradient of shape ``(M,)`` corresponding to + ``reduction="none"``. + logits: Input logits tensor of shape ``(M, N)``. + target: Integer class indices of shape ``(M,)``. + lse: Per-example log-sum-exp tensor of shape ``(M,)`` as returned + by :func:`cross_entropy_forward`. + ignore_index: Target value to ignore in gradient computation. + + Returns: + ``dlogits`` of shape ``(M, N)`` with the same dtype as ``logits``. + """ + assert logits.dim() == 2, "logits must be 2D (M, N)" + assert dloss.dim() == 1, "dloss must be 1D (M,)" + assert logits.size(0) == dloss.size(0), "Batch dimensions must match" + assert logits.is_cuda and dloss.is_cuda, "logits and dloss must be on CUDA device" + + dx = torch.empty_like(logits) + _cross_entropy_backward_sm100( + logits, + target, + dloss, + lse, + dx, + ignore_index=ignore_index, + ) + return dx + + +def cross_entropy( + logits: Tensor, + target: Tensor, + ignore_index: int = -100, + reduction: Literal["none", "mean", "sum"] = "mean", +) -> Tensor: + """Convenience wrapper mirroring ``torch.nn.functional.cross_entropy`` reductions. + + This uses :func:`cross_entropy_forward` under the hood but returns only + the reduced loss tensor. + """ + loss, _lse = cross_entropy_forward( + logits, + target, + ignore_index=ignore_index, + reduction="none", + ) + if reduction == "none": + return loss + mask = target != ignore_index + if reduction == "sum": + return loss.sum() + if reduction == "mean": + valid = mask.sum() + if valid > 0: + return loss[mask].sum() / valid.to(loss.dtype) + return loss.sum() * 0.0 + raise ValueError( + f"Invalid reduction mode: {reduction}. Expected one of 'none', 'mean', or 'sum'." + ) + + +def verify_cross_entropy_parity( + M: int, + N: int, + dtype: torch.dtype = torch.bfloat16, + ignore_index: int = -100, +) -> None: + """Compare SM100 CuteDSL cross-entropy against PyTorch for a single shape.""" + device = torch.device("cuda") + torch.manual_seed(0) + + logits = 0.1 * torch.randn(M, N, device=device, dtype=dtype) + logits.requires_grad_(True) + target = torch.randint(0, N, (M,), device=device, dtype=torch.int64) + + # Optionally sprinkle some ignore_index entries for robustness. + if ignore_index != -100: + mask = torch.rand(M, device=device) < 0.1 + target[mask] = ignore_index + + loss, lse = cross_entropy_forward(logits, target, ignore_index=ignore_index, reduction="none") + + logits_ref = logits.detach().clone().requires_grad_() + target_ref = target.detach().clone() + loss_ref = torch.nn.functional.cross_entropy( + logits_ref.float(), + target_ref, + ignore_index=ignore_index, + reduction="none", + ) + + # Forward parity + if dtype in (torch.float16, torch.bfloat16): + atol = 5e-2 + rtol = 5e-2 + else: + atol = 1e-4 + rtol = 1e-4 + torch.testing.assert_close(loss, loss_ref, atol=atol, rtol=rtol) + + # Backward parity + dloss = torch.randn_like(loss_ref) + (dx_ref,) = torch.autograd.grad(loss_ref, logits_ref, grad_outputs=dloss) + dx = cross_entropy_backward(dloss, logits, target, lse, ignore_index=ignore_index) + torch.testing.assert_close(dx, dx_ref.to(logits.dtype), atol=atol, rtol=rtol) + + +if __name__ == "__main__": + # Minimal functional check when executed directly. For performance + # comparisons and detailed tuning, use the dedicated benchmark harness. + if not torch.cuda.is_available(): + print("CUDA not available; cross-entropy parity check skipped.") + raise SystemExit(0) + + M, N = 1024, 8192 + dtype = torch.bfloat16 + verify_cross_entropy_parity(M, N, dtype=dtype, ignore_index=-100) + print("SM100 cross-entropy CuteDSL parity check passed.") diff --git a/oink/src/kernelagent_oink/blackwell/layernorm.py b/oink/src/kernelagent_oink/blackwell/layernorm.py new file mode 100644 index 0000000..05b11de --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/layernorm.py @@ -0,0 +1,1368 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +LayerNorm kernel for SM100 (Blackwell) in CuteDSL. + +This implementation: +- Mirrors Quack's LayerNorm tiling / cluster policy / cp.async pipeline + but uses only local helpers so that it does not depend on the external + `quack` package at runtime. +- Supports fp16 / bf16 / fp32 inputs with fp32 accumulation. +- Optionally writes out per-row `rstd` and `mean` buffers for reuse in + backward or fused kernels. + +Backward is implemented with dedicated CuteDSL kernels for input and +parameter gradients (dx, dweight, dbias), avoiding PyTorch autograd +while matching `torch.nn.functional.layer_norm`'s gradients numerically. +""" + +from __future__ import annotations + +import importlib.metadata +import os +import re +import operator +from typing import Optional, Tuple, Type + +import torch +from torch import Tensor + +import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python + +# CuTeDSL caches generated MLIR into a tempdir under a global default +# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across +# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy +# warnings (and disables cache reuse). +if "CUTE_DSL_CACHE_DIR" not in os.environ: + try: + _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl") + except Exception: + _dsl_ver = "unknown" + _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver) + _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user" + _tmp = os.environ.get("TMPDIR") or "/tmp" + os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join( + _tmp, _user, f"cutlass_python_cache_{_dsl_ver}" + ) + +try: + import cutlass # type: ignore # noqa: F401 +except Exception as e: + raise ImportError( + "kernelagent_oink.blackwell.layernorm requires CuTeDSL's Python package " + "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)." + ) from e + +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute import runtime as rt +from cutlass.cute.runtime import from_dlpack + +# Simple compile cache for the forward kernel +_COMPILE_CACHE: dict[Tuple[int, type[cutlass.Numeric], bool, bool, bool], object] = {} +_PTR_COMPILE_CACHE: dict[Tuple[object, ...], object] = {} + +# Backward compile caches: one for dx, one for parameter gradients. +_BWD_DX_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric]], object] = {} +_BWD_PARAM_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric], bool], object] = {} + +# Local helpers cloned from Quack via lite_quack so that this kernel does +# not depend on `quack` at runtime. +from kernelagent_oink.blackwell.lite_quack import ( + _KERNEL_ACCEPTS_LAYOUT_ARGS, + TORCH2CUTE_DTYPE, + ReductionBase as _ReductionBase, + convert_from_dlpack as convert_from_dlpack_cute, + domain_offset_i64, + get_sm_count, + predicate_k, + row_reduce, + warp_reduce, +) + + +def _convert_row_major(t: Tensor) -> cute.Tensor: + """ + Convert a 2D row-major torch.Tensor to a CuTeDSL tensor with a compact, + dynamic layout on the leading dimension. + """ + return from_dlpack(t.detach(), assumed_align=16).mark_compact_shape_dynamic( + mode=0, + stride_order=(0, 1), + ) + + +class LayerNormSM100(_ReductionBase): + """ + SM100 LayerNorm forward kernel. + + This mirrors `quack.layernorm.LayerNorm`'s schedule: + - Stage=2 pipeline: first pass computes mean, second pass computes + variance / rstd and normalization. + - Threads-per-row and cluster_n policy follow Quack's LayerNorm + heuristics to keep tensor-core friendly tiles across N. + - Optional `reload_from` hint enables reloading X from SMEM for large-N + shapes to shorten register lifetimes. + + Differences vs Quack: + - Bias is optional and supported directly in the kernel. + - Dtype mapping and reduction helpers come from `lite_quack`. + """ + + def __init__(self, dtype: type[cutlass.Numeric], N: int): + super().__init__(dtype, N, stage=2) # 2 stages for mean and var + # Default reload policy mirrors Quack: use SMEM reload only for + # very large hidden sizes. We keep this conservative for LayerNorm + # and tune primarily via threads-per-block / cluster_n. + self.reload_from: Optional[str] = None if N <= 16384 else "smem" + self.delay_w_load: bool = False + + def _calculate_threads_per_row(self) -> int: + # Match Quack's LayerNorm threads-per-row buckets. + N = self.N + return ( + 8 + if N <= 64 + else ( + 16 + if N <= 128 + else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))) + ) + ) + + def _set_cluster_n(self) -> None: + # Cluster_n policy mirrors quack.layernorm.LayerNorm._set_cluster_n. + N = self.N + if const_expr(self.dtype.width == 16): + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: + cluster_n = ( + 1 + if N <= 32 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mB: Optional[cute.Tensor], + mO: cute.Tensor, + mRstd: Optional[cute.Tensor], + mMean: Optional[cute.Tensor], + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ): + assert mX.element_type == self.dtype + assert mO.element_type == self.dtype + + # Tiling and cluster policy (mirrors Quack LayerNorm). + self._set_cluster_n() + tiler_mn, tv_layout = self._get_tv_layout() + num_threads = ( + cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + + # Expand weight / bias to match tiler_mn[0] rows per CTA. + mW = cute.make_tensor( + mW.iterator, + cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), + ) + if const_expr(mB is not None): + mB = cute.make_tensor( + mB.iterator, + cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), + ) + if const_expr(mRstd is not None): + mRstd = cute.make_tensor( + mRstd.iterator, + cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))), + ) + if const_expr(mMean is not None): + mMean = cute.make_tensor( + mMean.iterator, + cute.append(mMean.layout, cute.make_layout((self.N,), stride=(0,))), + ) + + kernel = ( + self.kernel( + mX, + mW, + mB, + mO, + mRstd, + mMean, + eps, + tv_layout, + tiler_mn, + const_expr(self.reload_from), + const_expr(self.delay_w_load), + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel( + mX, + mW, + mB, + mO, + mRstd, + mMean, + eps, + ) + ) + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[ + 1, + self.cluster_n, + 1, + ] + if const_expr(self.cluster_n > 1) + else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_x: cute.Pointer, + ptr_w: cute.Pointer, + ptr_b: Optional[cute.Pointer], + ptr_out: cute.Pointer, + ptr_rstd: Optional[cute.Pointer], + ptr_mean: Optional[cute.Pointer], + M: Int32, + ld: Int32, + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions. + + This reconstructs cute.Tensor views from raw device pointers + explicit + layouts inside the JIT graph, reusing the tuned LayerNormSM100 schedule. + """ + # The kernel uses 128-bit vectorized copies for X. Mirror Quack's + # `divisibility=128 // dtype.width` contract so the compiler can + # prove alignment for cp.async. + ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + # Match `mark_compact_shape_dynamic(mode=0, ...)`: M is dynamic, N is static. + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + layout_n = cute.make_layout((self.N,), stride=(1,)) + layout_m = cute.make_layout((M,), stride=(1,)) + + mX = cute.make_tensor(ptr_x, layout_mn) + mO = cute.make_tensor(ptr_out, layout_mn) + mW = cute.make_tensor(ptr_w, layout_n) + mB = cute.make_tensor(ptr_b, layout_n) if const_expr(ptr_b is not None) else None + mRstd = ( + cute.make_tensor(ptr_rstd, layout_m) + if const_expr(ptr_rstd is not None) + else None + ) + mMean = ( + cute.make_tensor(ptr_mean, layout_m) + if const_expr(ptr_mean is not None) + else None + ) + + self.__call__(mX, mW, mB, mO, mRstd, mMean, stream, eps) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mB: Optional[cute.Tensor], + mO: cute.Tensor, + mRstd: Optional[cute.Tensor], + mMean: Optional[cute.Tensor], + eps: Float32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + reload_from: cutlass.Constexpr, + delay_w_load: cutlass.Constexpr, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + if const_expr(self.cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = const_expr(0) + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + + # Slice for CTAs: use domain_offset_i64 to handle >2^31 elements. + mX, mO = [ + domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO) + ] + gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)] + cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) + gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) + gB = ( + cute.local_tile(mB, tiler_mn, (0, cluster_y)) + if const_expr(mB is not None) + else None + ) + gRstd = ( + cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y)) + if const_expr(mRstd is not None) + else None + ) + gMean = ( + cute.local_tile(mMean, tiler_mn, (bidx, cluster_y)) + if const_expr(mMean is not None) + else None + ) + + # Copy atoms for X / W / B / O. + copy_atom_load_X = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mX.element_type, + num_bits_per_copy=128, + ) + copy_atom_load_X_async = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mX.element_type, + num_bits_per_copy=128, + ) + copy_atom_load_WB = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mW.element_type, + num_bits_per_copy=128, + ) + copy_atom_store_O = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mO.element_type, + num_bits_per_copy=128, + ) + + thr_copy_X = cute.make_tiled_copy( + copy_atom_load_X_async, + tv_layout, + tiler_mn, + ).get_slice(tidx) + thr_copy_WB = cute.make_tiled_copy( + copy_atom_load_WB, + tv_layout, + tiler_mn, + ).get_slice(tidx) + thr_copy_O = cute.make_tiled_copy( + copy_atom_store_O, + tv_layout, + tiler_mn, + ).get_slice(tidx) + + tWgW = thr_copy_WB.partition_S(gW) + tBgB = ( + thr_copy_WB.partition_S(gB) + if const_expr(gB is not None) + else None + ) + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + tXgO = thr_copy_O.partition_D(gO) + tXrRstd = ( + thr_copy_O.partition_D(gRstd) + if const_expr(mRstd is not None) + else None + ) + tXrMean = ( + thr_copy_O.partition_D(gMean) + if const_expr(mMean is not None) + else None + ) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + + # Fragments for gmem->rmem. + tWrW = cute.make_fragment_like(tWgW) + tBrB = ( + cute.make_fragment_like(tBgB) + if const_expr(mB is not None) + else None + ) + tXrW = thr_copy_X.retile(tWrW) + tXrB = ( + thr_copy_X.retile(tBrB) + if const_expr(mB is not None) + else None + ) + tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)] + + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=False) + + tXpX = predicate_k( + thr_copy_X.partition_S(cX), + limit=shape[1], + ) + row = tXcX[0][0] + if row < shape[0]: + cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() + + tWpW = predicate_k( + thr_copy_WB.partition_S(cX), + limit=shape[1], + ) + if const_expr(not delay_w_load): + cute.copy(copy_atom_load_WB, tWgW, tWrW, pred=tWpW) + if const_expr(mB is not None): + cute.copy(copy_atom_load_WB, tBgB, tBrB, pred=tWpW) + + cute.arch.cp_async_wait_group(0) + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + threads_per_row = tv_layout.shape[0][0] + sum_x = row_reduce( + x, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None, + init_val=0.0, + hook_fn=( + cute.arch.cluster_wait + if const_expr(self.cluster_n > 1) + else None + ), + ) + mean = sum_x / shape[1] + + if const_expr(reload_from == "smem"): + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + elif const_expr(reload_from == "gmem"): + cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX) + x = tXrX.load().to(Float32) + + sum_sq_x_sub_mean = row_reduce( + (x - mean) * (x - mean), + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 1], + mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None, + init_val=0.0, + ) + rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True) + + if const_expr(mRstd is not None): + if ( + tXcX[0][1] == 0 + and row < shape[0] + and ( + self.cluster_n == 1 + or cute.arch.block_idx_in_cluster() == 0 + ) + ): + tXrRstd[0] = rstd + + if const_expr(mMean is not None): + if ( + tXcX[0][1] == 0 + and row < shape[0] + and ( + self.cluster_n == 1 + or cute.arch.block_idx_in_cluster() == 0 + ) + ): + tXrMean[0] = mean + + if const_expr(delay_w_load): + cute.copy(copy_atom_load_WB, tWgW, tWrW, pred=tWpW) + if const_expr(mB is not None): + cute.copy(copy_atom_load_WB, tBgB, tBrB, pred=tWpW) + + if const_expr(reload_from == "smem"): + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + elif const_expr(reload_from == "gmem"): + cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX) + x = tXrX.load().to(Float32) + + x_hat = (x - mean) * rstd + w = tXrW.load().to(Float32) + y = x_hat * w + if const_expr(mB is not None): + b = tXrB.load().to(Float32) + y = y + b + + tXrO.store(y.to(tXrO.element_type)) + tOpO = predicate_k( + thr_copy_O.partition_S(cX), + limit=shape[1], + ) + if row < shape[0]: + cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mB: Optional[cute.Tensor], + mO: cute.Tensor, + mRstd: Optional[cute.Tensor], + mMean: Optional[cute.Tensor], + eps: Float32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + reload_from: cutlass.Constexpr, + delay_w_load: cutlass.Constexpr, + ): + self._kernel_impl( + mX, + mW, + mB, + mO, + mRstd, + mMean, + eps, + tv_layout, + tiler_mn, + reload_from, + delay_w_load, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: cute.Tensor, + mB: Optional[cute.Tensor], + mO: cute.Tensor, + mRstd: Optional[cute.Tensor], + mMean: Optional[cute.Tensor], + eps: Float32, + ): + tiler_mn, tv_layout = self._get_tv_layout() + self._kernel_impl( + mX, + mW, + mB, + mO, + mRstd, + mMean, + eps, + tv_layout, + tiler_mn, + const_expr(self.reload_from), + const_expr(self.delay_w_load), + ) + + +# ----------------------------------------------------------------------------- +# Public Python API +# ----------------------------------------------------------------------------- + + +def layernorm( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + eps: float = 1e-6, + return_rstd: bool = False, + return_mean: bool = False, +): + """ + LayerNorm forward pass using the SM100 CuteDSL kernel. + + Args: + x: Input tensor of shape (M, N). + weight: Scale parameter of shape (N,), typically fp32. + bias: Optional bias parameter of shape (N,). + eps: Small value for numerical stability. + return_rstd: Whether to return per-row reciprocal std (shape (M,)). + return_mean: Whether to return per-row mean (shape (M,)). + """ + assert x.is_cuda and weight.is_cuda, "x and weight must be CUDA tensors" + assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand." + assert weight.dim() == 1, "weight must be 1D" + assert x.shape[1] == weight.shape[0], "Last dim of x must match weight.size(0)" + if bias is not None: + assert bias.is_cuda, "bias must be on CUDA" + assert bias.dim() == 1 and bias.shape[0] == weight.shape[0], ( + "bias must be 1D and match weight" + ) + + M, N = x.shape + dtype = TORCH2CUTE_DTYPE[x.dtype] + + rstd = torch.empty(M, device=x.device, dtype=torch.float32) if return_rstd else None + mean = torch.empty(M, device=x.device, dtype=torch.float32) if return_mean else None + + # Fast path: bypass DLPack conversions when the inputs are in the common + # contiguous row-major layout and weights/bias are fp32 (Quack-style). + if _can_use_ptr_path(x, weight, bias): + out = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) + _layernorm_forward_ptr_into( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + eps=eps, + ) + if return_mean and return_rstd: + return out, rstd, mean + if return_rstd and not return_mean: + return out, rstd + if return_mean and not return_rstd: + return out, mean + return out + + out = torch.empty_like(x) + mX = _convert_row_major(x) + mO = _convert_row_major(out) + + # Weight/bias live in feature dimension (N). + mW = convert_from_dlpack_cute( + weight.detach(), + leading_dim=0, + alignment=16, + divisibility=128 // cutlass.Float32.width, + ) + mB = ( + convert_from_dlpack_cute( + bias.detach(), + leading_dim=0, + alignment=16, + divisibility=128 // cutlass.Float32.width, + ) + if bias is not None + else None + ) + + mRstd = ( + from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + if rstd is not None + else None + ) + mMean = ( + from_dlpack(mean.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + if mean is not None + else None + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + key = (N, dtype, mB is not None, mRstd is not None, mMean is not None) + compiled = _COMPILE_CACHE.get(key) + if compiled is None: + op = LayerNormSM100(dtype, N) + compiled = cute.compile( + op, + mX, + mW, + mB, + mO, + mRstd, + mMean, + stream, + Float32(eps), + ) + _COMPILE_CACHE[key] = compiled + + compiled( + mX, + mW, + mB, + mO, + mRstd, + mMean, + stream, + Float32(eps), + ) + + if return_mean and return_rstd: + return out, rstd, mean + if return_rstd and not return_mean: + return out, rstd + if return_mean and not return_rstd: + return out, mean + return out + + +def _can_use_ptr_path(x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> bool: + """Return True if we can safely use the pointer-based fast path. + + This is intentionally conservative: we target the common inference-like + layout (2D row-major with stride(1)==1) and Quack-style fp32 weights. + """ + if not x.is_cuda or x.dim() != 2: + return False + if x.stride(1) != 1: + return False + if not weight.is_cuda or weight.dim() != 1: + return False + if weight.dtype != torch.float32: + return False + if not weight.is_contiguous(): + return False + if bias is not None: + if not bias.is_cuda or bias.dim() != 1: + return False + if bias.dtype != torch.float32: + return False + if not bias.is_contiguous(): + return False + # Require 16B alignment for 128-bit vector copies (matches Quack's assumed_align=16). + if (x.data_ptr() % 16) != 0: + return False + if (weight.data_ptr() % 16) != 0: + return False + if bias is not None and (bias.data_ptr() % 16) != 0: + return False + # The kernel uses 128-bit vectorized loads; require the leading dimension + # to preserve 16B alignment for every row start. + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + divby = 128 // dtype_x.width + if (x.stride(0) % divby) != 0: + return False + return True + + +def _layernorm_forward_ptr_into( + *, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + out: Tensor, + rstd: Optional[Tensor], + mean: Optional[Tensor], + eps: float, +) -> None: + """Launch the pointer-based LayerNorm kernel into preallocated outputs.""" + assert x.is_cuda and x.dim() == 2 + M, N = x.shape + assert weight.is_cuda and weight.dim() == 1 and weight.shape[0] == N + if bias is not None: + assert bias.is_cuda and bias.dim() == 1 and bias.shape[0] == N + assert out.is_cuda and out.shape == x.shape and out.dtype == x.dtype + assert out.stride() == x.stride(), "Pointer path expects out to match x strides" + if rstd is not None: + assert rstd.is_cuda and rstd.shape == (M,) and rstd.dtype == torch.float32 + if mean is not None: + assert mean.is_cuda and mean.shape == (M,) and mean.dtype == torch.float32 + + device_index = x.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + key = ( + "ptr", + int(N), + dtype_x, + bias is not None, + rstd is not None, + mean is not None, + int(device_index), + ) + compiled = _PTR_COMPILE_CACHE.get(key) + if compiled is None: + op = LayerNormSM100(dtype_x, int(N)) + ptr_x = rt.make_ptr( + dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_out = rt.make_ptr( + dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_w = rt.make_ptr( + cutlass.Float32, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) + ptr_b = ( + rt.make_ptr( + cutlass.Float32, + bias.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) + if bias is not None + else None + ) + ptr_rstd = ( + rt.make_ptr( + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if rstd is not None + else None + ) + ptr_mean = ( + rt.make_ptr( + cutlass.Float32, + mean.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if mean is not None + else None + ) + ld = Int32(int(x.stride(0))) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_x, + ptr_w, + ptr_b, + ptr_out, + ptr_rstd, + ptr_mean, + Int32(int(M)), + ld, + stream, + Float32(float(eps)), + ) + _PTR_COMPILE_CACHE[key] = compiled + + ptr_x = rt.make_ptr(dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_out = rt.make_ptr( + dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_w = rt.make_ptr( + cutlass.Float32, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) + ptr_b = ( + rt.make_ptr( + cutlass.Float32, + bias.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) + if bias is not None + else None + ) + ptr_rstd = ( + rt.make_ptr( + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if rstd is not None + else None + ) + ptr_mean = ( + rt.make_ptr( + cutlass.Float32, + mean.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if mean is not None + else None + ) + ld = Int32(int(x.stride(0))) + compiled( + ptr_x, + ptr_w, + ptr_b, + ptr_out, + ptr_rstd, + ptr_mean, + Int32(int(M)), + ld, + stream, + Float32(float(eps)), + ) + + +def layernorm_ref( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + eps: float = 1e-6, +) -> Tensor: + """ + Reference LayerNorm implemented via torch.nn.functional.layer_norm. + """ + x_f32 = x.float() + w = weight.float() + b = bias.float() if bias is not None else None + y = torch.nn.functional.layer_norm(x_f32, (x.shape[-1],), w, b, eps) + return y.to(x.dtype) + + +def _as_2d(x: Tensor) -> Tuple[Tensor, Tuple[int, ...]]: + if x.dim() == 2: + return x, x.shape + original_shape = x.shape + M = int(torch.prod(torch.tensor(original_shape[:-1])).item()) + N = original_shape[-1] + return x.reshape(M, N), original_shape + + +def _restore_shape(x: Tensor, shape: Tuple[int, ...]) -> Tensor: + return x.reshape(shape) + + +@cute.kernel +def _layernorm_backward_dx_kernel( + mX: cute.Tensor, + mW: cute.Tensor, + mdO: cute.Tensor, + mRstd: cute.Tensor, + mMean: cute.Tensor, + mdX: cute.Tensor, +): + """ + Simple CTA-per-row LayerNorm backward kernel for dx only. + + Each block processes one row of shape (N,), using block_threads threads. + It performs two passes over the row: + 1) Compute mean_wdy and mean_xhat_wdy in fp32. + 2) Compute dx using the standard LayerNorm backward formula: + dx = rstd * (wdy - mean_wdy - x_hat * mean_xhat_wdy), + where wdy = dy * gamma and x_hat = (x - mean) * rstd. + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + block_threads = const_expr(256) + shape = mX.shape + M = shape[0] + N = shape[1] + + row = bidx + if row < M: + # Shared buffers for warp-level reductions across the block. + smem = cutlass.utils.SmemAllocator() + num_warps = const_expr(block_threads // cute.arch.WARP_SIZE) + warp_sums_layout = cute.make_layout((num_warps,), stride=(1,)) + warp_sums_wdy = smem.allocate_tensor(Float32, warp_sums_layout, byte_alignment=4) + warp_sums_xhatwdy = smem.allocate_tensor(Float32, warp_sums_layout, byte_alignment=4) + + lane = cute.arch.lane_idx() + warp_idx = cute.arch.warp_idx() + + rstd_val = mRstd[row].to(Float32) + mean_val = mMean[row].to(Float32) + + # Pass 1: compute local partial sums of wdy and x_hat*wdy. + local_wdy = Float32(0.0) + local_xhatwdy = Float32(0.0) + for col in cutlass.range(tidx, N, block_threads): + x_val = mX[row, col].to(Float32) + dy_val = mdO[row, col].to(Float32) + gamma = mW[col].to(Float32) + x_mu = x_val - mean_val + x_hat = x_mu * rstd_val + wdy = dy_val * gamma + local_wdy += wdy + local_xhatwdy += x_hat * wdy + + # Warp-level reduction, then block-level reduction via shared memory. + red_op = operator.add # type: ignore[assignment] + local_wdy = warp_reduce(local_wdy, red_op) + local_xhatwdy = warp_reduce(local_xhatwdy, red_op) + + if lane == 0: + warp_sums_wdy[warp_idx] = local_wdy + warp_sums_xhatwdy[warp_idx] = local_xhatwdy + + cute.arch.barrier() + + total_wdy = Float32(0.0) + total_xhatwdy = Float32(0.0) + if warp_idx == 0 and lane == 0: + for wi in cutlass.range_constexpr(num_warps): + total_wdy += warp_sums_wdy[wi] + total_xhatwdy += warp_sums_xhatwdy[wi] + # Store totals back into first slots for broadcast. + warp_sums_wdy[0] = total_wdy + warp_sums_xhatwdy[0] = total_xhatwdy + + cute.arch.barrier() + + total_wdy = warp_sums_wdy[0] + total_xhatwdy = warp_sums_xhatwdy[0] + inv_N = Float32(1.0 / float(N)) + mean_wdy = total_wdy * inv_N + mean_xhatwdy = total_xhatwdy * inv_N + + # Pass 2: compute dx and write back. + for col in cutlass.range(tidx, N, block_threads): + x_val = mX[row, col].to(Float32) + dy_val = mdO[row, col].to(Float32) + gamma = mW[col].to(Float32) + x_mu = x_val - mean_val + x_hat = x_mu * rstd_val + wdy = dy_val * gamma + dx_val = (wdy - mean_wdy - x_hat * mean_xhatwdy) * rstd_val + mdX[row, col] = dx_val.to(mdX.element_type) + + +@cute.jit +def _layernorm_backward_dx( + mX: cute.Tensor, + mW: cute.Tensor, + mdO: cute.Tensor, + mRstd: cute.Tensor, + mMean: cute.Tensor, + mdX: cute.Tensor, + stream: cuda.CUstream, +) -> None: + """ + JIT wrapper that launches the dx-only LayerNorm backward kernel. + One CTA processes one row of length N with 256 threads. + """ + M = mX.shape[0] + _layernorm_backward_dx_kernel( + mX, + mW, + mdO, + mRstd, + mMean, + mdX, + ).launch( + grid=[M, 1, 1], + block=[256, 1, 1], + stream=stream, + ) + + +@cute.kernel +def _layernorm_backward_param_kernel( + mX: cute.Tensor, + mdO: cute.Tensor, + mRstd: cute.Tensor, + mMean: cute.Tensor, + mdW_partial: Optional[cute.Tensor], + mdB_partial: Optional[cute.Tensor], + num_blocks: Int32, +) -> None: + """ + Parameter-gradient kernel for LayerNorm. + + Each CTA accumulates partial dweight/dbias over a stripe of rows: + - Grid dim X: num_blocks (sm_count-style persistent CTAs). + - Threads in a CTA partition the N dimension. + - For each assigned column, a thread streams over rows + row = blockIdx.x, blockIdx.x + num_blocks, ... + + This mirrors the persistent-CTA pattern used by RMSNorm backward, + but uses a simpler per-thread accumulation since columns are + independent. + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + block_threads = const_expr(256) + M = mX.shape[0] + N = mX.shape[1] + + if bidx < num_blocks: + for col in cutlass.range(tidx, N, block_threads): + dw_local = Float32(0.0) + db_local = Float32(0.0) + for row in cutlass.range(bidx, M, num_blocks): + x_val = mX[row, col].to(Float32) + dy_val = mdO[row, col].to(Float32) + rstd_val = mRstd[row].to(Float32) + mean_val = mMean[row].to(Float32) + x_mu = x_val - mean_val + x_hat = x_mu * rstd_val + dw_local += dy_val * x_hat + db_local += dy_val + + if const_expr(mdW_partial is not None): + mdW_partial[bidx, col] = dw_local + if const_expr(mdB_partial is not None): + mdB_partial[bidx, col] = db_local + + +@cute.jit +def _layernorm_backward_param( + mX: cute.Tensor, + mdO: cute.Tensor, + mRstd: cute.Tensor, + mMean: cute.Tensor, + mdW_partial: Optional[cute.Tensor], + mdB_partial: Optional[cute.Tensor], + num_blocks: Int32, + stream: cuda.CUstream, +) -> None: + """ + JIT wrapper that launches the parameter-gradient kernel. + """ + _layernorm_backward_param_kernel( + mX, + mdO, + mRstd, + mMean, + mdW_partial, + mdB_partial, + num_blocks, + ).launch( + grid=[num_blocks, 1, 1], + block=[256, 1, 1], + stream=stream, + ) + + +def _layernorm_backward_dx_sm100( + dout_2d: Tensor, + x_2d: Tensor, + weight: Tensor, + rstd_1d: Tensor, + mean_1d: Tensor, + dx_2d: Tensor, +) -> None: + """ + Host-side helper to run the dx-only LayerNorm backward kernel. + """ + M, N = x_2d.shape + assert dout_2d.shape == (M, N) + assert rstd_1d.numel() == M + assert mean_1d.numel() == M + + dtype = TORCH2CUTE_DTYPE[x_2d.dtype] + + mX = _convert_row_major(x_2d) + mdO = _convert_row_major(dout_2d) + mdX = _convert_row_major(dx_2d) + + mW = convert_from_dlpack_cute( + weight.detach(), + leading_dim=0, + alignment=16, + divisibility=128 // cutlass.Float32.width, + ) + mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + key = (N, dtype) + compiled = _BWD_DX_COMPILE_CACHE.get(key) + if compiled is None: + compiled = cute.compile( + _layernorm_backward_dx, + mX, + mW, + mdO, + mRstd, + mMean, + mdX, + stream, + ) + _BWD_DX_COMPILE_CACHE[key] = compiled + + compiled( + mX, + mW, + mdO, + mRstd, + mMean, + mdX, + stream, + ) + + +def _layernorm_backward_params_sm100( + dout_2d: Tensor, + x_2d: Tensor, + rstd_1d: Tensor, + mean_1d: Tensor, + dw_partial: Optional[Tensor], + db_partial: Optional[Tensor], + sm_count: int, +) -> None: + """ + Host-side helper to run the parameter-gradient kernel that populates + dw_partial / db_partial of shape (sm_count, N). + """ + M, N = x_2d.shape + assert dout_2d.shape == (M, N) + assert rstd_1d.numel() == M + assert mean_1d.numel() == M + if dw_partial is None and db_partial is None: + return + + dtype = TORCH2CUTE_DTYPE[x_2d.dtype] + + mX = _convert_row_major(x_2d) + mdO = _convert_row_major(dout_2d) + mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + + mdW_partial = ( + from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) + if dw_partial is not None + else None + ) + mdB_partial = ( + from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) + if db_partial is not None + else None + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + has_bias = db_partial is not None + key = (N, dtype, has_bias) + compiled = _BWD_PARAM_COMPILE_CACHE.get(key) + if compiled is None: + compiled = cute.compile( + _layernorm_backward_param, + mX, + mdO, + mRstd, + mMean, + mdW_partial, + mdB_partial, + Int32(sm_count), + stream, + ) + _BWD_PARAM_COMPILE_CACHE[key] = compiled + + compiled( + mX, + mdO, + mRstd, + mMean, + mdW_partial, + mdB_partial, + Int32(sm_count), + stream, + ) + + +def layernorm_backward( + dout: Tensor, + x: Tensor, + weight: Tensor, + rstd: Tensor, + mean: Tensor, + bias: Optional[Tensor] = None, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """ + LayerNorm backward implemented in CuteDSL / CUTLASS. + + Computes gradients w.r.t. input, weight, and optional bias using + two kernels: + - A dx kernel (CTA-per-row) that streams over N. + - A parameter-gradient kernel that accumulates dw/db over a + persistent grid of CTAs across the M dimension. + """ + assert x.shape == dout.shape, "x and dout must have the same shape" + assert x.is_cuda and dout.is_cuda, "x and dout must be CUDA tensors" + assert weight.dim() == 1, "weight must be 1D" + if bias is not None: + assert bias.dim() == 1, "bias must be 1D" + + x_2d, orig_shape = _as_2d(x) + dout_2d, _ = _as_2d(dout) + M, N = x_2d.shape + + # Flatten to 2D for the kernels. + mean_flat = mean.view(M) + rstd_flat = rstd.view(M) + + dx_2d = torch.empty_like(x_2d) + _layernorm_backward_dx_sm100( + dout_2d, + x_2d, + weight, + rstd_flat, + mean_flat, + dx_2d, + ) + + device = x.device + sm_count = get_sm_count(N, device, M=M, dtype=x.dtype) + + dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) + db_partial = ( + torch.empty(sm_count, N, device=device, dtype=torch.float32) + if bias is not None + else None + ) + + _layernorm_backward_params_sm100( + dout_2d, + x_2d, + rstd_flat, + mean_flat, + dw_partial, + db_partial, + sm_count, + ) + + dweight = dw_partial.sum(dim=0).to(weight.dtype) + dbias = db_partial.sum(dim=0).to(bias.dtype) if bias is not None else None + + dx = _restore_shape(dx_2d, orig_shape) + return dx, dweight, dbias + + +if __name__ == "__main__": + # Allow direct execution for a quick functional check. + if not torch.cuda.is_available(): + print("CUDA not available; LayerNormSM100 test skipped.") + raise SystemExit(0) + + device = "cuda" + M, N = 2048, 4096 + dtype = torch.bfloat16 + x = torch.randn(M, N, device=device, dtype=dtype) + w = torch.randn(N, device=device, dtype=torch.float32) + b = torch.randn(N, device=device, dtype=torch.float32) + + y_ref = layernorm_ref(x, w, b) + y, rstd, mean = layernorm(x, w, b, return_rstd=True, return_mean=True) + torch.testing.assert_close( + y, + y_ref, + atol=5e-2 if dtype != torch.float32 else 1e-5, + rtol=5e-2 if dtype != torch.float32 else 1e-5, + ) + + print("LayerNormSM100 forward correctness check passed.") diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py index 14ae723..1bc15b1 100644 --- a/oink/src/kernelagent_oink/blackwell/lite_quack.py +++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py @@ -13,21 +13,24 @@ # limitations under the License. """ -Lightweight local clone of the small subset of helpers that the SM100 +Lightweight local clone of the small subset of Quack helpers that the SM100 RMSNorm CuteDSL kernels depend on. This module intentionally avoids importing the `quack` package so that -Oink Blackwell kernels can run without Quack installed, while keeping -numerical behaviour and performance close to the original reference -implementations. +KernelAgent Oink SM100 kernels can run without Quack installed, while keeping +numerical behaviour and performance identical to the reference kernels. """ from __future__ import annotations import math import operator -from typing import Callable, Optional +import importlib.metadata +import re +from functools import partial +from typing import Callable, Optional, Tuple, Type +import cuda.bindings.driver as cuda # type: ignore import torch from torch import Tensor @@ -36,11 +39,39 @@ from cutlass import Float32, Int32, const_expr from cutlass.cute.runtime import from_dlpack from cutlass.cutlass_dsl import T, dsl_user_op -from cutlass._mlir.dialects import llvm +from cutlass._mlir.dialects import llvm, nvvm, vector + + +def _parse_version_tuple(version: str) -> tuple[int, int, int]: + parts = version.split(".") + nums: list[int] = [] + for part in parts[:3]: + match = re.match(r"^(\d+)", part) + nums.append(int(match.group(1)) if match is not None else 0) + while len(nums) < 3: + nums.append(0) + return nums[0], nums[1], nums[2] + + +def _cutlass_dsl_version() -> Optional[tuple[int, int, int]]: + try: + return _parse_version_tuple(importlib.metadata.version("nvidia-cutlass-dsl")) + except Exception: + return None + + +_CUTLASS_DSL_VERSION = _cutlass_dsl_version() +# CuTeDSL 4.3.4 tightened some kernel argument expectations (notably around +# passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the +# older signature for <4.3.4, but switch to a 4.3.4+ compatible signature when +# we detect 4.3.4+ (or when version detection is unavailable). +_KERNEL_ACCEPTS_LAYOUT_ARGS = ( + _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4) +) # ------------------------- -# Dtype mapping +# Dtype mapping (from quack.cute_dsl_utils) # ------------------------- TORCH2CUTE_DTYPE = { @@ -51,21 +82,15 @@ # ------------------------- -# Tensor conversion helpers +# Tensor conversion helpers (from quack.utils) # ------------------------- - def convert_from_dlpack( x: Tensor, leading_dim: int, alignment: int = 16, divisibility: int = 1, ) -> cute.Tensor: - """ - Wrap a torch.Tensor in a CuteDSL tensor with layout metadata that - matches the logical leading dimension and alignment/divisibility - constraints expected by SM100 kernels. - """ return ( from_dlpack(x, assumed_align=alignment) .mark_layout_dynamic(leading_dim=leading_dim) @@ -78,14 +103,12 @@ def convert_from_dlpack( # ------------------------- -# SM90/SM100 cluster helpers +# SM90/SM100 cluster helpers (from quack.utils) # ------------------------- @dsl_user_op -def elem_pointer( - x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None -) -> cute.Pointer: +def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) @@ -136,9 +159,7 @@ def store_shared_remote( ).ir_value() if const_expr(isinstance(val, float)): val = Float32(val) - assert isinstance(val, (Float32, Int32, cutlass.Int64)), ( - "val must be Float32, Int32, or Int64" - ) + assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64" suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)] constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)] llvm.inline_asm( @@ -154,37 +175,22 @@ def store_shared_remote( @cute.jit def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: - """ - Build a predicate tensor for the K dimension only. Values beyond - `limit` are masked out. - """ + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if". tApA = cute.make_fragment( cute.make_layout( - ( - cute.size(tAcA, mode=[0, 1]), - cute.size(tAcA, mode=[1]), - cute.size(tAcA, mode=[2]), - ), + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), stride=(cute.size(tAcA, mode=[2]), 0, 1), ), cutlass.Boolean, ) for rest_v in cutlass.range_constexpr(tApA.shape[0]): for rest_k in cutlass.range_constexpr(tApA.shape[2]): - tApA[rest_v, 0, rest_k] = cute.elem_less( - tAcA[(0, rest_v), 0, rest_k][1], limit - ) + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) return tApA @dsl_user_op -def domain_offset_i64( - coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None -) -> cute.Tensor: - """ - Return a tensor whose iterator is offset by an Int64 byte offset - computed from `coord` and the tensor's strides. - """ +def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) flat_stride = cute.flatten_to_tuple(tensor.stride) assert len(flat_coord_i64) == len(flat_stride), ( @@ -201,8 +207,81 @@ def domain_offset_i64( return cute.make_tensor(new_ptr, tensor.layout) +@dsl_user_op +def coord_offset_i64( + idx: cute.typing.Int, + tensor: cute.Tensor, + dim: int, + *, + loc=None, + ip=None, +) -> cute.Tensor: + offset = cutlass.Int64(idx) * cute.size(tensor.stride[dim]) + assert isinstance(tensor.iterator, cute.Pointer) + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +@cute.jit +def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cutlass.Numeric) -> None: + """Fill out-of-bounds values in shared memory tensor.""" + tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0]) + tXrX_fill.fill(fill_value) + for rest_v in cutlass.range_constexpr(const_expr(tXsX.shape[0][1])): + for rest_k in cutlass.range_constexpr(const_expr(tXsX.shape[2])): + if const_expr(tXpX is not None): + if not tXpX[rest_v, 0, rest_k]: + cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k]) + else: + cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k]) + + +@dsl_user_op +def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64: + """Pack two f32 values into a single i64. + + This mirrors quack.utils.f32x2_to_i64 and is used by online_softmax_reduce + to store (max, sum_exp) pairs in an Int64 reduction buffer. + """ + vec_f32x2 = vector.from_elements( + T.vector(2, T.f32()), + (a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)), + loc=loc, + ip=ip, + ) + vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2, loc=loc, ip=ip) + res = cutlass.Int64( + vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip) + ) + return res + + +@dsl_user_op +def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + """Unpack a single i64 into two f32 values, inverse of f32x2_to_i64.""" + vec_i64x1 = vector.from_elements( + T.vector(1, T.i64()), + (c.ir_value(loc=loc, ip=ip),), + loc=loc, + ip=ip, + ) + vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1, loc=loc, ip=ip) + res0 = Float32( + vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip) + ) + res1 = Float32( + vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip) + ) + return res0, res1 + + # ------------------------- -# Reduction helpers +# Reduction helpers (from quack.reduce) # ------------------------- @@ -212,10 +291,6 @@ def warp_reduce( op: Callable, width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, ) -> cute.TensorSSA | cute.Numeric: - """ - Warp-level reduction for either scalar values or small TensorSSA - fragments. - """ if cutlass.const_expr(isinstance(val, cute.TensorSSA)): res = cute.make_fragment(val.shape, val.dtype) res.store(val) @@ -234,7 +309,7 @@ def block_reduce( reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0, ) -> cute.Numeric: - """Block-level reduction across warps.""" + """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row).""" lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() warps_per_row = cute.size(reduction_buffer.shape[1]) row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row @@ -256,10 +331,7 @@ def cluster_reduce( init_val: cute.Numeric = 0.0, phase: Optional[cutlass.Int32] = None, ) -> cute.Numeric: - """ - Cluster-wide reduction using shared memory and mbarrier. The - reduction_buffer has shape (rows_per_block, (warps_per_row, cluster_n)). - """ + """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n)).""" cta_rank_in_cluster = cute.arch.block_idx_in_cluster() lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape @@ -297,12 +369,10 @@ def block_or_cluster_reduce( phase: Optional[cutlass.Int32] = None, init_val: cute.Numeric = 0.0, ) -> cute.Numeric: - """Dispatch between block or cluster reduction depending on mbar_ptr.""" + """Perform either block or cluster reduction based on whether mbar_ptr is provided.""" if cutlass.const_expr(mbar_ptr is None): return block_reduce(val, op, reduction_buffer, init_val=init_val) - return cluster_reduce( - val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase - ) + return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase) @cute.jit @@ -316,21 +386,14 @@ def row_reduce( init_val: cute.Numeric = 0.0, hook_fn: Optional[Callable] = None, ) -> cute.Numeric: - """ - Row-wise reduction used by RMSNorm and similar kernels. - - reduction_buffer must have shape - (num_warps / warps_per_row, (warps_per_row, cluster_n)). - """ + """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n)).""" if cutlass.const_expr(isinstance(x, cute.TensorSSA)): val = x.reduce(op, init_val=init_val, reduction_profile=0) else: val = x warp_op = { cute.ReductionOp.ADD: operator.add, - cute.ReductionOp.MAX: cute.arch.fmax - if cutlass.const_expr(x.dtype == Float32) - else max, + cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max, cute.ReductionOp.MIN: min, cute.ReductionOp.MUL: operator.mul, }[op] @@ -358,28 +421,808 @@ def row_reduce( return val +@cute.jit +def row_reduce_add( + x: cute.TensorSSA | cute.Numeric, + threads_per_row: cutlass.Constexpr[int], + reduction_buffer: Optional[cute.Tensor] = None, + mbar_ptr: Optional[cute.Pointer] = None, + phase: Optional[cutlass.Int32] = None, + init_val: cute.Numeric = 0.0, + hook_fn: Optional[Callable] = None, +) -> cute.Numeric: + """Specialized row_reduce for ADD reductions. + + This mirrors row_reduce but hardcodes the ADD operation so we avoid + dynamic dispatch on the reduction op. It is used by bandwidth-bound + kernels like RMSNorm backward where the reduction is always ADD in + Float32. + """ + if cutlass.const_expr(isinstance(x, cute.TensorSSA)): + val = x.reduce(cute.ReductionOp.ADD, init_val=init_val, reduction_profile=0) + else: + val = x + val = warp_reduce( + val, + operator.add, + width=min(threads_per_row, cute.arch.WARP_SIZE), + ) + if cutlass.const_expr(hook_fn is not None): + hook_fn() + if cutlass.const_expr(reduction_buffer is not None): + warps_per_row, cluster_n = reduction_buffer.shape[1] + assert cluster_n == 1 or mbar_ptr is not None, ( + "mbar_ptr must be provided for cluster reduction" + ) + if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1): + val = block_or_cluster_reduce( + val, + operator.add, + reduction_buffer, + mbar_ptr, + phase=phase, + init_val=init_val, + ) + return val + + +@cute.jit +def online_softmax_reduce( + x: cute.TensorSSA, + threads_per_row: cutlass.Constexpr[int], + reduction_buffer: Optional[cute.Tensor] = None, + mbar_ptr: Optional[cute.Pointer] = None, + hook_fn: Optional[Callable] = None, + phase: Optional[cutlass.Int32] = None, + return_exp_x: bool = False, +) -> tuple[Float32, Float32, Optional[cute.TensorSSA]]: + """Online softmax reduction over a row. + + This mirrors quack.reduce.online_softmax_reduce and computes: + - max_x: row-wise maximum of x + - sum_exp_x: row-wise sum of exp(x - max_x) + - exp_x (optional): per-element exp(x - max_x_final) if return_exp_x is True + """ + assert x.dtype == Float32, "x must be of type Float32" + # reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2) + max_x = warp_reduce( + x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + cute.arch.fmax, + width=min(threads_per_row, cute.arch.WARP_SIZE), + ) + log2_e = math.log2(math.e) + exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True) + sum_exp_x = warp_reduce( + exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0), + operator.add, + width=min(threads_per_row, cute.arch.WARP_SIZE), + ) + if cutlass.const_expr(hook_fn is not None): + hook_fn() + if cutlass.const_expr(reduction_buffer is not None): + rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape + assert cluster_n == 1 or mbar_ptr is not None, ( + "mbar_ptr must be provided for cluster reduction" + ) + if cutlass.const_expr(warps_per_row > 1 or cluster_n > 1): + assert reduction_buffer.element_type == cutlass.Int64, ( + "reduction_buffer must be of type Int64" + ) + lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() + row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row + if cutlass.const_expr(mbar_ptr is None): + if lane_idx == 0: + reduction_buffer[row_idx, col_idx] = f32x2_to_i64(max_x, sum_exp_x) + cute.arch.barrier() + max_x_single_warp = -Float32.inf + sum_exp_x = 0.0 + if lane_idx < warps_per_row: + max_x_single_warp, sum_exp_x = i64_to_f32x2( + reduction_buffer[row_idx, lane_idx] + ) + max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax) + sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True) + sum_exp_x = warp_reduce(sum_exp_x, operator.add) + if cutlass.const_expr(return_exp_x): + exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True) + max_x = max_x_final + else: + cta_rank_in_cluster = cute.arch.block_idx_in_cluster() + if warp_idx == 0: + with cute.arch.elect_one(): + num_warps = rows_per_block * warps_per_row + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_ptr, + num_warps * cluster_n * reduction_buffer.element_type.width // 8, + ) + if lane_idx < cluster_n: + store_shared_remote( + f32x2_to_i64(max_x, sum_exp_x), + elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))), + mbar_ptr, + peer_cta_rank_in_cluster=lane_idx, + ) + cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0) + num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE) + max_x_single_warp = cute.make_fragment(num_iter, Float32) + max_x_single_warp.fill(-Float32.inf) + sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32) + sum_exp_x_single_warp.fill(0.0) + for i in cutlass.range_constexpr(num_iter): + idx = lane_idx + i * cute.arch.WARP_SIZE + if idx < cute.size(reduction_buffer, mode=[1]): + max_x_single_warp[i], sum_exp_x_single_warp[i] = i64_to_f32x2( + reduction_buffer[row_idx, idx] + ) + max_x_final = max_x_single_warp.load().reduce( + cute.ReductionOp.MAX, + init_val=-Float32.inf, + reduction_profile=0, + ) + max_x_final = warp_reduce(max_x_final, cute.arch.fmax) + sum_exp_x = 0.0 + for i in cutlass.range_constexpr(num_iter): + sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp( + max_x_single_warp[i] - max_x_final, + fastmath=True, + ) + sum_exp_x = warp_reduce(sum_exp_x, operator.add) + if cutlass.const_expr(return_exp_x): + exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True) + max_x = max_x_final + return max_x, sum_exp_x, (exp_x if cutlass.const_expr(return_exp_x) else None) + + +# ------------------------- +# Copy helpers (minimal subset of quack.copy_utils) +# ------------------------- + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], + num_copy_elems: int, + is_async: bool = False, + *, + loc=None, + ip=None, +) -> cute.CopyAtom: + from cutlass.cute.nvgpu import cpasync + + num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits, loc=loc, ip=ip) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + num_copy_elems: int = 1, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async, loc=loc, ip=ip) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +# ------------------------- +# Reduction base (from quack.reduction_base) +# ------------------------- + + +class ReductionBase: + def __init__( + self, + dtype: Type[cutlass.Numeric], + N: int, + stage: int, + reduction_dtype: Type[cutlass.Numeric] = cutlass.Float32, + ): + self.dtype = dtype + self.N = N + self.stage = stage + self.reduction_dtype = reduction_dtype + + def _calculate_threads_per_row(self) -> int: + raise NotImplementedError() + + def _set_cluster_n(self) -> None: + self.cluster_n = 1 + + def _get_num_threads(self) -> int: + return 128 if self.N <= 16384 else 256 + + def _get_tv_layout(self, num_copy_bits: int = 128) -> Tuple[cute.Shape, cute.Layout]: + vecsize = num_copy_bits // self.dtype.width + assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}" + num_threads = self._get_num_threads() + assert num_threads % cute.arch.WARP_SIZE == 0 + + threads_per_row = self._calculate_threads_per_row() + self._set_cluster_n() + num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n) + cols_per_block = num_threads // threads_per_row + tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) + tv_layout = cute.make_layout( + ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)), + stride=( + (vecsize * cols_per_block, 1), + (cols_per_block, cols_per_block * vecsize * threads_per_row), + ), + ) + return tiler_mn, tv_layout + + def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int: + return ( + cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) + + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8) + + self.stage * (cutlass.Int64.width // 8) + ) + + def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int) -> cute.Layout: + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1) + return cute.make_ordered_layout( + (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage), + order=(1, 0, 2), + ) + + def _allocate_reduction_buffer_and_mbar( + self, + smem: cutlass.utils.SmemAllocator, + tv_layout: cute.Layout, + is_persistent: bool = False, + ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]: + reduction_buffer = smem.allocate_tensor( + self.reduction_dtype, + self._get_reduction_buffer_layout(tv_layout, self.cluster_n), + byte_alignment=4, + ) + if cutlass.const_expr(self.cluster_n > 1): + mbar_ptr = smem.allocate_array( + cutlass.Int64, + num_elems=self.stage if not is_persistent else self.stage * 2, + ) + else: + mbar_ptr = None + return reduction_buffer, mbar_ptr + + @cute.jit + def _initialize_cluster( + self, + tidx: cutlass.Int32, + mbar_ptr: Optional[cute.Pointer], + num_warps: int, + is_persistent: bool = False, + ) -> None: + if cutlass.const_expr(self.cluster_n > 1 and mbar_ptr is not None): + if tidx < self.stage: + cute.arch.mbarrier_init(mbar_ptr + tidx, 1) + if cutlass.const_expr(is_persistent): + cute.arch.mbarrier_init( + mbar_ptr + self.stage + tidx, + num_warps * self.cluster_n, + ) + cute.arch.mbarrier_init_fence() + cute.arch.cluster_arrive_relaxed() + + # ------------------------- -# SM count helper +# RMSNorm backward base (from quack.rmsnorm.RMSNormBackward) # ------------------------- -def get_sm_count(N: int, device: torch.device) -> int: +class RMSNormBackward(ReductionBase): + def __init__(self, dtype: cutlass.Numeric, N: int): + # 2 stages for double buffering when computing mean of x_hat * wdy + super().__init__(dtype, N, stage=2, reduction_dtype=Float32) + self.reload_wdy = None if N <= 16 * 1024 else "smem" + if self.N > 128 * 1024 and self.dtype.width >= 32: + raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits") + + def _get_num_threads(self) -> int: + return 128 if self.N <= 4096 else 256 + + def _calculate_threads_per_row(self) -> int: + N = self.N + return ( + 8 + if N <= 64 + else ( + 16 + if N <= 128 + else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256))) + ) + ) + + def _set_cluster_n(self) -> None: + N = self.N + cluster_n = ( + 1 + if N <= 8 * 1024 + else (2 if N <= 16 * 1024 else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16))) + ) + self.cluster_n = cluster_n + + def _smem_size_in_bytes(self, tiler_mn, num_warps: int, do_dtype=None) -> int: + if do_dtype is None: + do_dtype = self.dtype + return ( + cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 + + cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2 + + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8) + + self.stage * (cutlass.Int64.width // 8) * 2 + ) + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mdO: cute.Tensor, + mdResO: Optional[cute.Tensor], + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: Optional[cute.Tensor], + mdRes: Optional[cute.Tensor], + mdB: Optional[cute.Tensor], + sm_count: Int32, + stream: cuda.CUstream, + ): + semistatic_shape = (*mX.shape[:-1], self.N) + + def new_stride(t): + return ( + cute.assume(t.stride[0], divby=128 // t.element_type.width), + t.stride[1], + ) + + mX, mdO, mdResO, mdX, mdRes = [ + cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))) + if const_expr(t is not None) + else None + for t in (mX, mdO, mdResO, mdX, mdRes) + ] + self._set_cluster_n() + largest_dtype_width = const_expr( + max( + mX.element_type.width, + mdO.element_type.width, + mdX.element_type.width, + mdResO.element_type.width if mdResO is not None else 0, + mdRes.element_type.width if mdRes is not None else 0, + ) + ) + tiler_mn, tv_layout = self._get_tv_layout( + num_copy_bits=128 // largest_dtype_width * mX.element_type.width + ) + num_threads = ( + cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + if const_expr(mW is not None): + mW_expanded_layout = cute.prepend( + mW.layout, + cute.make_layout((tiler_mn[0],), stride=(0,)), + ) + mW = cute.make_tensor(mW.iterator, mW_expanded_layout) + + num_blocks = sm_count + kernel = ( + self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes) + ) + kernel.launch( + grid=[num_blocks, self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type), + stream=stream, + ) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mdO: cute.Tensor, + mdResO: Optional[cute.Tensor], + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: Optional[cute.Tensor], + mdB: Optional[cute.Tensor], + mdRes: Optional[cute.Tensor], + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx_start, _, _ = cute.arch.block_idx() + gdim, _, _ = cute.arch.grid_dim() + if const_expr(self.cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = const_expr(0) + + shape = mX.shape + M = shape[0] + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + + idX = cute.make_identity_tensor(shape) + + smem = cutlass.utils.SmemAllocator() + smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2)) + sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16) + sdO = smem.allocate_tensor(mdO.element_type, smem_layout, byte_alignment=16) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, + tv_layout, + is_persistent=True, + ) + if const_expr(mbar_ptr is not None): + mbar_full_ptr, mbar_empty_ptr = mbar_ptr, mbar_ptr + 2 + else: + mbar_full_ptr, mbar_empty_ptr = None, None + + num_copy_elems_X = tv_layout.shape[1][0] + copy_atom_load_X = get_copy_atom(mX.element_type, num_copy_elems_X, is_async=False) + thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx) + copy_fn = partial(copy, num_copy_elems=num_copy_elems_X) + + gX, gdO, gdResO, gdX, gdRes, cX = [ + cute.local_tile(mT, tiler_mn, (None, cluster_y)) if mT is not None else None + for mT in (mX, mdO, mdResO, mdX, mdRes, idX) + ] + gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) if mW is not None else None + gdW, gdB = [ + cute.local_tile(mT, (1, tiler_mn[1]), (bidx_start, cluster_y)) + if const_expr(mT is not None) + else None + for mT in (mdW, mdB) + ] + + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + tXgdO = thr_copy_X.partition_S(gdO) + tXsdO = thr_copy_X.partition_D(sdO) + tXgdX = thr_copy_X.partition_D(gdX) + if const_expr(mdResO is not None): + tXgdResO = thr_copy_X.partition_S(gdResO) + if const_expr(mdRes is not None): + tXgdRes = thr_copy_X.partition_D(gdRes) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None] + + tXrX, tXrdO, tXrdX = [ + cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX) + ] + tXrdResO = None + if const_expr(mdResO is not None): + tXrdResO = cute.make_fragment_like(tXgdResO[None, None, None, 0]) + tXrdRes = None + if const_expr(mdRes is not None): + tXrdRes = cute.make_fragment_like(tXgdRes[None, None, None, 0]) + + tXpX = ( + predicate_k(thr_copy_X.partition_S(cX[None, None, 0]), limit=shape[1]) + if not is_even_N + else None + ) + + tXgdW, tXrdW = None, None + tXgdB, tXrdB = None, None + if const_expr(mdW is not None): + tXgdW = thr_copy_X.partition_S(gdW) + tXrdW = cute.make_fragment_like(tXgdW, Float32) + if const_expr(mdB is not None): + tXgdB = thr_copy_X.partition_S(gdB) + tXrdB = cute.make_fragment_like(tXgdB, Float32) + + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + + self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=True) + + tXrW = None + if const_expr(mW is not None): + tXgW = thr_copy_X.partition_S(gW) + tXrW = cute.make_fragment_like(tXgW) + if not is_even_N: + tXrW.fill(0.0) + copy_fn(tXgW, tXrW, pred=tXpX) + + row = tXcX[None, None, None, bidx_start][0][0] + if row < M: + tXgX_cur = coord_offset_i64(bidx_start, tXgX, dim=3)[None, None, None, 0] + tXgdO_cur = coord_offset_i64(bidx_start, tXgdO, dim=3)[None, None, None, 0] + copy_fn(tXgX_cur, tXsX[None, None, None, 0], pred=tXpX, is_async=True) + copy_fn(tXgdO_cur, tXsdO[None, None, None, 0], pred=tXpX, is_async=True) + elif tiler_mn[0] > 1: + fill_oob(tXsX[None, None, None, 0], None, fill_value=mX.element_type.zero) + fill_oob(tXsdO[None, None, None, 0], None, fill_value=mdO.element_type.zero) + cute.arch.cp_async_commit_group() + + if const_expr(self.cluster_n > 1): + cute.arch.cluster_wait() + + threads_per_row = tv_layout.shape[0][0] + if const_expr(mdW is not None): + tXrdW.fill(0.0) + if const_expr(mdB is not None): + tXrdB.fill(0.0) + stage = Int32(0) + producer_phase = Int32(1) + consumer_phase = Int32(0) + for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim): + row = tXcX[None, None, None, bidx][0][0] + if row + gdim * tiler_mn[0] < M: + tXgX_cur = coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0] + tXgdO_cur = coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0] + copy_fn(tXgX_cur, tXsX[None, None, None, stage ^ 1], pred=tXpX, is_async=True) + copy_fn(tXgdO_cur, tXsdO[None, None, None, stage ^ 1], pred=tXpX, is_async=True) + elif tiler_mn[0] > 1: + fill_oob( + tXsX[None, None, None, stage ^ 1], + None, + fill_value=mX.element_type.zero, + ) + fill_oob( + tXsdO[None, None, None, stage ^ 1], + None, + fill_value=mdO.element_type.zero, + ) + cute.arch.cp_async_commit_group() + rstd_val = cutlass.Float.zero + if row < M or tiler_mn[0] == 1: + rstd_val = mRstd[row] + if const_expr(mdResO is not None): + tXgdResO_cur = coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0] + if row < M or tiler_mn[0] == 1: + copy_fn(tXgdResO_cur, tXrdResO, pred=tXpX) + elif tiler_mn[0] > 1: + tXrdResO.fill(0.0) + cute.arch.cp_async_wait_group(1) + cute.autovec_copy(tXsX[None, None, None, stage], tXrX) + x = tXrX.load().to(cute.Float32) + cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO) + dout = tXrdO.load().to(cute.Float32) + x_hat = x * rstd_val + wdy = dout + if const_expr(mW is not None): + wdy *= tXrW.load().to(Float32) + if const_expr(self.cluster_n > 1): + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase) + mean_xhat_wdy = ( + row_reduce_add( + x_hat * wdy, + threads_per_row, + reduction_buffer[None, None, stage], + (mbar_full_ptr + stage if const_expr(self.cluster_n > 1) else None), + phase=consumer_phase, + init_val=0.0, + ) + / shape[1] + ) + + if const_expr(self.cluster_n > 1): + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + cute.arch.sync_warp() + lane_idx = cute.arch.lane_idx() + if lane_idx < self.cluster_n: + cute.arch.mbarrier_arrive( + mbar_empty_ptr + stage, + peer_cta_rank_in_cluster=lane_idx, + ) + + if const_expr(self.reload_wdy == "smem"): + cute.autovec_copy(tXsdO[None, None, None, stage], tXrdO) + dout = tXrdO.load().to(cute.Float32) + wdy = dout + if const_expr(mW is not None): + wdy *= tXrW.load().to(Float32) + + dx = (wdy - x_hat * mean_xhat_wdy) * rstd_val + if const_expr(mdResO is not None): + dx += tXrdResO.load().to(cute.Float32) + tXrdX.store(dx.to(tXrdX.element_type)) + if row < M or tiler_mn[0] == 1: + tXgdX_cur = coord_offset_i64(bidx, tXgdX, dim=3)[None, None, None, 0] + copy_fn(tXrdX, tXgdX_cur, pred=tXpX) + if const_expr(mdRes is not None): + tXrdRes.store(dx.to(tXrdRes.element_type)) + tXgdRes_cur = coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0] + if row < M or tiler_mn[0] == 1: + copy_fn(tXrdRes, tXgdRes_cur, pred=tXpX) + if const_expr(mdW is not None): + tXrdW.store(tXrdW.load() + dout * x_hat) + if const_expr(mdB is not None): + tXrdB.store(tXrdB.load() + dout) + + stage ^= 1 + if stage == 0: + consumer_phase ^= 1 + producer_phase ^= 1 + + if const_expr(tiler_mn[0] > 1): + if const_expr(mdW is not None): + sdW = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=cute.Float32), + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + ) + tXsdW = thr_copy_X.partition_D(sdW) + cute.arch.barrier() + row0 = tXcX[None, None, None, 0][0][0] + if row0 > 0: + cute.autovec_copy(tXrdW, tXsdW) + cute.arch.barrier() + if row0 == 0: + for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])): + tXrdW_other = cute.make_fragment_like(tXrdW) + tXsdW_other = cute.make_tensor( + tXsdW.iterator + i * sdW.stride[0], + tXsdW.layout, + ) + cute.autovec_copy(tXsdW_other, tXrdW_other) + tXrdW.store(tXrdW.load() + tXrdW_other.load()) + copy_fn(tXrdW, tXgdW, pred=tXpX) + cute.arch.barrier() + if const_expr(mdB is not None): + sdB = cute.make_tensor( + cute.recast_ptr(sX.iterator, dtype=cute.Float32), + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + ) + tXsdB = thr_copy_X.partition_D(sdB) + cute.arch.barrier() + row0 = tXcX[None, None, None, 0][0][0] + if row0 > 0: + cute.autovec_copy(tXrdB, tXsdB) + cute.arch.barrier() + if row0 == 0: + for i in cutlass.range_constexpr(1, const_expr(tiler_mn[0])): + tXrdB_other = cute.make_fragment_like(tXrdB) + tXsdB_other = cute.make_tensor( + tXsdB.iterator + i * sdB.stride[0], + tXsdB.layout, + ) + cute.autovec_copy(tXsdB_other, tXrdB_other) + tXrdB.store(tXrdB.load() + tXrdB_other.load()) + copy_fn(tXrdB, tXgdB, pred=tXpX) + else: + if const_expr(mdW is not None): + copy_fn(tXrdW, tXgdW, pred=tXpX) + if const_expr(mdB is not None): + copy_fn(tXrdB, tXgdB, pred=tXpX) + + if const_expr(self.cluster_n > 1): + stage ^= 1 + if stage == 0: + producer_phase ^= 1 + cute.arch.mbarrier_wait(mbar_empty_ptr + stage, producer_phase) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mdO: cute.Tensor, + mdResO: Optional[cute.Tensor], + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: Optional[cute.Tensor], + mdB: Optional[cute.Tensor], + mdRes: Optional[cute.Tensor], + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ): + self._kernel_impl( + mX, + mW, + mdO, + mdResO, + mRstd, + mdX, + mdW, + mdB, + mdRes, + tv_layout, + tiler_mn, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mdO: cute.Tensor, + mdResO: Optional[cute.Tensor], + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: Optional[cute.Tensor], + mdB: Optional[cute.Tensor], + mdRes: Optional[cute.Tensor], + ): + largest_dtype_width = const_expr( + max( + mX.element_type.width, + mdO.element_type.width, + mdX.element_type.width, + mdResO.element_type.width if mdResO is not None else 0, + mdRes.element_type.width if mdRes is not None else 0, + ) + ) + tiler_mn, tv_layout = self._get_tv_layout( + num_copy_bits=128 // largest_dtype_width * mX.element_type.width + ) + self._kernel_impl( + mX, + mW, + mdO, + mdResO, + mRstd, + mdX, + mdW, + mdB, + mdRes, + tv_layout, + tiler_mn, + ) + + +# ------------------------- +# SM count helper (from quack.rmsnorm._get_sm_count) +# ------------------------- + + +def get_sm_count( + N: int, + device: torch.device, + M: Optional[int] = None, + dtype: Optional[torch.dtype] = None, +) -> int: """ - Heuristic for the number of persistent CTAs (sm_count) based on N and - the GPU's SM count. This mirrors the behaviour used in Quack's - RMSNorm kernels but lives entirely in this local module. + SM count heuristic for reduction-style kernels. + + This starts from Quack's _get_sm_count policy and layers on SM100 / + DSv3-specific tuning so that: + - For DSv3-style shapes (large-M, N in {6144, 8192}, fp16/bf16), + sm_count is reduced for very large M to cut down the number of + dw_partial/db_partial rows that ever hit HBM. + - For Quack-suite hidden=4096, small-M shapes, sm_count is modestly + increased to improve SM occupancy, matching the existing SM100 + tuning used by both RMSNorm and LayerNorm. """ + props = torch.cuda.get_device_properties(device) + num_sms = props.multi_processor_count + sm_count_multiple = ( - 16 - if N <= 256 - else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1))) - ) - sm_count = torch.cuda.get_device_properties(device).multi_processor_count - sm_count = ( - sm_count * sm_count_multiple - if N <= 8192 - else sm_count // 2 - if N <= 16384 - else sm_count * 2 + 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1))) ) + sm_count = num_sms + if N <= 8192: + sm_count = sm_count * sm_count_multiple + elif N <= 16384: + sm_count = sm_count // 2 + else: + sm_count = sm_count * 2 + + # Quack-suite tuning: for small-M, hidden=4096 shapes (M<=8192) and + # 16-bit dtypes, increase sm_count to improve occupancy. This mirrors + # the existing SM100 RMSNorm/LayerNorm heuristics. + if ( + dtype in (torch.float16, torch.bfloat16) + and M is not None + and M <= 8192 + and N == 4096 + ): + sm_count = min(sm_count * 2, num_sms * 4) + return sm_count diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py index c7fc1b3..1e080a3 100644 --- a/oink/src/kernelagent_oink/blackwell/rmsnorm.py +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py @@ -76,6 +76,7 @@ import cutlass.cute as cute # noqa: E402 from cutlass import Float32, Int32, const_expr # noqa: E402 from cutlass.cute import runtime as rt # noqa: E402 +from cutlass.cute.runtime import from_dlpack # noqa: E402 # Simple compile cache declared early so direct execution works _PTR_COMPILE_CACHE = {} @@ -114,6 +115,14 @@ def _env_flag(name: str, default: bool) -> bool: _ENABLE_TPR256 = _env_flag("OINK_RMSNORM_ENABLE_TPR256", default=False) _ENABLE_STAGE2 = _env_flag("OINK_RMSNORM_ENABLE_STAGE2", default=False) +# Forward dispatch control: +# - Default behavior: use the pointer-based path when safe, otherwise fall back +# to the stage-2 module (then the torch reference). +# - If you want to force stage-2 even when the pointer path is available (for +# experimentation / A-B testing), set this env var **before** importing this +# module. +_FORCE_RMSNORM_STAGE2_FWD = _env_flag("KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2", default=False) + # CuTeDSL stability probe for the experimental cluster_n>1 + direct-GMEM schedule. # # Some CuTeDSL builds segfault during JIT compilation when combining: @@ -918,7 +927,13 @@ def _get_fast_ptr_fused_add_rmsnorm_launcher( # NOTE: Avoid `from . import ...` imports here: CuTeDSL's AST preprocessor may # mishandle that form (module=None in the AST). Use fully-qualified imports. from kernelagent_oink.blackwell import lite_quack as qutils # noqa: E402 -from kernelagent_oink.blackwell.lite_quack import TORCH2CUTE_DTYPE, row_reduce # noqa: E402 +from kernelagent_oink.blackwell.lite_quack import ( # noqa: E402 + TORCH2CUTE_DTYPE, + RMSNormBackward as BaseRMSNormBackward, + convert_from_dlpack as convert_from_dlpack_cute, + get_sm_count, + row_reduce, +) # ------------------------- @@ -2720,52 +2735,57 @@ def rmsnorm_forward( assert x.dim() == 2, "Use (M, N) tensor; flatten batch/seq beforehand." M, N = x.shape - # For DSv3 big-M outliers on SM100, keep using the dedicated - # stage-2 K-loop implementation, which is already tuned and - # parity-checked against the reference. - use_stage2_big_dsv3 = bool( - M >= 65536 and N in (6144, 8192) and x.dtype in (torch.float16, torch.bfloat16) - ) - if use_stage2_big_dsv3: - try: - import rmsnorm_with_stage2 as rms2 # type: ignore[import-not-found] - except Exception: - rms2 = None # type: ignore[assignment] - if rms2 is not None: - y, rstd, residual_out = rms2.rmsnorm_forward_with_stage2( - x, - weight=weight, - bias=bias, - residual=residual, - eps=eps, - store_rstd=store_rstd, - ) - # Preserve stride contracts for torch.compile consistency, even - # when using the optional stage-2 implementation. - if y.stride() != x.stride(): - y_strided = torch.empty_strided( - x.shape, x.stride(), device=x.device, dtype=x.dtype - ) - y_strided.copy_(y) - y = y_strided - if residual is not None and residual_out is not None: - if residual_out.stride() != residual.stride(): - residual_out_strided = torch.empty_strided( - residual.shape, - residual.stride(), - device=residual.device, - dtype=residual.dtype, - ) - residual_out_strided.copy_(residual_out) - residual_out = residual_out_strided - return y, rstd, residual_out - - # Default: use the pointer-based entry whenever we can represent the - # inputs as a row-major [M, N] view with stride(1) == 1. For rare layouts - # we can't safely express without DLPack, fall back to a torch reference. - if _can_use_ptr_path(x, weight, bias, residual): + # Fast path: use the pointer-based entry whenever we can represent the + # inputs as a row-major [M, N] view with stride(1) == 1 and dtype contracts + # are satisfied (vLLM uses this in inference). + # + # When the pointer path can't be used (e.g. float32 weights for Quack-style + # APIs, or non-standard layouts), fall back to the CuTeDSL stage-2 module + # before using the slow torch reference implementation. + force_stage2 = _FORCE_RMSNORM_STAGE2_FWD + + if not force_stage2 and _can_use_ptr_path(x, weight, bias, residual): return _rmsnorm_forward_ptr(x, weight, bias, residual, eps, store_rstd) + # CuTeDSL fallback for cases that aren't safe for the pointer path. + # Import lazily to keep vLLM plugin startup and common inference fast paths + # lightweight. + try: + import importlib + + rms2 = importlib.import_module( + ".rmsnorm_with_stage2", + package=__package__ or "kernelagent_oink.blackwell", + ) + except Exception: + rms2 = None # type: ignore[assignment] + if rms2 is not None: + y, rstd, residual_out = rms2.rmsnorm_forward_with_stage2( + x, + weight=weight, + bias=bias, + residual=residual, + eps=eps, + store_rstd=store_rstd, + ) + # Preserve stride contracts for torch.compile consistency, even + # when using the optional stage-2 implementation. + if y.stride() != x.stride(): + y_strided = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) + y_strided.copy_(y) + y = y_strided + if residual is not None and residual_out is not None: + if residual_out.stride() != residual.stride(): + residual_out_strided = torch.empty_strided( + residual.shape, + residual.stride(), + device=residual.device, + dtype=residual.dtype, + ) + residual_out_strided.copy_(residual_out) + residual_out = residual_out_strided + return y, rstd, residual_out + # Safe fallback (correctness-first). This is expected to be rare in vLLM. y = rmsnorm_ref(x, weight, bias, residual, eps) # Preserve the input stride contract even on the fallback path so @@ -2910,6 +2930,363 @@ def fused_add_rmsnorm_inplace_( return None +# ------------------------- +# Backward kernel (SM100) +# ------------------------- + + +class RMSNormBackwardSM100(BaseRMSNormBackward): + """SM100-tuned RMSNorm backward. + + This is a thin wrapper around the generic `lite_quack.RMSNormBackward` + base implementation, with SM100-friendly tiling heuristics that mirror + the forward policy used by Oink. + """ + + def __init__(self, dtype: cutlass.Numeric, N: int): + super().__init__(dtype, N) + + def _get_num_threads(self) -> int: + # Keep 128 threads only up to N=4k; use 256 for larger rows to ensure + # threads_per_row <= num_threads across buckets. + try: + return self._nt_override # type: ignore[attr-defined] + except Exception: + return 128 if self.N <= 4096 else 256 + + def _calculate_threads_per_row(self) -> int: + # Mirror RMSNormSM100 forward's tiling. + N = self.N + if N <= 64: + return 8 + if N <= 128: + return 16 + if N <= 1024: + return 32 + if N <= 4096: + return 128 + if N <= 8192: + try: + return self._tpr_override # type: ignore[attr-defined] + except Exception: + return 128 + if N <= 16384: + return 256 + return 256 + + def _set_cluster_n(self) -> None: + # Reuse the SM100 forward cluster growth policy so large-N shapes can + # fan out across multiple CTAs in the same row. + try: + self.cluster_n = self._cluster_n_override # type: ignore[attr-defined] + return + except Exception: + pass + + N = self.N + if N <= 8192: + cluster_n = 1 + elif self.dtype.width == 16: + if N <= 16 * 1024: + cluster_n = 2 + elif N <= 32 * 1024: + cluster_n = 2 + elif N <= 64 * 1024: + cluster_n = 4 + elif N <= 128 * 1024: + cluster_n = 8 + else: + cluster_n = 16 + else: + if N <= 32 * 1024: + cluster_n = 1 + elif N <= 64 * 1024: + cluster_n = 2 + elif N <= 128 * 1024: + cluster_n = 4 + elif N <= 256 * 1024: + cluster_n = 8 + else: + cluster_n = 16 + self.cluster_n = cluster_n + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mdO: cute.Tensor, + mdResO: Optional[cute.Tensor], + mRstd: cute.Tensor, + mdX: cute.Tensor, + mdW: Optional[cute.Tensor], + mdRes: Optional[cute.Tensor], + mdB: Optional[cute.Tensor], + sm_count: Int32, + stream: cuda.CUstream, + ): + # Match forward's 32B alignment on the leading dimension to unlock + # wider vectorization when legal. + semistatic_shape = (*mX.shape[:-1], self.N) + + def new_stride(t): + return ( + cute.assume(t.stride[0], divby=256 // t.element_type.width), + t.stride[1], + ) + + mX, mdO, mdResO, mdX, mdRes = [ + cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))) + if const_expr(t is not None) + else None + for t in (mX, mdO, mdResO, mdX, mdRes) + ] + + self._set_cluster_n() + largest_dtype_width = const_expr( + max( + mX.element_type.width, + mdO.element_type.width, + mdX.element_type.width, + mdResO.element_type.width if mdResO is not None else 0, + mdRes.element_type.width if mdRes is not None else 0, + ) + ) + tiler_mn, tv_layout = self._get_tv_layout( + num_copy_bits=128 // largest_dtype_width * mX.element_type.width + ) + num_threads = ( + cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + if const_expr(mW is not None): + mW_expanded_layout = cute.prepend( + mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,)) + ) + mW = cute.make_tensor(mW.iterator, mW_expanded_layout) + + num_blocks = sm_count + kernel = ( + self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes) + ) + kernel.launch( + grid=[num_blocks, self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type), + stream=stream, + ) + + +_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +def _rmsnorm_bwd_sm100( + x: Tensor, + weight: Optional[Tensor], + dout: Tensor, + rstd: Tensor, + dx: Tensor, + dw_partial: Optional[Tensor], + db_partial: Optional[Tensor] = None, + dresidual_out: Optional[Tensor] = None, + dresidual: Optional[Tensor] = None, + sm_count: Optional[int] = None, +) -> None: + """SM100-specific RMSNorm backward dispatch. + + Mirrors Quack's `quack.rmsnorm._rmsnorm_bwd`, but instantiates + `RMSNormBackwardSM100` (SM100-tuned heuristics). + """ + assert x.dim() == 2, "Input must be 2D" + assert x.is_cuda, "Input tensor must be on CUDA device" + assert x.dtype in (torch.float16, torch.bfloat16, torch.float32) + + if weight is not None: + assert weight.dim() == 1 + assert x.shape[-1] == weight.shape[0] + assert weight.is_cuda + assert weight.dtype in (torch.float32, torch.bfloat16, torch.float16) + if dresidual_out is not None: + assert dresidual_out.shape == x.shape + assert dresidual_out.is_cuda + assert dresidual_out.dtype in (torch.float16, torch.bfloat16, torch.float32) + if dresidual is not None: + assert dresidual.shape == x.shape + assert dresidual.is_cuda + assert dresidual.dtype in (torch.float16, torch.bfloat16, torch.float32) + + M, N = x.size(0), x.size(1) + device = x.device + if dw_partial is None and db_partial is None: + assert sm_count is not None + else: + sm_count = ( + dw_partial.shape[0] if dw_partial is not None else db_partial.shape[0] + ) + + # Match Quack's conversion strategy for activations/gradients: keep the + # (M, N) layout dynamic without enforcing additional compact-shape + # constraints. This reduces per-call Python overhead for small-M shapes. + convert_from_dlpack = lambda t: from_dlpack( # type: ignore[assignment] + t.detach(), + assumed_align=16, + ).mark_layout_dynamic(leading_dim=1) + x_tensor, dout_tensor, dres_out_tensor, dx_tensor, dres_tensor = [ + convert_from_dlpack(t) if t is not None else None + for t in (x, dout, dresidual_out, dx, dresidual) + ] + + if weight is not None: + weight_dtype = TORCH2CUTE_DTYPE[weight.dtype] + weight_tensor = convert_from_dlpack_cute( + weight.detach(), + leading_dim=0, + divisibility=128 // weight_dtype.width, + ) + else: + weight_tensor = None + + dw_partial_tensor = ( + from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) + if dw_partial is not None + else None + ) + db_partial_tensor = ( + from_dlpack(db_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) + if db_partial is not None + else None + ) + rstd_tensor = ( + from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + ) + + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compile_key = ( + M, + N, + x_tensor.element_type, + weight_tensor.element_type if weight is not None else None, + db_partial.dtype if db_partial is not None else None, + dresidual.dtype if dresidual is not None else None, + dresidual_out.dtype if dresidual_out is not None else None, + ) + kernel = _BWD_COMPILE_CACHE.get(compile_key) + if kernel is None: + op = RMSNormBackwardSM100(x_tensor.element_type, N) + + # Shape-specific tuning overrides for DSv3-style N=8192 rows. + if isinstance(op, RMSNormBackwardSM100) and N == 8192: + if M >= 65536: + op._tpr_override = 256 # type: ignore[attr-defined] + op._nt_override = 256 # type: ignore[attr-defined] + elif M >= 16384: + op._tpr_override = 256 # type: ignore[attr-defined] + + kernel = cute.compile( + op, + x_tensor, + weight_tensor, + dout_tensor, + dres_out_tensor, + rstd_tensor, + dx_tensor, + dw_partial_tensor, + dres_tensor, + db_partial_tensor, + Int32(sm_count if sm_count is not None else 0), + current_stream, + ) + _BWD_COMPILE_CACHE[compile_key] = kernel + + kernel( + x_tensor, + weight_tensor, + dout_tensor, + dres_out_tensor, + rstd_tensor, + dx_tensor, + dw_partial_tensor, + dres_tensor, + db_partial_tensor, + Int32(sm_count if sm_count is not None else 0), + current_stream, + ) + + +def rmsnorm_backward( + x: Tensor, + weight: Optional[Tensor], + dout: Tensor, + rstd: Tensor, + dresidual_out: Optional[Tensor] = None, + has_bias: bool = False, + has_residual: bool = False, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + """Public SM100 RMSNorm backward entry point. + + Signature mirrors `quack.rmsnorm.rmsnorm_bwd` for easy comparisons. + """ + device = x.device + M, N = x.size(0), x.size(1) + dx = torch.empty_like(x) + if dresidual_out is not None and dresidual_out.dtype != dx.dtype: + dresidual = torch.empty_like(x, dtype=dresidual_out.dtype) + else: + dresidual = None + + # Shared SM100 tuning policy (used by both RMSNorm and LayerNorm). + sm_count = get_sm_count(N, device, M=M, dtype=x.dtype) + + # Quack-suite smallest case (M=8192, N=4096) is extremely sensitive to + # Python/allocator overhead because the kernel itself is very fast. + # + # The default `lite_quack.get_sm_count` adds a small-M occupancy boost for + # N=4096, which increases `dw_partial` size and can amplify allocator + # pressure in benchmark/verify loops. Clamp to Quack's baseline policy + # (`sm_count = num_sms * 2` for N=4096) for this regime. + if N == 4096 and M <= 8192 and x.dtype in (torch.float16, torch.bfloat16): + try: + num_sms = torch.cuda.get_device_properties(device).multi_processor_count + sm_count = min(int(sm_count), int(num_sms) * 2) + except Exception: + pass + + if weight is not None: + dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) + else: + dw_partial = None + db_partial = ( + torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None + ) + + _rmsnorm_bwd_sm100( + x, + weight, + dout, + rstd, + dx, + dw_partial, + db_partial, + dresidual_out, + dresidual, + sm_count, + ) + + dw = dw_partial.sum(dim=0).to(weight.dtype) if weight is not None else None + db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None + if has_residual and dresidual is None: + dresidual = dx + return dx, dw, db, dresidual + + +# Quack-style alias for benchmarks +rmsnorm_bwd = rmsnorm_backward + + if __name__ == "__main__": # Minimal ad-hoc test (functionality only). For performance comparisons, use the benchmark harness. if not torch.cuda.is_available(): diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py new file mode 100644 index 0000000..b53da12 --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py @@ -0,0 +1,805 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +RMSNorm kernel for SM100 (Blackwell) in CuteDSL, with the experimental +stage-2 cp.async ping-pong path preserved for N≈6k/8k. + +This file is a fork of rmsnorm.py that keeps the K-loop cp.async path +behind `self.stage > 1` while the main implementation has been simplified +to a single-stage schedule. +""" + +from __future__ import annotations + +import importlib.metadata +import re +from typing import Optional, Tuple + +import torch +from torch import Tensor + +import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, const_expr +from cutlass.cute.runtime import from_dlpack + +from kernelagent_oink.blackwell import lite_quack as qutils +from kernelagent_oink.blackwell.lite_quack import TORCH2CUTE_DTYPE, row_reduce + +_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +def _parse_version_tuple(version: str) -> tuple[int, int, int]: + parts = version.split(".") + nums: list[int] = [] + for part in parts[:3]: + match = re.match(r"^(\d+)", part) + nums.append(int(match.group(1)) if match is not None else 0) + while len(nums) < 3: + nums.append(0) + return nums[0], nums[1], nums[2] + + +def _cutlass_dsl_version() -> Optional[tuple[int, int, int]]: + try: + return _parse_version_tuple(importlib.metadata.version("nvidia-cutlass-dsl")) + except Exception: + return None + + +_CUTLASS_DSL_VERSION = _cutlass_dsl_version() +# CuTeDSL 4.3.4 tightened some kernel argument expectations (notably around +# passing Layout/Shape/Constexpr objects into @cute.kernel functions). Keep the +# older signature for <4.3.4, but switch to a 4.3.4+ compatible signature when +# we detect 4.3.4+ (or when version detection is unavailable). +_KERNEL_ACCEPTS_LAYOUT_ARGS = ( + _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4) +) + + +@cute.jit +def get_copy_atom_bw( + dtype: type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False +) -> cute.CopyAtom: + max_bits = const_expr(128 if is_async else 256) + num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width)) + from cutlass.cute.nvgpu import cpasync + + copy_op = ( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL) + if is_async + else cute.nvgpu.CopyUniversalOp() + ) + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@cute.jit +def copy_tiled( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + num_copy_elems: int = 1, + is_async: bool = False, +) -> None: + atom = get_copy_atom_bw(src.element_type, num_copy_elems, is_async) + cute.copy(atom, src, dst, pred=pred) + + +class RMSNormSM100WithStage2: + def __init__(self, N: int, dtype: type[cutlass.Numeric], stage: Optional[int] = None): + self.N = N + self.dtype = dtype + self.stage = 1 if stage is None else stage + self.reduction_dtype = cutlass.Float32 + + def _threads_per_row(self) -> int: + N = self.N + if N <= 64: + return 8 + elif N <= 128: + return 16 + elif N <= 1024: + return 32 + elif N <= 4096: + return 128 + elif N <= 8192: + try: + return self._tpr_override # type: ignore[attr-defined] + except Exception: + return 128 + elif N <= 16384: + return 256 + else: + return 256 + + def _cluster_n(self) -> int: + N = self.N + if N <= 8192: + return 1 + if const_expr(self.dtype.width == 16): + if N <= 16 * 1024: + return 2 + elif N <= 32 * 1024: + return 2 + elif N <= 64 * 1024: + return 4 + elif N <= 128 * 1024: + return 8 + else: + return 16 + else: + if N <= 32 * 1024: + return 1 + elif N <= 64 * 1024: + return 2 + elif N <= 128 * 1024: + return 4 + elif N <= 256 * 1024: + return 8 + else: + return 16 + + def _num_threads(self) -> int: + try: + return self._nt_override # type: ignore[attr-defined] + except Exception: + return 128 if self.N <= 16384 else 256 + + def _tv_layout(self, num_copy_bits: int = 256) -> Tuple[cute.Shape, cute.Layout]: + vecsize = num_copy_bits // self.dtype.width + num_threads = self._num_threads() + assert num_threads % cute.arch.WARP_SIZE == 0 + tpr = self._threads_per_row() + cluster_n = self._cluster_n() + num_cols_vec = cute.ceil_div(self.N, vecsize) + num_blocks_N = cute.ceil_div(num_cols_vec, tpr * cluster_n) + cols_per_block = num_threads // tpr + tiler_mn = (cols_per_block, vecsize * num_blocks_N * tpr) + tv_layout = cute.make_layout( + ((tpr, cols_per_block), (vecsize, num_blocks_N)), + stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * tpr)), + ) + return tiler_mn, tv_layout + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + stream: cuda.CUstream, + eps: Float32 = 1e-6, + ): + semistatic_shape = (*mX.shape[:-1], self.N) + + def new_stride(t): + return ( + cute.assume(t.stride[0], divby=256 // t.element_type.width), + t.stride[1], + ) + + mX, mRes, mO, mResO = [ + cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))) + if const_expr(t is not None) + else None + for t in (mX, mRes, mO, mResO) + ] + assert mX.element_type == self.dtype + assert mO.element_type == self.dtype + + copy_bits = const_expr(128) + tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits) + num_threads = ( + cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + threads_per_row = ( + tv_layout.shape[0][0] if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._threads_per_row() + ) + warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1) + cluster_n = self._cluster_n() + + if const_expr(mW is not None): + mW = cute.make_tensor( + mW.iterator, cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))) + ) + if const_expr(mB is not None): + mB = cute.make_tensor( + mB.iterator, cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))) + ) + if const_expr(mRstd is not None): + mRstd = cute.make_tensor( + mRstd.iterator, cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))) + ) + + stage_bufs = 2 if self.stage > 1 else 1 + tile_bytes_x = cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * stage_bufs + tile_bytes_res = ( + cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn)) * stage_bufs + if const_expr(mRes is not None) + else 0 + ) + red_bytes = self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8) + mbar_bytes = self.stage * (cutlass.Int64.width // 8) + smem_bytes = tile_bytes_x + tile_bytes_res + red_bytes + mbar_bytes + + kernel = ( + self.kernel( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + tv_layout, + tiler_mn, + const_expr(num_warps), + const_expr(warps_per_row), + const_expr(threads_per_row), + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + ) + ) + + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), cluster_n, 1], + block=[num_threads, 1, 1], + cluster=([1, cluster_n, 1] if const_expr(cluster_n > 1) else None), + smem=smem_bytes, + stream=stream, + ) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + eps: Float32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + num_warps: cutlass.Constexpr[int], + warps_per_row: cutlass.Constexpr[int], + threads_per_row: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + cluster_n = self._cluster_n() + cluster_y = const_expr(0) if const_expr(cluster_n == 1) else cute.arch.block_idx()[1] + + smem = cutlass.utils.SmemAllocator() + sX0 = smem.allocate_tensor( + mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32 + ) + sX1 = ( + smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, + ) + if const_expr(self.stage > 1) + else None + ) + sRes0 = ( + smem.allocate_tensor( + mRes.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32 + ) + if const_expr(mRes is not None) + else None + ) + sRes1 = ( + smem.allocate_tensor( + mRes.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, + ) + if const_expr(mRes is not None and self.stage > 1) + else None + ) + + reduction_buffer, mbar_ptr = self._alloc_reduction_and_mbar(smem, num_warps, warps_per_row) + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + + num_copy_elems_X = tv_layout.shape[1][0] + use_async = const_expr(self.N >= 1024) + copy_atom = get_copy_atom_bw(mX.element_type, num_copy_elems_X, is_async=use_async) + thr_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn).get_slice(tidx) + + gW, gB = [ + cute.local_tile(t, tiler_mn, (0, cluster_y)) if const_expr(t is not None) else None + for t in (mW, mB) + ] + tXgW = thr_copy.partition_S(gW) if const_expr(mW is not None) else None + tXgB = thr_copy.partition_S(gB) if const_expr(mB is not None) else None + tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None + tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None + if const_expr(mW is not None): + cute.copy(get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False), tXgW, tXrW) + if const_expr(mB is not None): + cute.copy(get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False), tXgB, tXrB) + + self._init_cluster(tidx, mbar_ptr) + + mX_i, mRes_i, mO_i, mResO_i = [ + qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t) if t is not None else None + for t in (mX, mRes, mO, mResO) + ] + gX_i = cute.local_tile(mX_i, tiler_mn, (0, cluster_y)) + gO_i = cute.local_tile(mO_i, tiler_mn, (0, cluster_y)) + gRes_i = ( + cute.local_tile(mRes_i, tiler_mn, (0, cluster_y)) if const_expr(mRes is not None) else None + ) + gResO_i = ( + cute.local_tile(mResO_i, tiler_mn, (0, cluster_y)) if const_expr(mResO is not None) else None + ) + gRstd_i = ( + cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y)) if const_expr(mRstd is not None) else None + ) + cX_i = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) + + tXcX_i = thr_copy.partition_S(cX_i)[(0, None), None, None] + row_i = tXcX_i[0][0] + tXgRstd_i = thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None + + # Intra-row K-loop cp.async ping-pong (two-pass) for N≈6k/8k (stage=2) + if const_expr(self.stage > 1 and (shape[1] == 6144 or shape[1] == 8192)): + vecsize = tv_layout.shape[1][0] + tpr = threads_per_row + target_tile_n = const_expr(4096 if shape[1] == 6144 else 8192) + tile_factor = const_expr(target_tile_n // (vecsize * tpr)) + tile_n = vecsize * tpr * tile_factor + num_tiles = cute.ceil_div(shape[1], tile_n) + + tiler_mn_tile = (tiler_mn[0], tile_n) + sX0_tile = cute.local_tile(sX0, tiler_mn_tile, (0, 0)) + sX1_tile = cute.local_tile(sX1, tiler_mn_tile, (0, 0)) if const_expr(self.stage > 1) else None + sRes0_tile = ( + cute.local_tile(sRes0, tiler_mn_tile, (0, 0)) if const_expr(mRes is not None) else None + ) + sRes1_tile = ( + cute.local_tile(sRes1, tiler_mn_tile, (0, 0)) if const_expr(mRes is not None and self.stage > 1) else None + ) + + tv_layout_tile = cute.make_layout( + ((tpr, tiler_mn[0]), (vecsize, tile_factor)), + stride=((vecsize * tiler_mn[0], 1), (tiler_mn[0], tiler_mn[0] * vecsize * tpr)), + ) + thr_copy_tile = cute.make_tiled_copy(copy_atom, tv_layout_tile, tiler_mn_tile).get_slice(tidx) + + sum_sq_acc = cute.Float32(0.0) + k_off0 = const_expr(0) * tile_n + gX_0 = cute.local_tile(qutils.domain_offset_i64((0, k_off0), mX_i), tiler_mn_tile, (0, cluster_y)) + tXgX_0 = thr_copy_tile.partition_S(gX_0) + tXsX_0 = thr_copy_tile.partition_D(sX0_tile) + cX_0 = cute.local_tile(cute.domain_offset((0, k_off0), cX_i), tiler_mn_tile, (0, cluster_y)) + tXc_0 = thr_copy_tile.partition_S(cX_0) + tXp_0 = qutils.predicate_k(tXc_0, limit=shape[1]) + tXp_ping = tXp_0 + tXp_pong = tXp_0 + if row_i < shape[0]: + copy_tiled(tXgX_0, tXsX_0, num_copy_elems=vecsize, is_async=use_async, pred=tXp_0) + if const_expr(mRes is not None): + gRes_0 = cute.local_tile(qutils.domain_offset_i64((0, k_off0), mRes_i), tiler_mn_tile, (0, cluster_y)) + tXgRes_0 = thr_copy_tile.partition_S(gRes_0) + tXsRes_0 = thr_copy_tile.partition_D(sRes0_tile) + copy_tiled(tXgRes_0, tXsRes_0, num_copy_elems=vecsize, is_async=use_async, pred=tXp_0) + if const_expr(use_async): + cute.arch.cp_async_commit_group() + + for t in cutlass.range_constexpr(num_tiles): + next_t = t + 1 + if next_t < num_tiles: + k_off_n = next_t * tile_n + gX_n = cute.local_tile(qutils.domain_offset_i64((0, k_off_n), mX_i), tiler_mn_tile, (0, cluster_y)) + tXgX_n = thr_copy_tile.partition_S(gX_n) + cX_n = cute.local_tile(cute.domain_offset((0, k_off_n), cX_i), tiler_mn_tile, (0, cluster_y)) + tXc_n = thr_copy_tile.partition_S(cX_n) + tXp_n = qutils.predicate_k(tXc_n, limit=shape[1]) + if const_expr((t % 2) == 0): + tXsX_n = thr_copy_tile.partition_D(sX1_tile) + tXsRes_n = ( + thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None + ) + tXp_pong = tXp_n + else: + tXsX_n = thr_copy_tile.partition_D(sX0_tile) + tXsRes_n = ( + thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None + ) + tXp_ping = tXp_n + if row_i < shape[0]: + copy_tiled(tXgX_n, tXsX_n, num_copy_elems=vecsize, is_async=use_async, pred=tXp_n) + if const_expr(mRes is not None): + gRes_n = cute.local_tile(qutils.domain_offset_i64((0, k_off_n), mRes_i), tiler_mn_tile, (0, cluster_y)) + tXgRes_n = thr_copy_tile.partition_S(gRes_n) + copy_tiled(tXgRes_n, tXsRes_n, num_copy_elems=vecsize, is_async=use_async, pred=tXp_n) + if const_expr(use_async): + cute.arch.cp_async_commit_group() + if const_expr(use_async): + cute.arch.cp_async_wait_group(1 if next_t < num_tiles else 0) + + if const_expr((t % 2) == 0): + tXsX_cur = thr_copy_tile.partition_D(sX0_tile) + tXsRes_cur = thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None + pred_cur = tXp_ping + else: + tXsX_cur = thr_copy_tile.partition_D(sX1_tile) + tXsRes_cur = thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None + pred_cur = tXp_pong + qutils.fill_oob(tXsX_cur, pred_cur, mX.element_type.zero) + if const_expr(mRes is not None): + qutils.fill_oob(tXsRes_cur, pred_cur, mRes.element_type.zero) + + k_off = t * tile_n + gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, cluster_y)) + tXgX_t = thr_copy_tile.partition_S(gX_t) + tXrX = cute.make_fragment_like(tXgX_t) + cute.autovec_copy(tXsX_cur, tXrX) + x = tXrX.load().to(cute.Float32) + if const_expr(mRes is not None): + gRes_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, cluster_y)) + tXgRes_t = thr_copy_tile.partition_S(gRes_t) + tXrRes = cute.make_fragment_like(tXgRes_t) + cute.autovec_copy(tXsRes_cur, tXrRes) + x += tXrRes.load().to(cute.Float32) + + if const_expr(mResO is not None): + gResO_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mResO_i), tiler_mn_tile, (0, cluster_y)) + tXgResO_t = thr_copy_tile.partition_D(gResO_t) + tXrResO = cute.make_fragment_like(tXgResO_t) + tXrResO.store(x.to(tXrResO.element_type)) + if row_i < shape[0]: + copy_tiled(tXrResO, tXgResO_t, num_copy_elems=vecsize, is_async=False, pred=pred_cur) + + sum_sq_tile = row_reduce( + x * x, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + init_val=0.0, + hook_fn=(cute.arch.cluster_wait if const_expr(cluster_n > 1) else None), + ) + sum_sq_acc = sum_sq_acc + sum_sq_tile + + rstd = cute.math.rsqrt(sum_sq_acc / shape[1] + eps, fastmath=True) + if const_expr(mRstd is not None): + if ( + tXcX_i[0][1] == 0 + and row_i < shape[0] + and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) + ): + tXgRstd_i[0] = rstd + + for t in cutlass.range_constexpr(num_tiles): + k_off = t * tile_n + cX_t = cute.local_tile(cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, cluster_y)) + tXc_t = thr_copy_tile.partition_S(cX_t) + tXp_t = qutils.predicate_k(tXc_t, limit=shape[1]) + + if const_expr((t % 2) == 0): + tXsX_cur = thr_copy_tile.partition_D(sX0_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None + ) + else: + tXsX_cur = thr_copy_tile.partition_D(sX1_tile) + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None + ) + + qutils.fill_oob(tXsX_cur, tXp_t, mX.element_type.zero) + if const_expr(mRes is not None): + qutils.fill_oob(tXsRes_cur, tXp_t, mRes.element_type.zero) + + gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, cluster_y)) + tXgX_t = thr_copy_tile.partition_S(gX_t) + tXrX = cute.make_fragment_like(tXgX_t) + cute.autovec_copy(tXsX_cur, tXrX) + x = tXrX.load().to(cute.Float32) + if const_expr(mRes is not None): + gRes_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, cluster_y)) + tXgRes_t = thr_copy_tile.partition_S(gRes_t) + tXrRes = cute.make_fragment_like(tXgRes_t) + cute.autovec_copy(tXsRes_cur, tXrRes) + x += tXrRes.load().to(cute.Float32) + + y = x * rstd + if const_expr(mW is not None): + gW_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mW), tiler_mn_tile, (0, cluster_y)) + tWgW_t = thr_copy_tile.partition_S(gW_t) + tWrW_t = cute.make_fragment_like(tWgW_t) + copy_tiled(tWgW_t, tWrW_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t) + y = y * tWrW_t.load().to(cute.Float32) + if const_expr(mB is not None): + gB_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mB), tiler_mn_tile, (0, cluster_y)) + tWgB_t = thr_copy_tile.partition_S(gB_t) + tWrB_t = cute.make_fragment_like(tWgB_t) + copy_tiled(tWgB_t, tWrB_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t) + y = y + tWrB_t.load().to(cute.Float32) + + gO_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mO_i), tiler_mn_tile, (0, cluster_y)) + tXgO_t = thr_copy_tile.partition_D(gO_t) + tXrO = cute.make_fragment_like(tXgO_t) + tXrO.store(y.to(tXrO.element_type)) + if row_i < shape[0]: + copy_tiled(tXrO, tXgO_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t) + + return + + # Fallback: single-stage path identical to current rmsnorm.py + tXgX_i = thr_copy.partition_S(gX_i) + tXgRes_i = thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None + tXgO_i = thr_copy.partition_D(gO_i) + tXgResO_i = thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None + is_even_N_i = const_expr(shape[1] == tiler_mn[1] * cluster_n) + tXpX_i = ( + qutils.predicate_k(thr_copy.partition_S(cX_i), limit=shape[1]) if not is_even_N_i else None + ) + + if row_i < shape[0]: + cute.copy(copy_atom, tXgX_i, thr_copy.partition_D(sX0), pred=tXpX_i) + if const_expr(mRes is not None): + cute.copy(copy_atom, tXgRes_i, thr_copy.partition_D(sRes0), pred=tXpX_i) + if const_expr(use_async): + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + tXrX = cute.make_fragment_like(tXgX_i) + cute.autovec_copy(thr_copy.partition_D(sX0), tXrX) + x = tXrX.load().to(cute.Float32) + if const_expr(mRes is not None): + tXrRes = cute.make_fragment_like(tXgRes_i) + cute.autovec_copy(thr_copy.partition_D(sRes0), tXrRes) + x += tXrRes.load().to(cute.Float32) + + if const_expr(mResO is not None): + tXrResO = cute.make_fragment_like(tXgResO_i) + tXrResO.store(x.to(tXrResO.element_type)) + if row_i < shape[0]: + cute.copy( + get_copy_atom_bw(tXrResO.element_type, num_copy_elems_X, is_async=False), + tXrResO, + tXgResO_i, + ) + + sum_sq = row_reduce( + x * x, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + init_val=0.0, + hook_fn=(cute.arch.cluster_wait if const_expr(cluster_n > 1) else None), + ) + rstd = cute.math.rsqrt(sum_sq / shape[1] + eps, fastmath=True) + + if const_expr(mRstd is not None): + if ( + tXcX_i[0][1] == 0 + and row_i < shape[0] + and (cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) + ): + tXgRstd_i[0] = rstd + + y = x * rstd + if const_expr(mW is not None): + y = y * tXrW.load().to(cute.Float32) + if const_expr(mB is not None): + y = y + tXrB.load().to(cute.Float32) + + tXrO = cute.make_fragment_like(tXgO_i) + tXrO.store(y.to(tXrO.element_type)) + if row_i < shape[0]: + cute.copy( + get_copy_atom_bw(tXrO.element_type, num_copy_elems_X, is_async=False), + tXrO, + tXgO_i, + ) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + eps: Float32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + num_warps: cutlass.Constexpr[int], + warps_per_row: cutlass.Constexpr[int], + threads_per_row: cutlass.Constexpr[int], + ): + self._kernel_impl( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + tv_layout, + tiler_mn, + num_warps, + warps_per_row, + threads_per_row, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mW: Optional[cute.Tensor], + mB: Optional[cute.Tensor], + mRes: Optional[cute.Tensor], + mO: cute.Tensor, + mResO: Optional[cute.Tensor], + mRstd: Optional[cute.Tensor], + eps: Float32, + ): + copy_bits = const_expr(128) + tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits) + num_threads = self._num_threads() + num_warps = num_threads // cute.arch.WARP_SIZE + threads_per_row = self._threads_per_row() + warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1) + self._kernel_impl( + mX, + mW, + mB, + mRes, + mO, + mResO, + mRstd, + eps, + tv_layout, + tiler_mn, + const_expr(num_warps), + const_expr(warps_per_row), + const_expr(threads_per_row), + ) + + @cute.jit + def _alloc_reduction_and_mbar( + self, + smem: cutlass.utils.SmemAllocator, + num_warps: cutlass.Constexpr[int], + warps_per_row: cutlass.Constexpr[int], + ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]: + cluster_n = self._cluster_n() + red_layout = cute.make_ordered_layout( + (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage), + order=(1, 0, 2), + ) + reduction_buffer = smem.allocate_tensor(self.reduction_dtype, red_layout, byte_alignment=4) + if const_expr(cluster_n > 1): + mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage) + else: + mbar_ptr = None + return reduction_buffer, mbar_ptr + + @cute.jit + def _init_cluster(self, tidx: cutlass.Int32, mbar_ptr: Optional[cute.Pointer]): + if const_expr(mbar_ptr is not None): + if tidx < self.stage: + cute.arch.mbarrier_init(mbar_ptr + tidx, 1) + cute.arch.mbarrier_init_fence() + cute.arch.cluster_arrive_relaxed() + + +def rmsnorm_forward_with_stage2( + x: Tensor, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + residual: Optional[Tensor] = None, + eps: float = 1e-6, + store_rstd: bool = False, +) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + assert x.is_cuda + assert x.dim() == 2 + M, N = x.shape + dtype = TORCH2CUTE_DTYPE[x.dtype] + + convert_x = lambda t: from_dlpack(t.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1) + mX = convert_x(x) + mRes = convert_x(residual) if residual is not None else None + out = torch.empty_like(x, dtype=x.dtype) + mO = from_dlpack(out.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1) + + mW = ( + from_dlpack(weight.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=0) + if weight is not None + else None + ) + mB = ( + from_dlpack(bias.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=0) + if bias is not None + else None + ) + if store_rstd: + rstd = torch.empty(M, device=x.device, dtype=torch.float32) + mRstd = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + else: + rstd = None + mRstd = None + + residual_out = None + mResO = None + if residual is not None: + residual_out = torch.empty_like(residual) + mResO = from_dlpack(residual_out.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1) + + # Enable the intra-row cp.async K-loop only for DSv3-style large-N rows + # with very large M, where there is enough work per row to amortize the + # pipeline start-up cost. Mid-size M shapes are better served by the + # simpler single-stage schedule. + use_kloop = bool(M >= 65536 and N in (6144, 8192)) + stage = 2 if use_kloop else 1 + op = RMSNormSM100WithStage2(N, dtype, stage=stage) + if use_kloop: + op._tpr_override = 128 # type: ignore[attr-defined] + # Prefer 1 row/CTA at N=6144; keep 2 rows/CTA at N=8192 to match + # the original tuning there. + op._nt_override = (128 if N == 6144 else 256) # type: ignore[attr-defined] + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + key = ( + N, + dtype, + mRes is not None, + mW is not None, + mB is not None, + mResO is not None, + mRstd is not None, + stage, + ) + compiled = _COMPILE_CACHE.get(key) + if compiled is None: + compiled = cute.compile(op, mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps)) + _COMPILE_CACHE[key] = compiled + compiled(mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps)) + return out, rstd, residual_out diff --git a/oink/src/kernelagent_oink/blackwell/softmax.py b/oink/src/kernelagent_oink/blackwell/softmax.py new file mode 100644 index 0000000..a2f2581 --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/softmax.py @@ -0,0 +1,749 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Softmax forward + backward kernels for SM100 (Blackwell) in CuteDSL. + +This module implements numerically stable softmax over the last dimension of +2D tensors (M, N) and its backward pass, targeting SM100 with Quack-style +tiling, cp.async pipelines, and cluster reductions, but without depending on +the `quack` package at runtime. + +The kernels are self-contained and use only local helpers in +`kernelagent_oink.blackwell.lite_quack` plus CuTeDSL/CUTLASS. +""" + +from __future__ import annotations + +import importlib.metadata +import math +import os +import re +from typing import Optional, Type + +import torch +from torch import Tensor + +import cuda.bindings.driver as cuda # provided by NVIDIA cuda-python + +# CuTeDSL caches generated MLIR into a tempdir under a global default +# (`/tmp/$USER/cutlass_python_cache`). The cache bytecode format can differ across +# `nvidia-cutlass-dsl` versions, and cross-version cache sharing causes noisy +# warnings (and disables cache reuse). +if "CUTE_DSL_CACHE_DIR" not in os.environ: + try: + _dsl_ver = importlib.metadata.version("nvidia-cutlass-dsl") + except Exception: + _dsl_ver = "unknown" + _dsl_ver = re.sub(r"[^0-9A-Za-z]+", "_", _dsl_ver) + _user = os.environ.get("USER") or os.environ.get("USERNAME") or "user" + _tmp = os.environ.get("TMPDIR") or "/tmp" + os.environ["CUTE_DSL_CACHE_DIR"] = os.path.join( + _tmp, _user, f"cutlass_python_cache_{_dsl_ver}" + ) + +try: + import cutlass # type: ignore # noqa: F401 +except Exception as e: + raise ImportError( + "kernelagent_oink.blackwell.softmax requires CuTeDSL's Python package " + "(`cutlass`, typically provided by `nvidia-cutlass-dsl`)." + ) from e + +import cutlass.cute as cute +from cutlass import Float32, Int32, const_expr +from cutlass.cute import runtime as rt +from cutlass.cute.runtime import from_dlpack + +from kernelagent_oink.blackwell.lite_quack import ( + _KERNEL_ACCEPTS_LAYOUT_ARGS, + TORCH2CUTE_DTYPE, + ReductionBase, + domain_offset_i64, + fill_oob, + online_softmax_reduce, + predicate_k, + row_reduce, +) + +_FWD_COMPILE_CACHE: dict[tuple[Type[cutlass.Numeric], int], object] = {} +_BWD_COMPILE_CACHE: dict[tuple[Type[cutlass.Numeric], int], object] = {} +_PTR_FWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} +_PTR_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +class SoftmaxFwdSM100(ReductionBase): + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + # One-stage online reduction: pack (max, sum_exp) into Int64 reduction buffer. + super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64) + + def _calculate_threads_per_row(self) -> int: + # Match Quack's bucketed policy for Softmax. + N = self.N + if N <= 64: + return 8 + if N <= 128: + return 16 + if N <= 3072: + return 32 + if N <= 6144: + return 64 + if N <= 16384: + return 128 + return 256 + + def _set_cluster_n(self) -> None: + # Quack-style growth of cluster_n with N and dtype. + N = self.N + if const_expr(self.dtype.width == 16): + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: + cluster_n = ( + 1 + if N <= 32 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + @cute.jit + def __call__(self, mX: cute.Tensor, mO: cute.Tensor, stream: cuda.CUstream) -> None: + assert mX.element_type == self.dtype + assert mO.element_type == self.dtype + # Use the generic ReductionBase tiling with 128-bit vectorization. + tiler_mn, tv_layout = self._get_tv_layout() + num_threads = ( + cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + kernel = ( + self.kernel(mX, mO, tv_layout, tiler_mn) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel(mX, mO) + ) + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_x: cute.Pointer, + ptr_out: cute.Pointer, + M: Int32, + ld: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions. + + Reconstructs cute.Tensor views from raw pointers + explicit layouts + inside the JIT graph, matching the existing SM100 schedule. + """ + # Mirror Quack/LayerNorm contracts: assume 16B alignment and an LD that + # preserves 128-bit vectorized copies for every row start. + ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + mX = cute.make_tensor(ptr_x, layout_mn) + mO = cute.make_tensor(ptr_out, layout_mn) + self.__call__(mX, mO, stream) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, + mO: cute.Tensor, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + if const_expr(self.cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = const_expr(0) + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + + # Slice per-CTA region; use 64-bit indexing for large tensors. + mX, mO = [domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)] + gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)] + cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + + # Copy atoms for gmem <-> smem and smem <-> gmem. + # Use 128-bit cp.async for global->shared and 128-bit vectorized stores. + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mX.element_type, + num_bits_per_copy=128, + ) + copy_atom_store = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gO.element_type, + num_bits_per_copy=128, + ) + + thr_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn).get_slice(tidx) + thr_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx) + + tXgX = thr_copy_load.partition_S(gX) + tXsX = thr_copy_load.partition_D(sX) + tXgO = thr_copy_store.partition_D(gO) + tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None] + + # Register fragments. + tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)] + + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps) + + # Predicate and cp.async pipeline for potential tail tiles. + is_even_N = const_expr(self.N == tiler_mn[1] * self.cluster_n) + tXpX = ( + predicate_k(thr_copy_load.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_load, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + if const_expr(not is_even_N): + fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + threads_per_row = tv_layout.shape[0][0] + + # Online softmax reduction: compute max and sum_exp in a single pass, with + # optional cluster-wide aggregation via an Int64 reduction buffer. + max_x, denom, exp_x = online_softmax_reduce( + x, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + phase=None, + return_exp_x=True, + ) + + y = exp_x * cute.arch.rcp_approx(denom) + tXrO.store(y.to(tXrO.element_type)) + + tOpO = ( + predicate_k(thr_copy_store.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_store, tXrO, tXgO, pred=tOpO) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mO: cute.Tensor, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + self._kernel_impl(mX, mO, tv_layout, tiler_mn) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mO: cute.Tensor, + ) -> None: + tiler_mn, tv_layout = self._get_tv_layout() + self._kernel_impl(mX, mO, tv_layout, tiler_mn) + + +class SoftmaxBwdSM100(ReductionBase): + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + # One stage for dot(dy, y) per row. + super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Float32) + + def _calculate_threads_per_row(self) -> int: + # Match Quack backward softmax buckets. + N = self.N + if N <= 64: + return 8 + if N <= 128: + return 16 + if N <= 3072: + return 32 + if N <= 6144: + return 64 + if N <= 8192: + return 128 + return 256 + + def _set_cluster_n(self) -> None: + N = self.N + if const_expr(self.dtype.width == 16): + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: + cluster_n = ( + 1 + if N <= 32 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + def _get_num_threads(self) -> int: + # Slightly more aggressive threading for large N than the base class. + return 128 if self.N <= 8192 else 256 + + def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int: + # Store both y and dy tiles plus reduction buffers and mbarriers. + return ( + cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 + + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8) + + self.stage * (cutlass.Int64.width // 8) + ) + + @cute.jit + def __call__( + self, + mdY: cute.Tensor, + mY: cute.Tensor, + mdX: cute.Tensor, + stream: cuda.CUstream, + ) -> None: + assert mdY.element_type == self.dtype + assert mY.element_type == self.dtype + assert mdX.element_type == self.dtype + # Use the generic ReductionBase tiling with 128-bit vectorization. + tiler_mn, tv_layout = self._get_tv_layout() + num_threads = ( + cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + kernel = ( + self.kernel(mdY, mY, mdX, tv_layout, tiler_mn) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel(mdY, mY, mdX) + ) + kernel.launch( + grid=[cute.ceil_div(mdY.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_dy: cute.Pointer, + ptr_y: cute.Pointer, + ptr_dx: cute.Pointer, + M: Int32, + ld: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions.""" + ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + mdY = cute.make_tensor(ptr_dy, layout_mn) + mY = cute.make_tensor(ptr_y, layout_mn) + mdX = cute.make_tensor(ptr_dx, layout_mn) + self.__call__(mdY, mY, mdX, stream) + + @cute.jit + def _kernel_impl( + self, + mdY: cute.Tensor, + mY: cute.Tensor, + mdX: cute.Tensor, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + if const_expr(self.cluster_n > 1): + cluster_y = cute.arch.block_idx()[1] + else: + cluster_y = const_expr(0) + + shape = mdY.shape + idX = cute.make_identity_tensor(shape) + + mdY, mY, mdX = [ + domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mdY, mY, mdX) + ] + gdY, gY, gdX = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX)] + cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) + + smem = cutlass.utils.SmemAllocator() + sdY = smem.allocate_tensor( + mdY.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + sY = smem.allocate_tensor( + mY.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mdY.element_type, + num_bits_per_copy=128, + ) + copy_atom_store = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gdX.element_type, + num_bits_per_copy=128, + ) + + thr_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn).get_slice(tidx) + thr_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx) + + tdYgdY = thr_copy_load.partition_S(gdY) + tdYsdY = thr_copy_load.partition_D(sdY) + tYgY = thr_copy_load.partition_S(gY) + tYsY = thr_copy_load.partition_D(sY) + tdXgdX = thr_copy_store.partition_D(gdX) + tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None] + + tdYrdY, tYrY, tdXrdX = [cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX)] + + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps) + + is_even_N = const_expr(self.N == tiler_mn[1] * self.cluster_n) + tdYpdY = ( + predicate_k(thr_copy_load.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_load, tdYgdY, tdYsdY, pred=tdYpdY) + cute.copy(copy_atom_load, tYgY, tYsY, pred=tdYpdY) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + cute.autovec_copy(tdYsdY, tdYrdY) + cute.autovec_copy(tYsY, tYrY) + dy = tdYrdY.load().to(Float32) + y = tYrY.load().to(Float32) + + threads_per_row = tv_layout.shape[0][0] + dot = row_reduce( + dy * y, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr if const_expr(self.cluster_n > 1) else None, + init_val=0.0, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + ) + + dx = y * (dy - dot) + tdXrdX.store(dx.to(tdXrdX.element_type)) + + tdXpdX = ( + predicate_k(thr_copy_store.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tdXpdX) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mdY: cute.Tensor, + mY: cute.Tensor, + mdX: cute.Tensor, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + self._kernel_impl(mdY, mY, mdX, tv_layout, tiler_mn) + else: + + @cute.kernel + def kernel( + self, + mdY: cute.Tensor, + mY: cute.Tensor, + mdX: cute.Tensor, + ) -> None: + tiler_mn, tv_layout = self._get_tv_layout() + self._kernel_impl(mdY, mY, mdX, tv_layout, tiler_mn) + + +def _convert_2d_tensor(x: Tensor) -> cute.Tensor: + # Match Quack's Softmax conversion exactly: assume 16B alignment and mark + # the shape compact with row-major stride order (0, 1), with mode=0 (batch). + # We intentionally do not call mark_layout_dynamic here to avoid the + # leading_dim stride==1 constraint used in RMSNorm. + return ( + from_dlpack(x.detach(), assumed_align=16) + .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1)) + ) + + +def _can_use_ptr_path_2d(x: Tensor) -> bool: + """Conservative guard for the pointer-based fast path.""" + if not x.is_cuda or x.dim() != 2: + return False + if x.dtype not in TORCH2CUTE_DTYPE: + return False + # Require row-major last-dim contiguous. + if x.stride(1) != 1: + return False + # Require 16B alignment (matches from_dlpack(..., assumed_align=16)). + if (x.data_ptr() % 16) != 0: + return False + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + divby = 128 // dtype_x.width + # Softmax uses ReductionBase default num_copy_bits=128, so N must be divisible. + if (x.shape[1] % divby) != 0: + return False + # Ensure each row start remains aligned for 128-bit vectorized copies. + if (x.stride(0) % divby) != 0: + return False + return True + + +def _softmax_forward_ptr_into(*, x: Tensor, out: Tensor) -> None: + """Launch the pointer-based Softmax forward kernel into preallocated `out`.""" + assert x.is_cuda and x.dim() == 2 + assert out.is_cuda and out.shape == x.shape and out.dtype == x.dtype + assert out.stride() == x.stride(), "Pointer path expects out to match x strides" + + M, N = x.shape + device_index = x.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + key = ("ptr_fwd", int(N), dtype_x, int(device_index)) + compiled = _PTR_FWD_COMPILE_CACHE.get(key) + if compiled is None: + op = SoftmaxFwdSM100(dtype_x, int(N)) + ptr_x = rt.make_ptr(dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_out = rt.make_ptr( + dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ld = Int32(int(x.stride(0))) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_x, + ptr_out, + Int32(int(M)), + ld, + stream, + ) + _PTR_FWD_COMPILE_CACHE[key] = compiled + + ptr_x = rt.make_ptr(dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_out = rt.make_ptr(dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + compiled(ptr_x, ptr_out, Int32(int(M)), Int32(int(x.stride(0))), stream) + + +def _softmax_backward_ptr_into(*, dy: Tensor, y: Tensor, dx: Tensor) -> None: + """Launch the pointer-based Softmax backward kernel into preallocated `dx`.""" + assert dy.is_cuda and dy.dim() == 2 + assert y.is_cuda and y.shape == dy.shape and y.dtype == dy.dtype + assert dx.is_cuda and dx.shape == dy.shape and dx.dtype == dy.dtype + assert dy.stride() == y.stride() == dx.stride(), "Pointer path expects matching strides" + + M, N = dy.shape + device_index = dy.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + + dtype_x = TORCH2CUTE_DTYPE[dy.dtype] + key = ("ptr_bwd", int(N), dtype_x, int(device_index)) + compiled = _PTR_BWD_COMPILE_CACHE.get(key) + if compiled is None: + op = SoftmaxBwdSM100(dtype_x, int(N)) + ptr_dy = rt.make_ptr(dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_y = rt.make_ptr(dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_dx = rt.make_ptr(dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ld = Int32(int(dy.stride(0))) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_dy, + ptr_y, + ptr_dx, + Int32(int(M)), + ld, + stream, + ) + _PTR_BWD_COMPILE_CACHE[key] = compiled + + ptr_dy = rt.make_ptr(dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_y = rt.make_ptr(dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_dx = rt.make_ptr(dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + compiled(ptr_dy, ptr_y, ptr_dx, Int32(int(M)), Int32(int(dy.stride(0))), stream) + + +def softmax_forward(x: Tensor) -> Tensor: + """SM100 CuteDSL softmax forward pass: y = softmax(x, dim=-1).""" + assert x.dim() == 2, "Input must be 2D (M, N)" + assert x.is_cuda, "Input must be on CUDA device" + assert x.dtype in TORCH2CUTE_DTYPE, "Unsupported dtype" + + N = x.size(1) + dtype = TORCH2CUTE_DTYPE[x.dtype] + if _can_use_ptr_path_2d(x): + out = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) + _softmax_forward_ptr_into(x=x, out=out) + return out + + out = torch.empty_like(x) + + x_tensor = _convert_2d_tensor(x) + out_tensor = _convert_2d_tensor(out) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compile_key = (dtype, N) + kernel = _FWD_COMPILE_CACHE.get(compile_key) + if kernel is None: + op = SoftmaxFwdSM100(dtype, N) + kernel = cute.compile(op, x_tensor, out_tensor, current_stream) + _FWD_COMPILE_CACHE[compile_key] = kernel + kernel(x_tensor, out_tensor, current_stream) + return out + + +def softmax_backward(dy: Tensor, y: Tensor) -> Tensor: + """SM100 CuteDSL softmax backward pass.""" + assert dy.dim() == 2 and y.dim() == 2, "dy and y must be 2D (M, N)" + assert dy.shape == y.shape, "dy and y must have the same shape" + assert dy.is_cuda and y.is_cuda, "dy and y must be on CUDA device" + assert dy.dtype in TORCH2CUTE_DTYPE, "Unsupported dtype" + assert y.dtype == dy.dtype, "dy and y must have the same dtype" + + N = dy.size(1) + dtype = TORCH2CUTE_DTYPE[dy.dtype] + if _can_use_ptr_path_2d(dy) and _can_use_ptr_path_2d(y) and dy.stride() == y.stride(): + dx = torch.empty_strided(dy.shape, dy.stride(), device=dy.device, dtype=dy.dtype) + _softmax_backward_ptr_into(dy=dy, y=y, dx=dx) + return dx + + dx = torch.empty_like(dy) + + dy_tensor = _convert_2d_tensor(dy) + y_tensor = _convert_2d_tensor(y) + dx_tensor = _convert_2d_tensor(dx) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compile_key = (dtype, N) + kernel = _BWD_COMPILE_CACHE.get(compile_key) + if kernel is None: + op = SoftmaxBwdSM100(dtype, N) + kernel = cute.compile(op, dy_tensor, y_tensor, dx_tensor, current_stream) + _BWD_COMPILE_CACHE[compile_key] = kernel + kernel(dy_tensor, y_tensor, dx_tensor, current_stream) + return dx + + +class SoftmaxFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: Tensor) -> Tensor: + y = softmax_forward(x) + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy: Tensor) -> tuple[Tensor]: + (y,) = ctx.saved_tensors + dx = softmax_backward(dy, y) + return dx + + +def softmax(x: Tensor) -> Tensor: + """Autograd-friendly softmax using the SM100 CuteDSL kernel.""" + return SoftmaxFunction.apply(x) + + +def _torch_softmax_reference(x: Tensor) -> Tensor: + return torch.nn.functional.softmax(x, dim=-1) + + +def verify_softmax_parity( + M: int, + N: int, + dtype: torch.dtype = torch.bfloat16, + atol: float = 5e-2, + rtol: float = 5e-2, +) -> None: + """Compare SM100 CuteDSL softmax against PyTorch for a single shape.""" + device = torch.device("cuda") + x = torch.randn(M, N, device=device, dtype=dtype) + x.requires_grad_(True) + + # Forward parity + y_ref = _torch_softmax_reference(x) + y = softmax(x) + torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) + + # Backward parity + dy = torch.randn_like(y) + (dx_ref,) = torch.autograd.grad(y_ref, x, dy, retain_graph=False) + dx = softmax_backward(dy, y) + torch.testing.assert_close(dx, dx_ref, atol=atol, rtol=rtol) From 9b29732bb333a98ddbb5f750c1ff407d9050c5e7 Mon Sep 17 00:00:00 2001 From: Laura Wang <3700467+Laurawly@users.noreply.github.com> Date: Wed, 21 Jan 2026 20:06:57 -0800 Subject: [PATCH 5/8] oink: fix ruff lint --- .../benchmark_cross_entropy_sm100.py | 89 ++++++++++--------- .../benchmark_fused_add_rmsnorm_sm100.py | 36 ++++---- .../benchmark/benchmark_hbm_roofline_sm100.py | 42 +++++---- .../benchmark/benchmark_layernorm_sm100.py | 34 +++---- .../benchmark/benchmark_rmsnorm_bwd_sm100.py | 39 ++++---- .../benchmark/benchmark_rmsnorm_sm100.py | 40 +++++---- .../benchmark/benchmark_softmax_sm100.py | 42 ++++++--- .../benchmarks/readme/plot_quack_style_svg.py | 4 +- .../kernelagent_oink/blackwell/layernorm.py | 2 +- .../kernelagent_oink/blackwell/lite_quack.py | 2 +- .../src/kernelagent_oink/blackwell/rmsnorm.py | 13 +-- .../blackwell/rmsnorm_with_stage2.py | 10 ++- .../src/kernelagent_oink/blackwell/softmax.py | 3 +- 13 files changed, 199 insertions(+), 157 deletions(-) diff --git a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py index 8bcac15..18399c7 100644 --- a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py @@ -219,14 +219,16 @@ def bench_single( bytes_io = bytes_io_model_ce(M, N, dtype, target_dtype=target.dtype, mode=mode) if mode == "fwd": - fn_oink = lambda: oink_ce.cross_entropy_forward( - logits, target, ignore_index=int(ignore_index), reduction="none" - ) - fn_quack = ( - None - if quack_ce_fwd is None - else ( - lambda: quack_ce_fwd( + def fn_oink(): + return oink_ce.cross_entropy_forward( + logits, target, ignore_index=int(ignore_index), reduction="none" + ) + + fn_quack = None + if quack_ce_fwd is not None: + + def fn_quack(): + return quack_ce_fwd( logits, target, target_logit=None, @@ -235,8 +237,7 @@ def bench_single( return_dx=False, inplace_backward=False, ) - ) - ) + elif mode == "bwd": with torch.no_grad(): _loss_o, lse_o = oink_ce.cross_entropy_forward( @@ -254,14 +255,17 @@ def bench_single( ) else: lse_q = None - fn_oink = lambda: oink_ce.cross_entropy_backward( - dloss, logits, target, lse_o, ignore_index=int(ignore_index) - ) - fn_quack = ( - None - if (quack_ce_bwd is None or lse_q is None) - else ( - lambda: quack_ce_bwd( + + def fn_oink(): + return oink_ce.cross_entropy_backward( + dloss, logits, target, lse_o, ignore_index=int(ignore_index) + ) + + fn_quack = None + if quack_ce_bwd is not None and lse_q is not None: + + def fn_quack(): + return quack_ce_bwd( logits, target, dloss, @@ -269,37 +273,38 @@ def bench_single( ignore_index=int(ignore_index), inplace_backward=False, ) - ) - ) + elif mode == "fwd_bwd": - fn_oink = lambda: oink_ce.cross_entropy_fwd_bwd( - dloss, - logits, - target, - ignore_index=int(ignore_index), - ) - fn_quack = ( - None - if (quack_ce_fwd is None or quack_ce_bwd is None) - else ( - lambda: quack_ce_bwd( + def fn_oink(): + return oink_ce.cross_entropy_fwd_bwd( + dloss, + logits, + target, + ignore_index=int(ignore_index), + ) + + fn_quack = None + if quack_ce_fwd is not None and quack_ce_bwd is not None: + + def fn_quack(): + _loss_q, lse_q = quack_ce_fwd( + logits, + target, + target_logit=None, + ignore_index=int(ignore_index), + return_lse=True, + return_dx=False, + inplace_backward=False, + ) + return quack_ce_bwd( logits, target, dloss, - quack_ce_fwd( - logits, - target, - target_logit=None, - ignore_index=int(ignore_index), - return_lse=True, - return_dx=False, - inplace_backward=False, - )[1], + lse_q, ignore_index=int(ignore_index), inplace_backward=False, ) - ) - ) + else: raise ValueError(f"Unsupported mode: {mode}") diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py index b75f892..6418e61 100644 --- a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py @@ -1,5 +1,3 @@ -from __future__ import annotations - """ Benchmark fused_add_rmsnorm (in-place) on SM100. @@ -21,9 +19,11 @@ --json /tmp/kernelagent_oink_sm100_suite_bf16/fused_add_rmsnorm_dsv3.json """ +from __future__ import annotations + import argparse import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple import torch @@ -165,7 +165,9 @@ def bench_one( bytes_io = bytes_io_model_fused_add_rmsnorm_inplace(M, N, dtype) - fn = lambda: oink_rmsnorm.fused_add_rmsnorm_inplace_(x, residual, w, eps=1e-6) + def fn(): + oink_rmsnorm.fused_add_rmsnorm_inplace_(x, residual, w, eps=1e-6) + ms = do_bench_triton(fn, warmup_ms=warmup_ms, rep_ms=iters_ms) gbps = bytes_io / (ms * 1e-3) / 1e9 @@ -187,18 +189,20 @@ def bench_one( out_q = torch.empty_like(x) res_out_q = torch.empty_like(residual) - fn_q = lambda: quack_rmsnorm_fwd_mut( - x, - w, - out_q, - None, # bias - None, # rstd - None, # mean - residual, - res_out_q, - 1e-6, - False, # is_layernorm - ) + def fn_q(): + quack_rmsnorm_fwd_mut( + x, + w, + out_q, + None, # bias + None, # rstd + None, # mean + residual, + res_out_q, + 1e-6, + False, # is_layernorm + ) + ms_q = do_bench_triton(fn_q, warmup_ms=warmup_ms, rep_ms=iters_ms) gbps_q = bytes_io / (ms_q * 1e-3) / 1e9 row.update( diff --git a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py index 971a03c..8ec4bfd 100644 --- a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py @@ -1,5 +1,3 @@ -from __future__ import annotations - """ HBM roofline microbenchmark for SM100 (GB200 / Blackwell). @@ -17,9 +15,11 @@ CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py --dtype fp16 --op triad --gb 2 """ +from __future__ import annotations + import argparse import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple import torch import triton @@ -95,23 +95,27 @@ def bench_one( grid = (triton.cdiv(n_elements, block),) if op == "copy": - launch = lambda: _copy_kernel[grid]( - x, - y, - n_elements, - BLOCK=block, - num_warps=num_warps, - num_stages=4, - ) + def launch(): + _copy_kernel[grid]( + x, + y, + n_elements, + BLOCK=block, + num_warps=num_warps, + num_stages=4, + ) + elif op == "triad": - launch = lambda: _triad_kernel[grid]( - x, - y, - n_elements, - BLOCK=block, - num_warps=num_warps, - num_stages=4, - ) + def launch(): + _triad_kernel[grid]( + x, + y, + n_elements, + BLOCK=block, + num_warps=num_warps, + num_stages=4, + ) + else: raise ValueError(f"Unsupported op: {op}") diff --git a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py index 778e3e2..a9865d1 100644 --- a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py @@ -244,27 +244,31 @@ def bench_single( weight_dtype=w.dtype, ) - fn_oink = lambda: oink_ln.layernorm( - x, - w, - bias=b, - eps=eps, - return_rstd=return_rstd, - return_mean=return_mean, - ) + def fn_oink(): + return oink_ln.layernorm( + x, + w, + bias=b, + eps=eps, + return_rstd=return_rstd, + return_mean=return_mean, + ) + ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9 if quack_layernorm is None or has_bias: return (ms_oink, gbps_oink), None, stats - fn_quack = lambda: quack_layernorm( - x, - w, - eps=eps, - return_rstd=return_rstd, - return_mean=return_mean, - ) + def fn_quack(): + return quack_layernorm( + x, + w, + eps=eps, + return_rstd=return_rstd, + return_mean=return_mean, + ) + ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms) gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9 return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py index 01c390d..4ba1c47 100644 --- a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py @@ -262,15 +262,16 @@ def bench_single( if verify: stats = _verify_parity(x, w, dout, rstd, has_bias=False, has_residual=False) - fn_oink = lambda: oink_rmsnorm.rmsnorm_backward( - x, - w, - dout, - rstd, - dresidual_out=None, - has_bias=False, - has_residual=False, - ) + def fn_oink(): + return oink_rmsnorm.rmsnorm_backward( + x, + w, + dout, + rstd, + dresidual_out=None, + has_bias=False, + has_residual=False, + ) ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) bytes_io = bytes_io_model_bwd(M, N, dtype, weight_dtype=w.dtype) @@ -280,15 +281,17 @@ def bench_single( if quack_rmsnorm_bwd is None: return ours, None, stats - fn_quack = lambda: quack_rmsnorm_bwd( - x, - w, - dout, - rstd, - dresidual_out=None, - has_bias=False, - has_residual=False, - ) + def fn_quack(): + return quack_rmsnorm_bwd( + x, + w, + dout, + rstd, + dresidual_out=None, + has_bias=False, + has_residual=False, + ) + ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms) gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9 return ours, Result(ms=ms_quack, gbps=gbps_quack), stats diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py index e55e9ff..20ed8ac 100644 --- a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py @@ -183,30 +183,34 @@ def bench_single( bytes_io = bytes_io_model_fwd(M, N, dtype, weight_dtype=w.dtype) - fn_oink = lambda: oink_rmsnorm.rmsnorm_forward( - x, - weight=w, - bias=None, - residual=None, - eps=eps, - store_rstd=store_rstd, - ) + def fn_oink(): + return oink_rmsnorm.rmsnorm_forward( + x, + weight=w, + bias=None, + residual=None, + eps=eps, + store_rstd=store_rstd, + ) + ms_oink = do_bench_triton(fn_oink, warmup_ms=warmup_ms, rep_ms=iters_ms) gbps_oink = bytes_io / (ms_oink * 1e-3) / 1e9 if quack_rmsnorm_fwd is None: return (ms_oink, gbps_oink), None, stats - fn_quack = lambda: quack_rmsnorm_fwd( - x, - w, - bias=None, - residual=None, - out_dtype=None, - residual_dtype=None, - eps=eps, - store_rstd=store_rstd, - ) + def fn_quack(): + return quack_rmsnorm_fwd( + x, + w, + bias=None, + residual=None, + out_dtype=None, + residual_dtype=None, + eps=eps, + store_rstd=store_rstd, + ) + ms_quack = do_bench_triton(fn_quack, warmup_ms=warmup_ms, rep_ms=iters_ms) gbps_quack = bytes_io / (ms_quack * 1e-3) / 1e9 return (ms_oink, gbps_oink), (ms_quack, gbps_quack), stats diff --git a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py index 93c5af3..7826efc 100644 --- a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py @@ -150,25 +150,39 @@ def bench_single( bytes_io = bytes_io_model_softmax(M, N, dtype, mode=mode) if mode == "fwd": - fn_oink = lambda: oink_softmax.softmax_forward(x) - fn_quack = None if quack_softmax_fwd is None else (lambda: quack_softmax_fwd(x)) + def fn_oink(): + return oink_softmax.softmax_forward(x) + + fn_quack = None + if quack_softmax_fwd is not None: + + def fn_quack(): + return quack_softmax_fwd(x) + elif mode == "bwd": with torch.no_grad(): y_o = oink_softmax.softmax_forward(x) y_q = quack_softmax_fwd(x) if quack_softmax_fwd is not None else None - fn_oink = lambda: oink_softmax.softmax_backward(dy, y_o) - fn_quack = ( - None - if (quack_softmax_bwd is None or y_q is None) - else (lambda: quack_softmax_bwd(dy, y_q)) - ) + + def fn_oink(): + return oink_softmax.softmax_backward(dy, y_o) + + fn_quack = None + if quack_softmax_bwd is not None and y_q is not None: + + def fn_quack(): + return quack_softmax_bwd(dy, y_q) + elif mode == "fwd_bwd": - fn_oink = lambda: oink_softmax.softmax_fwd_bwd(dy, x) - fn_quack = ( - None - if (quack_softmax_fwd is None or quack_softmax_bwd is None) - else (lambda: quack_softmax_bwd(dy, quack_softmax_fwd(x))) - ) + def fn_oink(): + return oink_softmax.softmax_fwd_bwd(dy, x) + + fn_quack = None + if quack_softmax_fwd is not None and quack_softmax_bwd is not None: + + def fn_quack(): + return quack_softmax_bwd(dy, quack_softmax_fwd(x)) + else: raise ValueError(f"Unsupported mode: {mode}") diff --git a/oink/benchmarks/readme/plot_quack_style_svg.py b/oink/benchmarks/readme/plot_quack_style_svg.py index 1799f2e..c089b2b 100644 --- a/oink/benchmarks/readme/plot_quack_style_svg.py +++ b/oink/benchmarks/readme/plot_quack_style_svg.py @@ -1,5 +1,3 @@ -from __future__ import annotations - """ Generate Quack-style SVG performance plots (Oink vs Quack) from the SM100 suite JSON artifacts under `/tmp/kernelagent_oink_sm100_suite_{bf16,fp16}`. @@ -42,6 +40,8 @@ when available: `fused_add_rmsnorm_dsv3.json`. """ +from __future__ import annotations + import argparse import json import math diff --git a/oink/src/kernelagent_oink/blackwell/layernorm.py b/oink/src/kernelagent_oink/blackwell/layernorm.py index 05b11de..0e4d640 100644 --- a/oink/src/kernelagent_oink/blackwell/layernorm.py +++ b/oink/src/kernelagent_oink/blackwell/layernorm.py @@ -80,7 +80,7 @@ # Local helpers cloned from Quack via lite_quack so that this kernel does # not depend on `quack` at runtime. -from kernelagent_oink.blackwell.lite_quack import ( +from kernelagent_oink.blackwell.lite_quack import ( # noqa: E402 _KERNEL_ACCEPTS_LAYOUT_ARGS, TORCH2CUTE_DTYPE, ReductionBase as _ReductionBase, diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py index 1bc15b1..590d773 100644 --- a/oink/src/kernelagent_oink/blackwell/lite_quack.py +++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py @@ -39,7 +39,7 @@ from cutlass import Float32, Int32, const_expr from cutlass.cute.runtime import from_dlpack from cutlass.cutlass_dsl import T, dsl_user_op -from cutlass._mlir.dialects import llvm, nvvm, vector +from cutlass._mlir.dialects import llvm, vector def _parse_version_tuple(version: str) -> tuple[int, int, int]: diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py index 1e080a3..9df9f16 100644 --- a/oink/src/kernelagent_oink/blackwell/rmsnorm.py +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py @@ -3119,7 +3119,6 @@ def _rmsnorm_bwd_sm100( assert dresidual.dtype in (torch.float16, torch.bfloat16, torch.float32) M, N = x.size(0), x.size(1) - device = x.device if dw_partial is None and db_partial is None: assert sm_count is not None else: @@ -3130,12 +3129,14 @@ def _rmsnorm_bwd_sm100( # Match Quack's conversion strategy for activations/gradients: keep the # (M, N) layout dynamic without enforcing additional compact-shape # constraints. This reduces per-call Python overhead for small-M shapes. - convert_from_dlpack = lambda t: from_dlpack( # type: ignore[assignment] - t.detach(), - assumed_align=16, - ).mark_layout_dynamic(leading_dim=1) + def _convert_mx(t: Tensor) -> cute.Tensor: + return from_dlpack( + t.detach(), + assumed_align=16, + ).mark_layout_dynamic(leading_dim=1) + x_tensor, dout_tensor, dres_out_tensor, dx_tensor, dres_tensor = [ - convert_from_dlpack(t) if t is not None else None + _convert_mx(t) if t is not None else None for t in (x, dout, dresidual_out, dx, dresidual) ] diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py index b53da12..fec5bf4 100644 --- a/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py @@ -744,9 +744,13 @@ def rmsnorm_forward_with_stage2( M, N = x.shape dtype = TORCH2CUTE_DTYPE[x.dtype] - convert_x = lambda t: from_dlpack(t.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1) - mX = convert_x(x) - mRes = convert_x(residual) if residual is not None else None + def _convert_x(t: Tensor) -> cute.Tensor: + return from_dlpack( + t.detach(), assumed_align=32 + ).mark_layout_dynamic(leading_dim=1) + + mX = _convert_x(x) + mRes = _convert_x(residual) if residual is not None else None out = torch.empty_like(x, dtype=x.dtype) mO = from_dlpack(out.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1) diff --git a/oink/src/kernelagent_oink/blackwell/softmax.py b/oink/src/kernelagent_oink/blackwell/softmax.py index a2f2581..a8a2791 100644 --- a/oink/src/kernelagent_oink/blackwell/softmax.py +++ b/oink/src/kernelagent_oink/blackwell/softmax.py @@ -27,10 +27,9 @@ from __future__ import annotations import importlib.metadata -import math import os import re -from typing import Optional, Type +from typing import Type import torch from torch import Tensor From 0543b6f6948c8c110a72d7273ec343f8f897baaa Mon Sep 17 00:00:00 2001 From: Laura Wang <3700467+Laurawly@users.noreply.github.com> Date: Wed, 21 Jan 2026 20:08:22 -0800 Subject: [PATCH 6/8] oink: ruff format --- oink/benchmarks/benchmark/bench_utils.py | 34 +- .../benchmark_cross_entropy_sm100.py | 97 +++-- .../benchmark_fused_add_rmsnorm_sm100.py | 35 +- .../benchmark/benchmark_hbm_roofline_sm100.py | 44 ++- .../benchmark/benchmark_layernorm_sm100.py | 82 +++-- .../benchmark/benchmark_rmsnorm_bwd_sm100.py | 32 +- .../benchmark/benchmark_rmsnorm_sm100.py | 55 ++- .../benchmark/benchmark_softmax_sm100.py | 39 +- .../benchmarks/readme/plot_quack_style_svg.py | 44 ++- oink/benchmarks/readme/run_sm100_suite.py | 8 +- oink/benchmarks/readme/summarize_results.py | 61 +++- .../blackwell/cross_entropy.py | 81 ++++- .../kernelagent_oink/blackwell/layernorm.py | 90 +++-- .../kernelagent_oink/blackwell/lite_quack.py | 179 ++++++--- .../src/kernelagent_oink/blackwell/rmsnorm.py | 32 +- .../blackwell/rmsnorm_with_stage2.py | 342 ++++++++++++++---- .../src/kernelagent_oink/blackwell/softmax.py | 100 +++-- 17 files changed, 1019 insertions(+), 336 deletions(-) diff --git a/oink/benchmarks/benchmark/bench_utils.py b/oink/benchmarks/benchmark/bench_utils.py index 0abb005..0a9ae4b 100644 --- a/oink/benchmarks/benchmark/bench_utils.py +++ b/oink/benchmarks/benchmark/bench_utils.py @@ -67,7 +67,9 @@ def detect_hbm_peak_gbps(device: Optional[torch.device] = None) -> float: return 2000.0 -def do_bench_triton(fn: Callable[[], Any], *, warmup_ms: int = 25, rep_ms: int = 100) -> float: +def do_bench_triton( + fn: Callable[[], Any], *, warmup_ms: int = 25, rep_ms: int = 100 +) -> float: """Kernel-only timing consistent with the Oink benchmark harnesses.""" return float(triton_do_bench(fn, warmup=warmup_ms, rep=rep_ms, return_mode="mean")) @@ -127,7 +129,13 @@ def write_csv(path: str, rows: Sequence[Dict[str, Any]]) -> None: writer.writerow(row) -def write_json(path: str, meta: DeviceMeta, rows: Sequence[Dict[str, Any]], *, extra: Dict[str, Any] | None = None) -> None: +def write_json( + path: str, + meta: DeviceMeta, + rows: Sequence[Dict[str, Any]], + *, + extra: Dict[str, Any] | None = None, +) -> None: os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) payload: Dict[str, Any] = { "meta": {**asdict(meta), **(extra or {})}, @@ -179,7 +187,9 @@ def __init__(self, *, total_elems: int, p99_target_samples: int = 1_000_000): if total_elems <= 0: raise ValueError(f"total_elems must be > 0, got {total_elems}") if p99_target_samples <= 0: - raise ValueError(f"p99_target_samples must be > 0, got {p99_target_samples}") + raise ValueError( + f"p99_target_samples must be > 0, got {p99_target_samples}" + ) self.total_elems = int(total_elems) self.p99_target_samples = int(p99_target_samples) # Deterministic strided sampling across the flattened tensor order. @@ -193,7 +203,9 @@ def __init__(self, *, total_elems: int, p99_target_samples: int = 1_000_000): def update(self, out: torch.Tensor, ref: torch.Tensor) -> None: if out.shape != ref.shape: - raise ValueError(f"shape mismatch: out={tuple(out.shape)} ref={tuple(ref.shape)}") + raise ValueError( + f"shape mismatch: out={tuple(out.shape)} ref={tuple(ref.shape)}" + ) # Compute error in float32 for stable reductions. err_f32 = (out - ref).to(torch.float32) @@ -214,9 +226,13 @@ def update(self, out: torch.Tensor, ref: torch.Tensor) -> None: stride = int(self.sample_stride) first = (-int(self._global_offset)) % stride if first < block_elems: - idx = torch.arange(first, block_elems, step=stride, device=flat.device, dtype=torch.int64) + idx = torch.arange( + first, block_elems, step=stride, device=flat.device, dtype=torch.int64 + ) # Gather a modest number of values (≈ block_elems/stride). - vals = flat.index_select(0, idx).detach().to(device="cpu", dtype=torch.float32) + vals = ( + flat.index_select(0, idx).detach().to(device="cpu", dtype=torch.float32) + ) self._abs_err_samples.append(vals) self._global_offset += block_elems @@ -226,7 +242,11 @@ def finalize(self) -> ErrorStats: samples = torch.cat(self._abs_err_samples, dim=0) if samples.numel() > self.p99_target_samples: samples = samples[: self.p99_target_samples] - p99 = float(torch.quantile(samples, 0.99).item()) if samples.numel() > 0 else 0.0 + p99 = ( + float(torch.quantile(samples, 0.99).item()) + if samples.numel() > 0 + else 0.0 + ) sample_elems = int(samples.numel()) else: p99 = 0.0 diff --git a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py index 18399c7..ff1a99b 100644 --- a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py @@ -84,16 +84,22 @@ def dsv3_configs() -> List[Tuple[int, int]]: return [(m, n) for m in Ms for n in Ns] -def _verify_parity(logits: torch.Tensor, target: torch.Tensor, *, ignore_index: int) -> dict[str, object]: +def _verify_parity( + logits: torch.Tensor, target: torch.Tensor, *, ignore_index: int +) -> dict[str, object]: dtype = logits.dtype ref_block_rows = 512 - dloss = torch.randn(logits.size(0), device=logits.device, dtype=torch.float32) # upstream grad + dloss = torch.randn( + logits.size(0), device=logits.device, dtype=torch.float32 + ) # upstream grad with torch.no_grad(): loss_o, lse_o = oink_ce.cross_entropy_forward( logits, target, ignore_index=ignore_index, reduction="none" ) - dx_o = oink_ce.cross_entropy_backward(dloss, logits, target, lse_o, ignore_index=ignore_index) + dx_o = oink_ce.cross_entropy_backward( + dloss, logits, target, lse_o, ignore_index=ignore_index + ) dx_fused_o = oink_ce.cross_entropy_fwd_bwd( dloss, logits, @@ -125,8 +131,12 @@ def _verify_parity(logits: torch.Tensor, target: torch.Tensor, *, ignore_index: M = int(logits.shape[0]) N = int(logits.shape[1]) - loss_acc_ours = ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000)) - lse_acc_ours = ErrorStatsAccumulator(total_elems=M, p99_target_samples=min(M, 1_000_000)) + loss_acc_ours = ErrorStatsAccumulator( + total_elems=M, p99_target_samples=min(M, 1_000_000) + ) + lse_acc_ours = ErrorStatsAccumulator( + total_elems=M, p99_target_samples=min(M, 1_000_000) + ) dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N) dx_fused_acc_ours = ErrorStatsAccumulator(total_elems=M * N) loss_acc_quack = ( @@ -159,23 +169,39 @@ def _verify_parity(logits: torch.Tensor, target: torch.Tensor, *, ignore_index: ignore_index=ignore_index, ) lse_ref = torch.logsumexp(logits_f32, dim=-1) - (dx_ref_f32,) = torch.autograd.grad(loss_ref, logits_f32, grad_outputs=dloss_blk) + (dx_ref_f32,) = torch.autograd.grad( + loss_ref, logits_f32, grad_outputs=dloss_blk + ) dx_ref = dx_ref_f32.to(dtype) - torch.testing.assert_close(loss_o[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS) - torch.testing.assert_close(lse_o[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS) + torch.testing.assert_close( + loss_o[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS + ) + torch.testing.assert_close( + lse_o[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS + ) torch.testing.assert_close(dx_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype]) - torch.testing.assert_close(dx_fused_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype]) + torch.testing.assert_close( + dx_fused_o[start:end], dx_ref, **_VERIFY_TOL_DX[dtype] + ) loss_acc_ours.update(loss_o[start:end], loss_ref.detach()) lse_acc_ours.update(lse_o[start:end], lse_ref.detach()) dx_acc_ours.update(dx_o[start:end], dx_ref) dx_fused_acc_ours.update(dx_fused_o[start:end], dx_ref) if loss_q is not None and lse_q is not None and dx_q is not None: - torch.testing.assert_close(loss_q[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS) - torch.testing.assert_close(lse_q[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS) + torch.testing.assert_close( + loss_q[start:end], loss_ref.detach(), **_VERIFY_TOL_LOSS + ) + torch.testing.assert_close( + lse_q[start:end], lse_ref.detach(), **_VERIFY_TOL_LOSS + ) torch.testing.assert_close(dx_q[start:end], dx_ref, **_VERIFY_TOL_DX[dtype]) - assert loss_acc_quack is not None and lse_acc_quack is not None and dx_acc_quack is not None + assert ( + loss_acc_quack is not None + and lse_acc_quack is not None + and dx_acc_quack is not None + ) loss_acc_quack.update(loss_q[start:end], loss_ref.detach()) lse_acc_quack.update(lse_q[start:end], lse_ref.detach()) dx_acc_quack.update(dx_q[start:end], dx_ref) @@ -185,7 +211,11 @@ def _verify_parity(logits: torch.Tensor, target: torch.Tensor, *, ignore_index: stats.update(error_stats_to_row("ours_err_lse", lse_acc_ours.finalize())) stats.update(error_stats_to_row("ours_err_dx", dx_acc_ours.finalize())) stats.update(error_stats_to_row("ours_err_dx_fused", dx_fused_acc_ours.finalize())) - if loss_acc_quack is not None and lse_acc_quack is not None and dx_acc_quack is not None: + if ( + loss_acc_quack is not None + and lse_acc_quack is not None + and dx_acc_quack is not None + ): stats.update(error_stats_to_row("quack_err_loss", loss_acc_quack.finalize())) stats.update(error_stats_to_row("quack_err_lse", lse_acc_quack.finalize())) stats.update(error_stats_to_row("quack_err_dx", dx_acc_quack.finalize())) @@ -219,6 +249,7 @@ def bench_single( bytes_io = bytes_io_model_ce(M, N, dtype, target_dtype=target.dtype, mode=mode) if mode == "fwd": + def fn_oink(): return oink_ce.cross_entropy_forward( logits, target, ignore_index=int(ignore_index), reduction="none" @@ -275,6 +306,7 @@ def fn_quack(): ) elif mode == "fwd_bwd": + def fn_oink(): return oink_ce.cross_entropy_fwd_bwd( dloss, @@ -330,16 +362,34 @@ def main() -> None: print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") p = argparse.ArgumentParser() - p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]) - p.add_argument("--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"]) + p.add_argument( + "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"] + ) + p.add_argument( + "--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"] + ) p.add_argument("--ignore-index", type=int, default=-100) - p.add_argument("--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only).") + p.add_argument( + "--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only)." + ) p.add_argument("--warmup-ms", type=int, default=25) - p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows") - p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)") + p.add_argument( + "--csv", type=str, default=None, help="Optional CSV output path; appends rows" + ) + p.add_argument( + "--json", type=str, default=None, help="Optional JSON output path (meta + rows)" + ) p.add_argument("--configs", type=str, default="1024x4096,8192x4096") - p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid (vocab=4096)") - p.add_argument("--dsv3", action="store_true", help="Run DSv3 set: M in {4096,16384,65536}, N in {3072,6144,8192,12288}") + p.add_argument( + "--quack-suite", + action="store_true", + help="Run Quack-style batch/seq grid (vocab=4096)", + ) + p.add_argument( + "--dsv3", + action="store_true", + help="Run DSv3 set: M in {4096,16384,65536}, N in {3072,6144,8192,12288}", + ) p.add_argument( "--skip-verify", action="store_true", @@ -360,8 +410,11 @@ def main() -> None: meta = collect_device_meta(device) rows_out: List[Dict[str, Any]] = [] - for (M, N) in cfgs: - print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...", flush=True) + for M, N in cfgs: + print( + f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...", + flush=True, + ) (ms_oink, gbps_oink), quack, stats = bench_single( M=M, N=N, diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py index 6418e61..863712d 100644 --- a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py @@ -90,8 +90,16 @@ def _verify_parity( y_acc_ours = ErrorStatsAccumulator(total_elems=M * N) z_acc_ours = ErrorStatsAccumulator(total_elems=M * N) - y_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_fwd_mut is not None else None - z_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_fwd_mut is not None else None + y_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if quack_rmsnorm_fwd_mut is not None + else None + ) + z_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if quack_rmsnorm_fwd_mut is not None + else None + ) x_o = x.clone() r_o = residual.clone() @@ -141,7 +149,9 @@ def _verify_parity( stats.update(error_stats_to_row("ours_err_residual_out", z_acc_ours.finalize())) if y_acc_quack is not None and z_acc_quack is not None: stats.update(error_stats_to_row("quack_err_y", y_acc_quack.finalize())) - stats.update(error_stats_to_row("quack_err_residual_out", z_acc_quack.finalize())) + stats.update( + error_stats_to_row("quack_err_residual_out", z_acc_quack.finalize()) + ) return stats @@ -177,7 +187,9 @@ def fn(): row: Dict[str, Any] = dict( M=int(M), N=int(N), - dtype="bf16" if dtype is torch.bfloat16 else ("fp16" if dtype is torch.float16 else "fp32"), + dtype="bf16" + if dtype is torch.bfloat16 + else ("fp16" if dtype is torch.float16 else "fp32"), ours_ms=float(ms), ours_gbps=float(gbps), ours_tbps=float(tbps), @@ -247,7 +259,9 @@ def _print_table(rows: List[Dict[str, Any]]) -> None: def main() -> None: p = argparse.ArgumentParser() - p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) + p.add_argument( + "--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"] + ) p.add_argument("--M", type=int, default=65536) p.add_argument("--N", type=int, default=4096) p.add_argument( @@ -256,7 +270,9 @@ def main() -> None: help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", ) p.add_argument("--warmup-ms", type=int, default=25) - p.add_argument("--iters", type=int, default=200, help="rep_ms for do_bench (default: 200)") + p.add_argument( + "--iters", type=int, default=200, help="rep_ms for do_bench (default: 200)" + ) p.add_argument("--skip-verify", action="store_true") p.add_argument("--json", type=str, default=None) args = p.parse_args() @@ -266,8 +282,11 @@ def main() -> None: cfgs = dsv3_configs() if bool(args.dsv3) else [(int(args.M), int(args.N))] rows: List[Dict[str, Any]] = [] - for (M, N) in cfgs: - print(f"bench M={M:<8d} N={N:<6d} dtype={_dtype_label(dtype)} fused_add_rmsnorm ...", flush=True) + for M, N in cfgs: + print( + f"bench M={M:<8d} N={N:<6d} dtype={_dtype_label(dtype)} fused_add_rmsnorm ...", + flush=True, + ) rows.append( bench_one( M=int(M), diff --git a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py index 8ec4bfd..c22294e 100644 --- a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py @@ -95,6 +95,7 @@ def bench_one( grid = (triton.cdiv(n_elements, block),) if op == "copy": + def launch(): _copy_kernel[grid]( x, @@ -106,6 +107,7 @@ def launch(): ) elif op == "triad": + def launch(): _triad_kernel[grid]( x, @@ -134,20 +136,38 @@ def _print_summary(rows: List[Dict[str, Any]]) -> None: return best = max(rows, key=lambda r: float(r["tbps"])) print("\nSummary (STREAM-like):") - print(f"- best_tbps: {best['tbps']:.3f} TB/s ({best['op']}, BLOCK={best['block']}, warps={best['num_warps']})") + print( + f"- best_tbps: {best['tbps']:.3f} TB/s ({best['op']}, BLOCK={best['block']}, warps={best['num_warps']})" + ) def main() -> None: p = argparse.ArgumentParser() - p.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) + p.add_argument( + "--dtype", type=str, default="bf16", choices=["bf16", "fp16", "fp32"] + ) p.add_argument("--op", type=str, default="copy", choices=["copy", "triad", "both"]) - p.add_argument("--gb", type=float, default=2.0, help="Size per tensor in GB (default: 2)") + p.add_argument( + "--gb", type=float, default=2.0, help="Size per tensor in GB (default: 2)" + ) p.add_argument("--warmup-ms", type=int, default=25) - p.add_argument("--iters", type=int, default=100, help="rep_ms for do_bench (default: 100)") - p.add_argument("--json", type=str, default=None, help="Write JSON results to this path") - p.add_argument("--no-sweep", action="store_true", help="Disable tuning sweep; run a single config") - p.add_argument("--block", type=int, default=2048, help="BLOCK size when --no-sweep is set") - p.add_argument("--warps", type=int, default=8, help="num_warps when --no-sweep is set") + p.add_argument( + "--iters", type=int, default=100, help="rep_ms for do_bench (default: 100)" + ) + p.add_argument( + "--json", type=str, default=None, help="Write JSON results to this path" + ) + p.add_argument( + "--no-sweep", + action="store_true", + help="Disable tuning sweep; run a single config", + ) + p.add_argument( + "--block", type=int, default=2048, help="BLOCK size when --no-sweep is set" + ) + p.add_argument( + "--warps", type=int, default=8, help="num_warps when --no-sweep is set" + ) args = p.parse_args() dtype = parse_dtype(args.dtype) @@ -181,7 +201,9 @@ def main() -> None: print(f"Running on {props.name} (SM{props.major}{props.minor})") print(f"- dtype: {args.dtype} (elem={elem_size}B)") - print(f"- n_elements: {n_elements:,} (~{(n_elements * elem_size) / (1024**3):.2f} GiB per tensor)") + print( + f"- n_elements: {n_elements:,} (~{(n_elements * elem_size) / (1024**3):.2f} GiB per tensor)" + ) print(f"- ops: {ops}") print(f"- sweep: {sweep}") @@ -212,7 +234,9 @@ def main() -> None: tbps=float(tbps), ) ) - print(f"- {op:5s} BLOCK={block:4d} warps={warps}: {tbps:.3f} TB/s ({ms:.4f} ms)") + print( + f"- {op:5s} BLOCK={block:4d} warps={warps}: {tbps:.3f} TB/s ({ms:.4f} ms)" + ) _print_summary(rows) diff --git a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py index a9865d1..3c0e37d 100644 --- a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py @@ -97,7 +97,9 @@ def _verify_parity( y_acc_ours = ErrorStatsAccumulator(total_elems=M * N) y_acc_quack = ( - ErrorStatsAccumulator(total_elems=M * N) if (quack_layernorm is not None and b is None) else None + ErrorStatsAccumulator(total_elems=M * N) + if (quack_layernorm is not None and b is None) + else None ) with torch.no_grad(): ours = oink_ln.layernorm( @@ -138,8 +140,12 @@ def _unpack(out): # Pure-PyTorch reference (float32 accumulation), matching Quack's unit tests: # - compute ref output via F.layer_norm on float32 # - compute mean/rstd from float32 input - rstd_ref_all = torch.empty((M,), device=x.device, dtype=torch.float32) if return_rstd else None - mean_ref_all = torch.empty((M,), device=x.device, dtype=torch.float32) if return_mean else None + rstd_ref_all = ( + torch.empty((M,), device=x.device, dtype=torch.float32) if return_rstd else None + ) + mean_ref_all = ( + torch.empty((M,), device=x.device, dtype=torch.float32) if return_mean else None + ) for start, end in iter_row_blocks(M, ref_block_rows): x_f32 = x[start:end].float() @@ -165,16 +171,24 @@ def _unpack(out): rstd_ref_all[start:end] = rstd_ref assert rstd_o is not None - torch.testing.assert_close(rstd_o[start:end], rstd_ref, **_VERIFY_TOL_STATS) + torch.testing.assert_close( + rstd_o[start:end], rstd_ref, **_VERIFY_TOL_STATS + ) if rstd_q is not None: - torch.testing.assert_close(rstd_q[start:end], rstd_ref, **_VERIFY_TOL_STATS) + torch.testing.assert_close( + rstd_q[start:end], rstd_ref, **_VERIFY_TOL_STATS + ) if return_mean: mean_ref = mean_f32 assert mean_o is not None - torch.testing.assert_close(mean_o[start:end], mean_ref, **_VERIFY_TOL_STATS) + torch.testing.assert_close( + mean_o[start:end], mean_ref, **_VERIFY_TOL_STATS + ) if mean_q is not None: - torch.testing.assert_close(mean_q[start:end], mean_ref, **_VERIFY_TOL_STATS) + torch.testing.assert_close( + mean_q[start:end], mean_ref, **_VERIFY_TOL_STATS + ) stats: dict[str, object] = {} stats.update(error_stats_to_row("ours_err_y", y_acc_ours.finalize())) @@ -184,30 +198,38 @@ def _unpack(out): if return_rstd: assert rstd_o is not None and rstd_ref_all is not None rstd_acc_ours = ErrorStatsAccumulator( - total_elems=int(rstd_ref_all.numel()), p99_target_samples=int(rstd_ref_all.numel()) + total_elems=int(rstd_ref_all.numel()), + p99_target_samples=int(rstd_ref_all.numel()), ) rstd_acc_ours.update(rstd_o, rstd_ref_all) stats.update(error_stats_to_row("ours_err_rstd", rstd_acc_ours.finalize())) if rstd_q is not None: rstd_acc_quack = ErrorStatsAccumulator( - total_elems=int(rstd_ref_all.numel()), p99_target_samples=int(rstd_ref_all.numel()) + total_elems=int(rstd_ref_all.numel()), + p99_target_samples=int(rstd_ref_all.numel()), ) rstd_acc_quack.update(rstd_q, rstd_ref_all) - stats.update(error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize())) + stats.update( + error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize()) + ) if return_mean: assert mean_o is not None and mean_ref_all is not None mean_acc_ours = ErrorStatsAccumulator( - total_elems=int(mean_ref_all.numel()), p99_target_samples=int(mean_ref_all.numel()) + total_elems=int(mean_ref_all.numel()), + p99_target_samples=int(mean_ref_all.numel()), ) mean_acc_ours.update(mean_o, mean_ref_all) stats.update(error_stats_to_row("ours_err_mean", mean_acc_ours.finalize())) if mean_q is not None: mean_acc_quack = ErrorStatsAccumulator( - total_elems=int(mean_ref_all.numel()), p99_target_samples=int(mean_ref_all.numel()) + total_elems=int(mean_ref_all.numel()), + p99_target_samples=int(mean_ref_all.numel()), ) mean_acc_quack.update(mean_q, mean_ref_all) - stats.update(error_stats_to_row("quack_err_mean", mean_acc_quack.finalize())) + stats.update( + error_stats_to_row("quack_err_mean", mean_acc_quack.finalize()) + ) return stats @@ -232,7 +254,9 @@ def bench_single( stats: dict[str, object] = {} if verify: - stats = _verify_parity(x, w, b, eps=eps, return_rstd=return_rstd, return_mean=return_mean) + stats = _verify_parity( + x, w, b, eps=eps, return_rstd=return_rstd, return_mean=return_mean + ) bytes_io = bytes_io_model_layernorm( M, @@ -285,17 +309,33 @@ def main() -> None: print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") p = argparse.ArgumentParser() - p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]) + p.add_argument( + "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"] + ) p.add_argument("--eps", type=float, default=1e-6) p.add_argument("--return-rstd", action="store_true") p.add_argument("--return-mean", action="store_true") - p.add_argument("--with-bias", action="store_true", help="Benchmark bias path (Quack compare skipped)") - p.add_argument("--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only).") + p.add_argument( + "--with-bias", + action="store_true", + help="Benchmark bias path (Quack compare skipped)", + ) + p.add_argument( + "--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only)." + ) p.add_argument("--warmup-ms", type=int, default=25) - p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows") - p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)") + p.add_argument( + "--csv", type=str, default=None, help="Optional CSV output path; appends rows" + ) + p.add_argument( + "--json", type=str, default=None, help="Optional JSON output path (meta + rows)" + ) p.add_argument("--configs", type=str, default="1024x4096,8192x4096") - p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid (hidden=4096)") + p.add_argument( + "--quack-suite", + action="store_true", + help="Run Quack-style batch/seq grid (hidden=4096)", + ) p.add_argument( "--dsv3", action="store_true", @@ -322,7 +362,7 @@ def main() -> None: meta = collect_device_meta(device) rows_out: List[Dict[str, Any]] = [] - for (M, N) in cfgs: + for M, N in cfgs: print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True) (ms_oink, gbps_oink), quack, stats = bench_single( M=M, diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py index 4ba1c47..b9909e7 100644 --- a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py @@ -140,7 +140,11 @@ def _verify_parity( M, N = int(x.shape[0]), int(x.shape[1]) dx_acc_ours = ErrorStatsAccumulator(total_elems=M * N) - dx_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_bwd is not None else None + dx_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if quack_rmsnorm_bwd is not None + else None + ) with torch.no_grad(): dx_oink, dw_oink, db_oink, dres_oink = oink_rmsnorm.rmsnorm_backward( @@ -216,7 +220,9 @@ def _verify_parity( dw_tol = dict(atol=dw_atol, rtol=1e-3) torch.testing.assert_close(dw_oink_f32, dw_ref_f32, **dw_tol) if dw_quack is not None: - torch.testing.assert_close(dw_quack.to(torch.float32), dw_ref_f32, **dw_tol) + torch.testing.assert_close( + dw_quack.to(torch.float32), dw_ref_f32, **dw_tol + ) dw_tol = None # handled above if dw_tol is not None: torch.testing.assert_close(dw_oink, dw_ref, **dw_tol) @@ -224,7 +230,9 @@ def _verify_parity( torch.testing.assert_close(dw_quack, dw_ref, **dw_tol) # Record weight-grad error stats (small, so exact p99 over the full vector). - dw_acc_ours = ErrorStatsAccumulator(total_elems=int(dw_ref.numel()), p99_target_samples=int(dw_ref.numel())) + dw_acc_ours = ErrorStatsAccumulator( + total_elems=int(dw_ref.numel()), p99_target_samples=int(dw_ref.numel()) + ) dw_acc_ours.update(dw_oink, dw_ref) stats.update(error_stats_to_row("ours_err_dw", dw_acc_ours.finalize())) if dw_quack is not None: @@ -308,7 +316,9 @@ def main() -> None: print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") p = argparse.ArgumentParser() - p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]) + p.add_argument( + "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"] + ) p.add_argument( "--weight-dtype", type=str, @@ -324,10 +334,16 @@ def main() -> None: help="Triton do_bench rep_ms (kernel-only).", ) p.add_argument("--warmup-ms", type=int, default=25) - p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows") - p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)") + p.add_argument( + "--csv", type=str, default=None, help="Optional CSV output path; appends rows" + ) + p.add_argument( + "--json", type=str, default=None, help="Optional JSON output path (meta + rows)" + ) p.add_argument("--configs", type=str, default="1024x4096,8192x4096") - p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid") + p.add_argument( + "--quack-suite", action="store_true", help="Run Quack-style batch/seq grid" + ) p.add_argument( "--dsv3", action="store_true", @@ -358,7 +374,7 @@ def main() -> None: rows_out: list[dict[str, object]] = [] - for (M, N) in cfgs: + for M, N in cfgs: print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True) ours, quack, stats = bench_single( M=M, diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py index 20ed8ac..f4c8a5f 100644 --- a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py @@ -85,7 +85,11 @@ def _verify_parity( N = int(x.shape[1]) y_acc_ours = ErrorStatsAccumulator(total_elems=M * N) - y_acc_quack = ErrorStatsAccumulator(total_elems=M * N) if quack_rmsnorm_fwd is not None else None + y_acc_quack = ( + ErrorStatsAccumulator(total_elems=M * N) + if quack_rmsnorm_fwd is not None + else None + ) with torch.no_grad(): y_o, rstd_o, res_o = oink_rmsnorm.rmsnorm_forward( @@ -141,15 +145,20 @@ def _verify_parity( assert rstd_q is not None torch.testing.assert_close(rstd_q, rstd_ref, **tol_rstd) # Stats for rstd are cheap (M elements); compute exact p99 over all rows. - rstd_acc_ours = ErrorStatsAccumulator(total_elems=int(rstd_ref.numel()), p99_target_samples=int(rstd_ref.numel())) + rstd_acc_ours = ErrorStatsAccumulator( + total_elems=int(rstd_ref.numel()), p99_target_samples=int(rstd_ref.numel()) + ) rstd_acc_ours.update(rstd_o, rstd_ref) stats.update(error_stats_to_row("ours_err_rstd", rstd_acc_ours.finalize())) if rstd_q is not None: rstd_acc_quack = ErrorStatsAccumulator( - total_elems=int(rstd_ref.numel()), p99_target_samples=int(rstd_ref.numel()) + total_elems=int(rstd_ref.numel()), + p99_target_samples=int(rstd_ref.numel()), ) rstd_acc_quack.update(rstd_q, rstd_ref) - stats.update(error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize())) + stats.update( + error_stats_to_row("quack_err_rstd", rstd_acc_quack.finalize()) + ) # Residual output semantics differ slightly across implementations: # - Oink returns `None` when residual is None. # - Quack returns `x` as a safe alias in that case. @@ -227,7 +236,9 @@ def main() -> None: print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") p = argparse.ArgumentParser() - p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]) + p.add_argument( + "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"] + ) p.add_argument( "--weight-dtype", type=str, @@ -236,15 +247,33 @@ def main() -> None: help="RMSNorm weight dtype. `same` matches activation dtype (vLLM-style inference).", ) p.add_argument("--eps", type=float, default=1e-6) - p.add_argument("--store-rstd", action="store_true", help="Also write rstd (fp32 per row)") - p.add_argument("--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only).") + p.add_argument( + "--store-rstd", action="store_true", help="Also write rstd (fp32 per row)" + ) + p.add_argument( + "--iters", type=int, default=100, help="Triton do_bench rep_ms (kernel-only)." + ) p.add_argument("--warmup-ms", type=int, default=25) - p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows") - p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)") + p.add_argument( + "--csv", type=str, default=None, help="Optional CSV output path; appends rows" + ) + p.add_argument( + "--json", type=str, default=None, help="Optional JSON output path (meta + rows)" + ) p.add_argument("--configs", type=str, default="1024x4096,8192x4096") - p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid") - p.add_argument("--dsv3", action="store_true", help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}") - p.add_argument("--skip-verify", action="store_true", help="Skip correctness checks (Oink/Quack vs a pure-PyTorch reference)") + p.add_argument( + "--quack-suite", action="store_true", help="Run Quack-style batch/seq grid" + ) + p.add_argument( + "--dsv3", + action="store_true", + help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", + ) + p.add_argument( + "--skip-verify", + action="store_true", + help="Skip correctness checks (Oink/Quack vs a pure-PyTorch reference)", + ) args = p.parse_args() dtype = parse_dtype(args.dtype) @@ -265,7 +294,7 @@ def main() -> None: meta = collect_device_meta(device) rows_out: List[Dict[str, Any]] = [] - for (M, N) in cfgs: + for M, N in cfgs: print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} ...", flush=True) (ms_oink, gbps_oink), quack, stats = bench_single( M=M, diff --git a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py index 7826efc..995b09f 100644 --- a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py @@ -150,6 +150,7 @@ def bench_single( bytes_io = bytes_io_model_softmax(M, N, dtype, mode=mode) if mode == "fwd": + def fn_oink(): return oink_softmax.softmax_forward(x) @@ -174,6 +175,7 @@ def fn_quack(): return quack_softmax_bwd(dy, y_q) elif mode == "fwd_bwd": + def fn_oink(): return oink_softmax.softmax_fwd_bwd(dy, x) @@ -208,20 +210,36 @@ def main() -> None: print(f"Running on {torch.cuda.get_device_name(device)} (SM{sm})") p = argparse.ArgumentParser() - p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]) - p.add_argument("--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"]) - p.add_argument("--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only).") + p.add_argument( + "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"] + ) + p.add_argument( + "--mode", type=str, default="fwd_bwd", choices=["fwd", "bwd", "fwd_bwd"] + ) + p.add_argument( + "--iters", type=int, default=50, help="Triton do_bench rep_ms (kernel-only)." + ) p.add_argument("--warmup-ms", type=int, default=25) - p.add_argument("--csv", type=str, default=None, help="Optional CSV output path; appends rows") - p.add_argument("--json", type=str, default=None, help="Optional JSON output path (meta + rows)") + p.add_argument( + "--csv", type=str, default=None, help="Optional CSV output path; appends rows" + ) + p.add_argument( + "--json", type=str, default=None, help="Optional JSON output path (meta + rows)" + ) p.add_argument("--configs", type=str, default="1024x4096,8192x4096") - p.add_argument("--quack-suite", action="store_true", help="Run Quack-style batch/seq grid") + p.add_argument( + "--quack-suite", action="store_true", help="Run Quack-style batch/seq grid" + ) p.add_argument( "--dsv3", action="store_true", help="Run DSv3 set: M in {4096,16384,65536}, N in {6144,7168,8192}", ) - p.add_argument("--skip-verify", action="store_true", help="Skip correctness checks (Oink/Quack vs PyTorch softmax)") + p.add_argument( + "--skip-verify", + action="store_true", + help="Skip correctness checks (Oink/Quack vs PyTorch softmax)", + ) args = p.parse_args() dtype = parse_dtype(args.dtype) @@ -237,8 +255,11 @@ def main() -> None: meta = collect_device_meta(device) rows_out: List[Dict[str, Any]] = [] - for (M, N) in cfgs: - print(f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...", flush=True) + for M, N in cfgs: + print( + f"bench M={M:<8d} N={N:<6d} dtype={args.dtype} mode={args.mode} ...", + flush=True, + ) (ms_oink, gbps_oink), quack, stats = bench_single( M=M, N=N, diff --git a/oink/benchmarks/readme/plot_quack_style_svg.py b/oink/benchmarks/readme/plot_quack_style_svg.py index c089b2b..af76832 100644 --- a/oink/benchmarks/readme/plot_quack_style_svg.py +++ b/oink/benchmarks/readme/plot_quack_style_svg.py @@ -78,7 +78,9 @@ def _gbps_from_row(prefix: str, row: Mapping[str, Any]) -> Optional[float]: return None -def _aggregate_by_shape(rows: Sequence[Mapping[str, Any]]) -> Dict[Tuple[int, int], Dict[str, float]]: +def _aggregate_by_shape( + rows: Sequence[Mapping[str, Any]], +) -> Dict[Tuple[int, int], Dict[str, float]]: """Aggregate duplicate (M, N) rows using median (more robust than mean).""" buckets: dict[tuple[int, int], dict[str, list[float]]] = defaultdict( lambda: defaultdict(list) @@ -199,7 +201,11 @@ def _plot( continue ours_y.append(float(rec["ours"])) quack_y.append(float(rec["quack"])) - max_y = max(max_y, *(v for v in ours_y if math.isfinite(v)), *(v for v in quack_y if math.isfinite(v))) + max_y = max( + max_y, + *(v for v in ours_y if math.isfinite(v)), + *(v for v in quack_y if math.isfinite(v)), + ) ax.plot( x, @@ -337,7 +343,10 @@ def main() -> None: description="Generate Quack-style SVG plots from KernelAgent-Oink suite JSONs." ) p.add_argument( - "--in-dir", type=str, required=True, help="Directory containing suite JSON outputs" + "--in-dir", + type=str, + required=True, + help="Directory containing suite JSON outputs", ) p.add_argument( "--suite", @@ -362,22 +371,35 @@ def main() -> None: "`union` includes every shape across panels (may create gaps)." ), ) - p.add_argument("--roofline-json", type=str, default=None, help="Optional /tmp/hbm_roofline_sm100_*.json path") + p.add_argument( + "--roofline-json", + type=str, + default=None, + help="Optional /tmp/hbm_roofline_sm100_*.json path", + ) p.add_argument("--out", type=str, required=True, help="Output SVG path") - p.add_argument("--title", type=str, default=None, help="Optional figure title override") + p.add_argument( + "--title", type=str, default=None, help="Optional figure title override" + ) args = p.parse_args() in_dir = os.path.abspath(args.in_dir) if not os.path.isdir(in_dir): raise SystemExit(f"--in-dir is not a directory: {in_dir}") - roofline_gbps = _read_roofline_gbps(args.roofline_json) if args.roofline_json else None + roofline_gbps = ( + _read_roofline_gbps(args.roofline_json) if args.roofline_json else None + ) panel_files = list(_panel_files_for_suite(str(args.suite))) if args.include_layernorm: if args.suite != "quack_suite": - raise SystemExit("--include-layernorm is only supported for `--suite quack_suite`.") - panel_files.append(("LayerNorm (fwd)", _layernorm_file_for_suite(str(args.suite)))) + raise SystemExit( + "--include-layernorm is only supported for `--suite quack_suite`." + ) + panel_files.append( + ("LayerNorm (fwd)", _layernorm_file_for_suite(str(args.suite))) + ) panels: List[Tuple[str, Dict[Tuple[int, int], Dict[str, float]]]] = [] for panel_title, filename in panel_files: @@ -410,7 +432,11 @@ def main() -> None: suite_name = "DSv3 CrossEntropy" else: suite_name = str(args.suite) - suffix = " (+LayerNorm)" if (args.suite == "quack_suite" and args.include_layernorm) else "" + suffix = ( + " (+LayerNorm)" + if (args.suite == "quack_suite" and args.include_layernorm) + else "" + ) if args.suite == "dsv3_cross_entropy": title = f"SM100 {dtype.upper()} — {suite_name}{suffix}" else: diff --git a/oink/benchmarks/readme/run_sm100_suite.py b/oink/benchmarks/readme/run_sm100_suite.py index 5ac1091..c31d4b5 100644 --- a/oink/benchmarks/readme/run_sm100_suite.py +++ b/oink/benchmarks/readme/run_sm100_suite.py @@ -21,7 +21,9 @@ def _run(cmd: List[str], *, dry_run: bool) -> None: def main() -> None: p = argparse.ArgumentParser() - p.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"]) + p.add_argument( + "--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"] + ) p.add_argument( "--out-dir", type=str, @@ -33,7 +35,9 @@ def main() -> None: action="store_true", help="Skip correctness checks (Oink/Quack vs PyTorch / pure-PyTorch references)", ) - p.add_argument("--dry-run", action="store_true", help="Print commands without executing them") + p.add_argument( + "--dry-run", action="store_true", help="Print commands without executing them" + ) args = p.parse_args() # Standardize env for standalone runs outside the vLLM plugin. diff --git a/oink/benchmarks/readme/summarize_results.py b/oink/benchmarks/readme/summarize_results.py index 70782dd..29b768e 100644 --- a/oink/benchmarks/readme/summarize_results.py +++ b/oink/benchmarks/readme/summarize_results.py @@ -95,16 +95,26 @@ def _summarize_error_stats(rows: Sequence[Dict[str, Any]]) -> str: out_rows: List[Dict[str, Any]] = [] for pfx in prefixes: # Per-prefix worst-case across rows. - max_abs_vals = [float(r[pfx + "_max_abs"]) for r in rows if (pfx + "_max_abs") in r] - p99_abs_vals = [float(r[pfx + "_p99_abs"]) for r in rows if (pfx + "_p99_abs") in r] - rel_l2_vals = [float(r[pfx + "_rel_l2"]) for r in rows if (pfx + "_rel_l2") in r] + max_abs_vals = [ + float(r[pfx + "_max_abs"]) for r in rows if (pfx + "_max_abs") in r + ] + p99_abs_vals = [ + float(r[pfx + "_p99_abs"]) for r in rows if (pfx + "_p99_abs") in r + ] + rel_l2_vals = [ + float(r[pfx + "_rel_l2"]) for r in rows if (pfx + "_rel_l2") in r + ] if not max_abs_vals and not p99_abs_vals and not rel_l2_vals: continue out_rows.append( { "metric": pfx, - "max_abs (max over shapes)": max(max_abs_vals) if max_abs_vals else None, - "p99_abs (max over shapes)": max(p99_abs_vals) if p99_abs_vals else None, + "max_abs (max over shapes)": max(max_abs_vals) + if max_abs_vals + else None, + "p99_abs (max over shapes)": max(p99_abs_vals) + if p99_abs_vals + else None, "rel_l2 (max over shapes)": max(rel_l2_vals) if rel_l2_vals else None, } ) @@ -112,8 +122,15 @@ def _summarize_error_stats(rows: Sequence[Dict[str, Any]]) -> str: if not out_rows: return "" - cols = ["metric", "max_abs (max over shapes)", "p99_abs (max over shapes)", "rel_l2 (max over shapes)"] - return "\n".join(["", "### Error Stats (vs PyTorch ref)", "", _md_table(out_rows, cols), ""]) + cols = [ + "metric", + "max_abs (max over shapes)", + "p99_abs (max over shapes)", + "rel_l2 (max over shapes)", + ] + return "\n".join( + ["", "### Error Stats (vs PyTorch ref)", "", _md_table(out_rows, cols), ""] + ) def summarize_one(path: str) -> str: @@ -143,7 +160,9 @@ def summarize_one(path: str) -> str: if method is not None: parts.append(f"- method: `{method}`") if meta.get("warmup_ms") is not None and meta.get("rep_ms") is not None: - parts.append(f"- warmup_ms: `{meta.get('warmup_ms')}` | rep_ms: `{meta.get('rep_ms')}`") + parts.append( + f"- warmup_ms: `{meta.get('warmup_ms')}` | rep_ms: `{meta.get('rep_ms')}`" + ) if rows: parts.append("") @@ -153,7 +172,9 @@ def summarize_one(path: str) -> str: gm = _geomean(speeds) if gm is not None: parts.append("") - parts.append(f"- geomean speedup vs Quack: `{gm:.3f}x` (over {len(speeds)} shapes)") + parts.append( + f"- geomean speedup vs Quack: `{gm:.3f}x` (over {len(speeds)} shapes)" + ) err_block = _summarize_error_stats(rows) if err_block: @@ -167,9 +188,21 @@ def summarize_one(path: str) -> str: def main() -> None: - p = argparse.ArgumentParser(description="Summarize KernelAgent-Oink benchmark JSONs into Markdown tables.") - p.add_argument("--in-dir", type=str, required=True, help="Directory containing benchmark JSON files") - p.add_argument("--out", type=str, default=None, help="Optional output markdown path (default: stdout)") + p = argparse.ArgumentParser( + description="Summarize KernelAgent-Oink benchmark JSONs into Markdown tables." + ) + p.add_argument( + "--in-dir", + type=str, + required=True, + help="Directory containing benchmark JSON files", + ) + p.add_argument( + "--out", + type=str, + default=None, + help="Optional output markdown path (default: stdout)", + ) args = p.parse_args() in_dir = os.path.abspath(args.in_dir) @@ -177,7 +210,9 @@ def main() -> None: raise SystemExit(f"--in-dir is not a directory: {in_dir}") json_paths = sorted( - os.path.join(in_dir, name) for name in os.listdir(in_dir) if name.endswith(".json") + os.path.join(in_dir, name) + for name in os.listdir(in_dir) + if name.endswith(".json") ) if not json_paths: raise SystemExit(f"No .json files found under: {in_dir}") diff --git a/oink/src/kernelagent_oink/blackwell/cross_entropy.py b/oink/src/kernelagent_oink/blackwell/cross_entropy.py index 94f052f..3e6eef1 100644 --- a/oink/src/kernelagent_oink/blackwell/cross_entropy.py +++ b/oink/src/kernelagent_oink/blackwell/cross_entropy.py @@ -103,9 +103,8 @@ def _convert_logits_2d(x: Tensor) -> cute.Tensor: softmax and RMSNorm kernels. """ assert x.dim() == 2, "Input logits must be 2D (M, N)" - return ( - from_dlpack(x.detach(), assumed_align=16) - .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1)) + return from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic( + mode=0, stride_order=(0, 1) ) @@ -136,7 +135,11 @@ def _calculate_threads_per_row(self) -> int: else ( 16 if N <= 128 - else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))) + else ( + 32 + if N <= 3072 + else (64 if N <= 6144 else (128 if N <= 16384 else 256)) + ) ) ) @@ -183,7 +186,9 @@ def __call__( num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) num_threads = ( - cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() ) num_warps = num_threads // cute.arch.WARP_SIZE kernel = ( @@ -267,7 +272,9 @@ def _kernel_impl( cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16, ) - reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, tv_layout + ) # Copy setup: gmem -> smem via cp.async, 128-bit or narrower as needed. num_copy_elems_X = tv_layout.shape[1][0] @@ -277,7 +284,9 @@ def _kernel_impl( gX.element_type, num_bits_per_copy=num_copy_bits_X, ) - thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx) + thr_copy_X = cute.make_tiled_copy( + copy_atom_load_X, tv_layout, tiler_mn + ).get_slice(tidx) tXgX = thr_copy_X.partition_S(gX) tXsX = thr_copy_X.partition_D(sX) @@ -414,13 +423,21 @@ def _calculate_threads_per_row(self) -> int: else ( 16 if N <= 128 - else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))) + else ( + 32 + if N <= 3072 + else (64 if N <= 6144 else (128 if N <= 16384 else 256)) + ) ) ) - def _get_tv_layout(self, num_copy_bits: int = 128) -> tuple[cute.Shape, cute.Layout]: + def _get_tv_layout( + self, num_copy_bits: int = 128 + ) -> tuple[cute.Shape, cute.Layout]: vecsize = num_copy_bits // self.dtype.width - assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}" + assert self.N % vecsize == 0, ( + f"Input N {self.N} is not divisible by vector size {vecsize}" + ) N = min(self.N, 16384) num_threads = 128 if N <= 16384 else 256 threads_per_row = self._calculate_threads_per_row() @@ -452,7 +469,9 @@ def __call__( num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) num_threads = ( - cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() ) # Broadcast (M,) tensors along the N dimension with stride 0. mDLoss, mTarget, mLSE = [ @@ -564,8 +583,12 @@ def _kernel_impl( gdX.element_type, num_bits_per_copy=num_copy_bits_X, ) - thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx) - thr_copy_dX = cute.make_tiled_copy(copy_atom_store_dX, tv_layout, tiler_mn).get_slice(tidx) + thr_copy_X = cute.make_tiled_copy( + copy_atom_load_X, tv_layout, tiler_mn + ).get_slice(tidx) + thr_copy_dX = cute.make_tiled_copy( + copy_atom_store_dX, tv_layout, tiler_mn + ).get_slice(tidx) tXgX = thr_copy_X.partition_S(gX) tXsX = thr_copy_X.partition_D(sX) @@ -898,8 +921,14 @@ def _cross_entropy_forward_ptr_into( assert logits.is_cuda and logits.dim() == 2 assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0] assert target.dtype is torch.int64 - assert loss.is_cuda and loss.shape == (logits.shape[0],) and loss.dtype is torch.float32 - assert lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32 + assert ( + loss.is_cuda + and loss.shape == (logits.shape[0],) + and loss.dtype is torch.float32 + ) + assert ( + lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32 + ) M, N = logits.shape device_index = logits.get_device() @@ -991,10 +1020,18 @@ def _cross_entropy_backward_ptr_into( assert logits.is_cuda and logits.dim() == 2 assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0] assert target.dtype is torch.int64 - assert dloss.is_cuda and dloss.shape == (logits.shape[0],) and dloss.dtype is torch.float32 - assert lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32 + assert ( + dloss.is_cuda + and dloss.shape == (logits.shape[0],) + and dloss.dtype is torch.float32 + ) + assert ( + lse.is_cuda and lse.shape == (logits.shape[0],) and lse.dtype is torch.float32 + ) assert dx.is_cuda and dx.shape == logits.shape and dx.dtype == logits.dtype - assert dx.stride() == logits.stride(), "Pointer path expects dx to match logits strides" + assert dx.stride() == logits.stride(), ( + "Pointer path expects dx to match logits strides" + ) M, N = logits.shape device_index = logits.get_device() @@ -1060,7 +1097,9 @@ def _cross_entropy_backward_ptr_into( mem_space=rt.AddressSpace.gmem, assumed_align=4, ) - ptr_dx = rt.make_ptr(dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) ptr_lse = rt.make_ptr( cutlass.Float32, lse.data_ptr(), @@ -1169,7 +1208,9 @@ def verify_cross_entropy_parity( mask = torch.rand(M, device=device) < 0.1 target[mask] = ignore_index - loss, lse = cross_entropy_forward(logits, target, ignore_index=ignore_index, reduction="none") + loss, lse = cross_entropy_forward( + logits, target, ignore_index=ignore_index, reduction="none" + ) logits_ref = logits.detach().clone().requires_grad_() target_ref = target.detach().clone() diff --git a/oink/src/kernelagent_oink/blackwell/layernorm.py b/oink/src/kernelagent_oink/blackwell/layernorm.py index 0e4d640..67f67ce 100644 --- a/oink/src/kernelagent_oink/blackwell/layernorm.py +++ b/oink/src/kernelagent_oink/blackwell/layernorm.py @@ -138,7 +138,11 @@ def _calculate_threads_per_row(self) -> int: else ( 16 if N <= 128 - else (32 if N <= 3072 else (64 if N <= 6144 else (128 if N <= 16384 else 256))) + else ( + 32 + if N <= 3072 + else (64 if N <= 6144 else (128 if N <= 16384 else 256)) + ) ) ) @@ -186,7 +190,9 @@ def __call__( self._set_cluster_n() tiler_mn, tv_layout = self._get_tv_layout() num_threads = ( - cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() ) num_warps = num_threads // cute.arch.WARP_SIZE @@ -281,7 +287,9 @@ def launch_from_ptrs( mX = cute.make_tensor(ptr_x, layout_mn) mO = cute.make_tensor(ptr_out, layout_mn) mW = cute.make_tensor(ptr_w, layout_n) - mB = cute.make_tensor(ptr_b, layout_n) if const_expr(ptr_b is not None) else None + mB = ( + cute.make_tensor(ptr_b, layout_n) if const_expr(ptr_b is not None) else None + ) mRstd = ( cute.make_tensor(ptr_rstd, layout_m) if const_expr(ptr_rstd is not None) @@ -323,15 +331,15 @@ def _kernel_impl( cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16, ) - reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, tv_layout + ) shape = mX.shape idX = cute.make_identity_tensor(shape) # Slice for CTAs: use domain_offset_i64 to handle >2^31 elements. - mX, mO = [ - domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO) - ] + mX, mO = [domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)] gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)] cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) @@ -390,39 +398,23 @@ def _kernel_impl( ).get_slice(tidx) tWgW = thr_copy_WB.partition_S(gW) - tBgB = ( - thr_copy_WB.partition_S(gB) - if const_expr(gB is not None) - else None - ) + tBgB = thr_copy_WB.partition_S(gB) if const_expr(gB is not None) else None tXgX = thr_copy_X.partition_S(gX) tXsX = thr_copy_X.partition_D(sX) tXgO = thr_copy_O.partition_D(gO) tXrRstd = ( - thr_copy_O.partition_D(gRstd) - if const_expr(mRstd is not None) - else None + thr_copy_O.partition_D(gRstd) if const_expr(mRstd is not None) else None ) tXrMean = ( - thr_copy_O.partition_D(gMean) - if const_expr(mMean is not None) - else None + thr_copy_O.partition_D(gMean) if const_expr(mMean is not None) else None ) tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] # Fragments for gmem->rmem. tWrW = cute.make_fragment_like(tWgW) - tBrB = ( - cute.make_fragment_like(tBgB) - if const_expr(mB is not None) - else None - ) + tBrB = cute.make_fragment_like(tBgB) if const_expr(mB is not None) else None tXrW = thr_copy_X.retile(tWrW) - tXrB = ( - thr_copy_X.retile(tBrB) - if const_expr(mB is not None) - else None - ) + tXrB = thr_copy_X.retile(tBrB) if const_expr(mB is not None) else None tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)] num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE @@ -458,9 +450,7 @@ def _kernel_impl( mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None, init_val=0.0, hook_fn=( - cute.arch.cluster_wait - if const_expr(self.cluster_n > 1) - else None + cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None ), ) mean = sum_x / shape[1] @@ -486,10 +476,7 @@ def _kernel_impl( if ( tXcX[0][1] == 0 and row < shape[0] - and ( - self.cluster_n == 1 - or cute.arch.block_idx_in_cluster() == 0 - ) + and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) ): tXrRstd[0] = rstd @@ -497,10 +484,7 @@ def _kernel_impl( if ( tXcX[0][1] == 0 and row < shape[0] - and ( - self.cluster_n == 1 - or cute.arch.block_idx_in_cluster() == 0 - ) + and (self.cluster_n == 1 or cute.arch.block_idx_in_cluster() == 0) ): tXrMean[0] = mean @@ -861,7 +845,9 @@ def _layernorm_forward_ptr_into( ) _PTR_COMPILE_CACHE[key] = compiled - ptr_x = rt.make_ptr(dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_x = rt.make_ptr( + dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) ptr_out = rt.make_ptr( dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 ) @@ -978,8 +964,12 @@ def _layernorm_backward_dx_kernel( smem = cutlass.utils.SmemAllocator() num_warps = const_expr(block_threads // cute.arch.WARP_SIZE) warp_sums_layout = cute.make_layout((num_warps,), stride=(1,)) - warp_sums_wdy = smem.allocate_tensor(Float32, warp_sums_layout, byte_alignment=4) - warp_sums_xhatwdy = smem.allocate_tensor(Float32, warp_sums_layout, byte_alignment=4) + warp_sums_wdy = smem.allocate_tensor( + Float32, warp_sums_layout, byte_alignment=4 + ) + warp_sums_xhatwdy = smem.allocate_tensor( + Float32, warp_sums_layout, byte_alignment=4 + ) lane = cute.arch.lane_idx() warp_idx = cute.arch.warp_idx() @@ -1177,8 +1167,12 @@ def _layernorm_backward_dx_sm100( alignment=16, divisibility=128 // cutlass.Float32.width, ) - mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) - mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) key = (N, dtype) @@ -1231,8 +1225,12 @@ def _layernorm_backward_params_sm100( mX = _convert_row_major(x_2d) mdO = _convert_row_major(dout_2d) - mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) - mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + mRstd = from_dlpack(rstd_1d.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) + mMean = from_dlpack(mean_1d.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) mdW_partial = ( from_dlpack(dw_partial, assumed_align=16).mark_compact_shape_dynamic(mode=0) diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py index 590d773..e8ce93a 100644 --- a/oink/src/kernelagent_oink/blackwell/lite_quack.py +++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py @@ -85,6 +85,7 @@ def _cutlass_dsl_version() -> Optional[tuple[int, int, int]]: # Tensor conversion helpers (from quack.utils) # ------------------------- + def convert_from_dlpack( x: Tensor, leading_dim: int, @@ -108,7 +109,9 @@ def convert_from_dlpack( @dsl_user_op -def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: +def elem_pointer( + x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None +) -> cute.Pointer: return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) @@ -159,7 +162,9 @@ def store_shared_remote( ).ir_value() if const_expr(isinstance(val, float)): val = Float32(val) - assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64" + assert isinstance(val, (Float32, Int32, cutlass.Int64)), ( + "val must be Float32, Int32, or Int64" + ) suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)] constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)] llvm.inline_asm( @@ -178,19 +183,27 @@ def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if". tApA = cute.make_fragment( cute.make_layout( - (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + ( + cute.size(tAcA, mode=[0, 1]), + cute.size(tAcA, mode=[1]), + cute.size(tAcA, mode=[2]), + ), stride=(cute.size(tAcA, mode=[2]), 0, 1), ), cutlass.Boolean, ) for rest_v in cutlass.range_constexpr(tApA.shape[0]): for rest_k in cutlass.range_constexpr(tApA.shape[2]): - tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + tApA[rest_v, 0, rest_k] = cute.elem_less( + tAcA[(0, rest_v), 0, rest_k][1], limit + ) return tApA @dsl_user_op -def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: +def domain_offset_i64( + coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None +) -> cute.Tensor: flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) flat_stride = cute.flatten_to_tuple(tensor.stride) assert len(flat_coord_i64) == len(flat_stride), ( @@ -228,7 +241,9 @@ def coord_offset_i64( @cute.jit -def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cutlass.Numeric) -> None: +def fill_oob( + tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cutlass.Numeric +) -> None: """Fill out-of-bounds values in shared memory tensor.""" tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0]) tXrX_fill.fill(fill_value) @@ -256,7 +271,9 @@ def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64: ) vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2, loc=loc, ip=ip) res = cutlass.Int64( - vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip) + vector.extract( + vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip + ) ) return res @@ -272,10 +289,14 @@ def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float ) vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1, loc=loc, ip=ip) res0 = Float32( - vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip) + vector.extract( + vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip + ) ) res1 = Float32( - vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip) + vector.extract( + vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip + ) ) return res0, res1 @@ -372,7 +393,9 @@ def block_or_cluster_reduce( """Perform either block or cluster reduction based on whether mbar_ptr is provided.""" if cutlass.const_expr(mbar_ptr is None): return block_reduce(val, op, reduction_buffer, init_val=init_val) - return cluster_reduce(val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase) + return cluster_reduce( + val, op, reduction_buffer, mbar_ptr, init_val=init_val, phase=phase + ) @cute.jit @@ -393,7 +416,9 @@ def row_reduce( val = x warp_op = { cute.ReductionOp.ADD: operator.add, - cute.ReductionOp.MAX: cute.arch.fmax if cutlass.const_expr(x.dtype == Float32) else max, + cute.ReductionOp.MAX: cute.arch.fmax + if cutlass.const_expr(x.dtype == Float32) + else max, cute.ReductionOp.MIN: min, cute.ReductionOp.MUL: operator.mul, }[op] @@ -521,7 +546,9 @@ def online_softmax_reduce( reduction_buffer[row_idx, lane_idx] ) max_x_final = warp_reduce(max_x_single_warp, cute.arch.fmax) - sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True) + sum_exp_x *= cute.math.exp( + max_x_single_warp - max_x_final, fastmath=True + ) sum_exp_x = warp_reduce(sum_exp_x, operator.add) if cutlass.const_expr(return_exp_x): exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True) @@ -533,16 +560,23 @@ def online_softmax_reduce( num_warps = rows_per_block * warps_per_row cute.arch.mbarrier_arrive_and_expect_tx( mbar_ptr, - num_warps * cluster_n * reduction_buffer.element_type.width // 8, + num_warps + * cluster_n + * reduction_buffer.element_type.width + // 8, ) if lane_idx < cluster_n: store_shared_remote( f32x2_to_i64(max_x, sum_exp_x), - elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))), + elem_pointer( + reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster)) + ), mbar_ptr, peer_cta_rank_in_cluster=lane_idx, ) - cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0) + cute.arch.mbarrier_wait( + mbar_ptr, phase=phase if phase is not None else 0 + ) num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE) max_x_single_warp = cute.make_fragment(num_iter, Float32) max_x_single_warp.fill(-Float32.inf) @@ -591,7 +625,9 @@ def get_copy_atom( num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() - return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits, loc=loc, ip=ip) + return cute.make_copy_atom( + copy_op, dtype, num_bits_per_copy=num_copy_bits, loc=loc, ip=ip + ) @dsl_user_op @@ -606,7 +642,9 @@ def copy( ip=None, **kwargs, ) -> None: - copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async, loc=loc, ip=ip) + copy_atom = get_copy_atom( + src.element_type, num_copy_elems, is_async, loc=loc, ip=ip + ) cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) @@ -637,15 +675,21 @@ def _set_cluster_n(self) -> None: def _get_num_threads(self) -> int: return 128 if self.N <= 16384 else 256 - def _get_tv_layout(self, num_copy_bits: int = 128) -> Tuple[cute.Shape, cute.Layout]: + def _get_tv_layout( + self, num_copy_bits: int = 128 + ) -> Tuple[cute.Shape, cute.Layout]: vecsize = num_copy_bits // self.dtype.width - assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}" + assert self.N % vecsize == 0, ( + f"Input N {self.N} is not divisible by vector size {vecsize}" + ) num_threads = self._get_num_threads() assert num_threads % cute.arch.WARP_SIZE == 0 threads_per_row = self._calculate_threads_per_row() self._set_cluster_n() - num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n) + num_blocks_N = cute.ceil_div( + self.N // vecsize, threads_per_row * self.cluster_n + ) cols_per_block = num_threads // threads_per_row tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) tv_layout = cute.make_layout( @@ -660,11 +704,16 @@ def _get_tv_layout(self, num_copy_bits: int = 128) -> Tuple[cute.Shape, cute.Lay def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int: return ( cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) - + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8) + + self.stage + * num_warps + * self.cluster_n + * (self.reduction_dtype.width // 8) + self.stage * (cutlass.Int64.width // 8) ) - def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int) -> cute.Layout: + def _get_reduction_buffer_layout( + self, tv_layout: cute.Layout, cluster_n: int + ) -> cute.Layout: num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1) return cute.make_ordered_layout( @@ -723,7 +772,9 @@ def __init__(self, dtype: cutlass.Numeric, N: int): super().__init__(dtype, N, stage=2, reduction_dtype=Float32) self.reload_wdy = None if N <= 16 * 1024 else "smem" if self.N > 128 * 1024 and self.dtype.width >= 32: - raise ValueError("RMSNormBackward does not support N > 128k with dtype >= 32 bits") + raise ValueError( + "RMSNormBackward does not support N > 128k with dtype >= 32 bits" + ) def _get_num_threads(self) -> int: return 128 if self.N <= 4096 else 256 @@ -736,7 +787,11 @@ def _calculate_threads_per_row(self) -> int: else ( 16 if N <= 128 - else (32 if N <= 256 else (64 if N <= 512 else (128 if N <= 4096 else 256))) + else ( + 32 + if N <= 256 + else (64 if N <= 512 else (128 if N <= 4096 else 256)) + ) ) ) @@ -745,7 +800,11 @@ def _set_cluster_n(self) -> None: cluster_n = ( 1 if N <= 8 * 1024 - else (2 if N <= 16 * 1024 else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16))) + else ( + 2 + if N <= 16 * 1024 + else (4 if N <= 32 * 1024 else (8 if N <= 64 * 1024 else 16)) + ) ) self.cluster_n = cluster_n @@ -755,7 +814,10 @@ def _smem_size_in_bytes(self, tiler_mn, num_warps: int, do_dtype=None) -> int: return ( cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 + cute.size_in_bytes(do_dtype, cute.make_layout(tiler_mn)) * 2 - + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8) + + self.stage + * num_warps + * self.cluster_n + * (self.reduction_dtype.width // 8) + self.stage * (cutlass.Int64.width // 8) * 2 ) @@ -783,7 +845,9 @@ def new_stride(t): ) mX, mdO, mdResO, mdX, mdRes = [ - cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))) + cute.make_tensor( + t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)) + ) if const_expr(t is not None) else None for t in (mX, mdO, mdResO, mdX, mdRes) @@ -802,7 +866,9 @@ def new_stride(t): num_copy_bits=128 // largest_dtype_width * mX.element_type.width ) num_threads = ( - cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() ) num_warps = num_threads // cute.arch.WARP_SIZE if const_expr(mW is not None): @@ -814,7 +880,9 @@ def new_stride(t): num_blocks = sm_count kernel = ( - self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn) + self.kernel( + mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn + ) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes) ) @@ -822,7 +890,9 @@ def new_stride(t): grid=[num_blocks, self.cluster_n, 1], block=[num_threads, 1, 1], cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None, - smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type), + smem=self._smem_size_in_bytes( + tiler_mn, num_warps, do_dtype=mdO.element_type + ), stream=stream, ) @@ -856,7 +926,9 @@ def _kernel_impl( idX = cute.make_identity_tensor(shape) smem = cutlass.utils.SmemAllocator() - smem_layout = cute.make_ordered_layout((tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2)) + smem_layout = cute.make_ordered_layout( + (tiler_mn[0], tiler_mn[1], 2), order=(1, 0, 2) + ) sX = smem.allocate_tensor(mX.element_type, smem_layout, byte_alignment=16) sdO = smem.allocate_tensor(mdO.element_type, smem_layout, byte_alignment=16) reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( @@ -870,8 +942,12 @@ def _kernel_impl( mbar_full_ptr, mbar_empty_ptr = None, None num_copy_elems_X = tv_layout.shape[1][0] - copy_atom_load_X = get_copy_atom(mX.element_type, num_copy_elems_X, is_async=False) - thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, tv_layout, tiler_mn).get_slice(tidx) + copy_atom_load_X = get_copy_atom( + mX.element_type, num_copy_elems_X, is_async=False + ) + thr_copy_X = cute.make_tiled_copy( + copy_atom_load_X, tv_layout, tiler_mn + ).get_slice(tidx) copy_fn = partial(copy, num_copy_elems=num_copy_elems_X) gX, gdO, gdResO, gdX, gdRes, cX = [ @@ -898,7 +974,8 @@ def _kernel_impl( tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None, None] tXrX, tXrdO, tXrdX = [ - cute.make_fragment_like(thr[None, None, None, 0]) for thr in (tXgX, tXgdO, tXgdX) + cute.make_fragment_like(thr[None, None, None, 0]) + for thr in (tXgX, tXgdO, tXgdX) ] tXrdResO = None if const_expr(mdResO is not None): @@ -959,10 +1036,24 @@ def _kernel_impl( for bidx in cutlass.range(bidx_start, cute.ceil_div(M, tiler_mn[0]), gdim): row = tXcX[None, None, None, bidx][0][0] if row + gdim * tiler_mn[0] < M: - tXgX_cur = coord_offset_i64(bidx + gdim, tXgX, dim=3)[None, None, None, 0] - tXgdO_cur = coord_offset_i64(bidx + gdim, tXgdO, dim=3)[None, None, None, 0] - copy_fn(tXgX_cur, tXsX[None, None, None, stage ^ 1], pred=tXpX, is_async=True) - copy_fn(tXgdO_cur, tXsdO[None, None, None, stage ^ 1], pred=tXpX, is_async=True) + tXgX_cur = coord_offset_i64(bidx + gdim, tXgX, dim=3)[ + None, None, None, 0 + ] + tXgdO_cur = coord_offset_i64(bidx + gdim, tXgdO, dim=3)[ + None, None, None, 0 + ] + copy_fn( + tXgX_cur, + tXsX[None, None, None, stage ^ 1], + pred=tXpX, + is_async=True, + ) + copy_fn( + tXgdO_cur, + tXsdO[None, None, None, stage ^ 1], + pred=tXpX, + is_async=True, + ) elif tiler_mn[0] > 1: fill_oob( tXsX[None, None, None, stage ^ 1], @@ -979,7 +1070,9 @@ def _kernel_impl( if row < M or tiler_mn[0] == 1: rstd_val = mRstd[row] if const_expr(mdResO is not None): - tXgdResO_cur = coord_offset_i64(bidx, tXgdResO, dim=3)[None, None, None, 0] + tXgdResO_cur = coord_offset_i64(bidx, tXgdResO, dim=3)[ + None, None, None, 0 + ] if row < M or tiler_mn[0] == 1: copy_fn(tXgdResO_cur, tXrdResO, pred=tXpX) elif tiler_mn[0] > 1: @@ -1036,7 +1129,9 @@ def _kernel_impl( copy_fn(tXrdX, tXgdX_cur, pred=tXpX) if const_expr(mdRes is not None): tXrdRes.store(dx.to(tXrdRes.element_type)) - tXgdRes_cur = coord_offset_i64(bidx, tXgdRes, dim=3)[None, None, None, 0] + tXgdRes_cur = coord_offset_i64(bidx, tXgdRes, dim=3)[ + None, None, None, 0 + ] if row < M or tiler_mn[0] == 1: copy_fn(tXrdRes, tXgdRes_cur, pred=tXpX) if const_expr(mdW is not None): @@ -1204,7 +1299,9 @@ def get_sm_count( num_sms = props.multi_processor_count sm_count_multiple = ( - 16 if N <= 256 else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1))) + 16 + if N <= 256 + else (8 if N <= 1024 else (4 if N <= 2048 else (2 if N <= 4096 else 1))) ) sm_count = num_sms if N <= 8192: diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py index 9df9f16..e921947 100644 --- a/oink/src/kernelagent_oink/blackwell/rmsnorm.py +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py @@ -121,7 +121,9 @@ def _env_flag(name: str, default: bool) -> bool: # - If you want to force stage-2 even when the pointer path is available (for # experimentation / A-B testing), set this env var **before** importing this # module. -_FORCE_RMSNORM_STAGE2_FWD = _env_flag("KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2", default=False) +_FORCE_RMSNORM_STAGE2_FWD = _env_flag( + "KERNELAGENT_OINK_FORCE_RMSNORM_STAGE2", default=False +) # CuTeDSL stability probe for the experimental cluster_n>1 + direct-GMEM schedule. # @@ -2771,7 +2773,9 @@ def rmsnorm_forward( # Preserve stride contracts for torch.compile consistency, even # when using the optional stage-2 implementation. if y.stride() != x.stride(): - y_strided = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) + y_strided = torch.empty_strided( + x.shape, x.stride(), device=x.device, dtype=x.dtype + ) y_strided.copy_(y) y = y_strided if residual is not None and residual_out is not None: @@ -3036,7 +3040,9 @@ def new_stride(t): ) mX, mdO, mdResO, mdX, mdRes = [ - cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))) + cute.make_tensor( + t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)) + ) if const_expr(t is not None) else None for t in (mX, mdO, mdResO, mdX, mdRes) @@ -3056,7 +3062,9 @@ def new_stride(t): num_copy_bits=128 // largest_dtype_width * mX.element_type.width ) num_threads = ( - cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() ) num_warps = num_threads // cute.arch.WARP_SIZE if const_expr(mW is not None): @@ -3067,7 +3075,9 @@ def new_stride(t): num_blocks = sm_count kernel = ( - self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn) + self.kernel( + mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes, tv_layout, tiler_mn + ) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self.kernel(mX, mW, mdO, mdResO, mRstd, mdX, mdW, mdB, mdRes) ) @@ -3075,7 +3085,9 @@ def new_stride(t): grid=[num_blocks, self.cluster_n, 1], block=[num_threads, 1, 1], cluster=[1, self.cluster_n, 1] if self.cluster_n > 1 else None, - smem=self._smem_size_in_bytes(tiler_mn, num_warps, do_dtype=mdO.element_type), + smem=self._smem_size_in_bytes( + tiler_mn, num_warps, do_dtype=mdO.element_type + ), stream=stream, ) @@ -3160,8 +3172,8 @@ def _convert_mx(t: Tensor) -> cute.Tensor: if db_partial is not None else None ) - rstd_tensor = ( - from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + rstd_tensor = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 ) current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) @@ -3261,7 +3273,9 @@ def rmsnorm_backward( else: dw_partial = None db_partial = ( - torch.empty(sm_count, N, device=device, dtype=torch.float32) if has_bias else None + torch.empty(sm_count, N, device=device, dtype=torch.float32) + if has_bias + else None ) _rmsnorm_bwd_sm100( diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py index fec5bf4..2b5b36d 100644 --- a/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm_with_stage2.py @@ -101,7 +101,9 @@ def copy_tiled( class RMSNormSM100WithStage2: - def __init__(self, N: int, dtype: type[cutlass.Numeric], stage: Optional[int] = None): + def __init__( + self, N: int, dtype: type[cutlass.Numeric], stage: Optional[int] = None + ): self.N = N self.dtype = dtype self.stage = 1 if stage is None else stage @@ -172,7 +174,10 @@ def _tv_layout(self, num_copy_bits: int = 256) -> Tuple[cute.Shape, cute.Layout] tiler_mn = (cols_per_block, vecsize * num_blocks_N * tpr) tv_layout = cute.make_layout( ((tpr, cols_per_block), (vecsize, num_blocks_N)), - stride=((vecsize * cols_per_block, 1), (cols_per_block, cols_per_block * vecsize * tpr)), + stride=( + (vecsize * cols_per_block, 1), + (cols_per_block, cols_per_block * vecsize * tpr), + ), ) return tiler_mn, tv_layout @@ -198,7 +203,9 @@ def new_stride(t): ) mX, mRes, mO, mResO = [ - cute.make_tensor(t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t))) + cute.make_tensor( + t.iterator, cute.make_layout(semistatic_shape, stride=new_stride(t)) + ) if const_expr(t is not None) else None for t in (mX, mRes, mO, mResO) @@ -209,36 +216,48 @@ def new_stride(t): copy_bits = const_expr(128) tiler_mn, tv_layout = self._tv_layout(num_copy_bits=copy_bits) num_threads = ( - cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._num_threads() + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._num_threads() ) num_warps = num_threads // cute.arch.WARP_SIZE threads_per_row = ( - tv_layout.shape[0][0] if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._threads_per_row() + tv_layout.shape[0][0] + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._threads_per_row() ) warps_per_row = max(threads_per_row // cute.arch.WARP_SIZE, 1) cluster_n = self._cluster_n() if const_expr(mW is not None): mW = cute.make_tensor( - mW.iterator, cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))) + mW.iterator, + cute.prepend(mW.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), ) if const_expr(mB is not None): mB = cute.make_tensor( - mB.iterator, cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))) + mB.iterator, + cute.prepend(mB.layout, cute.make_layout((tiler_mn[0],), stride=(0,))), ) if const_expr(mRstd is not None): mRstd = cute.make_tensor( - mRstd.iterator, cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))) + mRstd.iterator, + cute.append(mRstd.layout, cute.make_layout((self.N,), stride=(0,))), ) stage_bufs = 2 if self.stage > 1 else 1 - tile_bytes_x = cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * stage_bufs + tile_bytes_x = ( + cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * stage_bufs + ) tile_bytes_res = ( - cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn)) * stage_bufs + cute.size_in_bytes(mRes.element_type, cute.make_layout(tiler_mn)) + * stage_bufs if const_expr(mRes is not None) else 0 ) - red_bytes = self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8) + red_bytes = ( + self.stage * num_warps * cluster_n * (self.reduction_dtype.width // 8) + ) mbar_bytes = self.stage * (cutlass.Int64.width // 8) smem_bytes = tile_bytes_x + tile_bytes_res + red_bytes + mbar_bytes @@ -299,11 +318,15 @@ def _kernel_impl( tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() cluster_n = self._cluster_n() - cluster_y = const_expr(0) if const_expr(cluster_n == 1) else cute.arch.block_idx()[1] + cluster_y = ( + const_expr(0) if const_expr(cluster_n == 1) else cute.arch.block_idx()[1] + ) smem = cutlass.utils.SmemAllocator() sX0 = smem.allocate_tensor( - mX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32 + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, ) sX1 = ( smem.allocate_tensor( @@ -316,7 +339,9 @@ def _kernel_impl( ) sRes0 = ( smem.allocate_tensor( - mRes.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=32 + mRes.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=32, ) if const_expr(mRes is not None) else None @@ -331,18 +356,24 @@ def _kernel_impl( else None ) - reduction_buffer, mbar_ptr = self._alloc_reduction_and_mbar(smem, num_warps, warps_per_row) + reduction_buffer, mbar_ptr = self._alloc_reduction_and_mbar( + smem, num_warps, warps_per_row + ) shape = mX.shape idX = cute.make_identity_tensor(shape) num_copy_elems_X = tv_layout.shape[1][0] use_async = const_expr(self.N >= 1024) - copy_atom = get_copy_atom_bw(mX.element_type, num_copy_elems_X, is_async=use_async) + copy_atom = get_copy_atom_bw( + mX.element_type, num_copy_elems_X, is_async=use_async + ) thr_copy = cute.make_tiled_copy(copy_atom, tv_layout, tiler_mn).get_slice(tidx) gW, gB = [ - cute.local_tile(t, tiler_mn, (0, cluster_y)) if const_expr(t is not None) else None + cute.local_tile(t, tiler_mn, (0, cluster_y)) + if const_expr(t is not None) + else None for t in (mW, mB) ] tXgW = thr_copy.partition_S(gW) if const_expr(mW is not None) else None @@ -350,32 +381,50 @@ def _kernel_impl( tXrW = cute.make_fragment_like(tXgW) if const_expr(mW is not None) else None tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None if const_expr(mW is not None): - cute.copy(get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False), tXgW, tXrW) + cute.copy( + get_copy_atom_bw(mW.element_type, num_copy_elems_X, is_async=False), + tXgW, + tXrW, + ) if const_expr(mB is not None): - cute.copy(get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False), tXgB, tXrB) + cute.copy( + get_copy_atom_bw(mB.element_type, num_copy_elems_X, is_async=False), + tXgB, + tXrB, + ) self._init_cluster(tidx, mbar_ptr) mX_i, mRes_i, mO_i, mResO_i = [ - qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t) if t is not None else None + qutils.domain_offset_i64((bidx * tiler_mn[0], 0), t) + if t is not None + else None for t in (mX, mRes, mO, mResO) ] gX_i = cute.local_tile(mX_i, tiler_mn, (0, cluster_y)) gO_i = cute.local_tile(mO_i, tiler_mn, (0, cluster_y)) gRes_i = ( - cute.local_tile(mRes_i, tiler_mn, (0, cluster_y)) if const_expr(mRes is not None) else None + cute.local_tile(mRes_i, tiler_mn, (0, cluster_y)) + if const_expr(mRes is not None) + else None ) gResO_i = ( - cute.local_tile(mResO_i, tiler_mn, (0, cluster_y)) if const_expr(mResO is not None) else None + cute.local_tile(mResO_i, tiler_mn, (0, cluster_y)) + if const_expr(mResO is not None) + else None ) gRstd_i = ( - cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y)) if const_expr(mRstd is not None) else None + cute.local_tile(mRstd, tiler_mn, (bidx, cluster_y)) + if const_expr(mRstd is not None) + else None ) cX_i = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) tXcX_i = thr_copy.partition_S(cX_i)[(0, None), None, None] row_i = tXcX_i[0][0] - tXgRstd_i = thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None + tXgRstd_i = ( + thr_copy.partition_D(gRstd_i) if const_expr(mRstd is not None) else None + ) # Intra-row K-loop cp.async ping-pong (two-pass) for N≈6k/8k (stage=2) if const_expr(self.stage > 1 and (shape[1] == 6144 or shape[1] == 8192)): @@ -388,37 +437,72 @@ def _kernel_impl( tiler_mn_tile = (tiler_mn[0], tile_n) sX0_tile = cute.local_tile(sX0, tiler_mn_tile, (0, 0)) - sX1_tile = cute.local_tile(sX1, tiler_mn_tile, (0, 0)) if const_expr(self.stage > 1) else None + sX1_tile = ( + cute.local_tile(sX1, tiler_mn_tile, (0, 0)) + if const_expr(self.stage > 1) + else None + ) sRes0_tile = ( - cute.local_tile(sRes0, tiler_mn_tile, (0, 0)) if const_expr(mRes is not None) else None + cute.local_tile(sRes0, tiler_mn_tile, (0, 0)) + if const_expr(mRes is not None) + else None ) sRes1_tile = ( - cute.local_tile(sRes1, tiler_mn_tile, (0, 0)) if const_expr(mRes is not None and self.stage > 1) else None + cute.local_tile(sRes1, tiler_mn_tile, (0, 0)) + if const_expr(mRes is not None and self.stage > 1) + else None ) tv_layout_tile = cute.make_layout( ((tpr, tiler_mn[0]), (vecsize, tile_factor)), - stride=((vecsize * tiler_mn[0], 1), (tiler_mn[0], tiler_mn[0] * vecsize * tpr)), + stride=( + (vecsize * tiler_mn[0], 1), + (tiler_mn[0], tiler_mn[0] * vecsize * tpr), + ), ) - thr_copy_tile = cute.make_tiled_copy(copy_atom, tv_layout_tile, tiler_mn_tile).get_slice(tidx) + thr_copy_tile = cute.make_tiled_copy( + copy_atom, tv_layout_tile, tiler_mn_tile + ).get_slice(tidx) sum_sq_acc = cute.Float32(0.0) k_off0 = const_expr(0) * tile_n - gX_0 = cute.local_tile(qutils.domain_offset_i64((0, k_off0), mX_i), tiler_mn_tile, (0, cluster_y)) + gX_0 = cute.local_tile( + qutils.domain_offset_i64((0, k_off0), mX_i), + tiler_mn_tile, + (0, cluster_y), + ) tXgX_0 = thr_copy_tile.partition_S(gX_0) tXsX_0 = thr_copy_tile.partition_D(sX0_tile) - cX_0 = cute.local_tile(cute.domain_offset((0, k_off0), cX_i), tiler_mn_tile, (0, cluster_y)) + cX_0 = cute.local_tile( + cute.domain_offset((0, k_off0), cX_i), tiler_mn_tile, (0, cluster_y) + ) tXc_0 = thr_copy_tile.partition_S(cX_0) tXp_0 = qutils.predicate_k(tXc_0, limit=shape[1]) tXp_ping = tXp_0 tXp_pong = tXp_0 if row_i < shape[0]: - copy_tiled(tXgX_0, tXsX_0, num_copy_elems=vecsize, is_async=use_async, pred=tXp_0) + copy_tiled( + tXgX_0, + tXsX_0, + num_copy_elems=vecsize, + is_async=use_async, + pred=tXp_0, + ) if const_expr(mRes is not None): - gRes_0 = cute.local_tile(qutils.domain_offset_i64((0, k_off0), mRes_i), tiler_mn_tile, (0, cluster_y)) + gRes_0 = cute.local_tile( + qutils.domain_offset_i64((0, k_off0), mRes_i), + tiler_mn_tile, + (0, cluster_y), + ) tXgRes_0 = thr_copy_tile.partition_S(gRes_0) tXsRes_0 = thr_copy_tile.partition_D(sRes0_tile) - copy_tiled(tXgRes_0, tXsRes_0, num_copy_elems=vecsize, is_async=use_async, pred=tXp_0) + copy_tiled( + tXgRes_0, + tXsRes_0, + num_copy_elems=vecsize, + is_async=use_async, + pred=tXp_0, + ) if const_expr(use_async): cute.arch.cp_async_commit_group() @@ -426,29 +510,57 @@ def _kernel_impl( next_t = t + 1 if next_t < num_tiles: k_off_n = next_t * tile_n - gX_n = cute.local_tile(qutils.domain_offset_i64((0, k_off_n), mX_i), tiler_mn_tile, (0, cluster_y)) + gX_n = cute.local_tile( + qutils.domain_offset_i64((0, k_off_n), mX_i), + tiler_mn_tile, + (0, cluster_y), + ) tXgX_n = thr_copy_tile.partition_S(gX_n) - cX_n = cute.local_tile(cute.domain_offset((0, k_off_n), cX_i), tiler_mn_tile, (0, cluster_y)) + cX_n = cute.local_tile( + cute.domain_offset((0, k_off_n), cX_i), + tiler_mn_tile, + (0, cluster_y), + ) tXc_n = thr_copy_tile.partition_S(cX_n) tXp_n = qutils.predicate_k(tXc_n, limit=shape[1]) if const_expr((t % 2) == 0): tXsX_n = thr_copy_tile.partition_D(sX1_tile) tXsRes_n = ( - thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None + thr_copy_tile.partition_D(sRes1_tile) + if const_expr(mRes is not None) + else None ) tXp_pong = tXp_n else: tXsX_n = thr_copy_tile.partition_D(sX0_tile) tXsRes_n = ( - thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None + thr_copy_tile.partition_D(sRes0_tile) + if const_expr(mRes is not None) + else None ) tXp_ping = tXp_n if row_i < shape[0]: - copy_tiled(tXgX_n, tXsX_n, num_copy_elems=vecsize, is_async=use_async, pred=tXp_n) + copy_tiled( + tXgX_n, + tXsX_n, + num_copy_elems=vecsize, + is_async=use_async, + pred=tXp_n, + ) if const_expr(mRes is not None): - gRes_n = cute.local_tile(qutils.domain_offset_i64((0, k_off_n), mRes_i), tiler_mn_tile, (0, cluster_y)) + gRes_n = cute.local_tile( + qutils.domain_offset_i64((0, k_off_n), mRes_i), + tiler_mn_tile, + (0, cluster_y), + ) tXgRes_n = thr_copy_tile.partition_S(gRes_n) - copy_tiled(tXgRes_n, tXsRes_n, num_copy_elems=vecsize, is_async=use_async, pred=tXp_n) + copy_tiled( + tXgRes_n, + tXsRes_n, + num_copy_elems=vecsize, + is_async=use_async, + pred=tXp_n, + ) if const_expr(use_async): cute.arch.cp_async_commit_group() if const_expr(use_async): @@ -456,36 +568,62 @@ def _kernel_impl( if const_expr((t % 2) == 0): tXsX_cur = thr_copy_tile.partition_D(sX0_tile) - tXsRes_cur = thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes0_tile) + if const_expr(mRes is not None) + else None + ) pred_cur = tXp_ping else: tXsX_cur = thr_copy_tile.partition_D(sX1_tile) - tXsRes_cur = thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None + tXsRes_cur = ( + thr_copy_tile.partition_D(sRes1_tile) + if const_expr(mRes is not None) + else None + ) pred_cur = tXp_pong qutils.fill_oob(tXsX_cur, pred_cur, mX.element_type.zero) if const_expr(mRes is not None): qutils.fill_oob(tXsRes_cur, pred_cur, mRes.element_type.zero) k_off = t * tile_n - gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, cluster_y)) + gX_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mX_i), + tiler_mn_tile, + (0, cluster_y), + ) tXgX_t = thr_copy_tile.partition_S(gX_t) tXrX = cute.make_fragment_like(tXgX_t) cute.autovec_copy(tXsX_cur, tXrX) x = tXrX.load().to(cute.Float32) if const_expr(mRes is not None): - gRes_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, cluster_y)) + gRes_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mRes_i), + tiler_mn_tile, + (0, cluster_y), + ) tXgRes_t = thr_copy_tile.partition_S(gRes_t) tXrRes = cute.make_fragment_like(tXgRes_t) cute.autovec_copy(tXsRes_cur, tXrRes) x += tXrRes.load().to(cute.Float32) if const_expr(mResO is not None): - gResO_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mResO_i), tiler_mn_tile, (0, cluster_y)) + gResO_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mResO_i), + tiler_mn_tile, + (0, cluster_y), + ) tXgResO_t = thr_copy_tile.partition_D(gResO_t) tXrResO = cute.make_fragment_like(tXgResO_t) tXrResO.store(x.to(tXrResO.element_type)) if row_i < shape[0]: - copy_tiled(tXrResO, tXgResO_t, num_copy_elems=vecsize, is_async=False, pred=pred_cur) + copy_tiled( + tXrResO, + tXgResO_t, + num_copy_elems=vecsize, + is_async=False, + pred=pred_cur, + ) sum_sq_tile = row_reduce( x * x, @@ -494,7 +632,9 @@ def _kernel_impl( reduction_buffer[None, None, 0], mbar_ptr, init_val=0.0, - hook_fn=(cute.arch.cluster_wait if const_expr(cluster_n > 1) else None), + hook_fn=( + cute.arch.cluster_wait if const_expr(cluster_n > 1) else None + ), ) sum_sq_acc = sum_sq_acc + sum_sq_tile @@ -509,32 +649,46 @@ def _kernel_impl( for t in cutlass.range_constexpr(num_tiles): k_off = t * tile_n - cX_t = cute.local_tile(cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, cluster_y)) + cX_t = cute.local_tile( + cute.domain_offset((0, k_off), cX_i), tiler_mn_tile, (0, cluster_y) + ) tXc_t = thr_copy_tile.partition_S(cX_t) tXp_t = qutils.predicate_k(tXc_t, limit=shape[1]) if const_expr((t % 2) == 0): tXsX_cur = thr_copy_tile.partition_D(sX0_tile) tXsRes_cur = ( - thr_copy_tile.partition_D(sRes0_tile) if const_expr(mRes is not None) else None + thr_copy_tile.partition_D(sRes0_tile) + if const_expr(mRes is not None) + else None ) else: tXsX_cur = thr_copy_tile.partition_D(sX1_tile) tXsRes_cur = ( - thr_copy_tile.partition_D(sRes1_tile) if const_expr(mRes is not None) else None + thr_copy_tile.partition_D(sRes1_tile) + if const_expr(mRes is not None) + else None ) qutils.fill_oob(tXsX_cur, tXp_t, mX.element_type.zero) if const_expr(mRes is not None): qutils.fill_oob(tXsRes_cur, tXp_t, mRes.element_type.zero) - gX_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mX_i), tiler_mn_tile, (0, cluster_y)) + gX_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mX_i), + tiler_mn_tile, + (0, cluster_y), + ) tXgX_t = thr_copy_tile.partition_S(gX_t) tXrX = cute.make_fragment_like(tXgX_t) cute.autovec_copy(tXsX_cur, tXrX) x = tXrX.load().to(cute.Float32) if const_expr(mRes is not None): - gRes_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mRes_i), tiler_mn_tile, (0, cluster_y)) + gRes_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mRes_i), + tiler_mn_tile, + (0, cluster_y), + ) tXgRes_t = thr_copy_tile.partition_S(gRes_t) tXrRes = cute.make_fragment_like(tXgRes_t) cute.autovec_copy(tXsRes_cur, tXrRes) @@ -542,35 +696,67 @@ def _kernel_impl( y = x * rstd if const_expr(mW is not None): - gW_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mW), tiler_mn_tile, (0, cluster_y)) + gW_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mW), + tiler_mn_tile, + (0, cluster_y), + ) tWgW_t = thr_copy_tile.partition_S(gW_t) tWrW_t = cute.make_fragment_like(tWgW_t) - copy_tiled(tWgW_t, tWrW_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t) + copy_tiled( + tWgW_t, + tWrW_t, + num_copy_elems=vecsize, + is_async=False, + pred=tXp_t, + ) y = y * tWrW_t.load().to(cute.Float32) if const_expr(mB is not None): - gB_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mB), tiler_mn_tile, (0, cluster_y)) + gB_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mB), + tiler_mn_tile, + (0, cluster_y), + ) tWgB_t = thr_copy_tile.partition_S(gB_t) tWrB_t = cute.make_fragment_like(tWgB_t) - copy_tiled(tWgB_t, tWrB_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t) + copy_tiled( + tWgB_t, + tWrB_t, + num_copy_elems=vecsize, + is_async=False, + pred=tXp_t, + ) y = y + tWrB_t.load().to(cute.Float32) - gO_t = cute.local_tile(qutils.domain_offset_i64((0, k_off), mO_i), tiler_mn_tile, (0, cluster_y)) + gO_t = cute.local_tile( + qutils.domain_offset_i64((0, k_off), mO_i), + tiler_mn_tile, + (0, cluster_y), + ) tXgO_t = thr_copy_tile.partition_D(gO_t) tXrO = cute.make_fragment_like(tXgO_t) tXrO.store(y.to(tXrO.element_type)) if row_i < shape[0]: - copy_tiled(tXrO, tXgO_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t) + copy_tiled( + tXrO, tXgO_t, num_copy_elems=vecsize, is_async=False, pred=tXp_t + ) return # Fallback: single-stage path identical to current rmsnorm.py tXgX_i = thr_copy.partition_S(gX_i) - tXgRes_i = thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None + tXgRes_i = ( + thr_copy.partition_S(gRes_i) if const_expr(mRes is not None) else None + ) tXgO_i = thr_copy.partition_D(gO_i) - tXgResO_i = thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None + tXgResO_i = ( + thr_copy.partition_D(gResO_i) if const_expr(mResO is not None) else None + ) is_even_N_i = const_expr(shape[1] == tiler_mn[1] * cluster_n) tXpX_i = ( - qutils.predicate_k(thr_copy.partition_S(cX_i), limit=shape[1]) if not is_even_N_i else None + qutils.predicate_k(thr_copy.partition_S(cX_i), limit=shape[1]) + if not is_even_N_i + else None ) if row_i < shape[0]: @@ -594,7 +780,9 @@ def _kernel_impl( tXrResO.store(x.to(tXrResO.element_type)) if row_i < shape[0]: cute.copy( - get_copy_atom_bw(tXrResO.element_type, num_copy_elems_X, is_async=False), + get_copy_atom_bw( + tXrResO.element_type, num_copy_elems_X, is_async=False + ), tXrResO, tXgResO_i, ) @@ -715,7 +903,9 @@ def _alloc_reduction_and_mbar( (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage), order=(1, 0, 2), ) - reduction_buffer = smem.allocate_tensor(self.reduction_dtype, red_layout, byte_alignment=4) + reduction_buffer = smem.allocate_tensor( + self.reduction_dtype, red_layout, byte_alignment=4 + ) if const_expr(cluster_n > 1): mbar_ptr = smem.allocate_array(cutlass.Int64, num_elems=self.stage) else: @@ -745,9 +935,9 @@ def rmsnorm_forward_with_stage2( dtype = TORCH2CUTE_DTYPE[x.dtype] def _convert_x(t: Tensor) -> cute.Tensor: - return from_dlpack( - t.detach(), assumed_align=32 - ).mark_layout_dynamic(leading_dim=1) + return from_dlpack(t.detach(), assumed_align=32).mark_layout_dynamic( + leading_dim=1 + ) mX = _convert_x(x) mRes = _convert_x(residual) if residual is not None else None @@ -755,7 +945,9 @@ def _convert_x(t: Tensor) -> cute.Tensor: mO = from_dlpack(out.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1) mW = ( - from_dlpack(weight.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=0) + from_dlpack(weight.detach(), assumed_align=32).mark_layout_dynamic( + leading_dim=0 + ) if weight is not None else None ) @@ -766,7 +958,9 @@ def _convert_x(t: Tensor) -> cute.Tensor: ) if store_rstd: rstd = torch.empty(M, device=x.device, dtype=torch.float32) - mRstd = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) + mRstd = from_dlpack(rstd.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=0 + ) else: rstd = None mRstd = None @@ -775,7 +969,9 @@ def _convert_x(t: Tensor) -> cute.Tensor: mResO = None if residual is not None: residual_out = torch.empty_like(residual) - mResO = from_dlpack(residual_out.detach(), assumed_align=32).mark_layout_dynamic(leading_dim=1) + mResO = from_dlpack( + residual_out.detach(), assumed_align=32 + ).mark_layout_dynamic(leading_dim=1) # Enable the intra-row cp.async K-loop only for DSv3-style large-N rows # with very large M, where there is enough work per row to amortize the @@ -788,7 +984,7 @@ def _convert_x(t: Tensor) -> cute.Tensor: op._tpr_override = 128 # type: ignore[attr-defined] # Prefer 1 row/CTA at N=6144; keep 2 rows/CTA at N=8192 to match # the original tuning there. - op._nt_override = (128 if N == 6144 else 256) # type: ignore[attr-defined] + op._nt_override = 128 if N == 6144 else 256 # type: ignore[attr-defined] stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) key = ( @@ -803,7 +999,9 @@ def _convert_x(t: Tensor) -> cute.Tensor: ) compiled = _COMPILE_CACHE.get(key) if compiled is None: - compiled = cute.compile(op, mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps)) + compiled = cute.compile( + op, mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps) + ) _COMPILE_CACHE[key] = compiled compiled(mX, mW, mB, mRes, mO, mResO, mRstd, stream, Float32(eps)) return out, rstd, residual_out diff --git a/oink/src/kernelagent_oink/blackwell/softmax.py b/oink/src/kernelagent_oink/blackwell/softmax.py index a8a2791..6a7eb54 100644 --- a/oink/src/kernelagent_oink/blackwell/softmax.py +++ b/oink/src/kernelagent_oink/blackwell/softmax.py @@ -134,7 +134,9 @@ def __call__(self, mX: cute.Tensor, mO: cute.Tensor, stream: cuda.CUstream) -> N # Use the generic ReductionBase tiling with 128-bit vectorization. tiler_mn, tv_layout = self._get_tv_layout() num_threads = ( - cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() ) num_warps = num_threads // cute.arch.WARP_SIZE kernel = ( @@ -201,7 +203,9 @@ def _kernel_impl( cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16, ) - reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, tv_layout + ) # Copy atoms for gmem <-> smem and smem <-> gmem. # Use 128-bit cp.async for global->shared and 128-bit vectorized stores. @@ -216,8 +220,12 @@ def _kernel_impl( num_bits_per_copy=128, ) - thr_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn).get_slice(tidx) - thr_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx) + thr_copy_load = cute.make_tiled_copy( + copy_atom_load, tv_layout, tiler_mn + ).get_slice(tidx) + thr_copy_store = cute.make_tiled_copy( + copy_atom_store, tv_layout, tiler_mn + ).get_slice(tidx) tXgX = thr_copy_load.partition_S(gX) tXsX = thr_copy_load.partition_D(sX) @@ -349,7 +357,10 @@ def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int: # Store both y and dy tiles plus reduction buffers and mbarriers. return ( cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) * 2 - + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8) + + self.stage + * num_warps + * self.cluster_n + * (self.reduction_dtype.width // 8) + self.stage * (cutlass.Int64.width // 8) ) @@ -367,7 +378,9 @@ def __call__( # Use the generic ReductionBase tiling with 128-bit vectorization. tiler_mn, tv_layout = self._get_tv_layout() num_threads = ( - cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS else self._get_num_threads() + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() ) num_warps = num_threads // cute.arch.WARP_SIZE kernel = ( @@ -423,7 +436,9 @@ def _kernel_impl( mdY, mY, mdX = [ domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mdY, mY, mdX) ] - gdY, gY, gdX = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX)] + gdY, gY, gdX = [ + cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX) + ] cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) smem = cutlass.utils.SmemAllocator() @@ -437,7 +452,9 @@ def _kernel_impl( cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16, ) - reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, tv_layout + ) copy_atom_load = cute.make_copy_atom( cute.nvgpu.cpasync.CopyG2SOp(), @@ -450,8 +467,12 @@ def _kernel_impl( num_bits_per_copy=128, ) - thr_copy_load = cute.make_tiled_copy(copy_atom_load, tv_layout, tiler_mn).get_slice(tidx) - thr_copy_store = cute.make_tiled_copy(copy_atom_store, tv_layout, tiler_mn).get_slice(tidx) + thr_copy_load = cute.make_tiled_copy( + copy_atom_load, tv_layout, tiler_mn + ).get_slice(tidx) + thr_copy_store = cute.make_tiled_copy( + copy_atom_store, tv_layout, tiler_mn + ).get_slice(tidx) tdYgdY = thr_copy_load.partition_S(gdY) tdYsdY = thr_copy_load.partition_D(sdY) @@ -460,7 +481,9 @@ def _kernel_impl( tdXgdX = thr_copy_store.partition_D(gdX) tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None] - tdYrdY, tYrY, tdXrdX = [cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX)] + tdYrdY, tYrY, tdXrdX = [ + cute.make_fragment_like(thr) for thr in (tdYgdY, tYgY, tdXgdX) + ] num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE self._initialize_cluster(tidx, mbar_ptr, num_warps) @@ -535,9 +558,8 @@ def _convert_2d_tensor(x: Tensor) -> cute.Tensor: # the shape compact with row-major stride order (0, 1), with mode=0 (batch). # We intentionally do not call mark_layout_dynamic here to avoid the # leading_dim stride==1 constraint used in RMSNorm. - return ( - from_dlpack(x.detach(), assumed_align=16) - .mark_compact_shape_dynamic(mode=0, stride_order=(0, 1)) + return from_dlpack(x.detach(), assumed_align=16).mark_compact_shape_dynamic( + mode=0, stride_order=(0, 1) ) @@ -581,7 +603,9 @@ def _softmax_forward_ptr_into(*, x: Tensor, out: Tensor) -> None: compiled = _PTR_FWD_COMPILE_CACHE.get(key) if compiled is None: op = SoftmaxFwdSM100(dtype_x, int(N)) - ptr_x = rt.make_ptr(dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_x = rt.make_ptr( + dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) ptr_out = rt.make_ptr( dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 ) @@ -596,8 +620,12 @@ def _softmax_forward_ptr_into(*, x: Tensor, out: Tensor) -> None: ) _PTR_FWD_COMPILE_CACHE[key] = compiled - ptr_x = rt.make_ptr(dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) - ptr_out = rt.make_ptr(dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_x = rt.make_ptr( + dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_out = rt.make_ptr( + dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) compiled(ptr_x, ptr_out, Int32(int(M)), Int32(int(x.stride(0))), stream) @@ -606,7 +634,9 @@ def _softmax_backward_ptr_into(*, dy: Tensor, y: Tensor, dx: Tensor) -> None: assert dy.is_cuda and dy.dim() == 2 assert y.is_cuda and y.shape == dy.shape and y.dtype == dy.dtype assert dx.is_cuda and dx.shape == dy.shape and dx.dtype == dy.dtype - assert dy.stride() == y.stride() == dx.stride(), "Pointer path expects matching strides" + assert dy.stride() == y.stride() == dx.stride(), ( + "Pointer path expects matching strides" + ) M, N = dy.shape device_index = dy.get_device() @@ -619,9 +649,15 @@ def _softmax_backward_ptr_into(*, dy: Tensor, y: Tensor, dx: Tensor) -> None: compiled = _PTR_BWD_COMPILE_CACHE.get(key) if compiled is None: op = SoftmaxBwdSM100(dtype_x, int(N)) - ptr_dy = rt.make_ptr(dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) - ptr_y = rt.make_ptr(dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) - ptr_dx = rt.make_ptr(dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_dy = rt.make_ptr( + dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_y = rt.make_ptr( + dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) ld = Int32(int(dy.stride(0))) compiled = cute.compile( op.launch_from_ptrs, @@ -634,9 +670,15 @@ def _softmax_backward_ptr_into(*, dy: Tensor, y: Tensor, dx: Tensor) -> None: ) _PTR_BWD_COMPILE_CACHE[key] = compiled - ptr_dy = rt.make_ptr(dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) - ptr_y = rt.make_ptr(dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) - ptr_dx = rt.make_ptr(dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16) + ptr_dy = rt.make_ptr( + dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_y = rt.make_ptr( + dtype_x, y.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) compiled(ptr_dy, ptr_y, ptr_dx, Int32(int(M)), Int32(int(dy.stride(0))), stream) @@ -679,8 +721,14 @@ def softmax_backward(dy: Tensor, y: Tensor) -> Tensor: N = dy.size(1) dtype = TORCH2CUTE_DTYPE[dy.dtype] - if _can_use_ptr_path_2d(dy) and _can_use_ptr_path_2d(y) and dy.stride() == y.stride(): - dx = torch.empty_strided(dy.shape, dy.stride(), device=dy.device, dtype=dy.dtype) + if ( + _can_use_ptr_path_2d(dy) + and _can_use_ptr_path_2d(y) + and dy.stride() == y.stride() + ): + dx = torch.empty_strided( + dy.shape, dy.stride(), device=dy.device, dtype=dy.dtype + ) _softmax_backward_ptr_into(dy=dy, y=y, dx=dx) return dx From 7e818eecf3448f9f8173ac4ad5154177fea830e8 Mon Sep 17 00:00:00 2001 From: Laura Wang <3700467+Laurawly@users.noreply.github.com> Date: Wed, 21 Jan 2026 20:11:32 -0800 Subject: [PATCH 7/8] oink: add license headers to benchmarks --- oink/benchmarks/benchmark/bench_utils.py | 14 ++++++++++++++ .../benchmark/benchmark_cross_entropy_sm100.py | 14 ++++++++++++++ .../benchmark/benchmark_fused_add_rmsnorm_sm100.py | 14 ++++++++++++++ .../benchmark/benchmark_hbm_roofline_sm100.py | 14 ++++++++++++++ .../benchmark/benchmark_layernorm_sm100.py | 14 ++++++++++++++ .../benchmark/benchmark_rmsnorm_bwd_sm100.py | 14 ++++++++++++++ .../benchmark/benchmark_rmsnorm_sm100.py | 14 ++++++++++++++ .../benchmark/benchmark_softmax_sm100.py | 14 ++++++++++++++ oink/benchmarks/readme/plot_quack_style_svg.py | 14 ++++++++++++++ oink/benchmarks/readme/run_sm100_suite.py | 14 ++++++++++++++ oink/benchmarks/readme/summarize_results.py | 14 ++++++++++++++ 11 files changed, 154 insertions(+) diff --git a/oink/benchmarks/benchmark/bench_utils.py b/oink/benchmarks/benchmark/bench_utils.py index 0a9ae4b..ef996ec 100644 --- a/oink/benchmarks/benchmark/bench_utils.py +++ b/oink/benchmarks/benchmark/bench_utils.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import csv diff --git a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py index ff1a99b..3c8bf44 100644 --- a/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_cross_entropy_sm100.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import argparse diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py index 863712d..8a0227b 100644 --- a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Benchmark fused_add_rmsnorm (in-place) on SM100. diff --git a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py index c22294e..22fb48d 100644 --- a/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_hbm_roofline_sm100.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ HBM roofline microbenchmark for SM100 (GB200 / Blackwell). diff --git a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py index 3c0e37d..20895b7 100644 --- a/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_layernorm_sm100.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import argparse diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py index b9909e7..31b335b 100644 --- a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import argparse diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py index f4c8a5f..39e6cd7 100644 --- a/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_sm100.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import argparse diff --git a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py index 995b09f..a5b2b3c 100644 --- a/oink/benchmarks/benchmark/benchmark_softmax_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_softmax_sm100.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import argparse diff --git a/oink/benchmarks/readme/plot_quack_style_svg.py b/oink/benchmarks/readme/plot_quack_style_svg.py index af76832..88eebdf 100644 --- a/oink/benchmarks/readme/plot_quack_style_svg.py +++ b/oink/benchmarks/readme/plot_quack_style_svg.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """ Generate Quack-style SVG performance plots (Oink vs Quack) from the SM100 suite JSON artifacts under `/tmp/kernelagent_oink_sm100_suite_{bf16,fp16}`. diff --git a/oink/benchmarks/readme/run_sm100_suite.py b/oink/benchmarks/readme/run_sm100_suite.py index c31d4b5..fb9d603 100644 --- a/oink/benchmarks/readme/run_sm100_suite.py +++ b/oink/benchmarks/readme/run_sm100_suite.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import argparse diff --git a/oink/benchmarks/readme/summarize_results.py b/oink/benchmarks/readme/summarize_results.py index 29b768e..684694d 100644 --- a/oink/benchmarks/readme/summarize_results.py +++ b/oink/benchmarks/readme/summarize_results.py @@ -1,3 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations import argparse From 5d195d6c2d80b9415d8a5433c54494ccac769eac Mon Sep 17 00:00:00 2001 From: Laura Wang <3700467+Laurawly@users.noreply.github.com> Date: Thu, 22 Jan 2026 09:59:37 -0800 Subject: [PATCH 8/8] update --- oink/benchmarks/README.md | 6 + .../benchmark_fused_add_rmsnorm_sm100.py | 55 +- .../benchmark/benchmark_rmsnorm_bwd_sm100.py | 11 +- .../media/sm100_bf16_oink_vs_quack.svg | 350 ++-- .../media/sm100_bf16_oink_vs_quack_dsv3.svg | 714 +++---- .../sm100_bf16_oink_vs_quack_dsv3_all.svg | 570 +++--- ..._bf16_oink_vs_quack_dsv3_cross_entropy.svg | 180 +- ...bf16_oink_vs_quack_dsv3_with_layernorm.svg | 1742 ++++++++--------- ...m100_bf16_oink_vs_quack_with_layernorm.svg | 460 ++--- .../media/sm100_fp16_oink_vs_quack.svg | 350 ++-- .../media/sm100_fp16_oink_vs_quack_dsv3.svg | 714 +++---- .../sm100_fp16_oink_vs_quack_dsv3_all.svg | 570 +++--- ..._fp16_oink_vs_quack_dsv3_cross_entropy.svg | 180 +- ...fp16_oink_vs_quack_dsv3_with_layernorm.svg | 1742 ++++++++--------- ...m100_fp16_oink_vs_quack_with_layernorm.svg | 460 ++--- oink/benchmarks/readme/run_sm100_suite.py | 17 + .../blackwell/cross_entropy.py | 1045 +++++++++- .../kernelagent_oink/blackwell/fast_launch.py | 115 ++ .../kernelagent_oink/blackwell/layernorm.py | 845 ++++++-- .../kernelagent_oink/blackwell/lite_quack.py | 181 +- .../src/kernelagent_oink/blackwell/rmsnorm.py | 1287 +++++++++++- .../src/kernelagent_oink/blackwell/softmax.py | 834 +++++++- 22 files changed, 7996 insertions(+), 4432 deletions(-) create mode 100644 oink/src/kernelagent_oink/blackwell/fast_launch.py diff --git a/oink/benchmarks/README.md b/oink/benchmarks/README.md index ceb7932..a5c4676 100644 --- a/oink/benchmarks/README.md +++ b/oink/benchmarks/README.md @@ -96,6 +96,12 @@ CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsn --json /tmp/fused_add_rmsnorm_sm100_bf16.json ``` +Note on the Quack baseline: Oink exposes an **in-place** fused op (updates `x` and `residual`). +Quack’s fused kernel produces `out` and `residual_out` out-of-place, so by default the benchmark +times `quack::_rmsnorm_fwd` **plus** two explicit copies (`x.copy_(out)`, `residual.copy_(residual_out)`) +to match the in-place semantics (integration-realistic). Use `--quack-baseline kernel` to time only +the Quack fused kernel with preallocated outputs. + ### RMSNorm backward ```bash diff --git a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py index 8a0227b..1787d7d 100644 --- a/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py @@ -31,6 +31,15 @@ DSv3 suite (Oink vs Quack, multi-shape): CUDA_VISIBLE_DEVICES=0 python oink/benchmarks/benchmark/benchmark_fused_add_rmsnorm_sm100.py --dtype bf16 --dsv3 \\ --json /tmp/kernelagent_oink_sm100_suite_bf16/fused_add_rmsnorm_dsv3.json + +Quack baseline note: +- Oink exposes an **in-place** fused op (writes `x` and `residual` in-place). +- Quack provides an equivalent fused kernel, but typically returns `out` and + `residual_out` (out-of-place) and does not expose a public "update my input + buffers in-place" API. +- For integration realism (vLLM-style semantics) we default to timing: + Quack fused kernel + 2 explicit copies to apply the in-place updates + so the benchmark covers the full semantic cost. """ from __future__ import annotations @@ -177,6 +186,7 @@ def bench_one( warmup_ms: int, iters_ms: int, verify: bool, + quack_baseline: str, ) -> Dict[str, Any]: device = torch.device("cuda") x = torch.randn((M, N), device=device, dtype=dtype) @@ -212,23 +222,40 @@ def fn(): row.update(stats) if quack_rmsnorm_fwd_mut is not None: - out_q = torch.empty_like(x) - res_out_q = torch.empty_like(residual) + x_q = x.clone() + residual_q = residual.clone() + out_q = torch.empty_like(x_q) + res_out_q = torch.empty_like(residual_q) - def fn_q(): + def fn_q_kernel(): quack_rmsnorm_fwd_mut( - x, + x_q, w, out_q, None, # bias None, # rstd None, # mean - residual, + residual_q, res_out_q, 1e-6, False, # is_layernorm ) + if quack_baseline == "kernel": + fn_q = fn_q_kernel + elif quack_baseline == "kernel_inplace": + + def fn_q(): + fn_q_kernel() + # Apply the same in-place semantics as vLLM expects: + # - x is overwritten with y + # - residual is overwritten with z = x + residual + x_q.copy_(out_q) + residual_q.copy_(res_out_q) + + else: + raise ValueError(f"Unknown quack_baseline: {quack_baseline}") + ms_q = do_bench_triton(fn_q, warmup_ms=warmup_ms, rep_ms=iters_ms) gbps_q = bytes_io / (ms_q * 1e-3) / 1e9 row.update( @@ -287,6 +314,18 @@ def main() -> None: p.add_argument( "--iters", type=int, default=200, help="rep_ms for do_bench (default: 200)" ) + p.add_argument( + "--quack-baseline", + type=str, + default="kernel_inplace", + choices=["kernel", "kernel_inplace"], + help=( + "How to time Quack for the in-place fused op.\n" + "- kernel: Quack fused kernel only (preallocated out/residual_out).\n" + "- kernel_inplace: Quack fused kernel + 2 explicit copies to apply " + "in-place semantics (integration-realistic)." + ), + ) p.add_argument("--skip-verify", action="store_true") p.add_argument("--json", type=str, default=None) args = p.parse_args() @@ -309,6 +348,7 @@ def main() -> None: warmup_ms=int(args.warmup_ms), iters_ms=int(args.iters), verify=not bool(args.skip_verify), + quack_baseline=str(args.quack_baseline), ) ) @@ -324,7 +364,10 @@ def main() -> None: warmup_ms=int(args.warmup_ms), rep_ms=int(args.iters), method="triton.testing.do_bench(mean)", - note="Oink fused_add_rmsnorm_inplace_ vs Quack quack::_rmsnorm_fwd(residual=..., residual_out=...) when available", + note=( + "Oink fused_add_rmsnorm_inplace_ vs Quack baseline " + f"({args.quack_baseline}) when available" + ), ), ) diff --git a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py index 31b335b..50ecb2e 100644 --- a/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py +++ b/oink/benchmarks/benchmark/benchmark_rmsnorm_bwd_sm100.py @@ -17,7 +17,6 @@ import argparse import csv import os -import sys from dataclasses import dataclass from typing import List, Optional, Tuple @@ -30,19 +29,17 @@ # Ensure SM100 (GB200) architecture is recognized by CuTeDSL when running outside vLLM. os.environ.setdefault("CUTE_DSL_ARCH", "sm_100a") -# Make the in-repo KernelAgent Oink package importable without an editable install. -_HERE = os.path.dirname(os.path.abspath(__file__)) -_OINK_SRC = os.path.abspath(os.path.join(_HERE, "..", "src")) -if _OINK_SRC not in sys.path: - sys.path.insert(0, _OINK_SRC) - from bench_utils import ( # noqa: E402 ErrorStatsAccumulator, collect_device_meta, + ensure_oink_src_on_path, error_stats_to_row, iter_row_blocks, write_json, ) + +ensure_oink_src_on_path() + from kernelagent_oink.blackwell import rmsnorm as oink_rmsnorm # noqa: E402 try: diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg index e32e3a7..96b5b83 100644 --- a/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack.svg @@ -6,7 +6,7 @@ - 2026-01-12T23:31:37.117906 + 2026-01-22T03:16:57.722815 image/svg+xml @@ -41,12 +41,12 @@ z - - + @@ -176,7 +176,7 @@ z - + @@ -243,7 +243,7 @@ z - + @@ -322,7 +322,7 @@ z - + @@ -343,7 +343,7 @@ z - + @@ -365,7 +365,7 @@ z - + @@ -414,7 +414,7 @@ z - + @@ -439,16 +439,16 @@ z +" clip-path="url(#p68502969ea)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - - + @@ -483,18 +483,18 @@ z - + - + - + @@ -504,18 +504,18 @@ L 424.416918 286.945749 - + - + - + @@ -525,18 +525,18 @@ L 424.416918 239.345028 - + - + - + @@ -944,16 +944,16 @@ z - + - - - - - - - - - + + + + + + + + - + - - - - - - - - - + + + + + + + + +" clip-path="url(#p68502969ea)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1247,7 +1247,7 @@ z - + @@ -1268,7 +1268,7 @@ z - + @@ -1289,7 +1289,7 @@ z - + @@ -1310,7 +1310,7 @@ z - + @@ -1332,7 +1332,7 @@ z - + @@ -1354,7 +1354,7 @@ z - + @@ -1379,93 +1379,93 @@ z +" clip-path="url(#p8e212f52e1)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - + + + + + + + + + - - - - - - - - - + + + + + + + + + +" clip-path="url(#p8e212f52e1)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1601,7 +1601,7 @@ z - + @@ -1622,7 +1622,7 @@ z - + @@ -1643,7 +1643,7 @@ z - + @@ -1664,7 +1664,7 @@ z - + @@ -1686,7 +1686,7 @@ z - + @@ -1708,7 +1708,7 @@ z - + @@ -1733,93 +1733,93 @@ z +" clip-path="url(#p9bab140156)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - + + + + + + + + + - - - - - - - - - + + + + + + + + + +" clip-path="url(#p9bab140156)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -2183,7 +2183,7 @@ L 619.955625 46.691969 L 636.205625 46.691969 " style="fill: none; stroke: #ff4444; stroke-width: 5; stroke-linecap: square"/> - + @@ -2246,13 +2246,13 @@ z - + - + - + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg index b70ba9b..254623e 100644 --- a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3.svg @@ -1,12 +1,12 @@ - + - 2026-01-12T20:27:29.562089 + 2026-01-22T03:17:01.077305 image/svg+xml @@ -21,19 +21,19 @@ - - @@ -41,17 +41,17 @@ z - - + - + - + - + - + - + @@ -225,12 +225,12 @@ z - + - + - + - + @@ -278,12 +278,12 @@ z - + - + @@ -299,12 +299,12 @@ z - + - + - + - + @@ -381,12 +381,12 @@ z - + - + @@ -403,23 +403,23 @@ z - + - - + - + - + - + - + @@ -470,18 +470,18 @@ L 436.873051 321.210406 - + - + - + - + - + - + - + - + - + @@ -593,18 +593,18 @@ L 436.873051 248.156813 - + - + - + - + - + - + @@ -662,18 +662,18 @@ L 436.873051 199.454417 - + - + - + @@ -683,7 +683,7 @@ L 436.873051 175.10322 - + - + - - - - - - - - - - - + + + + + + + + + + - + - - - - - - - - - - - + + + + + + + + + + - + - - - - - + - @@ -1380,12 +1380,12 @@ z - + - + @@ -1400,12 +1400,12 @@ z - + - + @@ -1421,12 +1421,12 @@ z - + - + @@ -1442,12 +1442,12 @@ z - + - + @@ -1462,12 +1462,12 @@ z - + - + @@ -1483,12 +1483,12 @@ z - + - + @@ -1504,12 +1504,12 @@ z - + - + @@ -1524,12 +1524,12 @@ z - + - + @@ -1545,12 +1545,12 @@ z - + - + @@ -1567,175 +1567,175 @@ z - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - + + + + + + + + + + + - + - - - - - + - @@ -1812,12 +1812,12 @@ z - + - + @@ -1832,12 +1832,12 @@ z - + - + @@ -1853,12 +1853,12 @@ z - + - + @@ -1874,12 +1874,12 @@ z - + - + @@ -1894,12 +1894,12 @@ z - + - + @@ -1915,12 +1915,12 @@ z - + - + @@ -1936,12 +1936,12 @@ z - + - + @@ -1956,12 +1956,12 @@ z - + - + @@ -1977,12 +1977,12 @@ z - + - + @@ -1999,175 +1999,175 @@ z - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - + + + + + + + + + + + - + - - - - - + - + - - + - + - - + - + @@ -2538,14 +2538,14 @@ L 644.270625 53.315969 - - + - + + - - + + - - + + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg index f5cd53c..9db31a5 100644 --- a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_all.svg @@ -6,7 +6,7 @@ - 2026-01-12T23:50:09.117981 + 2026-01-22T03:17:06.137573 image/svg+xml @@ -41,12 +41,12 @@ z - - + @@ -167,7 +167,7 @@ z - + @@ -204,7 +204,7 @@ z - + @@ -225,7 +225,7 @@ z - + @@ -257,7 +257,7 @@ z - + @@ -278,7 +278,7 @@ z - + @@ -299,7 +299,7 @@ z - + @@ -360,7 +360,7 @@ z - + @@ -381,7 +381,7 @@ z - + @@ -652,16 +652,16 @@ z +" clip-path="url(#pd5fb8ecbf3)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - - + @@ -696,18 +696,18 @@ z - + - + - + - + - + - + @@ -764,18 +764,18 @@ L 432.752252 230.145926 - + - + - + @@ -1032,18 +1032,18 @@ z - + - - - - - - - - - - - + + + + + + + + + + - + - - - - - - - - - - - + + + + + + + + + + +" clip-path="url(#pd5fb8ecbf3)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1307,7 +1307,7 @@ z - + @@ -1328,7 +1328,7 @@ z - + @@ -1349,7 +1349,7 @@ z - + @@ -1369,7 +1369,7 @@ z - + @@ -1390,7 +1390,7 @@ z - + @@ -1411,7 +1411,7 @@ z - + @@ -1431,7 +1431,7 @@ z - + @@ -1452,7 +1452,7 @@ z - + @@ -1500,101 +1500,101 @@ z +" clip-path="url(#p637d225080)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - + + + + + + + + + + + +" clip-path="url(#p637d225080)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1715,7 +1715,7 @@ z - + @@ -1736,7 +1736,7 @@ z - + @@ -1757,7 +1757,7 @@ z - + @@ -1777,7 +1777,7 @@ z - + @@ -1798,7 +1798,7 @@ z - + @@ -1819,7 +1819,7 @@ z - + @@ -1839,7 +1839,7 @@ z - + @@ -1860,7 +1860,7 @@ z - + @@ -1908,101 +1908,101 @@ z +" clip-path="url(#p8827df76f1)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - + + + + + + + + + + + +" clip-path="url(#p8827df76f1)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -2123,7 +2123,7 @@ z - + @@ -2144,7 +2144,7 @@ z - + @@ -2165,7 +2165,7 @@ z - + @@ -2185,7 +2185,7 @@ z - + @@ -2206,7 +2206,7 @@ z - + @@ -2227,7 +2227,7 @@ z - + @@ -2247,7 +2247,7 @@ z - + @@ -2268,7 +2268,7 @@ z - + @@ -2289,7 +2289,7 @@ z - + @@ -2310,7 +2310,7 @@ z - + @@ -2332,7 +2332,7 @@ z - + @@ -2433,113 +2433,113 @@ z +" clip-path="url(#p25c2489d1c)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - - - - - - + + + + + + + + + + + + + + - - - - - - - - - - - - - - + + + + + + + + + + + + + + +" clip-path="url(#p25c2489d1c)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -2857,7 +2857,7 @@ L 835.955625 46.691969 L 852.205625 46.691969 " style="fill: none; stroke: #ff4444; stroke-width: 5; stroke-linecap: square"/> - + @@ -2920,16 +2920,16 @@ z - + - + - + - + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg index db39e3c..c392959 100644 --- a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_cross_entropy.svg @@ -6,7 +6,7 @@ - 2026-01-12T23:31:44.506589 + 2026-01-22T03:17:04.456371 image/svg+xml @@ -41,12 +41,12 @@ z - - + @@ -169,7 +169,7 @@ z - + @@ -236,7 +236,7 @@ z - + @@ -257,7 +257,7 @@ z - + @@ -277,7 +277,7 @@ z - + @@ -298,7 +298,7 @@ z - + @@ -319,7 +319,7 @@ z - + @@ -380,7 +380,7 @@ z - + @@ -401,7 +401,7 @@ z - + @@ -422,7 +422,7 @@ z - + @@ -469,7 +469,7 @@ z - + @@ -491,7 +491,7 @@ z - + @@ -516,16 +516,16 @@ z +" clip-path="url(#p0d8a9a7a6c)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - - + @@ -560,18 +560,18 @@ z - + - + - + @@ -581,18 +581,18 @@ L 384.123653 286.898076 - + - + - + @@ -602,18 +602,18 @@ L 384.123653 239.313321 - + - + - + @@ -1021,21 +1021,21 @@ z - + - - - - - - - - - - - - - - + + + + + + + + + + + + + - + - - - - - - - - - - - - - - + + + + + + + + + + + + + +" clip-path="url(#p0d8a9a7a6c)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1568,7 +1568,7 @@ L 130.874375 62.995875 L 145.874375 62.995875 " style="fill: none; stroke: #ff4444; stroke-width: 5; stroke-linecap: square"/> - + @@ -1680,7 +1680,7 @@ z - + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg index e8d4cc6..0d4c1ae 100644 --- a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_dsv3_with_layernorm.svg @@ -1,12 +1,12 @@ - + - 2026-01-08T16:35:17.144819 + 2026-01-22T03:17:02.768056 image/svg+xml @@ -21,19 +21,19 @@ - - @@ -41,17 +41,17 @@ z - - + - + - + - + - + - + @@ -225,12 +225,86 @@ z - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + + - + - + - + @@ -304,15 +378,15 @@ z - - + + - + - + - + @@ -328,24 +402,24 @@ z - - + + - + - - + - + - + - - + + - + - + - + - + @@ -395,19 +469,19 @@ L 441.930945 276.439075 - - + + - + - + - + - + - - + + - + - + - + - + - - + + - + - + - + - + @@ -518,19 +592,19 @@ L 441.930945 198.671489 - - + + - + - + - + - + - - + + - + - + - + - + @@ -587,31 +661,19 @@ L 441.930945 146.826432 - - + + - + - + - + - - - - + @@ -619,9 +681,9 @@ z - + - + - + + - - - - - - - - + + + + + + + + + + - - + + - - - - - - - - + + + + + + + + + + - - + + - - - - - - - + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + - - - + + - + - + - + @@ -1315,15 +1397,15 @@ z - - + + - + - + - + @@ -1336,15 +1418,15 @@ z - - + + - + - + - + @@ -1357,15 +1439,77 @@ z - - + + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + @@ -1377,15 +1521,15 @@ z - - + + - + - + - + @@ -1398,15 +1542,15 @@ z - - + + - + - + - + @@ -1422,164 +1566,176 @@ z - - + + - + - + - - + + - + - + - - + + - + - + - - + + - + - + - - + + - + - + - - + + - + - + - - + + - + - + - - + + - + - + - - - - - - - - - + + + + + + + + + + + + - - - - - - - - - + + + + + + + + + + + + - - + + - - - - - + - + - - - - + + - + - + - + @@ -1688,15 +1829,15 @@ z - - + + - + - + - + @@ -1709,15 +1850,15 @@ z - - + + - + - + - + @@ -1730,381 +1871,77 @@ z - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + - + - - + + - + - - - + + + - + - - + + - + - - - + + + - + - - + + - + - + - + @@ -2116,15 +1953,15 @@ z - - + + - + - + - + @@ -2137,15 +1974,15 @@ z - - + + - + - + - + @@ -2159,166 +1996,178 @@ z - - - - - - + + + + + + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - - - - - - - + + + + + + + + + + + + - - - - - - - - - + + + + + + + + + + + + - - + + - - + - - + - - + - - + - + - + - - + + + - - + + + @@ -2566,47 +2426,70 @@ z - - - - - - - - - - - + + + + + + + + + + + + + + + + - - + - + - + - + - @@ -2635,18 +2518,18 @@ z - - + - + - + - + @@ -2654,15 +2537,15 @@ L 859.98375 39.937812 - - + - + - + - - - - + + - - + + - - + + diff --git a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg index a5670bd..1780d62 100644 --- a/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg +++ b/oink/benchmarks/media/sm100_bf16_oink_vs_quack_with_layernorm.svg @@ -6,7 +6,7 @@ - 2026-01-12T23:31:33.339254 + 2026-01-22T03:16:59.406646 image/svg+xml @@ -41,12 +41,12 @@ z - - + @@ -176,7 +176,7 @@ z - + @@ -243,7 +243,7 @@ z - + @@ -322,7 +322,7 @@ z - + @@ -343,7 +343,7 @@ z - + @@ -365,7 +365,7 @@ z - + @@ -414,7 +414,7 @@ z - + @@ -439,16 +439,16 @@ z +" clip-path="url(#pf9b9211caf)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - - + @@ -483,18 +483,18 @@ z - + - + - + @@ -504,18 +504,18 @@ L 429.474812 286.945749 - + - + - + @@ -525,18 +525,18 @@ L 429.474812 239.345028 - + - + - + @@ -944,16 +944,16 @@ z - + - - - - - - - - - + + + + + + + + - + - - - - - - - - - + + + + + + + + +" clip-path="url(#pf9b9211caf)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1247,7 +1247,7 @@ z - + @@ -1268,7 +1268,7 @@ z - + @@ -1289,7 +1289,7 @@ z - + @@ -1310,7 +1310,7 @@ z - + @@ -1332,7 +1332,7 @@ z - + @@ -1354,7 +1354,7 @@ z - + @@ -1379,93 +1379,93 @@ z +" clip-path="url(#p3f778d8e4c)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - + + + + + + + + + - - - - - - - - - + + + + + + + + + +" clip-path="url(#p3f778d8e4c)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1601,7 +1601,7 @@ z - + @@ -1622,7 +1622,7 @@ z - + @@ -1643,7 +1643,7 @@ z - + @@ -1664,7 +1664,7 @@ z - + @@ -1686,7 +1686,7 @@ z - + @@ -1708,7 +1708,7 @@ z - + @@ -1733,93 +1733,93 @@ z +" clip-path="url(#p17f87f3a7a)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - + + + + + + + + + - - - - - - - - - + + + + + + + + + +" clip-path="url(#p17f87f3a7a)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1948,7 +1948,7 @@ z - + @@ -1969,7 +1969,7 @@ z - + @@ -1990,7 +1990,7 @@ z - + @@ -2011,7 +2011,7 @@ z - + @@ -2033,7 +2033,7 @@ z - + @@ -2055,7 +2055,7 @@ z - + @@ -2080,93 +2080,93 @@ z +" clip-path="url(#p0aefa4695a)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - + + + + + + + + + - - - - - - - - - + + + + + + + + + +" clip-path="url(#p0aefa4695a)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -2501,7 +2501,7 @@ L 835.955625 46.691969 L 852.205625 46.691969 " style="fill: none; stroke: #ff4444; stroke-width: 5; stroke-linecap: square"/> - + @@ -2564,16 +2564,16 @@ z - + - + - + - + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg index 0a021e9..e3bcd46 100644 --- a/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack.svg @@ -6,7 +6,7 @@ - 2026-01-12T23:31:38.919062 + 2026-01-22T03:17:07.801333 image/svg+xml @@ -41,12 +41,12 @@ z - - + @@ -176,7 +176,7 @@ z - + @@ -243,7 +243,7 @@ z - + @@ -322,7 +322,7 @@ z - + @@ -343,7 +343,7 @@ z - + @@ -365,7 +365,7 @@ z - + @@ -414,7 +414,7 @@ z - + @@ -439,16 +439,16 @@ z +" clip-path="url(#pb2e732b357)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - - + @@ -483,18 +483,18 @@ z - + - + - + @@ -504,18 +504,18 @@ L 424.416918 286.915059 - + - + - + @@ -525,18 +525,18 @@ L 424.416918 239.283646 - + - + - + @@ -944,16 +944,16 @@ z - + - - - - - - - - - + + + + + + + + - + - - - - - - - - - + + + + + + + + +" clip-path="url(#pb2e732b357)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1247,7 +1247,7 @@ z - + @@ -1268,7 +1268,7 @@ z - + @@ -1289,7 +1289,7 @@ z - + @@ -1310,7 +1310,7 @@ z - + @@ -1332,7 +1332,7 @@ z - + @@ -1354,7 +1354,7 @@ z - + @@ -1379,93 +1379,93 @@ z +" clip-path="url(#pa994cb3d22)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - + + + + + + + + + - - - - - - - - - + + + + + + + + + +" clip-path="url(#pa994cb3d22)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1601,7 +1601,7 @@ z - + @@ -1622,7 +1622,7 @@ z - + @@ -1643,7 +1643,7 @@ z - + @@ -1664,7 +1664,7 @@ z - + @@ -1686,7 +1686,7 @@ z - + @@ -1708,7 +1708,7 @@ z - + @@ -1733,93 +1733,93 @@ z +" clip-path="url(#p0bfe6d2be3)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - + + + + + + + + + - - - - - - - - - + + + + + + + + + +" clip-path="url(#p0bfe6d2be3)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -2204,7 +2204,7 @@ L 619.955625 46.691969 L 636.205625 46.691969 " style="fill: none; stroke: #ff4444; stroke-width: 5; stroke-linecap: square"/> - + @@ -2267,13 +2267,13 @@ z - + - + - + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg index 9a58fde..e5cecac 100644 --- a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3.svg @@ -1,12 +1,12 @@ - + - 2026-01-12T20:27:30.111404 + 2026-01-22T03:17:11.211653 image/svg+xml @@ -21,19 +21,19 @@ - - @@ -41,17 +41,17 @@ z - - + - + - + - + - + - + @@ -225,12 +225,12 @@ z - + - + - + - + @@ -278,12 +278,12 @@ z - + - + @@ -299,12 +299,12 @@ z - + - + - + - + @@ -381,12 +381,12 @@ z - + - + @@ -403,23 +403,23 @@ z - + - - + - + - + - + - + @@ -470,18 +470,18 @@ L 436.873051 321.194706 - + - + - + - + - + - + - + - + - + @@ -593,18 +593,18 @@ L 436.873051 248.094011 - + - + - + - + - + - + @@ -662,18 +662,18 @@ L 436.873051 199.360215 - + - + - + @@ -683,7 +683,7 @@ L 436.873051 174.993316 - + - + - - - - - - - - - - - + + + + + + + + + + - + - - - - - - - - - - - + + + + + + + + + + - + - - - - - + - @@ -1380,12 +1380,12 @@ z - + - + @@ -1400,12 +1400,12 @@ z - + - + @@ -1421,12 +1421,12 @@ z - + - + @@ -1442,12 +1442,12 @@ z - + - + @@ -1462,12 +1462,12 @@ z - + - + @@ -1483,12 +1483,12 @@ z - + - + @@ -1504,12 +1504,12 @@ z - + - + @@ -1524,12 +1524,12 @@ z - + - + @@ -1545,12 +1545,12 @@ z - + - + @@ -1567,175 +1567,175 @@ z - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - + + + + + + + + + + + - + - - - - - + - @@ -1812,12 +1812,12 @@ z - + - + @@ -1832,12 +1832,12 @@ z - + - + @@ -1853,12 +1853,12 @@ z - + - + @@ -1874,12 +1874,12 @@ z - + - + @@ -1894,12 +1894,12 @@ z - + - + @@ -1915,12 +1915,12 @@ z - + - + @@ -1936,12 +1936,12 @@ z - + - + @@ -1956,12 +1956,12 @@ z - + - + @@ -1977,12 +1977,12 @@ z - + - + @@ -1999,175 +1999,175 @@ z - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - + + + + + + + + + + + - + - - - - - + - + - - + - + - - + - + @@ -2559,14 +2559,14 @@ L 644.270625 53.315969 - - + - + + - - + + - - + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg index bff56d0..1575906 100644 --- a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_all.svg @@ -6,7 +6,7 @@ - 2026-01-12T23:50:13.556455 + 2026-01-22T03:17:16.168483 image/svg+xml @@ -41,12 +41,12 @@ z - - + @@ -167,7 +167,7 @@ z - + @@ -204,7 +204,7 @@ z - + @@ -225,7 +225,7 @@ z - + @@ -257,7 +257,7 @@ z - + @@ -278,7 +278,7 @@ z - + @@ -299,7 +299,7 @@ z - + @@ -360,7 +360,7 @@ z - + @@ -381,7 +381,7 @@ z - + @@ -652,16 +652,16 @@ z +" clip-path="url(#pbb8d2b9fc8)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - - + @@ -696,18 +696,18 @@ z - + - + - + - + - + - + @@ -764,18 +764,18 @@ L 432.752252 230.090518 - + - + - + @@ -1032,18 +1032,18 @@ z - + - - - - - - - - - - - + + + + + + + + + + - + - - - - - - - - - - - + + + + + + + + + + +" clip-path="url(#pbb8d2b9fc8)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1307,7 +1307,7 @@ z - + @@ -1328,7 +1328,7 @@ z - + @@ -1349,7 +1349,7 @@ z - + @@ -1369,7 +1369,7 @@ z - + @@ -1390,7 +1390,7 @@ z - + @@ -1411,7 +1411,7 @@ z - + @@ -1431,7 +1431,7 @@ z - + @@ -1452,7 +1452,7 @@ z - + @@ -1500,101 +1500,101 @@ z +" clip-path="url(#p723f2efa67)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - + + + + + + + + + + + +" clip-path="url(#p723f2efa67)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1715,7 +1715,7 @@ z - + @@ -1736,7 +1736,7 @@ z - + @@ -1757,7 +1757,7 @@ z - + @@ -1777,7 +1777,7 @@ z - + @@ -1798,7 +1798,7 @@ z - + @@ -1819,7 +1819,7 @@ z - + @@ -1839,7 +1839,7 @@ z - + @@ -1860,7 +1860,7 @@ z - + @@ -1908,101 +1908,101 @@ z +" clip-path="url(#pf08b57edcd)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - + + + + + + + + + + + +" clip-path="url(#pf08b57edcd)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -2123,7 +2123,7 @@ z - + @@ -2144,7 +2144,7 @@ z - + @@ -2165,7 +2165,7 @@ z - + @@ -2185,7 +2185,7 @@ z - + @@ -2206,7 +2206,7 @@ z - + @@ -2227,7 +2227,7 @@ z - + @@ -2247,7 +2247,7 @@ z - + @@ -2268,7 +2268,7 @@ z - + @@ -2289,7 +2289,7 @@ z - + @@ -2310,7 +2310,7 @@ z - + @@ -2332,7 +2332,7 @@ z - + @@ -2433,113 +2433,113 @@ z +" clip-path="url(#p55a951beb2)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - - - - - - + + + + + + + + + + + + + + - - - - - - - - - - - - - - + + + + + + + + + + + + + + +" clip-path="url(#p55a951beb2)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -2878,7 +2878,7 @@ L 835.955625 46.691969 L 852.205625 46.691969 " style="fill: none; stroke: #ff4444; stroke-width: 5; stroke-linecap: square"/> - + @@ -2941,16 +2941,16 @@ z - + - + - + - + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg index 6a16fe8..66a3075 100644 --- a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_cross_entropy.svg @@ -6,7 +6,7 @@ - 2026-01-12T23:31:46.294935 + 2026-01-22T03:17:14.531728 image/svg+xml @@ -41,12 +41,12 @@ z - - + @@ -169,7 +169,7 @@ z - + @@ -236,7 +236,7 @@ z - + @@ -257,7 +257,7 @@ z - + @@ -277,7 +277,7 @@ z - + @@ -298,7 +298,7 @@ z - + @@ -319,7 +319,7 @@ z - + @@ -380,7 +380,7 @@ z - + @@ -401,7 +401,7 @@ z - + @@ -422,7 +422,7 @@ z - + @@ -469,7 +469,7 @@ z - + @@ -491,7 +491,7 @@ z - + @@ -516,16 +516,16 @@ z +" clip-path="url(#p0e700733a7)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - - + @@ -560,18 +560,18 @@ z - + - + - + @@ -581,18 +581,18 @@ L 384.123653 286.867396 - + - + - + @@ -602,18 +602,18 @@ L 384.123653 239.25196 - + - + - + @@ -1021,21 +1021,21 @@ z - + - - - - - - - - - - - - - - + + + + + + + + + + + + + - + - - - - - - - - - - - - - - + + + + + + + + + + + + + +" clip-path="url(#p0e700733a7)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1589,7 +1589,7 @@ L 130.874375 62.995875 L 145.874375 62.995875 " style="fill: none; stroke: #ff4444; stroke-width: 5; stroke-linecap: square"/> - + @@ -1701,7 +1701,7 @@ z - + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg index 242d013..d87b7b9 100644 --- a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_dsv3_with_layernorm.svg @@ -1,12 +1,12 @@ - + - 2026-01-08T16:35:17.806957 + 2026-01-22T03:17:12.903096 image/svg+xml @@ -21,19 +21,19 @@ - - @@ -41,17 +41,17 @@ z - - + - + - + - + - + - + @@ -225,12 +225,86 @@ z - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + + - + - + - + @@ -304,15 +378,15 @@ z - - + + - + - + - + @@ -328,24 +402,24 @@ z - - + + - + - - + - + - + - - + + - + - + - + - + @@ -395,19 +469,19 @@ L 441.930945 276.422362 - - + + - + - + - + - + - - + + - + - + - + - + - - + + - + - + - + - + @@ -518,19 +592,19 @@ L 441.930945 198.604635 - - + + - + - + - + - + - - + + - + - + - + - + @@ -587,31 +661,19 @@ L 441.930945 146.726151 - - + + - + - + - + - - - - + @@ -619,9 +681,9 @@ z - + - + - + + - - - - - - - - + + + + + + + + + + - - + + - - - - - - - - + + + + + + + + + + - - + + - - - - - - - + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + - - - + + - + - + - + @@ -1315,15 +1397,15 @@ z - - + + - + - + - + @@ -1336,15 +1418,15 @@ z - - + + - + - + - + @@ -1357,15 +1439,77 @@ z - - + + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + @@ -1377,15 +1521,15 @@ z - - + + - + - + - + @@ -1398,15 +1542,15 @@ z - - + + - + - + - + @@ -1422,164 +1566,176 @@ z - - + + - + - + - - + + - + - + - - + + - + - + - - + + - + - + - - + + - + - + - - + + - + - + - - + + - + - + - - + + - + - + - - - - - - - - - + + + + + + + + + + + + - - - - - - - - - + + + + + + + + + + + + - - + + - - - - - + - + - - - - + + - + - + - + @@ -1688,15 +1829,15 @@ z - - + + - + - + - + @@ -1709,15 +1850,15 @@ z - - + + - + - + - + @@ -1730,381 +1871,77 @@ z - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + - + - - + + - + - - - + + + - + - - + + - + - - - + + + - + - - + + - + - + - + @@ -2116,15 +1953,15 @@ z - - + + - + - + - + @@ -2137,15 +1974,15 @@ z - - + + - + - + - + @@ -2159,166 +1996,178 @@ z - - - - - - + + + + + + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - + + + - + - + - - - - - - - - - + + + + + + + + + + + + - - - - - - - - - + + + + + + + + + + + + - - + + - - + - - + - - + - - + - + - + - - + + + - - + + + @@ -2587,47 +2447,70 @@ z - - - - - - - - - - - + + + + + + + + + + + + + + + + - - + - + - + - + - @@ -2656,18 +2539,18 @@ z - - + - + - + - + @@ -2675,15 +2558,15 @@ L 859.98375 39.937812 - - + - + - + - - - - + + - - + + - - + + diff --git a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg index dac54ac..5c849b5 100644 --- a/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg +++ b/oink/benchmarks/media/sm100_fp16_oink_vs_quack_with_layernorm.svg @@ -6,7 +6,7 @@ - 2026-01-12T23:31:35.225900 + 2026-01-22T03:17:09.483028 image/svg+xml @@ -41,12 +41,12 @@ z - - + @@ -176,7 +176,7 @@ z - + @@ -243,7 +243,7 @@ z - + @@ -322,7 +322,7 @@ z - + @@ -343,7 +343,7 @@ z - + @@ -365,7 +365,7 @@ z - + @@ -414,7 +414,7 @@ z - + @@ -439,16 +439,16 @@ z +" clip-path="url(#p1b738c7a2f)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - - + @@ -483,18 +483,18 @@ z - + - + - + @@ -504,18 +504,18 @@ L 429.474812 286.915059 - + - + - + @@ -525,18 +525,18 @@ L 429.474812 239.283646 - + - + - + @@ -944,16 +944,16 @@ z - + - - - - - - - - - + + + + + + + + - + - - - - - - - - - + + + + + + + + +" clip-path="url(#p1b738c7a2f)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1247,7 +1247,7 @@ z - + @@ -1268,7 +1268,7 @@ z - + @@ -1289,7 +1289,7 @@ z - + @@ -1310,7 +1310,7 @@ z - + @@ -1332,7 +1332,7 @@ z - + @@ -1354,7 +1354,7 @@ z - + @@ -1379,93 +1379,93 @@ z +" clip-path="url(#pbffb2bfe56)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - + + + + + + + + + - - - - - - - - - + + + + + + + + + +" clip-path="url(#pbffb2bfe56)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1601,7 +1601,7 @@ z - + @@ -1622,7 +1622,7 @@ z - + @@ -1643,7 +1643,7 @@ z - + @@ -1664,7 +1664,7 @@ z - + @@ -1686,7 +1686,7 @@ z - + @@ -1708,7 +1708,7 @@ z - + @@ -1733,93 +1733,93 @@ z +" clip-path="url(#pdd076313b3)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - + + + + + + + + + - - - - - - - - - + + + + + + + + + +" clip-path="url(#pdd076313b3)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -1948,7 +1948,7 @@ z - + @@ -1969,7 +1969,7 @@ z - + @@ -1990,7 +1990,7 @@ z - + @@ -2011,7 +2011,7 @@ z - + @@ -2033,7 +2033,7 @@ z - + @@ -2055,7 +2055,7 @@ z - + @@ -2080,93 +2080,93 @@ z +" clip-path="url(#p7fa67217a3)" style="fill: none; stroke-dasharray: 3.2,5.76; stroke-dashoffset: 0; stroke: #b0b0b0; stroke-width: 0.8"/> - + - + - + - + - + - + - + - - - - - - - - - + + + + + + + + + - - - - - - - - - + + + + + + + + + +" clip-path="url(#p7fa67217a3)" style="fill: none; stroke-dasharray: 12,18; stroke-dashoffset: 0; stroke: #4d4d4d; stroke-width: 3"/> - + @@ -2522,7 +2522,7 @@ L 835.955625 46.691969 L 852.205625 46.691969 " style="fill: none; stroke: #ff4444; stroke-width: 5; stroke-linecap: square"/> - + @@ -2585,16 +2585,16 @@ z - + - + - + - + diff --git a/oink/benchmarks/readme/run_sm100_suite.py b/oink/benchmarks/readme/run_sm100_suite.py index fb9d603..af33e38 100644 --- a/oink/benchmarks/readme/run_sm100_suite.py +++ b/oink/benchmarks/readme/run_sm100_suite.py @@ -210,6 +210,23 @@ def script(name: str) -> str: os.path.join(out_dir, "rmsnorm_bwd_dsv3_wsame.json"), ], ), + ( + "fused_add_rmsnorm_dsv3", + [ + py, + script("benchmark_fused_add_rmsnorm_sm100.py"), + *common, + "--dsv3", + "--quack-baseline", + "kernel_inplace", + "--iters", + "200", + "--warmup-ms", + "25", + "--json", + os.path.join(out_dir, "fused_add_rmsnorm_dsv3.json"), + ], + ), ( "softmax_fwd_bwd_quack_suite", [ diff --git a/oink/src/kernelagent_oink/blackwell/cross_entropy.py b/oink/src/kernelagent_oink/blackwell/cross_entropy.py index 3e6eef1..d8b37ea 100644 --- a/oink/src/kernelagent_oink/blackwell/cross_entropy.py +++ b/oink/src/kernelagent_oink/blackwell/cross_entropy.py @@ -79,11 +79,17 @@ from cutlass.cute import runtime as rt from cutlass.cute.runtime import from_dlpack +from kernelagent_oink.blackwell.fast_launch import ( + StableI32Arg, + disable_fast_launch, + fast_launch_enabled, + set_runtime_ptr, + tls_cache as _tls_fast_launch_cache, +) from kernelagent_oink.blackwell.lite_quack import ( _KERNEL_ACCEPTS_LAYOUT_ARGS, TORCH2CUTE_DTYPE, ReductionBase, - domain_offset_i64, fill_oob, online_softmax_reduce, predicate_k, @@ -93,6 +99,454 @@ _BWD_COMPILE_CACHE: dict[tuple[type[cutlass.Numeric], int], cute.Kernel] = {} _PTR_FWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} _PTR_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} +_PTR_FWDBWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +class _PtrCrossEntropyFastLaunch: + def __init__( + self, + *, + compiled: object, + executor: object, + capi_func: object, + ptr_logits: object, + ptr_target: object, + ptr_aux_a: object, + ptr_aux_b: object, + ptr_aux_c: object | None, + arg_m: StableI32Arg, + arg_ld: StableI32Arg, + arg_ignore_index: StableI32Arg, + stream: cuda.CUstream, + packed_args: object, + keepalive: tuple[object, ...], + logits_align: int, + target_align: int, + aux_a_align: int, + aux_b_align: int, + aux_c_align: int | None, + ): + self._compiled = compiled + self._executor = executor + self._capi_func = capi_func + self._ptr_logits = ptr_logits + self._ptr_target = ptr_target + self._ptr_aux_a = ptr_aux_a + self._ptr_aux_b = ptr_aux_b + self._ptr_aux_c = ptr_aux_c + self._arg_m = arg_m + self._arg_ld = arg_ld + self._arg_ignore_index = arg_ignore_index + self._stream = stream + self._packed_args = packed_args + self._keepalive = keepalive + self._logits_align = int(logits_align) + self._target_align = int(target_align) + self._aux_a_align = int(aux_a_align) + self._aux_b_align = int(aux_b_align) + self._aux_c_align = int(aux_c_align) if aux_c_align is not None else None + + self._use_fast_launch = True + self._cuda_result = getattr(executor, "cuda_result", None) + + self._last_logits_ptr = -1 + self._last_target_ptr = -1 + self._last_aux_a_ptr = -1 + self._last_aux_b_ptr = -1 + self._last_aux_c_ptr = -1 + self._last_m = -1 + self._last_ld = -1 + self._last_ignore_index = None + + def launch( + self, + *, + logits_ptr: int, + target_ptr: int, + aux_a_ptr: int, + aux_b_ptr: int, + aux_c_ptr: int | None, + M: int, + ld: int, + ignore_index: int, + stream_handle: int, + dtype_logits: type[cutlass.Numeric], + aux_a_dtype: type[cutlass.Numeric], + aux_b_dtype: type[cutlass.Numeric], + aux_c_dtype: type[cutlass.Numeric] | None, + ) -> None: + if not fast_launch_enabled() or not self._use_fast_launch: + self._fallback_launch( + logits_ptr=logits_ptr, + target_ptr=target_ptr, + aux_a_ptr=aux_a_ptr, + aux_b_ptr=aux_b_ptr, + aux_c_ptr=aux_c_ptr, + M=M, + ld=ld, + ignore_index=ignore_index, + stream_handle=stream_handle, + dtype_logits=dtype_logits, + aux_a_dtype=aux_a_dtype, + aux_b_dtype=aux_b_dtype, + aux_c_dtype=aux_c_dtype, + ) + return + + if logits_ptr != self._last_logits_ptr: + try: + set_runtime_ptr(self._ptr_logits, logits_ptr) + self._last_logits_ptr = logits_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + logits_ptr=logits_ptr, + target_ptr=target_ptr, + aux_a_ptr=aux_a_ptr, + aux_b_ptr=aux_b_ptr, + aux_c_ptr=aux_c_ptr, + M=M, + ld=ld, + ignore_index=ignore_index, + stream_handle=stream_handle, + dtype_logits=dtype_logits, + aux_a_dtype=aux_a_dtype, + aux_b_dtype=aux_b_dtype, + aux_c_dtype=aux_c_dtype, + ) + return + + if target_ptr != self._last_target_ptr: + try: + set_runtime_ptr(self._ptr_target, target_ptr) + self._last_target_ptr = target_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + logits_ptr=logits_ptr, + target_ptr=target_ptr, + aux_a_ptr=aux_a_ptr, + aux_b_ptr=aux_b_ptr, + aux_c_ptr=aux_c_ptr, + M=M, + ld=ld, + ignore_index=ignore_index, + stream_handle=stream_handle, + dtype_logits=dtype_logits, + aux_a_dtype=aux_a_dtype, + aux_b_dtype=aux_b_dtype, + aux_c_dtype=aux_c_dtype, + ) + return + + if aux_a_ptr != self._last_aux_a_ptr: + try: + set_runtime_ptr(self._ptr_aux_a, aux_a_ptr) + self._last_aux_a_ptr = aux_a_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + logits_ptr=logits_ptr, + target_ptr=target_ptr, + aux_a_ptr=aux_a_ptr, + aux_b_ptr=aux_b_ptr, + aux_c_ptr=aux_c_ptr, + M=M, + ld=ld, + ignore_index=ignore_index, + stream_handle=stream_handle, + dtype_logits=dtype_logits, + aux_a_dtype=aux_a_dtype, + aux_b_dtype=aux_b_dtype, + aux_c_dtype=aux_c_dtype, + ) + return + + if aux_b_ptr != self._last_aux_b_ptr: + try: + set_runtime_ptr(self._ptr_aux_b, aux_b_ptr) + self._last_aux_b_ptr = aux_b_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + logits_ptr=logits_ptr, + target_ptr=target_ptr, + aux_a_ptr=aux_a_ptr, + aux_b_ptr=aux_b_ptr, + aux_c_ptr=aux_c_ptr, + M=M, + ld=ld, + ignore_index=ignore_index, + stream_handle=stream_handle, + dtype_logits=dtype_logits, + aux_a_dtype=aux_a_dtype, + aux_b_dtype=aux_b_dtype, + aux_c_dtype=aux_c_dtype, + ) + return + + if self._ptr_aux_c is not None and aux_c_ptr is not None: + if aux_c_ptr != self._last_aux_c_ptr: + try: + set_runtime_ptr(self._ptr_aux_c, aux_c_ptr) + self._last_aux_c_ptr = aux_c_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + logits_ptr=logits_ptr, + target_ptr=target_ptr, + aux_a_ptr=aux_a_ptr, + aux_b_ptr=aux_b_ptr, + aux_c_ptr=aux_c_ptr, + M=M, + ld=ld, + ignore_index=ignore_index, + stream_handle=stream_handle, + dtype_logits=dtype_logits, + aux_a_dtype=aux_a_dtype, + aux_b_dtype=aux_b_dtype, + aux_c_dtype=aux_c_dtype, + ) + return + + if M != self._last_m: + self._arg_m.set(M) + self._last_m = M + if ld != self._last_ld: + self._arg_ld.set(ld) + self._last_ld = ld + if ignore_index != self._last_ignore_index: + self._arg_ignore_index.set(ignore_index) + self._last_ignore_index = int(ignore_index) + + if self._cuda_result is not None: + self._cuda_result.value = 0 + ret = self._capi_func(self._packed_args) # type: ignore[misc] + if ret != 0: + raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}") + if self._cuda_result is not None: + err = int(self._cuda_result.value) + if err != 0: + raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})") + + def _disable_fast_launch(self) -> None: + self._use_fast_launch = False + disable_fast_launch() + + def _fallback_launch( + self, + *, + logits_ptr: int, + target_ptr: int, + aux_a_ptr: int, + aux_b_ptr: int, + aux_c_ptr: int | None, + M: int, + ld: int, + ignore_index: int, + stream_handle: int, + dtype_logits: type[cutlass.Numeric], + aux_a_dtype: type[cutlass.Numeric], + aux_b_dtype: type[cutlass.Numeric], + aux_c_dtype: type[cutlass.Numeric] | None, + ) -> None: + stream = cuda.CUstream(int(stream_handle)) + ptr_logits = rt.make_ptr( + dtype_logits, + int(logits_ptr), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._logits_align, + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + int(target_ptr), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._target_align, + ) + ptr_aux_a = rt.make_ptr( + aux_a_dtype, + int(aux_a_ptr), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._aux_a_align, + ) + ptr_aux_b = rt.make_ptr( + aux_b_dtype, + int(aux_b_ptr), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._aux_b_align, + ) + if ( + self._ptr_aux_c is not None + and aux_c_ptr is not None + and aux_c_dtype is not None + ): + ptr_aux_c = rt.make_ptr( + aux_c_dtype, + int(aux_c_ptr), + mem_space=rt.AddressSpace.gmem, + assumed_align=int(self._aux_c_align or 0), + ) + self._compiled( + ptr_logits, + ptr_target, + ptr_aux_a, + ptr_aux_b, + ptr_aux_c, + Int32(int(M)), + Int32(int(ld)), + Int32(int(ignore_index)), + stream, + ) + else: + self._compiled( + ptr_logits, + ptr_target, + ptr_aux_a, + ptr_aux_b, + Int32(int(M)), + Int32(int(ld)), + Int32(int(ignore_index)), + stream, + ) + + +def _get_fast_ptr_cross_entropy_launcher( + *, + compiled: object, + dtype_logits: type[cutlass.Numeric], + N: int, + device_index: int, + stream_handle: int, + mode: Literal["fwd", "bwd", "fwd_bwd"], +) -> _PtrCrossEntropyFastLaunch | None: + if not fast_launch_enabled(): + return None + key = ( + f"ptr_fast_{mode}", + id(compiled), + int(N), + dtype_logits, + int(device_index), + int(stream_handle), + ) + cache = _tls_fast_launch_cache() + cached = cache.get(key) + if cached is not None: + return cached # type: ignore[return-value] + + ptr_logits = rt.make_ptr( + dtype_logits, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, 0, mem_space=rt.AddressSpace.gmem, assumed_align=8 + ) + if mode == "fwd": + ptr_aux_a = rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4 + ) # loss + ptr_aux_b = rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4 + ) # lse + ptr_aux_c = None + aux_align_b = 4 + aux_align_c = None + elif mode == "bwd": + ptr_aux_a = rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4 + ) # dloss + ptr_aux_b = rt.make_ptr( + dtype_logits, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) # dx + ptr_aux_c = rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4 + ) # lse + aux_align_b = 16 + aux_align_c = 4 + elif mode == "fwd_bwd": + ptr_aux_a = rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4 + ) # dloss + ptr_aux_b = rt.make_ptr( + dtype_logits, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) # dx + ptr_aux_c = None + aux_align_b = 16 + aux_align_c = None + else: + raise ValueError(f"Unsupported mode: {mode}") + + arg_m = StableI32Arg(0) + arg_ld = StableI32Arg(N) + arg_ignore_index = StableI32Arg(-100) + stream = cuda.CUstream(int(stream_handle)) + executor = compiled.to(device_index) # type: ignore[attr-defined] + + try: + if ptr_aux_c is not None: + exe_args, adapted_args = executor.generate_execution_args( + ptr_logits, + ptr_target, + ptr_aux_a, + ptr_aux_b, + ptr_aux_c, + arg_m, + arg_ld, + arg_ignore_index, + stream, + ) + else: + exe_args, adapted_args = executor.generate_execution_args( + ptr_logits, + ptr_target, + ptr_aux_a, + ptr_aux_b, + arg_m, + arg_ld, + arg_ignore_index, + stream, + ) + packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined] + capi_func = compiled.capi_func # type: ignore[attr-defined] + except AttributeError: + disable_fast_launch() + return None + + keepalive: tuple[object, ...] = ( + executor, + ptr_logits, + ptr_target, + ptr_aux_a, + ptr_aux_b, + ptr_aux_c, + arg_m, + arg_ld, + arg_ignore_index, + stream, + *adapted_args, + ) + launcher = _PtrCrossEntropyFastLaunch( + compiled=compiled, + executor=executor, + capi_func=capi_func, + ptr_logits=ptr_logits, + ptr_target=ptr_target, + ptr_aux_a=ptr_aux_a, + ptr_aux_b=ptr_aux_b, + ptr_aux_c=ptr_aux_c, + arg_m=arg_m, + arg_ld=arg_ld, + arg_ignore_index=arg_ignore_index, + stream=stream, + packed_args=packed_args, + keepalive=keepalive, + logits_align=16, + target_align=8, + aux_a_align=4, + aux_b_align=aux_align_b, + aux_c_align=aux_align_c, + ) + cache[key] = launcher + return launcher def _convert_logits_2d(x: Tensor) -> cute.Tensor: @@ -261,9 +715,10 @@ def _kernel_impl( shape: cute.Shape = mX.shape idX = cute.make_identity_tensor(shape) - # Slice per-CTA region; use 64-bit indexing for large tensors. - mX_off = domain_offset_i64((bidx * tiler_mn[0], 0), mX) - gX = cute.local_tile(mX_off, tiler_mn, (0, cluster_y)) + # Quack-style CTA tiling: let CuTe compute the CTA offsets directly. + # (Avoids the extra 64-bit address arithmetic in `domain_offset_i64` on + # the common inference/benchmark sizes.) + gX = cute.local_tile(mX, tiler_mn, (bidx, cluster_y)) cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) smem = cutlass.utils.SmemAllocator() @@ -277,15 +732,28 @@ def _kernel_impl( ) # Copy setup: gmem -> smem via cp.async, 128-bit or narrower as needed. - num_copy_elems_X = tv_layout.shape[1][0] + num_copy_elems_X = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) num_copy_bits_X = mX.element_type.width * num_copy_elems_X copy_atom_load_X = cute.make_copy_atom( cute.nvgpu.cpasync.CopyG2SOp(), gX.element_type, num_bits_per_copy=num_copy_bits_X, ) - thr_copy_X = cute.make_tiled_copy( - copy_atom_load_X, tv_layout, tiler_mn + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems_X)) + thr_copy_X = cute.make_tiled_copy_tv( + copy_atom_load_X, thr_layout, val_layout ).get_slice(tidx) tXgX = thr_copy_X.partition_S(gX) @@ -321,14 +789,11 @@ def _kernel_impl( should_ignore = Boolean(target == ignore_index) - # Load the target logit if this row is not ignored. Use Int64 indexing - # to safely handle very large tensors. + # Load the target logit if this row is not ignored. target_logit = Float32.zero if row < shape[0] and tXcX[0][1] == 0 and not should_ignore: - mX_row = domain_offset_i64((row, 0), mX) - target_logit = Float32(mX_row[0, target]) + target_logit = Float32(mX[row, target]) - threads_per_row = tv_layout.shape[0][0] max_x, denom, _ = online_softmax_reduce( x, threads_per_row, @@ -398,6 +863,305 @@ def kernel( ) +class CrossEntropyFwdBwdSM100(ReductionBase): + """Fused cross-entropy forward+backward producing dx from (logits, target, dloss). + + This avoids materializing the intermediate `lse` (and loss) in global memory + when the only desired output is `dx` for `reduction="none"` semantics. + """ + + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64) + + def _calculate_threads_per_row(self) -> int: + N = self.N + return ( + 8 + if N <= 64 + else ( + 16 + if N <= 128 + else ( + 32 + if N <= 3072 + else (64 if N <= 6144 else (128 if N <= 16384 else 256)) + ) + ) + ) + + def _set_cluster_n(self) -> None: + N = self.N + if const_expr(self.dtype.width == 16): + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + @cute.jit + def __call__( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + ignore_index: Int32, + stream: cuda.CUstream, + ) -> None: + assert mX.element_type == self.dtype + assert mdX.element_type == self.dtype + self._set_cluster_n() + num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + kernel = ( + self.kernel( + mX, + mTarget, + mDLoss, + mdX, + ignore_index, + tv_layout, + tiler_mn, + ) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel( + mX, + mTarget, + mDLoss, + mdX, + ignore_index, + ) + ) + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_logits: cute.Pointer, + ptr_target: cute.Pointer, + ptr_dloss: cute.Pointer, + ptr_dx: cute.Pointer, + M: Int32, + ld: Int32, + ignore_index: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions.""" + ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + layout_m = cute.make_layout((M,), stride=(1,)) + mX = cute.make_tensor(ptr_logits, layout_mn) + mdX = cute.make_tensor(ptr_dx, layout_mn) + mTarget = cute.make_tensor(ptr_target, layout_m) + mDLoss = cute.make_tensor(ptr_dloss, layout_m) + self.__call__(mX, mTarget, mDLoss, mdX, ignore_index, stream) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + ignore_index: Int32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + cluster_y = ( + const_expr(0) + if const_expr(self.cluster_n == 1) + else cute.arch.block_idx()[1] + ) + + shape: cute.Shape = mX.shape + idX = cute.make_identity_tensor(shape) + + gX, gdX, cX = [ + cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mdX, idX) + ] + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar( + smem, tv_layout + ) + + num_copy_elems_X = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) + num_copy_bits_X = mX.element_type.width * num_copy_elems_X + copy_atom_load_X = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + gX.element_type, + num_bits_per_copy=num_copy_bits_X, + ) + copy_atom_store_dX = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gdX.element_type, + num_bits_per_copy=num_copy_bits_X, + ) + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems_X)) + thr_copy_X = cute.make_tiled_copy_tv( + copy_atom_load_X, thr_layout, val_layout + ).get_slice(tidx) + thr_copy_dX = cute.make_tiled_copy_tv( + copy_atom_store_dX, thr_layout, val_layout + ).get_slice(tidx) + + tXgX = thr_copy_X.partition_S(gX) + tXsX = thr_copy_X.partition_D(sX) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + tXcFull = thr_copy_X.partition_S(cX) + tXgdX = thr_copy_dX.partition_D(gdX) + + tXrX, tXrdX = [cute.make_fragment_like(thr) for thr in (tXgX, tXgdX)] + + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + self._initialize_cluster(tidx, mbar_ptr, num_warps) + + row = tXcX[0][0] + target = Int32.zero + dloss = Float32.zero + if row < shape[0]: + target = Int32(mTarget[row]) + should_ignore = Boolean(target == ignore_index) + dloss = Float32(mDLoss[row]) if not should_ignore else Float32.zero + + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + tXpX = ( + predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + if row < shape[0]: + cute.copy(copy_atom_load_X, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + if const_expr(not is_even_N): + fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + + _max_x, denom, exp_x = online_softmax_reduce( + x, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + phase=None, + return_exp_x=True, + ) + assert exp_x is not None + probs = exp_x * cute.arch.rcp_approx(denom) + prob_shifted = probs - 1.0 + + mask = cute.make_fragment_like(tXrX, cutlass.Boolean) + for i in cutlass.range(cute.size(tXcFull), unroll_full=True): + mask[i] = tXcFull[i][1] == target + grad = cute.where(mask.load(), prob_shifted, probs) + grad = grad * dloss + + tXrdX.store(grad.to(tXrdX.element_type)) + + tXpdX = ( + predicate_k(thr_copy_dX.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + if row < shape[0]: + cute.copy(copy_atom_store_dX, tXrdX, tXgdX, pred=tXpdX) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + ignore_index: Int32, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + self._kernel_impl( + mX, + mTarget, + mDLoss, + mdX, + ignore_index, + tv_layout, + tiler_mn, + ) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, # (M, N) + mTarget: cute.Tensor, # (M,) + mDLoss: cute.Tensor, # (M,) + mdX: cute.Tensor, # (M, N) + ignore_index: Int32, + ) -> None: + num_copy_bits = math.gcd(self.N, 128 // self.dtype.width) * self.dtype.width + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits) + self._kernel_impl( + mX, + mTarget, + mDLoss, + mdX, + ignore_index, + tv_layout, + tiler_mn, + ) + + class CrossEntropyBackwardSM100: """SM100-tuned cross-entropy backward kernel. @@ -565,13 +1329,17 @@ def _kernel_impl( ) idX = cute.make_identity_tensor(shape) - mX_off, mdX_off = [ - domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mdX) + # Quack-style CTA tiling: avoid extra 64-bit address arithmetic by + # letting CuTe compute the CTA offsets directly. + gX, gdX, cX = [ + cute.local_tile(mT, tiler_mn, (bidx, bidy)) for mT in (mX, mdX, idX) ] - gX, gdX = [cute.local_tile(mT, tiler_mn, (0, bidy)) for mT in (mX_off, mdX_off)] - cX = cute.local_tile(idX, tiler_mn, (bidx, bidy)) - num_copy_elems_X = tv_layout.shape[1][0] + num_copy_elems_X = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) num_copy_bits_X = mX.element_type.width * num_copy_elems_X copy_atom_load_X = cute.make_copy_atom( cute.nvgpu.cpasync.CopyG2SOp(), @@ -934,7 +1702,8 @@ def _cross_entropy_forward_ptr_into( device_index = logits.get_device() if torch.cuda.current_device() != device_index: torch.cuda.set_device(device_index) - stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) dtype_x = TORCH2CUTE_DTYPE[logits.dtype] key = ("ptr_fwd", int(N), dtype_x, int(device_index)) @@ -975,6 +1744,32 @@ def _cross_entropy_forward_ptr_into( ) _PTR_FWD_COMPILE_CACHE[key] = compiled + launcher = _get_fast_ptr_cross_entropy_launcher( + compiled=compiled, + dtype_logits=dtype_x, + N=int(N), + device_index=int(device_index), + stream_handle=stream_handle, + mode="fwd", + ) + if launcher is not None: + launcher.launch( + logits_ptr=int(logits.data_ptr()), + target_ptr=int(target.data_ptr()), + aux_a_ptr=int(loss.data_ptr()), + aux_b_ptr=int(lse.data_ptr()), + aux_c_ptr=None, + M=int(M), + ld=int(logits.stride(0)), + ignore_index=int(ignore_index), + stream_handle=stream_handle, + dtype_logits=dtype_x, + aux_a_dtype=cutlass.Float32, + aux_b_dtype=cutlass.Float32, + aux_c_dtype=None, + ) + return + ptr_logits = rt.make_ptr( dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 ) @@ -1037,7 +1832,8 @@ def _cross_entropy_backward_ptr_into( device_index = logits.get_device() if torch.cuda.current_device() != device_index: torch.cuda.set_device(device_index) - stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) dtype_x = TORCH2CUTE_DTYPE[logits.dtype] key = ("ptr_bwd", int(N), dtype_x, int(device_index)) @@ -1082,6 +1878,32 @@ def _cross_entropy_backward_ptr_into( ) _PTR_BWD_COMPILE_CACHE[key] = compiled + launcher = _get_fast_ptr_cross_entropy_launcher( + compiled=compiled, + dtype_logits=dtype_x, + N=int(N), + device_index=int(device_index), + stream_handle=stream_handle, + mode="bwd", + ) + if launcher is not None: + launcher.launch( + logits_ptr=int(logits.data_ptr()), + target_ptr=int(target.data_ptr()), + aux_a_ptr=int(dloss.data_ptr()), + aux_b_ptr=int(dx.data_ptr()), + aux_c_ptr=int(lse.data_ptr()), + M=int(M), + ld=int(logits.stride(0)), + ignore_index=int(ignore_index), + stream_handle=stream_handle, + dtype_logits=dtype_x, + aux_a_dtype=cutlass.Float32, + aux_b_dtype=dtype_x, + aux_c_dtype=cutlass.Float32, + ) + return + ptr_logits = rt.make_ptr( dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 ) @@ -1119,6 +1941,127 @@ def _cross_entropy_backward_ptr_into( ) +def _cross_entropy_fwd_bwd_ptr_into( + *, + logits: Tensor, + target: Tensor, + dloss: Tensor, + dx: Tensor, + ignore_index: int, +) -> None: + """Launch the fused pointer-based cross-entropy fwd+bwd kernel into preallocated `dx`.""" + assert logits.is_cuda and logits.dim() == 2 + assert target.is_cuda and target.dim() == 1 and target.shape[0] == logits.shape[0] + assert target.dtype is torch.int64 + assert ( + dloss.is_cuda + and dloss.shape == (logits.shape[0],) + and dloss.dtype is torch.float32 + ) + assert dx.is_cuda and dx.shape == logits.shape and dx.dtype == logits.dtype + assert dx.stride() == logits.stride(), ( + "Pointer path expects dx to match logits strides" + ) + + M, N = logits.shape + device_index = logits.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + + dtype_x = TORCH2CUTE_DTYPE[logits.dtype] + key = ("ptr_fwd_bwd", int(N), dtype_x, int(device_index)) + compiled = _PTR_FWDBWD_COMPILE_CACHE.get(key) + if compiled is None: + op = CrossEntropyFwdBwdSM100(dtype_x, int(N)) + ptr_logits = rt.make_ptr( + dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + target.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=8, + ) + ptr_dloss = rt.make_ptr( + cutlass.Float32, + dloss.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_logits, + ptr_target, + ptr_dloss, + ptr_dx, + Int32(int(M)), + Int32(int(logits.stride(0))), + Int32(int(ignore_index)), + stream, + ) + _PTR_FWDBWD_COMPILE_CACHE[key] = compiled + + launcher = _get_fast_ptr_cross_entropy_launcher( + compiled=compiled, + dtype_logits=dtype_x, + N=int(N), + device_index=int(device_index), + stream_handle=stream_handle, + mode="fwd_bwd", + ) + if launcher is not None: + launcher.launch( + logits_ptr=int(logits.data_ptr()), + target_ptr=int(target.data_ptr()), + aux_a_ptr=int(dloss.data_ptr()), + aux_b_ptr=int(dx.data_ptr()), + aux_c_ptr=None, + M=int(M), + ld=int(logits.stride(0)), + ignore_index=int(ignore_index), + stream_handle=stream_handle, + dtype_logits=dtype_x, + aux_a_dtype=cutlass.Float32, + aux_b_dtype=dtype_x, + aux_c_dtype=None, + ) + return + + ptr_logits = rt.make_ptr( + dtype_x, logits.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_target = rt.make_ptr( + cutlass.Int64, + target.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=8, + ) + ptr_dloss = rt.make_ptr( + cutlass.Float32, + dloss.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + compiled( + ptr_logits, + ptr_target, + ptr_dloss, + ptr_dx, + Int32(int(M)), + Int32(int(logits.stride(0))), + Int32(int(ignore_index)), + stream, + ) + + def cross_entropy_backward( dloss: Tensor, logits: Tensor, @@ -1157,6 +2100,70 @@ def cross_entropy_backward( return dx +def cross_entropy_fwd_bwd( + dloss: Tensor, + logits: Tensor, + target: Tensor, + ignore_index: int = -100, +) -> Tensor: + """Fused cross-entropy forward+backward producing ``dx`` for ``reduction='none'``. + + Computes per-logit gradients ``dx`` given: + - ``logits``: (M, N) + - ``target``: (M,) + - ``dloss``: (M,) upstream gradients (float32 recommended) + + The fast path avoids materializing intermediate ``lse`` in global memory. + """ + assert logits.dim() == 2, "logits must be 2D (M, N)" + assert target.dim() == 1, "target must be 1D (M,)" + assert dloss.dim() == 1, "dloss must be 1D (M,)" + assert logits.shape[0] == target.shape[0] == dloss.shape[0], ( + "Batch dimensions must match" + ) + assert logits.is_cuda and target.is_cuda and dloss.is_cuda, ( + "All tensors must be on CUDA device" + ) + assert logits.dtype in TORCH2CUTE_DTYPE, "Unsupported logits dtype" + + dx = torch.empty_like(logits) + + if ( + _can_use_ptr_path_logits(logits) + and _can_use_ptr_path_logits(dx) + and _can_use_ptr_path_target(target) + and _can_use_ptr_path_f32_1d(dloss) + and logits.stride() == dx.stride() + ): + _cross_entropy_fwd_bwd_ptr_into( + logits=logits, + target=target, + dloss=dloss, + dx=dx, + ignore_index=int(ignore_index), + ) + return dx + + # Fallback: reuse the existing forward+backward kernels (DLPack path handles + # any necessary dtype conversions). + with torch.no_grad(): + _loss, lse = cross_entropy_forward( + logits, + target, + ignore_index=int(ignore_index), + reduction="none", + ) + _cross_entropy_backward_sm100( + logits, + target, + dloss, + lse, + dx, + ignore_index=int(ignore_index), + ) + return dx + + def cross_entropy( logits: Tensor, target: Tensor, diff --git a/oink/src/kernelagent_oink/blackwell/fast_launch.py b/oink/src/kernelagent_oink/blackwell/fast_launch.py new file mode 100644 index 0000000..9b288f2 --- /dev/null +++ b/oink/src/kernelagent_oink/blackwell/fast_launch.py @@ -0,0 +1,115 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Host-side fast-launch helpers for CuTeDSL pointer entrypoints. + +CuTeDSL's Python runtime typically marshals each kernel call by allocating +`Int32` / `Float32` wrappers and runtime `Pointer` descriptors per invocation. +For latency-sensitive cases (small/medium M), this overhead can dominate. + +These helpers provide: +- Stable scalar argument wrappers (`StableI32Arg`, `StableF32Arg`) that avoid + per-call ctypes allocations. +- In-place mutation of runtime pointer descriptors (`set_runtime_ptr`) so a + compiled kernel can be launched repeatedly with different raw device pointers + without rebuilding argument objects. +- A small thread-local cache to store packed args objects (when supported by the + installed CuTeDSL version). + +All of this relies on a few private-ish CuTeDSL internals. Callers must treat +fast-launch as an optional optimization and fall back to the normal launch +path if those internals are unavailable. +""" + +from __future__ import annotations + +import ctypes +import os +import threading +from typing import Any + +_FAST_LAUNCH_TLS = threading.local() + + +def _env_flag(name: str, default: bool) -> bool: + val = os.environ.get(name) + if val is None: + return default + return val.strip().lower() not in {"0", "false", "no", "off", ""} + + +# Fast-launch uses internal CuTeDSL plumbing (packed args + pointer descriptors). +# Keep it enabled by default in our pinned environment, but allow disabling it +# via env var and auto-disable it if CuTeDSL internals change. +_ENABLE_FAST_LAUNCH = _env_flag("OINK_CUTEDSL_FAST_LAUNCH", default=True) +_FAST_LAUNCH_SUPPORTED = True + + +def fast_launch_enabled() -> bool: + return _ENABLE_FAST_LAUNCH and _FAST_LAUNCH_SUPPORTED + + +def disable_fast_launch() -> None: + global _FAST_LAUNCH_SUPPORTED + _FAST_LAUNCH_SUPPORTED = False + + +def tls_cache() -> dict[tuple[Any, ...], Any]: + cache = getattr(_FAST_LAUNCH_TLS, "cache", None) + if cache is None: + cache = {} + _FAST_LAUNCH_TLS.cache = cache + return cache + + +class StableI32Arg: + """A stable Int32 runtime arg (avoids per-call Int32().__c_pointers__ allocations).""" + + def __init__(self, value: int): + self._c_value = ctypes.c_int32(int(value)) + self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p) + + def set(self, value: int) -> None: + self._c_value.value = int(value) + + def __c_pointers__(self): + return [self._c_pointer] + + +class StableF32Arg: + """A stable Float32 runtime arg (avoids per-call Float32().__c_pointers__ allocations).""" + + def __init__(self, value: float): + self._c_value = ctypes.c_float(float(value)) + self._c_pointer = ctypes.cast(ctypes.pointer(self._c_value), ctypes.c_void_p) + + def set(self, value: float) -> None: + self._c_value.value = float(value) + + def __c_pointers__(self): + return [self._c_pointer] + + +def set_runtime_ptr(ptr: Any, device_ptr: int) -> None: + """Update a CuTeDSL runtime Pointer descriptor in-place. + + This relies on internal runtime pointer fields (`_desc`, `_pointer`, etc.). + If these internals change in a future CuTeDSL upgrade, this function may + raise AttributeError; callers should catch it and fall back. + """ + device_ptr = int(device_ptr) + ptr._pointer = device_ptr # type: ignore[attr-defined] + if getattr(ptr, "_c_pointer", None) is None: + ptr.__c_pointers__() # type: ignore[attr-defined] + ptr._desc.value = device_ptr # type: ignore[attr-defined] diff --git a/oink/src/kernelagent_oink/blackwell/layernorm.py b/oink/src/kernelagent_oink/blackwell/layernorm.py index 67f67ce..ada51ec 100644 --- a/oink/src/kernelagent_oink/blackwell/layernorm.py +++ b/oink/src/kernelagent_oink/blackwell/layernorm.py @@ -31,6 +31,7 @@ from __future__ import annotations import importlib.metadata +import math import os import re import operator @@ -70,27 +71,455 @@ from cutlass.cute import runtime as rt from cutlass.cute.runtime import from_dlpack -# Simple compile cache for the forward kernel -_COMPILE_CACHE: dict[Tuple[int, type[cutlass.Numeric], bool, bool, bool], object] = {} -_PTR_COMPILE_CACHE: dict[Tuple[object, ...], object] = {} - -# Backward compile caches: one for dx, one for parameter gradients. -_BWD_DX_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric]], object] = {} -_BWD_PARAM_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric], bool], object] = {} - -# Local helpers cloned from Quack via lite_quack so that this kernel does -# not depend on `quack` at runtime. -from kernelagent_oink.blackwell.lite_quack import ( # noqa: E402 +from kernelagent_oink.blackwell.lite_quack import ( _KERNEL_ACCEPTS_LAYOUT_ARGS, TORCH2CUTE_DTYPE, ReductionBase as _ReductionBase, convert_from_dlpack as convert_from_dlpack_cute, - domain_offset_i64, get_sm_count, predicate_k, row_reduce, warp_reduce, ) +from kernelagent_oink.blackwell.fast_launch import ( + StableF32Arg, + StableI32Arg, + disable_fast_launch, + fast_launch_enabled, + set_runtime_ptr, + tls_cache as _tls_fast_launch_cache, +) + +# Simple compile cache for the forward kernel +_COMPILE_CACHE: dict[Tuple[int, type[cutlass.Numeric], bool, bool, bool], object] = {} +_PTR_COMPILE_CACHE: dict[Tuple[object, ...], object] = {} + +# Backward compile caches: one for dx, one for parameter gradients. +_BWD_DX_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric]], object] = {} +_BWD_PARAM_COMPILE_CACHE: dict[Tuple[int, Type[cutlass.Numeric], bool], object] = {} + + +class _PtrLayernormFastLaunch: + def __init__( + self, + *, + compiled: object, + executor: object, + capi_func: object, + ptr_x: object, + ptr_w: object, + ptr_b: Optional[object], + ptr_out: object, + ptr_rstd: Optional[object], + ptr_mean: Optional[object], + arg_m: StableI32Arg, + arg_ld: StableI32Arg, + arg_eps: StableF32Arg, + stream: cuda.CUstream, + assumed_align_xo: int, + packed_args: object, + keepalive: tuple[object, ...], + ): + self._compiled = compiled + self._executor = executor + self._capi_func = capi_func + self._ptr_x = ptr_x + self._ptr_w = ptr_w + self._ptr_b = ptr_b + self._ptr_out = ptr_out + self._ptr_rstd = ptr_rstd + self._ptr_mean = ptr_mean + self._arg_m = arg_m + self._arg_ld = arg_ld + self._arg_eps = arg_eps + self._stream = stream + self._assumed_align_xo = int(assumed_align_xo) + self._packed_args = packed_args + self._keepalive = keepalive + + self._use_fast_launch = True + self._cuda_result = getattr(executor, "cuda_result", None) + + self._last_x_ptr = -1 + self._last_w_ptr = -1 + self._last_b_ptr = -1 + self._last_out_ptr = -1 + self._last_rstd_ptr = -1 + self._last_mean_ptr = -1 + self._last_m = -1 + self._last_ld = -1 + self._last_eps = float("nan") + + def launch( + self, + *, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + out: Tensor, + rstd: Optional[Tensor], + mean: Optional[Tensor], + M: int, + ld: int, + eps: float, + ) -> None: + if not fast_launch_enabled() or not self._use_fast_launch: + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + x_ptr = x.data_ptr() + if x_ptr != self._last_x_ptr: + try: + set_runtime_ptr(self._ptr_x, x_ptr) + self._last_x_ptr = x_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + w_ptr = weight.data_ptr() + if w_ptr != self._last_w_ptr: + try: + set_runtime_ptr(self._ptr_w, w_ptr) + self._last_w_ptr = w_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + if self._ptr_b is not None and bias is not None: + b_ptr = bias.data_ptr() + if b_ptr != self._last_b_ptr: + try: + set_runtime_ptr(self._ptr_b, b_ptr) + self._last_b_ptr = b_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + out_ptr = out.data_ptr() + if out_ptr != self._last_out_ptr: + try: + set_runtime_ptr(self._ptr_out, out_ptr) + self._last_out_ptr = out_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + if self._ptr_rstd is not None and rstd is not None: + rstd_ptr = rstd.data_ptr() + if rstd_ptr != self._last_rstd_ptr: + try: + set_runtime_ptr(self._ptr_rstd, rstd_ptr) + self._last_rstd_ptr = rstd_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + if self._ptr_mean is not None and mean is not None: + mean_ptr = mean.data_ptr() + if mean_ptr != self._last_mean_ptr: + try: + set_runtime_ptr(self._ptr_mean, mean_ptr) + self._last_mean_ptr = mean_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=M, + ld=ld, + eps=eps, + ) + return + + if M != self._last_m: + self._arg_m.set(M) + self._last_m = M + if ld != self._last_ld: + self._arg_ld.set(ld) + self._last_ld = ld + if eps != self._last_eps: + self._arg_eps.set(eps) + self._last_eps = eps + + if self._cuda_result is not None: + self._cuda_result.value = 0 + ret = self._capi_func(self._packed_args) # type: ignore[misc] + if ret != 0: + raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}") + if self._cuda_result is not None: + err = int(self._cuda_result.value) + if err != 0: + raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})") + + def _disable_fast_launch(self) -> None: + self._use_fast_launch = False + disable_fast_launch() + + def _fallback_launch( + self, + *, + x: Tensor, + weight: Tensor, + bias: Optional[Tensor], + out: Tensor, + rstd: Optional[Tensor], + mean: Optional[Tensor], + M: int, + ld: int, + eps: float, + ) -> None: + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + ptr_x = rt.make_ptr( + dtype_x, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_xo, + ) + ptr_out = rt.make_ptr( + dtype_x, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_xo, + ) + ptr_w = rt.make_ptr( + cutlass.Float32, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) + ptr_b = ( + rt.make_ptr( + cutlass.Float32, + bias.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=16, + ) + if bias is not None + else None + ) + ptr_rstd = ( + rt.make_ptr( + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if rstd is not None + else None + ) + ptr_mean = ( + rt.make_ptr( + cutlass.Float32, + mean.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=4, + ) + if mean is not None + else None + ) + self._compiled( + ptr_x, + ptr_w, + ptr_b, + ptr_out, + ptr_rstd, + ptr_mean, + Int32(int(M)), + Int32(int(ld)), + stream, + Float32(float(eps)), + ) + + +def _get_fast_ptr_layernorm_launcher( + *, + compiled: object, + N: int, + dtype_x: type[cutlass.Numeric], + has_bias: bool, + has_rstd: bool, + has_mean: bool, + device_index: int, + stream_handle: int, + assumed_align_xo: int, + eps: float, +) -> Optional[_PtrLayernormFastLaunch]: + if not fast_launch_enabled(): + return None + key = ( + "ptr_fast", + id(compiled), + int(N), + dtype_x, + bool(has_bias), + bool(has_rstd), + bool(has_mean), + int(device_index), + int(stream_handle), + int(assumed_align_xo), + ) + cache = _tls_fast_launch_cache() + cached = cache.get(key) + if cached is not None: + return cached # type: ignore[return-value] + + ptr_x = rt.make_ptr( + dtype_x, 0, mem_space=rt.AddressSpace.gmem, assumed_align=int(assumed_align_xo) + ) + ptr_out = rt.make_ptr( + dtype_x, 0, mem_space=rt.AddressSpace.gmem, assumed_align=int(assumed_align_xo) + ) + ptr_w = rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_b = ( + rt.make_ptr( + cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + if has_bias + else None + ) + ptr_rstd = ( + rt.make_ptr(cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4) + if has_rstd + else None + ) + ptr_mean = ( + rt.make_ptr(cutlass.Float32, 0, mem_space=rt.AddressSpace.gmem, assumed_align=4) + if has_mean + else None + ) + + arg_m = StableI32Arg(0) + arg_ld = StableI32Arg(N) + arg_eps = StableF32Arg(eps) + stream = cuda.CUstream(int(stream_handle)) + executor = compiled.to(device_index) # type: ignore[attr-defined] + + try: + exe_args, adapted_args = executor.generate_execution_args( + ptr_x, + ptr_w, + ptr_b, + ptr_out, + ptr_rstd, + ptr_mean, + arg_m, + arg_ld, + stream, + arg_eps, + ) + packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined] + capi_func = compiled.capi_func # type: ignore[attr-defined] + except AttributeError: + disable_fast_launch() + return None + + keepalive: tuple[object, ...] = ( + executor, + ptr_x, + ptr_w, + ptr_b, + ptr_out, + ptr_rstd, + ptr_mean, + arg_m, + arg_ld, + arg_eps, + stream, + *adapted_args, + ) + launcher = _PtrLayernormFastLaunch( + compiled=compiled, + executor=executor, + capi_func=capi_func, + ptr_x=ptr_x, + ptr_w=ptr_w, + ptr_b=ptr_b, + ptr_out=ptr_out, + ptr_rstd=ptr_rstd, + ptr_mean=ptr_mean, + arg_m=arg_m, + arg_ld=arg_ld, + arg_eps=arg_eps, + stream=stream, + assumed_align_xo=int(assumed_align_xo), + packed_args=packed_args, + keepalive=keepalive, + ) + cache[key] = launcher + return launcher def _convert_row_major(t: Tensor) -> cute.Tensor: @@ -121,17 +550,42 @@ class LayerNormSM100(_ReductionBase): - Dtype mapping and reduction helpers come from `lite_quack`. """ - def __init__(self, dtype: type[cutlass.Numeric], N: int): + def __init__( + self, + dtype: type[cutlass.Numeric], + N: int, + *, + copy_bits_x: Optional[int] = None, + direct_gmem: bool = False, + ): super().__init__(dtype, N, stage=2) # 2 stages for mean and var # Default reload policy mirrors Quack: use SMEM reload only for # very large hidden sizes. We keep this conservative for LayerNorm # and tune primarily via threads-per-block / cluster_n. self.reload_from: Optional[str] = None if N <= 16384 else "smem" - self.delay_w_load: bool = False + # SM100 tuning: for DSv3 hidden sizes where we fuse mean+var stats, + # delay loading fp32 weights/bias until after the reductions to lower + # register pressure. + self.delay_w_load: bool = bool(N in (4096, 6144, 7168, 8192)) + self.copy_bits_x: Optional[int] = ( + int(copy_bits_x) if copy_bits_x is not None else None + ) + self.direct_gmem: bool = bool(direct_gmem) + + def _get_num_threads(self) -> int: + nt = getattr(self, "_nt_override", None) + if nt is not None: + return int(nt) + return super()._get_num_threads() def _calculate_threads_per_row(self) -> int: + tpr = getattr(self, "_tpr_override", None) + if tpr is not None: + return int(tpr) # Match Quack's LayerNorm threads-per-row buckets. N = self.N + if N in (4096, 6144): + return 128 return ( 8 if N <= 64 @@ -188,7 +642,24 @@ def __call__( # Tiling and cluster policy (mirrors Quack LayerNorm). self._set_cluster_n() - tiler_mn, tv_layout = self._get_tv_layout() + largest_dtype_width = const_expr( + max( + t.element_type.width + for t in (mX, mW, mB, mO, mRstd, mMean) + if t is not None + ) + ) + # Match Quack's unified RMSNorm/LayerNorm kernel: pick vecsize based on + # the widest dtype participating in the op (e.g. fp32 weights => fp16 + # X uses 64b vectorization). + vecsize = math.gcd(self.N, 128 // largest_dtype_width) + default_copy_bits_x = vecsize * self.dtype.width + num_copy_bits_x = ( + int(self.copy_bits_x) + if self.copy_bits_x is not None + else default_copy_bits_x + ) + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits_x) num_threads = ( cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS @@ -275,10 +746,14 @@ def launch_from_ptrs( This reconstructs cute.Tensor views from raw device pointers + explicit layouts inside the JIT graph, reusing the tuned LayerNormSM100 schedule. """ - # The kernel uses 128-bit vectorized copies for X. Mirror Quack's - # `divisibility=128 // dtype.width` contract so the compiler can - # prove alignment for cp.async. - ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + # Mirror Quack-style divisibility contracts so the compiler can prove + # alignment for vectorized loads/stores (and cp.async when enabled). + divby = ( + int(self.copy_bits_x) // self.dtype.width + if const_expr(self.copy_bits_x is not None) + else (128 // self.dtype.width) + ) + ld_assumed = cute.assume(ld, divby=divby) # Match `mark_compact_shape_dynamic(mode=0, ...)`: M is dynamic, N is static. layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) layout_n = cute.make_layout((self.N,), stride=(1,)) @@ -338,9 +813,10 @@ def _kernel_impl( shape = mX.shape idX = cute.make_identity_tensor(shape) - # Slice for CTAs: use domain_offset_i64 to handle >2^31 elements. - mX, mO = [domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)] - gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)] + # Quack-style CTA tiling: let CuTe compute the CTA offsets directly. + # (Avoids the extra 64-bit address arithmetic in `domain_offset_i64` on + # the common inference/benchmark sizes.) + gX, gO = [cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO)] cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) gW = cute.local_tile(mW, tiler_mn, (0, cluster_y)) gB = ( @@ -359,118 +835,160 @@ def _kernel_impl( else None ) - # Copy atoms for X / W / B / O. + # Copy atoms for X / W / B / O (mirror Quack's vector-size contract). + num_copy_elems_x = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) + num_copy_bits_x = mX.element_type.width * num_copy_elems_x + num_copy_bits_x_async = const_expr(min(128, num_copy_bits_x)) copy_atom_load_X = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), mX.element_type, - num_bits_per_copy=128, + num_bits_per_copy=num_copy_bits_x, ) copy_atom_load_X_async = cute.make_copy_atom( cute.nvgpu.cpasync.CopyG2SOp(), mX.element_type, - num_bits_per_copy=128, + num_bits_per_copy=num_copy_bits_x_async, + ) + num_copy_bits_wb = const_expr( + min(128, mW.element_type.width * num_copy_elems_x) ) copy_atom_load_WB = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), mW.element_type, - num_bits_per_copy=128, + num_bits_per_copy=num_copy_bits_wb, ) copy_atom_store_O = cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), mO.element_type, - num_bits_per_copy=128, + num_bits_per_copy=num_copy_bits_x, ) - thr_copy_X = cute.make_tiled_copy( - copy_atom_load_X_async, - tv_layout, - tiler_mn, - ).get_slice(tidx) - thr_copy_WB = cute.make_tiled_copy( - copy_atom_load_WB, - tv_layout, - tiler_mn, - ).get_slice(tidx) - thr_copy_O = cute.make_tiled_copy( - copy_atom_store_O, - tv_layout, - tiler_mn, + # Quack-style partitioning: use `make_tiled_copy_tv` (2D thread/value + # layout) and let partitioning over the CTA tile handle the N loop. + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems_x)) + thr_copy = cute.make_tiled_copy_tv( + copy_atom_load_X, thr_layout, val_layout ).get_slice(tidx) - tWgW = thr_copy_WB.partition_S(gW) - tBgB = thr_copy_WB.partition_S(gB) if const_expr(gB is not None) else None - tXgX = thr_copy_X.partition_S(gX) - tXsX = thr_copy_X.partition_D(sX) - tXgO = thr_copy_O.partition_D(gO) - tXrRstd = ( - thr_copy_O.partition_D(gRstd) if const_expr(mRstd is not None) else None - ) - tXrMean = ( - thr_copy_O.partition_D(gMean) if const_expr(mMean is not None) else None - ) - tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + tXgX = thr_copy.partition_S(gX) + tXsX = thr_copy.partition_D(sX) + tXgO = thr_copy.partition_D(gO) + tXgW = thr_copy.partition_S(gW) + tXgB = thr_copy.partition_S(gB) if const_expr(gB is not None) else None + tXrRstd = thr_copy.partition_D(gRstd) if const_expr(mRstd is not None) else None + tXrMean = thr_copy.partition_D(gMean) if const_expr(mMean is not None) else None + tXcX = thr_copy.partition_S(cX)[(0, None), None, None] # Fragments for gmem->rmem. - tWrW = cute.make_fragment_like(tWgW) - tBrB = cute.make_fragment_like(tBgB) if const_expr(mB is not None) else None - tXrW = thr_copy_X.retile(tWrW) - tXrB = thr_copy_X.retile(tBrB) if const_expr(mB is not None) else None + tXrW = cute.make_fragment_like(tXgW) + tXrB = cute.make_fragment_like(tXgB) if const_expr(mB is not None) else None tXrX, tXrO = [cute.make_fragment_like(thr) for thr in (tXgX, tXgO)] num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE self._initialize_cluster(tidx, mbar_ptr, num_warps, is_persistent=False) - tXpX = predicate_k( - thr_copy_X.partition_S(cX), - limit=shape[1], + is_even_N = const_expr(shape[1] == tiler_mn[1] * self.cluster_n) + tXpX = ( + None if is_even_N else predicate_k(thr_copy.partition_S(cX), limit=shape[1]) ) row = tXcX[0][0] - if row < shape[0]: - cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX) - cute.arch.cp_async_commit_group() + if const_expr(not self.direct_gmem): + if row < shape[0]: + cute.copy(copy_atom_load_X_async, tXgX, tXsX, pred=tXpX) + cute.arch.cp_async_commit_group() - tWpW = predicate_k( - thr_copy_WB.partition_S(cX), - limit=shape[1], - ) if const_expr(not delay_w_load): - cute.copy(copy_atom_load_WB, tWgW, tWrW, pred=tWpW) + cute.copy(copy_atom_load_WB, tXgW, tXrW, pred=tXpX) if const_expr(mB is not None): - cute.copy(copy_atom_load_WB, tBgB, tBrB, pred=tWpW) - - cute.arch.cp_async_wait_group(0) - cute.autovec_copy(tXsX, tXrX) - x = tXrX.load().to(Float32) - threads_per_row = tv_layout.shape[0][0] - sum_x = row_reduce( - x, - cute.ReductionOp.ADD, - threads_per_row, - reduction_buffer[None, None, 0], - mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None, - init_val=0.0, - hook_fn=( - cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None - ), - ) - mean = sum_x / shape[1] + cute.copy(copy_atom_load_WB, tXgB, tXrB, pred=tXpX) - if const_expr(reload_from == "smem"): + if const_expr(not self.direct_gmem): + cute.arch.cp_async_wait_group(0) cute.autovec_copy(tXsX, tXrX) - x = tXrX.load().to(Float32) - elif const_expr(reload_from == "gmem"): - cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX) - x = tXrX.load().to(Float32) - - sum_sq_x_sub_mean = row_reduce( - (x - mean) * (x - mean), - cute.ReductionOp.ADD, - threads_per_row, - reduction_buffer[None, None, 1], - mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None, - init_val=0.0, - ) - rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True) + else: + if row < shape[0]: + cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX) + x = tXrX.load().to(Float32) + if const_expr(self.cluster_n == 1 and self.N in (4096, 6144, 7168, 8192)): + # SM100 tuning for DSv3 hidden sizes: + # Compute (sum_x, sum_x2) together so we can derive mean + variance + # without a second reduction pass (and without re-materializing + # x-mean for the variance reduction). + sum_x = x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0) + sum_x2 = (x * x).reduce( + cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0 + ) + sum_x = warp_reduce( + sum_x, + operator.add, + width=min(threads_per_row, cute.arch.WARP_SIZE), + ) + sum_x2 = warp_reduce( + sum_x2, + operator.add, + width=min(threads_per_row, cute.arch.WARP_SIZE), + ) + warps_per_row, cluster_n = reduction_buffer.shape[1] + if const_expr(warps_per_row > 1 or cluster_n > 1): + lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() + row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row + if lane_idx == 0: + reduction_buffer[row_idx, col_idx, 0] = sum_x + reduction_buffer[row_idx, col_idx, 1] = sum_x2 + cute.arch.barrier() + block_sum_x = 0.0 + block_sum_x2 = 0.0 + if lane_idx < warps_per_row: + block_sum_x = reduction_buffer[row_idx, lane_idx, 0] + block_sum_x2 = reduction_buffer[row_idx, lane_idx, 1] + sum_x = warp_reduce(block_sum_x, operator.add) + sum_x2 = warp_reduce(block_sum_x2, operator.add) + mean = sum_x / shape[1] + var = sum_x2 / shape[1] - mean * mean + var = cute.arch.fmax(var, 0.0) + rstd = cute.math.rsqrt(var + eps, fastmath=True) + else: + sum_x = row_reduce( + x, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + mbar_ptr + 0 if const_expr(self.cluster_n > 1) else None, + init_val=0.0, + hook_fn=( + cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None + ), + ) + mean = sum_x / shape[1] + + if const_expr(reload_from == "smem"): + cute.autovec_copy(tXsX, tXrX) + x = tXrX.load().to(Float32) + elif const_expr(reload_from == "gmem"): + cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX) + x = tXrX.load().to(Float32) + + sum_sq_x_sub_mean = row_reduce( + (x - mean) * (x - mean), + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 1], + mbar_ptr + 1 if const_expr(self.cluster_n > 1) else None, + init_val=0.0, + ) + rstd = cute.math.rsqrt(sum_sq_x_sub_mean / shape[1] + eps, fastmath=True) if const_expr(mRstd is not None): if ( @@ -489,9 +1007,9 @@ def _kernel_impl( tXrMean[0] = mean if const_expr(delay_w_load): - cute.copy(copy_atom_load_WB, tWgW, tWrW, pred=tWpW) + cute.copy(copy_atom_load_WB, tXgW, tXrW, pred=tXpX) if const_expr(mB is not None): - cute.copy(copy_atom_load_WB, tBgB, tBrB, pred=tWpW) + cute.copy(copy_atom_load_WB, tXgB, tXrB, pred=tXpX) if const_expr(reload_from == "smem"): cute.autovec_copy(tXsX, tXrX) @@ -508,12 +1026,8 @@ def _kernel_impl( y = y + b tXrO.store(y.to(tXrO.element_type)) - tOpO = predicate_k( - thr_copy_O.partition_S(cX), - limit=shape[1], - ) if row < shape[0]: - cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tOpO) + cute.copy(copy_atom_store_O, tXrO, tXgO, pred=tXpX) if _KERNEL_ACCEPTS_LAYOUT_ARGS: @@ -558,7 +1072,24 @@ def kernel( mMean: Optional[cute.Tensor], eps: Float32, ): - tiler_mn, tv_layout = self._get_tv_layout() + largest_dtype_width = const_expr( + max( + mX.element_type.width, + mW.element_type.width, + mB.element_type.width if const_expr(mB is not None) else 0, + mO.element_type.width, + mRstd.element_type.width if const_expr(mRstd is not None) else 0, + mMean.element_type.width if const_expr(mMean is not None) else 0, + ) + ) + vecsize = math.gcd(self.N, 128 // largest_dtype_width) + default_copy_bits_x = vecsize * mX.element_type.width + num_copy_bits_x = ( + int(self.copy_bits_x) + if const_expr(self.copy_bits_x is not None) + else default_copy_bits_x + ) + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=num_copy_bits_x) self._kernel_impl( mX, mW, @@ -775,6 +1306,37 @@ def _layernorm_forward_ptr_into( stream = cuda.CUstream(stream_handle) dtype_x = TORCH2CUTE_DTYPE[x.dtype] + # Keep the pointer path aligned with Quack's LayerNorm schedule: + # - <=128b vectorization (cp.async-compatible) + # - shared-memory staging for X (gmem->smem->rmem) to amortize global latency + direct_gmem = False + copy_bits_x: Optional[int] = None + assumed_align_xo = 16 + + # DSv3 hidden sizes are often latency-bound on small M. For these N buckets, + # a direct-GMEM schedule (skip gmem->smem cp.async) can reduce overhead. + # + # Keep the Quack-like staged path for large M where cp.async overlap tends to win. + if dtype_x.width == 16: + # DSv3 default hidden size (7168) is a common inference hot shape and + # benefits from the lower-overhead direct-GMEM path on this SM100. + if N == 7168 and M <= 65536: + direct_gmem = True + elif N == 8192 and M <= 16384: + direct_gmem = True + + # DSv3 smallest point (M=4096, N=7168) is latency-sensitive. Increasing + # per-row parallelism improves the reduction path and consistently beats + # Quack on this machine. + tpr_override: Optional[int] = None + nt_override: Optional[int] = None + if dtype_x.width == 16 and N == 7168 and M <= 4096: + tpr_override = 224 + nt_override = 224 + + # NOTE: We previously experimented with a direct-GMEM + 256b vectorized + # schedule for N=4096, but it was consistently slower on this GB200. + # Keep the pointer path on the Quack-like staged (cp.async) schedule. key = ( "ptr", int(N), @@ -782,16 +1344,36 @@ def _layernorm_forward_ptr_into( bias is not None, rstd is not None, mean is not None, + bool(direct_gmem), + int(copy_bits_x) if copy_bits_x is not None else None, + tpr_override, + nt_override, + int(assumed_align_xo), int(device_index), ) compiled = _PTR_COMPILE_CACHE.get(key) if compiled is None: - op = LayerNormSM100(dtype_x, int(N)) + op = LayerNormSM100( + dtype_x, + int(N), + copy_bits_x=copy_bits_x, + direct_gmem=direct_gmem, + ) + if tpr_override is not None: + op._tpr_override = tpr_override # type: ignore[attr-defined] + if nt_override is not None: + op._nt_override = nt_override # type: ignore[attr-defined] ptr_x = rt.make_ptr( - dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype_x, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_xo, ) ptr_out = rt.make_ptr( - dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype_x, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_xo, ) ptr_w = rt.make_ptr( cutlass.Float32, @@ -845,11 +1427,44 @@ def _layernorm_forward_ptr_into( ) _PTR_COMPILE_CACHE[key] = compiled + launcher = _get_fast_ptr_layernorm_launcher( + compiled=compiled, + N=int(N), + dtype_x=dtype_x, + has_bias=bias is not None, + has_rstd=rstd is not None, + has_mean=mean is not None, + device_index=int(device_index), + stream_handle=stream_handle, + assumed_align_xo=int(assumed_align_xo), + eps=float(eps), + ) + ld_val = int(x.stride(0)) + if launcher is not None: + launcher.launch( + x=x, + weight=weight, + bias=bias, + out=out, + rstd=rstd, + mean=mean, + M=int(M), + ld=ld_val, + eps=float(eps), + ) + return + ptr_x = rt.make_ptr( - dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype_x, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_xo, ) ptr_out = rt.make_ptr( - dtype_x, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype_x, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_xo, ) ptr_w = rt.make_ptr( cutlass.Float32, @@ -887,7 +1502,7 @@ def _layernorm_forward_ptr_into( if mean is not None else None ) - ld = Int32(int(x.stride(0))) + ld = Int32(ld_val) compiled( ptr_x, ptr_w, diff --git a/oink/src/kernelagent_oink/blackwell/lite_quack.py b/oink/src/kernelagent_oink/blackwell/lite_quack.py index e8ce93a..c7402d8 100644 --- a/oink/src/kernelagent_oink/blackwell/lite_quack.py +++ b/oink/src/kernelagent_oink/blackwell/lite_quack.py @@ -39,7 +39,7 @@ from cutlass import Float32, Int32, const_expr from cutlass.cute.runtime import from_dlpack from cutlass.cutlass_dsl import T, dsl_user_op -from cutlass._mlir.dialects import llvm, vector +from cutlass._mlir.dialects import llvm, nvvm, vector def _parse_version_tuple(version: str) -> tuple[int, int, int]: @@ -69,6 +69,24 @@ def _cutlass_dsl_version() -> Optional[tuple[int, int, int]]: _CUTLASS_DSL_VERSION is not None and _CUTLASS_DSL_VERSION < (4, 3, 4) ) +# Cache device properties lookups (notably `multi_processor_count`) since some +# dispatch paths call `get_sm_count` inside tight benchmark loops. +_DEVICE_NUM_SMS_CACHE: dict[int, int] = {} + + +def get_num_sms(device: torch.device) -> int: + """Return the number of SMs for a CUDA device (cached).""" + device_index = device.index + if device_index is None: + device_index = torch.cuda.current_device() + device_index = int(device_index) + cached = _DEVICE_NUM_SMS_CACHE.get(device_index) + if cached is not None: + return cached + num_sms = int(torch.cuda.get_device_properties(device_index).multi_processor_count) + _DEVICE_NUM_SMS_CACHE[device_index] = num_sms + return num_sms + # ------------------------- # Dtype mapping (from quack.cute_dsl_utils) @@ -178,6 +196,43 @@ def store_shared_remote( ) +@dsl_user_op +def atomic_add_f32( + a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None +) -> Float32: + """Atomic add into global memory (float32).""" + return nvvm.atomicrmw( + res=T.f32(), + op=nvvm.AtomicOpKind.FADD, + ptr=gmem_ptr.llvm_ptr, + a=Float32(a).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + +@cute.jit +def atomic_add_tensor_f32( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, +) -> None: + """Atomic-add a register fragment into a GMEM tile (float32).""" + if const_expr(pred is None): + for i in cutlass.range_constexpr(cute.size(src.shape)): + coord = cute.idx2crd(i, src.shape) + atomic_add_f32(src[i], elem_pointer(dst, coord)) + else: + for i in cutlass.range_constexpr(cute.size(src.shape)): + # CuTeDSL 4.3.4+ disallows introducing new tuple-typed values inside + # a dynamic `if`. Compute `coord` unconditionally, then predicate the + # atomic update. + coord = cute.idx2crd(i, src.shape) + if pred[i]: + atomic_add_f32(src[i], elem_pointer(dst, coord)) + + @cute.jit def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if". @@ -318,9 +373,7 @@ def warp_reduce( for i in cutlass.range_constexpr(cute.size(val.shape)): res[i] = warp_reduce(res[i], op, width) return res.load() - for i in cutlass.range_constexpr(int(math.log2(width))): - val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) - return val + return cute.arch.warp_reduction(val, op, threads_in_group=width) @cute.jit @@ -623,7 +676,10 @@ def get_copy_atom( ) -> cute.CopyAtom: from cutlass.cute.nvgpu import cpasync - num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + # cp.async is limited to 128b per op; synchronous vectorized copies can go wider. + max_bits = const_expr(128 if is_async else 256) + num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width)) + # Match Quack's default cp.async cache policy (leave cache_mode unspecified). copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() return cute.make_copy_atom( copy_op, dtype, num_bits_per_copy=num_copy_bits, loc=loc, ip=ip @@ -678,6 +734,20 @@ def _get_num_threads(self) -> int: def _get_tv_layout( self, num_copy_bits: int = 128 ) -> Tuple[cute.Shape, cute.Layout]: + """Return (tiler_mn, tv_layout) for SM100 reduction kernels. + + This intentionally mirrors Quack's `ReductionBase._get_tiled_copy(...)`: + - `tiler_mn` spans the full N range for the CTA, including any "K-loop" + repeats (`num_blocks_N`). + - `tv_layout` is the *tiled* thread/value layout used by CuTe's copy + partitioning (does **not** bake in `num_blocks_N`), matching + `quack.copy_utils.tiled_copy_2d(...).layout_tv_tiled`. + """ + if num_copy_bits > 128: + raise ValueError( + f"num_copy_bits={num_copy_bits} exceeds 128b; Quack-style SM100 reduction " + "tiling assumes <=128b vectorization (cp.async and common CopyAtoms)." + ) vecsize = num_copy_bits // self.dtype.width assert self.N % vecsize == 0, ( f"Input N {self.N} is not divisible by vector size {vecsize}" @@ -692,30 +762,56 @@ def _get_tv_layout( ) cols_per_block = num_threads // threads_per_row tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) - tv_layout = cute.make_layout( - ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)), - stride=( - (vecsize * cols_per_block, 1), - (cols_per_block, cols_per_block * vecsize * threads_per_row), - ), + + # Construct the same tv layout that Quack gets from `tiled_copy_2d(...).layout_tv_tiled`. + copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype, + num_bits_per_copy=num_copy_bits, + ) + thr_layout = cute.make_ordered_layout( + (cols_per_block, threads_per_row), + order=(1, 0), ) + val_layout = cute.make_layout((1, vecsize)) + tv_layout = cute.make_tiled_copy_tv( + copy_atom, thr_layout, val_layout + ).layout_tv_tiled return tiler_mn, tv_layout def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int: - return ( - cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn)) - + self.stage - * num_warps - * self.cluster_n - * (self.reduction_dtype.width // 8) - + self.stage * (cutlass.Int64.width // 8) + # Mirror the allocation order used by the SM100 reduction kernels: + # 1) sX (byte_alignment=16) + # 2) reduction_buffer (byte_alignment=8) + # 3) mbar_ptr (Int64, 8B) + # + # CuTeDSL's SmemAllocator may insert padding between allocations to satisfy + # alignment. Be conservative and round up offsets accordingly so we never + # under-allocate dynamic shared memory. + + def _align_up(x: int, align: int) -> int: + return ((x + align - 1) // align) * align + + sx_bytes = int(cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))) + reduction_bytes = int( + self.stage * num_warps * self.cluster_n * (self.reduction_dtype.width // 8) ) + mbar_bytes = int(self.stage * (cutlass.Int64.width // 8)) + + offset = _align_up(sx_bytes, 16) + offset = _align_up(offset, 8) + reduction_bytes + offset = _align_up(offset, 8) + mbar_bytes + return int(offset) def _get_reduction_buffer_layout( self, tv_layout: cute.Layout, cluster_n: int ) -> cute.Layout: num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE - warps_per_row = max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1) + warps_per_row = ( + num_warps + if cutlass.const_expr(cute.rank(tv_layout.shape[0]) == 1) + else max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1) + ) return cute.make_ordered_layout( (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage), order=(1, 0, 2), @@ -730,7 +826,7 @@ def _allocate_reduction_buffer_and_mbar( reduction_buffer = smem.allocate_tensor( self.reduction_dtype, self._get_reduction_buffer_layout(tv_layout, self.cluster_n), - byte_alignment=4, + byte_alignment=8, ) if cutlass.const_expr(self.cluster_n > 1): mbar_ptr = smem.allocate_array( @@ -771,6 +867,9 @@ def __init__(self, dtype: cutlass.Numeric, N: int): # 2 stages for double buffering when computing mean of x_hat * wdy super().__init__(dtype, N, stage=2, reduction_dtype=Float32) self.reload_wdy = None if N <= 16 * 1024 else "smem" + # Optional optimization: atomically accumulate mdW into a single (N,) + # buffer instead of writing an (sm_count, N) partial buffer + torch.sum. + self.atomic_dw = False if self.N > 128 * 1024 and self.dtype.width >= 32: raise ValueError( "RMSNormBackward does not support N > 128k with dtype >= 32 bits" @@ -856,15 +955,18 @@ def new_stride(t): largest_dtype_width = const_expr( max( mX.element_type.width, + mW.element_type.width if mW is not None else 0, mdO.element_type.width, mdX.element_type.width, mdResO.element_type.width if mdResO is not None else 0, mdRes.element_type.width if mdRes is not None else 0, ) ) - tiler_mn, tv_layout = self._get_tv_layout( - num_copy_bits=128 // largest_dtype_width * mX.element_type.width - ) + # Quack-style policy: cap the *largest* dtype to 128b, then scale the + # activation copy width down proportionally (e.g. fp16 + fp32-weight + # => 64b activation vectors so the fp32 path stays at 128b). + num_copy_bits = const_expr(128 // largest_dtype_width * mX.element_type.width) + tiler_mn, tv_layout = self._get_tv_layout(num_copy_bits=int(num_copy_bits)) num_threads = ( cute.size(tv_layout, mode=[0]) if _KERNEL_ACCEPTS_LAYOUT_ARGS @@ -941,12 +1043,25 @@ def _kernel_impl( else: mbar_full_ptr, mbar_empty_ptr = None, None - num_copy_elems_X = tv_layout.shape[1][0] + num_copy_elems_X = ( + tv_layout.shape[1] + if cutlass.const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if cutlass.const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) copy_atom_load_X = get_copy_atom( mX.element_type, num_copy_elems_X, is_async=False ) - thr_copy_X = cute.make_tiled_copy( - copy_atom_load_X, tv_layout, tiler_mn + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems_X)) + thr_copy_X = cute.make_tiled_copy_tv( + copy_atom_load_X, thr_layout, val_layout ).get_slice(tidx) copy_fn = partial(copy, num_copy_elems=num_copy_elems_X) @@ -1025,7 +1140,6 @@ def _kernel_impl( if const_expr(self.cluster_n > 1): cute.arch.cluster_wait() - threads_per_row = tv_layout.shape[0][0] if const_expr(mdW is not None): tXrdW.fill(0.0) if const_expr(mdB is not None): @@ -1165,7 +1279,10 @@ def _kernel_impl( ) cute.autovec_copy(tXsdW_other, tXrdW_other) tXrdW.store(tXrdW.load() + tXrdW_other.load()) - copy_fn(tXrdW, tXgdW, pred=tXpX) + if const_expr(self.atomic_dw): + atomic_add_tensor_f32(tXrdW, tXgdW, pred=tXpX) + else: + copy_fn(tXrdW, tXgdW, pred=tXpX) cute.arch.barrier() if const_expr(mdB is not None): sdB = cute.make_tensor( @@ -1190,7 +1307,10 @@ def _kernel_impl( copy_fn(tXrdB, tXgdB, pred=tXpX) else: if const_expr(mdW is not None): - copy_fn(tXrdW, tXgdW, pred=tXpX) + if const_expr(self.atomic_dw): + atomic_add_tensor_f32(tXrdW, tXgdW, pred=tXpX) + else: + copy_fn(tXrdW, tXgdW, pred=tXpX) if const_expr(mdB is not None): copy_fn(tXrdB, tXgdB, pred=tXpX) @@ -1295,8 +1415,7 @@ def get_sm_count( increased to improve SM occupancy, matching the existing SM100 tuning used by both RMSNorm and LayerNorm. """ - props = torch.cuda.get_device_properties(device) - num_sms = props.multi_processor_count + num_sms = get_num_sms(device) sm_count_multiple = ( 16 diff --git a/oink/src/kernelagent_oink/blackwell/rmsnorm.py b/oink/src/kernelagent_oink/blackwell/rmsnorm.py index e921947..252df6a 100644 --- a/oink/src/kernelagent_oink/blackwell/rmsnorm.py +++ b/oink/src/kernelagent_oink/blackwell/rmsnorm.py @@ -86,6 +86,35 @@ _PTR_FAST_LAUNCH_TLS = threading.local() +# Cache a (1, sm_count) fp32 ones row used for GEMM-based dw/db partial reductions. +# +# On SM100, `dw_partial.sum(dim=0)` can be a double-digit microsecond tail for +# Quack-suite small shapes (e.g. M=8192, N=4096). A cached GEMM-based reduction +# is consistently faster and avoids per-call allocation overhead. +_DW_REDUCE_ONES_CACHE: dict[tuple[int, int], Tensor] = {} + + +def _get_dw_reduce_ones(device_index: int, sm_count: int) -> Tensor: + key = (int(device_index), int(sm_count)) + ones = _DW_REDUCE_ONES_CACHE.get(key) + if ones is None or ones.shape != (1, sm_count) or ones.device.index != device_index: + ones = torch.ones( + (1, sm_count), + device=torch.device("cuda", device_index), + dtype=torch.float32, + ) + _DW_REDUCE_ONES_CACHE[key] = ones + return ones + + +def _reduce_partial_sum_fp32(partial: Tensor, *, device_index: int) -> Tensor: + """Reduce a (sm_count, N) fp32 partial buffer into an (N,) fp32 result.""" + assert partial.dtype is torch.float32 + assert partial.dim() == 2 + ones = _get_dw_reduce_ones(device_index, int(partial.shape[0])) + return torch.mm(ones, partial).squeeze(0) + + def _env_flag(name: str, default: bool) -> bool: val = os.environ.get(name) if val is None: @@ -213,6 +242,13 @@ def _probe_cluster_direct_gmem_max_copy_bits() -> int: """ env = os.environ.copy() + # The probe runs in a fresh subprocess, so it won't inherit any + # benchmark-harness sys.path tweaks. Ensure the in-tree Oink source is + # importable so `import kernelagent_oink...` works reliably. + oink_src = os.path.abspath(os.path.join(_HERE, "..", "..")) + if os.path.isdir(oink_src): + py_path = env.get("PYTHONPATH") + env["PYTHONPATH"] = oink_src + (os.pathsep + py_path if py_path else "") env["PYTHONNOUSERSITE"] = "1" def run_probe(copy_bits: int, assumed_align: int): @@ -325,6 +361,8 @@ def _direct_gmem_from_policy(*, default: bool) -> bool: def _copy_bits_from_policy(*, default: int, can_use_256: bool) -> int: """Resolve copy width (in bits) from the (import-time) policy string.""" + if _COPY_BITS_POLICY in {"64"}: + return 64 if _COPY_BITS_POLICY in {"128"}: return 128 if _COPY_BITS_POLICY in {"256"} and can_use_256: @@ -398,6 +436,8 @@ def __init__( arg_ld: _StableI32Arg, arg_eps: _StableF32Arg, stream: cuda.CUstream, + assumed_align: int, + weight_dtype: Optional[type[cutlass.Numeric]], packed_args: object, keepalive: tuple[object, ...], ): @@ -412,6 +452,8 @@ def __init__( self._arg_ld = arg_ld self._arg_eps = arg_eps self._stream = stream + self._assumed_align = int(assumed_align) + self._weight_dtype = weight_dtype self._packed_args = packed_args self._keepalive = keepalive @@ -520,17 +562,23 @@ def _fallback_launch( # (e.g. due to a CuTeDSL upgrade), fall back to the regular call path. dtype = TORCH2CUTE_DTYPE[x.dtype] ptr_x = rt.make_ptr( - dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, ) ptr_out = rt.make_ptr( - dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, ) ptr_w = ( rt.make_ptr( - dtype, + self._weight_dtype or dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, - assumed_align=16, + assumed_align=self._assumed_align, ) if weight is not None else None @@ -717,14 +765,477 @@ def _fallback_launch( ) +class _PtrRmsnormBwdFastLaunch: + def __init__( + self, + *, + compiled: object, + executor: object, + capi_func: object, + ptr_x: object, + ptr_w: Optional[object], + ptr_dout: object, + ptr_rstd: object, + ptr_dx: object, + ptr_dw_partial: Optional[object], + arg_m: _StableI32Arg, + arg_n: _StableI32Arg, + arg_ld: _StableI32Arg, + arg_sm_count: _StableI32Arg, + stream: cuda.CUstream, + assumed_align_x: int, + assumed_align_w: int, + assumed_align_dw: int, + weight_dtype: Optional[type[cutlass.Numeric]], + packed_args: object, + keepalive: tuple[object, ...], + ): + self._compiled = compiled + self._executor = executor + self._capi_func = capi_func + self._ptr_x = ptr_x + self._ptr_w = ptr_w + self._ptr_dout = ptr_dout + self._ptr_rstd = ptr_rstd + self._ptr_dx = ptr_dx + self._ptr_dw_partial = ptr_dw_partial + self._arg_m = arg_m + self._arg_n = arg_n + self._arg_ld = arg_ld + self._arg_sm_count = arg_sm_count + self._stream = stream + self._assumed_align_x = int(assumed_align_x) + self._assumed_align_w = int(assumed_align_w) + self._assumed_align_dw = int(assumed_align_dw) + self._weight_dtype = weight_dtype + self._packed_args = packed_args + self._keepalive = keepalive + + self._use_fast_launch = True + self._cuda_result = getattr(executor, "cuda_result", None) + + self._last_x_ptr = -1 + self._last_w_ptr = -1 + self._last_dout_ptr = -1 + self._last_rstd_ptr = -1 + self._last_dx_ptr = -1 + self._last_dw_ptr = -1 + self._last_m = -1 + self._last_ld = -1 + self._last_sm_count = -1 + + def launch( + self, + *, + x: Tensor, + weight: Optional[Tensor], + dout: Tensor, + rstd: Tensor, + dx: Tensor, + dw_partial: Optional[Tensor], + M: int, + N: int, + ld: int, + sm_count: int, + ) -> None: + if not _fast_launch_enabled() or not self._use_fast_launch: + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + x_ptr = x.data_ptr() + if x_ptr != self._last_x_ptr: + try: + _set_runtime_ptr(self._ptr_x, x_ptr) + self._last_x_ptr = x_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + if self._ptr_w is not None: + w_ptr = weight.data_ptr() # type: ignore[union-attr] + if w_ptr != self._last_w_ptr: + try: + _set_runtime_ptr(self._ptr_w, w_ptr) + self._last_w_ptr = w_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + dout_ptr = dout.data_ptr() + if dout_ptr != self._last_dout_ptr: + try: + _set_runtime_ptr(self._ptr_dout, dout_ptr) + self._last_dout_ptr = dout_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + rstd_ptr = rstd.data_ptr() + if rstd_ptr != self._last_rstd_ptr: + try: + _set_runtime_ptr(self._ptr_rstd, rstd_ptr) + self._last_rstd_ptr = rstd_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + dx_ptr = dx.data_ptr() + if dx_ptr != self._last_dx_ptr: + try: + _set_runtime_ptr(self._ptr_dx, dx_ptr) + self._last_dx_ptr = dx_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + if self._ptr_dw_partial is not None: + dw_ptr = dw_partial.data_ptr() # type: ignore[union-attr] + if dw_ptr != self._last_dw_ptr: + try: + _set_runtime_ptr(self._ptr_dw_partial, dw_ptr) + self._last_dw_ptr = dw_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld, + sm_count=sm_count, + ) + return + + if M != self._last_m: + self._arg_m.set(M) + self._last_m = M + if ld != self._last_ld: + self._arg_ld.set(ld) + self._last_ld = ld + if sm_count != self._last_sm_count: + self._arg_sm_count.set(sm_count) + self._last_sm_count = sm_count + + if self._cuda_result is not None: + self._cuda_result.value = 0 + + ret = self._capi_func(self._packed_args) # type: ignore[misc] + if ret != 0: + raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}") + if self._cuda_result is not None: + err = int(self._cuda_result.value) + if err != 0: + raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})") + + def _disable_fast_launch(self) -> None: + global _FAST_LAUNCH_SUPPORTED + self._use_fast_launch = False + _FAST_LAUNCH_SUPPORTED = False + + def _fallback_launch( + self, + *, + x: Tensor, + weight: Optional[Tensor], + dout: Tensor, + rstd: Tensor, + dx: Tensor, + dw_partial: Optional[Tensor], + M: int, + N: int, + ld: int, + sm_count: int, + ) -> None: + dtype = TORCH2CUTE_DTYPE[x.dtype] + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_x, + ) + ptr_dout = rt.make_ptr( + dtype, + dout.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_x, + ) + ptr_dx = rt.make_ptr( + dtype, + dx.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_x, + ) + ptr_rstd = rt.make_ptr( + TORCH2CUTE_DTYPE[rstd.dtype], + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_x, + ) + ptr_w = ( + rt.make_ptr( + self._weight_dtype or dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_w, + ) + if weight is not None + else None + ) + ptr_dw_partial = ( + rt.make_ptr( + TORCH2CUTE_DTYPE[dw_partial.dtype], + dw_partial.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align_dw, + ) + if dw_partial is not None + else None + ) + self._compiled( + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_dx, + ptr_dw_partial, + Int32(M), + Int32(N), + Int32(ld), + Int32(sm_count), + self._stream, + ) + + +def _get_fast_ptr_rmsnorm_bwd_launcher( + *, + compiled: object, + dtype: type[cutlass.Numeric], + weight_dtype: Optional[type[cutlass.Numeric]], + N: int, + device_index: int, + stream_handle: int, + has_weight: bool, + has_dw_partial: bool, + assumed_align_x: int, + assumed_align_w: int, + assumed_align_dw: int, +) -> Optional[_PtrRmsnormBwdFastLaunch]: + if not _fast_launch_enabled(): + return None + key = ( + "ptr_bwd_fast", + id(compiled), + N, + dtype, + weight_dtype, + device_index, + int(stream_handle), + has_weight, + has_dw_partial, + int(assumed_align_x), + int(assumed_align_w), + int(assumed_align_dw), + ) + cache = _tls_fast_launch_cache() + cached = cache.get(key) + if cached is not None: + return cached # type: ignore[return-value] + + assumed_align_x = int(assumed_align_x) + assumed_align_w = int(assumed_align_w) + assumed_align_dw = int(assumed_align_dw) + + ptr_x = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align_x + ) + ptr_w = ( + rt.make_ptr( + weight_dtype or dtype, + 0, + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_w, + ) + if has_weight + else None + ) + ptr_dout = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align_x + ) + ptr_rstd = rt.make_ptr( + cutlass.Float32, + 0, + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dx = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align_x + ) + ptr_dw_partial = ( + rt.make_ptr( + cutlass.Float32, + 0, + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) + if has_dw_partial + else None + ) + + arg_m = _StableI32Arg(0) + arg_n = _StableI32Arg(N) + arg_ld = _StableI32Arg(N) + arg_sm_count = _StableI32Arg(0) + stream = cuda.CUstream(int(stream_handle)) + + executor = compiled.to(device_index) # type: ignore[attr-defined] + try: + exe_args, adapted_args = executor.generate_execution_args( + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_dx, + ptr_dw_partial, + arg_m, + arg_n, + arg_ld, + arg_sm_count, + stream, + ) + packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined] + capi_func = compiled.capi_func # type: ignore[attr-defined] + except AttributeError: + global _FAST_LAUNCH_SUPPORTED + _FAST_LAUNCH_SUPPORTED = False + return None + + keepalive: tuple[object, ...] = ( + executor, + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_dx, + ptr_dw_partial, + arg_m, + arg_n, + arg_ld, + arg_sm_count, + stream, + *adapted_args, + ) + + launcher = _PtrRmsnormBwdFastLaunch( + compiled=compiled, + executor=executor, + capi_func=capi_func, + ptr_x=ptr_x, + ptr_w=ptr_w, + ptr_dout=ptr_dout, + ptr_rstd=ptr_rstd, + ptr_dx=ptr_dx, + ptr_dw_partial=ptr_dw_partial, + arg_m=arg_m, + arg_n=arg_n, + arg_ld=arg_ld, + arg_sm_count=arg_sm_count, + stream=stream, + assumed_align_x=assumed_align_x, + assumed_align_w=assumed_align_w, + assumed_align_dw=assumed_align_dw, + weight_dtype=weight_dtype if has_weight else None, + packed_args=packed_args, + keepalive=keepalive, + ) + cache[key] = launcher + return launcher + + def _get_fast_ptr_rmsnorm_launcher( *, compiled: object, dtype: type[cutlass.Numeric], + weight_dtype: Optional[type[cutlass.Numeric]] = None, N: int, device_index: int, stream_handle: int, has_weight: bool, + assumed_align: int = 16, eps: float, ) -> Optional[_PtrRmsnormFastLaunch]: if not _fast_launch_enabled(): @@ -736,9 +1247,11 @@ def _get_fast_ptr_rmsnorm_launcher( id(compiled), N, dtype, + weight_dtype, device_index, int(stream_handle), has_weight, + int(assumed_align), ) cache = _tls_fast_launch_cache() cached = cache.get(key) @@ -746,10 +1259,20 @@ def _get_fast_ptr_rmsnorm_launcher( return cached # type: ignore[return-value] # Create stable runtime args and pointer descriptors once. - ptr_x = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) - ptr_out = rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) + assumed_align = int(assumed_align) + ptr_x = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_out = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) ptr_w = ( - rt.make_ptr(dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=16) + rt.make_ptr( + weight_dtype or dtype, + 0, + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, + ) if has_weight else None ) @@ -813,6 +1336,8 @@ def _get_fast_ptr_rmsnorm_launcher( arg_ld=arg_ld, arg_eps=arg_eps, stream=stream, + assumed_align=assumed_align, + weight_dtype=weight_dtype if has_weight else None, packed_args=packed_args, keepalive=keepalive, ) @@ -952,7 +1477,7 @@ def get_copy_atom_bw( num_copy_bits = const_expr(min(max_bits, num_copy_elems * dtype.width)) from cutlass.cute.nvgpu import cpasync - # Prefer GLOBAL cache policy for bulk streaming reads at large M + # Prefer GLOBAL cache policy for bulk streaming reads at large M. copy_op = ( cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL) if is_async @@ -1012,14 +1537,25 @@ def _threads_per_row(self) -> int: if N == 1536 and self.dtype.width == 16: return 96 # DSv3 default hidden size (7168). Choose a threads-per-row that matches - # the selected vector width to avoid padded work: - # - 128b copies (vec=8 for bf16/fp16): 7168/8 = 896 = 7 * 128 -> tpr=128 - # - 256b copies (vec=16 for bf16/fp16): 7168/16 = 448 = 2 * 224 -> tpr=224 + # the selected vector width to avoid padded work. Using 224 threads/row + # yields exact tiles for all supported copy widths we use on SM100: + # - 64b copies (vec=4 for bf16/fp16): 7168/4 = 1792 = 8 * 224 + # - 128b copies (vec=8 for bf16/fp16): 7168/8 = 896 = 4 * 224 + # - 256b copies (vec=16 for bf16/fp16): 7168/16 = 448 = 2 * 224 # - # The fused direct-GMEM path often uses 256b copies on 32B-aligned - # tensors, while the non-fused path defaults to 128b copies. if N == 7168 and self.dtype.width == 16: - return 224 if self.copy_bits >= 256 else 128 + return 224 + # DSv3-ish N buckets (6144/8192): use larger threads/row so each thread + # holds fewer elements in registers. For 256b vectors, pick a threads/row + # that yields an exact tile without padding. + if self.dtype.width == 16: + if N == 6144: + if self.copy_bits >= 256: + return 192 + if self.copy_bits <= 128: + return 256 + if N == 8192: + return 256 # For small-N, use at least one full warp per row. The kernel # implementation assumes one row per CTA; returning <32 here can # produce multi-row tiles (cols_per_block > 1) which is not supported. @@ -1079,7 +1615,15 @@ def _num_threads(self) -> int: if self.N == 1536 and self.dtype.width == 16: return 96 if self.N == 7168 and self.dtype.width == 16: - return 224 if self.copy_bits >= 256 else 128 + return 224 + if self.dtype.width == 16: + if self.N == 6144: + if self.copy_bits >= 256: + return 192 + if self.copy_bits <= 128: + return 256 + if self.N == 8192: + return 256 if self.N <= 1024: return 32 return 128 if self.N <= 16384 else 256 @@ -2105,7 +2649,13 @@ def _can_use_ptr_path( if residual is not None and residual.dtype != x.dtype: return False if weight is not None and weight.dtype != x.dtype: - return False + # Allow the common "Quack-style" API where weights are fp32 even when + # activations are bf16/fp16. The pointer path constructs a weight tensor + # view with the correct element type (fp32) inside the compiled graph. + if weight.dtype is not torch.float32: + return False + if x.dtype not in (torch.float16, torch.bfloat16): + return False if bias is not None and bias.dtype != x.dtype: return False # The kernel assumes `ld` satisfies a divisibility constraint used by @@ -2128,8 +2678,15 @@ def _can_use_ptr_path( return False if bias is not None and not bias.is_contiguous(): return False - if weight is not None and (weight.data_ptr() % 16) != 0: - return False + if weight is not None: + # For fp32 weights we use 256b universal copies (32B) by default. + # Require 32B alignment so the compiler can safely vectorize loads. + if weight.dtype is torch.float32: + if (weight.data_ptr() % 32) != 0: + return False + else: + if (weight.data_ptr() % 16) != 0: + return False if bias is not None and (bias.data_ptr() % 16) != 0: return False return True @@ -2148,29 +2705,92 @@ def _can_use_ptr_path_fused_add_inplace( """ if x.stride(1) != 1: return False - if residual.dtype != x.dtype: + if residual.dtype != x.dtype: + return False + if weight.dtype != x.dtype: + return False + if residual.stride(1) != 1: + return False + if not residual.is_contiguous(): + return False + if not weight.is_contiguous(): + return False + + dtype = TORCH2CUTE_DTYPE[x.dtype] + divby = 256 // dtype.width + if (x.stride(0) % divby) != 0: + return False + if (residual.stride(0) % divby) != 0: + return False + + if (x.data_ptr() % 16) != 0: + return False + if (residual.data_ptr() % 16) != 0: + return False + if (weight.data_ptr() % 16) != 0: + return False + return True + + +def _can_use_ptr_path_bwd( + x: Tensor, + weight: Optional[Tensor], + dout: Tensor, + rstd: Tensor, +) -> bool: + """Fast-path precondition for the pointer-based RMSNorm backward entry. + + This path is only used for the common Quack-style signature: + - no bias gradient + - no residual / dresidual_out + - weight is either the same dtype as x, or fp32 for bf16/fp16 activations + """ + if x.dim() != 2 or dout.dim() != 2: + return False + if rstd.dim() != 1: + return False + if x.shape != dout.shape: return False - if weight.dtype != x.dtype: + if rstd.numel() != x.shape[0]: return False - if residual.stride(1) != 1: + # SM100 backward kernel assumes N is divisible by 8 (for 256b fp32 stores + # into dw_partial rows). + if (x.shape[1] % 8) != 0: return False - if not residual.is_contiguous(): + if x.stride(1) != 1 or dout.stride(1) != 1: + return False + if dout.stride(0) != x.stride(0): + return False + if dout.dtype != x.dtype: + return False + if rstd.dtype != torch.float32 or not rstd.is_contiguous(): + return False + if weight is None: + return False + if weight.dim() != 1 or weight.shape[0] != x.shape[1]: return False if not weight.is_contiguous(): return False + if weight.dtype != x.dtype: + if weight.dtype is not torch.float32: + return False + if x.dtype not in (torch.float16, torch.bfloat16): + return False dtype = TORCH2CUTE_DTYPE[x.dtype] divby = 256 // dtype.width if (x.stride(0) % divby) != 0: return False - if (residual.stride(0) % divby) != 0: - return False if (x.data_ptr() % 16) != 0: return False - if (residual.data_ptr() % 16) != 0: + if (dout.data_ptr() % 16) != 0: return False - if (weight.data_ptr() % 16) != 0: + # Torch CUDA allocations are typically >=256B aligned, but keep the check + # explicit so we never assume tighter alignment than is true. + if (rstd.data_ptr() % 4) != 0: + return False + if (weight.data_ptr() % (32 if weight.dtype is torch.float32 else 16)) != 0: return False return True @@ -2254,35 +2874,92 @@ def _rmsnorm_forward_ptr_into( stream_handle = int(torch.cuda.current_stream().cuda_stream) has_weight = weight is not None + weight_dtype = TORCH2CUTE_DTYPE[weight.dtype] if has_weight else None + + # Schedule selection (pointer fast path). + # + # Goals: + # - Keep vLLM inference fast path (contiguous/padded row-major) fast. + # - Enable higher vector widths when all participating pointers are 32B-aligned. + # - Prefer direct-GMEM for SM100-friendly hidden sizes to reduce SMEM/barrier + # overhead, especially for small/medium-M cases. + direct_gmem = _direct_gmem_from_policy( + default=bool(dtype.width == 16 and N in {4096, 6144, 7168, 8192}) + ) + use_async = not direct_gmem + + can_use_256 = bool( + dtype.width == 16 + and (x.data_ptr() % 32) == 0 + and (out.data_ptr() % 32) == 0 + and (not has_weight or (weight.data_ptr() % 32) == 0) # type: ignore[union-attr] + ) + default_copy_bits = 256 if can_use_256 else 128 + # Quack-style fp32-weight policy: cap the *widest* dtype to 128b, so when + # weights are fp32 we use 64b activation vectors (helps register pressure). + if dtype.width == 16 and weight_dtype is not None and weight_dtype.width == 32: + default_copy_bits = 64 + copy_bits = _copy_bits_from_policy( + default=default_copy_bits, can_use_256=can_use_256 + ) + assumed_align = 32 if copy_bits >= 256 else 16 + stage = 1 + if ( + _ENABLE_STAGE2 + and dtype.width == 16 + and N == 7168 + and (not direct_gmem) + and M >= 4096 + ): + stage = 2 + compiled_key = ( "ptr", N, dtype, + weight_dtype, False, # residual has_weight, False, # bias False, # residual_out False, # rstd stage, + int(copy_bits), + bool(use_async), + bool(direct_gmem), + int(assumed_align), device_index, ) compiled = _PTR_COMPILE_CACHE.get(compiled_key) if compiled is None: - op = RMSNormSM100(N, dtype, stage=stage) + op = RMSNormSM100( + N, + dtype, + stage=stage, + copy_bits=int(copy_bits), + use_async=bool(use_async), + direct_gmem=bool(direct_gmem), + ) ld_val = int(x.stride(0)) ptr_x = rt.make_ptr( - dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) ptr_out = rt.make_ptr( - dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) ptr_w = ( rt.make_ptr( - dtype, + weight_dtype or dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, - assumed_align=16, + assumed_align=assumed_align, ) if has_weight else None @@ -2309,10 +2986,12 @@ def _rmsnorm_forward_ptr_into( launcher = _get_fast_ptr_rmsnorm_launcher( compiled=compiled, dtype=dtype, + weight_dtype=weight_dtype, N=N, device_index=device_index, stream_handle=stream_handle, has_weight=has_weight, + assumed_align=assumed_align, eps=eps, ) ld_val = int(x.stride(0)) @@ -2321,17 +3000,23 @@ def _rmsnorm_forward_ptr_into( return ptr_x = rt.make_ptr( - dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) ptr_out = rt.make_ptr( - dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) ptr_w = ( rt.make_ptr( - dtype, + weight_dtype or dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, - assumed_align=16, + assumed_align=assumed_align, ) if has_weight else None @@ -2354,8 +3039,42 @@ def _rmsnorm_forward_ptr_into( ) return - # Fallback: general path (supports bias/residual/rstd, but is slower to launch). + # General path (supports bias/residual/rstd, but is slower to launch). + # + # Keep the same schedule-selection policy as the fast path so correctness-only + # features (bias/residual/rstd) don't accidentally fall off a performance cliff. + weight_dtype = TORCH2CUTE_DTYPE[weight.dtype] if weight is not None else None + direct_gmem = _direct_gmem_from_policy( + default=bool(dtype.width == 16 and N in {4096, 6144, 7168, 8192}) + ) + use_async = not direct_gmem + can_use_256 = bool( + dtype.width == 16 + and (x.data_ptr() % 32) == 0 + and (out.data_ptr() % 32) == 0 + and (weight is None or (weight.data_ptr() % 32) == 0) + and (bias is None or (bias.data_ptr() % 32) == 0) + and (residual is None or (residual.data_ptr() % 32) == 0) + and (residual_out is None or (residual_out.data_ptr() % 32) == 0) + ) + default_copy_bits = 256 if can_use_256 else 128 + if dtype.width == 16 and weight_dtype is not None and weight_dtype.width == 32: + default_copy_bits = 64 + copy_bits = _copy_bits_from_policy( + default=default_copy_bits, can_use_256=can_use_256 + ) + assumed_align = 32 if copy_bits >= 256 else 16 + stage = 1 + if ( + _ENABLE_STAGE2 + and dtype.width == 16 + and N == 7168 + and (not direct_gmem) + and M >= 4096 + ): + stage = 2 + if torch.cuda.current_device() != device_index: torch.cuda.set_device(device_index) stream_handle = int(torch.cuda.current_stream().cuda_stream) @@ -2363,29 +3082,47 @@ def _rmsnorm_forward_ptr_into( "ptr", N, dtype, + weight_dtype, residual is not None, weight is not None, bias is not None, residual_out is not None, rstd is not None, stage, + int(copy_bits), + bool(use_async), + bool(direct_gmem), + int(assumed_align), device_index, ) compiled = _PTR_COMPILE_CACHE.get(key) if compiled is None: - op = RMSNormSM100(N, dtype, stage=stage) + op = RMSNormSM100( + N, + dtype, + stage=stage, + copy_bits=int(copy_bits), + use_async=bool(use_async), + direct_gmem=bool(direct_gmem), + ) ptr_x = rt.make_ptr( - dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) ptr_out = rt.make_ptr( - dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) ptr_res = ( rt.make_ptr( dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, - assumed_align=16, + assumed_align=assumed_align, ) if residual is not None else None @@ -2395,24 +3132,27 @@ def _rmsnorm_forward_ptr_into( dtype, residual_out.data_ptr(), mem_space=rt.AddressSpace.gmem, - assumed_align=16, + assumed_align=assumed_align, ) if residual_out is not None else None ) ptr_w = ( rt.make_ptr( - dtype, + weight_dtype or dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, - assumed_align=16, + assumed_align=assumed_align, ) if weight is not None else None ) ptr_b = ( rt.make_ptr( - dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, + bias.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) if bias is not None else None @@ -2446,14 +3186,20 @@ def _rmsnorm_forward_ptr_into( ) _PTR_COMPILE_CACHE[key] = compiled ptr_x = rt.make_ptr( - dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align ) ptr_out = rt.make_ptr( - dtype, out.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, + out.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) ptr_res = ( rt.make_ptr( - dtype, residual.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, + residual.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) if residual is not None else None @@ -2463,21 +3209,27 @@ def _rmsnorm_forward_ptr_into( dtype, residual_out.data_ptr(), mem_space=rt.AddressSpace.gmem, - assumed_align=16, + assumed_align=assumed_align, ) if residual_out is not None else None ) ptr_w = ( rt.make_ptr( - dtype, weight.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + weight_dtype or dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) if weight is not None else None ) ptr_b = ( rt.make_ptr( - dtype, bias.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + dtype, + bias.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align, ) if bias is not None else None @@ -2545,8 +3297,11 @@ def _fused_add_rmsnorm_forward_ptr_inplace( # benchmark other models/shapes, you can override it with: # - OINK_RMSNORM_DIRECT_GMEM=0 (force staging/cp.async path) # - OINK_RMSNORM_DIRECT_GMEM=1 (force direct-gmem path) + # Default direct-GMEM policy: + # - small/medium M: direct-GMEM reduces staging/barrier overhead + # - large M: staged cp.async tends to win on sustained bandwidth direct_gmem = _direct_gmem_from_policy( - default=bool(dtype.width == 16 and N == 7168) + default=bool(dtype.width == 16 and N == 7168 and M <= 16384) ) use_async = not direct_gmem tpr_override: Optional[int] = None @@ -2743,10 +3498,13 @@ def rmsnorm_forward( # # When the pointer path can't be used (e.g. float32 weights for Quack-style # APIs, or non-standard layouts), fall back to the CuTeDSL stage-2 module - # before using the slow torch reference implementation. + # (ported from `/tmp/oink_main/Blackwell`) before using the slow torch + # reference implementation. force_stage2 = _FORCE_RMSNORM_STAGE2_FWD - if not force_stage2 and _can_use_ptr_path(x, weight, bias, residual): + use_ptr = (not force_stage2) and _can_use_ptr_path(x, weight, bias, residual) + + if use_ptr: return _rmsnorm_forward_ptr(x, weight, bias, residual, eps, store_rstd) # CuTeDSL fallback for cases that aren't safe for the pointer path. @@ -2950,6 +3708,83 @@ class RMSNormBackwardSM100(BaseRMSNormBackward): def __init__(self, dtype: cutlass.Numeric, N: int): super().__init__(dtype, N) + def _get_num_threads(self) -> int: + nt = getattr(self, "_nt_override", None) + if nt is not None: + return int(nt) + return super()._get_num_threads() + + def _calculate_threads_per_row(self) -> int: + tpr = getattr(self, "_tpr_override", None) + if tpr is not None: + return int(tpr) + return super()._calculate_threads_per_row() + + @cute.jit + def launch_from_ptrs( + self, + ptr_x: cute.Pointer, + ptr_w: cute.Pointer, + ptr_dout: cute.Pointer, + ptr_rstd: cute.Pointer, + ptr_dx: cute.Pointer, + ptr_dw_partial: cute.Pointer, + M: Int32, + N_dyn: Int32, + ld: Int32, + sm_count: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions. + + This is the performance-critical path used by the benchmark harness + (and any future training integrations) for the common case: + - weight gradient enabled (dw_partial is provided) + - no bias/residual gradients + """ + # Weight-grad stores use vectorized float32 copies. For the SM100 + # schedule we want to allow up to 256b (8x f32) stores, which requires + # the leading dimension to be divisible by 8 to prove 32B alignment for + # every row in `dw_partial`. + N_assumed = cute.assume(N_dyn, divby=8) + + layout_mn = cute.make_layout((M, N_assumed), stride=(ld, 1)) + layout_n = cute.make_layout((N_assumed,), stride=(1,)) + layout_m = cute.make_layout((M,), stride=(1,)) + # Default: write a full (sm_count, N) partial buffer (Quack-style), + # then reduce on the host with `torch.sum(dim=0)`. + # + # Optional: atomic-reduce directly into a single (N,) buffer by using + # a broadcasted leading dimension (stride0 = 0). This avoids the extra + # reduction kernel launch and is primarily used for tiny-M regimes. + if const_expr(self.atomic_dw): + layout_partial = cute.make_layout((sm_count, N_assumed), stride=(0, 1)) + else: + layout_partial = cute.make_layout( + (sm_count, N_assumed), stride=(N_assumed, 1) + ) + + mX = cute.make_tensor(ptr_x, layout_mn) + mW = cute.make_tensor(ptr_w, layout_n) + mdO = cute.make_tensor(ptr_dout, layout_mn) + mRstd = cute.make_tensor(ptr_rstd, layout_m) + mdX = cute.make_tensor(ptr_dx, layout_mn) + mdW = cute.make_tensor(ptr_dw_partial, layout_partial) + + self.__call__( + mX, + mW, + mdO, + None, # dresidual_out + mRstd, + mdX, + mdW, + None, # dresidual + None, # db_partial + sm_count, + stream, + ) + def _get_num_threads(self) -> int: # Keep 128 threads only up to N=4k; use 256 for larger rows to ensure # threads_per_row <= num_threads across buckets. @@ -2959,24 +3794,22 @@ def _get_num_threads(self) -> int: return 128 if self.N <= 4096 else 256 def _calculate_threads_per_row(self) -> int: - # Mirror RMSNormSM100 forward's tiling. + try: + return self._tpr_override # type: ignore[attr-defined] + except Exception: + pass + # Match Quack's backward tiling: use 256 threads/row for N > 4096. + # + # The earlier "mirror forward" policy (128 threads/row for N<=8192) + # regresses DSv3 backward at N=6144/7168/8192 on SM100. N = self.N - if N <= 64: - return 8 - if N <= 128: - return 16 - if N <= 1024: - return 32 - if N <= 4096: - return 128 - if N <= 8192: - try: - return self._tpr_override # type: ignore[attr-defined] - except Exception: - return 128 - if N <= 16384: + for limit, threads in [(64, 8), (128, 16), (256, 32), (512, 64), (4096, 128)]: + if N <= limit: + return threads + try: + return self._tpr_override # type: ignore[attr-defined] + except Exception: return 256 - return 256 def _set_cluster_n(self) -> None: # Reuse the SM100 forward cluster growth policy so large-N shapes can @@ -3093,6 +3926,7 @@ def new_stride(t): _BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} +_BWD_PTR_COMPILE_CACHE: dict[tuple[object, ...], object] = {} def _rmsnorm_bwd_sm100( @@ -3141,14 +3975,13 @@ def _rmsnorm_bwd_sm100( # Match Quack's conversion strategy for activations/gradients: keep the # (M, N) layout dynamic without enforcing additional compact-shape # constraints. This reduces per-call Python overhead for small-M shapes. - def _convert_mx(t: Tensor) -> cute.Tensor: - return from_dlpack( - t.detach(), - assumed_align=16, - ).mark_layout_dynamic(leading_dim=1) + def _convert_layout_dynamic(t: Tensor) -> cute.Tensor: + return from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) x_tensor, dout_tensor, dres_out_tensor, dx_tensor, dres_tensor = [ - _convert_mx(t) if t is not None else None + _convert_layout_dynamic(t) if t is not None else None for t in (x, dout, dresidual_out, dx, dresidual) ] @@ -3230,6 +4063,227 @@ def _convert_mx(t: Tensor) -> cute.Tensor: ) +def _rmsnorm_bwd_sm100_ptr( + x: Tensor, + weight: Tensor, + dout: Tensor, + rstd: Tensor, + dx: Tensor, + dw_partial: Tensor, + sm_count: int, + *, + atomic_dw: bool = False, +) -> None: + """Pointer-based SM100 RMSNorm backward launch (no DLPack conversions). + + When `atomic_dw=True`, `dw_partial` is treated as a single (N,) fp32 buffer + and the kernel atomically accumulates weight gradients into it (avoids the + extra `dw_partial.sum(dim=0)` reduction kernel). + """ + assert _can_use_ptr_path_bwd(x, weight, dout, rstd) + assert dx.shape == x.shape + assert dx.dtype == x.dtype + assert dw_partial.dtype == torch.float32 + + M, N = x.size(0), x.size(1) + if atomic_dw: + assert dw_partial.dim() == 1 and dw_partial.numel() == N + assert dw_partial.is_contiguous() + else: + assert dw_partial.dim() == 2 and dw_partial.shape[1] == N + device_index = x.get_device() + dtype = TORCH2CUTE_DTYPE[x.dtype] + weight_dtype = TORCH2CUTE_DTYPE[weight.dtype] + assumed_align_x = 16 + assumed_align_w = 32 if weight.dtype is torch.float32 else 16 + assumed_align_dw = 32 + assert (dw_partial.data_ptr() % assumed_align_dw) == 0 + + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + + ld_val = int(x.stride(0)) + key = ( + "bwd_ptr", + N, + dtype, + weight_dtype, + int(assumed_align_x), + int(assumed_align_w), + int(assumed_align_dw), + device_index, + bool(atomic_dw), + ) + compiled = _BWD_PTR_COMPILE_CACHE.get(key) + if compiled is None: + op = RMSNormBackwardSM100(dtype, N) + op.atomic_dw = bool(atomic_dw) + # 16-bit activations + 16-bit weights (vLLM-style) backward at N=4096: + # Use a 1-row/CTA schedule with 256 threads/row. This reduces per-thread + # work and improves bandwidth on large-M shapes on SM100. + if ( + (not atomic_dw) + and N == 4096 + and dtype.width == 16 + and weight_dtype.width == 16 + ): + op._tpr_override = 256 # type: ignore[attr-defined] + op._nt_override = 256 # type: ignore[attr-defined] + # 16-bit activations + fp32 weights backward at N=4096: + # Use a 256-thread schedule (tpr=256) to improve throughput. + if ( + (not atomic_dw) + and N == 4096 + and dtype.width == 16 + and weight_dtype is cutlass.Float32 + ): + op._tpr_override = 256 # type: ignore[attr-defined] + op._nt_override = 256 # type: ignore[attr-defined] + # FP16 + fp32-weight DSv3 backward: Quack's default (1 row/CTA with + # 256 threads/row) underperforms. Use a 2-rows/CTA schedule (256 threads + # total, 128 threads/row) to improve memory-level parallelism. + if ( + (not atomic_dw) + and N == 6144 + and dtype is cutlass.Float16 + and weight_dtype is cutlass.Float32 + ): + op._tpr_override = 128 # type: ignore[attr-defined] + op._nt_override = 256 # type: ignore[attr-defined] + + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_w = rt.make_ptr( + weight_dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_w, + ) + ptr_dout = rt.make_ptr( + dtype, + dout.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_rstd = rt.make_ptr( + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dx = rt.make_ptr( + dtype, + dx.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dw = rt.make_ptr( + cutlass.Float32, + dw_partial.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_dx, + ptr_dw, + Int32(M), + Int32(N), + Int32(ld_val), + Int32(int(sm_count)), + stream, + ) + _BWD_PTR_COMPILE_CACHE[key] = compiled + + launcher = _get_fast_ptr_rmsnorm_bwd_launcher( + compiled=compiled, + dtype=dtype, + weight_dtype=weight_dtype, + N=N, + device_index=device_index, + stream_handle=stream_handle, + has_weight=True, + has_dw_partial=True, + assumed_align_x=assumed_align_x, + assumed_align_w=assumed_align_w, + assumed_align_dw=assumed_align_dw, + ) + if launcher is not None: + launcher.launch( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + M=M, + N=N, + ld=ld_val, + sm_count=int(sm_count), + ) + return + + ptr_x = rt.make_ptr( + dtype, + x.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_w = rt.make_ptr( + weight_dtype, + weight.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_w, + ) + ptr_dout = rt.make_ptr( + dtype, + dout.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_rstd = rt.make_ptr( + cutlass.Float32, + rstd.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dx = rt.make_ptr( + dtype, + dx.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_x, + ) + ptr_dw = rt.make_ptr( + cutlass.Float32, + dw_partial.data_ptr(), + mem_space=rt.AddressSpace.gmem, + assumed_align=assumed_align_dw, + ) + compiled( + ptr_x, + ptr_w, + ptr_dout, + ptr_rstd, + ptr_dx, + ptr_dw, + Int32(M), + Int32(N), + Int32(ld_val), + Int32(int(sm_count)), + stream, + ) + + def rmsnorm_backward( x: Tensor, weight: Optional[Tensor], @@ -3262,14 +4316,32 @@ def rmsnorm_backward( # pressure in benchmark/verify loops. Clamp to Quack's baseline policy # (`sm_count = num_sms * 2` for N=4096) for this regime. if N == 4096 and M <= 8192 and x.dtype in (torch.float16, torch.bfloat16): - try: - num_sms = torch.cuda.get_device_properties(device).multi_processor_count - sm_count = min(int(sm_count), int(num_sms) * 2) - except Exception: - pass + num_sms = qutils.get_num_sms(device) + sm_count = min(int(sm_count), int(num_sms) * 2) + + use_atomic_dw = False + # DSv3 backward (N=6144/7168/8192) is dominated by the (sm_count, N) partial + # write + reduction for dW. Use the atomic-dW path to accumulate directly + # into a single (N,) fp32 buffer (no separate reduction kernel). + if ( + weight is not None + and (not has_bias) + and (not has_residual) + and dresidual_out is None + and dresidual is None + and N == 8192 + and weight.dtype is torch.float32 + and M >= 65536 + and x.dtype in (torch.float16, torch.bfloat16) + and _can_use_ptr_path_bwd(x, weight, dout, rstd) + ): + use_atomic_dw = True if weight is not None: - dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) + if use_atomic_dw: + dw_partial = torch.zeros(N, device=device, dtype=torch.float32) + else: + dw_partial = torch.empty(sm_count, N, device=device, dtype=torch.float32) else: dw_partial = None db_partial = ( @@ -3278,20 +4350,47 @@ def rmsnorm_backward( else None ) - _rmsnorm_bwd_sm100( - x, - weight, - dout, - rstd, - dx, - dw_partial, - db_partial, - dresidual_out, - dresidual, - sm_count, - ) + if ( + weight is not None + and dw_partial is not None + and (not has_bias) + and (not has_residual) + and dresidual_out is None + and dresidual is None + and _can_use_ptr_path_bwd(x, weight, dout, rstd) + ): + _rmsnorm_bwd_sm100_ptr( + x=x, + weight=weight, + dout=dout, + rstd=rstd, + dx=dx, + dw_partial=dw_partial, + sm_count=int(sm_count), + atomic_dw=bool(use_atomic_dw), + ) + else: + _rmsnorm_bwd_sm100( + x, + weight, + dout, + rstd, + dx, + dw_partial, + db_partial, + dresidual_out, + dresidual, + sm_count, + ) - dw = dw_partial.sum(dim=0).to(weight.dtype) if weight is not None else None + if weight is not None and dw_partial is not None: + if use_atomic_dw: + dw_fp32 = dw_partial + else: + dw_fp32 = _reduce_partial_sum_fp32(dw_partial, device_index=x.get_device()) + dw = dw_fp32 if weight.dtype is torch.float32 else dw_fp32.to(weight.dtype) + else: + dw = None db = db_partial.sum(dim=0).to(weight.dtype) if has_bias else None if has_residual and dresidual is None: dresidual = dx diff --git a/oink/src/kernelagent_oink/blackwell/softmax.py b/oink/src/kernelagent_oink/blackwell/softmax.py index 6a7eb54..394ab48 100644 --- a/oink/src/kernelagent_oink/blackwell/softmax.py +++ b/oink/src/kernelagent_oink/blackwell/softmax.py @@ -65,11 +65,17 @@ from cutlass.cute import runtime as rt from cutlass.cute.runtime import from_dlpack +from kernelagent_oink.blackwell.fast_launch import ( + StableI32Arg, + disable_fast_launch, + fast_launch_enabled, + set_runtime_ptr, + tls_cache as _tls_fast_launch_cache, +) from kernelagent_oink.blackwell.lite_quack import ( _KERNEL_ACCEPTS_LAYOUT_ARGS, TORCH2CUTE_DTYPE, ReductionBase, - domain_offset_i64, fill_oob, online_softmax_reduce, predicate_k, @@ -80,6 +86,275 @@ _BWD_COMPILE_CACHE: dict[tuple[Type[cutlass.Numeric], int], object] = {} _PTR_FWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} _PTR_BWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} +_PTR_FWDBWD_COMPILE_CACHE: dict[tuple[object, ...], object] = {} + + +class _PtrSoftmaxFastLaunch: + def __init__( + self, + *, + compiled: object, + executor: object, + capi_func: object, + ptr_a: object, + ptr_b: object, + ptr_c: object | None, + arg_m: StableI32Arg, + arg_ld: StableI32Arg, + stream: cuda.CUstream, + assumed_align: int, + packed_args: object, + keepalive: tuple[object, ...], + ): + self._compiled = compiled + self._executor = executor + self._capi_func = capi_func + self._ptr_a = ptr_a + self._ptr_b = ptr_b + self._ptr_c = ptr_c + self._arg_m = arg_m + self._arg_ld = arg_ld + self._stream = stream + self._assumed_align = int(assumed_align) + self._packed_args = packed_args + self._keepalive = keepalive + + self._use_fast_launch = True + self._cuda_result = getattr(executor, "cuda_result", None) + + self._last_a_ptr = -1 + self._last_b_ptr = -1 + self._last_c_ptr = -1 + self._last_m = -1 + self._last_ld = -1 + + def launch( + self, + *, + a_ptr: int, + b_ptr: int, + c_ptr: int | None, + M: int, + ld: int, + stream_handle: int, + dtype: type[cutlass.Numeric], + ) -> None: + if not fast_launch_enabled() or not self._use_fast_launch: + self._fallback_launch( + a_ptr=a_ptr, + b_ptr=b_ptr, + c_ptr=c_ptr, + M=M, + ld=ld, + stream_handle=stream_handle, + dtype=dtype, + ) + return + + if a_ptr != self._last_a_ptr: + try: + set_runtime_ptr(self._ptr_a, a_ptr) + self._last_a_ptr = a_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + a_ptr=a_ptr, + b_ptr=b_ptr, + c_ptr=c_ptr, + M=M, + ld=ld, + stream_handle=stream_handle, + dtype=dtype, + ) + return + + if b_ptr != self._last_b_ptr: + try: + set_runtime_ptr(self._ptr_b, b_ptr) + self._last_b_ptr = b_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + a_ptr=a_ptr, + b_ptr=b_ptr, + c_ptr=c_ptr, + M=M, + ld=ld, + stream_handle=stream_handle, + dtype=dtype, + ) + return + + if self._ptr_c is not None and c_ptr is not None: + if c_ptr != self._last_c_ptr: + try: + set_runtime_ptr(self._ptr_c, c_ptr) + self._last_c_ptr = c_ptr + except AttributeError: + self._disable_fast_launch() + self._fallback_launch( + a_ptr=a_ptr, + b_ptr=b_ptr, + c_ptr=c_ptr, + M=M, + ld=ld, + stream_handle=stream_handle, + dtype=dtype, + ) + return + + if M != self._last_m: + self._arg_m.set(M) + self._last_m = M + if ld != self._last_ld: + self._arg_ld.set(ld) + self._last_ld = ld + + if self._cuda_result is not None: + self._cuda_result.value = 0 + ret = self._capi_func(self._packed_args) # type: ignore[misc] + if ret != 0: + raise RuntimeError(f"CuTeDSL capi_func returned non-zero: {ret}") + if self._cuda_result is not None: + err = int(self._cuda_result.value) + if err != 0: + raise RuntimeError(f"CuTeDSL kernel launch failed (cuda_result={err})") + + def _disable_fast_launch(self) -> None: + self._use_fast_launch = False + disable_fast_launch() + + def _fallback_launch( + self, + *, + a_ptr: int, + b_ptr: int, + c_ptr: int | None, + M: int, + ld: int, + stream_handle: int, + dtype: type[cutlass.Numeric], + ) -> None: + stream = cuda.CUstream(int(stream_handle)) + ptr_a = rt.make_ptr( + dtype, + a_ptr, + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + ptr_b = rt.make_ptr( + dtype, + b_ptr, + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + if self._ptr_c is not None and c_ptr is not None: + ptr_c = rt.make_ptr( + dtype, + c_ptr, + mem_space=rt.AddressSpace.gmem, + assumed_align=self._assumed_align, + ) + self._compiled(ptr_a, ptr_b, ptr_c, Int32(int(M)), Int32(int(ld)), stream) + else: + self._compiled(ptr_a, ptr_b, Int32(int(M)), Int32(int(ld)), stream) + + +def _get_fast_ptr_softmax_launcher( + *, + compiled: object, + dtype: type[cutlass.Numeric], + N: int, + device_index: int, + stream_handle: int, + assumed_align: int, + is_bwd: bool, +) -> _PtrSoftmaxFastLaunch | None: + if not fast_launch_enabled(): + return None + key = ( + "ptr_fast_bwd" if is_bwd else "ptr_fast_fwd", + id(compiled), + int(N), + dtype, + int(device_index), + int(stream_handle), + int(assumed_align), + ) + cache = _tls_fast_launch_cache() + cached = cache.get(key) + if cached is not None: + return cached # type: ignore[return-value] + + assumed_align = int(assumed_align) + ptr_a = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_b = rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + ptr_c = ( + rt.make_ptr( + dtype, 0, mem_space=rt.AddressSpace.gmem, assumed_align=assumed_align + ) + if is_bwd + else None + ) + + arg_m = StableI32Arg(0) + arg_ld = StableI32Arg(N) + stream = cuda.CUstream(int(stream_handle)) + executor = compiled.to(device_index) # type: ignore[attr-defined] + try: + if ptr_c is not None: + exe_args, adapted_args = executor.generate_execution_args( + ptr_a, + ptr_b, + ptr_c, + arg_m, + arg_ld, + stream, + ) + else: + exe_args, adapted_args = executor.generate_execution_args( + ptr_a, + ptr_b, + arg_m, + arg_ld, + stream, + ) + packed_args = executor._get_invoke_packed_args(list(exe_args)) # type: ignore[attr-defined] + capi_func = compiled.capi_func # type: ignore[attr-defined] + except AttributeError: + disable_fast_launch() + return None + + keepalive: tuple[object, ...] = ( + executor, + ptr_a, + ptr_b, + ptr_c, + arg_m, + arg_ld, + stream, + *adapted_args, + ) + launcher = _PtrSoftmaxFastLaunch( + compiled=compiled, + executor=executor, + capi_func=capi_func, + ptr_a=ptr_a, + ptr_b=ptr_b, + ptr_c=ptr_c, + arg_m=arg_m, + arg_ld=arg_ld, + stream=stream, + assumed_align=assumed_align, + packed_args=packed_args, + keepalive=keepalive, + ) + cache[key] = launcher + return launcher class SoftmaxFwdSM100(ReductionBase): @@ -87,9 +362,23 @@ def __init__(self, dtype: Type[cutlass.Numeric], N: int): # One-stage online reduction: pack (max, sum_exp) into Int64 reduction buffer. super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64) + def _get_num_threads(self) -> int: + # SM100 tuning note: + # For N=4096, we use 32 threads per row (1 warp) and run 1 row per CTA + # (32 threads total). This keeps the reduction fully warp-local and + # improves throughput on this GB200 versus Quack's default 2-rows-per-CTA + # schedule with 64 threads per row (4 warps total). + if self.N == 4096: + return 32 + return super()._get_num_threads() + def _calculate_threads_per_row(self) -> int: # Match Quack's bucketed policy for Softmax. N = self.N + if N == 4096: + return 32 + if N == 6144: + return 128 if N <= 64: return 8 if N <= 128: @@ -192,10 +481,10 @@ def _kernel_impl( shape = mX.shape idX = cute.make_identity_tensor(shape) - # Slice per-CTA region; use 64-bit indexing for large tensors. - mX, mO = [domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mX, mO)] - gX, gO = [cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mX, mO)] - cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) + # Quack-style CTA tiling. + gX, gO, cX = [ + cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) for mT in (mX, mO, idX) + ] smem = cutlass.utils.SmemAllocator() sX = smem.allocate_tensor( @@ -220,11 +509,25 @@ def _kernel_impl( num_bits_per_copy=128, ) - thr_copy_load = cute.make_tiled_copy( - copy_atom_load, tv_layout, tiler_mn + num_copy_elems = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems)) + thr_copy_load = cute.make_tiled_copy_tv( + copy_atom_load, thr_layout, val_layout ).get_slice(tidx) - thr_copy_store = cute.make_tiled_copy( - copy_atom_store, tv_layout, tiler_mn + thr_copy_store = cute.make_tiled_copy_tv( + copy_atom_store, thr_layout, val_layout ).get_slice(tidx) tXgX = thr_copy_load.partition_S(gX) @@ -256,7 +559,6 @@ def _kernel_impl( cute.autovec_copy(tXsX, tXrX) x = tXrX.load().to(Float32) - threads_per_row = tv_layout.shape[0][0] # Online softmax reduction: compute max and sum_exp in a single pass, with # optional cluster-wide aggregation via an Int64 reduction buffer. @@ -313,6 +615,8 @@ def __init__(self, dtype: Type[cutlass.Numeric], N: int): def _calculate_threads_per_row(self) -> int: # Match Quack backward softmax buckets. N = self.N + if N in (4096, 6144): + return 128 if N <= 64: return 8 if N <= 128: @@ -433,13 +737,10 @@ def _kernel_impl( shape = mdY.shape idX = cute.make_identity_tensor(shape) - mdY, mY, mdX = [ - domain_offset_i64((bidx * tiler_mn[0], 0), mT) for mT in (mdY, mY, mdX) - ] - gdY, gY, gdX = [ - cute.local_tile(mT, tiler_mn, (0, cluster_y)) for mT in (mdY, mY, mdX) + gdY, gY, gdX, cX = [ + cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) + for mT in (mdY, mY, mdX, idX) ] - cX = cute.local_tile(idX, tiler_mn, (bidx, cluster_y)) smem = cutlass.utils.SmemAllocator() sdY = smem.allocate_tensor( @@ -467,11 +768,25 @@ def _kernel_impl( num_bits_per_copy=128, ) - thr_copy_load = cute.make_tiled_copy( - copy_atom_load, tv_layout, tiler_mn + num_copy_elems = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems)) + thr_copy_load = cute.make_tiled_copy_tv( + copy_atom_load, thr_layout, val_layout ).get_slice(tidx) - thr_copy_store = cute.make_tiled_copy( - copy_atom_store, tv_layout, tiler_mn + thr_copy_store = cute.make_tiled_copy_tv( + copy_atom_store, thr_layout, val_layout ).get_slice(tidx) tdYgdY = thr_copy_load.partition_S(gdY) @@ -505,8 +820,6 @@ def _kernel_impl( cute.autovec_copy(tYsY, tYrY) dy = tdYrdY.load().to(Float32) y = tYrY.load().to(Float32) - - threads_per_row = tv_layout.shape[0][0] dot = row_reduce( dy * y, cute.ReductionOp.ADD, @@ -553,6 +866,335 @@ def kernel( self._kernel_impl(mdY, mY, mdX, tv_layout, tiler_mn) +class SoftmaxFwdBwdSM100(ReductionBase): + """Fused softmax forward+backward producing dx from (x, dy). + + Computes: + y = softmax(x) + dot = sum(dy * y) + dx = y * (dy - dot) + + This avoids materializing the intermediate `y` in global memory, which is + the dominant overhead in a naive `softmax_backward(dy, softmax_forward(x))` + composition. + """ + + def __init__(self, dtype: Type[cutlass.Numeric], N: int): + # Online softmax reduction uses an Int64 reduction buffer packing + # (max, sum_exp) pairs. We allocate a separate Float32 reduction buffer + # for dot(dy, y). + super().__init__(dtype, N, stage=1, reduction_dtype=cutlass.Int64) + + def _calculate_threads_per_row(self) -> int: + # Favor the backward bucket policy (better for the dot reduction). + N = self.N + if N in (4096, 6144): + return 128 + if N <= 64: + return 8 + if N <= 128: + return 16 + if N <= 3072: + return 32 + if N <= 6144: + return 64 + if N <= 8192: + return 128 + return 256 + + def _set_cluster_n(self) -> None: + # Quack-style growth of cluster_n with N and dtype. + N = self.N + if const_expr(self.dtype.width == 16): + cluster_n = ( + 1 + if N <= 16 * 1024 + else ( + 2 + if N <= 32 * 1024 + else (4 if N <= 64 * 1024 else (8 if N <= 128 * 1024 else 16)) + ) + ) + else: + cluster_n = ( + 1 + if N <= 32 * 1024 + else ( + 2 + if N <= 64 * 1024 + else (4 if N <= 128 * 1024 else (8 if N <= 256 * 1024 else 16)) + ) + ) + self.cluster_n = cluster_n + + def _get_num_threads(self) -> int: + # Keep in sync with _calculate_threads_per_row. + return 128 if self.N <= 8192 else 256 + + def _smem_size_in_bytes(self, tiler_mn, num_warps: int) -> int: + # Allocation order: + # 1) sX (16B aligned) + # 2) sdY (16B aligned) + # 3) reduction_buffer_stats (8B aligned) + # 4) reduction_buffer_dot (8B aligned) + # 5) optional mbarrier array (8B aligned) + def _align_up(x: int, align: int) -> int: + return ((x + align - 1) // align) * align + + tile_bytes = int(cute.size_in_bytes(self.dtype, cute.make_layout(tiler_mn))) + reduction_stats_bytes = int( + num_warps * self.cluster_n * (cutlass.Int64.width // 8) + ) + reduction_dot_bytes = int( + num_warps * self.cluster_n * (cutlass.Float32.width // 8) + ) + mbar_bytes = ( + int(2 * (cutlass.Int64.width // 8)) if const_expr(self.cluster_n > 1) else 0 + ) + + offset = _align_up(tile_bytes, 16) + offset = _align_up(offset, 16) + tile_bytes + offset = _align_up(offset, 8) + reduction_stats_bytes + offset = _align_up(offset, 8) + reduction_dot_bytes + offset = _align_up(offset, 8) + mbar_bytes + return int(offset) + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mdY: cute.Tensor, + mdX: cute.Tensor, + stream: cuda.CUstream, + ) -> None: + assert mX.element_type == self.dtype + assert mdY.element_type == self.dtype + assert mdX.element_type == self.dtype + tiler_mn, tv_layout = self._get_tv_layout() + num_threads = ( + cute.size(tv_layout, mode=[0]) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self._get_num_threads() + ) + num_warps = num_threads // cute.arch.WARP_SIZE + kernel = ( + self.kernel(mX, mdY, mdX, tv_layout, tiler_mn) + if _KERNEL_ACCEPTS_LAYOUT_ARGS + else self.kernel(mX, mdY, mdX) + ) + kernel.launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), self.cluster_n, 1], + block=[num_threads, 1, 1], + cluster=[1, self.cluster_n, 1] if const_expr(self.cluster_n > 1) else None, + smem=self._smem_size_in_bytes(tiler_mn, num_warps), + stream=stream, + ) + + @cute.jit + def launch_from_ptrs( + self, + ptr_x: cute.Pointer, + ptr_dy: cute.Pointer, + ptr_dx: cute.Pointer, + M: Int32, + ld: Int32, + stream: cuda.CUstream, + ) -> None: + """Pointer-based entrypoint that bypasses DLPack conversions.""" + ld_assumed = cute.assume(ld, divby=128 // self.dtype.width) + layout_mn = cute.make_layout((M, self.N), stride=(ld_assumed, 1)) + mX = cute.make_tensor(ptr_x, layout_mn) + mdY = cute.make_tensor(ptr_dy, layout_mn) + mdX = cute.make_tensor(ptr_dx, layout_mn) + self.__call__(mX, mdY, mdX, stream) + + @cute.jit + def _kernel_impl( + self, + mX: cute.Tensor, + mdY: cute.Tensor, + mdX: cute.Tensor, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + cluster_y = ( + const_expr(0) + if const_expr(self.cluster_n == 1) + else cute.arch.block_idx()[1] + ) + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + + gX, gdY, gdX, cX = [ + cute.local_tile(mT, tiler_mn, (bidx, cluster_y)) + for mT in (mX, mdY, mdX, idX) + ] + + smem = cutlass.utils.SmemAllocator() + sX = smem.allocate_tensor( + mX.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + sdY = smem.allocate_tensor( + mdY.element_type, + cute.make_ordered_layout(tiler_mn, order=(1, 0)), + byte_alignment=16, + ) + + reduction_layout = self._get_reduction_buffer_layout(tv_layout, self.cluster_n) + reduction_buffer_stats = smem.allocate_tensor( + cutlass.Int64, reduction_layout, byte_alignment=8 + ) + reduction_buffer_dot = smem.allocate_tensor( + cutlass.Float32, reduction_layout, byte_alignment=8 + ) + + if const_expr(self.cluster_n > 1): + mbar_ptr_base = smem.allocate_array(cutlass.Int64, num_elems=2) + mbar_ptr_stats = mbar_ptr_base + mbar_ptr_dot = mbar_ptr_base + Int32(1) + else: + mbar_ptr_stats = None + mbar_ptr_dot = None + + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mX.element_type, + num_bits_per_copy=128, + ) + copy_atom_store = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gdX.element_type, + num_bits_per_copy=128, + ) + + num_copy_elems = ( + tv_layout.shape[1] + if const_expr(cute.rank(tv_layout.shape[1]) == 1) + else tv_layout.shape[1][0] + ) + threads_per_row = ( + tv_layout.shape[0] + if const_expr(cute.rank(tv_layout.shape[0]) == 1) + else tv_layout.shape[0][0] + ) + thr_layout = cute.make_ordered_layout( + (tiler_mn[0], threads_per_row), order=(1, 0) + ) + val_layout = cute.make_layout((1, num_copy_elems)) + thr_copy_load = cute.make_tiled_copy_tv( + copy_atom_load, thr_layout, val_layout + ).get_slice(tidx) + thr_copy_store = cute.make_tiled_copy_tv( + copy_atom_store, thr_layout, val_layout + ).get_slice(tidx) + + tXgX = thr_copy_load.partition_S(gX) + tXsX = thr_copy_load.partition_D(sX) + tdYgdY = thr_copy_load.partition_S(gdY) + tdYsdY = thr_copy_load.partition_D(sdY) + tdXgdX = thr_copy_store.partition_D(gdX) + tXcX = thr_copy_load.partition_S(cX)[(0, None), None, None] + + tXrX, tdYrdY, tdXrdX = [ + cute.make_fragment_like(thr) for thr in (tXgX, tdYgdY, tdXgdX) + ] + + if const_expr( + self.cluster_n > 1 + and mbar_ptr_stats is not None + and mbar_ptr_dot is not None + ): + if tidx < 2: + cute.arch.mbarrier_init(mbar_ptr_stats + tidx, 1) + cute.arch.mbarrier_init_fence() + cute.arch.cluster_arrive_relaxed() + + is_even_N = const_expr(self.N == tiler_mn[1] * self.cluster_n) + tXpX = ( + predicate_k(thr_copy_load.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_load, tXgX, tXsX, pred=tXpX) + cute.copy(copy_atom_load, tdYgdY, tdYsdY, pred=tXpX) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + + if const_expr(not is_even_N): + fill_oob(tXsX, tXpX, -tXsX.element_type.inf) + fill_oob(tdYsdY, tXpX, 0.0) + + cute.autovec_copy(tXsX, tXrX) + cute.autovec_copy(tdYsdY, tdYrdY) + x = tXrX.load().to(Float32) + dy = tdYrdY.load().to(Float32) + + _, denom, exp_x = online_softmax_reduce( + x, + threads_per_row, + reduction_buffer_stats[None, None, 0], + mbar_ptr_stats, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + phase=None, + return_exp_x=True, + ) + assert exp_x is not None + y = exp_x * cute.arch.rcp_approx(denom) + + dot = row_reduce( + dy * y, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer_dot[None, None, 0], + mbar_ptr_dot, + phase=None, + init_val=0.0, + hook_fn=cute.arch.cluster_wait if const_expr(self.cluster_n > 1) else None, + ) + + dx = y * (dy - dot) + tdXrdX.store(dx.to(tdXrdX.element_type)) + + tOpO = ( + predicate_k(thr_copy_store.partition_S(cX), limit=shape[1]) + if const_expr(not is_even_N) + else None + ) + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_store, tdXrdX, tdXgdX, pred=tOpO) + + if _KERNEL_ACCEPTS_LAYOUT_ARGS: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mdY: cute.Tensor, + mdX: cute.Tensor, + tv_layout: cute.Layout, + tiler_mn: cute.Shape, + ) -> None: + self._kernel_impl(mX, mdY, mdX, tv_layout, tiler_mn) + else: + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mdY: cute.Tensor, + mdX: cute.Tensor, + ) -> None: + tiler_mn, tv_layout = self._get_tv_layout() + self._kernel_impl(mX, mdY, mdX, tv_layout, tiler_mn) + + def _convert_2d_tensor(x: Tensor) -> cute.Tensor: # Match Quack's Softmax conversion exactly: assume 16B alignment and mark # the shape compact with row-major stride order (0, 1), with mode=0 (batch). @@ -596,7 +1238,8 @@ def _softmax_forward_ptr_into(*, x: Tensor, out: Tensor) -> None: device_index = x.get_device() if torch.cuda.current_device() != device_index: torch.cuda.set_device(device_index) - stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) dtype_x = TORCH2CUTE_DTYPE[x.dtype] key = ("ptr_fwd", int(N), dtype_x, int(device_index)) @@ -620,6 +1263,27 @@ def _softmax_forward_ptr_into(*, x: Tensor, out: Tensor) -> None: ) _PTR_FWD_COMPILE_CACHE[key] = compiled + launcher = _get_fast_ptr_softmax_launcher( + compiled=compiled, + dtype=dtype_x, + N=int(N), + device_index=int(device_index), + stream_handle=stream_handle, + assumed_align=16, + is_bwd=False, + ) + if launcher is not None: + launcher.launch( + a_ptr=int(x.data_ptr()), + b_ptr=int(out.data_ptr()), + c_ptr=None, + M=int(M), + ld=int(x.stride(0)), + stream_handle=stream_handle, + dtype=dtype_x, + ) + return + ptr_x = rt.make_ptr( dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 ) @@ -642,7 +1306,8 @@ def _softmax_backward_ptr_into(*, dy: Tensor, y: Tensor, dx: Tensor) -> None: device_index = dy.get_device() if torch.cuda.current_device() != device_index: torch.cuda.set_device(device_index) - stream = cuda.CUstream(int(torch.cuda.current_stream().cuda_stream)) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) dtype_x = TORCH2CUTE_DTYPE[dy.dtype] key = ("ptr_bwd", int(N), dtype_x, int(device_index)) @@ -670,6 +1335,27 @@ def _softmax_backward_ptr_into(*, dy: Tensor, y: Tensor, dx: Tensor) -> None: ) _PTR_BWD_COMPILE_CACHE[key] = compiled + launcher = _get_fast_ptr_softmax_launcher( + compiled=compiled, + dtype=dtype_x, + N=int(N), + device_index=int(device_index), + stream_handle=stream_handle, + assumed_align=16, + is_bwd=True, + ) + if launcher is not None: + launcher.launch( + a_ptr=int(dy.data_ptr()), + b_ptr=int(y.data_ptr()), + c_ptr=int(dx.data_ptr()), + M=int(M), + ld=int(dy.stride(0)), + stream_handle=stream_handle, + dtype=dtype_x, + ) + return + ptr_dy = rt.make_ptr( dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 ) @@ -682,6 +1368,81 @@ def _softmax_backward_ptr_into(*, dy: Tensor, y: Tensor, dx: Tensor) -> None: compiled(ptr_dy, ptr_y, ptr_dx, Int32(int(M)), Int32(int(dy.stride(0))), stream) +def _softmax_fwd_bwd_ptr_into(*, x: Tensor, dy: Tensor, dx: Tensor) -> None: + """Launch the fused pointer-based Softmax fwd+bwd kernel into preallocated `dx`.""" + assert x.is_cuda and x.dim() == 2 + assert dy.is_cuda and dy.shape == x.shape and dy.dtype == x.dtype + assert dx.is_cuda and dx.shape == x.shape and dx.dtype == x.dtype + assert x.stride() == dy.stride() == dx.stride(), ( + "Pointer path expects matching strides" + ) + + M, N = x.shape + device_index = x.get_device() + if torch.cuda.current_device() != device_index: + torch.cuda.set_device(device_index) + stream_handle = int(torch.cuda.current_stream().cuda_stream) + stream = cuda.CUstream(stream_handle) + + dtype_x = TORCH2CUTE_DTYPE[x.dtype] + key = ("ptr_fwd_bwd", int(N), dtype_x, int(device_index)) + compiled = _PTR_FWDBWD_COMPILE_CACHE.get(key) + if compiled is None: + op = SoftmaxFwdBwdSM100(dtype_x, int(N)) + ptr_x = rt.make_ptr( + dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_dy = rt.make_ptr( + dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ld = Int32(int(x.stride(0))) + compiled = cute.compile( + op.launch_from_ptrs, + ptr_x, + ptr_dy, + ptr_dx, + Int32(int(M)), + ld, + stream, + ) + _PTR_FWDBWD_COMPILE_CACHE[key] = compiled + + launcher = _get_fast_ptr_softmax_launcher( + compiled=compiled, + dtype=dtype_x, + N=int(N), + device_index=int(device_index), + stream_handle=stream_handle, + assumed_align=16, + is_bwd=True, + ) + if launcher is not None: + launcher.launch( + a_ptr=int(x.data_ptr()), + b_ptr=int(dy.data_ptr()), + c_ptr=int(dx.data_ptr()), + M=int(M), + ld=int(x.stride(0)), + stream_handle=stream_handle, + dtype=dtype_x, + ) + return + + ptr_x = rt.make_ptr( + dtype_x, x.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_dy = rt.make_ptr( + dtype_x, dy.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + ptr_dx = rt.make_ptr( + dtype_x, dx.data_ptr(), mem_space=rt.AddressSpace.gmem, assumed_align=16 + ) + compiled(ptr_x, ptr_dy, ptr_dx, Int32(int(M)), Int32(int(x.stride(0))), stream) + + def softmax_forward(x: Tensor) -> Tensor: """SM100 CuteDSL softmax forward pass: y = softmax(x, dim=-1).""" assert x.dim() == 2, "Input must be 2D (M, N)" @@ -749,6 +1510,31 @@ def softmax_backward(dy: Tensor, y: Tensor) -> Tensor: return dx +def softmax_fwd_bwd(dy: Tensor, x: Tensor) -> Tensor: + """Fused softmax forward+backward producing ``dx`` from ``(x, dy)``. + + This is intended for benchmarks and training-like use-cases where the + intermediate ``y = softmax(x)`` is not needed outside the backward pass. + """ + assert x.dim() == 2 and dy.dim() == 2, "x and dy must be 2D (M, N)" + assert x.shape == dy.shape, "x and dy must have the same shape" + assert x.is_cuda and dy.is_cuda, "x and dy must be on CUDA device" + assert x.dtype in TORCH2CUTE_DTYPE, "Unsupported dtype" + assert dy.dtype == x.dtype, "x and dy must have the same dtype" + + if ( + _can_use_ptr_path_2d(x) + and _can_use_ptr_path_2d(dy) + and x.stride() == dy.stride() + ): + dx = torch.empty_strided(x.shape, x.stride(), device=x.device, dtype=x.dtype) + _softmax_fwd_bwd_ptr_into(x=x, dy=dy, dx=dx) + return dx + + with torch.no_grad(): + return softmax_backward(dy, softmax_forward(x)) + + class SoftmaxFunction(torch.autograd.Function): @staticmethod def forward(ctx, x: Tensor) -> Tensor: