Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
122 changes: 122 additions & 0 deletions scripts/shinka_evolve/evaluate_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import argparse
import sys
import os
import torch
import json
import traceback

# Add repo root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

from src.dataset import construct_kernelbench_dataset
from src.eval import eval_kernel_against_ref
from src.utils import read_file, set_gpu_arch

def get_current_gpu_arch():
"""Dynamically detect GPU architecture to set correct NVCC flags."""
if not torch.cuda.is_available():
return ["Volta"] # Fallback

cap = torch.cuda.get_device_capability(0)
major, minor = cap

if major == 9: return ["Hopper"]
if major == 8 and minor == 9: return ["Ada"]
if major == 8: return ["Ampere"]
if major == 7 and minor == 5: return ["Turing"]
if major == 7: return ["Volta"]

return ["Ampere"] # Default fallback for modern cards

def main(program_path, results_dir, level, problem_id):
os.makedirs(results_dir, exist_ok=True)

if not os.path.exists(program_path):
write_results(results_dir, False, "File not found", {})
return

device = torch.device("cuda:0")

# FIX: Dynamic Architecture Detection
# Prevents "no kernel image is available" errors on H100/A100
current_arch = get_current_gpu_arch()
set_gpu_arch(current_arch)

# Load Reference
try:
dataset = construct_kernelbench_dataset(level)
ref_arch_src = read_file(dataset[problem_id - 1])
except Exception as e:
write_results(results_dir, False, f"Ref Load Error: {e}", {})
return

# Load Candidate
with open(program_path, 'r') as f:
custom_model_src = f.read()

metrics = {"combined_score": 0.0, "text_feedback": ""}

# FIX: Create a unique build directory for this specific evaluation run
# This prevents race conditions on the global torch_extensions lock file
# if parallel jobs are running.
jit_build_dir = os.path.join(results_dir, "jit_build")
os.makedirs(jit_build_dir, exist_ok=True)

try:
# Run Eval
result = eval_kernel_against_ref(
original_model_src=ref_arch_src,
custom_model_src=custom_model_src,
seed_num=42,
num_correct_trials=5,
num_perf_trials=100,
measure_performance=True,
timing_method="cuda_event",
device=device,
check_for_excessive_speedup=True,
excessive_speedup_threshold=50.0,
build_dir=jit_build_dir # <--- Critical for concurrency
)

if not result.compiled:
msg = result.metadata.get('compilation_error', 'Unknown Error')
# Hint for the user if they messed up ordering
if "name 'matmul' is not defined" in msg or "NameError" in msg:
msg += "\n\nHINT: You must define your CUDA kernel variables BEFORE the class ModelNew uses them."

metrics["text_feedback"] = f"Compilation/Runtime Error:\n{msg}"
write_results(results_dir, False, "Compilation Failed", metrics)

elif not result.correctness:
metrics["text_feedback"] = f"Incorrect Output. Max Diff: {result.metadata.get('max_difference', 'N/A')}"
write_results(results_dir, False, "Incorrect", metrics)

else:
runtime = max(result.runtime, 1e-9)
ref_runtime = max(result.ref_runtime, 1e-9)
speedup = ref_runtime / runtime

metrics["combined_score"] = float(speedup)
metrics["public"] = {"speedup": float(speedup), "runtime_ms": float(runtime)}
metrics["text_feedback"] = f"Success! Speedup: {speedup:.2f}x"
write_results(results_dir, True, None, metrics)

except Exception as e:
metrics["text_feedback"] = f"Harness Error:\n{str(e)}"
write_results(results_dir, False, str(e), metrics)

def write_results(results_dir, correct, error_msg, metrics):
with open(os.path.join(results_dir, "correct.json"), "w") as f:
json.dump({"correct": correct, "error": error_msg}, f, indent=4)
with open(os.path.join(results_dir, "metrics.json"), "w") as f:
json.dump(metrics, f, indent=4)
print(f"Eval Done. Correct: {correct}, Score: {metrics.get('combined_score', 0)}")

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--program_path", type=str, required=True)
parser.add_argument("--results_dir", type=str, required=True)
parser.add_argument("--level", type=int, required=True)
parser.add_argument("--problem_id", type=int, required=True)
args = parser.parse_args()
main(args.program_path, args.results_dir, args.level, args.problem_id)
78 changes: 78 additions & 0 deletions scripts/shinka_evolve/make_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
import sys
import re

# Add repo root to path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from src.dataset import construct_kernelbench_dataset
from src.utils import read_file

# Template that ensures proper ordering
TEMPLATE = """
import torch
import torch.nn as nn
from torch.utils.cpp_extension import load_inline

# SHINKA_EVOLVE_TEMPLATE
# This file is auto-generated to seed the evolution.

# EVOLVE-BLOCK-START
import torch
import torch.nn as nn

# --- INSERT CUSTOM CUDA KERNEL HERE ---
# (Define source strings and call load_inline BEFORE the class definition)

{content}

def get_inputs():
# Helper to generate inputs on CUDA
return [x.cuda() for x in {get_inputs_name}()]

def get_init_inputs():
return {get_init_inputs_name}()
# EVOLVE-BLOCK-END
"""

