Files
Shape_Model_2D/pm2d/_jit_kernels.py
T
Adriano ba54b42fdc perf: spread bitmap uint8 + pre-NMS prima refine (3.5x globale, 49x worst case)
Due ottimizzazioni chiave:

1. Spread bitmap uint8 invece di response map (N_BINS, H, W) float32
   - 32x meno memoria, cache-friendly
   - Nuovi kernel Numba: _jit_score_bitmap, _jit_popcount_density
   - Formato: spread[y,x] bit b = bin b attivo nel raggio di spread
   - _refine_angle usa slicing su bitmap con mask & bit

2. Pre-NMS prima di refine_angle/verify_ncc
   - Problema: loop 'for raw in candidati' applicava refine+verify A OGNI
     candidato prima del check NMS → 2000+ refine chiamati per ~25 match
   - Fix: pre-NMS su (cx, cy) subpixel, limita a max_matches*3 candidati,
     poi refine + verify solo su quelli
   - Esempio worst case: lama_full_fast 55.9s → 1.13s (49x)

Benchmark suite 16 scenari (4 immagini x full/part x fast/preciso):
  prima: totale find 94.6s
  dopo:  totale find 27.3s (3.5x globale)
  casi peggiori <5s (prima erano >50s)

ROI parziali (solo metà oggetto) funzionano in tutti i casi.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-24 02:11:33 +02:00

