84b73dc651
Flag opt-in use_polarity=True su LineShapeMatcher: distingue edge chiaro->scuro da scuro->chiaro raddoppiando i bin (8 mod pi a 16 mod 2pi). Riduce match accidentali quando il template e direzionale ma scena ha bordo opposto (es. pezzo nero su bg chiaro vs pezzo chiaro su bg nero). Implementazione: - _gradient calcola atan2 mod 2pi quando use_polarity - _spread_bitmap usa uint16 (16 bit) invece di uint8 (8 bit) - Nuovi kernel JIT _jit_score_bitmap_rescored_u16 e _jit_popcount_density_u16 - Wrapper Python score_bitmap_rescored / popcount_density fanno dispatch su dtype dello spread Default off (use_polarity=False) = backward compat completo, 8 bin. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
645 lines
24 KiB
Python
645 lines
24 KiB
Python
"""Numba JIT kernels per hot-path del matching.
|
||
|
||
Fallback automatico a numpy se numba non disponibile o kernel fallisce.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import numpy as np
|
||
|
||
try:
|
||
import numba as nb
|
||
HAS_NUMBA = True
|
||
except ImportError: # pragma: no cover
|
||
HAS_NUMBA = False
|
||
|
||
|
||
def _numpy_score_by_shift(
|
||
resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bin_has_data: np.ndarray | None = None,
|
||
) -> np.ndarray:
|
||
"""Fallback numpy (stessa logica di LineShapeMatcher._score_by_shift)."""
|
||
_, H, W = resp.shape
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
n = len(dx)
|
||
for i in range(n):
|
||
b = int(bins[i])
|
||
if bin_has_data is not None and not bin_has_data[b]:
|
||
continue
|
||
ddx = int(dx[i]); ddy = int(dy[i])
|
||
y0s = max(0, -ddy); y1s = min(H, H - ddy)
|
||
x0s = max(0, -ddx); x1s = min(W, W - ddx)
|
||
if y0s >= y1s or x0s >= x1s:
|
||
continue
|
||
y0r = y0s + ddy; y1r = y1s + ddy
|
||
x0r = x0s + ddx; x1r = x1s + ddx
|
||
acc[y0s:y1s, x0s:x1s] += resp[b, y0r:y1r, x0r:x1r]
|
||
if n > 0:
|
||
acc /= n
|
||
return acc
|
||
|
||
|
||
if HAS_NUMBA:
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_score_by_shift(
|
||
resp: np.ndarray, # float32 (N_BINS, H, W)
|
||
dx: np.ndarray, # int32 (N,)
|
||
dy: np.ndarray, # int32 (N,)
|
||
bins: np.ndarray, # int8 (N,)
|
||
bin_active: np.ndarray, # bool_ (N_BINS,)
|
||
) -> np.ndarray:
|
||
_, H, W = resp.shape
|
||
N = dx.shape[0]
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for i in range(N):
|
||
b = bins[i]
|
||
if not bin_active[b]:
|
||
continue
|
||
ddy = dy[i]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx[i]
|
||
x_lo = 0 if ddx >= 0 else -ddx
|
||
x_hi = W if ddx <= 0 else W - ddx
|
||
for x in range(x_lo, x_hi):
|
||
acc[y, x] += resp[b, yy, x + ddx]
|
||
if N > 0:
|
||
inv = 1.0 / N
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
acc[y, x] *= inv
|
||
return acc
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_score_bitmap(
|
||
spread: np.ndarray, # uint8 (H, W), bit b = bin b attivo
|
||
dx: np.ndarray, # int32 (N,)
|
||
dy: np.ndarray, # int32 (N,)
|
||
bins: np.ndarray, # int8 (N,) bin per ogni feature
|
||
bit_active: np.uint8, # bitmask bin attivi in scena
|
||
) -> np.ndarray:
|
||
"""score[y,x] = (Σ_i [bit bins[i] acceso in spread[y+dy_i, x+dx_i]]) / N.
|
||
|
||
32× meno memoria di response map float32 → cache-friendly.
|
||
"""
|
||
H, W = spread.shape
|
||
N = dx.shape[0]
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for i in range(N):
|
||
b = bins[i]
|
||
mask = np.uint8(1) << b
|
||
if (bit_active & mask) == 0:
|
||
continue
|
||
ddy = dy[i]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx[i]
|
||
x_lo = 0 if ddx >= 0 else -ddx
|
||
x_hi = W if ddx <= 0 else W - ddx
|
||
for x in range(x_lo, x_hi):
|
||
if spread[yy, x + ddx] & mask:
|
||
acc[y, x] += 1.0
|
||
if N > 0:
|
||
inv = 1.0 / N
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
acc[y, x] *= inv
|
||
return acc
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_score_bitmap_rescored_strided(
|
||
spread: np.ndarray,
|
||
dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bit_active: np.uint8,
|
||
bg: np.ndarray,
|
||
stride: nb.int32,
|
||
) -> np.ndarray:
|
||
"""Variante con sub-sampling: valuta solo pixel su griglia stride×stride.
|
||
Score restituito ha stessa shape (H, W); celle non valutate = 0.
|
||
|
||
4× speed-up con stride=2 (NMS recupera precisione in full-res).
|
||
Numba prange richiede step costante: itero su indici griglia e
|
||
moltiplico per stride dentro il body.
|
||
"""
|
||
H, W = spread.shape
|
||
N = dx.shape[0]
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
ny = (H + stride - 1) // stride
|
||
nx = (W + stride - 1) // stride
|
||
for yi in nb.prange(ny):
|
||
y = yi * stride
|
||
for i in range(N):
|
||
b = bins[i]
|
||
mask = np.uint8(1) << b
|
||
if (bit_active & mask) == 0:
|
||
continue
|
||
ddy = dy[i]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx[i]
|
||
x_lo = 0 if ddx >= 0 else -ddx
|
||
x_hi = W if ddx <= 0 else W - ddx
|
||
rem = x_lo % stride
|
||
if rem != 0:
|
||
x_lo += stride - rem
|
||
x = x_lo
|
||
while x < x_hi:
|
||
if spread[yy, x + ddx] & mask:
|
||
acc[y, x] += 1.0
|
||
x += stride
|
||
if N > 0:
|
||
inv = 1.0 / N
|
||
for yi in nb.prange(ny):
|
||
y = yi * stride
|
||
for xi in range(nx):
|
||
x = xi * stride
|
||
v = acc[y, x] * inv
|
||
bgv = bg[y, x]
|
||
if bgv < 1.0:
|
||
r = (v - bgv) / (1.0 - bgv + 1e-6)
|
||
acc[y, x] = r if r > 0.0 else 0.0
|
||
else:
|
||
acc[y, x] = 0.0
|
||
return acc
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_score_bitmap_greedy(
|
||
spread: np.ndarray,
|
||
dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bit_active: np.uint8,
|
||
min_score: nb.float32,
|
||
greediness: nb.float32,
|
||
) -> np.ndarray:
|
||
"""Score bitmap con early-exit greedy (no rescore background).
|
||
|
||
Per ogni pixel iteriamo le N feature; abortiamo non appena diventa
|
||
impossibile raggiungere `min_required` count anche aggiungendo
|
||
tutte le feature rimanenti. min_required = greediness * min_score * N.
|
||
|
||
greediness=0 → nessun early-exit (equivalente a kernel base).
|
||
greediness=1 → exit non appena hits + remaining < min_score * N.
|
||
Tipico: 0.7-0.9 → 2-4x speed-up senza perdere match.
|
||
"""
|
||
H, W = spread.shape
|
||
N = dx.shape[0]
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
if N == 0:
|
||
return acc
|
||
min_req = greediness * min_score * N
|
||
inv_N = nb.float32(1.0 / N)
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
hits = 0
|
||
for i in range(N):
|
||
b = bins[i]
|
||
mask = np.uint8(1) << b
|
||
if (bit_active & mask) == 0:
|
||
if hits + (N - i - 1) < min_req:
|
||
break
|
||
continue
|
||
ddy = dy[i]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
if hits + (N - i - 1) < min_req:
|
||
break
|
||
continue
|
||
ddx = dx[i]
|
||
xx = x + ddx
|
||
if xx < 0 or xx >= W:
|
||
if hits + (N - i - 1) < min_req:
|
||
break
|
||
continue
|
||
if spread[yy, xx] & mask:
|
||
hits += 1
|
||
else:
|
||
if hits + (N - i - 1) < min_req:
|
||
break
|
||
acc[y, x] = nb.float32(hits) * inv_N
|
||
return acc
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_score_bitmap_rescored(
|
||
spread: np.ndarray, # uint8 (H, W)
|
||
dx: np.ndarray, # int32 (N,)
|
||
dy: np.ndarray, # int32 (N,)
|
||
bins: np.ndarray, # int8 (N,)
|
||
bit_active: np.uint8,
|
||
bg: np.ndarray, # float32 (H, W) background density normalizzata
|
||
) -> np.ndarray:
|
||
"""score+rescore in un singolo pass: evita allocazione intermedia.
|
||
|
||
Equivalente a:
|
||
score = _jit_score_bitmap(...)
|
||
out = max(0, (score - bg) / (1 - bg + 1e-6))
|
||
ma fonde la seconda passata dentro la normalizzazione finale
|
||
(cache-friendly, risparmia ~15% sul totale find).
|
||
"""
|
||
H, W = spread.shape
|
||
N = dx.shape[0]
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for i in range(N):
|
||
b = bins[i]
|
||
mask = np.uint8(1) << b
|
||
if (bit_active & mask) == 0:
|
||
continue
|
||
ddy = dy[i]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx[i]
|
||
x_lo = 0 if ddx >= 0 else -ddx
|
||
x_hi = W if ddx <= 0 else W - ddx
|
||
for x in range(x_lo, x_hi):
|
||
if spread[yy, x + ddx] & mask:
|
||
acc[y, x] += 1.0
|
||
if N > 0:
|
||
inv = 1.0 / N
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
v = acc[y, x] * inv
|
||
bgv = bg[y, x]
|
||
if bgv < 1.0:
|
||
r = (v - bgv) / (1.0 - bgv + 1e-6)
|
||
acc[y, x] = r if r > 0.0 else 0.0
|
||
else:
|
||
acc[y, x] = 0.0
|
||
return acc
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_top_max_per_variant(
|
||
spread: np.ndarray, # uint8 (H, W)
|
||
dx_flat: np.ndarray, # int32 (sum_N,)
|
||
dy_flat: np.ndarray, # int32 (sum_N,)
|
||
bins_flat: np.ndarray, # int8 (sum_N,)
|
||
offsets: np.ndarray, # int32 (n_vars+1,) prefix sum
|
||
bit_active: np.uint8,
|
||
bg_per_variant: np.ndarray, # float32 (n_vars, H, W) - 1 per scala
|
||
scale_idx: np.ndarray, # int32 (n_vars,) idx in bg_per_variant
|
||
) -> np.ndarray:
|
||
"""Batch: per ogni variante calcola max score (rescored bg), ritorna
|
||
array float32 (n_vars,). Parallelismo prange ESTERNO sulle varianti
|
||
elimina overhead di n_vars chiamate JIT separate (avg ~20us per
|
||
chiamata su template piccoli) + pool thread Python.
|
||
|
||
Pensato per fase TOP del pruning quando n_vars >> n_threads.
|
||
"""
|
||
n_vars = offsets.shape[0] - 1
|
||
H, W = spread.shape
|
||
out = np.zeros(n_vars, dtype=np.float32)
|
||
for vi in nb.prange(n_vars):
|
||
i0 = offsets[vi]; i1 = offsets[vi + 1]
|
||
N = i1 - i0
|
||
if N == 0:
|
||
out[vi] = -1.0
|
||
continue
|
||
si = scale_idx[vi]
|
||
inv = nb.float32(1.0 / N)
|
||
best = nb.float32(-1.0)
|
||
for y in range(H):
|
||
for x in range(W):
|
||
s = nb.float32(0.0)
|
||
for k in range(N):
|
||
b = bins_flat[i0 + k]
|
||
mask = np.uint8(1) << b
|
||
if (bit_active & mask) == 0:
|
||
continue
|
||
ddy = dy_flat[i0 + k]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx_flat[i0 + k]
|
||
xx = x + ddx
|
||
if xx < 0 or xx >= W:
|
||
continue
|
||
if spread[yy, xx] & mask:
|
||
s += nb.float32(1.0)
|
||
s *= inv
|
||
bgv = bg_per_variant[si, y, x]
|
||
if bgv < 1.0:
|
||
r = (s - bgv) / (1.0 - bgv + 1e-6)
|
||
if r > best:
|
||
best = r
|
||
out[vi] = best if best > 0.0 else 0.0
|
||
return out
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_score_bitmap_rescored_u16(
|
||
spread: np.ndarray, # uint16 (H, W) - 16 bit di polarity-aware
|
||
dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bit_active: np.uint16,
|
||
bg: np.ndarray,
|
||
) -> np.ndarray:
|
||
"""Versione uint16 di _jit_score_bitmap_rescored per polarity 16-bin.
|
||
|
||
Identica logica ma mask = uint16(1) << b dove b in [0..15]
|
||
(orientamento mod 2π invece di mod π).
|
||
"""
|
||
H, W = spread.shape
|
||
N = dx.shape[0]
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for i in range(N):
|
||
b = bins[i]
|
||
mask = np.uint16(1) << b
|
||
if (bit_active & mask) == 0:
|
||
continue
|
||
ddy = dy[i]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx[i]
|
||
x_lo = 0 if ddx >= 0 else -ddx
|
||
x_hi = W if ddx <= 0 else W - ddx
|
||
for x in range(x_lo, x_hi):
|
||
if spread[yy, x + ddx] & mask:
|
||
acc[y, x] += 1.0
|
||
if N > 0:
|
||
inv = 1.0 / N
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
v = acc[y, x] * inv
|
||
bgv = bg[y, x]
|
||
if bgv < 1.0:
|
||
r = (v - bgv) / (1.0 - bgv + 1e-6)
|
||
acc[y, x] = r if r > 0.0 else 0.0
|
||
else:
|
||
acc[y, x] = 0.0
|
||
return acc
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_popcount_density_u16(spread: np.ndarray) -> np.ndarray:
|
||
"""Popcount per uint16 (16 bin polarity)."""
|
||
H, W = spread.shape
|
||
out = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
v = spread[y, x]
|
||
cnt = 0
|
||
for b in range(16):
|
||
if v & (np.uint16(1) << b):
|
||
cnt += 1
|
||
out[y, x] = float(cnt)
|
||
return out
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_popcount_density(spread: np.ndarray) -> np.ndarray:
|
||
"""Conta bit set per pixel: ritorna (H, W) float32 in [0..8]."""
|
||
H, W = spread.shape
|
||
out = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
v = spread[y, x]
|
||
# popcount manuale
|
||
v = (v & 0x55) + ((v >> 1) & 0x55)
|
||
v = (v & 0x33) + ((v >> 2) & 0x33)
|
||
v = (v & 0x0F) + ((v >> 4) & 0x0F)
|
||
out[y, x] = float(v)
|
||
return out
|
||
|
||
def _warmup():
|
||
resp = np.zeros((8, 32, 32), dtype=np.float32)
|
||
dx = np.zeros(1, dtype=np.int32)
|
||
dy = np.zeros(1, dtype=np.int32)
|
||
b = np.zeros(1, dtype=np.int8)
|
||
ba = np.ones(8, dtype=np.bool_)
|
||
_jit_score_by_shift(resp, dx, dy, b, ba)
|
||
spread = np.zeros((32, 32), dtype=np.uint8)
|
||
_jit_score_bitmap(spread, dx, dy, b, np.uint8(0xFF))
|
||
bg = np.zeros((32, 32), dtype=np.float32)
|
||
_jit_score_bitmap_rescored(spread, dx, dy, b, np.uint8(0xFF), bg)
|
||
_jit_score_bitmap_rescored_strided(
|
||
spread, dx, dy, b, np.uint8(0xFF), bg, np.int32(2),
|
||
)
|
||
_jit_score_bitmap_greedy(
|
||
spread, dx, dy, b, np.uint8(0xFF),
|
||
np.float32(0.5), np.float32(0.8),
|
||
)
|
||
offsets = np.array([0, 1], dtype=np.int32)
|
||
scale_idx = np.zeros(1, dtype=np.int32)
|
||
bg_pv = np.zeros((1, 32, 32), dtype=np.float32)
|
||
_jit_top_max_per_variant(
|
||
spread, dx, dy, b, offsets, np.uint8(0xFF), bg_pv, scale_idx,
|
||
)
|
||
_jit_popcount_density(spread)
|
||
spread16 = np.zeros((32, 32), dtype=np.uint16)
|
||
_jit_score_bitmap_rescored_u16(
|
||
spread16, dx, dy, b, np.uint16(0xFFFF), bg,
|
||
)
|
||
_jit_popcount_density_u16(spread16)
|
||
|
||
else: # pragma: no cover
|
||
|
||
def _jit_score_by_shift(resp, dx, dy, bins, bin_active):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_score_bitmap(spread, dx, dy, bins, bit_active):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_score_bitmap_rescored(spread, dx, dy, bins, bit_active, bg):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_score_bitmap_rescored_strided(spread, dx, dy, bins, bit_active, bg, stride):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_score_bitmap_greedy(spread, dx, dy, bins, bit_active, min_score, greediness):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_top_max_per_variant(
|
||
spread, dx_flat, dy_flat, bins_flat, offsets, bit_active,
|
||
bg_per_variant, scale_idx,
|
||
):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_score_bitmap_rescored_u16(spread, dx, dy, bins, bit_active, bg):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_popcount_density_u16(spread):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_popcount_density(spread):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _warmup():
|
||
pass
|
||
|
||
|
||
def score_bitmap(
|
||
spread: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bit_active: int,
|
||
) -> np.ndarray:
|
||
"""Dispatch bitmap: JIT se numba, fallback numpy."""
|
||
if HAS_NUMBA and len(dx) > 0:
|
||
return _jit_score_bitmap(
|
||
np.ascontiguousarray(spread, dtype=np.uint8),
|
||
np.ascontiguousarray(dx, dtype=np.int32),
|
||
np.ascontiguousarray(dy, dtype=np.int32),
|
||
np.ascontiguousarray(bins, dtype=np.int8),
|
||
np.uint8(bit_active),
|
||
)
|
||
# Fallback numpy (lento): converte bitmap a response 3D
|
||
H, W = spread.shape
|
||
resp = np.zeros((8, H, W), dtype=np.float32)
|
||
for b in range(8):
|
||
resp[b] = ((spread >> b) & 1).astype(np.float32)
|
||
return _numpy_score_by_shift(resp, dx, dy, bins, None)
|
||
|
||
|
||
def score_bitmap_rescored(
|
||
spread: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bit_active: int, bg: np.ndarray, stride: int = 1,
|
||
) -> np.ndarray:
|
||
"""Score bitmap + rescore fusi in un solo pass (JIT).
|
||
|
||
Dispatch per dtype: uint16 → kernel polarity 16-bin, uint8 → kernel
|
||
standard 8-bin (con eventuale stride > 1 per coarse top-level).
|
||
"""
|
||
if HAS_NUMBA and len(dx) > 0:
|
||
dx_c = np.ascontiguousarray(dx, dtype=np.int32)
|
||
dy_c = np.ascontiguousarray(dy, dtype=np.int32)
|
||
bins_c = np.ascontiguousarray(bins, dtype=np.int8)
|
||
bg_c = np.ascontiguousarray(bg, dtype=np.float32)
|
||
if spread.dtype == np.uint16:
|
||
spread_c = np.ascontiguousarray(spread, dtype=np.uint16)
|
||
return _jit_score_bitmap_rescored_u16(
|
||
spread_c, dx_c, dy_c, bins_c, np.uint16(bit_active), bg_c,
|
||
)
|
||
spread_c = np.ascontiguousarray(spread, dtype=np.uint8)
|
||
if stride > 1:
|
||
return _jit_score_bitmap_rescored_strided(
|
||
spread_c, dx_c, dy_c, bins_c, np.uint8(bit_active), bg_c,
|
||
np.int32(stride),
|
||
)
|
||
return _jit_score_bitmap_rescored(
|
||
spread_c, dx_c, dy_c, bins_c, np.uint8(bit_active), bg_c,
|
||
)
|
||
# Fallback: chiamate separate (stride ignorato in fallback)
|
||
score = score_bitmap(spread, dx, dy, bins, bit_active)
|
||
out = (score - bg) / (1.0 - bg + 1e-6)
|
||
return np.maximum(0.0, out).astype(np.float32)
|
||
|
||
|
||
def score_bitmap_greedy(
|
||
spread: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bit_active: int, min_score: float, greediness: float,
|
||
) -> np.ndarray:
|
||
"""Score bitmap con early-exit greedy. Per coarse-pass aggressivo.
|
||
|
||
Non applica rescore background: usare quando la scena ha basso clutter
|
||
o quando si vuole mass-prune varianti via top-level rapidamente.
|
||
"""
|
||
if HAS_NUMBA and len(dx) > 0:
|
||
return _jit_score_bitmap_greedy(
|
||
np.ascontiguousarray(spread, dtype=np.uint8),
|
||
np.ascontiguousarray(dx, dtype=np.int32),
|
||
np.ascontiguousarray(dy, dtype=np.int32),
|
||
np.ascontiguousarray(bins, dtype=np.int8),
|
||
np.uint8(bit_active),
|
||
np.float32(min_score), np.float32(greediness),
|
||
)
|
||
# Fallback: kernel base senza early-exit
|
||
return score_bitmap(spread, dx, dy, bins, bit_active)
|
||
|
||
|
||
def top_max_per_variant(
|
||
spread: np.ndarray,
|
||
dx_list: list, dy_list: list, bin_list: list,
|
||
bg_per_scale: dict,
|
||
variant_scales: list,
|
||
bit_active: int,
|
||
) -> np.ndarray:
|
||
"""Wrapper: prepara buffer flat e chiama kernel batch su tutte le varianti.
|
||
|
||
Parallelismo Numba prange-esterno sulle varianti (n_vars >> n_threads
|
||
tipicamente per top-pruning) → meglio del thread-pool Python che paga
|
||
overhead di n_vars chiamate JIT separate.
|
||
"""
|
||
if not HAS_NUMBA or len(dx_list) == 0:
|
||
return np.array([], dtype=np.float32)
|
||
n_vars = len(dx_list)
|
||
sizes = [len(d) for d in dx_list]
|
||
offsets = np.zeros(n_vars + 1, dtype=np.int32)
|
||
offsets[1:] = np.cumsum(sizes)
|
||
total = int(offsets[-1])
|
||
dx_flat = np.empty(total, dtype=np.int32)
|
||
dy_flat = np.empty(total, dtype=np.int32)
|
||
bins_flat = np.empty(total, dtype=np.int8)
|
||
for vi, (dx, dy, bn) in enumerate(zip(dx_list, dy_list, bin_list)):
|
||
i0 = int(offsets[vi]); i1 = int(offsets[vi + 1])
|
||
dx_flat[i0:i1] = dx
|
||
dy_flat[i0:i1] = dy
|
||
bins_flat[i0:i1] = bn
|
||
# bg per variante: indicizzato per scala
|
||
scales_unique = sorted(bg_per_scale.keys())
|
||
scale_to_idx = {s: i for i, s in enumerate(scales_unique)}
|
||
H, W = spread.shape
|
||
bg_pv = np.empty((len(scales_unique), H, W), dtype=np.float32)
|
||
for s, idx in scale_to_idx.items():
|
||
bg_pv[idx] = bg_per_scale[s]
|
||
scale_idx = np.array(
|
||
[scale_to_idx[s] for s in variant_scales], dtype=np.int32,
|
||
)
|
||
return _jit_top_max_per_variant(
|
||
np.ascontiguousarray(spread, dtype=np.uint8),
|
||
dx_flat, dy_flat, bins_flat, offsets, np.uint8(bit_active),
|
||
bg_pv, scale_idx,
|
||
)
|
||
|
||
|
||
_HAS_NP_BITCOUNT = hasattr(np, "bitwise_count")
|
||
|
||
|
||
def popcount_density(spread: np.ndarray) -> np.ndarray:
|
||
"""Conta bit set per pixel.
|
||
|
||
Order:
|
||
1) Numba JIT parallel (preferito: piu veloce su 1080p, 0.5ms vs 1.6ms)
|
||
2) numpy.bitwise_count (NumPy 2.0+, SIMD ma single-thread)
|
||
3) Fallback numpy bit-shift puro
|
||
"""
|
||
if spread.dtype == np.uint16:
|
||
spread_c = np.ascontiguousarray(spread, dtype=np.uint16)
|
||
if HAS_NUMBA:
|
||
return _jit_popcount_density_u16(spread_c)
|
||
if _HAS_NP_BITCOUNT:
|
||
return np.bitwise_count(spread_c).astype(np.float32, copy=False)
|
||
H, W = spread_c.shape
|
||
out = np.zeros((H, W), dtype=np.float32)
|
||
for b in range(16):
|
||
out += ((spread_c >> b) & 1).astype(np.float32)
|
||
return out
|
||
spread_c = np.ascontiguousarray(spread, dtype=np.uint8)
|
||
if HAS_NUMBA:
|
||
return _jit_popcount_density(spread_c)
|
||
if _HAS_NP_BITCOUNT:
|
||
return np.bitwise_count(spread_c).astype(np.float32, copy=False)
|
||
H, W = spread.shape
|
||
out = np.zeros((H, W), dtype=np.float32)
|
||
for b in range(8):
|
||
out += ((spread >> b) & 1).astype(np.float32)
|
||
return out
|
||
|
||
|
||
def score_by_shift(
|
||
resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bin_has_data: np.ndarray | None = None,
|
||
) -> np.ndarray:
|
||
"""Dispatch: JIT se possibile, fallback numpy."""
|
||
if not HAS_NUMBA or len(dx) == 0:
|
||
return _numpy_score_by_shift(resp, dx, dy, bins, bin_has_data)
|
||
# Normalizza tipi per Numba
|
||
resp_f = np.ascontiguousarray(resp, dtype=np.float32)
|
||
dx_i = np.ascontiguousarray(dx, dtype=np.int32)
|
||
dy_i = np.ascontiguousarray(dy, dtype=np.int32)
|
||
bins_i = np.ascontiguousarray(bins, dtype=np.int8)
|
||
if bin_has_data is None:
|
||
bin_active = np.ones(resp.shape[0], dtype=np.bool_)
|
||
else:
|
||
bin_active = np.ascontiguousarray(bin_has_data, dtype=np.bool_)
|
||
return _jit_score_by_shift(resp_f, dx_i, dy_i, bins_i, bin_active)
|