Skip to content

Commit f110796

Browse files
authored
[DOC] Docstring improved for dummy regressor (#2839)
* Docstring improved for dummy regressor * convo resolved
1 parent a892641 commit f110796

File tree

1 file changed

+49
-28
lines changed

1 file changed

+49
-28
lines changed

aeon/regression/_dummy.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,68 +3,84 @@
33
__maintainer__ = ["MatthewMiddlehurst"]
44
__all__ = ["DummyRegressor"]
55

6-
import numpy as np
76
from sklearn.dummy import DummyRegressor as SklearnDummyRegressor
87

98
from aeon.regression.base import BaseRegressor
109

1110

1211
class DummyRegressor(BaseRegressor):
13-
"""
14-
DummyRegressor makes predictions that ignore the input features.
12+
"""Dummy regressor that makes predictions ignoring input features.
1513
16-
This regressor is a wrapper for the scikit-learn DummyClassifier that serves as a
17-
simple baseline to compare against other more complex regressors.
18-
The specific behaviour of the baseline is selected with the ``strategy`` parameter.
14+
This regressor serves as a simple baseline to compare against other, more
15+
complex regressors. It is a wrapper for scikit-learn's DummyRegressor that
16+
has been adapted for aeon's time series regression framework. The specific
17+
behavior is controlled by the ``strategy`` parameter.
1918
2019
All strategies make predictions that ignore the input feature values passed
2120
as the ``X`` argument to ``fit`` and ``predict``. The predictions, however,
2221
typically depend on values observed in the ``y`` parameter passed to ``fit``.
2322
24-
Function-identical to ``sklearn.dummy.DummyRegressor``, which is called inside.
25-
2623
Parameters
2724
----------
2825
strategy : {"mean", "median", "quantile", "constant"}, default="mean"
29-
Strategy to use to generate predictions.
30-
* "mean": always predicts the mean of the training set
31-
* "median": always predicts the median of the training set
32-
* "quantile": always predicts a specified quantile of the training set,
33-
provided with the quantile parameter.
34-
* "constant": always predicts a constant value that is provided by
35-
the user.
36-
constant : int or float or array-like of shape (n_outputs,), default=None
37-
The explicit constant as predicted by the "constant" strategy. This
38-
parameter is useful only for the "constant" strategy.
26+
Strategy to use to generate predictions:
27+
28+
- "mean": always predicts the mean of the training set
29+
- "median": always predicts the median of the training set
30+
- "quantile": always predicts a specified quantile of the training set,
31+
provided with the ``quantile`` parameter
32+
- "constant": always predicts a constant value provided by the user
33+
34+
constant : int, float or array-like of shape (n_outputs,), default=None
35+
The explicit constant value predicted by the "constant" strategy.
36+
This parameter is only used when ``strategy="constant"``.
37+
3938
quantile : float in [0.0, 1.0], default=None
40-
The quantile to predict using the "quantile" strategy. A quantile of
41-
0.5 corresponds to the median, while 0.0 to the minimum and 1.0 to the
39+
The quantile to predict when using the "quantile" strategy. A quantile
40+
of 0.5 corresponds to the median, 0.0 to the minimum, and 1.0 to the
4241
maximum.
4342
43+
Attributes
44+
----------
45+
sklearn_dummy_regressor : sklearn.dummy.DummyRegressor
46+
The underlying scikit-learn DummyRegressor instance.
47+
48+
Notes
49+
-----
50+
Function-identical to ``sklearn.dummy.DummyRegressor``, which is called inside.
51+
This class has been adapted to work with aeon's time series regression framework.
52+
4453
Examples
4554
--------
4655
>>> from aeon.regression._dummy import DummyRegressor
4756
>>> from aeon.datasets import load_covid_3month
4857
>>> X_train, y_train = load_covid_3month(split="train")
4958
>>> X_test, y_test = load_covid_3month(split="test")
5059
60+
Using mean strategy:
61+
5162
>>> reg = DummyRegressor(strategy="mean")
5263
>>> reg.fit(X_train, y_train)
5364
DummyRegressor()
5465
>>> reg.predict(X_test)[:5]
5566
array([0.03689763, 0.03689763, 0.03689763, 0.03689763, 0.03689763])
5667
68+
Using quantile strategy:
69+
5770
>>> reg = DummyRegressor(strategy="quantile", quantile=0.75)
5871
>>> reg.fit(X_train, y_train)
5972
DummyRegressor(quantile=0.75, strategy='quantile')
6073
>>> reg.predict(X_test)[:5]
6174
array([0.05559524, 0.05559524, 0.05559524, 0.05559524, 0.05559524])
6275
76+
Using constant strategy:
77+
6378
>>> reg = DummyRegressor(strategy="constant", constant=0.5)
6479
>>> reg.fit(X_train, y_train)
6580
DummyRegressor(constant=0.5, strategy='constant')
6681
>>> reg.predict(X_test)[:5]
6782
array([0.5, 0.5, 0.5, 0.5, 0.5])
83+
6884
"""
6985

7086
_tags = {
@@ -86,29 +102,34 @@ def __init__(self, strategy="mean", constant=None, quantile=None):
86102
super().__init__()
87103

88104
def _fit(self, X, y):
89-
"""Fit the dummy regressor.
105+
"""Fit the dummy regressor to training data.
90106
91107
Parameters
92108
----------
93-
X : 3D np.ndarray of shape [n_cases, n_channels, n_timepoints]
94-
y : array-like, shape = [n_cases] - the target values
109+
X : np.ndarray of shape (n_cases, n_channels, n_timepoints)
110+
The training time series data.
111+
y : array-like of shape (n_cases,)
112+
The target values for training.
95113
96114
Returns
97115
-------
98-
self : reference to self.
116+
self : DummyRegressor
117+
Reference to the fitted regressor.
99118
"""
100119
self.sklearn_dummy_regressor.fit(X, y)
101120
return self
102121

103-
def _predict(self, X) -> np.ndarray:
104-
"""Perform regression on test vectors X.
122+
def _predict(self, X):
123+
"""Make predictions on test data.
105124
106125
Parameters
107126
----------
108-
X : 3D np.ndarray of shape [n_cases, n_channels, n_timepoints]
127+
X : np.ndarray of shape (n_cases, n_channels, n_timepoints)
128+
The test time series data.
109129
110130
Returns
111131
-------
112-
y : predictions of target values for X, np.ndarray
132+
y_pred : np.ndarray of shape (n_cases,)
133+
Predicted target values for X.
113134
"""
114135
return self.sklearn_dummy_regressor.predict(X)

0 commit comments

Comments
 (0)