diff --git a/docs/reference.md b/docs/reference.md index eaad7c95d..14a12db38 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -209,21 +209,6 @@ class RolloutScores(TypedDict): metrics: dict[str, list[float]] ``` -### ProcessedOutputs - -```python -class ProcessedOutputs(TypedDict): - prompt_ids: list[list[int]] - prompt_mask: list[list[int]] - completion_ids: list[list[int]] - completion_mask: list[list[int]] - completion_logprobs: list[list[float]] - rewards: list[float] - is_truncated: list[bool] -``` - -Tokenized outputs for training. - --- ## Classes diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index e8842dba4..0ebc205fb 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -10,7 +10,6 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from copy import deepcopy -from datetime import datetime from pathlib import Path from typing import ( TYPE_CHECKING, @@ -37,7 +36,6 @@ ChatCompletionToolParam, ChatMessage, DatasetBuilder, - GenerateMetadata, GenerateOutputs, LogCallback, Messages, @@ -52,11 +50,15 @@ ) from verifiers.utils.async_utils import maybe_retry, maybe_semaphore from verifiers.utils.error_utils import ErrorChain -from verifiers.utils.eval_utils import make_dataset, save_rollout_results from verifiers.utils.message_utils import ( strip_nones_from_content, ) -from verifiers.utils.path_utils import get_results_path +from verifiers.utils.save_utils import ( + sanitize_rollouts, + save_generate_outputs, + states_to_generate_metadata, + states_to_rollout_outputs, +) from verifiers.utils.token_utils import ( get_prompt_ids, prepare_sampling_args_for_token_prompts, @@ -794,93 +796,30 @@ async def run_group( await self.rubric.dummy_score_group(group_states) return list(group_states) - def _prepare_rollout_results( + def _build_generate_outputs( self, all_states: list[State], model: str, client: AsyncOpenAI, state_columns: list[str] | None, results_path: Path | None, - gen_sampling_args: SamplingArgs, + sampling_args: SamplingArgs, start_time: float, ) -> GenerateOutputs: """Prepare GenerateOutputs from a list of completed states.""" - # Determine path_to_save - if results_path is None: - path_to_save = get_results_path(self.env_id, model) - else: - path_to_save = results_path - prompts = [state["prompt"] for state in all_states] - completions = [state.get("completion") for state in all_states] - answers = [state.get("answer", "") for state in all_states] - tasks = [state.get("task", "default") for state in all_states] - infos = [state.get("info", {}) for state in all_states] - example_ids = [state.get("example_id", 0) for state in all_states] - rewards = [state.get("reward", 0.0) for state in all_states] - stop_conditions = [state.get("stop_condition", None) for state in all_states] - is_truncated = [state.get("is_truncated", False) for state in all_states] - - metrics: dict[str, list[float]] = {} - for state in all_states: - if state.get("metrics"): - for metric_name, metric_value in state["metrics"].items(): - if metric_name not in metrics: - metrics[metric_name] = [] - metrics[metric_name].append(metric_value) - - num_unique_examples = len(set(example_ids)) if example_ids else 0 - rollouts_per_example = ( - len(all_states) // num_unique_examples if num_unique_examples > 0 else 1 - ) - - def _tools_key(tools: list | None) -> str: - if not tools: - return "" - return str(sorted([t.get("function", {}).get("name", "") for t in tools])) - - all_tools = [state.get("oai_tools") for state in all_states] - unique_tool_sets = set(_tools_key(t) for t in all_tools) - tools_vary = len(unique_tool_sets) > 1 - - if tools_vary: - metadata_tools = None - else: - metadata_tools = next((t for t in all_tools if t), self.oai_tools) - - metadata = GenerateMetadata( - env_id=self.env_id, - env_args=self.env_args, - model=model, - base_url=str(client.base_url) if hasattr(client, "base_url") else "", - num_examples=num_unique_examples, - rollouts_per_example=rollouts_per_example, - sampling_args=gen_sampling_args, - date=datetime.now().isoformat(), - time_ms=(time.time() - start_time) * 1000.0, - avg_reward=sum(rewards) / len(rewards) if rewards else 0.0, - avg_metrics={ - name: sum(values) / len(values) if values else 0.0 - for name, values in metrics.items() - }, - state_columns=state_columns or [], - path_to_save=path_to_save, - tools=metadata_tools, - ) - - return GenerateOutputs( - prompt=prompts, - completion=completions, - answer=answers, - state=all_states, - task=tasks, - info=infos, - example_id=example_ids, - reward=rewards, - metrics=metrics, - stop_conditions=stop_conditions, - is_truncated=is_truncated, - metadata=metadata, + rollouts = states_to_rollout_outputs(all_states, state_columns or []) + metadata = states_to_generate_metadata( + self.env_id, + self.env_args, + model, + client, + all_states, + state_columns, + sampling_args, + start_time, + results_path, ) + return GenerateOutputs(rollouts=rollouts, metadata=metadata) async def generate( self, @@ -895,7 +834,7 @@ async def generate( state_columns: list[str] | None = None, save_results: bool = False, save_every: int = -1, - push_to_hf_hub: bool = False, + save_to_hf_hub: bool = False, hf_hub_dataset_name: str | None = None, use_tqdm: bool = True, independent_scoring: bool = False, @@ -1017,7 +956,7 @@ async def generate( and save_every > 0 and groups_or_rollouts_completed % save_every == 0 ): - temp_results = self._prepare_rollout_results( + intermediate_outputs = self._build_generate_outputs( all_states, model, client, @@ -1027,9 +966,9 @@ async def generate( start_time, ) self.logger.debug( - f"Saving intermediate results to {temp_results['metadata']['path_to_save']}" + f"Saving intermediate results to {intermediate_outputs['metadata']['path_to_save']}" ) - save_rollout_results(temp_results) + save_generate_outputs(intermediate_outputs) finally: if pbar is not None: pbar.close() @@ -1037,7 +976,7 @@ async def generate( # sort by example_id to ensure deterministic ordering regardless of completion order all_states.sort(key=lambda s: s.get("example_id", 0)) - results = self._prepare_rollout_results( + outputs = self._build_generate_outputs( all_states, model, client, @@ -1049,11 +988,11 @@ async def generate( # save if requested if save_results: - save_rollout_results(results, push_to_hf_hub, hf_hub_dataset_name) + save_generate_outputs(outputs, save_to_hf_hub, hf_hub_dataset_name) if on_log is not None: - on_log(f"Saved final results to {results['metadata']['path_to_save']}") + on_log(f"Saved final outputs to {outputs['metadata']['path_to_save']}") - return results + return outputs def generate_sync( self, @@ -1116,7 +1055,7 @@ async def evaluate( state_columns: list[str] | None = None, save_results: bool = False, save_every: int = -1, - push_to_hf_hub: bool = False, + save_to_hf_hub: bool = False, hf_hub_dataset_name: str | None = None, use_tqdm: bool = True, independent_scoring: bool = False, @@ -1142,7 +1081,7 @@ async def evaluate( state_columns=state_columns, save_results=save_results, save_every=save_every, - push_to_hf_hub=push_to_hf_hub, + save_to_hf_hub=save_to_hf_hub, hf_hub_dataset_name=hf_hub_dataset_name, use_tqdm=use_tqdm, independent_scoring=independent_scoring, @@ -1167,7 +1106,7 @@ def evaluate_sync( state_columns: list[str] | None = None, save_results: bool = False, save_every: int = -1, - push_to_hf_hub: bool = False, + save_to_hf_hub: bool = False, hf_hub_dataset_name: str | None = None, independent_scoring: bool = False, max_retries: int = 0, @@ -1188,7 +1127,7 @@ def evaluate_sync( state_columns=state_columns, save_results=save_results, save_every=save_every, - push_to_hf_hub=push_to_hf_hub, + save_to_hf_hub=save_to_hf_hub, hf_hub_dataset_name=hf_hub_dataset_name, independent_scoring=independent_scoring, max_retries=max_retries, @@ -1235,7 +1174,9 @@ def set_score_rollouts(self, score_rollouts: bool) -> None: """Set the score rollouts flag for this environment.""" self.score_rollouts = score_rollouts - make_dataset = staticmethod(make_dataset) + make_dataset = staticmethod( + lambda x: Dataset.from_list(sanitize_rollouts(x["rollouts"])) + ) # backwards compatibility _EnvT = TypeVar("_EnvT", bound=Environment) diff --git a/verifiers/types.py b/verifiers/types.py index 17f5b0243..e1be1d3d2 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -167,20 +167,28 @@ class GenerateMetadata(TypedDict): tools: list[ChatCompletionToolParam] | None +class RolloutOutput(TypedDict): + """TypedDict for generation outputs.""" + + example_id: int + prompt: Messages + completion: Messages + answer: str + task: str + info: Info + tools: list[ChatCompletionToolParam] | None + reward: float + metrics: dict[str, float] + stop_condition: str | None + is_truncated: bool + timing: RolloutTiming + error: Error | None + + class GenerateOutputs(TypedDict): """TypedDict for generation outputs.""" - prompt: list[Messages] - completion: list[Messages] - answer: list[str] - state: list[State] - task: list[str] - info: list[Info] - example_id: list[int] - reward: list[float] - metrics: dict[str, list[float]] - stop_conditions: list[str | None] - is_truncated: list[bool] + rollouts: list[RolloutOutput] metadata: GenerateMetadata @@ -198,18 +206,6 @@ class RolloutScores(TypedDict): metrics: dict[str, list[float]] -class ProcessedOutputs(TypedDict): - """TypedDict for processed outputs.""" - - prompt_ids: list[list[int]] - prompt_mask: list[list[int]] - completion_ids: list[list[int]] - completion_mask: list[list[int]] - completion_logprobs: list[list[float]] - rewards: list[float] - is_truncated: list[bool] - - Endpoint = TypedDict("Endpoint", {"key": str, "url": str, "model": str}) Endpoints = dict[str, Endpoint] diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 8035967c6..2426886e1 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -1,12 +1,10 @@ import asyncio import importlib.util -import json import logging import time from collections import Counter, defaultdict -from contextlib import contextmanager from pathlib import Path -from typing import cast +from typing import Any, cast try: import tomllib # type: ignore[import-not-found] @@ -14,15 +12,12 @@ import tomli as tomllib # type: ignore[import-not-found] import numpy as np -from datasets import Dataset, disable_progress_bar, enable_progress_bar -from datasets.utils import logging as ds_logging import verifiers as vf from verifiers.types import ( Endpoints, EvalConfig, EvalRunConfig, - GenerateMetadata, GenerateOutputs, LogCallback, ProgressCallback, @@ -33,7 +28,7 @@ from verifiers.utils.client_utils import setup_client from verifiers.utils.error_utils import ErrorChain from verifiers.utils.logging_utils import print_prompt_completions_sample, print_time -from verifiers.utils.message_utils import messages_to_printable, sanitize_tool_calls +from verifiers.utils.message_utils import messages_to_printable from verifiers.utils.path_utils import get_eval_results_path logger = logging.getLogger(__name__) @@ -181,54 +176,40 @@ def load_toml_config(path: Path) -> list[dict]: return merged_eval_list -def get_results_by_task(results: GenerateOutputs) -> dict[str, GenerateOutputs]: - """Group results by task name. +def to_col_order(list_of_dicts: list[dict[str, Any]]) -> dict[str, list[float]]: + """Convert a list of dictionaries to a dictionary of lists, ordered by the keys of the first dictionary.""" + return {k: [m[k] for m in list_of_dicts] for k in list_of_dicts[0].keys()} - Args: - results: The GenerateOutputs from an evaluation run. - Returns: - A dictionary mapping task names to their corresponding GenerateOutputs. - """ - task_indices: dict[str, list[int]] = {} - for i, task in enumerate(results["task"]): - if task not in task_indices: - task_indices[task] = [] - task_indices[task].append(i) - - task_results: dict[str, GenerateOutputs] = {} - for task, indices in task_indices.items(): - task_results[task] = GenerateOutputs( - prompt=[results["prompt"][i] for i in indices], - completion=[results["completion"][i] for i in indices], - answer=[results["answer"][i] for i in indices], - state=[results["state"][i] for i in indices], - task=[results["task"][i] for i in indices], - info=[results["info"][i] for i in indices], - example_id=[results["example_id"][i] for i in indices], - reward=[results["reward"][i] for i in indices], - metrics={k: [v[i] for i in indices] for k, v in results["metrics"].items()}, - stop_conditions=[results["stop_conditions"][i] for i in indices], - is_truncated=[results["is_truncated"][i] for i in indices], - metadata=results["metadata"], - ) - return task_results +def get_task_outputs(results: GenerateOutputs, task: str) -> GenerateOutputs: + """Get only the rollouts for a given task.""" + rollouts = [r for r in results["rollouts"] if r["task"] == task] + return GenerateOutputs( + rollouts=rollouts, + metadata=results["metadata"], # duplicate metadata + ) -def print_rewards(results: GenerateOutputs): +def print_rewards(outputs: GenerateOutputs): + metadata = outputs["metadata"] + n = metadata["num_examples"] + r = metadata["rollouts_per_example"] + + rewards = [r["reward"] for r in outputs["rollouts"]] print("Rewards:") print( - f"reward: avg - {sum(results['reward']) / len(results['reward']):.3f}, std - {np.std(results['reward']):.3f}" + f"reward: avg - {sum(rewards) / len(rewards):.3f}, std - {np.std(rewards):.3f}" ) - r = results["metadata"]["rollouts_per_example"] - n = len(results["reward"]) // r # results are sorted by example_id, so rollout i is at indices [i, i+r, i+2r, ...] for i in range(r): - trials = [round(results["reward"][i + (j * r)], 3) for j in range(n)] + trials = [round(rewards[i + (j * r)], 3) for j in range(n)] out = f"r{i + 1}: {trials}" print(out) - for k in results["metrics"]: - v = results["metrics"][k] + + metrics = [r["metrics"] for r in outputs["rollouts"]] + metrics_col = to_col_order(metrics) + for k in metrics_col.keys(): + v = metrics_col[k] print(f"{k}: avg - {sum(v) / len(v):.3f}, std - {np.std(v):.3f}") for i in range(r): trials = [round(v[i + (j * r)], 3) for j in range(n)] @@ -236,16 +217,18 @@ def print_rewards(results: GenerateOutputs): print(out) -def print_info(results: GenerateOutputs): +def print_info(outputs: GenerateOutputs): + is_truncated = [r["is_truncated"] for r in outputs["rollouts"]] print("Info:") print( - f"is_truncated: avg - {np.mean(results['is_truncated']):.3f}, std - {np.std(results['is_truncated']):.3f}" + f"is_truncated: avg - {np.mean(is_truncated):.3f}, std - {np.std(is_truncated):.3f}" ) - counter = Counter(results["stop_conditions"]) + stop_conditions = [r["stop_condition"] for r in outputs["rollouts"]] + counter = Counter(stop_conditions) print( f"stop_conditions: {', '.join([f'{k}: {v / counter.total():.3f}' for k, v in counter.items()])}" ) - errors = [s.get("error") for s in results["state"]] + errors = [r.get("error") for r in outputs["rollouts"]] has_errors = [e is not None for e in errors] if any(has_errors): print( @@ -258,13 +241,13 @@ def print_info(results: GenerateOutputs): print(f" - {repr(error_chain)}: {count / counter.total():.3f}") -def print_timing(results: GenerateOutputs): +def print_timing(outputs: GenerateOutputs): print("Timing:") - generation_ms_arr = np.array( - [s["timing"]["generation_ms"] for s in results["state"]] - ) - scoring_ms_arr = np.array([s["timing"]["scoring_ms"] for s in results["state"]]) - total_ms_arr = np.array([s["timing"]["total_ms"] for s in results["state"]]) + timing = [r["timing"] for r in outputs["rollouts"]] + timing_col = to_col_order(cast(list[dict], timing)) + generation_ms_arr = np.array(timing_col["generation_ms"]) + scoring_ms_arr = np.array(timing_col["scoring_ms"]) + total_ms_arr = np.array(timing_col["total_ms"]) generation_arr = generation_ms_arr / 1000 scoring_arr = scoring_ms_arr / 1000 total_arr = total_ms_arr / 1000 @@ -280,43 +263,45 @@ def print_timing(results: GenerateOutputs): ) -def print_results( - results: GenerateOutputs, - num_samples: int = 1, -): - assert results["metadata"] is not None +def print_results(outputs: GenerateOutputs, num_samples: int = 1): + assert outputs["metadata"] is not None print("--- Evaluation ---") - print(f"Environment: {results['metadata']['env_id']}") - print(f"Model: {results['metadata']['model']}") - print(f"Provider: {results['metadata']['base_url']}") - print(f"Examples: {results['metadata']['num_examples']}") - print(f"Rollouts per example: {results['metadata']['rollouts_per_example']}") + print(f"Environment: {outputs['metadata']['env_id']}") + print(f"Model: {outputs['metadata']['model']}") + print(f"Provider: {outputs['metadata']['base_url']}") + print(f"Examples: {outputs['metadata']['num_examples']}") + print(f"Rollouts per example: {outputs['metadata']['rollouts_per_example']}") print("--- Example ---") - printable_prompts = [messages_to_printable(p) for p in results["prompt"]] - printable_completions = [messages_to_printable(c) for c in results["completion"]] - errors = [s.get("error") for s in results["state"]] + printable_prompts = [ + messages_to_printable(r["prompt"]) for r in outputs["rollouts"] + ] + printable_completions = [ + messages_to_printable(r["completion"]) for r in outputs["rollouts"] + ] + rewards = [r["reward"] for r in outputs["rollouts"]] + errors = [r.get("error") for r in outputs["rollouts"]] print_prompt_completions_sample( printable_prompts, printable_completions, errors, - results["reward"], + rewards, step=0, num_samples=num_samples, ) print("--- All ---") - print_rewards(results) - print_info(results) - print_timing(results) - - num_tasks = len(set(results["task"])) - if num_tasks > 1: - task_results = get_results_by_task(results) - for task, task_results in task_results.items(): + print_rewards(outputs) + print_info(outputs) + print_timing(outputs) + + tasks = set([r["task"] for r in outputs["rollouts"]]) + if len(tasks) > 1: + for task in tasks: + task_outputs = get_task_outputs(outputs, task) print(f"\n--- {task} ---") - print_rewards(task_results) - print_info(task_results) - print_timing(task_results) + print_rewards(task_outputs) + print_info(task_outputs) + print_timing(task_outputs) async def run_evaluation( @@ -347,7 +332,7 @@ async def run_evaluation( ) # disable tqdm when callbacks are provided (TUI handles progress display) use_tqdm = config.use_tqdm and on_progress is None - results = await vf_env.evaluate( + outputs = await vf_env.evaluate( client=client, model=config.model, sampling_args=config.sampling_args, @@ -360,7 +345,7 @@ async def run_evaluation( state_columns=config.state_columns, save_results=config.save_results, save_every=config.save_every, - push_to_hf_hub=config.save_to_hf_hub, + save_to_hf_hub=config.save_to_hf_hub, hf_hub_dataset_name=config.hf_hub_dataset_name, use_tqdm=use_tqdm, independent_scoring=config.independent_scoring, @@ -370,7 +355,7 @@ async def run_evaluation( on_log=on_log, ) - return results + return outputs async def run_evaluations(config: EvalRunConfig) -> None: @@ -511,103 +496,3 @@ def on_log(message: str) -> None: # print final summary after exit display.print_final_summary() - - -def sanitize_metadata(metadata: GenerateMetadata) -> dict: - metadata_dict = dict(metadata) - metadata_dict.pop("path_to_save") - metadata_dict.pop("date") - - return metadata_dict - - -def get_hf_hub_dataset_name(results: GenerateOutputs) -> str: - metadata = results["metadata"] - dataset_name = ( - metadata["env_id"] - + "_" - + metadata["model"].replace("/", "_") - + "_n" - + str(metadata["num_examples"]) - + "_r" - + str(metadata["rollouts_per_example"]) - ) - return dataset_name - - -def make_dataset(results: GenerateOutputs, **kwargs) -> Dataset: - clean_prompts = [messages_to_printable(p) for p in results["prompt"]] - clean_prompts = [sanitize_tool_calls(p) for p in clean_prompts] - clean_completions = [messages_to_printable(c) for c in results["completion"]] - clean_completions = [sanitize_tool_calls(c) for c in clean_completions] - save_info = any(info != {} for info in results["info"]) - save_answer = any(answer != "" for answer in results["answer"]) - errors = [s.get("error") for s in results["state"]] - results_dict = { - "example_id": results["example_id"], - "prompt": clean_prompts, - "completion": clean_completions, - "task": results["task"], - "reward": results["reward"], - "error": [repr(e) if e is not None else None for e in errors], - "generation_ms": [s["timing"]["generation_ms"] for s in results["state"]], - "scoring_ms": [s["timing"]["scoring_ms"] for s in results["state"]], - "total_ms": [s["timing"]["total_ms"] for s in results["state"]], - } - if save_info: - results_dict["info"] = results["info"] - if save_answer: - results_dict["answer"] = results["answer"] - for k in results["metrics"]: - v = results["metrics"][k] - results_dict[k] = v - - # Add selected state columns if specified - state_columns = results["metadata"]["state_columns"] - if state_columns: - for col in state_columns: - if col == "responses": - results_dict[col] = [ - [r.model_dump() for r in s.get(col, [])] for s in results["state"] - ] - else: - results_dict[col] = [s.get(col) for s in results["state"]] - - return Dataset.from_dict(results_dict) - - -@contextmanager -def quiet_datasets(): - prev_level = ds_logging.get_verbosity() - ds_logging.set_verbosity(ds_logging.WARNING) - disable_progress_bar() - try: - yield - finally: - ds_logging.set_verbosity(prev_level) - enable_progress_bar() - - -def save_to_disk(dataset: Dataset, metadata_dict: dict, path_to_save: Path): - path_to_save.parent.mkdir(parents=True, exist_ok=True) - with quiet_datasets(): - dataset.to_json(path_to_save / "results.jsonl") - with open(path_to_save / "metadata.json", "w") as f: - json.dump(metadata_dict, f) - - -def save_rollout_results( - results: GenerateOutputs, - push_to_hf_hub: bool = False, - hf_hub_dataset_name: str | None = None, -): - path_to_save = results["metadata"]["path_to_save"] - path_to_save.parent.mkdir(parents=True, exist_ok=True) - dataset = make_dataset(results) - metadata_dict = sanitize_metadata(results["metadata"]) - save_to_disk(dataset, metadata_dict, path_to_save) - logger.info(f"Results saved to {path_to_save}") - if push_to_hf_hub: - dataset_name = hf_hub_dataset_name or get_hf_hub_dataset_name(results) - dataset.push_to_hub(dataset_name) - logger.info(f"Dataset saved to Hugging Face Hub: {dataset_name}") diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py new file mode 100644 index 000000000..06710e8ee --- /dev/null +++ b/verifiers/utils/save_utils.py @@ -0,0 +1,209 @@ +import json +import logging +import time +from collections import defaultdict +from datetime import datetime +from pathlib import Path + +from datasets import Dataset +from openai import AsyncOpenAI + +from verifiers.types import ( + GenerateMetadata, + GenerateOutputs, + RolloutOutput, + SamplingArgs, + State, +) +from verifiers.utils.message_utils import messages_to_printable, sanitize_tool_calls +from verifiers.utils.path_utils import get_results_path + +logger = logging.getLogger(__name__) + + +def state_to_rollout_output( + state: State, state_columns: list[str] = [] +) -> RolloutOutput: + """Converts a state to a rollout output. Adds state columns to the output.""" + rollout_output = RolloutOutput( + example_id=state.get("example_id", 0), + prompt=state.get("prompt"), + completion=state.get("completion"), + answer=state.get("answer", ""), + task=state.get("task", "default"), + info=state.get("info", {}), + tools=state.get("oai_tools", {}), + reward=state.get("reward", 0.0), + metrics=state.get("metrics", {}), + stop_condition=state.get("stop_condition", None), + is_truncated=state.get("is_truncated", False), + timing=state.get("timing", {}), + error=state.get("error", None), + ) + for col in state_columns: + rollout_output[col] = state.get(col) + + return rollout_output + + +def states_to_rollout_outputs( + states: list[State], state_columns: list[str] = [] +) -> list[RolloutOutput]: + """Converts a list of states to a list of rollout outputs.""" + return [state_to_rollout_output(state, state_columns) for state in states] + + +def states_to_generate_metadata( + env_id: str, + env_args: dict, + model: str, + client: AsyncOpenAI, + states: list[State], + state_columns: list[str] | None, + sampling_args: SamplingArgs, + start_time: float, + results_path: Path | None, +) -> GenerateMetadata: + """Converts a list of states to generate metadata.""" + base_url = str(client.base_url) if hasattr(client, "base_url") else "" + rewards = [s.get("reward", 0.0) for s in states] + avg_reward = sum(rewards) / len(rewards) if rewards else 0.0 + + metrics: dict[str, list[float]] = defaultdict(list) + for state in states: + if state.get("metrics"): + for metric_name, metric_value in state["metrics"].items(): + metrics[metric_name].append(metric_value) + avg_metrics = {k: sum(v) / len(v) if v else 0.0 for k, v in metrics.items()} + + example_ids = [s.get("example_id", 0) for s in states] + num_examples = len(set(example_ids)) if example_ids else 0 + rollouts_per_example = len(states) // num_examples if num_examples > 0 else 1 + + path_to_save = results_path or get_results_path(env_id, model) + + def tools_key(tools: list | None) -> str: + if not tools: + return "" + return str(sorted([t.get("function", {}).get("name", "") for t in tools])) + + all_tools = [s.get("oai_tools") for s in states] + unique_tools = set(tools_key(t) for t in all_tools) + tools = next((t for t in all_tools if t), None) if len(unique_tools) == 1 else None + + return GenerateMetadata( + env_id=env_id, + env_args=env_args, + model=model, + base_url=base_url, + num_examples=num_examples, + rollouts_per_example=rollouts_per_example, + sampling_args=sampling_args, + date=datetime.now().isoformat(), + time_ms=(time.time() - start_time) * 1000.0, + avg_reward=avg_reward, + avg_metrics=avg_metrics, + state_columns=state_columns or [], + path_to_save=path_to_save, + tools=tools, + ) + + +def get_hf_hub_dataset_name(outputs: GenerateOutputs) -> str: + """Auto-generates a dataset name.""" + metadata = outputs["metadata"] + dataset_name = ( + metadata["env_id"] + + "_" + + metadata["model"].replace("/", "_") + + "_n" + + str(metadata["num_examples"]) + + "_r" + + str(metadata["rollouts_per_example"]) + ) + return dataset_name + + +def sanitize_rollouts(rollouts: list[RolloutOutput]) -> list[dict]: + """Sanitizes a list of rollouts before saving to disk.""" + + def sanitize_rollout(rollout: RolloutOutput) -> dict: + sanitized_rollout = dict(rollout) + # sanitize messages + sanitized_rollout["prompt"] = sanitize_tool_calls( + messages_to_printable(rollout["prompt"]) + ) + sanitized_rollout["completion"] = sanitize_tool_calls( + messages_to_printable(rollout["completion"]) + ) + # use str repr for error + sanitized_rollout["error"] = repr(rollout.get("error")) + # only include optional fields if present + if not rollout.get("answer"): + sanitized_rollout.pop("answer") + if not rollout.get("info"): + sanitized_rollout.pop("info") + # flatten metrics + rollout_metrics = rollout.get("metrics", {}) + for k, v in rollout_metrics.items(): + sanitized_rollout[k] = v + + return sanitized_rollout + + return [sanitize_rollout(rollout) for rollout in rollouts] + + +def sanitize_metadata(metadata: GenerateMetadata) -> dict: + """Sanitizes metadata before saving to disk.""" + + metadata_dict = dict(metadata) + metadata_dict.pop("path_to_save") + metadata_dict.pop("date") + + return metadata_dict + + +def save_to_disk(rollouts: list[dict], metadata: dict, path: Path): + """Saves (sanitized) rollouts and metadata to disk.""" + path.mkdir(parents=True, exist_ok=True) + + def save_results(results_list: list[dict], results_path: Path): + with open(results_path, "w") as f: + for idx, result in enumerate(results_list): + example_id = result.get("example_id") or "unknown" + try: + json.dump(result, f) + f.write("\n") + except Exception as e: + logger.error( + f"Failed to save rollout with index {idx} ({example_id=}): {e}" + ) + + def save_metadata(metadata_dict: dict, metadata_path: Path): + with open(metadata_path, "w") as f: + try: + json.dump(metadata_dict, f) + except Exception as e: + logger.error(f"Failed to save metadata: {e}") + + save_metadata(metadata, path / "metadata.json") + save_results(rollouts, path / "results.jsonl") + + +def save_generate_outputs( + outputs: GenerateOutputs, + push_to_hf_hub: bool = False, + hf_hub_dataset_name: str | None = None, +): + path_to_save = outputs["metadata"]["path_to_save"] + sanitized_rollouts = sanitize_rollouts(outputs["rollouts"]) + sanitized_metadata = sanitize_metadata(outputs["metadata"]) + save_to_disk(sanitized_rollouts, sanitized_metadata, path_to_save) + logger.info(f"Results saved to {path_to_save}") + if push_to_hf_hub: + dataset_name = hf_hub_dataset_name or get_hf_hub_dataset_name(outputs) + try: + Dataset.from_list(sanitized_rollouts).push_to_hub(dataset_name) + logger.info(f"Dataset saved to Hugging Face Hub: {dataset_name}") + except Exception as e: + logger.error(f"Error pushing dataset to Hugging Face Hub: {e}")