6704d66cd5
Nuovo kernel _jit_top_max_per_variant: prange esterno sulle varianti invece di n_vars chiamate JIT separate via ThreadPoolExecutor. Wrapper Python top_max_per_variant prepara buffer flat (offsets + dx_flat/dy_flat/bins_flat) e bg per scala. Default batch_top=False perche su benchmark realistici (Linux 13 core, 72-180 varianti) ThreadPoolExecutor + kernel singolo che rilascia GIL e gia ottimale. Path batch_top=True utile come opzione per scenari con n_vars >>> n_threads o overhead chiamate JIT dominante. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
391 lines
14 KiB
Python
391 lines
14 KiB
Python
"""Numba JIT kernels per hot-path del matching.
|
||
|
||
Fallback automatico a numpy se numba non disponibile o kernel fallisce.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import numpy as np
|
||
|
||
try:
|
||
import numba as nb
|
||
HAS_NUMBA = True
|
||
except ImportError: # pragma: no cover
|
||
HAS_NUMBA = False
|
||
|
||
|
||
def _numpy_score_by_shift(
|
||
resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bin_has_data: np.ndarray | None = None,
|
||
) -> np.ndarray:
|
||
"""Fallback numpy (stessa logica di LineShapeMatcher._score_by_shift)."""
|
||
_, H, W = resp.shape
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
n = len(dx)
|
||
for i in range(n):
|
||
b = int(bins[i])
|
||
if bin_has_data is not None and not bin_has_data[b]:
|
||
continue
|
||
ddx = int(dx[i]); ddy = int(dy[i])
|
||
y0s = max(0, -ddy); y1s = min(H, H - ddy)
|
||
x0s = max(0, -ddx); x1s = min(W, W - ddx)
|
||
if y0s >= y1s or x0s >= x1s:
|
||
continue
|
||
y0r = y0s + ddy; y1r = y1s + ddy
|
||
x0r = x0s + ddx; x1r = x1s + ddx
|
||
acc[y0s:y1s, x0s:x1s] += resp[b, y0r:y1r, x0r:x1r]
|
||
if n > 0:
|
||
acc /= n
|
||
return acc
|
||
|
||
|
||
if HAS_NUMBA:
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_score_by_shift(
|
||
resp: np.ndarray, # float32 (N_BINS, H, W)
|
||
dx: np.ndarray, # int32 (N,)
|
||
dy: np.ndarray, # int32 (N,)
|
||
bins: np.ndarray, # int8 (N,)
|
||
bin_active: np.ndarray, # bool_ (N_BINS,)
|
||
) -> np.ndarray:
|
||
_, H, W = resp.shape
|
||
N = dx.shape[0]
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for i in range(N):
|
||
b = bins[i]
|
||
if not bin_active[b]:
|
||
continue
|
||
ddy = dy[i]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx[i]
|
||
x_lo = 0 if ddx >= 0 else -ddx
|
||
x_hi = W if ddx <= 0 else W - ddx
|
||
for x in range(x_lo, x_hi):
|
||
acc[y, x] += resp[b, yy, x + ddx]
|
||
if N > 0:
|
||
inv = 1.0 / N
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
acc[y, x] *= inv
|
||
return acc
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_score_bitmap(
|
||
spread: np.ndarray, # uint8 (H, W), bit b = bin b attivo
|
||
dx: np.ndarray, # int32 (N,)
|
||
dy: np.ndarray, # int32 (N,)
|
||
bins: np.ndarray, # int8 (N,) bin per ogni feature
|
||
bit_active: np.uint8, # bitmask bin attivi in scena
|
||
) -> np.ndarray:
|
||
"""score[y,x] = (Σ_i [bit bins[i] acceso in spread[y+dy_i, x+dx_i]]) / N.
|
||
|
||
32× meno memoria di response map float32 → cache-friendly.
|
||
"""
|
||
H, W = spread.shape
|
||
N = dx.shape[0]
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for i in range(N):
|
||
b = bins[i]
|
||
mask = np.uint8(1) << b
|
||
if (bit_active & mask) == 0:
|
||
continue
|
||
ddy = dy[i]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx[i]
|
||
x_lo = 0 if ddx >= 0 else -ddx
|
||
x_hi = W if ddx <= 0 else W - ddx
|
||
for x in range(x_lo, x_hi):
|
||
if spread[yy, x + ddx] & mask:
|
||
acc[y, x] += 1.0
|
||
if N > 0:
|
||
inv = 1.0 / N
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
acc[y, x] *= inv
|
||
return acc
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_score_bitmap_rescored(
|
||
spread: np.ndarray, # uint8 (H, W)
|
||
dx: np.ndarray, # int32 (N,)
|
||
dy: np.ndarray, # int32 (N,)
|
||
bins: np.ndarray, # int8 (N,)
|
||
bit_active: np.uint8,
|
||
bg: np.ndarray, # float32 (H, W) background density normalizzata
|
||
) -> np.ndarray:
|
||
"""score+rescore in un singolo pass: evita allocazione intermedia.
|
||
|
||
Equivalente a:
|
||
score = _jit_score_bitmap(...)
|
||
out = max(0, (score - bg) / (1 - bg + 1e-6))
|
||
ma fonde la seconda passata dentro la normalizzazione finale
|
||
(cache-friendly, risparmia ~15% sul totale find).
|
||
"""
|
||
H, W = spread.shape
|
||
N = dx.shape[0]
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for i in range(N):
|
||
b = bins[i]
|
||
mask = np.uint8(1) << b
|
||
if (bit_active & mask) == 0:
|
||
continue
|
||
ddy = dy[i]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx[i]
|
||
x_lo = 0 if ddx >= 0 else -ddx
|
||
x_hi = W if ddx <= 0 else W - ddx
|
||
for x in range(x_lo, x_hi):
|
||
if spread[yy, x + ddx] & mask:
|
||
acc[y, x] += 1.0
|
||
if N > 0:
|
||
inv = 1.0 / N
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
v = acc[y, x] * inv
|
||
bgv = bg[y, x]
|
||
if bgv < 1.0:
|
||
r = (v - bgv) / (1.0 - bgv + 1e-6)
|
||
acc[y, x] = r if r > 0.0 else 0.0
|
||
else:
|
||
acc[y, x] = 0.0
|
||
return acc
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_top_max_per_variant(
|
||
spread: np.ndarray, # uint8 (H, W)
|
||
dx_flat: np.ndarray, # int32 (sum_N,)
|
||
dy_flat: np.ndarray, # int32 (sum_N,)
|
||
bins_flat: np.ndarray, # int8 (sum_N,)
|
||
offsets: np.ndarray, # int32 (n_vars+1,) prefix sum
|
||
bit_active: np.uint8,
|
||
bg_per_variant: np.ndarray, # float32 (n_vars, H, W) - 1 per scala
|
||
scale_idx: np.ndarray, # int32 (n_vars,) idx in bg_per_variant
|
||
) -> np.ndarray:
|
||
"""Batch: per ogni variante calcola max score (rescored bg), ritorna
|
||
array float32 (n_vars,). Parallelismo prange ESTERNO sulle varianti
|
||
elimina overhead di n_vars chiamate JIT separate (avg ~20us per
|
||
chiamata su template piccoli) + pool thread Python.
|
||
|
||
Pensato per fase TOP del pruning quando n_vars >> n_threads.
|
||
"""
|
||
n_vars = offsets.shape[0] - 1
|
||
H, W = spread.shape
|
||
out = np.zeros(n_vars, dtype=np.float32)
|
||
for vi in nb.prange(n_vars):
|
||
i0 = offsets[vi]; i1 = offsets[vi + 1]
|
||
N = i1 - i0
|
||
if N == 0:
|
||
out[vi] = -1.0
|
||
continue
|
||
si = scale_idx[vi]
|
||
inv = nb.float32(1.0 / N)
|
||
best = nb.float32(-1.0)
|
||
for y in range(H):
|
||
for x in range(W):
|
||
s = nb.float32(0.0)
|
||
for k in range(N):
|
||
b = bins_flat[i0 + k]
|
||
mask = np.uint8(1) << b
|
||
if (bit_active & mask) == 0:
|
||
continue
|
||
ddy = dy_flat[i0 + k]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx_flat[i0 + k]
|
||
xx = x + ddx
|
||
if xx < 0 or xx >= W:
|
||
continue
|
||
if spread[yy, xx] & mask:
|
||
s += nb.float32(1.0)
|
||
s *= inv
|
||
bgv = bg_per_variant[si, y, x]
|
||
if bgv < 1.0:
|
||
r = (s - bgv) / (1.0 - bgv + 1e-6)
|
||
if r > best:
|
||
best = r
|
||
out[vi] = best if best > 0.0 else 0.0
|
||
return out
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_popcount_density(spread: np.ndarray) -> np.ndarray:
|
||
"""Conta bit set per pixel: ritorna (H, W) float32 in [0..8]."""
|
||
H, W = spread.shape
|
||
out = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
v = spread[y, x]
|
||
# popcount manuale
|
||
v = (v & 0x55) + ((v >> 1) & 0x55)
|
||
v = (v & 0x33) + ((v >> 2) & 0x33)
|
||
v = (v & 0x0F) + ((v >> 4) & 0x0F)
|
||
out[y, x] = float(v)
|
||
return out
|
||
|
||
def _warmup():
|
||
resp = np.zeros((8, 32, 32), dtype=np.float32)
|
||
dx = np.zeros(1, dtype=np.int32)
|
||
dy = np.zeros(1, dtype=np.int32)
|
||
b = np.zeros(1, dtype=np.int8)
|
||
ba = np.ones(8, dtype=np.bool_)
|
||
_jit_score_by_shift(resp, dx, dy, b, ba)
|
||
spread = np.zeros((32, 32), dtype=np.uint8)
|
||
_jit_score_bitmap(spread, dx, dy, b, np.uint8(0xFF))
|
||
bg = np.zeros((32, 32), dtype=np.float32)
|
||
_jit_score_bitmap_rescored(spread, dx, dy, b, np.uint8(0xFF), bg)
|
||
offsets = np.array([0, 1], dtype=np.int32)
|
||
scale_idx = np.zeros(1, dtype=np.int32)
|
||
bg_pv = np.zeros((1, 32, 32), dtype=np.float32)
|
||
_jit_top_max_per_variant(
|
||
spread, dx, dy, b, offsets, np.uint8(0xFF), bg_pv, scale_idx,
|
||
)
|
||
_jit_popcount_density(spread)
|
||
|
||
else: # pragma: no cover
|
||
|
||
def _jit_score_by_shift(resp, dx, dy, bins, bin_active):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_score_bitmap(spread, dx, dy, bins, bit_active):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_score_bitmap_rescored(spread, dx, dy, bins, bit_active, bg):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_top_max_per_variant(
|
||
spread, dx_flat, dy_flat, bins_flat, offsets, bit_active,
|
||
bg_per_variant, scale_idx,
|
||
):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_popcount_density(spread):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _warmup():
|
||
pass
|
||
|
||
|
||
def score_bitmap(
|
||
spread: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bit_active: int,
|
||
) -> np.ndarray:
|
||
"""Dispatch bitmap: JIT se numba, fallback numpy."""
|
||
if HAS_NUMBA and len(dx) > 0:
|
||
return _jit_score_bitmap(
|
||
np.ascontiguousarray(spread, dtype=np.uint8),
|
||
np.ascontiguousarray(dx, dtype=np.int32),
|
||
np.ascontiguousarray(dy, dtype=np.int32),
|
||
np.ascontiguousarray(bins, dtype=np.int8),
|
||
np.uint8(bit_active),
|
||
)
|
||
# Fallback numpy (lento): converte bitmap a response 3D
|
||
H, W = spread.shape
|
||
resp = np.zeros((8, H, W), dtype=np.float32)
|
||
for b in range(8):
|
||
resp[b] = ((spread >> b) & 1).astype(np.float32)
|
||
return _numpy_score_by_shift(resp, dx, dy, bins, None)
|
||
|
||
|
||
def score_bitmap_rescored(
|
||
spread: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bit_active: int, bg: np.ndarray,
|
||
) -> np.ndarray:
|
||
"""Score bitmap + rescore fusi in un solo pass (JIT)."""
|
||
if HAS_NUMBA and len(dx) > 0:
|
||
return _jit_score_bitmap_rescored(
|
||
np.ascontiguousarray(spread, dtype=np.uint8),
|
||
np.ascontiguousarray(dx, dtype=np.int32),
|
||
np.ascontiguousarray(dy, dtype=np.int32),
|
||
np.ascontiguousarray(bins, dtype=np.int8),
|
||
np.uint8(bit_active),
|
||
np.ascontiguousarray(bg, dtype=np.float32),
|
||
)
|
||
# Fallback: chiamate separate
|
||
score = score_bitmap(spread, dx, dy, bins, bit_active)
|
||
out = (score - bg) / (1.0 - bg + 1e-6)
|
||
return np.maximum(0.0, out).astype(np.float32)
|
||
|
||
|
||
def top_max_per_variant(
|
||
spread: np.ndarray,
|
||
dx_list: list, dy_list: list, bin_list: list,
|
||
bg_per_scale: dict,
|
||
variant_scales: list,
|
||
bit_active: int,
|
||
) -> np.ndarray:
|
||
"""Wrapper: prepara buffer flat e chiama kernel batch su tutte le varianti.
|
||
|
||
Parallelismo Numba prange-esterno sulle varianti (n_vars >> n_threads
|
||
tipicamente per top-pruning) → meglio del thread-pool Python che paga
|
||
overhead di n_vars chiamate JIT separate.
|
||
"""
|
||
if not HAS_NUMBA or len(dx_list) == 0:
|
||
return np.array([], dtype=np.float32)
|
||
n_vars = len(dx_list)
|
||
sizes = [len(d) for d in dx_list]
|
||
offsets = np.zeros(n_vars + 1, dtype=np.int32)
|
||
offsets[1:] = np.cumsum(sizes)
|
||
total = int(offsets[-1])
|
||
dx_flat = np.empty(total, dtype=np.int32)
|
||
dy_flat = np.empty(total, dtype=np.int32)
|
||
bins_flat = np.empty(total, dtype=np.int8)
|
||
for vi, (dx, dy, bn) in enumerate(zip(dx_list, dy_list, bin_list)):
|
||
i0 = int(offsets[vi]); i1 = int(offsets[vi + 1])
|
||
dx_flat[i0:i1] = dx
|
||
dy_flat[i0:i1] = dy
|
||
bins_flat[i0:i1] = bn
|
||
# bg per variante: indicizzato per scala
|
||
scales_unique = sorted(bg_per_scale.keys())
|
||
scale_to_idx = {s: i for i, s in enumerate(scales_unique)}
|
||
H, W = spread.shape
|
||
bg_pv = np.empty((len(scales_unique), H, W), dtype=np.float32)
|
||
for s, idx in scale_to_idx.items():
|
||
bg_pv[idx] = bg_per_scale[s]
|
||
scale_idx = np.array(
|
||
[scale_to_idx[s] for s in variant_scales], dtype=np.int32,
|
||
)
|
||
return _jit_top_max_per_variant(
|
||
np.ascontiguousarray(spread, dtype=np.uint8),
|
||
dx_flat, dy_flat, bins_flat, offsets, np.uint8(bit_active),
|
||
bg_pv, scale_idx,
|
||
)
|
||
|
||
|
||
def popcount_density(spread: np.ndarray) -> np.ndarray:
|
||
if HAS_NUMBA:
|
||
return _jit_popcount_density(np.ascontiguousarray(spread, dtype=np.uint8))
|
||
# Fallback
|
||
H, W = spread.shape
|
||
out = np.zeros((H, W), dtype=np.float32)
|
||
for b in range(8):
|
||
out += ((spread >> b) & 1).astype(np.float32)
|
||
return out
|
||
|
||
|
||
def score_by_shift(
|
||
resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bin_has_data: np.ndarray | None = None,
|
||
) -> np.ndarray:
|
||
"""Dispatch: JIT se possibile, fallback numpy."""
|
||
if not HAS_NUMBA or len(dx) == 0:
|
||
return _numpy_score_by_shift(resp, dx, dy, bins, bin_has_data)
|
||
# Normalizza tipi per Numba
|
||
resp_f = np.ascontiguousarray(resp, dtype=np.float32)
|
||
dx_i = np.ascontiguousarray(dx, dtype=np.int32)
|
||
dy_i = np.ascontiguousarray(dy, dtype=np.int32)
|
||
bins_i = np.ascontiguousarray(bins, dtype=np.int8)
|
||
if bin_has_data is None:
|
||
bin_active = np.ones(resp.shape[0], dtype=np.bool_)
|
||
else:
|
||
bin_active = np.ascontiguousarray(bin_has_data, dtype=np.bool_)
|
||
return _jit_score_by_shift(resp_f, dx_i, dy_i, bins_i, bin_active)
|