From 27a0ef1a458b168dce5094aa0df76f3931375a09 Mon Sep 17 00:00:00 2001 From: AdrianoDev Date: Mon, 4 May 2026 15:24:44 +0200 Subject: [PATCH] feat: coarse_stride per sub-sampling top-level Nuovo kernel JIT _jit_score_bitmap_rescored_strided: valuta solo pixel su griglia stride x stride al top della piramide. NMS + fase full-res recuperano precisione. Speed-up ~stride^2 sulla fase coarse, specie su scene grandi (1920x1080). Co-Authored-By: Claude Opus 4.7 (1M context) --- pm2d/_jit_kernels.py | 91 +++++++++++++++++++++++++++++++++++++++----- pm2d/line_matcher.py | 8 +++- 2 files changed, 88 insertions(+), 11 deletions(-) diff --git a/pm2d/_jit_kernels.py b/pm2d/_jit_kernels.py index e06d5d1..4546c99 100644 --- a/pm2d/_jit_kernels.py +++ b/pm2d/_jit_kernels.py @@ -110,6 +110,63 @@ if HAS_NUMBA: acc[y, x] *= inv return acc + @nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False) + def _jit_score_bitmap_rescored_strided( + spread: np.ndarray, + dx: np.ndarray, dy: np.ndarray, bins: np.ndarray, + bit_active: np.uint8, + bg: np.ndarray, + stride: nb.int32, + ) -> np.ndarray: + """Variante con sub-sampling: valuta solo pixel su griglia stride×stride. + Score restituito ha stessa shape (H, W); celle non valutate = 0. + + 4× speed-up con stride=2 (NMS recupera precisione in full-res). + Numba prange richiede step costante: itero su indici griglia e + moltiplico per stride dentro il body. + """ + H, W = spread.shape + N = dx.shape[0] + acc = np.zeros((H, W), dtype=np.float32) + ny = (H + stride - 1) // stride + nx = (W + stride - 1) // stride + for yi in nb.prange(ny): + y = yi * stride + for i in range(N): + b = bins[i] + mask = np.uint8(1) << b + if (bit_active & mask) == 0: + continue + ddy = dy[i] + yy = y + ddy + if yy < 0 or yy >= H: + continue + ddx = dx[i] + x_lo = 0 if ddx >= 0 else -ddx + x_hi = W if ddx <= 0 else W - ddx + rem = x_lo % stride + if rem != 0: + x_lo += stride - rem + x = x_lo + while x < x_hi: + if spread[yy, x + ddx] & mask: + acc[y, x] += 1.0 + x += stride + if N > 0: + inv = 1.0 / N + for yi in nb.prange(ny): + y = yi * stride + for xi in range(nx): + x = xi * stride + v = acc[y, x] * inv + bgv = bg[y, x] + if bgv < 1.0: + r = (v - bgv) / (1.0 - bgv + 1e-6) + acc[y, x] = r if r > 0.0 else 0.0 + else: + acc[y, x] = 0.0 + return acc + @nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False) def _jit_score_bitmap_rescored( spread: np.ndarray, # uint8 (H, W) @@ -185,6 +242,9 @@ if HAS_NUMBA: _jit_score_bitmap(spread, dx, dy, b, np.uint8(0xFF)) bg = np.zeros((32, 32), dtype=np.float32) _jit_score_bitmap_rescored(spread, dx, dy, b, np.uint8(0xFF), bg) + _jit_score_bitmap_rescored_strided( + spread, dx, dy, b, np.uint8(0xFF), bg, np.int32(2), + ) _jit_popcount_density(spread) else: # pragma: no cover @@ -198,6 +258,9 @@ else: # pragma: no cover def _jit_score_bitmap_rescored(spread, dx, dy, bins, bit_active, bg): raise RuntimeError("numba non disponibile") + def _jit_score_bitmap_rescored_strided(spread, dx, dy, bins, bit_active, bg, stride): + raise RuntimeError("numba non disponibile") + def _jit_popcount_density(spread): raise RuntimeError("numba non disponibile") @@ -228,19 +291,29 @@ def score_bitmap( def score_bitmap_rescored( spread: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray, - bit_active: int, bg: np.ndarray, + bit_active: int, bg: np.ndarray, stride: int = 1, ) -> np.ndarray: - """Score bitmap + rescore fusi in un solo pass (JIT).""" + """Score bitmap + rescore fusi in un solo pass (JIT). + + stride > 1: valuta solo pixel su griglia stride×stride. Le celle non + valutate restano 0 nello score map. Pensato per coarse-pass al top + della piramide; il refinement full-res poi recupera precisione. + """ if HAS_NUMBA and len(dx) > 0: + spread_c = np.ascontiguousarray(spread, dtype=np.uint8) + dx_c = np.ascontiguousarray(dx, dtype=np.int32) + dy_c = np.ascontiguousarray(dy, dtype=np.int32) + bins_c = np.ascontiguousarray(bins, dtype=np.int8) + bg_c = np.ascontiguousarray(bg, dtype=np.float32) + if stride > 1: + return _jit_score_bitmap_rescored_strided( + spread_c, dx_c, dy_c, bins_c, np.uint8(bit_active), bg_c, + np.int32(stride), + ) return _jit_score_bitmap_rescored( - np.ascontiguousarray(spread, dtype=np.uint8), - np.ascontiguousarray(dx, dtype=np.int32), - np.ascontiguousarray(dy, dtype=np.int32), - np.ascontiguousarray(bins, dtype=np.int8), - np.uint8(bit_active), - np.ascontiguousarray(bg, dtype=np.float32), + spread_c, dx_c, dy_c, bins_c, np.uint8(bit_active), bg_c, ) - # Fallback: chiamate separate + # Fallback: chiamate separate (stride ignorato in fallback) score = score_bitmap(spread, dx, dy, bins, bit_active) out = (score - bg) / (1.0 - bg + 1e-6) return np.maximum(0.0, out).astype(np.float32) diff --git a/pm2d/line_matcher.py b/pm2d/line_matcher.py index e5f212a..0e3fc9b 100644 --- a/pm2d/line_matcher.py +++ b/pm2d/line_matcher.py @@ -573,6 +573,7 @@ class LineShapeMatcher: verify_ncc: bool = True, verify_threshold: float = 0.4, coarse_angle_factor: int = 2, + coarse_stride: int = 1, scale_penalty: float = 0.0, ) -> list[Match]: """ @@ -645,13 +646,16 @@ class LineShapeMatcher: end = min(n, i + half + 1) neighbor_map[vi_c] = vi_sorted[start:end] - # Pruning varianti via top-level (parallelizzato) - solo coarse + # Pruning varianti via top-level (parallelizzato) - solo coarse. + # coarse_stride > 1: valuta solo 1 pixel ogni stride, ~stride² speed-up. + cs = max(1, int(coarse_stride)) + def _top_score(vi: int) -> tuple[int, float]: var = self.variants[vi] lvl = var.levels[min(top, len(var.levels) - 1)] score = _jit_score_bitmap_rescored( spread_top, lvl.dx, lvl.dy, lvl.bin, bit_active_top, - bg_cache_top[var.scale], + bg_cache_top[var.scale], stride=cs, ) return vi, float(score.max()) if score.size else -1.0