diff --git a/environments/rlm_secrets/README.md b/environments/rlm_secrets/README.md index 5762587b7..5048cc318 100644 --- a/environments/rlm_secrets/README.md +++ b/environments/rlm_secrets/README.md @@ -72,7 +72,6 @@ Both reward functions have equal weight (0.5 each): | Parameter | Default | Description | |-----------|---------|-------------| | `num_train_examples` | 100 | Training puzzles | -| `num_eval_examples` | 20 | Evaluation puzzles | | `num_files` | 4 | Files per puzzle | | `max_turns` | 50 | Max REPL iterations | | `sub_tool_max_turns` | 3 | Max tool turns for sub-LLMs | @@ -80,6 +79,9 @@ Both reward functions have equal weight (0.5 each): | `code_execution_timeout` | 120 | Bash execution timeout (seconds) | | `**kwargs` | - | Passed on `RLMEnv.__init__` | +Note: The eval dataset is not built separately. For evaluation, re-instantiate the +environment with a different `seed` to generate a new synthetic split. + ## Why This Environment? This environment is specifically designed to test RLM capabilities: diff --git a/environments/rlm_secrets/rlm_secrets.py b/environments/rlm_secrets/rlm_secrets.py index daa038f5c..7d4d0a55f 100644 --- a/environments/rlm_secrets/rlm_secrets.py +++ b/environments/rlm_secrets/rlm_secrets.py @@ -318,6 +318,7 @@ def build_dataset( Dataset with prompt, answer, and info columns """ rows = [] + task_name = "rlm-secrets" for i in range(num_examples): puzzle = generate_puzzle(num_files=num_files) @@ -359,9 +360,11 @@ def build_dataset( rows.append( { + "example_id": i, "prompt": prompt, "answer": str(puzzle["correct_position"]), "info": {"puzzle": puzzle}, + "task": task_name, } ) @@ -443,7 +446,6 @@ async def correct_filesystem_state(state: State) -> float: def load_environment( num_train_examples: int = 100, - num_eval_examples: int = 20, num_files: int = 4, max_turns: int = 50, seed: int | None = None, @@ -458,7 +460,6 @@ def load_environment( Args: num_train_examples: Number of training puzzle instances - num_eval_examples: Number of evaluation puzzle instances num_files: Number of files per puzzle (default: 4) max_turns: Maximum REPL iterations (default: 50) seed: Random seed for dataset generation @@ -477,11 +478,6 @@ def load_environment( num_files=num_files, ) - eval_dataset = build_dataset( - num_examples=num_eval_examples, - num_files=num_files, - ) - rubric = vf.Rubric( funcs=[correct_answer, correct_filesystem_state], weights=[0.5, 0.5], @@ -489,7 +485,6 @@ def load_environment( return RLMSecretsEnv( dataset=train_dataset, - eval_dataset=eval_dataset, num_files=num_files, repl_language=repl_language, rubric=rubric, diff --git a/tests/test_rlm_env.py b/tests/test_rlm_env.py index 9c4a85274..38897afcf 100644 --- a/tests/test_rlm_env.py +++ b/tests/test_rlm_env.py @@ -1036,6 +1036,7 @@ class TestSubLLMRequestPaths: async def test_interleaved_uses_tokens_endpoint(self, rlm_env): mock_client = MagicMock() mock_response = MagicMock() + mock_response.choices = [MagicMock()] mock_client.post = AsyncMock(return_value=mock_response) mock_client.chat.completions.create = AsyncMock() @@ -1066,6 +1067,103 @@ async def test_interleaved_uses_tokens_endpoint(self, rlm_env): assert "max_tokens" not in body mock_client.chat.completions.create.assert_not_called() + @pytest.mark.asyncio + async def test_sub_llm_normalizes_messages(self, rlm_env): + mock_client = MagicMock() + mock_message = MagicMock() + mock_message.tool_calls = None + mock_message.content = "ok" + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=mock_message)] + mock_client.chat.completions.create = AsyncMock(return_value=mock_response) + + rlm_env.interleaved_rollouts = False + messages = [ + {"role": "user", "content": {"type": "text", "text": "hello"}}, + {"role": "user", "content": {"role": "user", "content": "inner"}}, + ] + state = {} + + await rlm_env._call_sub_llm_api(state, mock_client, "gpt-4", messages) + + args, kwargs = mock_client.chat.completions.create.call_args + assert args == () + sent_messages = kwargs["messages"] + assert sent_messages[0]["content"] == [{"type": "text", "text": "hello"}] + assert sent_messages[1]["content"] == "inner" + + @pytest.mark.asyncio + async def test_interleaved_sub_llm_uses_incremental_prompt_ids( + self, rlm_env_with_sub_tools + ): + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock() + + mock_tool_call = MagicMock() + mock_tool_call.id = "call_1" + mock_tool_call.function.name = "sample_tool" + mock_tool_call.function.arguments = '{"x": 2, "y": 3}' + + mock_message1 = MagicMock() + mock_message1.tool_calls = [mock_tool_call] + mock_message1.content = None + + mock_message2 = MagicMock() + mock_message2.tool_calls = None + mock_message2.content = "done" + + response1 = MagicMock() + response1.choices = [MagicMock(message=mock_message1)] + response2 = MagicMock() + response2.choices = [MagicMock(message=mock_message2)] + + mock_client.post = AsyncMock(side_effect=[response1, response2]) + + rlm_env_with_sub_tools.interleaved_rollouts = True + messages = [{"role": "user", "content": "Add 2 and 3"}] + state = {"sampling_args": {"max_tokens": 7}} + + token_payload = { + "prompt_ids": [1], + "prompt_mask": [0], + "completion_ids": [2], + "completion_mask": [1], + "completion_logprobs": [0.0], + "overlong_prompt": False, + "is_truncated": False, + } + + with ( + patch( + "verifiers.envs.experimental.rlm_env.tokenize_vllm", + new=AsyncMock(return_value=[1, 2, 3]), + ) as mock_tokenize, + patch( + "verifiers.envs.experimental.rlm_env.get_prompt_ids", + new=AsyncMock(return_value=[4, 5, 6]), + ) as mock_get_prompt_ids, + patch( + "verifiers.envs.experimental.rlm_env.parse_response_tokens", + new=AsyncMock(return_value=token_payload), + ), + patch( + "verifiers.envs.experimental.rlm_env.parse_response_messages", + new=AsyncMock(return_value=[{"role": "assistant", "content": "ok"}]), + ), + patch( + "verifiers.envs.experimental.rlm_env.parse_is_truncated", + new=AsyncMock(return_value=False), + ), + ): + await rlm_env_with_sub_tools._run_sub_llm( + state, mock_client, "gpt-4", messages + ) + + assert mock_client.post.await_count == 2 + mock_tokenize.assert_awaited_once() + mock_get_prompt_ids.assert_awaited_once() + mock_client.chat.completions.create.assert_not_called() + # ============================================================================= # 8. Root Tool Serialization (pickle) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index e8842dba4..8e324f5ab 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -456,6 +456,40 @@ def normalize_sampling_args(sampling_args: SamplingArgs) -> SamplingArgs: sampling_args.pop("max_completion_tokens") return {k: v for k, v in sampling_args.items() if v is not None} + client, model, oai_tools, sampling_args, message_type = resolve_optional_args( + client, model, oai_tools, sampling_args, message_type + ) + sampling_args = normalize_sampling_args(sampling_args) + if self.interleaved_rollouts: + sampling_args = prepare_sampling_args_for_token_prompts(sampling_args) + + prompt_ids: list[int] | None = None + if self.interleaved_rollouts and len(state["trajectory"]) > 0: + prompt_ids = await get_prompt_ids(state, prompt, client) + + return await self._call_model_api( + client=client, + model=model, + prompt=prompt, + oai_tools=oai_tools, + sampling_args=sampling_args, + message_type=message_type, + prompt_ids=prompt_ids, + ) + + async def _call_model_api( + self, + *, + client: AsyncOpenAI, + model: str, + prompt: Messages, + oai_tools: list[ChatCompletionToolParam] | None, + sampling_args: SamplingArgs, + message_type: MessageType, + prompt_ids: list[int] | None = None, + ) -> ModelResponse: + """Shared low-level model call used by main and sub-LLM paths.""" + def handle_overlong_prompt(func): """Decorator to handle overlong prompt errors from the model API.""" @@ -487,7 +521,7 @@ async def wrapper(*args, **kwargs): return wrapper @handle_overlong_prompt - async def get_model_response_with_messages( + async def call_with_messages( client: AsyncOpenAI, model: str, prompt: Messages, @@ -547,7 +581,7 @@ async def get_model_response_with_messages( return response @handle_overlong_prompt - async def get_model_response_with_tokens( + async def call_with_tokens( client: AsyncOpenAI, model: str, prompt: Messages, @@ -581,16 +615,8 @@ async def get_model_response_with_tokens( cast_to=ChatCompletion, ) - client, model, oai_tools, sampling_args, message_type = resolve_optional_args( - client, model, oai_tools, sampling_args, message_type - ) - sampling_args = normalize_sampling_args(sampling_args) - if self.interleaved_rollouts: - sampling_args = prepare_sampling_args_for_token_prompts(sampling_args) - - if self.interleaved_rollouts and len(state["trajectory"]) > 0: - prompt_ids = await get_prompt_ids(state, prompt, client) - response = await get_model_response_with_tokens( + if prompt_ids is not None: + response = await call_with_tokens( client=client, model=model, prompt=prompt, @@ -600,7 +626,7 @@ async def get_model_response_with_tokens( message_type=message_type, ) else: - response = await get_model_response_with_messages( + response = await call_with_messages( client=client, model=model, prompt=prompt, diff --git a/verifiers/envs/experimental/rlm_env.py b/verifiers/envs/experimental/rlm_env.py index ec516865d..b066dd919 100644 --- a/verifiers/envs/experimental/rlm_env.py +++ b/verifiers/envs/experimental/rlm_env.py @@ -51,14 +51,18 @@ from typing import TypedDict from aiohttp import web -from openai.types.chat import ChatCompletion, ChatCompletionFunctionToolParam +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletionFunctionToolParam from prime_tunnel import Tunnel import verifiers as vf from verifiers.types import ( + ChatCompletionToolParam, ChatMessage, ChatMessages, Messages, + MessageType, ModelResponse, + SamplingArgs, State, TrajectoryStep, ) @@ -73,6 +77,7 @@ from verifiers.utils.tool_utils import convert_func_to_oai_tool from verifiers.utils.token_utils import ( prepare_sampling_args_for_token_prompts, + get_prompt_ids, tokenize_vllm, ) from verifiers.utils.sandbox_exec_utils import SandboxExecutorMixin @@ -3078,52 +3083,38 @@ async def _call_sub_llm_api( model: str, messages: ChatMessages, tools: list | None = None, + *, + sub_state: State | None = None, ) -> Any | None: """Make a single sub-LLM API call matching main-model request mode.""" normalized_messages = self._normalize_message_content(messages) sampling_args = self._prepare_sub_llm_sampling_args( state, interleaved=self.interleaved_rollouts ) - payload: dict[str, Any] = { - "model": model, - "messages": normalized_messages, - "tools": tools, - } try: + prompt_ids: list[int] | None = None if self.interleaved_rollouts: - extra_body = sampling_args.pop("extra_body", {}) - prompt_ids = await tokenize_vllm( + if sub_state is not None and sub_state.get("trajectory"): + prompt_ids = await get_prompt_ids( + sub_state, normalized_messages, client + ) + else: + prompt_ids = await tokenize_vllm( + client=client, + messages=normalized_messages, + tools=tools, + model=model, + ) + return await asyncio.wait_for( + self._call_model_api( client=client, - messages=normalized_messages, - tools=tools, model=model, - ) - payload = { - "model": model, - "messages": normalized_messages, - "tools": tools, - "tokens": prompt_ids, - **sampling_args, - **extra_body, - } - return await asyncio.wait_for( - client.post( - "/chat/completions/tokens", - body=payload, - cast_to=ChatCompletion, - ), - timeout=self.sub_llm_api_timeout, - ) - payload = { - "model": model, - "messages": normalized_messages, - "tools": tools, - **sampling_args, - } - return await asyncio.wait_for( - client.chat.completions.create( - **payload, + prompt=normalized_messages, + oai_tools=tools, + sampling_args=sampling_args, + message_type="chat", + prompt_ids=prompt_ids, ), timeout=self.sub_llm_api_timeout, ) @@ -3158,6 +3149,19 @@ async def _run_sub_llm( self, state: State, client: Any, model: str, messages: ChatMessages ) -> SubLLMResult: """Run a sub-LLM call, with optional tool-calling loop.""" + sub_state: State | None = None + if self.interleaved_rollouts: + # Track a minimal sub-LLM trajectory so get_prompt_ids() can compute + # incremental prompt_ids (same interleaving strategy as the main LLM). + # This sub_state is only for tokenization continuity and is not added + # to the main trajectory or used for scoring. + sub_state = State() + sub_state["trajectory"] = [] + sub_state["client"] = client + sub_state["model"] = model + sub_state["oai_tools"] = self.sub_oai_tools or [] + sub_state["sampling_args"] = state.get("sampling_args") + # Fast path: no tools configured - single LLM call if not self.sub_tools: response = await self._call_sub_llm_api(state, client, model, messages) @@ -3192,10 +3196,16 @@ async def _run_sub_llm( for _ in range(self.sub_tool_max_turns): num_turns += 1 - prompt_snapshot = [cast(ChatMessage, dict(m)) for m in current_messages] + normalized_messages = self._normalize_message_content(current_messages) + prompt_snapshot = [cast(ChatMessage, dict(m)) for m in normalized_messages] response = await self._call_sub_llm_api( - state, client, model, current_messages, tools + state, + client, + model, + normalized_messages, + tools, + sub_state=sub_state, ) if response is None: return self._make_timeout_result( @@ -3206,6 +3216,32 @@ async def _run_sub_llm( num_turns, ) + if sub_state is not None: + tokens = await parse_response_tokens(response, "chat", self.max_seq_len) + if tokens is None: + sub_state = None + else: + completion_messages = await parse_response_messages( + response, "chat" + ) + response_is_truncated = await parse_is_truncated(response, "chat") + is_truncated = response_is_truncated or bool( + tokens.get("is_truncated") + ) + sub_state["trajectory"].append( + TrajectoryStep( + prompt=cast(Messages, prompt_snapshot), + completion=completion_messages, + response=response, + tokens=tokens, + reward=None, + advantage=None, + is_truncated=is_truncated, + trajectory_id="sub_llm_local", + extras={"is_sub_llm_call": True, "sub_state_only": True}, + ) + ) + prompt_tokens, completion_tokens = _extract_tokens_from_response(response) total_prompt_tokens += prompt_tokens total_completion_tokens += completion_tokens @@ -3260,8 +3296,15 @@ async def _run_sub_llm( ) ) - prompt_snapshot = [cast(ChatMessage, dict(m)) for m in current_messages] - response = await self._call_sub_llm_api(state, client, model, current_messages) + normalized_messages = self._normalize_message_content(current_messages) + prompt_snapshot = [cast(ChatMessage, dict(m)) for m in normalized_messages] + response = await self._call_sub_llm_api( + state, + client, + model, + normalized_messages, + sub_state=sub_state, + ) if response is None: return self._make_timeout_result( turns, @@ -4162,6 +4205,124 @@ async def add_trajectory_step(self, state: State, trajectory_step: TrajectorySte # MultiTurnEnv Interface # ========================================================================= + async def get_model_response( + self, + state: State, + prompt: Messages, + client: AsyncOpenAI | None = None, + model: str | None = None, + oai_tools: list[ChatCompletionToolParam] | None = None, + sampling_args: SamplingArgs | None = None, + message_type: MessageType | None = None, + ) -> ModelResponse: + """ + Override to keep interleaved prompt_id computation scoped to main-LLM steps. + + RLMEnv injects sub-LLM turns into state["trajectory"] for training. If we + feed that mixed trajectory into get_prompt_ids, the "previous turn" + becomes a sub-LLM step and the main prompt looks unrelated. That produces + an empty env_response and breaks /tokenize. We avoid that by filtering + trajectory steps to the main trajectory_id only. + """ + + def resolve_optional_args( + client: AsyncOpenAI | None, + model: str | None, + oai_tools: list[ChatCompletionToolParam] | None, + sampling_args: SamplingArgs | None, + message_type: MessageType | None, + ) -> tuple[ + AsyncOpenAI, + str, + list[ChatCompletionToolParam] | None, + SamplingArgs, + MessageType, + ]: + client = client or state["client"] + model = model or state["model"] + assert client is not None and model is not None + oai_tools = oai_tools or state["oai_tools"] + sampling_args = cast( + SamplingArgs, sampling_args or state["sampling_args"] or {} + ) + message_type = message_type or self.message_type + return client, model, oai_tools, sampling_args, message_type + + def normalize_sampling_args(sampling_args: SamplingArgs) -> SamplingArgs: + if "max_tokens" in sampling_args: + if sampling_args["max_tokens"] is None: + sampling_args.pop("max_tokens") + elif message_type == "chat": + sampling_args["max_completion_tokens"] = sampling_args.pop( + "max_tokens" + ) + if ( + "max_completion_tokens" in sampling_args + and sampling_args["max_completion_tokens"] is None + ): + sampling_args.pop("max_completion_tokens") + return {k: v for k, v in sampling_args.items() if v is not None} + + client, model, oai_tools, sampling_args, message_type = resolve_optional_args( + client, model, oai_tools, sampling_args, message_type + ) + sampling_args = normalize_sampling_args(sampling_args) + if self.interleaved_rollouts: + sampling_args = prepare_sampling_args_for_token_prompts(sampling_args) + + prompt_ids: list[int] | None = None + if ( + self.interleaved_rollouts + and message_type == "chat" + and len(state["trajectory"]) > 0 + ): + main_trajectory_id = state.get("trajectory_id") + main_steps = [ + step + for step in state["trajectory"] + if step.get("trajectory_id") == main_trajectory_id + ] + if main_steps: + # Do not mutate the original state; build a minimal view for + # prompt_id computation that excludes sub-LLM turns. + prompt_state = State() + prompt_state["trajectory"] = main_steps + prompt_state["client"] = client + prompt_state["model"] = model + prompt_state["oai_tools"] = oai_tools or [] + # Reuse cached suffix ids if available to avoid extra tokenization. + if "_cached_suffix_ids" in state: + prompt_state["_cached_suffix_ids"] = state["_cached_suffix_ids"] + # Keep sub-LLM debug context for empty-env-response logging. + if "_last_sub_llm_root_call" in state: + prompt_state["_last_sub_llm_root_call"] = state[ + "_last_sub_llm_root_call" + ] + prompt_ids = await get_prompt_ids(prompt_state, prompt, client) + # If suffix ids were computed on the temporary state, persist them + # on the main state to avoid re-tokenizing every turn. + if "_cached_suffix_ids" in prompt_state: + state["_cached_suffix_ids"] = prompt_state["_cached_suffix_ids"] + else: + # If no main steps are present (should be rare), fall back to + # full-tokenize so we still use the tokens endpoint. + prompt_ids = await tokenize_vllm( + client=client, + messages=prompt, + tools=oai_tools, + model=model, + ) + + return await self._call_model_api( + client=client, + model=model, + prompt=prompt, + oai_tools=oai_tools, + sampling_args=sampling_args, + message_type=message_type, + prompt_ids=prompt_ids, + ) + async def get_prompt_messages(self, state: State) -> Messages: """Build prompt messages, adding system prompt with tool docs on first turn.""" if len(state["trajectory"]) == 0: