diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 45d4772d5..56226985a 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -6,9 +6,12 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger +import numpy as np +import pandas as pd from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator from pytorch_forecasting._registry import all_objects +from pytorch_forecasting.data.timeseries import TimeSeriesDataSet from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS # whether to test only estimators from modules that are changed w.r.t. main @@ -242,6 +245,51 @@ def _integration( ) +def _timeseries_integration(model, name): + print(f"Testing {name} integration with timeseries") + n_timeseries = 10 + time_points = 10 + max_prediction_length = 2 + data = pd.DataFrame( + data={ + "target": np.random.rand(time_points * n_timeseries), + "time_idx": np.tile(np.arange(time_points), n_timeseries), + "group_id": np.repeat(np.arange(n_timeseries), time_points), + } + ) + training_dataset = TimeSeriesDataSet( + data=data, + time_idx="time_idx", + target="target", + group_ids=["group_id"], + time_varying_unknown_reals=["target"], + max_prediction_length=max_prediction_length, + max_encoder_length=3, + ) + training_data_loader = training_dataset.to_dataloader(train=True) + forecaster = model.from_dataset(training_dataset, log_val_interval=1) + trainer = pl.Trainer( + accelerator="cpu", + max_epochs=3, + min_epochs=2, + limit_train_batches=10, + ) + trainer.fit( + forecaster, + train_dataloaders=training_data_loader, + ) + validation_dataset = TimeSeriesDataSet.from_dataset( + training_dataset, data, stop_randomization=True, predict=True + ) + validation_data_loader = validation_dataset.to_dataloader(train=False) + forecaster.predict( + validation_data_loader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + ) + + class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator): """Generic tests for all objects in the mini package.""" @@ -251,6 +299,11 @@ def test_doctest_examples(self, object_class): run_doctest(object_class, name=f"class {object_class.__name__}") + def test_timeseries_integration(self, object_class): + """Runs timeseries integration for estimator class.""" + + _timeseries_integration(object_class, name=f"class {object_class.__name__}") + def test_integration( self, object_metadata,