Skip to content

Commit d22855a

Browse files
[EHN] Allow exogenous variables in regression forecasters (#2915)
* add exogenous variable feature to base and regression forecasters * delete self._check_X for exog --------- Co-authored-by: Tony Bagnall <a.j.bagnall@soton.ac.uk>
1 parent 7ef70ed commit d22855a

File tree

4 files changed

+133
-24
lines changed

4 files changed

+133
-24
lines changed

aeon/forecasting/_regression.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class RegressionForecaster(BaseForecaster):
2020
window to form training collection ``X``, take ``horizon`` points ahead to form
2121
``y``, then apply an aeon or sklearn regressor.
2222
23+
If exogenous variables are provided, they are concatenated with the main series
24+
and included in the regression windows.
2325
2426
Parameters
2527
----------
@@ -36,6 +38,10 @@ class RegressionForecaster(BaseForecaster):
3638
with sklearn regressors.
3739
"""
3840

41+
_tags = {
42+
"capability:exogenous": True,
43+
}
44+
3945
def __init__(self, window: int, horizon: int = 1, regressor=None):
4046
self.window = window
4147
self.regressor = regressor
@@ -52,8 +58,7 @@ def _fit(self, y, exog=None):
5258
y : np.ndarray
5359
A time series on which to learn a forecaster to predict horizon ahead.
5460
exog : np.ndarray, default=None
55-
Optional exogenous time series data. Included for interface
56-
compatibility but ignored in this estimator.
61+
Optional exogenous time series data, assumed to be aligned with y.
5762
5863
Returns
5964
-------
@@ -65,18 +70,38 @@ def _fit(self, y, exog=None):
6570
self.regressor_ = LinearRegression()
6671
else:
6772
self.regressor_ = self.regressor
68-
y = y.squeeze()
69-
if self.window < 1 or self.window > len(y) - 3:
73+
74+
# Combine y and exog for windowing
75+
if exog is not None:
76+
if exog.ndim == 1:
77+
exog = exog.reshape(1, -1)
78+
if exog.shape[1] != y.shape[1]:
79+
raise ValueError("y and exog must have the same number of time points.")
80+
combined_data = np.vstack([y, exog])
81+
else:
82+
combined_data = y
83+
84+
# Enforce a minimum number of training samples, currently 3
85+
if self.window < 1 or self.window >= combined_data.shape[1] - 3:
7086
raise ValueError(
71-
f" window value {self.window} is invalid for series " f"length {len(y)}"
87+
f"window value {self.window} is invalid for series length "
88+
f"{combined_data.shape[1]}"
7289
)
73-
X = np.lib.stride_tricks.sliding_window_view(y, window_shape=self.window)
74-
# Ignore the final horizon values: need to store these for pred with empty y
90+
91+
# Create windowed data for X
92+
X = np.lib.stride_tricks.sliding_window_view(
93+
combined_data, window_shape=(combined_data.shape[0], self.window)
94+
)
95+
X = X.squeeze(axis=0)
96+
X = X[:, :, :].reshape(X.shape[0], -1)
97+
98+
# Ignore the final horizon values for X
7599
X = X[: -self.horizon]
76-
# Extract y_train
77-
y_train = y[self.window + self.horizon - 1 :]
78-
self.last_ = y[-self.window :]
79-
self.last_ = self.last_.reshape(1, -1)
100+
101+
# Extract y_train from the original series
102+
y_train = y.squeeze()[self.window + self.horizon - 1 :]
103+
104+
self.last_ = combined_data[:, -self.window :]
80105
self.regressor_.fit(X=X, y=y_train)
81106
return self
82107

@@ -90,18 +115,33 @@ def _predict(self, y=None, exog=None):
90115
A time series to predict the next horizon value for. If None,
91116
predict the next horizon value after series seen in fit.
92117
exog : np.ndarray, default=None
93-
Optional exogenous time series data. Included for interface
94-
compatibility but ignored in this estimator.
118+
Optional exogenous time series data, assumed to be aligned with y.
95119
96120
Returns
97121
-------
98122
float
99123
single prediction self.horizon steps ahead of y.
100124
"""
101125
if y is None:
102-
return self.regressor_.predict(self.last_)[0]
103-
last = y[:, -self.window :]
104-
return self.regressor_.predict(last)[0]
126+
# Flatten the last window to be compatible with sklearn regressors
127+
last_window_flat = self.last_.reshape(1, -1)
128+
return self.regressor_.predict(last_window_flat)[0]
129+
130+
# Combine y and exog for prediction
131+
if exog is not None:
132+
if exog.ndim == 1:
133+
exog = exog.reshape(1, -1)
134+
if exog.shape[1] != y.shape[1]:
135+
raise ValueError("y and exog must have the same number of time points.")
136+
combined_data = np.vstack([y, exog])
137+
else:
138+
combined_data = y
139+
140+
# Extract the last window and flatten for prediction
141+
last_window = combined_data[:, -self.window :]
142+
last_window_flat = last_window.reshape(1, -1)
143+
144+
return self.regressor_.predict(last_window_flat)[0]
105145

106146
@classmethod
107147
def _get_test_params(cls, parameter_set: str = "default"):

aeon/forecasting/base.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,14 @@ def fit(self, y, exog=None):
6868
if self.get_tag("fit_is_empty"):
6969
self.is_fitted = True
7070
return self
71+
7172
horizon = self.get_tag("capability:horizon")
7273
if not horizon and self.horizon > 1:
7374
raise ValueError(
7475
f"Horizon is set >1, but {self.__class__.__name__} cannot handle a "
7576
f"horizon greater than 1"
7677
)
78+
7779
exog_tag = self.get_tag("capability:exogenous")
7880
if not exog_tag and exog is not None:
7981
raise ValueError(
@@ -83,8 +85,10 @@ def fit(self, y, exog=None):
8385

8486
self._check_X(y, self.axis)
8587
y = self._convert_y(y, self.axis)
88+
8689
if exog is not None:
87-
raise NotImplementedError("Exogenous variables not yet supported")
90+
exog = self._convert_y(exog, self.axis)
91+
8892
self.is_fitted = True
8993
return self._fit(y, exog)
9094

@@ -113,9 +117,9 @@ def predict(self, y=None, exog=None):
113117
self._check_X(y, self.axis)
114118
y = self._convert_y(y, self.axis)
115119
if exog is not None:
116-
raise NotImplementedError("Exogenous variables not yet supported")
117-
x = self._predict(y, exog)
118-
return x
120+
exog = self._convert_y(exog, self.axis)
121+
122+
return self._predict(y, exog)
119123

120124
@abstractmethod
121125
def _predict(self, y=None, exog=None): ...
@@ -141,6 +145,8 @@ def forecast(self, y, exog=None):
141145
"""
142146
self._check_X(y, self.axis)
143147
y = self._convert_y(y, self.axis)
148+
if exog is not None:
149+
exog = self._convert_y(exog, self.axis)
144150
return self._forecast(y, exog)
145151

146152
def _forecast(self, y, exog=None):
@@ -149,7 +155,7 @@ def _forecast(self, y, exog=None):
149155
return self._predict(y, exog)
150156

151157
@final
152-
def direct_forecast(self, y, prediction_horizon):
158+
def direct_forecast(self, y, prediction_horizon, exog=None):
153159
"""
154160
Make ``prediction_horizon`` ahead forecasts using a fit for each horizon.
155161
@@ -166,7 +172,8 @@ def direct_forecast(self, y, prediction_horizon):
166172
The time series to make forecasts about.
167173
prediction_horizon : int
168174
The number of future time steps to forecast.
169-
175+
exog : np.ndarray, default =None
176+
Optional exogenous time series data assumed to be aligned with y.
170177
predictions : np.ndarray
171178
An array of shape `(prediction_horizon,)` containing the forecasts for
172179
each horizon.
@@ -198,7 +205,7 @@ def direct_forecast(self, y, prediction_horizon):
198205
preds = np.zeros(prediction_horizon)
199206
for i in range(0, prediction_horizon):
200207
self.horizon = i + 1
201-
preds[i] = self.forecast(y)
208+
preds[i] = self.forecast(y, exog)
202209
return preds
203210

204211
def iterative_forecast(self, y, prediction_horizon):
@@ -263,7 +270,6 @@ def _convert_y(self, y: VALID_SERIES_INNER_TYPES, axis: int):
263270
if inner_names[0] == "ndarray":
264271
y = y.to_numpy()
265272
elif inner_names[0] == "DataFrame":
266-
# converting a 1d array will create a 2d array in axis 0 format
267273
transpose = False
268274
if y.ndim == 1 and axis == 1:
269275
transpose = True

aeon/forecasting/tests/test_base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,17 @@ def test_recursive_forecast():
6565
p = f.predict(y)
6666
assert p == preds[i]
6767
y = np.append(y, p)
68+
69+
70+
def test_direct_forecast_with_exog():
71+
"""Test direct forecasting with exogenous variables."""
72+
y = np.arange(50)
73+
exog = np.arange(50) * 2
74+
f = RegressionForecaster(window=10)
75+
76+
preds = f.direct_forecast(y, prediction_horizon=10, exog=exog)
77+
assert isinstance(preds, np.ndarray) and len(preds) == 10
78+
79+
# Check that predictions are different from when no exog is used
80+
preds_no_exog = f.direct_forecast(y, prediction_horizon=10)
81+
assert not np.array_equal(preds, preds_no_exog)

aeon/forecasting/tests/test_regressor.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,52 @@ def test_regression_forecaster():
3232
with pytest.raises(ValueError):
3333
f = RegressionForecaster(window=101)
3434
f.fit(y)
35+
36+
37+
def test_regression_forecaster_with_exog():
38+
"""Test the regression forecaster with exogenous variables."""
39+
np.random.seed(0)
40+
41+
n_samples = 100
42+
exog = np.random.rand(n_samples) * 10
43+
y = 2 * exog + np.random.rand(n_samples) * 0.1
44+
45+
f = RegressionForecaster(window=10)
46+
47+
# Test fit and predict with exog
48+
f.fit(y, exog=exog)
49+
p1 = f.predict()
50+
assert isinstance(p1, float)
51+
52+
# Test that exog variable has an impact
53+
exog_zeros = np.zeros(n_samples)
54+
f.fit(y, exog=exog_zeros)
55+
p2 = f.predict()
56+
assert p1 != p2
57+
58+
# Test that forecast method works and is equivalent to fit+predict
59+
y_new = np.arange(50, 150)
60+
exog_new = np.arange(50, 150) * 2
61+
62+
# Manual fit + predict
63+
f.fit(y=y_new, exog=exog_new)
64+
p_manual = f.predict()
65+
66+
# forecast() method
67+
p_forecast = f.forecast(y=y_new, exog=exog_new)
68+
assert p_manual == pytest.approx(p_forecast)
69+
70+
71+
def test_regression_forecaster_with_exog_errors():
72+
"""Test errors in regression forecaster with exogenous variables."""
73+
y = np.random.rand(100)
74+
exog_short = np.random.rand(99)
75+
f = RegressionForecaster(window=10)
76+
77+
# Test for unequal length series
78+
with pytest.raises(ValueError, match="must have the same number of time points"):
79+
f.fit(y, exog=exog_short)
80+
81+
with pytest.raises(ValueError, match="must have the same number of time points"):
82+
f.fit(y)
83+
f.predict(y, exog=exog_short)

0 commit comments

Comments
 (0)