203 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Numba JIT kernels per hot-path del matching.
Fallback automatico a numpy se numba non disponibile o kernel fallisce.
"""
from __future__ import annotations
import numpy as np
try:
import numba as nb
HAS_NUMBA = True
except ImportError: # pragma: no cover
HAS_NUMBA = False
def _numpy_score_by_shift(
resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
bin_has_data: np.ndarray | None = None,
) -> np.ndarray:
"""Fallback numpy (stessa logica di LineShapeMatcher._score_by_shift)."""
_, H, W = resp.shape
acc = np.zeros((H, W), dtype=np.float32)
n = len(dx)
for i in range(n):
b = int(bins[i])
if bin_has_data is not None and not bin_has_data[b]:
continue
ddx = int(dx[i]); ddy = int(dy[i])
y0s = max(0, -ddy); y1s = min(H, H - ddy)
x0s = max(0, -ddx); x1s = min(W, W - ddx)
if y0s >= y1s or x0s >= x1s:
continue
y0r = y0s + ddy; y1r = y1s + ddy
x0r = x0s + ddx; x1r = x1s + ddx
acc[y0s:y1s, x0s:x1s] += resp[b, y0r:y1r, x0r:x1r]
if n > 0:
acc /= n
return acc
if HAS_NUMBA:
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
def _jit_score_by_shift(
resp: np.ndarray, # float32 (N_BINS, H, W)
dx: np.ndarray, # int32 (N,)
dy: np.ndarray, # int32 (N,)
bins: np.ndarray, # int8 (N,)
bin_active: np.ndarray, # bool_ (N_BINS,)
) -> np.ndarray:
_, H, W = resp.shape
N = dx.shape[0]
acc = np.zeros((H, W), dtype=np.float32)
for y in nb.prange(H):
for i in range(N):
b = bins[i]
if not bin_active[b]:
continue
ddy = dy[i]
yy = y + ddy
if yy < 0 or yy >= H:
continue
ddx = dx[i]
x_lo = 0 if ddx >= 0 else -ddx
x_hi = W if ddx <= 0 else W - ddx
for x in range(x_lo, x_hi):
acc[y, x] += resp[b, yy, x + ddx]
if N > 0:
inv = 1.0 / N
for y in nb.prange(H):
for x in range(W):
acc[y, x] *= inv
return acc
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
def _jit_score_bitmap(
spread: np.ndarray, # uint8 (H, W), bit b = bin b attivo
dx: np.ndarray, # int32 (N,)
dy: np.ndarray, # int32 (N,)
bins: np.ndarray, # int8 (N,) bin per ogni feature
bit_active: np.uint8, # bitmask bin attivi in scena
) -> np.ndarray:
"""score[y,x] = (Σ_i [bit bins[i] acceso in spread[y+dy_i, x+dx_i]]) / N.
32× meno memoria di response map float32 → cache-friendly.
"""
H, W = spread.shape
N = dx.shape[0]
acc = np.zeros((H, W), dtype=np.float32)
for y in nb.prange(H):
for i in range(N):
b = bins[i]
mask = np.uint8(1) << b
if (bit_active & mask) == 0:
continue
ddy = dy[i]
yy = y + ddy
if yy < 0 or yy >= H:
continue
ddx = dx[i]
x_lo = 0 if ddx >= 0 else -ddx
x_hi = W if ddx <= 0 else W - ddx
for x in range(x_lo, x_hi):
if spread[yy, x + ddx] & mask:
acc[y, x] += 1.0
if N > 0:
inv = 1.0 / N
for y in nb.prange(H):
for x in range(W):
acc[y, x] *= inv
return acc
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
def _jit_popcount_density(spread: np.ndarray) -> np.ndarray:
"""Conta bit set per pixel: ritorna (H, W) float32 in [0..8]."""
H, W = spread.shape
out = np.zeros((H, W), dtype=np.float32)
for y in nb.prange(H):
for x in range(W):
v = spread[y, x]
# popcount manuale
v = (v & 0x55) + ((v >> 1) & 0x55)
v = (v & 0x33) + ((v >> 2) & 0x33)
v = (v & 0x0F) + ((v >> 4) & 0x0F)
out[y, x] = float(v)
return out
def _warmup():
resp = np.zeros((8, 32, 32), dtype=np.float32)
dx = np.zeros(1, dtype=np.int32)
dy = np.zeros(1, dtype=np.int32)
b = np.zeros(1, dtype=np.int8)
ba = np.ones(8, dtype=np.bool_)
_jit_score_by_shift(resp, dx, dy, b, ba)
spread = np.zeros((32, 32), dtype=np.uint8)
_jit_score_bitmap(spread, dx, dy, b, np.uint8(0xFF))
_jit_popcount_density(spread)
else: # pragma: no cover
def _jit_score_by_shift(resp, dx, dy, bins, bin_active):
raise RuntimeError("numba non disponibile")
def _jit_score_bitmap(spread, dx, dy, bins, bit_active):
raise RuntimeError("numba non disponibile")
def _jit_popcount_density(spread):
raise RuntimeError("numba non disponibile")
def _warmup():
pass
def score_bitmap(
spread: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
bit_active: int,
) -> np.ndarray:
"""Dispatch bitmap: JIT se numba, fallback numpy."""
if HAS_NUMBA and len(dx) > 0:
return _jit_score_bitmap(
np.ascontiguousarray(spread, dtype=np.uint8),
np.ascontiguousarray(dx, dtype=np.int32),
np.ascontiguousarray(dy, dtype=np.int32),
np.ascontiguousarray(bins, dtype=np.int8),
np.uint8(bit_active),
)
# Fallback numpy (lento): converte bitmap a response 3D
H, W = spread.shape
resp = np.zeros((8, H, W), dtype=np.float32)
for b in range(8):
resp[b] = ((spread >> b) & 1).astype(np.float32)
return _numpy_score_by_shift(resp, dx, dy, bins, None)
def popcount_density(spread: np.ndarray) -> np.ndarray:
if HAS_NUMBA:
return _jit_popcount_density(np.ascontiguousarray(spread, dtype=np.uint8))
# Fallback
H, W = spread.shape
out = np.zeros((H, W), dtype=np.float32)
for b in range(8):
out += ((spread >> b) & 1).astype(np.float32)
return out
def score_by_shift(
resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
bin_has_data: np.ndarray | None = None,
) -> np.ndarray:
"""Dispatch: JIT se possibile, fallback numpy."""
if not HAS_NUMBA or len(dx) == 0:
return _numpy_score_by_shift(resp, dx, dy, bins, bin_has_data)
# Normalizza tipi per Numba
resp_f = np.ascontiguousarray(resp, dtype=np.float32)
dx_i = np.ascontiguousarray(dx, dtype=np.int32)
dy_i = np.ascontiguousarray(dy, dtype=np.int32)
bins_i = np.ascontiguousarray(bins, dtype=np.int8)
if bin_has_data is None:
bin_active = np.ones(resp.shape[0], dtype=np.bool_)
else:
bin_active = np.ascontiguousarray(bin_has_data, dtype=np.bool_)
return _jit_score_by_shift(resp_f, dx_i, dy_i, bins_i, bin_active)