Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,21 +209,6 @@ class RolloutScores(TypedDict):
metrics: dict[str, list[float]]
```

### ProcessedOutputs

```python
class ProcessedOutputs(TypedDict):
prompt_ids: list[list[int]]
prompt_mask: list[list[int]]
completion_ids: list[list[int]]
completion_mask: list[list[int]]
completion_logprobs: list[list[float]]
rewards: list[float]
is_truncated: list[bool]
```

Tokenized outputs for training.

---

## Classes
Expand Down
129 changes: 35 additions & 94 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand All @@ -37,7 +36,6 @@
ChatCompletionToolParam,
ChatMessage,
DatasetBuilder,
GenerateMetadata,
GenerateOutputs,
LogCallback,
Messages,
Expand All @@ -52,11 +50,15 @@
)
from verifiers.utils.async_utils import maybe_retry, maybe_semaphore
from verifiers.utils.error_utils import ErrorChain
from verifiers.utils.eval_utils import make_dataset, save_rollout_results
from verifiers.utils.message_utils import (
strip_nones_from_content,
)
from verifiers.utils.path_utils import get_results_path
from verifiers.utils.save_utils import (
sanitize_rollouts,
save_generate_outputs,
states_to_generate_metadata,
states_to_rollout_outputs,
)
from verifiers.utils.token_utils import (
get_prompt_ids,
prepare_sampling_args_for_token_prompts,
Expand Down Expand Up @@ -794,93 +796,30 @@ async def run_group(
await self.rubric.dummy_score_group(group_states)
return list(group_states)

def _prepare_rollout_results(
def _build_generate_outputs(
self,
all_states: list[State],
model: str,
client: AsyncOpenAI,
state_columns: list[str] | None,
results_path: Path | None,
gen_sampling_args: SamplingArgs,
sampling_args: SamplingArgs,
start_time: float,
) -> GenerateOutputs:
"""Prepare GenerateOutputs from a list of completed states."""
# Determine path_to_save
if results_path is None:
path_to_save = get_results_path(self.env_id, model)
else:
path_to_save = results_path
prompts = [state["prompt"] for state in all_states]
completions = [state.get("completion") for state in all_states]
answers = [state.get("answer", "") for state in all_states]
tasks = [state.get("task", "default") for state in all_states]
infos = [state.get("info", {}) for state in all_states]
example_ids = [state.get("example_id", 0) for state in all_states]
rewards = [state.get("reward", 0.0) for state in all_states]
stop_conditions = [state.get("stop_condition", None) for state in all_states]
is_truncated = [state.get("is_truncated", False) for state in all_states]

metrics: dict[str, list[float]] = {}
for state in all_states:
if state.get("metrics"):
for metric_name, metric_value in state["metrics"].items():
if metric_name not in metrics:
metrics[metric_name] = []
metrics[metric_name].append(metric_value)

num_unique_examples = len(set(example_ids)) if example_ids else 0
rollouts_per_example = (
len(all_states) // num_unique_examples if num_unique_examples > 0 else 1
)

def _tools_key(tools: list | None) -> str:
if not tools:
return ""
return str(sorted([t.get("function", {}).get("name", "") for t in tools]))

all_tools = [state.get("oai_tools") for state in all_states]
unique_tool_sets = set(_tools_key(t) for t in all_tools)
tools_vary = len(unique_tool_sets) > 1

if tools_vary:
metadata_tools = None
else:
metadata_tools = next((t for t in all_tools if t), self.oai_tools)

metadata = GenerateMetadata(
env_id=self.env_id,
env_args=self.env_args,
model=model,
base_url=str(client.base_url) if hasattr(client, "base_url") else "",
num_examples=num_unique_examples,
rollouts_per_example=rollouts_per_example,
sampling_args=gen_sampling_args,
date=datetime.now().isoformat(),
time_ms=(time.time() - start_time) * 1000.0,
avg_reward=sum(rewards) / len(rewards) if rewards else 0.0,
avg_metrics={
name: sum(values) / len(values) if values else 0.0
for name, values in metrics.items()
},
state_columns=state_columns or [],
path_to_save=path_to_save,
tools=metadata_tools,
)

return GenerateOutputs(
prompt=prompts,
completion=completions,
answer=answers,
state=all_states,
task=tasks,
info=infos,
example_id=example_ids,
reward=rewards,
metrics=metrics,
stop_conditions=stop_conditions,
is_truncated=is_truncated,
metadata=metadata,
rollouts = states_to_rollout_outputs(all_states, state_columns or [])
metadata = states_to_generate_metadata(
self.env_id,
self.env_args,
model,
client,
all_states,
state_columns,
sampling_args,
start_time,
results_path,
)
return GenerateOutputs(rollouts=rollouts, metadata=metadata)

async def generate(
self,
Expand All @@ -895,7 +834,7 @@ async def generate(
state_columns: list[str] | None = None,
save_results: bool = False,
save_every: int = -1,
push_to_hf_hub: bool = False,
save_to_hf_hub: bool = False,
hf_hub_dataset_name: str | None = None,
use_tqdm: bool = True,
independent_scoring: bool = False,
Expand Down Expand Up @@ -1017,7 +956,7 @@ async def generate(
and save_every > 0
and groups_or_rollouts_completed % save_every == 0
):
temp_results = self._prepare_rollout_results(
intermediate_outputs = self._build_generate_outputs(
all_states,
model,
client,
Expand All @@ -1027,17 +966,17 @@ async def generate(
start_time,
)
self.logger.debug(
f"Saving intermediate results to {temp_results['metadata']['path_to_save']}"
f"Saving intermediate results to {intermediate_outputs['metadata']['path_to_save']}"
)
save_rollout_results(temp_results)
save_generate_outputs(intermediate_outputs)
finally:
if pbar is not None:
pbar.close()

# sort by example_id to ensure deterministic ordering regardless of completion order
all_states.sort(key=lambda s: s.get("example_id", 0))

results = self._prepare_rollout_results(
outputs = self._build_generate_outputs(
all_states,
model,
client,
Expand All @@ -1049,11 +988,11 @@ async def generate(

# save if requested
if save_results:
save_rollout_results(results, push_to_hf_hub, hf_hub_dataset_name)
save_generate_outputs(outputs, save_to_hf_hub, hf_hub_dataset_name)
if on_log is not None:
on_log(f"Saved final results to {results['metadata']['path_to_save']}")
on_log(f"Saved final outputs to {outputs['metadata']['path_to_save']}")

return results
return outputs

def generate_sync(
self,
Expand Down Expand Up @@ -1116,7 +1055,7 @@ async def evaluate(
state_columns: list[str] | None = None,
save_results: bool = False,
save_every: int = -1,
push_to_hf_hub: bool = False,
save_to_hf_hub: bool = False,
hf_hub_dataset_name: str | None = None,
use_tqdm: bool = True,
independent_scoring: bool = False,
Expand All @@ -1142,7 +1081,7 @@ async def evaluate(
state_columns=state_columns,
save_results=save_results,
save_every=save_every,
push_to_hf_hub=push_to_hf_hub,
save_to_hf_hub=save_to_hf_hub,
hf_hub_dataset_name=hf_hub_dataset_name,
use_tqdm=use_tqdm,
independent_scoring=independent_scoring,
Expand All @@ -1167,7 +1106,7 @@ def evaluate_sync(
state_columns: list[str] | None = None,
save_results: bool = False,
save_every: int = -1,
push_to_hf_hub: bool = False,
save_to_hf_hub: bool = False,
hf_hub_dataset_name: str | None = None,
independent_scoring: bool = False,
max_retries: int = 0,
Expand All @@ -1188,7 +1127,7 @@ def evaluate_sync(
state_columns=state_columns,
save_results=save_results,
save_every=save_every,
push_to_hf_hub=push_to_hf_hub,
save_to_hf_hub=save_to_hf_hub,
hf_hub_dataset_name=hf_hub_dataset_name,
independent_scoring=independent_scoring,
max_retries=max_retries,
Expand Down Expand Up @@ -1235,7 +1174,9 @@ def set_score_rollouts(self, score_rollouts: bool) -> None:
"""Set the score rollouts flag for this environment."""
self.score_rollouts = score_rollouts

make_dataset = staticmethod(make_dataset)
make_dataset = staticmethod(
lambda x: Dataset.from_list(sanitize_rollouts(x["rollouts"]))
) # backwards compatibility


_EnvT = TypeVar("_EnvT", bound=Environment)
Expand Down
42 changes: 19 additions & 23 deletions verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,20 +167,28 @@ class GenerateMetadata(TypedDict):
tools: list[ChatCompletionToolParam] | None


class RolloutOutput(TypedDict):
"""TypedDict for generation outputs."""

example_id: int
prompt: Messages
completion: Messages
answer: str
task: str
info: Info
tools: list[ChatCompletionToolParam] | None
reward: float
metrics: dict[str, float]
stop_condition: str | None
is_truncated: bool
timing: RolloutTiming
error: Error | None


class GenerateOutputs(TypedDict):
"""TypedDict for generation outputs."""

prompt: list[Messages]
completion: list[Messages]
answer: list[str]
state: list[State]
task: list[str]
info: list[Info]
example_id: list[int]
reward: list[float]
metrics: dict[str, list[float]]
stop_conditions: list[str | None]
is_truncated: list[bool]
rollouts: list[RolloutOutput]
metadata: GenerateMetadata


Expand All @@ -198,18 +206,6 @@ class RolloutScores(TypedDict):
metrics: dict[str, list[float]]


class ProcessedOutputs(TypedDict):
"""TypedDict for processed outputs."""

prompt_ids: list[list[int]]
prompt_mask: list[list[int]]
completion_ids: list[list[int]]
completion_mask: list[list[int]]
completion_logprobs: list[list[float]]
rewards: list[float]
is_truncated: list[bool]


Endpoint = TypedDict("Endpoint", {"key": str, "url": str, "model": str})
Endpoints = dict[str, Endpoint]

Expand Down
Loading
Loading