diff --git a/EVAL.md b/EVAL.md index c6f70fb0..b9880560 100644 --- a/EVAL.md +++ b/EVAL.md @@ -36,6 +36,9 @@ We have (and continue to) implement various approaches to conduct kernel timing Check out `timing.py` to see available timing methods and `src/unit_tests/test_eval_timing.py` to test out various timing methods (including leveraging `cuda_event` marker, Triton `do_bench`, `host_time` E2E time). @palic and team is working on a blogpost explaining the different tradeoffs soon. +### Checkers +There are potentially many ways model might reward hack and we would like to catch the known ways through checkers [experimental and WIP]. We start with `kernel_static_checker.py`, which is a regex-based checker on the genenrated code against set of rules. We plan to add AST-based, LM-as-a-judge, and more runtime checks in the future. We welcome suggestions and contributions here. + ### Unit Tests with Adversarial Examples We've included some unit tests for the eval script in `src/unit_tests/test_eval_adversarial.py`. These tests run adversarial kernels (see `src/unit_tests/test_kernels/`) that contain examples of reward hacking that we've seen from LLMs and ensures that the eval script catches them, either by failing their correctness checks or flagging them for excessive speedups. Examples include: - Reusing computations cached during the PyTorch reference diff --git a/README.md b/README.md index da701a16..3686574a 100644 --- a/README.md +++ b/README.md @@ -117,24 +117,10 @@ uv run python scripts/generate_and_eval_single_sample.py dataset_src=huggingface * **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`. * **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`, `thunderkittens`. -Check the config fields for comprehensive set of options. Note we provide the model with a one-shot example by default along with the minimum set of info; you can check out other prompt settings or construct your own in `src/prompt_constructor_toml.py`. - -### Running Thunderkittens Locally -If you plan on using `scripts/generate_and_eval_single_sample.py` using `backend=thunderkittens`, make sure to git clone the ThunderKittens repo and you set the following environment variable to point to your local ThunderKittens directory: - -```bash -export THUNDERKITTENS_ROOT=/Users/willychan/Desktop/projects/KernelBench/ThunderKittens -``` -As seen in `src/kernelbench/prompts/model_new_ex_add_thunderkittens.py`, the generated kernels should have the following line: +Note on setting up ThunderKittens (TK) locally: to use `backend=thunderkittens`, you need to git clone the ThunderKittens repo and set the following environment variable to point to your local ThunderKittens directory, `export THUNDERKITTENS_ROOT=`, and all ThunderKitten programs as shown in the [example](src/kernelbench/prompts/model_new_ex_add_thunderkittens.py), should contain `tk_root = os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens")`, which enable the kernel to include the right TK primitives. In addition, we only support BF16 for TK right now. -```bash -tk_root = os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens") -``` - -This allows the kernel to include the right TK primitives. - -*NOTE*: Right now, all generated ThunderKittens kernels are required to be in datatype format BF16. FP16 support is TBD. +Check the config fields for comprehensive set of options. Note we provide the model with a one-shot example by default along with the minimum set of info; you can check out other prompt settings or construct your own in `src/prompt_constructor_toml.py`. ### Run on all problems diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index c42ea66a..0b571792 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -83,6 +83,8 @@ def __init__(self): self.hardware_gpu_name = None self.custom_prompt_key = None + self.check_kernel = True # [experimental] optional static checker catching potential hacking patterns + def verbose_logging(self): self.log = True self.log_prompt = True @@ -260,6 +262,19 @@ def main(config: EvalConfig): custom_kernel is not None ), f"Custom {config.backend} kernel code generation failed" + # Optional: static code checker for kernel code using regex matching + # NOTE: by no means is this checker complete, but it could help catch some potential hacks + if config.check_kernel: + from kernelbench.kernel_static_checker import validate_kernel_static + static_check_status, errors, warnings = validate_kernel_static( + custom_kernel, + backend=config.backend, + precision=config.precision, + ) + assert static_check_status, f"Static check failed for level {config.level} problem {config.problem_id}. Errors: {errors}. Warnings: {warnings}" + if warnings: + print(f"Static check warnings for level {config.level} problem {config.problem_id}: {warnings}") + # this should be optional if config.log: with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f: diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index d8dae68f..2a8e6aba 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -80,6 +80,8 @@ def __init__(self): self.hardware_gpu_name = None self.custom_prompt_key = None + self.check_kernel = True # [experimental] optional static checker catching potential hacking patterns + def verbose_logging(self): self.log = True self.log_prompt = True @@ -283,6 +285,19 @@ def main(config: EvalConfig): # check LLM is able to generate custom kernel code assert custom_kernel is not None, f"Custom {config.backend} kernel code generation failed" + # Optional: static code checker for kernel code using regex matching + # NOTE: by no means is this checker complete, but it could help catch some potential hacks + if config.check_kernel: + from kernelbench.kernel_static_checker import validate_kernel_static + static_check_status, errors, warnings = validate_kernel_static( + custom_kernel, + backend=config.backend, + precision=config.precision, + ) + assert static_check_status, f"Static check failed for level {config.level} problem {config.problem_id}. Errors: {errors}. Warnings: {warnings}" + if warnings: + print(f"Static check warnings for level {config.level} problem {config.problem_id}: {warnings}") + # this should be optional if config.log: with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f: diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 312a9545..40d718f8 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -18,6 +18,7 @@ read_file, set_gpu_arch, ) +from kernelbench.kernel_static_checker import validate_kernel_static """ Batch Generate Samples for Particular Level @@ -84,6 +85,8 @@ def __init__(self): self.hardware_gpu_name = None self.custom_prompt_key = None + self.check_kernel = True # [experimental] optional static checker catching potential hacking patterns + def greedy(self): # For greedy decoding, epsecially baseline eval self.greedy_sample = True @@ -162,6 +165,19 @@ def generate_sample_single( # check LLM is able to generate custom CUDA code assert custom_kernel is not None, "Custom CUDA code generation failed" + # Optional: we provide a static code checker for kernel code using regex matching + # NOTE: by no means, is this checker complete, but it might could help catch some potential hacks and issues + if config.check_kernel: + static_check_status, error, warnings = validate_kernel_static(custom_kernel, + backend=config.backend, + precision=config.precision, + # uses the default set of forbidden and warning patterns, + # you could adapt the patterns to your own setting (degree of banning cuda stream, allowing some torch ops) + ) + assert static_check_status, f"Static check failed for sample {work.sample_id} for problem {problem_number}: {problem_name}. Error: {error}. Warnings: {warnings}" + if warnings: + print(f"Static check warnings for sample {work.sample_id} for problem {problem_number}: {problem_name}. Warnings: {warnings}") + if config.verbose: print( f"Generated sample {work.sample_id} for problem {problem_number}: {problem_name}" diff --git a/scripts/run_and_check.py b/scripts/run_and_check.py index 37ab9732..8d150b7c 100644 --- a/scripts/run_and_check.py +++ b/scripts/run_and_check.py @@ -10,6 +10,7 @@ from kernelbench import utils as kernel_utils from scripts.generate_baseline_time import measure_program_time from kernelbench.utils import read_file +from kernelbench.kernel_static_checker import validate_kernel_static # Modal setup app = modal.App("run_and_check") @@ -120,6 +121,8 @@ def __init__(self): self.precision = "fp32" self.backend = "cuda" + self.check_kernel = True # [experimental] optional static checker catching potential hacking patterns + def __repr__(self): return f"ScriptConfig({self.to_dict()})" @@ -279,6 +282,18 @@ def main(config: ScriptConfig): kernel_src = read_file(config.kernel_src_path) + # Optional: static code checker for kernel code using regex matching + # NOTE: by no means is this checker complete, but it could help catch some potential hacks + if config.check_kernel: + static_check_status, errors, warnings = validate_kernel_static( + kernel_src, + backend=config.backend, + precision=config.precision, + ) + assert static_check_status, f"Static check failed. Errors: {errors}. Warnings: {warnings}" + if warnings: + print(f"[WARN] Static check warnings: {warnings}") + # Start Evaluation assert config.eval_mode in ["local", "modal"], "eval_mode must be either 'local' or 'modal'" diff --git a/src/kernelbench/kernel_static_checker.py b/src/kernelbench/kernel_static_checker.py new file mode 100644 index 00000000..c8832a1a --- /dev/null +++ b/src/kernelbench/kernel_static_checker.py @@ -0,0 +1,659 @@ +""" +Kernel Static Checker - Pattern-based validation for GPU kernel code. + +The goal is flag reward hacking patterns (both strictly prohibited and possible ones). +through statically examining the code. + +In the future we can add +- AST-based detections +- LM as a judge checker + +Warning: This list is by no means complete and nor this is not a replacement for runtime checks. +We welcome feedback and contributions as community find new ways of hacks. + +- Bypass hacks (PyTorch wrapping, try-except fallback, inheritance bypass) +- Disallow some high-level torch operations (depends on the settings) +- Backend implementation requirements, that CUDA or DSL features must be used + +Usage: + result = validate_kernel_static(code, backend="cuda") + will return a tuple (valid, errors, warnings) +""" + +import re +from typing import List, Tuple, Dict, Any, Optional, Callable, Union + +def _strip_comments(code: str) -> str: + """Remove # and // comments from code.""" + lines = [] + for line in code.split('\n'): + if '#' in line: + line = line[:line.index('#')] + if '//' in line: + line = line[:line.index('//')] + lines.append(line) + return '\n'.join(lines) + + +# ============================================================================= +# BYPASS CHECKS - Strictly Prohibited +# some of this is from Kevin RL Paper (arxiv:2507.11948) +# ============================================================================= + +# --- Try-Except Fallback --- +# Rationale: Models wrap incomplete CUDA in exception handlers that fall back to PyTorch. +# This allows them to pass tests without actually implementing the kernel. +TRY_EXCEPT_PATTERNS = [r"\btry\s*:", r"\bexcept\s*:", r"\bexcept\s+\w+"] + +# --- Pass Statement / Inheritance Bypass --- +# Rationale: Model inherits from reference class and uses 'pass' to do nothing, +# effectively just calling the parent implementation. +PASS_PATTERN = r"\bpass\b" + +def check_code_bypass(code: str) -> Tuple[bool, str]: + """ + Check for code bypass patterns (strictly prohibited). + 1. Try-Except Fallback: Models wrap incomplete CUDA in exception handlers + that fall back to PyTorch when custom code fails. + 2. Pass Statement: Models inherit from reference and use 'pass' to do nothing, + effectively calling parent implementation. + Uses word boundary for 'pass' to avoid matching 'passed', 'bypass', etc. + """ + code = _strip_comments(code) + + # Check for try-except fallback + for pattern in TRY_EXCEPT_PATTERNS: + if re.search(pattern, code): + return (True, "Contains try-except block (potential fallback bypass)") + + # Check for pass statement + if re.search(PASS_PATTERN, code): + return (True, "Contains 'pass' statement (inheritance bypass)") + + return (False, "") + +# Since KernelBench problems uses PyTorch as a reference, there could be settigs where +# Model generated code +# 1. Replaces some (not all) ops with custom kernels, others are kept in Torch +# --> More practical from a performance perspective (ie. make better systems) as you want to use whatever makes the best system for your use case. +# 2. All compuational ops must be replaced with custom kernels +# --> Could be helpful from an eval (model ability on transpile + optimization) / RL training perspective +# Depends the setting you use, you can move the checks below (pytorch_wrap, torch_computation_ops) +# from WARNING to STRICT + +# --- PyTorch NN Module Wrapping --- +# Allows: nn.Module, nn.Parameter, nn.ParameterList, nn.ParameterDict, +# nn.ModuleList, nn.ModuleDict, nn.init (needed for model structure) +# Blocks: nn.Linear, nn.Conv2d, nn.ReLU, etc. (compute layers) +PYTORCH_DISALLOWED_NN_PATTERN = r'torch\.nn\.(?!(Module|parameter|Parameter|ParameterList|ParameterDict|ModuleList|ModuleDict|init)\b)' + +def check_pytorch_wrap(code: str) -> Tuple[bool, str]: + """ + Check for PyTorch nn module usage (nn.Linear, nn.Conv2d, etc.). + + Allows containers (nn.Module, nn.Parameter, nn.init) needed for model structure. + Blocks compute layers (nn.Linear, nn.Conv2d, nn.ReLU, etc.). + """ + code = _strip_comments(code) + if re.search(PYTORCH_DISALLOWED_NN_PATTERN, code): + return (True, "Uses torch.nn compute layer (only containers, Parameter, init allowed)") + return (False, "") + + +# --- Torch Computation Operations --- +# Rationale: These are high-level PyTorch ops that conduct computation. +# Using them directly defeats the purpose of writing custom kernels. +# Includes both torch.* and F.* (torch.nn.functional) patterns. +TORCH_COMPUTATION_OPS = [ + # Matrix operations + "torch.mm", "torch.bmm", "torch.matmul", "torch.einsum", + # Convolutions + "torch.conv1d", "torch.conv2d", "torch.conv3d", "torch.conv", + "torch.conv_transpose1d", "torch.conv_transpose2d", "torch.conv_transpose3d", + # Pooling + "torch.avg_pool1d", "torch.avg_pool2d", "torch.avg_pool3d", + "torch.max_pool1d", "torch.max_pool2d", "torch.max_pool3d", + "torch.adaptive_avg_pool1d", "torch.adaptive_avg_pool2d", "torch.adaptive_avg_pool3d", + "torch.adaptive_max_pool1d", "torch.adaptive_max_pool2d", "torch.adaptive_max_pool3d", + # Activations + "torch.relu", "torch.hardtanh", "torch.elu", "torch.selu", + "torch.leaky_relu", "torch.gelu", "torch.softsign", "torch.softplus", + "torch.softmax", "torch.log_softmax", "torch.tanh", "torch.sigmoid", + "torch.hardsigmoid", "torch.silu", "torch.mish", + # Normalization + "torch.batch_norm", "torch.group_norm", "torch.layer_norm", + "torch.instance_norm", "torch.rms_norm", "torch.normalize", + # Linear & Loss + "torch.linear", "torch.cross_entropy", "torch.kl_div", "torch.mse_loss", + "torch.huber_loss", "torch.triplet_margin_loss", "torch.cosine_similarity", + # Others + "torch.logsumexp", "torch.clamp", "torch.dropout", +] + +# F.* patterns (torch.nn.functional equivalents) +TORCH_FUNCTIONAL_PATTERNS = [ + r"torch\.nn\.functional\.\w+", # torch.nn.functional.* + r"\bnn\.functional\.\w+", # nn.functional.* + r"\bF\.(conv|linear|relu|gelu|softmax|batch_norm|layer_norm|dropout|max_pool|avg_pool)", +] + +def check_torch_computation_ops(code: str) -> Tuple[bool, str]: + """ + Check for high-level torch computation operations. + + Matches both torch.* ops (torch.matmul) and F.* ops (F.relu). + This check is optional/taste-based. Configure as needed. + """ + code = _strip_comments(code) + + # Check torch.* ops + torch_pattern = r'\b(' + '|'.join(re.escape(f) for f in TORCH_COMPUTATION_OPS) + r')(?=\s*\(|\s|$)' + match = re.search(torch_pattern, code) + if match: + return (True, f"Uses torch computation op: {match.group(0)}") + + # Check F.* / nn.functional ops + for pattern in TORCH_FUNCTIONAL_PATTERNS: + match = re.search(pattern, code) + if match: + return (True, f"Uses torch.nn.functional op: {match.group(0)}") + + return (False, "") + +# ============================================================================= +# Backend Specific Checks +# ============================================================================= + +# <========= CUDA CHECKS =========> +# Rationale: Valid CUDA kernels must have __global__ (kernel definition) and +# use load_inline or cpp_extension (PyTorch's inline compilation). +CUDA_COMPILE_PATTERNS = ["load_inline", "cpp_extension"] + +def check_cuda_impl(code: str) -> Tuple[bool, str]: + """ + Check for valid CUDA kernel implementation. + + Requirements: + - Must have __global__ void kernel_name (kernel definition) + - Must have load_inline or cpp_extension (PyTorch inline compilation) + """ + code = _strip_comments(code) + if "__global__" not in code: + return (True, "Missing __global__ kernel definition") + if not any(p in code for p in CUDA_COMPILE_PATTERNS): + return (True, "Missing load_inline or cpp_extension for compilation") + return (False, "") + +# <========= TRITON CHECKS =========> +# Rationale: Triton kernels are compiled from @triton.jit decorated functions. +# They must use tl.* operations (tl.load, tl.store, etc.) for actual kernel work. +TRITON_JIT_PATTERN = r"@triton\.(jit|autotune)" +TRITON_OPS_PATTERN = r"\btl\.\w+" + +def check_triton_impl(code: str) -> Tuple[bool, str]: + """ + Check for valid Triton kernel implementation. + + Requirements: + - Must have @triton.jit or @triton.autotune decorator + - Must have tl.* operations (enforces actual Triton code, not wrapper) + + Note: Triton's compiler itself prevents PyTorch ops inside @triton.jit. + """ + code = _strip_comments(code) + if not re.search(TRITON_JIT_PATTERN, code): + return (True, "Missing @triton.jit or @triton.autotune") + if not re.search(TRITON_OPS_PATTERN, code): + return (True, "No tl.* operations found in Triton kernel") + return (False, "") + + +# <========= THUNDERKITTENS CHECKS =========> +# Rationale: ThunderKittens uses warp/warpgroup primitives and tile abstractions. +# Valid TK code must have namespace patterns and tile declarations. +TK_WARP_PATTERNS = [ + r"kittens::warp\b", r"kittens::warpgroup\b", + r"::warpgroup::", r"::warp::", r"warpgroup::", r"warp::" +] +TK_TILE_PATTERN = r"(?:kittens::)?(?:st|rt)_\w+\s*<[^>]+>" + +def check_tk_impl(code: str) -> Tuple[bool, str]: + """ + Check for valid ThunderKittens kernel implementation. + + Requirements: + - Must have warp/warpgroup namespace patterns (kittens::warp, etc.) + - Must have tile declarations (st_bf<...>, rt_fl<...>, etc.) + + TODO: Add producer-consumer pattern check for complex kernels. + """ + code = _strip_comments(code) + if not any(re.search(p, code) for p in TK_WARP_PATTERNS): + return (True, "Missing ThunderKittens warp/warpgroup patterns") + if not re.search(TK_TILE_PATTERN, code): + return (True, "Missing ThunderKittens tile declarations (st_*/rt_*)") + return (False, "") + + +# <========= CUTE/CUTLASS CHECKS =========> +# CUTLASS uses cute:: namespace for tensor operations +# Check: https://github.com/NVIDIA/cutlass +CUTE_PATTERNS = [ + r"cute::", # cute:: namespace (CuTe library) + r"cutlass::", # cutlass:: namespace + r"from cutlass", # Python CUTLASS bindings +] + +def check_cute_impl(code: str) -> Tuple[bool, str]: + """Check for valid CUTLASS/CuTe kernel implementation.""" + code = _strip_comments(code) + if not any(p in code for p in ["cute::", "cutlass::", "from cutlass"]): + return (True, "Missing cute:: or cutlass:: namespace") + return (False, "") + + +# <========= TILELANG CHECKS =========> +# TileLang uses TVM's T.prim_func decorator +# https://github.com/tile-ai/tilelang +TILELANG_PATTERNS = [ + r"@T\.prim_func", # TVM primitive function decorator + r"tvm\.build", # TVM build call + r"T\.grid", # TileLang grid +] + +def check_tilelang_impl(code: str) -> Tuple[bool, str]: + """Check for valid TileLang kernel implementation.""" + code = _strip_comments(code) + if not re.search(r"@T\.prim_func", code): + return (True, "Missing @T.prim_func decorator") + return (False, "") + + +# ============================================================================= +# TIMING MANIPULATION CHECKS - Reward Hacking Patterns +# From adversarial hack PR and DeepReinforce blog +# ============================================================================= + +# <========= STREAM INJECTION CHECKS =========> +# Rationale: Models may use CUDA streams to defer computation, manipulate timing, +# or execute operations asynchronously to cheat benchmark measurements. +STREAM_PATTERNS = [ + r"torch\.cuda\.Stream\s*\(", # torch.cuda.Stream() + r"cuda\.Stream\s*\(", # cuda.Stream() (if imported) + r"with\s+torch\.cuda\.stream", # with torch.cuda.stream(...) + r"\.wait_stream\s*\(", # .wait_stream() method + r"\.record_stream\s*\(", # .record_stream() method +] + +def check_stream_injection(code: str) -> Tuple[bool, str]: + """ + Check for CUDA stream injection patterns. + + Detects code that uses CUDA streams to potentially manipulate timing: + 1. Stream creation: torch.cuda.Stream(), cuda.Stream() + 2. Stream context managers: with torch.cuda.stream(...) + 3. Stream synchronization: .wait_stream(), .record_stream() + + Rationale: Streams can defer computation or change execution order, + potentially affecting benchmark timing measurements. + """ + code = _strip_comments(code) + + for pattern in STREAM_PATTERNS: + if re.search(pattern, code): + if "wait_stream" in pattern or "record_stream" in pattern: + return (True, "Uses stream synchronization (potential timing manipulation)") + elif "with" in pattern: + return (True, "Uses stream context manager (potential timing manipulation)") + else: + return (True, "Uses CUDA streams (potential timing manipulation)") + + return (False, "") + + +# <========= THREAD INJECTION CHECKS =========> +# Rationale: Models may use threading to parallelize work or manipulate execution +# order in ways that could affect benchmark timing. +THREAD_PATTERNS = [ + r"threading\.Thread\s*\(", # threading.Thread() + r"import\s+threading", # import threading + r"from\s+threading\s+import", # from threading import ... + r"multiprocessing\.(Process|Pool|Manager|Queue|Pipe)", + r"import\s+multiprocessing", # import multiprocessing + r"concurrent\.futures", # concurrent.futures (thread pools) + r"ThreadPoolExecutor", # ThreadPoolExecutor + r"ProcessPoolExecutor", # ProcessPoolExecutor +] + +def check_thread_injection(code: str) -> Tuple[bool, str]: + """ + Check for thread/multiprocessing injection patterns. + + Detects code that uses threading or multiprocessing: + 1. Thread creation: threading.Thread() + 2. Threading imports: import threading + 3. Multiprocessing: Process, Pool, Manager, Queue, Pipe + 4. Concurrent futures: ThreadPoolExecutor, ProcessPoolExecutor + + Rationale: Threading can defer computation or change execution order, + affecting timing measurements. + + Note: ProcessPoolExecutor might be legitimate in eval code but should + not appear in kernel implementations. + """ + code = _strip_comments(code) + + for pattern in THREAD_PATTERNS: + if re.search(pattern, code): + if "multiprocessing" in pattern: + return (True, "Uses multiprocessing (potential timing manipulation)") + elif "concurrent" in pattern or "Executor" in pattern: + return (True, "Uses concurrent futures (potential timing manipulation)") + else: + return (True, "Uses threading (potential timing manipulation)") + + return (False, "") + + +# <========= LAZY EVALUATION CHECKS =========> +# Rationale: Models may create fake/lazy tensors that don't actually compute +# anything, passing correctness checks without real implementation. +LAZY_TENSOR_PATTERNS = [ + r"_make_subclass", # torch.Tensor._make_subclass (common lazy hack) + r"class\s+\w+.*\(torch\.Tensor\)", # Custom tensor subclasses + r"class\s+\w+.*\(Tensor\)", # Custom tensor subclasses (imported Tensor) + r"torch\.Tensor\.__new__", # Direct tensor construction (potential lazy) +] + +def check_lazy_eval(code: str) -> Tuple[bool, str]: + """ + Check for lazy tensor creation patterns. + + Detects patterns commonly used to create lazy/fake tensors: + 1. _make_subclass: Common way to create custom tensor subclasses + 2. Custom tensor subclasses: Classes inheriting from torch.Tensor + 3. Direct tensor construction: torch.Tensor.__new__ manipulation + + Rationale: Lazy tensors can pass correctness checks without actually + computing anything, which is a form of reward hacking. + """ + code = _strip_comments(code) + + for pattern in LAZY_TENSOR_PATTERNS: + if re.search(pattern, code): + if "_make_subclass" in pattern: + return (True, "Uses _make_subclass (potential lazy tensor hack)") + elif "class" in pattern: + return (True, "Defines custom tensor subclass (potential lazy tensor hack)") + else: + return (True, "Uses direct tensor construction (potential lazy tensor hack)") + + return (False, "") + + +# <========= Timing Monkey Patch CHECKS =========> +# Rationale: Models may monkey-patch torch timing functions to fake benchmark results. +# This detects static patterns where timing functions are reassigned. +# especially when relying on timing markers like cuda.Event or time.perf_counter. +TIMING_EVENT_PATCH_PATTERNS = [ + r"torch\.cuda\.Event\.record\s*=", # torch.cuda.Event.record = ... + r"torch\.cuda\.Event\.elapsed_time\s*=", # torch.cuda.Event.elapsed_time = ... + r"torch\.cuda\.synchronize\s*=", # torch.cuda.synchronize = ... + r"torch\.cuda\.Event\s*=", # torch.cuda.Event = ... (replacing class) + r"time\.perf_counter\s*=", # time.perf_counter = ... + r"time\.time\s*=", # time.time = ... +] + +def check_timing_event_patch(code: str) -> Tuple[bool, str]: + """ + Check for monkey patching of timing functions (static analysis). + + Detects code that reassigns torch or time functions used for benchmarking: + 1. torch.cuda.Event.record = ... + 2. torch.cuda.Event.elapsed_time = ... + 3. torch.cuda.synchronize = ... + 4. time.perf_counter = ... + + Rationale: Monkey patching timing functions can fake benchmark results, + making slow code appear fast. + """ + code = _strip_comments(code) + + for pattern in TIMING_EVENT_PATCH_PATTERNS: + if re.search(pattern, code): + return (True, "Reassigns timing function (monkey patch detected)") + + return (False, "") + + +# <========= PRECISION DOWNGRADE CHECKS =========> +# Rationale: Models may downgrade precision internally (FP32 -> FP16) +# to make kernels faster while still returning FP32 output. This is a reward +# hacking pattern where models cheat on timing while appearing correct. +# This might be a valid optimization, but is still worth flagging as a warning. +# +# We use explicit, high-confidence patterns that indicate intentional precision +# downgrading. These patterns have minimal false positives and clear semantic intent. + +# Specific patterns that indicate FP32 -> FP16 precision downgrading +FP32_TO_FP16_PATTERNS = [ + # ========== CUDA / CUDA C++ ========== + # 1.1 Explicit float -> half intrinsics (⭐ gold standard) + # __float2half(f), __float2half_rn(f) + r"__float2half(_rn)?\s*\(", + + # 1.2 Explicit C-style cast to __half + # (__half)f + r"\(\s*__half\s*\)\s*[\w\->\.]+", + + # 1.3 static_cast / static_cast<__half> + # static_cast(f), static_cast<__half>(f) + r"static_cast\s*<\s*(__half|half)\s*>\s*\(", + + # ========== Triton (Python) ========== + # 2.1 Explicit tl.astype(..., tl.float16) (⭐ best signal) + # tl.astype(x, tl.float16) + r"tl\.astype\s*\(\s*[^,]+,\s*tl\.float16\s*\)", + + # ========== CUTLASS ========== + # 3.1 NumericConverter float -> half (⭐ extremely reliable) + # NumericConverter + r"NumericConverter\s*<\s*half_t\s*,\s*float\s*>", + + # 3.2 Epilogue output type = FP16 + # LinearCombination + r"LinearCombination\s*<\s*half_t", + + # ========== CUTE (when used directly) ========== + # 4.1 Explicit type conversion utility + # type_convert(x) + r"type_convert\s*<\s*half_t\s*>\s*\(", + + # ========== ThunderKittens ========== + # 5.1 Explicit half types in cast contexts + # tk::half(x) + r"tk::half\s*\(", + + # ========== Additional gaping patterns ========== + # PyTorch explicit casts (common in inline CUDA contexts) + r"\.half\s*\(", # .half() -> FP16 + r"\.float16\s*\(", # .float16() -> FP16 + r"\.to\s*\(\s*torch\.(float16|half)\b", # .to(torch.float16) + r"\.to\s*\(\s*dtype\s*=\s*torch\.(float16|half)\b", # .to(dtype=torch.float16) +] + +def check_precision_downgrade(code: str, precision: str = "fp32") -> Tuple[bool, str]: + """ + Check for FP32 -> FP16 precision downgrading patterns in kernel code. + + Detects code that downgrades precision from FP32 to FP16 internally, + which can make kernels faster but with reduced precision. This is a reward + hacking pattern where models cheat on timing while appearing correct. + + Currently only checks for FP32 -> FP16 downgrades. Other precision combinations + (e.g., FP16 -> BF16) are not flagged. + + Args: + code: Kernel source code + precision: Required precision - only checks when "fp32" or "float32" + + Returns: + (True, error_message) if FP32 -> FP16 downgrade detected + (False, "") if no downgrade detected + + Examples of detected patterns: + - .half(), .float16() + - .to(torch.float16), .to(torch.half) + - dtype=torch.float16 + - __half, half2 (CUDA) + - tl.float16 (Triton) + """ + code = _strip_comments(code) + precision = precision.lower() + + # Normalize precision to standard form + precision_map = {"fp32": "fp32", "float32": "fp32", "fp16": "fp16", "bf16": "bf16", "bfloat16": "bf16"} + precision = precision_map.get(precision, precision) + + # Only check for FP32 -> FP16 downgrades + if precision != "fp32": + return (False, "") + + # Check for FP16 patterns + for pattern in FP32_TO_FP16_PATTERNS: + if re.search(pattern, code): + return (True, "Precision downgrade detected: required FP32 but code uses FP16") + + return (False, "") + +# ============================================================================= +# In the future, we can add a AST-based checker and a LM-as-a-judge checker +# ============================================================================= + + +# ============================================================================= +# REGISTRY & PRESETS +# ============================================================================= + +# Check functions can take either (code) or (code, precision) arguments +# Most checks take only code, but precision-dependent checks take both +CHECK_FUNCTIONS: Dict[str, Union[Callable[[str], Tuple[bool, str]], Callable[[str, str], Tuple[bool, str]]]] = { + # Bypass checks (strict) + "code_bypass": check_code_bypass, + "pytorch_wrap": check_pytorch_wrap, + "timing_event_patch": check_timing_event_patch, # clearly malicious + + # Torch ops (depends on your setups) + "torch_computation_ops": check_torch_computation_ops, + + # Timing manipulation checks (usually warnings) + "stream_injection": check_stream_injection, + "thread_injection": check_thread_injection, + "lazy_eval": check_lazy_eval, + "precision_downgrade": check_precision_downgrade, # precision-dependent + + # Backend-specific implementation checks + # should be strict + "cuda_impl": check_cuda_impl, + "triton_impl": check_triton_impl, + "tk_impl": check_tk_impl, + "cute_impl": check_cute_impl, + "tilelang_impl": check_tilelang_impl, +} + +# Checks that require additional parameters beyond just code +PRECISION_DEPENDENT_CHECKS = {"precision_downgrade"} + +# Here are some presets for you to use +# You are welcome to adapt them to your settings +# These checks are NECESSARY for all kernels (strict = error) +STRICT_CHECKS = [ + "code_bypass", + "timing_event_patch", + "thread_injection", + "lazy_eval", +] + +# Backend-specific checks are added later at entry point +# per backend implementation check, usually strict +BACKEND_IMPL_CHECK = { + "cuda": "cuda_impl", + "triton": "triton_impl", + "thunderkittens": "tk_impl", + "cute": "cute_impl", + "cutlass": "cute_impl", # alias + "tilelang": "tilelang_impl", +} + +# These are optional checks (by user's decision) - flagged as warnings +# Move to STRICT_CHECKS if you want to enforce them +WARNING_CHECKS: List[str] = [ + # up to user to allow program to still have some torch computation ops + "pytorch_wrap", + "torch_computation_ops", + "stream_injection", # could have legitimate uses (async ops), but should be careful! + "precision_downgrade", # precision downgrading - can be intentional but often a hack +] + + +# ============================================================================= +# MAIN ENTRY POINT +# ============================================================================= + +def validate_kernel_static( + code: str, + backend: str = "cuda", + precision: str = "fp16", + forbidden: Optional[List[str]] = None, + warnings: Optional[List[str]] = None, +) -> Tuple[bool, List[str], List[str]]: + """ + Validate kernel code through statically inspecting the code + We configure the checks against check groups that we have provided for common hacks. + Note we do not guarantee that all checks are exhaustive. This is also only on the static level. + + Args: + code: Kernel source code + backend: "cuda", "triton", or "thunderkittens" + precision: "fp16", "fp32", or "bf16" (for future precision checks) + forbidden: Check categories that cause errors (default: STRICT_CHECKS) + warnings: Check categories that cause warnings (default: WARNING_CHECKS) + + Returns: + (valid, errors, warnings) + valid: bool + errors: List[str] + warnings: List[str] + """ + # Copy defaults to avoid mutating global lists + forbidden_checks = list(forbidden) if forbidden is not None else list(STRICT_CHECKS) + warning_checks = list(warnings) if warnings is not None else list(WARNING_CHECKS) + + # Add backend implementation check if specified + if backend in BACKEND_IMPL_CHECK: + impl_check = BACKEND_IMPL_CHECK[backend] + if impl_check not in forbidden_checks: + forbidden_checks.append(impl_check) + + # Aggregate results + errors: List[str] = [] + warnings_list: List[str] = [] + + for check_name in set(forbidden_checks + warning_checks): + if check_name not in CHECK_FUNCTIONS: + continue + + # Handle precision-dependent checks + if check_name in PRECISION_DEPENDENT_CHECKS: + has_issue, msg = CHECK_FUNCTIONS[check_name](code, precision) + else: + has_issue, msg = CHECK_FUNCTIONS[check_name](code) + + if has_issue: + if check_name in forbidden_checks: + errors.append(msg) + else: + warnings_list.append(msg) + + valid = len(errors) == 0 # valid if no errors + return valid, errors, warnings_list diff --git a/src/kernelbench/unit_tests/test_precision.py b/src/kernelbench/unit_tests/test_precision.py new file mode 100644 index 00000000..1e8a2299 --- /dev/null +++ b/src/kernelbench/unit_tests/test_precision.py @@ -0,0 +1,395 @@ +""" +Unit tests for precision downgrade detection in FP32 kernels. + +Tests the check_precision_downgrade function to ensure it correctly detects +various patterns that indicate precision downgrading (FP32 -> FP16), +which is a reward hacking pattern. + +Note: The current implementation only checks for FP32 -> FP16 downgrades, +not BF16 downgrades. + +Run with pytest: + pytest src/kernelbench/unit_tests/test_precision.py -v + or + uv run pytest src/kernelbench/unit_tests/test_precision.py -v +""" + +import os +import sys +import pytest + +# Add src directory to path for imports (consistent with other test files) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +from kernelbench.kernel_static_checker import check_precision_downgrade + + +# ============================================================================ +# Test Cases for FP32 -> FP16 Precision Downgrades (PyTorch patterns) +# ============================================================================ + +def test_fp32_half_method(): + """Test detection of .half() method call downgrading FP32 to FP16.""" + code = """ + def forward(self, x): + x = x.half() + return x * 2 + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect .half() method call" + assert "FP16" in message + + +def test_fp32_float16_method(): + """Test detection of .float16() method call.""" + code = """ + x = input_tensor.float16() + return x + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect .float16() method call" + assert "FP16" in message + + +def test_fp32_to_torch_half(): + """Test detection of .to(torch.half) pattern.""" + code = """ + def forward(self, x): + x = x.to(torch.half) + return x + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect .to(torch.half)" + assert "FP16" in message + + +def test_fp32_to_torch_float16(): + """Test detection of .to(torch.float16) pattern.""" + code = """ + x = x.to(torch.float16) + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect .to(torch.float16)" + assert "FP16" in message + + +def test_fp32_to_dtype_float16(): + """Test detection of .to(dtype=torch.float16) pattern.""" + code = """ + def forward(self, x): + return x.to(dtype=torch.float16) + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect .to(dtype=torch.float16)" + assert "FP16" in message + + +def test_fp32_to_dtype_half(): + """Test detection of .to(dtype=torch.half) pattern.""" + code = """ + x.to(dtype=torch.half) + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect .to(dtype=torch.half)" + assert "FP16" in message + + +# ============================================================================ +# Test Cases for FP32 -> FP16 Precision Downgrades (CUDA patterns) +# ============================================================================ + +def test_fp32_cuda_float2half(): + """Test detection of CUDA __float2half intrinsic.""" + code = """ + __global__ void kernel(float* input, __half* output) { + output[0] = __float2half(input[0]); + } + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect CUDA __float2half" + assert "FP16" in message + + +def test_fp32_cuda_float2half_rn(): + """Test detection of CUDA __float2half_rn intrinsic.""" + code = """ + __half result = __float2half_rn(value); + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect CUDA __float2half_rn" + assert "FP16" in message + + +def test_fp32_cuda_cast_to_half(): + """Test detection of CUDA C-style cast to __half.""" + code = """ + __half h = (__half)float_value; + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect CUDA cast to __half" + assert "FP16" in message + + +def test_fp32_cuda_static_cast_half(): + """Test detection of CUDA static_cast.""" + code = """ + __half h = static_cast<__half>(float_value); + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect CUDA static_cast<__half>" + assert "FP16" in message + + +# ============================================================================ +# Test Cases for FP32 -> FP16 Precision Downgrades (Triton patterns) +# ============================================================================ + +def test_fp32_triton_astype_float16(): + """Test detection of Triton tl.astype(..., tl.float16).""" + code = """ + @triton.jit + def kernel(X, Y): + x = tl.load(X) + x_fp16 = tl.astype(x, tl.float16) + tl.store(Y, x_fp16) + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect Triton tl.astype(..., tl.float16)" + assert "FP16" in message + + +# ============================================================================ +# Test Cases for FP32 -> FP16 Precision Downgrades (CUTLASS patterns) +# ============================================================================ + +def test_fp32_cutlass_numeric_converter(): + """Test detection of CUTLASS NumericConverter.""" + code = """ + NumericConverter converter; + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect CUTLASS NumericConverter" + assert "FP16" in message + + +def test_fp32_cutlass_linear_combination(): + """Test detection of CUTLASS LinearCombination.""" + code = """ + LinearCombination epilogue; + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect CUTLASS LinearCombination" + assert "FP16" in message + + +def test_fp32_cute_type_convert(): + """Test detection of CuTe type_convert.""" + code = """ + auto result = type_convert(input); + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect CuTe type_convert" + assert "FP16" in message + + +# ============================================================================ +# Test Cases for FP32 -> FP16 Precision Downgrades (ThunderKittens patterns) +# ============================================================================ + +def test_fp32_thunderkittens_half(): + """Test detection of ThunderKittens tk::half().""" + code = """ + using namespace tk; + auto h = tk::half(value); + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect ThunderKittens tk::half()" + assert "FP16" in message + + +# ============================================================================ +# Test Cases for Valid FP32 Code (Should NOT be detected) +# ============================================================================ + +def test_fp32_legitimate_code_no_downgrade(): + """Test that legitimate FP32 code is not flagged.""" + code = """ + def forward(self, x): + # Legitimate FP32 operations + x = x * 2.0 + y = torch.matmul(x, x.t()) + return y.float() # This is fine, float() returns float32 + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is False, "Should not detect downgrade in legitimate FP32 code" + assert message == "" + + +def test_fp32_with_comments_mentioning_half(): + """Test that comments mentioning half precision don't trigger false positives.""" + code = """ + def forward(self, x): + # Note: This should use FP32, not FP16 + # Don't use .half() here! + x = x * 2.0 + return x + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is False, "Comments should not trigger detection" + + +def test_fp32_string_literals(): + """Test that string literals mentioning precision don't trigger false positives.""" + code = """ + def forward(self, x): + error_msg = "dtype=torch.float16 is not allowed" + print(error_msg) + return x + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is False, "String literals should not trigger detection" + + +def test_fp32_function_names_containing_half(): + """Test that function names containing 'half' don't trigger false positives.""" + code = """ + def compute_half_tensor(self, x): + # Function name contains 'half' but doesn't downgrade precision + return x * 0.5 + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is False, "Function names should not trigger detection" + + +# ============================================================================ +# Test Cases for Edge Cases +# ============================================================================ + +def test_fp32_case_insensitive_precision(): + """Test that precision parameter is case-insensitive.""" + code = """ + x = x.half() + """ + detected1, _ = check_precision_downgrade(code, precision="FP32") + detected2, _ = check_precision_downgrade(code, precision="fp32") + detected3, _ = check_precision_downgrade(code, precision="Fp32") + assert detected1 == detected2 == detected3 == True, "Should handle case-insensitive precision" + + +def test_fp32_alternative_precision_names(): + """Test that alternative precision names are normalized correctly.""" + code = """ + x = x.half() + """ + detected1, _ = check_precision_downgrade(code, precision="float32") + detected2, _ = check_precision_downgrade(code, precision="fp32") + assert detected1 == detected2 == True, "Should normalize float32 to fp32" + + +def test_fp32_non_fp32_precision_skips_check(): + """Test that non-FP32 precision skips the check (implementation only checks FP32).""" + code = """ + x = x.half() + """ + detected, message = check_precision_downgrade(code, precision="fp16") + assert detected is False, "Should skip check for non-FP32 precision" + assert message == "" + + +def test_fp32_unknown_precision_skips_check(): + """Test that unknown precision skips the check.""" + code = """ + x = x.half() + """ + detected, message = check_precision_downgrade(code, precision="int8") + assert detected is False, "Should skip check for unknown precision" + assert message == "" + + +def test_fp32_empty_code(): + """Test that empty code doesn't trigger detection.""" + code = "" + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is False, "Empty code should not trigger detection" + + +def test_fp32_whitespace_only(): + """Test that whitespace-only code doesn't trigger detection.""" + code = " \n\n\t \n " + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is False, "Whitespace-only code should not trigger detection" + + +# ============================================================================ +# Test Cases for Real-World Scenarios +# ============================================================================ + +def test_fp32_matmul_with_hidden_downgrade(): + """Test detection in a realistic matmul kernel that downgrades internally.""" + code = """ + @triton.jit + def matmul_kernel(A, B, C, M, N, K): + # Compute matrix multiplication + a = tl.load(A) + b = tl.load(B) + + # Sneaky precision downgrade + a = tl.astype(a, tl.float16) + b = tl.astype(b, tl.float16) + + c = tl.dot(a, b) + tl.store(C, c) + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect hidden precision downgrade in matmul" + assert "FP16" in message + + +def test_fp32_cuda_kernel_with_half(): + """Test detection in a realistic CUDA kernel.""" + code = """ + __global__ void add_kernel(float* a, float* b, float* c, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + __half a_half = __float2half(a[idx]); + __half b_half = __float2half(b[idx]); + c[idx] = __half2float(a_half + b_half); + } + } + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect CUDA kernel using __float2half" + assert "FP16" in message + + +def test_fp32_complex_code_with_downgrade(): + """Test detection in complex code with multiple operations.""" + code = """ + def forward(self, x, y): + # Some preprocessing + x = x * 2.0 + y = y + 1.0 + + # Main computation with downgrade + x = x.to(dtype=torch.float16) + result = torch.matmul(x, y) + + # Post-processing + result = result * 3.0 + return result + """ + detected, message = check_precision_downgrade(code, precision="fp32") + assert detected is True, "Should detect downgrade in complex code" + assert "FP16" in message + + +# ============================================================================ +# Note on BF16 Tests +# ============================================================================ +# The current implementation only checks for FP32 -> FP16 downgrades. +# BF16 downgrade detection is not yet implemented. These tests document +# expected behavior when BF16 support is added in the future. + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/src/kernelbench/unit_tests/test_static_checker.py b/src/kernelbench/unit_tests/test_static_checker.py new file mode 100644 index 00000000..d4411fff --- /dev/null +++ b/src/kernelbench/unit_tests/test_static_checker.py @@ -0,0 +1,166 @@ +""" +Tests for kernel_static_checker.py + +Validates that the static checker correctly identifies: +- Valid DSL kernels (no false positives) +- Known adversarial/hack patterns (no false negatives) + +We welcome contributions to improve the static checker and providing adversarial kernels. + + +Run with: uv run pytest src/kernelbench/unit_tests/test_static_checker.py -v +""" + +import pytest +from pathlib import Path +from kernelbench.kernel_static_checker import validate_kernel_static + + +# ============================================================================= +# Fixtures - Common paths and helpers +# ============================================================================= + +@pytest.fixture +def prompts_dir(): + """Path to DSL example kernels.""" + return Path(__file__).parent.parent / "prompts" + +@pytest.fixture +def test_kernels_dir(): + """Path to adversarial test kernels.""" + return Path(__file__).parent / "test_kernels" + + +def read_kernel(path: Path) -> str: + """Read kernel code from file, skip if not found.""" + if not path.exists(): + pytest.skip(f"Kernel file not found: {path}") + return path.read_text() + + +# ============================================================================= +# Valid DSL Kernels - Should Pass (No False Positives) +# These are real, correct kernels from src/kernelbench/prompts/ +# ============================================================================= + +def test_cuda_example_valid(prompts_dir): + """Real CUDA kernel example should pass with default settings.""" + code = read_kernel(prompts_dir / "model_new_ex_add.py") + valid, errors, warnings = validate_kernel_static(code, backend="cuda") + # May have warnings (F import), but should be valid + assert valid or "import F" in str(warnings), f"CUDA example should pass: {errors}" + + +def test_triton_example_valid(prompts_dir): + """Real Triton kernel example should pass.""" + code = read_kernel(prompts_dir / "model_new_ex_add_triton.py") + valid, errors, warnings = validate_kernel_static(code, backend="triton") + assert valid or len(warnings) > 0, f"Triton example should pass: {errors}" + + +def test_cute_example_valid(prompts_dir): + """Real CuTe/CUTLASS kernel example should pass.""" + code = read_kernel(prompts_dir / "model_new_ex_add_cute.py") + valid, errors, warnings = validate_kernel_static(code, backend="cute") + assert valid or len(warnings) > 0, f"CuTe example should pass: {errors}" + + +def test_tilelang_example_valid(prompts_dir): + """Real TileLang kernel example should pass.""" + code = read_kernel(prompts_dir / "model_new_ex_add_tilelang.py") + valid, errors, warnings = validate_kernel_static(code, backend="tilelang") + assert valid or len(warnings) > 0, f"TileLang example should pass: {errors}" + + +# ============================================================================= +# Adversarial Kernels - Should Detect Issues +# These are known hack patterns from test_kernels/ +# ============================================================================= + +def test_stream_kernel_flagged(test_kernels_dir): + """Non-default stream kernel should trigger stream_injection warning.""" + code = read_kernel(test_kernels_dir / "non_default_stream_kernel.py") + valid, errors, warnings = validate_kernel_static(code, backend="cuda") + # Stream injection is in warnings by default + all_messages = errors + warnings + has_stream_warning = any("stream" in msg.lower() for msg in all_messages) + # Note: The CUDA code is in a string literal, so static checker may not catch it + # This test documents the limitation + + +def test_result_reuse_kernel_flagged(test_kernels_dir): + """Result reuse (empty tensor) kernel - static checker can't catch this.""" + code = read_kernel(test_kernels_dir / "result_reuse_kernel.py") + valid, errors, warnings = validate_kernel_static(code, backend="cuda") + # This is a runtime check, static checker won't catch it + # Just verify it doesn't crash + + +def test_zero_out_kernel_flagged(test_kernels_dir): + """Zero-out kernel - static checker can't catch this.""" + code = read_kernel(test_kernels_dir / "zero_out_kernel.py") + valid, errors, warnings = validate_kernel_static(code, backend="cuda") + # This is a correctness issue, not detectable statically + + +# ============================================================================= +# Hack Patterns - Synthetic Examples +# ============================================================================= + +def test_bypass_try_except(): + """Try-except fallback should be flagged as error.""" + code = """ +try: + result = custom_kernel(x) +except: + result = torch.matmul(x, w) # Fallback to torch +""" + valid, errors, warnings = validate_kernel_static(code, backend="cuda") + assert not valid, "Try-except should be flagged" + assert any("try-except" in e.lower() for e in errors) + + +def test_bypass_pass_statement(): + """Pass statement (inheritance bypass) should be flagged.""" + code = """ +class ModelNew(Model): + def forward(self, x): + pass # Does nothing, inherits parent +""" + valid, errors, warnings = validate_kernel_static(code, backend="cuda") + assert not valid, "Pass statement should be flagged" + assert any("pass" in e.lower() for e in errors) + + +def test_lazy_eval_make_subclass(): + """_make_subclass (lazy tensor hack) should be flagged.""" + code = """ +fake_tensor = torch.Tensor._make_subclass(FakeTensor, real_tensor) +""" + valid, errors, warnings = validate_kernel_static(code, backend="cuda") + assert not valid, "_make_subclass should be flagged" + + +def test_timing_monkey_patch(): + """Monkey patching timing functions should be flagged.""" + code = """ +# Override timing to fake benchmarks +torch.cuda.Event.elapsed_time = lambda self, end: 0.001 +""" + valid, errors, warnings = validate_kernel_static(code, backend="cuda") + assert not valid, "Timing monkey patch should be flagged" + + +def test_thread_injection(): + """Threading in kernel code should be flagged.""" + code = """ +import threading +t = threading.Thread(target=background_work) +t.start() +""" + valid, errors, warnings = validate_kernel_static(code, backend="cuda") + assert not valid, "Threading should be flagged" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/src/kernelbench/unit_tests/test_validate_kernel_static.py b/src/kernelbench/unit_tests/test_validate_kernel_static.py new file mode 100644 index 00000000..55d97a2c --- /dev/null +++ b/src/kernelbench/unit_tests/test_validate_kernel_static.py @@ -0,0 +1,423 @@ +""" +Unit tests for validate_kernel_static function. + +Tests the main entry point function to ensure it correctly: +- Passes precision to precision-dependent checks +- Categorizes errors vs warnings correctly +- Handles backend-specific checks +- Respects forbidden/warnings parameters +- Returns correct output format + +Run with pytest: + pytest src/kernelbench/unit_tests/test_validate_kernel_static.py -v + or + uv run pytest src/kernelbench/unit_tests/test_validate_kernel_static.py -v +""" + +import os +import sys +import pytest + +# Add src directory to path for imports (consistent with other test files) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +from kernelbench.kernel_static_checker import validate_kernel_static + + +# ============================================================================ +# Test Basic Function Signature and Return Values +# ============================================================================ + +def test_validate_kernel_static_returns_tuple(): + """Test that validate_kernel_static returns a tuple of (valid, errors, warnings).""" + code = "x = 1 + 1" + result = validate_kernel_static(code) + + assert isinstance(result, tuple), "Should return a tuple" + assert len(result) == 3, "Should return (valid, errors, warnings)" + valid, errors, warnings = result + assert isinstance(valid, bool), "First element should be bool" + assert isinstance(errors, list), "Second element should be list" + assert isinstance(warnings, list), "Third element should be list" + + +def test_validate_kernel_static_defaults(): + """Test that validate_kernel_static works with default parameters.""" + code = "x = 1 + 1" + valid, errors, warnings = validate_kernel_static(code) + + # Should work without errors for simple valid code + assert isinstance(valid, bool) + assert isinstance(errors, list) + assert isinstance(warnings, list) + + +# ============================================================================ +# Test Precision Parameter Passing +# ============================================================================ + +def test_precision_passed_to_precision_checker_fp32(): + """Test that precision parameter is correctly passed to precision-dependent checks.""" + # Code with FP32 -> FP16 downgrade + code = """ + def forward(self, x): + x = x.half() # Precision downgrade + return x + """ + + # With fp32 precision, should detect downgrade (as warning by default) + valid, errors, warnings = validate_kernel_static(code, precision="fp32") + + # Check that precision downgrade was detected (should be in warnings by default) + all_messages = errors + warnings + has_precision_warning = any("precision" in msg.lower() or "fp16" in msg.lower() + for msg in all_messages) + assert has_precision_warning, "Should detect precision downgrade with fp32 precision" + + +def test_precision_passed_to_precision_checker_fp16(): + """Test that fp16 precision skips FP32 -> FP16 downgrade check.""" + # Code with FP32 -> FP16 downgrade + code = """ + def forward(self, x): + x = x.half() # Precision downgrade + return x + """ + + # With fp16 precision, precision downgrade check should be skipped + valid, errors, warnings = validate_kernel_static(code, precision="fp16") + + # Should not detect precision downgrade (check is skipped for non-FP32) + all_messages = errors + warnings + has_precision_warning = any("precision" in msg.lower() or "fp16" in msg.lower() + for msg in all_messages) + # This is expected - the check only runs for fp32 + # So for fp16, it won't flag this + + +def test_precision_case_insensitive(): + """Test that precision parameter is case-insensitive.""" + code = """ + def forward(self, x): + x = x.half() + return x + """ + + # Test different case variations + result1 = validate_kernel_static(code, precision="FP32") + result2 = validate_kernel_static(code, precision="fp32") + result3 = validate_kernel_static(code, precision="Fp32") + + # All should produce the same result + assert result1 == result2 == result3, "Precision should be case-insensitive" + + +def test_precision_alternative_names(): + """Test that alternative precision names are normalized.""" + code = """ + def forward(self, x): + x = x.half() + return x + """ + + # float32 should be normalized to fp32 + result1 = validate_kernel_static(code, precision="float32") + result2 = validate_kernel_static(code, precision="fp32") + + assert result1 == result2, "float32 should be normalized to fp32" + + +# ============================================================================ +# Test Error vs Warning Categorization +# ============================================================================ + +def test_strict_checks_are_errors(): + """Test that strict checks (like code_bypass) produce errors.""" + code = """ + try: + result = custom_kernel(x) + except: + result = torch.matmul(x, w) # Fallback to torch + """ + + valid, errors, warnings = validate_kernel_static(code) + + assert not valid, "Code with strict violations should be invalid" + assert len(errors) > 0, "Strict checks should produce errors, not warnings" + assert any("try-except" in e.lower() or "bypass" in e.lower() + for e in errors), "Should flag bypass in errors" + + +def test_warning_checks_are_warnings(): + """Test that warning checks produce warnings, not errors.""" + code = """ + def forward(self, x): + x = x.half() # Precision downgrade - in warnings by default + return x + """ + + # Test with default settings - precision_downgrade should be in warnings + valid, errors, warnings = validate_kernel_static( + code, + precision="fp32" + # Using defaults - precision_downgrade is in WARNING_CHECKS + ) + + # Check that precision downgrade message is in warnings (if detected) + # Note: backend impl checks might add errors, but precision should be in warnings + precision_warnings = [w for w in warnings if "precision" in w.lower() or "fp16" in w.lower()] + precision_errors = [e for e in errors if "precision" in e.lower() or "fp16" in e.lower()] + + if precision_warnings or precision_errors: + # If precision downgrade is detected, it should be in warnings, not errors + assert len(precision_warnings) > 0, "Precision downgrade should be in warnings (default)" + assert len(precision_errors) == 0, "Precision downgrade should not be in errors (default)" + + +def test_custom_forbidden_checks(): + """Test that custom forbidden checks produce errors.""" + code = """ + def forward(self, x): + x = x.half() # Precision downgrade + return x + """ + + # Make precision_downgrade a forbidden check (error) instead of warning + valid, errors, warnings = validate_kernel_static( + code, + precision="fp32", + forbidden=["precision_downgrade"] + ) + + assert not valid, "Should be invalid when precision_downgrade is forbidden" + assert len(errors) > 0, "Forbidden checks should produce errors" + assert any("precision" in e.lower() or "fp16" in e.lower() + for e in errors), "Should flag precision downgrade in errors" + + +def test_custom_warnings_list(): + """Test that custom warnings list works.""" + code = """ + try: + result = custom_kernel(x) + except: + result = torch.matmul(x, w) + """ + + # Move code_bypass to warnings instead of errors + # Use a backend that won't add strict impl checks + valid, errors, warnings = validate_kernel_static( + code, + backend="cuda", # Explicit backend + forbidden=[], # No forbidden checks + warnings=["code_bypass"] # Make bypass a warning + ) + + # Note: Backend might add impl checks, so we check that code_bypass + # appears in warnings (not errors) if it's detected + all_messages = errors + warnings + bypass_messages = [msg for msg in all_messages if "bypass" in msg.lower() or "try-except" in msg.lower()] + + if bypass_messages: + # If bypass is detected, it should be in warnings, not errors + bypass_in_warnings = any(msg in warnings for msg in bypass_messages) + assert bypass_in_warnings, "Bypass should be in warnings when specified as warning" + + +# ============================================================================ +# Test Backend Parameter Handling +# ============================================================================ + +def test_backend_adds_impl_check(): + """Test that backend parameter adds appropriate implementation check.""" + code = """ + # This code doesn't have CUDA implementation + def forward(self, x): + return x * 2 + """ + + valid, errors, warnings = validate_kernel_static(code, backend="cuda") + + # Should check for CUDA implementation (cuda_impl check) + # The exact behavior depends on what cuda_impl check does, + # but we can verify the backend parameter is processed + assert isinstance(valid, bool) + assert isinstance(errors, list) + assert isinstance(warnings, list) + + +def test_different_backends(): + """Test that different backends are handled correctly.""" + code = """ + def forward(self, x): + return x * 2 + """ + + # Test multiple backends + backends = ["cuda", "triton", "thunderkittens", "cute", "tilelang"] + + for backend in backends: + valid, errors, warnings = validate_kernel_static(code, backend=backend) + assert isinstance(valid, bool) + assert isinstance(errors, list) + assert isinstance(warnings, list) + + +# ============================================================================ +# Test Edge Cases +# ============================================================================ + +def test_empty_code(): + """Test handling of empty code.""" + code = "" + + valid, errors, warnings = validate_kernel_static(code) + + assert isinstance(valid, bool) + assert isinstance(errors, list) + assert isinstance(warnings, list) + + +def test_whitespace_only_code(): + """Test handling of whitespace-only code.""" + code = " \n\n\t \n " + + valid, errors, warnings = validate_kernel_static(code) + + assert isinstance(valid, bool) + assert isinstance(errors, list) + assert isinstance(warnings, list) + + +def test_unknown_check_name(): + """Test that unknown check names are ignored.""" + code = "x = 1" + + # Should not crash with unknown check names + valid, errors, warnings = validate_kernel_static( + code, + forbidden=["unknown_check_that_doesnt_exist"], + warnings=["another_unknown_check"] + ) + + assert isinstance(valid, bool) + assert isinstance(errors, list) + assert isinstance(warnings, list) + + +def test_multiple_precision_dependent_checks(): + """Test that multiple precision-dependent checks work (if any exist in future).""" + code = """ + def forward(self, x): + x = x.half() + return x + """ + + # Currently only precision_downgrade is precision-dependent + valid, errors, warnings = validate_kernel_static(code, precision="fp32") + + assert isinstance(valid, bool) + assert isinstance(errors, list) + assert isinstance(warnings, list) + + +# ============================================================================ +# Test Integration: Precision + Backend + Custom Checks +# ============================================================================ + +def test_integration_precision_backend_forbidden(): + """Test integration of precision, backend, and custom forbidden checks.""" + code = """ + def forward(self, x): + x = x.half() # Precision downgrade + return x + """ + + valid, errors, warnings = validate_kernel_static( + code, + backend="cuda", + precision="fp32", + forbidden=["precision_downgrade"] + ) + + assert not valid, "Should be invalid with precision downgrade as forbidden" + assert len(errors) > 0, "Should have errors" + assert any("precision" in e.lower() or "fp16" in e.lower() + for e in errors), "Should flag precision downgrade" + + +def test_integration_all_parameters(): + """Test with all parameters specified.""" + code = """ + def forward(self, x): + return x * 2.0 + """ + + valid, errors, warnings = validate_kernel_static( + code, + backend="triton", + precision="fp16", + forbidden=["code_bypass"], + warnings=["precision_downgrade"] + ) + + assert isinstance(valid, bool) + assert isinstance(errors, list) + assert isinstance(warnings, list) + + +# ============================================================================ +# Test Precision Check Integration +# ============================================================================ + +def test_precision_check_in_warnings_by_default(): + """Test that precision_downgrade is in warnings by default.""" + code = """ + def forward(self, x): + x = x.half() + return x + """ + + valid, errors, warnings = validate_kernel_static(code, precision="fp32") + + # precision_downgrade should be in WARNING_CHECKS by default + # So it should produce warnings, not errors + all_messages = errors + warnings + has_precision_msg = any("precision" in msg.lower() or "fp16" in msg.lower() + for msg in all_messages) + + if has_precision_msg: + # If detected, should be in warnings, not errors (by default) + precision_in_warnings = any("precision" in msg.lower() or "fp16" in msg.lower() + for msg in warnings) + assert precision_in_warnings, "Precision downgrade should be in warnings by default" + + +def test_precision_check_respects_forbidden(): + """Test that precision_downgrade respects forbidden parameter.""" + code = """ + def forward(self, x): + x = x.half() + return x + """ + + # Make precision_downgrade forbidden + valid, errors, warnings = validate_kernel_static( + code, + precision="fp32", + forbidden=["precision_downgrade"], + warnings=[] # Remove from warnings + ) + + # Should produce errors, not warnings + has_precision_in_errors = any("precision" in msg.lower() or "fp16" in msg.lower() + for msg in errors) + + if has_precision_in_errors: + assert not valid, "Should be invalid when precision downgrade is forbidden" + assert len(errors) > 0, "Should have errors" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) +