Files
Shape_Model_2D/pm2d/_jit_kernels.py
T

548 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Numba JIT kernels per hot-path del matching.
Fallback automatico a numpy se numba non disponibile o kernel fallisce.
"""
from __future__ import annotations
import numpy as np
try:
import numba as nb
HAS_NUMBA = True
except ImportError: # pragma: no cover
HAS_NUMBA = False
def _numpy_score_by_shift(
resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
bin_has_data: np.ndarray | None = None,
) -> np.ndarray:
"""Fallback numpy (stessa logica di LineShapeMatcher._score_by_shift)."""
_, H, W = resp.shape
acc = np.zeros((H, W), dtype=np.float32)
n = len(dx)
for i in range(n):
b = int(bins[i])
if bin_has_data is not None and not bin_has_data[b]:
continue
ddx = int(dx[i]); ddy = int(dy[i])
y0s = max(0, -ddy); y1s = min(H, H - ddy)
x0s = max(0, -ddx); x1s = min(W, W - ddx)
if y0s >= y1s or x0s >= x1s:
continue
y0r = y0s + ddy; y1r = y1s + ddy
x0r = x0s + ddx; x1r = x1s + ddx
acc[y0s:y1s, x0s:x1s] += resp[b, y0r:y1r, x0r:x1r]
if n > 0:
acc /= n
return acc
if HAS_NUMBA:
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
def _jit_score_by_shift(
resp: np.ndarray, # float32 (N_BINS, H, W)
dx: np.ndarray, # int32 (N,)
dy: np.ndarray, # int32 (N,)
bins: np.ndarray, # int8 (N,)
bin_active: np.ndarray, # bool_ (N_BINS,)
) -> np.ndarray:
_, H, W = resp.shape
N = dx.shape[0]
acc = np.zeros((H, W), dtype=np.float32)
for y in nb.prange(H):
for i in range(N):
b = bins[i]
if not bin_active[b]:
continue
ddy = dy[i]
yy = y + ddy
if yy < 0 or yy >= H:
continue
ddx = dx[i]
x_lo = 0 if ddx >= 0 else -ddx
x_hi = W if ddx <= 0 else W - ddx
for x in range(x_lo, x_hi):
acc[y, x] += resp[b, yy, x + ddx]
if N > 0:
inv = 1.0 / N
for y in nb.prange(H):
for x in range(W):
acc[y, x] *= inv
return acc
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
def _jit_score_bitmap(
spread: np.ndarray, # uint8 (H, W), bit b = bin b attivo
dx: np.ndarray, # int32 (N,)
dy: np.ndarray, # int32 (N,)
bins: np.ndarray, # int8 (N,) bin per ogni feature
bit_active: np.uint8, # bitmask bin attivi in scena
) -> np.ndarray:
"""score[y,x] = (Σ_i [bit bins[i] acceso in spread[y+dy_i, x+dx_i]]) / N.
32× meno memoria di response map float32 → cache-friendly.
"""
H, W = spread.shape
N = dx.shape[0]
acc = np.zeros((H, W), dtype=np.float32)
for y in nb.prange(H):
for i in range(N):
b = bins[i]
mask = np.uint8(1) << b
if (bit_active & mask) == 0:
continue
ddy = dy[i]
yy = y + ddy
if yy < 0 or yy >= H:
continue
ddx = dx[i]
x_lo = 0 if ddx >= 0 else -ddx
x_hi = W if ddx <= 0 else W - ddx
for x in range(x_lo, x_hi):
if spread[yy, x + ddx] & mask:
acc[y, x] += 1.0
if N > 0:
inv = 1.0 / N
for y in nb.prange(H):
for x in range(W):
acc[y, x] *= inv
return acc
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
def _jit_score_bitmap_rescored_strided(
spread: np.ndarray,
dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
bit_active: np.uint8,
bg: np.ndarray,
stride: nb.int32,
) -> np.ndarray:
"""Variante con sub-sampling: valuta solo pixel su griglia stride×stride.
Score restituito ha stessa shape (H, W); celle non valutate = 0.
4× speed-up con stride=2 (NMS recupera precisione in full-res).
Numba prange richiede step costante: itero su indici griglia e
moltiplico per stride dentro il body.
"""
H, W = spread.shape
N = dx.shape[0]
acc = np.zeros((H, W), dtype=np.float32)
ny = (H + stride - 1) // stride
nx = (W + stride - 1) // stride
for yi in nb.prange(ny):
y = yi * stride
for i in range(N):
b = bins[i]
mask = np.uint8(1) << b
if (bit_active & mask) == 0:
continue
ddy = dy[i]
yy = y + ddy
if yy < 0 or yy >= H:
continue
ddx = dx[i]
x_lo = 0 if ddx >= 0 else -ddx
x_hi = W if ddx <= 0 else W - ddx
rem = x_lo % stride
if rem != 0:
x_lo += stride - rem
x = x_lo
while x < x_hi:
if spread[yy, x + ddx] & mask:
acc[y, x] += 1.0
x += stride
if N > 0:
inv = 1.0 / N
for yi in nb.prange(ny):
y = yi * stride
for xi in range(nx):
x = xi * stride
v = acc[y, x] * inv
bgv = bg[y, x]
if bgv < 1.0:
r = (v - bgv) / (1.0 - bgv + 1e-6)
acc[y, x] = r if r > 0.0 else 0.0
else:
acc[y, x] = 0.0
return acc
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
def _jit_score_bitmap_greedy(
spread: np.ndarray,
dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
bit_active: np.uint8,
min_score: nb.float32,
greediness: nb.float32,
) -> np.ndarray:
"""Score bitmap con early-exit greedy (no rescore background).
Per ogni pixel iteriamo le N feature; abortiamo non appena diventa
impossibile raggiungere `min_required` count anche aggiungendo
tutte le feature rimanenti. min_required = greediness * min_score * N.
greediness=0 → nessun early-exit (equivalente a kernel base).
greediness=1 → exit non appena hits + remaining < min_score * N.
Tipico: 0.7-0.9 → 2-4x speed-up senza perdere match.
"""
H, W = spread.shape
N = dx.shape[0]
acc = np.zeros((H, W), dtype=np.float32)
if N == 0:
return acc
min_req = greediness * min_score * N
inv_N = nb.float32(1.0 / N)
for y in nb.prange(H):
for x in range(W):
hits = 0
for i in range(N):
b = bins[i]
mask = np.uint8(1) << b
if (bit_active & mask) == 0:
if hits + (N - i - 1) < min_req:
break
continue
ddy = dy[i]
yy = y + ddy
if yy < 0 or yy >= H:
if hits + (N - i - 1) < min_req:
break
continue
ddx = dx[i]
xx = x + ddx
if xx < 0 or xx >= W:
if hits + (N - i - 1) < min_req:
break
continue
if spread[yy, xx] & mask:
hits += 1
else:
if hits + (N - i - 1) < min_req:
break
acc[y, x] = nb.float32(hits) * inv_N
return acc
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
def _jit_score_bitmap_rescored(
spread: np.ndarray, # uint8 (H, W)
dx: np.ndarray, # int32 (N,)
dy: np.ndarray, # int32 (N,)
bins: np.ndarray, # int8 (N,)
bit_active: np.uint8,
bg: np.ndarray, # float32 (H, W) background density normalizzata
) -> np.ndarray:
"""score+rescore in un singolo pass: evita allocazione intermedia.
Equivalente a:
score = _jit_score_bitmap(...)
out = max(0, (score - bg) / (1 - bg + 1e-6))
ma fonde la seconda passata dentro la normalizzazione finale
(cache-friendly, risparmia ~15% sul totale find).
"""
H, W = spread.shape
N = dx.shape[0]
acc = np.zeros((H, W), dtype=np.float32)
for y in nb.prange(H):
for i in range(N):
b = bins[i]
mask = np.uint8(1) << b
if (bit_active & mask) == 0:
continue
ddy = dy[i]
yy = y + ddy
if yy < 0 or yy >= H:
continue
ddx = dx[i]
x_lo = 0 if ddx >= 0 else -ddx
x_hi = W if ddx <= 0 else W - ddx
for x in range(x_lo, x_hi):
if spread[yy, x + ddx] & mask:
acc[y, x] += 1.0
if N > 0:
inv = 1.0 / N
for y in nb.prange(H):
for x in range(W):
v = acc[y, x] * inv
bgv = bg[y, x]
if bgv < 1.0:
r = (v - bgv) / (1.0 - bgv + 1e-6)
acc[y, x] = r if r > 0.0 else 0.0
else:
acc[y, x] = 0.0
return acc
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
def _jit_top_max_per_variant(
spread: np.ndarray, # uint8 (H, W)
dx_flat: np.ndarray, # int32 (sum_N,)
dy_flat: np.ndarray, # int32 (sum_N,)
bins_flat: np.ndarray, # int8 (sum_N,)
offsets: np.ndarray, # int32 (n_vars+1,) prefix sum
bit_active: np.uint8,
bg_per_variant: np.ndarray, # float32 (n_vars, H, W) - 1 per scala
scale_idx: np.ndarray, # int32 (n_vars,) idx in bg_per_variant
) -> np.ndarray:
"""Batch: per ogni variante calcola max score (rescored bg), ritorna
array float32 (n_vars,). Parallelismo prange ESTERNO sulle varianti
elimina overhead di n_vars chiamate JIT separate (avg ~20us per
chiamata su template piccoli) + pool thread Python.
Pensato per fase TOP del pruning quando n_vars >> n_threads.
"""
n_vars = offsets.shape[0] - 1
H, W = spread.shape
out = np.zeros(n_vars, dtype=np.float32)
for vi in nb.prange(n_vars):
i0 = offsets[vi]; i1 = offsets[vi + 1]
N = i1 - i0
if N == 0:
out[vi] = -1.0
continue
si = scale_idx[vi]
inv = nb.float32(1.0 / N)
best = nb.float32(-1.0)
for y in range(H):
for x in range(W):
s = nb.float32(0.0)
for k in range(N):
b = bins_flat[i0 + k]
mask = np.uint8(1) << b
if (bit_active & mask) == 0:
continue
ddy = dy_flat[i0 + k]
yy = y + ddy
if yy < 0 or yy >= H:
continue
ddx = dx_flat[i0 + k]
xx = x + ddx
if xx < 0 or xx >= W:
continue
if spread[yy, xx] & mask:
s += nb.float32(1.0)
s *= inv
bgv = bg_per_variant[si, y, x]
if bgv < 1.0:
r = (s - bgv) / (1.0 - bgv + 1e-6)
if r > best:
best = r
out[vi] = best if best > 0.0 else 0.0
return out
@nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False)
def _jit_popcount_density(spread: np.ndarray) -> np.ndarray:
"""Conta bit set per pixel: ritorna (H, W) float32 in [0..8]."""
H, W = spread.shape
out = np.zeros((H, W), dtype=np.float32)
for y in nb.prange(H):
for x in range(W):
v = spread[y, x]
# popcount manuale
v = (v & 0x55) + ((v >> 1) & 0x55)
v = (v & 0x33) + ((v >> 2) & 0x33)
v = (v & 0x0F) + ((v >> 4) & 0x0F)
out[y, x] = float(v)
return out
def _warmup():
resp = np.zeros((8, 32, 32), dtype=np.float32)
dx = np.zeros(1, dtype=np.int32)
dy = np.zeros(1, dtype=np.int32)
b = np.zeros(1, dtype=np.int8)
ba = np.ones(8, dtype=np.bool_)
_jit_score_by_shift(resp, dx, dy, b, ba)
spread = np.zeros((32, 32), dtype=np.uint8)
_jit_score_bitmap(spread, dx, dy, b, np.uint8(0xFF))
bg = np.zeros((32, 32), dtype=np.float32)
_jit_score_bitmap_rescored(spread, dx, dy, b, np.uint8(0xFF), bg)
_jit_score_bitmap_rescored_strided(
spread, dx, dy, b, np.uint8(0xFF), bg, np.int32(2),
)
_jit_score_bitmap_greedy(
spread, dx, dy, b, np.uint8(0xFF),
np.float32(0.5), np.float32(0.8),
)
offsets = np.array([0, 1], dtype=np.int32)
scale_idx = np.zeros(1, dtype=np.int32)
bg_pv = np.zeros((1, 32, 32), dtype=np.float32)
_jit_top_max_per_variant(
spread, dx, dy, b, offsets, np.uint8(0xFF), bg_pv, scale_idx,
)
_jit_popcount_density(spread)
else: # pragma: no cover
def _jit_score_by_shift(resp, dx, dy, bins, bin_active):
raise RuntimeError("numba non disponibile")
def _jit_score_bitmap(spread, dx, dy, bins, bit_active):
raise RuntimeError("numba non disponibile")
def _jit_score_bitmap_rescored(spread, dx, dy, bins, bit_active, bg):
raise RuntimeError("numba non disponibile")
def _jit_score_bitmap_rescored_strided(spread, dx, dy, bins, bit_active, bg, stride):
raise RuntimeError("numba non disponibile")
def _jit_score_bitmap_greedy(spread, dx, dy, bins, bit_active, min_score, greediness):
raise RuntimeError("numba non disponibile")
def _jit_top_max_per_variant(
spread, dx_flat, dy_flat, bins_flat, offsets, bit_active,
bg_per_variant, scale_idx,
):
raise RuntimeError("numba non disponibile")
def _jit_popcount_density(spread):
raise RuntimeError("numba non disponibile")
def _warmup():
pass
def score_bitmap(
spread: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
bit_active: int,
) -> np.ndarray:
"""Dispatch bitmap: JIT se numba, fallback numpy."""
if HAS_NUMBA and len(dx) > 0:
return _jit_score_bitmap(
np.ascontiguousarray(spread, dtype=np.uint8),
np.ascontiguousarray(dx, dtype=np.int32),
np.ascontiguousarray(dy, dtype=np.int32),
np.ascontiguousarray(bins, dtype=np.int8),
np.uint8(bit_active),
)
# Fallback numpy (lento): converte bitmap a response 3D
H, W = spread.shape
resp = np.zeros((8, H, W), dtype=np.float32)
for b in range(8):
resp[b] = ((spread >> b) & 1).astype(np.float32)
return _numpy_score_by_shift(resp, dx, dy, bins, None)
def score_bitmap_rescored(
spread: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
bit_active: int, bg: np.ndarray, stride: int = 1,
) -> np.ndarray:
"""Score bitmap + rescore fusi in un solo pass (JIT).
stride > 1: valuta solo pixel su griglia stride×stride. Le celle non
valutate restano 0 nello score map. Pensato per coarse-pass al top
della piramide; il refinement full-res poi recupera precisione.
"""
if HAS_NUMBA and len(dx) > 0:
spread_c = np.ascontiguousarray(spread, dtype=np.uint8)
dx_c = np.ascontiguousarray(dx, dtype=np.int32)
dy_c = np.ascontiguousarray(dy, dtype=np.int32)
bins_c = np.ascontiguousarray(bins, dtype=np.int8)
bg_c = np.ascontiguousarray(bg, dtype=np.float32)
if stride > 1:
return _jit_score_bitmap_rescored_strided(
spread_c, dx_c, dy_c, bins_c, np.uint8(bit_active), bg_c,
np.int32(stride),
)
return _jit_score_bitmap_rescored(
spread_c, dx_c, dy_c, bins_c, np.uint8(bit_active), bg_c,
)
# Fallback: chiamate separate (stride ignorato in fallback)
score = score_bitmap(spread, dx, dy, bins, bit_active)
out = (score - bg) / (1.0 - bg + 1e-6)
return np.maximum(0.0, out).astype(np.float32)
def score_bitmap_greedy(
spread: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
bit_active: int, min_score: float, greediness: float,
) -> np.ndarray:
"""Score bitmap con early-exit greedy. Per coarse-pass aggressivo.
Non applica rescore background: usare quando la scena ha basso clutter
o quando si vuole mass-prune varianti via top-level rapidamente.
"""
if HAS_NUMBA and len(dx) > 0:
return _jit_score_bitmap_greedy(
np.ascontiguousarray(spread, dtype=np.uint8),
np.ascontiguousarray(dx, dtype=np.int32),
np.ascontiguousarray(dy, dtype=np.int32),
np.ascontiguousarray(bins, dtype=np.int8),
np.uint8(bit_active),
np.float32(min_score), np.float32(greediness),
)
# Fallback: kernel base senza early-exit
return score_bitmap(spread, dx, dy, bins, bit_active)
def top_max_per_variant(
spread: np.ndarray,
dx_list: list, dy_list: list, bin_list: list,
bg_per_scale: dict,
variant_scales: list,
bit_active: int,
) -> np.ndarray:
"""Wrapper: prepara buffer flat e chiama kernel batch su tutte le varianti.
Parallelismo Numba prange-esterno sulle varianti (n_vars >> n_threads
tipicamente per top-pruning) → meglio del thread-pool Python che paga
overhead di n_vars chiamate JIT separate.
"""
if not HAS_NUMBA or len(dx_list) == 0:
return np.array([], dtype=np.float32)
n_vars = len(dx_list)
sizes = [len(d) for d in dx_list]
offsets = np.zeros(n_vars + 1, dtype=np.int32)
offsets[1:] = np.cumsum(sizes)
total = int(offsets[-1])
dx_flat = np.empty(total, dtype=np.int32)
dy_flat = np.empty(total, dtype=np.int32)
bins_flat = np.empty(total, dtype=np.int8)
for vi, (dx, dy, bn) in enumerate(zip(dx_list, dy_list, bin_list)):
i0 = int(offsets[vi]); i1 = int(offsets[vi + 1])
dx_flat[i0:i1] = dx
dy_flat[i0:i1] = dy
bins_flat[i0:i1] = bn
# bg per variante: indicizzato per scala
scales_unique = sorted(bg_per_scale.keys())
scale_to_idx = {s: i for i, s in enumerate(scales_unique)}
H, W = spread.shape
bg_pv = np.empty((len(scales_unique), H, W), dtype=np.float32)
for s, idx in scale_to_idx.items():
bg_pv[idx] = bg_per_scale[s]
scale_idx = np.array(
[scale_to_idx[s] for s in variant_scales], dtype=np.int32,
)
return _jit_top_max_per_variant(
np.ascontiguousarray(spread, dtype=np.uint8),
dx_flat, dy_flat, bins_flat, offsets, np.uint8(bit_active),
bg_pv, scale_idx,
)
def popcount_density(spread: np.ndarray) -> np.ndarray:
if HAS_NUMBA:
return _jit_popcount_density(np.ascontiguousarray(spread, dtype=np.uint8))
# Fallback
H, W = spread.shape
out = np.zeros((H, W), dtype=np.float32)
for b in range(8):
out += ((spread >> b) & 1).astype(np.float32)
return out
def score_by_shift(
resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray,
bin_has_data: np.ndarray | None = None,
) -> np.ndarray:
"""Dispatch: JIT se possibile, fallback numpy."""
if not HAS_NUMBA or len(dx) == 0:
return _numpy_score_by_shift(resp, dx, dy, bins, bin_has_data)
# Normalizza tipi per Numba
resp_f = np.ascontiguousarray(resp, dtype=np.float32)
dx_i = np.ascontiguousarray(dx, dtype=np.int32)
dy_i = np.ascontiguousarray(dy, dtype=np.int32)
bins_i = np.ascontiguousarray(bins, dtype=np.int8)
if bin_has_data is None:
bin_active = np.ones(resp.shape[0], dtype=np.bool_)
else:
bin_active = np.ascontiguousarray(bin_has_data, dtype=np.bool_)
return _jit_score_by_shift(resp_f, dx_i, dy_i, bins_i, bin_active)