-
Notifications
You must be signed in to change notification settings - Fork 28
Add Knowledge Database to Kernel optimization #85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Database package.""" | ||
|
|
||
| # Database package | ||
| __all__ = [] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,199 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
|
||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
|
|
||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from pathlib import Path | ||
|
|
||
| from kernel_perf_agent.kernel_opt.database.docs import ( | ||
| on_device_tma, | ||
| on_host_tma, | ||
| persistence, | ||
| pid_swizzle, | ||
| ) | ||
|
|
||
|
|
||
| class OptNode: | ||
| def __init__(self, level: int, dsl: str, opt_desc: str) -> None: | ||
| """Initialize the optimization node with the given level, description, and DSL. | ||
| :param level: int, Level in the tree | ||
| :param dsl: str, DSL used in the node | ||
| :param opt_desc: str, Description of the optimization | ||
| :param opt_parents: List[str], Parent nodes description | ||
| :param opt_children: List[OptNode], Children nodes | ||
| """ | ||
|
|
||
| self.level = level # int, Level in the tree | ||
| self.dsl = dsl | ||
| self.opt_desc = opt_desc # str, Root node description | ||
| self.opt_parents = [] # List[str], Parent nodes description | ||
| self.opt_children = [] # List[OptNode], Children nodes | ||
|
|
||
| def add_children(self, child_nodes): | ||
| """Adds a child node to the current node.""" | ||
| self.opt_children.extend(child_nodes) | ||
|
|
||
| def remove_children(self, child_nodes): | ||
| """Removes a child node from the current node.""" | ||
| for child in child_nodes: | ||
| if child in self.opt_children: | ||
| self.opt_children.remove(child) | ||
|
|
||
| def add_parents(self, parent_nodes): | ||
| """Adds a child node to the current node.""" | ||
| self.opt_parents.extend(parent_nodes) | ||
|
|
||
| def remove_parents(self, parent_nodes): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this for any reason? |
||
| """Removes a child node from the current node.""" | ||
| for parent in parent_nodes: | ||
| if parent in self.opt_parents: | ||
| self.opt_parents.remove(parent) | ||
|
|
||
| def __repr__(self): | ||
| """String representation of the node for easy printing.""" | ||
| return f"OptNode at level {self.level}: ({self.opt_desc})" | ||
|
|
||
|
|
||
| class OptHierarchy: | ||
| def __init__(self) -> None: | ||
| """Initialize the optimization hierarchy with the root node.""" | ||
| self.root = OptNode(level=0, dsl="text", opt_desc="root") | ||
|
|
||
| def get_root(self): | ||
| return self.root | ||
|
|
||
| def hard_initialize(self, common_path) -> None: | ||
| """Hard initialize the hierarchy with pre-programmed database.""" | ||
|
|
||
| # Level 1 nodes - Latency, Memory, Utilization bottlenecks | ||
| optnode_latency = OptNode( | ||
| level=1, | ||
| dsl="text", | ||
| opt_desc="""To optimize compute-bound kernels, we employ techniques to reduce kernel execution latency, including: | ||
| - Persistent programming style to minimize kernel launch overhead | ||
| - Software pipelining to improve instruction-level parallelism and reduce execution time | ||
| """, | ||
| ) | ||
| optnode_memory = OptNode( | ||
| level=1, | ||
| dsl="text", | ||
| opt_desc="""To optimize memory-bound kernels, we employ techniques to improve performance, including: | ||
| - PID swizzling to enhance L2 cache locality | ||
| - Leveraging new architecture features, such as Tensor Memory Accelerator (TMA) to overlap memory transfers | ||
| with compute operations | ||
| """, | ||
| ) | ||
| optnode_utilization = OptNode( | ||
| level=1, | ||
| dsl="text", | ||
| opt_desc="""To optimize kernels that are not fully utilizing hardware resources, we employ techniques | ||
| to increase resource utilization and occupancy rates, including: | ||
| - Leveraging Tensor Memory Accelerator (TMA) to overlap memory transfers with compute operations | ||
| - Enabling warp specializations to improve instruction-level parallelism and reduce register pressure | ||
| - Autotuning to identify and apply optimal kernel configurations that maximize resource usage | ||
| """, | ||
| ) | ||
| level_1_opts = [optnode_latency, optnode_memory, optnode_utilization] | ||
| self.root.add_children(level_1_opts) | ||
| optnode_latency.add_parents([self.root]) | ||
| optnode_memory.add_parents([self.root]) | ||
| optnode_utilization.add_parents([self.root]) | ||
|
Comment on lines
+105
to
+109
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: For legibility can we add a helper like It's easy to parse here, but level3 is a harder to parse |
||
|
|
||
| # Level 2 nodes - TMA, PID swizzling, persistent programming style | ||
| optnode_host_TMA = OptNode( | ||
| level=2, dsl="text", opt_desc=on_host_tma.ON_HOST_TMA | ||
| ) | ||
| optnode_device_TMA = OptNode( | ||
| level=2, dsl="text", opt_desc=on_device_tma.ON_DEVICE_TMA | ||
| ) | ||
| optnode_PID_swizzling = OptNode( | ||
| level=2, dsl="text", opt_desc=pid_swizzle.PID_SWIZZLE | ||
| ) | ||
| optnode_persistence = OptNode( | ||
| level=2, dsl="text", opt_desc=persistence.PERSISTENCE | ||
| ) | ||
|
|
||
| optnode_latency.add_children([optnode_persistence]) | ||
| optnode_memory.add_children( | ||
| [ | ||
| optnode_host_TMA, | ||
| optnode_device_TMA, | ||
| optnode_PID_swizzling, | ||
| optnode_persistence, | ||
| ] | ||
| ) | ||
| optnode_utilization.add_children([optnode_persistence]) | ||
|
|
||
| optnode_host_TMA.add_parents([optnode_memory]) | ||
| optnode_device_TMA.add_parents([optnode_memory]) | ||
| optnode_PID_swizzling.add_parents([optnode_memory]) | ||
| optnode_persistence.add_parents( | ||
| [optnode_latency, optnode_memory, optnode_utilization] | ||
| ) | ||
|
|
||
| # Level 3 nodes - code example of each kernel | ||
| # common_path="../kernel_opt/database/code_samples/" | ||
| optnode_matmul = OptNode( | ||
| level=3, dsl="triton", opt_desc=Path(common_path / "matmul.py").read_text() | ||
| ) | ||
| optnode_matmul_pid_swizzling = OptNode( | ||
| level=3, | ||
| dsl="triton", | ||
| opt_desc=Path(common_path / "matmul_sw.py").read_text(), | ||
| ) | ||
| optnode_matmul_tma_host = OptNode( | ||
| level=3, | ||
| dsl="triton", | ||
| opt_desc=Path(common_path / "matmul_tma_host.py").read_text(), | ||
| ) | ||
| optnode_matadd = OptNode( | ||
| level=3, dsl="triton", opt_desc=Path(common_path / "matadd.py").read_text() | ||
| ) | ||
| optnode_matadd_persistence = OptNode( | ||
| level=3, | ||
| dsl="triton", | ||
| opt_desc=Path(common_path / "matadd_perst.py").read_text(), | ||
| ) | ||
| optnode_matadd_tma_host = OptNode( | ||
| level=3, | ||
| dsl="triton", | ||
| opt_desc=Path(common_path / "matadd_tma_host.py").read_text(), | ||
| ) | ||
| optnode_matadd_tma_device = OptNode( | ||
| level=3, | ||
| dsl="triton", | ||
| opt_desc=Path(common_path / "matadd_tma_device.py").read_text(), | ||
| ) | ||
|
|
||
| optnode_host_TMA.add_children( | ||
| [ | ||
| optnode_matmul, | ||
| optnode_matmul_tma_host, | ||
| optnode_matadd, | ||
| optnode_matadd_tma_host, | ||
| ] | ||
| ) | ||
| optnode_device_TMA.add_children([optnode_matadd, optnode_matadd_tma_device]) | ||
| optnode_PID_swizzling.add_children( | ||
| [optnode_matmul, optnode_matmul_pid_swizzling] | ||
| ) | ||
| optnode_persistence.add_children([optnode_matadd, optnode_matadd_persistence]) | ||
|
|
||
| optnode_matmul.add_parents([optnode_host_TMA, optnode_PID_swizzling]) | ||
| optnode_matmul_pid_swizzling.add_parents([optnode_PID_swizzling]) | ||
| optnode_matmul_tma_host.add_parents([optnode_host_TMA]) | ||
| optnode_matadd.add_parents( | ||
| [optnode_host_TMA, optnode_device_TMA, optnode_persistence] | ||
| ) | ||
| optnode_matadd_persistence.add_parents([optnode_persistence]) | ||
| optnode_matadd_tma_host.add_parents([optnode_host_TMA]) | ||
| optnode_matadd_tma_device.add_parents([optnode_device_TMA]) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # ============================ unoptimized matadd ================================= | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| BLOCK_SIZE_M = 128 | ||
| BLOCK_SIZE_N = 128 | ||
| DEVICE = triton.runtime.driver.active.get_active_torch_device() | ||
|
|
||
|
|
||
| @triton.jit | ||
| def add_kernel( | ||
| x_ptr, | ||
| y_ptr, | ||
| output_ptr, | ||
| M, | ||
| N, | ||
| stride_m, | ||
| stride_n, | ||
| BLOCK_SIZE_M: tl.constexpr, | ||
| BLOCK_SIZE_N: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(axis=0) | ||
| _num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | ||
| num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | ||
| pid_m = pid // num_pid_n | ||
| pid_n = pid % num_pid_n | ||
|
|
||
| # Range of pointers for loading the block of A and B. | ||
| offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
| offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | ||
| x_ptrs = x_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) | ||
| y_ptrs = y_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) | ||
| output_ptrs = output_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) | ||
| data_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) | ||
|
|
||
| x = tl.load(x_ptrs, mask=data_mask, other=0.0) | ||
| y = tl.load(y_ptrs, mask=data_mask, other=0.0) | ||
| output = x + y | ||
| tl.store(output_ptrs, output, mask=data_mask) | ||
|
|
||
|
|
||
| def add(x: torch.Tensor, y: torch.Tensor): | ||
| M, N = x.shape | ||
| output = torch.empty((M, N), device=x.device, dtype=torch.float16) | ||
|
|
||
| def grid(meta): | ||
| return ( | ||
| triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]), | ||
| ) | ||
|
|
||
| add_kernel[grid]( | ||
| x, | ||
| y, | ||
| output, | ||
| M, | ||
| N, | ||
| x.stride(0), | ||
| x.stride(1), | ||
| BLOCK_SIZE_M=BLOCK_SIZE_M, | ||
| BLOCK_SIZE_N=BLOCK_SIZE_N, | ||
| ) | ||
| return output |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # ===================== matadd with persistent programming style ================== | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| BLOCK_SIZE_M = 128 | ||
| BLOCK_SIZE_N = 128 | ||
| DEVICE = triton.runtime.driver.active.get_active_torch_device() | ||
|
|
||
|
|
||
| @triton.jit | ||
| def add_kernel( | ||
| x_ptr, | ||
| y_ptr, | ||
| output_ptr, | ||
| M, | ||
| N, | ||
| stride_m, | ||
| stride_n, | ||
| BLOCK_SIZE_M: tl.constexpr, | ||
| BLOCK_SIZE_N: tl.constexpr, | ||
| NUM_SMS: tl.constexpr, | ||
| ): | ||
| start_pid = tl.program_id(axis=0) | ||
| num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | ||
| num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | ||
|
|
||
| num_tiles = num_pid_m * num_pid_n | ||
|
|
||
| # iterate over the program id with a stride of the total number of blocks | ||
| for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): | ||
| pid_m = tile_id // num_pid_n | ||
| pid_n = tile_id % num_pid_n | ||
|
|
||
| # Range of pointers for loading the block of A and B. | ||
| offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
| offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | ||
| x_ptrs = x_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) | ||
| y_ptrs = y_ptr + (offs_m[:, None] * stride_m + offs_n[None, :] * stride_n) | ||
| output_ptrs = output_ptr + ( | ||
| offs_m[:, None] * stride_m + offs_n[None, :] * stride_n | ||
| ) | ||
| data_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) | ||
|
|
||
| x = tl.load(x_ptrs, mask=data_mask, other=0.0) | ||
| y = tl.load(y_ptrs, mask=data_mask, other=0.0) | ||
| output = x + y | ||
| tl.store(output_ptrs, output, mask=data_mask) | ||
|
|
||
|
|
||
| def add(x: torch.Tensor, y: torch.Tensor): | ||
| M, N = x.shape | ||
| output = torch.empty((M, N), device=x.device, dtype=torch.float16) | ||
|
|
||
| # Get the number of streaming multiprocessors and use it to launch a fixed number of blocks | ||
| NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count | ||
|
|
||
| def grid(meta): | ||
| return ( | ||
| min( | ||
| NUM_SMS, | ||
| triton.cdiv(M, meta["BLOCK_SIZE_M"]) | ||
| * triton.cdiv(N, meta["BLOCK_SIZE_N"]), | ||
| ), | ||
| ) | ||
|
|
||
| add_kernel[grid]( | ||
| x, | ||
| y, | ||
| output, | ||
| M, | ||
| N, | ||
| x.stride(0), | ||
| x.stride(1), | ||
| BLOCK_SIZE_M=BLOCK_SIZE_M, | ||
| BLOCK_SIZE_N=BLOCK_SIZE_N, | ||
| NUM_SMS=NUM_SMS, | ||
| ) | ||
| return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment on this is the same as
add_childrenLet's also add type hints to the args