diff --git a/.env.example b/.env.example index 2150e6d9..b31c503f 100644 --- a/.env.example +++ b/.env.example @@ -9,7 +9,7 @@ OPENAI_API_KEY=sk-... ANTHROPIC_API_KEY=sk-ant-api03-... # Google Gemini -GEMINI_API_KEY=... +GEMINI_API_KEY= # DeepSeek DEEPSEEK_API_KEY=sk-... diff --git a/scripts/benchmark_eval_analysis.py b/scripts/benchmark_eval_analysis.py index ec950b03..4df4e30b 100644 --- a/scripts/benchmark_eval_analysis.py +++ b/scripts/benchmark_eval_analysis.py @@ -53,7 +53,7 @@ def patch(eval_results, dataset): """ Patch the eval results with the dataset """ - for pid in range(1, len(dataset) + 1): + for pid in dataset.get_problem_ids(): if str(pid) not in eval_results: eval_results[str(pid)] = { "sample_id": 0, @@ -161,19 +161,40 @@ def analyze_greedy_eval(run_name, hardware, baseline, level, ) # Extract the speedup values - is_correct = np.array([entry["correctness"] for entry in eval_results.values()]) - baseline_speed = np.array( - [entry["mean"] for entry in baseline_results[f"level{level}"].values()] - ) - actual_speed = np.array([entry["runtime"] for entry in eval_results.values()]) + is_correct_list = [] + baseline_speed_list = [] + actual_speed_list = [] + + # Sort problem IDs to ensure consistent order + sorted_pids = sorted(dataset.get_problem_ids()) + + for pid in sorted_pids: + # Get eval result + if str(pid) not in eval_results: + print(f"Warning: Problem {pid} not found in eval results") + continue + eval_entry = eval_results[str(pid)] + + # Get baseline result + problem = dataset.get_problem_by_id(pid) + problem_name = problem.name + + if problem_name not in baseline_results[f"level{level}"]: + print(f"Warning: Problem {problem_name} not found in baseline results") + continue + + baseline_entry = baseline_results[f"level{level}"][problem_name] + + is_correct_list.append(eval_entry["correctness"]) + actual_speed_list.append(eval_entry["runtime"]) + baseline_speed_list.append(baseline_entry["mean"]) + + is_correct = np.array(is_correct_list) + baseline_speed = np.array(baseline_speed_list) + actual_speed = np.array(actual_speed_list) n = len(is_correct) - assert ( - len(baseline_speed) == n - ), "Baseline speedup values do not match the number of eval results" - assert ( - len(actual_speed) == n - ), "Actual speedup values do not match the number of eval results" + print(f"Aligned {n} problems for analysis") # Calculate the metrics gmsr_correct = geometric_mean_speed_ratio_correct_only( diff --git a/scripts/eval_from_generations.py b/scripts/eval_from_generations.py index 9fd8d745..247410f3 100644 --- a/scripts/eval_from_generations.py +++ b/scripts/eval_from_generations.py @@ -12,7 +12,6 @@ import pydra import torch -from datasets import load_dataset from pydra import Config, REQUIRED # Import only what we need @@ -255,36 +254,17 @@ def evaluate_single_sample_modal( def fetch_ref_arch_from_problem_id( - dataset, problem_id: int, dataset_src: str + dataset, problem_id: int, dataset_src: str = None ) -> str | None: """ - Fetch reference architecture from problem directory - Either from Hugging Face or Local Dataset + Fetch reference architecture from problem directory. + Uses the unified dataset interface. + + Note: dataset_src parameter is kept for backward compatibility but ignored + since the dataset object already handles both sources. """ - if dataset_src == "huggingface": - curr_problem_row = dataset.filter( - lambda x: x["problem_id"] == problem_id, num_proc=None, desc=None - ) - ref_arch_src = curr_problem_row["code"][0] - problem_name = curr_problem_row["name"][0] - - elif dataset_src == "local": - problem_idx_in_dataset = ( - problem_id - 1 - ) # due to dataset list being 0-indexed locally - ref_arch_path = dataset[problem_idx_in_dataset] - - problem_name = os.path.basename(ref_arch_path) - ref_arch_src = read_file(ref_arch_path) - - # verify - # Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py") - problem_number = int(problem_name.split("_")[0]) - assert ( - problem_number == problem_id - ), f"Problem number in filename ({problem_number}) does not match config problem_id ({problem_id})" - - return ref_arch_src + problem = dataset.get_problem_by_id(problem_id) + return problem.code def fetch_kernel_from_disk( @@ -822,57 +802,48 @@ def main(config: EvalConfig): if mp.get_start_method(allow_none=True) is None: mp.set_start_method("spawn") - # Dataset Configurations - if config.dataset_src == "huggingface": - dataset = load_dataset(config.dataset_name) - curr_level_dataset = dataset[f"level_{config.level}"] - elif config.dataset_src == "local": - curr_level_dataset = construct_kernelbench_dataset(config.level) - - num_problems_in_level = len(curr_level_dataset) - - # Determine which problem IDs to evaluate - # you can either specify a list of problem IDs (prioritize) or a subset range - # NOTE: later once the dataset PR is in we will link the representative subset as a built-in preset too - if config.problem_ids is not None: - # Use specific problem IDs if provided - problem_id_list = config.problem_ids - for pid in problem_id_list: - assert 1 <= pid <= num_problems_in_level, f"Problem ID {pid} out of range for Level {config.level}" - elif config.subset == (None, None): - problem_id_list = list(range(1, num_problems_in_level + 1)) + # Dataset Configurations - Unified loading + dataset = construct_kernelbench_dataset( + level=config.level, + source=config.dataset_src, + dataset_name=config.dataset_name, + ) + + all_problem_ids = dataset.get_problem_ids() + + if config.subset == (None, None): + problem_ids_to_run = all_problem_ids else: - assert ( - config.subset[0] >= 1 and config.subset[1] <= num_problems_in_level - ), f"Subset range {config.subset} out of range for Level {config.level}" - problem_id_list = list(range(config.subset[0], config.subset[1] + 1)) + start, end = config.subset + problem_ids_to_run = [pid for pid in all_problem_ids if start <= pid <= end] + if not problem_ids_to_run: + print(f"Warning: No problems found in subset range {config.subset}") print( - f"Evaluating {config.num_samples_per_problem} sample(s) each for level {config.level} problems: {problem_id_list}" + f"Evaluating {config.num_samples_per_problem} sample(s) each for level {config.level} problems: {problem_ids_to_run}" ) run_dir = os.path.join(config.runs_dir, config.run_name) eval_file_path = os.path.join(run_dir, f"eval_results.json") # To Debug - # single_eval_example(config, curr_level_dataset, run_dir, eval_file_path) + # single_eval_example(config, dataset, run_dir, eval_file_path) total_work = [] - for problem_id in problem_id_list: + for problem_id in problem_ids_to_run: for sample_id in range(config.num_samples_per_problem): if not check_if_eval_exists_local(problem_id, sample_id, eval_file_path): total_work.append((problem_id, sample_id)) print( f"Start evaluation on {len(total_work)} unevaluated samples" - f" for problems: {problem_id_list}" + f" in range: {problem_ids_to_run}" ) # Build Cache on CPU as that is faster (only for local mode) if config.build_cache and config.eval_mode == "local": compile.batch_compile(total_work, config.to_dict()) - # Batch Eval on multiple GPUs in parallel - batch_eval(total_work, config, curr_level_dataset, run_dir, eval_file_path) + batch_eval(total_work, config, dataset, run_dir, eval_file_path) # Calculate pass@k metrics if multiple samples per problem were evaluated if config.num_samples_per_problem > 1: diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index c42ea66a..5308c26e 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -5,16 +5,12 @@ import json import modal -from datasets import load_dataset - -from kernelbench.dataset import construct_kernelbench_dataset from kernelbench.eval import eval_kernel_against_ref from kernelbench.prompt_constructor_toml import get_prompt_for_backend, get_custom_prompt from kernelbench.utils import ( create_inference_server_from_presets, extract_first_code, query_server, - read_file, set_gpu_arch, ) from kernelbench.eval import get_torch_dtype_from_string @@ -116,13 +112,14 @@ def main(config: EvalConfig): print(f"Starting Eval with config: {config}") - # Configurations - - if config.dataset_src == "huggingface": - dataset = load_dataset(config.dataset_name) - curr_level_dataset = dataset[f"level_{config.level}"] - elif config.dataset_src == "local": - curr_level_dataset = construct_kernelbench_dataset(config.level) + # Configurations - Unified dataset loading (works for both HF and local) + from kernelbench.dataset import construct_kernelbench_dataset + + dataset = construct_kernelbench_dataset( + level=config.level, + source=config.dataset_src, + dataset_name=config.dataset_name, + ) if config.gpu_arch: set_gpu_arch(config.gpu_arch) # otherwise build for all architectures @@ -131,41 +128,16 @@ def main(config: EvalConfig): os.makedirs(config.logdir, exist_ok=True) # Problem Checks - num_problems = len(curr_level_dataset) + num_problems = len(dataset) print(f"Number of problems in Level {config.level}: {num_problems}") print( f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}" ) - assert ( - config.problem_id <= num_problems - ), f"Problem ID {config.problem_id} out of range for Level {config.level}" - - # TODO: refactor dataset fetching logic to be as clean as posisble. - # 1. Fetch Problem - if config.dataset_src == "huggingface": - - curr_problem_row = curr_level_dataset.filter( - lambda x: x["problem_id"] == config.problem_id - ) - ref_arch_src = curr_problem_row["code"][0] - problem_name = curr_problem_row["name"][0] - - elif config.dataset_src == "local": - problem_idx_in_dataset = ( - config.problem_id - 1 - ) # due to dataset list being 0-indexed locally - ref_arch_path = curr_level_dataset[problem_idx_in_dataset] - - problem_name = os.path.basename(ref_arch_path) - ref_arch_src = read_file(ref_arch_path) - # import pdb; pdb.set_trace() - - # Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py") - problem_number = int(problem_name.split("_")[0]) - assert ( - problem_number == config.problem_id - ), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" + # Fetch problem - unified interface, no branching needed + problem = dataset.get_problem_by_id(config.problem_id) + ref_arch_src = problem.code + problem_name = problem.name # 2. Generate Sample # Create inference function with config parameters diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index d8dae68f..4b62533c 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -11,10 +11,8 @@ import json import modal -from datasets import load_dataset - -#from src.dataset import construct_kernelbench_dataset -from kernelbench.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets +from kernelbench.dataset import construct_kernelbench_dataset +from kernelbench.utils import extract_first_code, query_server, set_gpu_arch, create_inference_server_from_presets app = modal.App("eval_single_sample") @@ -155,41 +153,25 @@ def main(config: EvalConfig): print(f"Starting Eval with config: {config}") - # Configurations - - if config.dataset_src == "huggingface": - dataset = load_dataset(config.dataset_name) - curr_level_dataset = dataset[f"level_{config.level}"] + # Configurations - Unified dataset loading (works for both HF and local) + dataset = construct_kernelbench_dataset( + level=config.level, + source=config.dataset_src, + dataset_name=config.dataset_name, + ) if config.log: os.makedirs(config.logdir, exist_ok=True) # Problem Checks - num_problems = len(curr_level_dataset) + num_problems = len(dataset) print(f"Number of problems in Level {config.level}: {num_problems}") print(f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}") - assert config.problem_id <= num_problems, f"Problem ID {config.problem_id} out of range for Level {config.level}" - - - # 1. Fetch Problem - if config.dataset_src == "huggingface": - - curr_problem_row = curr_level_dataset.filter(lambda x: x["problem_id"] == config.problem_id) - ref_arch_src = curr_problem_row["code"][0] - problem_name = curr_problem_row["name"][0] - - elif config.dataset_src == "local": - problem_idx_in_dataset = config.problem_id - 1 # due to dataset list being 0-indexed locally - ref_arch_path = curr_level_dataset[problem_idx_in_dataset] - - problem_name = os.path.basename(ref_arch_path) - ref_arch_src = read_file(ref_arch_path) - # import pdb; pdb.set_trace() - - # Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py") - problem_number = int(problem_name.split("_")[0]) - assert problem_number == config.problem_id, f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" + # Fetch problem - unified interface, no branching needed + problem = dataset.get_problem_by_id(config.problem_id) + ref_arch_src = problem.code + problem_name = problem.name # 2. Generate Sample diff --git a/scripts/generate_baseline_time.py b/scripts/generate_baseline_time.py index 2d992d1e..a2c96a03 100644 --- a/scripts/generate_baseline_time.py +++ b/scripts/generate_baseline_time.py @@ -9,7 +9,7 @@ get_timing_function, get_timing_stats, ) -from kernelbench.dataset import construct_problem_dataset_from_problem_dir +from kernelbench.dataset import construct_kernelbench_dataset, fetch_ref_arch_from_dataset from kernelbench.utils import read_file import os import json @@ -48,32 +48,6 @@ TIMING_DIR = os.path.join(REPO_TOP_PATH, "results", "timing") -def fetch_ref_arch_from_dataset(dataset: list[str], - problem_id: int) -> tuple[str, str, str]: - """ - Fetch the reference architecture from the problem directory - problem_id should be logical index (1-indexed), matching the problem_id in the problem_name - - Returns: - ref_arch_path: str, the path to the reference architecture - ref_arch_name: str, the name of the reference architecture - ref_arch_src: str, the source code of the reference architecture - """ - ref_arch_path = None - - for file in dataset: - if file.split("/")[-1].split("_")[0] == str(problem_id): - ref_arch_path = file - break - if ref_arch_path is None: - raise ValueError(f"No reference architecture found for problem_id {problem_id}") - - ref_arch_src = read_file(ref_arch_path) - - ref_arch_name = ref_arch_path.split("/")[-1] - return (ref_arch_path, ref_arch_name, ref_arch_src) - - def measure_program_time( ref_arch_name: str, ref_arch_src: str, @@ -149,12 +123,11 @@ def record_baseline_times(use_torch_compile: bool = False, json_results = {} for level in [1, 2, 3]: - PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level)) - dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR) + dataset = construct_kernelbench_dataset(level) json_results[f"level{level}"] = {} num_problems = len(dataset) - for problem_id in tqdm(range(1, num_problems + 1)): + for problem_id in tqdm(dataset.get_problem_ids()): ref_arch_path, ref_arch_name, ref_arch_src = fetch_ref_arch_from_dataset(dataset, problem_id) runtime_stats = measure_program_time( ref_arch_name=ref_arch_name, @@ -180,8 +153,7 @@ def test_measure_particular_program(level_num: int, problem_id: int): """ device = torch.device("cuda:0") - PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level_num)) - dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR) + dataset = construct_kernelbench_dataset(level_num) ref_arch_path, ref_arch_name, ref_arch_src = fetch_ref_arch_from_dataset(dataset, problem_id) @@ -255,7 +227,7 @@ def get_time_old(level_num, problem_id, num_trials=100, torch_compile=False): ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id( level_num, problem_id, with_name=True ) - ref_arch_name = ref_arch_name.split("/")[-1] + ref_arch_name = os.path.basename(ref_arch_name) context = {} Model, get_init_inputs, get_inputs = load_original_model_and_inputs( ref_arch_src, context diff --git a/scripts/generate_baseline_time_modal.py b/scripts/generate_baseline_time_modal.py index 37bda8ae..22bdb1b0 100644 --- a/scripts/generate_baseline_time_modal.py +++ b/scripts/generate_baseline_time_modal.py @@ -9,7 +9,7 @@ time_execution_with_cuda_event, get_timing_stats, ) -from kernelbench.dataset import construct_problem_dataset_from_problem_dir +from kernelbench.dataset import construct_kernelbench_dataset, fetch_ref_arch_from_dataset from kernelbench.utils import read_file import os import json @@ -125,31 +125,6 @@ def write_batch_to_json(entries_to_write: list, f_path: str): print(f"[INFO] Wrote {len(entries_to_write)} entries to {f_path}") -def fetch_ref_arch_from_dataset(dataset: list[str], - problem_id: int) -> tuple[str, str, str]: - """ - Fetch the reference architecture from the problem directory - problem_id should be logical index (1-indexed), matching the problem_id in the problem_name - - Returns: - ref_arch_path: str, the path to the reference architecture - ref_arch_name: str, the name of the reference architecture - ref_arch_src: str, the source code of the reference architecture - """ - ref_arch_path = None - - for file in dataset: - if file.split("/")[-1].split("_")[0] == str(problem_id): - ref_arch_path = file - break - if ref_arch_path is None: - raise ValueError(f"No reference architecture found for problem_id {problem_id}") - - ref_arch_src = read_file(ref_arch_path) - - ref_arch_name = ref_arch_path.split("/")[-1] - return (ref_arch_path, ref_arch_name, ref_arch_src) - @app.cls(image=image, scaledown_window=5) class EvalFunc: @@ -223,10 +198,9 @@ def record_baseline_times(config: BaselineConfig, json_results = [] level = config.level - PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level)) - dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR) + dataset = construct_kernelbench_dataset(level) num_problems = len(dataset) - total_work = [(i, *fetch_ref_arch_from_dataset(dataset, i)) for i in list(range(1, num_problems + 1))] + total_work = [(i, *fetch_ref_arch_from_dataset(dataset, i)) for i in dataset.get_problem_ids()] batch_size = config.num_gpu_devices print(f"[Modal] Processing {len(total_work)} problems in parallel batches of {batch_size}") @@ -330,7 +304,7 @@ def get_time_old(level_num, problem_id, num_trials=100, torch_compile=False): ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id( level_num, problem_id, with_name=True ) - ref_arch_name = ref_arch_name.split("/")[-1] + ref_arch_name = os.path.basename(ref_arch_name) context = {} Model, get_init_inputs, get_inputs = load_original_model_and_inputs( ref_arch_src, context diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 312a9545..6618df1d 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -5,7 +5,6 @@ import pydra import torch -from datasets import load_dataset from pydra import Config, REQUIRED from kernelbench.dataset import construct_kernelbench_dataset @@ -15,7 +14,6 @@ create_inference_server_from_presets, extract_first_code, maybe_multithread, - read_file, set_gpu_arch, ) @@ -45,7 +43,7 @@ def __init__(self): self.subset = ( None, None, - ) # (problem_id, problem_name), these are the logical index + ) # (start_id, end_id), both inclusive - logical 1-indexed IDs self.run_name = REQUIRED # name of the run @@ -105,29 +103,10 @@ def generate_sample_single( inference_server: callable, run_dir: str, ) -> bool: - # 1. Fetch Problem - if config.dataset_src == "huggingface": - curr_problem_row = dataset.filter( - lambda x: x["problem_id"] == work.problem_id, desc=None - ) - - ref_arch_src = curr_problem_row["code"][0] - problem_name = curr_problem_row["name"][0] - - elif config.dataset_src == "local": - problem_idx_in_dataset = ( - work.problem_id - 1 - ) # due to dataset list being 0-indexed locally - ref_arch_path = dataset[problem_idx_in_dataset] - - problem_name = os.path.basename(ref_arch_path) - ref_arch_src = read_file(ref_arch_path) - - # Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py") - problem_number = int(problem_name.split("_")[0]) - assert ( - problem_number == work.problem_id - ), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" + # 1. Fetch Problem - unified interface + problem = dataset.get_problem_by_id(work.problem_id) + ref_arch_src = problem.code + problem_name = problem.name if config.custom_prompt_key: custom_prompt = get_custom_prompt( @@ -265,25 +244,25 @@ def main(config: GenerationConfig): print(f"Starting Batch Generation with config: {config}") - # Dataset Configurations - if config.dataset_src == "huggingface": - dataset = load_dataset(config.dataset_name) - curr_level_dataset = dataset[f"level_{config.level}"] - elif config.dataset_src == "local": - curr_level_dataset = construct_kernelbench_dataset(config.level) + # Dataset Configurations - Unified loading + dataset = construct_kernelbench_dataset( + level=config.level, + source=config.dataset_src, + dataset_name=config.dataset_name, + ) - num_problems_in_level = len(curr_level_dataset) + all_problem_ids = dataset.get_problem_ids() if config.subset == (None, None): - problem_id_range = range(1, num_problems_in_level) + problem_ids_to_run = all_problem_ids else: - assert ( - config.subset[0] >= 1 and config.subset[1] <= num_problems_in_level - ), f"Subset range {config.subset} out of range for Level {config.level}" - problem_id_range = range(config.subset[0], config.subset[1]) + start, end = config.subset + problem_ids_to_run = [pid for pid in all_problem_ids if start <= pid <= end] + if not problem_ids_to_run: + print(f"Warning: No problems found in subset range {config.subset}") print( - f"Generating {config.num_samples} sample(s) each for level {config.level} problems: {problem_id_range}" + f"Generating {config.num_samples} sample(s) each for level {config.level} problems: {problem_ids_to_run}" ) # set up run directory @@ -302,9 +281,7 @@ def main(config: GenerationConfig): problems_to_run = [] total_problems = 0 already_completed = 0 - for problem_id in range( - problem_id_range.start, problem_id_range.stop + 1 - ): # end index is inclusive + for problem_id in problem_ids_to_run: for sample_id in range(config.num_samples): total_problems += 1 if not check_kernel_exists(run_dir, config.level, problem_id, sample_id): @@ -338,7 +315,7 @@ def main(config: GenerationConfig): time_interval=config.api_query_interval, # extra args config=config, - dataset=curr_level_dataset, + dataset=dataset, inference_server=inference_server, run_dir=run_dir, ) diff --git a/scripts/inspect_baseline.py b/scripts/inspect_baseline.py index 9f9f6b7c..aeac959d 100644 --- a/scripts/inspect_baseline.py +++ b/scripts/inspect_baseline.py @@ -5,12 +5,10 @@ import numpy as np from kernelbench.eval import ( load_original_model_and_inputs, - time_execution_with_cuda_event, - get_timing_stats, set_seed, fetch_ref_arch_from_problem_id, ) -from kernelbench.dataset import construct_problem_dataset_from_problem_dir +from kernelbench.dataset import construct_kernelbench_dataset import os, sys import logging import json @@ -93,15 +91,15 @@ def emit(self, record): separator("") def fetch_ref_arch_from_level_problem_id(level_num, problem_id, with_name=False): - PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level_num)) - dataset = construct_problem_dataset_from_problem_dir(PROBLEM_DIR) + dataset = construct_kernelbench_dataset(level_num) return fetch_ref_arch_from_problem_id(problem_id, dataset, with_name) def inspect_torch_compile_triton(level_num, problem_id): ref_arch_name, ref_arch_src = fetch_ref_arch_from_level_problem_id( level_num, problem_id, with_name=True ) - ref_arch_name = ref_arch_name.split("/")[-1] + # Extract filename from path (works for both full paths and filenames) + ref_arch_name = os.path.basename(ref_arch_name) context = {} Model, get_init_inputs, get_inputs = load_original_model_and_inputs( ref_arch_src, context @@ -116,7 +114,8 @@ def inspect_baseline_torch_compile(level_num, problem_id): level_num, problem_id, with_name=True ) - ref_arch_name = ref_arch_name.split("/")[-1] + # Extract filename from path (works for both full paths and filenames) + ref_arch_name = os.path.basename(ref_arch_name) context = {} Model, get_init_inputs, get_inputs = load_original_model_and_inputs( ref_arch_src, context diff --git a/scripts/inspect_triton.py b/scripts/inspect_triton.py index f76e1950..56fe6a23 100644 --- a/scripts/inspect_triton.py +++ b/scripts/inspect_triton.py @@ -21,41 +21,28 @@ from kernelbench.eval import ( load_custom_model, load_original_model_and_inputs, - time_execution_with_cuda_event, get_timing_stats, set_seed, ) -def fetch_ref_arch_from_dataset(dataset: list[str], - problem_id: int) -> tuple[str, str, str]: - """ - Fetch the reference architecture from the problem directory - problem_id should be logical index (1-indexed), matching the problem_id in the problem_name +from kernelbench.timing import time_execution_with_cuda_event +from kernelbench.dataset import construct_kernelbench_dataset, fetch_ref_arch_from_dataset, BaseDataset - Returns: - ref_arch_path: str, the path to the reference architecture - ref_arch_name: str, the name of the reference architecture - ref_arch_src: str, the source code of the reference architecture - """ - ref_arch_path = None - - for file in dataset: - if file.split("/")[-1].split("_")[0] == str(problem_id): - ref_arch_path = file - break - if ref_arch_path is None: - raise ValueError(f"No reference architecture found for problem_id {problem_id}") - - ref_arch_src = read_file(ref_arch_path) - - ref_arch_name = ref_arch_path.split("/")[-1] - return (ref_arch_path, ref_arch_name, ref_arch_src) - - -def run_profile_and_save_trace(dataset: list[str], problem_id: int, num_trials=10): - """ - Helper function to get Torch Profile of a problem - # TODO: Fix up this function + +def run_profile_and_save_trace( + dataset: BaseDataset, + problem_id: int, + num_trials: int = 10 +) -> None: + """Helper function to get Torch Profile of a problem. + + Args: + dataset: BaseDataset object + problem_id: Problem ID to profile + num_trials: Number of profiling trials to run (default: 10) + + Note: + Saves trace files to 'trace_non_compiled.json' and 'trace_compiled.json' """ ref_arch_path, ref_arch_name, ref_arch_src = fetch_ref_arch_from_dataset( dataset, problem_id @@ -120,12 +107,19 @@ def run_profile_and_save_trace(dataset: list[str], problem_id: int, num_trials=1 # except Exception as e: # print(f"[Eval] Error in Measuring Performance: {e}") -def get_torch_compile_triton(level_num, problem_id): - """ - Get the triton code generated by torch compile for a particular problem +def get_torch_compile_triton(level_num: int, problem_id: int) -> str: + """Get the triton code generated by torch compile for a particular problem. + + Args: + level_num: KernelBench level (1, 2, or 3) + problem_id: Problem ID to inspect + + Returns: + str: Name of the reference architecture """ + dataset = construct_kernelbench_dataset(level_num) ref_arch_path, ref_arch_name, ref_arch_src = fetch_ref_arch_from_dataset( - dataset, problem_id, with_name=True + dataset, problem_id ) context = {} # import pdb; pdb.set_trace() diff --git a/scripts/run_and_check.py b/scripts/run_and_check.py index 37ab9732..2e080b2f 100644 --- a/scripts/run_and_check.py +++ b/scripts/run_and_check.py @@ -3,7 +3,6 @@ import pydra from pydra import REQUIRED, Config import os -from datasets import load_dataset import modal from kernelbench import eval as kernel_eval @@ -91,6 +90,7 @@ def __init__(self): # ref_origin is local, specify local file path self.ref_arch_src_path = "" # ref_origin is kernelbench, specify level and problem id + self.dataset_src = "huggingface" # either huggingface or local self.dataset_name = "ScalingIntelligence/KernelBench" self.level = "" self.problem_id = "" @@ -255,27 +255,23 @@ def main(config: ScriptConfig): if config.ref_origin == "local": assert config.ref_arch_src_path != "", "ref_arch_src_path is required" ref_arch_src = read_file(config.ref_arch_src_path) + print(f"Loaded reference from local file: {config.ref_arch_src_path}") elif config.ref_origin == "kernelbench": - assert config.dataset_name != "", "dataset_name is required" + from kernelbench.dataset import construct_kernelbench_dataset + assert config.level != "", "level is required" assert config.problem_id != "", "problem_id is required" - - # for now use the HuggingFace dataset - dataset = load_dataset(config.dataset_name) - curr_level_dataset = dataset[f"level_{config.level}"] - - curr_problem_row = curr_level_dataset.filter(lambda x: x["problem_id"] == config.problem_id) - ref_arch_src = curr_problem_row["code"][0] - problem_name = curr_problem_row["name"][0] - - problem_number = int(problem_name.split("_")[0]) - assert problem_number == config.problem_id, f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})" - - print(f"Fetched problem {config.problem_id} from KernelBench level {config.level}: {problem_name}") - - - else: - raise ValueError("Invalid ref_origin") + + # Unified interface - same code for huggingface and local! + dataset = construct_kernelbench_dataset( + level=int(config.level), + source=config.dataset_src, + dataset_name=config.dataset_name, + ) + problem = dataset.get_problem_by_id(int(config.problem_id)) + ref_arch_src = problem.code + + print(f"Fetched problem {problem.problem_id} from KernelBench level {problem.level}: {problem.name}") kernel_src = read_file(config.kernel_src_path) diff --git a/scripts/verify_bench.py b/scripts/verify_bench.py index 5fdc6862..369a25e2 100644 --- a/scripts/verify_bench.py +++ b/scripts/verify_bench.py @@ -71,37 +71,46 @@ def run(Model, NewModel, get_inputs, get_init_inputs, seed=1012): return check_correctness(Model, NewModel, get_inputs, get_init_inputs, seed) -def run_all(directory): - print(f"Running {directory}") +from kernelbench.dataset import construct_kernelbench_dataset + +def run_all(level): + print(f"Running Level {level}") + dataset = construct_kernelbench_dataset(level) total = 0 passed = 0 fail_tests = [] - abs_path = os.path.abspath(directory) - for filename in os.listdir(abs_path): - if filename.endswith(".py"): - total += 1 - module_name = filename[:-3] # Remove .py extension - try: - # Dynamically import the module - spec = importlib.util.spec_from_file_location( - module_name, os.path.join(abs_path, filename) + + for problem in dataset: + total += 1 + module_name = problem.name.replace(".py", "") + try: + problem_path = getattr(problem, "path", None) + if not problem_path: + raise ValueError( + f"Problem '{module_name}' does not have a local file path; " + "verify_bench.py only supports local datasets." ) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - # Get the required attributes from the module - Model = getattr(module, "Model") - get_inputs = getattr(module, "get_inputs") - get_init_inputs = getattr(module, "get_init_inputs") - assert run(Model, Model, get_inputs, get_init_inputs) - passed += 1 - except Exception as e: - fail_tests.append(module_name) - print(f"{directory}: {passed}/{total} passed") + # Dynamically import the module + spec = importlib.util.spec_from_file_location( + module_name, problem_path + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + # Get the required attributes from the module + Model = getattr(module, "Model") + get_inputs = getattr(module, "get_inputs") + get_init_inputs = getattr(module, "get_init_inputs") + assert run(Model, Model, get_inputs, get_init_inputs) + passed += 1 + except Exception as e: + print(f"Failed {module_name}: {e}") + fail_tests.append(module_name) + print(f"Level {level}: {passed}/{total} passed") if len(fail_tests) > 0: print(f"Failed tests: {fail_tests}") if __name__ == "__main__": - run_all(KERNEL_BENCH_PATH + "/level1") - run_all(KERNEL_BENCH_PATH + "/level2") - run_all(KERNEL_BENCH_PATH + "/level3") + run_all(1) + run_all(2) + run_all(3) diff --git a/src/kernelbench/dataset.py b/src/kernelbench/dataset.py index 29bea818..7f0e8b22 100644 --- a/src/kernelbench/dataset.py +++ b/src/kernelbench/dataset.py @@ -1,7 +1,13 @@ ################################################################################ -# Helpers for Dataset +# Unified Dataset Abstraction for KernelBench +# +# Supports both local filesystem and HuggingFace datasets through a unified +# interface. All problem access is by logical problem_id (1-indexed). ################################################################################ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Iterator, Optional import os import random import re @@ -16,86 +22,491 @@ KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench") -def assign_problem_hash(problem_path: str) -> list[int]: - """ - Assign a unique hash to a problem in the dataset +################################################################################ +# Problem Dataclass +################################################################################ + +@dataclass +class Problem: + """Unified representation of a KernelBench problem. + + Attributes: + problem_id: 1-indexed logical ID (matches filename prefix) + name: Filename, e.g., "1_Square_matrix_multiplication_.py" + code: The actual source code + level: KernelBench level (1, 2, 3, or custom) + path: Local filesystem path (None if from HuggingFace) + metadata: Extra metadata for future use (e.g., categories, difficulty) + + Note: + Code is loaded eagerly when the dataset is constructed (~500KB for level 1). + If memory becomes a concern for very large datasets, this could be refactored + to lazy loading where code is only read when Problem.code is accessed. """ - with open(problem_path, "r") as f: - problem_src = f.read() - return get_code_hash(problem_src) + problem_id: int + name: str + code: str + level: int + path: Optional[str] = None + metadata: Optional[dict] = None + @property + def hash(self) -> str: + """Compute code hash for problem identification. + + The hash ignores comments and whitespace, so functionally + equivalent code produces the same hash. Useful for: + - Deduplication across dataset versions + - Tracking problem identity when code formatting changes + - Comparing local vs HuggingFace versions + """ + return get_code_hash(self.code) -def get_code_hash(problem_src: str) -> str: - """ - Assign a unique hash to some piece of code - Important to strip out the comments and whitespace as they are not functionally part of the code - """ - # Remove multi-line comments first - problem_src = re.sub(r'"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\'', "", problem_src) + +################################################################################ +# Hash Utilities +################################################################################ + +def get_code_hash(code: str) -> str: + """Compute a unique hash for code, ignoring comments and whitespace.""" + # Remove multi-line comments + code = re.sub(r'"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\'', "", code) # Remove inline comments and all whitespace - cleaned_problem_src = re.sub(r"#.*$|\s+", "", problem_src, flags=re.MULTILINE) - # hash only on code - return hashlib.md5(cleaned_problem_src.encode()).hexdigest() + cleaned = re.sub(r"#.*$|\s+", "", code, flags=re.MULTILINE) + return hashlib.md5(cleaned.encode()).hexdigest() -def construct_problem_dataset_from_problem_dir(problem_dir: str) -> list[str]: +################################################################################ +# Base Dataset Abstract Class +################################################################################ + +class BaseDataset(ABC): + """Abstract base for all KernelBench datasets. + + Provides a unified interface for accessing problems by ID, + iteration, and length. """ - Construct a list of relative paths to all the python files in the problem directory - Sorted by the numerical prefix of the filenames + + @property + @abstractmethod + def level(self) -> int: + """Return the KernelBench level.""" + pass + + @abstractmethod + def __len__(self) -> int: + """Return the number of problems in the dataset.""" + pass + + @abstractmethod + def __iter__(self) -> Iterator[Problem]: + """Iterate over Problem objects in the dataset.""" + pass + + @abstractmethod + def get_problem_by_id(self, problem_id: int) -> Problem: + """Get problem by 1-indexed logical ID.""" + pass + + @abstractmethod + def get_problem_ids(self) -> list[int]: + """Get sorted list of all problem IDs in the dataset.""" + pass + + + def subset( + self, + problem_ids: Optional[list[int]] = None, + id_range: Optional[tuple[int, int]] = None, + ) -> "BaseDataset": + """Create a subset by problem IDs. + + Args: + problem_ids: Specific problem IDs to include (e.g., [1, 3, 5]) + id_range: (start_id, end_id) inclusive range of problem IDs + + Returns: + New dataset with only the specified problems + + Example: + >>> dataset.subset(problem_ids=[1, 3, 5]) + >>> dataset.subset(id_range=(1, 10)) + """ + raise NotImplementedError("Subclasses should implement subset()") + + def sample(self, n: int, seed: int = 42) -> "BaseDataset": + """Get a random sample of N problems. + + Args: + n: Number of problems to sample + seed: Random seed for reproducibility + + Returns: + New dataset with N randomly selected problems + """ + all_ids = self.get_problem_ids() + n = min(n, len(all_ids)) + random.seed(seed) + sampled_ids = random.sample(all_ids, n) + return self.subset(problem_ids=sorted(sampled_ids)) + + def get_representative_subset(self) -> "BaseDataset": + """Get a curated representative subset for quick iteration. + + Returns a diverse subset covering different problem categories + (matmul, conv, norms, etc.). Useful for testing. + """ + rep_ids = { + 1: [1, 3, 6, 18, 23, 26, 33, 36, 40, 42, 48, 54, 57, 65, 77, 82, 86, 87], + 2: [1, 2, 8, 18, 23, 28, 33, 43], + 3: [1, 5, 8, 11, 20, 21, 33, 38, 43], + } + + if self.level not in rep_ids: + raise ValueError(f"No representative subset for level {self.level}") + + available_ids = set(self.get_problem_ids()) + subset_ids = [pid for pid in rep_ids[self.level] if pid in available_ids] + + return self.subset(problem_ids=subset_ids) + + +################################################################################ +# Local Filesystem Dataset +################################################################################ + +class LocalKernelBenchDataset(BaseDataset): + """Dataset backed by local filesystem. + + Loads problems from KernelBench/level{N}/*.py + Flexible for any level number (1, 2, 3, or custom levels). """ - DATASET = [] - for file_name in os.listdir(problem_dir): - if file_name.endswith(".py"): - # TODO: revisit later to satisfy eval harnes - relative_path = os.path.join(problem_dir, file_name) - DATASET.append(relative_path) + def __init__( + self, + level: int, + base_path: str = KERNEL_BENCH_PATH, + problem_ids: Optional[list[int]] = None, + id_range: Optional[tuple[int, int]] = None, + ): + """Initialize local dataset. + + Args: + level: KernelBench level (any positive integer) + base_path: Path to KernelBench directory + problem_ids: Optional list of specific problem IDs to include + id_range: Optional (start_id, end_id) inclusive range + """ + if level < 1: + raise ValueError(f"level must be >= 1, got {level}") - # Sort the DATASET based on the numerical prefix of the filenames - DATASET.sort(key=lambda x: int(os.path.basename(x).split("_")[0])) + self._level = level + self._base_path = base_path + self._problems: dict[int, Problem] = {} + + # Build filter set from problem_ids and/or id_range + self._filter_ids = self._build_filter_set(problem_ids, id_range) + self._load_problems() - return DATASET + def _build_filter_set( + self, + problem_ids: Optional[list[int]], + id_range: Optional[tuple[int, int]], + ) -> Optional[set[int]]: + """Build a set of IDs to filter by, or None for no filtering.""" + if problem_ids is None and id_range is None: + return None + + filter_set = set() + if problem_ids: + filter_set.update(problem_ids) + if id_range: + start, end = id_range + filter_set.update(range(start, end + 1)) + return filter_set + @property + def level(self) -> int: + return self._level -def construct_kernelbench_dataset(level: int) -> list[str]: - return construct_problem_dataset_from_problem_dir( - os.path.join(KERNEL_BENCH_PATH, f"level{level}") - ) + def _load_problems(self): + problem_dir = os.path.join(self._base_path, f"level{self._level}") + + if not os.path.exists(problem_dir): + raise FileNotFoundError(f"Problem directory not found: {problem_dir}") + + for file_name in os.listdir(problem_dir): + if not file_name.endswith(".py"): + continue + + try: + problem_id = int(file_name.split("_")[0]) + except (ValueError, IndexError): + continue + + # Apply filter if specified + if self._filter_ids is not None and problem_id not in self._filter_ids: + continue + + path = os.path.join(problem_dir, file_name) + with open(path, "r") as f: + code = f.read() + + self._problems[problem_id] = Problem( + problem_id=problem_id, + name=file_name, + code=code, + level=self._level, + path=path, + ) + + def get_problem_by_id(self, problem_id: int) -> Problem: + if problem_id not in self._problems: + raise ValueError(f"Problem ID {problem_id} not found in dataset") + return self._problems[problem_id] + + def get_problem_ids(self) -> list[int]: + return sorted(self._problems.keys()) + + def __len__(self) -> int: + return len(self._problems) + def __iter__(self) -> Iterator[Problem]: + for pid in self.get_problem_ids(): + yield self._problems[pid] + + def __repr__(self) -> str: + return f"LocalKernelBenchDataset(level={self._level}, problems={len(self)})" + + def subset( + self, + problem_ids: Optional[list[int]] = None, + id_range: Optional[tuple[int, int]] = None, + ) -> "LocalKernelBenchDataset": + """Create a subset of this dataset.""" + return LocalKernelBenchDataset( + level=self._level, + base_path=self._base_path, + problem_ids=problem_ids, + id_range=id_range, + ) -KERNELBENCH_LEVEL_1_DATASET = construct_kernelbench_dataset(level=1) -KERNELBENCH_LEVEL_2_DATASET = construct_kernelbench_dataset(level=2) -KERNELBENCH_LEVEL_3_DATASET = construct_kernelbench_dataset(level=3) ################################################################################ -# Eval on Subsets of KernelBench +# HuggingFace Dataset ################################################################################ +class HuggingFaceKernelBenchDataset(BaseDataset): + """Dataset backed by HuggingFace datasets.""" -def get_kernelbench_subset( - level: int, num_subset_problems: int = 10, random_seed: int = 42 -) -> tuple[list[str], list[int]]: + def __init__( + self, + level: int, + dataset_name: str = "ScalingIntelligence/KernelBench", + problem_ids: Optional[list[int]] = None, + id_range: Optional[tuple[int, int]] = None, + ): + """Initialize HuggingFace dataset. + + Args: + level: KernelBench level (1, 2, or 3) + dataset_name: HuggingFace dataset identifier + problem_ids: Optional list of specific problem IDs to include + id_range: Optional (start_id, end_id) inclusive range + """ + if level not in [1, 2, 3]: + raise ValueError(f"HuggingFace dataset only has levels 1, 2, 3, got {level}") + + self._level = level + self._dataset_name = dataset_name + self._problems: dict[int, Problem] = {} + self._filter_ids = self._build_filter_set(problem_ids, id_range) + self._load_dataset() + + def _build_filter_set( + self, + problem_ids: Optional[list[int]], + id_range: Optional[tuple[int, int]], + ) -> Optional[set[int]]: + """Build a set of IDs to filter by, or None for no filtering.""" + if problem_ids is None and id_range is None: + return None + + filter_set = set() + if problem_ids: + filter_set.update(problem_ids) + if id_range: + start, end = id_range + filter_set.update(range(start, end + 1)) + return filter_set + + @property + def level(self) -> int: + return self._level + + def _load_dataset(self): + from datasets import load_dataset + + split_name = f"level_{self._level}" + hf_dataset = load_dataset(self._dataset_name, split=split_name) + + for row in hf_dataset: + problem_id = row["problem_id"] + + if self._filter_ids is not None and problem_id not in self._filter_ids: + continue + + self._problems[problem_id] = Problem( + problem_id=problem_id, + name=row["name"], + code=row["code"], + level=self._level, + path=None, + ) + + def get_problem_by_id(self, problem_id: int) -> Problem: + if problem_id not in self._problems: + raise ValueError(f"Problem ID {problem_id} not found in dataset") + return self._problems[problem_id] + + def get_problem_ids(self) -> list[int]: + return sorted(self._problems.keys()) + + def __len__(self) -> int: + return len(self._problems) + + def __iter__(self) -> Iterator[Problem]: + for pid in self.get_problem_ids(): + yield self._problems[pid] + + def __repr__(self) -> str: + return f"HuggingFaceKernelBenchDataset(level={self._level}, problems={len(self)})" + + def subset( + self, + problem_ids: Optional[list[int]] = None, + id_range: Optional[tuple[int, int]] = None, + ) -> "HuggingFaceKernelBenchDataset": + """Create a subset of this dataset.""" + return HuggingFaceKernelBenchDataset( + level=self._level, + dataset_name=self._dataset_name, + problem_ids=problem_ids, + id_range=id_range, + ) + + +################################################################################ +# Factory Function +################################################################################ + +def construct_kernelbench_dataset( + level: int, + source: str = "local", + dataset_name: str = "ScalingIntelligence/KernelBench", + base_path: str = KERNEL_BENCH_PATH, + problem_ids: Optional[list[int]] = None, + id_range: Optional[tuple[int, int]] = None, +) -> BaseDataset: + """Construct a KernelBench dataset for a specific level. + + Args: + level: KernelBench level (1, 2, 3, or custom for local) + source: "local" for filesystem, "huggingface" for HF datasets + dataset_name: HuggingFace dataset identifier (if source="huggingface") + base_path: Path to KernelBench directory (if source="local") + problem_ids: Optional list of specific problem IDs to include + id_range: Optional (start_id, end_id) inclusive range + + Returns: + BaseDataset instance for the specified level + + Examples: + # Local filesystem (default) + >>> dataset = construct_kernelbench_dataset(level=1, source="local") + >>> len(dataset) + 100 + + # HuggingFace + >>> dataset = construct_kernelbench_dataset(level=1, source="huggingface") + >>> len(dataset) + 100 + + # Filter by specific IDs + >>> dataset = construct_kernelbench_dataset(level=1, problem_ids=[1, 3, 5]) + >>> dataset.get_problem_ids() + [1, 3, 5] + + # Filter by range + >>> dataset = construct_kernelbench_dataset(level=1, id_range=(1, 10)) + >>> dataset.get_problem_ids() + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + # Access problems + >>> problem = dataset.get_problem_by_id(1) + >>> problem.name + '1_Square_matrix_multiplication_.py' + >>> problem.code[:50] + 'import torch...' """ - Get a random subset of problems from the KernelBench dataset + if source == "local": + return LocalKernelBenchDataset(level, base_path, problem_ids, id_range) + elif source == "huggingface": + return HuggingFaceKernelBenchDataset(level, dataset_name, problem_ids, id_range) + else: + raise ValueError(f"Unknown source: {source}. Must be 'local' or 'huggingface'") + + +################################################################################ +# Convenience Functions +################################################################################ + +def fetch_ref_arch_from_dataset( + dataset: BaseDataset, + problem_id: int, +) -> tuple[Optional[str], str, str]: + """Fetch reference architecture from dataset. + + Returns: + (path, name, code) - path is None for HuggingFace datasets """ + problem = dataset.get_problem_by_id(problem_id) + return (problem.path, problem.name, problem.code) - full_dataset = construct_kernelbench_dataset(level) + +def get_kernelbench_subset( + level: int, + num_subset_problems: int = 10, + random_seed: int = 42, + source: str = "local", +) -> tuple[BaseDataset, list[int]]: + """Get a random subset of problems. + + Returns: + (subset_dataset, subset_problem_ids) + """ + full_dataset = construct_kernelbench_dataset(level, source=source) + all_ids = full_dataset.get_problem_ids() random.seed(random_seed) - num_subset_problems = min(num_subset_problems, len(full_dataset)) - subset_indices = random.sample(range(len(full_dataset)), num_subset_problems) + num_subset_problems = min(num_subset_problems, len(all_ids)) + subset_ids = sorted(random.sample(all_ids, num_subset_problems)) - subset = sorted([full_dataset[i] for i in subset_indices]) - return subset, subset_indices + subset_dataset = construct_kernelbench_dataset( + level=level, + source=source, + problem_ids=subset_ids, + ) + return subset_dataset, subset_ids ################################################################################ -# Representative subsets of KernelBench -# use this if you want to iterate on methods without the hassle of running the full dataset -# problem_ids are 1-indexed (logical index) +# Representative Subsets of KernelBench +# Use these for quick iteration without running the full dataset ################################################################################ -level1_representative_subset = [ +# Level 1: Basic operators - matmul, activations, norms, pooling, convolutions +LEVEL1_REPRESENTATIVE_SUBSET = [ "1_Square_matrix_multiplication_.py", "3_Batched_matrix_multiplication.py", "6_Matmul_with_large_K_dimension_.py", @@ -115,10 +526,10 @@ def get_kernelbench_subset( "86_conv_depthwise_separable_2D.py", "87_conv_pointwise_2D.py", ] +LEVEL1_REPRESENTATIVE_IDS = [1, 3, 6, 18, 23, 26, 33, 36, 40, 42, 48, 54, 57, 65, 77, 82, 86, 87] -level1_representative_subset_problem_ids = [1, 3, 6, 18, 23, 26, 33, 36, 40, 42, 48, 54, 57, 65, 77, 82, 86, 87] - -level2_representative_subset = [ +# Level 2: Fused operators - multi-op fusion patterns +LEVEL2_REPRESENTATIVE_SUBSET = [ "1_Conv2D_ReLU_BiasAdd.py", "2_ConvTranspose2d_BiasAdd_Clamp_Scaling_Clamp_Divide.py", "8_Conv3d_Divide_Max_GlobalAvgPool_BiasAdd_Sum.py", @@ -128,10 +539,10 @@ def get_kernelbench_subset( "33_Gemm_Scale_BatchNorm.py", "43_Conv3d_Max_LogSumExp_ReLU.py", ] +LEVEL2_REPRESENTATIVE_IDS = [1, 2, 8, 18, 23, 28, 33, 43] -level2_representative_subset_problem_ids = [1, 2, 8, 18, 23, 28, 33, 43] - -level3_representative_subset = [ +# Level 3: Full models - MLP, CNN architectures, RNNs, Transformers +LEVEL3_REPRESENTATIVE_SUBSET = [ "1_MLP.py", "5_AlexNet.py", "8_ResNetBasicBlock.py", @@ -142,5 +553,29 @@ def get_kernelbench_subset( "38_LTSMBidirectional.py", "43_MinGPTCausalAttention.py", ] +LEVEL3_REPRESENTATIVE_IDS = [1, 5, 8, 11, 20, 21, 33, 38, 43] -level3_representative_subset_problem_ids = [1, 5, 8, 11, 20, 33, 38, 43] \ No newline at end of file + +def get_representative_dataset(level: int, source: str = "local") -> BaseDataset: + """Get a representative subset dataset for quick iteration. + + Args: + level: 1, 2, or 3 + source: "local" or "huggingface" + + Returns: + Dataset containing only representative problems + """ + id_map = { + 1: LEVEL1_REPRESENTATIVE_IDS, + 2: LEVEL2_REPRESENTATIVE_IDS, + 3: LEVEL3_REPRESENTATIVE_IDS, + } + if level not in id_map: + raise ValueError(f"No representative subset for level {level}") + + return construct_kernelbench_dataset( + level=level, + source=source, + problem_ids=id_map[level], + ) \ No newline at end of file diff --git a/src/kernelbench/eval.py b/src/kernelbench/eval.py index 023b4f88..47f59793 100644 --- a/src/kernelbench/eval.py +++ b/src/kernelbench/eval.py @@ -21,7 +21,7 @@ import torch.nn as nn from pydantic import BaseModel -from . import utils, timing +from . import timing, dataset REPO_TOP_PATH = os.path.abspath( os.path.join( @@ -39,30 +39,27 @@ def get_error_name(e: Exception) -> str: return f"{e.__class__.__module__}.{e.__class__.__name__}" -def fetch_ref_arch_from_problem_id(problem_id, problems, with_name=False) -> str: +def fetch_ref_arch_from_problem_id(problem_id: int, dataset: "BaseDataset", with_name=False) -> Union[str, tuple[str, str]]: """ - Fetches the reference architecture in string for a given problem_id + Fetches the reference architecture for a given problem_id from the dataset. """ if isinstance(problem_id, str): problem_id = int(problem_id) - problem_path = problems[problem_id] - - # problem_path = os.path.join(REPO_ROOT_PATH, problem) - if not os.path.exists(problem_path): - raise FileNotFoundError(f"Problem file at {problem_path} does not exist.") - - ref_arch = utils.read_file(problem_path) + problem = dataset.get_problem_by_id(problem_id) + ref_arch = problem.code + if not with_name: return ref_arch else: - return (problem_path, ref_arch) + # Use problem.name as fallback when path is None (e.g., for HuggingFace datasets) + name = problem.path if problem.path is not None else problem.name + return (name, ref_arch) def fetch_ref_arch_from_level_problem_id(level, problem_id, with_name=False): - PROBLEM_DIR = os.path.join(KERNEL_BENCH_PATH, "level" + str(level)) - dataset = utils.construct_problem_dataset_from_problem_dir(PROBLEM_DIR) - return fetch_ref_arch_from_problem_id(problem_id, dataset, with_name) + kb_dataset = dataset.construct_kernelbench_dataset(level) + return fetch_ref_arch_from_problem_id(problem_id, kb_dataset, with_name) def set_seed(seed: int): @@ -531,6 +528,18 @@ def eval_kernel_against_ref( compiled=False, metadata=metadata ) # skip further steps + # Check if ModelNew was successfully loaded (load_custom_model returns None on syntax errors) + if ModelNew is None: + print( + "Failed to load custom model: Syntax error or ModelNew not found in generated code. Record as compilation failure." + ) + metadata["compilation_error_name"] = "SyntaxError" + metadata["compilation_error"] = "Syntax error in custom generated code or ModelNew not found" + graceful_eval_cleanup(context, device, tempfile) + return KernelExecResult( + compiled=False, metadata=metadata + ) # skip further steps + # at this point we passed compilation try: with torch.no_grad(): @@ -880,8 +889,7 @@ def convert_to_serializable(obj): return converted_metadata - # if __name__ == "__main__": # fetch_kernel_from_database("kernelbench_prompt_v2_level_2", 1, 1, "http://localhost:9091") # print(fetch_ref_arch_from_level_problem_id("2", 1, with_name=True)) -# fetch_baseline_time("level1", 0, ["1_Square_matrix_multiplication_.py"], "tests/baseline_time_matx3.json") \ No newline at end of file +# Note: fetch_baseline_time is available in kernelbench.timing module \ No newline at end of file diff --git a/src/kernelbench/frameworks.py b/src/kernelbench/frameworks.py index d102e35a..a7f1a2a0 100644 --- a/src/kernelbench/frameworks.py +++ b/src/kernelbench/frameworks.py @@ -24,7 +24,6 @@ from dotenv import load_dotenv load_dotenv() -# from datasets import load_dataset import numpy as np from contextlib import contextmanager from collections import defaultdict diff --git a/src/kernelbench/timing.py b/src/kernelbench/timing.py index b930a890..6b0db010 100644 --- a/src/kernelbench/timing.py +++ b/src/kernelbench/timing.py @@ -2,7 +2,7 @@ import json import numpy as np import time -from typing import Any +from typing import Any, Optional import os ################################################################################ @@ -393,8 +393,8 @@ def time_execution_with_host_time( # tools to help compute speedup and other time ######################################################### def fetch_baseline_time( - level_name: str, problem_id: int, dataset: list[str], baseline_time_filepath: str -) -> dict: + level_name: str, problem_id: int, dataset: "BaseDataset", baseline_time_filepath: str +) -> Optional[float]: """ Fetch the baseline time from the time @@ -409,8 +409,8 @@ def fetch_baseline_time( with open(baseline_time_filepath, "r") as f: baseline_json = json.load(f) - # TODO: replace with the new Dataset object that Omar will merge in - problem_name = dataset[problem_id].split("/")[-1] + problem = dataset.get_problem_by_id(problem_id) + problem_name = problem.name baseline_time = baseline_json[level_name].get(problem_name, None) return baseline_time diff --git a/src/kernelbench/unit_tests/test_dataset.py b/src/kernelbench/unit_tests/test_dataset.py index 22b0612d..b8468503 100644 --- a/src/kernelbench/unit_tests/test_dataset.py +++ b/src/kernelbench/unit_tests/test_dataset.py @@ -1,19 +1,29 @@ - -import pytest -from kernelbench.dataset import get_code_hash - """ -Usage -pytest test_dataset.py +Unit tests for the KernelBench dataset module. + +Usage: + pytest src/kernelbench/unit_tests/test_dataset.py -v """ +import pytest +from kernelbench.dataset import ( + get_code_hash, + construct_kernelbench_dataset, + fetch_ref_arch_from_dataset, + Problem, + BaseDataset, + LocalKernelBenchDataset, + LEVEL1_REPRESENTATIVE_IDS, +) -def test_get_code_hash(): - """ - Test collision and equivalence checking - """ - code_snippet_batch_1_v1 = """ +################################################################################ +# Hash Tests +################################################################################ + +def test_get_code_hash_ignores_comments(): + """Hash should be equal for semantically equivalent code with different comments.""" + code_v1 = """ import torch # This is for a single batch ''' @@ -22,7 +32,7 @@ def test_get_code_hash(): B = 1 """ - code_snippet_batch_1_v2 = """ + code_v2 = """ import torch ''' More problem descriptions (updated) @@ -31,18 +41,291 @@ def test_get_code_hash(): B = 1 """ + + assert get_code_hash(code_v1) == get_code_hash(code_v2) - code_snippet_batch_64 = """ + +def test_get_code_hash_different_for_different_code(): + """Hash should differ for code with actual differences.""" + code_batch_1 = """ + import torch + B = 1 + """ + + code_batch_64 = """ import torch - # This is for a single batch - ''' - Some random multi-line comment - ''' B = 64 """ + + assert get_code_hash(code_batch_1) != get_code_hash(code_batch_64) + + +################################################################################ +# Dataset Construction Tests +################################################################################ + +def test_construct_local_dataset(): + """Test constructing a local dataset.""" + dataset = construct_kernelbench_dataset(level=1, source="local") + + assert isinstance(dataset, BaseDataset) + assert isinstance(dataset, LocalKernelBenchDataset) + assert dataset.level == 1 + assert len(dataset) > 0 + + +def test_construct_dataset_invalid_level(): + """Test that invalid level raises ValueError.""" + with pytest.raises(ValueError): + construct_kernelbench_dataset(level=0, source="local") + + with pytest.raises(ValueError): + construct_kernelbench_dataset(level=-1, source="local") + + +def test_construct_dataset_invalid_source(): + """Test that invalid source raises ValueError.""" + with pytest.raises(ValueError): + construct_kernelbench_dataset(level=1, source="invalid") + + +################################################################################ +# Problem Access Tests +################################################################################ + +def test_get_problem_by_id(): + """Test getting a problem by its logical ID.""" + dataset = construct_kernelbench_dataset(level=1) + + problem = dataset.get_problem_by_id(1) + + assert isinstance(problem, Problem) + assert problem.problem_id == 1 + assert problem.name.startswith("1_") + assert problem.level == 1 + assert len(problem.code) > 0 + assert problem.path is not None + + +def test_get_problem_by_id_not_found(): + """Test that non-existent ID raises ValueError.""" + dataset = construct_kernelbench_dataset(level=1) + + with pytest.raises(ValueError, match="not found"): + dataset.get_problem_by_id(9999) + + +def test_get_problem_ids(): + """Test getting list of all problem IDs.""" + dataset = construct_kernelbench_dataset(level=1) + + ids = dataset.get_problem_ids() + + assert isinstance(ids, list) + assert len(ids) > 0 + assert ids == sorted(ids) # should be sorted + assert 1 in ids + + +def test_problem_hash(): + """Test that Problem.hash property works.""" + dataset = construct_kernelbench_dataset(level=1) + problem = dataset.get_problem_by_id(1) + + hash_value = problem.hash + + assert isinstance(hash_value, str) + assert len(hash_value) == 32 # MD5 hex digest length + + +################################################################################ +# Subset Tests +################################################################################ - assert get_code_hash(code_snippet_batch_1_v1) == get_code_hash(code_snippet_batch_1_v2), \ - "Hash should be equal for semantically equivalent code with different comments" +def test_subset_by_ids(): + """Test creating a subset by specific problem IDs.""" + dataset = construct_kernelbench_dataset(level=1) + subset = dataset.subset(problem_ids=[1, 3, 5]) - assert get_code_hash(code_snippet_batch_1_v1) != get_code_hash(code_snippet_batch_64), \ - "Hash should differ for code with different batch sizes" \ No newline at end of file + assert len(subset) == 3 + assert subset.get_problem_ids() == [1, 3, 5] + + +def test_subset_by_range(): + """Test creating a subset by ID range.""" + dataset = construct_kernelbench_dataset(level=1) + subset = dataset.subset(id_range=(1, 5)) + + assert len(subset) == 5 + assert subset.get_problem_ids() == [1, 2, 3, 4, 5] + + +def test_sample_random(): + """Test random sampling from dataset.""" + dataset = construct_kernelbench_dataset(level=1) + + sample1 = dataset.sample(n=5, seed=42) + sample2 = dataset.sample(n=5, seed=42) + sample3 = dataset.sample(n=5, seed=123) + + assert len(sample1) == 5 + assert sample1.get_problem_ids() == sample2.get_problem_ids() # same seed + assert sample1.get_problem_ids() != sample3.get_problem_ids() # different seed + + +def test_get_representative_subset(): + """Test getting the representative subset.""" + dataset = construct_kernelbench_dataset(level=1) + rep = dataset.get_representative_subset() + + assert len(rep) == len(LEVEL1_REPRESENTATIVE_IDS) + assert rep.get_problem_ids() == LEVEL1_REPRESENTATIVE_IDS + + +################################################################################ +# Iterator Tests +################################################################################ + +def test_dataset_iteration(): + """Test iterating over dataset.""" + dataset = construct_kernelbench_dataset(level=1, problem_ids=[1, 2, 3]) + + problems = list(dataset) + + assert len(problems) == 3 + assert all(isinstance(p, Problem) for p in problems) + assert [p.problem_id for p in problems] == [1, 2, 3] + + + +def test_dataset_len(): + """Test len() on dataset.""" + dataset = construct_kernelbench_dataset(level=1, problem_ids=[1, 2, 3]) + assert len(dataset) == 3 + + +################################################################################ +# Compatibility Tests +################################################################################ + +def test_fetch_ref_arch_from_dataset(): + """Test the backward-compatible fetch function.""" + dataset = construct_kernelbench_dataset(level=1) + + path, name, code = fetch_ref_arch_from_dataset(dataset, problem_id=1) + + assert path is not None + assert name.startswith("1_") + assert len(code) > 0 + + # Should match direct problem access + problem = dataset.get_problem_by_id(1) + assert path == problem.path + assert name == problem.name + assert code == problem.code + + +################################################################################ +# Multiple Levels Tests +################################################################################ + +def test_all_levels_load(): + """Test that all standard levels can be loaded.""" + for level in [1, 2, 3]: + dataset = construct_kernelbench_dataset(level=level) + assert len(dataset) > 0 + assert dataset.level == level + + +################################################################################ +# HuggingFace Tests (requires network) +################################################################################ + +@pytest.mark.slow +def test_huggingface_dataset_loads(): + """Test that HuggingFace dataset can be loaded.""" + dataset = construct_kernelbench_dataset(level=1, source="huggingface") + + assert len(dataset) > 0 + assert dataset.level == 1 + + problem = dataset.get_problem_by_id(1) + assert problem.problem_id == 1 + assert problem.name.startswith("1_") + assert len(problem.code) > 0 + assert problem.path is None # HF has no local path + + +@pytest.mark.slow +def test_local_and_huggingface_parity(): + """Test that local and HuggingFace datasets have the same content.""" + local_ds = construct_kernelbench_dataset(level=1, source="local") + hf_ds = construct_kernelbench_dataset(level=1, source="huggingface") + + # Same number of problems + assert len(local_ds) == len(hf_ds), "Local and HF should have same number of problems" + + # Same problem IDs + assert local_ds.get_problem_ids() == hf_ds.get_problem_ids(), "Problem IDs should match" + + # Check a few problems have same content + for pid in [1, 10, 50]: + local_p = local_ds.get_problem_by_id(pid) + hf_p = hf_ds.get_problem_by_id(pid) + + # Names may differ slightly (HF may not have .py extension) + # But the base name (without extension) should match + local_base = local_p.name.replace(".py", "") + hf_base = hf_p.name.replace(".py", "") + assert local_base == hf_base, f"Problem {pid} name mismatch: {local_p.name} vs {hf_p.name}" + + # Code and hash should match exactly + assert local_p.code == hf_p.code, f"Problem {pid} code mismatch" + assert local_p.hash == hf_p.hash, f"Problem {pid} hash mismatch" + + +@pytest.mark.slow +def test_huggingface_subset(): + """Test that HuggingFace dataset supports subsetting.""" + hf_ds = construct_kernelbench_dataset(level=1, source="huggingface") + + subset = hf_ds.subset(problem_ids=[1, 2, 3]) + assert len(subset) == 3 + assert subset.get_problem_ids() == [1, 2, 3] + + +@pytest.mark.slow +def test_huggingface_representative(): + """Test that HuggingFace dataset supports representative subset.""" + hf_ds = construct_kernelbench_dataset(level=1, source="huggingface") + rep = hf_ds.get_representative_subset() + + assert len(rep) == len(LEVEL1_REPRESENTATIVE_IDS) + + +@pytest.mark.slow +def test_unified_interface_behavior(): + """Test that both sources behave identically through the unified interface.""" + for source in ["local", "huggingface"]: + dataset = construct_kernelbench_dataset(level=1, source=source) + + # All these operations should work the same + assert len(dataset) > 0 + assert 1 in dataset.get_problem_ids() + + problem = dataset.get_problem_by_id(1) + assert isinstance(problem, Problem) + assert problem.problem_id == 1 + assert problem.level == 1 + + # Iteration works + count = 0 + for p in dataset: + count += 1 + if count >= 3: + break + assert count == 3 + + # Subset works + subset = dataset.subset(problem_ids=[1, 2]) + assert len(subset) == 2 \ No newline at end of file diff --git a/src/kernelbench/utils.py b/src/kernelbench/utils.py index 2ace37cd..abe067b1 100644 --- a/src/kernelbench/utils.py +++ b/src/kernelbench/utils.py @@ -20,7 +20,6 @@ from openai import OpenAI from litellm import completion -# from datasets import load_dataset import numpy as np from contextlib import contextmanager from collections import defaultdict