From a8eb03c3c9f9050e996a706cfaa74e614335e731 Mon Sep 17 00:00:00 2001 From: Aoyu Zhang Date: Wed, 28 Jan 2026 06:15:59 +0000 Subject: [PATCH] feat: add CompilerConfig dataclass for structured compiler options Replace string-based compiler arguments with a type-safe CompilerConfig dataclass that provides discoverability and easy customization. Changes: - Add CompilerConfig dataclass with fields for common neuronx-cc options (lnc, model_type, auto_cast, enable_mixed_precision_accumulation, etc.) - Add factory methods for_nkipy() and for_nki() with appropriate defaults - Add get_default_compiler_args() helper to inspect default settings - Add compiler_config parameter to @baremetal_jit, baremetal_run_traced_kernel, and DeviceKernel.compile_and_load() - Export CompilerConfig and get_default_compiler_args from nkipy.runtime - Add tutorial section demonstrating CompilerConfig usage Backward compatible: legacy additional_compiler_args parameter still works. --- docs/tutorials/01_simple.ipynb | 399 ++++++++++++----------- nkipy/src/nkipy/core/compile.py | 163 ++++++++- nkipy/src/nkipy/runtime/__init__.py | 4 + nkipy/src/nkipy/runtime/decorators.py | 20 +- nkipy/src/nkipy/runtime/device_kernel.py | 44 ++- nkipy/src/nkipy/runtime/execute.py | 41 ++- 6 files changed, 445 insertions(+), 226 deletions(-) diff --git a/docs/tutorials/01_simple.ipynb b/docs/tutorials/01_simple.ipynb index 2ab3534..767f114 100644 --- a/docs/tutorials/01_simple.ipynb +++ b/docs/tutorials/01_simple.ipynb @@ -1,199 +1,202 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "273d6f72-dcf1-4bf7-8538-39c6aed09036", - "metadata": {}, - "source": [ - "# Simple NKIPy Tutorial\n", - "\n", - "This tutorial uses a simple softmax NKIPy kernel to go through how NKIPy works.\n", - "\n", - "We will cover:\n", - "\n", - "- Defining a NKIPy kernel\n", - "- Run it as NumPy function\n", - "- Trace and run in simulation mode\n", - "- Compile it and run it on Trainium hardware" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "976309b5", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "\n", - "from nkipy.core.trace import NKIPyKernel\n", - "from nkipy.core.compile import lower_to_nki\n", - "from nkipy.runtime.execute import simulate_traced_kernel, baremetal_run_traced_kernel" - ] - }, - { - "cell_type": "markdown", - "id": "a1368557-8c80-4b7d-9cb0-6f317bf61b5b", - "metadata": {}, - "source": [ - "## Defining A NKIPy Kernel\n", - "\n", - "A NKIPy looks like a NumPy kernel. \n", - "It supports a subset of NumPy and Python syntax." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "690ede8c", - "metadata": {}, - "outputs": [], - "source": [ - "def softmax_kernel(x):\n", - " exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))\n", - " sum_x = np.sum(exp_x, axis=-1, keepdims=True)\n", - " return exp_x / sum_x" - ] - }, - { - "cell_type": "markdown", - "id": "c172321f-67af-48aa-83fb-e9e49c93eec4", - "metadata": {}, - "source": [ - "## Running a NKIPy Kernel as a NumPy function" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "42015e75", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input is [[0.9495564 0.39231038]\n", - " [0.05852599 0.9262922 ]]\n", - "NumPy output is [[0.635815 0.36418492]\n", - " [0.2957193 0.7042807 ]]\n" - ] - } - ], - "source": [ - "# NKIPy is NumPy-like, and in most cases, NumPy compatible\n", - "# So, we can run NKIPy kernel directly as NumPy\n", - "x = np.random.rand(2, 2).astype(np.float32)\n", - "print(f\"Input is {x}\")\n", - "\n", - "out_numpy = softmax_kernel(x)\n", - "print(f\"NumPy output is {out_numpy}\")" - ] - }, - { - "cell_type": "markdown", - "id": "4c6b5597-61c4-40e8-9ce1-75573e38331b", - "metadata": {}, - "source": [ - "## Tracing a NKIPy Kernel" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "9d15c533", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "# To run NKIPy kernels on Trainium, we need to trace as a NKIPyKernel with the `trace` wrapper\n", - "traced_kernel = NKIPyKernel.trace(softmax_kernel)" - ] - }, - { - "cell_type": "markdown", - "id": "3c8809ee-5d41-4f0c-a23a-78a6ba54cedc", - "metadata": {}, - "source": [ - "## Running the Traced Kernel with Simulation" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "8f1e4610", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Is the simulated output the same as NumPy? True\n" - ] - } - ], - "source": [ - "out_nkipy = simulate_traced_kernel(traced_kernel, x)\n", - "print(f\"Is the simulated output the same as NumPy? {np.allclose(out_nkipy, out_numpy)}\")" - ] - }, - { - "cell_type": "markdown", - "id": "5e49b9bf-7323-4eac-825e-91abe3347e02", - "metadata": {}, - "source": [ - "## Running it On Trainium Hardware" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "7c3c6a0a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Is the output the same as NumPy? True\n" - ] - } - ], - "source": [ - "# NKIPy kernel can be compiled to binary (NEFF) and execute on real hardware!\n", - "# The baremetal wrapper is used to execute the compiled binary on Trainium hardware\n", - "# in baremetal mode (without framework support)\n", - "out_baremetal = baremetal_run_traced_kernel(traced_kernel, x)\n", - "print(f\"Is the output the same as NumPy? {np.allclose(out_baremetal, out_numpy)}\")" - ] - } - ], - "metadata": { - "jupytext": { - "cell_metadata_filter": "-all", - "formats": "ipynb,auto:light", - "main_language": "python", - "notebook_metadata_filter": "-all" - }, - "kernelspec": { - "display_name": "venv_nkipy_local", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "cells": [ + { + "cell_type": "markdown", + "id": "273d6f72-dcf1-4bf7-8538-39c6aed09036", + "metadata": {}, + "source": "# Simple NKIPy Tutorial\n\nThis tutorial uses a simple softmax NKIPy kernel to go through how NKIPy works.\n\nWe will cover:\n\n- Defining a NKIPy kernel\n- Run it as NumPy function\n- Trace and run in simulation mode\n- Compile it and run it on Trainium hardware\n- Customize compiler options with CompilerConfig" + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "976309b5", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "from nkipy.core.trace import NKIPyKernel\n", + "from nkipy.core.compile import lower_to_nki\n", + "from nkipy.runtime.execute import simulate_traced_kernel, baremetal_run_traced_kernel" + ] + }, + { + "cell_type": "markdown", + "id": "a1368557-8c80-4b7d-9cb0-6f317bf61b5b", + "metadata": {}, + "source": [ + "## Defining A NKIPy Kernel\n", + "\n", + "A NKIPy looks like a NumPy kernel. \n", + "It supports a subset of NumPy and Python syntax." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "690ede8c", + "metadata": {}, + "outputs": [], + "source": [ + "def softmax_kernel(x):\n", + " exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))\n", + " sum_x = np.sum(exp_x, axis=-1, keepdims=True)\n", + " return exp_x / sum_x" + ] + }, + { + "cell_type": "markdown", + "id": "c172321f-67af-48aa-83fb-e9e49c93eec4", + "metadata": {}, + "source": [ + "## Running a NKIPy Kernel as a NumPy function" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "42015e75", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input is [[0.9495564 0.39231038]\n", + " [0.05852599 0.9262922 ]]\n", + "NumPy output is [[0.635815 0.36418492]\n", + " [0.2957193 0.7042807 ]]\n" + ] + } + ], + "source": [ + "# NKIPy is NumPy-like, and in most cases, NumPy compatible\n", + "# So, we can run NKIPy kernel directly as NumPy\n", + "x = np.random.rand(2, 2).astype(np.float32)\n", + "print(f\"Input is {x}\")\n", + "\n", + "out_numpy = softmax_kernel(x)\n", + "print(f\"NumPy output is {out_numpy}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4c6b5597-61c4-40e8-9ce1-75573e38331b", + "metadata": {}, + "source": [ + "## Tracing a NKIPy Kernel" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9d15c533", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# To run NKIPy kernels on Trainium, we need to trace as a NKIPyKernel with the `trace` wrapper\n", + "traced_kernel = NKIPyKernel.trace(softmax_kernel)" + ] + }, + { + "cell_type": "markdown", + "id": "3c8809ee-5d41-4f0c-a23a-78a6ba54cedc", + "metadata": {}, + "source": [ + "## Running the Traced Kernel with Simulation" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8f1e4610", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Is the simulated output the same as NumPy? True\n" + ] + } + ], + "source": [ + "out_nkipy = simulate_traced_kernel(traced_kernel, x)\n", + "print(f\"Is the simulated output the same as NumPy? {np.allclose(out_nkipy, out_numpy)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "5e49b9bf-7323-4eac-825e-91abe3347e02", + "metadata": {}, + "source": [ + "## Running it On Trainium Hardware" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7c3c6a0a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Is the output the same as NumPy? True\n" + ] + } + ], + "source": [ + "# NKIPy kernel can be compiled to binary (NEFF) and execute on real hardware!\n", + "# The baremetal wrapper is used to execute the compiled binary on Trainium hardware\n", + "# in baremetal mode (without framework support)\n", + "out_baremetal = baremetal_run_traced_kernel(traced_kernel, x)\n", + "print(f\"Is the output the same as NumPy? {np.allclose(out_baremetal, out_numpy)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "274lm7aiz4s", + "source": "## Customizing Compiler Options\n\nNKIPy provides a `CompilerConfig` class to configure the neuronx-cc compiler with type-safe options.\nThis is useful for performance tuning and enabling advanced features.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "21dx5ny8mnl", + "source": "from nkipy.core.compile import CompilerConfig, get_default_compiler_args\n\n# View the default compiler arguments\nprint(f\"Default args: {get_default_compiler_args()}\")\n\n# View default configuration for NKIPy kernels\nconfig = CompilerConfig.for_nkipy()\nprint(f\"NKIPy preset: {config.to_args()}\")\n\n# View default configuration for NKI kernels\nconfig = CompilerConfig.for_nki()\nprint(f\"NKI preset: {config.to_args()}\")\n\n# Customize options - e.g., for transformer models with 2 NeuronCores\nconfig = CompilerConfig.for_nkipy(\n lnc=2,\n model_type=\"transformer\",\n enable_mixed_precision_accumulation=True,\n)\nprint(f\"Custom config: {config.to_args()}\")\n\n# Use custom config with baremetal_run_traced_kernel\n# out = baremetal_run_traced_kernel(traced_kernel, x, compiler_config=config)", + "metadata": {}, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "formats": "ipynb,auto:light", + "main_language": "python", + "notebook_metadata_filter": "-all" + }, + "kernelspec": { + "display_name": "venv_nkipy_local", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/nkipy/src/nkipy/core/compile.py b/nkipy/src/nkipy/core/compile.py index 0d3fe5a..dc78d35 100644 --- a/nkipy/src/nkipy/core/compile.py +++ b/nkipy/src/nkipy/core/compile.py @@ -9,7 +9,7 @@ import subprocess import sys import tempfile -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from pathlib import Path from typing import List, Optional @@ -34,17 +34,158 @@ def _set_build_dir(build_dir): _DEFAULT_BUILD_DIR = build_dir -# Compiler arguments -DEFAULT_ADDITIONAL_COMPILER_ARGS = "--lnc 1" -NKIPY_KERNEL_ADDITIONAL_COMPILER_ARGS = "--internal-tensorizer-opt-level=2" -NKI_KERNEL_ADDITIONAL_COMPILER_ARGS = "--internal-tensorizer-opt-level=nki" +@dataclass +class CompilerConfig: + """Structured compiler configuration for neuronx-cc. + + Provides type-safe configuration with sensible defaults for NKIPy and NKI kernels. + Use factory methods `for_nkipy()` and `for_nki()` for preset configurations. + + See neuronx-cc documentation for full option details: + https://awsdocs-neuron.readthedocs-hosted.com/en/latest/compiler/neuronx-cc/api-reference-guide/neuron-compiler-cli-reference-guide.html + + Example: + # Use preset with overrides + config = CompilerConfig.for_nkipy(model_type="transformer") + + # Or create custom config + config = CompilerConfig(lnc=2, model_type="transformer") + + # View the resulting args + print(config.to_args()) + """ + + # Core options + lnc: int = 1 # Logical NeuronCore config (1 or 2) + model_type: Optional[str] = None # "generic", "transformer", "unet-inference" + + # Precision options + auto_cast: Optional[str] = None # "none", "matmult", "all" + auto_cast_type: Optional[str] = None # "bf16", "fp16", "tf32", "fp8_e4m3" + enable_mixed_precision_accumulation: Optional[bool] = None # None = default + enable_saturate_infinity: bool = False + + # Performance options + optlevel: Optional[int] = None # 1 (fast), 2 (balanced), 3 (maximum) + enable_fast_context_switch: bool = False + enable_fast_loading_neuron_binaries: bool = False + + # Arbitrary extra arguments (for options not covered above) + extra_args: List[str] = field(default_factory=list) + + def to_args(self) -> str: + """Convert configuration to compiler argument string.""" + args = [] + + # Core options + args.append(f"--lnc {self.lnc}") + if self.model_type: + args.append(f"--model-type {self.model_type}") + + # Precision options + if self.auto_cast: + args.append(f"--auto-cast {self.auto_cast}") + if self.auto_cast_type: + args.append(f"--auto-cast-type {self.auto_cast_type}") + if self.enable_mixed_precision_accumulation is True: + args.append("--enable-mixed-precision-accumulation") + elif self.enable_mixed_precision_accumulation is False: + args.append("--disable-mixed-precision-accumulation") + if self.enable_saturate_infinity: + args.append("--enable-saturate-infinity") + + # Performance options + if self.optlevel is not None: + args.append(f"-O{self.optlevel}") + if self.enable_fast_context_switch: + args.append("--enable-fast-context-switch") + if self.enable_fast_loading_neuron_binaries: + args.append("--enable-fast-loading-neuron-binaries") + + # Extra args + args.extend(self.extra_args) + + return " ".join(args) + + @classmethod + def for_nkipy( + cls, + lnc: int = 1, + model_type: Optional[str] = None, + auto_cast: Optional[str] = None, + auto_cast_type: Optional[str] = None, + enable_mixed_precision_accumulation: Optional[bool] = None, + enable_saturate_infinity: bool = False, + optlevel: Optional[int] = None, + enable_fast_context_switch: bool = False, + enable_fast_loading_neuron_binaries: bool = False, + extra_args: Optional[List[str]] = None, + ) -> "CompilerConfig": + """Create configuration preset for NKIPy kernels. + + Default settings: + - lnc=1 + """ + return cls( + lnc=lnc, + model_type=model_type, + auto_cast=auto_cast, + auto_cast_type=auto_cast_type, + enable_mixed_precision_accumulation=enable_mixed_precision_accumulation, + enable_saturate_infinity=enable_saturate_infinity, + optlevel=optlevel, + enable_fast_context_switch=enable_fast_context_switch, + enable_fast_loading_neuron_binaries=enable_fast_loading_neuron_binaries, + extra_args=extra_args or [], + ) + + @classmethod + def for_nki( + cls, + lnc: int = 1, + model_type: Optional[str] = None, + auto_cast: Optional[str] = None, + auto_cast_type: Optional[str] = None, + enable_mixed_precision_accumulation: Optional[bool] = None, + enable_saturate_infinity: bool = False, + optlevel: Optional[int] = None, + enable_fast_context_switch: bool = False, + enable_fast_loading_neuron_binaries: bool = False, + extra_args: Optional[List[str]] = None, + ) -> "CompilerConfig": + """Create configuration preset for NKI kernels. + + Default settings: + - lnc=1 + """ + return cls( + lnc=lnc, + model_type=model_type, + auto_cast=auto_cast, + auto_cast_type=auto_cast_type, + enable_mixed_precision_accumulation=enable_mixed_precision_accumulation, + enable_saturate_infinity=enable_saturate_infinity, + optlevel=optlevel, + enable_fast_context_switch=enable_fast_context_switch, + enable_fast_loading_neuron_binaries=enable_fast_loading_neuron_binaries, + extra_args=extra_args or [], + ) + + +def get_default_compiler_args() -> str: + """Return the default compiler arguments string for NKIPy kernels. + + Useful for debugging and understanding what args will be passed to neuronx-cc. + + Returns: + The default compiler arguments as a string. + """ + return CompilerConfig.for_nkipy().to_args() + -nkipy_compiler_args = ( - DEFAULT_ADDITIONAL_COMPILER_ARGS + " " + NKIPY_KERNEL_ADDITIONAL_COMPILER_ARGS -) -nki_compiler_args = ( - DEFAULT_ADDITIONAL_COMPILER_ARGS + " " + NKI_KERNEL_ADDITIONAL_COMPILER_ARGS -) +# Legacy compatibility - computed from CompilerConfig +nkipy_compiler_args = CompilerConfig.for_nkipy().to_args() +nki_compiler_args = CompilerConfig.for_nki().to_args() class CompilationTarget(Enum): diff --git a/nkipy/src/nkipy/runtime/__init__.py b/nkipy/src/nkipy/runtime/__init__.py index 13cb9a4..26d52a8 100644 --- a/nkipy/src/nkipy/runtime/__init__.py +++ b/nkipy/src/nkipy/runtime/__init__.py @@ -1,5 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from nkipy.core.compile import CompilerConfig, get_default_compiler_args + from .decorators import baremetal_jit, simulate_jit from .execute import baremetal_run_traced_kernel, simulate_traced_kernel from .utils import is_neuron_compatible @@ -15,10 +17,12 @@ __all__ = [ "BaremetalExecutor", "CompiledKernel", + "CompilerConfig", "DeviceKernel", "DeviceTensor", "baremetal_jit", "baremetal_run_traced_kernel", + "get_default_compiler_args", "is_neuron_compatible", "simulate_jit", "simulate_traced_kernel", diff --git a/nkipy/src/nkipy/runtime/decorators.py b/nkipy/src/nkipy/runtime/decorators.py index ba53cbd..7e8937c 100644 --- a/nkipy/src/nkipy/runtime/decorators.py +++ b/nkipy/src/nkipy/runtime/decorators.py @@ -1,9 +1,10 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import functools +from typing import Optional from nkipy.core import compile -from nkipy.core.compile import trace +from nkipy.core.compile import CompilerConfig, trace from nkipy.runtime.execute import baremetal_run_traced_kernel, simulate_traced_kernel @@ -45,6 +46,7 @@ def wrapper(*args, **kwargs): def baremetal_jit( kernel_func=None, *, + compiler_config: Optional[CompilerConfig] = None, additional_compiler_args="", target=compile.CompilationTarget.DEFAULT, ): @@ -56,7 +58,11 @@ def baremetal_jit( Args: kernel_func: The kernel function to decorate (when used without parentheses) - additional_compiler_args: Additional arguments to pass to the compiler + compiler_config: Structured compiler configuration (recommended). + Use CompilerConfig.for_nkipy() or CompilerConfig.for_nki() for presets. + additional_compiler_args: Additional arguments to pass to the compiler (legacy). + If both compiler_config and additional_compiler_args are provided, + additional_compiler_args will be appended to the config's args. target: Compilation target (default: CompilationTarget.DEFAULT) Returns: @@ -70,8 +76,13 @@ def my_kernel(A, B): # Compiles on first call with this signature result = my_kernel(input_a, input_b) - # Or with compiler args: - @baremetal_jit(additional_compiler_args="--lnc 1") + # With structured config (recommended): + @baremetal_jit(compiler_config=CompilerConfig.for_nkipy(model_type="transformer")) + def my_kernel(A, B): + return A @ B + + # Legacy string args (still supported): + @baremetal_jit(additional_compiler_args="--model-type transformer") def my_kernel(A, B): return A @ B """ @@ -85,6 +96,7 @@ def wrapper(*args, **kwargs): return baremetal_run_traced_kernel( traced_kernel, *args, + compiler_config=compiler_config, additional_compiler_args=additional_compiler_args, target=target, **kwargs, diff --git a/nkipy/src/nkipy/runtime/device_kernel.py b/nkipy/src/nkipy/runtime/device_kernel.py index 460ef6b..defb0cf 100644 --- a/nkipy/src/nkipy/runtime/device_kernel.py +++ b/nkipy/src/nkipy/runtime/device_kernel.py @@ -5,9 +5,16 @@ import shutil import time import types +from typing import Optional from nkipy.core import compile -from nkipy.core.compile import CompilationTarget, _get_build_dir, compile_to_neff, trace +from nkipy.core.compile import ( + CompilationTarget, + CompilerConfig, + _get_build_dir, + compile_to_neff, + trace, +) from nkipy.core.logger import get_logger from nkipy.core.trace import NKIPyKernel from nkipy.runtime import device_tensor @@ -45,6 +52,7 @@ def compile_and_load( kernel, *args, name=None, + compiler_config: Optional[CompilerConfig] = None, additional_compiler_args=None, use_cached_if_exists=True, build_dir=None, @@ -56,14 +64,32 @@ def compile_and_load( Args: kernel: The kernel function to compile name: Optional name for the kernel. If None, uses kernel.__name__ - additional_compiler_args: Optional additional compiler arguments to append + compiler_config: Structured compiler configuration (recommended). + Use CompilerConfig.for_nkipy() or CompilerConfig.for_nki() for presets. + If not provided, auto-detects kernel type and uses appropriate preset. + additional_compiler_args: Optional additional compiler arguments (legacy). + If both compiler_config and additional_compiler_args are provided, + the additional args will be appended. use_cached_if_exists: If True, use cached neff if it exists. build_dir: Overriding the build directory for the kernel target: Compilation target for the kernel - \*args, \*\*kwargs: Arguments for specialization (numpy array or DeviceTensor) + *args, **kwargs: Arguments for specialization (numpy array or DeviceTensor) Returns: DeviceKernel: A DeviceKernel instance with the compiled kernel + + Example: + # With structured config (recommended): + kernel = DeviceKernel.compile_and_load( + my_func, x, w, + compiler_config=CompilerConfig.for_nkipy(model_type="transformer") + ) + + # Legacy string args (still supported): + kernel = DeviceKernel.compile_and_load( + my_func, x, w, + additional_compiler_args="--model-type transformer" + ) """ if name is None: # FIXME: this is likely to introduce unexpected conflict @@ -115,13 +141,21 @@ def compile_and_load( if isinstance(kernel, types.FunctionType): # Treat untraced function as NKIPy traced_kernel = trace(kernel) - compiler_args = compile.nkipy_compiler_args + is_nkipy = True elif isinstance(kernel, NKIPyKernel): traced_kernel = kernel - compiler_args = compile.nkipy_compiler_args + is_nkipy = True else: logger.info("Continue as NKI kernel") traced_kernel = kernel + is_nkipy = False + + # Resolve compiler args: compiler_config takes precedence, else auto-detect + if compiler_config is not None: + compiler_args = compiler_config.to_args() + elif is_nkipy: + compiler_args = compile.nkipy_compiler_args + else: compiler_args = compile.nki_compiler_args # Append user-provided additional compiler args if any diff --git a/nkipy/src/nkipy/runtime/execute.py b/nkipy/src/nkipy/runtime/execute.py index 4dc81a1..4f4b800 100644 --- a/nkipy/src/nkipy/runtime/execute.py +++ b/nkipy/src/nkipy/runtime/execute.py @@ -4,10 +4,12 @@ import os import shutil +from typing import Optional import numpy as np from nkipy.core import compile +from nkipy.core.compile import CompilerConfig from nkipy.core.ops._registry import set_backend from nkipy.core.trace import NKIPyKernel @@ -59,10 +61,29 @@ def baremetal_run_traced_kernel( *args, artifacts_dir=None, save_trace=False, + compiler_config: Optional[CompilerConfig] = None, additional_compiler_args="", target=compile.CompilationTarget.DEFAULT, **kwargs, ): + """Execute a traced kernel on Trainium hardware. + + Args: + kernel: The traced kernel to execute. + *args: Positional arguments for the kernel. + artifacts_dir: Directory to save compilation artifacts. + save_trace: Whether to save execution trace. + compiler_config: Structured compiler configuration (recommended). + If not provided, auto-detects kernel type and uses appropriate preset. + additional_compiler_args: Additional arguments to append (legacy). + If both compiler_config and additional_compiler_args are provided, + additional_compiler_args will be appended. + target: Compilation target (default: auto-detect). + **kwargs: Keyword arguments for the kernel. + + Returns: + The kernel output(s). + """ if not _RUNTIME_AVAILABLE: raise RuntimeError( "Runtime is not available. Please install Spike to use this function." @@ -87,15 +108,19 @@ def baremetal_run_traced_kernel( name = kernel.__name__ build_dir = artifacts_dir if artifacts_dir else f"{compile._get_build_dir()}/{name}" - if isinstance(kernel, compile.NKIPyKernel): - additional_compiler_args = ( - compile.nkipy_compiler_args + " " + additional_compiler_args - ) + + # Resolve compiler args: compiler_config takes precedence, else auto-detect + if compiler_config is not None: + final_compiler_args = compiler_config.to_args() + elif isinstance(kernel, compile.NKIPyKernel): + final_compiler_args = compile.nkipy_compiler_args else: # assume is NKI - additional_compiler_args = ( - compile.nki_compiler_args + " " + additional_compiler_args - ) + final_compiler_args = compile.nki_compiler_args + + # Append legacy additional_compiler_args if provided + if additional_compiler_args: + final_compiler_args = final_compiler_args + " " + additional_compiler_args # always clean the build dir in baremetal mode if os.path.exists(build_dir): @@ -106,7 +131,7 @@ def baremetal_run_traced_kernel( output_dir=build_dir, neff_name=f"{name}.neff", save_artifacts=True, - additional_compiler_args=additional_compiler_args, + additional_compiler_args=final_compiler_args, target=target, )