diff --git a/pm2d/eval.py b/pm2d/eval.py new file mode 100644 index 0000000..ddad80d --- /dev/null +++ b/pm2d/eval.py @@ -0,0 +1,217 @@ +"""CLI validation harness per LineShapeMatcher. + +Usage: + python -m pm2d.eval dataset.json [opzioni] + +Formato dataset (JSON): + { + "template": "path/to/template.png", + "mask": "path/to/mask.png", # opzionale + "params": { # opzionali, override su matcher init + "use_polarity": true, + "angle_step_deg": 5, + ... + }, + "find_params": { # opzionali, passati a find() + "min_score": 0.6, + "use_soft_score": true, + ... + }, + "scenes": [ + { + "image": "path/to/scene1.png", + "ground_truth": [ + {"cx": 320.0, "cy": 240.0, "angle_deg": 12.0, + "scale": 1.0, "tolerance_px": 5.0, + "tolerance_deg": 3.0} + ] + } + ] + } + +Output: report precision/recall/IoU/timing per ogni scena + aggregati. +""" + +from __future__ import annotations + +import argparse +import json +import math +import sys +import time +from pathlib import Path + +import cv2 +import numpy as np + +from pm2d.line_matcher import LineShapeMatcher, _poly_iou, _oriented_bbox_polygon + + +def _load_image(path: str | Path) -> np.ndarray: + img = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) + if img is None: + raise FileNotFoundError(f"Immagine non trovata: {path}") + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + return img + + +def _gt_to_poly(gt: dict, tw: int, th: int) -> np.ndarray: + """Costruisce bbox poligonale per un ground truth.""" + s = float(gt.get("scale", 1.0)) + return _oriented_bbox_polygon( + float(gt["cx"]), float(gt["cy"]), + tw * s, th * s, float(gt["angle_deg"]), + ) + + +def _match_to_gt(match, gt: dict, tw: int, th: int, + iou_thr: float = 0.3) -> bool: + """True se il match corrisponde al ground truth. + + Criterio: distanza centro <= tolerance_px AND |angle_deg - gt| <= tolerance_deg + OR IoU bbox >= iou_thr (fallback per pose con tolerance ampie). + """ + tol_px = float(gt.get("tolerance_px", 5.0)) + tol_deg = float(gt.get("tolerance_deg", 3.0)) + dx = match.cx - float(gt["cx"]) + dy = match.cy - float(gt["cy"]) + dist = math.hypot(dx, dy) + da = abs((match.angle_deg - float(gt["angle_deg"]) + 180) % 360 - 180) + if dist <= tol_px and da <= tol_deg: + return True + # Fallback IoU + poly_gt = _gt_to_poly(gt, tw, th) + poly_m = match.bbox_poly + if _poly_iou(poly_m, poly_gt) >= iou_thr: + return True + return False + + +def evaluate_scene(matcher: LineShapeMatcher, scene_bgr: np.ndarray, + gt_list: list[dict], find_params: dict, + tw: int, th: int) -> dict: + """Esegue match e calcola TP/FP/FN per una scena.""" + t0 = time.time() + matches = matcher.find(scene_bgr, **find_params) + elapsed = time.time() - t0 + + gt_matched = [False] * len(gt_list) + match_is_tp = [False] * len(matches) + iou_per_match = [0.0] * len(matches) + for i, m in enumerate(matches): + for j, gt in enumerate(gt_list): + if gt_matched[j]: + continue + if _match_to_gt(m, gt, tw, th): + gt_matched[j] = True + match_is_tp[i] = True + # Calcolo IoU per metrica + poly_gt = _gt_to_poly(gt, tw, th) + iou_per_match[i] = _poly_iou(m.bbox_poly, poly_gt) + break + tp = sum(match_is_tp) + fp = len(matches) - tp + fn = len(gt_list) - sum(gt_matched) + return { + "n_matches": len(matches), + "n_gt": len(gt_list), + "tp": tp, "fp": fp, "fn": fn, + "find_time_s": elapsed, + "iou_mean": float(np.mean([i for i, t in zip(iou_per_match, match_is_tp) if t]) + if tp > 0 else 0.0), + "diag": (matcher.get_last_diag() + if hasattr(matcher, "get_last_diag") else None), + } + + +def run(dataset_path: str, scene_filter: str | None = None, + verbose: bool = False) -> dict: + """Esegue eval su dataset, ritorna report aggregato.""" + dataset_path = Path(dataset_path) + base = dataset_path.parent + with open(dataset_path) as f: + ds = json.load(f) + + template = _load_image(base / ds["template"]) + mask = None + if ds.get("mask"): + mask_img = cv2.imread(str(base / ds["mask"]), cv2.IMREAD_GRAYSCALE) + if mask_img is not None: + mask = (mask_img > 128).astype(np.uint8) * 255 + init_params = ds.get("params", {}) + find_params = ds.get("find_params", {}) + + matcher = LineShapeMatcher(**init_params) + n_var = matcher.train(template, mask=mask) + tw, th = matcher.template_size + print(f"Template: {ds['template']} ({tw}x{th}), {n_var} varianti") + print(f"Param matcher: {init_params}") + print(f"Param find: {find_params}") + print() + + scenes = ds["scenes"] + if scene_filter: + scenes = [s for s in scenes if scene_filter in s["image"]] + + rows = [] + tot_tp = tot_fp = tot_fn = 0 + tot_time = 0.0 + for sc in scenes: + scene = _load_image(base / sc["image"]) + gt = sc.get("ground_truth", []) + result = evaluate_scene(matcher, scene, gt, find_params, tw, th) + rows.append({"scene": sc["image"], **result}) + tot_tp += result["tp"]; tot_fp += result["fp"]; tot_fn += result["fn"] + tot_time += result["find_time_s"] + prec = result["tp"] / max(1, result["tp"] + result["fp"]) + rec = result["tp"] / max(1, result["tp"] + result["fn"]) + line = (f" {sc['image']:30s} " + f"TP={result['tp']} FP={result['fp']} FN={result['fn']} " + f"P={prec:.2f} R={rec:.2f} " + f"IoU={result['iou_mean']:.2f} " + f"t={result['find_time_s']*1000:.0f}ms") + print(line) + if verbose and result["diag"] and hasattr(matcher, "_format_diag"): + print(f" diag: {matcher._format_diag(result['diag'])}") + + # Aggregati + precision = tot_tp / max(1, tot_tp + tot_fp) + recall = tot_tp / max(1, tot_tp + tot_fn) + f1 = 2 * precision * recall / max(1e-9, precision + recall) + print() + print(f"AGGREGATO: precision={precision:.3f} recall={recall:.3f} " + f"F1={f1:.3f} TP={tot_tp} FP={tot_fp} FN={tot_fn}") + print(f"TIME: total={tot_time:.2f}s avg={tot_time / max(1, len(scenes)) * 1000:.0f}ms/scene") + + return { + "precision": precision, "recall": recall, "f1": f1, + "tp": tot_tp, "fp": tot_fp, "fn": tot_fn, + "total_time_s": tot_time, "n_scenes": len(scenes), + "per_scene": rows, + } + + +def main(argv: list[str] | None = None) -> int: + p = argparse.ArgumentParser( + description="pm2d-eval: validation harness per LineShapeMatcher" + ) + p.add_argument("dataset", help="JSON dataset (template + scenes + GT)") + p.add_argument("--scene-filter", default=None, + help="Filtro substring sui nomi scena (debug)") + p.add_argument("--verbose", "-v", action="store_true", + help="Stampa diag dict per ogni scena") + p.add_argument("--out", default=None, + help="Salva report JSON su file") + args = p.parse_args(argv) + report = run(args.dataset, scene_filter=args.scene_filter, + verbose=args.verbose) + if args.out: + with open(args.out, "w") as f: + json.dump(report, f, indent=2) + print(f"Report salvato: {args.out}") + return 0 if report["f1"] > 0.5 else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/pyproject.toml b/pyproject.toml index 14ea27c..3849c69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,9 @@ dependencies = [ "uvicorn[standard]>=0.34", ] +[project.scripts] +pm2d-eval = "pm2d.eval:main" + [dependency-groups] dev = [ "httpx>=0.28.1",