From 3ca157e4050d9a6c582354f80862addfffb3563f Mon Sep 17 00:00:00 2001 From: eligotts <78387377+eligotts@users.noreply.github.com> Date: Tue, 20 Jan 2026 16:27:31 -0800 Subject: [PATCH 1/2] implemented protocol wrapper on top of environment. protocol owns dataset, and environments are registered within protocols. protocol exposes a spawn method to allow recursive running of environments --- verifiers/__init__.py | 10 + verifiers/envs/actor.py | 19 + .../envs/experimental/spawning_rlm_env.py | 757 ++++++++++++++++++ verifiers/envs/multiagent_env.py | 232 ++++++ verifiers/envs/protocol.py | 361 +++++++++ verifiers/rl/trainer/__init__.py | 2 + .../rl/trainer/multiagent_orchestrator.py | 215 +++++ verifiers/rubrics/multiagent_rubric.py | 121 +++ verifiers/types.py | 4 +- 9 files changed, 1720 insertions(+), 1 deletion(-) create mode 100644 verifiers/envs/actor.py create mode 100644 verifiers/envs/experimental/spawning_rlm_env.py create mode 100644 verifiers/envs/multiagent_env.py create mode 100644 verifiers/envs/protocol.py create mode 100644 verifiers/rl/trainer/multiagent_orchestrator.py create mode 100644 verifiers/rubrics/multiagent_rubric.py diff --git a/verifiers/__init__.py b/verifiers/__init__.py index 535a16870..01d4a5f28 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -18,6 +18,9 @@ from .envs.environment import Environment # noqa # isort: skip from .envs.multiturn_env import MultiTurnEnv # noqa # isort: skip from .envs.tool_env import ToolEnv # noqa # isort: skip +from .envs.actor import Actor # noqa # isort: skip +from .envs.protocol import Protocol # noqa # isort: skip +from .envs.multiagent_env import MultiAgentEnv, SingleTurnMAEnv # noqa # isort: skip # main imports from .envs.env_group import EnvGroup @@ -62,12 +65,16 @@ "MCPEnv", "Environment", "MultiTurnEnv", + "MultiAgentEnv", + "SingleTurnMAEnv", "SingleTurnEnv", "PythonEnv", "SandboxEnv", "StatefulToolEnv", "ToolEnv", "EnvGroup", + "Actor", + "Protocol", "extract_boxed_answer", "extract_hash_answer", "load_example_dataset", @@ -80,6 +87,7 @@ "get_model_and_tokenizer", "RLTrainer", "RLConfig", + "MultiAgentOrchestrator", "GRPOTrainer", "GRPOConfig", "grpo_defaults", @@ -94,6 +102,7 @@ "get_model_and_tokenizer": "verifiers.rl.trainer.utils:get_model_and_tokenizer", "RLConfig": "verifiers.rl.trainer:RLConfig", "RLTrainer": "verifiers.rl.trainer:RLTrainer", + "MultiAgentOrchestrator": "verifiers.rl.trainer:MultiAgentOrchestrator", "GRPOTrainer": "verifiers.rl.trainer:GRPOTrainer", "GRPOConfig": "verifiers.rl.trainer:GRPOConfig", "grpo_defaults": "verifiers.rl.trainer:grpo_defaults", @@ -135,6 +144,7 @@ def __getattr__(name: str): from .rl.trainer import ( # noqa: F401 GRPOConfig, GRPOTrainer, + MultiAgentOrchestrator, RLConfig, RLTrainer, grpo_defaults, diff --git a/verifiers/envs/actor.py b/verifiers/envs/actor.py new file mode 100644 index 000000000..10e84d28c --- /dev/null +++ b/verifiers/envs/actor.py @@ -0,0 +1,19 @@ +""" +Actor: A trainable entity in multi-agent environments. + +Actors are registered to a Protocol and define the system prompt +used when making model calls. +""" +from dataclasses import dataclass + + +@dataclass +class Actor: + """ + A trainable actor. Registered to Protocol. + + The system_prompt is used when this actor makes model calls. + """ + + id: str + system_prompt: str = "" diff --git a/verifiers/envs/experimental/spawning_rlm_env.py b/verifiers/envs/experimental/spawning_rlm_env.py new file mode 100644 index 000000000..754bb8849 --- /dev/null +++ b/verifiers/envs/experimental/spawning_rlm_env.py @@ -0,0 +1,757 @@ +""" +Spawning Language Model (RLM) Environment with Protocol/MultiAgentEnv support. + +Demonstrates the spawning pattern: +1. Give RLM an initial task +2. It can run code in REPL (call_python_repl) +3. It has a tool to spawn sub-instances (spawn_rlm via protocol.spawn()) +4. Sub-RLMs get parent context plus task +5. Scored by MultiAgentRubric for reward propagation +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import signal +import subprocess +import tempfile +import textwrap +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, cast + +from openai.types.chat import ChatCompletionAssistantMessageParam + +import verifiers as vf +from verifiers.envs.actor import Actor +from verifiers.envs.multiagent_env import MultiAgentEnv +from verifiers.rubrics.multiagent_rubric import MultiAgentRubric +from verifiers.types import Messages, RolloutInput, State +from verifiers.utils.async_utils import maybe_await +from verifiers.utils.tool_utils import convert_func_to_oai_tool + +if TYPE_CHECKING: + from verifiers.envs.protocol import Protocol + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Data Classes +# ============================================================================= + + +@dataclass(frozen=True) +class RLMWorkerPaths: + base_dir: str + command_fifo: str + response_fifo: str + ready_flag: str + worker_path: str + context_file: str + answer_file: str + + +@dataclass +class LocalRLMSession: + rollout_id: str + temp_dir: tempfile.TemporaryDirectory + paths: RLMWorkerPaths + worker_process: subprocess.Popen | None = None + + +# ============================================================================= +# Worker Script - Simple REPL with extra_data and answer +# ============================================================================= + + +_RLM_WORKER_SCRIPT = textwrap.dedent( + ''' + import ast + import contextlib + import io + import json + import os + import traceback + from pathlib import Path + + COMMAND_FIFO = "{command_fifo}" + RESPONSE_FIFO = "{response_fifo}" + READY_FLAG = "{ready_flag}" + CONTEXT_FILE = "{context_file}" + ANSWER_FILE = "{answer_file}" + + def ensure_fifo(path: str) -> None: + if os.path.exists(path): + os.remove(path) + os.mkfifo(path) + + for fifo_path in (COMMAND_FIFO, RESPONSE_FIFO): + ensure_fifo(fifo_path) + + # Load extra_data from context file + extra_data = None + if Path(CONTEXT_FILE).exists(): + with open(CONTEXT_FILE, "r", encoding="utf-8") as f: + extra_data = json.load(f).get("extra_data") + + # Initialize answer + answer = {{"ready": False, "content": ""}} + if Path(ANSWER_FILE).exists(): + with open(ANSWER_FILE, "r", encoding="utf-8") as f: + answer = json.load(f) + + # Execution namespace + namespace: dict[str, object] = {{ + "__name__": "__main__", + "__builtins__": __builtins__, + "extra_data": extra_data, + "answer": answer, + }} + + # Signal ready + Path(READY_FLAG).write_text("ready", encoding="utf-8") + + execution_count = 0 + + while True: + with open(COMMAND_FIFO, "r", encoding="utf-8") as cmd_file: + payload = cmd_file.read() + if not payload: + continue + request = json.loads(payload) + if request.get("shutdown"): + break + + code = request.get("code", "") + seq = request.get("seq", 0) + execution_count += 1 + + result = {{ + "status": "ok", + "stdout": "", + "stderr": "", + "result": None, + "execution_count": execution_count, + "seq": seq, + "answer": namespace.get("answer", {{"ready": False, "content": ""}}), + }} + + stdout_buf = io.StringIO() + stderr_buf = io.StringIO() + + try: + with contextlib.redirect_stdout(stdout_buf), contextlib.redirect_stderr(stderr_buf): + module_ast = ast.parse(code, mode="exec") + body = list(module_ast.body) + trailing_expr = None + if body and isinstance(body[-1], ast.Expr): + trailing_expr = body.pop() + if body: + exec(compile(ast.Module(body=body, type_ignores=[]), "", "exec"), namespace, namespace) + if trailing_expr is not None: + value = eval(compile(ast.Expression(trailing_expr.value), "", "eval"), namespace, namespace) + if value is not None: + result["result"] = repr(value) + except Exception: + result["status"] = "error" + result["result"] = traceback.format_exc() + + result["stdout"] = stdout_buf.getvalue() + result["stderr"] = stderr_buf.getvalue() + result["answer"] = namespace.get("answer", {{"ready": False, "content": ""}}) + + # Persist answer + with open(ANSWER_FILE, "w", encoding="utf-8") as f: + json.dump(result["answer"], f) + + with open(RESPONSE_FIFO, "w", encoding="utf-8") as resp_file: + resp_file.write(json.dumps(result)) + ''' +) + + +# ============================================================================= +# System Prompt +# ============================================================================= + + +_RLM_SYSTEM_PROMPT = """You are an RLM (Recursive Language Model) - an iterative Python REPL where you explore data step by step. + +## Tools + +- `call_python_repl`: Execute Python code. State persists across calls. +- `spawn_rlm`: Spawn sub-RLMs in parallel. Pass a list of tasks, get back JSON results. Use this to break down complex tasks into smaller sub-tasks. + +## Available Variables + +- `extra_data`: Input data to process +- `answer`: Dict with `answer["content"]` (your answer) and `answer["ready"]` (set True to finish) + +## Workflow + +1. Explore: `print(type(extra_data))` - see what you have +2. Process: Work step by step, see output before continuing +3. Answer: `answer["content"] = "result"` then `answer["ready"] = True` + +**Important:** Never set `answer["ready"] = True` until you've seen execution output. +""" + + +# ============================================================================= +# Helper Functions +# ============================================================================= + + +def _build_worker_paths(base_dir: str) -> RLMWorkerPaths: + return RLMWorkerPaths( + base_dir=base_dir, + command_fifo=os.path.join(base_dir, "rlm_cmd"), + response_fifo=os.path.join(base_dir, "rlm_res"), + ready_flag=os.path.join(base_dir, "rlm_ready"), + worker_path=os.path.join(base_dir, "rlm_worker.py"), + context_file=os.path.join(base_dir, "rlm_context.json"), + answer_file=os.path.join(base_dir, "rlm_answer.json"), + ) + + +def _render_worker_script(paths: RLMWorkerPaths) -> str: + return _RLM_WORKER_SCRIPT.format( + command_fifo=paths.command_fifo, + response_fifo=paths.response_fifo, + ready_flag=paths.ready_flag, + context_file=paths.context_file, + answer_file=paths.answer_file, + ) + + +# ============================================================================= +# Local Executor - Subprocess with FIFO communication +# ============================================================================= + + +class LocalRLMExecutor: + def __init__(self, code_timeout: int = 120): + self.code_timeout = code_timeout + self._sessions: dict[str, LocalRLMSession] = {} + + async def setup(self, state: State, extra_data: Any) -> None: + """Create temp dir, write context, start worker.""" + rollout_id = state["rollout_id"] + temp_dir = tempfile.TemporaryDirectory(prefix=f"rlm_{rollout_id}_") + paths = _build_worker_paths(temp_dir.name) + + session = LocalRLMSession( + rollout_id=rollout_id, + temp_dir=temp_dir, + paths=paths, + ) + self._sessions[rollout_id] = session + + # Write context file + Path(paths.context_file).write_text( + json.dumps({"extra_data": extra_data}), encoding="utf-8" + ) + Path(paths.answer_file).write_text( + json.dumps({"ready": False, "content": ""}), encoding="utf-8" + ) + + # Write and start worker + worker_script = _render_worker_script(paths) + Path(paths.worker_path).write_text(worker_script, encoding="utf-8") + + process = subprocess.Popen( + ["python", "-u", paths.worker_path], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + session.worker_process = process + + # Wait for ready + await self._wait_for_ready(session) + + async def _wait_for_ready(self, session: LocalRLMSession, timeout: int = 30) -> None: + start = asyncio.get_event_loop().time() + while True: + if Path(session.paths.ready_flag).exists(): + return + if session.worker_process and session.worker_process.poll() is not None: + raise vf.SandboxError("Worker exited before ready") + if asyncio.get_event_loop().time() - start > timeout: + raise vf.SandboxError("Worker failed to start") + await asyncio.sleep(0.1) + + async def execute(self, code: str, state: State) -> dict[str, Any]: + """Send code to worker, get result.""" + session = self._sessions.get(state["rollout_id"]) + if not session or not session.worker_process: + raise vf.SandboxError("Session not initialized") + if session.worker_process.poll() is not None: + raise vf.SandboxError("Worker process not running") + + seq = state.get("_exec_seq", 0) + 1 + state["_exec_seq"] = seq + + def _do_io() -> str: + payload = json.dumps({"code": code, "seq": seq}) + with open(session.paths.command_fifo, "w", encoding="utf-8") as f: + f.write(payload) + with open(session.paths.response_fifo, "r", encoding="utf-8") as f: + return f.read() + + try: + raw = await asyncio.wait_for( + asyncio.to_thread(_do_io), + timeout=self.code_timeout, + ) + except asyncio.TimeoutError: + return { + "status": "error", + "result": f"Code execution timed out after {self.code_timeout}s", + "answer": {"ready": False, "content": ""}, + } + + try: + return json.loads(raw) + except json.JSONDecodeError: + return { + "status": "error", + "result": f"Failed to parse response: {raw[:200]}", + "answer": {"ready": False, "content": ""}, + } + + async def read_answer(self, state: State) -> str: + session = self._sessions.get(state.get("rollout_id", "")) + if not session: + return "" + try: + content = Path(session.paths.answer_file).read_text(encoding="utf-8") + return json.loads(content).get("content", "") + except Exception: + return "" + + async def cleanup(self, state: State) -> None: + rollout_id = state.get("rollout_id") + if not rollout_id: + return + session = self._sessions.pop(rollout_id, None) + if not session: + return + self._stop_worker(session) + session.temp_dir.cleanup() + + def _stop_worker(self, session: LocalRLMSession) -> None: + if not session.worker_process: + return + try: + if os.name != "nt": + os.killpg(session.worker_process.pid, signal.SIGTERM) + else: + session.worker_process.terminate() + session.worker_process.wait(timeout=5) + except Exception: + try: + if os.name != "nt": + os.killpg(session.worker_process.pid, signal.SIGKILL) + else: + session.worker_process.kill() + except Exception: + pass + session.worker_process = None + + async def teardown(self) -> None: + for session in list(self._sessions.values()): + self._stop_worker(session) + session.temp_dir.cleanup() + self._sessions.clear() + + +# ============================================================================= +# MultiAgentToolEnv - Tool handling on MultiAgentEnv +# ============================================================================= + + +class MultiAgentToolEnv(MultiAgentEnv): + """ToolEnv that extends MultiAgentEnv.""" + + def __init__( + self, + tools: list[Callable] | None = None, + max_turns: int = 10, + **kwargs, + ): + self.tools: list[Callable] = tools or [] + self.max_turns = max_turns + self.oai_tools = [convert_func_to_oai_tool(tool) for tool in self.tools] + self.tool_map: dict[str, Callable] = { + getattr(tool, "__name__", tool.__class__.__name__): tool + for tool in self.tools + } + self.skipped_args: dict[str, list[str]] = {} + + super().__init__(oai_tools=self.oai_tools, max_turns=max_turns, **kwargs) + + def add_tool(self, tool: Callable, args_to_skip: list[str] | None = None): + """Add a tool, optionally hiding arguments from the agent's view.""" + args_to_skip = args_to_skip or [] + self.tools.append(tool) + + # Build OAI tool schema, filtering skipped args + import inspect + sig = inspect.signature(tool) + filtered_params = [ + p for n, p in sig.parameters.items() + if n not in args_to_skip and n != "self" + ] + filtered_sig = sig.replace(parameters=filtered_params) + + def wrapper(*args, **kw): + return tool(*args, **kw) + + wrapper.__name__ = tool.__name__ + wrapper.__doc__ = tool.__doc__ + wrapper.__signature__ = filtered_sig + wrapper.__annotations__ = { + k: v for k, v in getattr(tool, "__annotations__", {}).items() + if k not in args_to_skip + } + + oai_tool = convert_func_to_oai_tool(wrapper) + if self.oai_tools is None: + self.oai_tools = [] + self.oai_tools.append(oai_tool) + + tool_name = tool.__name__ + self.tool_map[tool_name] = tool + self.skipped_args[tool_name] = args_to_skip + + def update_tool_args( + self, tool_name: str, tool_args: dict, messages: Messages, state: State, **kwargs + ) -> dict: + """Override to inject state-based args.""" + return tool_args + + @vf.stop + async def no_tools_called(self, state: State) -> bool: + if len(state["trajectory"]) == 0: + return False + last_msg = state["trajectory"][-1]["completion"][-1] + return last_msg["role"] == "assistant" and "tool_calls" not in last_msg + + async def call_tool( + self, tool_name: str, tool_args: dict, tool_call_id: str, **kwargs + ) -> vf.Message: + tool_func = self.tool_map[tool_name] + result = await maybe_await(tool_func, **tool_args) + return cast(vf.Message, {"role": "tool", "content": str(result), "tool_call_id": tool_call_id}) + + async def env_response(self, messages: Messages, state: State, **kwargs) -> Messages: + assert isinstance(messages, list) and "tool_calls" in messages[-1] + tool_messages = [] + last_msg = cast(ChatCompletionAssistantMessageParam, messages[-1]) + + for tool_call in last_msg.get("tool_calls", []): + tool_call_id: str = tool_call.get("id", "") + try: + tool_name: str = tool_call.get("function", {}).get("name", "") + tool_args: dict = json.loads(tool_call.get("function", {}).get("arguments", "")) + except Exception as e: + tool_messages.append(cast(vf.Message, { + "role": "tool", "content": str(e), "tool_call_id": tool_call_id + })) + continue + + tool_args = self.update_tool_args(tool_name, tool_args, messages, state, **kwargs) + try: + tool_msg = await self.call_tool(tool_name, tool_args, tool_call_id) + tool_messages.append(tool_msg) + except Exception as e: + tool_messages.append(cast(vf.Message, { + "role": "tool", "content": str(e), "tool_call_id": tool_call_id + })) + + return tool_messages + + def get_initial_actor(self, state: State) -> str: + return "rlm" + + def get_next_actor(self, state: State) -> str: + return "rlm" + + +# ============================================================================= +# SpawningRLMEnv +# ============================================================================= + + +class SpawningRLMEnv(MultiAgentToolEnv): + """ + RLM environment using Protocol.spawn() for spawning sub-agents. + + Core pattern: + - call_python_repl: Execute code in persistent REPL + - spawn_rlm: Spawn sub-RLMs via protocol.spawn() + - MultiAgentRubric propagates rewards through execution tree + """ + + # Multi agent specific fields + name = "spawning_rlm" + # We only have one actor, the RLM + actors = ["rlm"] + protocol: "Protocol" + # Only have one actor, so just hardcode it as the rlm + current_actor: str = "rlm" + + def __init__( + self, + max_iterations: int = 50, + max_output_length: int = 8192, + code_timeout: int = 120, + system_prompt: str | None = None, + **kwargs, + ): + self.max_iterations = max_iterations + self.max_output_length = max_output_length + self.code_timeout = code_timeout + self.custom_system_prompt = system_prompt + + super().__init__(max_turns=max_iterations, **kwargs) + + self._executor = LocalRLMExecutor(code_timeout=code_timeout) + self.add_tool(self.call_python_repl, args_to_skip=["state"]) + self.add_tool(self.spawn_rlm, args_to_skip=["state"]) + + # ========================================================================= + # State Management + # ========================================================================= + + async def setup_state(self, state: State, **kwargs) -> State: + state = await super().setup_state(state, **kwargs) + + state["rollout_id"] = f"rlm_{uuid.uuid4().hex[:8]}" + state["_exec_seq"] = 0 + + # Get context from RolloutInput.extra_data + input_extra = state.get("input", {}).get("extra_data", {}) + state["rlm_context"] = input_extra.get("rlm_context", {}) + state["is_child"] = input_extra.get("is_child", False) + state["child_states"] = [] + + # Get extra_data from dataset row + # Dataset columns are in state["input"], but State only forwards specific fields + input_data = state.get("input", {}) + extra_data = input_data.get("context") or input_data.get("extra_data") + if extra_data is None: + # Fallback: check info dict + info = state.get("info", {}) + if isinstance(info, dict): + extra_data = info.get("context", info.get("extra_data")) + + # Build system prompt + state["rlm_system_prompt"] = self.custom_system_prompt or _RLM_SYSTEM_PROMPT + + # Start worker + await self._executor.setup(state, extra_data) + + return state + + def update_tool_args( + self, tool_name: str, tool_args: dict, messages: Messages, state: State, **kwargs + ) -> dict: + if tool_name in ("call_python_repl", "spawn_rlm"): + return {**tool_args, "state": state} + return tool_args + + # ========================================================================= + # Prompt Building + # ========================================================================= + + async def get_prompt_messages(self, state: State) -> Messages: + if len(state["trajectory"]) == 0: + prompt = state.get("prompt", []) + if isinstance(prompt, str): + prompt = [{"role": "user", "content": prompt}] + + messages = list(prompt) + system_prompt = state.get("rlm_system_prompt", _RLM_SYSTEM_PROMPT) + + if not messages or messages[0].get("role") != "system": + messages.insert(0, {"role": "system", "content": system_prompt}) + + return cast(Messages, messages) + return await super().get_prompt_messages(state) + + # ========================================================================= + # Tools + # ========================================================================= + + async def call_python_repl(self, code: str, state: State) -> str: + """ + Execute Python code in a persistent REPL. + + Available: + - `extra_data`: Input data to process + - `answer["content"]`: Set your answer here + - `answer["ready"]`: Set True to finish + + Args: + code: Python code to execute + """ + result = await self._executor.execute(code, state) + output = self._format_output(result) + + # Check if answer ready + answer = result.get("answer", {}) + if answer.get("ready"): + state["final_answer"] = answer.get("content", "") + + return output + + def _format_output(self, result: dict[str, Any]) -> str: + parts = [] + if result.get("stdout"): + parts.append(result["stdout"].rstrip()) + if result.get("stderr"): + parts.append(f"stderr:\n{result['stderr'].rstrip()}") + + status = result.get("status") + res = result.get("result") + if status == "error" and res: + parts.append(res.rstrip()) + elif status == "ok" and res: + parts.append(f"Out[{result.get('execution_count', 0)}]: {res}") + + output = "\n".join(parts) if parts else "(no output)" + if len(output) > self.max_output_length: + output = output[:self.max_output_length] + "\n... [truncated]" + return output + + async def spawn_rlm(self, tasks: list[str], state: State) -> str: + """ + Spawn sub-RLMs to solve tasks in parallel. + + Args: + tasks: List of task descriptions to spawn (runs concurrently) + + Returns: + JSON array of results: [{"task": "...", "answer": "..."}, ...] + """ + if not hasattr(self, "protocol") or self.protocol is None: + return json.dumps([{"error": "Protocol not configured"}]) + + parent_ctx = state.get("rlm_context", {}) + child_inputs = [ + RolloutInput( + prompt=[{"role": "user", "content": t}], + example_id=state.get("example_id", 0), + task=self.name, + extra_data={ + "rlm_context": { + **parent_ctx, + "parent_task": t, + "depth": parent_ctx.get("depth", 0) + 1, + }, + "is_child": True, + }, + ) + for t in tasks + ] + + try: + child_states = await self.protocol.spawn(child_inputs, score=False) + results = [] + for t, cs in zip(tasks, child_states): + if cs.get("final_answer"): + state.setdefault("child_states", []).append(cs) + results.append({"task": t, "answer": str(cs["final_answer"])}) + else: + results.append({"task": t, "error": str(cs.get("error", "No answer"))}) + return json.dumps(results) + except Exception as e: + logger.error(f"spawn_rlm failed: {e}") + return json.dumps([{"error": str(e)}]) + + # ========================================================================= + # Stop Conditions + # ========================================================================= + + @vf.stop + async def answer_ready(self, state: State) -> bool: + return "final_answer" in state + + # ========================================================================= + # Cleanup + # ========================================================================= + + @vf.cleanup + async def cleanup_state(self, state: State): + await self._executor.cleanup(state) + + @vf.teardown + async def teardown_executor(self): + await self._executor.teardown() + + +# ============================================================================= +# Scoring +# ============================================================================= + + +async def exact_answer(state: State, answer: str, **_kwargs) -> float: + """Score: 1.0 if final_answer matches expected answer.""" + final = str(state.get("final_answer") or "").strip() + expected = str(answer).strip() + return 1.0 if final == expected else 0.0 + + +# ============================================================================= +# Protocol Factory +# ============================================================================= + + +def create_spawning_rlm_protocol( + dataset=None, + eval_dataset=None, + system_prompt: str | None = None, + **env_kwargs, +) -> "Protocol": + """ + Create Protocol for spawning RLM usage. + + The dataset is registered on the Protocol, not the environment. + This follows the multi-agent pattern where Protocol owns the dataset. + + Args: + dataset: Training dataset (registered on Protocol) + eval_dataset: Evaluation dataset (registered on Protocol) + system_prompt: Custom system prompt for the RLM actor + **env_kwargs: Additional arguments passed to SpawningRLMEnv + """ + from verifiers.envs.protocol import Protocol + + rlm_actor = Actor( + id="rlm", + system_prompt=system_prompt or _RLM_SYSTEM_PROMPT, + ) + + rubric = MultiAgentRubric(funcs=[exact_answer]) + + # Environment doesn't need dataset - Protocol owns it + rlm_env = SpawningRLMEnv( + rubric=rubric, + system_prompt=system_prompt, + **env_kwargs, + ) + + # Dataset registered on Protocol + return Protocol( + actors=[rlm_actor], + envs=[rlm_env], + dataset=dataset, + eval_dataset=eval_dataset, + ) diff --git a/verifiers/envs/multiagent_env.py b/verifiers/envs/multiagent_env.py new file mode 100644 index 000000000..4da6a36a4 --- /dev/null +++ b/verifiers/envs/multiagent_env.py @@ -0,0 +1,232 @@ +""" +MultiAgentEnv: Multi-turn environment with multiple actors. + +Extends MultiTurnEnv to add actor management - tracking which actor +is currently active, and rewriting system prompts accordingly. + +Dataset lives on the Protocol, not the environment. MultiAgentEnv provides +a dummy dataset to satisfy the base class requirement. +""" +from abc import abstractmethod +from typing import TYPE_CHECKING + +from datasets import Dataset +from openai import AsyncOpenAI + +import verifiers as vf +from verifiers.types import ( + Messages, + ModelResponse, + RolloutInput, + SamplingArgs, + State, + TrajectoryStep, +) +from verifiers.utils.message_utils import concat_messages +from verifiers.utils.response_utils import ( + parse_is_truncated, + parse_response_messages, + parse_response_tokens, +) + +from .multiturn_env import MultiTurnEnv + +if TYPE_CHECKING: + from .protocol import Protocol + + +def _dummy_dataset() -> Dataset: + """Create a minimal dummy dataset to satisfy base class requirement.""" + return Dataset.from_dict({ + "example_id": [0], + "prompt": [[{"role": "user", "content": "dummy"}]], + "answer": [""], + }) + + +class MultiAgentEnv(MultiTurnEnv): + """ + Multi-turn environment with multiple actors. + + Extends MultiTurnEnv to add: + - self.current_actor tracking + - get_initial_actor() / get_next_actor() for turn management + - System prompt rewriting based on current actor (naive approach, needs more sophisticated rearchitecture) + - actor_id stored in TrajectoryStep.extras + + Subclasses must implement: + - get_initial_actor(state) -> str + - get_next_actor(state) -> str + - env_response(messages, state) -> Messages (inherited from MultiTurnEnv) + """ + + # Subclasses declare which actors they use + actors: list[str] = [] + + # Current actor tracked as instance field (set during rollout) + current_actor: str = "" + + # Protocol reference (injected by Protocol.__init__) + protocol: "Protocol" + + def __init__(self, **kwargs): + """ + Initialize MultiAgentEnv with a dummy dataset. + + Dataset lives on the Protocol, not the environment. The dummy dataset + satisfies the base class requirement but is never actually used. + """ + # Inject dummy dataset if none provided (protocol manages the real dataset) + if "dataset" not in kwargs and "eval_dataset" not in kwargs: + kwargs["dataset"] = _dummy_dataset() + super().__init__(**kwargs) + + @abstractmethod + def get_initial_actor(self, state: State) -> str: + """Return the actor ID that starts the rollout.""" + pass + + @abstractmethod + def get_next_actor(self, state: State) -> str: + """Return the actor ID for the next turn.""" + pass + + async def get_prompt_messages(self, state: State) -> Messages: + """ + Build prompt messages, rewriting system prompt for current actor. + + Gets system prompt from protocol.get_actor(self.current_actor). + """ + # Get base messages from parent logic + if len(state["trajectory"]) == 0: + messages = list(state["prompt"]) # copy + else: + prev_turn_prompt = state["trajectory"][-1]["prompt"] + prev_turn_completion = state["trajectory"][-1]["completion"] + messages = concat_messages([prev_turn_prompt, prev_turn_completion]) + env_response = await self.env_response(messages, state) + messages = concat_messages([messages, env_response]) + + # Rewrite system prompt for current actor + actor = self.protocol.get_actor(self.current_actor) + + if messages and messages[0].get("role") == "system": + # Replace existing system prompt + messages[0] = {"role": "system", "content": actor.system_prompt} + elif actor.system_prompt: + # Prepend system prompt + messages = [{"role": "system", "content": actor.system_prompt}] + messages + + return messages + + async def add_model_response( + self, + state: State, + prompt_messages: Messages, + response: ModelResponse, + ): + """Add model response to trajectory, storing actor_id in extras.""" + completion_messages = await parse_response_messages(response, self.message_type) + response_is_truncated = await parse_is_truncated(response, self.message_type) + tokens = await parse_response_tokens( + response, self.message_type, self.max_seq_len + ) + is_truncated = response_is_truncated or ( + tokens is not None and bool(tokens.get("is_truncated")) + ) + + trajectory_step = TrajectoryStep( + prompt=prompt_messages, + completion=completion_messages, + response=response, + tokens=tokens, + reward=None, + advantage=None, + is_truncated=is_truncated, + trajectory_id=state["trajectory_id"], + extras={"actor_id": self.current_actor}, # Store actor_id in extras + ) + await self.add_trajectory_step(state, trajectory_step) + + async def rollout( + self, + input: RolloutInput, + client: AsyncOpenAI, + model: str, + sampling_args: SamplingArgs | None = None, + ) -> State: + """ + Multi-agent rollout loop. + + Similar to MultiTurnEnv.rollout() but: + 1. Sets self.current_actor at start via get_initial_actor() + 2. Updates self.current_actor after each turn via get_next_actor() + """ + state = await self.init_state(input, client, model, sampling_args) + + try: + state = await self.setup_state(state) + except vf.Error as e: + state["error"] = e + return state + + # Initialize current actor (instance field) + self.current_actor = self.get_initial_actor(state) + + while not await self.is_completed(state): + try: + prompt_messages = await self.get_prompt_messages(state) + if state.get("final_env_response") is not None: + continue + response = await self.get_model_response(state, prompt_messages) + await self.add_model_response(state, prompt_messages, response) + + # Update actor for next turn + self.current_actor = self.get_next_actor(state) + + except vf.Error as e: + if isinstance(e, vf.OverlongPromptError): + state["prompt_too_long"] = True + state["is_truncated"] = True + else: + state["error"] = e + + await self.render_completion(state) + return state + + +class SingleTurnMAEnv(MultiAgentEnv): + """ + Single-turn multi-agent environment with exactly one actor. + + Similar to how SingleTurnEnv simplifies MultiTurnEnv, this class simplifies + MultiAgentEnv for single-turn, single-actor use cases: + - max_turns=1 (only one model response) + - env_response raises NotImplementedError (no multi-turn interaction) + - get_initial_actor/get_next_actor return the single declared actor + + Subclasses must declare exactly one actor in the `actors` list. + """ + + def __init__(self, **kwargs): + """Initialize SingleTurnMAEnv with max_turns=1.""" + super().__init__(max_turns=1, **kwargs) + # Validate single actor requirement + if len(self.actors) != 1: + raise ValueError( + f"SingleTurnMAEnv requires exactly one actor, got {len(self.actors)}: {self.actors}" + ) + + async def env_response( + self, messages: Messages, state: State, **kwargs + ) -> Messages: + """Not implemented - single turn means no environment responses.""" + raise NotImplementedError("env_response is not implemented for SingleTurnMAEnv") + + def get_initial_actor(self, state: State) -> str: + """Return the single declared actor.""" + return self.actors[0] + + def get_next_actor(self, state: State) -> str: + """Return the single declared actor (though this should never be called).""" + return self.actors[0] diff --git a/verifiers/envs/protocol.py b/verifiers/envs/protocol.py new file mode 100644 index 000000000..0e0fdcf7b --- /dev/null +++ b/verifiers/envs/protocol.py @@ -0,0 +1,361 @@ +""" +Protocol: Orchestrates multiple environments and actors for composable training. + +The Protocol holds actors, environments, and the dataset. It enables cross-environment +calls via spawn() and handling initial scheduling via generate(). + +Dataset registration: The Protocol owns the dataset, not individual environments. +This allows multi-agent scenarios where the dataset is shared across environments. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING, List + +from datasets import Dataset +from openai import AsyncOpenAI + +from verifiers.types import DatasetBuilder, RolloutInput, SamplingArgs, State +from verifiers.utils.async_utils import maybe_semaphore + +from .actor import Actor + +if TYPE_CHECKING: + from .environment import Environment + + +class Protocol: + """ + Holds actors, environments, and the dataset. Enables cross-env calls and handles generation scheduling. + + The Protocol owns the dataset, not individual environments. This allows multi-agent + scenarios where the dataset is shared across environments. + + Usage: + # Define actors + solver = Actor(id="solver", system_prompt="You are a math solver...") + verifier = Actor(id="verifier", system_prompt="You verify solutions...") + + # Define environments with their actors (no dataset needed) + math_env = MathEnv(actors=["solver"]) + verify_env = VerifyEnv(actors=["verifier"]) + + # Create protocol with dataset + protocol = Protocol( + actors=[solver, verifier], + envs=[math_env, verify_env], + dataset=my_dataset, # Dataset registered here + ) + + # Get inputs from protocol's dataset + inputs = protocol.get_inputs(n=100, rollouts_per_example=4) + states = await protocol.generate(inputs, client, model) + + # Within math_env's env_response, spawn child rollouts: + child_states = await self.protocol.spawn([sub_input1, sub_input2], score=True) + state["child_states"].extend(child_states) + """ + + def __init__( + self, + actors: list[Actor], + envs: list["Environment"], + dataset: Dataset | DatasetBuilder | None = None, + eval_dataset: Dataset | DatasetBuilder | None = None, + ): + """ + Initialize protocol with actors, environments, and optional dataset. + + - Builds actor dict: {actor.id: actor} + - Builds env dict: {env.name: env} + - Validates each env's actors exist in protocol + - Injects self into each environment + - Registers dataset (owned by Protocol, not environments) + + Args: + actors: List of Actor instances to register + envs: List of Environment instances to register + dataset: Training dataset (Dataset or callable that returns Dataset) + eval_dataset: Evaluation dataset (Dataset or callable that returns Dataset) + """ + # Register actors + self._actors: dict[str, Actor] = {} + for actor in actors: + if actor.id in self._actors: + raise ValueError(f"Duplicate actor id: {actor.id}") + self._actors[actor.id] = actor + + # Register environments + self._envs: dict[str, "Environment"] = {} + for env in envs: + name = getattr(env, "name", env.__class__.__name__) + if name in self._envs: + raise ValueError(f"Duplicate environment name: {name}") + + # Validate env's actors exist in protocol + env_actors = getattr(env, "actors", []) + for actor_id in env_actors: + if actor_id not in self._actors: + raise ValueError( + f"Environment '{name}' references unknown actor '{actor_id}'. " + f"Available actors: {list(self._actors.keys())}" + ) + + self._envs[name] = env + # Inject protocol reference + env.protocol = self + + # Dataset registration (owned by Protocol) + self._dataset: Dataset | None = None + self._eval_dataset: Dataset | None = None + + if dataset is not None: + if callable(dataset): + self._dataset_source: DatasetBuilder | None = dataset + else: + self._dataset_source = lambda ds=dataset: ds + self._build_dataset() # Eagerly build for raw datasets + else: + self._dataset_source = None + + if eval_dataset is not None: + if callable(eval_dataset): + self._eval_dataset_source: DatasetBuilder | None = eval_dataset + else: + self._eval_dataset_source = lambda ds=eval_dataset: ds + self._build_eval_dataset() # Eagerly build for raw datasets + else: + self._eval_dataset_source = None + + # Context stored during generate() for spawn() to use + self._client: AsyncOpenAI | None = None + self._model: str | None = None + self._sampling_args: SamplingArgs | None = None + self._gen_sem = None + self._score_sem = None + + def get_actor(self, actor_id: str) -> Actor: + """Get actor by id.""" + if actor_id not in self._actors: + raise KeyError( + f"Actor '{actor_id}' not found. Available: {list(self._actors.keys())}" + ) + return self._actors[actor_id] + + def get_env(self, name: str) -> "Environment": + """Get environment by name.""" + if name not in self._envs: + raise KeyError( + f"Environment '{name}' not found. Available: {list(self._envs.keys())}" + ) + return self._envs[name] + + @property + def actors(self) -> dict[str, Actor]: + """All registered actors.""" + return self._actors + + @property + def envs(self) -> dict[str, "Environment"]: + """All registered environments.""" + return self._envs + + # Dataset management methods + + def _build_dataset(self) -> Dataset | None: + """Build and cache the training dataset from source if needed.""" + if self._dataset is not None: + return self._dataset + if self._dataset_source is None: + return None + self._dataset = self._dataset_source() + return self._dataset + + def _build_eval_dataset(self) -> Dataset | None: + """Build and cache the evaluation dataset from source if needed.""" + if self._eval_dataset is not None: + return self._eval_dataset + if self._eval_dataset_source is None: + return None + self._eval_dataset = self._eval_dataset_source() + return self._eval_dataset + + def get_dataset(self, n: int = -1, seed: int | None = None) -> Dataset: + """Get the training dataset, optionally shuffled and truncated.""" + self._build_dataset() + if self._dataset is None: + raise ValueError("dataset is not set on Protocol") + dataset = self._dataset + if seed is not None: + dataset = dataset.shuffle(seed=seed) + if n > 0: + n = min(n, len(dataset)) + return dataset.select(range(n)) + return dataset + + def get_eval_dataset(self, n: int = -1, seed: int | None = None) -> Dataset: + """Get the evaluation dataset, optionally shuffled and truncated.""" + self._build_eval_dataset() + if self._eval_dataset is None: + # Fall back to train dataset + return self.get_dataset(n, seed) + dataset = self._eval_dataset + if seed is not None: + dataset = dataset.shuffle(seed=seed) + if n > 0: + n = min(n, len(dataset)) + return dataset.select(range(n)) + return dataset + + def get_inputs( + self, n: int = -1, rollouts_per_example: int = 1, seed: int | None = None + ) -> List[RolloutInput]: + """Get training inputs from the dataset.""" + dataset = self.get_dataset(n=n, seed=seed) + if rollouts_per_example > 1: + dataset = dataset.repeat(rollouts_per_example) + return dataset.to_list() + + def get_eval_inputs( + self, n: int = -1, rollouts_per_example: int = 1, seed: int | None = None + ) -> List[RolloutInput]: + """Get evaluation inputs from the dataset.""" + dataset = self.get_eval_dataset(n=n, seed=seed) + if rollouts_per_example > 1: + dataset = dataset.repeat(rollouts_per_example) + return dataset.to_list() + + async def generate( + self, + inputs: list[RolloutInput], + client: AsyncOpenAI, + model: str, + sampling_args: SamplingArgs | None = None, + max_concurrent: int = -1, + ) -> list[State]: + """ + Generate rollouts for a batch of inputs. + + This is where initial scheduling happens - determines which env + handles each input and dispatches accordingly. + + Called by MultiAgentOrchestrator.generate_batch(). + """ + # Store context for spawn() calls during this generation + self._client = client + self._model = model + self._sampling_args = sampling_args + self._gen_sem = await maybe_semaphore(max_concurrent) + self._score_sem = await maybe_semaphore(max_concurrent) + + # Group inputs by target environment + by_env: dict[str, list[RolloutInput]] = {} + for inp in inputs: + env_name = inp.get("task") or self._get_default_env() + by_env.setdefault(env_name, []).append(inp) + + # Run each environment's generate() + all_states: list[State] = [] + for env_name, env_inputs in by_env.items(): + env = self.get_env(env_name) + results = await env.generate( + env_inputs, + client=client, + model=model, + sampling_args=sampling_args, + max_concurrent=max_concurrent, + ) + all_states.extend(results["state"]) + + # Flatten: collect trainable child_states recursively + return self._flatten_trainable(all_states) + + def _get_default_env(self) -> str: + """Return first registered environment as default.""" + return next(iter(self._envs.keys())) + + def _flatten_trainable(self, states: list[State]) -> list[State]: + """Recursively collect all trainable states including children.""" + result: list[State] = [] + for state in states: + result.append(state) + child_states = state.get("child_states", []) + if child_states: + result.extend(self._flatten_trainable(child_states)) + return result + + async def spawn( + self, + inputs: list[RolloutInput], + score: bool = True, + ) -> list[State]: + """ + Spawn child rollouts in sibling environments. + + Routes each input to its target environment based on the `task` field, + then calls run_group() for each environment. + Uses context stored by the enclosing generate() call. + + Args: + inputs: List of rollout inputs (task field determines target env) + score: Whether to score the group (passed through to run_group) + + Returns: + List of completed states from the child rollouts + """ + if self._client is None or self._model is None: + raise RuntimeError( + "spawn() can only be called within a generate() context. " + "Ensure you're calling spawn() from within an env_response or rollout." + ) + + # Group inputs by target environment (using task field) + by_env: dict[str, list[RolloutInput]] = {} + for inp in inputs: + env_name = inp.get("task") or self._get_default_env() + by_env.setdefault(env_name, []).append(inp) + + # Run each environment's run_group() + all_states: list[State] = [] + for env_name, env_inputs in by_env.items(): + env = self.get_env(env_name) + child_states = await env.run_group( + env_inputs, + client=self._client, + model=self._model, + gen_sampling_args=self._sampling_args or {}, + gen_sem=self._gen_sem, + score_sem=self._score_sem, + score=score, + ) + all_states.extend(child_states) + + return all_states + + async def evaluate( + self, + client: AsyncOpenAI, + model: str, + sampling_args: SamplingArgs | None = None, + num_examples: int = -1, + rollouts_per_example: int = 1, + max_concurrent: int = -1, + seed: int | None = None, + ) -> list[State]: + """ + Evaluate model on the Protocol's evaluation dataset. + + Gets inputs from the Protocol's eval dataset (or train dataset if no eval + dataset is set) and runs generate() to produce rollout states. + """ + inputs = self.get_eval_inputs( + n=num_examples, + rollouts_per_example=rollouts_per_example, + seed=seed, + ) + return await self.generate( + inputs=inputs, + client=client, + model=model, + sampling_args=sampling_args, + max_concurrent=max_concurrent, + ) diff --git a/verifiers/rl/trainer/__init__.py b/verifiers/rl/trainer/__init__.py index 3fef17b46..3c1f4846d 100644 --- a/verifiers/rl/trainer/__init__.py +++ b/verifiers/rl/trainer/__init__.py @@ -3,6 +3,7 @@ import torch._dynamo from .config import RLConfig +from .multiagent_orchestrator import MultiAgentOrchestrator from .trainer import RLTrainer torch._dynamo.config.suppress_errors = True @@ -30,6 +31,7 @@ def lora_defaults(**kwargs): __all__ = [ "RLConfig", "RLTrainer", + "MultiAgentOrchestrator", "GRPOTrainer", "GRPOConfig", "grpo_defaults", diff --git a/verifiers/rl/trainer/multiagent_orchestrator.py b/verifiers/rl/trainer/multiagent_orchestrator.py new file mode 100644 index 000000000..7d39500db --- /dev/null +++ b/verifiers/rl/trainer/multiagent_orchestrator.py @@ -0,0 +1,215 @@ +""" +MultiAgentOrchestrator: Thin wrapper on Orchestrator that uses Protocol for generation. + +Delegates batch generation to Protocol.generate() which handles multi-environment +orchestration and child state flattening. + +Dataset is owned by Protocol, not environments. +""" +import time +from typing import Any + +import numpy as np +from datasets import Dataset + +from verifiers.envs.protocol import Protocol + +from .orchestrator import Batch, Microbatch, Orchestrator + + +class MultiAgentOrchestrator(Orchestrator): + """ + Thin wrapper on Orchestrator that delegates to Protocol for generation. + + Instead of a single Environment, takes a Protocol which holds multiple + environments and actors. generate_batch() calls protocol.generate() + which handles scheduling across environments and flattening child states. + + Dataset is owned by Protocol, not environments. + """ + + def __init__( + self, + protocol: Protocol, + **kwargs, + ): + self.protocol = protocol + + # Pick first env as default for parent Orchestrator + first_env = next(iter(protocol.envs.values())) + super().__init__(env=first_env, **kwargs) + + # Override: filter Protocol's dataset instead of environment's + # (parent __init__ filters env.dataset, we need to filter protocol's) + max_length = self.max_prompt_len + + def filter_by_prompt_length(example, processing_class): + prompt = example["prompt"] + if isinstance(prompt, list): + prompt_text = processing_class.apply_chat_template( + prompt, tokenize=False, add_generation_prompt=True + ) + else: + prompt_text = prompt + prompt_ids = processing_class.encode(prompt_text) + return len(prompt_ids) <= max_length + + # Filter Protocol's dataset + if self.protocol._dataset is not None: + self.protocol._dataset = self.protocol.get_dataset().filter( + filter_by_prompt_length, + fn_kwargs={"processing_class": self.processing_class}, + ) + + def get_dataset_slice(self, batch_id: int) -> Dataset: + """Get dataset slice from Protocol's dataset for a given batch id.""" + num_rows = self.prompts_per_batch + dataset = self.protocol.get_dataset() + total_rows = len(dataset) + if total_rows == 0: + raise ValueError("Protocol dataset is empty") + offset = (batch_id * num_rows) % total_rows + indices = [(offset + i) % total_rows for i in range(num_rows)] + return dataset.select(indices) + + async def generate_batch(self, batch_id: int) -> Batch: + """ + Generate batch by delegating to protocol.generate(). + + Flow: + 1. MultiAgentOrchestrator.generate_batch() calls protocol.generate() + 2. protocol.generate() calls env.generate() for each environment + 3. env.generate() runs rollouts, which may call protocol.spawn() for children + 4. protocol.generate() flattens all states (including child_states) + """ + self.is_generating = True + assert self.client is not None + start_time = time.time() + + # Get dataset slice and repeat for rollouts + batch_ds = self.get_dataset_slice(batch_id) + repeated_ds = batch_ds.repeat(self.rollouts_per_example) + inputs = repeated_ds.to_list() + + # Use protocol.generate() instead of env.generate() + # This returns flattened list of all trainable states + all_states = await self.protocol.generate( + inputs, + client=self.client, + model=self.model_name, + sampling_args=self.sampling_args, + max_concurrent=self.max_concurrent, + ) + + self.is_generating = False + wall_clock_s = time.time() - start_time + + # Process trajectories - each step becomes a separate training example + prompt_ids: list[list[int]] = [] + prompt_mask: list[list[int]] = [] + completion_ids: list[list[int]] = [] + completion_mask: list[list[int]] = [] + completion_logprobs: list[list[float]] = [] + advantages: list[float] = [] + + for state in all_states: + trajectory = state.get("trajectory", []) + for step in trajectory: + tokens = step.get("tokens") + if tokens is None: + continue + prompt_ids.append(tokens["prompt_ids"]) + prompt_mask.append(tokens["prompt_mask"]) + completion_ids.append(tokens["completion_ids"]) + completion_mask.append(tokens["completion_mask"]) + completion_logprobs.append(tokens["completion_logprobs"]) + advantages.append(step.get("advantage", 0.0)) + + # Build rewards_dict from rollout-level data (for logging only) + rewards = [state.get("reward", 0.0) for state in all_states] + rewards_dict: dict[str, list[float]] = {"reward": rewards} + + # Collect metrics + metrics_dict: dict[str, float] = {} + if rewards: + rewards_arr = np.asarray(rewards, dtype=np.float32) + metrics_dict["reward"] = float(rewards_arr.mean()) + metrics_dict["reward/std"] = float(rewards_arr.std()) + + if advantages: + adv_arr = np.asarray(advantages, dtype=np.float32) + metrics_dict["advantage/absmean"] = float(np.abs(adv_arr).mean()) + + completion_lengths = [len(ids) for ids in completion_ids] + if completion_lengths: + completion_lengths_arr = np.asarray(completion_lengths, dtype=np.float32) + metrics_dict["tokens/completion"] = float(completion_lengths_arr.mean()) + + completion_mask_lengths = np.asarray( + [sum(mask) for mask in completion_mask], + dtype=np.float32, + ) + valid_tokens = completion_mask_lengths.sum() + total_tokens = completion_lengths_arr.sum() + if total_tokens > 0: + masked_fraction = 1.0 - (valid_tokens / total_tokens) + metrics_dict["tokens/masked_fraction"] = float(masked_fraction) + + metrics_dict["wall_clock/generate_s"] = float(wall_clock_s) + + # Collect errors and completions for logging + errors = [state.get("error") for state in all_states] + completions = [state.get("completion") for state in all_states] + prompts = [state.get("prompt") for state in all_states] + + # Build per-process microbatches + N = len(advantages) + per_proc = N // self.num_processes if self.num_processes > 0 else N + microbatches: list[list[Microbatch]] = [] + items_per_process: list[int] = [] + + for proc in range(self.num_processes): + ps = proc * per_proc + pe = ps + per_proc + proc_mbs: list[Microbatch] = [] + proc_item_total = 0 + for s in range(ps, pe, self.micro_batch_size): + e = min(s + self.micro_batch_size, pe) + ids_chunk = [prompt_ids[i] + completion_ids[i] for i in range(s, e)] + mask_chunk = [prompt_mask[i] + completion_mask[i] for i in range(s, e)] + logprobs_chunk = [ + [0.0] * len(prompt_mask[i]) + completion_logprobs[i] + for i in range(s, e) + ] + lengths = [len(mask) for mask in mask_chunk] + adv_chunk = [ + [advantages[i]] * lengths[idx] + for idx, i in enumerate(list(range(s, e))) + ] + mb_items = sum(sum(mask) for mask in mask_chunk) + microbatch = Microbatch( + input_ids=ids_chunk, + loss_mask=mask_chunk, + sampling_logprobs=logprobs_chunk, + advantages=adv_chunk, + items=mb_items, + ) + proc_item_total += mb_items + proc_mbs.append(microbatch) + microbatches.append(proc_mbs) + items_per_process.append(proc_item_total) + + global_item_count = sum(items_per_process) + + return Batch( + batch_id=batch_id, + microbatches=microbatches, + items_per_process=items_per_process, + global_item_count=global_item_count, + generation_time=wall_clock_s, + rewards_dict=rewards_dict, + completions=completions, + prompts=prompts, + errors=errors, + metrics_dict=metrics_dict, + ) diff --git a/verifiers/rubrics/multiagent_rubric.py b/verifiers/rubrics/multiagent_rubric.py new file mode 100644 index 000000000..6d2560800 --- /dev/null +++ b/verifiers/rubrics/multiagent_rubric.py @@ -0,0 +1,121 @@ +""" +MultiAgentRubric: Rubric that propagates rewards to child_states. + +Extends Rubric.score_group to traverse the state tree and apply +the root state's reward/advantage to all nested trajectory steps. +""" +import asyncio +import time +from typing import AsyncContextManager, cast + +from verifiers.types import GroupRewardFunc, RewardFunc, State + +from .rubric import Rubric + + +class MultiAgentRubric(Rubric): + """ + Rubric that propagates rewards to child_states. + + When scoring a group, after computing rewards for top-level states, + traverses into child_states and applies the same reward/advantage + to all trajectory steps in the tree. + """ + + async def score_group(self, states: list[State], score_sem: AsyncContextManager): + """ + Score a group of rollouts together, propagating to child_states. + + All reward functions are executed in order, parallelizing across states. + After scoring, propagates the reward/advantage to all nested child_states. + """ + start_time = time.time() + num_states = len(states) + if num_states == 0: + self.logger.warning("No states to score") + return + aggregated_rewards = [0.0] * num_states + aggregated_metrics: dict[str, list[float]] = {} + + # process functions in order + for func, weight in zip(self.funcs, self.weights): + is_group = self._is_group_func(func) + if is_group: + # GroupRewardFunc: score all states together + group_func = cast(GroupRewardFunc, func) + scores = await self._call_group_reward_func( + group_func, states, score_sem=score_sem + ) + func_name = func.__name__ + if func_name not in aggregated_metrics: + aggregated_metrics[func_name] = [0.0] * num_states + for i in range(num_states): + score_value = scores[i] + aggregated_rewards[i] += score_value * weight + aggregated_metrics[func_name][i] = score_value + else: + reward_func = cast(RewardFunc, func) + score_tasks = [ + self._call_individual_reward_func( + reward_func, state, score_sem=score_sem + ) + for state in states + ] + scores = await asyncio.gather(*score_tasks) + + func_name = func.__name__ + if func_name not in aggregated_metrics: + aggregated_metrics[func_name] = [0.0] * num_states + for i in range(num_states): + score_value = scores[i] + aggregated_rewards[i] += score_value * weight + aggregated_metrics[func_name][i] = score_value + + # update states with aggregated results + end_time = time.time() + scoring_ms = (end_time - start_time) * 1000 + avg_reward = sum(aggregated_rewards) / num_states + for i, state in enumerate(states): + state["reward"] = aggregated_rewards[i] + state["advantage"] = aggregated_rewards[i] - avg_reward + for t in state["trajectory"]: + if t["advantage"] is None: + t["advantage"] = state["advantage"] + if t["reward"] is None: + t["reward"] = state["reward"] + state["metrics"] = { + func_name: values[i] for func_name, values in aggregated_metrics.items() + } + state["timing"]["scoring_ms"] = scoring_ms + state["timing"]["total_ms"] += state["timing"]["scoring_ms"] + + # Propagate reward/advantage to all child_states + self._propagate_to_children( + state.get("child_states", []), + state["reward"], + state["advantage"], + ) + + def _propagate_to_children( + self, + child_states: list[State], + reward: float, + advantage: float, + ): + """ + Recursively apply reward/advantage to all trajectory steps in child_states. + """ + for child in child_states: + child["reward"] = reward + child["advantage"] = advantage + for t in child.get("trajectory", []): + if t["advantage"] is None: + t["advantage"] = advantage + if t["reward"] is None: + t["reward"] = reward + # Recurse into nested children + self._propagate_to_children( + child.get("child_states", []), + reward, + advantage, + ) diff --git a/verifiers/types.py b/verifiers/types.py index 7c45e9609..3653249f4 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -72,7 +72,7 @@ class TrajectoryStep(TypedDict): advantage: float | None is_truncated: bool trajectory_id: str - extras: dict[str, Any] + extras: dict[str, Any] # actor_id stored here for multi-agent class BaseRolloutInput(TypedDict): @@ -114,6 +114,8 @@ class State(dict): metrics: dict[str, float] | None timing: RolloutTiming | None error: Error | None + # multi-agent / composable training + child_states: list["State"] | None # Child rollout states from protocol.spawn() def __getitem__(self, key: str) -> Any: # forward to input if exists From 48f6698098b6fbbf29263950f836786c78dfcd02 Mon Sep 17 00:00:00 2001 From: eligotts <78387377+eligotts@users.noreply.github.com> Date: Tue, 20 Jan 2026 17:32:39 -0800 Subject: [PATCH 2/2] call run_rollout directly rather than run_group --- verifiers/envs/protocol.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/verifiers/envs/protocol.py b/verifiers/envs/protocol.py index 0e0fdcf7b..0fe11509c 100644 --- a/verifiers/envs/protocol.py +++ b/verifiers/envs/protocol.py @@ -9,13 +9,14 @@ """ from __future__ import annotations +import asyncio from typing import TYPE_CHECKING, List from datasets import Dataset from openai import AsyncOpenAI from verifiers.types import DatasetBuilder, RolloutInput, SamplingArgs, State -from verifiers.utils.async_utils import maybe_semaphore +from verifiers.utils.async_utils import NullAsyncContext, maybe_semaphore from .actor import Actor @@ -292,12 +293,12 @@ async def spawn( Spawn child rollouts in sibling environments. Routes each input to its target environment based on the `task` field, - then calls run_group() for each environment. + then runs rollouts in parallel using asyncio.gather. Uses context stored by the enclosing generate() call. Args: inputs: List of rollout inputs (task field determines target env) - score: Whether to score the group (passed through to run_group) + score: Whether to score the rollouts after completion Returns: List of completed states from the child rollouts @@ -308,28 +309,24 @@ async def spawn( "Ensure you're calling spawn() from within an env_response or rollout." ) - # Group inputs by target environment (using task field) - by_env: dict[str, list[RolloutInput]] = {} - for inp in inputs: - env_name = inp.get("task") or self._get_default_env() - by_env.setdefault(env_name, []).append(inp) + # Use NullAsyncContext for children to allow parallel execution + null_sem = NullAsyncContext() - # Run each environment's run_group() - all_states: list[State] = [] - for env_name, env_inputs in by_env.items(): - env = self.get_env(env_name) - child_states = await env.run_group( - env_inputs, + # Run all rollouts in parallel + all_states = await asyncio.gather(*( + self.get_env(inp.get("task") or self._get_default_env()).run_rollout( + inp, client=self._client, model=self._model, gen_sampling_args=self._sampling_args or {}, - gen_sem=self._gen_sem, - score_sem=self._score_sem, + gen_sem=null_sem, + score_sem=null_sem, score=score, ) - all_states.extend(child_states) + for inp in inputs + )) - return all_states + return list(all_states) async def evaluate( self,