forked from kristaller486/RuQualBench
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate_judge.py
More file actions
292 lines (244 loc) · 10.9 KB
/
evaluate_judge.py
File metadata and controls
292 lines (244 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import argparse
import asyncio
import json
import logging
import os
import random
import statistics
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Any
from dotenv import load_dotenv
from tqdm.asyncio import tqdm as atqdm
import litellm
# Import BenchmarkV1
from benchmark.v1 import BenchmarkV1
# Configure litellm
litellm.ssl_verify = False
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
REFERENCE_JUDGE_MODEL = "gemini-2.5-pro"
SEED = 42
SAMPLE_SIZE = 100
LOGS_LIMIT = 100
def get_reference_samples() -> List[Dict[str, Any]]:
"""
Collects samples from the first 100 logs where the judge was gemini-2.5-pro.
Returns 100 random samples from this pool.
"""
logs_dir = Path("logs")
if not logs_dir.exists():
raise FileNotFoundError("Logs directory not found")
# Find all log files
log_files = list(logs_dir.glob("benchmark_*.json"))
# Sort by timestamp (filename usually contains timestamp)
# Format: benchmark_YYYY-MM-DD_HH-MM-SS_run_X_dataset.json
log_files.sort(key=lambda x: x.name)
valid_logs = []
logger.info(f"Scanning logs for judge '{REFERENCE_JUDGE_MODEL}'...")
for log_file in log_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
config = data.get("config", {})
judge_model = config.get("judge_model", "")
# Check if judge model matches reference
# We use 'in' because sometimes model names might have prefixes/suffixes
# or the user might have specified it slightly differently
if REFERENCE_JUDGE_MODEL.lower() in judge_model.lower():
valid_logs.append(data)
if len(valid_logs) >= LOGS_LIMIT:
break
except Exception as e:
logger.warning(f"Error reading {log_file}: {e}")
continue
logger.info(f"Found {len(valid_logs)} logs matching the reference judge.")
if not valid_logs:
raise ValueError(f"No logs found with judge model '{REFERENCE_JUDGE_MODEL}'")
# Collect all successful dialog results
all_samples = []
for log_data in valid_logs:
results = log_data.get("results", [])
for res in results:
if res.get("error") is None and res.get("answer") is not None:
# Store necessary data for re-evaluation and comparison
sample = {
"dialog_id": res.get("dialog_id"),
"dialog": res.get("dialog"),
"answer": res.get("answer"),
"original_judge_model": log_data["config"].get("judge_model"),
"original_scores": {
"critical_mistakes": res.get("critical_mistakes", 0),
"mistakes": res.get("mistakes", 0),
"additional_mistakes": res.get("additional_mistakes", 0),
"explanation_critical_mistakes": res.get("explanation_critical_mistakes", []),
"explanation_mistakes": res.get("explanation_mistakes", []),
"explanation_additional_mistakes": res.get("explanation_additional_mistakes", [])
},
"source_log": log_data["config"].get("timestamp")
}
all_samples.append(sample)
logger.info(f"Collected {len(all_samples)} valid samples from these logs.")
if len(all_samples) < SAMPLE_SIZE:
logger.warning(f"Available samples ({len(all_samples)}) is less than requested size ({SAMPLE_SIZE}). Using all available.")
return all_samples
# Select random samples with fixed seed
random.seed(SEED)
selected_samples = random.sample(all_samples, SAMPLE_SIZE)
return selected_samples
async def evaluate_sample(benchmark: BenchmarkV1, sample: Dict[str, Any], semaphore: asyncio.Semaphore) -> Dict[str, Any]:
"""
Evaluates a single sample with the new judge using BenchmarkV1 logic.
"""
try:
# Use internal method _judge_answer from BenchmarkV1
judge_result = await benchmark._judge_answer(
sample["dialog"],
sample["answer"],
semaphore
)
return {
"sample_id": sample["dialog_id"], # Using dialog_id as ID, though it might not be unique across datasets
"original_scores": sample["original_scores"],
"new_scores": {
"critical_mistakes": judge_result["critical_mistakes"],
"mistakes": judge_result["mistakes"],
"additional_mistakes": judge_result["additional_mistakes"],
"explanation_critical_mistakes": judge_result["explanation_critical_mistakes"],
"explanation_mistakes": judge_result["explanation_mistakes"],
"explanation_additional_mistakes": judge_result["explanation_additional_mistakes"]
},
"diff": {
"critical_mistakes": judge_result["critical_mistakes"] - sample["original_scores"]["critical_mistakes"],
"mistakes": judge_result["mistakes"] - sample["original_scores"]["mistakes"],
"additional_mistakes": judge_result["additional_mistakes"] - sample["original_scores"]["additional_mistakes"]
},
"error": None
}
except Exception as e:
logger.error(f"Error evaluating sample: {e}")
return {
"sample_id": sample["dialog_id"],
"original_scores": sample["original_scores"],
"new_scores": None,
"diff": None,
"error": str(e)
}
async def run_evaluation(judge_model: str, extra_body: dict = None):
# 1. Get reference samples
logger.info("Step 1: Collecting reference samples...")
try:
samples = get_reference_samples()
except Exception as e:
logger.error(f"Failed to collect samples: {e}")
return
# 2. Setup evaluation
# Initialize BenchmarkV1 to use its configuration and methods
benchmark = BenchmarkV1(
dataset_name="lite", # Placeholder, not used for single sample evaluation
judge_model_name=judge_model,
extra_body=extra_body
)
judge_max_workers = int(os.getenv("JUDGE_MODEL_MAX_WORKERS", "10"))
semaphore = asyncio.Semaphore(judge_max_workers)
logger.info(f"Step 2: Evaluating {len(samples)} samples with judge '{judge_model}'...")
tasks = [
evaluate_sample(benchmark, sample, semaphore)
for sample in samples
]
results = []
for coro in atqdm.as_completed(tasks, desc="Evaluating", total=len(tasks)):
res = await coro
results.append(res)
# 3. Analyze results
valid_results = [r for r in results if r["error"] is None]
failed_count = len(results) - len(valid_results)
if failed_count > 0:
logger.warning(f"{failed_count} evaluations failed.")
# Calculate aggregate stats
total_diff_critical = sum(r["diff"]["critical_mistakes"] for r in valid_results)
total_diff_mistakes = sum(r["diff"]["mistakes"] for r in valid_results)
total_diff_additional = sum(r["diff"]["additional_mistakes"] for r in valid_results)
# Count how often new judge found MORE or LESS errors
comparison_stats = {
"critical": {"more": 0, "less": 0, "same": 0},
"mistakes": {"more": 0, "less": 0, "same": 0},
"additional": {"more": 0, "less": 0, "same": 0}
}
for r in valid_results:
for key in ["critical", "mistakes", "additional"]:
full_key = f"{key}_mistakes" if key != "mistakes" else "mistakes" # handle naming inconsistency if any, but keys are consistent here
if key == "critical": full_key = "critical_mistakes"
if key == "additional": full_key = "additional_mistakes"
diff = r["diff"][full_key]
if diff > 0:
comparison_stats[key]["more"] += 1
elif diff < 0:
comparison_stats[key]["less"] += 1
else:
comparison_stats[key]["same"] += 1
summary = {
"judge_model": judge_model,
"reference_judge": REFERENCE_JUDGE_MODEL,
"samples_count": len(samples),
"successful_evals": len(valid_results),
"total_diff": {
"critical_mistakes": total_diff_critical,
"mistakes": total_diff_mistakes,
"additional_mistakes": total_diff_additional
},
"comparison_stats": comparison_stats,
"timestamp": datetime.now().isoformat()
}
# 4. Save results
output_dir = Path("judge_evals")
output_dir.mkdir(exist_ok=True)
timestamp_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
safe_model_name = judge_model.replace("/", "_").replace(":", "_")
output_file = output_dir / f"judge_eval_{timestamp_str}_{safe_model_name}.json"
output_data = {
"summary": summary,
"detailed_results": results
}
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, ensure_ascii=False, indent=2)
logger.info(f"Results saved to {output_file}")
# Print summary to console
print("\n" + "="*60)
print(f"JUDGE EVALUATION REPORT: {judge_model}")
print("="*60)
print(f"Reference Judge: {REFERENCE_JUDGE_MODEL}")
print(f"Samples: {len(samples)}")
print("-" * 60)
print("DIFFERENCES (New Judge - Reference Judge):")
print(f"Total Critical Mistakes Diff: {total_diff_critical:+d}")
print(f"Total Mistakes Diff: {total_diff_mistakes:+d}")
print(f"Total Additional Mistakes Diff: {total_diff_additional:+d}")
print("-" * 60)
print("DETAILED COMPARISON (Count of samples):")
print(f"{'Type':<15} | {'More Errors':<12} | {'Less Errors':<12} | {'Same':<12}")
print("-" * 60)
for key in ["critical", "mistakes", "additional"]:
stats = comparison_stats[key]
print(f"{key.capitalize():<15} | {stats['more']:<12} | {stats['less']:<12} | {stats['same']:<12}")
print("="*60)
def main():
parser = argparse.ArgumentParser(description="Evaluate a judge model against a reference judge")
parser.add_argument("judge_model", type=str, help="Name of the judge model to evaluate")
parser.add_argument("--extra-body", type=str, help="JSON string for extra_body parameter")
args = parser.parse_args()
extra_body = None
if args.extra_body:
try:
extra_body = json.loads(args.extra_body)
except json.JSONDecodeError:
logger.error("Invalid JSON for --extra-body")
return
asyncio.run(run_evaluation(args.judge_model, extra_body))
if __name__ == "__main__":
main()