diff --git a/pm2d/_jit_kernels.py b/pm2d/_jit_kernels.py index a0f1aa3..afa56c5 100644 --- a/pm2d/_jit_kernels.py +++ b/pm2d/_jit_kernels.py @@ -271,6 +271,63 @@ if HAS_NUMBA: acc[y, x] = 0.0 return acc + @nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False) + def _jit_top_max_per_variant( + spread: np.ndarray, # uint8 (H, W) + dx_flat: np.ndarray, # int32 (sum_N,) + dy_flat: np.ndarray, # int32 (sum_N,) + bins_flat: np.ndarray, # int8 (sum_N,) + offsets: np.ndarray, # int32 (n_vars+1,) prefix sum + bit_active: np.uint8, + bg_per_variant: np.ndarray, # float32 (n_vars, H, W) - 1 per scala + scale_idx: np.ndarray, # int32 (n_vars,) idx in bg_per_variant + ) -> np.ndarray: + """Batch: per ogni variante calcola max score (rescored bg), ritorna + array float32 (n_vars,). Parallelismo prange ESTERNO sulle varianti + elimina overhead di n_vars chiamate JIT separate (avg ~20us per + chiamata su template piccoli) + pool thread Python. + + Pensato per fase TOP del pruning quando n_vars >> n_threads. + """ + n_vars = offsets.shape[0] - 1 + H, W = spread.shape + out = np.zeros(n_vars, dtype=np.float32) + for vi in nb.prange(n_vars): + i0 = offsets[vi]; i1 = offsets[vi + 1] + N = i1 - i0 + if N == 0: + out[vi] = -1.0 + continue + si = scale_idx[vi] + inv = nb.float32(1.0 / N) + best = nb.float32(-1.0) + for y in range(H): + for x in range(W): + s = nb.float32(0.0) + for k in range(N): + b = bins_flat[i0 + k] + mask = np.uint8(1) << b + if (bit_active & mask) == 0: + continue + ddy = dy_flat[i0 + k] + yy = y + ddy + if yy < 0 or yy >= H: + continue + ddx = dx_flat[i0 + k] + xx = x + ddx + if xx < 0 or xx >= W: + continue + if spread[yy, xx] & mask: + s += nb.float32(1.0) + s *= inv + bgv = bg_per_variant[si, y, x] + if bgv < 1.0: + r = (s - bgv) / (1.0 - bgv + 1e-6) + if r > best: + best = r + out[vi] = best if best > 0.0 else 0.0 + return out + @nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False) def _jit_popcount_density(spread: np.ndarray) -> np.ndarray: """Conta bit set per pixel: ritorna (H, W) float32 in [0..8].""" @@ -304,6 +361,12 @@ if HAS_NUMBA: spread, dx, dy, b, np.uint8(0xFF), np.float32(0.5), np.float32(0.8), ) + offsets = np.array([0, 1], dtype=np.int32) + scale_idx = np.zeros(1, dtype=np.int32) + bg_pv = np.zeros((1, 32, 32), dtype=np.float32) + _jit_top_max_per_variant( + spread, dx, dy, b, offsets, np.uint8(0xFF), bg_pv, scale_idx, + ) _jit_popcount_density(spread) else: # pragma: no cover @@ -323,6 +386,12 @@ else: # pragma: no cover def _jit_score_bitmap_greedy(spread, dx, dy, bins, bit_active, min_score, greediness): raise RuntimeError("numba non disponibile") + def _jit_top_max_per_variant( + spread, dx_flat, dy_flat, bins_flat, offsets, bit_active, + bg_per_variant, scale_idx, + ): + raise RuntimeError("numba non disponibile") + def _jit_popcount_density(spread): raise RuntimeError("numba non disponibile") @@ -403,6 +472,51 @@ def score_bitmap_greedy( return score_bitmap(spread, dx, dy, bins, bit_active) +def top_max_per_variant( + spread: np.ndarray, + dx_list: list, dy_list: list, bin_list: list, + bg_per_scale: dict, + variant_scales: list, + bit_active: int, +) -> np.ndarray: + """Wrapper: prepara buffer flat e chiama kernel batch su tutte le varianti. + + Parallelismo Numba prange-esterno sulle varianti (n_vars >> n_threads + tipicamente per top-pruning) → meglio del thread-pool Python che paga + overhead di n_vars chiamate JIT separate. + """ + if not HAS_NUMBA or len(dx_list) == 0: + return np.array([], dtype=np.float32) + n_vars = len(dx_list) + sizes = [len(d) for d in dx_list] + offsets = np.zeros(n_vars + 1, dtype=np.int32) + offsets[1:] = np.cumsum(sizes) + total = int(offsets[-1]) + dx_flat = np.empty(total, dtype=np.int32) + dy_flat = np.empty(total, dtype=np.int32) + bins_flat = np.empty(total, dtype=np.int8) + for vi, (dx, dy, bn) in enumerate(zip(dx_list, dy_list, bin_list)): + i0 = int(offsets[vi]); i1 = int(offsets[vi + 1]) + dx_flat[i0:i1] = dx + dy_flat[i0:i1] = dy + bins_flat[i0:i1] = bn + # bg per variante: indicizzato per scala + scales_unique = sorted(bg_per_scale.keys()) + scale_to_idx = {s: i for i, s in enumerate(scales_unique)} + H, W = spread.shape + bg_pv = np.empty((len(scales_unique), H, W), dtype=np.float32) + for s, idx in scale_to_idx.items(): + bg_pv[idx] = bg_per_scale[s] + scale_idx = np.array( + [scale_to_idx[s] for s in variant_scales], dtype=np.int32, + ) + return _jit_top_max_per_variant( + np.ascontiguousarray(spread, dtype=np.uint8), + dx_flat, dy_flat, bins_flat, offsets, np.uint8(bit_active), + bg_pv, scale_idx, + ) + + def popcount_density(spread: np.ndarray) -> np.ndarray: if HAS_NUMBA: return _jit_popcount_density(np.ascontiguousarray(spread, dtype=np.uint8)) diff --git a/pm2d/line_matcher.py b/pm2d/line_matcher.py index 9271bdc..91e1c2c 100644 --- a/pm2d/line_matcher.py +++ b/pm2d/line_matcher.py @@ -41,6 +41,7 @@ from pm2d._jit_kernels import ( score_bitmap as _jit_score_bitmap, score_bitmap_rescored as _jit_score_bitmap_rescored, score_bitmap_greedy as _jit_score_bitmap_greedy, + top_max_per_variant as _jit_top_max_per_variant, popcount_density as _jit_popcount, HAS_NUMBA, ) @@ -724,6 +725,7 @@ class LineShapeMatcher: propagate_topk: int = 8, refine_pose_joint: bool = False, greediness: float = 0.0, + batch_top: bool = False, ) -> list[Match]: """ scale_penalty: se > 0, riduce lo score per match a scala diversa da 1.0: @@ -855,7 +857,25 @@ class LineShapeMatcher: kept_coarse: list[tuple[int, float]] = [] all_top_scores: list[tuple[int, float]] = [] - if self.n_threads > 1 and len(coarse_idx_list) > 1: + # batch_top: usa kernel batch single-call con prange-esterno su + # varianti. Vince su threadpool quando n_vars >> n_threads e quando + # H*W top e' piccolo (overhead chiamate JIT > costo kernel). + if (batch_top and HAS_NUMBA and len(coarse_idx_list) > 4): + dx_l = []; dy_l = []; bn_l = []; vs_l = [] + for vi in coarse_idx_list: + var = self.variants[vi] + lvl = var.levels[min(top, len(var.levels) - 1)] + dx_l.append(lvl.dx); dy_l.append(lvl.dy); bn_l.append(lvl.bin) + vs_l.append(var.scale) + scores_arr = _jit_top_max_per_variant( + spread_top, dx_l, dy_l, bn_l, bg_cache_top, vs_l, + bit_active_top, + ) + for vi, best in zip(coarse_idx_list, scores_arr.tolist()): + all_top_scores.append((vi, best)) + if best >= top_thresh: + kept_coarse.append((vi, best)) + elif self.n_threads > 1 and len(coarse_idx_list) > 1: with ThreadPoolExecutor(max_workers=self.n_threads) as ex: for vi, best in ex.map(_top_score, coarse_idx_list): all_top_scores.append((vi, best))