diff --git a/pm2d/_jit_kernels.py b/pm2d/_jit_kernels.py index 4546c99..a0f1aa3 100644 --- a/pm2d/_jit_kernels.py +++ b/pm2d/_jit_kernels.py @@ -167,6 +167,61 @@ if HAS_NUMBA: acc[y, x] = 0.0 return acc + @nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False) + def _jit_score_bitmap_greedy( + spread: np.ndarray, + dx: np.ndarray, dy: np.ndarray, bins: np.ndarray, + bit_active: np.uint8, + min_score: nb.float32, + greediness: nb.float32, + ) -> np.ndarray: + """Score bitmap con early-exit greedy (no rescore background). + + Per ogni pixel iteriamo le N feature; abortiamo non appena diventa + impossibile raggiungere `min_required` count anche aggiungendo + tutte le feature rimanenti. min_required = greediness * min_score * N. + + greediness=0 → nessun early-exit (equivalente a kernel base). + greediness=1 → exit non appena hits + remaining < min_score * N. + Tipico: 0.7-0.9 → 2-4x speed-up senza perdere match. + """ + H, W = spread.shape + N = dx.shape[0] + acc = np.zeros((H, W), dtype=np.float32) + if N == 0: + return acc + min_req = greediness * min_score * N + inv_N = nb.float32(1.0 / N) + for y in nb.prange(H): + for x in range(W): + hits = 0 + for i in range(N): + b = bins[i] + mask = np.uint8(1) << b + if (bit_active & mask) == 0: + if hits + (N - i - 1) < min_req: + break + continue + ddy = dy[i] + yy = y + ddy + if yy < 0 or yy >= H: + if hits + (N - i - 1) < min_req: + break + continue + ddx = dx[i] + xx = x + ddx + if xx < 0 or xx >= W: + if hits + (N - i - 1) < min_req: + break + continue + if spread[yy, xx] & mask: + hits += 1 + else: + if hits + (N - i - 1) < min_req: + break + acc[y, x] = nb.float32(hits) * inv_N + return acc + @nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False) def _jit_score_bitmap_rescored( spread: np.ndarray, # uint8 (H, W) @@ -245,6 +300,10 @@ if HAS_NUMBA: _jit_score_bitmap_rescored_strided( spread, dx, dy, b, np.uint8(0xFF), bg, np.int32(2), ) + _jit_score_bitmap_greedy( + spread, dx, dy, b, np.uint8(0xFF), + np.float32(0.5), np.float32(0.8), + ) _jit_popcount_density(spread) else: # pragma: no cover @@ -261,6 +320,9 @@ else: # pragma: no cover def _jit_score_bitmap_rescored_strided(spread, dx, dy, bins, bit_active, bg, stride): raise RuntimeError("numba non disponibile") + def _jit_score_bitmap_greedy(spread, dx, dy, bins, bit_active, min_score, greediness): + raise RuntimeError("numba non disponibile") + def _jit_popcount_density(spread): raise RuntimeError("numba non disponibile") @@ -319,6 +381,28 @@ def score_bitmap_rescored( return np.maximum(0.0, out).astype(np.float32) +def score_bitmap_greedy( + spread: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray, + bit_active: int, min_score: float, greediness: float, +) -> np.ndarray: + """Score bitmap con early-exit greedy. Per coarse-pass aggressivo. + + Non applica rescore background: usare quando la scena ha basso clutter + o quando si vuole mass-prune varianti via top-level rapidamente. + """ + if HAS_NUMBA and len(dx) > 0: + return _jit_score_bitmap_greedy( + np.ascontiguousarray(spread, dtype=np.uint8), + np.ascontiguousarray(dx, dtype=np.int32), + np.ascontiguousarray(dy, dtype=np.int32), + np.ascontiguousarray(bins, dtype=np.int8), + np.uint8(bit_active), + np.float32(min_score), np.float32(greediness), + ) + # Fallback: kernel base senza early-exit + return score_bitmap(spread, dx, dy, bins, bit_active) + + def popcount_density(spread: np.ndarray) -> np.ndarray: if HAS_NUMBA: return _jit_popcount_density(np.ascontiguousarray(spread, dtype=np.uint8)) diff --git a/pm2d/line_matcher.py b/pm2d/line_matcher.py index a441731..9271bdc 100644 --- a/pm2d/line_matcher.py +++ b/pm2d/line_matcher.py @@ -40,6 +40,7 @@ from pm2d._jit_kernels import ( score_by_shift as _jit_score_by_shift, score_bitmap as _jit_score_bitmap, score_bitmap_rescored as _jit_score_bitmap_rescored, + score_bitmap_greedy as _jit_score_bitmap_greedy, popcount_density as _jit_popcount, HAS_NUMBA, ) @@ -722,6 +723,7 @@ class LineShapeMatcher: pyramid_propagate: bool = True, propagate_topk: int = 8, refine_pose_joint: bool = False, + greediness: float = 0.0, ) -> list[Match]: """ scale_penalty: se > 0, riduce lo score per match a scala diversa da 1.0: @@ -813,23 +815,31 @@ class LineShapeMatcher: neighbor_map[vi_c] = vi_sorted[start:end] # Pruning varianti via top-level (parallelizzato). - # coarse_stride > 1: valuta solo 1 pixel ogni stride, ~stride^2 speed-up. - # pyramid_propagate=True: ritorna top-K picchi per restringere full-res. + # coarse_stride > 1: 1 pixel ogni stride (~stride^2 speed-up). + # pyramid_propagate=True: top-K picchi per restringere full-res. + # greediness > 0: kernel greedy con early-exit (alternativo a rescore). cs = max(1, int(coarse_stride)) peaks_by_vi: dict[int, list[tuple[int, int, float]]] = {} + use_greedy_top = greediness > 0.0 def _top_score(vi: int) -> tuple[int, float]: var = self.variants[vi] lvl = var.levels[min(top, len(var.levels) - 1)] - score = _jit_score_bitmap_rescored( - spread_top, lvl.dx, lvl.dy, lvl.bin, bit_active_top, - bg_cache_top[var.scale], stride=cs, - ) + if use_greedy_top: + # Greedy non supporta stride né rescore bg + score = _jit_score_bitmap_greedy( + spread_top, lvl.dx, lvl.dy, lvl.bin, bit_active_top, + top_thresh, greediness, + ) + else: + score = _jit_score_bitmap_rescored( + spread_top, lvl.dx, lvl.dy, lvl.bin, bit_active_top, + bg_cache_top[var.scale], stride=cs, + ) if score.size == 0: return vi, -1.0 best = float(score.max()) if pyramid_propagate and best > 0: - # Top-K posizioni > top_thresh (max propagate_topk) flat = score.ravel() k = min(propagate_topk, flat.size) idx = np.argpartition(-flat, k - 1)[:k]