Skip to content
Open
25 changes: 13 additions & 12 deletions examples/notebooks/grpo_functiongemma_browsergym_openenv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@
"\n",
"max_steps=10\n",
"\n",
"def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:\n",
"async def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:\n",
" \"\"\"Async rollout function - TRL handles the event loop automatically.\"\"\"\n",
" episode_prompt_ids: list[list[int]] = []\n",
" episode_completion_ids: list[list[int]] = []\n",
" episode_logprobs: list[list[float]] = []\n",
Expand All @@ -206,7 +207,7 @@
"\n",
" for i, prompt_text in enumerate(prompts):\n",
" print(f\"[DEBUG] Processing prompt {i + 1}/{len(prompts)}\")\n",
" episode = rollout_once(\n",
" episode = await rollout_once(\n",
" trainer=trainer,\n",
" env=client,\n",
" tokenizer=trainer.processing_class,\n",
Expand Down Expand Up @@ -261,15 +262,15 @@
"from browsergym_env import BrowserGymAction\n",
"from transformers import AutoTokenizer\n",
"\n",
"def rollout_once(\n",
"async def rollout_once(\n",
" trainer: GRPOTrainer,\n",
" env: BrowserGymEnv,\n",
" tokenizer: AutoTokenizer,\n",
" dataset_prompt: str,\n",
" max_steps: int,\n",
") -> dict[str, list]:\n",
" \"\"\"Run one episode and collect training data (text-only, no screenshots).\"\"\"\n",
" result = env.reset()\n",
" result = await env.reset()\n",
" observation = result.observation\n",
"\n",
" prompt_ids: list[int] = []\n",
Expand Down Expand Up @@ -314,7 +315,7 @@
" print(f\"Step {step_num + 1}: {action_str}\")\n",
"\n",
" # Take action in environment\n",
" result = env.step(BrowserGymAction(action_str=action_str))\n",
" result = await env.step(BrowserGymAction(action_str=action_str))\n",
" observation = result.observation\n",
"\n",
" # Track rewards\n",
Expand Down Expand Up @@ -546,7 +547,7 @@
},
"outputs": [
{
"name": "stderr",
"name": "stdout",
"output_type": "stream",
"text": [
"/tmp/ipython-input-3830121904.py:1: UserWarning: You are importing from 'rollout_func', which is an experimental feature. This API may change or be removed at any time without prior notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.\n",
Expand All @@ -570,7 +571,7 @@
"output_type": "display_data"
},
{
"name": "stderr",
"name": "stdout",
"output_type": "stream",
"text": [
"Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 4/4 [00:00<00:00, 19.64it/s]\n"
Expand All @@ -596,7 +597,7 @@
},
"outputs": [
{
"name": "stderr",
"name": "stdout",
"output_type": "stream",
"text": [
"The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': 2, 'pad_token_id': 0}.\n"
Expand Down Expand Up @@ -678,7 +679,7 @@
]
},
{
"name": "stderr",
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:liger_kernel.transformers.model.gemma3:It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.\n",
Expand Down Expand Up @@ -1608,7 +1609,7 @@
"output_type": "display_data"
},
{
"name": "stderr",
"name": "stdout",
"output_type": "stream",
"text": [
"No files have been modified since last commit. Skipping to prevent empty commit.\n",
Expand Down Expand Up @@ -1700,7 +1701,7 @@
"output_type": "display_data"
},
{
"name": "stderr",
"name": "stdout",
"output_type": "stream",
"text": [
"No files have been modified since last commit. Skipping to prevent empty commit.\n",
Expand All @@ -1716,7 +1717,7 @@
"CommitInfo(commit_url='https://huggingface.co/sergiopaniego/browsergym-grpo-functiongemma-270m-it/commit/a17de133c28ca7fddfcb2694c32f2791de5ddbe6', commit_message='End of training', commit_description='', oid='a17de133c28ca7fddfcb2694c32f2791de5ddbe6', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sergiopaniego/browsergym-grpo-functiongemma-270m-it', endpoint='https://huggingface.co', repo_type='model', repo_id='sergiopaniego/browsergym-grpo-functiongemma-270m-it'), pr_revision=None, pr_num=None)"
]
},
"execution_count": 12,
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
Expand Down
Loading
Loading