feat(data): expanding walk-forward splits

Aggiunge expanding_walk_forward e dataclass Split per generare fold
walk-forward con train che cresce e test sulla finestra successiva.
Rispetta min_train_days e ritorna lista vuota su fold troppo piccoli.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-09 19:08:43 +02:00
parent 116879400a
commit d30f981421
2 changed files with 90 additions and 0 deletions
+50
View File
@@ -0,0 +1,50 @@
from __future__ import annotations
from dataclasses import dataclass
import pandas as pd # type: ignore[import-untyped]
@dataclass(frozen=True)
class Split:
fold: int
train_idx: pd.DatetimeIndex
test_idx: pd.DatetimeIndex
def expanding_walk_forward(
index: pd.DatetimeIndex,
train_ratio: float = 0.7,
n_folds: int = 4,
min_train_days: int = 30,
) -> list[Split]:
"""Genera split walk-forward expanding: train cresce, test è la finestra successiva.
Esempio con n_folds=4, train_ratio=0.7:
fold 0: train [0..a0], test [a0..a0+(end-a0)/4]
fold 1: train [0..a1], test [a1..a1+(end-a1)/4]
...
Il train iniziale parte da train_ratio dell'intervallo totale.
"""
if n_folds < 1:
raise ValueError("n_folds must be >= 1")
if not 0 < train_ratio < 1:
raise ValueError("train_ratio must be in (0,1)")
total = len(index)
initial_train = int(total * train_ratio)
remaining = total - initial_train
fold_size = max(1, remaining // n_folds)
splits: list[Split] = []
for f in range(n_folds):
train_end = initial_train + f * fold_size
test_start = train_end
test_end = min(test_start + fold_size, total)
train_idx = index[:train_end]
test_idx = index[test_start:test_end]
if len(train_idx) < min_train_days or len(test_idx) == 0:
continue
splits.append(Split(fold=f, train_idx=train_idx, test_idx=test_idx))
return splits
+40
View File
@@ -0,0 +1,40 @@
import pandas as pd
import pytest
from multi_swarm.data.splits import expanding_walk_forward
@pytest.fixture
def daily_index():
return pd.date_range("2024-01-01", "2024-12-31", freq="D", tz="UTC")
def test_expanding_split_count(daily_index: pd.DatetimeIndex):
splits = expanding_walk_forward(
daily_index, train_ratio=0.7, n_folds=4, min_train_days=30
)
assert len(splits) == 4
def test_expanding_split_train_grows(daily_index: pd.DatetimeIndex):
splits = expanding_walk_forward(
daily_index, train_ratio=0.7, n_folds=4, min_train_days=30
)
train_lengths = [len(s.train_idx) for s in splits]
assert train_lengths == sorted(train_lengths)
assert train_lengths[0] < train_lengths[-1]
def test_no_overlap_train_test(daily_index: pd.DatetimeIndex):
splits = expanding_walk_forward(
daily_index, train_ratio=0.7, n_folds=4, min_train_days=30
)
for s in splits:
assert s.train_idx[-1] < s.test_idx[0]
def test_min_train_days_respected():
idx = pd.date_range("2024-01-01", "2024-02-15", freq="D", tz="UTC")
splits = expanding_walk_forward(idx, train_ratio=0.7, n_folds=2, min_train_days=20)
for s in splits:
assert len(s.train_idx) >= 20