27a0ef1a45
Nuovo kernel JIT _jit_score_bitmap_rescored_strided: valuta solo pixel su griglia stride x stride al top della piramide. NMS + fase full-res recuperano precisione. Speed-up ~stride^2 sulla fase coarse, specie su scene grandi (1920x1080). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
350 lines
13 KiB
Python
350 lines
13 KiB
Python
"""Numba JIT kernels per hot-path del matching.
|
||
|
||
Fallback automatico a numpy se numba non disponibile o kernel fallisce.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import numpy as np
|
||
|
||
try:
|
||
import numba as nb
|
||
HAS_NUMBA = True
|
||
except ImportError: # pragma: no cover
|
||
HAS_NUMBA = False
|
||
|
||
|
||
def _numpy_score_by_shift(
|
||
resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bin_has_data: np.ndarray | None = None,
|
||
) -> np.ndarray:
|
||
"""Fallback numpy (stessa logica di LineShapeMatcher._score_by_shift)."""
|
||
_, H, W = resp.shape
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
n = len(dx)
|
||
for i in range(n):
|
||
b = int(bins[i])
|
||
if bin_has_data is not None and not bin_has_data[b]:
|
||
continue
|
||
ddx = int(dx[i]); ddy = int(dy[i])
|
||
y0s = max(0, -ddy); y1s = min(H, H - ddy)
|
||
x0s = max(0, -ddx); x1s = min(W, W - ddx)
|
||
if y0s >= y1s or x0s >= x1s:
|
||
continue
|
||
y0r = y0s + ddy; y1r = y1s + ddy
|
||
x0r = x0s + ddx; x1r = x1s + ddx
|
||
acc[y0s:y1s, x0s:x1s] += resp[b, y0r:y1r, x0r:x1r]
|
||
if n > 0:
|
||
acc /= n
|
||
return acc
|
||
|
||
|
||
if HAS_NUMBA:
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_score_by_shift(
|
||
resp: np.ndarray, # float32 (N_BINS, H, W)
|
||
dx: np.ndarray, # int32 (N,)
|
||
dy: np.ndarray, # int32 (N,)
|
||
bins: np.ndarray, # int8 (N,)
|
||
bin_active: np.ndarray, # bool_ (N_BINS,)
|
||
) -> np.ndarray:
|
||
_, H, W = resp.shape
|
||
N = dx.shape[0]
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for i in range(N):
|
||
b = bins[i]
|
||
if not bin_active[b]:
|
||
continue
|
||
ddy = dy[i]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx[i]
|
||
x_lo = 0 if ddx >= 0 else -ddx
|
||
x_hi = W if ddx <= 0 else W - ddx
|
||
for x in range(x_lo, x_hi):
|
||
acc[y, x] += resp[b, yy, x + ddx]
|
||
if N > 0:
|
||
inv = 1.0 / N
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
acc[y, x] *= inv
|
||
return acc
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_score_bitmap(
|
||
spread: np.ndarray, # uint8 (H, W), bit b = bin b attivo
|
||
dx: np.ndarray, # int32 (N,)
|
||
dy: np.ndarray, # int32 (N,)
|
||
bins: np.ndarray, # int8 (N,) bin per ogni feature
|
||
bit_active: np.uint8, # bitmask bin attivi in scena
|
||
) -> np.ndarray:
|
||
"""score[y,x] = (Σ_i [bit bins[i] acceso in spread[y+dy_i, x+dx_i]]) / N.
|
||
|
||
32× meno memoria di response map float32 → cache-friendly.
|
||
"""
|
||
H, W = spread.shape
|
||
N = dx.shape[0]
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for i in range(N):
|
||
b = bins[i]
|
||
mask = np.uint8(1) << b
|
||
if (bit_active & mask) == 0:
|
||
continue
|
||
ddy = dy[i]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx[i]
|
||
x_lo = 0 if ddx >= 0 else -ddx
|
||
x_hi = W if ddx <= 0 else W - ddx
|
||
for x in range(x_lo, x_hi):
|
||
if spread[yy, x + ddx] & mask:
|
||
acc[y, x] += 1.0
|
||
if N > 0:
|
||
inv = 1.0 / N
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
acc[y, x] *= inv
|
||
return acc
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_score_bitmap_rescored_strided(
|
||
spread: np.ndarray,
|
||
dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bit_active: np.uint8,
|
||
bg: np.ndarray,
|
||
stride: nb.int32,
|
||
) -> np.ndarray:
|
||
"""Variante con sub-sampling: valuta solo pixel su griglia stride×stride.
|
||
Score restituito ha stessa shape (H, W); celle non valutate = 0.
|
||
|
||
4× speed-up con stride=2 (NMS recupera precisione in full-res).
|
||
Numba prange richiede step costante: itero su indici griglia e
|
||
moltiplico per stride dentro il body.
|
||
"""
|
||
H, W = spread.shape
|
||
N = dx.shape[0]
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
ny = (H + stride - 1) // stride
|
||
nx = (W + stride - 1) // stride
|
||
for yi in nb.prange(ny):
|
||
y = yi * stride
|
||
for i in range(N):
|
||
b = bins[i]
|
||
mask = np.uint8(1) << b
|
||
if (bit_active & mask) == 0:
|
||
continue
|
||
ddy = dy[i]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx[i]
|
||
x_lo = 0 if ddx >= 0 else -ddx
|
||
x_hi = W if ddx <= 0 else W - ddx
|
||
rem = x_lo % stride
|
||
if rem != 0:
|
||
x_lo += stride - rem
|
||
x = x_lo
|
||
while x < x_hi:
|
||
if spread[yy, x + ddx] & mask:
|
||
acc[y, x] += 1.0
|
||
x += stride
|
||
if N > 0:
|
||
inv = 1.0 / N
|
||
for yi in nb.prange(ny):
|
||
y = yi * stride
|
||
for xi in range(nx):
|
||
x = xi * stride
|
||
v = acc[y, x] * inv
|
||
bgv = bg[y, x]
|
||
if bgv < 1.0:
|
||
r = (v - bgv) / (1.0 - bgv + 1e-6)
|
||
acc[y, x] = r if r > 0.0 else 0.0
|
||
else:
|
||
acc[y, x] = 0.0
|
||
return acc
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_score_bitmap_rescored(
|
||
spread: np.ndarray, # uint8 (H, W)
|
||
dx: np.ndarray, # int32 (N,)
|
||
dy: np.ndarray, # int32 (N,)
|
||
bins: np.ndarray, # int8 (N,)
|
||
bit_active: np.uint8,
|
||
bg: np.ndarray, # float32 (H, W) background density normalizzata
|
||
) -> np.ndarray:
|
||
"""score+rescore in un singolo pass: evita allocazione intermedia.
|
||
|
||
Equivalente a:
|
||
score = _jit_score_bitmap(...)
|
||
out = max(0, (score - bg) / (1 - bg + 1e-6))
|
||
ma fonde la seconda passata dentro la normalizzazione finale
|
||
(cache-friendly, risparmia ~15% sul totale find).
|
||
"""
|
||
H, W = spread.shape
|
||
N = dx.shape[0]
|
||
acc = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for i in range(N):
|
||
b = bins[i]
|
||
mask = np.uint8(1) << b
|
||
if (bit_active & mask) == 0:
|
||
continue
|
||
ddy = dy[i]
|
||
yy = y + ddy
|
||
if yy < 0 or yy >= H:
|
||
continue
|
||
ddx = dx[i]
|
||
x_lo = 0 if ddx >= 0 else -ddx
|
||
x_hi = W if ddx <= 0 else W - ddx
|
||
for x in range(x_lo, x_hi):
|
||
if spread[yy, x + ddx] & mask:
|
||
acc[y, x] += 1.0
|
||
if N > 0:
|
||
inv = 1.0 / N
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
v = acc[y, x] * inv
|
||
bgv = bg[y, x]
|
||
if bgv < 1.0:
|
||
r = (v - bgv) / (1.0 - bgv + 1e-6)
|
||
acc[y, x] = r if r > 0.0 else 0.0
|
||
else:
|
||
acc[y, x] = 0.0
|
||
return acc
|
||
|
||
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
|
||
def _jit_popcount_density(spread: np.ndarray) -> np.ndarray:
|
||
"""Conta bit set per pixel: ritorna (H, W) float32 in [0..8]."""
|
||
H, W = spread.shape
|
||
out = np.zeros((H, W), dtype=np.float32)
|
||
for y in nb.prange(H):
|
||
for x in range(W):
|
||
v = spread[y, x]
|
||
# popcount manuale
|
||
v = (v & 0x55) + ((v >> 1) & 0x55)
|
||
v = (v & 0x33) + ((v >> 2) & 0x33)
|
||
v = (v & 0x0F) + ((v >> 4) & 0x0F)
|
||
out[y, x] = float(v)
|
||
return out
|
||
|
||
def _warmup():
|
||
resp = np.zeros((8, 32, 32), dtype=np.float32)
|
||
dx = np.zeros(1, dtype=np.int32)
|
||
dy = np.zeros(1, dtype=np.int32)
|
||
b = np.zeros(1, dtype=np.int8)
|
||
ba = np.ones(8, dtype=np.bool_)
|
||
_jit_score_by_shift(resp, dx, dy, b, ba)
|
||
spread = np.zeros((32, 32), dtype=np.uint8)
|
||
_jit_score_bitmap(spread, dx, dy, b, np.uint8(0xFF))
|
||
bg = np.zeros((32, 32), dtype=np.float32)
|
||
_jit_score_bitmap_rescored(spread, dx, dy, b, np.uint8(0xFF), bg)
|
||
_jit_score_bitmap_rescored_strided(
|
||
spread, dx, dy, b, np.uint8(0xFF), bg, np.int32(2),
|
||
)
|
||
_jit_popcount_density(spread)
|
||
|
||
else: # pragma: no cover
|
||
|
||
def _jit_score_by_shift(resp, dx, dy, bins, bin_active):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_score_bitmap(spread, dx, dy, bins, bit_active):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_score_bitmap_rescored(spread, dx, dy, bins, bit_active, bg):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_score_bitmap_rescored_strided(spread, dx, dy, bins, bit_active, bg, stride):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _jit_popcount_density(spread):
|
||
raise RuntimeError("numba non disponibile")
|
||
|
||
def _warmup():
|
||
pass
|
||
|
||
|
||
def score_bitmap(
|
||
spread: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bit_active: int,
|
||
) -> np.ndarray:
|
||
"""Dispatch bitmap: JIT se numba, fallback numpy."""
|
||
if HAS_NUMBA and len(dx) > 0:
|
||
return _jit_score_bitmap(
|
||
np.ascontiguousarray(spread, dtype=np.uint8),
|
||
np.ascontiguousarray(dx, dtype=np.int32),
|
||
np.ascontiguousarray(dy, dtype=np.int32),
|
||
np.ascontiguousarray(bins, dtype=np.int8),
|
||
np.uint8(bit_active),
|
||
)
|
||
# Fallback numpy (lento): converte bitmap a response 3D
|
||
H, W = spread.shape
|
||
resp = np.zeros((8, H, W), dtype=np.float32)
|
||
for b in range(8):
|
||
resp[b] = ((spread >> b) & 1).astype(np.float32)
|
||
return _numpy_score_by_shift(resp, dx, dy, bins, None)
|
||
|
||
|
||
def score_bitmap_rescored(
|
||
spread: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bit_active: int, bg: np.ndarray, stride: int = 1,
|
||
) -> np.ndarray:
|
||
"""Score bitmap + rescore fusi in un solo pass (JIT).
|
||
|
||
stride > 1: valuta solo pixel su griglia stride×stride. Le celle non
|
||
valutate restano 0 nello score map. Pensato per coarse-pass al top
|
||
della piramide; il refinement full-res poi recupera precisione.
|
||
"""
|
||
if HAS_NUMBA and len(dx) > 0:
|
||
spread_c = np.ascontiguousarray(spread, dtype=np.uint8)
|
||
dx_c = np.ascontiguousarray(dx, dtype=np.int32)
|
||
dy_c = np.ascontiguousarray(dy, dtype=np.int32)
|
||
bins_c = np.ascontiguousarray(bins, dtype=np.int8)
|
||
bg_c = np.ascontiguousarray(bg, dtype=np.float32)
|
||
if stride > 1:
|
||
return _jit_score_bitmap_rescored_strided(
|
||
spread_c, dx_c, dy_c, bins_c, np.uint8(bit_active), bg_c,
|
||
np.int32(stride),
|
||
)
|
||
return _jit_score_bitmap_rescored(
|
||
spread_c, dx_c, dy_c, bins_c, np.uint8(bit_active), bg_c,
|
||
)
|
||
# Fallback: chiamate separate (stride ignorato in fallback)
|
||
score = score_bitmap(spread, dx, dy, bins, bit_active)
|
||
out = (score - bg) / (1.0 - bg + 1e-6)
|
||
return np.maximum(0.0, out).astype(np.float32)
|
||
|
||
|
||
def popcount_density(spread: np.ndarray) -> np.ndarray:
|
||
if HAS_NUMBA:
|
||
return _jit_popcount_density(np.ascontiguousarray(spread, dtype=np.uint8))
|
||
# Fallback
|
||
H, W = spread.shape
|
||
out = np.zeros((H, W), dtype=np.float32)
|
||
for b in range(8):
|
||
out += ((spread >> b) & 1).astype(np.float32)
|
||
return out
|
||
|
||
|
||
def score_by_shift(
|
||
resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
|
||
bin_has_data: np.ndarray | None = None,
|
||
) -> np.ndarray:
|
||
"""Dispatch: JIT se possibile, fallback numpy."""
|
||
if not HAS_NUMBA or len(dx) == 0:
|
||
return _numpy_score_by_shift(resp, dx, dy, bins, bin_has_data)
|
||
# Normalizza tipi per Numba
|
||
resp_f = np.ascontiguousarray(resp, dtype=np.float32)
|
||
dx_i = np.ascontiguousarray(dx, dtype=np.int32)
|
||
dy_i = np.ascontiguousarray(dy, dtype=np.int32)
|
||
bins_i = np.ascontiguousarray(bins, dtype=np.int8)
|
||
if bin_has_data is None:
|
||
bin_active = np.ones(resp.shape[0], dtype=np.bool_)
|
||
else:
|
||
bin_active = np.ascontiguousarray(bin_has_data, dtype=np.bool_)
|
||
return _jit_score_by_shift(resp_f, dx_i, dy_i, bins_i, bin_active)
|