import pandas as pd import pytest from multi_swarm.data.splits import expanding_walk_forward @pytest.fixture def daily_index(): return pd.date_range("2024-01-01", "2024-12-31", freq="D", tz="UTC") def test_expanding_split_count(daily_index: pd.DatetimeIndex): splits = expanding_walk_forward( daily_index, train_ratio=0.7, n_folds=4, min_train_days=30 ) assert len(splits) == 4 def test_expanding_split_train_grows(daily_index: pd.DatetimeIndex): splits = expanding_walk_forward( daily_index, train_ratio=0.7, n_folds=4, min_train_days=30 ) train_lengths = [len(s.train_idx) for s in splits] assert train_lengths == sorted(train_lengths) assert train_lengths[0] < train_lengths[-1] def test_no_overlap_train_test(daily_index: pd.DatetimeIndex): splits = expanding_walk_forward( daily_index, train_ratio=0.7, n_folds=4, min_train_days=30 ) for s in splits: assert s.train_idx[-1] < s.test_idx[0] def test_min_train_days_respected(): idx = pd.date_range("2024-01-01", "2024-02-15", freq="D", tz="UTC") splits = expanding_walk_forward(idx, train_ratio=0.7, n_folds=2, min_train_days=20) for s in splits: assert len(s.train_idx) >= 20