From d30f981421347d6e328dd718e75d35b49a5ed9da Mon Sep 17 00:00:00 2001 From: AdrianoDev Date: Sat, 9 May 2026 19:08:43 +0200 Subject: [PATCH] feat(data): expanding walk-forward splits Aggiunge expanding_walk_forward e dataclass Split per generare fold walk-forward con train che cresce e test sulla finestra successiva. Rispetta min_train_days e ritorna lista vuota su fold troppo piccoli. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/multi_swarm/data/splits.py | 50 ++++++++++++++++++++++++++++++++++ tests/unit/test_splits.py | 40 +++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 src/multi_swarm/data/splits.py create mode 100644 tests/unit/test_splits.py diff --git a/src/multi_swarm/data/splits.py b/src/multi_swarm/data/splits.py new file mode 100644 index 0000000..fc4dc32 --- /dev/null +++ b/src/multi_swarm/data/splits.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import pandas as pd # type: ignore[import-untyped] + + +@dataclass(frozen=True) +class Split: + fold: int + train_idx: pd.DatetimeIndex + test_idx: pd.DatetimeIndex + + +def expanding_walk_forward( + index: pd.DatetimeIndex, + train_ratio: float = 0.7, + n_folds: int = 4, + min_train_days: int = 30, +) -> list[Split]: + """Genera split walk-forward expanding: train cresce, test รจ la finestra successiva. + + Esempio con n_folds=4, train_ratio=0.7: + fold 0: train [0..a0], test [a0..a0+(end-a0)/4] + fold 1: train [0..a1], test [a1..a1+(end-a1)/4] + ... + Il train iniziale parte da train_ratio dell'intervallo totale. + """ + if n_folds < 1: + raise ValueError("n_folds must be >= 1") + if not 0 < train_ratio < 1: + raise ValueError("train_ratio must be in (0,1)") + + total = len(index) + initial_train = int(total * train_ratio) + remaining = total - initial_train + fold_size = max(1, remaining // n_folds) + + splits: list[Split] = [] + for f in range(n_folds): + train_end = initial_train + f * fold_size + test_start = train_end + test_end = min(test_start + fold_size, total) + train_idx = index[:train_end] + test_idx = index[test_start:test_end] + if len(train_idx) < min_train_days or len(test_idx) == 0: + continue + splits.append(Split(fold=f, train_idx=train_idx, test_idx=test_idx)) + + return splits diff --git a/tests/unit/test_splits.py b/tests/unit/test_splits.py new file mode 100644 index 0000000..c3151b7 --- /dev/null +++ b/tests/unit/test_splits.py @@ -0,0 +1,40 @@ +import pandas as pd +import pytest + +from multi_swarm.data.splits import expanding_walk_forward + + +@pytest.fixture +def daily_index(): + return pd.date_range("2024-01-01", "2024-12-31", freq="D", tz="UTC") + + +def test_expanding_split_count(daily_index: pd.DatetimeIndex): + splits = expanding_walk_forward( + daily_index, train_ratio=0.7, n_folds=4, min_train_days=30 + ) + assert len(splits) == 4 + + +def test_expanding_split_train_grows(daily_index: pd.DatetimeIndex): + splits = expanding_walk_forward( + daily_index, train_ratio=0.7, n_folds=4, min_train_days=30 + ) + train_lengths = [len(s.train_idx) for s in splits] + assert train_lengths == sorted(train_lengths) + assert train_lengths[0] < train_lengths[-1] + + +def test_no_overlap_train_test(daily_index: pd.DatetimeIndex): + splits = expanding_walk_forward( + daily_index, train_ratio=0.7, n_folds=4, min_train_days=30 + ) + for s in splits: + assert s.train_idx[-1] < s.test_idx[0] + + +def test_min_train_days_respected(): + idx = pd.date_range("2024-01-01", "2024-02-15", freq="D", tz="UTC") + splits = expanding_walk_forward(idx, train_ratio=0.7, n_folds=2, min_train_days=20) + for s in splits: + assert len(s.train_idx) >= 20