From af94a866b7ea9604043c421a7db3edf17925e29d Mon Sep 17 00:00:00 2001 From: lucifer4073 Date: Thu, 22 May 2025 21:51:48 +0530 Subject: [PATCH 1/6] Basedeep forecaster added --- aeon/forecasting/deep_learning/__init__.py | 1 + aeon/forecasting/deep_learning/base.py | 215 +++++++++++++++++++++ 2 files changed, 216 insertions(+) create mode 100644 aeon/forecasting/deep_learning/__init__.py create mode 100644 aeon/forecasting/deep_learning/base.py diff --git a/aeon/forecasting/deep_learning/__init__.py b/aeon/forecasting/deep_learning/__init__.py new file mode 100644 index 0000000000..42067031dc --- /dev/null +++ b/aeon/forecasting/deep_learning/__init__.py @@ -0,0 +1 @@ +"""Initialization for aeon forecasting deep learning module.""" diff --git a/aeon/forecasting/deep_learning/base.py b/aeon/forecasting/deep_learning/base.py new file mode 100644 index 0000000000..17eead9c49 --- /dev/null +++ b/aeon/forecasting/deep_learning/base.py @@ -0,0 +1,215 @@ +""" +BaseDeepForecaster class. + +A simplified first base class for deep learning forecasting models. +This class is a subclass of BaseForecaster and inherits its methods and attributes. +It provides a base for deep learning models, including methods for training and +predicting. + +""" + +from abc import abstractmethod + +import numpy as np +import pandas as pd +import tensorflow as tf + +from aeon.forecasting.base import BaseForecaster + + +class BaseDeepForecaster(BaseForecaster): + """Base class for deep learning forecasters in aeon. + + Parameters + ---------- + horizon : int, default=1 + Forecasting horizon, the number of steps ahead to predict. + window : int, default=10 + The window size for creating input sequences. + batch_size : int, default=32 + Batch size for training the model. + epochs : int, default=100 + Number of epochs to train the model. + verbose : int, default=0 + Verbosity mode (0, 1, or 2). + optimizer : str or tf.keras.optimizers.Optimizer, default='adam' + Optimizer to use for training. + loss : str or tf.keras.losses.Loss, default='mse' + Loss function for training. + random_state : int, default=None + Seed for random number generators. + """ + + def __init__( + self, + horizon=1, + window=10, + batch_size=32, + epochs=100, + verbose=0, + optimizer="adam", + loss="mse", + random_state=None, + ): + self.horizon = horizon + self.window = window + self.batch_size = batch_size + self.epochs = epochs + self.verbose = verbose + self.optimizer = optimizer + self.loss = loss + self.random_state = random_state + self.model_ = None + super().__init__() + + def _fit(self, y, X=None): + """Fit the forecaster to training data. + + Parameters + ---------- + y : np.ndarray or pd.Series + Target time series to which to fit the forecaster. + X : np.ndarray or pd.DataFrame, default=None + Exogenous variables. + + Returns + ------- + self : returns an instance of self + """ + # Set random seed for reproducibility + if self.random_state is not None: + np.random.seed(self.random_state) + tf.random.set_seed(self.random_state) + + # Convert input data to numpy array + y_inner = self._convert_input(y) + + # Create sequences for training + X_train, y_train = self._create_sequences(y_inner) + + # Build and compile the model + self.model_ = self._build_model(X_train.shape[1:]) + self.model_.compile(optimizer=self.optimizer, loss=self.loss) + + # Train the model + self.model_.fit( + X_train, + y_train, + batch_size=self.batch_size, + epochs=self.epochs, + verbose=self.verbose, + ) + + return self + + def _predict(self, y=None, X=None): + """Make forecasts for y. + + Parameters + ---------- + y : np.ndarray or pd.Series, default=None + Series to predict from. + X : np.ndarray or pd.DataFrame, default=None + Exogenous variables. + + Returns + ------- + predictions : np.ndarray + Predicted values. + """ + if y is None: + raise ValueError("y cannot be None for prediction") + + # Convert input data to numpy array + y_inner = self._convert_input(y) + + # Use the last window of data for prediction + last_window = y_inner[-self.window :].reshape(1, self.window, 1) + + # Make prediction + prediction = self.model_.predict(last_window, verbose=0) + + return prediction.flatten() + + def _forecast(self, y, X=None): + """Forecast time series at future horizon. + + Parameters + ---------- + y : np.ndarray or pd.Series + Time series to forecast from. + X : np.ndarray or pd.DataFrame, default=None + Exogenous variables. + + Returns + ------- + forecasts : np.ndarray + Forecasted values. + """ + # Fit the model + self._fit(y, X) + + # Make prediction + return self._predict(y, X) + + def _convert_input(self, y): + """Convert input data to numpy array. + + Parameters + ---------- + y : np.ndarray or pd.Series + Input time series. + + Returns + ------- + y_inner : np.ndarray + Converted numpy array. + """ + if isinstance(y, pd.Series) or isinstance(y, pd.DataFrame): + y_inner = y.values + else: + y_inner = y + + # Ensure 1D array + if len(y_inner.shape) > 1: + y_inner = y_inner.flatten() + + return y_inner + + def _create_sequences(self, data): + """Create input sequences and target values for training. + + Parameters + ---------- + data : np.ndarray + Time series data. + + Returns + ------- + X : np.ndarray + Input sequences. + y : np.ndarray + Target values. + """ + X, y = [], [] + for i in range(len(data) - self.window - self.horizon + 1): + X.append(data[i : (i + self.window)]) + y.append(data[i + self.window : (i + self.window + self.horizon)]) + + return np.array(X).reshape(-1, self.window, 1), np.array(y) + + @abstractmethod + def _build_model(self, input_shape): + """Build the deep learning model. + + Parameters + ---------- + input_shape : tuple + Shape of input data. + + Returns + ------- + model : tf.keras.Model + Compiled Keras model. + """ + pass From d2ee9ec5acda38de318ad22c1df0563b3f9d526f Mon Sep 17 00:00:00 2001 From: lucifer4073 Date: Mon, 26 May 2025 20:30:50 +0530 Subject: [PATCH 2/6] init for basedlf added --- aeon/forecasting/deep_learning/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/aeon/forecasting/deep_learning/__init__.py b/aeon/forecasting/deep_learning/__init__.py index 42067031dc..c4b7a27030 100644 --- a/aeon/forecasting/deep_learning/__init__.py +++ b/aeon/forecasting/deep_learning/__init__.py @@ -1 +1,7 @@ """Initialization for aeon forecasting deep learning module.""" + +__all__ = [ + "BaseDeepForecaster", +] + +from aeon.forecasting.deep_learning.base import BaseDeepForecaster From ab3030c0767fb773714c523a32963aa99c18c078 Mon Sep 17 00:00:00 2001 From: lucifer4073 Date: Sun, 15 Jun 2025 16:13:43 +0530 Subject: [PATCH 3/6] test file and axis added for basedeepforecaster --- aeon/forecasting/deep_learning/base.py | 76 ++++++++++++++----- .../deep_learning/tests/__init__.py | 1 + .../deep_learning/tests/test_base.py | 62 +++++++++++++++ 3 files changed, 120 insertions(+), 19 deletions(-) create mode 100644 aeon/forecasting/deep_learning/tests/__init__.py create mode 100644 aeon/forecasting/deep_learning/tests/test_base.py diff --git a/aeon/forecasting/deep_learning/base.py b/aeon/forecasting/deep_learning/base.py index 17eead9c49..ebab0116bd 100644 --- a/aeon/forecasting/deep_learning/base.py +++ b/aeon/forecasting/deep_learning/base.py @@ -1,12 +1,13 @@ +"""Base class module for deep learning forecasters in aeon. + +This module defines the `BaseDeepForecaster` class, an abstract base class for +deep learning-based forecasting models within the aeon toolkit. """ -BaseDeepForecaster class. -A simplified first base class for deep learning forecasting models. -This class is a subclass of BaseForecaster and inherits its methods and attributes. -It provides a base for deep learning models, including methods for training and -predicting. +from __future__ import annotations -""" +__maintainer__ = [] +__all__ = ["BaseDeepForecaster"] from abc import abstractmethod @@ -20,6 +21,9 @@ class BaseDeepForecaster(BaseForecaster): """Base class for deep learning forecasters in aeon. + This class provides a foundation for deep learning-based forecasting models, + handling data preprocessing, model training, and prediction. + Parameters ---------- horizon : int, default=1 @@ -38,6 +42,9 @@ class BaseDeepForecaster(BaseForecaster): Loss function for training. random_state : int, default=None Seed for random number generators. + axis : int, default=0 + Axis along which to apply the forecaster. + Default is 0 for univariate time series. """ def __init__( @@ -50,6 +57,7 @@ def __init__( optimizer="adam", loss="mse", random_state=None, + axis=0, ): self.horizon = horizon self.window = window @@ -59,8 +67,11 @@ def __init__( self.optimizer = optimizer self.loss = loss self.random_state = random_state + self.axis = axis self.model_ = None - super().__init__() + + # Pass horizon and axis to BaseForecaster + super().__init__(horizon=horizon, axis=axis) def _fit(self, y, X=None): """Fit the forecaster to training data. @@ -74,7 +85,8 @@ def _fit(self, y, X=None): Returns ------- - self : returns an instance of self + self : BaseDeepForecaster + Returns an instance of self. """ # Set random seed for reproducibility if self.random_state is not None: @@ -83,12 +95,21 @@ def _fit(self, y, X=None): # Convert input data to numpy array y_inner = self._convert_input(y) + if y_inner.shape[0] < self.window + self.horizon: + raise ValueError( + f"Data length ({y_inner.shape[0]}) is insufficient" + f"({self.window}) and horizon ({self.horizon})." + ) # Create sequences for training X_train, y_train = self._create_sequences(y_inner) + if X_train.shape[0] == 0: + raise ValueError("No training sequences could be created.") + # Build and compile the model - self.model_ = self._build_model(X_train.shape[1:]) + input_shape = X_train.shape[1:] + self.model_ = self._build_model(input_shape) self.model_.compile(optimizer=self.optimizer, loss=self.loss) # Train the model @@ -115,7 +136,7 @@ def _predict(self, y=None, X=None): Returns ------- predictions : np.ndarray - Predicted values. + Predicted values for the specified horizon. """ if y is None: raise ValueError("y cannot be None for prediction") @@ -123,13 +144,26 @@ def _predict(self, y=None, X=None): # Convert input data to numpy array y_inner = self._convert_input(y) + if len(y_inner) < self.window: + raise ValueError( + f"Input data length ({len(y_inner)}) is less than the window size " + f"({self.window})." + ) + # Use the last window of data for prediction last_window = y_inner[-self.window :].reshape(1, self.window, 1) # Make prediction - prediction = self.model_.predict(last_window, verbose=0) + predictions = [] + current_window = last_window + for _ in range(self.horizon): + pred = self.model_.predict(current_window, verbose=0) + predictions.append(pred[0, 0]) + # Update the window with the latest prediction (autoregressive) + current_window = np.roll(current_window, -1, axis=1) + current_window[0, -1, 0] = pred[0, 0] - return prediction.flatten() + return np.array(predictions) def _forecast(self, y, X=None): """Forecast time series at future horizon. @@ -144,13 +178,9 @@ def _forecast(self, y, X=None): Returns ------- forecasts : np.ndarray - Forecasted values. + Forecasted values for the specified horizon. """ - # Fit the model - self._fit(y, X) - - # Make prediction - return self._predict(y, X) + return self._fit(y, X)._predict(y, X) def _convert_input(self, y): """Convert input data to numpy array. @@ -191,12 +221,20 @@ def _create_sequences(self, data): y : np.ndarray Target values. """ + if len(data) < self.window + self.horizon: + raise ValueError( + f"Data length ({len(data)}) is insufficient for window " + f"({self.window}) and horizon ({self.horizon})." + ) + X, y = [], [] for i in range(len(data) - self.window - self.horizon + 1): X.append(data[i : (i + self.window)]) y.append(data[i + self.window : (i + self.window + self.horizon)]) - return np.array(X).reshape(-1, self.window, 1), np.array(y) + X = np.array(X).reshape(-1, self.window, 1) + y = np.array(y).reshape(-1, self.horizon) + return X, y @abstractmethod def _build_model(self, input_shape): diff --git a/aeon/forecasting/deep_learning/tests/__init__.py b/aeon/forecasting/deep_learning/tests/__init__.py new file mode 100644 index 0000000000..3dda9d25ea --- /dev/null +++ b/aeon/forecasting/deep_learning/tests/__init__.py @@ -0,0 +1 @@ +"""Deep Learning Forecasting Tests File.""" diff --git a/aeon/forecasting/deep_learning/tests/test_base.py b/aeon/forecasting/deep_learning/tests/test_base.py new file mode 100644 index 0000000000..05536f98c5 --- /dev/null +++ b/aeon/forecasting/deep_learning/tests/test_base.py @@ -0,0 +1,62 @@ +"""Test for BaseDeepForecaster class in aeon.""" + +import numpy as np +import pytest + +from aeon.forecasting.deep_learning import BaseDeepForecaster +from aeon.utils.validation._dependencies import _check_soft_dependencies + + +class SimpleDeepForecaster(BaseDeepForecaster): + """A simple concrete implementation of BaseDeepForecaster for testing.""" + + def _build_model(self, input_shape): + import tensorflow as tf + + model = tf.keras.Sequential( + [ + tf.keras.layers.Flatten(input_shape=input_shape), + tf.keras.layers.Dense(10, activation="relu"), + tf.keras.layers.Dense(self.horizon), + ] + ) + return model + + +@pytest.mark.skipif( + not _check_soft_dependencies("tensorflow", severity="none"), + reason="skip test if required soft dependency not available", +) +def test_base_deep_forecaster_fit_predict(): + """Test fitting and predicting with BaseDeepForecaster implementation.""" + # Generate synthetic data + np.random.seed(42) + data = np.random.randn(50) + + # Initialize forecaster + forecaster = SimpleDeepForecaster(horizon=2, window=5, epochs=1, verbose=0) + + # Fit the model + forecaster.fit(data) + + # Predict + predictions = forecaster.predict(data) + + # Validate output shape + assert ( + len(predictions) == 2 + ), f"Expected predictions of length 2, got {len(predictions)}" + assert isinstance(predictions, np.ndarray), "Predictions should be a numpy array" + + +@pytest.mark.skipif( + not _check_soft_dependencies("tensorflow", severity="none"), + reason="skip test if required soft dependency not available", +) +def test_base_deep_forecaster_insufficient_data(): + """Test error handling for insufficient data.""" + data = np.random.randn(5) + forecaster = SimpleDeepForecaster(horizon=2, window=5, epochs=1, verbose=0) + + with pytest.raises(ValueError, match="Data length.*insufficient"): + forecaster.fit(data) From 1f202db1cae45834503986ffb37859599b29759a Mon Sep 17 00:00:00 2001 From: lucifer4073 Date: Sun, 15 Jun 2025 16:38:39 +0530 Subject: [PATCH 4/6] test locally --- .github/workflows/pr_pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr_pytest.yml b/.github/workflows/pr_pytest.yml index cf1baee900..9c70b16b4c 100644 --- a/.github/workflows/pr_pytest.yml +++ b/.github/workflows/pr_pytest.yml @@ -3,7 +3,7 @@ name: PR pytest on: push: branches: - - main + - basedlf pull_request: paths: - "aeon/**" From 14eb41fa83a5799d0fa8608ffd516f1766da7a1c Mon Sep 17 00:00:00 2001 From: lucifer4073 Date: Sun, 15 Jun 2025 17:19:27 +0530 Subject: [PATCH 5/6] dlf corrected --- .github/workflows/pr_pytest.yml | 2 +- aeon/forecasting/deep_learning/tests/test_base.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pr_pytest.yml b/.github/workflows/pr_pytest.yml index 9c70b16b4c..cf1baee900 100644 --- a/.github/workflows/pr_pytest.yml +++ b/.github/workflows/pr_pytest.yml @@ -3,7 +3,7 @@ name: PR pytest on: push: branches: - - basedlf + - main pull_request: paths: - "aeon/**" diff --git a/aeon/forecasting/deep_learning/tests/test_base.py b/aeon/forecasting/deep_learning/tests/test_base.py index 05536f98c5..1eae0969e1 100644 --- a/aeon/forecasting/deep_learning/tests/test_base.py +++ b/aeon/forecasting/deep_learning/tests/test_base.py @@ -10,6 +10,9 @@ class SimpleDeepForecaster(BaseDeepForecaster): """A simple concrete implementation of BaseDeepForecaster for testing.""" + def __init__(self, horizon=1, window=5, epochs=1, verbose=0): + super().__init__(horizon=horizon, window=window, epochs=epochs, verbose=verbose) + def _build_model(self, input_shape): import tensorflow as tf From d1a2aab72097dd38658fa0a10f572005c5b70aaa Mon Sep 17 00:00:00 2001 From: lucifer4073 Date: Sun, 22 Jun 2025 12:27:15 +0530 Subject: [PATCH 6/6] tf soft dep added --- aeon/forecasting/deep_learning/base.py | 3 ++- aeon/forecasting/deep_learning/tests/test_base.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/aeon/forecasting/deep_learning/base.py b/aeon/forecasting/deep_learning/base.py index ebab0116bd..ba33331fc7 100644 --- a/aeon/forecasting/deep_learning/base.py +++ b/aeon/forecasting/deep_learning/base.py @@ -13,7 +13,6 @@ import numpy as np import pandas as pd -import tensorflow as tf from aeon.forecasting.base import BaseForecaster @@ -88,6 +87,8 @@ def _fit(self, y, X=None): self : BaseDeepForecaster Returns an instance of self. """ + import tensorflow as tf + # Set random seed for reproducibility if self.random_state is not None: np.random.seed(self.random_state) diff --git a/aeon/forecasting/deep_learning/tests/test_base.py b/aeon/forecasting/deep_learning/tests/test_base.py index 1eae0969e1..270a60225e 100644 --- a/aeon/forecasting/deep_learning/tests/test_base.py +++ b/aeon/forecasting/deep_learning/tests/test_base.py @@ -7,6 +7,10 @@ from aeon.utils.validation._dependencies import _check_soft_dependencies +@pytest.mark.skipif( + not _check_soft_dependencies("tensorflow", severity="none"), + reason="skip test if required soft dependency not available", +) class SimpleDeepForecaster(BaseDeepForecaster): """A simple concrete implementation of BaseDeepForecaster for testing."""