-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcli.py
More file actions
90 lines (71 loc) · 3.5 KB
/
cli.py
File metadata and controls
90 lines (71 loc) · 3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import os
import yaml
import json
from promptguard.runners.run_models import run_model
from promptguard.diff.semantic_diff import check_semantic_drift
from promptguard.diff.tone_diff import check_tone_drift
from promptguard.diff.safety_diff import check_safety_regression
def main():
parser = argparse.ArgumentParser(description="PromptGuard: CI for LLM behavior.")
parser.add_argument("--suite", type=str, required=True, help="Path to the directory containing prompt YAML files.")
parser.add_argument("--configs", type=str, required=True, help="Path to the model configurations JSON file.")
args = parser.parse_args()
print(f"Running PromptGuard with:")
print(f" Prompt Suite: {args.suite}")
print(f" Model Configs: {args.configs}")
prompts = []
for filename in os.listdir(args.suite):
if filename.endswith(".yaml"):
filepath = os.path.join(args.suite, filename)
with open(filepath, 'r') as f:
prompts.extend(yaml.safe_load(f))
with open(args.configs, 'r') as f:
model_configs = json.load(f)
print(f"Loaded {len(prompts)} prompts from {len(os.listdir(args.suite))} files.")
print(f"Loaded {len(model_configs)} model configurations.")
final_report = []
for prompt_data in prompts:
prompt_id = prompt_data["id"]
prompt_text = prompt_data["prompt"]
expected = prompt_data["expected"]
for config in model_configs:
model_output = run_model(prompt_text, config)
regressions = []
regressions.extend(check_semantic_drift(expected, model_output))
# Debugging tone drift
if prompt_id == "tax_explanation_v1" and "tone" in expected:
from promptguard.diff.tone_diff import classify_tone # Import here to avoid circular dependency
expected_tone = expected["tone"]
actual_tone = classify_tone(model_output)
print(f"\nDEBUG: Prompt ID: {prompt_id}, Config: {config}")
print(f" Expected Tone: '{expected_tone}'")
print(f" Actual Tone: '{actual_tone}'")
print(f" Model Output: '{model_output}'")
regressions.extend(check_tone_drift(expected, model_output))
regressions.extend(check_safety_regression(expected, model_output))
status = "PASS" if not regressions else "FAIL"
final_report.append({
"prompt_id": prompt_id,
"config": config,
"output": model_output, # Include output for debugging/context
"regressions": regressions,
"status": status
})
report_path = "promptguard/report/report.json"
os.makedirs(os.path.dirname(report_path), exist_ok=True)
with open(report_path, 'w') as f:
json.dump(final_report, f, indent=2)
print(f"\nReport generated at {report_path}")
# Optionally print a summary of failures
failures = [item for item in final_report if item["status"] == "FAIL"]
if failures:
print(f"\n--- {len(failures)} Failing Tests ---")
for fail in failures:
print(f"Prompt ID: {fail['prompt_id']}, Config: {fail['config']}, Status: {fail['status']}")
for reg in fail['regressions']:
print(f" - {reg['type']}: {reg['reason']} (Severity: {reg['severity']})")
else:
print("\nAll tests passed!")
if __name__ == "__main__":
main()