Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion vectorbt/generic/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@
from vectorbt.generic.drawdowns import Drawdowns
from vectorbt.generic.plots_builder import PlotsBuilderMixin
from vectorbt.generic.ranges import Ranges
from vectorbt.generic.splitters import SplitterT, RangeSplitter, RollingSplitter, ExpandingSplitter
from vectorbt.generic.splitters import SplitterT, RangeSplitter, RollingSplitter, ExpandingSplitter, ShrinkingSplitter
from vectorbt.generic.stats_builder import StatsBuilderMixin
from vectorbt.records.mapped_array import MappedArray
from vectorbt.utils import checks
Expand Down Expand Up @@ -1595,6 +1595,41 @@ def expanding_split(self, **kwargs) -> SplitOutputT:
![](/assets/images/expanding_split_plot.svg)
"""
return self.split(ExpandingSplitter(), **kwargs)


def shrinking_split(self, **kwargs) -> SplitOutputT:
"""Split using `GenericAccessor.split` on `vectorbt.generic.splitters.ShrinkingSplitter`.

Usage:
```pycon
>>> train_set, valid_set, test_set = sr.vbt.shrinking_split(
... n=5, set_lens=(1, 1), min_len=3, left_to_right=False)
>>> train_set[0]
split_idx 0 1 2 3 4 5 6 7
0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0
1 NaN 1.0 1.0 1.0 1.0 1.0 1.0 1
2 NaN NaN 2.0 2.0 2.0 2.0 2.0 2
3 NaN NaN NaN 3.0 3.0 3.0 3.0 3
4 NaN NaN NaN NaN 4.0 4.0 4.0 4
5 NaN NaN NaN NaN NaN 5.0 5.0 5
6 NaN NaN NaN NaN NaN NaN 6.0 6
7 NaN NaN NaN NaN NaN NaN NaN 7
>>> valid_set[0]
split_idx 0 1 2 3 4 5 6 7
0 1 2 3 4 5 6 7 8
>>> test_set[0]
split_idx 0 1 2 3 4 5 6 7
0 2 3 4 5 6 7 8 9

>>> sr.vbt.shrinking_split(
... set_lens=(1, 1), min_len=3, left_to_right=False,
... plot=True, trace_names=['train', 'valid', 'test'])
```

![](/assets/images/shrinking_split_plot.svg)
"""
return self.split(ShrinkingSplitter(), **kwargs)


# ############# Plotting ############# #

Expand Down
44 changes: 44 additions & 0 deletions vectorbt/generic/splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,47 @@ def split(self,
end_idxs = end_idxs[idxs]

return split_ranges_into_sets(start_idxs, end_idxs, **kwargs)


class ShrinkingSplitter(BaseSplitter):
"""Shrinking walk-forward splitter."""

def split(self,
X: tp.ArrayLike,
n: tp.Optional[int] = None,
min_len: int = 1,
**kwargs) -> RangesT:
"""Similar to `RollingSplitter.split`, but shrinking.

`**kwargs` are passed to `split_ranges_into_sets`."""
X = to_any_array(X)
if isinstance(X, (pd.Series, pd.DataFrame)):
index = X.index
else:
index = pd.Index(np.arange(X.shape[0]))

# Resolve start_idxs and end_idxs
start_idxs = np.arange(len(index))
end_idxs = np.full(len(index), len(index)-1)

# Filter out short ranges
window_lens = end_idxs - start_idxs + 1
min_len_mask = window_lens >= min_len
if not np.any(min_len_mask):
raise ValueError(f"There are no ranges that meet window_len>={min_len}")
start_idxs = start_idxs[min_len_mask]
end_idxs = end_idxs[min_len_mask]

# Evenly select n ranges
if n is not None:
if n > len(start_idxs):
raise ValueError(f"n cannot be bigger than the maximum number of windows {len(start_idxs)}")
idxs = np.round(np.linspace(0, len(start_idxs) - 1, n)).astype(int)
start_idxs = start_idxs[idxs]
end_idxs = end_idxs[idxs]

if 'left_to_right' not in kwargs:
kwargs['left_to_right'] = False

return split_ranges_into_sets(start_idxs, end_idxs, **kwargs)