def create_seed_file(level, problem_id, output_path):
dataset = construct_kernelbench_dataset(level)
# KernelBench uses 0-indexed list, problem_id is 1-indexed
ref_path = dataset[problem_id - 1]
ref_src = read_file(ref_path)

# 1. Rename Class
content = re.sub(r'class Model\s*\(', 'class ModelNew(', ref_src)

# 2. Fix super() call to be generic (handles name change)
content = re.sub(r'super\s*\([^\)]+\)\.__init__\(\)', 'super().__init__()', content)

# 3. Extract just the class and the original helper functions
# FIX: Use Regex for robust replacement of function definitions
# Handles "def get_inputs():" and "def get_inputs( ):" etc.
content = re.sub(
r"def\s+get_inputs\s*\(\s*\):",
"def _original_get_inputs():",
content
)
content = re.sub(
r"def\s+get_init_inputs\s*\(\s*\):",
"def _original_get_init_inputs():",
content
)

seed_content = TEMPLATE.format(
content=content,
get_inputs_name="_original_get_inputs",
get_init_inputs_name="_original_get_init_inputs"
)

os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w") as f:
f.write(seed_content)

print(f"Created seed file at {output_path}")

if __name__ == "__main__":
# Example usage
# create_seed_file(1, 1, "runs/debug_seed/initial.py")
pass
119 changes: 119 additions & 0 deletions scripts/shinka_evolve/run_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import argparse
import os
import sys

# Add repo root to path to import src
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

from shinka.core import EvolutionRunner, EvolutionConfig
from shinka.database import DatabaseConfig
from shinka.launch import LocalJobConfig
from scripts.shinka_evolve.make_seed import create_seed_file
from src.prompt_constructor_toml import get_prompt_for_backend

def build_kernelbench_sys_msg(model_name):
# Get standard KernelBench prompt context
base_prompt = get_prompt_for_backend(
ref_arch_src="",
backend="cuda",
option="few_shot",
precision="fp32",
include_hardware=True,
gpu_name="H100"
)

# Remove empty placeholder
base_prompt = base_prompt.replace("You are given the following architecture:\n\n\n\n\n", "")

sys_msg = f"You are a world-class CUDA optimization expert.\n\n{base_prompt}"

# ADD CRITICAL ORDERING INSTRUCTIONS
sys_msg += """

# CRITICAL INSTRUCTIONS FOR SHINKA EVOLUTION

1. **ORDERING IS VITAL:** Python executes top-to-bottom.
- You MUST define your `cuda_source` string FIRST.
- You MUST define `cpp_source` SECOND.
- You MUST call `load_inline` THIRD (assigning it to a variable like `my_kernel`).
- You MUST define `class ModelNew` LAST.

2. **PLACEMENT:**
- Look for the comment `# --- INSERT CUSTOM CUDA KERNEL HERE ---`.
- Place your C++ string definitions and `load_inline` call there.

3. **CLASS USAGE:**
- Inside `ModelNew.__init__`, assign the global kernel object to `self` (e.g., `self.kernel = my_kernel`).
- Do NOT define `load_inline` inside the class methods.

4. **SYNTAX:**
- Use `super().__init__()` in the constructor.

5. **NO BOILERPLATE:**
- Do NOT include `if __name__ == "__main__":` blocks, argument parsing, or test runners.
- Provide ONLY the kernel definition and the `ModelNew` class.
"""

return sys_msg

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--level", type=int, default=1)
parser.add_argument("--problem_id", type=int, default=1)
parser.add_argument("--model", type=str, default="gpt-4o-2024-08-06")
parser.add_argument("--generations", type=int, default=10)
# New arguments for suite integration
parser.add_argument("--results_root", type=str, default="runs")
parser.add_argument("--max_parallel_jobs", type=int, default=1)
args = parser.parse_args()

# 1. Setup Workspace
# Organize by Level -> Problem
problem_slug = f"L{args.level}_P{args.problem_id}"
run_name = f"{problem_slug}_{args.model.replace('/', '_')}"

# If results_root is provided (from suite), nest it there
results_dir = os.path.join(args.results_root, run_name)
os.makedirs(results_dir, exist_ok=True)

print(f"Starting search for {problem_slug} in {results_dir}")

# 2. Generate Initial Program
init_path = os.path.join(results_dir, "initial.py")
if not os.path.exists(init_path):
create_seed_file(args.level, args.problem_id, init_path)

# 3. Configure Shinka
job_config = LocalJobConfig(
eval_program_path="scripts/shinka_evolve/evaluate_bridge.py",
extra_cmd_args={
"level": args.level,
"problem_id": args.problem_id
}
)

db_config = DatabaseConfig(
db_path="evolution.db",
parent_selection_strategy="weighted",
archive_size=20
)

evo_config = EvolutionConfig(
task_sys_msg=build_kernelbench_sys_msg(args.model),
num_generations=args.generations,
max_parallel_jobs=args.max_parallel_jobs,
init_program_path=init_path,
results_dir=results_dir,
llm_models=[args.model],
use_text_feedback=True,
# We increase Full Rewrite probability because ordering is hard to fix with Diff
patch_types=["full"],
patch_type_probs=[1.0],
language="python"
)

runner = EvolutionRunner(evo_config, job_config, db_config)
runner.run()

if __name__ == "__main__":
main()
Loading