From c8515d0d0a088cddfc24b17416452c2accf985bd Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Fri, 23 Jan 2026 14:08:13 +0000 Subject: [PATCH 01/11] do not use dataset for local saving --- verifiers/envs/environment.py | 4 +- verifiers/utils/eval_utils.py | 165 +++++++++++++++++++--------------- 2 files changed, 93 insertions(+), 76 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index e8842dba4..6953951d5 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -52,7 +52,7 @@ ) from verifiers.utils.async_utils import maybe_retry, maybe_semaphore from verifiers.utils.error_utils import ErrorChain -from verifiers.utils.eval_utils import make_dataset, save_rollout_results +from verifiers.utils.eval_utils import build_results, save_rollout_results from verifiers.utils.message_utils import ( strip_nones_from_content, ) @@ -1235,7 +1235,7 @@ def set_score_rollouts(self, score_rollouts: bool) -> None: """Set the score rollouts flag for this environment.""" self.score_rollouts = score_rollouts - make_dataset = staticmethod(make_dataset) + make_dataset = staticmethod(build_results) # backwards compatibility _EnvT = TypeVar("_EnvT", bound=Environment) diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 8035967c6..02deb4d72 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -4,9 +4,8 @@ import logging import time from collections import Counter, defaultdict -from contextlib import contextmanager from pathlib import Path -from typing import cast +from typing import Any, cast try: import tomllib # type: ignore[import-not-found] @@ -14,8 +13,7 @@ import tomli as tomllib # type: ignore[import-not-found] import numpy as np -from datasets import Dataset, disable_progress_bar, enable_progress_bar -from datasets.utils import logging as ds_logging +from datasets import Dataset import verifiers as vf from verifiers.types import ( @@ -513,14 +511,6 @@ def on_log(message: str) -> None: display.print_final_summary() -def sanitize_metadata(metadata: GenerateMetadata) -> dict: - metadata_dict = dict(metadata) - metadata_dict.pop("path_to_save") - metadata_dict.pop("date") - - return metadata_dict - - def get_hf_hub_dataset_name(results: GenerateOutputs) -> str: metadata = results["metadata"] dataset_name = ( @@ -535,65 +525,89 @@ def get_hf_hub_dataset_name(results: GenerateOutputs) -> str: return dataset_name -def make_dataset(results: GenerateOutputs, **kwargs) -> Dataset: - clean_prompts = [messages_to_printable(p) for p in results["prompt"]] - clean_prompts = [sanitize_tool_calls(p) for p in clean_prompts] - clean_completions = [messages_to_printable(c) for c in results["completion"]] - clean_completions = [sanitize_tool_calls(c) for c in clean_completions] - save_info = any(info != {} for info in results["info"]) - save_answer = any(answer != "" for answer in results["answer"]) - errors = [s.get("error") for s in results["state"]] - results_dict = { - "example_id": results["example_id"], - "prompt": clean_prompts, - "completion": clean_completions, - "task": results["task"], - "reward": results["reward"], - "error": [repr(e) if e is not None else None for e in errors], - "generation_ms": [s["timing"]["generation_ms"] for s in results["state"]], - "scoring_ms": [s["timing"]["scoring_ms"] for s in results["state"]], - "total_ms": [s["timing"]["total_ms"] for s in results["state"]], - } - if save_info: - results_dict["info"] = results["info"] - if save_answer: - results_dict["answer"] = results["answer"] - for k in results["metrics"]: - v = results["metrics"][k] - results_dict[k] = v - - # Add selected state columns if specified - state_columns = results["metadata"]["state_columns"] - if state_columns: - for col in state_columns: - if col == "responses": - results_dict[col] = [ - [r.model_dump() for r in s.get(col, [])] for s in results["state"] - ] - else: - results_dict[col] = [s.get(col) for s in results["state"]] - - return Dataset.from_dict(results_dict) - - -@contextmanager -def quiet_datasets(): - prev_level = ds_logging.get_verbosity() - ds_logging.set_verbosity(ds_logging.WARNING) - disable_progress_bar() - try: - yield - finally: - ds_logging.set_verbosity(prev_level) - enable_progress_bar() +def to_col_order(list_of_dicts: list[dict[str, Any]]) -> dict[str, list[Any]]: + """Converts a list of dicts (row-ordered dataset) to a dict of lists (col-ordered dataset).""" + return {k: [d[k] for d in list_of_dicts] for k in list_of_dicts[0]} -def save_to_disk(dataset: Dataset, metadata_dict: dict, path_to_save: Path): - path_to_save.parent.mkdir(parents=True, exist_ok=True) - with quiet_datasets(): - dataset.to_json(path_to_save / "results.jsonl") - with open(path_to_save / "metadata.json", "w") as f: - json.dump(metadata_dict, f) +def to_row_order(dict_of_lists: dict[str, list[Any]]) -> list[dict[str, Any]]: + """Converts a dict of lists (col-ordered dataset) to a list of dicts (row-ordered dataset).""" + return [ + {k: v for k, v in zip(dict_of_lists.keys(), values)} + for values in zip(*dict_of_lists.values()) + ] + + +def build_results(results: GenerateOutputs) -> list[dict]: + """Builds list of results to save to disk from GenerateOutputs.""" + raw_results_list = to_row_order(cast(dict[str, list[Any]], results)) + metadata = cast(GenerateMetadata, results["metadata"]) + state_columns = metadata.get("state_columns", []) + results_list = [] + for raw_result in raw_results_list: + try: + clean_prompt = sanitize_tool_calls( + messages_to_printable(raw_result["prompt"]) + ) + clean_completion = sanitize_tool_calls( + messages_to_printable(raw_result["completion"]) + ) + + result_dict = { + "example_id": raw_result["example_id"], + "prompt": clean_prompt, + "completion": clean_completion, + "task": raw_result["task"], + "reward": raw_result["reward"], + "error": raw_result["state"].get("error"), + "generation_ms": raw_result["state"]["timing"]["generation_ms"], + "scoring_ms": raw_result["state"]["timing"]["scoring_ms"], + "total_ms": raw_result["state"]["timing"]["total_ms"], + **{k: v for k, v in raw_result["state"]["metrics"].items()}, + } + + if raw_result.get("info"): + result_dict["info"] = raw_result["info"] + if raw_result.get("answer"): + result_dict["answer"] = raw_result["answer"] + + # add selected state columns if specified + for col in state_columns: + result_dict[col] = raw_result["state"].get(col) + + results_list.append(result_dict) + except Exception as e: + logger.error( + f"Skipping saving result for example {raw_result['example_id']}: {repr(e)}" + ) + + return results_list + + +def build_metadata(metadata: GenerateMetadata) -> dict: + """Builds metadata dict to save from GenerateMetadata.""" + metadata_dict = dict(metadata) + metadata_dict.pop("path_to_save") + metadata_dict.pop("date") + + return metadata_dict + + +def save_to_disk(results_dict: list[dict], metadata_dict: dict, path_to_save: Path): + path_to_save.mkdir(parents=True, exist_ok=True) + + def save_results(results_list: list[dict], path: Path): + with open(path, "w") as f: + for result in results_list: + json.dump(result, f) + f.write("\n") + + def save_metadata(metadata_dict: dict, path: Path): + with open(path, "w") as f: + json.dump(metadata_dict, f) + + save_results(results_dict, path_to_save / "results.jsonl") + save_metadata(metadata_dict, path_to_save / "metadata.json") def save_rollout_results( @@ -603,11 +617,14 @@ def save_rollout_results( ): path_to_save = results["metadata"]["path_to_save"] path_to_save.parent.mkdir(parents=True, exist_ok=True) - dataset = make_dataset(results) - metadata_dict = sanitize_metadata(results["metadata"]) - save_to_disk(dataset, metadata_dict, path_to_save) + results_list = build_results(results) + metadata_dict = build_metadata(results["metadata"]) + save_to_disk(results_list, metadata_dict, path_to_save) logger.info(f"Results saved to {path_to_save}") if push_to_hf_hub: dataset_name = hf_hub_dataset_name or get_hf_hub_dataset_name(results) - dataset.push_to_hub(dataset_name) - logger.info(f"Dataset saved to Hugging Face Hub: {dataset_name}") + try: + Dataset.from_list(results_list).push_to_hub(dataset_name) + logger.info(f"Dataset saved to Hugging Face Hub: {dataset_name}") + except Exception as e: + logger.error(f"Error pushing dataset to Hugging Face Hub: {e}") From dc6e80efacba9c5ea79eacc8cacd79448725cd2e Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Fri, 23 Jan 2026 14:26:00 +0000 Subject: [PATCH 02/11] correctly convert generate outputs --- verifiers/utils/eval_utils.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 02deb4d72..187f76ce6 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -525,22 +525,36 @@ def get_hf_hub_dataset_name(results: GenerateOutputs) -> str: return dataset_name -def to_col_order(list_of_dicts: list[dict[str, Any]]) -> dict[str, list[Any]]: +def to_col_order( + list_of_dicts: list[dict[str, Any]], ignore_keys: list[str] = [] +) -> dict[str, list[Any]]: """Converts a list of dicts (row-ordered dataset) to a dict of lists (col-ordered dataset).""" - return {k: [d[k] for d in list_of_dicts] for k in list_of_dicts[0]} + list_keys = [k for k in list_of_dicts[0].keys() if k not in ignore_keys] + return {k: [d[k] for d in list_of_dicts] for k in list_keys} -def to_row_order(dict_of_lists: dict[str, list[Any]]) -> list[dict[str, Any]]: +def to_row_order( + dict_of_lists: dict[str, list[Any]], ignore_keys: list[str] = [] +) -> list[dict[str, Any]]: """Converts a dict of lists (col-ordered dataset) to a list of dicts (row-ordered dataset).""" - return [ - {k: v for k, v in zip(dict_of_lists.keys(), values)} - for values in zip(*dict_of_lists.values()) - ] + list_keys = [k for k in dict_of_lists.keys() if k not in ignore_keys] + list_values = [dict_of_lists[k] for k in list_keys] + return [{k: v for k, v in zip(list_keys, values)} for values in zip(*list_values)] def build_results(results: GenerateOutputs) -> list[dict]: """Builds list of results to save to disk from GenerateOutputs.""" - raw_results_list = to_row_order(cast(dict[str, list[Any]], results)) + + def get_results_list(results: GenerateOutputs) -> list[dict]: + """Converts GenerateOutputs to a list of dicts.""" + results_list = to_row_order(results, ignore_keys=["metrics", "metadata"]) # type: ignore + metrics_list = to_row_order(results["metrics"]) # type: ignore + return [ + {**result_dict, **metrics_dict} + for result_dict, metrics_dict in zip(results_list, metrics_list) + ] + + raw_results_list = get_results_list(results) metadata = cast(GenerateMetadata, results["metadata"]) state_columns = metadata.get("state_columns", []) results_list = [] From ff7649571815c79f255544e2f642839ee98c5225 Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Fri, 23 Jan 2026 15:29:18 +0000 Subject: [PATCH 03/11] rewrite generate outputs to row-order --- verifiers/envs/environment.py | 118 ++++----------- verifiers/types.py | 30 ++-- verifiers/utils/eval_utils.py | 262 +++++++++++++--------------------- verifiers/utils/save_utils.py | 99 +++++++++++++ 4 files changed, 247 insertions(+), 262 deletions(-) create mode 100644 verifiers/utils/save_utils.py diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 6953951d5..76c1db005 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -10,7 +10,6 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from copy import deepcopy -from datetime import datetime from pathlib import Path from typing import ( TYPE_CHECKING, @@ -37,7 +36,6 @@ ChatCompletionToolParam, ChatMessage, DatasetBuilder, - GenerateMetadata, GenerateOutputs, LogCallback, Messages, @@ -52,11 +50,14 @@ ) from verifiers.utils.async_utils import maybe_retry, maybe_semaphore from verifiers.utils.error_utils import ErrorChain -from verifiers.utils.eval_utils import build_results, save_rollout_results +from verifiers.utils.eval_utils import sanitize_rollouts, save_generate_outputs from verifiers.utils.message_utils import ( strip_nones_from_content, ) -from verifiers.utils.path_utils import get_results_path +from verifiers.utils.save_utils import ( + states_to_generate_metadata, + states_to_rollout_outputs, +) from verifiers.utils.token_utils import ( get_prompt_ids, prepare_sampling_args_for_token_prompts, @@ -794,93 +795,30 @@ async def run_group( await self.rubric.dummy_score_group(group_states) return list(group_states) - def _prepare_rollout_results( + def _build_generate_outputs( self, all_states: list[State], model: str, client: AsyncOpenAI, state_columns: list[str] | None, results_path: Path | None, - gen_sampling_args: SamplingArgs, + sampling_args: SamplingArgs, start_time: float, ) -> GenerateOutputs: """Prepare GenerateOutputs from a list of completed states.""" - # Determine path_to_save - if results_path is None: - path_to_save = get_results_path(self.env_id, model) - else: - path_to_save = results_path - prompts = [state["prompt"] for state in all_states] - completions = [state.get("completion") for state in all_states] - answers = [state.get("answer", "") for state in all_states] - tasks = [state.get("task", "default") for state in all_states] - infos = [state.get("info", {}) for state in all_states] - example_ids = [state.get("example_id", 0) for state in all_states] - rewards = [state.get("reward", 0.0) for state in all_states] - stop_conditions = [state.get("stop_condition", None) for state in all_states] - is_truncated = [state.get("is_truncated", False) for state in all_states] - - metrics: dict[str, list[float]] = {} - for state in all_states: - if state.get("metrics"): - for metric_name, metric_value in state["metrics"].items(): - if metric_name not in metrics: - metrics[metric_name] = [] - metrics[metric_name].append(metric_value) - - num_unique_examples = len(set(example_ids)) if example_ids else 0 - rollouts_per_example = ( - len(all_states) // num_unique_examples if num_unique_examples > 0 else 1 - ) - - def _tools_key(tools: list | None) -> str: - if not tools: - return "" - return str(sorted([t.get("function", {}).get("name", "") for t in tools])) - - all_tools = [state.get("oai_tools") for state in all_states] - unique_tool_sets = set(_tools_key(t) for t in all_tools) - tools_vary = len(unique_tool_sets) > 1 - - if tools_vary: - metadata_tools = None - else: - metadata_tools = next((t for t in all_tools if t), self.oai_tools) - - metadata = GenerateMetadata( - env_id=self.env_id, - env_args=self.env_args, - model=model, - base_url=str(client.base_url) if hasattr(client, "base_url") else "", - num_examples=num_unique_examples, - rollouts_per_example=rollouts_per_example, - sampling_args=gen_sampling_args, - date=datetime.now().isoformat(), - time_ms=(time.time() - start_time) * 1000.0, - avg_reward=sum(rewards) / len(rewards) if rewards else 0.0, - avg_metrics={ - name: sum(values) / len(values) if values else 0.0 - for name, values in metrics.items() - }, - state_columns=state_columns or [], - path_to_save=path_to_save, - tools=metadata_tools, - ) - - return GenerateOutputs( - prompt=prompts, - completion=completions, - answer=answers, - state=all_states, - task=tasks, - info=infos, - example_id=example_ids, - reward=rewards, - metrics=metrics, - stop_conditions=stop_conditions, - is_truncated=is_truncated, - metadata=metadata, + rollouts = states_to_rollout_outputs(all_states, state_columns or []) + metadata = states_to_generate_metadata( + self.env_id, + self.env_args, + model, + client, + all_states, + state_columns, + sampling_args, + start_time, + results_path, ) + return GenerateOutputs(rollouts=rollouts, metadata=metadata) async def generate( self, @@ -1017,7 +955,7 @@ async def generate( and save_every > 0 and groups_or_rollouts_completed % save_every == 0 ): - temp_results = self._prepare_rollout_results( + intermediate_outputs = self._build_generate_outputs( all_states, model, client, @@ -1027,9 +965,9 @@ async def generate( start_time, ) self.logger.debug( - f"Saving intermediate results to {temp_results['metadata']['path_to_save']}" + f"Saving intermediate results to {intermediate_outputs['metadata']['path_to_save']}" ) - save_rollout_results(temp_results) + save_generate_outputs(intermediate_outputs) finally: if pbar is not None: pbar.close() @@ -1037,7 +975,7 @@ async def generate( # sort by example_id to ensure deterministic ordering regardless of completion order all_states.sort(key=lambda s: s.get("example_id", 0)) - results = self._prepare_rollout_results( + outputs = self._build_generate_outputs( all_states, model, client, @@ -1049,11 +987,11 @@ async def generate( # save if requested if save_results: - save_rollout_results(results, push_to_hf_hub, hf_hub_dataset_name) + save_generate_outputs(outputs) if on_log is not None: - on_log(f"Saved final results to {results['metadata']['path_to_save']}") + on_log(f"Saved final outputs to {outputs['metadata']['path_to_save']}") - return results + return outputs def generate_sync( self, @@ -1235,7 +1173,9 @@ def set_score_rollouts(self, score_rollouts: bool) -> None: """Set the score rollouts flag for this environment.""" self.score_rollouts = score_rollouts - make_dataset = staticmethod(build_results) # backwards compatibility + make_dataset = staticmethod( + lambda x: Dataset.from_list(sanitize_rollouts(x["rollouts"])) + ) # backwards compatibility _EnvT = TypeVar("_EnvT", bound=Environment) diff --git a/verifiers/types.py b/verifiers/types.py index 17f5b0243..5ce238f91 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -167,20 +167,28 @@ class GenerateMetadata(TypedDict): tools: list[ChatCompletionToolParam] | None +class RolloutOutput(TypedDict): + """TypedDict for generation outputs.""" + + example_id: int + prompt: Messages + completion: Messages + answer: str + task: str + info: Info + tools: dict + reward: float + metrics: dict[str, float] + stop_condition: str | None + is_truncated: bool + timing: RolloutTiming + error: Error | None + + class GenerateOutputs(TypedDict): """TypedDict for generation outputs.""" - prompt: list[Messages] - completion: list[Messages] - answer: list[str] - state: list[State] - task: list[str] - info: list[Info] - example_id: list[int] - reward: list[float] - metrics: dict[str, list[float]] - stop_conditions: list[str | None] - is_truncated: list[bool] + rollouts: list[RolloutOutput] metadata: GenerateMetadata diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 187f76ce6..2dfda28d6 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -5,7 +5,9 @@ import time from collections import Counter, defaultdict from pathlib import Path -from typing import Any, cast +from typing import cast + +from verifiers.utils.save_utils import to_col_order try: import tomllib # type: ignore[import-not-found] @@ -24,6 +26,7 @@ GenerateOutputs, LogCallback, ProgressCallback, + RolloutOutput, StartCallback, State, ) @@ -179,54 +182,35 @@ def load_toml_config(path: Path) -> list[dict]: return merged_eval_list -def get_results_by_task(results: GenerateOutputs) -> dict[str, GenerateOutputs]: - """Group results by task name. - - Args: - results: The GenerateOutputs from an evaluation run. +def get_task_outputs(results: GenerateOutputs, task: str) -> GenerateOutputs: + """Get only the rollouts for a given task.""" + rollouts = [r for r in results["rollouts"] if r["task"] == task] + return GenerateOutputs( + rollouts=rollouts, + metadata=results["metadata"], # duplicate metadata + ) - Returns: - A dictionary mapping task names to their corresponding GenerateOutputs. - """ - task_indices: dict[str, list[int]] = {} - for i, task in enumerate(results["task"]): - if task not in task_indices: - task_indices[task] = [] - task_indices[task].append(i) - - task_results: dict[str, GenerateOutputs] = {} - for task, indices in task_indices.items(): - task_results[task] = GenerateOutputs( - prompt=[results["prompt"][i] for i in indices], - completion=[results["completion"][i] for i in indices], - answer=[results["answer"][i] for i in indices], - state=[results["state"][i] for i in indices], - task=[results["task"][i] for i in indices], - info=[results["info"][i] for i in indices], - example_id=[results["example_id"][i] for i in indices], - reward=[results["reward"][i] for i in indices], - metrics={k: [v[i] for i in indices] for k, v in results["metrics"].items()}, - stop_conditions=[results["stop_conditions"][i] for i in indices], - is_truncated=[results["is_truncated"][i] for i in indices], - metadata=results["metadata"], - ) - return task_results +def print_rewards(outputs: GenerateOutputs): + metadata = outputs["metadata"] + n = metadata["num_examples"] + r = metadata["rollouts_per_example"] -def print_rewards(results: GenerateOutputs): + rewards = [r["reward"] for r in outputs["rollouts"]] print("Rewards:") print( - f"reward: avg - {sum(results['reward']) / len(results['reward']):.3f}, std - {np.std(results['reward']):.3f}" + f"reward: avg - {sum(rewards) / len(rewards):.3f}, std - {np.std(rewards):.3f}" ) - r = results["metadata"]["rollouts_per_example"] - n = len(results["reward"]) // r # results are sorted by example_id, so rollout i is at indices [i, i+r, i+2r, ...] for i in range(r): - trials = [round(results["reward"][i + (j * r)], 3) for j in range(n)] + trials = [round(rewards[i + (j * r)], 3) for j in range(n)] out = f"r{i + 1}: {trials}" print(out) - for k in results["metrics"]: - v = results["metrics"][k] + + metrics = [r["metrics"] for r in outputs["rollouts"]] + metrics_col = to_col_order(metrics) + for k in metrics_col.keys(): + v = metrics_col[k] print(f"{k}: avg - {sum(v) / len(v):.3f}, std - {np.std(v):.3f}") for i in range(r): trials = [round(v[i + (j * r)], 3) for j in range(n)] @@ -234,16 +218,18 @@ def print_rewards(results: GenerateOutputs): print(out) -def print_info(results: GenerateOutputs): +def print_info(outputs: GenerateOutputs): + is_truncated = [r["is_truncated"] for r in outputs["rollouts"]] print("Info:") print( - f"is_truncated: avg - {np.mean(results['is_truncated']):.3f}, std - {np.std(results['is_truncated']):.3f}" + f"is_truncated: avg - {np.mean(is_truncated):.3f}, std - {np.std(is_truncated):.3f}" ) - counter = Counter(results["stop_conditions"]) + stop_conditions = [r["stop_condition"] for r in outputs["rollouts"]] + counter = Counter(stop_conditions) print( f"stop_conditions: {', '.join([f'{k}: {v / counter.total():.3f}' for k, v in counter.items()])}" ) - errors = [s.get("error") for s in results["state"]] + errors = [r.get("error") for r in outputs["rollouts"]] has_errors = [e is not None for e in errors] if any(has_errors): print( @@ -256,13 +242,13 @@ def print_info(results: GenerateOutputs): print(f" - {repr(error_chain)}: {count / counter.total():.3f}") -def print_timing(results: GenerateOutputs): +def print_timing(outputs: GenerateOutputs): print("Timing:") - generation_ms_arr = np.array( - [s["timing"]["generation_ms"] for s in results["state"]] - ) - scoring_ms_arr = np.array([s["timing"]["scoring_ms"] for s in results["state"]]) - total_ms_arr = np.array([s["timing"]["total_ms"] for s in results["state"]]) + timing = [r["timing"] for r in outputs["rollouts"]] + timing_col = to_col_order(cast(list[dict], timing)) + generation_ms_arr = np.array(timing_col["generation_ms"]) + scoring_ms_arr = np.array(timing_col["scoring_ms"]) + total_ms_arr = np.array(timing_col["total_ms"]) generation_arr = generation_ms_arr / 1000 scoring_arr = scoring_ms_arr / 1000 total_arr = total_ms_arr / 1000 @@ -278,43 +264,45 @@ def print_timing(results: GenerateOutputs): ) -def print_results( - results: GenerateOutputs, - num_samples: int = 1, -): - assert results["metadata"] is not None +def print_results(outputs: GenerateOutputs, num_samples: int = 1): + assert outputs["metadata"] is not None print("--- Evaluation ---") - print(f"Environment: {results['metadata']['env_id']}") - print(f"Model: {results['metadata']['model']}") - print(f"Provider: {results['metadata']['base_url']}") - print(f"Examples: {results['metadata']['num_examples']}") - print(f"Rollouts per example: {results['metadata']['rollouts_per_example']}") + print(f"Environment: {outputs['metadata']['env_id']}") + print(f"Model: {outputs['metadata']['model']}") + print(f"Provider: {outputs['metadata']['base_url']}") + print(f"Examples: {outputs['metadata']['num_examples']}") + print(f"Rollouts per example: {outputs['metadata']['rollouts_per_example']}") print("--- Example ---") - printable_prompts = [messages_to_printable(p) for p in results["prompt"]] - printable_completions = [messages_to_printable(c) for c in results["completion"]] - errors = [s.get("error") for s in results["state"]] + printable_prompts = [ + messages_to_printable(r["prompt"]) for r in outputs["rollouts"] + ] + printable_completions = [ + messages_to_printable(r["completion"]) for r in outputs["rollouts"] + ] + rewards = [r["reward"] for r in outputs["rollouts"]] + errors = [r.get("error") for r in outputs["rollouts"]] print_prompt_completions_sample( printable_prompts, printable_completions, errors, - results["reward"], + rewards, step=0, num_samples=num_samples, ) print("--- All ---") - print_rewards(results) - print_info(results) - print_timing(results) - - num_tasks = len(set(results["task"])) - if num_tasks > 1: - task_results = get_results_by_task(results) - for task, task_results in task_results.items(): + print_rewards(outputs) + print_info(outputs) + print_timing(outputs) + + tasks = set([r["task"] for r in outputs["rollouts"]]) + if len(tasks) > 1: + for task in tasks: + task_outputs = get_task_outputs(outputs, task) print(f"\n--- {task} ---") - print_rewards(task_results) - print_info(task_results) - print_timing(task_results) + print_rewards(task_outputs) + print_info(task_outputs) + print_timing(task_outputs) async def run_evaluation( @@ -345,7 +333,7 @@ async def run_evaluation( ) # disable tqdm when callbacks are provided (TUI handles progress display) use_tqdm = config.use_tqdm and on_progress is None - results = await vf_env.evaluate( + outputs = await vf_env.evaluate( client=client, model=config.model, sampling_args=config.sampling_args, @@ -368,7 +356,12 @@ async def run_evaluation( on_log=on_log, ) - return results + if config.save_results: + save_generate_outputs( + outputs, config.save_to_hf_hub, config.hf_hub_dataset_name + ) + + return outputs async def run_evaluations(config: EvalRunConfig) -> None: @@ -511,8 +504,9 @@ def on_log(message: str) -> None: display.print_final_summary() -def get_hf_hub_dataset_name(results: GenerateOutputs) -> str: - metadata = results["metadata"] +def get_hf_hub_dataset_name(outputs: GenerateOutputs) -> str: + """Auto-generates a dataset name.""" + metadata = outputs["metadata"] dataset_name = ( metadata["env_id"] + "_" @@ -525,81 +519,25 @@ def get_hf_hub_dataset_name(results: GenerateOutputs) -> str: return dataset_name -def to_col_order( - list_of_dicts: list[dict[str, Any]], ignore_keys: list[str] = [] -) -> dict[str, list[Any]]: - """Converts a list of dicts (row-ordered dataset) to a dict of lists (col-ordered dataset).""" - list_keys = [k for k in list_of_dicts[0].keys() if k not in ignore_keys] - return {k: [d[k] for d in list_of_dicts] for k in list_keys} - - -def to_row_order( - dict_of_lists: dict[str, list[Any]], ignore_keys: list[str] = [] -) -> list[dict[str, Any]]: - """Converts a dict of lists (col-ordered dataset) to a list of dicts (row-ordered dataset).""" - list_keys = [k for k in dict_of_lists.keys() if k not in ignore_keys] - list_values = [dict_of_lists[k] for k in list_keys] - return [{k: v for k, v in zip(list_keys, values)} for values in zip(*list_values)] - - -def build_results(results: GenerateOutputs) -> list[dict]: - """Builds list of results to save to disk from GenerateOutputs.""" - - def get_results_list(results: GenerateOutputs) -> list[dict]: - """Converts GenerateOutputs to a list of dicts.""" - results_list = to_row_order(results, ignore_keys=["metrics", "metadata"]) # type: ignore - metrics_list = to_row_order(results["metrics"]) # type: ignore - return [ - {**result_dict, **metrics_dict} - for result_dict, metrics_dict in zip(results_list, metrics_list) - ] - - raw_results_list = get_results_list(results) - metadata = cast(GenerateMetadata, results["metadata"]) - state_columns = metadata.get("state_columns", []) - results_list = [] - for raw_result in raw_results_list: - try: - clean_prompt = sanitize_tool_calls( - messages_to_printable(raw_result["prompt"]) - ) - clean_completion = sanitize_tool_calls( - messages_to_printable(raw_result["completion"]) - ) +def sanitize_rollouts(rollouts: list[RolloutOutput]) -> list[dict]: + """Sanitizes a list of rollouts before saving to disk.""" - result_dict = { - "example_id": raw_result["example_id"], - "prompt": clean_prompt, - "completion": clean_completion, - "task": raw_result["task"], - "reward": raw_result["reward"], - "error": raw_result["state"].get("error"), - "generation_ms": raw_result["state"]["timing"]["generation_ms"], - "scoring_ms": raw_result["state"]["timing"]["scoring_ms"], - "total_ms": raw_result["state"]["timing"]["total_ms"], - **{k: v for k, v in raw_result["state"]["metrics"].items()}, - } - - if raw_result.get("info"): - result_dict["info"] = raw_result["info"] - if raw_result.get("answer"): - result_dict["answer"] = raw_result["answer"] - - # add selected state columns if specified - for col in state_columns: - result_dict[col] = raw_result["state"].get(col) - - results_list.append(result_dict) - except Exception as e: - logger.error( - f"Skipping saving result for example {raw_result['example_id']}: {repr(e)}" - ) + def sanitize_rollout(rollout: RolloutOutput) -> dict: + sanitized_rollout = dict(rollout) + sanitized_rollout["prompt"] = sanitize_tool_calls( + messages_to_printable(rollout["prompt"]) + ) + sanitized_rollout["completion"] = sanitize_tool_calls( + messages_to_printable(rollout["completion"]) + ) + return sanitized_rollout + + return [sanitize_rollout(rollout) for rollout in rollouts] - return results_list +def sanitize_metadata(metadata: GenerateMetadata) -> dict: + """Sanitizes metadata before saving to disk.""" -def build_metadata(metadata: GenerateMetadata) -> dict: - """Builds metadata dict to save from GenerateMetadata.""" metadata_dict = dict(metadata) metadata_dict.pop("path_to_save") metadata_dict.pop("date") @@ -607,8 +545,9 @@ def build_metadata(metadata: GenerateMetadata) -> dict: return metadata_dict -def save_to_disk(results_dict: list[dict], metadata_dict: dict, path_to_save: Path): - path_to_save.mkdir(parents=True, exist_ok=True) +def save_to_disk(rollouts: list[dict], metadata: dict, path: Path): + """Saves (sanitized) rollouts and metadata to disk.""" + path.mkdir(parents=True, exist_ok=True) def save_results(results_list: list[dict], path: Path): with open(path, "w") as f: @@ -620,25 +559,24 @@ def save_metadata(metadata_dict: dict, path: Path): with open(path, "w") as f: json.dump(metadata_dict, f) - save_results(results_dict, path_to_save / "results.jsonl") - save_metadata(metadata_dict, path_to_save / "metadata.json") + save_results(rollouts, path / "results.jsonl") + save_metadata(metadata, path / "metadata.json") -def save_rollout_results( - results: GenerateOutputs, +def save_generate_outputs( + outputs: GenerateOutputs, push_to_hf_hub: bool = False, hf_hub_dataset_name: str | None = None, ): - path_to_save = results["metadata"]["path_to_save"] - path_to_save.parent.mkdir(parents=True, exist_ok=True) - results_list = build_results(results) - metadata_dict = build_metadata(results["metadata"]) - save_to_disk(results_list, metadata_dict, path_to_save) + path_to_save = outputs["metadata"]["path_to_save"] + sanitized_rollouts = sanitize_rollouts(outputs["rollouts"]) + sanitized_metadata = sanitize_metadata(outputs["metadata"]) + save_to_disk(sanitized_rollouts, sanitized_metadata, path_to_save) logger.info(f"Results saved to {path_to_save}") if push_to_hf_hub: - dataset_name = hf_hub_dataset_name or get_hf_hub_dataset_name(results) + dataset_name = hf_hub_dataset_name or get_hf_hub_dataset_name(outputs) try: - Dataset.from_list(results_list).push_to_hub(dataset_name) + Dataset.from_list(sanitized_rollouts).push_to_hub(dataset_name) logger.info(f"Dataset saved to Hugging Face Hub: {dataset_name}") except Exception as e: logger.error(f"Error pushing dataset to Hugging Face Hub: {e}") diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py new file mode 100644 index 000000000..d9e98c0e0 --- /dev/null +++ b/verifiers/utils/save_utils.py @@ -0,0 +1,99 @@ +import time +from collections import defaultdict +from datetime import datetime +from pathlib import Path +from typing import Any + +from openai import AsyncOpenAI + +from verifiers.types import GenerateMetadata, RolloutOutput, SamplingArgs, State +from verifiers.utils.path_utils import get_results_path + + +def to_col_order(list_of_dicts: list[dict[str, Any]]) -> dict[str, list[float]]: + return {k: [m[k] for m in list_of_dicts] for k in list_of_dicts[0].keys()} + + +def state_to_rollout_output( + state: State, state_columns: list[str] = [] +) -> RolloutOutput: + rollout_output = RolloutOutput( + example_id=state.get("example_id", 0), + prompt=state.get("prompt"), + completion=state.get("completion"), + answer=state.get("answer", ""), + task=state.get("task", "default"), + info=state.get("info", {}), + tools=state.get("oai_tools", {}), + reward=state.get("reward", 0.0), + metrics=state.get("metrics", {}), + stop_condition=state.get("stop_condition", None), + is_truncated=state.get("is_truncated", False), + timing=state.get("timing", {}), + error=state.get("error", None), + ) + for col in state_columns: + rollout_output[col] = state.get(col) + + return rollout_output + + +def states_to_rollout_outputs( + states: list[State], state_columns: list[str] = [] +) -> list[RolloutOutput]: + return [state_to_rollout_output(state, state_columns) for state in states] + + +def states_to_generate_metadata( + env_id: str, + env_args: dict, + model: str, + client: AsyncOpenAI, + states: list[State], + state_columns: list[str] | None, + sampling_args: SamplingArgs, + start_time: float, + results_path: Path | None, +) -> GenerateMetadata: + base_url = str(client.base_url) if hasattr(client, "base_url") else "" + rewards = [s.get("reward", 0.0) for s in states] + avg_reward = sum(rewards) / len(rewards) if rewards else 0.0 + + metrics: dict[str, list[float]] = defaultdict(list) + for state in states: + if state.get("metrics"): + for metric_name, metric_value in state["metrics"].items(): + metrics[metric_name].append(metric_value) + avg_metrics = {k: sum(v) / len(v) if v else 0.0 for k, v in metrics.items()} + + example_ids = [s.get("example_id", 0) for s in states] + num_examples = len(set(example_ids)) if example_ids else 0 + rollouts_per_example = len(states) // num_examples if num_examples > 0 else 1 + + path_to_save = results_path or get_results_path(env_id, model) + + def tools_key(tools: list | None) -> str: + if not tools: + return "" + return str(sorted([t.get("function", {}).get("name", "") for t in tools])) + + all_tools = [s.get("oai_tools") for s in states] + unique_tools = set(tools_key(t) for t in all_tools) + tools = next((t for t in all_tools if t), None) if len(unique_tools) == 1 else None + + return GenerateMetadata( + env_id=env_id, + env_args=env_args, + model=model, + base_url=base_url, + num_examples=num_examples, + rollouts_per_example=rollouts_per_example, + sampling_args=sampling_args, + date=datetime.now().isoformat(), + time_ms=(time.time() - start_time) * 1000.0, + avg_reward=avg_reward, + avg_metrics=avg_metrics, + state_columns=state_columns or [], + path_to_save=path_to_save, + tools=tools, + ) From d2352c5efbab66d24b5093ce7e2e0c06e306d7ff Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Fri, 23 Jan 2026 15:31:30 +0000 Subject: [PATCH 04/11] only save once --- verifiers/envs/environment.py | 12 ++++++------ verifiers/utils/eval_utils.py | 7 +------ 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 76c1db005..c3a1afbe3 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -833,7 +833,7 @@ async def generate( state_columns: list[str] | None = None, save_results: bool = False, save_every: int = -1, - push_to_hf_hub: bool = False, + save_to_hf_hub: bool = False, hf_hub_dataset_name: str | None = None, use_tqdm: bool = True, independent_scoring: bool = False, @@ -987,7 +987,7 @@ async def generate( # save if requested if save_results: - save_generate_outputs(outputs) + save_generate_outputs(outputs, save_to_hf_hub, hf_hub_dataset_name) if on_log is not None: on_log(f"Saved final outputs to {outputs['metadata']['path_to_save']}") @@ -1054,7 +1054,7 @@ async def evaluate( state_columns: list[str] | None = None, save_results: bool = False, save_every: int = -1, - push_to_hf_hub: bool = False, + save_to_hf_hub: bool = False, hf_hub_dataset_name: str | None = None, use_tqdm: bool = True, independent_scoring: bool = False, @@ -1080,7 +1080,7 @@ async def evaluate( state_columns=state_columns, save_results=save_results, save_every=save_every, - push_to_hf_hub=push_to_hf_hub, + save_to_hf_hub=save_to_hf_hub, hf_hub_dataset_name=hf_hub_dataset_name, use_tqdm=use_tqdm, independent_scoring=independent_scoring, @@ -1105,7 +1105,7 @@ def evaluate_sync( state_columns: list[str] | None = None, save_results: bool = False, save_every: int = -1, - push_to_hf_hub: bool = False, + save_to_hf_hub: bool = False, hf_hub_dataset_name: str | None = None, independent_scoring: bool = False, max_retries: int = 0, @@ -1126,7 +1126,7 @@ def evaluate_sync( state_columns=state_columns, save_results=save_results, save_every=save_every, - push_to_hf_hub=push_to_hf_hub, + save_to_hf_hub=save_to_hf_hub, hf_hub_dataset_name=hf_hub_dataset_name, independent_scoring=independent_scoring, max_retries=max_retries, diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 2dfda28d6..d9e15d5ea 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -346,7 +346,7 @@ async def run_evaluation( state_columns=config.state_columns, save_results=config.save_results, save_every=config.save_every, - push_to_hf_hub=config.save_to_hf_hub, + save_to_hf_hub=config.save_to_hf_hub, hf_hub_dataset_name=config.hf_hub_dataset_name, use_tqdm=use_tqdm, independent_scoring=config.independent_scoring, @@ -356,11 +356,6 @@ async def run_evaluation( on_log=on_log, ) - if config.save_results: - save_generate_outputs( - outputs, config.save_to_hf_hub, config.hf_hub_dataset_name - ) - return outputs From b6fe5eff967e51a1aa726bbf9a7fdfae812ba4cf Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Fri, 23 Jan 2026 15:35:52 +0000 Subject: [PATCH 05/11] repr errors in saved outputs --- verifiers/utils/eval_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index d9e15d5ea..4c17075cd 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -525,6 +525,7 @@ def sanitize_rollout(rollout: RolloutOutput) -> dict: sanitized_rollout["completion"] = sanitize_tool_calls( messages_to_printable(rollout["completion"]) ) + sanitized_rollout["error"] = repr(rollout.get("error")) return sanitized_rollout return [sanitize_rollout(rollout) for rollout in rollouts] From d089d2c8b5bacb82079aa2a1349872e1de60da8f Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Fri, 23 Jan 2026 15:40:48 +0000 Subject: [PATCH 06/11] move to save_utils --- verifiers/envs/environment.py | 3 +- verifiers/utils/eval_utils.py | 94 +++------------------------------- verifiers/utils/save_utils.py | 96 +++++++++++++++++++++++++++++++++-- 3 files changed, 100 insertions(+), 93 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index c3a1afbe3..0ebc205fb 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -50,11 +50,12 @@ ) from verifiers.utils.async_utils import maybe_retry, maybe_semaphore from verifiers.utils.error_utils import ErrorChain -from verifiers.utils.eval_utils import sanitize_rollouts, save_generate_outputs from verifiers.utils.message_utils import ( strip_nones_from_content, ) from verifiers.utils.save_utils import ( + sanitize_rollouts, + save_generate_outputs, states_to_generate_metadata, states_to_rollout_outputs, ) diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 4c17075cd..2426886e1 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -1,13 +1,10 @@ import asyncio import importlib.util -import json import logging import time from collections import Counter, defaultdict from pathlib import Path -from typing import cast - -from verifiers.utils.save_utils import to_col_order +from typing import Any, cast try: import tomllib # type: ignore[import-not-found] @@ -15,18 +12,15 @@ import tomli as tomllib # type: ignore[import-not-found] import numpy as np -from datasets import Dataset import verifiers as vf from verifiers.types import ( Endpoints, EvalConfig, EvalRunConfig, - GenerateMetadata, GenerateOutputs, LogCallback, ProgressCallback, - RolloutOutput, StartCallback, State, ) @@ -34,7 +28,7 @@ from verifiers.utils.client_utils import setup_client from verifiers.utils.error_utils import ErrorChain from verifiers.utils.logging_utils import print_prompt_completions_sample, print_time -from verifiers.utils.message_utils import messages_to_printable, sanitize_tool_calls +from verifiers.utils.message_utils import messages_to_printable from verifiers.utils.path_utils import get_eval_results_path logger = logging.getLogger(__name__) @@ -182,6 +176,11 @@ def load_toml_config(path: Path) -> list[dict]: return merged_eval_list +def to_col_order(list_of_dicts: list[dict[str, Any]]) -> dict[str, list[float]]: + """Convert a list of dictionaries to a dictionary of lists, ordered by the keys of the first dictionary.""" + return {k: [m[k] for m in list_of_dicts] for k in list_of_dicts[0].keys()} + + def get_task_outputs(results: GenerateOutputs, task: str) -> GenerateOutputs: """Get only the rollouts for a given task.""" rollouts = [r for r in results["rollouts"] if r["task"] == task] @@ -497,82 +496,3 @@ def on_log(message: str) -> None: # print final summary after exit display.print_final_summary() - - -def get_hf_hub_dataset_name(outputs: GenerateOutputs) -> str: - """Auto-generates a dataset name.""" - metadata = outputs["metadata"] - dataset_name = ( - metadata["env_id"] - + "_" - + metadata["model"].replace("/", "_") - + "_n" - + str(metadata["num_examples"]) - + "_r" - + str(metadata["rollouts_per_example"]) - ) - return dataset_name - - -def sanitize_rollouts(rollouts: list[RolloutOutput]) -> list[dict]: - """Sanitizes a list of rollouts before saving to disk.""" - - def sanitize_rollout(rollout: RolloutOutput) -> dict: - sanitized_rollout = dict(rollout) - sanitized_rollout["prompt"] = sanitize_tool_calls( - messages_to_printable(rollout["prompt"]) - ) - sanitized_rollout["completion"] = sanitize_tool_calls( - messages_to_printable(rollout["completion"]) - ) - sanitized_rollout["error"] = repr(rollout.get("error")) - return sanitized_rollout - - return [sanitize_rollout(rollout) for rollout in rollouts] - - -def sanitize_metadata(metadata: GenerateMetadata) -> dict: - """Sanitizes metadata before saving to disk.""" - - metadata_dict = dict(metadata) - metadata_dict.pop("path_to_save") - metadata_dict.pop("date") - - return metadata_dict - - -def save_to_disk(rollouts: list[dict], metadata: dict, path: Path): - """Saves (sanitized) rollouts and metadata to disk.""" - path.mkdir(parents=True, exist_ok=True) - - def save_results(results_list: list[dict], path: Path): - with open(path, "w") as f: - for result in results_list: - json.dump(result, f) - f.write("\n") - - def save_metadata(metadata_dict: dict, path: Path): - with open(path, "w") as f: - json.dump(metadata_dict, f) - - save_results(rollouts, path / "results.jsonl") - save_metadata(metadata, path / "metadata.json") - - -def save_generate_outputs( - outputs: GenerateOutputs, - push_to_hf_hub: bool = False, - hf_hub_dataset_name: str | None = None, -): - path_to_save = outputs["metadata"]["path_to_save"] - sanitized_rollouts = sanitize_rollouts(outputs["rollouts"]) - sanitized_metadata = sanitize_metadata(outputs["metadata"]) - save_to_disk(sanitized_rollouts, sanitized_metadata, path_to_save) - logger.info(f"Results saved to {path_to_save}") - if push_to_hf_hub: - dataset_name = hf_hub_dataset_name or get_hf_hub_dataset_name(outputs) - try: - Dataset.from_list(sanitized_rollouts).push_to_hub(dataset_name) - logger.info(f"Dataset saved to Hugging Face Hub: {dataset_name}") - except Exception as e: - logger.error(f"Error pushing dataset to Hugging Face Hub: {e}") diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index d9e98c0e0..f3dd80ea8 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -1,17 +1,24 @@ +import json +import logging import time from collections import defaultdict from datetime import datetime from pathlib import Path -from typing import Any +from datasets import Dataset from openai import AsyncOpenAI -from verifiers.types import GenerateMetadata, RolloutOutput, SamplingArgs, State +from verifiers.types import ( + GenerateMetadata, + GenerateOutputs, + RolloutOutput, + SamplingArgs, + State, +) +from verifiers.utils.message_utils import messages_to_printable, sanitize_tool_calls from verifiers.utils.path_utils import get_results_path - -def to_col_order(list_of_dicts: list[dict[str, Any]]) -> dict[str, list[float]]: - return {k: [m[k] for m in list_of_dicts] for k in list_of_dicts[0].keys()} +logger = logging.getLogger(__name__) def state_to_rollout_output( @@ -97,3 +104,82 @@ def tools_key(tools: list | None) -> str: path_to_save=path_to_save, tools=tools, ) + + +def get_hf_hub_dataset_name(outputs: GenerateOutputs) -> str: + """Auto-generates a dataset name.""" + metadata = outputs["metadata"] + dataset_name = ( + metadata["env_id"] + + "_" + + metadata["model"].replace("/", "_") + + "_n" + + str(metadata["num_examples"]) + + "_r" + + str(metadata["rollouts_per_example"]) + ) + return dataset_name + + +def sanitize_rollouts(rollouts: list[RolloutOutput]) -> list[dict]: + """Sanitizes a list of rollouts before saving to disk.""" + + def sanitize_rollout(rollout: RolloutOutput) -> dict: + sanitized_rollout = dict(rollout) + sanitized_rollout["prompt"] = sanitize_tool_calls( + messages_to_printable(rollout["prompt"]) + ) + sanitized_rollout["completion"] = sanitize_tool_calls( + messages_to_printable(rollout["completion"]) + ) + sanitized_rollout["error"] = repr(rollout.get("error")) + return sanitized_rollout + + return [sanitize_rollout(rollout) for rollout in rollouts] + + +def sanitize_metadata(metadata: GenerateMetadata) -> dict: + """Sanitizes metadata before saving to disk.""" + + metadata_dict = dict(metadata) + metadata_dict.pop("path_to_save") + metadata_dict.pop("date") + + return metadata_dict + + +def save_to_disk(rollouts: list[dict], metadata: dict, path: Path): + """Saves (sanitized) rollouts and metadata to disk.""" + path.mkdir(parents=True, exist_ok=True) + + def save_results(results_list: list[dict], path: Path): + with open(path, "w") as f: + for result in results_list: + json.dump(result, f) + f.write("\n") + + def save_metadata(metadata_dict: dict, path: Path): + with open(path, "w") as f: + json.dump(metadata_dict, f) + + save_results(rollouts, path / "results.jsonl") + save_metadata(metadata, path / "metadata.json") + + +def save_generate_outputs( + outputs: GenerateOutputs, + push_to_hf_hub: bool = False, + hf_hub_dataset_name: str | None = None, +): + path_to_save = outputs["metadata"]["path_to_save"] + sanitized_rollouts = sanitize_rollouts(outputs["rollouts"]) + sanitized_metadata = sanitize_metadata(outputs["metadata"]) + save_to_disk(sanitized_rollouts, sanitized_metadata, path_to_save) + logger.info(f"Results saved to {path_to_save}") + if push_to_hf_hub: + dataset_name = hf_hub_dataset_name or get_hf_hub_dataset_name(outputs) + try: + Dataset.from_list(sanitized_rollouts).push_to_hub(dataset_name) + logger.info(f"Dataset saved to Hugging Face Hub: {dataset_name}") + except Exception as e: + logger.error(f"Error pushing dataset to Hugging Face Hub: {e}") From 51f3adc5e891ef5156b50c7a6ef9337ad3c1bd98 Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Sun, 25 Jan 2026 12:23:13 +0000 Subject: [PATCH 07/11] try..except around saving utils --- verifiers/utils/save_utils.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index f3dd80ea8..6e79c0b7c 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -152,18 +152,27 @@ def save_to_disk(rollouts: list[dict], metadata: dict, path: Path): """Saves (sanitized) rollouts and metadata to disk.""" path.mkdir(parents=True, exist_ok=True) - def save_results(results_list: list[dict], path: Path): - with open(path, "w") as f: - for result in results_list: - json.dump(result, f) - f.write("\n") + def save_results(results_list: list[dict], results_path: Path): + with open(results_path, "w") as f: + for idx, result in enumerate(results_list): + example_id = result.get("example_id") or "unknown" + try: + json.dump(result, f) + f.write("\n") + except Exception as e: + logger.error( + f"Failed to save rollout with index {idx} ({example_id=}): {e}" + ) + + def save_metadata(metadata_dict: dict, metadata_path: Path): + with open(metadata_path, "w") as f: + try: + json.dump(metadata_dict, f) + except Exception as e: + logger.error(f"Failed to save metadata: {e}") - def save_metadata(metadata_dict: dict, path: Path): - with open(path, "w") as f: - json.dump(metadata_dict, f) - - save_results(rollouts, path / "results.jsonl") save_metadata(metadata, path / "metadata.json") + save_results(rollouts, path / "results.jsonl") def save_generate_outputs( From c5da4463a620653e91b6db44e24bf6d9f84b0e9a Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Sun, 25 Jan 2026 12:24:43 +0000 Subject: [PATCH 08/11] flatten metrics --- verifiers/utils/save_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index 6e79c0b7c..55e0c4edc 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -133,6 +133,10 @@ def sanitize_rollout(rollout: RolloutOutput) -> dict: messages_to_printable(rollout["completion"]) ) sanitized_rollout["error"] = repr(rollout.get("error")) + rollout_metrics = rollout.get("metrics", {}) + for k, v in rollout_metrics.items(): + sanitized_rollout[k] = v + return sanitized_rollout return [sanitize_rollout(rollout) for rollout in rollouts] From 86201b2e0b3bb3a53a2b1b9945790d457217f6af Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Sun, 25 Jan 2026 12:26:18 +0000 Subject: [PATCH 09/11] pop answer and info if not present --- verifiers/utils/save_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index 55e0c4edc..9ef9db61f 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -133,6 +133,10 @@ def sanitize_rollout(rollout: RolloutOutput) -> dict: messages_to_printable(rollout["completion"]) ) sanitized_rollout["error"] = repr(rollout.get("error")) + if not rollout.get("answer"): + sanitized_rollout.pop("answer") + if not rollout.get("info"): + sanitized_rollout.pop("info") rollout_metrics = rollout.get("metrics", {}) for k, v in rollout_metrics.items(): sanitized_rollout[k] = v From 713acb948c6aa2493d78d9e2943ca4fb235a6795 Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Sun, 25 Jan 2026 12:28:21 +0000 Subject: [PATCH 10/11] some more comments --- verifiers/utils/save_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index 9ef9db61f..06710e8ee 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -24,6 +24,7 @@ def state_to_rollout_output( state: State, state_columns: list[str] = [] ) -> RolloutOutput: + """Converts a state to a rollout output. Adds state columns to the output.""" rollout_output = RolloutOutput( example_id=state.get("example_id", 0), prompt=state.get("prompt"), @@ -48,6 +49,7 @@ def state_to_rollout_output( def states_to_rollout_outputs( states: list[State], state_columns: list[str] = [] ) -> list[RolloutOutput]: + """Converts a list of states to a list of rollout outputs.""" return [state_to_rollout_output(state, state_columns) for state in states] @@ -62,6 +64,7 @@ def states_to_generate_metadata( start_time: float, results_path: Path | None, ) -> GenerateMetadata: + """Converts a list of states to generate metadata.""" base_url = str(client.base_url) if hasattr(client, "base_url") else "" rewards = [s.get("reward", 0.0) for s in states] avg_reward = sum(rewards) / len(rewards) if rewards else 0.0 @@ -126,17 +129,21 @@ def sanitize_rollouts(rollouts: list[RolloutOutput]) -> list[dict]: def sanitize_rollout(rollout: RolloutOutput) -> dict: sanitized_rollout = dict(rollout) + # sanitize messages sanitized_rollout["prompt"] = sanitize_tool_calls( messages_to_printable(rollout["prompt"]) ) sanitized_rollout["completion"] = sanitize_tool_calls( messages_to_printable(rollout["completion"]) ) + # use str repr for error sanitized_rollout["error"] = repr(rollout.get("error")) + # only include optional fields if present if not rollout.get("answer"): sanitized_rollout.pop("answer") if not rollout.get("info"): sanitized_rollout.pop("info") + # flatten metrics rollout_metrics = rollout.get("metrics", {}) for k, v in rollout_metrics.items(): sanitized_rollout[k] = v From 005b15f2e2c48a13c5f230d817536acf29e90110 Mon Sep 17 00:00:00 2001 From: Mika Senghaas Date: Sun, 25 Jan 2026 12:30:48 +0000 Subject: [PATCH 11/11] deprecate ProcessedOutputs --- docs/reference.md | 15 --------------- verifiers/types.py | 14 +------------- 2 files changed, 1 insertion(+), 28 deletions(-) diff --git a/docs/reference.md b/docs/reference.md index eaad7c95d..14a12db38 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -209,21 +209,6 @@ class RolloutScores(TypedDict): metrics: dict[str, list[float]] ``` -### ProcessedOutputs - -```python -class ProcessedOutputs(TypedDict): - prompt_ids: list[list[int]] - prompt_mask: list[list[int]] - completion_ids: list[list[int]] - completion_mask: list[list[int]] - completion_logprobs: list[list[float]] - rewards: list[float] - is_truncated: list[bool] -``` - -Tokenized outputs for training. - --- ## Classes diff --git a/verifiers/types.py b/verifiers/types.py index 5ce238f91..e1be1d3d2 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -176,7 +176,7 @@ class RolloutOutput(TypedDict): answer: str task: str info: Info - tools: dict + tools: list[ChatCompletionToolParam] | None reward: float metrics: dict[str, float] stop_condition: str | None @@ -206,18 +206,6 @@ class RolloutScores(TypedDict): metrics: dict[str, list[float]] -class ProcessedOutputs(TypedDict): - """TypedDict for processed outputs.""" - - prompt_ids: list[list[int]] - prompt_mask: list[list[int]] - completion_ids: list[list[int]] - completion_mask: list[list[int]] - completion_logprobs: list[list[float]] - rewards: list[float] - is_truncated: list[bool] - - Endpoint = TypedDict("Endpoint", {"key": str, "url": str, "model": str}) Endpoints = dict[str, Endpoint]