diff --git a/scripts/shinka_evolve/__init__.py b/scripts/shinka_evolve/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/scripts/shinka_evolve/evaluate_bridge.py b/scripts/shinka_evolve/evaluate_bridge.py new file mode 100644 index 00000000..164e97c6 --- /dev/null +++ b/scripts/shinka_evolve/evaluate_bridge.py @@ -0,0 +1,122 @@ +import argparse +import sys +import os +import torch +import json +import traceback + +# Add repo root to path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +from src.dataset import construct_kernelbench_dataset +from src.eval import eval_kernel_against_ref +from src.utils import read_file, set_gpu_arch + +def get_current_gpu_arch(): + """Dynamically detect GPU architecture to set correct NVCC flags.""" + if not torch.cuda.is_available(): + return ["Volta"] # Fallback + + cap = torch.cuda.get_device_capability(0) + major, minor = cap + + if major == 9: return ["Hopper"] + if major == 8 and minor == 9: return ["Ada"] + if major == 8: return ["Ampere"] + if major == 7 and minor == 5: return ["Turing"] + if major == 7: return ["Volta"] + + return ["Ampere"] # Default fallback for modern cards + +def main(program_path, results_dir, level, problem_id): + os.makedirs(results_dir, exist_ok=True) + + if not os.path.exists(program_path): + write_results(results_dir, False, "File not found", {}) + return + + device = torch.device("cuda:0") + + # FIX: Dynamic Architecture Detection + # Prevents "no kernel image is available" errors on H100/A100 + current_arch = get_current_gpu_arch() + set_gpu_arch(current_arch) + + # Load Reference + try: + dataset = construct_kernelbench_dataset(level) + ref_arch_src = read_file(dataset[problem_id - 1]) + except Exception as e: + write_results(results_dir, False, f"Ref Load Error: {e}", {}) + return + + # Load Candidate + with open(program_path, 'r') as f: + custom_model_src = f.read() + + metrics = {"combined_score": 0.0, "text_feedback": ""} + + # FIX: Create a unique build directory for this specific evaluation run + # This prevents race conditions on the global torch_extensions lock file + # if parallel jobs are running. + jit_build_dir = os.path.join(results_dir, "jit_build") + os.makedirs(jit_build_dir, exist_ok=True) + + try: + # Run Eval + result = eval_kernel_against_ref( + original_model_src=ref_arch_src, + custom_model_src=custom_model_src, + seed_num=42, + num_correct_trials=5, + num_perf_trials=100, + measure_performance=True, + timing_method="cuda_event", + device=device, + check_for_excessive_speedup=True, + excessive_speedup_threshold=50.0, + build_dir=jit_build_dir # <--- Critical for concurrency + ) + + if not result.compiled: + msg = result.metadata.get('compilation_error', 'Unknown Error') + # Hint for the user if they messed up ordering + if "name 'matmul' is not defined" in msg or "NameError" in msg: + msg += "\n\nHINT: You must define your CUDA kernel variables BEFORE the class ModelNew uses them." + + metrics["text_feedback"] = f"Compilation/Runtime Error:\n{msg}" + write_results(results_dir, False, "Compilation Failed", metrics) + + elif not result.correctness: + metrics["text_feedback"] = f"Incorrect Output. Max Diff: {result.metadata.get('max_difference', 'N/A')}" + write_results(results_dir, False, "Incorrect", metrics) + + else: + runtime = max(result.runtime, 1e-9) + ref_runtime = max(result.ref_runtime, 1e-9) + speedup = ref_runtime / runtime + + metrics["combined_score"] = float(speedup) + metrics["public"] = {"speedup": float(speedup), "runtime_ms": float(runtime)} + metrics["text_feedback"] = f"Success! Speedup: {speedup:.2f}x" + write_results(results_dir, True, None, metrics) + + except Exception as e: + metrics["text_feedback"] = f"Harness Error:\n{str(e)}" + write_results(results_dir, False, str(e), metrics) + +def write_results(results_dir, correct, error_msg, metrics): + with open(os.path.join(results_dir, "correct.json"), "w") as f: + json.dump({"correct": correct, "error": error_msg}, f, indent=4) + with open(os.path.join(results_dir, "metrics.json"), "w") as f: + json.dump(metrics, f, indent=4) + print(f"Eval Done. Correct: {correct}, Score: {metrics.get('combined_score', 0)}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--program_path", type=str, required=True) + parser.add_argument("--results_dir", type=str, required=True) + parser.add_argument("--level", type=int, required=True) + parser.add_argument("--problem_id", type=int, required=True) + args = parser.parse_args() + main(args.program_path, args.results_dir, args.level, args.problem_id) \ No newline at end of file diff --git a/scripts/shinka_evolve/make_seed.py b/scripts/shinka_evolve/make_seed.py new file mode 100644 index 00000000..1614a5b0 --- /dev/null +++ b/scripts/shinka_evolve/make_seed.py @@ -0,0 +1,78 @@ +import os +import sys +import re + +# Add repo root to path +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +from src.dataset import construct_kernelbench_dataset +from src.utils import read_file + +# Template that ensures proper ordering +TEMPLATE = """ +import torch +import torch.nn as nn +from torch.utils.cpp_extension import load_inline + +# SHINKA_EVOLVE_TEMPLATE +# This file is auto-generated to seed the evolution. + +# EVOLVE-BLOCK-START +import torch +import torch.nn as nn + +# --- INSERT CUSTOM CUDA KERNEL HERE --- +# (Define source strings and call load_inline BEFORE the class definition) + +{content} + +def get_inputs(): + # Helper to generate inputs on CUDA + return [x.cuda() for x in {get_inputs_name}()] + +def get_init_inputs(): + return {get_init_inputs_name}() +# EVOLVE-BLOCK-END +""" + +def create_seed_file(level, problem_id, output_path): + dataset = construct_kernelbench_dataset(level) + # KernelBench uses 0-indexed list, problem_id is 1-indexed + ref_path = dataset[problem_id - 1] + ref_src = read_file(ref_path) + + # 1. Rename Class + content = re.sub(r'class Model\s*\(', 'class ModelNew(', ref_src) + + # 2. Fix super() call to be generic (handles name change) + content = re.sub(r'super\s*\([^\)]+\)\.__init__\(\)', 'super().__init__()', content) + + # 3. Extract just the class and the original helper functions + # FIX: Use Regex for robust replacement of function definitions + # Handles "def get_inputs():" and "def get_inputs( ):" etc. + content = re.sub( + r"def\s+get_inputs\s*\(\s*\):", + "def _original_get_inputs():", + content + ) + content = re.sub( + r"def\s+get_init_inputs\s*\(\s*\):", + "def _original_get_init_inputs():", + content + ) + + seed_content = TEMPLATE.format( + content=content, + get_inputs_name="_original_get_inputs", + get_init_inputs_name="_original_get_init_inputs" + ) + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "w") as f: + f.write(seed_content) + + print(f"Created seed file at {output_path}") + +if __name__ == "__main__": + # Example usage + # create_seed_file(1, 1, "runs/debug_seed/initial.py") + pass \ No newline at end of file diff --git a/scripts/shinka_evolve/run_search.py b/scripts/shinka_evolve/run_search.py new file mode 100644 index 00000000..93f115ce --- /dev/null +++ b/scripts/shinka_evolve/run_search.py @@ -0,0 +1,119 @@ +import argparse +import os +import sys + +# Add repo root to path to import src +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +from shinka.core import EvolutionRunner, EvolutionConfig +from shinka.database import DatabaseConfig +from shinka.launch import LocalJobConfig +from scripts.shinka_evolve.make_seed import create_seed_file +from src.prompt_constructor_toml import get_prompt_for_backend + +def build_kernelbench_sys_msg(model_name): + # Get standard KernelBench prompt context + base_prompt = get_prompt_for_backend( + ref_arch_src="", + backend="cuda", + option="few_shot", + precision="fp32", + include_hardware=True, + gpu_name="H100" + ) + + # Remove empty placeholder + base_prompt = base_prompt.replace("You are given the following architecture:\n\n\n\n\n", "") + + sys_msg = f"You are a world-class CUDA optimization expert.\n\n{base_prompt}" + + # ADD CRITICAL ORDERING INSTRUCTIONS + sys_msg += """ + + # CRITICAL INSTRUCTIONS FOR SHINKA EVOLUTION + + 1. **ORDERING IS VITAL:** Python executes top-to-bottom. + - You MUST define your `cuda_source` string FIRST. + - You MUST define `cpp_source` SECOND. + - You MUST call `load_inline` THIRD (assigning it to a variable like `my_kernel`). + - You MUST define `class ModelNew` LAST. + + 2. **PLACEMENT:** + - Look for the comment `# --- INSERT CUSTOM CUDA KERNEL HERE ---`. + - Place your C++ string definitions and `load_inline` call there. + + 3. **CLASS USAGE:** + - Inside `ModelNew.__init__`, assign the global kernel object to `self` (e.g., `self.kernel = my_kernel`). + - Do NOT define `load_inline` inside the class methods. + + 4. **SYNTAX:** + - Use `super().__init__()` in the constructor. + + 5. **NO BOILERPLATE:** + - Do NOT include `if __name__ == "__main__":` blocks, argument parsing, or test runners. + - Provide ONLY the kernel definition and the `ModelNew` class. + """ + + return sys_msg + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--level", type=int, default=1) + parser.add_argument("--problem_id", type=int, default=1) + parser.add_argument("--model", type=str, default="gpt-4o-2024-08-06") + parser.add_argument("--generations", type=int, default=10) + # New arguments for suite integration + parser.add_argument("--results_root", type=str, default="runs") + parser.add_argument("--max_parallel_jobs", type=int, default=1) + args = parser.parse_args() + + # 1. Setup Workspace + # Organize by Level -> Problem + problem_slug = f"L{args.level}_P{args.problem_id}" + run_name = f"{problem_slug}_{args.model.replace('/', '_')}" + + # If results_root is provided (from suite), nest it there + results_dir = os.path.join(args.results_root, run_name) + os.makedirs(results_dir, exist_ok=True) + + print(f"Starting search for {problem_slug} in {results_dir}") + + # 2. Generate Initial Program + init_path = os.path.join(results_dir, "initial.py") + if not os.path.exists(init_path): + create_seed_file(args.level, args.problem_id, init_path) + + # 3. Configure Shinka + job_config = LocalJobConfig( + eval_program_path="scripts/shinka_evolve/evaluate_bridge.py", + extra_cmd_args={ + "level": args.level, + "problem_id": args.problem_id + } + ) + + db_config = DatabaseConfig( + db_path="evolution.db", + parent_selection_strategy="weighted", + archive_size=20 + ) + + evo_config = EvolutionConfig( + task_sys_msg=build_kernelbench_sys_msg(args.model), + num_generations=args.generations, + max_parallel_jobs=args.max_parallel_jobs, + init_program_path=init_path, + results_dir=results_dir, + llm_models=[args.model], + use_text_feedback=True, + # We increase Full Rewrite probability because ordering is hard to fix with Diff + patch_types=["full"], + patch_type_probs=[1.0], + language="python" + ) + + runner = EvolutionRunner(evo_config, job_config, db_config) + runner.run() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/shinka_evolve/run_suite.py b/scripts/shinka_evolve/run_suite.py new file mode 100644 index 00000000..4b22ec47 --- /dev/null +++ b/scripts/shinka_evolve/run_suite.py @@ -0,0 +1,129 @@ +import argparse +import subprocess +import sys +import time +import os +import datetime +from concurrent.futures import ThreadPoolExecutor + +# From src/dataset.py - Level 1 Representative Subset +# These cover Matmul, Softmax, Norms, Convs, etc. +LEVEL_1_SUBSET = [ + 23, # Softmax (High chance of speedup via fusion) + 26, # GELU (High chance of speedup) + 40, # LayerNorm (High chance of speedup) + 33, # BatchNorm + 1, # Square Matmul (Hard) + 3, # Batched Matmul + 6, # Matmul Large K + 18, # Matmul Transposed + 36, # RMSNorm + 42, # Max Pool 2D + 48, # Mean Reduction + 54, # Conv 3D + 57, # Conv Transposed 2D + 82, # Conv Depthwise + # Add more level 1 problems if desired, or loop 1-100 +] + +LEVEL_2_SUBSET = [ + 1, # Conv2D + ReLU + BiasAdd (Great candidate for fusion) + 2, # ConvTranspose + Bias + Scaling + 18, # Matmul + Sum + Max + AvgPool (Complex fusion) + 28, # BMM + InstanceNorm + Sum + 33, # Gemm + Scale + BatchNorm +] + +def run_problem(gpu_id, problem_id, args, results_root): + print(f"[GPU {gpu_id}] Starting Level 1 Problem {problem_id}...") + + log_file = os.path.join(results_root, f"log_P{problem_id}.txt") + + cmd = [ + sys.executable, "scripts/shinka_evolve/run_search.py", + "--level", "2", + "--problem_id", str(problem_id), + "--model", args.model, + "--generations", str(args.generations), + "--results_root", results_root, + "--max_parallel_jobs", str(args.jobs_per_gpu) + ] + + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + + start_time = time.time() + + with open(log_file, "w") as f: + try: + # We use check=True to raise exception on failure + subprocess.run( + cmd, + env=env, + check=True, + stdout=f, + stderr=subprocess.STDOUT + ) + status = "āœ… Success" + except subprocess.CalledProcessError: + status = "āŒ Failed" + except Exception as e: + status = f"šŸ’„ Error: {e}" + + duration = time.time() - start_time + print(f"[GPU {gpu_id}] Finished P{problem_id}: {status} ({duration:.1f}s)") + return status + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="gpt-5-mini") + parser.add_argument("--generations", type=int, default=10) + parser.add_argument("--gpus", type=str, default="5,6,7", help="Comma separated list of GPU IDs") + parser.add_argument("--jobs_per_gpu", type=int, default=6, help="Parallel evals inside Shinka per GPU") + args = parser.parse_args() + + # Create a timestamped root directory for this suite run + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + results_root = f"runs/suite_lvl1_{args.model.replace('/', '_')}_{timestamp}" + os.makedirs(results_root, exist_ok=True) + + gpu_list = [int(x) for x in args.gpus.split(",")] + + print(f"šŸš€ Starting ShinkaEvolve Suite") + print(f"šŸ¤– Model: {args.model}") + print(f"šŸ–„ļø GPUs: {gpu_list}") + print(f"šŸ“‚ Results: {results_root}") + print(f"šŸŽÆ Targets: {len(LEVEL_2_SUBSET)} problems") + print("-" * 50) + + # Create a queue of problems + # We use ThreadPoolExecutor. The worker threads just manage the subprocess calls. + # The actual heavy lifting is done by the OS scheduling the python subprocesses onto the GPUs. + + with ThreadPoolExecutor(max_workers=len(gpu_list)) as executor: + # We need to map problems to GPUs as they become free. + # This simple approach launches N futures where N = num_gpus. + # Each future pulls from a shared iterator/queue. + + problem_queue = list(LEVEL_2_SUBSET) + + def gpu_worker(gpu_id): + while problem_queue: + # Simple thread-safe pop + try: + pid = problem_queue.pop(0) + except IndexError: + break + run_problem(gpu_id, pid, args, results_root) + + # Launch one worker per GPU + futures = [executor.submit(gpu_worker, gpu) for gpu in gpu_list] + + # Wait for all to finish + for f in futures: + f.result() + + print(f"\nšŸ Suite completed. Check {results_root} for results.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index f4fac580..6bb92838 100644 --- a/src/utils.py +++ b/src/utils.py @@ -153,7 +153,8 @@ def query_server( if "openai/" not in model_name.lower() and "gpt" not in model_name.lower(): completion_kwargs["top_k"] = top_k - response = completion(**completion_kwargs) + # FIX: Added drop_params=True to handle gpt-5-mini/reasoning models that reject top_p + response = completion(**completion_kwargs, drop_params=True) # output processing if num_completions == 1: