From 1e5f474ac2c21c8b0d13b0363a71d3f574fd6005 Mon Sep 17 00:00:00 2001 From: wjh581 Date: Sun, 25 Jan 2026 05:37:47 -0500 Subject: [PATCH] Add multi-agent support: Actor, Protocol, MultiAgentEnv, MultiAgentRubric --- .../rock_paper_scissors/pyproject.toml | 20 + .../rock_paper_scissors.py | 369 +++++++++ environments/twenty_questions/pyproject.toml | 19 + .../twenty_questions/twenty_questions.py | 389 ++++++++++ verifiers/__init__.py | 16 + verifiers/envs/actor.py | 50 ++ verifiers/envs/multiagent_env.py | 720 ++++++++++++++++++ verifiers/envs/protocol.py | 368 +++++++++ verifiers/rl/trainer/__init__.py | 2 + .../rl/trainer/multiagent_orchestrator.py | 327 ++++++++ verifiers/rubrics/multiagent_rubric.py | 270 +++++++ verifiers/utils/eval_utils.py | 4 + 12 files changed, 2554 insertions(+) create mode 100644 environments/rock_paper_scissors/pyproject.toml create mode 100644 environments/rock_paper_scissors/rock_paper_scissors.py create mode 100644 environments/twenty_questions/pyproject.toml create mode 100644 environments/twenty_questions/twenty_questions.py create mode 100644 verifiers/envs/actor.py create mode 100644 verifiers/envs/multiagent_env.py create mode 100644 verifiers/envs/protocol.py create mode 100644 verifiers/rl/trainer/multiagent_orchestrator.py create mode 100644 verifiers/rubrics/multiagent_rubric.py diff --git a/environments/rock_paper_scissors/pyproject.toml b/environments/rock_paper_scissors/pyproject.toml new file mode 100644 index 000000000..aa03d251c --- /dev/null +++ b/environments/rock_paper_scissors/pyproject.toml @@ -0,0 +1,20 @@ +[project] +name = "rock-paper-scissors" +description = "Rock-Paper-Scissors with simultaneous moves" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "verifiers>=0.1.9", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["rock_paper_scissors.py", "pyproject.toml"] + +[tool.verifiers.eval] +num_examples = 3 +rollouts_per_example = 1 +num_rounds = 3 diff --git a/environments/rock_paper_scissors/rock_paper_scissors.py b/environments/rock_paper_scissors/rock_paper_scissors.py new file mode 100644 index 000000000..137f12136 --- /dev/null +++ b/environments/rock_paper_scissors/rock_paper_scissors.py @@ -0,0 +1,369 @@ +""" +Rock-Paper-Scissors: Multi-agent environment with simultaneous moves. + +This environment demonstrates: +- Simultaneous moves via get_active_actors() returning both players +- Custom rollout loop (both players act each round, not alternating) +- Per-actor reward functions (competitive scoring) +- Round-based game with history tracking + +Game flow: +1. Both players see the round number and previous results +2. Both make their choice (simultaneously from game perspective) +3. Round is resolved, scores updated +4. Repeat for num_rounds +5. Split into per-actor states for scoring +""" + +from datasets import Dataset + +import verifiers as vf +from verifiers import Actor, MultiAgentEnv, MultiAgentRubric, Protocol +from verifiers.types import Messages, State + + +# ============================================================================= +# Actors +# ============================================================================= + +PLAYER1 = Actor( + id="player1", + system_prompt="""You are Player 1 in Rock-Paper-Scissors. + +Choose ONE of: rock, paper, or scissors + +Output ONLY your choice (one word, lowercase). Nothing else.""", + max_tokens=10, + is_trainable=True, +) + +PLAYER2 = Actor( + id="player2", + system_prompt="""You are Player 2 in Rock-Paper-Scissors. + +Choose ONE of: rock, paper, or scissors + +Output ONLY your choice (one word, lowercase). Nothing else.""", + max_tokens=10, + is_trainable=True, +) + + +# ============================================================================= +# Environment +# ============================================================================= + +class RockPaperScissorsEnv(MultiAgentEnv): + """Rock-Paper-Scissors with simultaneous moves.""" + + name = "rock_paper_scissors" + actors = ["player1", "player2"] + + def __init__(self, num_rounds: int = 3, **kwargs): + super().__init__(**kwargs) + self.num_rounds = num_rounds + + # ------------------------------------------------------------------------- + # Turn Management + # Required by MultiAgentEnv but not really used here - we override rollout() + # to use get_active_actors() for simultaneous play instead. + # ------------------------------------------------------------------------- + + def get_initial_actor(self, state: State) -> str: + return "player1" + + def get_next_actor(self, state: State) -> str: + return "player1" + + def get_active_actors(self, state: State) -> list[str]: + """Both players act simultaneously each round.""" + return ["player1", "player2"] + + # ------------------------------------------------------------------------- + # Stop Condition + # ------------------------------------------------------------------------- + + @vf.stop + async def game_over(self, state: State) -> bool: + """Stop after all rounds played.""" + return state.get("extras", {}).get("round", 0) >= self.num_rounds + + # ------------------------------------------------------------------------- + # State Setup + # ------------------------------------------------------------------------- + + async def setup_state(self, state: State) -> State: + """Initialize RPS-specific game state.""" + state = await super().setup_state(state) + state["extras"]["round"] = 0 # Current round (0-indexed during play) + state["extras"]["p1_score"] = 0 # Player 1 win count + state["extras"]["p2_score"] = 0 # Player 2 win count + state["extras"]["history"] = [] # List of (p1_choice, p2_choice, result) + state["extras"]["p1_choice"] = None # Temp storage for current round + state["extras"]["p2_choice"] = None # Temp storage for current round + return state + + # ------------------------------------------------------------------------- + # Prompt Building + # ------------------------------------------------------------------------- + + async def get_prompt_messages(self, state: State) -> Messages: + """ + Build fresh prompt for current actor each round. + + Overrides base class because RPS needs clean, summarized prompts + rather than accumulated raw conversation. + """ + current_actor_id = state["extras"]["current_actor_id"] + actor = self.get_actor(current_actor_id) + round_num = state["extras"]["round"] + 1 + + # Build history from this player's perspective ("You" vs "Opponent") + history = state["extras"]["history"] + history_str = "" + if history: + history_str = "\n\nPrevious rounds:\n" + for i, (p1, p2, result) in enumerate(history, 1): + you = p1 if current_actor_id == "player1" else p2 + opponent = p2 if current_actor_id == "player1" else p1 + history_str += f" Round {i}: You={you}, Opponent={opponent} → {result}\n" + + return [ + {"role": "system", "content": actor.system_prompt}, + {"role": "user", "content": f"Round {round_num} of {self.num_rounds}. Make your choice!{history_str}"} + ] + + # ------------------------------------------------------------------------- + # Environment Response (Game Logic) + # ------------------------------------------------------------------------- + + async def env_response(self, messages: Messages, state: State, **kwargs) -> Messages: + """ + Process each actor's choice and resolve round when both have played. + + Called after each actor's turn: + - First call: stores player1's choice + - Second call: stores player2's choice, resolves round + + Returns [] because we build fresh prompts in get_prompt_messages(). + """ + if not state["trajectory"]: + return [] + + # Get the last completion + last_step = state["trajectory"][-1] + last_completion = last_step.get("completion", []) + if not last_completion: + return [] + + # Parse the choice from model output + content = last_completion[-1].get("content", "").lower().strip() if isinstance(last_completion[-1], dict) else str(last_completion[-1]).lower().strip() + choice = self._parse_choice(content) + + # Store choice for the actor who just played + actor_id = last_step.get("extras", {}).get("actor_id", state["extras"]["current_actor_id"]) + if actor_id == "player1": + state["extras"]["p1_choice"] = choice + else: + state["extras"]["p2_choice"] = choice + + # If both have chosen, resolve the round + p1_choice = state["extras"]["p1_choice"] + p2_choice = state["extras"]["p2_choice"] + + if p1_choice and p2_choice: + winner = self._determine_winner(p1_choice, p2_choice) + + if winner == "player1": + state["extras"]["p1_score"] += 1 + result = "Player 1 wins" + elif winner == "player2": + state["extras"]["p2_score"] += 1 + result = "Player 2 wins" + else: + result = "Tie" + + # Record result and reset for next round + state["extras"]["history"].append((p1_choice, p2_choice, result)) + state["extras"]["round"] += 1 + state["extras"]["p1_choice"] = None + state["extras"]["p2_choice"] = None + + print(f" [Round {state['extras']['round']}] {p1_choice} vs {p2_choice} → {result}") + + return [] + + # ------------------------------------------------------------------------- + # Helper Functions + # ------------------------------------------------------------------------- + + def _parse_choice(self, text: str) -> str: + """Extract rock/paper/scissors from model output. Defaults to rock.""" + text = text.lower() + if "rock" in text: + return "rock" + elif "paper" in text: + return "paper" + elif "scissors" in text: + return "scissors" + return "rock" + + def _determine_winner(self, p1: str, p2: str) -> str | None: + """Determine winner using standard RPS rules. Returns None for tie.""" + if p1 == p2: + return None + wins = {"rock": "scissors", "paper": "rock", "scissors": "paper"} + if wins.get(p1) == p2: + return "player1" + return "player2" + + # ------------------------------------------------------------------------- + # Custom Rollout (Simultaneous Moves) + # ------------------------------------------------------------------------- + + async def rollout( + self, + input, + client, + model, + sampling_args=None, + ) -> State: + """ + Custom rollout with simultaneous moves. + + Overrides base class because RPS has both players act each round, + rather than strict alternation. Uses get_active_actors() to get + both players, then loops through them each round. + """ + state = await self.init_state(input, client, model, sampling_args) + state = await self.setup_state(state) + + while not await self.is_completed(state): + active_actors = self.get_active_actors(state) + + for actor_id in active_actors: + # Set who's currently playing + state["extras"]["current_actor_id"] = actor_id + + try: + # Build prompt for this actor + prompt_messages = await self.get_prompt_messages(state) + + # Get actor's sampling settings + actor = self.get_actor(actor_id) + merged_args = actor.merge_sampling_args(sampling_args or {}) + + # Call the model + response = await self.get_model_response(state, prompt_messages, sampling_args=merged_args) + + # Store in trajectory (tagged with actor_id) + await self.add_model_response(state, prompt_messages, response) + + # Process choice and maybe resolve round + await self.env_response([], state) + + except vf.Error as e: + state["error"] = e + break + + await self.render_completion(state) + + # Split into per-actor states for scoring + state["child_states"] = self.create_actor_states(state) + + return state + + +# ============================================================================= +# Rubric (Scoring) +# ============================================================================= + +def create_rubric() -> MultiAgentRubric: + """ + Create competitive rubric with per-actor rewards. + + Each player's reward = win rate (wins / total_rounds). + Creates competitive dynamic: one player's gain ≈ other's loss. + """ + rubric = MultiAgentRubric() + + def player1_reward(state: State, **kwargs) -> float: + """Player 1 reward = win rate.""" + extras = state.get("extras", {}) + p1_score = extras.get("p1_score", 0) + total_rounds = extras.get("round", 1) + return p1_score / total_rounds if total_rounds > 0 else 0.0 + + def player2_reward(state: State, **kwargs) -> float: + """Player 2 reward = win rate.""" + extras = state.get("extras", {}) + p2_score = extras.get("p2_score", 0) + total_rounds = extras.get("round", 1) + return p2_score / total_rounds if total_rounds > 0 else 0.0 + + def rounds_played_metric(state: State, **kwargs) -> float: + """Track rounds played (metric only, weight=0).""" + return float(state.get("extras", {}).get("round", 0)) + + rubric.add_actor_reward_func("player1", player1_reward, weight=1.0) + rubric.add_actor_reward_func("player2", player2_reward, weight=1.0) + rubric.add_reward_func(rounds_played_metric, weight=0.0) + + return rubric + + +# ============================================================================= +# Dataset +# ============================================================================= + +def create_dataset() -> Dataset: + """ + Create dataset for RPS games. + + The prompt is a placeholder - RPS builds its own prompts via + get_prompt_messages(). Each row represents one game to play. + """ + return Dataset.from_list([ + { + "example_id": i, + "prompt": [{"role": "user", "content": "play"}], + "answer": "", + "task": "rock_paper_scissors" + } + for i in range(10) + ]) + + +# ============================================================================= +# Environment Loader +# ============================================================================= + +def load_environment( + num_rounds: int = 3, + num_examples: int = -1, +) -> RockPaperScissorsEnv: + """ + Factory function to create a fully configured RPS environment. + + Args: + num_rounds: Rounds per game (default 3) + num_examples: Number of games to run (-1 = all 10) + + Returns: + Ready-to-use RockPaperScissorsEnv + """ + dataset = create_dataset() + if num_examples > 0: + dataset = dataset.select(range(min(num_examples, len(dataset)))) + + env = RockPaperScissorsEnv( + num_rounds=num_rounds, + rubric=create_rubric(), + max_turns=num_rounds * 2 + 2, # 2 turns per round + buffer + dataset=dataset, + ) + + # Wire actors to environment via Protocol + Protocol(actors=[PLAYER1, PLAYER2], envs=[env]) + + return env diff --git a/environments/twenty_questions/pyproject.toml b/environments/twenty_questions/pyproject.toml new file mode 100644 index 000000000..9d3b174c8 --- /dev/null +++ b/environments/twenty_questions/pyproject.toml @@ -0,0 +1,19 @@ +[project] +name = "twenty-questions" +description = "20 Questions multi-agent guessing game" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "verifiers>=0.1.9", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["twenty_questions.py", "pyproject.toml"] + +[tool.verifiers.eval] +num_examples = 5 +rollouts_per_example = 1 diff --git a/environments/twenty_questions/twenty_questions.py b/environments/twenty_questions/twenty_questions.py new file mode 100644 index 000000000..6e77677a0 --- /dev/null +++ b/environments/twenty_questions/twenty_questions.py @@ -0,0 +1,389 @@ +""" +20 Questions: A simple multi-agent guessing game. + +This environment demonstrates: +- Alternating turns via get_next_actor() (standard turn-based flow) +- Asymmetric actors (one trainable, one frozen) +- Multiple stop conditions (win or max questions) +- Conversation relay between actors via env_response() + +Game flow: +1. Guesser receives category hint and asks first question +2. Thinker (with secret word) answers yes/no +3. Alternate until guesser wins or runs out of questions +4. Only guesser is trained - rewarded for winning quickly +""" + +import re +from datasets import Dataset + +import verifiers as vf +from verifiers import Actor, MultiAgentEnv, MultiAgentRubric, Protocol +from verifiers.types import Messages, State + + +# ============================================================================= +# Actors +# ============================================================================= + +THINKER = Actor( + id="thinker", + system_prompt="""You are the Thinker in 20 Questions. You have a SECRET WORD. + +Rules: +1. Answer questions with ONLY "Yes" or "No" +2. Be honest and consistent +3. If asked to guess, confirm with "Correct!" or "No, try again" + +Format your response as exactly one of: +- Yes +- No +- Correct! +- No, try again""", + max_tokens=20, + is_trainable=False, # Frozen - just follows rules, not trained +) + +GUESSER = Actor( + id="guesser", + system_prompt="""You are the Guesser in 20 Questions. Try to figure out the secret word. + +Rules: +1. Ask yes/no questions to narrow down possibilities +2. When ready to guess, say "Is it [your guess]?" +3. You have 20 questions maximum + +Good strategy: Start broad (Is it alive? Is it man-made?) then narrow down. + +Format: Just ask your question directly.""", + max_tokens=50, + is_trainable=True, # This is the agent we're training +) + + +# ============================================================================= +# Environment +# ============================================================================= + +class TwentyQuestionsEnv(MultiAgentEnv): + """ + 20 Questions game environment. + + Uses standard alternating turn flow (unlike RPS which uses simultaneous moves). + The Guesser asks questions, Thinker answers, until win or question limit. + """ + + name = "twenty_questions" + actors = ["thinker", "guesser"] + + def __init__(self, max_questions: int = 20, **kwargs): + """ + Initialize environment. + + Args: + max_questions: Maximum questions before game ends (default 20) + **kwargs: Passed to parent (rubric, dataset, max_turns, etc.) + """ + super().__init__(**kwargs) + self.max_questions = max_questions + + # ------------------------------------------------------------------------- + # Turn Management + # Uses standard alternating flow - no custom rollout needed + # ------------------------------------------------------------------------- + + def get_initial_actor(self, state: State) -> str: + """Guesser asks first question.""" + return "guesser" + + def get_next_actor(self, state: State) -> str: + """Alternate between guesser and thinker each turn.""" + current = state["extras"]["current_actor_id"] + return "thinker" if current == "guesser" else "guesser" + + # ------------------------------------------------------------------------- + # Stop Conditions + # Two ways to end: win or run out of questions + # ------------------------------------------------------------------------- + + @vf.stop + async def game_won(self, state: State) -> bool: + """Stop if guesser guessed correctly.""" + return state.get("extras", {}).get("won", False) + + @vf.stop + async def max_questions_reached(self, state: State) -> bool: + """Stop after max questions (guesser loses).""" + return state.get("extras", {}).get("question_count", 0) >= self.max_questions + + # ------------------------------------------------------------------------- + # State Setup + # ------------------------------------------------------------------------- + + async def setup_state(self, state: State) -> State: + """ + Initialize game state with secret word from dataset. + + Sets up: + - secret_word: The word guesser must discover + - question_count: Progress toward limit + - won: Victory flag (checked by game_won stop condition) + - questions: History for debugging/analysis + """ + state = await super().setup_state(state) + + # Get secret word from dataset's "answer" field + secret_word = state["input"].get("answer", "dog") + state["extras"]["secret_word"] = secret_word.lower() + state["extras"]["question_count"] = 0 + state["extras"]["won"] = False + state["extras"]["questions"] = [] + + return state + + # ------------------------------------------------------------------------- + # Prompt Building + # ------------------------------------------------------------------------- + + async def get_initial_messages(self, state: State) -> Messages: + """ + Create opening prompt for guesser. + + Includes category hint (animal/object/food) from dataset + to make the game tractable. + """ + category = state["input"].get("info", {}).get("category", "thing") + return [ + {"role": "system", "content": GUESSER.system_prompt}, + {"role": "user", "content": f"I'm thinking of a {category}. Ask your first question!"} + ] + + # ------------------------------------------------------------------------- + # Environment Response (Game Logic) + # ------------------------------------------------------------------------- + + async def env_response(self, messages: Messages, state: State, **kwargs) -> Messages: + """ + Process last response and build prompt for next actor. + + Called after each actor speaks: + - After Guesser: Check for correct guess, prompt Thinker to answer + - After Thinker: Relay answer to Guesser, check for game over + + Returns the messages to show the next actor. + """ + if not state["trajectory"]: + return [] + + # Extract what was just said + last_step = state["trajectory"][-1] + last_completion = last_step.get("completion", []) + if not last_completion: + return [] + + content = last_completion[-1].get("content", "") if isinstance(last_completion[-1], dict) else str(last_completion[-1]) + content_lower = content.lower().strip() + secret = state["extras"]["secret_word"] + + # Get who just spoke (from trajectory, not current_actor which is NEXT speaker) + last_actor = last_step.get("extras", {}).get("actor_id", "") + + if last_actor == "guesser": + # ----------------------------------------------------------------- + # Guesser just asked a question + # ----------------------------------------------------------------- + state["extras"]["question_count"] += 1 + state["extras"]["questions"].append(content) + + # Check if it's a guess - look for "is it [something]?" patterns + guess_match = re.search(r"is it .*?([a-zA-Z]+)\s*\??$", content_lower) + if guess_match: + guess = guess_match.group(1).lower() + # Check exact match or if secret appears anywhere in question + if guess == secret or secret in content_lower: + state["extras"]["won"] = True + state["final_env_response"] = [{"role": "user", "content": "Correct! You win!"}] + return [] + + # Build prompt for Thinker (inject secret word so it can answer) + return [ + {"role": "system", "content": THINKER.system_prompt + f"\n\nYour secret word is: {secret}"}, + {"role": "user", "content": f"Question {state['extras']['question_count']}: {content}"} + ] + + else: + # ----------------------------------------------------------------- + # Thinker just answered + # ----------------------------------------------------------------- + # Safety check if already won + if state["extras"].get("won", False): + return [] + + # Check if Thinker confirmed a correct guess + if "correct" in content_lower: + state["extras"]["won"] = True + state["final_env_response"] = [{"role": "user", "content": "Correct! You win!"}] + return [] + + # Check if max questions reached (game over, guesser loses) + if state["extras"]["question_count"] >= self.max_questions: + state["final_env_response"] = [ + {"role": "user", "content": f"Game over! The word was: {secret}"} + ] + return [] + + # Relay answer to Guesser with remaining count + remaining = self.max_questions - state["extras"]["question_count"] + return [ + {"role": "user", "content": f"Answer: {content}\n\nYou have {remaining} questions left. Ask another question or make a guess!"} + ] + + # ------------------------------------------------------------------------- + # Rollout + # ------------------------------------------------------------------------- + + async def rollout(self, input, client, model, sampling_args=None) -> State: + """ + Run the game and split into per-actor states. + + Uses parent's standard alternating rollout (no custom loop needed). + Just adds per-actor state splitting at the end for scoring. + """ + # Use parent rollout (handles turn alternation, stop conditions) + state = await super().rollout(input, client, model, sampling_args) + + # Split into per-actor states for proper per-actor scoring + state["child_states"] = self.create_actor_states(state) + + return state + + +# ============================================================================= +# Rubric (Scoring) +# ============================================================================= + +def create_rubric() -> MultiAgentRubric: + """ + Create rubric - only guesser is trained, rewarded for winning fast. + + Reward structure: + - Win in 1 question: 1.0 (maximum) + - Win in 20 questions: 0.1 (minimum win reward) + - Lose: 0.0 + + Thinker has is_trainable=False, so no reward needed for it. + """ + rubric = MultiAgentRubric() + + def guesser_reward(state, **kwargs) -> float: + """Reward guesser for winning quickly. Faster = higher reward.""" + extras = state.get("extras", {}) + won = extras.get("won", False) + questions = extras.get("question_count", 20) + max_q = 20 + + if won: + # Linear scale from 1.0 (1 question) to 0.1 (20 questions) + return 1.0 - 0.9 * (questions - 1) / (max_q - 1) + else: + return 0.0 + + def game_length_metric(state, **kwargs) -> float: + """Track how many questions were asked (metric only, not trained on).""" + return float(state.get("extras", {}).get("question_count", 0)) + + def win_rate_metric(state, **kwargs) -> float: + """Track win rate (metric only, not trained on).""" + return 1.0 if state.get("extras", {}).get("won", False) else 0.0 + + # Guesser reward - the only trainable actor + rubric.add_actor_reward_func("guesser", guesser_reward, weight=1.0) + # Thinker is frozen (is_trainable=False), no reward function needed + + # Metrics for logging (weight=0.0 means not used in training) + rubric.add_reward_func(game_length_metric, weight=0.0) + rubric.add_reward_func(win_rate_metric, weight=0.0) + + return rubric + + +# ============================================================================= +# Dataset +# ============================================================================= + +def create_dataset() -> Dataset: + """ + Create dataset of secret words across categories. + + Each row contains: + - prompt: Initial messages for Guesser + - answer: The secret word (used in setup_state) + - info.category: Hint category (animal/object/food) + - example_id: Unique identifier + - task: Environment name for routing + """ + def make_prompt(category: str) -> list: + return [ + {"role": "system", "content": GUESSER.system_prompt}, + {"role": "user", "content": f"I'm thinking of a {category}. You have 20 questions to guess what it is. Ask your first question!"} + ] + + items = [ + # Animals (4) + {"prompt": make_prompt("animal"), "answer": "dog", "info": {"category": "animal"}, "example_id": 0, "task": "twenty_questions"}, + {"prompt": make_prompt("animal"), "answer": "cat", "info": {"category": "animal"}, "example_id": 1, "task": "twenty_questions"}, + {"prompt": make_prompt("animal"), "answer": "elephant", "info": {"category": "animal"}, "example_id": 2, "task": "twenty_questions"}, + {"prompt": make_prompt("animal"), "answer": "penguin", "info": {"category": "animal"}, "example_id": 3, "task": "twenty_questions"}, + # Objects (4) + {"prompt": make_prompt("object"), "answer": "chair", "info": {"category": "object"}, "example_id": 4, "task": "twenty_questions"}, + {"prompt": make_prompt("object"), "answer": "book", "info": {"category": "object"}, "example_id": 5, "task": "twenty_questions"}, + {"prompt": make_prompt("object"), "answer": "computer", "info": {"category": "object"}, "example_id": 6, "task": "twenty_questions"}, + {"prompt": make_prompt("object"), "answer": "bicycle", "info": {"category": "object"}, "example_id": 7, "task": "twenty_questions"}, + # Food (2) + {"prompt": make_prompt("food"), "answer": "pizza", "info": {"category": "food"}, "example_id": 8, "task": "twenty_questions"}, + {"prompt": make_prompt("food"), "answer": "apple", "info": {"category": "food"}, "example_id": 9, "task": "twenty_questions"}, + ] + return Dataset.from_list(items) + + +# ============================================================================= +# Environment Loader +# ============================================================================= + +def load_environment( + max_questions: int = 20, + num_examples: int = -1, +) -> TwentyQuestionsEnv: + """ + Factory function to create a fully configured 20 Questions environment. + + Args: + max_questions: Questions before game ends (default 20) + num_examples: Number of games to run (-1 = all 10) + + Returns: + Ready-to-use TwentyQuestionsEnv + """ + dataset = create_dataset() + if num_examples > 0: + dataset = dataset.select(range(min(num_examples, len(dataset)))) + + rubric = create_rubric() + + env = TwentyQuestionsEnv( + max_questions=max_questions, + rubric=rubric, + max_turns=max_questions * 2 + 2, # 2 turns per question + buffer + dataset=dataset, + ) + + # Wire actors to environment via Protocol + # Protocol constructor registers itself with env (env.protocol = protocol) + # This enables env.get_actor() and env.protocol.spawn() + Protocol( + actors=[THINKER, GUESSER], + envs=[env], + dataset=dataset, + ) + + return env diff --git a/verifiers/__init__.py b/verifiers/__init__.py index 535a16870..2362b65da 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -19,6 +19,12 @@ from .envs.multiturn_env import MultiTurnEnv # noqa # isort: skip from .envs.tool_env import ToolEnv # noqa # isort: skip +# Multi-agent support +from .envs.actor import Actor # noqa # isort: skip +from .envs.protocol import EpisodeRequest, GenerateResult, Protocol # noqa # isort: skip +from .envs.multiagent_env import MultiAgentEnv # noqa # isort: skip +from .rubrics.multiagent_rubric import MultiAgentRubric # noqa # isort: skip + # main imports from .envs.env_group import EnvGroup from .envs.singleturn_env import SingleTurnEnv @@ -54,6 +60,13 @@ "JudgeRubric", "RubricGroup", "MathRubric", + "MultiAgentRubric", + # Multi-agent support + "Actor", + "EpisodeRequest", + "GenerateResult", + "Protocol", + "MultiAgentEnv", "TextArenaEnv", "ReasoningGymEnv", "GymEnv", @@ -80,6 +93,7 @@ "get_model_and_tokenizer", "RLTrainer", "RLConfig", + "MultiAgentOrchestrator", "GRPOTrainer", "GRPOConfig", "grpo_defaults", @@ -94,6 +108,7 @@ "get_model_and_tokenizer": "verifiers.rl.trainer.utils:get_model_and_tokenizer", "RLConfig": "verifiers.rl.trainer:RLConfig", "RLTrainer": "verifiers.rl.trainer:RLTrainer", + "MultiAgentOrchestrator": "verifiers.rl.trainer:MultiAgentOrchestrator", "GRPOTrainer": "verifiers.rl.trainer:GRPOTrainer", "GRPOConfig": "verifiers.rl.trainer:GRPOConfig", "grpo_defaults": "verifiers.rl.trainer:grpo_defaults", @@ -135,6 +150,7 @@ def __getattr__(name: str): from .rl.trainer import ( # noqa: F401 GRPOConfig, GRPOTrainer, + MultiAgentOrchestrator, RLConfig, RLTrainer, grpo_defaults, diff --git a/verifiers/envs/actor.py b/verifiers/envs/actor.py new file mode 100644 index 000000000..69b291d89 --- /dev/null +++ b/verifiers/envs/actor.py @@ -0,0 +1,50 @@ +""" +Actor: A trainable entity with distinct identity (system prompt) in multi-agent environments. +""" + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from verifiers.types import State + + +@dataclass +class Actor: + """A trainable actor with distinct system prompt. Set is_trainable=False for frozen actors.""" + + id: str + system_prompt: str = "" + max_tokens: int = 4096 + is_trainable: bool = True + sampling_args: dict[str, Any] = field(default_factory=dict) + + def __hash__(self) -> int: + return hash(self.id) + + def __eq__(self, other: object) -> bool: + if isinstance(other, Actor): + return self.id == other.id + return False + + def __repr__(self) -> str: + trainable_str = "trainable" if self.is_trainable else "frozen" + return f"Actor(id={self.id!r}, {trainable_str})" + + def get_system_message(self) -> dict[str, str] | None: + """Return system message dict or None if no system prompt.""" + if self.system_prompt: + return {"role": "system", "content": self.system_prompt} + return None + + def merge_sampling_args(self, base_args: dict[str, Any]) -> dict[str, Any]: + """Merge actor's sampling args with base args (actor takes precedence).""" + merged = dict(base_args) + merged.update(self.sampling_args) + if self.max_tokens: + merged["max_tokens"] = self.max_tokens + return merged + + def filter_state(self, state: "State") -> "State": + """Filter state to what this actor can see. Override for hidden info games.""" + return state diff --git a/verifiers/envs/multiagent_env.py b/verifiers/envs/multiagent_env.py new file mode 100644 index 000000000..d887e2793 --- /dev/null +++ b/verifiers/envs/multiagent_env.py @@ -0,0 +1,720 @@ +""" +Multi-agent environment with turn order management and hierarchical spawning. + +This module provides the base class for multi-agent RL environments, extending +MultiTurnEnv with support for: +- Multiple actors with distinct system prompts and sampling args +- Turn order management via get_initial_actor() / get_next_actor() +- Per-actor trajectory tagging for credit assignment +- Per-actor state splitting for individual reward computation +- Hierarchical episode spawning via Protocol.spawn() + +Key concepts: +- Actor: A trainable entity with its own system prompt (defined in actor.py) +- Protocol: Wires actors to environments, enables spawning (defined in protocol.py) +- State splitting: One game state → multiple actor states for per-actor rewards + +""" + +from __future__ import annotations + +import uuid +from abc import abstractmethod +from typing import TYPE_CHECKING, Any + +from datasets import Dataset +from openai import AsyncOpenAI + +import verifiers as vf +from verifiers.envs.multiturn_env import MultiTurnEnv + +from verifiers.envs.protocol import GenerateResult +from verifiers.types import ( + Messages, + RolloutInput, + SamplingArgs, + State, + TrajectoryStep, +) +from verifiers.utils.message_utils import concat_messages +from verifiers.utils.async_utils import maybe_semaphore +from verifiers.utils.eval_utils import save_rollout_results +from verifiers.utils.response_utils import ( + parse_is_truncated, + parse_response_messages, + parse_response_tokens, +) + +if TYPE_CHECKING: + from verifiers.envs.actor import Actor + from verifiers.envs.protocol import Protocol + + +def _dummy_dataset() -> Dataset: + """ + Create a placeholder dataset for environments that don't specify one. + + The real dataset is typically owned by Protocol. This prevents errors + when MultiTurnEnv requires a dataset but one isn't provided. + """ + return Dataset.from_dict({ + "example_id": [0], + "prompt": [[{"role": "user", "content": "dummy"}]], + "answer": [""], + }) + + +# ============================================================================= +# MultiAgentEnv Base Class +# ============================================================================= + +class MultiAgentEnv(MultiTurnEnv): + """ + Base class for multi-agent environments. + + Subclasses must implement: + - get_initial_actor(): Who goes first + - get_next_actor(): Who goes next (for alternating turns) + - env_response(): Game logic between turns + + The Protocol reference is injected by Protocol.__init__ when wiring + actors to environments. + """ + + # ------------------------------------------------------------------------- + # Class Attributes + # ------------------------------------------------------------------------- + + # List of actor IDs this environment uses (e.g., ["player1", "player2"]) + # Subclasses should override this + actors: list[str] = [] + + # Injected by Protocol.__init__ - provides actor lookup and spawning + protocol: "Protocol" + + def __init__(self, **kwargs): + """ + Initialize with dummy dataset if none provided. + + The parent class (MultiTurnEnv) requires a dataset, but for multi-agent + environments the Protocol often owns the real dataset. + """ + if "dataset" not in kwargs and "eval_dataset" not in kwargs: + kwargs["dataset"] = _dummy_dataset() + super().__init__(**kwargs) + + # ------------------------------------------------------------------------- + # Turn Management (Abstract Methods) + # ------------------------------------------------------------------------- + + @abstractmethod + def get_initial_actor(self, state: State) -> str: + """ + Return the actor ID that starts the rollout. + + Example: return "guesser" for Twenty Questions + """ + pass + + @abstractmethod + def get_next_actor(self, state: State) -> str: + """ + Return the actor ID for the next turn. + + Example: return "thinker" if current == "guesser" else "guesser" + """ + pass + + def get_active_actors(self, state: State) -> list[str]: + """ + Return actor IDs that can act this turn. + + Default: Single actor (standard alternating turns). + Override for simultaneous moves (e.g., RPS returns ["player1", "player2"]). + """ + return [self.get_next_actor(state)] + + def get_actor(self, actor_id: str) -> "Actor": + """Get an actor by ID from Protocol.""" + return self.protocol.get_actor(actor_id) + + # ------------------------------------------------------------------------- + # State Setup + # ------------------------------------------------------------------------- + + async def setup_state(self, state: State) -> State: + """ + Initialize multi-agent state fields. + + Sets up state["extras"] with: + - current_actor_id: Who is currently speaking (set in rollout) + - actor_history: List of (actor_id, turn_index) for credit assignment + - episode_id: Unique ID for this rollout + - parent_episode_id: Links to parent if this is a spawned child + + Also initializes state["child_states"] for per-actor state splitting. + """ + state = await super().setup_state(state) + + state["child_states"] = [] + state["extras"] = { + "current_actor_id": None, # Set in rollout() after setup + "actor_history": [], # Tracks who spoke at each turn + "episode_id": state.get("trajectory_id", uuid.uuid4().hex), + "parent_episode_id": None, # Set if spawned from parent episode + } + + return state + + # ------------------------------------------------------------------------- + # Trajectory Management + # ------------------------------------------------------------------------- + + async def add_trajectory_step( + self, state: State, trajectory_step: TrajectoryStep + ) -> None: + """ + Add trajectory step, tagging with current actor for credit assignment. + + This tagging is critical for create_actor_states() which filters + the trajectory by actor_id to split into per-actor states. + """ + current_actor_id = state["extras"]["current_actor_id"] + + if current_actor_id: + # Tag this step with who generated it + if "extras" not in trajectory_step: + trajectory_step["extras"] = {} + trajectory_step["extras"]["actor_id"] = current_actor_id + + # Record in history: "actor X spoke at turn Y" + turn_index = len(state["trajectory"]) + state["extras"]["actor_history"].append((current_actor_id, turn_index)) + + await super().add_trajectory_step(state, trajectory_step) + + async def add_model_response( + self, + state: State, + prompt_messages: Messages, + response: Any, + ) -> None: + """ + Parse API response and add to trajectory with actor_id tag. + + Extracts completion text, truncation status, and token data, + then creates a TrajectoryStep tagged with the current actor. + """ + current_actor_id = state["extras"]["current_actor_id"] + + # Parse the raw API response + completion_messages = await parse_response_messages(response, self.message_type) + response_is_truncated = await parse_is_truncated(response, self.message_type) + tokens = await parse_response_tokens( + response, self.message_type, self.max_seq_len + ) + is_truncated = response_is_truncated or ( + tokens is not None and bool(tokens.get("is_truncated")) + ) + + # Build trajectory step with actor tag + trajectory_step = TrajectoryStep( + prompt=prompt_messages, + completion=completion_messages, + response=response, + tokens=tokens, + reward=None, # Filled by rubric later + advantage=None, # Filled by GRPO later + is_truncated=is_truncated, + trajectory_id=state["trajectory_id"], + extras={"actor_id": current_actor_id}, # Critical for splitting + ) + await self.add_trajectory_step(state, trajectory_step) + + # ------------------------------------------------------------------------- + # Direct Actor Invocation + # ------------------------------------------------------------------------- + + async def call_actor( + self, + actor_id: str, + messages: Messages, + state: State, + sampling_args: SamplingArgs | None = None, + ) -> GenerateResult: + """ + Directly call a specific actor outside the standard rollout loop. + + Useful for custom rollouts (like RPS simultaneous moves) or + spawning patterns. Handles system prompt injection and sampling + args merging automatically. + """ + actor = self.get_actor(actor_id) + + # Mark who is speaking + state["extras"]["current_actor_id"] = actor_id + + # Build messages with actor's system prompt + system_msg = actor.get_system_message() + actor_messages = [system_msg] if system_msg else [] + actor_messages.extend(messages) + + # Merge sampling args: environment → actor → call overrides + merged_args = actor.merge_sampling_args(state.get("sampling_args") or {}) + if sampling_args: + merged_args.update(sampling_args) + + # Call model and record in trajectory + response = await self.get_model_response(state, actor_messages, sampling_args=merged_args) + await self.add_model_response(state, actor_messages, response) + + return GenerateResult( + actor_id=actor_id, + state=state, + is_trainable=actor.is_trainable, + episode_id=state["extras"]["episode_id"], + parent_episode_id=state["extras"]["parent_episode_id"], + ) + + # ------------------------------------------------------------------------- + # Prompt Building + # ------------------------------------------------------------------------- + + async def get_prompt_messages(self, state: State) -> Messages: + """ + Build prompt messages, injecting current actor's system prompt. + + For each turn: + 1. First turn: Use initial prompt from dataset + 2. Later turns: Concatenate previous turn + env_response() + 3. Always: Replace/prepend system prompt for current actor + + This ensures each actor sees their own instructions regardless + of what accumulated in the conversation context. + """ + current_actor_id = state["extras"]["current_actor_id"] + + # Build base messages + if len(state["trajectory"]) == 0: + # First turn: start with dataset prompt + messages = list(state["prompt"]) # Copy to avoid mutation + else: + # Later turns: build from previous turn + env_response + prev_turn_prompt = state["trajectory"][-1]["prompt"] + prev_turn_completion = state["trajectory"][-1]["completion"] + messages = concat_messages([prev_turn_prompt, prev_turn_completion]) + env_response = await self.env_response(messages, state) + messages = concat_messages([messages, env_response]) + + # Inject current actor's system prompt + actor = self.get_actor(current_actor_id) + system_prompt = actor.system_prompt + + if messages and messages[0].get("role") == "system": + # Replace existing system prompt with actor's + messages[0] = {"role": "system", "content": system_prompt} + elif system_prompt: + # Prepend actor's system prompt + messages = [{"role": "system", "content": system_prompt}] + messages + + return messages + + # ------------------------------------------------------------------------- + # Main Rollout Loop + # ------------------------------------------------------------------------- + + async def rollout( + self, + input: RolloutInput, + client: AsyncOpenAI, + model: str, + sampling_args: SamplingArgs | None = None, + ) -> State: + """ + Standard alternating-turn rollout using get_initial_actor/get_next_actor. + + Flow: + 1. init_state() - create base state from input + 2. setup_state() - initialize multi-agent fields + 3. Loop until stop condition: + a. get_prompt_messages() - build prompt for current actor + b. get_model_response() - call the model + c. add_model_response() - store in trajectory (tagged) + d. get_next_actor() - determine next speaker + 4. render_completion() - finalize state + + For simultaneous moves (like RPS), override this entirely and use + get_active_actors() instead. + """ + state = await self.init_state(input, client, model, sampling_args) + + try: + state = await self.setup_state(state) + except vf.Error as e: + state["error"] = e + return state + + # Set first actor + initial_actor_id = self.get_initial_actor(state) + state["extras"]["current_actor_id"] = initial_actor_id + + # Main loop + while not await self.is_completed(state): + try: + current_actor_id = state["extras"]["current_actor_id"] + + # Build prompt and check for early termination + prompt_messages = await self.get_prompt_messages(state) + if state.get("final_env_response") is not None: + break + + # Get actor-specific sampling args and call model + actor = self.get_actor(current_actor_id) + merged_args = actor.merge_sampling_args(sampling_args or {}) + + response = await self.get_model_response(state, prompt_messages, sampling_args=merged_args) + await self.add_model_response(state, prompt_messages, response) + + # Advance to next actor + next_actor_id = self.get_next_actor(state) + state["extras"]["current_actor_id"] = next_actor_id + + except vf.Error as e: + if isinstance(e, vf.OverlongPromptError): + state["prompt_too_long"] = True + state["is_truncated"] = True + else: + state["error"] = e + break + + await self.render_completion(state) + return state + + # ------------------------------------------------------------------------- + # Abstract: Game Logic + # ------------------------------------------------------------------------- + + @abstractmethod + async def env_response( + self, messages: Messages, state: State, **kwargs + ) -> Messages: + """ + Generate environment response between turns. + + This is where game logic lives: + - Process the last actor's output + - Update game state (scores, flags, etc.) + - Build the prompt for the next actor + + Return [] if no additional messages needed. + Set state["final_env_response"] to terminate early. + """ + pass + + # ------------------------------------------------------------------------- + # Actor Helpers + # ------------------------------------------------------------------------- + + def get_trainable_actors(self) -> list["Actor"]: + """Get actors with is_trainable=True (will be trained).""" + return [a for a in self.protocol.actors.values() if a.is_trainable] + + def get_frozen_actors(self) -> list["Actor"]: + """Get actors with is_trainable=False (frozen, not trained).""" + return [a for a in self.protocol.actors.values() if not a.is_trainable] + + # ------------------------------------------------------------------------- + # Per-Actor State Creation + # ------------------------------------------------------------------------- + # + # After a game completes, we split the single game state into per-actor + # states for individual reward computation and GRPO advantage calculation. + # + # Example: RPS game with 6 turns + # Full trajectory: [p1, p2, p1, p2, p1, p2] + # Player1 state: trajectory=[p1, p1, p1], prompt="You are Player 1..." + # Player2 state: trajectory=[p2, p2, p2], prompt="You are Player 2..." + # ------------------------------------------------------------------------- + + # Fields shared by reference across all actor states + # NOTE: "input" is deliberately NOT shared because State.__getitem__/__setitem__ + # forward reads/writes for INPUT_FIELDS (prompt, answer, etc.) to input[key]. + # If we shared input, all actor_states would read the same prompt. + SHARED_STATE_FIELDS = { + "client", # AsyncOpenAI API client + "model", # Model name string (e.g., "gpt-4o-mini") + "timing", # Performance metrics (generation_ms, scoring_ms) + "trajectory_id", # Unique rollout identifier + } + + def create_actor_state( + self, + parent_state: State, + actor_id: str, + actor_trajectory: list[TrajectoryStep], + ) -> State: + """ + Create a child state for a specific actor from a parent state. + + This splits a multi-actor game state into per-actor states for: + - Per-actor reward computation (via MultiAgentRubric) + - GRPO advantage calculation per actor + - Training only specific actors (is_trainable filtering) + + Args: + parent_state: The full game state with all actors' turns + actor_id: The actor this state is for (e.g., "guesser", "player1") + actor_trajectory: Only this actor's trajectory steps (filtered) + + Returns: + A new State with shared fields referenced and actor-specific fields fresh + """ + actor_state = State() + + # Copy shared fields by reference (not duplicated in memory) + for key in parent_state.keys(): + if key in self.SHARED_STATE_FIELDS: + actor_state[key] = parent_state[key] + + # Copy INPUT_FIELDS using dict.__setitem__ to bypass State forwarding + # State.__setitem__ forwards writes for INPUT_FIELDS to input[key] if + # input exists. We bypass this to store directly on actor_state. + dict.__setitem__(actor_state, "answer", parent_state.get("answer", "")) + dict.__setitem__(actor_state, "task", parent_state.get("task", "")) + dict.__setitem__(actor_state, "example_id", parent_state.get("example_id", 0)) + dict.__setitem__(actor_state, "info", parent_state.get("info", {})) + + # Set actor-specific trajectory (filtered to just this actor's steps) + actor_state["trajectory"] = actor_trajectory + + # Copy extras but override actor_id to mark whose state this is + actor_state["extras"] = { + **parent_state.get("extras", {}), + "current_actor_id": actor_id, + } + + # Fresh fields for scoring (will be computed by rubric) + actor_state["child_states"] = [] + actor_state["reward"] = None + actor_state["advantage"] = None + actor_state["metrics"] = None + + # Extract actor-specific prompt and completion + if actor_trajectory: + # Prompt: Find the LAST system message (actor's own prompt) + # The raw prompt may contain accumulated context from other actors + raw_prompt = actor_trajectory[0].get("prompt", []) + prompt_ref = raw_prompt + for i in range(len(raw_prompt) - 1, -1, -1): + if raw_prompt[i].get("role") == "system": + prompt_ref = raw_prompt[i:] # From last system message onward + break + dict.__setitem__(actor_state, "prompt", prompt_ref) + + # Completion: Collect all responses across all turns + all_completions = [] + for step in actor_trajectory: + step_completion = step.get("completion", []) + all_completions.extend(step_completion) + dict.__setitem__(actor_state, "completion", all_completions) + else: + # No trajectory for this actor - use parent's prompt + dict.__setitem__(actor_state, "prompt", parent_state.get("prompt", [])) + dict.__setitem__(actor_state, "completion", []) + + return actor_state + + def create_actor_states(self, state: State, actor_ids: list[str] | None = None) -> list[State]: + """ + Split a parent state into per-actor child states. + + Filters the full trajectory by actor_id (set in add_trajectory_step), + then creates a state for each actor with their filtered trajectory. + + Args: + state: The full game state with all actors' turns + actor_ids: List of actor IDs to create states for. + Defaults to self.actors if not provided. + + Returns: + List of per-actor states, one for each actor_id + """ + if actor_ids is None: + actor_ids = self.actors + + actor_states = [] + for actor_id in actor_ids: + # Filter trajectory to only this actor's steps + actor_trajectory = [ + step for step in state.get("trajectory", []) + if step.get("extras", {}).get("actor_id") == actor_id + ] + + new_state = self.create_actor_state(state, actor_id, actor_trajectory) + actor_states.append(new_state) + + return actor_states + + # ------------------------------------------------------------------------- + # Result Building + # ------------------------------------------------------------------------- + + def build_generate_result(self, state: State) -> GenerateResult: + """ + Build a GenerateResult tree from completed state with children attached. + + Creates a hierarchical structure linking parent episodes to children, + useful for hierarchical credit assignment (e.g., Proposer-Solver). + """ + extras = state["extras"] + + # Determine root actor (first in history, or first declared) + actor_history = extras["actor_history"] + root_actor_id = actor_history[0][0] if actor_history else (self.actors[0] if self.actors else "unknown") + + # Get trainability (default True if actor not found) + try: + is_trainable = self.get_actor(root_actor_id).is_trainable + except KeyError: + is_trainable = True + + # Create root result + root_result = GenerateResult( + actor_id=root_actor_id, + state=state, + is_trainable=is_trainable, + episode_id=extras["episode_id"], + parent_episode_id=extras["parent_episode_id"], + ) + + # Attach child states as GenerateResults linked to parent + parent_episode_id = extras["episode_id"] + for child_state in state["child_states"]: + child_extras = child_state.get("extras", {}) + child_actor_id = child_extras.get("current_actor_id", "unknown") + + try: + child_trainable = self.get_actor(child_actor_id).is_trainable + except KeyError: + child_trainable = True + + child_result = GenerateResult( + actor_id=child_actor_id, + state=child_state, + is_trainable=child_trainable, + episode_id=child_extras.get("episode_id", uuid.uuid4().hex), + parent_episode_id=parent_episode_id, + ) + root_result.add_child(child_result) + + return root_result + + # ------------------------------------------------------------------------- + # Generate Override (Flattening & Per-Actor Scoring) + # ------------------------------------------------------------------------- + + async def generate(self, inputs, client, model, **kwargs): + """ + Generate rollouts, flatten child_states, then score per-actor. + + The parent's generate() returns one state per game. For multi-agent, + we need one state per actor per game for proper per-actor rewards. + + Flow: + 1. Intercept save/score options (we'll handle them after flattening) + 2. Run parent's generate() to get game states + 3. Flatten: Replace each game state with its child_states (per-actor) + 4. Score all flattened states together (proper GRPO grouping) + 5. Rebuild result arrays to match flattened states + 6. Update metadata and save + + Before: result["state"] = [game1, game2] + After: result["state"] = [g1_actor1, g1_actor2, g2_actor1, g2_actor2] + """ + # Intercept options - we'll handle them after flattening + original_save_results = kwargs.pop("save_results", False) + push_to_hf_hub = kwargs.pop("push_to_hf_hub", False) + hf_hub_dataset_name = kwargs.pop("hf_hub_dataset_name", None) + + # Disable parent's scoring - we'll score after flattening + original_score_rollouts = getattr(self, 'score_rollouts', True) + self.score_rollouts = False + + try: + result = await super().generate(inputs, client, model, save_results=False, **kwargs) + finally: + self.score_rollouts = original_score_rollouts + + # Flatten: replace parent states with their child_states + original_states = result.get("state", []) + flattened_states = [] + needs_flatten = False + + for state in original_states: + child_states = state.get("child_states", []) + if child_states: + flattened_states.extend(child_states) + needs_flatten = True + else: + flattened_states.append(state) + + # Early exit if no flattening occurred + if not needs_flatten: + # Still need to score since we disabled it above + if original_states and self.rubric and original_score_rollouts: + score_sem = await maybe_semaphore(-1) + await self.rubric.score_group(original_states, score_sem=score_sem) + result["reward"] = [s.get("reward", 0.0) for s in original_states] + if original_save_results: + save_rollout_results(result, push_to_hf_hub, hf_hub_dataset_name) + return result + + # Score all flattened states together for proper GRPO grouping + if flattened_states and self.rubric and original_score_rollouts: + score_sem = await maybe_semaphore(-1) + await self.rubric.score_group(flattened_states, score_sem=score_sem) + + # Rebuild all result columns from flattened states + prompts, completions, answers, tasks, example_ids, rewards, infos, actor_ids = [], [], [], [], [], [], [], [] + metrics: dict[str, list[float]] = {} + + for s in flattened_states: + prompts.append(s.get("prompt", [])) + completions.append(s.get("completion")) + answers.append(s.get("answer", "")) + tasks.append(s.get("task", "default")) + example_ids.append(s.get("example_id", 0)) + rewards.append(s.get("reward", 0.0)) + infos.append(s.get("info", {})) + actor_ids.append(s.get("extras", {}).get("current_actor_id", "unknown")) + + state_metrics = s.get("metrics") + if state_metrics: + for name, value in state_metrics.items(): + if name not in metrics: + metrics[name] = [] + metrics[name].append(value) + + # Update result with flattened data + result["state"] = flattened_states + result["prompt"] = prompts + result["completion"] = completions + result["answer"] = answers + result["task"] = tasks + result["example_id"] = example_ids + result["reward"] = rewards + result["info"] = infos + result["actor_id"] = actor_ids # New field for multi-agent + result["metrics"] = metrics + + # Update metadata to reflect flattened counts + if "metadata" in result and rewards: + num_examples = len(set(example_ids)) + result["metadata"]["avg_reward"] = sum(rewards) / len(rewards) + result["metadata"]["num_examples"] = num_examples + result["metadata"]["rollouts_per_example"] = len(flattened_states) // num_examples if num_examples > 0 else 1 + + if original_save_results: + save_rollout_results(result, push_to_hf_hub, hf_hub_dataset_name) + + return result diff --git a/verifiers/envs/protocol.py b/verifiers/envs/protocol.py new file mode 100644 index 000000000..7f75147ec --- /dev/null +++ b/verifiers/envs/protocol.py @@ -0,0 +1,368 @@ +""" +Protocol: Orchestrates multiple environments and actors for composable training. + +- Protocol: Top-level coordinator owning actors, environments, and dataset +- EpisodeRequest: Request to spawn a child episode +- GenerateResult: Result from an episode, supporting tree structures +""" + +from __future__ import annotations + +import asyncio +import contextvars +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, List + +from datasets import Dataset +from openai import AsyncOpenAI + +from verifiers.types import DatasetBuilder, RolloutInput, SamplingArgs, State +from verifiers.utils.async_utils import maybe_semaphore + +from .actor import Actor + +if TYPE_CHECKING: + from .environment import Environment + +# Context variables for task-local storage during generate(). +# Enables: (1) concurrent generate() calls without interference, +# (2) simplified spawn() API - environments can call self.protocol.spawn(inputs) +# without passing client/model explicitly (retrieved from context). +_ctx_client: contextvars.ContextVar[AsyncOpenAI | None] = contextvars.ContextVar("client", default=None) +_ctx_model: contextvars.ContextVar[str | None] = contextvars.ContextVar("model", default=None) +_ctx_sampling_args: contextvars.ContextVar[SamplingArgs | None] = contextvars.ContextVar("sampling_args", default=None) + + + +@dataclass +class EpisodeRequest: + """Request to spawn a child episode with env_id, artifact, and is_trainable flag.""" + + env_id: str + artifact: Any = None + is_trainable: bool = True + meta: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + """Generate unique episode ID if not provided.""" + if "episode_id" not in self.meta: + self.meta["episode_id"] = uuid.uuid4().hex + + +@dataclass +class GenerateResult: + """Result from an episode, supporting hierarchical trees with children.""" + + actor_id: str + state: "State" + children: list["GenerateResult"] = field(default_factory=list) + is_trainable: bool = True + episode_id: str = field(default_factory=lambda: uuid.uuid4().hex) + parent_episode_id: str | None = None + + def flatten(self) -> list["GenerateResult"]: + """Flatten tree into list of all results (depth-first).""" + results = [self] + for child in self.children: + results.extend(child.flatten()) + return results + + def get_trainable(self) -> list["GenerateResult"]: + """Get only trainable results (is_trainable=True) from tree.""" + return [r for r in self.flatten() if r.is_trainable] + + def get_children_by_actor(self, actor_id: str) -> list["GenerateResult"]: + """Get direct children with a specific actor ID.""" + return [c for c in self.children if c.actor_id == actor_id] + + def add_child(self, child: "GenerateResult") -> None: + """Add a child result and set its parent reference.""" + child.parent_episode_id = self.episode_id + self.children.append(child) + + def get_rewards_by_actor(self) -> dict[str, list[float]]: + """Collect rewards grouped by actor ID from all descendants.""" + rewards: dict[str, list[float]] = {} + for result in self.flatten(): + actor_id = result.actor_id + reward = result.state.get("reward", 0.0) or 0.0 + if actor_id not in rewards: + rewards[actor_id] = [] + rewards[actor_id].append(reward) + return rewards + + def __repr__(self) -> str: + n_children = len(self.children) + trainable_str = "trainable" if self.is_trainable else "frozen" + children_str = f", {n_children} children" if n_children else "" + return f"GenerateResult(actor={self.actor_id!r}, {trainable_str}{children_str})" + + +class Protocol: + """Top-level coordinator owning actors, environments, and dataset.""" + + def __init__( + self, + actors: list[Actor], + envs: list["Environment"], + dataset: Dataset | DatasetBuilder | None = None, + eval_dataset: Dataset | DatasetBuilder | None = None, + ): + # Register actors + self._actors: dict[str, Actor] = {} + for actor in actors: + if actor.id in self._actors: + raise ValueError(f"Duplicate actor id: {actor.id}") + self._actors[actor.id] = actor + + # Register environments + self._envs: dict[str, "Environment"] = {} + for env in envs: + name = getattr(env, "name", env.__class__.__name__) + if name in self._envs: + raise ValueError(f"Duplicate environment name: {name}") + self._envs[name] = env + # Inject protocol reference into environment + env.protocol = self + + # Dataset registration (owned by Protocol) + self._dataset: Dataset | None = None + self._eval_dataset: Dataset | None = None + + if dataset is not None: + if callable(dataset): + self._dataset_source: DatasetBuilder | None = dataset + else: + self._dataset_source = lambda ds=dataset: ds + self._build_dataset() + else: + self._dataset_source = None + + if eval_dataset is not None: + if callable(eval_dataset): + self._eval_dataset_source: DatasetBuilder | None = eval_dataset + else: + self._eval_dataset_source = lambda ds=eval_dataset: ds + self._build_eval_dataset() + else: + self._eval_dataset_source = None + + def get_actor(self, actor_id: str) -> Actor: + """Get actor by id.""" + if actor_id not in self._actors: + raise KeyError( + f"Actor '{actor_id}' not found. Available: {list(self._actors.keys())}" + ) + return self._actors[actor_id] + + def get_env(self, name: str) -> "Environment": + """Get environment by name.""" + if name not in self._envs: + raise KeyError( + f"Environment '{name}' not found. Available: {list(self._envs.keys())}" + ) + return self._envs[name] + + @property + def actors(self) -> dict[str, Actor]: + """All registered actors.""" + return self._actors + + @property + def envs(self) -> dict[str, "Environment"]: + """All registered environments.""" + return self._envs + + # Dataset management + + def _build_dataset(self) -> Dataset | None: + """Build and cache the training dataset from source.""" + if self._dataset is not None: + return self._dataset + if self._dataset_source is None: + return None + self._dataset = self._dataset_source() + return self._dataset + + def _build_eval_dataset(self) -> Dataset | None: + """Build and cache the evaluation dataset from source.""" + if self._eval_dataset is not None: + return self._eval_dataset + if self._eval_dataset_source is None: + return None + self._eval_dataset = self._eval_dataset_source() + return self._eval_dataset + + def get_dataset(self, n: int = -1, seed: int | None = None) -> Dataset: + """Get the training dataset, optionally shuffled and truncated.""" + self._build_dataset() + if self._dataset is None: + raise ValueError("Dataset is not set on Protocol") + dataset = self._dataset + if seed is not None: + dataset = dataset.shuffle(seed=seed) + if n > 0: + n = min(n, len(dataset)) + return dataset.select(range(n)) + return dataset + + def get_eval_dataset(self, n: int = -1, seed: int | None = None) -> Dataset: + """Get the evaluation dataset, optionally shuffled and truncated.""" + self._build_eval_dataset() + if self._eval_dataset is None: + return self.get_dataset(n, seed) + dataset = self._eval_dataset + if seed is not None: + dataset = dataset.shuffle(seed=seed) + if n > 0: + n = min(n, len(dataset)) + return dataset.select(range(n)) + return dataset + + def get_inputs( + self, n: int = -1, rollouts_per_example: int = 1, seed: int | None = None + ) -> List[RolloutInput]: + """Get training inputs from the dataset.""" + dataset = self.get_dataset(n=n, seed=seed) + inputs = dataset.to_list() + if rollouts_per_example > 1: + inputs = inputs * rollouts_per_example + return inputs + + def get_eval_inputs( + self, n: int = -1, rollouts_per_example: int = 1, seed: int | None = None + ) -> List[RolloutInput]: + """Get evaluation inputs from the dataset.""" + dataset = self.get_eval_dataset(n=n, seed=seed) + inputs = dataset.to_list() + if rollouts_per_example > 1: + inputs = inputs * rollouts_per_example + return inputs + + async def generate( + self, + inputs: list[RolloutInput], + client: AsyncOpenAI, + model: str, + sampling_args: SamplingArgs | None = None, + max_concurrent: int = -1, + ) -> list[State]: + """Generate rollouts, dispatching to environments based on input['task'].""" + # Store context in task-local variables for spawn() calls + # Using contextvars allows concurrent generate() calls on same Protocol + _ctx_client.set(client) + _ctx_model.set(model) + _ctx_sampling_args.set(sampling_args) + + # Group inputs by target environment + by_env: dict[str, list[RolloutInput]] = {} + for inp in inputs: + env_name = inp.get("task") or self._get_default_env() + by_env.setdefault(env_name, []).append(inp) + + # Run each environment's generate() + all_states: list[State] = [] + for env_name, env_inputs in by_env.items(): + env = self.get_env(env_name) + results = await env.generate( + env_inputs, + client=client, + model=model, + sampling_args=sampling_args, + max_concurrent=max_concurrent, + ) + all_states.extend(results["state"]) + + # Flatten: collect trainable child_states recursively + return self._flatten_states(all_states) + + def _get_default_env(self) -> str: + """Return first registered environment as default.""" + return next(iter(self._envs.keys())) + + def _flatten_states(self, states: list[State]) -> list[State]: + """Recursively collect all states including children.""" + result: list[State] = [] + for state in states: + result.append(state) + child_states = state.get("child_states", []) + if child_states: + result.extend(self._flatten_states(child_states)) + return result + + async def spawn( + self, + inputs: list[RolloutInput], + score: bool = True, + client: AsyncOpenAI | None = None, + model: str | None = None, + sampling_args: SamplingArgs | None = None, + ) -> list[State]: + """Spawn child rollouts from within an environment. + + Can pass client/model explicitly, or they'll be read from context vars + (set by Protocol.generate()). + """ + # Try explicit params first, then context vars + client = client or _ctx_client.get() + model = model or _ctx_model.get() + sampling_args = sampling_args or _ctx_sampling_args.get() + + if client is None or model is None: + raise RuntimeError( + "spawn() requires client and model. Either pass them explicitly " + "or call spawn() from within a Protocol.generate() context." + ) + + # Run all rollouts in parallel + tasks = [] + for inp in inputs: + env_name = inp.get("task") or self._get_default_env() + env = self.get_env(env_name) + tasks.append( + env.rollout( + inp, + client=client, + model=model, + sampling_args=sampling_args, + ) + ) + + all_states = await asyncio.gather(*tasks) + + # Score rollouts if requested + if score: + from verifiers.utils.async_utils import maybe_semaphore + score_sem = await maybe_semaphore(-1) # No concurrency limit + for inp, state in zip(inputs, all_states): + env_name = inp.get("task") or self._get_default_env() + env = self.get_env(env_name) + if env.rubric: + await env.rubric.score_rollout(state, score_sem=score_sem) + + return list(all_states) + + async def evaluate( + self, + client: AsyncOpenAI, + model: str, + sampling_args: SamplingArgs | None = None, + num_examples: int = -1, + rollouts_per_example: int = 1, + max_concurrent: int = -1, + seed: int | None = None, + ) -> list[State]: + """Evaluate model on the Protocol's evaluation dataset.""" + inputs = self.get_eval_inputs( + n=num_examples, + rollouts_per_example=rollouts_per_example, + seed=seed, + ) + return await self.generate( + inputs=inputs, + client=client, + model=model, + sampling_args=sampling_args, + max_concurrent=max_concurrent, + ) diff --git a/verifiers/rl/trainer/__init__.py b/verifiers/rl/trainer/__init__.py index 3fef17b46..3c1f4846d 100644 --- a/verifiers/rl/trainer/__init__.py +++ b/verifiers/rl/trainer/__init__.py @@ -3,6 +3,7 @@ import torch._dynamo from .config import RLConfig +from .multiagent_orchestrator import MultiAgentOrchestrator from .trainer import RLTrainer torch._dynamo.config.suppress_errors = True @@ -30,6 +31,7 @@ def lora_defaults(**kwargs): __all__ = [ "RLConfig", "RLTrainer", + "MultiAgentOrchestrator", "GRPOTrainer", "GRPOConfig", "grpo_defaults", diff --git a/verifiers/rl/trainer/multiagent_orchestrator.py b/verifiers/rl/trainer/multiagent_orchestrator.py new file mode 100644 index 000000000..6cf26f98c --- /dev/null +++ b/verifiers/rl/trainer/multiagent_orchestrator.py @@ -0,0 +1,327 @@ +""" +MultiAgentOrchestrator: Training integration for multi-agent environments. + +This orchestrator wraps a Protocol to enable multi-agent and multi-environment +training. It delegates batch generation to Protocol.generate() which handles: +- Routing examples to correct environments (via 'task' field) +- Multi-actor turn management +- Child episode spawning +- Per-actor state flattening for credit assignment + +Key differences from base Orchestrator: +- Uses Protocol's unified dataset instead of single env's dataset +- Calls protocol.generate() instead of env.generate() +- Receives flattened states (per-actor) with pre-computed advantages + +Flow: +1. get_dataset_slice() pulls from Protocol's dataset +2. generate_batch() calls protocol.generate() for rollouts +3. Protocol returns flattened per-actor states with advantages +4. This class packages them into microbatches for training + +""" + +import time +from typing import Any + +import numpy as np +from datasets import Dataset + +from verifiers.envs.protocol import Protocol + +from .orchestrator import Batch, Microbatch, Orchestrator + + +# ============================================================================= +# MultiAgentOrchestrator +# ============================================================================= + +class MultiAgentOrchestrator(Orchestrator): + """ + Orchestrator that delegates to Protocol for multi-agent generation. + + Extends base Orchestrator but overrides: + - get_dataset_slice(): Use Protocol's dataset instead of env's + - generate_batch(): Use protocol.generate() instead of env.generate() + + All other functionality (tokenizer, batch sizes, client setup) inherited + from parent Orchestrator. + """ + + # ------------------------------------------------------------------------- + # Initialization + # ------------------------------------------------------------------------- + + def __init__( + self, + protocol: Protocol, + **kwargs, + ): + """ + Initialize orchestrator with Protocol. + + Args: + protocol: Protocol containing actors, environments, and dataset + **kwargs: Passed to parent Orchestrator (client, model_name, + sampling_args, batch sizes, etc.) + """ + self.protocol = protocol + + # Parent Orchestrator requires an env parameter for initialization. + # We use the first env from Protocol - parent uses it to set up + # tokenizer and other config, but we override the actual generation. + first_env = next(iter(protocol.envs.values())) + super().__init__(env=first_env, **kwargs) + + # ---- Filter Protocol's dataset by prompt length ---- + # Parent's __init__ filters self.env.dataset, but we use Protocol's + # dataset instead, so we need to apply the same filtering here. + max_length = self.max_prompt_len + + def filter_by_prompt_length(example, processing_class): + """Keep only examples whose prompts fit in context.""" + prompt = example["prompt"] + if isinstance(prompt, list): + # Chat format - apply template to get full text + prompt_text = processing_class.apply_chat_template( + prompt, tokenize=False, add_generation_prompt=True + ) + else: + prompt_text = prompt + prompt_ids = processing_class.encode(prompt_text) + return len(prompt_ids) <= max_length + + if self.protocol._dataset is not None: + self.protocol._dataset = self.protocol.get_dataset().filter( + filter_by_prompt_length, + fn_kwargs={"processing_class": self.processing_class}, + ) + + # ------------------------------------------------------------------------- + # Dataset Access + # ------------------------------------------------------------------------- + + def get_dataset_slice(self, batch_id: int) -> Dataset: + """ + Get dataset slice from Protocol's unified dataset. + + Overrides parent to use Protocol's dataset instead of env's dataset. + This is necessary because Protocol's dataset may contain examples for + multiple environments (routed by 'task' field). + + Args: + batch_id: Which batch to get (determines offset into dataset) + + Returns: + Dataset slice with prompts_per_batch examples + + """ + num_rows = self.prompts_per_batch + dataset = self.protocol.get_dataset() + total_rows = len(dataset) + + if total_rows == 0: + raise ValueError("Protocol dataset is empty") + + # Calculate offset with wraparound for continuous training + offset = (batch_id * num_rows) % total_rows + indices = [(offset + i) % total_rows for i in range(num_rows)] + + return dataset.select(indices) + + # ------------------------------------------------------------------------- + # Batch Generation + # ------------------------------------------------------------------------- + + async def generate_batch(self, batch_id: int) -> Batch: + """ + Generate training batch using protocol.generate() for multi-agent support. + + Overrides parent to use Protocol instead of single environment. + Protocol handles: + - Routing examples to correct environments + - Multi-actor turn management + - Child episode spawning + - Per-actor state flattening with pre-computed advantages + + Args: + batch_id: Batch identifier (determines dataset slice) + + Returns: + Batch object containing microbatches ready for training + + Flow: + 1. Get dataset slice and repeat for GRPO rollouts + 2. Call protocol.generate() → flattened per-actor states + 3. Extract training data from trajectories + 4. Collect metrics for logging + 5. Package into microbatches for distributed training + """ + self.is_generating = True + assert self.client is not None + start_time = time.time() + + # ==== Step 1: Prepare inputs ==== + # Get batch of examples and repeat each for multiple GRPO rollouts + # e.g., 8 examples × 4 rollouts = 32 inputs + batch_ds = self.get_dataset_slice(batch_id) + repeated_ds = batch_ds.repeat(self.rollouts_per_example) + inputs = repeated_ds.to_list() + + # ==== Step 2: Run rollouts via Protocol ==== + # Protocol.generate() returns FLATTENED states: + # - Original game states are replaced by per-actor child states + # - Each state has pre-computed reward and advantage + # - Advantages computed per-actor within GRPO groups + all_states = await self.protocol.generate( + inputs, + client=self.client, + model=self.model_name, + sampling_args=self.sampling_args, + max_concurrent=self.max_concurrent, + ) + + self.is_generating = False + wall_clock_s = time.time() - start_time + + # ==== Step 3: Extract training data from trajectories ==== + # Each trajectory step = one model call = one training example + # Multi-agent states have trajectory filtered to single actor's turns + prompt_ids: list[list[int]] = [] + prompt_mask: list[list[int]] = [] + completion_ids: list[list[int]] = [] + completion_mask: list[list[int]] = [] + completion_logprobs: list[list[float]] = [] + advantages: list[float] = [] + + for state in all_states: + trajectory = state.get("trajectory", []) + for step in trajectory: + # Skip steps without tokenized data (e.g., env-only turns) + tokens = step.get("tokens") + if tokens is None: + continue + + # Tokenized prompt and completion for this turn + prompt_ids.append(tokens["prompt_ids"]) + prompt_mask.append(tokens["prompt_mask"]) + completion_ids.append(tokens["completion_ids"]) + completion_mask.append(tokens["completion_mask"]) + + # Log probs from sampling (for importance weighting in GRPO) + completion_logprobs.append(tokens["completion_logprobs"]) + + # Advantage already computed per-actor during scoring + advantages.append(step.get("advantage", 0.0)) + + # ==== Step 4: Collect metrics for logging ==== + # Rewards per state (for logging, not training) + rewards = [state.get("reward", 0.0) for state in all_states] + rewards_dict: dict[str, list[float]] = {"reward": rewards} + + metrics_dict: dict[str, float] = {} + + # Reward statistics + if rewards: + rewards_arr = np.asarray(rewards, dtype=np.float32) + metrics_dict["reward"] = float(rewards_arr.mean()) + metrics_dict["reward/std"] = float(rewards_arr.std()) + + # Advantage statistics (should be mean ~0 after GRPO normalization) + if advantages: + adv_arr = np.asarray(advantages, dtype=np.float32) + metrics_dict["advantage/absmean"] = float(np.abs(adv_arr).mean()) + + # Token statistics + completion_lengths = [len(ids) for ids in completion_ids] + if completion_lengths: + completion_lengths_arr = np.asarray(completion_lengths, dtype=np.float32) + metrics_dict["tokens/completion"] = float(completion_lengths_arr.mean()) + + # Calculate fraction of tokens that are masked (padding) + completion_mask_lengths = np.asarray( + [sum(mask) for mask in completion_mask], + dtype=np.float32, + ) + valid_tokens = completion_mask_lengths.sum() + total_tokens = completion_lengths_arr.sum() + if total_tokens > 0: + masked_fraction = 1.0 - (valid_tokens / total_tokens) + metrics_dict["tokens/masked_fraction"] = float(masked_fraction) + + metrics_dict["wall_clock/generate_s"] = float(wall_clock_s) + + # Collect raw data for logging/debugging + errors = [state.get("error") for state in all_states] + completions = [state.get("completion") for state in all_states] + prompts = [state.get("prompt") for state in all_states] + + # ==== Step 5: Build microbatches for distributed training ==== + # Split training examples across GPU processes, then into microbatches + # + + N = len(advantages) # Total training examples (trajectory steps) + per_proc = N // self.num_processes if self.num_processes > 0 else N + microbatches: list[list[Microbatch]] = [] + items_per_process: list[int] = [] + + for proc in range(self.num_processes): + # Index range for this process + ps = proc * per_proc # process start + pe = ps + per_proc # process end + + proc_mbs: list[Microbatch] = [] + proc_item_total = 0 + + # Split process's examples into microbatches + for s in range(ps, pe, self.micro_batch_size): + e = min(s + self.micro_batch_size, pe) + + # Combine prompt + completion into single sequence for training + ids_chunk = [prompt_ids[i] + completion_ids[i] for i in range(s, e)] + mask_chunk = [prompt_mask[i] + completion_mask[i] for i in range(s, e)] + + # Log probs: zeros for prompt (no loss), actual for completion + logprobs_chunk = [ + [0.0] * len(prompt_mask[i]) + completion_logprobs[i] + for i in range(s, e) + ] + + # Expand scalar advantage to per-token (same value repeated) + lengths = [len(mask) for mask in mask_chunk] + adv_chunk = [ + [advantages[i]] * lengths[idx] + for idx, i in enumerate(range(s, e)) + ] + + # Count valid (non-masked) tokens for normalization + mb_items = sum(sum(mask) for mask in mask_chunk) + + microbatch = Microbatch( + input_ids=ids_chunk, + loss_mask=mask_chunk, + sampling_logprobs=logprobs_chunk, + advantages=adv_chunk, + items=mb_items, + ) + proc_item_total += mb_items + proc_mbs.append(microbatch) + + microbatches.append(proc_mbs) + items_per_process.append(proc_item_total) + + global_item_count = sum(items_per_process) + + # ==== Return complete batch ==== + return Batch( + batch_id=batch_id, + microbatches=microbatches, + items_per_process=items_per_process, + global_item_count=global_item_count, + generation_time=wall_clock_s, + rewards_dict=rewards_dict, + completions=completions, # For logging + prompts=prompts, # For logging + errors=errors, # For debugging + metrics_dict=metrics_dict, + ) diff --git a/verifiers/rubrics/multiagent_rubric.py b/verifiers/rubrics/multiagent_rubric.py new file mode 100644 index 000000000..a77d97ada --- /dev/null +++ b/verifiers/rubrics/multiagent_rubric.py @@ -0,0 +1,270 @@ +""" +Multi-agent rubric with per-actor rewards and hierarchical credit assignment. + +Extends Rubric with: +- Per-actor reward functions (different rewards for different actors) +- Per-actor GRPO advantages (within-actor normalization) +- Bottom-up hierarchical scoring (children scored first, parents use child results) +""" + +from __future__ import annotations + +import asyncio +from collections import defaultdict +from typing import TYPE_CHECKING, AsyncContextManager, Awaitable, Callable + +import verifiers as vf +from verifiers.rubrics.rubric import Rubric +from verifiers.types import RewardFunc, State + +if TYPE_CHECKING: + from verifiers.envs.protocol import GenerateResult + +# Type alias for hierarchical reward functions +HierarchicalRewardFunc = Callable[["GenerateResult", list["GenerateResult"]], Awaitable[float] | float] + + +class MultiAgentRubric(Rubric): + """ + Rubric with per-actor rewards and hierarchical credit assignment. + + GRPO advantages are computed within actor groups (solver vs solver), + not across all actors, preventing unfair comparisons. + """ + + def __init__( + self, + funcs: list[RewardFunc] | None = None, + weights: list[float] | None = None, + parser: vf.Parser | None = None, + ): + super().__init__(funcs=funcs, weights=weights, parser=parser) + + # Per-actor reward functions: actor_id -> [(func, weight), ...] + self.actor_reward_funcs: dict[str, list[tuple[RewardFunc, float]]] = defaultdict(list) + + # Hierarchical reward functions for parent-child relationships + self.hierarchical_reward_funcs: list[tuple[HierarchicalRewardFunc, float]] = [] + + def add_actor_reward_func( + self, + actor_id: str, + func: RewardFunc, + weight: float = 1.0, + ) -> None: + """Add a reward function specific to an actor.""" + self.actor_reward_funcs[actor_id].append((func, weight)) + + def add_actor_metric( + self, + actor_id: str, + func: RewardFunc, + ) -> None: + """Add a metric (zero-weight reward) for logging without affecting reward.""" + self.add_actor_reward_func(actor_id, func, weight=0.0) + + def add_hierarchical_reward_func( + self, + func: HierarchicalRewardFunc, + weight: float = 1.0, + ) -> None: + """Add a function that computes parent reward based on child results.""" + self.hierarchical_reward_funcs.append((func, weight)) + + + def get_actor_id_from_state(self, state: State) -> str | None: + """Extract actor ID from state (checks extras, actor_history, trajectory).""" + # Check extras first (primary location) + extras = state.get("extras", {}) + if "current_actor_id" in extras: + return extras["current_actor_id"] + + # Check actor_history (take first actor if available) + actor_history = extras.get("actor_history", []) + if actor_history: + return actor_history[0][0] + + # Check trajectory steps + trajectory = state.get("trajectory", []) + for step in trajectory: + step_extras = step.get("extras", {}) + if "actor_id" in step_extras: + return step_extras["actor_id"] + + return None + + async def _compute_actor_reward( + self, + state: State, + actor_id: str, + score_sem: AsyncContextManager, + ) -> tuple[float, dict[str, float]]: + """Compute reward using actor-specific + global reward functions. + + Computes ALL actor rewards for metrics (so columns match), but only + adds the current actor's reward to total_reward. + """ + total_reward = 0.0 + metrics: dict[str, float] = {} + + # Only compute rewards for the current actor (not all actors) + actor_funcs = self.actor_reward_funcs.get(actor_id, []) + for func, weight in actor_funcs: + try: + score = await self._call_individual_reward_func(func, state, score_sem) + score = score if score is not None else 0.0 # Handle None + metrics[func.__name__] = score + total_reward += score * weight + except Exception as e: + self.logger.error(f"Error in actor reward func {func.__name__}: {e}") + metrics[func.__name__] = 0.0 + + # Also compute global reward functions + for func, weight in zip(self.funcs, self.weights): + if not self._is_group_func(func): + try: + score = await self._call_individual_reward_func(func, state, score_sem) + score = score if score is not None else 0.0 # Handle None + total_reward += score * weight + metrics[func.__name__] = score + except Exception as e: + self.logger.error(f"Error in global reward func {func.__name__}: {e}") + metrics[func.__name__] = 0.0 + + return total_reward, metrics + + async def score_group( + self, + states: list[State], + score_sem: AsyncContextManager, + ) -> None: + """Score with per-actor GRPO advantages (solver vs solver, not vs proposer).""" + if not states: + self.logger.warning("No states to score") + return + + # Extract actor_ids once (cache for reuse) + actor_ids = [self.get_actor_id_from_state(s) or "default" for s in states] + + # Compute individual rewards in parallel + reward_tasks = [ + self._compute_actor_reward(state, actor_id, score_sem) + for state, actor_id in zip(states, actor_ids) + ] + results = await asyncio.gather(*reward_tasks) + + # Apply rewards and group by actor + actor_groups: dict[str, list[State]] = defaultdict(list) + for state, actor_id, (reward, metrics) in zip(states, actor_ids, results): + state["reward"] = reward + state["metrics"] = metrics + actor_groups[actor_id].append(state) + + # Compute GRPO advantages per-actor group + for actor_id, actor_states in actor_groups.items(): + # Compute mean reward for this actor group + actor_rewards = [s["reward"] for s in actor_states] + mean_reward = sum(actor_rewards) / len(actor_rewards) + + # Compute advantages relative to actor mean + for state in actor_states: + advantage = state["reward"] - mean_reward + state["advantage"] = advantage + + # Propagate to trajectory steps + for step in state.get("trajectory", []): + if step.get("advantage") is None: + step["advantage"] = advantage + if step.get("reward") is None: + step["reward"] = state["reward"] + + async def score_hierarchical( + self, + root_results: list["GenerateResult"], + score_sem: AsyncContextManager, + ) -> None: + """Score bottom-up: children first, then parents using hierarchical reward funcs.""" + from verifiers.envs.protocol import GenerateResult + + # Collect all results in tree order (children before parents) + all_results: list[GenerateResult] = [] + for root in root_results: + # Depth-first traversal, children first + all_results.extend(self._collect_bottom_up(root)) + + # Score each result + for result in all_results: + state = result.state + + # Get base reward from actor-specific scoring + actor_id = result.actor_id + reward, metrics = await self._compute_actor_reward(state, actor_id, score_sem) + + # Add hierarchical rewards based on children + if result.children and self.hierarchical_reward_funcs: + from verifiers.utils.async_utils import maybe_await + for func, weight in self.hierarchical_reward_funcs: + try: + hier_score = await maybe_await(func, result, result.children) + reward += hier_score * weight + metrics[f"hierarchical/{func.__name__}"] = hier_score + except Exception as e: + self.logger.error(f"Error in hierarchical func {func.__name__}: {e}") + + state["reward"] = reward + state["metrics"] = metrics + + # Compute per-actor, per-level advantages + await self._compute_hierarchical_advantages(root_results) + + def _collect_bottom_up(self, result: "GenerateResult") -> list["GenerateResult"]: + """Collect results in bottom-up order (children before parents).""" + collected = [] + for child in result.children: + collected.extend(self._collect_bottom_up(child)) + collected.append(result) + return collected + + async def _compute_hierarchical_advantages( + self, + root_results: list["GenerateResult"], + ) -> None: + """Compute GRPO advantages grouped by (actor_id, depth).""" + from verifiers.envs.protocol import GenerateResult + + # Group by (actor_id, depth) + groups: dict[tuple[str, int], list[GenerateResult]] = defaultdict(list) + + def collect_with_depth(result: GenerateResult, depth: int) -> None: + groups[(result.actor_id, depth)].append(result) + for child in result.children: + collect_with_depth(child, depth + 1) + + for root in root_results: + collect_with_depth(root, 0) + + # Compute advantages per group + for (actor_id, depth), group_results in groups.items(): + rewards = [r.state.get("reward", 0.0) for r in group_results] + mean_reward = sum(rewards) / len(rewards) + + for result in group_results: + advantage = result.state.get("reward", 0.0) - mean_reward + result.state["advantage"] = advantage + + # Propagate to trajectory + for step in result.state.get("trajectory", []): + if step.get("advantage") is None: + step["advantage"] = advantage + if step.get("reward") is None: + step["reward"] = result.state.get("reward", 0.0) + + def get_trainable_results( + self, + root_results: list["GenerateResult"], + ) -> list["GenerateResult"]: + """Extract only trainable results (is_trainable=True) from result trees.""" + trainable = [] + for root in root_results: + trainable.extend(root.get_trainable()) + return trainable diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 8035967c6..5d14beeae 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -562,6 +562,10 @@ def make_dataset(results: GenerateOutputs, **kwargs) -> Dataset: v = results["metrics"][k] results_dict[k] = v + # Add actor_id column for multi-agent environments + if "actor_id" in results and results["actor_id"]: + results_dict["actor_id"] = results["actor_id"] + # Add selected state columns if specified state_columns = results["metadata"]["state_columns"] if state_columns: