merge: greediness (kernel greedy alternativo a rescore strided)

This commit is contained in:
2026-05-04 15:45:15 +02:00
2 changed files with 101 additions and 7 deletions
+17 -7
View File
@@ -40,6 +40,7 @@ from pm2d._jit_kernels import (
score_by_shift as _jit_score_by_shift,
score_bitmap as _jit_score_bitmap,
score_bitmap_rescored as _jit_score_bitmap_rescored,
score_bitmap_greedy as _jit_score_bitmap_greedy,
popcount_density as _jit_popcount,
HAS_NUMBA,
)
@@ -722,6 +723,7 @@ class LineShapeMatcher:
pyramid_propagate: bool = True,
propagate_topk: int = 8,
refine_pose_joint: bool = False,
greediness: float = 0.0,
) -> list[Match]:
"""
scale_penalty: se > 0, riduce lo score per match a scala diversa da 1.0:
@@ -813,23 +815,31 @@ class LineShapeMatcher:
neighbor_map[vi_c] = vi_sorted[start:end]
# Pruning varianti via top-level (parallelizzato).
# coarse_stride > 1: valuta solo 1 pixel ogni stride, ~stride^2 speed-up.
# pyramid_propagate=True: ritorna top-K picchi per restringere full-res.
# coarse_stride > 1: 1 pixel ogni stride (~stride^2 speed-up).
# pyramid_propagate=True: top-K picchi per restringere full-res.
# greediness > 0: kernel greedy con early-exit (alternativo a rescore).
cs = max(1, int(coarse_stride))
peaks_by_vi: dict[int, list[tuple[int, int, float]]] = {}
use_greedy_top = greediness > 0.0
def _top_score(vi: int) -> tuple[int, float]:
var = self.variants[vi]
lvl = var.levels[min(top, len(var.levels) - 1)]
score = _jit_score_bitmap_rescored(
spread_top, lvl.dx, lvl.dy, lvl.bin, bit_active_top,
bg_cache_top[var.scale], stride=cs,
)
if use_greedy_top:
# Greedy non supporta stride né rescore bg
score = _jit_score_bitmap_greedy(
spread_top, lvl.dx, lvl.dy, lvl.bin, bit_active_top,
top_thresh, greediness,
)
else:
score = _jit_score_bitmap_rescored(
spread_top, lvl.dx, lvl.dy, lvl.bin, bit_active_top,
bg_cache_top[var.scale], stride=cs,
)
if score.size == 0:
return vi, -1.0
best = float(score.max())
if pyramid_propagate and best > 0:
# Top-K posizioni > top_thresh (max propagate_topk)
flat = score.ravel()
k = min(propagate_topk, flat.size)
idx = np.argpartition(-flat, k - 1)[:k]