-
Notifications
You must be signed in to change notification settings - Fork 105
Simplified Thunderkittens Port #107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4bc4c95
06506a2
45eae0a
7b38e6f
3db725b
faf0935
3ed42ab
dfe5df8
e7e61ee
e492af5
d7fee2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -115,10 +115,27 @@ uv run python scripts/generate_and_eval_single_sample.py dataset_src=huggingface | |
| **What you might need to modify** | ||
| * **`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware. | ||
| * **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`. | ||
| * **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`. | ||
| * **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`, `thunderkittens`. | ||
|
|
||
| Check the config fields for comprehensive set of options. Note we provide the model with a one-shot example by default along with the minimum set of info; you can check out other prompt settings or construct your own in `src/prompt_constructor_toml.py`. | ||
|
|
||
| ### Running Thunderkittens Locally | ||
| If you plan on using `scripts/generate_and_eval_single_sample.py` using `backend=thunderkittens`, make sure to git clone the ThunderKittens repo and you set the following environment variable to point to your local ThunderKittens directory: | ||
|
|
||
| ```bash | ||
| export THUNDERKITTENS_ROOT=/Users/willychan/Desktop/projects/KernelBench/ThunderKittens | ||
| ``` | ||
|
|
||
| As seen in `src/kernelbench/prompts/model_new_ex_add_thunderkittens.py`, the generated kernels should have the following line: | ||
|
|
||
| ```bash | ||
| tk_root = os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens") | ||
| ``` | ||
|
|
||
| This allows the kernel to include the right TK primitives. | ||
|
|
||
| *NOTE*: Right now, all generated ThunderKittens kernels are required to be in datatype format BF16. FP16 support is TBD. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be good for now! Most TK kernels from the TK paper was done in BF16. We will add code path to restrict it. |
||
|
|
||
| ### Run on all problems | ||
|
|
||
| ```bash | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,8 +70,13 @@ | |
| "g++-10", | ||
| "clang" | ||
| ) | ||
|
|
||
| .uv_sync(uv_project_dir=REPO_TOP_DIR) | ||
| .env({"PYTHONPATH": "/root/src"}) | ||
| .run_commands("git clone -b tk-v2 https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") | ||
| .env({ | ||
| "THUNDERKITTENS_ROOT": "/root/ThunderKittens", | ||
| "PYTHONPATH": "/root/src:/root" | ||
| }) | ||
| .add_local_dir(SRC_DIR, remote_path="/root/src") | ||
| .add_local_dir(KERNELBENCH_DIR, remote_path="/root/KernelBench") # must be last | ||
| ) | ||
|
|
@@ -747,6 +752,15 @@ def main(config: EvalConfig): | |
| """ | ||
| print(f"Starting Batch Eval with config: {config}") | ||
|
|
||
| # Handle backend-specific settings | ||
| backend = config.backend.lower() | ||
|
|
||
| # thunderkittens requires bf16 and H100 GPU | ||
| if backend == "thunderkittens": | ||
| config.precision = "bf16" | ||
| config.gpu = "H100" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. technically blackwell and H200 too but we can worry about that later (i will address that in the enforcing H100 vs 200 PR) |
||
| print(f"[ThunderKittens] Auto-configured: precision=bf16, gpu=H100") | ||
|
|
||
| # Check if CUDA is available (only for local mode) | ||
| if config.eval_mode == "local": | ||
| if not torch.cuda.is_available(): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
|
|
||
| class Model(nn.Module): | ||
| def __init__(self) -> None: | ||
| super().__init__() | ||
|
|
||
| def forward(self, a, b): | ||
| return a + b | ||
|
|
||
|
|
||
| def get_inputs(): | ||
| # Use shapes compatible with ThunderKittens 16x16 tiles, bf16 dtype | ||
| a = torch.randn(128, 128, dtype=torch.bfloat16).cuda() | ||
| b = torch.randn(128, 128, dtype=torch.bfloat16).cuda() | ||
| return [a, b] | ||
|
|
||
|
|
||
| def get_init_inputs(): | ||
| return [] | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| import os | ||
| import torch | ||
| import torch.nn as nn | ||
| from torch.utils.cpp_extension import load_inline | ||
|
|
||
| tk_root = os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens") | ||
| tk_include_path = os.path.join(tk_root, "include") | ||
| tk_prototype_path = os.path.join(tk_root, "prototype") | ||
|
|
||
| extra_include_paths = [tk_root, tk_include_path] | ||
| if os.path.isdir(tk_prototype_path): | ||
| extra_include_paths.append(tk_prototype_path) | ||
|
|
||
| thunderkittens_add_source = """ | ||
| #include "kittens.cuh" | ||
| #include <torch/extension.h> | ||
| using namespace kittens; | ||
| constexpr int BLOCK_SIZE = 16; | ||
| #define NUM_WORKERS (1) | ||
| #define NUM_THREADS (NUM_WORKERS * kittens::WARP_THREADS) | ||
| struct add_globals { | ||
| using sub_tile = st_bf<BLOCK_SIZE, BLOCK_SIZE>; | ||
| using tile_gl = gl<bf16, 1, 1, -1, -1, sub_tile>; | ||
| tile_gl A; | ||
| tile_gl B; | ||
| tile_gl C; | ||
| }; | ||
| __global__ void add_tk(const __grid_constant__ add_globals g) { | ||
| extern __shared__ alignment_dummy __shm[]; | ||
| shared_allocator al((int*)&__shm[0]); | ||
| st_bf<BLOCK_SIZE, BLOCK_SIZE> &As = al.allocate<st_bf<BLOCK_SIZE, BLOCK_SIZE>>(); | ||
| st_bf<BLOCK_SIZE, BLOCK_SIZE> &Bs = al.allocate<st_bf<BLOCK_SIZE, BLOCK_SIZE>>(); | ||
| rt_bf<BLOCK_SIZE, BLOCK_SIZE> A_reg; | ||
| rt_bf<BLOCK_SIZE, BLOCK_SIZE> B_reg; | ||
| rt_bf<BLOCK_SIZE, BLOCK_SIZE> C_reg; | ||
| int col = blockIdx.x; | ||
| int row = blockIdx.y; | ||
| // Load A and B tiles from global to shared | ||
| kittens::warp::load(As, g.A, {0, 0, row, col}); | ||
| kittens::warp::load(Bs, g.B, {0, 0, row, col}); | ||
| __syncthreads(); | ||
| // Load from shared to register | ||
| kittens::warp::load(A_reg, As); | ||
| kittens::warp::load(B_reg, Bs); | ||
| __syncthreads(); | ||
| // Element-wise add: C = A + B | ||
| kittens::warp::add(C_reg, A_reg, B_reg); | ||
| __syncthreads(); | ||
| // Store result back to global | ||
| kittens::warp::store(g.C, C_reg, {0, 0, row, col}); | ||
| } | ||
| torch::Tensor thunderkittens_add_cuda(torch::Tensor A, torch::Tensor B) { | ||
| TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor"); | ||
| TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor"); | ||
| TORCH_CHECK(A.dtype() == torch::kBFloat16, "A must be bfloat16"); | ||
| TORCH_CHECK(B.dtype() == torch::kBFloat16, "B must be bfloat16"); | ||
| int M = A.size(0); | ||
| int N = A.size(1); | ||
| auto C = torch::empty_like(A); | ||
| using tile_gl = add_globals::tile_gl; | ||
| tile_gl a_arg{(bf16*)A.data_ptr(), nullptr, nullptr, (size_t)M, (size_t)N}; | ||
| tile_gl b_arg{(bf16*)B.data_ptr(), nullptr, nullptr, (size_t)M, (size_t)N}; | ||
| tile_gl c_arg{(bf16*)C.data_ptr(), nullptr, nullptr, (size_t)M, (size_t)N}; | ||
| add_globals g{a_arg, b_arg, c_arg}; | ||
| dim3 blocks((N + BLOCK_SIZE - 1) / BLOCK_SIZE, (M + BLOCK_SIZE - 1) / BLOCK_SIZE); | ||
| unsigned long mem_size = 50480; | ||
| cudaFuncSetAttribute(add_tk, cudaFuncAttributeMaxDynamicSharedMemorySize, mem_size); | ||
| add_tk<<<blocks, NUM_THREADS, mem_size>>>(g); | ||
| return C; | ||
| } | ||
| """ | ||
|
|
||
| thunderkittens_add_cpp_source = """ | ||
| torch::Tensor thunderkittens_add_cuda(torch::Tensor A, torch::Tensor B); | ||
| """ | ||
|
|
||
| thunderkittens_add = load_inline( | ||
| name="thunderkittens_add", | ||
| cpp_sources=thunderkittens_add_cpp_source, | ||
| cuda_sources=thunderkittens_add_source, | ||
| functions=["thunderkittens_add_cuda"], | ||
| verbose=True, | ||
| extra_include_paths=extra_include_paths, | ||
| extra_cflags=["-std=c++20", "-O3", "-DNDEBUG"], | ||
| extra_ldflags=["-lcuda"], | ||
| extra_cuda_cflags=[ | ||
| "-std=c++20", | ||
| "-O3", | ||
| "-DNDEBUG", | ||
| "-arch=sm_90a", | ||
| "--expt-relaxed-constexpr", | ||
| "--expt-extended-lambda", | ||
| "-DKITTENS_HOPPER", | ||
| "-DKITTENS_BLACKWELL", | ||
| "-diag-suppress=20012", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a smart way of doing it |
||
| "-Xcompiler", "-fPIC", | ||
| "-U__CUDA_NO_HALF_OPERATORS__", | ||
| "-U__CUDA_NO_HALF_CONVERSIONS__", | ||
| "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", | ||
| "-U__CUDA_NO_HALF2_OPERATORS__", | ||
| ], | ||
| ) | ||
|
|
||
|
|
||
| class ModelNew(nn.Module): | ||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.thunderkittens_add = thunderkittens_add | ||
|
|
||
| def forward(self, a, b): | ||
| return self.thunderkittens_add.thunderkittens_add_cuda(a, b) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix to be generic and i will add some comments etc