diff --git a/EVAL.md b/EVAL.md index b9880560..faebda6f 100644 --- a/EVAL.md +++ b/EVAL.md @@ -36,6 +36,9 @@ We have (and continue to) implement various approaches to conduct kernel timing Check out `timing.py` to see available timing methods and `src/unit_tests/test_eval_timing.py` to test out various timing methods (including leveraging `cuda_event` marker, Triton `do_bench`, `host_time` E2E time). @palic and team is working on a blogpost explaining the different tradeoffs soon. +### Profiling +We have experimental profiling support leveraging NVIDIA NCU in `profile.py`. + ### Checkers There are potentially many ways model might reward hack and we would like to catch the known ways through checkers [experimental and WIP]. We start with `kernel_static_checker.py`, which is a regex-based checker on the genenrated code against set of rules. We plan to add AST-based, LM-as-a-judge, and more runtime checks in the future. We welcome suggestions and contributions here. diff --git a/README.md b/README.md index 3686574a..7343e73b 100644 --- a/README.md +++ b/README.md @@ -66,13 +66,15 @@ We organize the repo into the following structure: KernelBench/ ├── assets/ ├── KernelBench/ # Benchmark dataset files -├── src/ # KernelBench logic code +├── src/kernelbench/ # KernelBench logic code │ ├── unit_tests/ │ ├── prompts/ │ ├── .... ├── scripts/ # helpful scripts to run the benchmark ├── results/ # baseline times across hardware ├── runs/ # where your runs will be stored +├── notebooks/ # example notebooks for analysis +├── pyproject.toml # Project configuration and dependencies ``` ## 🔧 Set up diff --git a/pyproject.toml b/pyproject.toml index 85a690c5..bed37150 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ gpu = [ "nvidia-cutlass-dsl", "tilelang", "cupy-cuda12x", + "nsight-python", ] dev = [ "pytest", @@ -51,4 +52,7 @@ dev = [ [tool.setuptools.packages.find] where = ["src"] -include = ["kernelbench*"] \ No newline at end of file +include = ["kernelbench*"] + +[tool.setuptools.package-data] +kernelbench = ["prompts/**/*"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8bf9a48f..07603a86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,7 @@ ninja>=1.13.0 cupy-cuda12x==13.6.0 tomli>=2.3.0 tabulate>=0.9.0 +nsight-python # Numerics einops>=0.8.1 diff --git a/scripts/generate_baseline_time_modal.py b/scripts/generate_baseline_time_modal.py index 22bdb1b0..e9c8428e 100644 --- a/scripts/generate_baseline_time_modal.py +++ b/scripts/generate_baseline_time_modal.py @@ -6,7 +6,7 @@ fetch_ref_arch_from_problem_id, ) from kernelbench.timing import ( - time_execution_with_cuda_event, + get_timing_function, get_timing_stats, ) from kernelbench.dataset import construct_kernelbench_dataset, fetch_ref_arch_from_dataset @@ -134,6 +134,7 @@ def measure_program_time( ref_arch_name: str, ref_arch_src: str, num_trials: int = 100, + timing_method: str="cuda_event", use_torch_compile: bool = False, torch_compile_backend: str="inductor", torch_compile_options: str="default", @@ -173,9 +174,16 @@ def measure_program_time( print(f"Using PyTorch Eager Execution on {ref_arch_name}") model = model.cuda(device=device) + timing_func = get_timing_function(timing_method) torch.cuda.synchronize(device=device) - elapsed_times = time_execution_with_cuda_event( - model, inputs, num_trials=num_trials, verbose=verbose, device=device + elapsed_times = timing_func( + model, + inputs, + num_warmup=3, # or any default you prefer + num_trials=num_trials, + discard_first=1, # or 0 to include first trial + verbose=verbose, + device=device, ) runtime_stats = get_timing_stats(elapsed_times, device=device) @@ -220,6 +228,7 @@ def record_baseline_times(config: BaselineConfig, ref_arch_name=ref_arch_name, ref_arch_src=ref_arch_src, num_trials=config.num_trials, + timing_method="cuda_event", use_torch_compile=use_torch_compile, torch_compile_backend=torch_compile_backend, torch_compile_options=torch_compile_options, diff --git a/scripts/get_baseline_time_single_problem.py b/scripts/get_baseline_time_single_problem.py index 1613f4b0..91e1472f 100644 --- a/scripts/get_baseline_time_single_problem.py +++ b/scripts/get_baseline_time_single_problem.py @@ -2,16 +2,17 @@ import numpy as np from kernelbench.eval import ( load_original_model_and_inputs, - time_execution_with_cuda_event, - get_timing_stats, set_seed, fetch_ref_arch_from_problem_id, ) +from src.timing import get_timing_function, get_timing_stats + def measure_program_time( ref_arch_name: str, ref_arch_src: str, num_trials: int = 100, + timing_method: str="cuda_event", use_torch_compile: bool = False, torch_compile_backend: str="inductor", torch_compile_options: str="default", @@ -52,8 +53,9 @@ def measure_program_time( model = model.cuda(device=device) torch.cuda.synchronize(device=device) - elapsed_times = time_execution_with_cuda_event( - model, *inputs, num_trials=num_trials, verbose=verbose, device=device + timing_func = get_timing_function(timing_method ) + elapsed_times = timing_func( + model, inputs, num_warmup=3, num_trials=num_trials, discard_first=1, verbose=verbose, device=device ) runtime_stats = get_timing_stats(elapsed_times, device=device) @@ -87,5 +89,4 @@ def get_inputs(): def get_init_inputs(): return [] # No special initialization inputs needed """ - print(measure_program_time(ref_arch_name, ref_arch_src, use_torch_compile=False)) - print(measure_program_time(ref_arch_name, ref_arch_src, use_torch_compile=True)) \ No newline at end of file + print(measure_program_time(ref_arch_name, ref_arch_src, use_torch_compile=False, timing_method="cuda_event")) \ No newline at end of file diff --git a/scripts/inspect_triton.py b/scripts/inspect_triton.py index 56fe6a23..3333dd91 100644 --- a/scripts/inspect_triton.py +++ b/scripts/inspect_triton.py @@ -161,7 +161,9 @@ def get_torch_compile_triton(level_num: int, problem_id: int) -> str: torch.cuda.synchronize(device=device) - elapsed_times = time_execution_with_cuda_event( + timing_method = "cuda_event" # use cuda event for timing here + time_func_cuda_event = get_timing_function(timing_method) + elapsed_times = time_func_cuda_event( model, *inputs, num_trials=1, verbose=False, device=device ) runtime_stats = get_timing_stats(elapsed_times, device=device) diff --git a/src/kernelbench/profile.py b/src/kernelbench/profile.py new file mode 100644 index 00000000..8326324e --- /dev/null +++ b/src/kernelbench/profile.py @@ -0,0 +1,438 @@ +""" +Nsight Profiling Module for KernelBench +======================================== + +This module provides GPU profiling capabilities using NVIDIA Nsight Compute (ncu). +It allows collecting hardware-level metrics from kernels. + +NOTE: this is an experimental module, not part of the default eval pass. +You need hardware counter access (usually requires sudo) to accesss hardware counter. +We only support local mode with this feature, not avaliable on Modal. + +Key Features: +- Profile arbitrary PyTorch functions with hardware metrics +- Profile KernelBench models (ModelNew) with automatic setup/teardown +- Combine metrics from multi-kernel operations (common in PyTorch) + +Requirements: +- NVIDIA Nsight Compute CLI (ncu) must be installed and in PATH +- nsight-python package + +Common Metrics: +- gpu__time_duration.sum: Total GPU time in nanoseconds +- sm__cycles_elapsed.sum: Total SM cycles elapsed +- sm__cycles_active.avg: Average active cycles per SM + +Reference: https://docs.nvidia.com/nsight-python +""" + +import os +from shutil import which + +import torch + +# ============================================================================= +# Nsight Availability Check +# ============================================================================= + +try: + import nsight + NSIGHT_AVAILABLE = True +except ImportError: + NSIGHT_AVAILABLE = False + + +# ============================================================================= +# Utility Functions +# ============================================================================= + +def check_ncu_available() -> bool: + """ + Check if NVIDIA Nsight Compute CLI (ncu) is available in PATH. + + The ncu command-line tool is required for collecting GPU hardware metrics. + It's typically installed with the CUDA Toolkit or NVIDIA Nsight Compute. + + Returns: + True if ncu is found in PATH, False otherwise. + """ + return which('ncu') is not None + + +# ============================================================================= +# Core Profiling Functions +# ============================================================================= + +def profile_with_nsight(func, metrics=None, num_trials=1): + """ + Profile a PyTorch function and collect hardware metrics. + + Handles complexity: + - Setting up the Nsight kernel analyzer + - Combining metrics when PyTorch ops launch multiple CUDA kernels + - Extracting results from Nsight's DataFrame format + + Args: + func: A callable (no arguments) that executes the code to profile. + Typically a closure that captures the model and inputs. + metrics: List of Nsight metric names to collect. If None, defaults to + ['sm__cycles_active.avg']. Can also pass a single string. + num_trials: Number of times to run the function for averaging. + + Returns: + Dictionary mapping metric names to their values (float). + Returns None for metrics that couldn't be collected. + + Example: + >>> def my_kernel(): + ... return torch.matmul(a, b) + >>> results = profile_with_nsight(my_kernel, ['gpu__time_duration.sum']) + >>> print(results['gpu__time_duration.sum']) # Time in nanoseconds + + Raises: + RuntimeError: If nsight-python is not installed. + """ + if not NSIGHT_AVAILABLE: + raise RuntimeError( + "nsight-python not available." + ) + + # Normalize metrics to a list + if metrics is None: + metrics = ['sm__cycles_active.avg'] + elif isinstance(metrics, str): + metrics = [metrics] + + # Define the profiled function with Nsight decorator + # NOTE: PyTorch operations often launch multiple CUDA kernels (e.g., a matmul + # might have separate kernels for the computation and memory operations). + # We use combine_kernel_metrics to sum these together for a single measurement. + @nsight.analyze.kernel( + metrics=metrics, + runs=num_trials, + configs=[(0,)], # Use default GPU config + combine_kernel_metrics=lambda a, b: (0 if a is None else a) + (0 if b is None else b), + ) + def profiled(_): + # The nsight.annotate context marks the region we care about + with nsight.annotate("kernel"): + return func() + + try: + # Run profiling - this invokes ncu under the hood + result = profiled() + + # Convert results to DataFrame + df = result.to_dataframe() if result else None + if df is None or df.empty: + return {m: None for m in metrics} + + # Nsight returns a DataFrame with columns like: + # - 'Metric': The metric name (e.g., 'gpu__time_duration.sum') + # - 'AvgValue': The measured value + # We need to find these columns (names may vary slightly) + metric_col = next((c for c in df.columns if c.lower() == 'metric'), None) + value_col = next((c for c in df.columns if 'value' in c.lower()), None) + + if not metric_col or not value_col: + return {m: None for m in metrics} + + # Build a dictionary of all metrics in the DataFrame + metric_dict = { + row[metric_col]: float(row[value_col]) + for _, row in df.iterrows() + } + + # Return only the requested metrics (None if not found) + return {m: metric_dict.get(m) for m in metrics} + + except Exception as e: + print(f"Error profiling: {e}") + return {m: None for m in metrics} + + +def profile_kernelbench_model_with_nsight( + custom_model_src: str, + ref_model_src: str = None, + metrics: list = None, + num_trials: int = 1, + seed: int = 42, + device: torch.device = None, + backend: str = "cuda", + precision: torch.dtype = torch.float32, + build_dir: str = None, + verbose: bool = False, +) -> dict: + """ + Profile a KernelBench model (ModelNew) using Nsight hardware metrics. + + This is the high-level profiling function designed for KernelBench workflows. + It handles the full lifecycle: + 1. Load and compile the custom model from source code + 2. Generate inputs using the model's get_inputs() function + 3. Profile the forward pass with Nsight + 4. Clean up resources + + IMPORTANT: This function assumes the model has already been validated for + correctness via eval. No correctness checking is performed here. + + Args: + custom_model_src: Python source code string containing the ModelNew class. + ref_model_src: Optional source code for the reference model. Used to get + get_inputs() and get_init_inputs() if they're not in + custom_model_src. If None, uses custom_model_src. + metrics: List of Nsight metrics to collect. Defaults to ['sm__cycles_active.avg']. + num_trials: Number of profiling runs for averaging. Default: 1. + seed: Random seed for reproducible input generation. Default: 42. + device: CUDA device to run on. Default: cuda:0. + backend: Compilation backend ('cuda', 'triton', 'tilelang', 'cute'). + precision: torch.dtype for computation. Default: torch.float32. + build_dir: Directory for compiled kernel artifacts. Default: None. + verbose: Print progress messages. Default: False. + + Returns: + Dictionary mapping metric names to their measured values. + Values are None if the metric couldn't be collected. + + Example: + >>> results = profile_kernelbench_model_with_nsight( + ... custom_model_src=my_model_code, + ... ref_model_src=ref_model_code, + ... metrics=['gpu__time_duration.sum', 'sm__cycles_elapsed.sum'], + ... verbose=True + ... ) + >>> print(f"GPU time: {results['gpu__time_duration.sum']} ns") + """ + # Import eval utilities (deferred to avoid circular imports) + from kernelbench.eval import ( + load_custom_model, + load_custom_model_with_tempfile, + load_original_model_and_inputs, + _process_input_tensor, + set_seed, + graceful_eval_cleanup, + ) + + # Set defaults + device = device or torch.device("cuda:0") + if metrics is None: + metrics = ['sm__cycles_active.avg'] + elif isinstance(metrics, str): + metrics = [metrics] + + torch.cuda.set_device(device) + + # ------------------------------------------------------------------------- + # Step 1: Load input generation functions from model source + # ------------------------------------------------------------------------- + # The model source should define get_inputs() and get_init_inputs() functions + # that return the tensors needed to run the model. + input_source = ref_model_src or custom_model_src + context = {} + _, get_init_inputs, get_inputs = load_original_model_and_inputs(input_source, context) + + # Generate initialization inputs (for model constructor) + set_seed(seed) + init_inputs = [ + _process_input_tensor(x, device, backend, precision) + for x in get_init_inputs() + ] + + # ------------------------------------------------------------------------- + # Step 2: Load and compile the custom model + # ------------------------------------------------------------------------- + if verbose: + print("[Profile] Loading and compiling custom model...") + + # Enable CUDA Device-Side Assertions for better error messages + os.environ["TORCH_USE_CUDA_DSA"] = "1" + tempfile = None + + # Different backends require different loading mechanisms + if backend.lower() in ["triton", "tilelang", "cute"]: + # These backends need a temp file for proper module loading + ModelNew, tempfile = load_custom_model_with_tempfile( + custom_model_src, entry_point="ModelNew" + ) + else: + # Standard CUDA backend + ModelNew = load_custom_model(custom_model_src, {}, build_dir) + + torch.cuda.synchronize(device=device) + + # ------------------------------------------------------------------------- + # Step 3: Instantiate the model + # ------------------------------------------------------------------------- + with torch.no_grad(): + set_seed(seed) + custom_model = ModelNew(*init_inputs) + custom_model = custom_model.to(device=device, dtype=precision) + torch.cuda.synchronize(device=device) + + if verbose: + print("[Profile] Model instantiated successfully") + + # ------------------------------------------------------------------------- + # Step 4: Profile the forward pass + # ------------------------------------------------------------------------- + # Generate forward pass inputs + set_seed(seed) + inputs = [ + _process_input_tensor(x, device, backend, precision) + for x in get_inputs() + ] + + if verbose: + print(f"[Profile] Profiling with nsight (metrics: {metrics})...") + + # Create a closure for the forward pass + def model_forward(): + with torch.no_grad(): + return custom_model(*inputs) + + # Run profiling + metric_values = profile_with_nsight( + model_forward, + metrics=metrics, + num_trials=num_trials + ) + + if verbose: + print("[Profile] Profiling completed successfully") + + # ------------------------------------------------------------------------- + # Step 5: Cleanup + # ------------------------------------------------------------------------- + graceful_eval_cleanup(context, device, tempfile) + + return metric_values + + +# ============================================================================= +# Examples and Tests +# ============================================================================= + +def example_ncu_python_profile(): + """ + Simple example demonstrating how to profile a basic matrix multiplication. + + This shows the minimal setup needed to use profile_with_nsight(). + """ + print("Creating test tensors...") + a = torch.randn(256, 256, device="cuda") + b = torch.randn(256, 256, device="cuda") + + # Create a closure that captures the tensors + def matmul_kernel(): + return a @ b + + print("Running nsight profiling...") + + metric_values = profile_with_nsight( + matmul_kernel, + metrics=[ + 'sm__cycles_active.avg', # Average active cycles per SM + 'sm__cycles_elapsed.sum', # Total cycles elapsed + 'smsp__inst_executed_pipe_tensor_op_hmma.sum', # Tensor core ops + ], + num_trials=1, + ) + + print("\nProfiling results:") + for metric_name, value in metric_values.items(): + print(f" {metric_name}: {value}") + + +def test_flash_attention_profile(): + """ + Test profiling a Flash Attention model from the KernelBench examples. + + This demonstrates the full workflow of profiling a KernelBench model + using profile_kernelbench_model_with_nsight(). + """ + from kernelbench.utils import read_file + + # Locate the example model files + REPO_ROOT = os.path.dirname(__file__) + ref_model_path = os.path.join( + REPO_ROOT, "prompts/few_shot/model_ex_flash_attn.py" + ) + custom_model_path = os.path.join( + REPO_ROOT, "prompts/few_shot/model_new_ex_flash_attn.py" + ) + + print("[Test] Reading model source files...") + ref_model_src = read_file(ref_model_path) + custom_model_src = read_file(custom_model_path) + + print("[Test] Starting profiling with nsight...") + + metrics = profile_kernelbench_model_with_nsight( + custom_model_src=custom_model_src, + ref_model_src=ref_model_src, + metrics=[ + 'gpu__time_duration.sum', # Total GPU execution time (ns) + 'sm__cycles_elapsed.sum', # Total SM cycles + ], + seed=42, + backend="cuda", + precision=torch.float32, + verbose=True + ) + + print("\n[Test] Profiling results:") + print("=" * 60) + for metric_name, value in metrics.items(): + if value is not None: + print(f" {metric_name}: {value:,.0f}") + else: + print(f" {metric_name}: ") + print("=" * 60) + + return metrics + + +# Optional: Decorated benchmark function for direct use with nsight +if NSIGHT_AVAILABLE: + @nsight.analyze.kernel + def benchmark_matmul(n): + """ + Standard benchmark following nsight-python documentation style. + + This shows how to use the @nsight.analyze.kernel decorator directly + for simple benchmarking scenarios. + """ + a = torch.randn(n, n, device="cuda") + b = torch.randn(n, n, device="cuda") + with nsight.annotate("matmul"): + c = a @ b + return c + + +# ============================================================================= +# Main Entry Point +# ============================================================================= + +if __name__ == "__main__": + # Verify prerequisites + if not check_ncu_available(): + print("ERROR: ncu not found in PATH.") + print("Install NVIDIA Nsight Compute from:") + print(" https://developer.nvidia.com/nsight-compute") + exit(1) + + if not torch.cuda.is_available(): + print("ERROR: CUDA is not available.") + exit(1) + + print("=" * 60) + print("Running Nsight Profiling Examples") + print("=" * 60) + + # Run the simple example first + print("\n--- Example: Basic Matrix Multiplication ---\n") + example_ncu_python_profile() + + # Run the full KernelBench model test + print("\n--- Test: Flash Attention Model ---\n") + test_flash_attention_profile() diff --git a/src/kernelbench/prompt_constructor_toml.py b/src/kernelbench/prompt_constructor_toml.py index 7b8b1bdf..4349a74d 100644 --- a/src/kernelbench/prompt_constructor_toml.py +++ b/src/kernelbench/prompt_constructor_toml.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional -from kernelbench.utils import read_file +from kernelbench.utils import read_file, get_package_resource_path, resolve_path, REPO_TOP_PATH """ TOML-based prompt constructor for managing prompt templates and configurations. @@ -14,11 +14,9 @@ You can easily check some of the prompt templates we have provided and create your own. """ -REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) -PROMPTS_TOML = os.path.join(REPO_TOP_PATH, "src/kernelbench/prompts/prompts.toml") - -assert os.path.exists(PROMPTS_TOML), f"Prompts.toml not found at {PROMPTS_TOML}" -GPU_SPECS_PY = "src/kernelbench/prompts/hardware/gpu_specs.py" +# Resolve paths using the helper from utils +PROMPTS_TOML = get_package_resource_path("prompts/prompts.toml") +GPU_SPECS_PY = get_package_resource_path("prompts/hardware/gpu_specs.py") HARDWARE_COMPONENT_KEYS = [ "hardware_header", "hardware_specs", @@ -26,12 +24,6 @@ "hardware_best_practices", ] -def _abs_path(rel: str) -> str: - """Convert relative path to absolute path from repo root.""" - if os.path.isabs(rel): - return rel - return os.path.join(REPO_TOP_PATH, rel) - @dataclass class PromptConfig: """ @@ -255,17 +247,17 @@ def render_example_entry(input_code: str, output_code: str, example_label: str) # Use multiple examples (true few-shot) examples_intro = intro_few_shot for i, (input_path, output_path) in enumerate(few_shot_examples, 1): - input_code = read_file(_abs_path(input_path)) - output_code = read_file(_abs_path(output_path)) + input_code = read_file(resolve_path(input_path)) + output_code = read_file(resolve_path(output_path)) examples_entries.append( render_example_entry(input_code, output_code, f"Example {i}:") ) else: # Fall back to one-shot - ex_arch_path = _abs_path( + ex_arch_path = resolve_path( backend_data.get("few_shot_example_arch") or shared.get("few_shot_example_arch") ) - ex_new_path = _abs_path(backend_data["one_shot_new_arch"]) + ex_new_path = resolve_path(backend_data["one_shot_new_arch"]) input_code = read_file(ex_arch_path) output_code = read_file(ex_new_path) examples_entries.append( @@ -274,10 +266,10 @@ def render_example_entry(input_code: str, output_code: str, example_label: str) elif requires_example == "one_shot": # Always use one-shot - ex_arch_path = _abs_path( + ex_arch_path = resolve_path( backend_data.get("few_shot_example_arch") or shared.get("few_shot_example_arch") ) - ex_new_path = _abs_path(backend_data["one_shot_new_arch"]) + ex_new_path = resolve_path(backend_data["one_shot_new_arch"]) input_code = read_file(ex_arch_path) output_code = read_file(ex_new_path) examples_entries.append( @@ -296,7 +288,7 @@ def render_example_entry(input_code: str, output_code: str, example_label: str) raise ValueError( f"Hardware info requested for option '{option}'; provide gpu_specs_py and gpu_name" ) - context = {**context, **_gpu_context_from_gpu_specs(_abs_path(gpu_specs_py), gpu_name)} + context = {**context, **_gpu_context_from_gpu_specs(resolve_path(gpu_specs_py), gpu_name)} # Builds the prompt from the components in the toml file. prompt_parts = [] @@ -416,10 +408,10 @@ def test_prompt(): generation. Customize the reference architecture or custom_prompt_key if you want to try different inputs. """ - REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) ref_arch_src = read_file(os.path.join(REPO_TOP_PATH, "KernelBench", "level1", "1_Square_matrix_multiplication_.py")) assert len(ref_arch_src) > 0, "ref_arch_src is empty" - + + print("Testing prompt construction...") scratch_dir = os.path.join(REPO_TOP_PATH, "scratch") # baseline prompt baseline_prompt = get_prompt_for_backend( diff --git a/src/kernelbench/timing.py b/src/kernelbench/timing.py index 6b0db010..83908391 100644 --- a/src/kernelbench/timing.py +++ b/src/kernelbench/timing.py @@ -77,6 +77,8 @@ def get_timing_function( return time_execution_with_do_bench_impl case "host_time": return time_execution_with_host_time + case "nsight_python_time": + return time_execution_with_nsight_python # we might add other methods in the future case _: raise ValueError(f"Unsupported timing method: {method}") @@ -86,7 +88,6 @@ def get_timing_function( NOTE: we have a WIP blogpost on this topic covering the various timing approaches """ - def time_execution_with_cuda_event( kernel_fn: callable, args: list[Any], @@ -388,6 +389,74 @@ def time_execution_with_host_time( return elapsed_times +def time_execution_with_nsight_python( + kernel_fn: callable, + args: list[Any], + num_warmup: int = 3, + num_trials: int = 10, + discard_first: int = 1, # not used here + verbose: bool = True, + device: torch.device | None = None) -> list[float]: + """ + Time a CUDA kernel function using nsight-python. + + Note: nsight returns an average time across num_trials runs. + Returns a list with a single value (average time) for API consistency. + GPU time from nsight is in nanoseconds, converted to milliseconds. + + Returns: + List containing one float: average elapsed time in milliseconds + """ + + from kernelbench.profile import profile_with_nsight + + if device is None: + if verbose: + print(f"Using current device: {torch.cuda.current_device()}") + device = torch.cuda.current_device() + + with torch.cuda.device(device): + # Warm ups + for _ in range(num_warmup): + kernel_fn(*args) + torch.cuda.synchronize(device=device) + + # Clear cache for cold start + torch.cuda.empty_cache() + clear_l2_cache(device=device) + + if verbose: + print(f"[Profiling] Using device: {device} {torch.cuda.get_device_name(device)}, warm up {num_warmup}, trials {num_trials}") + + # Profile with nsight - returns average time in nanoseconds + # Wrap kernel function + def wrapped_kernel(): + return kernel_fn(*args) + + # Profile with nsight, use gpu_time_duration.sum metric for GPU time + metric_values = profile_with_nsight( + wrapped_kernel, + metrics=["gpu__time_duration.sum"], + num_trials=num_trials + ) + + # Convert from nanoseconds to milliseconds + gpu_time_ns = metric_values.get("gpu__time_duration.sum") + if gpu_time_ns is None: + raise RuntimeError("Failed to get GPU time from nsight") + + # Convert nanoseconds to milliseconds + # nsight returns average across num_trials, so we return a single value in a list + gpu_time_ms = gpu_time_ns / 1_000_000.0 + + if verbose: + print(f"Average GPU time: {gpu_time_ms:.3f} ms (across {num_trials} trials)") + + # NOTE: nsight only returns average time across num_trials, so we return a single value in a list + # it did run num_trials times, but we only return the average (1 item) + # Return list with single average value for API consistency + return [gpu_time_ms] + ######################################################## # Timing stats # tools to help compute speedup and other time @@ -439,3 +508,4 @@ def get_timing_stats(elapsed_times: list[float], device: torch.device = None) -> stats["device"] = str(device) # for debugging return stats + diff --git a/src/kernelbench/unit_tests/test_eval_timing.py b/src/kernelbench/unit_tests/test_eval_timing.py index 8be8824b..470c4254 100644 --- a/src/kernelbench/unit_tests/test_eval_timing.py +++ b/src/kernelbench/unit_tests/test_eval_timing.py @@ -5,12 +5,13 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from kernelbench import timing +from kernelbench.profile import NSIGHT_AVAILABLE, check_ncu_available """ Test Timing We want to systematically study different timing methodologies. """ -REPO_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +REPO_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) # use exampls in the few shot directory EXAMPLES_PATH = os.path.join(REPO_PATH, "src", "kernelbench", "prompts", "few_shot") @@ -78,14 +79,60 @@ def matmul_kernel(a, b): # test all currently available timing methods def run_all_timing_tests(device="cuda"): timing_methods = ["cuda_event", "host_time", "do_bench", "do_bench_impl"] - # timing_methods = ["cuda_event", "do_bench_impl"] for timing_method in timing_methods: _run_timing_smoke_test_matmul(timing_method, device=device) +def run_nsight_timing_test(device="cuda"): + """ + Run nsight-python timing test if available. + + Nsight requires: + - nsight-python package installed + - ncu (Nsight Compute CLI) in PATH + + Compares nsight GPU time against cuda_event timing for the same matmul operation. + """ + if not NSIGHT_AVAILABLE: + print("[SKIP] nsight-python not installed") + return None + + if not check_ncu_available(): + print("[SKIP] ncu not found in PATH") + return None + + print("\n" + "=" * 60) + print("Running nsight-python timing benchmark") + print("=" * 60) + + # Run nsight timing + print("\n--- nsight_python_time ---") + _run_timing_smoke_test_matmul("nsight_python_time", device=device) + + # Run cuda_event for comparison + print("\n--- cuda_event (for comparison) ---") + _run_timing_smoke_test_matmul("cuda_event", device=device) + + print("\n" + "=" * 60) + print("nsight-python timing benchmark complete") + print("=" * 60) + + +def run_all_timing_tests_with_nsight(device="cuda"): + """ + Run all timing methods including nsight-python if available. + """ + # Standard timing methods (always available) + run_all_timing_tests(device=device) + + # Nsight timing (requires ncu + nsight-python) + run_nsight_timing_test(device=device) + +# select a free GPU here or set CUDA_VISIBLE_DEVICES test_device = torch.device("cuda:5") run_all_timing_tests(test_device) +run_nsight_timing_test(test_device) diff --git a/src/kernelbench/utils.py b/src/kernelbench/utils.py index abe067b1..cf8b0ad8 100644 --- a/src/kernelbench/utils.py +++ b/src/kernelbench/utils.py @@ -15,6 +15,7 @@ import os import json from tqdm import tqdm +from importlib.resources import files, as_file # API clients from openai import OpenAI @@ -273,6 +274,59 @@ def read_file(file_path) -> str: return "" +######################################################## +# Path Resolution Helpers +######################################################## +REPO_TOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) + +def get_package_resource_path(relative_path: str) -> str: + """ + Get absolute path to a kernelbench package resource. + Works for all three usage modes: + - Running from repo directly + - As a git submodule + - As an installed pip/uv dependency + + Args: + relative_path: Path relative to kernelbench/, e.g. "prompts/prompts.toml" + """ + # Try importlib.resources first (installed package) + try: + resource = files("kernelbench").joinpath(relative_path) + with as_file(resource) as path: + if path.exists(): + return str(path) + except (TypeError, FileNotFoundError): + pass + + # Try repo path (running from source / submodule) + repo_path = os.path.join(REPO_TOP_PATH, "src/kernelbench", relative_path) + if os.path.exists(repo_path): + return repo_path + + raise FileNotFoundError(f"Could not find resource: {relative_path}") + + +def resolve_path(rel: str) -> str: + """ + Resolve a relative path to absolute. Handles paths like "src/kernelbench/prompts/..." + from prompts.toml which reference files relative to repo root. + """ + if os.path.isabs(rel): + return rel + + # Convert "src/kernelbench/..." paths to package-relative + if rel.startswith("src/kernelbench/"): + return get_package_resource_path(rel[len("src/kernelbench/"):]) + + # Otherwise treat as repo-relative + repo_path = os.path.join(REPO_TOP_PATH, rel) + if os.path.exists(repo_path): + return repo_path + + raise FileNotFoundError(f"Could not resolve path: {rel}") + + def print_messages(messages): for message in messages: print(message["role"])