"""Numba JIT kernels per hot-path del matching. Fallback automatico a numpy se numba non disponibile o kernel fallisce. """ from __future__ import annotations import numpy as np try: import numba as nb HAS_NUMBA = True except ImportError: # pragma: no cover HAS_NUMBA = False def _numpy_score_by_shift( resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray, bin_has_data: np.ndarray | None = None, ) -> np.ndarray: """Fallback numpy (stessa logica di LineShapeMatcher._score_by_shift).""" _, H, W = resp.shape acc = np.zeros((H, W), dtype=np.float32) n = len(dx) for i in range(n): b = int(bins[i]) if bin_has_data is not None and not bin_has_data[b]: continue ddx = int(dx[i]); ddy = int(dy[i]) y0s = max(0, -ddy); y1s = min(H, H - ddy) x0s = max(0, -ddx); x1s = min(W, W - ddx) if y0s >= y1s or x0s >= x1s: continue y0r = y0s + ddy; y1r = y1s + ddy x0r = x0s + ddx; x1r = x1s + ddx acc[y0s:y1s, x0s:x1s] += resp[b, y0r:y1r, x0r:x1r] if n > 0: acc /= n return acc if HAS_NUMBA: @nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False) def _jit_score_by_shift( resp: np.ndarray, # float32 (N_BINS, H, W) dx: np.ndarray, # int32 (N,) dy: np.ndarray, # int32 (N,) bins: np.ndarray, # int8 or int32 (N,) bin_active: np.ndarray, # bool_ (N_BINS,) ) -> np.ndarray: n_bins, H, W = resp.shape N = dx.shape[0] acc = np.zeros((H, W), dtype=np.float32) # Parallelizza per riga: niente race (ogni y scrive solo acc[y, :]) for y in nb.prange(H): for i in range(N): b = bins[i] if not bin_active[b]: continue ddy = dy[i] yy = y + ddy if yy < 0 or yy >= H: continue ddx = dx[i] x_lo = 0 if ddx >= 0 else -ddx x_hi = W if ddx <= 0 else W - ddx for x in range(x_lo, x_hi): acc[y, x] += resp[b, yy, x + ddx] if N > 0: inv = 1.0 / N for y in nb.prange(H): for x in range(W): acc[y, x] *= inv return acc # Warmup: precompila con dummy data def _warmup(): resp = np.zeros((8, 32, 32), dtype=np.float32) dx = np.zeros(1, dtype=np.int32) dy = np.zeros(1, dtype=np.int32) b = np.zeros(1, dtype=np.int8) ba = np.ones(8, dtype=np.bool_) _jit_score_by_shift(resp, dx, dy, b, ba) else: # pragma: no cover def _jit_score_by_shift(resp, dx, dy, bins, bin_active): raise RuntimeError("numba non disponibile") def _warmup(): pass def score_by_shift( resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray, bin_has_data: np.ndarray | None = None, ) -> np.ndarray: """Dispatch: JIT se possibile, fallback numpy.""" if not HAS_NUMBA or len(dx) == 0: return _numpy_score_by_shift(resp, dx, dy, bins, bin_has_data) # Normalizza tipi per Numba resp_f = np.ascontiguousarray(resp, dtype=np.float32) dx_i = np.ascontiguousarray(dx, dtype=np.int32) dy_i = np.ascontiguousarray(dy, dtype=np.int32) bins_i = np.ascontiguousarray(bins, dtype=np.int8) if bin_has_data is None: bin_active = np.ones(resp.shape[0], dtype=np.bool_) else: bin_active = np.ascontiguousarray(bin_has_data, dtype=np.bool_) return _jit_score_by_shift(resp_f, dx_i, dy_i, bins_i, bin_active)