diff --git a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb index e19df898d54..c940ae34761 100644 --- a/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb +++ b/examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb @@ -196,7 +196,8 @@ "\n", "max_steps=10\n", "\n", - "def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:\n", + "async def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:\n", + " \"\"\"Async rollout function - TRL handles the event loop automatically.\"\"\"\n", " episode_prompt_ids: list[list[int]] = []\n", " episode_completion_ids: list[list[int]] = []\n", " episode_logprobs: list[list[float]] = []\n", @@ -206,7 +207,7 @@ "\n", " for i, prompt_text in enumerate(prompts):\n", " print(f\"[DEBUG] Processing prompt {i + 1}/{len(prompts)}\")\n", - " episode = rollout_once(\n", + " episode = await rollout_once(\n", " trainer=trainer,\n", " env=client,\n", " tokenizer=trainer.processing_class,\n", @@ -261,7 +262,7 @@ "from browsergym_env import BrowserGymAction\n", "from transformers import AutoTokenizer\n", "\n", - "def rollout_once(\n", + "async def rollout_once(\n", " trainer: GRPOTrainer,\n", " env: BrowserGymEnv,\n", " tokenizer: AutoTokenizer,\n", @@ -269,7 +270,7 @@ " max_steps: int,\n", ") -> dict[str, list]:\n", " \"\"\"Run one episode and collect training data (text-only, no screenshots).\"\"\"\n", - " result = env.reset()\n", + " result = await env.reset()\n", " observation = result.observation\n", "\n", " prompt_ids: list[int] = []\n", @@ -314,7 +315,7 @@ " print(f\"Step {step_num + 1}: {action_str}\")\n", "\n", " # Take action in environment\n", - " result = env.step(BrowserGymAction(action_str=action_str))\n", + " result = await env.step(BrowserGymAction(action_str=action_str))\n", " observation = result.observation\n", "\n", " # Track rewards\n", @@ -546,7 +547,7 @@ }, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ "/tmp/ipython-input-3830121904.py:1: UserWarning: You are importing from 'rollout_func', which is an experimental feature. This API may change or be removed at any time without prior notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.\n", @@ -570,7 +571,7 @@ "output_type": "display_data" }, { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ "Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 4/4 [00:00<00:00, 19.64it/s]\n" @@ -596,7 +597,7 @@ }, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': 2, 'pad_token_id': 0}.\n" @@ -678,7 +679,7 @@ ] }, { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ "WARNING:liger_kernel.transformers.model.gemma3:It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('', attn_implementation='eager')`.\n", @@ -1608,7 +1609,7 @@ "output_type": "display_data" }, { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ "No files have been modified since last commit. Skipping to prevent empty commit.\n", @@ -1700,7 +1701,7 @@ "output_type": "display_data" }, { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ "No files have been modified since last commit. Skipping to prevent empty commit.\n", @@ -1716,7 +1717,7 @@ "CommitInfo(commit_url='https://huggingface.co/sergiopaniego/browsergym-grpo-functiongemma-270m-it/commit/a17de133c28ca7fddfcb2694c32f2791de5ddbe6', commit_message='End of training', commit_description='', oid='a17de133c28ca7fddfcb2694c32f2791de5ddbe6', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sergiopaniego/browsergym-grpo-functiongemma-270m-it', endpoint='https://huggingface.co', repo_type='model', repo_id='sergiopaniego/browsergym-grpo-functiongemma-270m-it'), pr_revision=None, pr_num=None)" ] }, - "execution_count": 12, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/notebooks/openenv_sudoku_grpo.ipynb b/examples/notebooks/openenv_sudoku_grpo.ipynb index 7c2120a8d23..53bff822ff3 100644 --- a/examples/notebooks/openenv_sudoku_grpo.ipynb +++ b/examples/notebooks/openenv_sudoku_grpo.ipynb @@ -1,2577 +1,2578 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "lSR2nwdJg962" - }, - "source": [ - "# OpenEnv Sudoku with GRPO using TRL\n", - "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_sudoku_grpo.ipynb)\n", - "\n", - "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n", - "\n", - "With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can train a model that learns to **play Sudoku**, through interaction and reinforcement.\n", - "\n", - "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", - "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", - "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)\n", - "- [OpenEnv](https://github.com/meta-pytorch/OpenEnv)\n", - "\n", - "An **agentic environment** is a setting where a model can take actions, observe outcomes, and adjust its behavior based on feedback, similar to how humans learn from trial and error.\n", - "In this case, the agent interacts with the **Sudoku** environment through the [**OpenEnv**](https://github.com/meta-pytorch/OpenEnv) framework, which standardizes multi-agent and RL-style text environments.\n", - "\n", - "Sudoku is a classic logic-based puzzle where the objective is to fill a **9×9 grid** so that. Each **row**, **column**, and **3×3 subgrid** contains all digits from **1 to 9** exactly once.\n", - "\n", - "This structured yet challenging setup makes Sudoku an excellent benchmark for reasoning and decision-making tasks.\n", - "\n", - "We'll fine-tune a model using **GRPO** (Group Relative Policy Optimization) via TRL.\n", - "The training loop follows these steps:\n", - "\n", - "1. The agent **generates guesses** based on the current game state.\n", - "2. The environment **evaluates the guess** and returns structured feedback.\n", - "3. The agent **updates its policy** using reward signals to improve future decisions.\n", - "\n", - "Over time, the model learns to make increasingly valid and efficient Sudoku moves.\n", - "\n", - "## Install dependencies\n", - "\n", - "We'll start by installing **TRL**, which automatically includes the main dependencies like **Transformers**. \n", - "We'll also install the **OpenEnv** framework (for the environment) via the HF Space we will use as environment server ([openenv/sudoku](https://huggingface.co/spaces/openenv/sudoku)), **trackio** (for logging and monitoring training runs), and **vLLM** (for efficient generation)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mHmE7GhRKyJj" - }, - "outputs": [], - "source": [ - "!pip install -Uq trl[vllm] trackio git+https://huggingface.co/spaces/openenv/sudoku liger-kernel" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Inxeq6ZGpRno" - }, - "source": [ - "### Log in to Hugging Face\n", - "\n", - "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JRd5fGR-KyJk" - }, - "outputs": [], - "source": [ - "from huggingface_hub import notebook_login\n", - "\n", - "notebook_login()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O3kr38TGm_hb" - }, - "source": [ - "## Initialize the OpenEnv TextArena Environment\n", - "\n", - "Let's begin by setting up the environment that will be used throughout training.\n", - "\n", - "For this example, we will use the **TextArena** environment provided by **OpenEnv**, which exposes a familiar **Gymnasium-style API** (`reset()`, `step()`, etc.) to simplify interaction and integration with reinforcement learning pipelines.\n", - "\n", - "Specifically, we will connect to a **remote TextArena instance** that hosts a **Sudoku environment**, available at [openenv/sudoku](https://huggingface.co/spaces/openenv/sudoku).\n", - "\n", - "This setup allows us to interact with the environment without needing to run the backend locally.\n", - "\n", - "> ⚠️ **Note:** Hosted environments on the Hugging Face Hub have limited concurrency. \n", - "> For improved stability, higher throughput, or parallel experiments, it is recommended to **duplicate the Space into your own account**.\n", - "\n", - "For more information, refer to the [TRL-OpenEnv documentation](https://huggingface.co/docs/trl/main/en/openenv).\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "P6O03louKyJk" - }, - "outputs": [], - "source": [ - "from textarena_env import TextArenaEnv\n", - "\n", - "space_url = \"https://openenv-sudoku.hf.space\"\n", - "client = TextArenaEnv(base_url=space_url)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EqfDavDQnD_5" - }, - "source": [ - "## Create Rollout Function with Helpers\n", - "\n", - "The **rollout function** defines how the agent interacts with the environment during GRPO training.\n", - "It is responsible for generating model outputs, collecting feedback (rewards), and returning all the information needed for policy optimization.\n", - "\n", - "In this setup:\n", - "- The function is called automatically by the **GRPOTrainer** at each training step.\n", - "- It uses the trainer's `generate_rollout_completions()` method for efficient generation with **vLLM** in colocate mode.\n", - "- Each rollout represents a full interaction loop: the model makes guesses, receives feedback from the Sudoku environment, and updates its policy based on reward signals.\n", - "\n", - "Rewards track different aspects of the agent's performance, while helper functions like `rollout_once` handle a single episode of interaction, keeping the main `rollout_func` clean and modular.\n", - "\n", - "This modular approach allows GRPO to efficiently sample, evaluate, and improve the model's guessing strategy through reinforcement learning.\n", - "\n", - "First, we define the `system_prompt` that guides the model's behavior as an expert Sudoku solver with strategic reasoning and structured responses." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pi1JGoUBKyJk" - }, - "outputs": [], - "source": [ - "# @title System prompt (click to expand)\n", - "SYSTEM_PROMPT = \"\"\"You are an expert Sudoku player with deep knowledge of logical deduction strategies and number placement techniques.\n", - "\n", - "## GAME RULES\n", - "\n", - "1. The puzzle is a 9x9 grid divided into nine 3x3 subgrids (boxes)\n", - "2. Some cells are pre-filled with numbers 1-9\n", - "3. You must fill in the empty cells (shown as '.') with numbers 1-9\n", - "4. Each row must contain numbers 1-9 without repetition\n", - "5. Each column must contain numbers 1-9 without repetition\n", - "6. Each 3x3 subgrid must contain numbers 1-9 without repetition\n", - "7. You cannot overwrite pre-filled cells\n", - "8. Invalid moves result in penalties (-1 reward)\n", - "\n", - "## RESPONSE FORMAT\n", - "\n", - "**CRITICAL: Output ONLY the move, nothing else. No text, no explanation.**\n", - "\n", - "Format: [row col number]\n", - "\n", - "Examples:\n", - "- [5 3 7] → places 7 in row 5, column 3\n", - "- [1 2 4] → places 4 in row 1, column 2\n", - "\n", - "## STRATEGIC APPROACH\n", - "\n", - "Do not repeat the same move twice.\n", - "\n", - "### Basic Strategies\n", - "- **Naked Singles**: If a cell has only one possible candidate, fill it in immediately.\n", - "- **Hidden Singles**: If a number can only go in one cell within a row, column, or box, place it there.\n", - "- **Scanning**: Look at each row, column, and box to find where specific numbers can go.\n", - "\n", - "### Intermediate Strategies\n", - "- **Naked Pairs/Triples**: When two/three cells in a unit contain only the same candidates, eliminate those from other cells.\n", - "- **Hidden Pairs/Triples**: When numbers only appear in specific cells within a unit, those cells can only contain those numbers.\n", - "- **Pointing Pairs**: When a candidate in a box is restricted to a single row/column, eliminate it elsewhere.\n", - "\n", - "### Solving Process\n", - "1. Start by scanning the entire grid to identify easy fills (cells with few candidates)\n", - "2. Look for rows, columns, or boxes with many numbers already placed\n", - "3. Fill all naked singles first\n", - "4. Then look for hidden singles in each row, column, and box\n", - "5. Apply more advanced techniques as needed\n", - "\n", - "### Common Pitfalls to Avoid\n", - "- Don't guess randomly - Sudoku is pure logic\n", - "- Don't overlook any constraint (row, column, or box)\n", - "- Don't try to overwrite pre-filled cells\n", - "- Don't place invalid numbers (must be 1-9)\n", - "- Don't use invalid coordinates (must be 1-9)\n", - "- Don't repeat a move that was already made\n", - "\n", - "## EXAMPLES\n", - "\n", - "### Example 1: Naked Single\n", - "If row 3, column 4 can only contain the number 5:\n", - "[3 4 5]\n", - "\n", - "### Example 2: Hidden Single\n", - "If the number 8 can only go in one cell in row 1:\n", - "[1 7 8]\n", - "\n", - "### Example 3: Row Analysis\n", - "Row 2 is missing only value 5, and column 8 is the empty cell:\n", - "[2 8 5]\n", - "\n", - "### Example 4: Box Analysis\n", - "In the center box, only one cell can contain 9:\n", - "[5 5 9]\n", - "\n", - "## BOARD READING\n", - "\n", - "The board is displayed as a 9x9 grid:\n", - "- Numbers 1-9 are pre-filled or already placed\n", - "- Empty cells are shown as '.'\n", - "- Rows are labeled R1-R9 (top to bottom)\n", - "- Columns are labeled C1-C9 (left to right)\n", - "\n", - "Example board representation:\n", - "```\n", - " C1 C2 C3 C4 C5 C6 C7 C8 C9\n", - "R1 . 8 9 | 1 . . | . 3 7\n", - "R2 2 7 1 | 9 4 3 | 6 . 8\n", - "R3 . 6 5 | . 2 7 | 4 9 .\n", - " - - - - - - - - - - - - - - - -\n", - "R4 . . . | 7 8 . | 9 2 3\n", - "R5 . 9 2 | . 5 6 | . . 4\n", - "R6 7 3 8 | . . 2 | 1 . .\n", - " - - - - - - - - - - - - - - - -\n", - "R7 8 4 . | . . 9 | 5 . .\n", - "R8 5 . . | 6 . 8 | 3 4 9\n", - "R9 9 . 6 | 5 3 4 | 8 7 2\n", - "```\n", - "\n", - "## COORDINATE REFERENCE\n", - "\n", - "Row indices (top to bottom): 1, 2, 3, 4, 5, 6, 7, 8, 9\n", - "Column indices (left to right): 1, 2, 3, 4, 5, 6, 7, 8, 9\n", - "\n", - "Subgrid layout:\n", - "```\n", - "Subgrid 1 | Subgrid 2 | Subgrid 3\n", - " (R1-R3) (R1-R3) (R1-R3)\n", - " (C1-C3) (C4-C6) (C7-C9)\n", - "----------+-----------+----------\n", - "Subgrid 4 | Subgrid 5 | Subgrid 6\n", - " (R4-R6) (R4-R6) (R4-R6)\n", - " (C1-C3) (C4-C6) (C7-C9)\n", - "----------+-----------+----------\n", - "Subgrid 7 | Subgrid 8 | Subgrid 9\n", - " (R7-R9) (R7-R9) (R7-R9)\n", - " (C1-C3) (C4-C6) (C7-C9)\n", - "```\n", - "\n", - "## IMPORTANT CONSTRAINTS\n", - "\n", - "- Coordinates are 1-indexed (1-9 for both row and column)\n", - "- Numbers must be 1-9\n", - "- One move per response\n", - "- Must be a valid move (no rule violations)\n", - "- Never repeat a previous move\n", - "\n", - "## YOUR GOAL\n", - "\n", - "Output ONLY your move in the format [row col number]. No explanation, no reasoning, just the move.\n", - "\"\"\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Vi1rFey39GUl" - }, - "source": [ - "Now, let's define the `rollout_func`.\n", - "\n", - "This function manages the interaction between the model and the Sudoku environment. \n", - "For each prompt in the batch, it runs a full episode, collecting both the model's outputs and the corresponding rewards. These results are then used by GRPO to optimize the agent's policy.\n", - "\n", - "Each game allows the model to make **up to 100 turns**, giving it multiple chances to solve the puzzle.\n", - "We have different difficulty levels available: `'easy'`, `'medium'`, and `'hard'`. The level affects the amount of information provided in the prompt. Higher difficulties give less guidance.\n", - "\n", - "For the **easy** level, the Qwen/Qwen3-1.7B model is sufficient to solve the puzzles efficiently in a Colab notebook.\n", - "For **medium** or **hard** levels, a larger or more advanced model would likely be needed." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "wMQQoQ_UKyJl" - }, - "outputs": [], - "source": [ - "from trl import GRPOTrainer\n", - "\n", - "max_turns = 100\n", - "debug = False # Activate for detailed logs during training\n", - "difficulty=\"easy\"\n", - "\n", - "def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:\n", - " all_prompt_ids = []\n", - " all_completion_ids = []\n", - " all_logprobs = []\n", - " all_correct = []\n", - " all_valid = []\n", - " all_empty_cell = []\n", - " all_repetition = []\n", - " all_progress = []\n", - "\n", - " for _ in prompts:\n", - " episode = rollout_once(\n", - " trainer=trainer,\n", - " env=client,\n", - " tokenizer=trainer.processing_class,\n", - " system_prompt=SYSTEM_PROMPT,\n", - " max_turns=max_turns,\n", - " debug=debug,\n", - " difficulty=difficulty,\n", - " )\n", - " all_prompt_ids.append(episode[\"prompt_ids\"])\n", - " all_completion_ids.append(episode[\"completion_ids\"])\n", - " all_logprobs.append(episode[\"logprobs\"])\n", - " all_correct.append(episode[\"correct_reward\"])\n", - " all_valid.append(episode[\"valid_move_reward\"])\n", - " all_empty_cell.append(episode[\"empty_cell_reward\"])\n", - " all_repetition.append(episode[\"repetition_reward\"])\n", - " all_progress.append(episode[\"progress_reward\"])\n", - "\n", - " return {\n", - " \"prompt_ids\": all_prompt_ids,\n", - " \"completion_ids\": all_completion_ids,\n", - " \"logprobs\": all_logprobs,\n", - " \"correct_reward\": all_correct,\n", - " \"valid_move_reward\": all_valid,\n", - " \"empty_cell_reward\": all_empty_cell,\n", - " \"repetition_reward\": all_repetition,\n", - " \"progress_reward\": all_progress,\n", - " }\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ioUHdIxr9ZQO" - }, - "source": [ - "### Define `rollout_once`\n", - "\n", - "The `rollout_once` function runs **a single interaction loop** between the model and the Sudoku environment using the trainer's generation method. \n", - "It executes one mini-episode, from generating a guess to receiving and processing feedback.\n", - "\n", - "Step-by-step:\n", - "\n", - "1. **Environment reset:** Start a new game session and initialize the observation.\n", - "2. **Prompt construction:** Combine the system prompt, current state, and user messages to form the model input.\n", - "3. **Generation:** Use `trl.experimental.openenv.generate_rollout_completions()` to efficiently produce the model's guess.\n", - "4. **Feedback extraction:** Parse the environment's response with helpers like `extract_sudoku_move()` and `extract_feedback()`.\n", - "5. **Reward calculation:** Compute rewards based on correctness, valid moves, empty cell moves, repeated moves, and progress.\n", - "6. **Return structured rollout data:** Includes prompt and completion IDs, log probabilities, and all reward components.\n", - "\n", - "This design allows each episode to be processed independently while providing detailed feedback for the **GRPO training loop**." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "AZim6XzEKyJl" - }, - "outputs": [], - "source": [ - "from trl.experimental.openenv import generate_rollout_completions\n", - "from textarena_env import TextArenaAction\n", - "from transformers import AutoTokenizer\n", - "from collections import defaultdict\n", - "\n", - "\n", - "def rollout_once(\n", - " trainer: GRPOTrainer,\n", - " env: TextArenaEnv,\n", - " tokenizer: AutoTokenizer,\n", - " system_prompt: str,\n", - " max_turns: int,\n", - " debug: bool = False,\n", - " difficulty: str = \"hard\",\n", - ") -> dict[str, list]:\n", - " result = env.reset()\n", - " observation = result.observation\n", - "\n", - " # Only store the LAST turn for backprop (much more efficient!)\n", - " last_turn_data: dict | None = None\n", - "\n", - " valid_move_scores: list[float] = []\n", - " empty_cell_scores: list[float] = []\n", - " correct_scores: list[float] = []\n", - " repetition_scores: list[float] = []\n", - "\n", - " move_counts: defaultdict[str, int] = defaultdict(int)\n", - "\n", - " # Track successful and failed moves for summary\n", - " successful_moves: list[str] = []\n", - " failed_moves: list[str] = []\n", - "\n", - " # Extract initial board state\n", - " last_board_state = \"\"\n", - " initial_filled = 0\n", - " for message in observation.messages:\n", - " if message.content and is_valid_board_state(message.content):\n", - " last_board_state = message.content\n", - " initial_filled = count_filled_cells(last_board_state)\n", - " break\n", - "\n", - " max_filled = initial_filled # Track max progress\n", - "\n", - " for turn in range(max_turns):\n", - " if result.done:\n", - " break\n", - "\n", - " # Build COMPACT prompt (saves tokens!)\n", - " user_prompt = make_compact_prompt(\n", - " board=last_board_state,\n", - " step=turn + 1,\n", - " successful_moves=successful_moves,\n", - " failed_moves=failed_moves,\n", - " difficulty=difficulty,\n", - " )\n", - " messages = [\n", - " {\"role\": \"system\", \"content\": system_prompt},\n", - " {\"role\": \"user\", \"content\": user_prompt},\n", - " ]\n", - " prompt_text = tokenizer.apply_chat_template(\n", - " messages, add_generation_prompt=True, tokenize=False, enable_thinking=False # `enable_thinking` is usable for the current model but could need to be updated for other models\n", - " )\n", - "\n", - " if debug:\n", - " print(f\"\\n{'=' * 60}\")\n", - " print(f\"STEP {turn + 1}\")\n", - " print(f\"{'=' * 60}\")\n", - " print(f\"USER PROMPT:\\n{user_prompt}\")\n", - " print(f\"{'=' * 60}\")\n", - "\n", - " # Generate\n", - " rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]\n", - "\n", - " # Store ONLY this turn's data (replace previous)\n", - " last_turn_data = {\n", - " \"prompt_ids\": rollout_outputs[\"prompt_ids\"],\n", - " \"completion_ids\": rollout_outputs[\"completion_ids\"],\n", - " \"logprobs\": rollout_outputs[\"logprobs\"],\n", - " }\n", - "\n", - " if debug:\n", - " step_tokens = len(rollout_outputs[\"prompt_ids\"]) + len(rollout_outputs[\"completion_ids\"])\n", - " print(f\"TOKENS: this_step={step_tokens} (only last turn used for backprop)\")\n", - "\n", - " completion_text = rollout_outputs.get(\"text\") or tokenizer.decode(\n", - " rollout_outputs[\"completion_ids\"], skip_special_tokens=True\n", - " )\n", - "\n", - " # Extract move\n", - " move = extract_sudoku_move(completion_text)\n", - "\n", - " if debug:\n", - " print(f\"MODEL OUTPUT: {completion_text}\")\n", - " print(f\"EXTRACTED MOVE: {move}\")\n", - "\n", - " # Step environment\n", - " result = env.step(TextArenaAction(message=move))\n", - " observation = result.observation\n", - " correct_score = float(result.reward or 0.0)\n", - "\n", - " # Get feedback\n", - " feedback = extract_feedback(observation)\n", - "\n", - " # Get environment response\n", - " env_response = \"\"\n", - " for msg in observation.messages:\n", - " if msg.sender_id == -1: # Environment message\n", - " env_response = msg.content\n", - " break\n", - "\n", - " if debug:\n", - " print(\n", - " f\"ENV RESPONSE: {env_response[:200]}...\"\n", - " if len(env_response) > 200\n", - " else f\"ENV RESPONSE: {env_response}\"\n", - " )\n", - " print(f\"VALID: {feedback['valid_move']}, WARNING: {feedback['got_warning']}, REWARD: {correct_score}\")\n", - "\n", - " # Calculate empty_cell_score\n", - " if last_board_state and move:\n", - " targets_empty = check_move_targets_empty_cell(move, last_board_state)\n", - " empty_cell_score = 1.0 if targets_empty else -1.0\n", - " else:\n", - " empty_cell_score = 0.0\n", - "\n", - " # Calculate valid_move_score and repetition_score\n", - " is_new_move = move_counts[move] == 0\n", - " repetition_count = move_counts[move]\n", - " move_counts[move] += 1\n", - "\n", - " # Exponential penalty for repetitions: -2^(n-1) capped at -10\n", - " # 1st repeat: -1, 2nd: -2, 3rd: -4, 4th+: -10 (capped)\n", - " if repetition_count > 0:\n", - " repetition_score = -min(2 ** (repetition_count - 1), 10.0)\n", - " else:\n", - " repetition_score = 0.0\n", - "\n", - " if debug:\n", - " print(\n", - " f\"SCORES: empty_cell={empty_cell_score}, is_new={is_new_move}, repetitions={repetition_count}, rep_penalty={repetition_score}\"\n", - " )\n", - "\n", - " if not debug:\n", - " print(f\"Step {turn + 1}: {move}\")\n", - "\n", - " if feedback[\"valid_move\"] and is_new_move:\n", - " valid_move_score = 1.0\n", - " if move:\n", - " successful_moves.append(move) # Track for summary\n", - " elif feedback[\"got_warning\"]:\n", - " valid_move_score = -0.5\n", - " if move:\n", - " failed_moves.append(move) # Track for summary\n", - " else:\n", - " valid_move_score = 0.0\n", - "\n", - " # Update board state and track progress\n", - " if feedback[\"board_state\"] and is_valid_board_state(feedback[\"board_state\"]):\n", - " last_board_state = feedback[\"board_state\"]\n", - " current_filled = count_filled_cells(last_board_state)\n", - " if current_filled > max_filled:\n", - " max_filled = current_filled\n", - "\n", - " valid_move_scores.append(valid_move_score)\n", - " empty_cell_scores.append(empty_cell_score)\n", - " correct_scores.append(correct_score)\n", - " repetition_scores.append(repetition_score)\n", - "\n", - " # Aggregate rewards\n", - " correct_reward = correct_scores[-1] if correct_scores else 0.0\n", - " valid_move_reward = sum(valid_move_scores) / len(valid_move_scores) if valid_move_scores else 0.0\n", - " empty_cell_reward = sum(empty_cell_scores) / len(empty_cell_scores) if empty_cell_scores else 0.0\n", - " repetition_reward = sum(repetition_scores) / len(repetition_scores) if repetition_scores else 0.0\n", - "\n", - " # Progress reward: how many cells we filled beyond initial state (normalized to 0-1)\n", - " # 81 total cells, so (max_filled - initial_filled) / (81 - initial_filled) gives progress\n", - " remaining_to_fill = 81 - initial_filled\n", - " if remaining_to_fill > 0:\n", - " progress_reward = (max_filled - initial_filled) / remaining_to_fill\n", - " else:\n", - " progress_reward = 1.0 # Already complete\n", - "\n", - " # Use ONLY last turn for backpropagation (much more efficient!)\n", - " if last_turn_data:\n", - " prompt_ids = last_turn_data[\"prompt_ids\"]\n", - " completion_ids = last_turn_data[\"completion_ids\"]\n", - " logprobs = last_turn_data[\"logprobs\"]\n", - " else:\n", - " prompt_ids = []\n", - " completion_ids = []\n", - " logprobs = []\n", - "\n", - " total_tokens = len(prompt_ids) + len(completion_ids)\n", - " cells_filled = max_filled - initial_filled\n", - " print(\n", - " f\"Episode: empty_cell={empty_cell_reward:.2f}, valid={valid_move_reward:.2f}, \"\n", - " f\"repetition={repetition_reward:.2f}, progress={progress_reward:.2f} ({cells_filled} cells), \"\n", - " f\"correct={correct_reward:.2f}, tokens={total_tokens}\"\n", - " )\n", - "\n", - " return {\n", - " \"prompt_ids\": prompt_ids,\n", - " \"completion_ids\": completion_ids,\n", - " \"logprobs\": logprobs,\n", - " \"correct_reward\": correct_reward,\n", - " \"valid_move_reward\": valid_move_reward,\n", - " \"empty_cell_reward\": empty_cell_reward,\n", - " \"repetition_reward\": repetition_reward,\n", - " \"progress_reward\": progress_reward,\n", - " }" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MDJKMQ__8qzj" - }, - "source": [ - "### Helper Functions\n", - "\n", - "These utility functions are used within `rollout_once` to process the environment and model outputs:\n", - "\n", - "- **`extract_sudoku_move`**: Extract a Sudoku move `[row, col, number]` from text. \n", - "- **`is_valid_board_state`**: Check if a string represents a valid Sudoku board. \n", - "- **`parse_board`**: Convert a board string into a 9×9 grid (with `0` for empty cells). \n", - "- **`count_filled_cells`**: Count the number of filled cells in the board. \n", - "- **`get_valid_numbers`**: Get the valid numbers for a specific cell according to Sudoku rules. \n", - "- **`extract_empty_cells_with_candidates`**: Identify empty cells along with their valid candidate numbers. \n", - "- **`extract_empty_cells`**: List all empty cells `(row, col)` from a board string. \n", - "- **`extract_board_only`**: Extract just the Sudoku grid from a message. \n", - "- **`make_compact_prompt`**: Create a concise prompt with only essential information to save tokens. \n", - "- **`check_move_targets_empty_cell`**: Verify if a proposed move targets an empty cell on the board. \n", - "- **`extract_feedback`**: Extract structured feedback from the environment's observation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0f9RqHh7KyJl" - }, - "outputs": [], - "source": [ - "# @title Helpers (click to expand)\n", - "import re\n", - "\n", - "def extract_sudoku_move(text: str) -> str:\n", - " \"\"\"Extract a Sudoku move [row col number] from text.\"\"\"\n", - " # Try with spaces\n", - " match = re.search(r\"\\[(\\d)\\s+(\\d)\\s+(\\d)\\]\", text)\n", - " if match:\n", - " row, col, num = match.groups()\n", - " return f\"[{row} {col} {num}]\"\n", - "\n", - " # Try without spaces\n", - " match = re.search(r\"\\[(\\d)(\\d)(\\d)\\]\", text)\n", - " if match:\n", - " row, col, num = match.groups()\n", - " return f\"[{row} {col} {num}]\"\n", - "\n", - " return \"\" # Handled by the environment: missing/invalid moves trigger a \"wrong movement\" message affecting rewards\n", - "\n", - "\n", - "def is_valid_board_state(board_str: str) -> bool:\n", - " \"\"\"Check if the string contains an actual Sudoku board.\"\"\"\n", - " return \"R1\" in board_str and \"R9\" in board_str and \"|\" in board_str\n", - "\n", - "\n", - "def parse_board(board_str: str) -> list[list[int]]:\n", - " \"\"\"Parse board string into 9x9 grid (0 = empty).\"\"\"\n", - " grid = [[0] * 9 for _ in range(9)]\n", - " if not is_valid_board_state(board_str):\n", - " return grid\n", - "\n", - " for line in board_str.split(\"\\n\"):\n", - " line_stripped = line.strip()\n", - " if line_stripped and line_stripped[0] == \"R\" and len(line_stripped) > 1 and line_stripped[1].isdigit():\n", - " row = int(line_stripped[1]) - 1 # 0-indexed\n", - " cell_part = line_stripped[2:]\n", - " col = 0\n", - " for char in cell_part:\n", - " if char == \".\":\n", - " grid[row][col] = 0\n", - " col += 1\n", - " elif char.isdigit():\n", - " grid[row][col] = int(char)\n", - " col += 1\n", - " return grid\n", - "\n", - "\n", - "def count_filled_cells(board_str: str) -> int:\n", - " \"\"\"Count the number of filled cells in the board.\"\"\"\n", - " if not is_valid_board_state(board_str):\n", - " return 0\n", - " grid = parse_board(board_str)\n", - " return sum(1 for row in grid for cell in row if cell != 0)\n", - "\n", - "\n", - "def get_valid_numbers(grid: list[list[int]], row: int, col: int) -> set[int]:\n", - " \"\"\"Get valid numbers for a cell based on Sudoku rules.\"\"\"\n", - " if grid[row][col] != 0:\n", - " return set()\n", - "\n", - " used = set()\n", - "\n", - " # Check row\n", - " for c in range(9):\n", - " if grid[row][c] != 0:\n", - " used.add(grid[row][c])\n", - "\n", - " # Check column\n", - " for r in range(9):\n", - " if grid[r][col] != 0:\n", - " used.add(grid[r][col])\n", - "\n", - " # Check 3x3 box\n", - " box_row, box_col = 3 * (row // 3), 3 * (col // 3)\n", - " for r in range(box_row, box_row + 3):\n", - " for c in range(box_col, box_col + 3):\n", - " if grid[r][c] != 0:\n", - " used.add(grid[r][c])\n", - "\n", - " return set(range(1, 10)) - used\n", - "\n", - "\n", - "def extract_empty_cells_with_candidates(\n", - " board_str: str, sort_by_difficulty: bool = True\n", - ") -> list[tuple[int, int, set[int]]]:\n", - " \"\"\"Extract empty cells with their valid candidate numbers.\n", - "\n", - " Args:\n", - " sort_by_difficulty: If True, sort by number of candidates (easiest first).\n", - " If False, keep natural order (top-left to bottom-right).\n", - " \"\"\"\n", - " grid = parse_board(board_str)\n", - " cells_with_candidates = []\n", - "\n", - " for row in range(9):\n", - " for col in range(9):\n", - " if grid[row][col] == 0:\n", - " candidates = get_valid_numbers(grid, row, col)\n", - " cells_with_candidates.append((row + 1, col + 1, candidates)) # 1-indexed\n", - "\n", - " if sort_by_difficulty:\n", - " # Sort by number of candidates (easiest first = naked singles)\n", - " cells_with_candidates.sort(key=lambda x: len(x[2]))\n", - "\n", - " return cells_with_candidates\n", - "\n", - "\n", - "def extract_empty_cells(board_str: str) -> list[tuple[int, int]]:\n", - " \"\"\"Extract list of empty cells (row, col) from board string.\"\"\"\n", - " empty_cells = []\n", - " if not is_valid_board_state(board_str):\n", - " return empty_cells\n", - "\n", - " for line in board_str.split(\"\\n\"):\n", - " line_stripped = line.strip()\n", - " if line_stripped and line_stripped[0] == \"R\" and len(line_stripped) > 1 and line_stripped[1].isdigit():\n", - " row = int(line_stripped[1])\n", - " cell_part = line_stripped[2:]\n", - " col = 0\n", - " for char in cell_part:\n", - " if char == \".\":\n", - " col += 1\n", - " empty_cells.append((row, col))\n", - " elif char.isdigit():\n", - " col += 1\n", - " return empty_cells\n", - "\n", - "\n", - "def extract_board_only(text: str) -> str:\n", - " \"\"\"Extract just the Sudoku grid from a message.\"\"\"\n", - " if not text:\n", - " return \"\"\n", - "\n", - " lines = text.split(\"\\n\")\n", - " board_lines = []\n", - " in_board = False\n", - "\n", - " for line in lines:\n", - " stripped = line.strip()\n", - " if stripped.startswith(\"C1\") or (\n", - " stripped and stripped[0] == \"R\" and len(stripped) > 1 and stripped[1].isdigit()\n", - " ):\n", - " in_board = True\n", - " if in_board and (stripped.startswith(\"-\") or stripped.startswith(\"R\") or stripped.startswith(\"C1\")):\n", - " board_lines.append(line)\n", - " elif (\n", - " in_board\n", - " and stripped\n", - " and not stripped.startswith(\"-\")\n", - " and not (stripped[0] == \"R\" and len(stripped) > 1 and stripped[1].isdigit())\n", - " ):\n", - " break\n", - "\n", - " return \"\\n\".join(board_lines) if board_lines else \"\"\n", - "\n", - "\n", - "def make_compact_prompt(\n", - " board: str,\n", - " step: int,\n", - " successful_moves: list[str],\n", - " failed_moves: list[str],\n", - " difficulty: str = \"hard\",\n", - ") -> str:\n", - " \"\"\"Create a compact prompt with only essential info (saves tokens!).\n", - "\n", - " Args:\n", - " difficulty: Training difficulty level:\n", - " - \"easy\": Show guaranteed moves (naked singles) + other options\n", - " - \"medium\": Only show other options (hints where to look, not exact answers)\n", - " - \"hard\": No hints (model must learn Sudoku rules by itself)\n", - " \"\"\"\n", - "\n", - " # Summary line\n", - " cells_filled = len(successful_moves)\n", - " summary = f\"Step {step}. Progress: {cells_filled} cells filled.\"\n", - "\n", - " # Board (only show the grid, stripped down)\n", - " board_only = extract_board_only(board) if board else \"No board available.\"\n", - "\n", - " # Moves already tried (for learning what NOT to do)\n", - " tried_moves_hint = \"\"\n", - " all_tried = successful_moves + failed_moves\n", - " if all_tried:\n", - " tried_moves_hint = f\"\\n\\n⚠️ MOVES ALREADY TRIED (do not repeat): {', '.join(all_tried)}\"\n", - "\n", - " # Hints based on difficulty\n", - " hints = \"\"\n", - " if difficulty == \"easy\" and board:\n", - " # Easy: sorted by difficulty, show guaranteed moves + other easy options\n", - " cells_with_candidates = extract_empty_cells_with_candidates(board, sort_by_difficulty=True)\n", - " if cells_with_candidates:\n", - " guaranteed = []\n", - " other_hints = []\n", - " for row, col, candidates in cells_with_candidates[:10]:\n", - " if len(candidates) == 1:\n", - " num = list(candidates)[0]\n", - " guaranteed.append(f\"[{row} {col} {num}]\")\n", - " elif len(candidates) <= 3:\n", - " nums = \",\".join(str(n) for n in sorted(candidates))\n", - " other_hints.append(f\"({row},{col})→{nums}\")\n", - "\n", - " if guaranteed:\n", - " hints = f\"\\n\\n🎯 GUARANTEED MOVES: {', '.join(guaranteed[:5])}\"\n", - " if other_hints:\n", - " hints += f\"\\nOther options: {' | '.join(other_hints[:5])}\"\n", - "\n", - " elif difficulty == \"medium\" and board:\n", - " # Medium: NOT sorted, just show empty cells with candidates (no ordering hints)\n", - " cells_with_candidates = extract_empty_cells_with_candidates(board, sort_by_difficulty=False)\n", - " if cells_with_candidates:\n", - " cell_hints = []\n", - " for row, col, candidates in cells_with_candidates[:10]:\n", - " nums = \",\".join(str(n) for n in sorted(candidates))\n", - " cell_hints.append(f\"({row},{col})→{nums}\")\n", - " if cell_hints:\n", - " hints = f\"\\n\\nEmpty cells: {' | '.join(cell_hints)}\"\n", - "\n", - " return f\"{summary}\\n\\nBoard:\\n{board_only}{tried_moves_hint}{hints}\\n\\nYour move:\"\n", - "\n", - "\n", - "def check_move_targets_empty_cell(move: str, board_str: str) -> bool:\n", - " \"\"\"Check if the move targets an empty cell on the board.\"\"\"\n", - " if not move or not board_str:\n", - " return False\n", - "\n", - " match = re.search(r\"\\[(\\d)\\s+(\\d)\\s+(\\d)\\]\", move)\n", - " if not match:\n", - " return False\n", - "\n", - " row, col = int(match.group(1)), int(match.group(2))\n", - " empty_cells = extract_empty_cells(board_str)\n", - " return (row, col) in empty_cells\n", - "\n", - "\n", - "def extract_feedback(observation) -> dict:\n", - " \"\"\"Extract feedback from environment observation.\"\"\"\n", - " feedback = {\"valid_move\": True, \"got_warning\": False, \"board_state\": \"\"}\n", - "\n", - " if not observation or not observation.messages:\n", - " return feedback\n", - "\n", - " for message in observation.messages:\n", - " content = message.content.lower() if message.content else \"\"\n", - "\n", - " if any(kw in content for kw in [\"invalid\", \"error\", \"cannot\", \"already\", \"violation\", \"lost\"]):\n", - " feedback[\"valid_move\"] = False\n", - " if \"please resubmit\" in content or \"avoid penalties\" in content:\n", - " feedback[\"got_warning\"] = True\n", - "\n", - " if message.content and \"|\" in message.content and \"R1\" in message.content:\n", - " feedback[\"board_state\"] = message.content\n", - "\n", - " return feedback" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Oek3JhcWnKhw" - }, - "source": [ - "## Define Reward Functions\n", - "\n", - "To guide the agent's learning, we define reward functions that convert the environment's feedback into numeric signals.\n", - "Each function captures a specific aspect of performance in the **Sudoku** game:\n", - "\n", - "- **`reward_empty_cell`**: Reward for targeting empty cells, encouraging the agent to pick valid positions first.\n", - "- **`reward_valid_moves`**: Reward for making moves that comply with Sudoku rules.\n", - "- **`reward_correct`**: Reward for correctly placing numbers, contributing to solving the puzzle.\n", - "- **`reward_repetition`**: Penalty for repeating moves in the same cell.\n", - "- **`reward_progress`**: Reward for filling more cells on the board, indicating overall progress." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "TPe4XL89KyJl" - }, - "outputs": [], - "source": [ - "def reward_empty_cell(completions: list[str], **kwargs) -> list[float]:\n", - " \"\"\"Reward for targeting empty cells (learn to pick valid positions first).\"\"\"\n", - " rewards = kwargs.get(\"empty_cell_reward\")\n", - " if rewards is None:\n", - " return [0.0 for _ in completions]\n", - " return [float(r) for r in rewards]\n", - "\n", - "\n", - "def reward_valid_moves(completions: list[str], **kwargs) -> list[float]:\n", - " \"\"\"Reward for making valid moves.\"\"\"\n", - " rewards = kwargs.get(\"valid_move_reward\")\n", - " if rewards is None:\n", - " return [0.0 for _ in completions]\n", - " return [float(r) for r in rewards]\n", - "\n", - "\n", - "def reward_correct(completions: list[str], **kwargs) -> list[float]:\n", - " \"\"\"Reward for solving the puzzle.\"\"\"\n", - " rewards = kwargs.get(\"correct_reward\")\n", - " if rewards is None:\n", - " return [0.0 for _ in completions]\n", - " return [float(r) for r in rewards]\n", - "\n", - "\n", - "def reward_repetition(completions: list[str], **kwargs) -> list[float]:\n", - " \"\"\"Penalty for repeating moves.\"\"\"\n", - " rewards = kwargs.get(\"repetition_reward\")\n", - " if rewards is None:\n", - " return [0.0 for _ in completions]\n", - " return [float(r) for r in rewards]\n", - "\n", - "\n", - "def reward_progress(completions: list[str], **kwargs) -> list[float]:\n", - " \"\"\"Reward for filling more cells in the board.\"\"\"\n", - " rewards = kwargs.get(\"progress_reward\")\n", - " if rewards is None:\n", - " return [0.0 for _ in completions]\n", - " return [float(r) for r in rewards]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "66ZsrLplm07U" - }, - "source": [ - "## Load the Custom Dataset\n", - "\n", - "The dataset is built using repeated prompts to control the total number of training episodes.\n", - "\n", - "Each entry in the dataset triggers **one rollout episode** during training. \n", - "The `dataset_prompt` provides the initial instruction to the model at the start of each episode, ensuring consistent guidance and context for task execution." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zV7C_t1GKyJm" - }, - "outputs": [], - "source": [ - "from datasets import Dataset\n", - "\n", - "dataset_prompt = \"Play Sudoku like an expert.\"\n", - "dataset_size = 30\n", - "\n", - "dataset = Dataset.from_dict({\"prompt\": [dataset_prompt] * dataset_size})" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-mvka-96m3I7" - }, - "source": [ - "## Fine-tune using TRL and the GRPOTrainer\n", - "\n", - "The next step is to define the GRPOConfig, which sets all key training parameters.\n", - "\n", - "This configuration determines how the model interacts with vLLM, handles memory and computation, and records training metrics and logs for monitoring the fine-tuning process." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4BP-aBcVKyJm" - }, - "outputs": [], - "source": [ - "from trl import GRPOConfig\n", - "\n", - "output_dir = \"sudoku-grpo-qwen3\"\n", - "\n", - "grpo_config = GRPOConfig(\n", - " use_vllm=True, # Use vLLM engine for fast and efficient inference\n", - " vllm_mode=\"colocate\", # Run vLLM generation on the same GPU as training\n", - " vllm_gpu_memory_utilization=0.1, # Fraction of GPU memory allocated to vLLM\n", - " vllm_max_model_length=2560, # Maximum context length for vLLM generations\n", - "\n", - " output_dir=output_dir, # Directory to save model checkpoints and logs\n", - " num_train_epochs=1, # Number of training epochs\n", - " learning_rate=5e-6, # Initial learning rate\n", - "\n", - " #weight_decay=args.weight_decay, # Optional weight decay for optimizer\n", - " gradient_accumulation_steps=8, # Accumulate gradients over multiple steps to simulate larger batch size\n", - " per_device_train_batch_size=1, # Batch size per device (GPU)\n", - " warmup_steps=20, # Number of warmup steps for learning rate scheduler\n", - " num_generations=8, # Number of rollouts generated per prompt\n", - " max_completion_length=8, # Maximum length of generated completions\n", - "\n", - " logging_steps=1, # Log metrics every N steps\n", - " save_strategy=\"steps\", # Save checkpoints based on steps\n", - " save_steps=10, # Save every N steps\n", - "\n", - " report_to=\"trackio\", # Reporting backend for tracking experiments\n", - " trackio_space_id=output_dir, # Trackio space ID to log metrics\n", - "\n", - " use_liger_kernel=False, # Enable Liger kernel optimizations for faster training\n", - " # chat_template_kwargs={\"enable_thinking\": False}, # Optional template args for model reasoning. We manage this in the rollout function\n", - "\n", - " temperature=0.8,\n", - " top_k=10,\n", - "\n", - " model_init_kwargs={\n", - " \"use_cache\": False,\n", - " }\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "a1taGmD--0Y4" - }, - "source": [ - "## Create `GRPOTrainer` and Start Training\n", - "\n", - "Next, we initialize the `GRPOTrainer`, which handles the full reinforcement learning loop.\n", - "\n", - "It requires the **model**, **reward functions**, **rollout function**, and **dataset** defined earlier. \n", - "Here, we use **Qwen/Qwen3-1.7B**, a smaller version of the Qwen3 models. This model is sufficient for training on the \"easy\" difficulty Sudoku setting. \n", - "For \"medium\" or \"hard\" difficulty, a larger model would be needed, but this setup fits well in Colab with the current configuration.\n", - "\n", - "The trainer coordinates:\n", - "- Interaction between the model and the environment \n", - "- Application of reward signals \n", - "- Policy updates based on feedback\n", - "\n", - "Finally, calling `trainer.train()` starts the fine-tuning process, allowing the model to learn to solve Sudoku through repeated feedback and iteration." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "O-aKk1EwKyJm" - }, - "outputs": [], - "source": [ - "model_name = \"Qwen/Qwen3-1.7B\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "referenced_widgets": [ - "c75c6199d00d42b88a0ef49f650317bf", - "114a42d7d0a74a7a81dad02c21cf41b2" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "lSR2nwdJg962" + }, + "source": [ + "# OpenEnv Sudoku with GRPO using TRL\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_sudoku_grpo.ipynb)\n", + "\n", + "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n", + "\n", + "With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can train a model that learns to **play Sudoku**, through interaction and reinforcement.\n", + "\n", + "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", + "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", + "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)\n", + "- [OpenEnv](https://github.com/meta-pytorch/OpenEnv)\n", + "\n", + "An **agentic environment** is a setting where a model can take actions, observe outcomes, and adjust its behavior based on feedback, similar to how humans learn from trial and error.\n", + "In this case, the agent interacts with the **Sudoku** environment through the [**OpenEnv**](https://github.com/meta-pytorch/OpenEnv) framework, which standardizes multi-agent and RL-style text environments.\n", + "\n", + "Sudoku is a classic logic-based puzzle where the objective is to fill a **9×9 grid** so that. Each **row**, **column**, and **3×3 subgrid** contains all digits from **1 to 9** exactly once.\n", + "\n", + "This structured yet challenging setup makes Sudoku an excellent benchmark for reasoning and decision-making tasks.\n", + "\n", + "We'll fine-tune a model using **GRPO** (Group Relative Policy Optimization) via TRL.\n", + "The training loop follows these steps:\n", + "\n", + "1. The agent **generates guesses** based on the current game state.\n", + "2. The environment **evaluates the guess** and returns structured feedback.\n", + "3. The agent **updates its policy** using reward signals to improve future decisions.\n", + "\n", + "Over time, the model learns to make increasingly valid and efficient Sudoku moves.\n", + "\n", + "## Install dependencies\n", + "\n", + "We'll start by installing **TRL**, which automatically includes the main dependencies like **Transformers**. \n", + "We'll also install the **OpenEnv** framework (for the environment) via the HF Space we will use as environment server ([openenv/sudoku](https://huggingface.co/spaces/openenv/sudoku)), **trackio** (for logging and monitoring training runs), and **vLLM** (for efficient generation)." + ] }, - "id": "cQP77cFYKyJm", - "outputId": "8ec8a2c5-6e64-4c88-a99b-3b54e2f0f1c5" - }, - "outputs": [ { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c75c6199d00d42b88a0ef49f650317bf", - "version_major": 2, - "version_minor": 0 + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mHmE7GhRKyJj" }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00 ⚠️ **Note:** Hosted environments on the Hugging Face Hub have limited concurrency. \n", + "> For improved stability, higher throughput, or parallel experiments, it is recommended to **duplicate the Space into your own account**.\n", + "\n", + "For more information, refer to the [TRL-OpenEnv documentation](https://huggingface.co/docs/trl/main/en/openenv).\n" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "GPU = NVIDIA A100-SXM4-40GB. Max memory = 39.557 GB.\n", - "10.371 GB of memory reserved.\n" - ] - } - ], - "source": [ - "import torch\n", - "gpu_stats = torch.cuda.get_device_properties(0)\n", - "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", - "\n", - "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", - "print(f\"{start_gpu_memory} GB of memory reserved.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "McgHZH-XA1EK" - }, - "source": [ - "And train!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "-O4RJlyBKyJm", - "outputId": "1f65963d-4a41-4fb4-af47-bcc38e8c8de9" - }, - "outputs": [ + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "P6O03louKyJk" + }, + "outputs": [], + "source": [ + "from textarena_env import TextArenaEnv\n", + "\n", + "space_url = \"https://openenv-sudoku.hf.space\"\n", + "client = TextArenaEnv(base_url=space_url)" + ] + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.\n" - ] + "cell_type": "markdown", + "metadata": { + "id": "EqfDavDQnD_5" + }, + "source": [ + "## Create Rollout Function with Helpers\n", + "\n", + "The **rollout function** defines how the agent interacts with the environment during GRPO training.\n", + "It is responsible for generating model outputs, collecting feedback (rewards), and returning all the information needed for policy optimization.\n", + "\n", + "In this setup:\n", + "- The function is called automatically by the **GRPOTrainer** at each training step.\n", + "- It uses the trainer's `generate_rollout_completions()` method for efficient generation with **vLLM** in colocate mode.\n", + "- Each rollout represents a full interaction loop: the model makes guesses, receives feedback from the Sudoku environment, and updates its policy based on reward signals.\n", + "\n", + "Rewards track different aspects of the agent's performance, while helper functions like `rollout_once` handle a single episode of interaction, keeping the main `rollout_func` clean and modular.\n", + "\n", + "This modular approach allows GRPO to efficiently sample, evaluate, and improve the model's guessing strategy through reinforcement learning.\n", + "\n", + "First, we define the `system_prompt` that guides the model's behavior as an expert Sudoku solver with strategic reasoning and structured responses." + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Trackio project initialized: huggingface\n", - "* Trackio metrics will be synced to Hugging Face Dataset: sergiopaniego/sudoku-grpo-qwen3-dataset\n", - "* Creating new space: https://huggingface.co/spaces/sergiopaniego/sudoku-grpo-qwen3\n", - "* View dashboard by going to: https://sergiopaniego-sudoku-grpo-qwen3.hf.space/\n" - ] + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pi1JGoUBKyJk" + }, + "outputs": [], + "source": [ + "# @title System prompt (click to expand)\n", + "SYSTEM_PROMPT = \"\"\"You are an expert Sudoku player with deep knowledge of logical deduction strategies and number placement techniques.\n", + "\n", + "## GAME RULES\n", + "\n", + "1. The puzzle is a 9x9 grid divided into nine 3x3 subgrids (boxes)\n", + "2. Some cells are pre-filled with numbers 1-9\n", + "3. You must fill in the empty cells (shown as '.') with numbers 1-9\n", + "4. Each row must contain numbers 1-9 without repetition\n", + "5. Each column must contain numbers 1-9 without repetition\n", + "6. Each 3x3 subgrid must contain numbers 1-9 without repetition\n", + "7. You cannot overwrite pre-filled cells\n", + "8. Invalid moves result in penalties (-1 reward)\n", + "\n", + "## RESPONSE FORMAT\n", + "\n", + "**CRITICAL: Output ONLY the move, nothing else. No text, no explanation.**\n", + "\n", + "Format: [row col number]\n", + "\n", + "Examples:\n", + "- [5 3 7] → places 7 in row 5, column 3\n", + "- [1 2 4] → places 4 in row 1, column 2\n", + "\n", + "## STRATEGIC APPROACH\n", + "\n", + "Do not repeat the same move twice.\n", + "\n", + "### Basic Strategies\n", + "- **Naked Singles**: If a cell has only one possible candidate, fill it in immediately.\n", + "- **Hidden Singles**: If a number can only go in one cell within a row, column, or box, place it there.\n", + "- **Scanning**: Look at each row, column, and box to find where specific numbers can go.\n", + "\n", + "### Intermediate Strategies\n", + "- **Naked Pairs/Triples**: When two/three cells in a unit contain only the same candidates, eliminate those from other cells.\n", + "- **Hidden Pairs/Triples**: When numbers only appear in specific cells within a unit, those cells can only contain those numbers.\n", + "- **Pointing Pairs**: When a candidate in a box is restricted to a single row/column, eliminate it elsewhere.\n", + "\n", + "### Solving Process\n", + "1. Start by scanning the entire grid to identify easy fills (cells with few candidates)\n", + "2. Look for rows, columns, or boxes with many numbers already placed\n", + "3. Fill all naked singles first\n", + "4. Then look for hidden singles in each row, column, and box\n", + "5. Apply more advanced techniques as needed\n", + "\n", + "### Common Pitfalls to Avoid\n", + "- Don't guess randomly - Sudoku is pure logic\n", + "- Don't overlook any constraint (row, column, or box)\n", + "- Don't try to overwrite pre-filled cells\n", + "- Don't place invalid numbers (must be 1-9)\n", + "- Don't use invalid coordinates (must be 1-9)\n", + "- Don't repeat a move that was already made\n", + "\n", + "## EXAMPLES\n", + "\n", + "### Example 1: Naked Single\n", + "If row 3, column 4 can only contain the number 5:\n", + "[3 4 5]\n", + "\n", + "### Example 2: Hidden Single\n", + "If the number 8 can only go in one cell in row 1:\n", + "[1 7 8]\n", + "\n", + "### Example 3: Row Analysis\n", + "Row 2 is missing only value 5, and column 8 is the empty cell:\n", + "[2 8 5]\n", + "\n", + "### Example 4: Box Analysis\n", + "In the center box, only one cell can contain 9:\n", + "[5 5 9]\n", + "\n", + "## BOARD READING\n", + "\n", + "The board is displayed as a 9x9 grid:\n", + "- Numbers 1-9 are pre-filled or already placed\n", + "- Empty cells are shown as '.'\n", + "- Rows are labeled R1-R9 (top to bottom)\n", + "- Columns are labeled C1-C9 (left to right)\n", + "\n", + "Example board representation:\n", + "```\n", + " C1 C2 C3 C4 C5 C6 C7 C8 C9\n", + "R1 . 8 9 | 1 . . | . 3 7\n", + "R2 2 7 1 | 9 4 3 | 6 . 8\n", + "R3 . 6 5 | . 2 7 | 4 9 .\n", + " - - - - - - - - - - - - - - - -\n", + "R4 . . . | 7 8 . | 9 2 3\n", + "R5 . 9 2 | . 5 6 | . . 4\n", + "R6 7 3 8 | . . 2 | 1 . .\n", + " - - - - - - - - - - - - - - - -\n", + "R7 8 4 . | . . 9 | 5 . .\n", + "R8 5 . . | 6 . 8 | 3 4 9\n", + "R9 9 . 6 | 5 3 4 | 8 7 2\n", + "```\n", + "\n", + "## COORDINATE REFERENCE\n", + "\n", + "Row indices (top to bottom): 1, 2, 3, 4, 5, 6, 7, 8, 9\n", + "Column indices (left to right): 1, 2, 3, 4, 5, 6, 7, 8, 9\n", + "\n", + "Subgrid layout:\n", + "```\n", + "Subgrid 1 | Subgrid 2 | Subgrid 3\n", + " (R1-R3) (R1-R3) (R1-R3)\n", + " (C1-C3) (C4-C6) (C7-C9)\n", + "----------+-----------+----------\n", + "Subgrid 4 | Subgrid 5 | Subgrid 6\n", + " (R4-R6) (R4-R6) (R4-R6)\n", + " (C1-C3) (C4-C6) (C7-C9)\n", + "----------+-----------+----------\n", + "Subgrid 7 | Subgrid 8 | Subgrid 9\n", + " (R7-R9) (R7-R9) (R7-R9)\n", + " (C1-C3) (C4-C6) (C7-C9)\n", + "```\n", + "\n", + "## IMPORTANT CONSTRAINTS\n", + "\n", + "- Coordinates are 1-indexed (1-9 for both row and column)\n", + "- Numbers must be 1-9\n", + "- One move per response\n", + "- Must be a valid move (no rule violations)\n", + "- Never repeat a previous move\n", + "\n", + "## YOUR GOAL\n", + "\n", + "Output ONLY your move in the format [row col number]. No explanation, no reasoning, just the move.\n", + "\"\"\"" + ] }, { - "data": { - "text/html": [ - "
" - ], - "text/plain": [ - "" + "cell_type": "markdown", + "metadata": { + "id": "Vi1rFey39GUl" + }, + "source": [ + "Now, let's define the `rollout_func`.\n", + "\n", + "This function manages the interaction between the model and the Sudoku environment. \n", + "For each prompt in the batch, it runs a full episode, collecting both the model's outputs and the corresponding rewards. These results are then used by GRPO to optimize the agent's policy.\n", + "\n", + "Each game allows the model to make **up to 100 turns**, giving it multiple chances to solve the puzzle.\n", + "We have different difficulty levels available: `'easy'`, `'medium'`, and `'hard'`. The level affects the amount of information provided in the prompt. Higher difficulties give less guidance.\n", + "\n", + "For the **easy** level, the Qwen/Qwen3-1.7B model is sufficient to solve the puzzles efficiently in a Colab notebook.\n", + "For **medium** or **hard** levels, a larger or more advanced model would likely be needed." ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Created new run: sergiopaniego-1767361842\n", - "Step 1: [1 1 6]\n", - "Step 2: [2 7 4]\n", - "Step 3: [3 4 8]\n", - "Step 4: [3 3 2]\n", - "Step 5: [4 8 6]\n", - "Step 6: [4 8 6]\n", - "Episode: empty_cell=1.00, valid=0.58, repetition=-0.17, progress=0.13 (4 cells), correct=-1.00, tokens=1860\n", - "Step 1: [1 1 7]\n", - "Step 2: [2 4 6]\n", - "Step 3: [1 7 8]\n", - "Step 4: [2 4 6]\n", - "Step 5: [2 4 4]\n", - "Step 6: [2 4 6]\n", - "Step 7: [2 4 6]\n", - "Episode: empty_cell=0.43, valid=0.21, repetition=-1.00, progress=0.10 (3 cells), correct=-1.00, tokens=1866\n", - "Step 1: [1 1 2]\n", - "Step 2: [1 1 2]\n", - "Episode: empty_cell=-1.00, valid=-0.25, repetition=-0.50, progress=0.00 (0 cells), correct=-1.00, tokens=1826\n", - "\n", - "# ... Output truncated for readability (see Trackio dashboard for full logs) ...\n", - "\n", - "Step 1: [1 7 6]\n", - "Step 2: [1 9 2]\n", - "Step 3: [2 6 5]\n", - "Step 4: [2 1 3]\n", - "Step 5: [2 2 2]\n", - "Step 6: [3 1 4]\n", - "Step 7: [3 2 6]\n", - "Step 8: [2 9 7]\n", - "Step 9: [3 4 8]\n", - "Step 10: [3 8 9]\n", - "Step 11: [1 3 5]\n", - "Step 12: [2 9 4]\n", - "Step 13: [4 3 8]\n", - "Step 14: [3 8 4]\n", - "Step 15: [2 9 4]\n", - "Episode: empty_cell=0.60, valid=0.73, repetition=-0.07, progress=0.40 (12 cells), correct=-1.00, tokens=1931\n", - "Step 1: [2 8 1]\n", - "Step 2: [2 5 2]\n", - "Step 3: [3 4 6]\n", - "Step 4: [1 5 1]\n", - "Step 5: [2 6 4]\n", - "Step 6: [3 7 4]\n", - "Step 7: [4 3 6]\n", - "Step 8: [5 1 2]\n", - "Step 9: [1 1 4]\n", - "Step 10: [4 3 6]\n", - "Step 11: [4 3 6]\n", - "Step 12: [4 3 6]\n", - "Step 13: [4 3 6]\n", - "Step 14: [4 1 9]\n", - "Step 15: [4 2 4]\n", - "Step 16: [7 8 5]\n", - "Step 17: [7 8 5]\n", - "Episode: empty_cell=0.53, valid=0.62, repetition=-0.94, progress=0.37 (11 cells), correct=-1.00, tokens=1916\n" - ] + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wMQQoQ_UKyJl" + }, + "outputs": [], + "source": [ + "from trl import GRPOTrainer\n", + "\n", + "max_turns = 100\n", + "debug = False # Activate for detailed logs during training\n", + "difficulty=\"easy\"\n", + "\n", + "async def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:\n", + " \"\"\"Async rollout function - TRL handles the event loop automatically.\"\"\"\n", + " all_prompt_ids = []\n", + " all_completion_ids = []\n", + " all_logprobs = []\n", + " all_correct = []\n", + " all_valid = []\n", + " all_empty_cell = []\n", + " all_repetition = []\n", + " all_progress = []\n", + "\n", + " for _ in prompts:\n", + " episode = await rollout_once(\n", + " trainer=trainer,\n", + " env=client,\n", + " tokenizer=trainer.processing_class,\n", + " system_prompt=SYSTEM_PROMPT,\n", + " max_turns=max_turns,\n", + " debug=debug,\n", + " difficulty=difficulty,\n", + " )\n", + " all_prompt_ids.append(episode[\"prompt_ids\"])\n", + " all_completion_ids.append(episode[\"completion_ids\"])\n", + " all_logprobs.append(episode[\"logprobs\"])\n", + " all_correct.append(episode[\"correct_reward\"])\n", + " all_valid.append(episode[\"valid_move_reward\"])\n", + " all_empty_cell.append(episode[\"empty_cell_reward\"])\n", + " all_repetition.append(episode[\"repetition_reward\"])\n", + " all_progress.append(episode[\"progress_reward\"])\n", + "\n", + " return {\n", + " \"prompt_ids\": all_prompt_ids,\n", + " \"completion_ids\": all_completion_ids,\n", + " \"logprobs\": all_logprobs,\n", + " \"correct_reward\": all_correct,\n", + " \"valid_move_reward\": all_valid,\n", + " \"empty_cell_reward\": all_empty_cell,\n", + " \"repetition_reward\": all_repetition,\n", + " \"progress_reward\": all_progress,\n", + " }\n" + ] }, { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " \n", - " [30/30 26:13, Epoch 1/1]\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
StepTraining Loss
1-0.113800
2-0.001800
3-0.051300
4-0.012800
50.012200
60.045600
7-0.104800
8-0.093600
90.182400
10-0.027000
110.042300
12-0.052400
13-0.100100
14-0.074400
15-0.105500
160.125200
17-0.016900
180.119100
190.081800
200.003300
210.024400
22-0.038700
230.000000
240.000000
250.000000
260.000000
270.000000
280.000000
290.000000
300.000000

