diff --git a/README.md b/README.md index 1f83bf1c..da701a16 100644 --- a/README.md +++ b/README.md @@ -115,10 +115,27 @@ uv run python scripts/generate_and_eval_single_sample.py dataset_src=huggingface **What you might need to modify** * **`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware. * **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`. -* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`. +* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`, `thunderkittens`. Check the config fields for comprehensive set of options. Note we provide the model with a one-shot example by default along with the minimum set of info; you can check out other prompt settings or construct your own in `src/prompt_constructor_toml.py`. +### Running Thunderkittens Locally +If you plan on using `scripts/generate_and_eval_single_sample.py` using `backend=thunderkittens`, make sure to git clone the ThunderKittens repo and you set the following environment variable to point to your local ThunderKittens directory: + +```bash +export THUNDERKITTENS_ROOT=/Users/willychan/Desktop/projects/KernelBench/ThunderKittens +``` + +As seen in `src/kernelbench/prompts/model_new_ex_add_thunderkittens.py`, the generated kernels should have the following line: + +```bash +tk_root = os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens") +``` + +This allows the kernel to include the right TK primitives. + +*NOTE*: Right now, all generated ThunderKittens kernels are required to be in datatype format BF16. FP16 support is TBD. + ### Run on all problems ```bash diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index a71c3478..3bd5fb1d 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -70,8 +70,13 @@ "g++-10", "clang" ) + .uv_sync(uv_project_dir=REPO_TOP_DIR) - .env({"PYTHONPATH": "/root/src"}) + .run_commands("git clone -b tk-v2 https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + .env({ + "THUNDERKITTENS_ROOT": "/root/ThunderKittens", + "PYTHONPATH": "/root/src:/root" + }) .add_local_dir(SRC_DIR, remote_path="/root/src") .add_local_dir(KERNELBENCH_DIR, remote_path="/root/KernelBench") # must be last ) @@ -747,6 +752,15 @@ def main(config: EvalConfig): """ print(f"Starting Batch Eval with config: {config}") + # Handle backend-specific settings + backend = config.backend.lower() + + # thunderkittens requires bf16 and H100 GPU + if backend == "thunderkittens": + config.precision = "bf16" + config.gpu = "H100" + print(f"[ThunderKittens] Auto-configured: precision=bf16, gpu=H100") + # Check if CUDA is available (only for local mode) if config.eval_mode == "local": if not torch.cuda.is_available(): diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index 082a0e93..c42ea66a 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -200,7 +200,7 @@ def main(config: EvalConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "tilelang", "cute"} + supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( @@ -210,6 +210,9 @@ def main(config: EvalConfig): if backend == "tilelang": config.precision = "fp16" # tilelang only operates with fp16 config.hardware_gpu_name = config.hardware_gpu_name or getattr(config, "gpu", None) + + if backend == "thunderkittens": + config.precision = "bf16" if not custom_prompt_key: if prompt_option not in valid_prompt_options: diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index 6b249248..d8dae68f 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -103,8 +103,13 @@ def __repr__(self): "g++-10", "clang" # note i skip a step ) + .uv_sync(uv_project_dir=REPO_TOP_DIR, extras=["gpu"]) - .env({"PYTHONPATH": "/root/src"}) + .run_commands("git clone -b tk-v2 https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + .env({ + "THUNDERKITTENS_ROOT": "/root/ThunderKittens", + "PYTHONPATH": "/root:/root/src" + }) .add_local_dir(SRC_DIR, remote_path="/root/src") # must be last ) @@ -218,7 +223,7 @@ def main(config: EvalConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "tilelang", "cute"} + supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( @@ -229,6 +234,11 @@ def main(config: EvalConfig): if backend == "tilelang": config.precision = "fp16" config.hardware_gpu_name = config.hardware_gpu_name or getattr(config, "gpu", None) + + # thunderkittens can use bf16 or fp16 by default, also set default GPU to H100 + if backend == "thunderkittens": + config.precision = "bf16" + config.gpu = "H100" if not custom_prompt_key: if prompt_option not in valid_prompt_options: diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index eb65a210..312a9545 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -239,7 +239,7 @@ def main(config: GenerationConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "cute", "tilelang"} + supported_backends = {"cuda", "triton", "cute", "tilelang", "thunderkittens"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( @@ -248,6 +248,8 @@ def main(config: GenerationConfig): config.backend = backend if backend == "tilelang": config.precision = "fp16" + if backend == "thunderkittens": + config.precision = "bf16" config.prompt_option = str(config.prompt_option).lower() valid_prompt_options = {"zero_shot", "one_shot", "few_shot"} diff --git a/scripts/run_and_check.py b/scripts/run_and_check.py index 0bcd8e37..37ab9732 100644 --- a/scripts/run_and_check.py +++ b/scripts/run_and_check.py @@ -39,7 +39,11 @@ modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10") .apt_install("git", "gcc-10", "g++-10", "clang") .uv_sync(uv_project_dir=REPO_TOP_PATH) - .env({"PYTHONPATH": "/root/src:/root/scripts"}) + .run_commands("git clone -b tk-v2 https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens") + .env({ + "THUNDERKITTENS_ROOT": "/root/ThunderKittens", + "PYTHONPATH": "/root:/root/src:/root/scripts" + }) .add_local_dir(SRC_DIR, remote_path="/root/src") .add_local_dir(SCRIPTS_DIR, remote_path="/root/scripts") .add_local_dir(KERNELBENCH_DIR, remote_path="/root/KernelBench") # must be last diff --git a/src/kernelbench/prompts/model_ex_add_thunderkittens.py b/src/kernelbench/prompts/model_ex_add_thunderkittens.py new file mode 100644 index 00000000..8575537b --- /dev/null +++ b/src/kernelbench/prompts/model_ex_add_thunderkittens.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn + + +class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, a, b): + return a + b + + +def get_inputs(): + # Use shapes compatible with ThunderKittens 16x16 tiles, bf16 dtype + a = torch.randn(128, 128, dtype=torch.bfloat16).cuda() + b = torch.randn(128, 128, dtype=torch.bfloat16).cuda() + return [a, b] + + +def get_init_inputs(): + return [] + diff --git a/src/kernelbench/prompts/model_new_ex_add_thunderkittens.py b/src/kernelbench/prompts/model_new_ex_add_thunderkittens.py new file mode 100644 index 00000000..33cc3ef9 --- /dev/null +++ b/src/kernelbench/prompts/model_new_ex_add_thunderkittens.py @@ -0,0 +1,122 @@ +import os +import torch +import torch.nn as nn +from torch.utils.cpp_extension import load_inline + +tk_root = os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens") +tk_include_path = os.path.join(tk_root, "include") +tk_prototype_path = os.path.join(tk_root, "prototype") + +extra_include_paths = [tk_root, tk_include_path] +if os.path.isdir(tk_prototype_path): + extra_include_paths.append(tk_prototype_path) + +thunderkittens_add_source = """ +#include "kittens.cuh" +#include + +using namespace kittens; + +constexpr int BLOCK_SIZE = 16; + +#define NUM_WORKERS (1) +#define NUM_THREADS (NUM_WORKERS * kittens::WARP_THREADS) + +struct add_globals { + using sub_tile = st_bf; + using tile_gl = gl; + tile_gl A; + tile_gl B; + tile_gl C; +}; + +__global__ void add_tk(const __grid_constant__ add_globals g) { + extern __shared__ alignment_dummy __shm[]; + shared_allocator al((int*)&__shm[0]); + st_bf &As = al.allocate>(); + st_bf &Bs = al.allocate>(); + rt_bf A_reg; + rt_bf B_reg; + rt_bf C_reg; + int col = blockIdx.x; + int row = blockIdx.y; + // Load A and B tiles from global to shared + kittens::warp::load(As, g.A, {0, 0, row, col}); + kittens::warp::load(Bs, g.B, {0, 0, row, col}); + __syncthreads(); + // Load from shared to register + kittens::warp::load(A_reg, As); + kittens::warp::load(B_reg, Bs); + __syncthreads(); + // Element-wise add: C = A + B + kittens::warp::add(C_reg, A_reg, B_reg); + __syncthreads(); + // Store result back to global + kittens::warp::store(g.C, C_reg, {0, 0, row, col}); +} + +torch::Tensor thunderkittens_add_cuda(torch::Tensor A, torch::Tensor B) { + TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor"); + TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor"); + TORCH_CHECK(A.dtype() == torch::kBFloat16, "A must be bfloat16"); + TORCH_CHECK(B.dtype() == torch::kBFloat16, "B must be bfloat16"); + + int M = A.size(0); + int N = A.size(1); + + auto C = torch::empty_like(A); + + using tile_gl = add_globals::tile_gl; + tile_gl a_arg{(bf16*)A.data_ptr(), nullptr, nullptr, (size_t)M, (size_t)N}; + tile_gl b_arg{(bf16*)B.data_ptr(), nullptr, nullptr, (size_t)M, (size_t)N}; + tile_gl c_arg{(bf16*)C.data_ptr(), nullptr, nullptr, (size_t)M, (size_t)N}; + add_globals g{a_arg, b_arg, c_arg}; + + dim3 blocks((N + BLOCK_SIZE - 1) / BLOCK_SIZE, (M + BLOCK_SIZE - 1) / BLOCK_SIZE); + unsigned long mem_size = 50480; + cudaFuncSetAttribute(add_tk, cudaFuncAttributeMaxDynamicSharedMemorySize, mem_size); + add_tk<<>>(g); + + return C; +} +""" + +thunderkittens_add_cpp_source = """ +torch::Tensor thunderkittens_add_cuda(torch::Tensor A, torch::Tensor B); +""" + +thunderkittens_add = load_inline( + name="thunderkittens_add", + cpp_sources=thunderkittens_add_cpp_source, + cuda_sources=thunderkittens_add_source, + functions=["thunderkittens_add_cuda"], + verbose=True, + extra_include_paths=extra_include_paths, + extra_cflags=["-std=c++20", "-O3", "-DNDEBUG"], + extra_ldflags=["-lcuda"], + extra_cuda_cflags=[ + "-std=c++20", + "-O3", + "-DNDEBUG", + "-arch=sm_90a", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "-DKITTENS_HOPPER", + "-DKITTENS_BLACKWELL", + "-diag-suppress=20012", + "-Xcompiler", "-fPIC", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + ], +) + + +class ModelNew(nn.Module): + def __init__(self) -> None: + super().__init__() + self.thunderkittens_add = thunderkittens_add + + def forward(self, a, b): + return self.thunderkittens_add.thunderkittens_add_cuda(a, b) diff --git a/src/kernelbench/prompts/prompts.toml b/src/kernelbench/prompts/prompts.toml index acd5f678..2768aa11 100644 --- a/src/kernelbench/prompts/prompts.toml +++ b/src/kernelbench/prompts/prompts.toml @@ -49,6 +49,11 @@ backend_display = "TileLang kernels" one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_tilelang.py" # No few_shot_examples - will use one-shot when few_shot option is selected +[backends.thunderkittens] +backend_display = "ThunderKittens kernels" +one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_thunderkittens.py" +# No few_shot_examples - will use one-shot when few_shot option is selected + # ------------------------------------------------------------------------- # Precision: Precision-specific configuration # -------------------------------------------------------------------------