merge: batch_top variant-parallel kernel

This commit is contained in:
2026-05-04 15:46:17 +02:00
2 changed files with 135 additions and 1 deletions
+21 -1
View File
@@ -41,6 +41,7 @@ from pm2d._jit_kernels import (
score_bitmap as _jit_score_bitmap,
score_bitmap_rescored as _jit_score_bitmap_rescored,
score_bitmap_greedy as _jit_score_bitmap_greedy,
top_max_per_variant as _jit_top_max_per_variant,
popcount_density as _jit_popcount,
HAS_NUMBA,
)
@@ -724,6 +725,7 @@ class LineShapeMatcher:
propagate_topk: int = 8,
refine_pose_joint: bool = False,
greediness: float = 0.0,
batch_top: bool = False,
) -> list[Match]:
"""
scale_penalty: se > 0, riduce lo score per match a scala diversa da 1.0:
@@ -855,7 +857,25 @@ class LineShapeMatcher:
kept_coarse: list[tuple[int, float]] = []
all_top_scores: list[tuple[int, float]] = []
if self.n_threads > 1 and len(coarse_idx_list) > 1:
# batch_top: usa kernel batch single-call con prange-esterno su
# varianti. Vince su threadpool quando n_vars >> n_threads e quando
# H*W top e' piccolo (overhead chiamate JIT > costo kernel).
if (batch_top and HAS_NUMBA and len(coarse_idx_list) > 4):
dx_l = []; dy_l = []; bn_l = []; vs_l = []
for vi in coarse_idx_list:
var = self.variants[vi]
lvl = var.levels[min(top, len(var.levels) - 1)]
dx_l.append(lvl.dx); dy_l.append(lvl.dy); bn_l.append(lvl.bin)
vs_l.append(var.scale)
scores_arr = _jit_top_max_per_variant(
spread_top, dx_l, dy_l, bn_l, bg_cache_top, vs_l,
bit_active_top,
)
for vi, best in zip(coarse_idx_list, scores_arr.tolist()):
all_top_scores.append((vi, best))
if best >= top_thresh:
kept_coarse.append((vi, best))
elif self.n_threads > 1 and len(coarse_idx_list) > 1:
with ThreadPoolExecutor(max_workers=self.n_threads) as ex:
for vi, best in ex.map(_top_score, coarse_idx_list):
all_top_scores.append((vi, best))