" - ], - "text/plain": [ - "" + "cell_type": "markdown", + "metadata": { + "id": "ioUHdIxr9ZQO" + }, + "source": [ + "### Define `rollout_once`\n", + "\n", + "The `rollout_once` function runs **a single interaction loop** between the model and the Sudoku environment using the trainer's generation method. \n", + "It executes one mini-episode, from generating a guess to receiving and processing feedback.\n", + "\n", + "Step-by-step:\n", + "\n", + "1. **Environment reset:** Start a new game session and initialize the observation.\n", + "2. **Prompt construction:** Combine the system prompt, current state, and user messages to form the model input.\n", + "3. **Generation:** Use `trl.experimental.openenv.generate_rollout_completions()` to efficiently produce the model's guess.\n", + "4. **Feedback extraction:** Parse the environment's response with helpers like `extract_sudoku_move()` and `extract_feedback()`.\n", + "5. **Reward calculation:** Compute rewards based on correctness, valid moves, empty cell moves, repeated moves, and progress.\n", + "6. **Return structured rollout data:** Includes prompt and completion IDs, log probabilities, and all reward components.\n", + "\n", + "This design allows each episode to be processed independently while providing detailed feedback for the **GRPO training loop**." ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Step 1: [1 3 8]\n", - "Step 2: [1 9 2]\n", - "Step 3: [3 4 2]\n", - "Step 4: [3 8 5]\n", - "Step 5: [1 2 4]\n", - "Step 6: [7 5 1]\n", - "Step 7: [7 5 1]\n", - "Episode: empty_cell=0.43, valid=0.64, repetition=-0.14, progress=0.17 (5 cells), correct=-1.00, tokens=1881\n", - "Step 1: [1 2 9]\n", - "Step 2: [1 6 5]\n", - "Step 3: [2 8 9]\n", - "Step 4: [2 9 7]\n", - "Step 5: [3 4 6]\n", - "Step 6: [3 3 3]\n", - "Step 7: [3 6 2]\n", - "Step 8: [3 5 9]\n", - "Step 9: [4 7 8]\n", - "Step 10: [5 2 2]\n", - "Step 11: [5 1 9]\n", - "Step 12: [3 4 6]\n", - "Step 13: [4 7 1]\n", - "Step 14: [5 3 7]\n", - "Step 15: [5 5 6]\n", - "Step 16: [6 8 4]\n", - "Step 17: [6 1 5]\n", - "Step 18: [6 5 7]\n", - "Step 19: [6 6 1]\n", - "Step 20: [7 2 4]\n", - "Step 21: [7 3 8]\n", - "Step 22: [7 6 6]\n", - "Step 23: [7 7 7]\n", - "Step 24: [8 3 5]\n", - "Step 25: [8 5 2]\n", - "Step 26: [8 6 4]\n", - "Step 27: [8 9 1]\n", - "Step 28: [9 1 2]\n", - "Step 29: [9 2 3]\n", - "Step 30: [9 3 9]\n", - "Step 31: [9 4 8]\n", - "Step 32: [9 7 4]\n", - "Episode: empty_cell=0.88, valid=0.92, repetition=-0.03, progress=1.00 (30 cells), correct=1.00, tokens=2035\n", - "\n", - "# ... Output truncated for readability (see Trackio dashboard for full logs) ...\n", - "\n", - "Step 1: [3 6 4]\n", - "Step 2: [2 3 7]\n", - "Step 3: [4 9 2]\n", - "Step 4: [5 4 7]\n", - "Step 5: [3 9 8]\n", - "Step 6: [4 6 9]\n", - "Step 7: [5 5 1]\n", - "Step 8: [5 6 2]\n", - "Step 9: [6 3 2]\n", - "Step 10: [6 8 8]\n", - "Step 11: [5 8 5]\n", - "Step 12: [5 2 8]\n", - "Step 13: [5 1 9]\n", - "Step 14: [6 7 6]\n", - "Step 15: [6 5 4]\n", - "Step 16: [4 5 6]\n", - "Step 17: [6 4 3]\n", - "Step 18: [7 4 4]\n", - "Step 19: [7 7 8]\n", - "Step 20: [7 1 3]\n", - "Step 21: [9 2 6]\n", - "Step 22: [2 2 4]\n", - "Step 23: [3 7 7]\n", - "Step 24: [4 1 4]\n", - "Step 25: [4 2 5]\n", - "Step 26: [9 1 8]\n", - "Step 27: [1 2 3]\n", - "Step 28: [2 1 6]\n", - "Step 29: [1 7 4]\n", - "Step 30: [9 7 9]\n", - "Episode: empty_cell=1.00, valid=1.00, repetition=0.00, progress=1.00 (30 cells), correct=1.00, tokens=2028\n", - "Step 1: [3 3 7]\n", - "Step 2: [2 1 9]\n", - "Step 3: [3 4 1]\n", - "Step 4: [3 6 2]\n", - "Step 5: [4 3 1]\n", - "Step 6: [4 2 6]\n", - "Step 7: [4 7 8]\n", - "Step 8: [4 6 5]\n", - "Step 9: [4 8 7]\n", - "Step 10: [3 8 6]\n", - "Step 11: [2 5 8]\n", - "Step 12: [6 5 7]\n", - "Step 13: [6 1 8]\n", - "Step 14: [5 7 3]\n", - "Step 15: [7 6 9]\n", - "Step 16: [7 7 2]\n", - "Step 17: [6 6 3]\n", - "Step 18: [8 2 4]\n", - "Step 19: [8 4 5]\n", - "Step 20: [8 6 1]\n", - "Step 21: [5 6 8]\n", - "Step 22: [5 4 6]\n", - "Step 23: [5 5 1]\n", - "Step 24: [8 5 2]\n", - "Step 25: [9 4 8]\n", - "Step 26: [9 5 6]\n", - "Step 27: [4 5 4]\n", - "Step 28: [1 9 8]\n", - "Step 29: [7 9 7]\n", - "Episode: empty_cell=1.00, valid=1.00, repetition=0.00, progress=1.00 (29 cells), correct=1.00, tokens=2020\n", - "* Run finished. Uploading logs to Trackio (please wait...)\n" - ] - } - ], - "source": [ - "trainer_stats = trainer.train()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gF-mr-gfAtkp" - }, - "source": [ - "Show memory stats after training" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "qM8tW2pdKyJm" - }, - "outputs": [], - "source": [ - "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", - "used_memory_for_training = round(used_memory - start_gpu_memory, 3)\n", - "used_percentage = round(used_memory / max_memory * 100, 3)\n", - "training_memory_percentage = round(used_memory_for_training / max_memory * 100, 3)\n", - "\n", - "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", - "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", - "print(f\"Peak reserved memory = {used_memory} GB.\")\n", - "print(f\"Peak reserved memory for training = {used_memory_for_training} GB.\")\n", - "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", - "print(f\"Peak reserved memory for training % of max memory = {training_memory_percentage} %.\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BZj4IG9ZBAix" - }, - "source": [ - "In this step, the fine-tuned model is saved locally and uploaded to the Hugging Face Hub using the configured account credentials." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "xJV-NZTmKyJm" - }, - "outputs": [], - "source": [ - "client.close()\n", - "trainer.save_model(output_dir)\n", - "trainer.push_to_hub()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "X-6lB52GAl_u" - }, - "source": [ - "## Load the Fine-Tuned Model and Run Inference\n", - "\n", - "Now let's test our fine-tuned model by loading the **adapter** and running **inference**. \n", - "We begin by loading the **base model**, attaching the adapter, and obtaining the final fine-tuned model ready for evaluation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "referenced_widgets": [ - "d686d3933bef4ea3a9fb58193495e970", - "4ce63e90903a4f60be6694de976e7127", - "a1003f172b954e218fdb00539d79a7d1", - "859e82d390204d5d8a763bd61b356ae4", - "b54f50facca141639f3412b86cfc433d", - "ec156a0cf90a42059f2335d4bae0628e", - "fd9f8dcdf82d4e39a9c72d0e25464c28", - "8696f66460244e55b9d67fd2fe9d6e51", - "6acb03abb643408b8f802704f00674f5", - "e13e54f756874f78af67a25564d64375", - "5e29181a5ec1421d9f2eefbf7529d363", - "fa420825fdab47d6aea18cd352dd9ef1", - "9a460882ac834c8a90d77eec9a2c34ea", - "7bf3166759f64fcf8968a6d282f8df85" - ] + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AZim6XzEKyJl" + }, + "outputs": [], + "source": [ + "from trl.experimental.openenv import generate_rollout_completions\n", + "from textarena_env import TextArenaAction\n", + "from transformers import AutoTokenizer\n", + "from collections import defaultdict\n", + "\n", + "\n", + "async def rollout_once(\n", + " trainer: GRPOTrainer,\n", + " env: TextArenaEnv,\n", + " tokenizer: AutoTokenizer,\n", + " system_prompt: str,\n", + " max_turns: int,\n", + " debug: bool = False,\n", + " difficulty: str = \"hard\",\n", + ") -> dict[str, list]:\n", + " result = await env.reset()\n", + " observation = result.observation\n", + "\n", + " # Only store the LAST turn for backprop (much more efficient!)\n", + " last_turn_data: dict | None = None\n", + "\n", + " valid_move_scores: list[float] = []\n", + " empty_cell_scores: list[float] = []\n", + " correct_scores: list[float] = []\n", + " repetition_scores: list[float] = []\n", + "\n", + " move_counts: defaultdict[str, int] = defaultdict(int)\n", + "\n", + " # Track successful and failed moves for summary\n", + " successful_moves: list[str] = []\n", + " failed_moves: list[str] = []\n", + "\n", + " # Extract initial board state\n", + " last_board_state = \"\"\n", + " initial_filled = 0\n", + " for message in observation.messages:\n", + " if message.content and is_valid_board_state(message.content):\n", + " last_board_state = message.content\n", + " initial_filled = count_filled_cells(last_board_state)\n", + " break\n", + "\n", + " max_filled = initial_filled # Track max progress\n", + "\n", + " for turn in range(max_turns):\n", + " if result.done:\n", + " break\n", + "\n", + " # Build COMPACT prompt (saves tokens!)\n", + " user_prompt = make_compact_prompt(\n", + " board=last_board_state,\n", + " step=turn + 1,\n", + " successful_moves=successful_moves,\n", + " failed_moves=failed_moves,\n", + " difficulty=difficulty,\n", + " )\n", + " messages = [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": user_prompt},\n", + " ]\n", + " prompt_text = tokenizer.apply_chat_template(\n", + " messages, add_generation_prompt=True, tokenize=False, enable_thinking=False # `enable_thinking` is usable for the current model but could need to be updated for other models\n", + " )\n", + "\n", + " if debug:\n", + " print(f\"\\n{'=' * 60}\")\n", + " print(f\"STEP {turn + 1}\")\n", + " print(f\"{'=' * 60}\")\n", + " print(f\"USER PROMPT:\\n{user_prompt}\")\n", + " print(f\"{'=' * 60}\")\n", + "\n", + " # Generate\n", + " rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]\n", + "\n", + " # Store ONLY this turn's data (replace previous)\n", + " last_turn_data = {\n", + " \"prompt_ids\": rollout_outputs[\"prompt_ids\"],\n", + " \"completion_ids\": rollout_outputs[\"completion_ids\"],\n", + " \"logprobs\": rollout_outputs[\"logprobs\"],\n", + " }\n", + "\n", + " if debug:\n", + " step_tokens = len(rollout_outputs[\"prompt_ids\"]) + len(rollout_outputs[\"completion_ids\"])\n", + " print(f\"TOKENS: this_step={step_tokens} (only last turn used for backprop)\")\n", + "\n", + " completion_text = rollout_outputs.get(\"text\") or tokenizer.decode(\n", + " rollout_outputs[\"completion_ids\"], skip_special_tokens=True\n", + " )\n", + "\n", + " # Extract move\n", + " move = extract_sudoku_move(completion_text)\n", + "\n", + " if debug:\n", + " print(f\"MODEL OUTPUT: {completion_text}\")\n", + " print(f\"EXTRACTED MOVE: {move}\")\n", + "\n", + " # Step environment\n", + " result = await env.step(TextArenaAction(message=move))\n", + " observation = result.observation\n", + " correct_score = float(result.reward or 0.0)\n", + "\n", + " # Get feedback\n", + " feedback = extract_feedback(observation)\n", + "\n", + " # Get environment response\n", + " env_response = \"\"\n", + " for msg in observation.messages:\n", + " if msg.sender_id == -1: # Environment message\n", + " env_response = msg.content\n", + " break\n", + "\n", + " if debug:\n", + " print(\n", + " f\"ENV RESPONSE: {env_response[:200]}...\"\n", + " if len(env_response) > 200\n", + " else f\"ENV RESPONSE: {env_response}\"\n", + " )\n", + " print(f\"VALID: {feedback['valid_move']}, WARNING: {feedback['got_warning']}, REWARD: {correct_score}\")\n", + "\n", + " # Calculate empty_cell_score\n", + " if last_board_state and move:\n", + " targets_empty = check_move_targets_empty_cell(move, last_board_state)\n", + " empty_cell_score = 1.0 if targets_empty else -1.0\n", + " else:\n", + " empty_cell_score = 0.0\n", + "\n", + " # Calculate valid_move_score and repetition_score\n", + " is_new_move = move_counts[move] == 0\n", + " repetition_count = move_counts[move]\n", + " move_counts[move] += 1\n", + "\n", + " # Exponential penalty for repetitions: -2^(n-1) capped at -10\n", + " # 1st repeat: -1, 2nd: -2, 3rd: -4, 4th+: -10 (capped)\n", + " if repetition_count > 0:\n", + " repetition_score = -min(2 ** (repetition_count - 1), 10.0)\n", + " else:\n", + " repetition_score = 0.0\n", + "\n", + " if debug:\n", + " print(\n", + " f\"SCORES: empty_cell={empty_cell_score}, is_new={is_new_move}, repetitions={repetition_count}, rep_penalty={repetition_score}\"\n", + " )\n", + "\n", + " if not debug:\n", + " print(f\"Step {turn + 1}: {move}\")\n", + "\n", + " if feedback[\"valid_move\"] and is_new_move:\n", + " valid_move_score = 1.0\n", + " if move:\n", + " successful_moves.append(move) # Track for summary\n", + " elif feedback[\"got_warning\"]:\n", + " valid_move_score = -0.5\n", + " if move:\n", + " failed_moves.append(move) # Track for summary\n", + " else:\n", + " valid_move_score = 0.0\n", + "\n", + " # Update board state and track progress\n", + " if feedback[\"board_state\"] and is_valid_board_state(feedback[\"board_state\"]):\n", + " last_board_state = feedback[\"board_state\"]\n", + " current_filled = count_filled_cells(last_board_state)\n", + " if current_filled > max_filled:\n", + " max_filled = current_filled\n", + "\n", + " valid_move_scores.append(valid_move_score)\n", + " empty_cell_scores.append(empty_cell_score)\n", + " correct_scores.append(correct_score)\n", + " repetition_scores.append(repetition_score)\n", + "\n", + " # Aggregate rewards\n", + " correct_reward = correct_scores[-1] if correct_scores else 0.0\n", + " valid_move_reward = sum(valid_move_scores) / len(valid_move_scores) if valid_move_scores else 0.0\n", + " empty_cell_reward = sum(empty_cell_scores) / len(empty_cell_scores) if empty_cell_scores else 0.0\n", + " repetition_reward = sum(repetition_scores) / len(repetition_scores) if repetition_scores else 0.0\n", + "\n", + " # Progress reward: how many cells we filled beyond initial state (normalized to 0-1)\n", + " # 81 total cells, so (max_filled - initial_filled) / (81 - initial_filled) gives progress\n", + " remaining_to_fill = 81 - initial_filled\n", + " if remaining_to_fill > 0:\n", + " progress_reward = (max_filled - initial_filled) / remaining_to_fill\n", + " else:\n", + " progress_reward = 1.0 # Already complete\n", + "\n", + " # Use ONLY last turn for backpropagation (much more efficient!)\n", + " if last_turn_data:\n", + " prompt_ids = last_turn_data[\"prompt_ids\"]\n", + " completion_ids = last_turn_data[\"completion_ids\"]\n", + " logprobs = last_turn_data[\"logprobs\"]\n", + " else:\n", + " prompt_ids = []\n", + " completion_ids = []\n", + " logprobs = []\n", + "\n", + " total_tokens = len(prompt_ids) + len(completion_ids)\n", + " cells_filled = max_filled - initial_filled\n", + " print(\n", + " f\"Episode: empty_cell={empty_cell_reward:.2f}, valid={valid_move_reward:.2f}, \"\n", + " f\"repetition={repetition_reward:.2f}, progress={progress_reward:.2f} ({cells_filled} cells), \"\n", + " f\"correct={correct_reward:.2f}, tokens={total_tokens}\"\n", + " )\n", + "\n", + " return {\n", + " \"prompt_ids\": prompt_ids,\n", + " \"completion_ids\": completion_ids,\n", + " \"logprobs\": logprobs,\n", + " \"correct_reward\": correct_reward,\n", + " \"valid_move_reward\": valid_move_reward,\n", + " \"empty_cell_reward\": empty_cell_reward,\n", + " \"repetition_reward\": repetition_reward,\n", + " \"progress_reward\": progress_reward,\n", + " }" + ] }, - "id": "-Vu--VueKyJm", - "outputId": "399ccb1e-45bf-4305-ae9d-97edece48b53" - }, - "outputs": [ { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d686d3933bef4ea3a9fb58193495e970", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "MDJKMQ__8qzj" }, - "text/plain": [ - "config.json: 0.00B [00:00, ?B/s]" + "source": [ + "### Helper Functions\n", + "\n", + "These utility functions are used within `rollout_once` to process the environment and model outputs:\n", + "\n", + "- **`extract_sudoku_move`**: Extract a Sudoku move `[row, col, number]` from text. \n", + "- **`is_valid_board_state`**: Check if a string represents a valid Sudoku board. \n", + "- **`parse_board`**: Convert a board string into a 9×9 grid (with `0` for empty cells). \n", + "- **`count_filled_cells`**: Count the number of filled cells in the board. \n", + "- **`get_valid_numbers`**: Get the valid numbers for a specific cell according to Sudoku rules. \n", + "- **`extract_empty_cells_with_candidates`**: Identify empty cells along with their valid candidate numbers. \n", + "- **`extract_empty_cells`**: List all empty cells `(row, col)` from a board string. \n", + "- **`extract_board_only`**: Extract just the Sudoku grid from a message. \n", + "- **`make_compact_prompt`**: Create a concise prompt with only essential information to save tokens. \n", + "- **`check_move_targets_empty_cell`**: Verify if a proposed move targets an empty cell on the board. \n", + "- **`extract_feedback`**: Extract structured feedback from the environment's observation." ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4ce63e90903a4f60be6694de976e7127", - "version_major": 2, - "version_minor": 0 + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0f9RqHh7KyJl" }, - "text/plain": [ - "model.safetensors.index.json: 0.00B [00:00, ?B/s]" + "outputs": [], + "source": [ + "# @title Helpers (click to expand)\n", + "import re\n", + "\n", + "def extract_sudoku_move(text: str) -> str:\n", + " \"\"\"Extract a Sudoku move [row col number] from text.\"\"\"\n", + " # Try with spaces\n", + " match = re.search(r\"\\[(\\d)\\s+(\\d)\\s+(\\d)\\]\", text)\n", + " if match:\n", + " row, col, num = match.groups()\n", + " return f\"[{row} {col} {num}]\"\n", + "\n", + " # Try without spaces\n", + " match = re.search(r\"\\[(\\d)(\\d)(\\d)\\]\", text)\n", + " if match:\n", + " row, col, num = match.groups()\n", + " return f\"[{row} {col} {num}]\"\n", + "\n", + " return \"\" # Handled by the environment: missing/invalid moves trigger a \"wrong movement\" message affecting rewards\n", + "\n", + "\n", + "def is_valid_board_state(board_str: str) -> bool:\n", + " \"\"\"Check if the string contains an actual Sudoku board.\"\"\"\n", + " return \"R1\" in board_str and \"R9\" in board_str and \"|\" in board_str\n", + "\n", + "\n", + "def parse_board(board_str: str) -> list[list[int]]:\n", + " \"\"\"Parse board string into 9x9 grid (0 = empty).\"\"\"\n", + " grid = [[0] * 9 for _ in range(9)]\n", + " if not is_valid_board_state(board_str):\n", + " return grid\n", + "\n", + " for line in board_str.split(\"\\n\"):\n", + " line_stripped = line.strip()\n", + " if line_stripped and line_stripped[0] == \"R\" and len(line_stripped) > 1 and line_stripped[1].isdigit():\n", + " row = int(line_stripped[1]) - 1 # 0-indexed\n", + " cell_part = line_stripped[2:]\n", + " col = 0\n", + " for char in cell_part:\n", + " if char == \".\":\n", + " grid[row][col] = 0\n", + " col += 1\n", + " elif char.isdigit():\n", + " grid[row][col] = int(char)\n", + " col += 1\n", + " return grid\n", + "\n", + "\n", + "def count_filled_cells(board_str: str) -> int:\n", + " \"\"\"Count the number of filled cells in the board.\"\"\"\n", + " if not is_valid_board_state(board_str):\n", + " return 0\n", + " grid = parse_board(board_str)\n", + " return sum(1 for row in grid for cell in row if cell != 0)\n", + "\n", + "\n", + "def get_valid_numbers(grid: list[list[int]], row: int, col: int) -> set[int]:\n", + " \"\"\"Get valid numbers for a cell based on Sudoku rules.\"\"\"\n", + " if grid[row][col] != 0:\n", + " return set()\n", + "\n", + " used = set()\n", + "\n", + " # Check row\n", + " for c in range(9):\n", + " if grid[row][c] != 0:\n", + " used.add(grid[row][c])\n", + "\n", + " # Check column\n", + " for r in range(9):\n", + " if grid[r][col] != 0:\n", + " used.add(grid[r][col])\n", + "\n", + " # Check 3x3 box\n", + " box_row, box_col = 3 * (row // 3), 3 * (col // 3)\n", + " for r in range(box_row, box_row + 3):\n", + " for c in range(box_col, box_col + 3):\n", + " if grid[r][c] != 0:\n", + " used.add(grid[r][c])\n", + "\n", + " return set(range(1, 10)) - used\n", + "\n", + "\n", + "def extract_empty_cells_with_candidates(\n", + " board_str: str, sort_by_difficulty: bool = True\n", + ") -> list[tuple[int, int, set[int]]]:\n", + " \"\"\"Extract empty cells with their valid candidate numbers.\n", + "\n", + " Args:\n", + " sort_by_difficulty: If True, sort by number of candidates (easiest first).\n", + " If False, keep natural order (top-left to bottom-right).\n", + " \"\"\"\n", + " grid = parse_board(board_str)\n", + " cells_with_candidates = []\n", + "\n", + " for row in range(9):\n", + " for col in range(9):\n", + " if grid[row][col] == 0:\n", + " candidates = get_valid_numbers(grid, row, col)\n", + " cells_with_candidates.append((row + 1, col + 1, candidates)) # 1-indexed\n", + "\n", + " if sort_by_difficulty:\n", + " # Sort by number of candidates (easiest first = naked singles)\n", + " cells_with_candidates.sort(key=lambda x: len(x[2]))\n", + "\n", + " return cells_with_candidates\n", + "\n", + "\n", + "def extract_empty_cells(board_str: str) -> list[tuple[int, int]]:\n", + " \"\"\"Extract list of empty cells (row, col) from board string.\"\"\"\n", + " empty_cells = []\n", + " if not is_valid_board_state(board_str):\n", + " return empty_cells\n", + "\n", + " for line in board_str.split(\"\\n\"):\n", + " line_stripped = line.strip()\n", + " if line_stripped and line_stripped[0] == \"R\" and len(line_stripped) > 1 and line_stripped[1].isdigit():\n", + " row = int(line_stripped[1])\n", + " cell_part = line_stripped[2:]\n", + " col = 0\n", + " for char in cell_part:\n", + " if char == \".\":\n", + " col += 1\n", + " empty_cells.append((row, col))\n", + " elif char.isdigit():\n", + " col += 1\n", + " return empty_cells\n", + "\n", + "\n", + "def extract_board_only(text: str) -> str:\n", + " \"\"\"Extract just the Sudoku grid from a message.\"\"\"\n", + " if not text:\n", + " return \"\"\n", + "\n", + " lines = text.split(\"\\n\")\n", + " board_lines = []\n", + " in_board = False\n", + "\n", + " for line in lines:\n", + " stripped = line.strip()\n", + " if stripped.startswith(\"C1\") or (\n", + " stripped and stripped[0] == \"R\" and len(stripped) > 1 and stripped[1].isdigit()\n", + " ):\n", + " in_board = True\n", + " if in_board and (stripped.startswith(\"-\") or stripped.startswith(\"R\") or stripped.startswith(\"C1\")):\n", + " board_lines.append(line)\n", + " elif (\n", + " in_board\n", + " and stripped\n", + " and not stripped.startswith(\"-\")\n", + " and not (stripped[0] == \"R\" and len(stripped) > 1 and stripped[1].isdigit())\n", + " ):\n", + " break\n", + "\n", + " return \"\\n\".join(board_lines) if board_lines else \"\"\n", + "\n", + "\n", + "def make_compact_prompt(\n", + " board: str,\n", + " step: int,\n", + " successful_moves: list[str],\n", + " failed_moves: list[str],\n", + " difficulty: str = \"hard\",\n", + ") -> str:\n", + " \"\"\"Create a compact prompt with only essential info (saves tokens!).\n", + "\n", + " Args:\n", + " difficulty: Training difficulty level:\n", + " - \"easy\": Show guaranteed moves (naked singles) + other options\n", + " - \"medium\": Only show other options (hints where to look, not exact answers)\n", + " - \"hard\": No hints (model must learn Sudoku rules by itself)\n", + " \"\"\"\n", + "\n", + " # Summary line\n", + " cells_filled = len(successful_moves)\n", + " summary = f\"Step {step}. Progress: {cells_filled} cells filled.\"\n", + "\n", + " # Board (only show the grid, stripped down)\n", + " board_only = extract_board_only(board) if board else \"No board available.\"\n", + "\n", + " # Moves already tried (for learning what NOT to do)\n", + " tried_moves_hint = \"\"\n", + " all_tried = successful_moves + failed_moves\n", + " if all_tried:\n", + " tried_moves_hint = f\"\\n\\n⚠️ MOVES ALREADY TRIED (do not repeat): {', '.join(all_tried)}\"\n", + "\n", + " # Hints based on difficulty\n", + " hints = \"\"\n", + " if difficulty == \"easy\" and board:\n", + " # Easy: sorted by difficulty, show guaranteed moves + other easy options\n", + " cells_with_candidates = extract_empty_cells_with_candidates(board, sort_by_difficulty=True)\n", + " if cells_with_candidates:\n", + " guaranteed = []\n", + " other_hints = []\n", + " for row, col, candidates in cells_with_candidates[:10]:\n", + " if len(candidates) == 1:\n", + " num = list(candidates)[0]\n", + " guaranteed.append(f\"[{row} {col} {num}]\")\n", + " elif len(candidates) <= 3:\n", + " nums = \",\".join(str(n) for n in sorted(candidates))\n", + " other_hints.append(f\"({row},{col})→{nums}\")\n", + "\n", + " if guaranteed:\n", + " hints = f\"\\n\\n🎯 GUARANTEED MOVES: {', '.join(guaranteed[:5])}\"\n", + " if other_hints:\n", + " hints += f\"\\nOther options: {' | '.join(other_hints[:5])}\"\n", + "\n", + " elif difficulty == \"medium\" and board:\n", + " # Medium: NOT sorted, just show empty cells with candidates (no ordering hints)\n", + " cells_with_candidates = extract_empty_cells_with_candidates(board, sort_by_difficulty=False)\n", + " if cells_with_candidates:\n", + " cell_hints = []\n", + " for row, col, candidates in cells_with_candidates[:10]:\n", + " nums = \",\".join(str(n) for n in sorted(candidates))\n", + " cell_hints.append(f\"({row},{col})→{nums}\")\n", + " if cell_hints:\n", + " hints = f\"\\n\\nEmpty cells: {' | '.join(cell_hints)}\"\n", + "\n", + " return f\"{summary}\\n\\nBoard:\\n{board_only}{tried_moves_hint}{hints}\\n\\nYour move:\"\n", + "\n", + "\n", + "def check_move_targets_empty_cell(move: str, board_str: str) -> bool:\n", + " \"\"\"Check if the move targets an empty cell on the board.\"\"\"\n", + " if not move or not board_str:\n", + " return False\n", + "\n", + " match = re.search(r\"\\[(\\d)\\s+(\\d)\\s+(\\d)\\]\", move)\n", + " if not match:\n", + " return False\n", + "\n", + " row, col = int(match.group(1)), int(match.group(2))\n", + " empty_cells = extract_empty_cells(board_str)\n", + " return (row, col) in empty_cells\n", + "\n", + "\n", + "def extract_feedback(observation) -> dict:\n", + " \"\"\"Extract feedback from environment observation.\"\"\"\n", + " feedback = {\"valid_move\": True, \"got_warning\": False, \"board_state\": \"\"}\n", + "\n", + " if not observation or not observation.messages:\n", + " return feedback\n", + "\n", + " for message in observation.messages:\n", + " content = message.content.lower() if message.content else \"\"\n", + "\n", + " if any(kw in content for kw in [\"invalid\", \"error\", \"cannot\", \"already\", \"violation\", \"lost\"]):\n", + " feedback[\"valid_move\"] = False\n", + " if \"please resubmit\" in content or \"avoid penalties\" in content:\n", + " feedback[\"got_warning\"] = True\n", + "\n", + " if message.content and \"|\" in message.content and \"R1\" in message.content:\n", + " feedback[\"board_state\"] = message.content\n", + "\n", + " return feedback" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a1003f172b954e218fdb00539d79a7d1", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "Oek3JhcWnKhw" }, - "text/plain": [ - "Fetching 2 files: 0%| | 0/2 [00:00 list[float]:\n", + " \"\"\"Reward for targeting empty cells (learn to pick valid positions first).\"\"\"\n", + " rewards = kwargs.get(\"empty_cell_reward\")\n", + " if rewards is None:\n", + " return [0.0 for _ in completions]\n", + " return [float(r) for r in rewards]\n", + "\n", + "\n", + "def reward_valid_moves(completions: list[str], **kwargs) -> list[float]:\n", + " \"\"\"Reward for making valid moves.\"\"\"\n", + " rewards = kwargs.get(\"valid_move_reward\")\n", + " if rewards is None:\n", + " return [0.0 for _ in completions]\n", + " return [float(r) for r in rewards]\n", + "\n", + "\n", + "def reward_correct(completions: list[str], **kwargs) -> list[float]:\n", + " \"\"\"Reward for solving the puzzle.\"\"\"\n", + " rewards = kwargs.get(\"correct_reward\")\n", + " if rewards is None:\n", + " return [0.0 for _ in completions]\n", + " return [float(r) for r in rewards]\n", + "\n", + "\n", + "def reward_repetition(completions: list[str], **kwargs) -> list[float]:\n", + " \"\"\"Penalty for repeating moves.\"\"\"\n", + " rewards = kwargs.get(\"repetition_reward\")\n", + " if rewards is None:\n", + " return [0.0 for _ in completions]\n", + " return [float(r) for r in rewards]\n", + "\n", + "\n", + "def reward_progress(completions: list[str], **kwargs) -> list[float]:\n", + " \"\"\"Reward for filling more cells in the board.\"\"\"\n", + " rewards = kwargs.get(\"progress_reward\")\n", + " if rewards is None:\n", + " return [0.0 for _ in completions]\n", + " return [float(r) for r in rewards]" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b54f50facca141639f3412b86cfc433d", - "version_major": 2, - "version_minor": 0 + "cell_type": "markdown", + "metadata": { + "id": "66ZsrLplm07U" }, - "text/plain": [ - "model-00001-of-00002.safetensors: 0%| | 0.00/4.97G [00:00" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Created new run: sergiopaniego-1767361842\n", + "Step 1: [1 1 6]\n", + "Step 2: [2 7 4]\n", + "Step 3: [3 4 8]\n", + "Step 4: [3 3 2]\n", + "Step 5: [4 8 6]\n", + "Step 6: [4 8 6]\n", + "Episode: empty_cell=1.00, valid=0.58, repetition=-0.17, progress=0.13 (4 cells), correct=-1.00, tokens=1860\n", + "Step 1: [1 1 7]\n", + "Step 2: [2 4 6]\n", + "Step 3: [1 7 8]\n", + "Step 4: [2 4 6]\n", + "Step 5: [2 4 4]\n", + "Step 6: [2 4 6]\n", + "Step 7: [2 4 6]\n", + "Episode: empty_cell=0.43, valid=0.21, repetition=-1.00, progress=0.10 (3 cells), correct=-1.00, tokens=1866\n", + "Step 1: [1 1 2]\n", + "Step 2: [1 1 2]\n", + "Episode: empty_cell=-1.00, valid=-0.25, repetition=-0.50, progress=0.00 (0 cells), correct=-1.00, tokens=1826\n", + "\n", + "# ... Output truncated for readability (see Trackio dashboard for full logs) ...\n", + "\n", + "Step 1: [1 7 6]\n", + "Step 2: [1 9 2]\n", + "Step 3: [2 6 5]\n", + "Step 4: [2 1 3]\n", + "Step 5: [2 2 2]\n", + "Step 6: [3 1 4]\n", + "Step 7: [3 2 6]\n", + "Step 8: [2 9 7]\n", + "Step 9: [3 4 8]\n", + "Step 10: [3 8 9]\n", + "Step 11: [1 3 5]\n", + "Step 12: [2 9 4]\n", + "Step 13: [4 3 8]\n", + "Step 14: [3 8 4]\n", + "Step 15: [2 9 4]\n", + "Episode: empty_cell=0.60, valid=0.73, repetition=-0.07, progress=0.40 (12 cells), correct=-1.00, tokens=1931\n", + "Step 1: [2 8 1]\n", + "Step 2: [2 5 2]\n", + "Step 3: [3 4 6]\n", + "Step 4: [1 5 1]\n", + "Step 5: [2 6 4]\n", + "Step 6: [3 7 4]\n", + "Step 7: [4 3 6]\n", + "Step 8: [5 1 2]\n", + "Step 9: [1 1 4]\n", + "Step 10: [4 3 6]\n", + "Step 11: [4 3 6]\n", + "Step 12: [4 3 6]\n", + "Step 13: [4 3 6]\n", + "Step 14: [4 1 9]\n", + "Step 15: [4 2 4]\n", + "Step 16: [7 8 5]\n", + "Step 17: [7 8 5]\n", + "Episode: empty_cell=0.53, valid=0.62, repetition=-0.94, progress=0.37 (11 cells), correct=-1.00, tokens=1916\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [30/30 26:13, Epoch 1/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
1-0.113800
2-0.001800
3-0.051300
4-0.012800
50.012200
60.045600
7-0.104800
8-0.093600
90.182400
10-0.027000
110.042300
12-0.052400
13-0.100100
14-0.074400
15-0.105500
160.125200
17-0.016900
180.119100
190.081800
200.003300
210.024400
22-0.038700
230.000000
240.000000
250.000000
260.000000
270.000000
280.000000
290.000000
300.000000

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 1: [1 3 8]\n", + "Step 2: [1 9 2]\n", + "Step 3: [3 4 2]\n", + "Step 4: [3 8 5]\n", + "Step 5: [1 2 4]\n", + "Step 6: [7 5 1]\n", + "Step 7: [7 5 1]\n", + "Episode: empty_cell=0.43, valid=0.64, repetition=-0.14, progress=0.17 (5 cells), correct=-1.00, tokens=1881\n", + "Step 1: [1 2 9]\n", + "Step 2: [1 6 5]\n", + "Step 3: [2 8 9]\n", + "Step 4: [2 9 7]\n", + "Step 5: [3 4 6]\n", + "Step 6: [3 3 3]\n", + "Step 7: [3 6 2]\n", + "Step 8: [3 5 9]\n", + "Step 9: [4 7 8]\n", + "Step 10: [5 2 2]\n", + "Step 11: [5 1 9]\n", + "Step 12: [3 4 6]\n", + "Step 13: [4 7 1]\n", + "Step 14: [5 3 7]\n", + "Step 15: [5 5 6]\n", + "Step 16: [6 8 4]\n", + "Step 17: [6 1 5]\n", + "Step 18: [6 5 7]\n", + "Step 19: [6 6 1]\n", + "Step 20: [7 2 4]\n", + "Step 21: [7 3 8]\n", + "Step 22: [7 6 6]\n", + "Step 23: [7 7 7]\n", + "Step 24: [8 3 5]\n", + "Step 25: [8 5 2]\n", + "Step 26: [8 6 4]\n", + "Step 27: [8 9 1]\n", + "Step 28: [9 1 2]\n", + "Step 29: [9 2 3]\n", + "Step 30: [9 3 9]\n", + "Step 31: [9 4 8]\n", + "Step 32: [9 7 4]\n", + "Episode: empty_cell=0.88, valid=0.92, repetition=-0.03, progress=1.00 (30 cells), correct=1.00, tokens=2035\n", + "\n", + "# ... Output truncated for readability (see Trackio dashboard for full logs) ...\n", + "\n", + "Step 1: [3 6 4]\n", + "Step 2: [2 3 7]\n", + "Step 3: [4 9 2]\n", + "Step 4: [5 4 7]\n", + "Step 5: [3 9 8]\n", + "Step 6: [4 6 9]\n", + "Step 7: [5 5 1]\n", + "Step 8: [5 6 2]\n", + "Step 9: [6 3 2]\n", + "Step 10: [6 8 8]\n", + "Step 11: [5 8 5]\n", + "Step 12: [5 2 8]\n", + "Step 13: [5 1 9]\n", + "Step 14: [6 7 6]\n", + "Step 15: [6 5 4]\n", + "Step 16: [4 5 6]\n", + "Step 17: [6 4 3]\n", + "Step 18: [7 4 4]\n", + "Step 19: [7 7 8]\n", + "Step 20: [7 1 3]\n", + "Step 21: [9 2 6]\n", + "Step 22: [2 2 4]\n", + "Step 23: [3 7 7]\n", + "Step 24: [4 1 4]\n", + "Step 25: [4 2 5]\n", + "Step 26: [9 1 8]\n", + "Step 27: [1 2 3]\n", + "Step 28: [2 1 6]\n", + "Step 29: [1 7 4]\n", + "Step 30: [9 7 9]\n", + "Episode: empty_cell=1.00, valid=1.00, repetition=0.00, progress=1.00 (30 cells), correct=1.00, tokens=2028\n", + "Step 1: [3 3 7]\n", + "Step 2: [2 1 9]\n", + "Step 3: [3 4 1]\n", + "Step 4: [3 6 2]\n", + "Step 5: [4 3 1]\n", + "Step 6: [4 2 6]\n", + "Step 7: [4 7 8]\n", + "Step 8: [4 6 5]\n", + "Step 9: [4 8 7]\n", + "Step 10: [3 8 6]\n", + "Step 11: [2 5 8]\n", + "Step 12: [6 5 7]\n", + "Step 13: [6 1 8]\n", + "Step 14: [5 7 3]\n", + "Step 15: [7 6 9]\n", + "Step 16: [7 7 2]\n", + "Step 17: [6 6 3]\n", + "Step 18: [8 2 4]\n", + "Step 19: [8 4 5]\n", + "Step 20: [8 6 1]\n", + "Step 21: [5 6 8]\n", + "Step 22: [5 4 6]\n", + "Step 23: [5 5 1]\n", + "Step 24: [8 5 2]\n", + "Step 25: [9 4 8]\n", + "Step 26: [9 5 6]\n", + "Step 27: [4 5 4]\n", + "Step 28: [1 9 8]\n", + "Step 29: [7 9 7]\n", + "Episode: empty_cell=1.00, valid=1.00, repetition=0.00, progress=1.00 (29 cells), correct=1.00, tokens=2020\n", + "* Run finished. Uploading logs to Trackio (please wait...)\n" + ] + } + ], + "source": [ + "trainer_stats = trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gF-mr-gfAtkp" + }, + "source": [ + "Show memory stats after training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qM8tW2pdKyJm" + }, + "outputs": [], + "source": [ + "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "used_memory_for_training = round(used_memory - start_gpu_memory, 3)\n", + "used_percentage = round(used_memory / max_memory * 100, 3)\n", + "training_memory_percentage = round(used_memory_for_training / max_memory * 100, 3)\n", + "\n", + "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", + "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", + "print(f\"Peak reserved memory = {used_memory} GB.\")\n", + "print(f\"Peak reserved memory for training = {used_memory_for_training} GB.\")\n", + "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", + "print(f\"Peak reserved memory for training % of max memory = {training_memory_percentage} %.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BZj4IG9ZBAix" + }, + "source": [ + "In this step, the fine-tuned model is saved locally and uploaded to the Hugging Face Hub using the configured account credentials." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xJV-NZTmKyJm" + }, + "outputs": [], + "source": [ + "client.close()\n", + "trainer.save_model(output_dir)\n", + "trainer.push_to_hub()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "X-6lB52GAl_u" + }, + "source": [ + "## Load the Fine-Tuned Model and Run Inference\n", + "\n", + "Now let's test our fine-tuned model by loading the **adapter** and running **inference**. \n", + "We begin by loading the **base model**, attaching the adapter, and obtaining the final fine-tuned model ready for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "referenced_widgets": [ + "d686d3933bef4ea3a9fb58193495e970", + "4ce63e90903a4f60be6694de976e7127", + "a1003f172b954e218fdb00539d79a7d1", + "859e82d390204d5d8a763bd61b356ae4", + "b54f50facca141639f3412b86cfc433d", + "ec156a0cf90a42059f2335d4bae0628e", + "fd9f8dcdf82d4e39a9c72d0e25464c28", + "8696f66460244e55b9d67fd2fe9d6e51", + "6acb03abb643408b8f802704f00674f5", + "e13e54f756874f78af67a25564d64375", + "5e29181a5ec1421d9f2eefbf7529d363", + "fa420825fdab47d6aea18cd352dd9ef1", + "9a460882ac834c8a90d77eec9a2c34ea", + "7bf3166759f64fcf8968a6d282f8df85" + ] + }, + "id": "-Vu--VueKyJm", + "outputId": "399ccb1e-45bf-4305-ae9d-97edece48b53" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d686d3933bef4ea3a9fb58193495e970", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "config.json: 0.00B [00:00, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4ce63e90903a4f60be6694de976e7127", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "model.safetensors.index.json: 0.00B [00:00, ?B/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a1003f172b954e218fdb00539d79a7d1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 2 files: 0%| | 0/2 [00:00 str: return "noop()" -def rollout_once( +async def rollout_once( trainer: GRPOTrainer, env: BrowserGymEnv, tokenizer: AutoTokenizer, @@ -355,7 +355,7 @@ def rollout_once( debug: bool = False, ) -> dict[str, list]: """Run one episode and collect training data.""" - result = env.reset() + result = await env.reset() observation = result.observation prompt_ids: list[int] = [] @@ -421,7 +421,7 @@ def rollout_once( print(f"Step {step_num + 1}: {action_str}") # Take action in environment - result = env.step(BrowserGymAction(action_str=action_str)) + result = await env.step(BrowserGymAction(action_str=action_str)) observation = result.observation # Track rewards @@ -531,7 +531,8 @@ def main() -> None: grpo_config.run_name = args.run_name or f"run-{timestamp}" grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}" - def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: + async def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: + """Async rollout function - TRL handles the event loop automatically.""" episode_prompt_ids: list[list[int]] = [] episode_completion_ids: list[list[int]] = [] episode_logprobs: list[list[float]] = [] @@ -542,7 +543,7 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: for i, prompt_text in enumerate(prompts): print(f"[DEBUG] Processing prompt {i + 1}/{len(prompts)}") - episode = rollout_once( + episode = await rollout_once( trainer=trainer, env=client, tokenizer=tokenizer, diff --git a/examples/scripts/openenv/browsergym_llm.py b/examples/scripts/openenv/browsergym_llm.py index 0f57fd23f3b..f4005883f87 100644 --- a/examples/scripts/openenv/browsergym_llm.py +++ b/examples/scripts/openenv/browsergym_llm.py @@ -309,7 +309,7 @@ def parse_action(response_text: str) -> str: return "noop()" -def rollout_once( +async def rollout_once( trainer: GRPOTrainer, env: BrowserGymEnv, tokenizer: AutoTokenizer, @@ -318,7 +318,7 @@ def rollout_once( debug: bool = False, ) -> dict[str, list]: """Run one episode and collect training data (text-only, no screenshots).""" - result = env.reset() + result = await env.reset() observation = result.observation prompt_ids: list[int] = [] @@ -364,7 +364,7 @@ def rollout_once( print(f"Step {step_num + 1}: {action_str}") # Take action in environment - result = env.step(BrowserGymAction(action_str=action_str)) + result = await env.step(BrowserGymAction(action_str=action_str)) observation = result.observation # Track rewards @@ -451,7 +451,8 @@ def main() -> None: grpo_config.run_name = args.run_name or f"run-{timestamp}" grpo_config.project = args.project or f"group-{sanitize_name(args.model_id)}" - def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: + async def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: + """Async rollout function - TRL handles the event loop automatically.""" episode_prompt_ids: list[list[int]] = [] episode_completion_ids: list[list[int]] = [] episode_logprobs: list[list[float]] = [] @@ -463,7 +464,7 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: for i, prompt_text in enumerate(prompts): if args.debug: print(f"[DEBUG] Processing prompt {i + 1}/{len(prompts)}") - episode = rollout_once( + episode = await rollout_once( trainer=trainer, env=client, tokenizer=trainer.processing_class, diff --git a/examples/scripts/openenv/sudoku.py b/examples/scripts/openenv/sudoku.py index 5f5ec44b8d1..871ab248aa3 100644 --- a/examples/scripts/openenv/sudoku.py +++ b/examples/scripts/openenv/sudoku.py @@ -460,7 +460,7 @@ def extract_feedback(observation) -> dict: # --------------------------------------------------------------------------- -def rollout_once( +async def rollout_once( trainer: GRPOTrainer, env: TextArenaEnv, tokenizer: AutoTokenizer, @@ -470,7 +470,7 @@ def rollout_once( difficulty: str = "hard", api_delay: float = 0.0, ) -> dict[str, list]: - result = env.reset() + result = await env.reset() time.sleep(api_delay) # Avoid rate limiting observation = result.observation @@ -552,7 +552,7 @@ def rollout_once( print(f"EXTRACTED MOVE: {move}") # Step environment - result = env.step(TextArenaAction(message=move)) + result = await env.step(TextArenaAction(message=move)) time.sleep(api_delay) # Avoid rate limiting observation = result.observation correct_score = float(result.reward or 0.0) @@ -774,7 +774,8 @@ def main() -> None: grpo_config.trackio_space_id = args.trackio_space_id grpo_config.gradient_checkpointing = args.gradient_checkpointing - def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: + async def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: + """Async rollout function - TRL handles the event loop automatically.""" all_prompt_ids = [] all_completion_ids = [] all_logprobs = [] @@ -785,7 +786,7 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: all_progress = [] for _ in prompts: - episode = rollout_once( + episode = await rollout_once( trainer=trainer, env=client, tokenizer=trainer.processing_class, diff --git a/examples/scripts/openenv/wordle.py b/examples/scripts/openenv/wordle.py index df1fa529e5d..3e91055099e 100644 --- a/examples/scripts/openenv/wordle.py +++ b/examples/scripts/openenv/wordle.py @@ -314,7 +314,7 @@ def make_user_prompt(prompt_text: str, messages: Iterable[TextArenaMessage]) -> return f"Conversation so far:\n{history_section}\n\nReply with your next guess enclosed in square brackets." -def rollout_once( +async def rollout_once( trainer: GRPOTrainer, env: TextArenaEnv, tokenizer: AutoTokenizer, @@ -323,7 +323,7 @@ def rollout_once( max_turns: int, max_new_tokens: int = 16, ) -> dict[str, list]: - result = env.reset() + result = await env.reset() observation = result.observation prompt_ids: list[int] = [] @@ -396,7 +396,7 @@ def rollout_once( guess = extract_guess(completion_text) model_outputs.append(completion_text.strip()) # Store raw model output for format reward - result = env.step(TextArenaAction(message=guess)) + result = await env.step(TextArenaAction(message=guess)) raw_rewards.append(float(result.reward or 0.0)) observation = result.observation @@ -550,7 +550,8 @@ def main() -> None: grpo_config.project = args.project or f"wordle-grpo-{sanitize_name(args.model_id)}-{timestamp}" grpo_config.trackio_space_id = args.trackio_space_id - def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: + async def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: + """Async rollout function - TRL handles the event loop automatically.""" episode_prompt_ids: list[list[int]] = [] episode_completion_ids: list[list[int]] = [] episode_logprobs: list[list[float]] = [] @@ -560,7 +561,7 @@ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: format_rewards: list[float] = [] for prompt_text in prompts: - episode = rollout_once( + episode = await rollout_once( trainer=trainer, env=client, tokenizer=tokenizer, diff --git a/trl/generation/vllm_generation.py b/trl/generation/vllm_generation.py index 489e817eb0a..262e0e967ca 100644 --- a/trl/generation/vllm_generation.py +++ b/trl/generation/vllm_generation.py @@ -14,6 +14,8 @@ """vLLM-based generation backend for TRL trainers.""" +import asyncio +import inspect import json import os from collections.abc import Callable @@ -52,6 +54,45 @@ import bitsandbytes as bnb +# Persistent event loop for scripts - preserves async resources (e.g., websockets) across calls +_async_rollout_loop = None + + +def run_async_safely(coro): + """ + Run an async coroutine safely from any context (notebooks, Colab, scripts, etc.). + + - If an event loop is running: applies nest_asyncio automatically (must be installed) + - If no event loop is running: uses a persistent loop to preserve async resources + """ + global _async_rollout_loop + + try: + running_loop = asyncio.get_running_loop() + except RuntimeError: + running_loop = None + + if running_loop is not None: + # Already in async context (e.g., notebook) - apply nest_asyncio automatically + try: + import nest_asyncio + + nest_asyncio.apply() + except ImportError: + raise RuntimeError( + "An event loop is already running (e.g., in a notebook). " + "Please install nest_asyncio (`pip install nest_asyncio`) to enable async rollout functions." + ) + return asyncio.run(coro) + + # No running loop - use a persistent loop to preserve async resources + if _async_rollout_loop is None or _async_rollout_loop.is_closed(): + _async_rollout_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_async_rollout_loop) + + return _async_rollout_loop.run_until_complete(coro) + + class VLLMGeneration: """Handles vLLM-based generation for trainers. @@ -547,7 +588,11 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte apply_chat_template({"prompt": p}, processing_class, **chat_template_kwargs)["prompt"] for p in rollout_prompts ] - output = rollout_func(rollout_prompts) + # Support both sync and async rollout functions: + if inspect.iscoroutinefunction(rollout_func): + output = run_async_safely(rollout_func(rollout_prompts)) + else: + output = rollout_func(rollout_prompts) else: if is_conversational({"prompt": ordered_set_of_prompts[0]}): output = self.vllm_client.chat( @@ -603,7 +648,13 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte apply_chat_template({"prompt": prompt}, processing_class, **chat_template_kwargs)["prompt"] for prompt in rollout_prompts ] - output = rollout_func(rollout_prompts) + # Support both sync and async rollout functions: + if inspect.iscoroutinefunction(rollout_func): + # Handle async rollout_func + output = run_async_safely(rollout_func(rollout_prompts)) + else: + # Handle sync rollout_func + output = rollout_func(rollout_prompts) required_keys = {"prompt_ids", "completion_ids", "logprobs"} extra_fields = {k: v for k, v in output.items() if k not in required_keys} prompt_ids = output["prompt_ids"] diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 297c664237e..309bbe48ac5 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -640,9 +640,15 @@ def cast_outputs_to_original_dtype(module, args, output): # Wrap rollout_func to capture trainer context if provided rollout_func = None if self.rollout_func is not None: - - def rollout_func(prompts): - return self.rollout_func(prompts, self) + # Check if the user's rollout_func is async + if inspect.iscoroutinefunction(self.rollout_func): + # Wrap async function to pass trainer context + async def rollout_func(prompts): + return await self.rollout_func(prompts, self) + else: + # Wrap sync function to pass trainer context + def rollout_func(prompts): + return self.rollout_func(prompts, self) self.vllm_generation = VLLMGeneration( model=self.model,