From 6704d66cd523d948defdc24c199dce2ffb22a680 Mon Sep 17 00:00:00 2001 From: AdrianoDev Date: Mon, 4 May 2026 15:35:51 +0200 Subject: [PATCH] feat: kernel JIT batch top-max-per-variant (opt-in) Nuovo kernel _jit_top_max_per_variant: prange esterno sulle varianti invece di n_vars chiamate JIT separate via ThreadPoolExecutor. Wrapper Python top_max_per_variant prepara buffer flat (offsets + dx_flat/dy_flat/bins_flat) e bg per scala. Default batch_top=False perche su benchmark realistici (Linux 13 core, 72-180 varianti) ThreadPoolExecutor + kernel singolo che rilascia GIL e gia ottimale. Path batch_top=True utile come opzione per scenari con n_vars >>> n_threads o overhead chiamate JIT dominante. Co-Authored-By: Claude Opus 4.7 (1M context) --- pm2d/_jit_kernels.py | 114 +++++++++++++++++++++++++++++++++++++++++++ pm2d/line_matcher.py | 22 ++++++++- 2 files changed, 135 insertions(+), 1 deletion(-) diff --git a/pm2d/_jit_kernels.py b/pm2d/_jit_kernels.py index e06d5d1..7db5753 100644 --- a/pm2d/_jit_kernels.py +++ b/pm2d/_jit_kernels.py @@ -159,6 +159,63 @@ if HAS_NUMBA: acc[y, x] = 0.0 return acc + @nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False) + def _jit_top_max_per_variant( + spread: np.ndarray, # uint8 (H, W) + dx_flat: np.ndarray, # int32 (sum_N,) + dy_flat: np.ndarray, # int32 (sum_N,) + bins_flat: np.ndarray, # int8 (sum_N,) + offsets: np.ndarray, # int32 (n_vars+1,) prefix sum + bit_active: np.uint8, + bg_per_variant: np.ndarray, # float32 (n_vars, H, W) - 1 per scala + scale_idx: np.ndarray, # int32 (n_vars,) idx in bg_per_variant + ) -> np.ndarray: + """Batch: per ogni variante calcola max score (rescored bg), ritorna + array float32 (n_vars,). Parallelismo prange ESTERNO sulle varianti + elimina overhead di n_vars chiamate JIT separate (avg ~20us per + chiamata su template piccoli) + pool thread Python. + + Pensato per fase TOP del pruning quando n_vars >> n_threads. + """ + n_vars = offsets.shape[0] - 1 + H, W = spread.shape + out = np.zeros(n_vars, dtype=np.float32) + for vi in nb.prange(n_vars): + i0 = offsets[vi]; i1 = offsets[vi + 1] + N = i1 - i0 + if N == 0: + out[vi] = -1.0 + continue + si = scale_idx[vi] + inv = nb.float32(1.0 / N) + best = nb.float32(-1.0) + for y in range(H): + for x in range(W): + s = nb.float32(0.0) + for k in range(N): + b = bins_flat[i0 + k] + mask = np.uint8(1) << b + if (bit_active & mask) == 0: + continue + ddy = dy_flat[i0 + k] + yy = y + ddy + if yy < 0 or yy >= H: + continue + ddx = dx_flat[i0 + k] + xx = x + ddx + if xx < 0 or xx >= W: + continue + if spread[yy, xx] & mask: + s += nb.float32(1.0) + s *= inv + bgv = bg_per_variant[si, y, x] + if bgv < 1.0: + r = (s - bgv) / (1.0 - bgv + 1e-6) + if r > best: + best = r + out[vi] = best if best > 0.0 else 0.0 + return out + @nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False) def _jit_popcount_density(spread: np.ndarray) -> np.ndarray: """Conta bit set per pixel: ritorna (H, W) float32 in [0..8].""" @@ -185,6 +242,12 @@ if HAS_NUMBA: _jit_score_bitmap(spread, dx, dy, b, np.uint8(0xFF)) bg = np.zeros((32, 32), dtype=np.float32) _jit_score_bitmap_rescored(spread, dx, dy, b, np.uint8(0xFF), bg) + offsets = np.array([0, 1], dtype=np.int32) + scale_idx = np.zeros(1, dtype=np.int32) + bg_pv = np.zeros((1, 32, 32), dtype=np.float32) + _jit_top_max_per_variant( + spread, dx, dy, b, offsets, np.uint8(0xFF), bg_pv, scale_idx, + ) _jit_popcount_density(spread) else: # pragma: no cover @@ -198,6 +261,12 @@ else: # pragma: no cover def _jit_score_bitmap_rescored(spread, dx, dy, bins, bit_active, bg): raise RuntimeError("numba non disponibile") + def _jit_top_max_per_variant( + spread, dx_flat, dy_flat, bins_flat, offsets, bit_active, + bg_per_variant, scale_idx, + ): + raise RuntimeError("numba non disponibile") + def _jit_popcount_density(spread): raise RuntimeError("numba non disponibile") @@ -246,6 +315,51 @@ def score_bitmap_rescored( return np.maximum(0.0, out).astype(np.float32) +def top_max_per_variant( + spread: np.ndarray, + dx_list: list, dy_list: list, bin_list: list, + bg_per_scale: dict, + variant_scales: list, + bit_active: int, +) -> np.ndarray: + """Wrapper: prepara buffer flat e chiama kernel batch su tutte le varianti. + + Parallelismo Numba prange-esterno sulle varianti (n_vars >> n_threads + tipicamente per top-pruning) → meglio del thread-pool Python che paga + overhead di n_vars chiamate JIT separate. + """ + if not HAS_NUMBA or len(dx_list) == 0: + return np.array([], dtype=np.float32) + n_vars = len(dx_list) + sizes = [len(d) for d in dx_list] + offsets = np.zeros(n_vars + 1, dtype=np.int32) + offsets[1:] = np.cumsum(sizes) + total = int(offsets[-1]) + dx_flat = np.empty(total, dtype=np.int32) + dy_flat = np.empty(total, dtype=np.int32) + bins_flat = np.empty(total, dtype=np.int8) + for vi, (dx, dy, bn) in enumerate(zip(dx_list, dy_list, bin_list)): + i0 = int(offsets[vi]); i1 = int(offsets[vi + 1]) + dx_flat[i0:i1] = dx + dy_flat[i0:i1] = dy + bins_flat[i0:i1] = bn + # bg per variante: indicizzato per scala + scales_unique = sorted(bg_per_scale.keys()) + scale_to_idx = {s: i for i, s in enumerate(scales_unique)} + H, W = spread.shape + bg_pv = np.empty((len(scales_unique), H, W), dtype=np.float32) + for s, idx in scale_to_idx.items(): + bg_pv[idx] = bg_per_scale[s] + scale_idx = np.array( + [scale_to_idx[s] for s in variant_scales], dtype=np.int32, + ) + return _jit_top_max_per_variant( + np.ascontiguousarray(spread, dtype=np.uint8), + dx_flat, dy_flat, bins_flat, offsets, np.uint8(bit_active), + bg_pv, scale_idx, + ) + + def popcount_density(spread: np.ndarray) -> np.ndarray: if HAS_NUMBA: return _jit_popcount_density(np.ascontiguousarray(spread, dtype=np.uint8)) diff --git a/pm2d/line_matcher.py b/pm2d/line_matcher.py index e5f212a..8e5527e 100644 --- a/pm2d/line_matcher.py +++ b/pm2d/line_matcher.py @@ -40,6 +40,7 @@ from pm2d._jit_kernels import ( score_by_shift as _jit_score_by_shift, score_bitmap as _jit_score_bitmap, score_bitmap_rescored as _jit_score_bitmap_rescored, + top_max_per_variant as _jit_top_max_per_variant, popcount_density as _jit_popcount, HAS_NUMBA, ) @@ -574,6 +575,7 @@ class LineShapeMatcher: verify_threshold: float = 0.4, coarse_angle_factor: int = 2, scale_penalty: float = 0.0, + batch_top: bool = False, ) -> list[Match]: """ scale_penalty: se > 0, riduce lo score per match a scala diversa da 1.0: @@ -657,7 +659,25 @@ class LineShapeMatcher: kept_coarse: list[tuple[int, float]] = [] all_top_scores: list[tuple[int, float]] = [] - if self.n_threads > 1 and len(coarse_idx_list) > 1: + # batch_top: usa kernel batch single-call con prange-esterno su + # varianti. Vince su threadpool quando n_vars >> n_threads e quando + # H*W top e' piccolo (overhead chiamate JIT > costo kernel). + if (batch_top and HAS_NUMBA and len(coarse_idx_list) > 4): + dx_l = []; dy_l = []; bn_l = []; vs_l = [] + for vi in coarse_idx_list: + var = self.variants[vi] + lvl = var.levels[min(top, len(var.levels) - 1)] + dx_l.append(lvl.dx); dy_l.append(lvl.dy); bn_l.append(lvl.bin) + vs_l.append(var.scale) + scores_arr = _jit_top_max_per_variant( + spread_top, dx_l, dy_l, bn_l, bg_cache_top, vs_l, + bit_active_top, + ) + for vi, best in zip(coarse_idx_list, scores_arr.tolist()): + all_top_scores.append((vi, best)) + if best >= top_thresh: + kept_coarse.append((vi, best)) + elif self.n_threads > 1 and len(coarse_idx_list) > 1: with ThreadPoolExecutor(max_workers=self.n_threads) as ex: for vi, best in ex.map(_top_score, coarse_idx_list): all_top_scores.append((vi, best))