From 610e7667bc9c9ded8f6e745cd0e207705fdf5167 Mon Sep 17 00:00:00 2001 From: Xinyu Jiang Date: Fri, 19 Dec 2025 16:58:08 +0000 Subject: [PATCH 1/6] Add TerminalBench eval scaffold --- examples/eval/tb/README.md | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 examples/eval/tb/README.md diff --git a/examples/eval/tb/README.md b/examples/eval/tb/README.md new file mode 100644 index 000000000..a4508800b --- /dev/null +++ b/examples/eval/tb/README.md @@ -0,0 +1,8 @@ +# TerminalBench Evaluation (WIP) + +This directory will contain TerminalBench evaluation integration for Slime. + +TODO: +- Run tb from Slime +- Parse metrics +- Log results to wandb \ No newline at end of file From e98c71324cc82fc4e9731d541d799a2b88a13ab2 Mon Sep 17 00:00:00 2001 From: Zhiyao Jiang Date: Mon, 22 Dec 2025 20:43:45 -0500 Subject: [PATCH 2/6] feat(eval): add Terminal Bench eval delegate - Integrates **Terminal Bench** as an eval delegate for **Slime**, enabling evaluation via an external TB server. - Adds a minimal **smoke eval config** and an example **Qwen3-8B** launch script for quick end-to-end testing. - Provides client/server support for submitting eval jobs, polling status, and collecting metrics from Terminal Bench. Co-authored-by: Zhiyao Jiang --- examples/eval/eval_delegate.py | 10 + examples/eval/scripts/eval_tb_smoke.yaml | 22 + examples/eval/scripts/run_eval_tb_qwen.sh | 176 ++++++++ examples/eval/tb/README.md | 8 - examples/eval/terminal_bench/README.md | 174 ++++++++ examples/eval/terminal_bench/__init__.py | 1 + .../terminal_bench/config/local_cluster.yaml | 12 + examples/eval/terminal_bench/requirements.txt | 3 + examples/eval/terminal_bench/tb_client.py | 101 +++++ examples/eval/terminal_bench/tb_config.py | 57 +++ examples/eval/terminal_bench/tb_server.py | 415 ++++++++++++++++++ 11 files changed, 971 insertions(+), 8 deletions(-) create mode 100644 examples/eval/scripts/eval_tb_smoke.yaml create mode 100644 examples/eval/scripts/run_eval_tb_qwen.sh delete mode 100644 examples/eval/tb/README.md create mode 100644 examples/eval/terminal_bench/README.md create mode 100644 examples/eval/terminal_bench/__init__.py create mode 100644 examples/eval/terminal_bench/config/local_cluster.yaml create mode 100644 examples/eval/terminal_bench/requirements.txt create mode 100644 examples/eval/terminal_bench/tb_client.py create mode 100644 examples/eval/terminal_bench/tb_config.py create mode 100644 examples/eval/terminal_bench/tb_server.py diff --git a/examples/eval/eval_delegate.py b/examples/eval/eval_delegate.py index c52c1c9c2..cdea6c158 100644 --- a/examples/eval/eval_delegate.py +++ b/examples/eval/eval_delegate.py @@ -91,6 +91,12 @@ def _rebuild_delegate_config( env_cfg = build_skills_eval_env_config(args, env, defaults) if env_cfg is not None: envs.append(env_cfg) + elif env_name == "terminal_bench": + from examples.eval.terminal_bench.tb_config import build_terminal_bench_config + + env_cfg = build_terminal_bench_config(args, env, defaults) + if env_cfg is not None: + envs.append(env_cfg) else: raise ValueError(f"Unknown delegate environment: {env_name}") return envs @@ -151,6 +157,10 @@ def _create_delegate(env_cfg: EvalEnvConfig, router_addr: str): from examples.eval.nemo_skills.skills_client import SkillsEvalClient return SkillsEvalClient.from_config(env_cfg, router_addr) + elif env_name == "terminal_bench": + from examples.eval.terminal_bench.tb_client import TerminalBenchClient + + return TerminalBenchClient.from_config(env_cfg, router_addr) logger.warning("No delegate client registered for environment: %s", env_name) return None diff --git a/examples/eval/scripts/eval_tb_smoke.yaml b/examples/eval/scripts/eval_tb_smoke.yaml new file mode 100644 index 000000000..90be2ade6 --- /dev/null +++ b/examples/eval/scripts/eval_tb_smoke.yaml @@ -0,0 +1,22 @@ +eval: + defaults: + n_samples_per_eval_prompt: 1 + # temperature: 0.6 + top_p: 0.95 + top_k: -1 + max_response_len: 24576 + datasets: # minimal smoke eval to keep eval-only path happy + - name: smoke + path: /mnt/data/zhiyao/tb_evaluation/tb_eval_smoke/eval_smoke.jsonl + rm_type: deepscaler + delegate: + - name: terminal_bench + type: examples.eval.terminal_bench.tb_config.build_terminal_bench_config + url: http://172.17.0.1:9052 + timeout_secs: 14400 + max_retries: 1 + model_name: openai/qwen3-8b + api_base: http://127.0.1.1:30005/v1 + dataset_path: /mnt/data/xinyu/program/slime-tb/terminal-bench/tasks + task_id: hello-world + n_concurrent: 8 diff --git a/examples/eval/scripts/run_eval_tb_qwen.sh b/examples/eval/scripts/run_eval_tb_qwen.sh new file mode 100644 index 000000000..8682661a3 --- /dev/null +++ b/examples/eval/scripts/run_eval_tb_qwen.sh @@ -0,0 +1,176 @@ +#!/bin/bash + +# Example launcher that reuses the Qwen3-8B recipe but delegates evaluation to an +# external Terminal Bench server via the eval_delegate_rollout wrapper. + +# Clean up any stale processes from a previous run. +pkill -9 sglang +sleep 3 +ray stop --force +pkill -9 ray +pkill -9 python +sleep 3 +pkill -9 ray +pkill -9 python + +set -ex + +export PYTHONBUFFERED=16 + +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) +if [ "$NVLINK_COUNT" -gt 0 ]; then + HAS_NVLINK=1 +else + HAS_NVLINK=0 +fi +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../../.." &>/dev/null && pwd)" +source "${REPO_ROOT}/scripts/models/qwen3-8B.sh" + +# Store eval/delegate settings in a YAML config similar to examples/eval_multi_task. +EVAL_CONFIG_PATH=${TB_EVAL_CONFIG_PATH:-"${REPO_ROOT}/examples/eval/scripts/eval_tb_smoke.yaml"} + +DEBUG_ARGS=( + --debug-rollout-only +) + +CKPT_ARGS=( + --hf-checkpoint /mnt/data/xinyu/OpenThinker-Agent-v1 + # --ref-load /mnt/data/xinyu/OpenThinker-Agent-v1 + # --hf-checkpoint /root/shared/Qwen3-8B + # --ref-load /root/shared/Qwen3-8B_torch_dist + # --load /root/shared/Qwen3-8B_slime/ + # --save /root/shared/Qwen3-8B_slime/ + # --save-interval 20 +) + +ROLLOUT_ARGS=( + --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl + --input-key prompt + --label-key label + --apply-chat-template + --rollout-shuffle + --rm-type deepscaler + # --num-rollout 3000 + --num-rollout 1 + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 8192 + --rollout-temperature 0.8 + --global-batch-size 256 + --balance-data +) + +EVAL_ARGS=( + # --eval-interval 5 + --eval-interval 1 + --eval-config "${EVAL_CONFIG_PATH}" + --eval-function-path examples.eval.eval_delegate_rollout.generate_rollout +) + +PERF_ARGS=( + --tensor-model-parallel-size 1 + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 +) + +GRPO_ARGS=( + # --advantage-estimator grpo + # --use-kl-loss + # --kl-loss-coef 0.00 + # --kl-loss-type low_var_kl + # --entropy-coef 0.00 + # --eps-clip 0.2 + # --eps-clip-high 0.28 +) + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 +) + +WANDB_ARGS=( + # --use-wandb + # --wandb-project slime-eval + # --wandb-group qwen3-8b-eval + # --wandb-key ${WANDB_KEY} + --wandb-mode disabled +) + +# ROUTER_IP=$(hostname -I | awk '{print $1}') + +SGLANG_ARGS=( + # --use-slime-router + --rollout-num-gpus-per-engine 1 + --sglang-mem-fraction-static 0.7 + # --sglang-cuda-graph-max-bs 16 + # set up sglang router + # --sglang-router-ip "${ROUTER_IP}" + --sglang-router-port 30005 +) + +MISC_ARGS=( + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +# export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export MASTER_ADDR=${MASTER_ADDR:-"10.102.22.21"} +export CUDA_VISIBLE_DEVICES=6,7 + +unset RAY_ADDRESS RAY_REDIS_ADDRESS RAY_GCS_ADDRESS +export RAY_TMPDIR=/tmp/ray_zhiyao +ray start --head --node-ip-address ${MASTER_ADDR} --port 6380 --num-gpus 2 \ + --disable-usage-stats \ + --dashboard-host=0.0.0.0 \ + --dashboard-port=8266 \ + --dashboard-agent-listen-port 52366 \ + --dashboard-agent-grpc-port 52367 \ + --runtime-env-agent-port 52368 + + +RUNTIME_ENV_JSON="{ + \"env_vars\": { + \"PYTHONPATH\": \"/root/Megatron-LM/\", + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\" + } +}" + +sleep 5 + +ray job submit --address="http://${MASTER_ADDR}:8266" \ + --working-dir "${REPO_ROOT}" \ + --runtime-env-json="${RUNTIME_ENV_JSON}" \ + -- python3 train.py \ + --actor-num-nodes 1 \ + --actor-num-gpus-per-node 2 \ + --colocate \ + ${MODEL_ARGS[@]} \ + ${CKPT_ARGS[@]} \ + ${ROLLOUT_ARGS[@]} \ + ${OPTIMIZER_ARGS[@]} \ + ${GRPO_ARGS[@]} \ + ${WANDB_ARGS[@]} \ + ${PERF_ARGS[@]} \ + ${EVAL_ARGS[@]} \ + ${SGLANG_ARGS[@]} \ + ${DEBUG_ARGS[@]} \ + ${MISC_ARGS[@]} \ No newline at end of file diff --git a/examples/eval/tb/README.md b/examples/eval/tb/README.md deleted file mode 100644 index a4508800b..000000000 --- a/examples/eval/tb/README.md +++ /dev/null @@ -1,8 +0,0 @@ -# TerminalBench Evaluation (WIP) - -This directory will contain TerminalBench evaluation integration for Slime. - -TODO: -- Run tb from Slime -- Parse metrics -- Log results to wandb \ No newline at end of file diff --git a/examples/eval/terminal_bench/README.md b/examples/eval/terminal_bench/README.md new file mode 100644 index 000000000..fcde58620 --- /dev/null +++ b/examples/eval/terminal_bench/README.md @@ -0,0 +1,174 @@ +# Terminal Bench Eval (Slime) + +This folder wires Terminal Bench (TB) into Slime as an eval delegate. The TB +run happens on the host via the `tb` CLI, and Slime reads back `accuracy` and +`n_resolved`. + +This guide is written for ML/algorithm folks who just want it to run. + +## What runs where + +- Slime runs your training/eval loop. +- Slime calls the TB delegate client. +- The TB delegate server (`tb_server.py`) runs `tb run ...` on the host. +- The server reads the latest TB JSON results and returns metrics to Slime. + +## Prereqs + +This setup assumes `slime/` and `terminal-bench/` are sibling directories under +`/mnt/data/xinyu/program/slime-tb`, and you only need one venv at +`/mnt/data/xinyu/program/slime-tb/.venv`. + +1) A working OpenAI-compatible inference endpoint, e.g.: + - `http://127.0.0.1:30001/v1` + +2) Terminal Bench installed and its `tb` CLI available. + - Use the same venv for both Slime and TerminalBench. + +3) TerminalBench Eval Server dependencies (Slime side only). + - Install with: + ```bash + uv pip install -r ../slime/examples/eval/terminal_bench/requirements.txt + ``` + - Do not assume these dependencies exist in `terminal-bench/`. + - Do not put this `requirements.txt` under `terminal-bench/`. + +4) A Slime eval config file that includes `eval.datasets`. + - Slime requires at least one dataset under `eval.datasets`. + - You can reuse your existing eval config; just add the delegate section. + +## Step 1: Start the inference server (sglang) + +Example: + +```bash +python3 -m sglang.launch_server \ + --model-path /data/models/OpenThinker-Agent-v1 \ + --served-model-name openai/qwen3-8b \ + --port 30001 \ + --host 0.0.0.0 +``` + +Notes: +- The `served-model-name` should match what TB sends (`openai/`). +- If your model name is different, update `model_name` in the delegate config. + +## Step 2: Start the TB server + +Run on the host (same machine where `tb` works): + +```bash +cd /mnt/data/xinyu/program/slime-tb/terminal-bench + +python slime/examples/eval/terminal_bench/tb_server.py \ + --host 0.0.0.0 --port 9050 \ + --output-root /tmp/tb-eval +``` + +What it does: +- Uses `OPENAI_API_KEY=EMPTY` +- Runs `tb run -a terminus-2 -m openai/ ... --n-concurrent 8` +- Waits for completion, then returns `accuracy` and `n_resolved` + +## Step 3: Quick sanity check (curl, async) + +Run a single task (e.g. `hello-world`). The server returns a `job_id` +immediately, then you poll the status endpoint. + +```bash +# Submit job +curl -X POST http://localhost:9050/evaluate \ + -H 'Content-Type: application/json' \ + -d '{"model_name":"qwen3-8b","api_base":"http://127.0.0.1:30001/v1","dataset_path":"/mnt/data/xinyu/program/slime-tb/terminal-bench/tasks","task_id":"hello-world","n_concurrent":1}' +``` + +The response includes `job_id` and `status_url`, for example: + +```json +{"job_id":"...","status":"queued","status_url":"/status/", ...} +``` + +Poll status until `completed`: + +```bash +curl http://localhost:9050/status/ +``` + +Where to check outputs: +- Logs: `/mnt/data/xinyu/program/slime-tb/tb_eval_logs/.log` +- Results: `/tmp/tb-eval//results.json` + +## Step 4: Configure Slime eval + +You need an eval config. Example: + +```yaml +eval: + # Slime still needs normal eval datasets (can be any small one). + datasets: + - name: aime + path: /root/datasets/aime-2024/aime-2024.jsonl + rm_type: math + + # TB delegate config. + delegate: + - name: terminal_bench + url: http://localhost:9050 # "/evaluate" auto-added if missing + timeout_secs: 1200 # 20 minutes + model_name: qwen3-8b + api_base: http://127.0.0.1:30001/v1 + dataset_path: /mnt/data/xinyu/program/slime-tb/terminal-bench/tasks + n_tasks: 10 + n_concurrent: 1 + # Optional: run specific tasks instead of n_tasks + # task_ids: ["hello-world"] + # task_id: "hello-world" +``` + +Notes: +- `model_name` is auto-normalized to `openai/` if you omit the prefix. +- The TB client auto-adds `/evaluate` if you give a bare host:port. +- `task_id` / `task_ids` overrides `n_tasks` when provided. +- `dataset_path` lets you run from any working directory. + +## Step 5: Tell Slime to use the delegate rollout + +Add this to your training/eval command: + +```bash +--eval-config /path/to/your_eval_config.yaml \ +--eval-function-path examples.eval.eval_delegate_rollout.generate_rollout +``` + +This makes Slime call the TB delegate during evaluation. + +## Quick sanity check (eval-only) + +If you just want to verify the TB integration, run a quick eval-only pass +(you still need your normal Slime args for model/data/etc.): + +```bash +python slime/train.py \ + --num-rollout 0 \ + --eval-interval 1 \ + --eval-config /path/to/your_eval_config.yaml \ + --eval-function-path examples.eval.eval_delegate_rollout.generate_rollout \ + ...other required args... +``` + +## Common gotchas + +- 404 from TB server: use `url: http://localhost:9050` or `.../evaluate`. +- Timeouts: keep `timeout_secs` large (TB tasks can compile code). +- No TB metrics: check `/tmp/tb-eval//results.json` and poll `/status/`. +- No output in terminal: tail the log at `/mnt/data/xinyu/program/slime-tb/tb_eval_logs/.log`. + +## Reference: the CLI command it runs + +The server is aligned with: + +```bash +OPENAI_API_KEY=EMPTY tb run -a terminus-2 -m openai/qwen3-8b \ + --agent-kwarg api_base=http://127.0.0.1:30001/v1 \ + --n-concurrent 1 +``` diff --git a/examples/eval/terminal_bench/__init__.py b/examples/eval/terminal_bench/__init__.py new file mode 100644 index 000000000..6ba998ca7 --- /dev/null +++ b/examples/eval/terminal_bench/__init__.py @@ -0,0 +1 @@ +"""NeMo Skills evaluation helpers.""" diff --git a/examples/eval/terminal_bench/config/local_cluster.yaml b/examples/eval/terminal_bench/config/local_cluster.yaml new file mode 100644 index 000000000..c4a824a1b --- /dev/null +++ b/examples/eval/terminal_bench/config/local_cluster.yaml @@ -0,0 +1,12 @@ +# Minimal Terminal Bench delegate config for running on the host (no containers). + +type: examples.eval.terminal_bench.tb_config.build_terminal_bench_config + +name: terminal_bench +url: http://localhost:9050/evaluate +timeout_secs: 1200 + +model_name: qwen3-8b +api_base: http://172.17.0.1:30001/v1 +n_tasks: 10 +n_concurrent: 4 diff --git a/examples/eval/terminal_bench/requirements.txt b/examples/eval/terminal_bench/requirements.txt new file mode 100644 index 000000000..1a0006c93 --- /dev/null +++ b/examples/eval/terminal_bench/requirements.txt @@ -0,0 +1,3 @@ +flask +omegaconf +requests diff --git a/examples/eval/terminal_bench/tb_client.py b/examples/eval/terminal_bench/tb_client.py new file mode 100644 index 000000000..38e50ab74 --- /dev/null +++ b/examples/eval/terminal_bench/tb_client.py @@ -0,0 +1,101 @@ +import logging +import time +from typing import Any + +import requests +from examples.eval.eval_delegate import EvalClient, EvalDelegateError +from examples.eval.terminal_bench.tb_config import TerminalBenchConfig + +logger = logging.getLogger(__name__) + + +class TerminalBenchClient(EvalClient): + """HTTP client that proxies evaluation requests to the Terminal Bench server.""" + + def __init__(self, config: TerminalBenchConfig, router_url: str): + super().__init__(config.name or "terminal_bench") + self._config = config + endpoint = (config.url or "").rstrip("/") + if endpoint.endswith("/evaluate"): + base_endpoint = endpoint[: -len("/evaluate")] + else: + base_endpoint = endpoint + self._endpoint = f"{base_endpoint}/evaluate" if base_endpoint else "" + self._status_endpoint = f"{base_endpoint}/status" if base_endpoint else "" + self._timeout_secs = float(config.timeout_secs) + self._max_retries = max(1, int(config.max_retries)) + self._headers = dict(config.headers or {}) + self._session = requests.Session() + + @classmethod + def from_config(cls, config: TerminalBenchConfig, router_url: str): + if not config.url: + return None + return cls(config, router_url) + + def evaluate(self, args, rollout_id: int) -> tuple[dict[str, Any], dict[str, Any]]: + payload = self._build_payload(args, rollout_id) + response = self._request(payload) + metrics = response.get("raw_metrics", {}) + return metrics, response + + def _build_payload(self, args, rollout_id: int) -> dict[str, Any]: + payload = { + "model_name": self._config.model_name, + "api_base": self._config.api_base, + "n_tasks": self._config.n_tasks, + "n_concurrent": self._config.n_concurrent, + } + if self._config.dataset_path: + payload["dataset_path"] = self._config.dataset_path + if self._config.task_ids: + payload["task_ids"] = list(self._config.task_ids) + return payload + + def _request(self, payload: dict[str, Any]) -> dict[str, Any]: + last_error: Exception | None = None + for attempt in range(1, self._max_retries + 1): + try: + response = self._session.post( + self._endpoint, + json=payload, + timeout=self._timeout_secs, + headers=self._headers, + ) + response.raise_for_status() + if not response.content: + return {} + body = response.json() + if body.get("status") == "completed": + return body + job_id = body.get("job_id") + if not job_id: + return body + return self._poll_status(job_id) + except requests.RequestException as exc: + last_error = exc + logger.warning( + "Terminal Bench delegate request failed (attempt %s/%s): %s", attempt, self._max_retries, exc + ) + if attempt < self._max_retries: + time.sleep(min(2**attempt, 30)) + raise EvalDelegateError("Terminal Bench evaluation request failed") from last_error + + def _poll_status(self, job_id: str) -> dict[str, Any]: + status_url = f"{self._status_endpoint}/{job_id}" + deadline = time.time() + self._timeout_secs + while time.time() < deadline: + response = self._session.get(status_url, timeout=min(self._timeout_secs, 30), headers=self._headers) + response.raise_for_status() + if not response.content: + time.sleep(2) + continue + body = response.json() + status = body.get("status") + if status == "completed": + return body + if status == "failed": + error = body.get("error") or "Terminal Bench job failed" + raise EvalDelegateError(error) + time.sleep(2) + raise EvalDelegateError("Terminal Bench evaluation timed out") diff --git a/examples/eval/terminal_bench/tb_config.py b/examples/eval/terminal_bench/tb_config.py new file mode 100644 index 000000000..fbd96e7b5 --- /dev/null +++ b/examples/eval/terminal_bench/tb_config.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any + +from examples.eval.eval_delegate import EvalEnvConfig + + +@dataclass +class TerminalBenchConfig(EvalEnvConfig): + """Environment configuration shared by the Terminal Bench client/server.""" + + model_name: str = "qwen3-8b" + api_base: str = "http://172.17.0.1:30001/v1" + n_tasks: int = 10 + n_concurrent: int = 4 + dataset_path: str | None = None + task_id: str | None = None + task_ids: list[str] = field(default_factory=list) + + @classmethod + def parse(cls, args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, Any]) -> TerminalBenchConfig: + clean_raw = dict(raw_env_config or {}) + clean_raw.pop("type", None) + base_cfg: TerminalBenchConfig = super().parse(clean_raw, defaults) + model_name = clean_raw.get("model_name") + if model_name is not None: + base_cfg.model_name = str(model_name) + api_base = clean_raw.get("api_base") + if api_base is not None: + base_cfg.api_base = str(api_base) + n_tasks = clean_raw.get("n_tasks") + if n_tasks is not None: + base_cfg.n_tasks = int(n_tasks) + n_concurrent = clean_raw.get("n_concurrent") + if n_concurrent is not None: + base_cfg.n_concurrent = int(n_concurrent) + dataset_path = clean_raw.get("dataset_path") + if dataset_path is not None: + base_cfg.dataset_path = str(dataset_path) + task_id = clean_raw.get("task_id") + if task_id is not None: + base_cfg.task_id = str(task_id) + task_ids = clean_raw.get("task_ids") + if task_ids is None: + task_ids = task_id + if task_ids is not None: + if isinstance(task_ids, (list, tuple)): + base_cfg.task_ids = [str(item) for item in task_ids if item] + else: + base_cfg.task_ids = [str(task_ids)] + return base_cfg + + +def build_terminal_bench_config(args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, Any]): + return TerminalBenchConfig.parse(args, raw_env_config, defaults) diff --git a/examples/eval/terminal_bench/tb_server.py b/examples/eval/terminal_bench/tb_server.py new file mode 100644 index 000000000..7543f36ca --- /dev/null +++ b/examples/eval/terminal_bench/tb_server.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 +""" +Simple HTTP server that proxies Slime evaluation requests to the `tb run` +command shipped with Terminal Bench. + +Usage: + python examples/eval/terminal_bench/tb_server.py \ + --host 0.0.0.0 --port 9050 \ + --output-root /opt/tb-eval + +Slime (or Slime-compatible runners) should POST the payload described in +`EvalRequestPayload` to http://:/evaluate. The server blocks until +`tb run` finishes, then returns aggregated metrics along with paths to the +generated artifacts (logs + raw metrics). +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import shlex +import subprocess +import sys +import threading +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +REPO_ROOT = Path(__file__).resolve().parents[3] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from flask import Flask, jsonify, request +from omegaconf import OmegaConf +from omegaconf.errors import OmegaConfBaseException + +logger = logging.getLogger("terminal_bench_server") +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") + + +# --------------------------------------------------------------------------- +# Request payload helpers +# --------------------------------------------------------------------------- + + +@dataclass +class EvalRequestPayload: + model_name: str = "" + api_base: str = "" + n_tasks: int | None = None + n_concurrent: int | None = None + dataset_path: str | None = None + task_ids: list[str] | None = None + task_id: str | None = None + + +@dataclass +class JobRecord: + job_id: str + status: str + run_id: str + command: str + output_dir: str + log_path: str + raw_metrics: dict[str, Any] | None = None + error: str | None = None + created_at: float = field(default_factory=time.time) + started_at: float | None = None + finished_at: float | None = None + + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "job_id": self.job_id, + "status": self.status, + "run_id": self.run_id, + "command": self.command, + "output_dir": self.output_dir, + "log_path": self.log_path, + "created_at": self.created_at, + "started_at": self.started_at, + "finished_at": self.finished_at, + } + if self.raw_metrics is not None: + payload["raw_metrics"] = self.raw_metrics + if self.error: + payload["error"] = self.error + return payload + + +# --------------------------------------------------------------------------- +# Configuration + command helpers +# --------------------------------------------------------------------------- + + +def _normalize_model_name(model_name: str) -> str: + name = (model_name or "").strip() + if not name: + return "" + if "/" in name: + return name + return f"openai/{name}" + + +@dataclass +class ServerConfig: + output_root: Path + + @classmethod + def from_args(cls, args: argparse.Namespace) -> "ServerConfig": + return cls(output_root=Path(args.output_root).expanduser().resolve()) + + +class TerminalBenchEvaluator: + def __init__(self, config: ServerConfig): + self._config = config + self._lock = threading.Lock() + self._jobs_lock = threading.Lock() + self._jobs: dict[str, JobRecord] = {} + self._config.output_root.mkdir(parents=True, exist_ok=True) + self._log_root = REPO_ROOT / "tb_eval_logs" + self._log_root.mkdir(parents=True, exist_ok=True) + + def evaluate(self, payload: EvalRequestPayload) -> dict[str, Any]: + if not payload.model_name: + raise ValueError("Missing `model_name` in request payload.") + if not payload.api_base: + raise ValueError("Missing `api_base` in request payload.") + + job_id = uuid.uuid4().hex + run_id = f"{int(time.time())}-{job_id[:8]}" + run_dir = self._config.output_root / run_id + + command = self._build_command(payload, run_id) + command_str = " ".join(shlex.quote(part) for part in command) + log_path = self._log_root / f"{run_id}.log" + + record = JobRecord( + job_id=job_id, + status="queued", + run_id=run_id, + command=command_str, + output_dir=str(run_dir), + log_path=str(log_path), + ) + with self._jobs_lock: + self._jobs[job_id] = record + + thread = threading.Thread( + target=self._run_job, + args=(job_id, payload, run_dir, command, log_path), + daemon=True, + ) + thread.start() + + return { + "job_id": job_id, + "status": "queued", + "status_url": f"/status/{job_id}", + "run_id": run_id, + "command": command_str, + "output_dir": str(run_dir), + "log_path": str(log_path), + } + + def _run_job( + self, + job_id: str, + payload: EvalRequestPayload, + run_dir: Path, + command: list[str], + log_path: Path, + ) -> None: + with self._jobs_lock: + record = self._jobs.get(job_id) + if record is None: + return + record.status = "running" + record.started_at = time.time() + + env = self._build_env() + logger.info("Starting Terminal Bench run: %s", " ".join(shlex.quote(part) for part in command)) + try: + with self._lock: + self._run_command(command, env=env, log_path=log_path) + metrics = self._collect_metrics(run_dir) + with self._jobs_lock: + record = self._jobs.get(job_id) + if record is None: + return + record.status = "completed" + record.raw_metrics = metrics + record.finished_at = time.time() + except Exception as exc: # noqa: BLE001 + with self._jobs_lock: + record = self._jobs.get(job_id) + if record is None: + return + record.status = "failed" + record.error = str(exc) + record.finished_at = time.time() + + def get_job_status(self, job_id: str) -> dict[str, Any] | None: + with self._jobs_lock: + record = self._jobs.get(job_id) + if record is None: + return None + return record.to_dict() + + def _build_command(self, payload: EvalRequestPayload, run_id: str) -> list[str]: + # 1. Normalize model name (add openai/ prefix) + model_name = _normalize_model_name(payload.model_name) + + cmd = [ + "tb", + "run", + "-a", + "terminus-2", # Added Agent flag + "--output-path", + str(self._config.output_root), + "--run-id", + run_id, + ] + + # 2. Add model + if model_name: + cmd.extend(["--model", model_name]) + + # 3. Add Agent kwargs (Use api_base exactly like the CLI command) + if payload.api_base: + cmd.extend(["--agent-kwarg", f"api_base={payload.api_base}"]) + + # 4. Add n_tasks if present + task_ids = [] + if payload.task_ids: + task_ids.extend([str(item) for item in payload.task_ids if item]) + if payload.task_id: + task_ids.append(str(payload.task_id)) + + if payload.dataset_path: + cmd.extend(["--dataset-path", payload.dataset_path]) + + if task_ids: + for task_id in task_ids: + cmd.extend(["--task-id", task_id]) + elif payload.n_tasks is not None: + cmd.extend(["--n-tasks", str(payload.n_tasks)]) + + # 5. Add concurrency + n_concurrent = payload.n_concurrent + if n_concurrent is None: + n_concurrent = 1 + cmd.extend(["--n-concurrent", str(n_concurrent)]) + + return cmd + + def _build_env(self) -> dict[str, str]: + env = os.environ.copy() + # Inject env var to simulate "OPENAI_API_KEY=EMPTY" + env["OPENAI_API_KEY"] = "EMPTY" + return env + + @staticmethod + def _run_command(cmd: list[str], *, env: dict[str, str], log_path: Path): + with open(log_path, "w", encoding="utf-8") as log_file: + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + bufsize=1, + ) + assert process.stdout is not None + for line in process.stdout: + log_file.write(line) + log_file.flush() + sys.stdout.write(line) + sys.stdout.flush() + retcode = process.wait() + if retcode != 0: + with open(log_path, encoding="utf-8", errors="ignore") as log_file: + tail = "".join(log_file.readlines()[-200:]) + raise RuntimeError(f"`tb run` failed with exit code {retcode}. See {log_path}\n{tail}") + + @staticmethod + def _collect_metrics(run_dir: Path) -> dict[str, Any]: + metrics_path = run_dir / "results.json" + if not metrics_path.exists(): + logger.warning("Results file missing at %s", metrics_path) + return {} + + metrics = TerminalBenchEvaluator._extract_metrics(metrics_path) + if not metrics: + logger.warning("No accuracy/n_resolved metrics found in %s", metrics_path) + return metrics + + @staticmethod + def _extract_metrics(metrics_path: Path) -> dict[str, Any]: + try: + with open(metrics_path, encoding="utf-8") as fp: + metrics_data = json.load(fp) + except json.JSONDecodeError as exc: + logger.warning("Failed to parse %s: %s", metrics_path, exc) + return {} + + accuracy = metrics_data.get("accuracy") + n_resolved = metrics_data.get("n_resolved") + + if accuracy is None or n_resolved is None: + results = metrics_data.get("results") + if isinstance(results, list): + resolved = sum(1 for result in results if result.get("is_resolved")) + total = len(results) + if n_resolved is None: + n_resolved = resolved + if accuracy is None: + accuracy = resolved / total if total else 0.0 + + if accuracy is None or n_resolved is None: + return {} + + metrics: dict[str, Any] = {} + if accuracy is not None: + try: + metrics["accuracy"] = float(accuracy) + except (TypeError, ValueError): + logger.warning("Non-numeric accuracy in %s: %r", metrics_path, accuracy) + if n_resolved is not None: + try: + metrics["n_resolved"] = int(n_resolved) + except (TypeError, ValueError): + logger.warning("Non-numeric n_resolved in %s: %r", metrics_path, n_resolved) + if "accuracy" not in metrics or "n_resolved" not in metrics: + return {} + return metrics + + +# --------------------------------------------------------------------------- +# HTTP server +# --------------------------------------------------------------------------- + + +def build_app(evaluator: TerminalBenchEvaluator) -> Flask: + app = Flask(__name__) + + @app.get("/health") + def health_check(): + return jsonify({"status": "ok"}) + + @app.post("/evaluate") + def evaluate_endpoint(): + try: + raw_payload = request.get_json(force=True, silent=False) + cfg = OmegaConf.merge( + OmegaConf.structured(EvalRequestPayload), + OmegaConf.create(raw_payload or {}), + ) + payload = OmegaConf.to_object(cfg) + result = evaluator.evaluate(payload) + return jsonify(result) + except OmegaConfBaseException as exc: + logger.exception("Invalid request payload") + return jsonify({"error": str(exc)}), 400 + except Exception as exc: # noqa: BLE001 + logger.exception("Evaluation failed") + return jsonify({"error": str(exc)}), 500 + + @app.get("/status/") + def status_endpoint(job_id: str): + status = evaluator.get_job_status(job_id) + if status is None: + return jsonify({"error": "job not found"}), 404 + return jsonify(status) + + return app + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run the Terminal Bench evaluation HTTP server.") + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=9050) + parser.add_argument( + "--output-root", + type=str, + default="./terminal-bench-output", + help="Directory to store `tb run` outputs.", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + config = ServerConfig.from_args(args) + evaluator = TerminalBenchEvaluator(config) + app = build_app(evaluator) + logger.info( + "Starting Terminal Bench evaluation server on %s:%s (output root=%s)", + args.host, + args.port, + config.output_root, + ) + app.run(host=args.host, port=args.port) + + +if __name__ == "__main__": + main() From d99048ba49f34b63716af5174eae4654f9a03615 Mon Sep 17 00:00:00 2001 From: Zhiyao Jiang Date: Fri, 26 Dec 2025 00:22:49 -0500 Subject: [PATCH 3/6] successfully integrate tb in slime delegate eval with train Co-authored-by: Zhiyao Jiang Co-authored-by: Xinyu Jiang --- examples/eval/scripts/eval_tb_smoke.yaml | 23 ++++--- examples/eval/scripts/run_eval_tb_qwen.sh | 58 +++++++----------- examples/eval/terminal_bench/tb_client.py | 3 + examples/eval/terminal_bench/tb_config.py | 11 +++- examples/eval/terminal_bench/tb_server.py | 74 +++++++++++++++-------- 5 files changed, 95 insertions(+), 74 deletions(-) diff --git a/examples/eval/scripts/eval_tb_smoke.yaml b/examples/eval/scripts/eval_tb_smoke.yaml index 90be2ade6..8ff19959f 100644 --- a/examples/eval/scripts/eval_tb_smoke.yaml +++ b/examples/eval/scripts/eval_tb_smoke.yaml @@ -5,18 +5,25 @@ eval: top_p: 0.95 top_k: -1 max_response_len: 24576 - datasets: # minimal smoke eval to keep eval-only path happy - - name: smoke - path: /mnt/data/zhiyao/tb_evaluation/tb_eval_smoke/eval_smoke.jsonl - rm_type: deepscaler + datasets: # these eval tasks go through slime dataset config and default rollout function (slime.rollout.sglang_rollout.generate_rollout) + - name: gpqa # huggingface-cli download --repo-type dataset zyzshishui0627/gpqa_diamond --local-dir /root/gpqa + path: /root/gpqa/gpqa_eval.jsonl + rm_type: gpqa + n_samples_per_eval_prompt: 2 + - name: ifbench # huggingface-cli download --repo-type dataset zyzshishui0627/IFBench --local-dir /root/ifbench + path: /root/ifbench/IFBench_eval.jsonl + rm_type: ifbench + n_samples_per_eval_prompt: 1 delegate: - name: terminal_bench - type: examples.eval.terminal_bench.tb_config.build_terminal_bench_config + # type: examples.eval.terminal_bench.tb_config.build_terminal_bench_config url: http://172.17.0.1:9052 - timeout_secs: 14400 - max_retries: 1 - model_name: openai/qwen3-8b + timeout_secs: 86400 # 24 hours + max_retries: 1 # HTTP request retries from Slime to the TB server + model_name: qwen3-8b api_base: http://127.0.1.1:30005/v1 dataset_path: /mnt/data/xinyu/program/slime-tb/terminal-bench/tasks task_id: hello-world + # n_tasks: 10 + n_attempts: 2 # TB task-level retries (per task within tb run) n_concurrent: 8 diff --git a/examples/eval/scripts/run_eval_tb_qwen.sh b/examples/eval/scripts/run_eval_tb_qwen.sh index 8682661a3..fb98cb4d4 100644 --- a/examples/eval/scripts/run_eval_tb_qwen.sh +++ b/examples/eval/scripts/run_eval_tb_qwen.sh @@ -17,6 +17,9 @@ set -ex export PYTHONBUFFERED=16 +MODEL_DIR="${MODEL_DIR:-/mnt/data/xinyu}" +export MODEL_DIR + NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) if [ "$NVLINK_COUNT" -gt 0 ]; then HAS_NVLINK=1 @@ -32,18 +35,12 @@ source "${REPO_ROOT}/scripts/models/qwen3-8B.sh" # Store eval/delegate settings in a YAML config similar to examples/eval_multi_task. EVAL_CONFIG_PATH=${TB_EVAL_CONFIG_PATH:-"${REPO_ROOT}/examples/eval/scripts/eval_tb_smoke.yaml"} -DEBUG_ARGS=( - --debug-rollout-only -) - CKPT_ARGS=( - --hf-checkpoint /mnt/data/xinyu/OpenThinker-Agent-v1 - # --ref-load /mnt/data/xinyu/OpenThinker-Agent-v1 - # --hf-checkpoint /root/shared/Qwen3-8B - # --ref-load /root/shared/Qwen3-8B_torch_dist - # --load /root/shared/Qwen3-8B_slime/ - # --save /root/shared/Qwen3-8B_slime/ - # --save-interval 20 + --hf-checkpoint ${MODEL_DIR}/OpenThinker-Agent-v1 + --ref-load ${MODEL_DIR}/OpenThinker-Agent-v1_torch_dist + # --load ${MODEL_DIR}/OpenThinker-Agent-v1_slime/ + --save ${MODEL_DIR}/OpenThinker-Agent-v1_slime/ + --save-interval 20 ) ROLLOUT_ARGS=( @@ -86,13 +83,13 @@ PERF_ARGS=( ) GRPO_ARGS=( - # --advantage-estimator grpo - # --use-kl-loss - # --kl-loss-coef 0.00 - # --kl-loss-type low_var_kl - # --entropy-coef 0.00 - # --eps-clip 0.2 - # --eps-clip-high 0.28 + --advantage-estimator grpo + --use-kl-loss + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 ) OPTIMIZER_ARGS=( @@ -105,22 +102,15 @@ OPTIMIZER_ARGS=( ) WANDB_ARGS=( - # --use-wandb - # --wandb-project slime-eval - # --wandb-group qwen3-8b-eval - # --wandb-key ${WANDB_KEY} - --wandb-mode disabled + --use-wandb + --wandb-project slime-eval + --wandb-group qwen3-8b-eval + --wandb-key ${WANDB_KEY} ) -# ROUTER_IP=$(hostname -I | awk '{print $1}') - SGLANG_ARGS=( - # --use-slime-router --rollout-num-gpus-per-engine 1 --sglang-mem-fraction-static 0.7 - # --sglang-cuda-graph-max-bs 16 - # set up sglang router - # --sglang-router-ip "${ROUTER_IP}" --sglang-router-port 30005 ) @@ -132,12 +122,9 @@ MISC_ARGS=( --attention-backend flash ) -# export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -export MASTER_ADDR=${MASTER_ADDR:-"10.102.22.21"} +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} export CUDA_VISIBLE_DEVICES=6,7 -unset RAY_ADDRESS RAY_REDIS_ADDRESS RAY_GCS_ADDRESS -export RAY_TMPDIR=/tmp/ray_zhiyao ray start --head --node-ip-address ${MASTER_ADDR} --port 6380 --num-gpus 2 \ --disable-usage-stats \ --dashboard-host=0.0.0.0 \ @@ -154,8 +141,6 @@ RUNTIME_ENV_JSON="{ } }" -sleep 5 - ray job submit --address="http://${MASTER_ADDR}:8266" \ --working-dir "${REPO_ROOT}" \ --runtime-env-json="${RUNTIME_ENV_JSON}" \ @@ -172,5 +157,4 @@ ray job submit --address="http://${MASTER_ADDR}:8266" \ ${PERF_ARGS[@]} \ ${EVAL_ARGS[@]} \ ${SGLANG_ARGS[@]} \ - ${DEBUG_ARGS[@]} \ - ${MISC_ARGS[@]} \ No newline at end of file + ${MISC_ARGS[@]} diff --git a/examples/eval/terminal_bench/tb_client.py b/examples/eval/terminal_bench/tb_client.py index 38e50ab74..2a93b7161 100644 --- a/examples/eval/terminal_bench/tb_client.py +++ b/examples/eval/terminal_bench/tb_client.py @@ -45,11 +45,14 @@ def _build_payload(self, args, rollout_id: int) -> dict[str, Any]: "api_base": self._config.api_base, "n_tasks": self._config.n_tasks, "n_concurrent": self._config.n_concurrent, + "metric_prefix": self._config.name, } if self._config.dataset_path: payload["dataset_path"] = self._config.dataset_path if self._config.task_ids: payload["task_ids"] = list(self._config.task_ids) + if self._config.n_attempts is not None: + payload["n_attempts"] = self._config.n_attempts return payload def _request(self, payload: dict[str, Any]) -> dict[str, Any]: diff --git a/examples/eval/terminal_bench/tb_config.py b/examples/eval/terminal_bench/tb_config.py index fbd96e7b5..3c401d9e1 100644 --- a/examples/eval/terminal_bench/tb_config.py +++ b/examples/eval/terminal_bench/tb_config.py @@ -12,12 +12,14 @@ class TerminalBenchConfig(EvalEnvConfig): """Environment configuration shared by the Terminal Bench client/server.""" model_name: str = "qwen3-8b" - api_base: str = "http://172.17.0.1:30001/v1" - n_tasks: int = 10 - n_concurrent: int = 4 + api_base: str = "http://127.0.1.1:30001/v1" dataset_path: str | None = None + n_tasks: int | None = None task_id: str | None = None task_ids: list[str] = field(default_factory=list) + n_attempts: int | None = None + n_concurrent: int = 8 + @classmethod def parse(cls, args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, Any]) -> TerminalBenchConfig: @@ -30,6 +32,9 @@ def parse(cls, args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, A api_base = clean_raw.get("api_base") if api_base is not None: base_cfg.api_base = str(api_base) + n_attempts = clean_raw.get("n_attempts") + if n_attempts is not None: + base_cfg.n_attempts = int(n_attempts) n_tasks = clean_raw.get("n_tasks") if n_tasks is not None: base_cfg.n_tasks = int(n_tasks) diff --git a/examples/eval/terminal_bench/tb_server.py b/examples/eval/terminal_bench/tb_server.py index 7543f36ca..0a75a9128 100644 --- a/examples/eval/terminal_bench/tb_server.py +++ b/examples/eval/terminal_bench/tb_server.py @@ -26,6 +26,7 @@ import threading import time import uuid +import statistics from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -56,6 +57,8 @@ class EvalRequestPayload: dataset_path: str | None = None task_ids: list[str] | None = None task_id: str | None = None + n_attempts: int | None = None + metric_prefix: str | None = None @dataclass @@ -187,6 +190,8 @@ def _run_job( with self._lock: self._run_command(command, env=env, log_path=log_path) metrics = self._collect_metrics(run_dir) + if payload.metric_prefix: + metrics = {payload.metric_prefix: metrics} with self._jobs_lock: record = self._jobs.get(job_id) if record is None: @@ -242,6 +247,9 @@ def _build_command(self, payload: EvalRequestPayload, run_id: str) -> list[str]: if payload.dataset_path: cmd.extend(["--dataset-path", payload.dataset_path]) + + if payload.n_attempts is not None: + cmd.extend(["--n-attempts", str(payload.n_attempts)]) if task_ids: for task_id in task_ids: @@ -307,35 +315,49 @@ def _extract_metrics(metrics_path: Path) -> dict[str, Any]: logger.warning("Failed to parse %s: %s", metrics_path, exc) return {} + metrics: dict[str, Any] = {} + + # core metrics accuracy = metrics_data.get("accuracy") - n_resolved = metrics_data.get("n_resolved") + if isinstance(accuracy, (int, float)): + metrics["accuracy"] = float(accuracy) - if accuracy is None or n_resolved is None: - results = metrics_data.get("results") - if isinstance(results, list): - resolved = sum(1 for result in results if result.get("is_resolved")) - total = len(results) - if n_resolved is None: - n_resolved = resolved - if accuracy is None: - accuracy = resolved / total if total else 0.0 - - if accuracy is None or n_resolved is None: - return {} + n_resolved = metrics_data.get("n_resolved") + if isinstance(n_resolved, (int, float)): + metrics["n_resolved"] = int(n_resolved) + + n_unresolved = metrics_data.get("n_unresolved") + if isinstance(n_unresolved, (int, float)): + metrics["n_unresolved"] = int(n_unresolved) + + # pass@k flatten + pass_at_k = metrics_data.get("pass_at_k") + if isinstance(pass_at_k, dict): + for k, v in pass_at_k.items(): + if isinstance(v, (int, float)): + metrics[f"pass_at_k/{k}"] = float(v) + + # token stats from per-task results + results = metrics_data.get("results") + if isinstance(results, list): + input_tokens = [ + r.get("total_input_tokens") + for r in results + if isinstance(r, dict) and isinstance(r.get("total_input_tokens"), (int, float)) + ] + output_tokens = [ + r.get("total_output_tokens") + for r in results + if isinstance(r, dict) and isinstance(r.get("total_output_tokens"), (int, float)) + ] + + if input_tokens: + metrics["total_input_tokens_mean"] = float(statistics.mean(input_tokens)) + metrics["total_input_tokens_median"] = float(statistics.median(input_tokens)) + if output_tokens: + metrics["total_output_tokens_mean"] = float(statistics.mean(output_tokens)) + metrics["total_output_tokens_median"] = float(statistics.median(output_tokens)) - metrics: dict[str, Any] = {} - if accuracy is not None: - try: - metrics["accuracy"] = float(accuracy) - except (TypeError, ValueError): - logger.warning("Non-numeric accuracy in %s: %r", metrics_path, accuracy) - if n_resolved is not None: - try: - metrics["n_resolved"] = int(n_resolved) - except (TypeError, ValueError): - logger.warning("Non-numeric n_resolved in %s: %r", metrics_path, n_resolved) - if "accuracy" not in metrics or "n_resolved" not in metrics: - return {} return metrics From 75540cecb5abf2ed7e67ac8b9f340ffd190f3496 Mon Sep 17 00:00:00 2001 From: Xinyu Jiang Date: Fri, 26 Dec 2025 21:53:44 -0500 Subject: [PATCH 4/6] write quick-start for slime + tb delegate eval Co-authored-by: Xinyu Jiang Co-authored-by: Zhiyao Jiang --- examples/eval/terminal_bench/README.md | 181 +++++++------------------ 1 file changed, 51 insertions(+), 130 deletions(-) diff --git a/examples/eval/terminal_bench/README.md b/examples/eval/terminal_bench/README.md index fcde58620..4e749e2a4 100644 --- a/examples/eval/terminal_bench/README.md +++ b/examples/eval/terminal_bench/README.md @@ -1,174 +1,95 @@ # Terminal Bench Eval (Slime) -This folder wires Terminal Bench (TB) into Slime as an eval delegate. The TB -run happens on the host via the `tb` CLI, and Slime reads back `accuracy` and -`n_resolved`. - -This guide is written for ML/algorithm folks who just want it to run. +This folder wires Terminal Bench (TB) into Slime as an eval delegate. The TB run happens on the host via the `tb` CLI, and Slime reads back aggregated metrics such as `accuracy`, `n_resolved`, `n_unresolved`, `pass_at_k/*`, and token stats like `total_input_tokens_mean/median` and `total_output_tokens_mean/median`. ## What runs where -- Slime runs your training/eval loop. +- Slime runs your training/eval loop inside the Docker container. - Slime calls the TB delegate client. - The TB delegate server (`tb_server.py`) runs `tb run ...` on the host. - The server reads the latest TB JSON results and returns metrics to Slime. ## Prereqs -This setup assumes `slime/` and `terminal-bench/` are sibling directories under -`/mnt/data/xinyu/program/slime-tb`, and you only need one venv at -`/mnt/data/xinyu/program/slime-tb/.venv`. - -1) A working OpenAI-compatible inference endpoint, e.g.: - - `http://127.0.0.1:30001/v1` - -2) Terminal Bench installed and its `tb` CLI available. - - Use the same venv for both Slime and TerminalBench. - -3) TerminalBench Eval Server dependencies (Slime side only). - - Install with: - ```bash - uv pip install -r ../slime/examples/eval/terminal_bench/requirements.txt - ``` - - Do not assume these dependencies exist in `terminal-bench/`. - - Do not put this `requirements.txt` under `terminal-bench/`. - -4) A Slime eval config file that includes `eval.datasets`. +1) Docker with GPU access. +2) `uv` installed on the host. +3) Terminal Bench installed and its `tb` CLI available on the machine that runs + `tb_server.py`. +4) The Slime repo available on the machine that runs `tb_server.py`. +5) A Slime eval config file that includes `eval.datasets`. - Slime requires at least one dataset under `eval.datasets`. - You can reuse your existing eval config; just add the delegate section. -## Step 1: Start the inference server (sglang) - -Example: +## 1) Get the code (host) ```bash -python3 -m sglang.launch_server \ - --model-path /data/models/OpenThinker-Agent-v1 \ - --served-model-name openai/qwen3-8b \ - --port 30001 \ - --host 0.0.0.0 +git clone --branch xinyu/quick_start https://github.com/XinyuJiangCMU/slime.git +git clone https://github.com/laude-institute/terminal-bench ``` -Notes: -- The `served-model-name` should match what TB sends (`openai/`). -- If your model name is different, update `model_name` in the delegate config. - -## Step 2: Start the TB server - -Run on the host (same machine where `tb` works): +## 2) Launch the Slime container ```bash -cd /mnt/data/xinyu/program/slime-tb/terminal-bench - -python slime/examples/eval/terminal_bench/tb_server.py \ - --host 0.0.0.0 --port 9050 \ - --output-root /tmp/tb-eval +docker run \ + -itd \ + --gpus all \ + --shm-size 32g \ + --network host \ + --ipc=host \ + --privileged \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + --ulimit nofile=65536:65536 \ + -v ~/.cache:/root/.cache \ + -v $(pwd)/slime:/opt/slime \ + -v $(pwd)/terminal-bench:/opt/terminal-bench \ + --name \ + slimerl/slime:latest \ + /bin/bash ``` -What it does: -- Uses `OPENAI_API_KEY=EMPTY` -- Runs `tb run -a terminus-2 -m openai/ ... --n-concurrent 8` -- Waits for completion, then returns `accuracy` and `n_resolved` - -## Step 3: Quick sanity check (curl, async) - -Run a single task (e.g. `hello-world`). The server returns a `job_id` -immediately, then you poll the status endpoint. +## 3) Inside the Slime container ```bash -# Submit job -curl -X POST http://localhost:9050/evaluate \ - -H 'Content-Type: application/json' \ - -d '{"model_name":"qwen3-8b","api_base":"http://127.0.0.1:30001/v1","dataset_path":"/mnt/data/xinyu/program/slime-tb/terminal-bench/tasks","task_id":"hello-world","n_concurrent":1}' +docker exec -it /bin/bash ``` -The response includes `job_id` and `status_url`, for example: - -```json -{"job_id":"...","status":"queued","status_url":"/status/", ...} -``` +## 4) Terminal Bench environment (host) -Poll status until `completed`: +Run on the machine that will host `tb_server.py` (where you cloned both repos): ```bash -curl http://localhost:9050/status/ -``` +uv venv --python 3.13 .venv +source .venv/bin/activate -Where to check outputs: -- Logs: `/mnt/data/xinyu/program/slime-tb/tb_eval_logs/.log` -- Results: `/tmp/tb-eval//results.json` - -## Step 4: Configure Slime eval - -You need an eval config. Example: - -```yaml -eval: - # Slime still needs normal eval datasets (can be any small one). - datasets: - - name: aime - path: /root/datasets/aime-2024/aime-2024.jsonl - rm_type: math - - # TB delegate config. - delegate: - - name: terminal_bench - url: http://localhost:9050 # "/evaluate" auto-added if missing - timeout_secs: 1200 # 20 minutes - model_name: qwen3-8b - api_base: http://127.0.0.1:30001/v1 - dataset_path: /mnt/data/xinyu/program/slime-tb/terminal-bench/tasks - n_tasks: 10 - n_concurrent: 1 - # Optional: run specific tasks instead of n_tasks - # task_ids: ["hello-world"] - # task_id: "hello-world" +uv pip install terminal-bench/. +uv pip install -r slime/examples/eval/terminal_bench/requirements.txt ``` Notes: -- `model_name` is auto-normalized to `openai/` if you omit the prefix. -- The TB client auto-adds `/evaluate` if you give a bare host:port. -- `task_id` / `task_ids` overrides `n_tasks` when provided. -- `dataset_path` lets you run from any working directory. - -## Step 5: Tell Slime to use the delegate rollout - -Add this to your training/eval command: - -```bash ---eval-config /path/to/your_eval_config.yaml \ ---eval-function-path examples.eval.eval_delegate_rollout.generate_rollout -``` - -This makes Slime call the TB delegate during evaluation. +- Use your local repo paths if they are not `./slime` and `./terminal-bench`. -## Quick sanity check (eval-only) +## 5) Start the TB server -If you just want to verify the TB integration, run a quick eval-only pass -(you still need your normal Slime args for model/data/etc.): +Run on the host (same machine where `tb` works): ```bash -python slime/train.py \ - --num-rollout 0 \ - --eval-interval 1 \ - --eval-config /path/to/your_eval_config.yaml \ - --eval-function-path examples.eval.eval_delegate_rollout.generate_rollout \ - ...other required args... +python slime/examples/eval/terminal_bench/tb_server.py \ + --host 0.0.0.0 --port 9051 \ + --output-root tb_eval_output ``` -## Common gotchas - -- 404 from TB server: use `url: http://localhost:9050` or `.../evaluate`. -- Timeouts: keep `timeout_secs` large (TB tasks can compile code). -- No TB metrics: check `/tmp/tb-eval//results.json` and poll `/status/`. -- No output in terminal: tail the log at `/mnt/data/xinyu/program/slime-tb/tb_eval_logs/.log`. +What it does: +- Uses `OPENAI_API_KEY=EMPTY` +- Runs `tb run -a terminus-2 -m openai/ ... --n-concurrent 8` +- Waits for completion, then returns `accuracy`, `n_resolved`, + `n_unresolved`, `pass_at_k/*`, and token stats such as + `total_input_tokens_mean/median` and `total_output_tokens_mean/median` -## Reference: the CLI command it runs +## 6) Run the eval script (example) -The server is aligned with: +If you use the provided Qwen eval launcher: ```bash -OPENAI_API_KEY=EMPTY tb run -a terminus-2 -m openai/qwen3-8b \ - --agent-kwarg api_base=http://127.0.0.1:30001/v1 \ - --n-concurrent 1 +bash slime/examples/eval/scripts/run_eval_tb_qwen.sh 2>&1 | tee run.log ``` From 98b5ce42ac29e5684eea51db95cb5d8900fcedf9 Mon Sep 17 00:00:00 2001 From: Zhiyao Jiang Date: Wed, 31 Dec 2025 15:52:11 -0500 Subject: [PATCH 5/6] modify code and quick-start based on review comments Co-authored-by: Zhiyao Jiang Co-authored-by: Xinyu Jiang --- examples/eval/{ => nemo_skills}/README.md | 0 ...val_tb_smoke.yaml => eval_tb_example.yaml} | 16 +++--- ...un_eval_tb_qwen.sh => run-eval-tb-qwen.sh} | 17 ++++--- examples/eval/terminal_bench/README.md | 47 +++++++++++++++-- .../terminal_bench/config/local_cluster.yaml | 12 ----- examples/eval/terminal_bench/tb_config.py | 51 ++++++++----------- examples/eval/terminal_bench/tb_server.py | 17 +++---- 7 files changed, 87 insertions(+), 73 deletions(-) rename examples/eval/{ => nemo_skills}/README.md (100%) rename examples/eval/scripts/{eval_tb_smoke.yaml => eval_tb_example.yaml} (70%) rename examples/eval/scripts/{run_eval_tb_qwen.sh => run-eval-tb-qwen.sh} (91%) delete mode 100644 examples/eval/terminal_bench/config/local_cluster.yaml diff --git a/examples/eval/README.md b/examples/eval/nemo_skills/README.md similarity index 100% rename from examples/eval/README.md rename to examples/eval/nemo_skills/README.md diff --git a/examples/eval/scripts/eval_tb_smoke.yaml b/examples/eval/scripts/eval_tb_example.yaml similarity index 70% rename from examples/eval/scripts/eval_tb_smoke.yaml rename to examples/eval/scripts/eval_tb_example.yaml index 8ff19959f..5104ae6e1 100644 --- a/examples/eval/scripts/eval_tb_smoke.yaml +++ b/examples/eval/scripts/eval_tb_example.yaml @@ -1,7 +1,7 @@ eval: defaults: n_samples_per_eval_prompt: 1 - # temperature: 0.6 + temperature: 0.6 top_p: 0.95 top_k: -1 max_response_len: 24576 @@ -16,14 +16,14 @@ eval: n_samples_per_eval_prompt: 1 delegate: - name: terminal_bench - # type: examples.eval.terminal_bench.tb_config.build_terminal_bench_config - url: http://172.17.0.1:9052 + url: http://172.17.0.1:9051 # Port must match the TB server running on the host machine timeout_secs: 86400 # 24 hours max_retries: 1 # HTTP request retries from Slime to the TB server model_name: qwen3-8b - api_base: http://127.0.1.1:30005/v1 - dataset_path: /mnt/data/xinyu/program/slime-tb/terminal-bench/tasks - task_id: hello-world + api_base: http://127.0.0.1:30005/v1 # Port must match the sglang router port set in run-eval-tb-qwen.sh + dataset_path: /mnt/data/xinyu/program/slime-tb/terminal-bench/tasks # Dataset path on the host machine + # task_ids: + # - hello-world # n_tasks: 10 - n_attempts: 2 # TB task-level retries (per task within tb run) - n_concurrent: 8 + n_attempts: 1 # TB task-level retries (per task within tb run) + n_concurrent: 8 \ No newline at end of file diff --git a/examples/eval/scripts/run_eval_tb_qwen.sh b/examples/eval/scripts/run-eval-tb-qwen.sh similarity index 91% rename from examples/eval/scripts/run_eval_tb_qwen.sh rename to examples/eval/scripts/run-eval-tb-qwen.sh index fb98cb4d4..589c24501 100644 --- a/examples/eval/scripts/run_eval_tb_qwen.sh +++ b/examples/eval/scripts/run-eval-tb-qwen.sh @@ -16,8 +16,9 @@ pkill -9 python set -ex export PYTHONBUFFERED=16 +export SLIME_HOST_IP=${SLIME_HOST_IP:-"127.0.0.1"} -MODEL_DIR="${MODEL_DIR:-/mnt/data/xinyu}" +MODEL_DIR="${MODEL_DIR:-/root/.cache}" export MODEL_DIR NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) @@ -33,10 +34,10 @@ REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../../.." &>/dev/null && pwd)" source "${REPO_ROOT}/scripts/models/qwen3-8B.sh" # Store eval/delegate settings in a YAML config similar to examples/eval_multi_task. -EVAL_CONFIG_PATH=${TB_EVAL_CONFIG_PATH:-"${REPO_ROOT}/examples/eval/scripts/eval_tb_smoke.yaml"} +EVAL_CONFIG_PATH=${TB_EVAL_CONFIG_PATH:-"${REPO_ROOT}/examples/eval/scripts/eval_tb_example.yaml"} CKPT_ARGS=( - --hf-checkpoint ${MODEL_DIR}/OpenThinker-Agent-v1 + --hf-checkpoint ${MODEL_DIR}/OpenThinker-Agent-v1 # huggingface-cli download open-thoughts/OpenThinker-Agent-v1 --ref-load ${MODEL_DIR}/OpenThinker-Agent-v1_torch_dist # --load ${MODEL_DIR}/OpenThinker-Agent-v1_slime/ --save ${MODEL_DIR}/OpenThinker-Agent-v1_slime/ @@ -50,8 +51,8 @@ ROLLOUT_ARGS=( --apply-chat-template --rollout-shuffle --rm-type deepscaler - # --num-rollout 3000 - --num-rollout 1 + --num-rollout 3000 + # --num-rollout 1 --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 @@ -61,8 +62,8 @@ ROLLOUT_ARGS=( ) EVAL_ARGS=( - # --eval-interval 5 - --eval-interval 1 + --eval-interval 5 + # --eval-interval 1 --eval-config "${EVAL_CONFIG_PATH}" --eval-function-path examples.eval.eval_delegate_rollout.generate_rollout ) @@ -123,7 +124,7 @@ MISC_ARGS=( ) export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} -export CUDA_VISIBLE_DEVICES=6,7 +export CUDA_VISIBLE_DEVICES=0,1 ray start --head --node-ip-address ${MASTER_ADDR} --port 6380 --num-gpus 2 \ --disable-usage-stats \ diff --git a/examples/eval/terminal_bench/README.md b/examples/eval/terminal_bench/README.md index 4e749e2a4..2a59dccdb 100644 --- a/examples/eval/terminal_bench/README.md +++ b/examples/eval/terminal_bench/README.md @@ -23,7 +23,7 @@ This folder wires Terminal Bench (TB) into Slime as an eval delegate. The TB run ## 1) Get the code (host) ```bash -git clone --branch xinyu/quick_start https://github.com/XinyuJiangCMU/slime.git +git clone https://github.com/THUDM/slime.git git clone https://github.com/laude-institute/terminal-bench ``` @@ -69,7 +69,7 @@ uv pip install -r slime/examples/eval/terminal_bench/requirements.txt Notes: - Use your local repo paths if they are not `./slime` and `./terminal-bench`. -## 5) Start the TB server +## 5) Start the Terminal Bench server Run on the host (same machine where `tb` works): @@ -88,8 +88,47 @@ What it does: ## 6) Run the eval script (example) -If you use the provided Qwen eval launcher: +If you use the provided Qwen eval launcher (`run-eval-tb-qwen.sh`), follow the steps below to run Terminal-Bench evaluation. + +First, update the `dataset_path` in `eval_tb_example.yaml` to the local path of `terminal-bench/tasks` on your host (not an internal Docker-only path). + +Then download the HuggingFace model checkpoint inside the Slime container: + +```bash +huggingface-cli download open-thoughts/OpenThinker-Agent-v1 \ +--local-dir /root/.cache/OpenThinker-Agent-v1 +``` + +After downloading, convert the HuggingFace checkpoint to Slime's torch distributed format. From the Slime root directory, run: ```bash -bash slime/examples/eval/scripts/run_eval_tb_qwen.sh 2>&1 | tee run.log +cd /opt/slime +source scripts/models/qwen3-8B.sh + +PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/.cache/OpenThinker-Agent-v1 \ + --save /root/.cache/OpenThinker-Agent-v1_torch_dist ``` + +Finally, run the following command inside the Slime container: + +```bash +bash slime/examples/eval/scripts/run-eval-tb-qwen.sh 2>&1 | tee run.log +``` + +For convenience, you can restrict the evaluation scope in `eval_tb_example.yaml`, either by specifying a single task or multiple tasks (`task_ids`), or by limiting the number of tasks via `n_tasks`. + +## 7) Common Issues + +When running Slime inside a Docker container with `--network host`, Ray may encounter port conflicts due to shared networking with the host. + +In some cases, this manifests as Ray failing to start or reporting Redis- or session-related errors. This can usually be resolved by explicitly assigning unused ports when starting the Ray head node, for example by setting a non-default `--port` and `--dashboard-port`. + +In more severe cases, Ray job submission may fail with errors indicating that no available agent can accept jobs. This typically happens when the dashboard agent or runtime environment agent ports are also in conflict. In such situations, explicitly specifying the agent-related ports (e.g. `--dashboard-agent-listen-port`, `--dashboard-agent-grpc-port`, and `--runtime-env-agent-port`) when starting Ray can resolve the issue. + +If the TB server cannot connect to the Slime server through the sglang router, check which address is actually listening on the router port (e.g. 30005 in this example) and update the `api_base` in `eval_tb_example.yaml` accordingly: + +```bash +ss -lntp | grep 30005 +``` \ No newline at end of file diff --git a/examples/eval/terminal_bench/config/local_cluster.yaml b/examples/eval/terminal_bench/config/local_cluster.yaml deleted file mode 100644 index c4a824a1b..000000000 --- a/examples/eval/terminal_bench/config/local_cluster.yaml +++ /dev/null @@ -1,12 +0,0 @@ -# Minimal Terminal Bench delegate config for running on the host (no containers). - -type: examples.eval.terminal_bench.tb_config.build_terminal_bench_config - -name: terminal_bench -url: http://localhost:9050/evaluate -timeout_secs: 1200 - -model_name: qwen3-8b -api_base: http://172.17.0.1:30001/v1 -n_tasks: 10 -n_concurrent: 4 diff --git a/examples/eval/terminal_bench/tb_config.py b/examples/eval/terminal_bench/tb_config.py index 3c401d9e1..adb4f2c30 100644 --- a/examples/eval/terminal_bench/tb_config.py +++ b/examples/eval/terminal_bench/tb_config.py @@ -15,48 +15,39 @@ class TerminalBenchConfig(EvalEnvConfig): api_base: str = "http://127.0.1.1:30001/v1" dataset_path: str | None = None n_tasks: int | None = None - task_id: str | None = None task_ids: list[str] = field(default_factory=list) n_attempts: int | None = None n_concurrent: int = 8 - @classmethod def parse(cls, args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, Any]) -> TerminalBenchConfig: clean_raw = dict(raw_env_config or {}) clean_raw.pop("type", None) base_cfg: TerminalBenchConfig = super().parse(clean_raw, defaults) - model_name = clean_raw.get("model_name") - if model_name is not None: - base_cfg.model_name = str(model_name) - api_base = clean_raw.get("api_base") - if api_base is not None: - base_cfg.api_base = str(api_base) - n_attempts = clean_raw.get("n_attempts") - if n_attempts is not None: - base_cfg.n_attempts = int(n_attempts) - n_tasks = clean_raw.get("n_tasks") - if n_tasks is not None: - base_cfg.n_tasks = int(n_tasks) - n_concurrent = clean_raw.get("n_concurrent") - if n_concurrent is not None: - base_cfg.n_concurrent = int(n_concurrent) - dataset_path = clean_raw.get("dataset_path") - if dataset_path is not None: - base_cfg.dataset_path = str(dataset_path) - task_id = clean_raw.get("task_id") - if task_id is not None: - base_cfg.task_id = str(task_id) + + field_casts = { + "model_name": str, + "api_base": str, + "n_attempts": int, + "n_tasks": int, + "n_concurrent": int, + "dataset_path": str, + } + + for key, caster in field_casts.items(): + value = clean_raw.get(key) + if value is not None: + setattr(base_cfg, key, caster(value)) + task_ids = clean_raw.get("task_ids") - if task_ids is None: - task_ids = task_id - if task_ids is not None: - if isinstance(task_ids, (list, tuple)): - base_cfg.task_ids = [str(item) for item in task_ids if item] - else: - base_cfg.task_ids = [str(task_ids)] + if isinstance(task_ids, (list, tuple)): + base_cfg.task_ids = [str(item) for item in task_ids if item] + elif task_ids is not None: + raise ValueError("task_ids must be a list") + return base_cfg + def build_terminal_bench_config(args, raw_env_config: Mapping[str, Any], defaults: Mapping[str, Any]): return TerminalBenchConfig.parse(args, raw_env_config, defaults) diff --git a/examples/eval/terminal_bench/tb_server.py b/examples/eval/terminal_bench/tb_server.py index 0a75a9128..a43537faa 100644 --- a/examples/eval/terminal_bench/tb_server.py +++ b/examples/eval/terminal_bench/tb_server.py @@ -56,7 +56,6 @@ class EvalRequestPayload: n_concurrent: int | None = None dataset_path: str | None = None task_ids: list[str] | None = None - task_id: str | None = None n_attempts: int | None = None metric_prefix: str | None = None @@ -124,7 +123,7 @@ def __init__(self, config: ServerConfig): self._jobs_lock = threading.Lock() self._jobs: dict[str, JobRecord] = {} self._config.output_root.mkdir(parents=True, exist_ok=True) - self._log_root = REPO_ROOT / "tb_eval_logs" + self._log_root = REPO_ROOT.parent / "tb_eval_logs" self._log_root.mkdir(parents=True, exist_ok=True) def evaluate(self, payload: EvalRequestPayload) -> dict[str, Any]: @@ -237,20 +236,17 @@ def _build_command(self, payload: EvalRequestPayload, run_id: str) -> list[str]: # 3. Add Agent kwargs (Use api_base exactly like the CLI command) if payload.api_base: cmd.extend(["--agent-kwarg", f"api_base={payload.api_base}"]) - - # 4. Add n_tasks if present - task_ids = [] - if payload.task_ids: - task_ids.extend([str(item) for item in payload.task_ids if item]) - if payload.task_id: - task_ids.append(str(payload.task_id)) - + if payload.dataset_path: cmd.extend(["--dataset-path", payload.dataset_path]) if payload.n_attempts is not None: cmd.extend(["--n-attempts", str(payload.n_attempts)]) + # 4. Add n_tasks if present + task_ids = [] + if payload.task_ids: + task_ids.extend([str(item) for item in payload.task_ids if item]) if task_ids: for task_id in task_ids: cmd.extend(["--task-id", task_id]) @@ -405,7 +401,6 @@ def status_endpoint(job_id: str): # Entry point # --------------------------------------------------------------------------- - def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run the Terminal Bench evaluation HTTP server.") parser.add_argument("--host", type=str, default="0.0.0.0") From 1fd519d96f627bd3bb0e223395a7fba7ab2968b3 Mon Sep 17 00:00:00 2001 From: Xinyu Jiang Date: Thu, 1 Jan 2026 00:06:05 -0500 Subject: [PATCH 6/6] add README-cn.md Co-authored-by: Zhiyao Jiang Co-authored-by: Xinyu Jiang --- examples/eval/scripts/run-eval-tb-qwen.sh | 6 +- examples/eval/terminal_bench/README-cn.md | 122 ++++++++++++++++++++++ examples/eval/terminal_bench/README.md | 33 +++--- examples/eval/terminal_bench/__init__.py | 2 +- 4 files changed, 139 insertions(+), 24 deletions(-) create mode 100644 examples/eval/terminal_bench/README-cn.md diff --git a/examples/eval/scripts/run-eval-tb-qwen.sh b/examples/eval/scripts/run-eval-tb-qwen.sh index 589c24501..67434f8ec 100644 --- a/examples/eval/scripts/run-eval-tb-qwen.sh +++ b/examples/eval/scripts/run-eval-tb-qwen.sh @@ -52,7 +52,6 @@ ROLLOUT_ARGS=( --rollout-shuffle --rm-type deepscaler --num-rollout 3000 - # --num-rollout 1 --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 @@ -63,14 +62,13 @@ ROLLOUT_ARGS=( EVAL_ARGS=( --eval-interval 5 - # --eval-interval 1 --eval-config "${EVAL_CONFIG_PATH}" --eval-function-path examples.eval.eval_delegate_rollout.generate_rollout ) PERF_ARGS=( --tensor-model-parallel-size 1 - --pipeline-model-parallel-size 1 + --pipeline-model-parallel-size 1 --context-parallel-size 1 --expert-model-parallel-size 1 --expert-tensor-parallel-size 1 @@ -106,7 +104,7 @@ WANDB_ARGS=( --use-wandb --wandb-project slime-eval --wandb-group qwen3-8b-eval - --wandb-key ${WANDB_KEY} + --wandb-key ${WANDB_KEY} # export WANDB_KEY="your_key" ) SGLANG_ARGS=( diff --git a/examples/eval/terminal_bench/README-cn.md b/examples/eval/terminal_bench/README-cn.md new file mode 100644 index 000000000..057a945b2 --- /dev/null +++ b/examples/eval/terminal_bench/README-cn.md @@ -0,0 +1,122 @@ +# Terminal Bench 评估集成 + +本目录将 Terminal Bench (TB) 封装为 Slime 的评估委托(Eval Delegate)。评估过程在宿主机(Host)上通过 `tb` CLI 执行,Slime 负责读取并汇总各项指标,包括 `accuracy`、`n_resolved`、`n_unresolved`、`pass_at_k/*` 以及 Token 统计数据(如 `total_input_tokens_mean/median` 和 `total_output_tokens_mean/median`)。 + +## 运行架构 + +* **Slime 内部**:运行训练/评估主循环;调用 TB delegate client。 +* **宿主机(Host)**:运行 TB delegate server (`tb_server.py`),由其执行 `tb run ...`。 +* **Server逻辑**:读取最新的 TB JSON 结果并将各项指标返回给 Slime。 + +## 1) 获取代码 (宿主机) + +```bash +mkdir slime-tb +cd slime-tb +git clone https://github.com/THUDM/slime.git +git clone https://github.com/laude-institute/terminal-bench +``` + +## 2) 启动 Slime 容器 + +```bash +docker run \ + -itd \ + --gpus all \ + --shm-size 32g \ + --network host \ + --ipc=host \ + --privileged \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + --ulimit nofile=65536:65536 \ + -v /mnt/data/.cache:/root/.cache \ + -v $(pwd):/shared/slime-tb \ + --name \ + slimerl/slime:latest \ + /bin/bash +``` + +## 3) 进入 Slime 容器 + +```bash +docker exec -it /bin/bash +``` + +## 4) 配置 Terminal Bench 环境 (宿主机) + +在运行 `tb_server.py` 的宿主机上执行: + +```bash +# 在宿主机终端执行(非 Docker 内部) +uv venv --python 3.13 .venv +source .venv/bin/activate +uv pip install terminal-bench/. +uv pip install -r slime/examples/eval/terminal_bench/requirements.txt +``` + +*如果仓库路径不是 `./slime` 和 `./terminal-bench`,请根据实际路径调整。* + +## 5) 启动 Terminal Bench server + +在宿主机上启动(即 `tb` 命令可用的环境): + +```bash +python slime/examples/eval/terminal_bench/tb_server.py \ + --host 0.0.0.0 --port 9051 \ + --output-root tb_eval_output +``` + +**该脚本的功能:** + +* 默认设置 `OPENAI_API_KEY=EMPTY`。 +* 执行 `tb run -a terminus-2 -m openai/ ... --n-concurrent 8`。 +* 等待运行完成后,返回 `accuracy`、`pass_at_k` 以及 Token 消耗等统计数据。 + +## 6) 运行评估脚本 (示例) + +如果使用提供的 Qwen 评估启动脚本 (`run-eval-tb-qwen.sh`),请按以下步骤操作: + +**更新路径**:将 `eval_tb_example.yaml` 中的 `dataset_path` 修改为宿主机上 `terminal-bench/tasks` 的**绝对路径**(注意不是 Docker 内部路径)。 + +**下载模型**:在 Slime 容器内下载 HuggingFace 权重: +```bash +huggingface-cli download open-thoughts/OpenThinker-Agent-v1 \ +--local-dir /root/.cache/OpenThinker-Agent-v1 +``` + +**格式转换**:将 HuggingFace 权重转换为 Slime 的 torch distributed 格式。在 Slime 根目录下执行: +```bash +cd /shared/slime-tb/slime +source scripts/models/qwen3-8B.sh + +export PYTHONPATH=/root/Megatron-LM:/shared/slime-tb/slime + +python tools/convert_hf_to_torch_dist.py \ + ${MODEL_ARGS[@]} \ + --hf-checkpoint /root/.cache/OpenThinker-Agent-v1 \ + --save /root/.cache/OpenThinker-Agent-v1_torch_dist +``` + +**开始评估**:在 Slime 容器内运行: +```bash +bash slime/examples/eval/scripts/run-eval-tb-qwen.sh 2>&1 | tee run.log +``` + +*为了快速测试,可以在 `eval_tb_example.yaml` 中通过 `task_ids` 指定特定任务,或通过 `n_tasks` 限制评估任务的数量。* + +## 7) 常见问题 + +当在 Docker 容器中使用 `--network host` 运行 Slime 时,Ray 可能由于与宿主机共享网络而出现端口冲突。 + +这会导致 Ray 启动失败,或报 Redis/会话相关错误。通常可以在启动 Ray head 时显式指定未占用端口来解决,比如设置非默认的 `--port` 和 `--dashboard-port`。 + +有时甚至会导致 Ray job 提交失败,提示没有可用 agent 接受任务。这通常是 dashboard agent 或 runtime env agent 的端口也发生冲突。此时可在启动 Ray 时指定这些端口(如 `--dashboard-agent-listen-port`、`--dashboard-agent-grpc-port`、`--runtime-env-agent-port`)来解决。 + +如果 TB server无法通过 sglang router 连接到 Slime(`InternalServerError`),请检查 router 端口(例如 30005)实际监听的地址,并更新 `eval_tb_example.yaml` 中的 `api_base`: + +```bash +ss -lntp | grep 30005 +``` + +TB server开始接受请求后,可能会在输出中看到 `Parser warnings`、`Context length exceeded`、`Command 1 should end with newline`、`Harness execution failed`等。这些是Terminal Bench 的警告,如果正常运行可以忽略。 \ No newline at end of file diff --git a/examples/eval/terminal_bench/README.md b/examples/eval/terminal_bench/README.md index 2a59dccdb..125bb1756 100644 --- a/examples/eval/terminal_bench/README.md +++ b/examples/eval/terminal_bench/README.md @@ -1,4 +1,4 @@ -# Terminal Bench Eval (Slime) +# Terminal Bench Eval This folder wires Terminal Bench (TB) into Slime as an eval delegate. The TB run happens on the host via the `tb` CLI, and Slime reads back aggregated metrics such as `accuracy`, `n_resolved`, `n_unresolved`, `pass_at_k/*`, and token stats like `total_input_tokens_mean/median` and `total_output_tokens_mean/median`. @@ -9,20 +9,11 @@ This folder wires Terminal Bench (TB) into Slime as an eval delegate. The TB run - The TB delegate server (`tb_server.py`) runs `tb run ...` on the host. - The server reads the latest TB JSON results and returns metrics to Slime. -## Prereqs - -1) Docker with GPU access. -2) `uv` installed on the host. -3) Terminal Bench installed and its `tb` CLI available on the machine that runs - `tb_server.py`. -4) The Slime repo available on the machine that runs `tb_server.py`. -5) A Slime eval config file that includes `eval.datasets`. - - Slime requires at least one dataset under `eval.datasets`. - - You can reuse your existing eval config; just add the delegate section. - ## 1) Get the code (host) ```bash +mkdir slime-tb +cd slime-tb git clone https://github.com/THUDM/slime.git git clone https://github.com/laude-institute/terminal-bench ``` @@ -40,9 +31,8 @@ docker run \ --ulimit memlock=-1 \ --ulimit stack=67108864 \ --ulimit nofile=65536:65536 \ - -v ~/.cache:/root/.cache \ - -v $(pwd)/slime:/opt/slime \ - -v $(pwd)/terminal-bench:/opt/terminal-bench \ + -v /mnt/data/.cache:/root/.cache \ + -v $(pwd):/shared/slime-tb \ --name \ slimerl/slime:latest \ /bin/bash @@ -59,6 +49,7 @@ docker exec -it /bin/bash Run on the machine that will host `tb_server.py` (where you cloned both repos): ```bash +# Host machine terminal (outside Docker) uv venv --python 3.13 .venv source .venv/bin/activate @@ -102,10 +93,12 @@ huggingface-cli download open-thoughts/OpenThinker-Agent-v1 \ After downloading, convert the HuggingFace checkpoint to Slime's torch distributed format. From the Slime root directory, run: ```bash -cd /opt/slime +cd /shared/slime-tb/slime source scripts/models/qwen3-8B.sh -PYTHONPATH=/root/Megatron-LM python tools/convert_hf_to_torch_dist.py \ +export PYTHONPATH=/root/Megatron-LM:/shared/slime-tb/slime + +python tools/convert_hf_to_torch_dist.py \ ${MODEL_ARGS[@]} \ --hf-checkpoint /root/.cache/OpenThinker-Agent-v1 \ --save /root/.cache/OpenThinker-Agent-v1_torch_dist @@ -127,8 +120,10 @@ In some cases, this manifests as Ray failing to start or reporting Redis- or ses In more severe cases, Ray job submission may fail with errors indicating that no available agent can accept jobs. This typically happens when the dashboard agent or runtime environment agent ports are also in conflict. In such situations, explicitly specifying the agent-related ports (e.g. `--dashboard-agent-listen-port`, `--dashboard-agent-grpc-port`, and `--runtime-env-agent-port`) when starting Ray can resolve the issue. -If the TB server cannot connect to the Slime server through the sglang router, check which address is actually listening on the router port (e.g. 30005 in this example) and update the `api_base` in `eval_tb_example.yaml` accordingly: +If the TB server cannot connect to the Slime server through the sglang router (`InternalServerError`), check which address is actually listening on the router port (e.g. 30005 in this example) and update the `api_base` in `eval_tb_example.yaml` accordingly: ```bash ss -lntp | grep 30005 -``` \ No newline at end of file +``` + +You may see `Parser warnings`, `Context length exceeded`, `Command 1 should end with newline`, `Harness execution failed` in `tb_server.py` logs. They are warnings from Terminal Bench and can be ignored if runs proceed normally. \ No newline at end of file diff --git a/examples/eval/terminal_bench/__init__.py b/examples/eval/terminal_bench/__init__.py index 6ba998ca7..6d2704250 100644 --- a/examples/eval/terminal_bench/__init__.py +++ b/examples/eval/terminal_bench/__init__.py @@ -1 +1 @@ -"""NeMo Skills evaluation helpers.""" +"""Terminal Bench evaluation helpers."""