diff --git a/.gitignore b/.gitignore index 771f6f4..82be5d8 100644 --- a/.gitignore +++ b/.gitignore @@ -30,7 +30,7 @@ runs.db-shm series/ *.parquet *.feather -data/ +/data/ checkpoints/ logs/ *.log diff --git a/src/multi_swarm/data/__init__.py b/src/multi_swarm/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/multi_swarm/data/ohlcv_loader.py b/src/multi_swarm/data/ohlcv_loader.py new file mode 100644 index 0000000..d79d083 --- /dev/null +++ b/src/multi_swarm/data/ohlcv_loader.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import hashlib +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path + +import ccxt # type: ignore[import-untyped] +import pandas as pd # type: ignore[import-untyped] + + +@dataclass(frozen=True) +class OHLCVRequest: + symbol: str + timeframe: str + start: datetime + end: datetime + + def cache_key(self) -> str: + s = f"{self.symbol}|{self.timeframe}|{self.start.isoformat()}|{self.end.isoformat()}" + return hashlib.sha1(s.encode()).hexdigest()[:16] + + +class OHLCVLoader: + """Carica OHLCV via ccxt (Binance) e cachea in parquet.""" + + def __init__(self, cache_dir: Path, exchange_name: str = "binance"): + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.exchange_name = exchange_name + + def load(self, req: OHLCVRequest) -> pd.DataFrame: + cache_file = self.cache_dir / f"{req.cache_key()}.parquet" + if cache_file.exists(): + return pd.read_parquet(cache_file) + + df = self._fetch_paginated(req) + df.to_parquet(cache_file) + return df + + @staticmethod + def _timeframe_to_ms(timeframe: str) -> int: + units = {"m": 60, "h": 3600, "d": 86400, "w": 604800} + unit = timeframe[-1] + if unit not in units: + raise ValueError(f"Unsupported timeframe: {timeframe}") + return int(timeframe[:-1]) * units[unit] * 1000 + + def _fetch_paginated(self, req: OHLCVRequest) -> pd.DataFrame: + exchange = getattr(ccxt, self.exchange_name)({"enableRateLimit": True}) + timeframe_ms = self._timeframe_to_ms(req.timeframe) + since = int(req.start.timestamp() * 1000) + end_ms = int(req.end.timestamp() * 1000) + all_rows: list[list[float]] = [] + limit = 1000 + + while since <= end_ms: + rows = exchange.fetch_ohlcv(req.symbol, req.timeframe, since=since, limit=limit) + if not rows: + break + all_rows.extend(rows) + last_ts = rows[-1][0] + new_since = last_ts + timeframe_ms + if new_since <= since: + break + since = new_since + + df = pd.DataFrame(all_rows, columns=["ts", "open", "high", "low", "close", "volume"]) + df = df.drop_duplicates(subset=["ts"]).sort_values("ts") + df["ts"] = pd.to_datetime(df["ts"], unit="ms", utc=True) + df = df.set_index("ts") + df = df[(df.index >= req.start) & (df.index < req.end)] + return df[["open", "high", "low", "close", "volume"]].astype("float64") diff --git a/tests/unit/test_ohlcv_loader.py b/tests/unit/test_ohlcv_loader.py new file mode 100644 index 0000000..ea8a102 --- /dev/null +++ b/tests/unit/test_ohlcv_loader.py @@ -0,0 +1,64 @@ +from datetime import UTC, datetime +from pathlib import Path + +import pandas as pd +import pytest + +from multi_swarm.data.ohlcv_loader import OHLCVLoader, OHLCVRequest + + +@pytest.fixture +def sample_ohlcv_rows(): + base_ts = int(datetime(2024, 1, 1, tzinfo=UTC).timestamp() * 1000) + rows = [] + for i in range(48): + rows.append( + [base_ts + i * 3600 * 1000, 40000 + i, 40100 + i, 39900 + i, 40050 + i, 100.0 + i] + ) + return rows + + +def test_loader_fetches_and_caches(tmp_path: Path, mocker, sample_ohlcv_rows): + fake_exchange = mocker.MagicMock() + fake_exchange.fetch_ohlcv.return_value = sample_ohlcv_rows + mocker.patch("multi_swarm.data.ohlcv_loader.ccxt.binance", return_value=fake_exchange) + + loader = OHLCVLoader(cache_dir=tmp_path) + req = OHLCVRequest( + symbol="BTC/USDT", + timeframe="1h", + start=datetime(2024, 1, 1, tzinfo=UTC), + end=datetime(2024, 1, 3, tzinfo=UTC), + ) + df = loader.load(req) + + assert isinstance(df, pd.DataFrame) + assert list(df.columns) == ["open", "high", "low", "close", "volume"] + assert len(df) == 48 + assert df.index.is_monotonic_increasing + cache_files = list(tmp_path.glob("*.parquet")) + assert len(cache_files) == 1 + + +def test_loader_uses_cache_on_second_call(tmp_path: Path, mocker, sample_ohlcv_rows): + fake_exchange = mocker.MagicMock() + fake_exchange.fetch_ohlcv.return_value = sample_ohlcv_rows + mocker.patch("multi_swarm.data.ohlcv_loader.ccxt.binance", return_value=fake_exchange) + + loader = OHLCVLoader(cache_dir=tmp_path) + req = OHLCVRequest( + symbol="BTC/USDT", + timeframe="1h", + start=datetime(2024, 1, 1, tzinfo=UTC), + end=datetime(2024, 1, 3, tzinfo=UTC), + ) + df1 = loader.load(req) + df2 = loader.load(req) + + assert fake_exchange.fetch_ohlcv.call_count == 2 # paginazione interna, non caching + pd.testing.assert_frame_equal(df1, df2) + # Seconda chiamata legge da cache, non chiama exchange + fake_exchange.fetch_ohlcv.reset_mock() + df3 = loader.load(req) + assert fake_exchange.fetch_ohlcv.call_count == 0 + pd.testing.assert_frame_equal(df1, df3)