Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions aggregate_token_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import os
import json
import glob
import argparse
from collections import defaultdict

def deep_aggregate(target, source):
"""Recursively aggregates (sums numbers, extends lists) source into target."""
if not isinstance(source, dict):
return target

for k, v in source.items():
if k not in target:
if isinstance(v, dict):
target[k] = deep_aggregate({}, v)
elif isinstance(v, list):
target[k] = list(v)
elif isinstance(v, (int, float)):
target[k] = v
else:
assert False, f"Value is not a dict, list, int, or float: {v}"
else:
if isinstance(v, dict):
if not isinstance(target[k], dict):
target[k] = deep_aggregate({}, v)
else:
deep_aggregate(target[k], v)
elif isinstance(v, list):
if not isinstance(target[k], list):
target[k] = list(v)
else:
target[k].extend(v)
elif isinstance(v, (int, float)):
if not isinstance(target[k], (int, float)):
target[k] = v
else:
target[k] += v
else:
assert False, f"Value is not a dict, list, int, or float: {v}"
return target

def process_run_dir(directory, output_filename):
"""Processes directories containing run_*.json files."""
run_files = glob.glob(os.path.join(directory, "run_*.json"))
if not run_files:
return

print(f"Processing run directory: {directory} ({len(run_files)} files)")
qid_aggregated_stats = defaultdict(dict)

for file_path in run_files:
try:
with open(file_path, 'r') as f:
data = json.load(f)

qid = data.get("query_id") or data.get("qid")
token_stats_list = data.get("token_stats", [])

if qid is None:
continue

qid_str = str(qid)
target = qid_aggregated_stats[qid_str]
for turn_stats in token_stats_list:
if isinstance(turn_stats, dict):
deep_aggregate(target, turn_stats)
except Exception as e:
print(f" Error processing {file_path}: {e}")

write_output(qid_aggregated_stats, os.path.join(directory, output_filename))

def process_invocation_dir(directory, output_filename):
"""Processes directories containing invocation history JSON files."""
json_files = glob.glob(os.path.join(directory, "*.json"))
if not json_files:
return

json_files = [f for f in json_files if not f.endswith(output_filename)]
if not json_files:
return

print(f"Processing invocation history directory: {directory} ({len(json_files)} files)")

qid_to_history = {}

for file_path in sorted(json_files):
try:
with open(file_path, 'r') as f:
data = json.load(f)

assert isinstance(data, list), f"Data is not a list: {data}"
for entry in data:
assert isinstance(entry, dict), f"Entry is not a dict: {entry}"
query = entry.get("query", {})
qid = query.get("qid") or query.get("query_id")
assert qid is not None, f"QID is not found in {entry}"
history = entry.get("invocations_history")
assert history is not None, f"History is not found in {entry}"
qid_to_history[str(qid)] = history
except Exception as e:
print(f" Error processing {file_path}: {e}")

qid_aggregated_stats = defaultdict(dict)
for qid_str, history in qid_to_history.items():
assert isinstance(history, list), f"History is not a list: {history}"

target = qid_aggregated_stats[qid_str]
for item in history:
assert isinstance(item, dict), f"Item is not a dict: {item}"
token_usage = item.get("token_usage") or item.get("token_stats")
if token_usage and isinstance(token_usage, dict):
deep_aggregate(target, token_usage)

write_output(qid_aggregated_stats, os.path.join(directory, output_filename))

def strip_lists(data):
"""Recursively removes list values from a dictionary."""
if isinstance(data, dict):
return {k: strip_lists(v) for k, v in data.items() if not isinstance(v, list)}
return data

def write_output(qid_aggregated_stats, output_file):
if not qid_aggregated_stats:
return

# Aggregate across ALL queries for the summary line
all_stats = {}
for stats in qid_aggregated_stats.values():
deep_aggregate(all_stats, stats)
all_stats_stripped = strip_lists(all_stats)

with open(output_file, 'w') as f:
sorted_qids = sorted(qid_aggregated_stats.keys(), key=lambda x: int(x) if x.isdigit() else x)
for qid in sorted_qids:
line = {
"qid": qid,
"token_stats": qid_aggregated_stats[qid]
}
f.write(json.dumps(line) + "\n")

# Add a summary line that aggregates across all queries
# Note: Using "aid": "all" as requested
f.write(json.dumps({"aid": "all", "token_stats": all_stats_stripped}) + "\n")
print(f" Aggregated stats written to {output_file}")

def main():
parser = argparse.ArgumentParser(description="Aggregate token stats from run files and invocation histories.")
parser.add_argument("path", nargs="?", default=".", help="Root directory or specific directory to process.")
parser.add_argument("--output", "-o", default="aggregated_token_stats.jsonl", help="Output JSONL filename.")
parser.add_argument("--recursive", "-r", action="store_true", help="Recursively search for directories to process.")
args = parser.parse_args()

root_path = os.path.abspath(args.path)

if args.recursive:
for root, dirs, files in os.walk(root_path):
folder_name = os.path.basename(root).lower()
if "invocat" in folder_name and "history" in folder_name:
process_invocation_dir(root, args.output)
else:
if any(f.startswith("run_") and f.endswith(".json") for f in files):
# Only process run dir if it's not an invocation history dir's parent
# (Wait, actually, it's fine either way if they don't overlap in file patterns)
process_run_dir(root, args.output)
else:
folder_name = os.path.basename(root_path).lower()
if "invocat" in folder_name and "history" in folder_name:
process_invocation_dir(root_path, args.output)
else:
process_run_dir(root_path, args.output)

if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"python-dotenv>=1.1.1",
"qwen-agent[gui,mcp,rag]==0.0.27",
"qwen-omni-utils==0.0.8",
"rank-llm>=0.21.0",
"rich>=14.0.0",
"tevatron",
"torchvision",
Expand Down
4 changes: 2 additions & 2 deletions scripts_evaluation/evaluate_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,9 @@ def main():
output_dir = mirror_directory_structure(input_dir, eval_dir)
print(f"Evaluations will be saved to {output_dir}")

json_files = list(input_dir.glob("*.json"))
json_files = list(input_dir.glob("run_*.json"))
if not json_files:
print(f"No JSON files found in {input_dir}")
print(f"No JSON files starting with 'run_' found in {input_dir}")
return

print(f"Found {len(json_files)} JSON files to evaluate")
Expand Down
Loading