diff --git a/pm2d/__init__.py b/pm2d/__init__.py index 321a79a..452eb82 100644 --- a/pm2d/__init__.py +++ b/pm2d/__init__.py @@ -1,7 +1,14 @@ from pm2d.matcher import EdgeShapeMatcher, Match, Template from pm2d.line_matcher import LineShapeMatcher, Match as LineMatch +from pm2d._jit_kernels import HAS_NUMBA, _warmup as _warmup_jit + +# Precompila kernel JIT in background al primo import (evita lag al 1° match) +try: + _warmup_jit() +except Exception: + pass __all__ = [ "EdgeShapeMatcher", "Match", "Template", - "LineShapeMatcher", "LineMatch", + "LineShapeMatcher", "LineMatch", "HAS_NUMBA", ] diff --git a/pm2d/_jit_kernels.py b/pm2d/_jit_kernels.py new file mode 100644 index 0000000..6d0087e --- /dev/null +++ b/pm2d/_jit_kernels.py @@ -0,0 +1,110 @@ +"""Numba JIT kernels per hot-path del matching. + +Fallback automatico a numpy se numba non disponibile o kernel fallisce. +""" +from __future__ import annotations + +import numpy as np + +try: + import numba as nb + HAS_NUMBA = True +except ImportError: # pragma: no cover + HAS_NUMBA = False + + +def _numpy_score_by_shift( + resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray, + bin_has_data: np.ndarray | None = None, +) -> np.ndarray: + """Fallback numpy (stessa logica di LineShapeMatcher._score_by_shift).""" + _, H, W = resp.shape + acc = np.zeros((H, W), dtype=np.float32) + n = len(dx) + for i in range(n): + b = int(bins[i]) + if bin_has_data is not None and not bin_has_data[b]: + continue + ddx = int(dx[i]); ddy = int(dy[i]) + y0s = max(0, -ddy); y1s = min(H, H - ddy) + x0s = max(0, -ddx); x1s = min(W, W - ddx) + if y0s >= y1s or x0s >= x1s: + continue + y0r = y0s + ddy; y1r = y1s + ddy + x0r = x0s + ddx; x1r = x1s + ddx + acc[y0s:y1s, x0s:x1s] += resp[b, y0r:y1r, x0r:x1r] + if n > 0: + acc /= n + return acc + + +if HAS_NUMBA: + + @nb.njit(cache=True, parallel=True, fastmath=True, boundscheck=False) + def _jit_score_by_shift( + resp: np.ndarray, # float32 (N_BINS, H, W) + dx: np.ndarray, # int32 (N,) + dy: np.ndarray, # int32 (N,) + bins: np.ndarray, # int8 or int32 (N,) + bin_active: np.ndarray, # bool_ (N_BINS,) + ) -> np.ndarray: + n_bins, H, W = resp.shape + N = dx.shape[0] + acc = np.zeros((H, W), dtype=np.float32) + # Parallelizza per riga: niente race (ogni y scrive solo acc[y, :]) + for y in nb.prange(H): + for i in range(N): + b = bins[i] + if not bin_active[b]: + continue + ddy = dy[i] + yy = y + ddy + if yy < 0 or yy >= H: + continue + ddx = dx[i] + x_lo = 0 if ddx >= 0 else -ddx + x_hi = W if ddx <= 0 else W - ddx + for x in range(x_lo, x_hi): + acc[y, x] += resp[b, yy, x + ddx] + if N > 0: + inv = 1.0 / N + for y in nb.prange(H): + for x in range(W): + acc[y, x] *= inv + return acc + + # Warmup: precompila con dummy data + def _warmup(): + resp = np.zeros((8, 32, 32), dtype=np.float32) + dx = np.zeros(1, dtype=np.int32) + dy = np.zeros(1, dtype=np.int32) + b = np.zeros(1, dtype=np.int8) + ba = np.ones(8, dtype=np.bool_) + _jit_score_by_shift(resp, dx, dy, b, ba) + +else: # pragma: no cover + + def _jit_score_by_shift(resp, dx, dy, bins, bin_active): + raise RuntimeError("numba non disponibile") + + def _warmup(): + pass + + +def score_by_shift( + resp: np.ndarray, dx: np.ndarray, dy: np.ndarray, bins: np.ndarray, + bin_has_data: np.ndarray | None = None, +) -> np.ndarray: + """Dispatch: JIT se possibile, fallback numpy.""" + if not HAS_NUMBA or len(dx) == 0: + return _numpy_score_by_shift(resp, dx, dy, bins, bin_has_data) + # Normalizza tipi per Numba + resp_f = np.ascontiguousarray(resp, dtype=np.float32) + dx_i = np.ascontiguousarray(dx, dtype=np.int32) + dy_i = np.ascontiguousarray(dy, dtype=np.int32) + bins_i = np.ascontiguousarray(bins, dtype=np.int8) + if bin_has_data is None: + bin_active = np.ones(resp.shape[0], dtype=np.bool_) + else: + bin_active = np.ascontiguousarray(bin_has_data, dtype=np.bool_) + return _jit_score_by_shift(resp_f, dx_i, dy_i, bins_i, bin_active) diff --git a/pm2d/line_matcher.py b/pm2d/line_matcher.py index 0c62c62..0aa72ad 100644 --- a/pm2d/line_matcher.py +++ b/pm2d/line_matcher.py @@ -33,6 +33,8 @@ from dataclasses import dataclass import cv2 import numpy as np +from pm2d._jit_kernels import score_by_shift as _jit_score_by_shift, HAS_NUMBA + N_BINS = 8 # orientamenti quantizzati modulo π @@ -303,27 +305,9 @@ class LineShapeMatcher: ) -> np.ndarray: """score[y,x] = Σ_i resp[bin_i][y+dy_i, x+dx_i] / len(dx). - Ottimizzazione: se `bin_has_data` è fornito, skippa feature il cui - bin non ha pixel attivi nella scena (contribuzione = 0). + Dispatch a kernel Numba JIT se disponibile, altrimenti fallback numpy. """ - _, H, W = resp.shape - acc = np.zeros((H, W), dtype=np.float32) - n = len(dx) - for i in range(n): - b = int(bins[i]) - if bin_has_data is not None and not bin_has_data[b]: - continue - ddx = int(dx[i]); ddy = int(dy[i]) - y0s = max(0, -ddy); y1s = min(H, H - ddy) - x0s = max(0, -ddx); x1s = min(W, W - ddx) - if y0s >= y1s or x0s >= x1s: - continue - y0r = y0s + ddy; y1r = y1s + ddy - x0r = x0s + ddx; x1r = x1s + ddx - acc[y0s:y1s, x0s:x1s] += resp[b, y0r:y1r, x0r:x1r] - if n > 0: - acc /= n - return acc + return _jit_score_by_shift(resp, dx, dy, bins, bin_has_data) @staticmethod def _subpixel_peak(acc: np.ndarray, x: int, y: int) -> tuple[float, float]: diff --git a/pyproject.toml b/pyproject.toml index cfdc4f9..5ceb221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,7 @@ name = "shape-model-2d" version = "0.1.0" requires-python = ">=3.13" dependencies = [ + "numba>=0.65.0", "numpy>=1.24", "opencv-python>=4.8", ] diff --git a/uv.lock b/uv.lock index 8c7ab2a..707c668 100644 --- a/uv.lock +++ b/uv.lock @@ -2,6 +2,50 @@ version = 1 revision = 3 requires-python = ">=3.13" +[[package]] +name = "llvmlite" +version = "0.47.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/88/a8952b6d5c21e74cbf158515b779666f692846502623e9e3c39d8e8ba25f/llvmlite-0.47.0.tar.gz", hash = "sha256:62031ce968ec74e95092184d4b0e857e444f8fdff0b8f9213707699570c33ccc", size = 193614, upload-time = "2026-03-31T18:29:53.497Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/6f/4615353e016799f80fa52ccb270a843c413b22361fadda2589b2922fb9b0/llvmlite-0.47.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:a3c6a735d4e1041808434f9d440faa3d78d9b4af2ee64d05a66f351883b6ceec", size = 37232771, upload-time = "2026-03-31T18:29:01.324Z" }, + { url = "https://files.pythonhosted.org/packages/31/b8/69f5565f1a280d032525878a86511eebed0645818492feeb169dfb20ae8e/llvmlite-0.47.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2699a74321189e812d476a43d6d7f652f51811e7b5aad9d9bba842a1c7927acb", size = 56275178, upload-time = "2026-03-31T18:29:05.748Z" }, + { url = "https://files.pythonhosted.org/packages/d6/da/b32cafcb926fb0ce2aa25553bf32cb8764af31438f40e2481df08884c947/llvmlite-0.47.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c6951e2b29930227963e53ee152441f0e14be92e9d4231852102d986c761e40", size = 55128632, upload-time = "2026-03-31T18:29:11.235Z" }, + { url = "https://files.pythonhosted.org/packages/46/9f/4898b44e4042c60fafcb1162dfb7014f6f15b1ec19bf29cfea6bf26df90d/llvmlite-0.47.0-cp313-cp313-win_amd64.whl", hash = "sha256:c2e9adf8698d813a9a5efb2d4370caf344dbc1e145019851fee6a6f319ba760e", size = 38138695, upload-time = "2026-03-31T18:29:15.43Z" }, + { url = "https://files.pythonhosted.org/packages/1c/d4/33c8af00f0bf6f552d74f3a054f648af2c5bc6bece97972f3bfadce4f5ec/llvmlite-0.47.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:de966c626c35c9dff5ae7bf12db25637738d0df83fc370cf793bc94d43d92d14", size = 37232773, upload-time = "2026-03-31T18:29:19.453Z" }, + { url = "https://files.pythonhosted.org/packages/64/1d/a760e993e0c0ba6db38d46b9f48f6c7dceb8ac838824997fb9e25f97bc04/llvmlite-0.47.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ddbccff2aeaff8670368340a158abefc032fe9b3ccf7d9c496639263d00151aa", size = 56275176, upload-time = "2026-03-31T18:29:24.149Z" }, + { url = "https://files.pythonhosted.org/packages/84/3b/e679bc3b29127182a7f4aa2d2e9e5bea42adb93fb840484147d59c236299/llvmlite-0.47.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d4a7b778a2e144fc64468fb9bf509ac1226c9813a00b4d7afea5d988c4e22fca", size = 55128631, upload-time = "2026-03-31T18:29:29.536Z" }, + { url = "https://files.pythonhosted.org/packages/be/f7/19e2a09c62809c9e63bbd14ce71fb92c6ff7b7b3045741bb00c781efc3c9/llvmlite-0.47.0-cp314-cp314-win_amd64.whl", hash = "sha256:694e3c2cdc472ed2bd8bd4555ca002eec4310961dd58ef791d508f57b5cc4c94", size = 39153826, upload-time = "2026-03-31T18:29:33.681Z" }, + { url = "https://files.pythonhosted.org/packages/40/a1/581a8c707b5e80efdbbe1dd94527404d33fe50bceb71f39d5a7e11bd57b7/llvmlite-0.47.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:92ec8a169a20b473c1c54d4695e371bde36489fc1efa3688e11e99beba0abf9c", size = 37232772, upload-time = "2026-03-31T18:29:37.952Z" }, + { url = "https://files.pythonhosted.org/packages/11/03/16090dd6f74ba2b8b922276047f15962fbeea0a75d5601607edb301ba945/llvmlite-0.47.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fa1cbd800edd3b20bc141521f7fd45a6185a5b84109aa6855134e81397ffe72b", size = 56275178, upload-time = "2026-03-31T18:29:42.58Z" }, + { url = "https://files.pythonhosted.org/packages/f5/cb/0abf1dd4c5286a95ffe0c1d8c67aec06b515894a0dd2ac97f5e27b82ab0b/llvmlite-0.47.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f6725179b89f03b17dabe236ff3422cb8291b4c1bf40af152826dfd34e350ae8", size = 55128632, upload-time = "2026-03-31T18:29:46.939Z" }, + { url = "https://files.pythonhosted.org/packages/4f/79/d3bbab197e86e0ff4f9c07122895b66a3e0d024247fcff7f12c473cb36d9/llvmlite-0.47.0-cp314-cp314t-win_amd64.whl", hash = "sha256:6842cf6f707ec4be3d985a385ad03f72b2d724439e118fcbe99b2929964f0453", size = 39153839, upload-time = "2026-03-31T18:29:51.004Z" }, +] + +[[package]] +name = "numba" +version = "0.65.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "llvmlite" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/61/7299643b9c18d669e04be7c5bcb64d985070d07553274817b45b049e7bfe/numba-0.65.0.tar.gz", hash = "sha256:edad0d9f6682e93624c00125a471ae4df186175d71fd604c983c377cdc03e68b", size = 2764131, upload-time = "2026-04-01T03:52:01.946Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/f8/eee0f1ff456218db036bfc9023995ec1f85a9dc8f2422f1594f6a87829e0/numba-0.65.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:c6334094563a456a695c812e6846288376ca02327cf246cdcc83e1bb27862367", size = 2680679, upload-time = "2026-04-01T03:51:39.491Z" }, + { url = "https://files.pythonhosted.org/packages/1b/8f/3d116e4b8e92f6abace431afa4b2b944f4d65bdee83af886f5c4b263df95/numba-0.65.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b8a9008411615c69d083d1dcf477f75a5aa727b30beb16e139799e2be945cdfd", size = 3809537, upload-time = "2026-04-01T03:51:41.42Z" }, + { url = "https://files.pythonhosted.org/packages/b5/2c/6a3ca4128e253cb67affe06deb47688f51ce968f5111e2a06d010e6f1fa6/numba-0.65.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af96c0cba53664efcb361528b8c75e011a6556c859c7e08424c2715201c6cf7a", size = 3508615, upload-time = "2026-04-01T03:51:43.444Z" }, + { url = "https://files.pythonhosted.org/packages/96/0e/267f9a36fb282c104a971d7eecb685b411c47dce2a740fe69cf5fc2945d9/numba-0.65.0-cp313-cp313-win_amd64.whl", hash = "sha256:6254e73b9c929dc736a1fbd3d6f5680789709a5067cae1fa7198707385129c04", size = 2749938, upload-time = "2026-04-01T03:51:45.218Z" }, + { url = "https://files.pythonhosted.org/packages/56/a4/90edb01e9176053578e343d7a7276bc28356741ee67059aed8ed2c1a4e59/numba-0.65.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:ee336b398a6fca51b1f626034de99f50cb1bd87d537a166275158a3cee744b82", size = 2680878, upload-time = "2026-04-01T03:51:46.91Z" }, + { url = "https://files.pythonhosted.org/packages/24/8d/e12d6ff4b9119db3cbf7b2db1ce257576441bd3c76388c786dea74f20b02/numba-0.65.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:05c0a9fdf75d85f57dee47b719e8d6415707b80aae45d75f63f9dc1b935c29f7", size = 3778456, upload-time = "2026-04-01T03:51:48.552Z" }, + { url = "https://files.pythonhosted.org/packages/17/89/abcd83e76f6a773276fe76244140671bcc5bf820f6e2ae1a15362ae4c8c9/numba-0.65.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:583680e0e8faf124d362df23b4b593f3221a8996341a63d1b664c122401bec2f", size = 3478464, upload-time = "2026-04-01T03:51:50.527Z" }, + { url = "https://files.pythonhosted.org/packages/73/5b/fbce55ce3d933afbc7ade04df826853e4a846aaa47d58d2fbb669b8f2d08/numba-0.65.0-cp314-cp314-win_amd64.whl", hash = "sha256:add297d3e1c08dd884f44100152612fa41e66a51d15fdf91307f9dde31d06830", size = 2752012, upload-time = "2026-04-01T03:51:52.691Z" }, + { url = "https://files.pythonhosted.org/packages/1e/ab/af705f4257d9388fb2fd6d7416573e98b6ca9c786e8b58f02720978557bd/numba-0.65.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:194a243ba53a9157c8538cbb3166ec015d785a8c5d584d06cdd88bee902233c7", size = 2683961, upload-time = "2026-04-01T03:51:54.281Z" }, + { url = "https://files.pythonhosted.org/packages/ff/e5/8267b0adb0c01b52b553df5062fbbb42c30ed5362d08b85cc913a36f838f/numba-0.65.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c7fa502960f7a2f3f5cb025bc7bff888a3551277b92431bfdc5ba2f11a375749", size = 3816373, upload-time = "2026-04-01T03:51:56.18Z" }, + { url = "https://files.pythonhosted.org/packages/b0/f5/b8397ca360971669a93706b9274592b6864e4367a37d498fbbcb62aa2d48/numba-0.65.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5046c63f783ca3eb6195f826a50797465e7c4ce811daa17c9bea47e310c9b964", size = 3532782, upload-time = "2026-04-01T03:51:58.387Z" }, + { url = "https://files.pythonhosted.org/packages/f5/21/1e73fa16bf0393ebb74c5bb208d712152ffdfc84600a8e93a3180317856e/numba-0.65.0-cp314-cp314t-win_amd64.whl", hash = "sha256:46fd679ae4f68c7a5d5721efbd29ecee0b0f3013211591891d79b51bfdf73113", size = 2757611, upload-time = "2026-04-01T03:52:00.083Z" }, +] + [[package]] name = "numpy" version = "2.4.4" @@ -75,12 +119,14 @@ name = "shape-model-2d" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "numba" }, { name = "numpy" }, { name = "opencv-python" }, ] [package.metadata] requires-dist = [ + { name = "numba", specifier = ">=0.65.0" }, { name = "numpy", specifier = ">=1.24" }, { name = "opencv-python", specifier = ">=4.8" }, ]