diff --git a/.gitignore b/.gitignore index 84c949d..2f2fa79 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ SlurmOutput/ temp/ .vscode/ *.ncu-rep +artifacts \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 8a37b06..92f5532 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,18 +5,21 @@ FROM continuumio/miniconda3 WORKDIR /app # Create the conda environment -COPY environment.yml /app/environment.yml +COPY . /app RUN conda env create -f /app/environment.yml -# Initialize conda in bash shell -RUN echo "source activate llmcompass_ae" > ~/.bashrc +# Do not rely on `source activate` in non-interactive shells; set PATH to the env's bin ENV PATH /opt/conda/envs/llmcompass_ae/bin:$PATH -# Clone your GitHub repository -RUN git clone https://github.com/HenryChang213/LLMCompass_ISCA_AE.git /app/LLMCompass_ISCA_AE -RUN cd /app/LLMCompass_ISCA_AE && git submodule init && git submodule update --recursive +# Install lightweight Python deps for the API server inside the conda env +RUN /opt/conda/envs/llmcompass_ae/bin/pip install \ + fastapi \ + "uvicorn[standard]" \ + aiosqlite \ + requests -# Expose the port your app runs on +# Expose the port your app runs on and run uvicorn as entrypoint EXPOSE 8000 - +# Start the FastAPI server using Uvicorn when the container launches +CMD ["uvicorn", "backend_app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/backend_app/README.md b/backend_app/README.md new file mode 100644 index 0000000..f8fcf80 --- /dev/null +++ b/backend_app/README.md @@ -0,0 +1,157 @@ +# LLMCompass backend_app + +This directory contains a minimal backend HTTP service (FastAPI + simulation scheduler). +This README explains how to build and run the backend (Docker-only), how to call the API, +and how to extend the codebase with new synchronous simulators. + +## Prerequisites +- Docker (required runtime) +- Python 3.8+ (for development/testing inside the image) + +## Docker (build & run) + +The backend is supported to run only via Docker. Build the image from the repository root: + +```bash +sudo docker build -t llmcompass-backend . +``` + +Run the docker container, which exposes 8000 to host for backend interaction: + +```bash +sudo docker run --rm -p 8000:8000 llmcompass-backend +``` + +## Environment variables +- `API_PORT` — port used by tests/uvicorn inside the container (default: 8000) +- `API_URL` — if set in tests, the test suite will target this external URL instead of starting a local server +- `ARTIFACT_DIR` — directory where tests write artifacts (default: `artifacts/`) + +## HTTP API endpoints +- `GET /health` — health check, returns `{status: "ok"}` +- `GET /supported_ops` — list of supported operations (e.g. `matmul`, `gelu`) +- `POST /tasks` — submit a simulation task (returns `task_id`) +- `GET /tasks/{task_id}` — query task status and result + +### Usage + +1. **Submit a task**: POST the payload to the /tasks endpoint (for example `requests.post(f"{BASE}/tasks", json=payload, timeout=5)`). On success (HTTP 200) parse the response JSON and read the returned `task_id`. +2. **Poll the task status**: GET `/tasks/{task_id}` (for example `requests.get(f"{BASE}/tasks/{task_id}", timeout=5)`) until the returned `status` is `done` (or `failed`). Handle non-200 responses and timeouts as needed. +3. **Parse the result**: when the task `status` is `done`, inspect the `result` object. The `result.status` field indicates whether the simulation succeeded (e.g., `success`) or failed. The `time_taken` field is the simulation duration in seconds. Other fields in `result` (for example `metadata`) describe the simulated kernel and the target system. + +Below is an example task payload for a matmul operation. When constructing payload, ensure the input_dim follows the simulator’s expected format for this operation (e.g., a list of two two-element dimension arrays). + +```json +{ + "kernel_name": "itest_matmul", + "op": "matmul", + "input_dim": [[1, 2048], [2048, 7168]], + "dtype": ["c10::BFloat16", "c10::BFloat16"], + "system_key": "A100_4_fp16" +} +``` + +Below is an example of a completed matmul task, illustrating the response fields and their structure: +```json +{ + "task_id": "089d0b13-2ef9-43e1-bde3-44ed7219e959", + "status": "done", + "result": { + "status": "success", + "output": { + "summary": "matmul simulated" + }, + "simulated_time": 1.4408317802844531e-05, + }, + "user_submitted_request": { + "kernel_name": "itest_matmul_M_1", + "op": "matmul", + "input_dim": [ + [ + 1, + 2048 + ], + [ + 2048, + 7168 + ] + ], + "dtype": [ + "c10::BFloat16", + "c10::BFloat16" + ], + "system_key": "A100_4_fp16" + }, + "created_at": "2025-09-17T02:23:11.777675", + "updated_at": "2025-09-17T02:23:11.778457" +} +``` + +For failed tasks, the `result` includes a `failure_reason` object to aid root‑cause analysis. It typically contains fields like `kernel_name`, `error` (message), and `error_code`. Example: +```json +{ + ... + "result": { + "status": "failed", + "output": null, + "simulated_time": null, + "failure_reason": { + "error": "unsupported op - no generic simulator available", + "error_code": "UNSUPPORTED_OP" + } + }, + ... +} +``` + +Task states: `queued`, `running`, `done`, `failed` (scheduler/worker-dependent). When a task is submitted it enters the `queued` state. A free worker picks up the next queued task and the state transitions to `running`. After execution finishes, the task becomes `done` on success or `failed` on error. + +## Code layout and runtime flow + +Key modules: +- `backend_app/scheduler.py` — async entry points and dispatcher (`simulate_kernel_trace`, `process_kernel_simulation_task`). +- `backend_app/sim_utils.py` — shared helpers: dtype mapping, tensor construction, unified failure response helper `_make_failure`. +- `backend_app/sync_simulators.py` — synchronous `_simulate_*` implementations (e.g. `_simulate_matmul_sync`) and `_select_sync_simulator`. + +Runtime flow (simplified): +1. Worker receives a task, constructs `kernel_task` dict and calls `process_kernel_simulation_task`. +2. `process_kernel_simulation_task` calls `simulate_kernel_trace` (async). +3. `simulate_kernel_trace` selects a synchronous implementation via `_select_sync_simulator` and runs it in a thread using `asyncio.to_thread` to avoid blocking the event loop. +4. Synchronous implementations perform compile/simulate and return a standardized dict: `{status, output, time_taken, metadata}`. + +All failure responses are created via `_make_failure(kernel_name, error, error_code)` to keep format consistent. + +## Adding a new synchronous simulator + +1. Implement a new `_simulate__sync` function in `backend_app/sync_simulators.py` with the same signature as existing ones: + +```py +def _simulate_conv_sync(kernel_name, input_dim, dtype_str, system_key=None): + # use backend_app.sim_utils helpers: _map_dtype, _make_tensor, _make_failure + ... + return {"status": "success", ...} +``` + +2. Update `_select_sync_simulator` in the same file to return your function when appropriate (e.g. `if "conv" in kn:`). + +3. Optionally add the op keyword to `get_supported_ops()` in `backend_app/sim_utils.py`. + +4. Add unit and/or integration tests to cover happy-path and failure cases. + +5. If new dependencies are required, update `requirements.txt` / `pyproject.toml` and Dockerfile. + +Important: keep the synchronous implementation's return schema consistent so the async wrapper can handle it uniformly. + +## Error codes + +Error codes are currently plain strings (e.g. `INVALID_INPUT`, `NO_SYSTEM`, `SIMULATOR_ERROR`). + +## Tests + +Run tests from the repository root inside the Docker container or a development image: + +```bash +pytest tests/ +``` + +Integration tests write artifacts to the directory specified by `ARTIFACT_DIR` to aid debugging. \ No newline at end of file diff --git a/backend_app/main.py b/backend_app/main.py new file mode 100644 index 0000000..617ef0d --- /dev/null +++ b/backend_app/main.py @@ -0,0 +1,218 @@ +import uuid +import asyncio +import json +import datetime +import os +from typing import List, Union, Optional, Any +from contextlib import asynccontextmanager +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from backend_app.scheduler import simulate_kernel_trace, process_kernel_simulation_task + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # in-memory tasks store (not persisted): lost on restart + app.state.tasks = {} + app.state.tasks_lock = asyncio.Lock() + + # create queue and start background workers + app.state.queue = asyncio.Queue() + # number of concurrent background consumers (in-process). Use environment var WORKER_COUNT + try: + worker_count = int(os.environ.get("WORKER_COUNT", "32")) + except Exception: + worker_count = 1 + app.state.worker_tasks = [] + for i in range(max(1, worker_count)): + app.state.worker_tasks.append( + asyncio.create_task(worker_loop(app.state.queue, app.state.tasks, app.state.tasks_lock, worker_id=i)) + ) + + try: + yield + finally: + # shutdown: cancel background worker + workers = getattr(app.state, "worker_tasks", None) or [] + for worker in workers: + worker.cancel() + for worker in workers: + try: + await worker + except asyncio.CancelledError: + pass + + +app = FastAPI(title="LLMCompass Kernel Simulator", lifespan=lifespan) + + +class KernelTask(BaseModel): + kernel_name: str + op: str + input_dim: Optional[Any] = None + # some clients send a single dtype string, others send a list of dtype strings + dtype: Optional[Union[str, List[str]]] = "fp32" + # optional system key + system_key: Optional[str] = None + + +async def worker_loop(queue: asyncio.Queue, tasks: dict, lock: asyncio.Lock, worker_id: int = 0): + while True: + task_id = await queue.get() + try: + async with lock: + entry = tasks.get(task_id) + if not entry: + queue.task_done() + continue + payload = entry["payload"] + # mark as running and record worker id / start time + tasks[task_id]["status"] = "running" + tasks[task_id]["worker"] = worker_id + tasks[task_id]["started_at"] = datetime.datetime.utcnow().isoformat() + # process (outside lock) + result = await process_kernel_simulation_task(payload) + async with lock: + if task_id in tasks: + tasks[task_id]["status"] = "done" + tasks[task_id]["result"] = result + tasks[task_id][ + "updated_at" + ] = datetime.datetime.utcnow().isoformat() + except Exception as e: + async with lock: + if task_id in tasks: + tasks[task_id]["status"] = "failed" + tasks[task_id]["result"] = {"error": str(e)} + tasks[task_id][ + "updated_at" + ] = datetime.datetime.utcnow().isoformat() + finally: + queue.task_done() + + +@app.post("/tasks") +async def create_task(t: KernelTask, wait: bool = False, timeout: float = 30.0): + """ + Create a kernel simulation task. + If `wait` is False (default) the task is queued and returns immediately with status queued. + If `wait` is True the request will block up to `timeout` seconds and return the final status/result. + """ + task_id = str(uuid.uuid4()) + payload = t.dict() + created_at = datetime.datetime.utcnow().isoformat() + + # insert into in-memory store + async with app.state.tasks_lock: + app.state.tasks[task_id] = { + "payload": payload, + "status": "queued", + "result": None, + "created_at": created_at, + "updated_at": created_at, + } + + if not wait: + # enqueue for background processing + await app.state.queue.put(task_id) + return {"task_id": task_id, "status": "queued"} + + # synchronous path: process inline with timeout + try: + # mark as running for synchronous (wait) path + async with app.state.tasks_lock: + if task_id in app.state.tasks: + app.state.tasks[task_id]["status"] = "running" + app.state.tasks[task_id]["worker"] = "inline" + app.state.tasks[task_id]["started_at"] = datetime.datetime.utcnow().isoformat() + + result = await asyncio.wait_for( + process_kernel_simulation_task(payload), timeout=timeout + ) + except asyncio.TimeoutError: + # leave as queued for background worker to pick up later + return { + "task_id": task_id, + "status": "timeout", + "message": f"processing did not finish within {timeout}s", + } + except Exception as e: + # update in-memory store as failed + async with app.state.tasks_lock: + if task_id in app.state.tasks: + app.state.tasks[task_id]["status"] = "failed" + app.state.tasks[task_id]["result"] = {"error": str(e)} + app.state.tasks[task_id][ + "updated_at" + ] = datetime.datetime.utcnow().isoformat() + raise HTTPException(status_code=500, detail=str(e)) + + # write result into in-memory store and return + async with app.state.tasks_lock: + if task_id in app.state.tasks: + app.state.tasks[task_id]["status"] = "done" + app.state.tasks[task_id]["result"] = result + app.state.tasks[task_id][ + "updated_at" + ] = datetime.datetime.utcnow().isoformat() + + return {"task_id": task_id, "status": "done", "result": result} + + +@app.get("/supported_ops") +async def supported_ops(): + from backend_app.sim_utils import get_supported_ops + + return {"supported_ops": get_supported_ops()} + + +@app.get("/tasks/{task_id}") +async def get_task(task_id: str): + async with app.state.tasks_lock: + entry = app.state.tasks.get(task_id) + if not entry: + raise HTTPException(status_code=404, detail="task not found") + status = entry.get("status") + result = entry.get("result") + payload = entry.get("payload") + created_at = entry.get("created_at") + updated_at = entry.get("updated_at") + + return { + "task_id": task_id, + "status": status, + "result": result, + "user_submitted_request": payload, + "created_at": created_at, + "updated_at": updated_at, + } + + +@app.get("/health") +async def health(): + # provide richer health info: queue size, worker tasks status, and task counts + queue = getattr(app.state, "queue", None) + workers = getattr(app.state, "worker_tasks", None) or [] + tasks_store = getattr(app.state, "tasks", None) or {} + + # summarize task states + counts = {"queued": 0, "running": 0, "done": 0, "failed": 0} + for entry in tasks_store.values(): + st = entry.get("status") + if st in counts: + counts[st] += 1 + + worker_info = [] + for w in workers: + try: + worker_info.append({"done": w.done(), "cancelled": w.cancelled()}) + except Exception: + worker_info.append({"done": None, "cancelled": None}) + + return { + "status": "ok", + "queue_length": queue.qsize() if queue is not None else None, + "worker_count": len(workers), + "workers": worker_info, + "task_counts": counts, + } diff --git a/backend_app/scheduler.py b/backend_app/scheduler.py new file mode 100644 index 0000000..4be1532 --- /dev/null +++ b/backend_app/scheduler.py @@ -0,0 +1,66 @@ +import asyncio +import time +from typing import Any, Dict + +# top-level async dispatcher imports helpers and sync simulators from smaller modules +from backend_app.sim_utils import _make_failure +from backend_app.sync_simulators import _select_sync_simulator + + +async def simulate_kernel_trace( + kernel_name: str, op: str, input_dim: list[list], dtype: list[str], system_key: str +) -> Dict[str, Any]: + """ + Dispatch to real software_model simulations. Blocking compile_and_simulate calls are executed + in a thread via asyncio.to_thread so the event loop is not blocked. + Returns standardized result: {status, output, simulated_time} + """ + # prefer op if provided (more specific), else fall back to kernel_name + selector = op if op else kernel_name + simulator = _select_sync_simulator(selector) + # run simulator in thread; pass system_key when simulator expects it + result = await asyncio.to_thread( + simulator, kernel_name, input_dim, dtype, system_key + ) + + # simulator returns dict; validate and propagate structured failure + if not isinstance(result, dict): + return _make_failure( + "simulator returned invalid result type", "SIMULATOR_ERROR" + ) + # ensure failure results have error_code when possible + if result.get("status") == "failed": + md = result.setdefault("metadata", {}) + md.setdefault("error_code", md.get("error_code", "SIMULATOR_FAILED")) + return result + + +async def process_kernel_simulation_task(kernel_task: Dict[str, Any]) -> Dict[str, Any]: + """ + Public entrypoint used by the worker. Calls simulate_kernel_trace and normalizes output. + """ + kernel_name = kernel_task.get("kernel_name", "") + op = kernel_task.get("op", "") + input_dim = kernel_task.get("input_dim", []) + dtype = kernel_task.get("dtype", []) + system_key = kernel_task.get("system_key") + start = time.time() + res = await simulate_kernel_trace( + kernel_name, op, input_dim, dtype, system_key=system_key + ) + end = time.time() + # normalize to expected schema + if res.get("status") == "failed": + out = { + "status": "failed", + "output": None, + "simulated_time": None, + "failure_reason": res.get("metadata", {}), + } + else: + out = { + "status": res.get("status"), + "output": res.get("output"), + "simulated_time": res.get("simulated_time"), + } + return out diff --git a/backend_app/sim_utils.py b/backend_app/sim_utils.py new file mode 100644 index 0000000..1f7e085 --- /dev/null +++ b/backend_app/sim_utils.py @@ -0,0 +1,43 @@ +from typing import Any + +# import minimal software/hardware helpers used by simulators +from software_model.utils import Tensor, data_type_dict + + +def _map_dtype(dtype_str: str): + s = dtype_str.lower() + if "fp16" in s or "float16" in s: + return data_type_dict.get("fp16") + else: + return None + + +def _make_tensor(shape, dtype_obj): + try: + return Tensor(list(shape), dtype_obj) + except Exception: + return None + + +# centralized failure helper to avoid repeated dict literals +def _make_failure(error: str, error_code: str): + return { + "status": "failed", + "output": None, + "time_taken": None, + "metadata": { + "error": error, + "error_code": error_code, + }, + } + + +def _make_missing_system(system_key: str): + return _make_failure( + f"missing system configuration '{system_key}'", "NO_SYSTEM" + ) + + +def get_supported_ops() -> list: + """Return list of supported op keywords for routing/diagnostics.""" + return ["gelu", "layernorm", "matmul", "softmax"] diff --git a/backend_app/sync_simulators.py b/backend_app/sync_simulators.py new file mode 100644 index 0000000..e05a79f --- /dev/null +++ b/backend_app/sync_simulators.py @@ -0,0 +1,218 @@ +from typing import Any, Dict + +from backend_app.sim_utils import ( + _map_dtype, + _make_tensor, + _make_failure, + _make_missing_system, +) +from software_model.matmul import Matmul, BatchedMatmul +from software_model.softmax import Softmax +from software_model.layernorm import LayerNorm +from software_model.gelu import GeLU +from hardware_model.system import system_dict + + +def _simulate_matmul_sync( + kernel_name: str, + input_dim: list[list], + dtype_str: list[str], + system_key: str = None, +) -> Dict[str, Any]: + dt = _map_dtype(dtype_str[0]) + A_shape = input_dim[0] + B_shape = input_dim[1] + + if dt is None: + return _make_failure( + "invalid or unsupported dtype", "INVALID_INPUT" + ) + elif A_shape is None or B_shape is None: + return _make_failure("invalid input dimension", "INVALID_INPUT") + + op = Matmul(dt) + A = _make_tensor(A_shape, dt) + B = _make_tensor(B_shape, dt) + _ = op(A, B) + + if not system_key: + return _make_failure("no valid system_key provided", "NO_SYSTEM") + system = system_dict.get(system_key) + if system is None: + return _make_missing_system(system_key) + + device = system.device + latency = op.compile_and_simulate(device, compile_mode="heuristic-GPU") + return { + "status": "success", + "output": {"summary": "matmul simulated"}, + "simulated_time": float(latency), + } + + +def _simulate_bmm_sync( + kernel_name: str, + input_dim: list[list], + dtype_str: list[str], + system_key: str = None, +) -> Dict[str, Any]: + dt = _map_dtype(dtype_str[0]) + A_shape = input_dim[0] + B_shape = input_dim[1] + + if dt is None: + return _make_failure( + "invalid or unsupported dtype", "INVALID_INPUT" + ) + elif A_shape is None or B_shape is None: + return _make_failure("invalid input dimension", "INVALID_INPUT") + + op = BatchedMatmul(dt) + A = _make_tensor(A_shape, dt) + B = _make_tensor(B_shape, dt) + _ = op(A, B) + + if not system_key: + return _make_failure("no valid system_key provided", "NO_SYSTEM") + system = system_dict.get(system_key) + if system is None: + return _make_missing_system(system_key) + + device = system.device + latency = op.compile_and_simulate(device, compile_mode="heuristic-GPU") + return { + "status": "success", + "output": {"summary": "bmm simulated"}, + "simulated_time": float(latency), + } + + +def _simulate_layernorm_sync( + kernel_name: str, + input_dim: list, + dtype_str: str, + system_key: str = None, +) -> Dict[str, Any]: + dt = _map_dtype(dtype_str) + A_shape = input_dim + + if dt is None: + return _make_failure( + "invalid or unsupported dtype", "INVALID_INPUT" + ) + elif A_shape is None: + return _make_failure("invalid input dimension", "INVALID_INPUT") + + op = LayerNorm(dt) + A = _make_tensor(A_shape, dt) + _ = op(A) + + if not system_key: + return _make_failure("no valid system_key provided", "NO_SYSTEM") + system = system_dict.get(system_key) + if system is None: + return _make_missing_system(system_key) + + device = system.device + latency = op.compile_and_simulate(device, compile_mode="heuristic-GPU") + return { + "status": "success", + "output": {"summary": "LayerNorm simulated"}, + "simulated_time": float(latency), + } + + +def _simulate_gelu_sync( + kernel_name: str, + input_dim: list, + dtype_str: str, + system_key: str = None, +) -> Dict[str, Any]: + dt = _map_dtype(dtype_str) + A_shape = input_dim + + if dt is None: + return _make_failure( + "invalid or unsupported dtype", "INVALID_INPUT" + ) + elif A_shape is None: + return _make_failure("invalid input dimension", "INVALID_INPUT") + + op = GeLu(dt) + A = _make_tensor(A_shape, dt) + _ = op(A) + + if not system_key: + return _make_failure("no valid system_key provided", "NO_SYSTEM") + system = system_dict.get(system_key) + if system is None: + return _make_missing_system(system_key) + + device = system.device + latency = op.compile_and_simulate(device, compile_mode="heuristic-GPU") + return { + "status": "success", + "output": {"summary": "GeLU simulated"}, + "simulated_time": float(latency), + } + + +def _simulate_softmax_sync( + kernel_name: str, + input_dim: list, + dtype_str: str, + system_key: str = None, +) -> Dict[str, Any]: + dt = _map_dtype(dtype_str) + A_shape = input_dim + + if dt is None: + return _make_failure( + "invalid or unsupported dtype", "INVALID_INPUT" + ) + elif A_shape is None: + return _make_failure("invalid input dimension", "INVALID_INPUT") + + op = Softmax(dt) + A = _make_tensor(A_shape, dt) + _ = op(A) + + if not system_key: + return _make_failure("no valid system_key provided", "NO_SYSTEM") + system = system_dict.get(system_key) + if system is None: + return _make_missing_system(system_key) + + device = system.device + latency = op.compile_and_simulate(device, compile_mode="heuristic-GPU") + return { + "status": "success", + "output": {"summary": "Softmax simulated"}, + "simulated_time": float(latency), + } + + +def _simulate_fail( + kernel_name: str, _input_dim=None, _dtype_str: str = "", system_key: str = None, +) -> Dict[str, Any]: + return _make_failure( + "unsupported op - no generic simulator available", "UNSUPPORTED_OP" + ) + + +def _select_sync_simulator(kernel_name: str): + if not kernel_name: + return _simulate_fail + kn = kernel_name.lower() + if kn == "matmul": + return _simulate_matmul_sync + elif kn == "bmm": + return _simulate_bmm_sync + elif kn == "layernorm": + return _simulate_layernorm_sync + elif kn == "gelu": + return _simulate_gelu_sync + elif kn == "softmax": + return _simulate_softmax_sync + # conv and other ops are not supported unless explicitly implemented + return _simulate_fail diff --git a/environment.yml b/environment.yml index b168460..e7bfae2 100644 --- a/environment.yml +++ b/environment.yml @@ -1,12 +1,12 @@ name: llmcompass_ae channels: - - pytorch - defaults dependencies: - python=3.9 - - pytorch - pip: - - scalesim + - torch==2.5.1 + - scalesim==2.0.2 - matplotlib - seaborn - - scipy \ No newline at end of file + - scipy + - pytest \ No newline at end of file diff --git a/tests/test_api_integration.py b/tests/test_api_integration.py new file mode 100644 index 0000000..cc62b11 --- /dev/null +++ b/tests/test_api_integration.py @@ -0,0 +1,429 @@ +import os +import time +import subprocess +import sys +import signal +import json +from pathlib import Path + +import pytest +import requests + + +# Server config: if API_URL is set we will target that and not start a server. +SERVER_HOST = "127.0.0.1" +SERVER_PORT = int(os.environ.get("API_PORT", "8000")) +BASE = os.environ.get("API_URL", f"http://{SERVER_HOST}:{SERVER_PORT}") + +# artifacts directory for intermediate results +ARTIFACT_DIR = Path(os.environ.get("ARTIFACT_DIR", "artifacts")) + + +def _ensure_artifacts_dir(): + ARTIFACT_DIR.mkdir(parents=True, exist_ok=True) + + +@pytest.fixture(scope="session", autouse=True) +def server(): + """Start a uvicorn server for the test session when API_URL is not set. + + If API_URL is provided (pointing to an external service), the fixture does nothing. + """ + _ensure_artifacts_dir() + if os.environ.get("API_URL"): + # External server provided, do not start local uvicorn. + yield + return + + cmd = [ + sys.executable, + "-m", + "uvicorn", + "backend_app.main:app", + "--host", + SERVER_HOST, + "--port", + str(SERVER_PORT), + ] + + # redirect uvicorn output to artifact files + out_path = ARTIFACT_DIR / "uvicorn.out" + err_path = ARTIFACT_DIR / "uvicorn.err" + fout = open(out_path, "wb") + ferr = open(err_path, "wb") + proc = subprocess.Popen(cmd, stdout=fout, stderr=ferr, preexec_fn=None) + + # wait for health endpoint + deadline = time.time() + 10 + while time.time() < deadline: + try: + r = requests.get(f"http://{SERVER_HOST}:{SERVER_PORT}/health", timeout=1) + if r.status_code == 200: + break + except Exception: + pass + time.sleep(0.1) + else: + # failed to start in time; capture stderr for debugging + try: + ferr.flush() + with open(err_path, "rb") as f: + err = f.read() + except Exception: + proc.kill() + err = b"(no stderr available)" + fout.close() + ferr.close() + raise RuntimeError( + f"uvicorn failed to start in time. stderr:\n{err.decode(errors='ignore')}" + ) + + try: + yield + finally: + # terminate the server process + proc.terminate() + try: + proc.wait(timeout=5) + except Exception: + proc.kill() + proc.wait() + fout.close() + ferr.close() + + +def _url(path: str) -> str: + return BASE.rstrip("/") + path + + +def test_health(): + r = requests.get(_url("/health"), timeout=5) + with open(ARTIFACT_DIR / "health.json", "w") as f: + json.dump({"status_code": r.status_code, "body": r.json()}, f, indent=2) + assert r.status_code == 200 + assert r.json().get("status") == "ok" + + +def test_supported_ops(): + r = requests.get(_url("/supported_ops"), timeout=5) + with open(ARTIFACT_DIR / "supported_ops.json", "w") as f: + json.dump({"status_code": r.status_code, "body": r.json()}, f, indent=2) + assert r.status_code == 200 + j = r.json() + assert isinstance(j.get("supported_ops"), list) + + +@pytest.mark.parametrize( + "matmul_payload", + [ + { + "kernel_name": "itest_matmul_M_1", + "op": "matmul", + "input_dim": [[1, 2048], [2048, 7168]], + "dtype": ["c10::BFloat16", "c10::BFloat16"], + "system_key": "A100_4_fp16", + }, + { + "kernel_name": "itest_matmul_M_128", + "op": "matmul", + "input_dim": [[128, 128], [128, 128]], + "dtype": ["c10::BFloat16", "c10::BFloat16"], + "system_key": "A100_4_fp16", + }, + { + "kernel_name": "itest_matmul_fp8_unsupported", + "op": "matmul", + "input_dim": [[128, 128], [128, 128]], + "dtype": ["c10::Float8_e4m3fn", "c10::Float8_e4m3fn"], + "system_key": "A100_4_fp16", + }, + ], +) +def test_create_task_and_poll_matmul(matmul_payload): + """Submit a matmul task and poll for completion; save artifacts for debugging.""" + payload = matmul_payload + + with open(ARTIFACT_DIR / "matmul_create_task_request.json", "w") as f: + json.dump( + {"url": _url("/tasks"), "payload": payload, "op": "matmul"}, f, indent=2 + ) + + r = requests.post(_url("/tasks"), json=payload, timeout=5) + with open(ARTIFACT_DIR / "matmul_create_task_response.json", "w") as f: + try: + body = r.json() + except Exception: + body = {"text": r.text} + json.dump({"status_code": r.status_code, "body": body}, f, indent=2) + + assert r.status_code == 200 + j = r.json() + task_id = j.get("task_id") + assert task_id + + # poll briefly for terminal status + deadline = time.time() + 20 + last = None + while time.time() < deadline: + r = requests.get(_url(f"/tasks/{task_id}"), timeout=5) + if r.status_code == 200: + info = r.json() + status = info.get("status") + with open(ARTIFACT_DIR / f"task_{task_id}_poll_matmul.json", "w") as f: + json.dump({"status_code": r.status_code, "body": info}, f, indent=2) + if status == "done": + assert "result" in info + simulated_time = info.get("result", {}).get("time_taken") + if payload.get("kernel_name") == "itest_matmul_fp8_unsupported": + # this kernel is expected to be unsupported + assert ( + info.get("result", {}).get("metadata").get("error_code") == "INVALID_INPUT" + ), f"error_code={info.get('result', {}).get('metadata').get('error_code')}" + elif payload.get("kernel_name") == "itest_matmul_M_1": + # this kernel is expected to be very fast + assert ( + simulated_time == 1.4408317802844531e-05 + ), f"simulated_time={simulated_time}" + elif payload.get("kernel_name") == "itest_matmul_M_128": + # this kernel is expected to be fast + assert ( + simulated_time == 1.1276595744680851e-07 + ), f"simulated_time={simulated_time}" + break + last = status + time.sleep(1) + + assert last in ("done", "failed", "queued", None) + + +@pytest.mark.parametrize( + "bmm_payload", + [ + { + "kernel_name": "itest_bmm_M_1", + "op": "bmm", + # batch=1, m=1, k=2048 ; batch=1, k=2048, n=7168 + "input_dim": [[1, 1, 2048], [1, 2048, 7168]], + "dtype": ["c10::BFloat16", "c10::BFloat16"], + "system_key": "A100_4_fp16", + }, + { + "kernel_name": "itest_bmm_fp8_unsupported", + "op": "bmm", + "input_dim": [[128, 128, 128], [128, 128, 128]], + "dtype": ["c10::Float8_e4m3fn", "c10::Float8_e4m3fn"], + "system_key": "A100_4_fp16", + }, + ], +) +def test_create_task_and_poll_bmm(bmm_payload): + """Submit a batched-matmul (bmm) task and poll for completion; save artifacts for debugging.""" + payload = bmm_payload + + with open(ARTIFACT_DIR / "bmm_create_task_request.json", "w") as f: + json.dump({"url": _url("/tasks"), "payload": payload, "op": "bmm"}, f, indent=2) + + r = requests.post(_url("/tasks"), json=payload, timeout=5) + with open(ARTIFACT_DIR / "bmm_create_task_response.json", "w") as f: + try: + body = r.json() + except Exception: + body = {"text": r.text} + json.dump({"status_code": r.status_code, "body": body}, f, indent=2) + + assert r.status_code == 200 + j = r.json() + task_id = j.get("task_id") + assert task_id + + # poll briefly for terminal status + deadline = time.time() + 20 + last = None + while time.time() < deadline: + r = requests.get(_url(f"/tasks/{task_id}"), timeout=5) + if r.status_code == 200: + info = r.json() + status = info.get("status") + with open(ARTIFACT_DIR / f"task_{task_id}_poll_bmm.json", "w") as f: + json.dump({"status_code": r.status_code, "body": info}, f, indent=2) + if status == "done": + assert "result" in info + simulated_time = info.get("result", {}).get("time_taken") + if payload.get("kernel_name") == "itest_bmm_fp8_unsupported": + # this kernel is expected to be unsupported + assert ( + info.get("result", {}).get("metadata", {}).get("error_code") + == "INVALID_INPUT" + ) + else: + # for supported kernels we expect a numeric simulation time + assert simulated_time is not None + assert isinstance(simulated_time, (int, float)) + break + last = status + time.sleep(1) + + assert last in ("done", "failed", "queued", None) + + +@pytest.mark.parametrize( + "payload", + [ + { + "kernel_name": "itest_gelu_default", + "op": "gelu", + "input_dim": [1024], + "dtype": "fp16", + "system_key": "A100_4_fp16", + }, + ], +) +def test_create_task_and_poll_gelu(payload): + """Template test for gelu: submit and poll for completion.""" + with open(ARTIFACT_DIR / "gelu_create_task_request.json", "w") as f: + json.dump( + {"url": _url("/tasks"), "payload": payload, "op": "gelu"}, f, indent=2 + ) + + r = requests.post(_url("/tasks"), json=payload, timeout=5) + with open(ARTIFACT_DIR / "gelu_create_task_response.json", "w") as f: + try: + body = r.json() + except Exception: + body = {"text": r.text} + json.dump({"status_code": r.status_code, "body": body}, f, indent=2) + + assert r.status_code == 200 + j = r.json() + task_id = j.get("task_id") + assert task_id + + deadline = time.time() + 20 + last = None + while time.time() < deadline: + r = requests.get(_url(f"/tasks/{task_id}"), timeout=5) + if r.status_code == 200: + info = r.json() + status = info.get("status") + with open(ARTIFACT_DIR / f"task_{task_id}_poll_gelu.json", "w") as f: + json.dump({"status_code": r.status_code, "body": info}, f, indent=2) + if status == "done": + assert "result" in info + simulated_time = info.get("result", {}).get("time_taken") + assert simulated_time is not None + break + last = status + time.sleep(1) + assert last in ("done", "failed", "queued", None) + + +@pytest.mark.parametrize( + "payload", + [ + { + "kernel_name": "itest_layernorm_default", + "op": "layernorm", + "input_dim": [1, 1024, 7168], + "dtype": "fp16", + "system_key": "A100_4_fp16", + }, + { + "kernel_name": "itest_layernorm_unsupported", + "op": "layernorm", + "input_dim": [1, 1024, 7168], + "dtype": "fp8", + "system_key": "A100_4_fp16", + }, + ], +) +def test_create_task_and_poll_layernorm(payload): + """Template test for layernorm: submit and poll for completion.""" + with open(ARTIFACT_DIR / "layernorm_create_task_request.json", "w") as f: + json.dump( + {"url": _url("/tasks"), "payload": payload, "op": "layernorm"}, f, indent=2 + ) + + r = requests.post(_url("/tasks"), json=payload, timeout=5) + with open(ARTIFACT_DIR / "layernorm_create_task_response.json", "w") as f: + try: + body = r.json() + except Exception: + body = {"text": r.text} + json.dump({"status_code": r.status_code, "body": body}, f, indent=2) + + assert r.status_code == 200 + j = r.json() + task_id = j.get("task_id") + assert task_id + + deadline = time.time() + 20 + last = None + while time.time() < deadline: + r = requests.get(_url(f"/tasks/{task_id}"), timeout=5) + if r.status_code == 200: + info = r.json() + status = info.get("status") + with open(ARTIFACT_DIR / f"task_{task_id}_poll_layernorm.json", "w") as f: + json.dump({"status_code": r.status_code, "body": info}, f, indent=2) + if status == "done": + assert "result" in info + simulated_time = info.get("result", {}).get("time_taken") + if payload.get("kernel_name") == "itest_layernorm_default": + assert simulated_time is not None + elif payload.get("kernel_name") == "itest_layernorm_unsupported": + assert simulated_time is None + break + last = status + time.sleep(1) + assert last in ("done", "failed", "queued", None) + + +@pytest.mark.parametrize( + "payload", + [ + { + "kernel_name": "itest_softmax_default", + "op": "softmax", + "input_dim": [64, 128], + "dtype": "fp16", + "system_key": "A100_4_fp16", + }, + ], +) +def test_create_task_and_poll_softmax(payload): + """Template test for softmax: submit and poll for completion.""" + with open(ARTIFACT_DIR / "softmax_create_task_request.json", "w") as f: + json.dump( + {"url": _url("/tasks"), "payload": payload, "op": "softmax"}, f, indent=2 + ) + + r = requests.post(_url("/tasks"), json=payload, timeout=5) + with open(ARTIFACT_DIR / "softmax_create_task_response.json", "w") as f: + try: + body = r.json() + except Exception: + body = {"text": r.text} + json.dump({"status_code": r.status_code, "body": body}, f, indent=2) + + assert r.status_code == 200 + j = r.json() + task_id = j.get("task_id") + assert task_id + + deadline = time.time() + 20 + last = None + while time.time() < deadline: + r = requests.get(_url(f"/tasks/{task_id}"), timeout=5) + if r.status_code == 200: + info = r.json() + status = info.get("status") + with open(ARTIFACT_DIR / f"task_{task_id}_poll_softmax.json", "w") as f: + json.dump({"status_code": r.status_code, "body": info}, f, indent=2) + if status == "done": + assert "result" in info + simulated_time = info.get("result", {}).get("time_taken") + assert simulated_time is not None + break + last = status + time.sleep(1) + assert last in ("done", "failed", "queued", None)