feat(data): expanding walk-forward splits
Aggiunge expanding_walk_forward e dataclass Split per generare fold walk-forward con train che cresce e test sulla finestra successiva. Rispetta min_train_days e ritorna lista vuota su fold troppo piccoli. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Split:
|
||||
fold: int
|
||||
train_idx: pd.DatetimeIndex
|
||||
test_idx: pd.DatetimeIndex
|
||||
|
||||
|
||||
def expanding_walk_forward(
|
||||
index: pd.DatetimeIndex,
|
||||
train_ratio: float = 0.7,
|
||||
n_folds: int = 4,
|
||||
min_train_days: int = 30,
|
||||
) -> list[Split]:
|
||||
"""Genera split walk-forward expanding: train cresce, test è la finestra successiva.
|
||||
|
||||
Esempio con n_folds=4, train_ratio=0.7:
|
||||
fold 0: train [0..a0], test [a0..a0+(end-a0)/4]
|
||||
fold 1: train [0..a1], test [a1..a1+(end-a1)/4]
|
||||
...
|
||||
Il train iniziale parte da train_ratio dell'intervallo totale.
|
||||
"""
|
||||
if n_folds < 1:
|
||||
raise ValueError("n_folds must be >= 1")
|
||||
if not 0 < train_ratio < 1:
|
||||
raise ValueError("train_ratio must be in (0,1)")
|
||||
|
||||
total = len(index)
|
||||
initial_train = int(total * train_ratio)
|
||||
remaining = total - initial_train
|
||||
fold_size = max(1, remaining // n_folds)
|
||||
|
||||
splits: list[Split] = []
|
||||
for f in range(n_folds):
|
||||
train_end = initial_train + f * fold_size
|
||||
test_start = train_end
|
||||
test_end = min(test_start + fold_size, total)
|
||||
train_idx = index[:train_end]
|
||||
test_idx = index[test_start:test_end]
|
||||
if len(train_idx) < min_train_days or len(test_idx) == 0:
|
||||
continue
|
||||
splits.append(Split(fold=f, train_idx=train_idx, test_idx=test_idx))
|
||||
|
||||
return splits
|
||||
@@ -0,0 +1,40 @@
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from multi_swarm.data.splits import expanding_walk_forward
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def daily_index():
|
||||
return pd.date_range("2024-01-01", "2024-12-31", freq="D", tz="UTC")
|
||||
|
||||
|
||||
def test_expanding_split_count(daily_index: pd.DatetimeIndex):
|
||||
splits = expanding_walk_forward(
|
||||
daily_index, train_ratio=0.7, n_folds=4, min_train_days=30
|
||||
)
|
||||
assert len(splits) == 4
|
||||
|
||||
|
||||
def test_expanding_split_train_grows(daily_index: pd.DatetimeIndex):
|
||||
splits = expanding_walk_forward(
|
||||
daily_index, train_ratio=0.7, n_folds=4, min_train_days=30
|
||||
)
|
||||
train_lengths = [len(s.train_idx) for s in splits]
|
||||
assert train_lengths == sorted(train_lengths)
|
||||
assert train_lengths[0] < train_lengths[-1]
|
||||
|
||||
|
||||
def test_no_overlap_train_test(daily_index: pd.DatetimeIndex):
|
||||
splits = expanding_walk_forward(
|
||||
daily_index, train_ratio=0.7, n_folds=4, min_train_days=30
|
||||
)
|
||||
for s in splits:
|
||||
assert s.train_idx[-1] < s.test_idx[0]
|
||||
|
||||
|
||||
def test_min_train_days_respected():
|
||||
idx = pd.date_range("2024-01-01", "2024-02-15", freq="D", tz="UTC")
|
||||
splits = expanding_walk_forward(idx, train_ratio=0.7, n_folds=2, min_train_days=20)
|
||||
for s in splits:
|
||||
assert len(s.train_idx) >= 20
|
||||
Reference in New Issue
Block a user