diff --git a/pytorch_forecasting/models/base/__init__.py b/pytorch_forecasting/models/base/__init__.py index 7b69ec246..a0ac824e1 100644 --- a/pytorch_forecasting/models/base/__init__.py +++ b/pytorch_forecasting/models/base/__init__.py @@ -10,11 +10,13 @@ from pytorch_forecasting.models.base._base_object import ( _BaseObject, _BasePtForecaster, + _BasePtForecasterV2, ) __all__ = [ "_BaseObject", "_BasePtForecaster", + "_BasePtForecasterV2", "AutoRegressiveBaseModel", "AutoRegressiveBaseModelWithCovariates", "BaseModel", diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 75ff45f9b..12e37fcc5 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -11,16 +11,12 @@ class _BaseObject(_SkbaseBaseObject): pass -class _BasePtForecaster(_BaseObject): +class _BasePtForecaster_Common(_BaseObject): """Base class for all PyTorch Forecasting forecaster packages. This class points to model objects and contains metadata as tags. """ - _tags = { - "object_type": "forecaster_pytorch", - } - @classmethod def get_model_cls(cls): """Get model class.""" @@ -112,3 +108,19 @@ def create_test_instances_and_names(cls, parameter_set="default"): names = [cls.__name__] return objs, names + + +class _BasePtForecaster(_BasePtForecaster_Common): + """Base class for PyTorch Forecasting v1 forecasters.""" + + _tags = { + "object_type": ["forecaster_pytorch", "forecaster_pytorch_v1"], + } + + +class _BasePtForecasterV2(_BasePtForecaster_Common): + """Base class for PyTorch Forecasting v2 forecasters.""" + + _tags = { + "object_type": "forecaster_pytorch_v2", + } diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index c823d6229..66ba9b58a 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -3,6 +3,9 @@ from pytorch_forecasting.models.temporal_fusion_transformer._tft import ( TemporalFusionTransformer, ) +from pytorch_forecasting.models.temporal_fusion_transformer._tft_pkg_v2 import ( + TFT_pkg_v2, +) from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import ( AddNorm, GateAddNorm, @@ -19,5 +22,6 @@ "GatedLinearUnit", "GatedResidualNetwork", "InterpretableMultiHeadAttention", + "TFT_pkg_v2", "VariableSelectionNetwork", ] diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py new file mode 100644 index 000000000..5b9bfe6c7 --- /dev/null +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py @@ -0,0 +1,139 @@ +"""TFT package container.""" + +from pytorch_forecasting.models.base import _BasePtForecasterV2 + + +class TFT_pkg_v2(_BasePtForecasterV2): + """TFT package container.""" + + _tags = { + "info:name": "TFT", + "authors": ["phoeenniixx"], + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": False, + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT + + return TFT + + @classmethod + def _get_test_datamodule_from(cls, trainer_kwargs): + """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + from pytorch_forecasting.data.data_module import ( + EncoderDecoderTimeSeriesDataModule, + ) + from pytorch_forecasting.tests._data_scenarios import ( + data_with_covariates_v2, + make_datasets_v2, + ) + + data_with_covariates = data_with_covariates_v2() + + data_loader_default_kwargs = dict( + target="target", + group_ids=["agency_encoded", "sku_encoded"], + add_relative_time_idx=True, + ) + + data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) + data_loader_default_kwargs.update(data_loader_kwargs) + + datasets_info = make_datasets_v2( + data_with_covariates, **data_loader_default_kwargs + ) + + training_dataset = datasets_info["training_dataset"] + validation_dataset = datasets_info["validation_dataset"] + training_max_time_idx = datasets_info["training_max_time_idx"] + + max_encoder_length = data_loader_kwargs.get("max_encoder_length", 4) + max_prediction_length = data_loader_kwargs.get("max_prediction_length", 3) + add_relative_time_idx = data_loader_kwargs.get("add_relative_time_idx", True) + batch_size = data_loader_kwargs.get("batch_size", 2) + + train_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=training_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + add_relative_time_idx=add_relative_time_idx, + batch_size=batch_size, + train_val_test_split=(0.8, 0.2, 0.0), + ) + + val_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=batch_size, + train_val_test_split=(0.0, 1.0, 0.0), + ) + + test_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=1, + train_val_test_split=(0.0, 0.0, 1.0), + ) + + train_datamodule.setup("fit") + val_datamodule.setup("fit") + test_datamodule.setup("test") + + train_dataloader = train_datamodule.train_dataloader() + val_dataloader = val_datamodule.val_dataloader() + test_dataloader = test_datamodule.test_dataloader() + + return { + "train": train_dataloader, + "val": val_dataloader, + "test": test_dataloader, + "data_module": train_datamodule, + } + + @classmethod + def get_test_train_params(cls): + """Return testing parameter settings for the trainer. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + return [ + {}, + dict( + hidden_size=25, + attention_head_size=5, + ), + dict( + data_loader_kwargs=dict(max_encoder_length=5, max_prediction_length=3) + ), + dict( + hidden_size=24, + attention_head_size=8, + data_loader_kwargs=dict( + max_encoder_length=5, + max_prediction_length=3, + add_relative_time_idx=False, + ), + ), + dict( + hidden_size=12, + data_loader_kwargs=dict(max_encoder_length=7, max_prediction_length=10), + ), + dict(attention_head_size=2), + ] diff --git a/pytorch_forecasting/tests/_data_scenarios.py b/pytorch_forecasting/tests/_data_scenarios.py index c79b202d4..c13ff0ae5 100644 --- a/pytorch_forecasting/tests/_data_scenarios.py +++ b/pytorch_forecasting/tests/_data_scenarios.py @@ -1,9 +1,13 @@ +from datetime import datetime + import numpy as np +import pandas as pd import torch from pytorch_forecasting import TimeSeriesDataSet from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data +from pytorch_forecasting.data.timeseries import TimeSeries torch.manual_seed(23) @@ -78,6 +82,187 @@ def make_dataloaders(data_with_covariates, **kwargs): return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) +def data_with_covariates_v2(): + """Create synthetic time series data with all numerical features.""" + + start_date = datetime(2015, 1, 1) + end_date = datetime(2017, 12, 31) + dates = pd.date_range(start_date, end_date, freq="M") + + agencies = [0, 1] + skus = [0, 1] + data_list = [] + + for agency in agencies: + for sku in skus: + for date in dates: + time_idx = (date.year - 2015) * 12 + date.month - 1 + + volume = ( + np.random.exponential(2) + + 0.1 * time_idx + + 0.5 * np.sin(date.month * np.pi / 6) + ) + volume = max(0.001, volume) + month = date.month + year = date.year + quarter = (date.month - 1) // 3 + 1 + + seasonal_1 = np.sin(2 * np.pi * date.month / 12) + seasonal_2 = np.cos(2 * np.pi * date.month / 12) + + agency_feature_1 = agency * 10 + np.random.normal(0, 0.1) + agency_feature_2 = agency * 5 + np.random.normal(0, 0.1) + + sku_feature_1 = sku * 8 + np.random.normal(0, 0.1) + sku_feature_2 = sku * 3 + np.random.normal(0, 0.1) + + trend = time_idx * 0.1 + noise = np.random.normal(0, 0.1) + + special_event_1 = 1 if date.month in [12, 1] else 0 + special_event_2 = 1 if date.month in [6, 7, 8] else 0 + + data_list.append( + { + "date": date, + "time_idx": time_idx, + "agency_encoded": agency, + "sku_encoded": sku, + "volume": volume, + "target": volume, + "weight": 1.0 + np.sqrt(volume), + "month": month, + "year": year, + "quarter": quarter, + "seasonal_1": seasonal_1, + "seasonal_2": seasonal_2, + "agency_feature_1": agency_feature_1, + "agency_feature_2": agency_feature_2, + "sku_feature_1": sku_feature_1, + "sku_feature_2": sku_feature_2, + "trend": trend, + "noise": noise, + "special_event_1": special_event_1, + "special_event_2": special_event_2, + "log_volume": np.log1p(volume), + } + ) + + data = pd.DataFrame(data_list) + + numeric_cols = [col for col in data.columns if col != "date"] + for col in numeric_cols: + data[col] = pd.to_numeric(data[col], errors="coerce") + data[numeric_cols] = data[numeric_cols].fillna(0) + + return data + + +def make_datasets_v2(data_with_covariates, **kwargs): + """Create datasets with consistent encoder/decoder features.""" + + training_cutoff = "2016-09-01" + target_col = kwargs.get("target", "target") + group_cols = kwargs.get("group_ids", ["agency_encoded", "sku_encoded"]) + + known_features = [ + "month", + "year", + "quarter", + "seasonal_1", + "seasonal_2", + "special_event_1", + "special_event_2", + "trend", + ] + unknown_features = [ + "agency_feature_1", + "agency_feature_2", + "sku_feature_1", + "sku_feature_2", + "noise", + "log_volume", + ] + + numerical_features = known_features + unknown_features + categorical_features = [] + static_features = group_cols + + for col in numerical_features + categorical_features + group_cols + [target_col]: + if col in data_with_covariates.columns: + data_with_covariates[col] = pd.to_numeric( + data_with_covariates[col], errors="coerce" + ).fillna(0) + + for col in categorical_features + group_cols: + if col in data_with_covariates.columns: + data_with_covariates[col] = data_with_covariates[col].astype("int64") + + if "weight" in data_with_covariates.columns: + data_with_covariates["weight"] = pd.to_numeric( + data_with_covariates["weight"], errors="coerce" + ).fillna(1.0) + + training_data = data_with_covariates[ + data_with_covariates.date < training_cutoff + ].copy() + validation_data = data_with_covariates.copy() + + required_columns = ( + ["time_idx", target_col, "weight", "date"] + + group_cols + + numerical_features + + categorical_features + ) + + available_columns = [ + col for col in required_columns if col in data_with_covariates.columns + ] + + training_data_clean = training_data[available_columns].copy() + validation_data_clean = validation_data[available_columns].copy() + + if "date" in training_data_clean.columns: + training_data_clean = training_data_clean.drop("date", axis=1) + if "date" in validation_data_clean.columns: + validation_data_clean = validation_data_clean.drop("date", axis=1) + + training_dataset = TimeSeries( + data=training_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + validation_dataset = TimeSeries( + data=validation_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + training_max_time_idx = training_data["time_idx"].max() + 1 + + return { + "training_dataset": training_dataset, + "validation_dataset": validation_dataset, + "training_max_time_idx": training_max_time_idx, + } + + def dataloaders_with_different_encoder_decoder_length(): return make_dataloaders( data_with_covariates(), diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index d1c4bb754..4dad73ab8 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -250,6 +250,8 @@ def _integration( class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator): """Generic tests for all objects in the mini package.""" + object_type_filter = "forecaster_pytorch_v1" + def test_doctest_examples(self, object_class): """Runs doctests for estimator class.""" from skbase.utils.doctest_run import run_doctest diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py new file mode 100644 index 000000000..4b063ed44 --- /dev/null +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -0,0 +1,117 @@ +"""Automated tests based on the skbase test suite template.""" + +from inspect import isclass +import shutil + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.loggers import TensorBoardLogger +import torch.nn as nn + +from pytorch_forecasting.tests.test_all_estimators import ( + BaseFixtureGenerator, + PackageConfig, +) + +# whether to test only estimators from modules that are changed w.r.t. main +# default is False, can be set to True by pytest --only_changed_modules True flag +ONLY_CHANGED_MODULES = False + + +def _integration( + estimator_cls, + dataloaders, + tmp_path, + data_loader_kwargs={}, + clip_target: bool = False, + trainer_kwargs=None, + **kwargs, +): + train_dataloader = dataloaders["train"] + val_dataloader = dataloaders["val"] + test_dataloader = dataloaders["test"] + + early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" + ) + + logger = TensorBoardLogger(tmp_path) + if trainer_kwargs is None: + trainer_kwargs = {} + trainer = pl.Trainer( + max_epochs=3, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + **trainer_kwargs, + ) + training_data_module = dataloaders.get("data_module") + metadata = training_data_module.metadata + + assert isinstance( + metadata, dict + ), f"Expected metadata to be dict, got {type(metadata)}" + + net = estimator_cls( + metadata=metadata, + loss=nn.MSELoss(), + **kwargs, + ) + + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + + shutil.rmtree(tmp_path, ignore_errors=True) + + +class TestAllPtForecastersV2(PackageConfig, BaseFixtureGenerator): + """Generic tests for all objects in the mini package.""" + + object_type_filter = "forecaster_pytorch_v2" + + def test_doctest_examples(self, object_class): + """Runs doctests for estimator class.""" + from skbase.utils.doctest_run import run_doctest + + run_doctest(object_class, name=f"class {object_class.__name__}") + + def test_integration( + self, + object_pkg, + trainer_kwargs, + tmp_path, + ): + object_class = object_pkg.get_model_cls() + dataloaders = object_pkg._get_test_datamodule_from(trainer_kwargs) + + _integration(object_class, dataloaders, tmp_path, **trainer_kwargs) + + def test_pkg_linkage(self, object_pkg, object_class): + """Test that the package is linked correctly.""" + # check name method + msg = ( + f"Package {object_pkg}.name() does not match class " + f"name {object_class.__name__}. " + "The expected package name is " + f"{object_class.__name__}_pkg." + ) + assert object_pkg.name() == object_class.__name__, msg + + # check naming convention + msg = ( + f"Package {object_pkg.__name__} does not match class " + f"name {object_class.__name__}. " + "The expected package name is " + f"{object_class.__name__}_pkg." + ) + assert object_pkg.__name__ == object_class.__name__ + "_pkg_v2", msg