From 48c9b0c21988a50506b521b1a961818dff2f525e Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Tue, 27 Jan 2026 00:03:51 -0800 Subject: [PATCH 1/2] Introducing Database module to kernel_opt --- .../kernel_opt/database/__init__.py | 18 ++ kernel_perf_agent/kernel_opt/database/base.py | 199 ++++++++++++++++++ .../database/code_samples/matadd.py | 75 +++++++ .../database/code_samples/matadd_perst.py | 89 ++++++++ .../code_samples/matadd_tma_device.py | 95 +++++++++ .../database/code_samples/matadd_tma_host.py | 71 +++++++ .../database/code_samples/matmul.py | 98 +++++++++ .../database/code_samples/matmul_sw.py | 105 +++++++++ .../database/code_samples/matmul_tma_host.py | 79 +++++++ .../database/docs/experimental_tma.md | 169 +++++++++++++++ .../kernel_opt/database/docs/on_device_tma.py | 56 +++++ .../kernel_opt/database/docs/on_host_tma.py | 49 +++++ .../kernel_opt/database/docs/persistence.py | 43 ++++ .../kernel_opt/database/docs/pid_swizzle.py | 37 ++++ .../kernel_opt/database/docs/tma.md | 150 +++++++++++++ 15 files changed, 1333 insertions(+) create mode 100644 kernel_perf_agent/kernel_opt/database/__init__.py create mode 100644 kernel_perf_agent/kernel_opt/database/base.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matadd.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matmul.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py create mode 100644 kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/experimental_tma.md create mode 100644 kernel_perf_agent/kernel_opt/database/docs/on_device_tma.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/on_host_tma.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/persistence.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/pid_swizzle.py create mode 100644 kernel_perf_agent/kernel_opt/database/docs/tma.md diff --git a/kernel_perf_agent/kernel_opt/database/__init__.py b/kernel_perf_agent/kernel_opt/database/__init__.py new file mode 100644 index 0000000..b214283 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Database package.""" + +# Database package +__all__ = [] diff --git a/kernel_perf_agent/kernel_opt/database/base.py b/kernel_perf_agent/kernel_opt/database/base.py new file mode 100644 index 0000000..ffb4a0c --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/base.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +from kernel_perf_agent.kernel_opt.database.docs import ( + on_device_tma, + on_host_tma, + persistence, + pid_swizzle, +) + + +class OptNode: + def __init__(self, level: int, dsl: str, opt_desc: str) -> None: + """Initialize the optimization node with the given level, description, and DSL. + :param level: int, Level in the tree + :param dsl: str, DSL used in the node + :param opt_desc: str, Description of the optimization + :param opt_parents: List[str], Parent nodes description + :param opt_children: List[OptNode], Children nodes + """ + + self.level = level # int, Level in the tree + self.dsl = dsl + self.opt_desc = opt_desc # str, Root node description + self.opt_parents = [] # List[str], Parent nodes description + self.opt_children = [] # List[OptNode], Children nodes + + def add_children(self, child_nodes): + """Adds a child node to the current node.""" + self.opt_children.extend(child_nodes) + + def remove_children(self, child_nodes): + """Removes a child node from the current node.""" + for child in child_nodes: + if child in self.opt_children: + self.opt_children.remove(child) + + def add_parents(self, parent_nodes): + """Adds a child node to the current node.""" + self.opt_parents.extend(parent_nodes) + + def remove_parents(self, parent_nodes): + """Removes a child node from the current node.""" + for parent in parent_nodes: + if parent in self.opt_parents: + self.opt_parents.remove(parent) + + def __repr__(self): + """String representation of the node for easy printing.""" + return f"OptNode at level {self.level}: ({self.opt_desc})" + + +class OptHierarchy: + def __init__(self) -> None: + """Initialize the optimization hierarchy with the root node.""" + self.root = OptNode(level=0, dsl="text", opt_desc="root") + + def get_root(self): + return self.root + + def hard_initialize(self, common_path) -> None: + """Hard initialize the hierarchy with pre-programmed database.""" + + # Level 1 nodes - Latency, Memory, Utilization bottlenecks + optnode_latency = OptNode( + level=1, + dsl="text", + opt_desc="""To optimize compute-bound kernels, we employ techniques to reduce kernel execution latency, including: + - Persistent programming style to minimize kernel launch overhead + - Software pipelining to improve instruction-level parallelism and reduce execution time + """, + ) + optnode_memory = OptNode( + level=1, + dsl="text", + opt_desc="""To optimize memory-bound kernels, we employ techniques to improve performance, including: + - PID swizzling to enhance L2 cache locality + - Leveraging new architecture features, such as Tensor Memory Accelerator (TMA) to overlap memory transfers + with compute operations + """, + ) + optnode_utilization = OptNode( + level=1, + dsl="text", + opt_desc="""To optimize kernels that are not fully utilizing hardware resources, we employ techniques + to increase resource utilization and occupancy rates, including: + - Leveraging Tensor Memory Accelerator (TMA) to overlap memory transfers with compute operations + - Enabling warp specializations to improve instruction-level parallelism and reduce register pressure + - Autotuning to identify and apply optimal kernel configurations that maximize resource usage + """, + ) + level_1_opts = [optnode_latency, optnode_memory, optnode_utilization] + self.root.add_children(level_1_opts) + optnode_latency.add_parents([self.root]) + optnode_memory.add_parents([self.root]) + optnode_utilization.add_parents([self.root]) + + # Level 2 nodes - TMA, PID swizzling, persistent programming style + optnode_host_TMA = OptNode( + level=2, dsl="text", opt_desc=on_host_tma.ON_HOST_TMA + ) + optnode_device_TMA = OptNode( + level=2, dsl="text", opt_desc=on_device_tma.ON_DEVICE_TMA + ) + optnode_PID_swizzling = OptNode( + level=2, dsl="text", opt_desc=pid_swizzle.PID_SWIZZLE + ) + optnode_persistence = OptNode( + level=2, dsl="text", opt_desc=persistence.PERSISTENCE + ) + + optnode_latency.add_children([optnode_persistence]) + optnode_memory.add_children( + [ + optnode_host_TMA, + optnode_device_TMA, + optnode_PID_swizzling, + optnode_persistence, + ] + ) + optnode_utilization.add_children([optnode_persistence]) + + optnode_host_TMA.add_parents([optnode_memory]) + optnode_device_TMA.add_parents([optnode_memory]) + optnode_PID_swizzling.add_parents([optnode_memory]) + optnode_persistence.add_parents( + [optnode_latency, optnode_memory, optnode_utilization] + ) + + # Level 3 nodes - code example of each kernel + # common_path="../kernel_opt/database/code_samples/" + optnode_matmul = OptNode( + level=3, dsl="triton", opt_desc=Path(common_path / "matmul.py").read_text() + ) + optnode_matmul_pid_swizzling = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matmul_sw.py").read_text(), + ) + optnode_matmul_tma_host = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matmul_tma_host.py").read_text(), + ) + optnode_matadd = OptNode( + level=3, dsl="triton", opt_desc=Path(common_path / "matadd.py").read_text() + ) + optnode_matadd_persistence = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matadd_perst.py").read_text(), + ) + optnode_matadd_tma_host = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matadd_tma_host.py").read_text(), + ) + optnode_matadd_tma_device = OptNode( + level=3, + dsl="triton", + opt_desc=Path(common_path / "matadd_tma_device.py").read_text(), + ) + + optnode_host_TMA.add_children( + [ + optnode_matmul, + optnode_matmul_tma_host, + optnode_matadd, + optnode_matadd_tma_host, + ] + ) + optnode_device_TMA.add_children([optnode_matadd, optnode_matadd_tma_device]) + optnode_PID_swizzling.add_children( + [optnode_matmul, optnode_matmul_pid_swizzling] + ) + optnode_persistence.add_children([optnode_matadd, optnode_matadd_persistence]) + + optnode_matmul.add_parents([optnode_host_TMA, optnode_PID_swizzling]) + optnode_matmul_pid_swizzling.add_parents([optnode_PID_swizzling]) + optnode_matmul_tma_host.add_parents([optnode_host_TMA]) + optnode_matadd.add_parents( + [optnode_host_TMA, optnode_device_TMA, optnode_persistence] + ) + optnode_matadd_persistence.add_parents([optnode_persistence]) + optnode_matadd_tma_host.add_parents([optnode_host_TMA]) + optnode_matadd_tma_device.add_parents([optnode_device_TMA]) diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py new file mode 100644 index 0000000..ec71868 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================ unoptimized matadd ================================= +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + M, + N, + stride_m, + stride_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # Range of pointers for loading the block of A and B. + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + x_ptrs = x_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + y_ptrs = y_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + output_ptrs = output_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + data_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + + x = tl.load(x_ptrs, mask=data_mask, other=0.0) + y = tl.load(y_ptrs, mask=data_mask, other=0.0) + output = x + y + tl.store(output_ptrs, output, mask=data_mask) + + +def add(x: torch.Tensor, y: torch.Tensor): + M, N = x.shape + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( + x, + y, + output, + M, + N, + x.stride(0), + x.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return output diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py new file mode 100644 index 0000000..99b4c79 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ===================== matadd with persistent programming style ================== +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + M, + N, + stride_m, + stride_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + NUM_SMS: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + num_tiles = num_pid_m * num_pid_n + + # iterate over the program id with a stride of the total number of blocks + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + + # Range of pointers for loading the block of A and B. + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + x_ptrs = x_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + y_ptrs = y_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) + output_ptrs = output_ptr + ( + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + ) + data_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + + x = tl.load(x_ptrs, mask=data_mask, other=0.0) + y = tl.load(y_ptrs, mask=data_mask, other=0.0) + output = x + y + tl.store(output_ptrs, output, mask=data_mask) + + +def add(x: torch.Tensor, y: torch.Tensor): + M, N = x.shape + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + # Get the number of streaming multiprocessors and use it to launch a fixed number of blocks + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + grid = lambda meta: ( + min( + NUM_SMS, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ), + ) + add_kernel[grid]( + x, + y, + output, + M, + N, + x.stride(0), + x.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + NUM_SMS=NUM_SMS, + ) + return output diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py new file mode 100644 index 0000000..274e172 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ======== matadd with on-device Tensor Memory Accelerator (TMA) integration ========== +from typing import Optional + +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel( + x_ptr, + y_ptr, + output_ptr, + M, + N, + stride_m, + stride_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # device TMA + x_desc = tl.make_tensor_descriptor( + x_ptr, + shape=[M, N], + strides=[stride_m, stride_n], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + y_desc = tl.make_tensor_descriptor( + y_ptr, + shape=[M, N], + strides=[stride_m, stride_n], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + output_desc = tl.make_tensor_descriptor( + output_ptr, + shape=[M, N], + strides=[stride_m, stride_n], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], + ) + + x = x_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) + y = y_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) + output = x + y + output_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], output) + + +def add(x: torch.Tensor, y: torch.Tensor): + M, N = x.shape + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( + x, + y, + output, + M, + N, + x.stride(0), + x.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return output diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py new file mode 100644 index 0000000..7b65f1b --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ======== matadd with on-host Tensor Memory Accelerator (TMA) integration ========== +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def add_kernel( + x_desc, + y_desc, + output_desc, + M, + N, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + x = x_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) + y = y_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) + output = x + y + output_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], output) + + +def add(x: torch.Tensor, y: torch.Tensor): + M, N = x.shape + output = torch.empty((M, N), device=x.device, dtype=torch.float16) + + # TMA descriptors for loading A, B and storing C + x_desc = TensorDescriptor(x, x.shape, x.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N]) + y_desc = TensorDescriptor(y, y.shape, y.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N]) + output_desc = TensorDescriptor( + output, output.shape, output.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( + x_desc, + y_desc, + output_desc, + M, + N, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return output diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py new file mode 100644 index 0000000..08f1225 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================ unoptimized matmul ================================= +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + # Range of pointers for loading the block of A and B. + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + return c diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py new file mode 100644 index 0000000..61ca813 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ==================== matmul with PID swizzling ================================= +import torch +import triton +import triton.language as tl + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 128 +GROUP_SIZE_M = 8 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + # Range of pointers for loading the block of A and B. + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + ) + return c diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py new file mode 100644 index 0000000..0e2b0af --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ======== matmul with on-host Tensor Memory Accelerator (TMA) integration ========== +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +BLOCK_SIZE_M = 128 +BLOCK_SIZE_N = 128 +BLOCK_SIZE_K = 128 +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def matmul_kernel( + a_desc, + b_desc, + c_desc, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = a_desc.load([pid_m * BLOCK_SIZE_M, k * BLOCK_SIZE_K]) # TMA load of A + b = b_desc.load([k * BLOCK_SIZE_K, pid_n * BLOCK_SIZE_N]) # TMA load of B + accumulator = tl.dot(a, b, accumulator) + c = accumulator.to(tl.float16) + c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c) + + +def matmul(a, b): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + + # TMA descriptors for loading A, B and storing C + a_desc = TensorDescriptor(a, a.shape, a.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_K]) + b_desc = TensorDescriptor(b, b.shape, b.stride(), [BLOCK_SIZE_K, BLOCK_SIZE_N]) + c_desc = TensorDescriptor(c, c.shape, c.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N]) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a_desc, + b_desc, + c_desc, + M, + N, + K, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + return c diff --git a/kernel_perf_agent/kernel_opt/database/docs/experimental_tma.md b/kernel_perf_agent/kernel_opt/database/docs/experimental_tma.md new file mode 100644 index 0000000..7edbad0 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/experimental_tma.md @@ -0,0 +1,169 @@ + + +# Triton Tutorial: How to integrate NV TMA into kernels +## Background +TMA is a hardware unit introduced by the NV Hopper GPU. It takes over some of the data transfer work from softwares and thus improves the performance by freeing up warps or reducing register pressures etc. In practice, Triton kernel authors can update the kernel code by simply replacing `tl.load` and `tl.store` with TMA API calls to get this performance boost. + +## TMA APIs +TMA API is going through changes (from experimental to official) on upstream Triton. While we’re working out a plan to migrate, we’ll support the “old” experimental API that’s currently being used in our fbsource codebase. This tutorial will be based on the experimental API. + +TMA data load/store needs a TMA tensor descriptor object. The descriptor will describe the tensor address, strides, shapes etc. of the tensor to be copied (treat it as the CUDA `TensorMap` object). The descriptor itself needs to be stored somewhere. Depending on where we initialize the descriptor, we have two types of descriptors: on-host and on-device. The former allocates the memory on host memory, initializes descriptors there and then copies them by value to GMEM. The latter will allocate a big chunk of memory on GMEM, and then have each program to find their own offset and initialize descriptors there. + +To leverage TMA, we need to decide between on-host and on-device descriptors. That decision could be yet another topic. Here we quickly highlight a few key differences: +- Only on-device descriptors can handle dynamic shapes where not all programs are handling the same box size, which is typical in kernels like Jagged Flash Attention or HSTU. The reason is that on-host descriptors are initialized before kernel launch while on-device ones are initialized in the kernel where the box size is known. +- Torch Inductor, especially AOTI, currently only supports on-device descriptors +- On-device descriptors are initialized by every kernel program in the grid while on-host ones are initialized by host code so likely on-device descriptors take more compute resources +- Current on-device descriptors implementation (experimental API) might take more global memory because the number of programs is not necessarily known when allocating memory chunk for descriptors (e.g. depending on auto tuned BLOCK_SIZE_M), so we need to be conservative and allocate more memory + +Note: neither of these two types of TMA is necessarily faster than the other. It depends on actual use cases. + +Now for the sake of this tutorial we’ll start with on-device descriptors. And also we’ll use the example of copying 2d tensors as it’s the most common. + +With those premises, here’re the APIs to call: + +- Allocate memory chunk to store descriptors on host: +``` +TMA_DESC_SIZE = 128 # size in bytes used by a single descriptor, tunable +NUM_DESC_PER_PROGRAM = ... # how many different tensors to load/store by each program. e.g. 3 for GEMM `C=AB`, 4 for HSTU Q,K,V,O tensors +NUM_OF_PROGRAMS = ... # same as specified in kernel `grid`. If grid size is related to auto tune config, use a reasonable upper bound by hard coding "minimal block M size" etc. for now. +workspace = torch.empty( + TMA_DESC_SIZE * NUM_DESC_PER_PROGRAM * NUM_OF_PROGRAMS, + dtype=torch.uint8, + device="cuda",) +# then pass `workspace` to kernel +``` +- Initialize descriptor object: +``` +desc_ptr = workspace + TMA_DESC_SIZE * + TMA_DESC_SIZE * # in program offset in range [0,NUM_DESC_PER_PROGRAM) + + +tl.extra.cuda.experimental_device_tensormap_create2d( +desc_ptr=desc_ptr, +global_address=, # tensor to load into or store from +load_size=[BOX_SIZE_0, BOX_SIZE_1], # size of the 2D box to copy +global_size=[GLOBAL_SIZE_0, GLOBAL_SIZE_1], # this defines a "global box" in GMEM. TMA load/store won't go over this boundary if load_size is not divisble by global_size. e.g. Assuming GLOBAL_SIZE_0 == 1.5 * BLOCK_SIZE_0 and GLOBAL_SIZE_1 == BLOCK_SIZE_1, then: for TMA load, the second box will return a tensor of size (BLOCK_SIZE_0, BLOCK_SIZE_1) but the second half of the tensor is all 0; for TMA store, the second box will only have its first half written to GMEM. +element_ty= # usually tensor_ptr.dtype.element_ty +) +``` +- Acquire fence on a TensorMap/descriptor object: +``` +tl.extra.cuda.experimental_tensormap_fenceproxy_acquire() +``` +- Load data from GMEM to SMEM: +``` +x = tl._experimental_descriptor_load( + , #initialized, and acquired fence above + [OFFSET_0, OFFSET_1], # offset in "global box" for the 2D loading box to start from + [BOX_SIZE_0, BOX_SIZE_1], # keep the same as descriptor's `load_size` + ,) +``` +- Store data from SMEM to GMEM: +``` +tl._experimental_descriptor_store( + , #initialized, and acquired fence above + , #the tensor to be stored on GMEM + [OFFSET_0, OFFSET_1], # offset in "global box" for the 2D loading box to start from +) +``` + +## Example +### Store +Let’s assume we have the following non TMA store code now: + +``` +start_m = pid * BLOCK_M +offs_m = start_m + tl.arange(0, BLOCK_M) +offs_v_d = tl.arange(0, BLOCK_D_V) +off_o = Out + seq_start * stride_om + off_h * stride_oh # TMA will use Out as global address, and include seq_start * stride_om + off_h * stride_oh as part of offsets +out_ptrs = off_o + offs_m[:, None] * stride_om + offs_v_d[None, :] +tl.store(out_ptrs, acc, mask=(offs_m < seq_len)[:, None]) + +# Essentially, it tries to store the tensor `acc` into this box: +# Out[ +# (seq_start + pid * BLOCK_M : seq_start + (pid+1) * BLOCK_M), +# (off_h * stride_oh : off_h * stride_oh + BLOCK_D_V) +# ] +# In other words, it's a box of size (BLOCK_M, BLOCK_D_V) starting at [seq_start + pid * BLOCK_M, off_h * stride_oh]. This will be the bases for our TMA desc init and load/store op. +# And the rows with dim0 larger than (seq_start + seq_len) will be masked. Note that (seq_start + seq_len) == seq_end, which we'll use in TMA store below +``` +The equivalent TMA store code would be: +``` +# pyre-ignore [20] +tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=device_desc_o, + global_address=Out, # Out is of shape (L, H, DimV) + load_size=[BLOCK_M, BLOCK_D_V], #box size as explained in comments above + global_size=[seq_end.to(tl.int32), H * DimV], # this eliminates the need for `mask`, TMA automatically take care of boundaries. + element_ty=Out.dtype.element_ty, +) +# pyre-ignore [20] +tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(device_desc_o) +tl._experimental_descriptor_store( + device_desc_o, + acc, # acc needs to be casted to the right dtype + [ #offset as explained in comments above (where the box starts at) + (seq_start + pid * BLOCK_M).to(tl.int32), + (off_h * stride_oh).to(tl.int32), + ], + ) +``` +### Load +Assume we have this non TMA load code: +``` +Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + seq_start * stride_qm, + shape=(seq_len, BLOCK_D_Q), + strides=(stride_qm, 1), + offsets=(start_m, 0), + block_shape=(BLOCK_M, BLOCK_D_Q), + order=(1, 0), + ) +q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") + + +# Essentially this tries to load this box into q: +# Q[ +# (seq_start + start_m : seq_start + start_m + BLOCK_M), +# (off_h * stride_qh : off_h * stride_qh + BLOCK_D_Q) +# ] +# In other words, it's a box of size (BLOCK_M, BLOCK_D_Q) starting at [seq_start + start_m, off_h * stride_qh]. This will be the bases for our TMA desc init and load/store op. +# And the rows with dim0 larger than seq_len will be filled with zero, with shape of q always being (BLOCK_M, BLOCK_D_Q). +``` +The equivalent TMA load code will be: +``` +# pyre-ignore [20] +tl.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=device_desc_q, + global_address=Q, # shape (L, H, DimQ) + load_size=[BLOCK_M,BLOCK_D_Q], #box size as explained in comments above + global_size=[seq_end.to(tl.int32), H * DimQ], # seq_end == seq_start + seq_len + element_ty=Q.dtype.element_ty, + ) +# pyre-ignore [20] + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(device_desc_q) + + +q = tl._experimental_descriptor_load( + device_desc_q, + [ #offset as explained in comments above (where the box starts at) + (seq_start + start_m).to(tl.int32), + (off_h * stride_qh).to(tl.int32), + ], + [BLOCK_M,BLOCK_D_Q], + Q.dtype.element_ty, + ) +``` diff --git a/kernel_perf_agent/kernel_opt/database/docs/on_device_tma.py b/kernel_perf_agent/kernel_opt/database/docs/on_device_tma.py new file mode 100644 index 0000000..45d1deb --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/on_device_tma.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ON_DEVICE_TMA = """ +============================= On-Device Tensor Memory Accelerator (TMA) =================================== +## What is TMA? +The Tensor Memory Accelerator (TMA) is a hardware feature introduced in NVIDIA Hopper GPUs +for performing asynchronous memory copies between a GPU's global memory (GMEM) and the +shared memory (SMEM) of its thread blocks (i.e., CTAs). TMA offloads some of the data +transfer work from software, thereby improving performance by overlapping memory transfers +with computation, freeing up warps, and reducing register pressure. + +## On-Device TMA: +TMA data load/store operations require a TMA tensor descriptor object. This descriptor +specifies the tensor's address, strides, shapes, and other attributes necessary for the +copy operation. TMA descriptors can be initialized on the device. On-device descriptors +allocate a large chunk of memory in GMEM, and each program have to find its own offset +and initialize descriptors there. + +## How to integrate on-device TMA into a Triton program? +To enable on-device TMA in a Triton program, we need to add support from both the host and kernel programs. +In the host program, a global memory allocation is needed by adding the following function: +``` +def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + +triton.set_allocator(alloc_fn) +``` +In addition, we need to import the method `from typing import Optional`. +In the kernel program, instead of loading and storing a tensor block with a range of pointers, +we declare a TMA descriptor for each tensor and then use the descriptor to load and store the tensor in blocks. +An example of a TMA descriptor declaration is +``` +x_desc = tl.make_tensor_descriptor( + x_ptr, # the pointer to the tensor + shape=[M, N], # the shape of the tensor + strides=[stride_m, stride_n], # the stride of the tensor + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], # the block size of each TMA load/store +) +``` +An example of the TMA load is +``` +x = x_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) # the start offset of the TMA load +``` +""" diff --git a/kernel_perf_agent/kernel_opt/database/docs/on_host_tma.py b/kernel_perf_agent/kernel_opt/database/docs/on_host_tma.py new file mode 100644 index 0000000..31b169d --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/on_host_tma.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ON_HOST_TMA = """ +============================= On-Host Tensor Memory Accelerator (TMA) =================================== +## What is TMA? +The Tensor Memory Accelerator (TMA) is a hardware feature introduced in NVIDIA Hopper GPUs +for performing asynchronous memory copies between a GPU's global memory (GMEM) and the +shared memory (SMEM) of its thread blocks (i.e., CTAs). TMA offloads some of the data +transfer work from software, thereby improving performance by overlapping memory transfers +with computation, freeing up warps, and reducing register pressure. + +## On-Host TMA: +TMA data load/store operations require a TMA tensor descriptor object. This descriptor +specifies the tensor's address, strides, shapes, and other attributes necessary for the +copy operation. TMA descriptors can be initialized on the host. On-host descriptors +allocate memory in the host memory, initialize the descriptors there, and then copy +them by value to GMEM. + +## How to integrate on-host TMA into a Triton program? +To enable on-host TMA in a Triton program, we need to add support on both the host and kernel programs. +In the host program, we allocate a TMA descriptor for each tensor and pass the descriptor as an argument to the kernel. +An example of a TMA descriptor declaration is +``` +x_desc = TensorDescriptor( + x, # the pointer to the tensor + x.shape, # the shape of the tensor + x.stride(), # the stride of the tensor + [BLOCK_SIZE_M, BLOCK_SIZE_N] # the block size of each TMA load/store +) +``` +And in addition, we need to import the method `from triton.tools.tensor_descriptor import TensorDescriptor`. +In the kernel program, instead of loading and storing a tensor block with a range of pointers, +we use the TMA descriptor to load and store the tensor in blocks. An example of the TMA load is +``` +x = x_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) # the start offset of the TMA load +``` +""" diff --git a/kernel_perf_agent/kernel_opt/database/docs/persistence.py b/kernel_perf_agent/kernel_opt/database/docs/persistence.py new file mode 100644 index 0000000..755743e --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/persistence.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +PERSISTENCE = """ +================================ Persistent Programming Style ======================================= +## What it is: +The persistent programming style in GPU is a kernel design pattern where a fixed number of +blocks is launched, typically equal to the number of streaming multiprocessors (SMs), +instead of launching blocks proportional to the problem size. This pattern is particularly effective +for large-scale computations where the problem size exceeds the GPU's parallel capacity. + +## Traditional Approach: +In an unoptimized Triton GPU kernel, the number of blocks launched is dependent on the input size, +typically calculated as `triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]` +in the grid argument. +Each block processes exactly one tile of work, and the number of blocks can be much larger +than the available hardware resources. + +## Persistent Approach: +In a persistent style implementation, a fixed number of blocks is launched, which can be the number +of streaming multiprocessors (SMs) on the GPU by calling `torch.cuda.get_device_properties("cuda").multi_processor_count`. +In the kernel code, each block iterates over the program ID with a stride equal to the total number of blocks, +ensuring that the computation is completed by a fixed number of blocks. +These blocks "persist" and loop until all work is completed. + +## Advantages: +* Better resource utilization: Matches hardware capabilities exactly +* Reduced launch overhead: Fewer kernel launches for large problems +* Improved occupancy: Keeps all SMs busy throughout execution +* Better cache locality: Blocks can reuse data across multiple iterations +* Load balancing: Work is distributed more evenly across SMs +""" diff --git a/kernel_perf_agent/kernel_opt/database/docs/pid_swizzle.py b/kernel_perf_agent/kernel_opt/database/docs/pid_swizzle.py new file mode 100644 index 0000000..acb14ed --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/pid_swizzle.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +PID_SWIZZLE = """ +===================================== PID Swizzling =========================================== +## What it is: +PID swizzling is a GPU optimization technique used in Triton programming that remaps +program identifiers (`pid_m` and `pid_n`) to create better memory access patterns, +specifically for L2 cache locality. This technique is commonly used in high-performance GPU kernels, +particularly for GEMM (General Matrix Multiply) operations in frameworks like Triton. + +## Traditional Approach: +The program launch order matters as it affects the L2 cache hit rate. +In an unoptimized GPU kernel, each program instance computes a [BLOCK_SIZE_M, BLOCK_SIZE_N] +block of the output tensor, and the program identifiers are arranged in a simple row-major ordering +by `pid_m = pid // num_pid_n` and `pid_n = pid % num_pid_n`. +This creates poor cache locality because adjacent programs access memory locations that are far apart. + +## PID Swizzling Approach: +PID swizzling forms "super-grouping" of programs with a fixed row size `GROUP_SIZE_M`. +The number of programs in a group is `GROUP_SIZE_M * num_pid_n`. +The `group_id` is calculated by dividing the program id by the number of programs in a group. +If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the row size of the last group is smaller +and can be calculated by subtracting `GROUP_SIZE_M * group_id` from `num_pid_m`. +The programs within a group are arranged in a column-major order. +""" diff --git a/kernel_perf_agent/kernel_opt/database/docs/tma.md b/kernel_perf_agent/kernel_opt/database/docs/tma.md new file mode 100644 index 0000000..89087d9 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/database/docs/tma.md @@ -0,0 +1,150 @@ +**TMA (Tensor Memory Accelerator)** is a hardware feature in NVIDIA GPUs that accelerates memory transfers for tensor operations by providing more efficient block-based memory access patterns. + +What is TMA? +------------ + +TMA replaces traditional pointer-based memory access with **tensor descriptors** that describe the entire tensor layout, enabling the GPU hardware to optimize memory transfers automatically. + +Benefits of TMA: +---------------- + +* **Hardware-accelerated memory transfers** +* **Better memory coalescing** +* **Reduced memory access overhead** +* **Simplified memory access patterns** + +How to Add TMA to Triton Code +----------------------------- + +There are two approaches: **Host-side TMA** and **Device-side TMA**. + +### 1. Host-side TMA Implementation + +**Host-side setup:** + +``` +from triton.tools.tensor_descriptor import TensorDescriptor + +def matmul_with_tma(a, b): + # Create TMA descriptors on host + a_desc = TensorDescriptor( + a, # the tensor + a.shape, # tensor shape + a.stride(), # tensor strides + [BLOCK_SIZE_M, BLOCK_SIZE_K] # block size for TMA operations + ) + + b_desc = TensorDescriptor( + b, + b.shape, + b.stride(), + [BLOCK_SIZE_K, BLOCK_SIZE_N] + ) + + c_desc = TensorDescriptor( + c, + c.shape, + c.stride(), + [BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + + # Pass descriptors to kernel + kernel[grid](a_desc, b_desc, c_desc, ...) +``` + +**Kernel-side usage:** + +``` +@triton.jit +def matmul_kernel(a_desc, b_desc, c_desc, ...): + pid = tl.program_id(axis=0) + # Calculate tile positions + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # Load using TMA descriptors + a = a_desc.load([pid_m * BLOCK_SIZE_M, 0]) # offset coordinates + b = b_desc.load([0, pid_n * BLOCK_SIZE_N]) + + # Compute + accumulator = tl.dot(a, b) + + # Store using TMA descriptor + c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], accumulator) +``` + +### 2. Device-side TMA Implementation + +**Host-side setup:** + +``` +from typing import Optional + +def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + +# Set custom allocator for TMA +triton.set_allocator(alloc_fn) +``` + +**Kernel-side usage:** + +``` +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, ...): + # Create TMA descriptors in kernel + a_desc = tl.make_tensor_descriptor( + a_ptr, # pointer to tensor + shape=[M, K], # tensor shape + strides=[stride_am, stride_ak], # tensor strides + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K] # TMA block size + ) + + b_desc = tl.make_tensor_descriptor( + b_ptr, + shape=[K, N], + strides=[stride_bk, stride_bn], + block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N] + ) + + c_desc = tl.make_tensor_descriptor( + c_ptr, + shape=[M, N], + strides=[stride_cm, stride_cn], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N] + ) + + # Use descriptors for memory operations + pid = tl.program_id(axis=0) + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + # Load blocks using TMA + a = a_desc.load([pid_m * BLOCK_SIZE_M, 0]) + b = b_desc.load([0, pid_n * BLOCK_SIZE_N]) + + # Compute and store + result = tl.dot(a, b) + c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], result) +``` + +Key Differences from Traditional Approach: +------------------------------------------ + +**Traditional:** + +``` +# Manual pointer arithmetic +offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) +a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak +a = tl.load(a_ptrs, mask=...) +``` + +**TMA:** + +``` +# Descriptor-based access +a = a_desc.load([pid_m * BLOCK_SIZE_M, k_offset]) +``` + +TMA simplifies memory access patterns and leverages hardware acceleration for better performance in tensor operations. From 84708fd8bca0ea81dc6e321f5f21d8f27ce4bb0f Mon Sep 17 00:00:00 2001 From: Kaiming Cheng Date: Tue, 27 Jan 2026 00:37:17 -0800 Subject: [PATCH 2/2] fix ruff --- .../kernel_opt/database/code_samples/matadd.py | 10 ++++++---- .../database/code_samples/matadd_perst.py | 16 ++++++++++------ .../database/code_samples/matadd_tma_device.py | 10 ++++++---- .../database/code_samples/matadd_tma_host.py | 10 ++++++---- .../kernel_opt/database/code_samples/matmul.py | 16 +++++++++------- .../database/code_samples/matmul_sw.py | 8 +++++--- .../database/code_samples/matmul_tma_host.py | 10 ++++++---- 7 files changed, 48 insertions(+), 32 deletions(-) diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py index ec71868..20cf664 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd.py @@ -35,7 +35,7 @@ def add_kernel( BLOCK_SIZE_N: tl.constexpr, ): pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + _num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n pid_n = pid % num_pid_n @@ -58,9 +58,11 @@ def add(x: torch.Tensor, y: torch.Tensor): M, N = x.shape output = torch.empty((M, N), device=x.device, dtype=torch.float16) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( x, y, diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py index 99b4c79..33015ea 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_perst.py @@ -68,12 +68,16 @@ def add(x: torch.Tensor, y: torch.Tensor): # Get the number of streaming multiprocessors and use it to launch a fixed number of blocks NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - grid = lambda meta: ( - min( - NUM_SMS, - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ), - ) + + def grid(meta): + return ( + min( + NUM_SMS, + triton.cdiv(M, meta["BLOCK_SIZE_M"]) + * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ), + ) + add_kernel[grid]( x, y, diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py index 274e172..3d4d6c1 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_device.py @@ -37,7 +37,7 @@ def add_kernel( BLOCK_SIZE_N: tl.constexpr, ): pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + _num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n pid_n = pid % num_pid_n @@ -78,9 +78,11 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): triton.set_allocator(alloc_fn) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( x, y, diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py index 7b65f1b..81b4761 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matadd_tma_host.py @@ -34,7 +34,7 @@ def add_kernel( BLOCK_SIZE_N: tl.constexpr, ): pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + _num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n pid_n = pid % num_pid_n @@ -56,9 +56,11 @@ def add(x: torch.Tensor, y: torch.Tensor): output, output.shape, output.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N] ) - grid = lambda meta: ( - triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), - ) + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + add_kernel[grid]( x_desc, y_desc, diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py index 08f1225..7f78258 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul.py @@ -1,11 +1,11 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# + # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# + # http://www.apache.org/licenses/LICENSE-2.0 -# + # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -42,7 +42,7 @@ def matmul_kernel( BLOCK_SIZE_K: tl.constexpr, ): pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + _num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n pid_n = pid % num_pid_n @@ -75,9 +75,11 @@ def matmul(a, b): K, N = b.shape c = torch.empty((M, N), device=a.device, dtype=torch.float16) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( a, b, diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py index 61ca813..1e2e468 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_sw.py @@ -81,9 +81,11 @@ def matmul(a, b): K, N = b.shape c = torch.empty((M, N), device=a.device, dtype=torch.float16) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( a, b, diff --git a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py index 0e2b0af..dc4f414 100644 --- a/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py +++ b/kernel_perf_agent/kernel_opt/database/code_samples/matmul_tma_host.py @@ -37,7 +37,7 @@ def matmul_kernel( BLOCK_SIZE_K: tl.constexpr, ): pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + _num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n pid_n = pid % num_pid_n @@ -62,9 +62,11 @@ def matmul(a, b): b_desc = TensorDescriptor(b, b.shape, b.stride(), [BLOCK_SIZE_K, BLOCK_SIZE_N]) c_desc = TensorDescriptor(c, c.shape, c.stride(), [BLOCK_SIZE_M, BLOCK_SIZE_N]) - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( a_desc, b_desc,