diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/__init__.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/__init__.py new file mode 100644 index 0000000..34e8023 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Diagnose Prompt Module for Hardware Bottleneck Analysis. + +This module provides prompt building utilities for the Judge LLM that +analyzes NCU profiling metrics to identify performance bottlenecks. +""" + +from .gpu_specs import get_gpu_specs +from .judger_prompts import ( + build_judge_optimization_prompt, + extract_judge_response, + validate_judge_response, +) + +__all__ = [ + "get_gpu_specs", + "build_judge_optimization_prompt", + "extract_judge_response", + "validate_judge_response", +] diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py new file mode 100644 index 0000000..e465116 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GPU Specifications Database for Bottleneck Analysis + +This module provides GPU hardware specifications needed for performance analysis +and bottleneck identification. It includes peak compute performance, memory bandwidth, +cache sizes, and SM counts for common NVIDIA GPUs. + +""" + +import subprocess +from typing import Any + +from kernel_perf_agent.kernel_opt.diagnose_prompt.gpu_specs_database import ( + GPU_SPECS_DATABASE, +) + +__all__ = ["GPU_SPECS_DATABASE", "query_gpu_name", "get_gpu_specs"] + + +def query_gpu_name() -> str | None: + """ + Query GPU name using nvidia-smi. + + Returns: + GPU name string, or None if query fails + """ + try: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + # Take only the first GPU (nvidia-smi returns one line per GPU) + gpu_name = result.stdout.strip().split("\n")[0].strip() + return gpu_name + except (subprocess.TimeoutExpired, FileNotFoundError, Exception): + pass + return None + + +def get_gpu_specs(gpu_name: str | None = None) -> dict[str, Any]: + """ + Get GPU specifications for bottleneck analysis. + + This function returns hardware specifications needed for performance analysis, + including peak compute performance, memory bandwidth, cache sizes, and SM counts. + + Args: + gpu_name: GPU name (if None, auto-detect with nvidia-smi) + + Returns: + Dictionary with GPU specifications containing: + - name: GPU name + - architecture: GPU architecture (e.g., "Ampere", "Hopper") + - peak_fp32_tflops: Peak FP32 compute performance in TFLOPS + - peak_fp16_tflops: Peak FP16 compute performance in TFLOPS + - peak_bf16_tflops: Peak BF16 compute performance in TFLOPS (0 if not supported) + - peak_memory_bw_gbps: Peak memory bandwidth in GB/s + - sm_count: Number of streaming multiprocessors + - max_threads_per_sm: Maximum threads per SM + - l1_cache_kb: L1 cache size in KB per SM + - l2_cache_mb: Total L2 cache size in MB + - memory_gb: Total GPU memory in GB + - memory_type: Memory type (e.g., "HBM2e", "GDDR6X") + + Examples: + >>> specs = get_gpu_specs() # Auto-detect + >>> print(f"Peak BW: {specs['peak_memory_bw_gbps']} GB/s") + + >>> specs = get_gpu_specs("NVIDIA A100") + >>> print(f"SM Count: {specs['sm_count']}") + """ + # Auto-detect if not provided + if gpu_name is None: + gpu_name = query_gpu_name() + + # Return default if detection failed + if gpu_name is None: + print("⚠️ GPU auto-detection failed, using A100 specs as fallback") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + # Try exact match + if gpu_name in GPU_SPECS_DATABASE: + return GPU_SPECS_DATABASE[gpu_name].copy() + + # Try fuzzy match (contains or partial match) + gpu_name_lower = gpu_name.lower() + for key, specs in GPU_SPECS_DATABASE.items(): + key_lower = key.lower() + # Check if either name contains the other + if gpu_name_lower in key_lower or key_lower in gpu_name_lower: + print(f"ℹ️ Matched '{gpu_name}' to '{key}' (fuzzy match)") + return specs.copy() + + # Fallback to A100 specs with warning + print(f"⚠️ Unknown GPU: '{gpu_name}', using A100 specs as fallback") + print(f" Available GPUs: {', '.join(GPU_SPECS_DATABASE.keys())}") + return GPU_SPECS_DATABASE["NVIDIA A100"].copy() + + +if __name__ == "__main__": + print("GPU Specifications Module") + print("=" * 60) + + # Auto-detect GPU + detected_name = query_gpu_name() + if detected_name: + print(f"\nDetected GPU: {detected_name}") + else: + print("\nNo GPU detected (nvidia-smi not available)") + exit() + + # Get specs + specs = get_gpu_specs() + print( + f"\nUsing specs for: {specs['name']} ({specs.get('architecture', 'Unknown')})" + ) + print(f" - Peak Memory Bandwidth: {specs['peak_memory_bw_gbps']} GB/s") + print(f" - Peak FP32 Performance: {specs['peak_fp32_tflops']} TFLOPS") + print(f" - SM Count: {specs['sm_count']}") + + # Show all available GPUs + print(f"\n{'=' * 60}") + print("Available GPU specifications in database:") + for gpu_name in sorted(GPU_SPECS_DATABASE.keys()): + print(f" - {gpu_name}") diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py new file mode 100644 index 0000000..d5e4586 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/gpu_specs_database.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +GPU Specifications Database + +This module contains the GPU hardware specifications database used for +performance analysis and bottleneck identification. Separated into its +own file to allow easier module overriding. + +Sources: NVIDIA official specifications, manufacturer datasheets +""" + +GPU_SPECS_DATABASE: dict[str, dict[str, object]] = { + "NVIDIA A100": { + "name": "NVIDIA A100", + "architecture": "Ampere", + "peak_fp32_tflops": 19.5, + "peak_fp16_tflops": 312.0, + "peak_bf16_tflops": 312.0, + "peak_memory_bw_gbps": 1555, + "sm_count": 108, + "max_threads_per_sm": 2048, + "l1_cache_kb": 192, + "l2_cache_mb": 40, + "memory_gb": 40, + "memory_type": "HBM2e", + }, + "NVIDIA H100": { + "name": "NVIDIA H100", + "architecture": "Hopper", + "peak_fp32_tflops": 51.0, + "peak_fp16_tflops": 989.0, + "peak_bf16_tflops": 989.0, + "peak_memory_bw_gbps": 3352, + "sm_count": 132, + "max_threads_per_sm": 2048, + "l1_cache_kb": 256, + "l2_cache_mb": 50, + "memory_gb": 80, + "memory_type": "HBM3", + }, + "NVIDIA RTX 4090": { + "name": "NVIDIA RTX 4090", + "architecture": "Ada Lovelace", + "peak_fp32_tflops": 82.6, + "peak_fp16_tflops": 165.0, + "peak_bf16_tflops": 165.0, + "peak_memory_bw_gbps": 1008, + "sm_count": 128, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 72, + "memory_gb": 24, + "memory_type": "GDDR6X", + }, + "NVIDIA RTX 5080": { + "name": "NVIDIA RTX 5080", + "architecture": "Blackwell", + "peak_fp32_tflops": 57.0, + "peak_fp16_tflops": 114.0, + "peak_bf16_tflops": 114.0, + "peak_memory_bw_gbps": 960, + "sm_count": 84, + "max_threads_per_sm": 1536, + "l1_cache_kb": 128, + "l2_cache_mb": 64, + "memory_gb": 16, + "memory_type": "GDDR7", + }, +} diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py new file mode 100644 index 0000000..e07749e --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/judger_prompts.py @@ -0,0 +1,343 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Prompt Builder for Hardware Bottleneck Diagnosis + +This module provides prompt templates and builder functions for the Judge LLM +that analyzes NCU profiling metrics to identify performance bottlenecks and +provide specific optimization recommendations. + +The Judge uses a dual-bottleneck framework based on NCU hardware profiling: +- bottleneck_1 (Primary): Highest-impact performance issue +- bottleneck_2 (Secondary): Different category issue that also limits performance + +Both bottlenecks are selected from NCU hardware profiling categories: +- memory-bound +- compute-bound +- occupancy-limited +- latency-bound + +Metric definitions are in metric_schema.py. +""" + +from typing import Any, Callable + +from .metric_schema import GPU_MEMORY_FIELDS, GPU_SPEC_FIELDS, NCU_METRIC_SECTIONS + + +# ============================================================================= +# Section Renderers +# ============================================================================= + + +def render_problem_description(problem_description: str) -> list[str]: + """Render the problem description section.""" + return ["## Problem Description", "", problem_description] + + +def render_kernel_code(kernel_code: str, language: str = "python") -> list[str]: + """Render the kernel code section with syntax highlighting.""" + return ["", "## Current Kernel Code", "", f"```{language}", kernel_code, "```"] + + +def render_gpu_specs(gpu_specs: dict[str, Any]) -> list[str]: + """Render the GPU hardware specifications section.""" + lines = ["", "## GPU Hardware Specifications", ""] + + for label, key, unit in GPU_SPEC_FIELDS: + value = gpu_specs.get(key, "N/A") + if value == "N/A": + lines.append(f"- **{label}:** N/A") + else: + lines.append(f"- **{label}:** {value}{unit}") + + for label, size_key, type_key, size_unit in GPU_MEMORY_FIELDS: + size_value = gpu_specs.get(size_key, "N/A") + type_value = gpu_specs.get(type_key, "") + lines.append(f"- **{label}:** {size_value}{size_unit} {type_value}") + + return lines + + +def render_ncu_metrics( + ncu_metrics: dict[str, Any], + get_metric_fn: Callable[[str, str], str], +) -> list[str]: + """Render the NCU profiling metrics section.""" + lines = ["", "## NCU Profiling Metrics"] + + for section_name, metrics in NCU_METRIC_SECTIONS.items(): + lines.append("") + lines.append(f"### {section_name}") + for label, key, unit in metrics: + value = get_metric_fn(key, "N/A") + lines.append(f"- **{label}:** {value}{unit}") + + return lines + + +def render_task_instructions() -> list[str]: + """Render the task instructions section for dual-bottleneck analysis.""" + return [ + "", + "## Your Task", + "", + "Identify exactly TWO distinct bottlenecks from the NCU profiling metrics above:", + "1. **Bottleneck 1 (Primary)**: The highest-impact performance issue", + "2. **Bottleneck 2 (Secondary)**: A different category issue that also limits performance", + "", + "For each bottleneck, cite 3-4 specific metrics that reveal the issue, " + "and recommend ONE actionable optimization.", + "", + "**Be surgical and metrics-driven.** Return JSON in the format specified in the system prompt.", + ] + + +def create_metric_getter(kernel_metrics: dict[str, Any]) -> Callable[[str, str], str]: + """Create a metric getter function for a specific kernel's metrics.""" + + def get_metric(key: str, default: str = "N/A") -> str: + val = kernel_metrics.get(key, default) + if isinstance(val, (int, float)): + return f"{val:.2f}" + return str(val) + + return get_metric + + +# ============================================================================= +# Bottleneck Analysis +# ============================================================================= + + +# System prompt for the Judge LLM (Dual-Bottleneck NCU Analysis) +JUDGE_SYSTEM_PROMPT = """You are a senior GPU performance engineer. Analyze the target GPU spec, the current kernel, and the Nsight Compute (NCU) profiling metrics. Identify EXACTLY TWO DISTINCT bottlenecks from the hardware profiling data, and propose specific optimization methods for each. Be surgical and metrics-driven. + +## Bottleneck Categories (NCU Hardware Profiling) + +Analyze fundamental resource utilization using NCU profiling data: + +## Bottleneck Categories (Indicators Only) +- **memory-bound**: High DRAM throughput (>60%), low L1/L2 hit rates (<70%), high memory stalls (>30%) +- **compute-bound**: Low DRAM throughput (<40%), high compute utilization (>60%), low memory stalls (<20%) +- **occupancy-limited**: Low warp active (<50%), high register usage (>100/thread), shared memory pressure (>80%) +- **latency-bound**: High total stalls (>40%), memory dependency stalls dominate, long scoreboard stalls + +- Return EXACTLY TWO DISTINCT bottlenecks with DIFFERENT categories +- Both bottlenecks must be from: {memory-bound, compute-bound, occupancy-limited, latency-bound} +- For each bottleneck, cite 3-4 specific NCU metric values that reveal the issue +- Propose ONE actionable optimization method per bottleneck +- Keep fields brief; avoid lists of alternatives, disclaimers, or generic advice + +## Output Format (JSON - STRICT) + +```json +{ + "bottleneck_1": { + "category": "", + "root_cause": "", + "suggestion": "", + "priority_metrics": ["", "", ""], + "expected_improvement": "" + }, + "bottleneck_2": { + "category": "", + "root_cause": "", + "suggestion": "", + "priority_metrics": ["", "", ""], + "expected_improvement": "" + } +} +``` + +## Important Notes + +- bottleneck_1 is the PRIMARY (highest-impact) issue +- bottleneck_2 is the SECONDARY issue (different category from bottleneck_1) +- They should be independently addressable (fixing one doesn't automatically fix the other) + +Follow the Rules exactly. Return JSON in the specified format. +""" + + +def build_judge_optimization_prompt( + kernel_code: str, + problem_description: str, + ncu_metrics: dict[str, Any], + gpu_specs: dict[str, Any], +) -> tuple[str, str]: + """ + Build system and user prompts for Judge to analyze bottleneck. + + This function constructs detailed prompts for the Judge LLM that include: + - The kernel code being analyzed + - The original problem description + - Complete NCU profiling metrics + - GPU hardware specifications + + Args: + kernel_code: Current Triton kernel code + problem_description: Original problem description + ncu_metrics: NCU profiling metrics as a dictionary (from metrics_to_prompt) + gpu_specs: GPU specifications (from get_gpu_specs) + + Returns: + Tuple of (system_prompt, user_prompt) + + Example: + >>> sys_prompt, user_prompt = build_judge_optimization_prompt( + ... kernel_code=kernel_code, + ... problem_description=problem_desc, + ... ncu_metrics=ncu_metrics, + ... gpu_specs=gpu_specs, + ... ) + >>> response = llm.call([ + ... {"role": "system", "content": sys_prompt}, + ... {"role": "user", "content": user_prompt} + ... ]) + """ + if not ncu_metrics: + raise ValueError("NCU metrics are empty - cannot build judge prompt") + + # Extract first kernel's metrics for the metric getter + first_kernel = list(ncu_metrics.values())[0] + get_metric_fn = create_metric_getter(first_kernel) + + # Build user prompt using modular section renderers + parts: list[str] = [] + + # Compose sections using renderers + parts.extend(render_problem_description(problem_description)) + parts.extend(render_kernel_code(kernel_code)) + parts.extend(render_gpu_specs(gpu_specs)) + parts.extend(render_ncu_metrics(ncu_metrics, get_metric_fn)) + parts.extend(render_task_instructions()) + + user_prompt = "\n".join(parts) + return JUDGE_SYSTEM_PROMPT, user_prompt + + +def extract_judge_response(response_text: str) -> dict[str, Any] | None: + """ + Extract and parse JSON from Judge LLM response. + + This function handles various response formats and provides fallback strategies + for robust JSON extraction. Expects dual-bottleneck format with bottleneck_1 + and bottleneck_2 fields. + + Args: + response_text: Raw text response from Judge LLM + + Returns: + Parsed JSON dictionary with bottleneck_1 and bottleneck_2, + or None if extraction fails + + Example: + >>> response = llm.call(judge_prompts) + >>> analysis = extract_judge_response(response) + >>> if analysis: + ... print(f"Bottleneck 1: {analysis['bottleneck_1']['category']}") + ... print(f"Bottleneck 2: {analysis['bottleneck_2']['category']}") + """ + import json + import re + + # Strategy 1: Find JSON in code block + match = re.search(r"```json\s*(\{.*?\})\s*```", response_text, re.DOTALL) + if match: + try: + data = json.loads(match.group(1)) + if "bottleneck_1" in data and "bottleneck_2" in data: + return data + except json.JSONDecodeError: + pass + + # Strategy 2: Find first { ... } block with "bottleneck_1" field + match = re.search(r'\{[^}]*"bottleneck_1"[^}]*\}', response_text, re.DOTALL) + if match: + try: + # Extract the full JSON object (may be nested) + start_pos = response_text.find("{", match.start()) + brace_count = 0 + end_pos = start_pos + + for i in range(start_pos, len(response_text)): + if response_text[i] == "{": + brace_count += 1 + elif response_text[i] == "}": + brace_count -= 1 + if brace_count == 0: + end_pos = i + 1 + break + + json_str = response_text[start_pos:end_pos] + data = json.loads(json_str) + if "bottleneck_1" in data and "bottleneck_2" in data: + return data + except (json.JSONDecodeError, ValueError): + pass + + # Strategy 3: Find any JSON object with dual-bottleneck structure + match = re.search( + r'\{\s*"bottleneck_1"\s*:\s*\{.*?\}\s*,\s*"bottleneck_2"\s*:\s*\{.*?\}\s*\}', + response_text, + re.DOTALL, + ) + if match: + try: + return json.loads(match.group(0)) + except json.JSONDecodeError: + pass + + # Return None if all strategies fail + return None + + +def validate_judge_response(analysis: dict[str, Any]) -> bool: + """Validate that Judge response contains required dual-bottleneck fields.""" + if "bottleneck_1" not in analysis or "bottleneck_2" not in analysis: + return False + return _validate_bottleneck_entry( + analysis["bottleneck_1"] + ) and _validate_bottleneck_entry(analysis["bottleneck_2"]) + + +VALID_CATEGORIES = { + "memory-bound", + "compute-bound", + "occupancy-limited", + "latency-bound", +} + + +def _validate_bottleneck_entry(bottleneck: dict[str, Any]) -> bool: + """Validate a single bottleneck entry.""" + required = [ + "category", + "root_cause", + "suggestion", + "priority_metrics", + "expected_improvement", + ] + if not all(f in bottleneck for f in required): + return False + if bottleneck["category"] not in VALID_CATEGORIES: + return False + if not isinstance(bottleneck["priority_metrics"], list): + return False + for f in ["root_cause", "suggestion", "expected_improvement"]: + if not isinstance(bottleneck[f], str) or len(bottleneck[f].strip()) < 5: + return False + return True diff --git a/kernel_perf_agent/kernel_opt/diagnose_prompt/metric_schema.py b/kernel_perf_agent/kernel_opt/diagnose_prompt/metric_schema.py new file mode 100644 index 0000000..64d1d67 --- /dev/null +++ b/kernel_perf_agent/kernel_opt/diagnose_prompt/metric_schema.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Metric Schema Definitions for NCU Profiling and GPU Specifications. + +This module provides the single source of truth for: +- NCU profiling metric definitions (keys, labels, units) +- GPU specification field definitions + +Schema Format: List of tuples (display_label, key, unit_suffix) +- display_label: Human-readable name shown in prompts +- key: NCU metric key or GPU spec dictionary key +- unit_suffix: Unit to append after value (e.g., "%", " GB/s", " bytes") +""" + +from typing import Dict, List, Tuple + +# Type alias for metric definition: (label, key, unit) +MetricDef = Tuple[str, str, str] + +# ============================================================================= +# GPU Specification Fields +# ============================================================================= + +GPU_SPEC_FIELDS: List[MetricDef] = [ + ("Name", "name", ""), + ("Architecture", "architecture", ""), + ("Peak Memory Bandwidth", "peak_memory_bw_gbps", " GB/s"), + ("Peak FP32 Performance", "peak_fp32_tflops", " TFLOPS"), + ("Peak FP16 Performance", "peak_fp16_tflops", " TFLOPS"), + ("SM Count", "sm_count", ""), + ("Max Threads per SM", "max_threads_per_sm", ""), + ("L1 Cache per SM", "l1_cache_kb", " KB"), + ("L2 Cache (Total)", "l2_cache_mb", " MB"), +] + +# Special case: Memory Size has two fields combined +GPU_MEMORY_FIELDS: List[Tuple[str, str, str, str]] = [ + # (label, size_key, type_key, size_unit) + ("Memory Size", "memory_gb", "memory_type", " GB"), +] + +# ============================================================================= +# NCU Profiling Metric Sections +# ============================================================================= + +NCU_METRIC_SECTIONS: Dict[str, List[MetricDef]] = { + "SM & Compute Utilization": [ + ("SM Cycles Active", "sm__cycles_active.avg", ""), + ("Warp Active", "sm__warps_active.avg.pct_of_peak_sustained_active", "%"), + ("Total Instructions Executed", "sm__inst_executed.sum", ""), + ( + "Tensor Core Utilization", + "sm__inst_executed_pipe_tensor.avg.pct_of_peak_sustained_active", + "%", + ), + ( + "Tensor Core Pipeline Active", + "sm__pipe_tensor_cycles_active.avg.pct_of_peak_sustained_elapsed", + "%", + ), + ], + "Memory Bandwidth & Cache": [ + ( + "DRAM Throughput", + "dram__throughput.avg.pct_of_peak_sustained_elapsed", + "%", + ), + ("DRAM Bandwidth", "dram__bytes.sum.per_second", " bytes/sec"), + ( + "GPU DRAM Throughput", + "gpu__dram_throughput.avg.pct_of_peak_sustained_elapsed", + "%", + ), + ("DRAM Bytes Read", "dram__bytes_read.sum", " bytes"), + ("DRAM Bytes Write", "dram__bytes_write.sum", " bytes"), + ("L1 Cache Hit Rate", "l1tex__t_sector_hit_rate.pct", "%"), + ( + "L1 Throughput", + "l1tex__throughput.avg.pct_of_peak_sustained_active", + "%", + ), + ("L2 Cache Hit Rate", "lts__t_sector_hit_rate.pct", "%"), + ( + "L2 Throughput", + "lts__throughput.avg.pct_of_peak_sustained_active", + "%", + ), + ], + "Memory Access Patterns": [ + ( + "Memory Coalescing", + "smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct", + "%", + ), + ( + "Branch Uniformity", + "smsp__sass_average_branch_targets_threads_uniform.pct", + "%", + ), + ], + "Occupancy & Resources": [ + ("Occupancy Limited By Blocks", "launch__occupancy_limit_blocks", ""), + ("Occupancy Limited By Registers", "launch__occupancy_limit_registers", ""), + ( + "Occupancy Limited By Shared Memory", + "launch__occupancy_limit_shared_mem", + "", + ), + ("Registers per Thread", "launch__registers_per_thread", ""), + ( + "Shared Memory per Block", + "launch__shared_mem_per_block_allocated", + " bytes", + ), + ], + "Stall Metrics (Warp Issue Stalls)": [ + ( + "Short Scoreboard Stalls", + "smsp__warp_issue_stalled_short_scoreboard_per_warp_active.pct", + "%", + ), + ( + "Long Scoreboard Stalls", + "smsp__warp_issue_stalled_long_scoreboard_per_warp_active.pct", + "%", + ), + ( + "Barrier Stalls", + "smsp__warp_issue_stalled_barrier_per_warp_active.pct", + "%", + ), + ( + "Branch Resolving Stalls", + "smsp__warp_issue_stalled_branch_resolving_per_warp_active.pct", + "%", + ), + ], +} diff --git a/triton_kernel_agent/prompt_manager.py b/triton_kernel_agent/prompt_manager.py index 7c22009..2b6ccbe 100644 --- a/triton_kernel_agent/prompt_manager.py +++ b/triton_kernel_agent/prompt_manager.py @@ -15,6 +15,7 @@ """Prompt Manager for handling Jinja2 templates.""" from pathlib import Path +from typing import Any from triton_kernel_agent.platform_config import PlatformConfig, get_platform @@ -88,6 +89,7 @@ def _load_templates(self): "test_generation": "test_generation.j2", "kernel_generation": "kernel_generation.j2", "kernel_refinement": "kernel_refinement.j2", + "kernel_optimization": "kernel_optimization.j2", "triton_guidelines": "triton_guidelines.j2", } @@ -188,6 +190,59 @@ def render_kernel_refinement_prompt( kernel_guidance=self.target_platform.kernel_guidance, ) + def render_kernel_optimization_prompt( + self, + kernel_code: str, + problem_description: str, + bottleneck_analysis: dict[str, Any], + bottleneck_id: int = 1, + gpu_specs: dict[str, Any] | None = None, + pytorch_baseline_ms: float | None = None, + error_feedback: str | None = None, + ) -> str: + """ + Render the kernel optimization prompt based on bottleneck analysis. + + Args: + kernel_code: Current kernel code to optimize + problem_description: Problem description + bottleneck_analysis: Dual-bottleneck analysis with bottleneck_1 and bottleneck_2 + bottleneck_id: Which bottleneck to focus on (1 or 2) + gpu_specs: GPU specifications dict + pytorch_baseline_ms: PyTorch baseline time in ms + error_feedback: Error feedback from previous failed attempt + + Returns: + Rendered prompt string + """ + template = self.templates["kernel_optimization"] + + # Select bottleneck + if bottleneck_id == 2: + bottleneck = bottleneck_analysis.get("bottleneck_2", {}) + bottleneck_label = "Bottleneck 2 (Secondary)" + else: + bottleneck = bottleneck_analysis.get("bottleneck_1", {}) + bottleneck_label = "Bottleneck 1 (Primary)" + + # Calculate target time if baseline provided + target_ms = None + if pytorch_baseline_ms and pytorch_baseline_ms != float("inf"): + target_ms = pytorch_baseline_ms * 0.8 + + return template.render( + kernel_code=kernel_code, + problem_description=problem_description, + bottleneck=bottleneck, + bottleneck_label=bottleneck_label, + gpu_specs=gpu_specs, + pytorch_baseline_ms=pytorch_baseline_ms + if pytorch_baseline_ms != float("inf") + else None, + target_ms=target_ms, + error_feedback=error_feedback, + ) + def render_triton_guidelines(self) -> str: """ Render the Triton guidelines. diff --git a/triton_kernel_agent/templates/kernel_optimization.j2 b/triton_kernel_agent/templates/kernel_optimization.j2 new file mode 100644 index 0000000..bef1c64 --- /dev/null +++ b/triton_kernel_agent/templates/kernel_optimization.j2 @@ -0,0 +1,89 @@ +{# +Copyright (c) Meta Platforms, Inc. and affiliates. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +#} + +TASK: Optimize the following Triton kernel based on hardware profiling analysis to achieve better performance. + +{% if gpu_specs %} +TARGET GPU: +{% if gpu_specs.name %}- GPU: {{ gpu_specs.name }} +{% endif %} +{% if gpu_specs.architecture %}- Architecture: {{ gpu_specs.architecture }} +{% endif %} +{% if gpu_specs.peak_memory_bw_gbps %}- Peak Memory Bandwidth: {{ gpu_specs.peak_memory_bw_gbps }} GB/s +{% endif %} +{% if gpu_specs.peak_fp32_tflops %}- Peak FP32: {{ gpu_specs.peak_fp32_tflops }} TFLOPS +{% endif %} +{% if gpu_specs.peak_fp16_tflops %}- Peak FP16: {{ gpu_specs.peak_fp16_tflops }} TFLOPS +{% endif %} +{% if gpu_specs.peak_bf16_tflops %}- Peak BF16: {{ gpu_specs.peak_bf16_tflops }} TFLOPS +{% endif %} +{% if gpu_specs.sm_count %}- SM Count: {{ gpu_specs.sm_count }} +{% endif %} +{% if gpu_specs.max_threads_per_sm %}- Max Threads per SM: {{ gpu_specs.max_threads_per_sm }} +{% endif %} +{% if gpu_specs.l1_cache_kb %}- L1 Cache per SM: {{ gpu_specs.l1_cache_kb }} KB +{% endif %} +{% if gpu_specs.l2_cache_mb %}- L2 Cache (Total): {{ gpu_specs.l2_cache_mb }} MB +{% endif %} +{% if gpu_specs.memory_gb %}- Memory: {{ gpu_specs.memory_gb }} GB {{ gpu_specs.memory_type | default('') }} +{% endif %} + +{% endif %} +PROBLEM DESCRIPTION: +{{ problem_description }} +{% if pytorch_baseline_ms %} +PyTorch Eager baseline: {{ "%.4f"|format(pytorch_baseline_ms) }} ms +{% endif %} + +CURRENT KERNEL IMPLEMENTATION: +```python +{{ kernel_code }} +``` + +OPTIMIZATION STRATEGY ({{ bottleneck_label }}): +The hardware profiling (NCU) analysis identified the following bottleneck: +- Category: {{ bottleneck.category | default('unknown') }} +- Root Cause: {{ bottleneck.root_cause | default('N/A') }} +- Suggested Optimization: {{ bottleneck.suggestion | default('N/A') }} +- Expected Improvement: {{ bottleneck.expected_improvement | default('N/A') }} + +{% if error_feedback %} +PREVIOUS ATTEMPT FAILED: +{{ error_feedback }} + +{% endif %} +PERFORMANCE TARGET: +{% if target_ms %} +- Achieve at least 1.25x speedup vs PyTorch Eager (target: <= {{ "%.4f"|format(target_ms) }} ms) +{% else %} +- Achieve 20-100% performance improvement over baseline +{% endif %} +- Maintain numerical correctness (atol=1e-4 or rtol=1e-4) +- Preserve public API (same inputs/outputs, shapes, dtypes) + +CRITICAL REQUIREMENTS: +1. Apply the optimization strategy described above to address the identified bottleneck +2. The implementation must be a complete, valid Python file +3. The main function must be named 'kernel_function' that wraps the actual Triton kernel +4. Focus on the specific optimization while maintaining correctness +5. Keep the wrapper free of PyTorch compute primitives + +OUTPUT FORMAT: +1. Output complete optimized kernel code in ```python blocks +2. Include only: imports, Triton kernel (@triton.jit), wrapper function (kernel_function) +3. No testing code, benchmarks, or explanatory comments + +Generate the complete optimized kernel implementation: