diff --git a/pm2d/line_matcher.py b/pm2d/line_matcher.py index a5195df..ece8db4 100644 --- a/pm2d/line_matcher.py +++ b/pm2d/line_matcher.py @@ -49,6 +49,27 @@ from pm2d._jit_kernels import ( N_BINS = 8 # orientamenti quantizzati modulo π +def _poly_iou(p1: np.ndarray, p2: np.ndarray) -> float: + """IoU tra due poligoni convessi (4 vertici, float32) via cv2.intersectConvexConvex. + + Usa OpenCV (cv2.intersectConvexConvex) per intersezione esatta: + ritorna area intersezione / area unione. Robusto a rotazioni + qualsiasi (anti-orarie/orarie) - cv2 normalizza orientamento. + """ + a1 = float(cv2.contourArea(p1)) + a2 = float(cv2.contourArea(p2)) + if a1 <= 0 or a2 <= 0: + return 0.0 + inter_area, _ = cv2.intersectConvexConvex( + p1.astype(np.float32), p2.astype(np.float32), + ) + inter_area = float(inter_area) + if inter_area <= 0: + return 0.0 + union = a1 + a2 - inter_area + return inter_area / union if union > 0 else 0.0 + + def _oriented_bbox_polygon( cx: float, cy: float, w: float, h: float, angle_deg: float, ) -> np.ndarray: @@ -762,6 +783,7 @@ class LineShapeMatcher: refine_pose_joint: bool = False, greediness: float = 0.0, batch_top: bool = False, + nms_iou_threshold: float = 0.3, ) -> list[Match]: """ scale_penalty: se > 0, riduce lo score per match a scala diversa da 1.0: @@ -1122,12 +1144,21 @@ class LineShapeMatcher: score_f = float(score_f) * max( 0.0, 1.0 - scale_penalty * abs(var.scale - 1.0) ) - # NMS post-refine: refine puo spostare il match di nms_radius; - # ricontrollo overlap su match gia accettati per evitare - # duplicati (stesso oggetto trovato da varianti angolari diverse). + # NMS post-refine cross-variant: usa IoU bbox-poligonale invece + # di sola distanza centro. Due match orientati diversi ma vicini + # (pezzi adiacenti) NON vengono fusi se l'overlap reale e basso; + # due match dello stesso pezzo (centri uguali, rotazione simile) + # hanno IoU alto e vengono droppati. + # Fallback distanza centro per match con bbox degenere. dup = False for k in kept: - if (k.cx - cx_out) ** 2 + (k.cy - cy_out) ** 2 < r2: + iou = _poly_iou(k.bbox_poly, poly) + if iou > nms_iou_threshold: + dup = True + break + # Sicurezza: centri molto vicini (dentro nms_radius/2) + # sempre fusi, anche con orientamenti molto diversi. + if (k.cx - cx_out) ** 2 + (k.cy - cy_out) ** 2 < (r2 / 4.0): dup = True break if dup: