|
| 1 | +import json |
| 2 | +import numpy as np |
| 3 | +import argparse |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | + |
| 7 | +def load_jsonl(path): |
| 8 | + data = [] |
| 9 | + if not Path(path).exists(): |
| 10 | + return data |
| 11 | + with open(path, 'r') as f: |
| 12 | + for line in f: |
| 13 | + data.append(json.loads(line)) |
| 14 | + return data |
| 15 | + |
| 16 | + |
| 17 | +def compute_pck(gt_pose, pred_pose, threshold=0.05): |
| 18 | + """ |
| 19 | + Computes Percentage of Correct Keypoints (PCK). |
| 20 | + threshold: normalized distance relative to bbox diagonal. |
| 21 | + """ |
| 22 | + gt = np.array(gt_pose) |
| 23 | + pred = np.array(pred_pose) |
| 24 | + |
| 25 | + # Simple L2 distance on normalized coordinates |
| 26 | + distances = np.linalg.norm(gt - pred, axis=1) |
| 27 | + correct = distances < threshold |
| 28 | + return np.mean(correct) |
| 29 | + |
| 30 | + |
| 31 | +def run_benchmark(gt_path, pred_path): |
| 32 | + print(f"[BENCHMARK] Comparing {gt_path} vs {pred_path}...") |
| 33 | + |
| 34 | + gt_data = load_jsonl(gt_path) |
| 35 | + pred_data = load_jsonl(pred_path) |
| 36 | + |
| 37 | + if not gt_data or not pred_data: |
| 38 | + print("[ERROR] Missing input data for benchmarking.") |
| 39 | + return |
| 40 | + |
| 41 | + # Filter pred_data to only include player events |
| 42 | + preds = [p for p in pred_data if p.get("kind") == "player"] |
| 43 | + |
| 44 | + stats = { |
| 45 | + "total_gt_samples": len(gt_data), |
| 46 | + "matches_found": 0, |
| 47 | + "id_accuracy": 0.0, |
| 48 | + "avg_pck": 0.0 |
| 49 | + } |
| 50 | + |
| 51 | + pck_scores = [] |
| 52 | + id_matches = 0 |
| 53 | + |
| 54 | + for gt in gt_data: |
| 55 | + t_ms = gt["t_ms"] |
| 56 | + gt_tid = gt["track_id"] |
| 57 | + |
| 58 | + # Find corresponding prediction |
| 59 | + # Search for same t_ms and closest bbox/id |
| 60 | + match = None |
| 61 | + for p in preds: |
| 62 | + if abs(p["t_ms"] - t_ms) < 33: # Within 1 frame at 30fps |
| 63 | + # For now, we trust track_id matching for the audit |
| 64 | + if p["track_id"] == gt_tid: |
| 65 | + match = p |
| 66 | + break |
| 67 | + |
| 68 | + if match: |
| 69 | + stats["matches_found"] += 1 |
| 70 | + if gt.get("type") == "id_verification": |
| 71 | + if gt["label"] == "correct": |
| 72 | + id_matches += 1 |
| 73 | + |
| 74 | + # If both have poses, compute PCK |
| 75 | + # (Requires pose to be in Shush-P JSONL, which we added recently) |
| 76 | + if "pose_2d" in match and "pose_2d" in gt: |
| 77 | + pck = compute_pck(gt["pose_2d"], match["pose_2d"]) |
| 78 | + pck_scores.append(pck) |
| 79 | + |
| 80 | + if stats["matches_found"] > 0: |
| 81 | + stats["id_accuracy"] = id_matches / stats["matches_found"] |
| 82 | + if pck_scores: |
| 83 | + stats["avg_pck"] = np.mean(pck_scores) |
| 84 | + |
| 85 | + print("\n--- PERCEPTION BENCHMARK REPORT ---") |
| 86 | + print(f"GT Samples: {stats['total_gt_samples']}") |
| 87 | + print(f"Matches: {stats['matches_found']}") |
| 88 | + print(f"ID Acc: {stats['id_accuracy']:.2%}") |
| 89 | + print(f"Avg PCK: {stats['avg_pck']:.2%}") |
| 90 | + print("-----------------------------------\n") |
| 91 | + |
| 92 | + return stats |
| 93 | + |
| 94 | + |
| 95 | +if __name__ == "__main__": |
| 96 | + parser = argparse.ArgumentParser() |
| 97 | + parser.add_argument("--gt", default="data/training/manual_gt.jsonl") |
| 98 | + parser.add_argument("--pred", default="data/intelligent_game_dna.jsonl") |
| 99 | + args = parser.parse_args() |
| 100 | + |
| 101 | + run_benchmark(args.gt, args.pred) |
0 commit comments