From b3644a65b041a790b94756fb1d9bbf2797236d0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 22 Feb 2025 23:18:51 +0100 Subject: [PATCH 01/80] test suite --- pyproject.toml | 1 + pytorch_forecasting/tests/__init__.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 pytorch_forecasting/tests/__init__.py diff --git a/pyproject.toml b/pyproject.toml index f3d1e339c..8c661db1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,7 @@ dev = [ "pytest-dotenv>=0.5.2,<1.0.0", "tensorboard>=2.12.1,<3.0.0", "pandoc>=2.3,<3.0.0", + "scikit-base", ] # docs - dependencies for building the documentation diff --git a/pytorch_forecasting/tests/__init__.py b/pytorch_forecasting/tests/__init__.py new file mode 100644 index 000000000..6c2d26856 --- /dev/null +++ b/pytorch_forecasting/tests/__init__.py @@ -0,0 +1 @@ +"""PyTorch Forecasting test suite.""" From 4b2486e083ca93d8f4c1a29a6a25d882027815f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 22 Feb 2025 23:33:29 +0100 Subject: [PATCH 02/80] skeleton --- pytorch_forecasting/_registry/__init__.py | 16 ++ pytorch_forecasting/_registry/_lookup.py | 214 ++++++++++++++++++ pytorch_forecasting/models/base/__init__.py | 2 + .../models/base/_base_object.py | 8 + pytorch_forecasting/tests/_config.py | 13 ++ .../tests/test_all_estimators.py | 118 ++++++++++ 6 files changed, 371 insertions(+) create mode 100644 pytorch_forecasting/_registry/__init__.py create mode 100644 pytorch_forecasting/_registry/_lookup.py create mode 100644 pytorch_forecasting/models/base/_base_object.py create mode 100644 pytorch_forecasting/tests/_config.py create mode 100644 pytorch_forecasting/tests/test_all_estimators.py diff --git a/pytorch_forecasting/_registry/__init__.py b/pytorch_forecasting/_registry/__init__.py new file mode 100644 index 000000000..bb0b88e61 --- /dev/null +++ b/pytorch_forecasting/_registry/__init__.py @@ -0,0 +1,16 @@ +"""PyTorch Forecasting registry.""" + +from pytorch_forecasting._registry._lookup import all_objects, all_tags +from pytorch_forecasting._registry._tags import ( + OBJECT_TAG_LIST, + OBJECT_TAG_REGISTER, + check_tag_is_valid, +) + +__all__ = [ + "OBJECT_TAG_LIST", + "OBJECT_TAG_REGISTER", + "all_objects", + "all_tags", + "check_tag_is_valid", +] diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py new file mode 100644 index 000000000..ea3210cb0 --- /dev/null +++ b/pytorch_forecasting/_registry/_lookup.py @@ -0,0 +1,214 @@ +"""Registry lookup methods. + +This module exports the following methods for registry lookup: + +all_objects(object_types, filter_tags) + lookup and filtering of objects +""" +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) +# based on the sktime module of same name + +__author__ = ["fkiraly"] +# all_objects is based on the sklearn utility all_estimators + + +from copy import deepcopy +from operator import itemgetter +from pathlib import Path + +import pandas as pd +from skbase.lookup import all_objects as _all_objects + +from pytorch_forecasting.models.base import _BaseObject + + +def all_objects( + object_types=None, + filter_tags=None, + exclude_objects=None, + return_names=True, + as_dataframe=False, + return_tags=None, + suppress_import_stdout=True, +): + """Get a list of all objects from pytorch_forecasting. + + This function crawls the module and gets all classes that inherit + from skbase compatible base classes. + + Not included are: the base classes themselves, classes defined in test + modules. + + Parameters + ---------- + object_types: str, list of str, optional (default=None) + Which kind of objects should be returned. + if None, no filter is applied and all objects are returned. + if str or list of str, strings define scitypes specified in search + only objects that are of (at least) one of the scitypes are returned + possible str values are entries of registry.BASE_CLASS_REGISTER (first col) + for instance 'regrssor_proba', 'distribution, 'metric' + + return_names: bool, optional (default=True) + + if True, estimator class name is included in the ``all_objects`` + return in the order: name, estimator class, optional tags, either as + a tuple or as pandas.DataFrame columns + + if False, estimator class name is removed from the ``all_objects`` return. + + filter_tags: dict of (str or list of str), optional (default=None) + For a list of valid tag strings, use the registry.all_tags utility. + + ``filter_tags`` subsets the returned estimators as follows: + + * each key/value pair is statement in "and"/conjunction + * key is tag name to sub-set on + * value str or list of string are tag values + * condition is "key must be equal to value, or in set(value)" + + exclude_estimators: str, list of str, optional (default=None) + Names of estimators to exclude. + + as_dataframe: bool, optional (default=False) + + True: ``all_objects`` will return a pandas.DataFrame with named + columns for all of the attributes being returned. + + False: ``all_objects`` will return a list (either a list of + estimators or a list of tuples, see Returns) + + return_tags: str or list of str, optional (default=None) + Names of tags to fetch and return each estimator's value of. + For a list of valid tag strings, use the registry.all_tags utility. + if str or list of str, + the tag values named in return_tags will be fetched for each + estimator and will be appended as either columns or tuple entries. + + suppress_import_stdout : bool, optional. Default=True + whether to suppress stdout printout upon import. + + Returns + ------- + all_objects will return one of the following: + 1. list of objects, if return_names=False, and return_tags is None + 2. list of tuples (optional object name, class, ~optional object + tags), if return_names=True or return_tags is not None. + 3. pandas.DataFrame if as_dataframe = True + if list of objects: + entries are objects matching the query, + in alphabetical order of object name + if list of tuples: + list of (optional object name, object, optional object + tags) matching the query, in alphabetical order of object name, + where + ``name`` is the object name as string, and is an + optional return + ``object`` is the actual object + ``tags`` are the object's values for each tag in return_tags + and is an optional return. + if dataframe: + all_objects will return a pandas.DataFrame. + column names represent the attributes contained in each column. + "objects" will be the name of the column of objects, "names" + will be the name of the column of object class names and the string(s) + passed in return_tags will serve as column names for all columns of + tags that were optionally requested. + + Examples + -------- + >>> from skpro.registry import all_objects + >>> # return a complete list of objects as pd.Dataframe + >>> all_objects(as_dataframe=True) # doctest: +SKIP + >>> # return all probabilistic regressors by filtering for object type + >>> all_objects("regressor_proba", as_dataframe=True) # doctest: +SKIP + >>> # return all regressors which handle missing data in the input by tag filtering + >>> all_objects( + ... "regressor_proba", + ... filter_tags={"capability:missing": True}, + ... as_dataframe=True + ... ) # doctest: +SKIP + + References + ---------- + Adapted version of sktime's ``all_estimators``, + which is an evolution of scikit-learn's ``all_estimators`` + """ + MODULES_TO_IGNORE = ( + "tests", + "setup", + "contrib", + "utils", + "all", + ) + + result = [] + ROOT = str(Path(__file__).parent.parent) # skpro package root directory + + if isinstance(filter_tags, str): + filter_tags = {filter_tags: True} + filter_tags = filter_tags.copy() if filter_tags else None + + if object_types: + if filter_tags and "object_type" not in filter_tags.keys(): + object_tag_filter = {"object_type": object_types} + elif filter_tags: + filter_tags_filter = filter_tags.get("object_type", []) + if isinstance(object_types, str): + object_types = [object_types] + object_tag_update = {"object_type": object_types + filter_tags_filter} + filter_tags.update(object_tag_update) + else: + object_tag_filter = {"object_type": object_types} + if filter_tags: + filter_tags.update(object_tag_filter) + else: + filter_tags = object_tag_filter + + result = _all_objects( + object_types=[_BaseObject], + filter_tags=filter_tags, + exclude_objects=exclude_objects, + return_names=return_names, + as_dataframe=as_dataframe, + return_tags=return_tags, + suppress_import_stdout=suppress_import_stdout, + package_name="skpro", + path=ROOT, + modules_to_ignore=MODULES_TO_IGNORE, + ) + + return result + + +def _check_list_of_str_or_error(arg_to_check, arg_name): + """Check that certain arguments are str or list of str. + + Parameters + ---------- + arg_to_check: argument we are testing the type of + arg_name: str, + name of the argument we are testing, will be added to the error if + ``arg_to_check`` is not a str or a list of str + + Returns + ------- + arg_to_check: list of str, + if arg_to_check was originally a str it converts it into a list of str + so that it can be iterated over. + + Raises + ------ + TypeError if arg_to_check is not a str or list of str + """ + # check that return_tags has the right type: + if isinstance(arg_to_check, str): + arg_to_check = [arg_to_check] + if not isinstance(arg_to_check, list) or not all( + isinstance(value, str) for value in arg_to_check + ): + raise TypeError( + f"Error in all_objects! Argument {arg_name} must be either\ + a str or list of str" + ) + return arg_to_check diff --git a/pytorch_forecasting/models/base/__init__.py b/pytorch_forecasting/models/base/__init__.py index 4860e4838..474e7d564 100644 --- a/pytorch_forecasting/models/base/__init__.py +++ b/pytorch_forecasting/models/base/__init__.py @@ -7,8 +7,10 @@ BaseModelWithCovariates, Prediction, ) +from pytorch_forecasting.models.base._base_object import _BaseObject __all__ = [ + "_BaseObject", "AutoRegressiveBaseModel", "AutoRegressiveBaseModelWithCovariates", "BaseModel", diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py new file mode 100644 index 000000000..7330867b1 --- /dev/null +++ b/pytorch_forecasting/models/base/_base_object.py @@ -0,0 +1,8 @@ +"""Base Classes for pytorch-forecasting models, skbase compatible for indexing.""" + +from skbase.base import BaseObject as _SkbaseBaseObject + + +class _BaseObject(_SkbaseBaseObject): + + pass diff --git a/pytorch_forecasting/tests/_config.py b/pytorch_forecasting/tests/_config.py new file mode 100644 index 000000000..dd9c2e889 --- /dev/null +++ b/pytorch_forecasting/tests/_config.py @@ -0,0 +1,13 @@ +"""Test configs.""" + +# list of str, names of estimators to exclude from testing +# WARNING: tests for these estimators will be skipped +EXCLUDE_ESTIMATORS = [ + "DummySkipped", + "ClassName", # exclude classes from extension templates +] + +# dictionary of lists of str, names of tests to exclude from testing +# keys are class names of estimators, values are lists of test names to exclude +# WARNING: tests with these names will be skipped +EXCLUDED_TESTS = {} diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py new file mode 100644 index 000000000..c806691fa --- /dev/null +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -0,0 +1,118 @@ +"""Automated tests based on the skbase test suite template.""" +from inspect import isclass + +from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator +from skbase.testing import TestAllObjects as _TestAllObjects + +from pytorch_forecasting._registry import all_objects +from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS + + +# whether to test only estimators from modules that are changed w.r.t. main +# default is False, can be set to True by pytest --only_changed_modules True flag +ONLY_CHANGED_MODULES = False + + +class PackageConfig: + """Contains package config variables for test classes.""" + + # class variables which can be overridden by descendants + # ------------------------------------------------------ + + # package to search for objects + # expected type: str, package/module name, relative to python environment root + package_name = "pytroch_forecasting" + + # list of object types (class names) to exclude + # expected type: list of str, str are class names + exclude_objects = EXCLUDE_ESTIMATORS + + # list of tests to exclude + # expected type: dict of lists, key:str, value: List[str] + # keys are class names of estimators, values are lists of test names to exclude + excluded_tests = EXCLUDED_TESTS + + +class BaseFixtureGenerator(_BaseFixtureGenerator): + """Fixture generator for base testing functionality in sktime. + + Test classes inheriting from this and not overriding pytest_generate_tests + will have estimator and scenario fixtures parametrized out of the box. + + Descendants can override: + estimator_type_filter: str, class variable; None or scitype string + e.g., "forecaster", "transformer", "classifier", see BASE_CLASS_SCITYPE_LIST + which estimators are being retrieved and tested + fixture_sequence: list of str + sequence of fixture variable names in conditional fixture generation + _generate_[variable]: object methods, all (test_name: str, **kwargs) -> list + generating list of fixtures for fixture variable with name [variable] + to be used in test with name test_name + can optionally use values for fixtures earlier in fixture_sequence, + these must be input as kwargs in a call + is_excluded: static method (test_name: str, est: class) -> bool + whether test with name test_name should be excluded for estimator est + should be used only for encoding general rules, not individual skips + individual skips should go on the EXCLUDED_TESTS list in _config + requires _generate_object_class and _generate_object_instance as is + _excluded_scenario: static method (test_name: str, scenario) -> bool + whether scenario should be skipped in test with test_name test_name + requires _generate_estimator_scenario as is + + Fixtures parametrized + --------------------- + object_class: estimator inheriting from BaseObject + ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS + object_instance: instance of estimator inheriting from BaseObject + ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS + instances are generated by create_test_instance class method of object_class + """ + + # overrides object retrieval in scikit-base + def _all_objects(self): + """Retrieve list of all object classes of type self.object_type_filter. + + If self.object_type_filter is None, retrieve all objects. + If class, retrieve all classes inheriting from self.object_type_filter. + Otherwise (assumed str or list of str), retrieve all classes with tags + object_type in self.object_type_filter. + """ + filter = getattr(self, "object_type_filter", None) + + if isclass(filter): + object_types = filter.get_class_tag("object_type", None) + else: + object_types = filter + + obj_list = all_objects( + object_types=object_types, + return_names=False, + exclude_objects=self.exclude_objects, + ) + + if isclass(filter): + obj_list = [obj for obj in obj_list if issubclass(obj, filter)] + + # run_test_for_class selects the estimators to run + # based on whether they have changed, and whether they have all dependencies + # internally, uses the ONLY_CHANGED_MODULES flag, + # and checks the python env against python_dependencies tag + # obj_list = [obj for obj in obj_list if run_test_for_class(obj)] + + return obj_list + + # which sequence the conditional fixtures are generated in + fixture_sequence = [ + "object_class", + "object_instance", + ] + + +class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator, _TestAllObjects): + """Generic tests for all objects in the mini package.""" + + def test_doctest_examples(self, object_class): + """Runs doctests for estimator class.""" + import doctest + + doctest.run_docstring_examples(object_class, globals()) From 02b0ce6fa53443044fffce8cbbce54a0c6d6b947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 22 Feb 2025 23:33:58 +0100 Subject: [PATCH 03/80] skeleton --- pytorch_forecasting/_registry/__init__.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/pytorch_forecasting/_registry/__init__.py b/pytorch_forecasting/_registry/__init__.py index bb0b88e61..f71836bfe 100644 --- a/pytorch_forecasting/_registry/__init__.py +++ b/pytorch_forecasting/_registry/__init__.py @@ -1,16 +1,5 @@ """PyTorch Forecasting registry.""" -from pytorch_forecasting._registry._lookup import all_objects, all_tags -from pytorch_forecasting._registry._tags import ( - OBJECT_TAG_LIST, - OBJECT_TAG_REGISTER, - check_tag_is_valid, -) +from pytorch_forecasting._registry._lookup import all_objects -__all__ = [ - "OBJECT_TAG_LIST", - "OBJECT_TAG_REGISTER", - "all_objects", - "all_tags", - "check_tag_is_valid", -] +__all__ = ["all_objects"] From 41cbf667f9aea5848c3390778a53612338319504 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 06:42:02 +0100 Subject: [PATCH 04/80] Update test_all_estimators.py --- pytorch_forecasting/tests/test_all_estimators.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index c806691fa..704ddfc20 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -1,13 +1,15 @@ """Automated tests based on the skbase test suite template.""" + from inspect import isclass -from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator -from skbase.testing import TestAllObjects as _TestAllObjects +from skbase.testing import ( + BaseFixtureGenerator as _BaseFixtureGenerator, + TestAllObjects as _TestAllObjects, +) from pytorch_forecasting._registry import all_objects from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS - # whether to test only estimators from modules that are changed w.r.t. main # default is False, can be set to True by pytest --only_changed_modules True flag ONLY_CHANGED_MODULES = False From cef62d36df5eceff5238a2d6c7fd829319028446 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 06:47:24 +0100 Subject: [PATCH 05/80] Update _base_object.py --- pytorch_forecasting/models/base/_base_object.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 7330867b1..a91b10525 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -1,6 +1,8 @@ """Base Classes for pytorch-forecasting models, skbase compatible for indexing.""" -from skbase.base import BaseObject as _SkbaseBaseObject +from pytorch_forecasting.utils._dependencies import _safe_import + +_SkbaseBaseObject = _safe_import("skbase._base_object._BaseObject") class _BaseObject(_SkbaseBaseObject): From bc2e93b606095440772f7236eeebb070109c649f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 09:45:10 +0100 Subject: [PATCH 06/80] Update _lookup.py --- pytorch_forecasting/_registry/_lookup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index ea3210cb0..517bfa9af 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -5,7 +5,6 @@ all_objects(object_types, filter_tags) lookup and filtering of objects """ -# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) # based on the sktime module of same name __author__ = ["fkiraly"] From eee1c86859dc1d66d46eb85c7b39938639f8231e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 10:37:22 +0100 Subject: [PATCH 07/80] Update _lookup.py --- pytorch_forecasting/_registry/_lookup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index 517bfa9af..1ab4cfdb3 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -5,6 +5,7 @@ all_objects(object_types, filter_tags) lookup and filtering of objects """ + # based on the sktime module of same name __author__ = ["fkiraly"] From 164fe0d238ebe6b9f888c416f998a73948b365f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:05:54 +0100 Subject: [PATCH 08/80] base metadatda --- .../models/base/_base_object.py | 98 ++++++++++++++ .../models/deepar/_deepar_metadata.py | 128 ++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 pytorch_forecasting/models/deepar/_deepar_metadata.py diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index a91b10525..62fad456b 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -1,5 +1,7 @@ """Base Classes for pytorch-forecasting models, skbase compatible for indexing.""" +import inspect + from pytorch_forecasting.utils._dependencies import _safe_import _SkbaseBaseObject = _safe_import("skbase._base_object._BaseObject") @@ -8,3 +10,99 @@ class _BaseObject(_SkbaseBaseObject): pass + + +class _BasePtForecaster(_BaseObject): + """Base class for all PyTorch Forecasting forecaster metadata. + + This class points to model objects and contains metadata as tags. + """ + + _tags = { + "object_type": "forecaster_pytorch", + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + raise NotImplementedError + + + @classmethod + def create_test_instance(cls, parameter_set="default"): + """Construct an instance of the class, using first test parameter set. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + instance : instance of the class with default parameters + + """ + if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: + params = cls.get_test_params(parameter_set=parameter_set) + else: + params = cls.get_test_params() + + if isinstance(params, list) and isinstance(params[0], dict): + params = params[0] + elif isinstance(params, dict): + pass + else: + raise TypeError( + "get_test_params should either return a dict or list of dict." + ) + + return cls.get_model_cls()(**params) + + @classmethod + def create_test_instances_and_names(cls, parameter_set="default"): + """Create list of all test instances and a list of names for them. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + objs : list of instances of cls + i-th instance is ``cls(**cls.get_test_params()[i])`` + names : list of str, same length as objs + i-th element is name of i-th instance of obj in tests. + The naming convention is ``{cls.__name__}-{i}`` if more than one instance, + otherwise ``{cls.__name__}`` + """ + if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: + param_list = cls.get_test_params(parameter_set=parameter_set) + else: + param_list = cls.get_test_params() + + objs = [] + if not isinstance(param_list, (dict, list)): + raise RuntimeError( + f"Error in {cls.__name__}.get_test_params, " + "return must be param dict for class, or list thereof" + ) + if isinstance(param_list, dict): + param_list = [param_list] + for params in param_list: + if not isinstance(params, dict): + raise RuntimeError( + f"Error in {cls.__name__}.get_test_params, " + "return must be param dict for class, or list thereof" + ) + objs += [cls.get_model_cls()(**params)] + + num_instances = len(param_list) + if num_instances > 1: + names = [cls.__name__ + "-" + str(i) for i in range(num_instances)] + else: + names = [cls.__name__] + + return objs, names diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py new file mode 100644 index 000000000..330b86e80 --- /dev/null +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -0,0 +1,128 @@ +"""DeepAR metadata container.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecaster + + +class _DeepARMetadata(_BasePtForecaster): + """DeepAR metadata container.""" + + _tags = { + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": True, + "capability:cold_start": False, + "info:compute": 3, + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + from pytorch_forecasting.models import DeepAR + + return DeepAR + + @classmethod + def get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the skbase object. + + ``get_test_params`` is a unified interface point to store + parameter settings for testing purposes. This function is also + used in ``create_test_instance`` and ``create_test_instances_and_names`` + to construct test instances. + + ``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``. + + Each ``dict`` is a parameter configuration for testing, + and can be used to construct an "interesting" test instance. + A call to ``cls(**params)`` should + be valid for all dictionaries ``params`` in the return of ``get_test_params``. + + The ``get_test_params`` need not return fixed lists of dictionaries, + it can also return dynamic or stochastic parameter settings. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + from pytorch_forecasting.data.encoders import GroupNormalizer + from pytorch_forecasting.metrics import ( + BetaDistributionLoss, + ImplicitQuantileNetworkDistributionLoss, + LogNormalDistributionLoss, + MultivariateNormalDistributionLoss, + NegativeBinomialDistributionLoss, + ) + + return [ + {}, + {"cell_type": "GRU"}, + dict( + loss=LogNormalDistributionLoss(), + clip_target=True, + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="log" + ) + ), + ), + dict( + loss=NegativeBinomialDistributionLoss(), + clip_target=False, + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], center=False + ) + ), + ), + dict( + loss=BetaDistributionLoss(), + clip_target=True, + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="logit" + ) + ), + ), + dict( + data_loader_kwargs=dict( + lags={"volume": [2, 5]}, + target="volume", + time_varying_unknown_reals=["volume"], + min_encoder_length=2, + ) + ), + dict( + data_loader_kwargs=dict( + time_varying_unknown_reals=["volume", "discount"], + target=["volume", "discount"], + lags={"volume": [2], "discount": [2]}, + ) + ), + dict( + loss=ImplicitQuantileNetworkDistributionLoss(hidden_size=8), + ), + dict( + loss=MultivariateNormalDistributionLoss(), + trainer_kwargs=dict(accelerator="cpu"), + ), + dict( + loss=MultivariateNormalDistributionLoss(), + data_loader_kwargs=dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="log1p" + ) + ), + trainer_kwargs=dict(accelerator="cpu"), + ), + ] From 20e88d09993f3fed62ab52f93d0a4678f1a0c068 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:21:34 +0100 Subject: [PATCH 09/80] registry --- pytorch_forecasting/_registry/_lookup.py | 18 +++--------------- pytorch_forecasting/models/base/__init__.py | 6 +++++- .../models/base/_base_object.py | 2 +- .../utils/_dependencies/_safe_import.py | 3 +++ 4 files changed, 12 insertions(+), 17 deletions(-) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index 1ab4cfdb3..828c448b1 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -11,12 +11,8 @@ __author__ = ["fkiraly"] # all_objects is based on the sklearn utility all_estimators - -from copy import deepcopy -from operator import itemgetter from pathlib import Path -import pandas as pd from skbase.lookup import all_objects as _all_objects from pytorch_forecasting.models.base import _BaseObject @@ -117,17 +113,9 @@ def all_objects( Examples -------- - >>> from skpro.registry import all_objects + >>> from pytorch_forecasting._registry import all_objects >>> # return a complete list of objects as pd.Dataframe >>> all_objects(as_dataframe=True) # doctest: +SKIP - >>> # return all probabilistic regressors by filtering for object type - >>> all_objects("regressor_proba", as_dataframe=True) # doctest: +SKIP - >>> # return all regressors which handle missing data in the input by tag filtering - >>> all_objects( - ... "regressor_proba", - ... filter_tags={"capability:missing": True}, - ... as_dataframe=True - ... ) # doctest: +SKIP References ---------- @@ -143,7 +131,7 @@ def all_objects( ) result = [] - ROOT = str(Path(__file__).parent.parent) # skpro package root directory + ROOT = str(Path(__file__).parent.parent) # package root directory if isinstance(filter_tags, str): filter_tags = {filter_tags: True} @@ -173,7 +161,7 @@ def all_objects( as_dataframe=as_dataframe, return_tags=return_tags, suppress_import_stdout=suppress_import_stdout, - package_name="skpro", + package_name="pytorch_forecasting", path=ROOT, modules_to_ignore=MODULES_TO_IGNORE, ) diff --git a/pytorch_forecasting/models/base/__init__.py b/pytorch_forecasting/models/base/__init__.py index 474e7d564..7b69ec246 100644 --- a/pytorch_forecasting/models/base/__init__.py +++ b/pytorch_forecasting/models/base/__init__.py @@ -7,10 +7,14 @@ BaseModelWithCovariates, Prediction, ) -from pytorch_forecasting.models.base._base_object import _BaseObject +from pytorch_forecasting.models.base._base_object import ( + _BaseObject, + _BasePtForecaster, +) __all__ = [ "_BaseObject", + "_BasePtForecaster", "AutoRegressiveBaseModel", "AutoRegressiveBaseModelWithCovariates", "BaseModel", diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 62fad456b..8895b4c2c 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -4,7 +4,7 @@ from pytorch_forecasting.utils._dependencies import _safe_import -_SkbaseBaseObject = _safe_import("skbase._base_object._BaseObject") +_SkbaseBaseObject = _safe_import("skbase.base.BaseObject", pkg_name="scikit-base") class _BaseObject(_SkbaseBaseObject): diff --git a/pytorch_forecasting/utils/_dependencies/_safe_import.py b/pytorch_forecasting/utils/_dependencies/_safe_import.py index f4805f9c1..ffbde8b5d 100644 --- a/pytorch_forecasting/utils/_dependencies/_safe_import.py +++ b/pytorch_forecasting/utils/_dependencies/_safe_import.py @@ -70,6 +70,9 @@ def _safe_import(import_path, pkg_name=None): if pkg_name is None: path_list = import_path.split(".") pkg_name = path_list[0] + else: + path_list = import_path.split(".") + path_list = [pkg_name] + path_list[1:] if pkg_name in _get_installed_packages(): try: From 318c1fbdbfface24fdc67568cf9a01a6dde1650c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:33:42 +0100 Subject: [PATCH 10/80] fix private name --- pytorch_forecasting/models/deepar/__init__.py | 3 ++- pytorch_forecasting/models/deepar/_deepar_metadata.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/models/deepar/__init__.py b/pytorch_forecasting/models/deepar/__init__.py index 679f296f6..149e19e0d 100644 --- a/pytorch_forecasting/models/deepar/__init__.py +++ b/pytorch_forecasting/models/deepar/__init__.py @@ -1,5 +1,6 @@ """DeepAR: Probabilistic forecasting with autoregressive recurrent networks.""" from pytorch_forecasting.models.deepar._deepar import DeepAR +from pytorch_forecasting.models.deepar._deepar_metadata import DeepARMetadata -__all__ = ["DeepAR"] +__all__ = ["DeepAR", "DeepARMetadata"] diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index 330b86e80..89aefc1b0 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -3,7 +3,7 @@ from pytorch_forecasting.models.base._base_object import _BasePtForecaster -class _DeepARMetadata(_BasePtForecaster): +class DeepARMetadata(_BasePtForecaster): """DeepAR metadata container.""" _tags = { From 012ab3d78ed8a99e6920f3df704188834bbe1c14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:48:02 +0100 Subject: [PATCH 11/80] Update _base_object.py --- pytorch_forecasting/models/base/_base_object.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 8895b4c2c..4fcb1bd22 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -27,7 +27,6 @@ def get_model_cls(cls): """Get model class.""" raise NotImplementedError - @classmethod def create_test_instance(cls, parameter_set="default"): """Construct an instance of the class, using first test parameter set. From 86365a00d88cda407674dbcac5c4d53bd26f3fce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:57:16 +0100 Subject: [PATCH 12/80] test failure --- pytorch_forecasting/tests/_conftest.py | 262 ++++++++++++++++++ .../tests/test_all_estimators.py | 101 +++++++ 2 files changed, 363 insertions(+) create mode 100644 pytorch_forecasting/tests/_conftest.py diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py new file mode 100644 index 000000000..e276446a6 --- /dev/null +++ b/pytorch_forecasting/tests/_conftest.py @@ -0,0 +1,262 @@ +import numpy as np +import pytest +import torch + +from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data + +torch.manual_seed(23) + + +@pytest.fixture(scope="session") +def gpus(): + if torch.cuda.is_available(): + return [0] + else: + return 0 + + +@pytest.fixture(scope="session") +def data_with_covariates(): + data = get_stallion_data() + data["month"] = data.date.dt.month.astype(str) + data["log_volume"] = np.log1p(data.volume) + data["weight"] = 1 + np.sqrt(data.volume) + + data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month + data["time_idx"] -= data["time_idx"].min() + + # convert special days into strings + special_days = [ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + data[special_days] = ( + data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category") + ) + data = data.astype(dict(industry_volume=float)) + + # select data subset + data = data[lambda x: x.sku.isin(data.sku.unique()[:2])][ + lambda x: x.agency.isin(data.agency.unique()[:2]) + ] + + # default target + data["target"] = data["volume"].clip(1e-3, 1.0) + + return data + + +def make_dataloaders(data_with_covariates, **kwargs): + training_cutoff = "2016-09-01" + max_encoder_length = 4 + max_prediction_length = 3 + + kwargs.setdefault("target", "volume") + kwargs.setdefault("group_ids", ["agency", "sku"]) + kwargs.setdefault("add_relative_time_idx", True) + kwargs.setdefault("time_varying_unknown_reals", ["volume"]) + + training = TimeSeriesDataSet( + data_with_covariates[lambda x: x.date < training_cutoff].copy(), + time_idx="time_idx", + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + **kwargs, # fixture parametrization + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data_with_covariates.copy(), + min_prediction_idx=training.index.time.max() + 1, + ) + train_dataloader = training.to_dataloader(train=True, batch_size=2, num_workers=0) + val_dataloader = validation.to_dataloader(train=False, batch_size=2, num_workers=0) + test_dataloader = validation.to_dataloader(train=False, batch_size=1, num_workers=0) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) + + +@pytest.fixture( + params=[ + dict(), + dict( + static_categoricals=["agency", "sku"], + static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + constant_fill_strategy={"volume": 0}, + categorical_encoders={"sku": NaNLabelEncoder(add_nan=True)}, + ), + dict(static_categoricals=["agency", "sku"]), + dict(randomize_length=True, min_encoder_length=2), + dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2), + dict(target_normalizer=GroupNormalizer(transformation="log1p")), + dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="softplus", center=False + ) + ), + dict(target="agency"), + # test multiple targets + dict(target=["industry_volume", "volume"]), + dict(target=["agency", "volume"]), + dict( + target=["agency", "volume"], min_encoder_length=1, min_prediction_length=1 + ), + dict(target=["agency", "volume"], weight="volume"), + # test weights + dict(target="volume", weight="volume"), + ], + scope="session", +) +def multiple_dataloaders_with_covariates(data_with_covariates, request): + return make_dataloaders(data_with_covariates, **request.param) + + +@pytest.fixture(scope="session") +def dataloaders_with_different_encoder_decoder_length(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "target", + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_with_covariates(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_reals=["discount"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_multi_target(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + time_varying_unknown_reals=["target", "discount"], + target=["target", "discount"], + add_relative_time_idx=False, + ) + + +@pytest.fixture(scope="session") +def dataloaders_fixed_window_without_covariates(): + data = generate_ar_data(seasonality=10.0, timesteps=50, n_series=2) + validation = data.series.iloc[:2] + + max_encoder_length = 30 + max_prediction_length = 10 + + training = TimeSeriesDataSet( + data[lambda x: ~x.series.isin(validation)], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + static_categoricals=[], + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + time_varying_unknown_reals=["value"], + target_normalizer=EncoderNormalizer(), + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data[lambda x: x.series.isin(validation)], + stop_randomization=True, + ) + batch_size = 2 + train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 + ) + val_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + test_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 704ddfc20..609761e21 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -1,7 +1,11 @@ """Automated tests based on the skbase test suite template.""" from inspect import isclass +import shutil +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.loggers import TensorBoardLogger from skbase.testing import ( BaseFixtureGenerator as _BaseFixtureGenerator, TestAllObjects as _TestAllObjects, @@ -9,6 +13,7 @@ from pytorch_forecasting._registry import all_objects from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS +from pytorch_forecasting.tests._conftest import make_dataloaders # whether to test only estimators from modules that are changed w.r.t. main # default is False, can be set to True by pytest --only_changed_modules True flag @@ -110,6 +115,98 @@ def _all_objects(self): ] +def _integration( + data_with_covariates, + tmp_path, + cell_type="LSTM", + data_loader_kwargs={}, + clip_target: bool = False, + trainer_kwargs=None, + **kwargs, +): + data_with_covariates = data_with_covariates.copy() + if clip_target: + data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0) + else: + data_with_covariates["target"] = data_with_covariates["volume"] + data_loader_default_kwargs = dict( + target="target", + time_varying_known_reals=["price_actual"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=True, + ) + data_loader_default_kwargs.update(data_loader_kwargs) + dataloaders_with_covariates = make_dataloaders( + data_with_covariates, **data_loader_default_kwargs + ) + + train_dataloader = dataloaders_with_covariates["train"] + val_dataloader = dataloaders_with_covariates["val"] + test_dataloader = dataloaders_with_covariates["test"] + + early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" + ) + + logger = TensorBoardLogger(tmp_path) + if trainer_kwargs is None: + trainer_kwargs = {} + trainer = pl.Trainer( + max_epochs=3, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + **trainer_kwargs, + ) + + net = DeepAR.from_dataset( + train_dataloader.dataset, + hidden_size=5, + cell_type=cell_type, + learning_rate=0.01, + log_gradient_flow=True, + log_interval=1000, + n_plotting_samples=100, + **kwargs, + ) + net.size() + try: + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + # check loading + net = DeepAR.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + + # check prediction + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + trainer_kwargs=trainer_kwargs, + ) + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + net.predict( + val_dataloader, + fast_dev_run=True, + return_index=True, + return_decoder_lengths=True, + trainer_kwargs=trainer_kwargs, + ) + + class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator, _TestAllObjects): """Generic tests for all objects in the mini package.""" @@ -118,3 +215,7 @@ def test_doctest_examples(self, object_class): import doctest doctest.run_docstring_examples(object_class, globals()) + + def certain_failure(self, object_class): + """Fails for certain, for testing.""" + assert False From f6dee46efaa6853afa299d5edca80d11a80367ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 13:59:52 +0100 Subject: [PATCH 13/80] Update test_all_estimators.py --- pytorch_forecasting/tests/test_all_estimators.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 609761e21..a5f2c3783 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -116,6 +116,7 @@ def _all_objects(self): def _integration( + estimator_cls, data_with_covariates, tmp_path, cell_type="LSTM", @@ -165,7 +166,7 @@ def _integration( **trainer_kwargs, ) - net = DeepAR.from_dataset( + net = estimator_cls.from_dataset( train_dataloader.dataset, hidden_size=5, cell_type=cell_type, @@ -185,7 +186,7 @@ def _integration( test_outputs = trainer.test(net, dataloaders=test_dataloader) assert len(test_outputs) > 0 # check loading - net = DeepAR.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + net = estimator_cls.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # check prediction net.predict( From 9b0e4ec4c7d47dc0115a87ea4297a22a2f0fe5eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 15:51:07 +0100 Subject: [PATCH 14/80] Update test_all_estimators.py --- pytorch_forecasting/tests/test_all_estimators.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index a5f2c3783..2934d42db 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -116,7 +116,7 @@ def _all_objects(self): def _integration( - estimator_cls, + estimator_cls, data_with_covariates, tmp_path, cell_type="LSTM", @@ -186,7 +186,9 @@ def _integration( test_outputs = trainer.test(net, dataloaders=test_dataloader) assert len(test_outputs) > 0 # check loading - net = estimator_cls.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + net = estimator_cls.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path + ) # check prediction net.predict( From 7de528537d6fe36dd554f3cad5550d6f66c512e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 16:02:28 +0100 Subject: [PATCH 15/80] Update test_all_estimators.py --- pytorch_forecasting/tests/test_all_estimators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 2934d42db..37b597712 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -28,7 +28,7 @@ class PackageConfig: # package to search for objects # expected type: str, package/module name, relative to python environment root - package_name = "pytroch_forecasting" + package_name = "pytorch_forecasting" # list of object types (class names) to exclude # expected type: list of str, str are class names From 57dfe3a4e47ac3a34199d787cc6282f43b18a9f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 16:42:01 +0100 Subject: [PATCH 16/80] test folders --- pytest.ini | 4 +- .../tests/test_all_estimators.py | 43 ++++++++++++++++--- 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/pytest.ini b/pytest.ini index 457863f87..52f4fa1c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,7 +10,9 @@ addopts = --no-cov-on-fail markers = -testpaths = tests/ +testpaths = + tests/ + pytorch_forecasting/tests/ log_cli_level = ERROR log_format = %(asctime)s %(levelname)s %(message)s log_date_format = %Y-%m-%d %H:%M:%S diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 37b597712..8fbcc6ffe 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -6,10 +6,8 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger -from skbase.testing import ( - BaseFixtureGenerator as _BaseFixtureGenerator, - TestAllObjects as _TestAllObjects, -) +import pytest +from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator from pytorch_forecasting._registry import all_objects from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS @@ -110,10 +108,43 @@ def _all_objects(self): # which sequence the conditional fixtures are generated in fixture_sequence = [ + "object_metadata", "object_class", "object_instance", ] + def _generate_object_metadata(self, test_name, **kwargs): + """Return object class fixtures. + + Fixtures parametrized + --------------------- + object_class: object inheriting from BaseObject + ranges over all object classes not excluded by self.excluded_tests + """ + object_classes_to_test = [ + est for est in self._all_objects() if not self.is_excluded(test_name, est) + ] + object_names = [est.__name__ for est in object_classes_to_test] + + return object_classes_to_test, object_names + + def _generate_object_class(self, test_name, **kwargs): + """Return object class fixtures. + + Fixtures parametrized + --------------------- + object_class: object inheriting from BaseObject + ranges over all object classes not excluded by self.excluded_tests + """ + all_metadata = self._all_objects() + all_cls = [est.get_model_cls() for est in all_metadata] + object_classes_to_test = [ + est for est in all_cls if not self.is_excluded(test_name, est) + ] + object_names = [est.__name__ for est in object_classes_to_test] + + return object_classes_to_test, object_names + def _integration( estimator_cls, @@ -210,7 +241,7 @@ def _integration( ) -class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator, _TestAllObjects): +class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator): """Generic tests for all objects in the mini package.""" def test_doctest_examples(self, object_class): @@ -219,6 +250,6 @@ def test_doctest_examples(self, object_class): doctest.run_docstring_examples(object_class, globals()) - def certain_failure(self, object_class): + def test_certain_failure(self, object_class): """Fails for certain, for testing.""" assert False From c9f12dbdeea4aa431b52620b226cd57193a9a249 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 16:59:30 +0100 Subject: [PATCH 17/80] Update test.yml --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5fcd9c1ff..0083302dd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -71,7 +71,7 @@ jobs: - name: Run pytest shell: bash - run: python -m pytest tests + run: python -m pytest pytest: name: Run pytest @@ -110,7 +110,7 @@ jobs: - name: Run pytest shell: bash - run: python -m pytest tests + run: python -m pytest - name: Statistics run: | From fa8144ebae6312d34458a2464da2d1bc3f186754 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 18:56:39 +0100 Subject: [PATCH 18/80] test integration --- .../models/deepar/_deepar_metadata.py | 25 +---------- .../tests/test_all_estimators.py | 41 ++++++++++++++++++- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index 89aefc1b0..206f113f0 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -23,29 +23,8 @@ def get_model_cls(cls): return DeepAR @classmethod - def get_test_params(cls, parameter_set="default"): - """Return testing parameter settings for the skbase object. - - ``get_test_params`` is a unified interface point to store - parameter settings for testing purposes. This function is also - used in ``create_test_instance`` and ``create_test_instances_and_names`` - to construct test instances. - - ``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``. - - Each ``dict`` is a parameter configuration for testing, - and can be used to construct an "interesting" test instance. - A call to ``cls(**params)`` should - be valid for all dictionaries ``params`` in the return of ``get_test_params``. - - The ``get_test_params`` need not return fixed lists of dictionaries, - it can also return dynamic or stochastic parameter settings. - - Parameters - ---------- - parameter_set : str, default="default" - Name of the set of test parameters to return, for use in tests. If no - special parameters are defined for a value, will return `"default"` set. + def get_test_train_params(cls): + """Return testing parameter settings for the trainer. Returns ------- diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 8fbcc6ffe..378316d7a 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -71,6 +71,8 @@ class BaseFixtureGenerator(_BaseFixtureGenerator): object_instance: instance of estimator inheriting from BaseObject ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS instances are generated by create_test_instance class method of object_class + trainer_kwargs: list of dict + ranges over dictionaries of kwargs for the trainer """ # overrides object retrieval in scikit-base @@ -111,6 +113,7 @@ def _all_objects(self): "object_metadata", "object_class", "object_instance", + "trainer_kwargs", ] def _generate_object_metadata(self, test_name, **kwargs): @@ -145,6 +148,30 @@ def _generate_object_class(self, test_name, **kwargs): return object_classes_to_test, object_names + def _generate_trainer_kwargs(self, test_name, **kwargs): + """Return kwargs for the trainer. + + Fixtures parametrized + --------------------- + trainer_kwargs: dict + ranges over all kwargs for the trainer + """ + # call _generate_object_class to get all the classes + object_meta_to_test, _ = self._generate_object_metadata(test_name=test_name) + + # create instances from the classes + train_kwargs_to_test = [] + train_kwargs_names = [] + # retrieve all object parameters if multiple, construct instances + for est in object_meta_to_test: + est_name = est.__name__ + all_train_kwargs = est.get_test_train_params() + train_kwargs_to_test += all_train_kwargs + rg = range(len(all_train_kwargs)) + train_kwargs_names += [f"{est_name}_{i}" for i in rg] + + return train_kwargs_to_test, train_kwargs_names + def _integration( estimator_cls, @@ -250,6 +277,16 @@ def test_doctest_examples(self, object_class): doctest.run_docstring_examples(object_class, globals()) - def test_certain_failure(self, object_class): + def test_integration( + self, object_class, trainer_kwargs, data_with_covariates, tmp_path + ): """Fails for certain, for testing.""" - assert False + from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss + + if "loss" in trainer_kwargs and isinstance( + trainer_kwargs["loss"], NegativeBinomialDistributionLoss + ): + data_with_covariates = data_with_covariates.assign( + volume=lambda x: x.volume.round() + ) + _integration(object_class, data_with_covariates, tmp_path, **trainer_kwargs) From 232a510bc6820786ef1ce46ab3115a04126054ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 19:07:35 +0100 Subject: [PATCH 19/80] fixes --- .../models/base/_base_object.py | 8 +++++ .../models/deepar/_deepar_metadata.py | 4 ++- .../tests/test_all_estimators.py | 31 ++++++++++--------- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 4fcb1bd22..7fd59d6a4 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -27,6 +27,14 @@ def get_model_cls(cls): """Get model class.""" raise NotImplementedError + @classmethod + def name(cls): + """Get model name.""" + name = cls.get_class_tags().get("info:name", None) + if name is None: + name = cls.get_model_cls().__name__ + return name + @classmethod def create_test_instance(cls, parameter_set="default"): """Construct an instance of the class, using first test parameter set. diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index 206f113f0..e477d63b0 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -7,12 +7,14 @@ class DeepARMetadata(_BasePtForecaster): """DeepAR metadata container.""" _tags = { + "info:name": "DeepAR", + "info:compute": 3, + "authors": ["jdb78"], "capability:exogenous": True, "capability:multivariate": True, "capability:pred_int": True, "capability:flexible_history_length": True, "capability:cold_start": False, - "info:compute": 3, } @classmethod diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 378316d7a..c2d86e40e 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -127,7 +127,7 @@ def _generate_object_metadata(self, test_name, **kwargs): object_classes_to_test = [ est for est in self._all_objects() if not self.is_excluded(test_name, est) ] - object_names = [est.__name__ for est in object_classes_to_test] + object_names = [est.name() for est in object_classes_to_test] return object_classes_to_test, object_names @@ -156,21 +156,16 @@ def _generate_trainer_kwargs(self, test_name, **kwargs): trainer_kwargs: dict ranges over all kwargs for the trainer """ - # call _generate_object_class to get all the classes - object_meta_to_test, _ = self._generate_object_metadata(test_name=test_name) + if "object_metadata" in kwargs.keys(): + obj_meta = kwargs["object_metadata"] + else: + return [] - # create instances from the classes - train_kwargs_to_test = [] - train_kwargs_names = [] - # retrieve all object parameters if multiple, construct instances - for est in object_meta_to_test: - est_name = est.__name__ - all_train_kwargs = est.get_test_train_params() - train_kwargs_to_test += all_train_kwargs - rg = range(len(all_train_kwargs)) - train_kwargs_names += [f"{est_name}_{i}" for i in rg] + all_train_kwargs = obj_meta.get_test_train_params() + rg = range(len(all_train_kwargs)) + train_kwargs_names = [str(i) for i in rg] - return train_kwargs_to_test, train_kwargs_names + return all_train_kwargs, train_kwargs_names def _integration( @@ -278,11 +273,17 @@ def test_doctest_examples(self, object_class): doctest.run_docstring_examples(object_class, globals()) def test_integration( - self, object_class, trainer_kwargs, data_with_covariates, tmp_path + self, + object_metadata, + trainer_kwargs, + data_with_covariates, + tmp_path, ): """Fails for certain, for testing.""" from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss + object_class = object_metadata.get_model_cls() + if "loss" in trainer_kwargs and isinstance( trainer_kwargs["loss"], NegativeBinomialDistributionLoss ): From 1c8d4b5c4fbf8dca91ec28e36e8781bb08a291bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 19:30:39 +0100 Subject: [PATCH 20/80] Update _conftest.py --- pytorch_forecasting/tests/_conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index e276446a6..20cad22c5 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -17,7 +17,7 @@ def gpus(): return 0 -@pytest.fixture(scope="session") +@pytest.fixture(scope="package") def data_with_covariates(): data = get_stallion_data() data["month"] = data.date.dt.month.astype(str) From f632e32325a657fe975f9c76344eaba0585e17e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 23 Feb 2025 19:37:01 +0100 Subject: [PATCH 21/80] try scenarios --- pytorch_forecasting/tests/_conftest.py | 2 +- pytorch_forecasting/tests/_data_scenarios.py | 261 ++++++++++++++++++ .../tests/test_all_estimators.py | 5 +- 3 files changed, 265 insertions(+), 3 deletions(-) create mode 100644 pytorch_forecasting/tests/_data_scenarios.py diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index 20cad22c5..e276446a6 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -17,7 +17,7 @@ def gpus(): return 0 -@pytest.fixture(scope="package") +@pytest.fixture(scope="session") def data_with_covariates(): data = get_stallion_data() data["month"] = data.date.dt.month.astype(str) diff --git a/pytorch_forecasting/tests/_data_scenarios.py b/pytorch_forecasting/tests/_data_scenarios.py new file mode 100644 index 000000000..062db97dd --- /dev/null +++ b/pytorch_forecasting/tests/_data_scenarios.py @@ -0,0 +1,261 @@ +import numpy as np +import pytest +import torch + +from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data + +torch.manual_seed(23) + + +@pytest.fixture(scope="session") +def gpus(): + if torch.cuda.is_available(): + return [0] + else: + return 0 + + +def data_with_covariates(): + data = get_stallion_data() + data["month"] = data.date.dt.month.astype(str) + data["log_volume"] = np.log1p(data.volume) + data["weight"] = 1 + np.sqrt(data.volume) + + data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month + data["time_idx"] -= data["time_idx"].min() + + # convert special days into strings + special_days = [ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + data[special_days] = ( + data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category") + ) + data = data.astype(dict(industry_volume=float)) + + # select data subset + data = data[lambda x: x.sku.isin(data.sku.unique()[:2])][ + lambda x: x.agency.isin(data.agency.unique()[:2]) + ] + + # default target + data["target"] = data["volume"].clip(1e-3, 1.0) + + return data + + +def make_dataloaders(data_with_covariates, **kwargs): + training_cutoff = "2016-09-01" + max_encoder_length = 4 + max_prediction_length = 3 + + kwargs.setdefault("target", "volume") + kwargs.setdefault("group_ids", ["agency", "sku"]) + kwargs.setdefault("add_relative_time_idx", True) + kwargs.setdefault("time_varying_unknown_reals", ["volume"]) + + training = TimeSeriesDataSet( + data_with_covariates[lambda x: x.date < training_cutoff].copy(), + time_idx="time_idx", + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + **kwargs, # fixture parametrization + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data_with_covariates.copy(), + min_prediction_idx=training.index.time.max() + 1, + ) + train_dataloader = training.to_dataloader(train=True, batch_size=2, num_workers=0) + val_dataloader = validation.to_dataloader(train=False, batch_size=2, num_workers=0) + test_dataloader = validation.to_dataloader(train=False, batch_size=1, num_workers=0) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) + + +@pytest.fixture( + params=[ + dict(), + dict( + static_categoricals=["agency", "sku"], + static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + constant_fill_strategy={"volume": 0}, + categorical_encoders={"sku": NaNLabelEncoder(add_nan=True)}, + ), + dict(static_categoricals=["agency", "sku"]), + dict(randomize_length=True, min_encoder_length=2), + dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2), + dict(target_normalizer=GroupNormalizer(transformation="log1p")), + dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="softplus", center=False + ) + ), + dict(target="agency"), + # test multiple targets + dict(target=["industry_volume", "volume"]), + dict(target=["agency", "volume"]), + dict( + target=["agency", "volume"], min_encoder_length=1, min_prediction_length=1 + ), + dict(target=["agency", "volume"], weight="volume"), + # test weights + dict(target="volume", weight="volume"), + ], + scope="session", +) +def multiple_dataloaders_with_covariates(data_with_covariates, request): + return make_dataloaders(data_with_covariates, **request.param) + + +@pytest.fixture(scope="session") +def dataloaders_with_different_encoder_decoder_length(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "target", + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_with_covariates(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + target="target", + time_varying_known_reals=["discount"], + time_varying_unknown_reals=["target"], + static_categoricals=["agency"], + add_relative_time_idx=False, + target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False), + ) + + +@pytest.fixture(scope="session") +def dataloaders_multi_target(data_with_covariates): + return make_dataloaders( + data_with_covariates.copy(), + time_varying_unknown_reals=["target", "discount"], + target=["target", "discount"], + add_relative_time_idx=False, + ) + + +@pytest.fixture(scope="session") +def dataloaders_fixed_window_without_covariates(): + data = generate_ar_data(seasonality=10.0, timesteps=50, n_series=2) + validation = data.series.iloc[:2] + + max_encoder_length = 30 + max_prediction_length = 10 + + training = TimeSeriesDataSet( + data[lambda x: ~x.series.isin(validation)], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + static_categoricals=[], + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + time_varying_unknown_reals=["value"], + target_normalizer=EncoderNormalizer(), + ) + + validation = TimeSeriesDataSet.from_dataset( + training, + data[lambda x: x.series.isin(validation)], + stop_randomization=True, + ) + batch_size = 2 + train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 + ) + val_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + test_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 + ) + + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index c2d86e40e..b8a21cc6a 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -6,7 +6,6 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger -import pytest from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator from pytorch_forecasting._registry import all_objects @@ -276,11 +275,13 @@ def test_integration( self, object_metadata, trainer_kwargs, - data_with_covariates, tmp_path, ): """Fails for certain, for testing.""" from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss + from pytorch_forecasting.tests._data_scenarios import data_with_covariates + + data_with_covariates = data_with_covariates() object_class = object_metadata.get_model_cls() From 252598d2ce3f31244a422cd9206961776ea79615 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 18:43:51 +0530 Subject: [PATCH 22/80] D1, D2 layer commit --- pytorch_forecasting/data/data_module.py | 633 ++++++++++++++++++++++++ pytorch_forecasting/data/timeseries.py | 257 ++++++++++ 2 files changed, 890 insertions(+) create mode 100644 pytorch_forecasting/data/data_module.py diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py new file mode 100644 index 000000000..56917696d --- /dev/null +++ b/pytorch_forecasting/data/data_module.py @@ -0,0 +1,633 @@ +####################################################################################### +# Disclaimer: This data-module is still work in progress and experimental, please +# use with care. This data-module is a basic skeleton of how the data-handling pipeline +# may look like in the future. +# This is D2 layer that will handle the preprocessing and data loaders. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + +from typing import Any, Dict, List, Optional, Tuple, Union + +from lightning.pytorch import LightningDataModule +from sklearn.preprocessing import RobustScaler, StandardScaler +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_forecasting.data.encoders import ( + EncoderNormalizer, + NaNLabelEncoder, + TorchNormalizer, +) +from pytorch_forecasting.data.timeseries import TimeSeries, _coerce_to_dict + +NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] + + +class EncoderDecoderTimeSeriesDataModule(LightningDataModule): + """ + Lightning DataModule for processing time series data in an encoder-decoder format. + + This module handles preprocessing, splitting, and batching of time series data + for use in deep learning models. It supports categorical and continuous features, + various scalers, and automatic target normalization. + + Parameters + ---------- + time_series_dataset : TimeSeries + The dataset containing time series data. + max_encoder_length : int, default=30 + Maximum length of the encoder input sequence. + min_encoder_length : Optional[int], default=None + Minimum length of the encoder input sequence. + Defaults to `max_encoder_length` if not specified. + max_prediction_length : int, default=1 + Maximum length of the decoder output sequence. + min_prediction_length : Optional[int], default=None + Minimum length of the decoder output sequence. + Defaults to `max_prediction_length` if not specified. + min_prediction_idx : Optional[int], default=None + Minimum index from which predictions start. + allow_missing_timesteps : bool, default=False + Whether to allow missing timesteps in the dataset. + add_relative_time_idx : bool, default=False + Whether to add a relative time index feature. + add_target_scales : bool, default=False + Whether to add target scaling information. + add_encoder_length : Union[bool, str], default="auto" + Whether to include encoder length information. + target_normalizer : + Union[NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None], + default="auto" + Normalizer for the target variable. If "auto", uses `RobustScaler`. + + categorical_encoders : Optional[Dict[str, NaNLabelEncoder]], default=None + Dictionary of categorical encoders. + + scalers : + Optional[Dict[str, Union[StandardScaler, RobustScaler, + TorchNormalizer, EncoderNormalizer]]], default=None + Dictionary of feature scalers. + + randomize_length : Union[None, Tuple[float, float], bool], default=False + Whether to randomize input sequence length. + batch_size : int, default=32 + Batch size for DataLoader. + num_workers : int, default=0 + Number of workers for DataLoader. + train_val_test_split : tuple, default=(0.7, 0.15, 0.15) + Proportions for train, validation, and test dataset splits. + """ + + def __init__( + self, + time_series_dataset: TimeSeries, + max_encoder_length: int = 30, + min_encoder_length: Optional[int] = None, + max_prediction_length: int = 1, + min_prediction_length: Optional[int] = None, + min_prediction_idx: Optional[int] = None, + allow_missing_timesteps: bool = False, + add_relative_time_idx: bool = False, + add_target_scales: bool = False, + add_encoder_length: Union[bool, str] = "auto", + target_normalizer: Union[ + NORMALIZER, str, List[NORMALIZER], Tuple[NORMALIZER], None + ] = "auto", + categorical_encoders: Optional[Dict[str, NaNLabelEncoder]] = None, + scalers: Optional[ + Dict[ + str, + Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer], + ] + ] = None, + randomize_length: Union[None, Tuple[float, float], bool] = False, + batch_size: int = 32, + num_workers: int = 0, + train_val_test_split: tuple = (0.7, 0.15, 0.15), + ): + super().__init__() + self.time_series_dataset = time_series_dataset + self.time_series_metadata = time_series_dataset.get_metadata() + + self.max_encoder_length = max_encoder_length + self.min_encoder_length = min_encoder_length or max_encoder_length + self.max_prediction_length = max_prediction_length + self.min_prediction_length = min_prediction_length or max_prediction_length + self.min_prediction_idx = min_prediction_idx + + self.allow_missing_timesteps = allow_missing_timesteps + self.add_relative_time_idx = add_relative_time_idx + self.add_target_scales = add_target_scales + self.add_encoder_length = add_encoder_length + self.randomize_length = randomize_length + + self.batch_size = batch_size + self.num_workers = num_workers + self.train_val_test_split = train_val_test_split + + if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto": + self.target_normalizer = RobustScaler() + else: + self.target_normalizer = target_normalizer + + self.categorical_encoders = _coerce_to_dict(categorical_encoders) + self.scalers = _coerce_to_dict(scalers) + + self.categorical_indices = [] + self.continuous_indices = [] + self._metadata = None + + for idx, col in enumerate(self.time_series_metadata["cols"]["x"]): + if self.time_series_metadata["col_type"].get(col) == "C": + self.categorical_indices.append(idx) + else: + self.continuous_indices.append(idx) + + def _prepare_metadata(self): + """Prepare metadata for model initialisation. + + Returns + ------- + dict + dictionary containing the following keys: + + * ``encoder_cat``: Number of categorical variables in the encoder. + Computed as ``len(self.categorical_indices)``, which counts the + categorical feature indices. + * ``encoder_cont``: Number of continuous variables in the encoder. + Computed as ``len(self.continuous_indices)``, which counts the + continuous feature indices. + * ``decoder_cat``: Number of categorical variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "C"(categorical) and col_known == "K" (known) + * ``decoder_cont``: Number of continuous variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "F"(continuous) and col_known == "K"(known) + * ``target``: Number of target variables. + Computed as ``len(self.time_series_metadata["cols"]["y"])``, which + gives the number of output target columns.. + * ``static_categorical_features``: Number of static categorical features + Computed by filtering ``self.time_series_metadata["cols"]["st"]`` + (static features) where col_type == "C" (categorical). + * ``static_continuous_features``: Number of static continuous features + Computed as difference of + ``len(self.time_series_metadata["cols"]["st"])`` (static features) + and static_categorical_features that gives static continuous feature + * ``max_encoder_length``: maximum encoder length + Taken directly from `self.max_encoder_length`. + * ``max_prediction_length``: maximum prediction length + Taken directly from `self.max_prediction_length`. + * ``min_encoder_length``: minimum encoder length + Taken directly from `self.min_encoder_length`. + * ``min_prediction_length``: minimum prediction length + Taken directly from `self.min_prediction_length`. + + """ + encoder_cat_count = len(self.categorical_indices) + encoder_cont_count = len(self.continuous_indices) + + decoder_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "C" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + decoder_cont_count = len( + [ + col + for col in self.time_series_metadata["cols"]["x"] + if self.time_series_metadata["col_type"].get(col) == "F" + and self.time_series_metadata["col_known"].get(col) == "K" + ] + ) + + target_count = len(self.time_series_metadata["cols"]["y"]) + metadata = { + "encoder_cat": encoder_cat_count, + "encoder_cont": encoder_cont_count, + "decoder_cat": decoder_cat_count, + "decoder_cont": decoder_cont_count, + "target": target_count, + } + if self.time_series_metadata["cols"]["st"]: + static_cat_count = len( + [ + col + for col in self.time_series_metadata["cols"]["st"] + if self.time_series_metadata["col_type"].get(col) == "C" + ] + ) + static_cont_count = ( + len(self.time_series_metadata["cols"]["st"]) - static_cat_count + ) + + metadata["static_categorical_features"] = static_cat_count + metadata["static_continuous_features"] = static_cont_count + else: + metadata["static_categorical_features"] = 0 + metadata["static_continuous_features"] = 0 + + metadata.update( + { + "max_encoder_length": self.max_encoder_length, + "max_prediction_length": self.max_prediction_length, + "min_encoder_length": self.min_encoder_length, + "min_prediction_length": self.min_prediction_length, + } + ) + + return metadata + + @property + def metadata(self): + """Compute metadata for model initialization. + + This property returns a dictionary containing the shapes and key information + related to the time series model. The metadata includes: + + * ``encoder_cat``: Number of categorical variables in the encoder. + * ``encoder_cont``: Number of continuous variables in the encoder. + * ``decoder_cat``: Number of categorical variables in the decoder that are + known in advance. + * ``decoder_cont``: Number of continuous variables in the decoder that are + known in advance. + * ``target``: Number of target variables. + + If static features are present, the following keys are added: + + * ``static_categorical_features``: Number of static categorical features + * ``static_continuous_features``: Number of static continuous features + + It also contains the following information: + + * ``max_encoder_length``: maximum encoder length + * ``max_prediction_length``: maximum prediction length + * ``min_encoder_length``: minimum encoder length + * ``min_prediction_length``: minimum prediction length + """ + if self._metadata is None: + self._metadata = self._prepare_metadata() + return self._metadata + + def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: + """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. + + Preprocessing steps + -------------------- + + * Converts target (`y`) and features (`x`) to `torch.float32`. + * Masks time points that are at or before the cutoff time. + * Splits features into categorical and continuous subsets based on + predefined indices. + + + TODO: add scalers, target normalizers etc. + """ + processed_data = [] + + for idx in indices: + sample = self.time_series_dataset[idx.item()] + + target = sample["y"] + features = sample["x"] + times = sample["t"] + cutoff_time = sample["cutoff_time"] + + time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) + + if isinstance(target, torch.Tensor): + target = target.float() + else: + target = torch.tensor(target, dtype=torch.float32) + + if isinstance(features, torch.Tensor): + features = features.float() + else: + features = torch.tensor(features, dtype=torch.float32) + + # TODO: add scalers, target normalizers etc. + + categorical = ( + features[:, self.categorical_indices] + if self.categorical_indices + else torch.zeros((features.shape[0], 0)) + ) + continuous = ( + features[:, self.continuous_indices] + if self.continuous_indices + else torch.zeros((features.shape[0], 0)) + ) + + processed_data.append( + { + "features": {"categorical": categorical, "continuous": continuous}, + "target": target, + "static": sample.get("st", None), + "group": sample.get("group", torch.tensor([0])), + "length": len(target), + "time_mask": time_mask, + "times": times, + "cutoff_time": cutoff_time, + } + ) + + return processed_data + + class _ProcessedEncoderDecoderDataset(Dataset): + """PyTorch Dataset for processed encoder-decoder time series data. + + Parameters + ---------- + processed_data : List[Dict[str, Any]] + List of preprocessed time series samples. + windows : List[Tuple[int, int, int, int]] + List of window tuples containing + (series_idx, start_idx, enc_length, pred_length). + add_relative_time_idx : bool, default=False + Whether to include relative time indices. + """ + + def __init__( + self, + processed_data: List[Dict[str, Any]], + windows: List[Tuple[int, int, int, int]], + add_relative_time_idx: bool = False, + ): + self.processed_data = processed_data + self.windows = windows + self.add_relative_time_idx = add_relative_time_idx + + def __len__(self): + return len(self.windows) + + def __getitem__(self, idx): + """Retrieve a processed time series window for dataloader input. + + x : dict + Dictionary containing model inputs: + + * ``encoder_cat`` : tensor of shape (enc_length, n_cat_features) + Categorical features for the encoder. + * ``encoder_cont`` : tensor of shape (enc_length, n_cont_features) + Continuous features for the encoder. + * ``decoder_cat`` : tensor of shape (pred_length, n_cat_features) + Categorical features for the decoder. + * ``decoder_cont`` : tensor of shape (pred_length, n_cont_features) + Continuous features for the decoder. + * ``encoder_lengths`` : tensor of shape (1,) + Length of the encoder sequence. + * ``decoder_lengths`` : tensor of shape (1,) + Length of the decoder sequence. + * ``decoder_target_lengths`` : tensor of shape (1,) + Length of the decoder target sequence. + * ``groups`` : tensor of shape (1,) + Group identifier for the time series instance. + * ``encoder_time_idx`` : tensor of shape (enc_length,) + Time indices for the encoder sequence. + * ``decoder_time_idx`` : tensor of shape (pred_length,) + Time indices for the decoder sequence. + * ``target_scale`` : tensor of shape (1,) + Scaling factor for the target values. + * ``encoder_mask`` : tensor of shape (enc_length,) + Boolean mask indicating valid encoder time points. + * ``decoder_mask`` : tensor of shape (pred_length,) + Boolean mask indicating valid decoder time points. + + If static features are present, the following keys are added: + + * ``static_categorical_features`` : tensor of shape + (1, n_static_cat_features), optional + Static categorical features, if available. + * ``static_continuous_features`` : tensor of shape (1, 0), optional + Placeholder for static continuous features (currently empty). + + y : tensor of shape ``(pred_length, n_targets)`` + Target values for the decoder sequence. + """ + series_idx, start_idx, enc_length, pred_length = self.windows[idx] + data = self.processed_data[series_idx] + + end_idx = start_idx + enc_length + pred_length + encoder_indices = slice(start_idx, start_idx + enc_length) + decoder_indices = slice(start_idx + enc_length, end_idx) + + target_scale = data["target"][encoder_indices] + target_scale = target_scale[~torch.isnan(target_scale)].abs().mean() + if torch.isnan(target_scale) or target_scale == 0: + target_scale = torch.tensor(1.0) + + encoder_mask = ( + data["time_mask"][encoder_indices] + if "time_mask" in data + else torch.ones(enc_length, dtype=torch.bool) + ) + decoder_mask = ( + data["time_mask"][decoder_indices] + if "time_mask" in data + else torch.zeros(pred_length, dtype=torch.bool) + ) + + x = { + "encoder_cat": data["features"]["categorical"][encoder_indices], + "encoder_cont": data["features"]["continuous"][encoder_indices], + "decoder_cat": data["features"]["categorical"][decoder_indices], + "decoder_cont": data["features"]["continuous"][decoder_indices], + "encoder_lengths": torch.tensor(enc_length), + "decoder_lengths": torch.tensor(pred_length), + "decoder_target_lengths": torch.tensor(pred_length), + "groups": data["group"], + "encoder_time_idx": torch.arange(enc_length), + "decoder_time_idx": torch.arange(enc_length, enc_length + pred_length), + "target_scale": target_scale, + "encoder_mask": encoder_mask, + "decoder_mask": decoder_mask, + } + if data["static"] is not None: + x["static_categorical_features"] = data["static"].unsqueeze(0) + x["static_continuous_features"] = torch.zeros((1, 0)) + + y = data["target"][decoder_indices] + if y.ndim == 1: + y = y.unsqueeze(-1) + + return x, y + + def _create_windows( + self, processed_data: List[Dict[str, Any]] + ) -> List[Tuple[int, int, int, int]]: + """Generate sliding windows for training, validation, and testing. + + Returns + ------- + List[Tuple[int, int, int, int]] + A list of tuples, where each tuple consists of: + - ``series_idx`` : int + Index of the time series in `processed_data`. + - ``start_idx`` : int + Start index of the encoder window. + - ``enc_length`` : int + Length of the encoder input sequence. + - ``pred_length`` : int + Length of the decoder output sequence. + """ + windows = [] + + for idx, data in enumerate(processed_data): + sequence_length = data["length"] + + if sequence_length < self.max_encoder_length + self.max_prediction_length: + continue + + effective_min_prediction_idx = ( + self.min_prediction_idx + if self.min_prediction_idx is not None + else self.max_encoder_length + ) + + max_prediction_idx = sequence_length - self.max_prediction_length + 1 + + if max_prediction_idx <= effective_min_prediction_idx: + continue + + for start_idx in range( + 0, max_prediction_idx - effective_min_prediction_idx + ): + if ( + start_idx + self.max_encoder_length + self.max_prediction_length + <= sequence_length + ): + windows.append( + ( + idx, + start_idx, + self.max_encoder_length, + self.max_prediction_length, + ) + ) + + return windows + + def setup(self, stage: Optional[str] = None): + """Prepare the datasets for training, validation, testing, or prediction. + + Parameters + ---------- + stage : Optional[str], default=None + Specifies the stage of setup. Can be one of: + - ``"fit"`` : Prepares training and validation datasets. + - ``"test"`` : Prepares the test dataset. + - ``"predict"`` : Prepares the dataset for inference. + - ``None`` : Prepares all datasets. + """ + total_series = len(self.time_series_dataset) + self._split_indices = torch.randperm(total_series) + + self._train_size = int(self.train_val_test_split[0] * total_series) + self._val_size = int(self.train_val_test_split[1] * total_series) + + self._train_indices = self._split_indices[: self._train_size] + self._val_indices = self._split_indices[ + self._train_size : self._train_size + self._val_size + ] + self._test_indices = self._split_indices[self._train_size + self._val_size :] + + if stage is None or stage == "fit": + if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): + self.train_processed = self._preprocess_data(self._train_indices) + self.val_processed = self._preprocess_data(self._val_indices) + + self.train_windows = self._create_windows(self.train_processed) + self.val_windows = self._create_windows(self.val_processed) + + self.train_dataset = self._ProcessedEncoderDecoderDataset( + self.train_processed, self.train_windows, self.add_relative_time_idx + ) + self.val_dataset = self._ProcessedEncoderDecoderDataset( + self.val_processed, self.val_windows, self.add_relative_time_idx + ) + # print(self.val_dataset[0]) + + elif stage is None or stage == "test": + if not hasattr(self, "test_dataset"): + self.test_processed = self._preprocess_data(self._test_indices) + self.test_windows = self._create_windows(self.test_processed) + + self.test_dataset = self._ProcessedEncoderDecoderDataset( + self.test_processed, self.test_windows, self.add_relative_time_idx + ) + elif stage == "predict": + predict_indices = torch.arange(len(self.time_series_dataset)) + self.predict_processed = self._preprocess_data(predict_indices) + self.predict_windows = self._create_windows(self.predict_processed) + self.predict_dataset = self._ProcessedEncoderDecoderDataset( + self.predict_processed, self.predict_windows, self.add_relative_time_idx + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + collate_fn=self.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + def predict_dataloader(self): + return DataLoader( + self.predict_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + ) + + @staticmethod + def collate_fn(batch): + x_batch = { + "encoder_cat": torch.stack([x["encoder_cat"] for x, _ in batch]), + "encoder_cont": torch.stack([x["encoder_cont"] for x, _ in batch]), + "decoder_cat": torch.stack([x["decoder_cat"] for x, _ in batch]), + "decoder_cont": torch.stack([x["decoder_cont"] for x, _ in batch]), + "encoder_lengths": torch.stack([x["encoder_lengths"] for x, _ in batch]), + "decoder_lengths": torch.stack([x["decoder_lengths"] for x, _ in batch]), + "decoder_target_lengths": torch.stack( + [x["decoder_target_lengths"] for x, _ in batch] + ), + "groups": torch.stack([x["groups"] for x, _ in batch]), + "encoder_time_idx": torch.stack([x["encoder_time_idx"] for x, _ in batch]), + "decoder_time_idx": torch.stack([x["decoder_time_idx"] for x, _ in batch]), + "target_scale": torch.stack([x["target_scale"] for x, _ in batch]), + "encoder_mask": torch.stack([x["encoder_mask"] for x, _ in batch]), + "decoder_mask": torch.stack([x["decoder_mask"] for x, _ in batch]), + } + + if "static_categorical_features" in batch[0][0]: + x_batch["static_categorical_features"] = torch.stack( + [x["static_categorical_features"] for x, _ in batch] + ) + x_batch["static_continuous_features"] = torch.stack( + [x["static_continuous_features"] for x, _ in batch] + ) + + y_batch = torch.stack([y for _, y in batch]) + return x_batch, y_batch diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 336eecd5f..bc8300300 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2657,6 +2657,8 @@ def _coerce_to_list(obj): """ if obj is None: return [] + if isinstance(obj, str): + return [obj] return list(obj) @@ -2668,3 +2670,258 @@ def _coerce_to_dict(obj): if obj is None: return {} return deepcopy(obj) + + +####################################################################################### +# Disclaimer: This dataset class is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the data-handling pipeline may +# look like in the future. +# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion +# and turning the data to tensors. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + + +class TimeSeries(Dataset): + """PyTorch Dataset for time series data stored in pandas DataFrame. + + Parameters + ---------- + data : pd.DataFrame + data frame with sequence data. + Column names must all be str, and contain str as referred to below. + data_future : pd.DataFrame, optional, default=None + data frame with future data. + Column names must all be str, and contain str as referred to below. + May contain only columns that are in time, group, weight, known, or static. + time : str, optional, default = first col not in group_ids, weight, target, static. + integer typed column denoting the time index within ``data``. + This column is used to determine the sequence of samples. + If there are no missing observations, + the time index should increase by ``+1`` for each subsequent sample. + The first time_idx for each series does not necessarily + have to be ``0`` but any value is allowed. + target : str or List[str], optional, default = last column (at iloc -1) + column(s) in ``data`` denoting the forecasting target. + Can be categorical or numerical dtype. + group : List[str], optional, default = None + list of column names identifying a time series instance within ``data``. + This means that the ``group`` together uniquely identify an instance, + and ``group`` together with ``time`` uniquely identify a single observation + within a time series instance. + If ``None``, the dataset is assumed to be a single time series. + weight : str, optional, default=None + column name for weights. + If ``None``, it is assumed that there is no weight column. + num : list of str, optional, default = all columns with dtype in "fi" + list of numerical variables in ``data``, + list may also contain list of str, which are then grouped together. + cat : list of str, optional, default = all columns with dtype in "Obc" + list of categorical variables in ``data``, + list may also contain list of str, which are then grouped together + (e.g. useful for product categories). + known : list of str, optional, default = all variables + list of variables that change over time and are known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for special days or promotion categories). + unknown : list of str, optional, default = no variables + list of variables that are not known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for weather categories). + static : list of str, optional, default = all variables not in known, unknown + list of variables that do not change over time, + list may also contain list of str, which are then grouped together. + """ + + def __init__( + self, + data: pd.DataFrame, + data_future: Optional[pd.DataFrame] = None, + time: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group: Optional[List[str]] = None, + weight: Optional[str] = None, + num: Optional[List[Union[str, List[str]]]] = None, + cat: Optional[List[Union[str, List[str]]]] = None, + known: Optional[List[Union[str, List[str]]]] = None, + unknown: Optional[List[Union[str, List[str]]]] = None, + static: Optional[List[Union[str, List[str]]]] = None, + ): + + self.data = data + self.data_future = data_future + self.time = time + self.target = _coerce_to_list(target) + self.group = _coerce_to_list(group) + self.weight = weight + self.num = _coerce_to_list(num) + self.cat = _coerce_to_list(cat) + self.known = _coerce_to_list(known) + self.unknown = _coerce_to_list(unknown) + self.static = _coerce_to_list(static) + + self.feature_cols = [ + col + for col in data.columns + if col not in [self.time] + self.group + [self.weight] + self.target + ] + if self.group: + self._groups = self.data.groupby(self.group).groups + self._group_ids = list(self._groups.keys()) + else: + self._groups = {"_single_group": self.data.index} + self._group_ids = ["_single_group"] + + self._prepare_metadata() + + def _prepare_metadata(self): + """Prepare metadata for the dataset. + + The funcion returns metadata that contains: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + """ + self.metadata = { + "cols": { + "y": self.target, + "x": self.feature_cols, + "st": self.static, + }, + "col_type": {}, + "col_known": {}, + } + + all_cols = self.target + self.feature_cols + self.static + for col in all_cols: + self.metadata["col_type"][col] = "C" if col in self.cat else "F" + + self.metadata["col_known"][col] = "K" if col in self.known else "U" + + def __len__(self) -> int: + """Return number of time series in the dataset.""" + return len(self._group_ids) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get time series data for given index. + + It returns: + + * ``t``: ``numpy.ndarray`` of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with ``y``, + and ``x`` not ending in ``f``. + * ``y``: tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with ``t``. + * ``x``: tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with ``t``. + * ``group``: tensor of shape (n_groups) + Group identifiers for time series instances. + * ``st``: tensor of shape (n_static_features) + Static features. + * ``cutoff_time``: float or ``numpy.float64`` + Cutoff time for the time series instance. + + Optionally, the following str-keyed entry can be included: + + * ``weights``: tensor of shape (n_timepoints), only if weight is not None + """ + group_id = self._group_ids[index] + + if self.group: + mask = self._groups[group_id] + data = self.data.loc[mask] + else: + data = self.data + + cutoff_time = data[self.time].max() + + result = { + "t": data[self.time].values, + "y": torch.tensor(data[self.target].values), + "x": torch.tensor(data[self.feature_cols].values), + "group": torch.tensor([hash(str(group_id))]), + "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + "cutoff_time": cutoff_time, + } + + if self.data_future is not None: + if self.group: + future_mask = self.data_future.groupby(self.group).groups[group_id] + future_data = self.data_future.loc[future_mask] + else: + future_data = self.data_future + + combined_times = np.concatenate( + [data[self.time].values, future_data[self.time].values] + ) + combined_times = np.unique(combined_times) + combined_times.sort() + + num_timepoints = len(combined_times) + x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(self.target)), np.nan) + + current_time_indices = {t: i for i, t in enumerate(combined_times)} + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + x_merged[idx] = data[self.feature_cols].values[i] + y_merged[idx] = data[self.target].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices: + idx = current_time_indices[t] + for j, col in enumerate(self.known): + if col in self.feature_cols: + feature_idx = self.feature_cols.index(col) + x_merged[idx, feature_idx] = future_data[col].values[i] + + result.update( + { + "t": combined_times, + "x": torch.tensor(x_merged, dtype=torch.float32), + "y": torch.tensor(y_merged, dtype=torch.float32), + } + ) + + if self.weight: + if self.data_future is not None and self.weight in self.data_future.columns: + weights_merged = np.full(num_timepoints, np.nan) + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + weights_merged[idx] = data[self.weight].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices and self.weight in future_data.columns: + idx = current_time_indices[t] + weights_merged[idx] = future_data[self.weight].values[i] + + result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) + else: + result["weights"] = torch.tensor( + data[self.weight].values, dtype=torch.float32 + ) + + return result + + def get_metadata(self) -> Dict: + """Return metadata about the dataset. + + Returns + ------- + Dict + Dictionary containing: + - cols: column names for y, x, and static features + - col_type: mapping of columns to their types (F/C) + - col_known: mapping of columns to their future known status (K/U) + """ + return self.metadata From d0d1c3ec7fb3bdee8e80d9ff83cd43e8990a5319 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 18:47:46 +0530 Subject: [PATCH 23/80] remove one comment --- pytorch_forecasting/data/data_module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 56917696d..2958f1705 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -550,7 +550,6 @@ def setup(self, stage: Optional[str] = None): self.val_dataset = self._ProcessedEncoderDecoderDataset( self.val_processed, self.val_windows, self.add_relative_time_idx ) - # print(self.val_dataset[0]) elif stage is None or stage == "test": if not hasattr(self, "test_dataset"): From 80e64d218a744557bd493ea07547f0f42b029573 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 19:07:01 +0530 Subject: [PATCH 24/80] model layer commit --- .../models/base/base_model_refactor.py | 283 ++++++++++++++++++ .../tft_version_two.py | 218 ++++++++++++++ 2 files changed, 501 insertions(+) create mode 100644 pytorch_forecasting/models/base/base_model_refactor.py create mode 100644 pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py diff --git a/pytorch_forecasting/models/base/base_model_refactor.py b/pytorch_forecasting/models/base/base_model_refactor.py new file mode 100644 index 000000000..ccd2c2600 --- /dev/null +++ b/pytorch_forecasting/models/base/base_model_refactor.py @@ -0,0 +1,283 @@ +######################################################################################## +# Disclaimer: This baseclass is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the base classes may look like +# in the version-2. +######################################################################################## + + +from typing import Dict, List, Optional, Tuple, Union + +from lightning.pytorch import LightningModule +from lightning.pytorch.utilities.types import STEP_OUTPUT +import torch +import torch.nn as nn +from torch.optim import Optimizer + + +class BaseModel(LightningModule): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + ): + """ + Base model for time series forecasting. + + Parameters + ---------- + loss : nn.Module + Loss function to use for training. + logging_metrics : Optional[List[nn.Module]], optional + List of metrics to log during training, validation, and testing. + optimizer : Optional[Union[Optimizer, str]], optional + Optimizer to use for training. + Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`. + optimizer_params : Optional[Dict], optional + Parameters for the optimizer. + lr_scheduler : Optional[str], optional + Learning rate scheduler to use. + Supported values: "reduce_lr_on_plateau", "step_lr". + lr_scheduler_params : Optional[Dict], optional + Parameters for the learning rate scheduler. + """ + super().__init__() + self.loss = loss + self.logging_metrics = logging_metrics if logging_metrics is not None else [] + self.optimizer = optimizer + self.optimizer_params = optimizer_params if optimizer_params is not None else {} + self.lr_scheduler = lr_scheduler + self.lr_scheduler_params = ( + lr_scheduler_params if lr_scheduler_params is not None else {} + ) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors + """ + raise NotImplementedError("Forward method must be implemented by subclass.") + + def training_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Training step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="train") + return {"loss": loss} + + def validation_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Validation step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="val") + return {"val_loss": loss} + + def test_step( + self, batch: Tuple[Dict[str, torch.Tensor]], batch_idx: int + ) -> STEP_OUTPUT: + """ + Test step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input and target tensors. + batch_idx : int + Index of the batch. + + Returns + ------- + STEP_OUTPUT + Dictionary containing the loss and other metrics. + """ + x, y = batch + y_hat_dict = self(x) + y_hat = y_hat_dict["prediction"] + loss = self.loss(y_hat, y) + self.log( + "test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True + ) + self.log_metrics(y_hat, y, prefix="test") + return {"test_loss": loss} + + def predict_step( + self, + batch: Tuple[Dict[str, torch.Tensor]], + batch_idx: int, + dataloader_idx: int = 0, + ) -> torch.Tensor: + """ + Prediction step for the model. + + Parameters + ---------- + batch : Tuple[Dict[str, torch.Tensor]] + Batch of data containing input tensors. + batch_idx : int + Index of the batch. + dataloader_idx : int + Index of the dataloader. + + Returns + ------- + torch.Tensor + Predicted output tensor. + """ + x, _ = batch + y_hat = self(x) + return y_hat + + def configure_optimizers(self) -> Dict: + """ + Configure the optimizer and learning rate scheduler. + + Returns + ------- + Dict + Dictionary containing the optimizer and scheduler configuration. + """ + optimizer = self._get_optimizer() + if self.lr_scheduler is not None: + scheduler = self._get_scheduler(optimizer) + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val_loss", + }, + } + else: + return {"optimizer": optimizer, "lr_scheduler": scheduler} + return {"optimizer": optimizer} + + def _get_optimizer(self) -> Optimizer: + """ + Get the optimizer based on the specified optimizer name and parameters. + + Returns + ------- + Optimizer + The optimizer instance. + """ + if isinstance(self.optimizer, str): + if self.optimizer.lower() == "adam": + return torch.optim.Adam(self.parameters(), **self.optimizer_params) + elif self.optimizer.lower() == "sgd": + return torch.optim.SGD(self.parameters(), **self.optimizer_params) + else: + raise ValueError(f"Optimizer {self.optimizer} not supported.") + elif isinstance(self.optimizer, Optimizer): + return self.optimizer + else: + raise ValueError( + "Optimizer must be either a string or " + "an instance of torch.optim.Optimizer." + ) + + def _get_scheduler( + self, optimizer: Optimizer + ) -> torch.optim.lr_scheduler._LRScheduler: + """ + Get the lr scheduler based on the specified scheduler name and params. + + Parameters + ---------- + optimizer : Optimizer + The optimizer instance. + + Returns + ------- + torch.optim.lr_scheduler._LRScheduler + The learning rate scheduler instance. + """ + if self.lr_scheduler.lower() == "reduce_lr_on_plateau": + return torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, **self.lr_scheduler_params + ) + elif self.lr_scheduler.lower() == "step_lr": + return torch.optim.lr_scheduler.StepLR( + optimizer, **self.lr_scheduler_params + ) + else: + raise ValueError(f"Scheduler {self.lr_scheduler} not supported.") + + def log_metrics( + self, y_hat: torch.Tensor, y: torch.Tensor, prefix: str = "val" + ) -> None: + """ + Log additional metrics during training, validation, or testing. + + Parameters + ---------- + y_hat : torch.Tensor + Predicted output tensor. + y : torch.Tensor + Target output tensor. + prefix : str + Prefix for the logged metrics (e.g., "train", "val", "test"). + """ + for metric in self.logging_metrics: + metric_value = metric(y_hat, y) + self.log( + f"{prefix}_{metric.__class__.__name__}", + metric_value, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py new file mode 100644 index 000000000..30f70f98e --- /dev/null +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py @@ -0,0 +1,218 @@ +######################################################################################## +# Disclaimer: This implementation is based on the new version of data pipeline and is +# experimental, please use with care. +######################################################################################## + +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn +from torch.optim import Optimizer + +from pytorch_forecasting.models.base.base_model_refactor import BaseModel + + +class TFT(BaseModel): + def __init__( + self, + loss: nn.Module, + logging_metrics: Optional[List[nn.Module]] = None, + optimizer: Optional[Union[Optimizer, str]] = "adam", + optimizer_params: Optional[Dict] = None, + lr_scheduler: Optional[str] = None, + lr_scheduler_params: Optional[Dict] = None, + hidden_size: int = 64, + num_layers: int = 2, + attention_head_size: int = 4, + dropout: float = 0.1, + metadata: Optional[Dict] = None, + output_size: int = 1, + ): + super().__init__( + loss=loss, + logging_metrics=logging_metrics, + optimizer=optimizer, + optimizer_params=optimizer_params, + lr_scheduler=lr_scheduler, + lr_scheduler_params=lr_scheduler_params, + ) + self.hidden_size = hidden_size + self.num_layers = num_layers + self.attention_head_size = attention_head_size + self.dropout = dropout + self.metadata = metadata + self.output_size = output_size + + self.max_encoder_length = self.metadata["max_encoder_length"] + self.max_prediction_length = self.metadata["max_prediction_length"] + self.encoder_cont = self.metadata["encoder_cont"] + self.encoder_cat = self.metadata["encoder_cat"] + self.static_categorical_features = self.metadata["static_categorical_features"] + self.static_continuous_features = self.metadata["static_continuous_features"] + + total_feature_size = self.encoder_cont + self.encoder_cat + total_static_size = ( + self.static_categorical_features + self.static_continuous_features + ) + + self.encoder_var_selection = nn.Sequential( + nn.Linear(total_feature_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, total_feature_size), + nn.Sigmoid(), + ) + + self.decoder_var_selection = nn.Sequential( + nn.Linear(total_feature_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, total_feature_size), + nn.Sigmoid(), + ) + + self.static_context_linear = ( + nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None + ) + + self.lstm_encoder = nn.LSTM( + input_size=total_feature_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + self.lstm_decoder = nn.LSTM( + input_size=total_feature_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + + self.self_attention = nn.MultiheadAttention( + embed_dim=hidden_size, + num_heads=attention_head_size, + dropout=dropout, + batch_first=True, + ) + + self.pre_output = nn.Linear(hidden_size, hidden_size) + self.output_layer = nn.Linear(hidden_size, output_size) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Forward pass of the TFT model. + + Parameters + ---------- + x : Dict[str, torch.Tensor] + Dictionary containing input tensors: + - encoder_cat: Categorical encoder features + - encoder_cont: Continuous encoder features + - decoder_cat: Categorical decoder features + - decoder_cont: Continuous decoder features + - static_categorical_features: Static categorical features + - static_continuous_features: Static continuous features + + Returns + ------- + Dict[str, torch.Tensor] + Dictionary containing output tensors: + - prediction: Prediction output (batch_size, prediction_length, output_size) + """ + batch_size = x["encoder_cont"].shape[0] + + encoder_cat = x.get( + "encoder_cat", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + encoder_cont = x.get( + "encoder_cont", + torch.zeros(batch_size, self.max_encoder_length, 0, device=self.device), + ) + decoder_cat = x.get( + "decoder_cat", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + decoder_cont = x.get( + "decoder_cont", + torch.zeros(batch_size, self.max_prediction_length, 0, device=self.device), + ) + + encoder_input = torch.cat([encoder_cont, encoder_cat], dim=2) + decoder_input = torch.cat([decoder_cont, decoder_cat], dim=2) + + static_context = None + if self.static_context_linear is not None: + static_cat = x.get( + "static_categorical_features", + torch.zeros(batch_size, 0, device=self.device), + ) + static_cont = x.get( + "static_continuous_features", + torch.zeros(batch_size, 0, device=self.device), + ) + + if static_cat.size(2) == 0 and static_cont.size(2) == 0: + static_context = None + elif static_cat.size(2) == 0: + static_input = static_cont.to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + elif static_cont.size(2) == 0: + static_input = static_cat.to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + else: + + static_input = torch.cat([static_cont, static_cat], dim=1).to( + dtype=self.static_context_linear.weight.dtype + ) + static_context = self.static_context_linear(static_input) + static_context = static_context.view(batch_size, self.hidden_size) + + encoder_weights = self.encoder_var_selection(encoder_input) + encoder_input = encoder_input * encoder_weights + + decoder_weights = self.decoder_var_selection(decoder_input) + decoder_input = decoder_input * decoder_weights + + if static_context is not None: + encoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_encoder_length, -1 + ) + decoder_static_context = static_context.unsqueeze(1).expand( + -1, self.max_prediction_length, -1 + ) + + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + encoder_output = encoder_output + encoder_static_context + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + decoder_output = decoder_output + decoder_static_context + else: + encoder_output, (h_n, c_n) = self.lstm_encoder(encoder_input) + decoder_output, _ = self.lstm_decoder(decoder_input, (h_n, c_n)) + + sequence = torch.cat([encoder_output, decoder_output], dim=1) + + if static_context is not None: + expanded_static_context = static_context.unsqueeze(1).expand( + -1, sequence.size(1), -1 + ) + + attended_output, _ = self.self_attention( + sequence + expanded_static_context, sequence, sequence + ) + else: + attended_output, _ = self.self_attention(sequence, sequence, sequence) + + decoder_attended = attended_output[:, -self.max_prediction_length :, :] + + output = nn.functional.relu(self.pre_output(decoder_attended)) + prediction = self.output_layer(output) + + return {"prediction": prediction} From 6364780ae121298e3d98a2c14c6f6747bf62a7b4 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 6 Apr 2025 19:34:57 +0530 Subject: [PATCH 25/80] update docstring --- pytorch_forecasting/data/timeseries.py | 44 +++++++++++++++----------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index bc8300300..9da02d3a0 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2815,25 +2815,31 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: """Get time series data for given index. - It returns: - - * ``t``: ``numpy.ndarray`` of shape (n_timepoints,) - Time index for each time point in the past or present. Aligned with ``y``, - and ``x`` not ending in ``f``. - * ``y``: tensor of shape (n_timepoints, n_targets) - Target values for each time point. Rows are time points, aligned with ``t``. - * ``x``: tensor of shape (n_timepoints, n_features) - Features for each time point. Rows are time points, aligned with ``t``. - * ``group``: tensor of shape (n_groups) - Group identifiers for time series instances. - * ``st``: tensor of shape (n_static_features) - Static features. - * ``cutoff_time``: float or ``numpy.float64`` - Cutoff time for the time series instance. - - Optionally, the following str-keyed entry can be included: - - * ``weights``: tensor of shape (n_timepoints), only if weight is not None + Returns + ------- + t : numpy.ndarray of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with `y`, + and `x` not ending in `f`. + + y : torch.Tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with `t`. + + x : torch.Tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with `t`. + + group : torch.Tensor of shape (n_groups,) + Group identifiers for time series instances. + + st : torch.Tensor of shape (n_static_features,) + Static features. + + cutoff_time : float or numpy.float64 + Cutoff time for the time series instance. + + Other Returns + ------------- + weights : torch.Tensor of shape (n_timepoints,), optional + Only included if weights are not `None`. """ group_id = self._group_ids[index] From 257183ce4d2b1f7fd40c95ecd7dc38c8004a017b Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 01:54:50 +0530 Subject: [PATCH 26/80] update data_module.py --- pytorch_forecasting/data/data_module.py | 160 ++++++++++++------------ 1 file changed, 80 insertions(+), 80 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 2958f1705..c796b85fa 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -1,15 +1,6 @@ -####################################################################################### -# Disclaimer: This data-module is still work in progress and experimental, please -# use with care. This data-module is a basic skeleton of how the data-handling pipeline -# may look like in the future. -# This is D2 layer that will handle the preprocessing and data loaders. -# For now, this pipeline handles the simplest situation: The whole data can be loaded -# into the memory. -####################################################################################### - from typing import Any, Dict, List, Optional, Tuple, Union -from lightning.pytorch import LightningDataModule +from lightning.pytorch import LightningDataModule, LightningModule from sklearn.preprocessing import RobustScaler, StandardScaler import torch from torch.utils.data import DataLoader, Dataset @@ -19,7 +10,11 @@ NaNLabelEncoder, TorchNormalizer, ) -from pytorch_forecasting.data.timeseries import TimeSeries, _coerce_to_dict +from pytorch_forecasting.data.timeseries import ( + TimeSeries, + _coerce_to_dict, + _coerce_to_list, +) NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] @@ -274,7 +269,7 @@ def metadata(self): self._metadata = self._prepare_metadata() return self._metadata - def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: + def _preprocess_data(self, series_idx: torch.Tensor) -> List[Dict[str, Any]]: """Preprocess the data before feeding it into _ProcessedEncoderDecoderDataset. Preprocessing steps @@ -288,63 +283,58 @@ def _preprocess_data(self, indices: torch.Tensor) -> List[Dict[str, Any]]: TODO: add scalers, target normalizers etc. """ - processed_data = [] + sample = self.time_series_dataset[series_idx] - for idx in indices: - sample = self.time_series_dataset[idx.item()] + target = sample["y"] + features = sample["x"] + times = sample["t"] + cutoff_time = sample["cutoff_time"] - target = sample["y"] - features = sample["x"] - times = sample["t"] - cutoff_time = sample["cutoff_time"] + time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) - time_mask = torch.tensor(times <= cutoff_time, dtype=torch.bool) - - if isinstance(target, torch.Tensor): - target = target.float() - else: - target = torch.tensor(target, dtype=torch.float32) - - if isinstance(features, torch.Tensor): - features = features.float() - else: - features = torch.tensor(features, dtype=torch.float32) + if isinstance(target, torch.Tensor): + target = target.float() + else: + target = torch.tensor(target, dtype=torch.float32) - # TODO: add scalers, target normalizers etc. + if isinstance(features, torch.Tensor): + features = features.float() + else: + features = torch.tensor(features, dtype=torch.float32) - categorical = ( - features[:, self.categorical_indices] - if self.categorical_indices - else torch.zeros((features.shape[0], 0)) - ) - continuous = ( - features[:, self.continuous_indices] - if self.continuous_indices - else torch.zeros((features.shape[0], 0)) - ) + # TODO: add scalers, target normalizers etc. - processed_data.append( - { - "features": {"categorical": categorical, "continuous": continuous}, - "target": target, - "static": sample.get("st", None), - "group": sample.get("group", torch.tensor([0])), - "length": len(target), - "time_mask": time_mask, - "times": times, - "cutoff_time": cutoff_time, - } - ) + categorical = ( + features[:, self.categorical_indices] + if self.categorical_indices + else torch.zeros((features.shape[0], 0)) + ) + continuous = ( + features[:, self.continuous_indices] + if self.continuous_indices + else torch.zeros((features.shape[0], 0)) + ) - return processed_data + return { + "features": {"categorical": categorical, "continuous": continuous}, + "target": target, + "static": sample.get("st", None), + "group": sample.get("group", torch.tensor([0])), + "length": len(target), + "time_mask": time_mask, + "times": times, + "cutoff_time": cutoff_time, + } class _ProcessedEncoderDecoderDataset(Dataset): """PyTorch Dataset for processed encoder-decoder time series data. Parameters ---------- - processed_data : List[Dict[str, Any]] - List of preprocessed time series samples. + dataset : TimeSeries + The base time series dataset that provides access to raw data and metadata. + data_module : EncoderDecoderTimeSeriesDataModule + The data module handling preprocessing and metadata configuration. windows : List[Tuple[int, int, int, int]] List of window tuples containing (series_idx, start_idx, enc_length, pred_length). @@ -354,11 +344,13 @@ class _ProcessedEncoderDecoderDataset(Dataset): def __init__( self, - processed_data: List[Dict[str, Any]], + dataset: TimeSeries, + data_module: "EncoderDecoderTimeSeriesDataModule", windows: List[Tuple[int, int, int, int]], add_relative_time_idx: bool = False, ): - self.processed_data = processed_data + self.dataset = dataset + self.data_module = data_module self.windows = windows self.add_relative_time_idx = add_relative_time_idx @@ -410,7 +402,7 @@ def __getitem__(self, idx): Target values for the decoder sequence. """ series_idx, start_idx, enc_length, pred_length = self.windows[idx] - data = self.processed_data[series_idx] + data = self.data_module._preprocess_data(series_idx) end_idx = start_idx + enc_length + pred_length encoder_indices = slice(start_idx, start_idx + enc_length) @@ -457,9 +449,7 @@ def __getitem__(self, idx): return x, y - def _create_windows( - self, processed_data: List[Dict[str, Any]] - ) -> List[Tuple[int, int, int, int]]: + def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, int]]: """Generate sliding windows for training, validation, and testing. Returns @@ -477,8 +467,10 @@ def _create_windows( """ windows = [] - for idx, data in enumerate(processed_data): - sequence_length = data["length"] + for idx in indices: + series_idx = idx.item() + sample = self.time_series_dataset[series_idx] + sequence_length = len(sample["y"]) if sequence_length < self.max_encoder_length + self.max_prediction_length: continue @@ -503,7 +495,7 @@ def _create_windows( ): windows.append( ( - idx, + series_idx, start_idx, self.max_encoder_length, self.max_prediction_length, @@ -538,33 +530,41 @@ def setup(self, stage: Optional[str] = None): if stage is None or stage == "fit": if not hasattr(self, "train_dataset") or not hasattr(self, "val_dataset"): - self.train_processed = self._preprocess_data(self._train_indices) - self.val_processed = self._preprocess_data(self._val_indices) - - self.train_windows = self._create_windows(self.train_processed) - self.val_windows = self._create_windows(self.val_processed) + self.train_windows = self._create_windows(self._train_indices) + self.val_windows = self._create_windows(self._val_indices) self.train_dataset = self._ProcessedEncoderDecoderDataset( - self.train_processed, self.train_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.train_windows, + self.add_relative_time_idx, ) self.val_dataset = self._ProcessedEncoderDecoderDataset( - self.val_processed, self.val_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.val_windows, + self.add_relative_time_idx, ) - elif stage is None or stage == "test": + elif stage == "test": if not hasattr(self, "test_dataset"): - self.test_processed = self._preprocess_data(self._test_indices) - self.test_windows = self._create_windows(self.test_processed) - + self.test_windows = self._create_windows(self._test_indices) self.test_dataset = self._ProcessedEncoderDecoderDataset( - self.test_processed, self.test_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.test_windows, + self, + self.add_relative_time_idx, ) elif stage == "predict": predict_indices = torch.arange(len(self.time_series_dataset)) - self.predict_processed = self._preprocess_data(predict_indices) - self.predict_windows = self._create_windows(self.predict_processed) + self.predict_windows = self._create_windows(predict_indices) self.predict_dataset = self._ProcessedEncoderDecoderDataset( - self.predict_processed, self.predict_windows, self.add_relative_time_idx + self.time_series_dataset, + self, + self.predict_windows, + self, + self.add_relative_time_idx, ) def train_dataloader(self): From 9cdcb195c4c9e3f9b6d0e76ef3b6ed889bc14998 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 01:56:55 +0530 Subject: [PATCH 27/80] update data_module.py --- pytorch_forecasting/data/data_module.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index c796b85fa..9a4a5bf5e 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -553,7 +553,6 @@ def setup(self, stage: Optional[str] = None): self.time_series_dataset, self, self.test_windows, - self, self.add_relative_time_idx, ) elif stage == "predict": @@ -563,7 +562,6 @@ def setup(self, stage: Optional[str] = None): self.time_series_dataset, self, self.predict_windows, - self, self.add_relative_time_idx, ) From ac56d4fd56aeeb1287f162559c67e785de4446f4 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 02:05:58 +0530 Subject: [PATCH 28/80] Add disclaimer --- pytorch_forecasting/data/data_module.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 9a4a5bf5e..b33a11d47 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -1,6 +1,15 @@ +####################################################################################### +# Disclaimer: This data-module is still work in progress and experimental, please +# use with care. This data-module is a basic skeleton of how the data-handling pipeline +# may look like in the future. +# This is D2 layer that will handle the preprocessing and data loaders. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + from typing import Any, Dict, List, Optional, Tuple, Union -from lightning.pytorch import LightningDataModule, LightningModule +from lightning.pytorch import LightningDataModule from sklearn.preprocessing import RobustScaler, StandardScaler import torch from torch.utils.data import DataLoader, Dataset @@ -13,7 +22,6 @@ from pytorch_forecasting.data.timeseries import ( TimeSeries, _coerce_to_dict, - _coerce_to_list, ) NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] From 4bfff21de1a75be0c93dcb713cb91defe6bc2fad Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 11 Apr 2025 12:44:44 +0530 Subject: [PATCH 29/80] update docstring --- pytorch_forecasting/data/data_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index b33a11d47..9d4e0b02f 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -465,7 +465,7 @@ def _create_windows(self, indices: torch.Tensor) -> List[Tuple[int, int, int, in List[Tuple[int, int, int, int]] A list of tuples, where each tuple consists of: - ``series_idx`` : int - Index of the time series in `processed_data`. + Index of the time series in `time_series_dataset`. - ``start_idx`` : int Start index of the encoder window. - ``enc_length`` : int @@ -522,7 +522,7 @@ def setup(self, stage: Optional[str] = None): - ``"fit"`` : Prepares training and validation datasets. - ``"test"`` : Prepares the test dataset. - ``"predict"`` : Prepares the dataset for inference. - - ``None`` : Prepares all datasets. + - ``None`` : Prepares ``fit`` datasets. """ total_series = len(self.time_series_dataset) self._split_indices = torch.randperm(total_series) From 8a53ed63933b0b92d752eaa707eadc7c45d35566 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 19 Apr 2025 19:37:40 +0530 Subject: [PATCH 30/80] Add tests for D1,D2 layer --- pytorch_forecasting/data/data_module.py | 56 ++- tests/test_data/test_d1.py | 379 +++++++++++++++++++ tests/test_data/test_data_module.py | 464 ++++++++++++++++++++++++ 3 files changed, 895 insertions(+), 4 deletions(-) create mode 100644 tests/test_data/test_d1.py create mode 100644 tests/test_data/test_data_module.py diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 9d4e0b02f..1203e83ac 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -432,11 +432,59 @@ def __getitem__(self, idx): else torch.zeros(pred_length, dtype=torch.bool) ) + encoder_cat = data["features"]["categorical"][encoder_indices] + encoder_cont = data["features"]["continuous"][encoder_indices] + + features = data["features"] + metadata = self.data_module.time_series_metadata + + known_cat_indices = [ + i + for i, col in enumerate(metadata["cols"]["x"]) + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + + known_cont_indices = [ + i + for i, col in enumerate(metadata["cols"]["x"]) + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + + cat_map = { + orig_idx: i + for i, orig_idx in enumerate(self.data_module.categorical_indices) + } + cont_map = { + orig_idx: i + for i, orig_idx in enumerate(self.data_module.continuous_indices) + } + + mapped_known_cat_indices = [ + cat_map[idx] for idx in known_cat_indices if idx in cat_map + ] + mapped_known_cont_indices = [ + cont_map[idx] for idx in known_cont_indices if idx in cont_map + ] + + decoder_cat = ( + features["categorical"][decoder_indices][:, mapped_known_cat_indices] + if mapped_known_cat_indices + else torch.zeros((pred_length, 0)) + ) + + decoder_cont = ( + features["continuous"][decoder_indices][:, mapped_known_cont_indices] + if mapped_known_cont_indices + else torch.zeros((pred_length, 0)) + ) + x = { - "encoder_cat": data["features"]["categorical"][encoder_indices], - "encoder_cont": data["features"]["continuous"][encoder_indices], - "decoder_cat": data["features"]["categorical"][decoder_indices], - "decoder_cont": data["features"]["continuous"][decoder_indices], + "encoder_cat": encoder_cat, + "encoder_cont": encoder_cont, + "decoder_cat": decoder_cat, + "decoder_cont": decoder_cont, "encoder_lengths": torch.tensor(enc_length), "decoder_lengths": torch.tensor(pred_length), "decoder_target_lengths": torch.tensor(pred_length), diff --git a/tests/test_data/test_d1.py b/tests/test_data/test_d1.py new file mode 100644 index 000000000..b32c13213 --- /dev/null +++ b/tests/test_data/test_d1.py @@ -0,0 +1,379 @@ +import numpy as np +import pandas as pd +import pytest +import torch + +from pytorch_forecasting.data.timeseries import TimeSeries + + +@pytest.fixture +def sample_data(): + """Create time series data for testing.""" + dates = pd.date_range(start="2023-01-01", periods=10, freq="D") + data = pd.DataFrame( + { + "timestamp": dates, + "target_value": np.sin(np.arange(10)) + 10, + "feature1": np.random.randn(10), + "feature2": np.random.randn(10), + "feature3": np.random.randn(10), + "group_id": [1, 1, 1, 1, 1, 2, 2, 2, 2, 2], + "weight": np.abs(np.random.randn(10)) + 0.1, + "static_feat": [10, 10, 10, 10, 10, 20, 20, 20, 20, 20], + } + ) + return data + + +@pytest.fixture +def future_data(): + """Create future time series data.""" + dates = pd.date_range(start="2023-01-11", periods=5, freq="D") + data = pd.DataFrame( + { + "timestamp": dates, + "feature1": np.random.randn(5), + "feature2": np.random.randn(5), + "feature3": np.random.randn(5), + "group_id": [1, 1, 1, 2, 2], + "weight": np.abs(np.random.randn(5)) + 0.1, + "static_feat": [10, 10, 10, 20, 20], + } + ) + return data + + +def test_init_basic(sample_data): + """Test basic initialization of TimeSeries class. + + Ensures that the class stores time, target, and correctly detects feature columns + when no group, known/unknown features, or static/weight features are specified.""" + ts = TimeSeries(data=sample_data, time="timestamp", target="target_value") + + assert ts.time == "timestamp" + assert ts.target == ["target_value"] + assert len(ts.feature_cols) == 6 # All columns except timestamp, target_value + assert len(ts) == 1 # Single group by default + + +def test_init_with_groups(sample_data): + """Test initialization with group parameter. + + Verifies that data is grouped correctly and each group is handled as a + separate time series. + """ + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", group=["group_id"] + ) + + assert ts.group == ["group_id"] + assert len(ts) == 2 # Two groups (1 and 2) + assert set(ts._group_ids) == {1, 2} + + +def test_init_with_features_categorization(sample_data): + """Test feature categorization. + + Ensures that numeric, categorical, and static features are categorized and + stored correctly in metadata.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + num=["feature1", "feature2", "feature3"], + cat=[], + static=["static_feat"], + ) + + assert ts.num == ["feature1", "feature2", "feature3"] + assert ts.cat == [] + assert ts.static == ["static_feat"] + assert ts.metadata["col_type"]["feature1"] == "F" + assert ts.metadata["col_type"]["feature2"] == "F" + + +def test_init_with_known_unknown(sample_data): + """Test known and unknown features classification. + + Checks if the known and unknown feature categorization is correctly set + and stored in metadata.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + known=["feature1"], + unknown=["feature2", "feature3"], + ) + + assert ts.known == ["feature1"] + assert ts.unknown == ["feature2", "feature3"] + assert ts.metadata["col_known"]["feature1"] == "K" + assert ts.metadata["col_known"]["feature2"] == "U" + + +def test_init_with_weight(sample_data): + """Test initialization with weight parameter. + + Verifies that the weight column is stored correctly and excluded + from the feature columns.""" + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", weight="weight" + ) + + assert ts.weight == "weight" + assert "weight" not in ts.feature_cols + + +def test_getitem_basic(sample_data): + """Test __getitem__ with basic configuration. + + Checks the output structure of a single time series without grouping, + ensuring x, y are tensors of correct shapes.""" + ts = TimeSeries(data=sample_data, time="timestamp", target="target_value") + + result = ts[0] + assert torch.is_tensor(result["y"]) + assert torch.is_tensor(result["x"]) + assert "t" in result + assert "cutoff_time" in result + assert len(result["y"]) == 10 # 10 data points + assert result["y"].shape == (10, 1) # One target variable + assert result["x"].shape[1] == 6 # Six feature columns + + +def test_getitem_with_groups(sample_data): + """Test __getitem__ with groups parameter. + + Verifies the per-group access using index and checks that each group + has the correct number of time steps.""" + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", group=["group_id"] + ) + + # group (1) + result_g1 = ts[0] + assert len(result_g1["t"]) == 5 # 5 data points in group 1 + + # group (2) + result_g2 = ts[1] + assert len(result_g2["t"]) == 5 # 5 data points in group 2 + + +def test_getitem_with_static(sample_data): + """Test __getitem__ with static features. + + Ensures static features are included in the output and correctly + mapped per group.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + group=["group_id"], + static=["static_feat"], + ) + + result_g1 = ts[0] + result_g2 = ts[1] + + assert torch.is_tensor(result_g1["st"]) + assert result_g1["st"].item() == 10 # Static feature for group 1 + assert result_g2["st"].item() == 20 # Static feature for group 2 + + +def test_getitem_with_weight(sample_data): + """Test __getitem__ with weight parameter. + + Validates that weights are correctly returned in the output and have the + expected length and type.""" + ts = TimeSeries( + data=sample_data, time="timestamp", target="target_value", weight="weight" + ) + + result = ts[0] + assert "weights" in result + assert torch.is_tensor(result["weights"]) + assert len(result["weights"]) == 10 + + +def test_with_future_data(sample_data, future_data): + """Test with future data provided. + + Verifies that future time steps are appended to the end of each group, + especially for known features.""" + ts = TimeSeries( + data=sample_data, + data_future=future_data, + time="timestamp", + target="target_value", + group=["group_id"], + known=["feature1"], + ) + + result_g1 = ts[0] # Group 1 + + assert len(result_g1["t"]) == 8 # 5 original + 3 future for group 1 + + feature1_idx = ts.feature_cols.index("feature1") + assert not torch.isnan( + result_g1["x"][-1, feature1_idx] + ) # feature1 is not NaN in last row + + +def test_future_data_with_weights(sample_data, future_data): + """Test handling of weights with future data. + + Ensures that weights from future data are combined properly and match the + time indices.""" + ts = TimeSeries( + data=sample_data, + data_future=future_data, + time="timestamp", + target="target_value", + group=["group_id"], + weight="weight", + ) + + result = ts[0] # Group 1 + assert "weights" in result + assert torch.is_tensor(result["weights"]) + assert len(result["weights"]) == len(result["t"]) + + +def test_future_data_missing_columns(sample_data): + """Test handling when future data is missing some columns. + + Verifies the handling of missing feature columns in future data by + checking NaN padding.""" + dates = pd.date_range(start="2023-01-11", periods=5, freq="D") + incomplete_future = pd.DataFrame( + { + "timestamp": dates, + "feature1": np.random.randn(5), + # Missing feature2, feature3 + "group_id": [1, 1, 1, 2, 2], + "weight": np.abs(np.random.randn(5)) + 0.1, + } + ) + + ts = TimeSeries( + data=sample_data, + data_future=incomplete_future, + time="timestamp", + target="target_value", + group=["group_id"], + known=["feature1"], + ) + + result = ts[0] + # Check that missing features are NaN in future timepoints + future_indices = np.where(result["t"] >= np.datetime64("2023-01-11"))[0] + feature2_idx = ts.feature_cols.index("feature2") + feature3_idx = ts.feature_cols.index("feature3") + assert torch.isnan(result["x"][future_indices[0], feature2_idx]) + assert torch.isnan(result["x"][future_indices[0], feature3_idx]) + + +def test_different_future_groups(sample_data): + """Test with future data that has different groups than original data. + + Ensures that groups present only in future data are ignored if not + in the original dataset.""" + dates = pd.date_range(start="2023-01-11", periods=5, freq="D") + future_with_new_group = pd.DataFrame( + { + "timestamp": dates, + "feature1": np.random.randn(5), + "feature2": np.random.randn(5), + "feature3": np.random.randn(5), + "group_id": [1, 1, 3, 3, 3], # Group 3 is new + "weight": np.abs(np.random.randn(5)) + 0.1, + "static_feat": [10, 10, 30, 30, 30], + } + ) + + ts = TimeSeries( + data=sample_data, + data_future=future_with_new_group, + time="timestamp", + target="target_value", + group=["group_id"], + ) + + # Original data has groups 1 and 2, but not 3 + assert len(ts) == 2 + assert 3 not in ts._group_ids + + +def test_multiple_targets(sample_data): + """Test handling of multiple target variables. + + Verifies that multiple target columns are handled and returned + as the correct shape in the output.""" + sample_data["target_value2"] = np.cos(np.arange(10)) + 5 + + ts = TimeSeries( + data=sample_data, time="timestamp", target=["target_value", "target_value2"] + ) + + result = ts[0] + assert result["y"].shape == (10, 2) # Two target variables + + +def test_empty_groups(): + """Test handling of empty groups. + + Confirms that the class handles datasets with a single group and + no empty group errors occur.""" + data = pd.DataFrame( + { + "timestamp": pd.date_range(start="2023-01-01", periods=5, freq="D"), + "target_value": np.random.randn(5), + "group_id": [1, 1, 1, 1, 1], # Only one group + } + ) + + ts = TimeSeries( + data=data, time="timestamp", target="target_value", group=["group_id"] + ) + + assert len(ts) == 1 # Only one group + + +def test_metadata_structure(sample_data): + """Test the structure of metadata. + + Ensures the metadata dictionary includes the expected keys and + correct mappings of feature roles.""" + ts = TimeSeries( + data=sample_data, + time="timestamp", + target="target_value", + num=["feature1", "feature2", "feature3"], + cat=[], # No categorical features + static=["static_feat"], + known=["feature1"], + unknown=["feature2", "feature3"], + ) + + metadata = ts.get_metadata() + + assert "cols" in metadata + assert "col_type" in metadata + assert "col_known" in metadata + + assert metadata["cols"]["y"] == ["target_value"] + assert set(metadata["cols"]["x"]) == { + "feature1", + "feature2", + "feature3", + "group_id", + "weight", + "static_feat", + } + assert metadata["cols"]["st"] == ["static_feat"] + + assert metadata["col_type"]["feature1"] == "F" + assert metadata["col_type"]["feature2"] == "F" + + assert metadata["col_known"]["feature1"] == "K" + assert metadata["col_known"]["feature2"] == "U" diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py new file mode 100644 index 000000000..c14e3d8f4 --- /dev/null +++ b/tests/test_data/test_data_module.py @@ -0,0 +1,464 @@ +import numpy as np +import pandas as pd +import pytest + +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule +from pytorch_forecasting.data.timeseries import TimeSeries + + +@pytest.fixture +def sample_timeseries_data(): + """Create a sample time series dataset with only numerical values.""" + num_groups = 5 + seq_length = 100 + + groups = [] + times = [] + values = [] + categorical_feature = [] + continuous_feature1 = [] + continuous_feature2 = [] + known_future = [] + + for g in range(num_groups): + for t in range(seq_length): + groups.append(g) + times.append(pd.Timestamp("2020-01-01") + pd.Timedelta(days=t)) + + value = 10 + 0.1 * t + 5 * np.sin(t / 10) + g * 2 + np.random.normal(0, 1) + values.append(value) + + categorical_feature.append(np.random.choice([0, 1, 2])) + + continuous_feature1.append(np.random.normal(g, 1)) + continuous_feature2.append(value * 0.5 + np.random.normal(0, 0.5)) + + known_future.append(t % 7) + + df = pd.DataFrame( + { + "group": groups, + "time": times, + "target": values, + "cat_feat": categorical_feature, + "cont_feat1": continuous_feature1, + "cont_feat2": continuous_feature2, + "known_future": known_future, + } + ) + + time_series = TimeSeries( + data=df, + time="time", + target="target", + group=["group"], + num=["cont_feat1", "cont_feat2", "known_future"], + cat=["cat_feat"], + known=["known_future"], + ) + + return time_series + + +@pytest.fixture +def data_module(sample_timeseries_data): + """Create a data module instance.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + train_val_test_split=(0.7, 0.15, 0.15), + ) + return dm + + +def test_init(sample_timeseries_data): + """Test the initialization of the data module. + + Verifies hyperparameter assignment and basic time_series_metadata creation.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=8, + ) + + assert dm.max_encoder_length == 24 + assert dm.max_prediction_length == 12 + assert dm.min_encoder_length == 24 + assert dm.min_prediction_length == 12 + assert dm.batch_size == 8 + assert dm.train_val_test_split == (0.7, 0.15, 0.15) + + assert isinstance(dm.time_series_metadata, dict) + assert "cols" in dm.time_series_metadata + + +def test_prepare_metadata(data_module): + """Test the metadata preparation method. + + Ensures that internal metadata keys are created correctly.""" + metadata = data_module._prepare_metadata() + + assert "encoder_cat" in metadata + assert "encoder_cont" in metadata + assert "decoder_cat" in metadata + assert "decoder_cont" in metadata + assert "target" in metadata + assert "max_encoder_length" in metadata + assert "max_prediction_length" in metadata + + assert metadata["max_encoder_length"] == 24 + assert metadata["max_prediction_length"] == 12 + + +def test_metadata_property(data_module): + """Test the metadata property. + + Confirms caching behavior and correct feature counts.""" + metadata = data_module.metadata + + # Should return the same object when called multiple times (caching) + assert data_module.metadata is metadata + + assert metadata["encoder_cat"] == 1 # cat_feat + assert metadata["encoder_cont"] == 3 # cont_feat1, cont_feat2, known_future + assert metadata["decoder_cat"] == 0 # No categorical features marked as known + assert metadata["decoder_cont"] == 1 # Only known_future marked as known + + +# def test_setup(data_module): +# """Test the setup method that prepares the datasets.""" +# data_module.setup(stage="fit") +# print(data_module._val_indices) +# assert hasattr(data_module, "train_dataset") +# assert hasattr(data_module, "val_dataset") +# assert len(data_module.train_windows) > 0 +# assert len(data_module.val_windows) > 0 +# +# data_module.setup(stage="test") +# assert hasattr(data_module, "test_dataset") +# assert len(data_module.test_windows) > 0 +# +# data_module.setup(stage="predict") +# assert hasattr(data_module, "predict_dataset") +# assert len(data_module.predict_windows) > 0 + + +def test_create_windows(data_module): + """Test the window creation logic. + + Validates window structure and length settings.""" + data_module.setup() + + windows = data_module._create_windows(data_module._train_indices) + + assert len(windows) > 0 + + for window in windows: + assert len(window) == 4 + assert window[2] == data_module.max_encoder_length + assert window[3] == data_module.max_prediction_length + + +def test_dataloader_creation(data_module): + """Test that dataloaders are created correctly. + + Checks batch sizes and dataloader instantiation across all stages.""" + data_module.setup() + + train_loader = data_module.train_dataloader() + assert train_loader.batch_size == data_module.batch_size + assert train_loader.num_workers == data_module.num_workers + + val_loader = data_module.val_dataloader() + assert val_loader.batch_size == data_module.batch_size + + data_module.setup(stage="test") + test_loader = data_module.test_dataloader() + assert test_loader.batch_size == data_module.batch_size + + data_module.setup(stage="predict") + predict_loader = data_module.predict_dataloader() + assert predict_loader.batch_size == data_module.batch_size + + +def test_processed_dataset(data_module): + """Test the internal ProcessedEncoderDecoderDataset class. + + Verifies sample structure and tensor dimensions for encoder/decoder inputs.""" + data_module.setup() + + assert len(data_module.train_dataset) == len(data_module.train_windows) + assert len(data_module.val_dataset) == len(data_module.val_windows) + + x, y = data_module.train_dataset[0] + + required_keys = [ + "encoder_cat", + "encoder_cont", + "decoder_cat", + "decoder_cont", + "encoder_lengths", + "decoder_lengths", + "decoder_target_lengths", + "groups", + "encoder_time_idx", + "decoder_time_idx", + "target_scale", + "encoder_mask", + "decoder_mask", + ] + + for key in required_keys: + assert key in x + + assert x["encoder_cat"].shape[0] == data_module.max_encoder_length + assert x["decoder_cat"].shape[0] == data_module.max_prediction_length + + metadata = data_module.time_series_metadata + known_cat_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + ) + + known_cont_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + ) + + assert x["decoder_cat"].shape[1] == known_cat_count + assert x["decoder_cont"].shape[1] == known_cont_count + + assert y.shape[0] == data_module.max_prediction_length + + +def test_collate_fn(data_module): + """Test the collate function that combines batch samples. + + Ensures proper stacking of dictionary keys and batch outputs.""" + data_module.setup() + + batch_size = 3 + batch = [data_module.train_dataset[i] for i in range(batch_size)] + + x_batch, y_batch = data_module.collate_fn(batch) + + for key in x_batch: + assert x_batch[key].shape[0] == batch_size + + metadata = data_module.time_series_metadata + known_cat_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + ) + + known_cont_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + ) + + assert x_batch["decoder_cat"].shape[2] == known_cat_count + assert x_batch["decoder_cont"].shape[2] == known_cont_count + assert y_batch.shape[0] == batch_size + assert y_batch.shape[1] == data_module.max_prediction_length + + +def test_full_dataloader_iteration(data_module): + """Test a full iteration through the train dataloader. + + Confirms batch retrieval and tensor dimensions match configuration.""" + data_module.setup() + train_loader = data_module.train_dataloader() + + batch = next(iter(train_loader)) + x_batch, y_batch = batch + + assert x_batch["encoder_cat"].shape[0] == data_module.batch_size + assert x_batch["encoder_cat"].shape[1] == data_module.max_encoder_length + + metadata = data_module.time_series_metadata + known_cat_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "C" + and metadata["col_known"].get(col) == "K" + ] + ) + + known_cont_count = len( + [ + col + for col in metadata["cols"]["x"] + if metadata["col_type"].get(col) == "F" + and metadata["col_known"].get(col) == "K" + ] + ) + + assert x_batch["decoder_cat"].shape[0] == data_module.batch_size + assert x_batch["decoder_cat"].shape[2] == known_cat_count + assert x_batch["decoder_cont"].shape[0] == data_module.batch_size + assert x_batch["decoder_cont"].shape[2] == known_cont_count + assert y_batch.shape[0] == data_module.batch_size + assert y_batch.shape[1] == data_module.max_prediction_length + + +def test_variable_encoder_lengths(sample_timeseries_data): + """Test with variable encoder lengths. + + Ensures random length behavior is respected and functional.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + min_encoder_length=12, + max_prediction_length=12, + batch_size=4, + randomize_length=True, + ) + + dm.setup() + assert dm.min_encoder_length == 12 + assert dm.max_encoder_length == 24 + + +def test_preprocess_data(data_module, sample_timeseries_data): + """Test the _preprocess_data method. + + Checks preprocessing output structure and alignment with raw data.""" + if not hasattr(data_module, "_split_indices"): + data_module.setup() + + series_idx = data_module._train_indices[0] + + processed = data_module._preprocess_data(series_idx) + + assert "features" in processed + assert "categorical" in processed["features"] + assert "continuous" in processed["features"] + assert "target" in processed + assert "time_mask" in processed + + original_sample = sample_timeseries_data[series_idx.item()] + expected_length = len(original_sample["y"]) + + assert processed["features"]["categorical"].shape[0] == expected_length + assert processed["features"]["continuous"].shape[0] == expected_length + assert processed["target"].shape[0] == expected_length + + +def test_with_static_features(): + """Test with static features included. + + Validates static feature support in both metadata and sample input.""" + df = pd.DataFrame( + { + "group": [0, 0, 0, 1, 1, 1], + "time": pd.date_range("2020-01-01", periods=6), + "target": [1, 2, 3, 4, 5, 6], + "static_cat": [0, 0, 0, 1, 1, 1], + "static_num": [10, 10, 10, 20, 20, 20], + "feature1": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + } + ) + + ts = TimeSeries( + data=df, + time="time", + target="target", + group=["group"], + num=["feature1", "static_num"], + static=["static_cat", "static_num"], + cat=["static_cat"], + ) + + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=ts, + max_encoder_length=2, + max_prediction_length=1, + batch_size=2, + ) + + dm.setup() + + metadata = dm.metadata + assert metadata["static_categorical_features"] == 1 + assert metadata["static_continuous_features"] == 1 + + x, y = dm.train_dataset[0] + assert "static_categorical_features" in x + assert "static_continuous_features" in x + + +# def test_different_train_val_test_split(sample_timeseries_data): +# """Test with different train/val/test split ratios.""" +# dm = EncoderDecoderTimeSeriesDataModule( +# time_series_dataset=sample_timeseries_data, +# max_encoder_length=24, +# max_prediction_length=12, +# batch_size=4, +# train_val_test_split=(0.8, 0.1, 0.1), +# ) +# +# dm.setup() +# +# total_series = len(sample_timeseries_data) +# expected_train = int(0.8 * total_series) +# expected_val = int(0.1 * total_series) +# +# assert len(dm._train_indices) == expected_train +# assert len(dm._val_indices) == expected_val +# assert len(dm._test_indices) == total_series - expected_train - expected_val + + +def test_multivariate_target(): + """Test with multivariate target (multiple target columns). + + Verifies correct handling of multivariate targets in data pipeline.""" + df = pd.DataFrame( + { + "group": np.repeat([0, 1], 50), + "time": np.tile(pd.date_range("2020-01-01", periods=50), 2), + "target1": np.random.normal(0, 1, 100), + "target2": np.random.normal(5, 2, 100), + "feature1": np.random.normal(0, 1, 100), + "feature2": np.random.normal(0, 1, 100), + } + ) + + ts = TimeSeries( + data=df, + time="time", + target=["target1", "target2"], + group=["group"], + num=["feature1", "feature2"], + ) + + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=ts, + max_encoder_length=10, + max_prediction_length=5, + batch_size=4, + ) + + dm.setup() + + x, y = dm.train_dataset[0] + assert y.shape[-1] == 2 From cdecb770a63269c965261cee3a54744449b445a4 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 19 Apr 2025 19:44:16 +0530 Subject: [PATCH 31/80] Code quality --- pytorch_forecasting/data/timeseries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 6b2662e95..fda08d561 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -9,7 +9,7 @@ from copy import copy as _copy, deepcopy from functools import lru_cache import inspect -from typing import Any, Callable, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union import warnings import numpy as np From 20aafb749cfebdb1f9789b4dff5120fa8527db74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 30 Apr 2025 18:40:01 +0200 Subject: [PATCH 32/80] refactor file --- pytorch_forecasting/data/__init__.py | 3 +- .../data/timeseries/__init__.py | 9 + .../data/timeseries/_coerce.py | 25 ++ .../_timeseries.py} | 286 +----------------- .../data/timeseries/_timeseries_v2.py | 276 +++++++++++++++++ 5 files changed, 314 insertions(+), 285 deletions(-) create mode 100644 pytorch_forecasting/data/timeseries/__init__.py create mode 100644 pytorch_forecasting/data/timeseries/_coerce.py rename pytorch_forecasting/data/{timeseries.py => timeseries/_timeseries.py} (90%) create mode 100644 pytorch_forecasting/data/timeseries/_timeseries_v2.py diff --git a/pytorch_forecasting/data/__init__.py b/pytorch_forecasting/data/__init__.py index 301c8394d..17be285a0 100644 --- a/pytorch_forecasting/data/__init__.py +++ b/pytorch_forecasting/data/__init__.py @@ -13,10 +13,11 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler -from pytorch_forecasting.data.timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries import TimeSeries, TimeSeriesDataSet __all__ = [ "TimeSeriesDataSet", + "TimeSeries", "NaNLabelEncoder", "GroupNormalizer", "TorchNormalizer", diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py new file mode 100644 index 000000000..7734cccf2 --- /dev/null +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -0,0 +1,9 @@ +"""Data loaders for time series data.""" + +from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries +from pytorch_forecasting.data.timeseries._timeseries import TimeSeriesDataSet + +__all__ = [ + "TimeSeriesDataSet", + "TimeSeries", +] diff --git a/pytorch_forecasting/data/timeseries/_coerce.py b/pytorch_forecasting/data/timeseries/_coerce.py new file mode 100644 index 000000000..328431aa8 --- /dev/null +++ b/pytorch_forecasting/data/timeseries/_coerce.py @@ -0,0 +1,25 @@ +"""Coercion functions for various data types.""" + +from copy import deepcopy + + +def _coerce_to_list(obj): + """Coerce object to list. + + None is coerced to empty list, otherwise list constructor is used. + """ + if obj is None: + return [] + if isinstance(obj, str): + return [obj] + return list(obj) + + +def _coerce_to_dict(obj): + """Coerce object to dict. + + None is coerce to empty dict, otherwise deepcopy is used. + """ + if obj is None: + return {} + return deepcopy(obj) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py similarity index 90% rename from pytorch_forecasting/data/timeseries.py rename to pytorch_forecasting/data/timeseries/_timeseries.py index fda08d561..263e0ea3a 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -9,7 +9,7 @@ from copy import copy as _copy, deepcopy from functools import lru_cache import inspect -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Optional, Type, TypeVar, Union import warnings import numpy as np @@ -31,6 +31,7 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler +from pytorch_forecasting.data.timeseries._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils import repr_class from pytorch_forecasting.utils._dependencies import _check_matplotlib @@ -2663,286 +2664,3 @@ def __repr__(self) -> str: attributes=self.get_parameters(), extra_attributes=dict(length=len(self)), ) - - -def _coerce_to_list(obj): - """Coerce object to list. - - None is coerced to empty list, otherwise list constructor is used. - """ - if obj is None: - return [] - if isinstance(obj, str): - return [obj] - return list(obj) - - -def _coerce_to_dict(obj): - """Coerce object to dict. - - None is coerce to empty dict, otherwise deepcopy is used. - """ - if obj is None: - return {} - return deepcopy(obj) - - -####################################################################################### -# Disclaimer: This dataset class is still work in progress and experimental, please -# use with care. This class is a basic skeleton of how the data-handling pipeline may -# look like in the future. -# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion -# and turning the data to tensors. -# For now, this pipeline handles the simplest situation: The whole data can be loaded -# into the memory. -####################################################################################### - - -class TimeSeries(Dataset): - """PyTorch Dataset for time series data stored in pandas DataFrame. - - Parameters - ---------- - data : pd.DataFrame - data frame with sequence data. - Column names must all be str, and contain str as referred to below. - data_future : pd.DataFrame, optional, default=None - data frame with future data. - Column names must all be str, and contain str as referred to below. - May contain only columns that are in time, group, weight, known, or static. - time : str, optional, default = first col not in group_ids, weight, target, static. - integer typed column denoting the time index within ``data``. - This column is used to determine the sequence of samples. - If there are no missing observations, - the time index should increase by ``+1`` for each subsequent sample. - The first time_idx for each series does not necessarily - have to be ``0`` but any value is allowed. - target : str or List[str], optional, default = last column (at iloc -1) - column(s) in ``data`` denoting the forecasting target. - Can be categorical or numerical dtype. - group : List[str], optional, default = None - list of column names identifying a time series instance within ``data``. - This means that the ``group`` together uniquely identify an instance, - and ``group`` together with ``time`` uniquely identify a single observation - within a time series instance. - If ``None``, the dataset is assumed to be a single time series. - weight : str, optional, default=None - column name for weights. - If ``None``, it is assumed that there is no weight column. - num : list of str, optional, default = all columns with dtype in "fi" - list of numerical variables in ``data``, - list may also contain list of str, which are then grouped together. - cat : list of str, optional, default = all columns with dtype in "Obc" - list of categorical variables in ``data``, - list may also contain list of str, which are then grouped together - (e.g. useful for product categories). - known : list of str, optional, default = all variables - list of variables that change over time and are known in the future, - list may also contain list of str, which are then grouped together - (e.g. useful for special days or promotion categories). - unknown : list of str, optional, default = no variables - list of variables that are not known in the future, - list may also contain list of str, which are then grouped together - (e.g. useful for weather categories). - static : list of str, optional, default = all variables not in known, unknown - list of variables that do not change over time, - list may also contain list of str, which are then grouped together. - """ - - def __init__( - self, - data: pd.DataFrame, - data_future: Optional[pd.DataFrame] = None, - time: Optional[str] = None, - target: Optional[Union[str, List[str]]] = None, - group: Optional[List[str]] = None, - weight: Optional[str] = None, - num: Optional[List[Union[str, List[str]]]] = None, - cat: Optional[List[Union[str, List[str]]]] = None, - known: Optional[List[Union[str, List[str]]]] = None, - unknown: Optional[List[Union[str, List[str]]]] = None, - static: Optional[List[Union[str, List[str]]]] = None, - ): - - self.data = data - self.data_future = data_future - self.time = time - self.target = _coerce_to_list(target) - self.group = _coerce_to_list(group) - self.weight = weight - self.num = _coerce_to_list(num) - self.cat = _coerce_to_list(cat) - self.known = _coerce_to_list(known) - self.unknown = _coerce_to_list(unknown) - self.static = _coerce_to_list(static) - - self.feature_cols = [ - col - for col in data.columns - if col not in [self.time] + self.group + [self.weight] + self.target - ] - if self.group: - self._groups = self.data.groupby(self.group).groups - self._group_ids = list(self._groups.keys()) - else: - self._groups = {"_single_group": self.data.index} - self._group_ids = ["_single_group"] - - self._prepare_metadata() - - def _prepare_metadata(self): - """Prepare metadata for the dataset. - - The funcion returns metadata that contains: - - * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } - Names of columns for y, x, and static features. - List elements are in same order as column dimensions. - Columns not appearing are assumed to be named (x0, x1, etc.), - (y0, y1, etc.), (st0, st1, etc.). - * ``col_type``: dict[str, str] - maps column names to data types "F" (numerical) and "C" (categorical). - Column names not occurring are assumed "F". - * ``col_known``: dict[str, str] - maps column names to "K" (future known) or "U" (future unknown). - Column names not occurring are assumed "K". - """ - self.metadata = { - "cols": { - "y": self.target, - "x": self.feature_cols, - "st": self.static, - }, - "col_type": {}, - "col_known": {}, - } - - all_cols = self.target + self.feature_cols + self.static - for col in all_cols: - self.metadata["col_type"][col] = "C" if col in self.cat else "F" - - self.metadata["col_known"][col] = "K" if col in self.known else "U" - - def __len__(self) -> int: - """Return number of time series in the dataset.""" - return len(self._group_ids) - - def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: - """Get time series data for given index. - - Returns - ------- - t : numpy.ndarray of shape (n_timepoints,) - Time index for each time point in the past or present. Aligned with `y`, - and `x` not ending in `f`. - - y : torch.Tensor of shape (n_timepoints, n_targets) - Target values for each time point. Rows are time points, aligned with `t`. - - x : torch.Tensor of shape (n_timepoints, n_features) - Features for each time point. Rows are time points, aligned with `t`. - - group : torch.Tensor of shape (n_groups,) - Group identifiers for time series instances. - - st : torch.Tensor of shape (n_static_features,) - Static features. - - cutoff_time : float or numpy.float64 - Cutoff time for the time series instance. - - Other Returns - ------------- - weights : torch.Tensor of shape (n_timepoints,), optional - Only included if weights are not `None`. - """ - group_id = self._group_ids[index] - - if self.group: - mask = self._groups[group_id] - data = self.data.loc[mask] - else: - data = self.data - - cutoff_time = data[self.time].max() - - result = { - "t": data[self.time].values, - "y": torch.tensor(data[self.target].values), - "x": torch.tensor(data[self.feature_cols].values), - "group": torch.tensor([hash(str(group_id))]), - "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), - "cutoff_time": cutoff_time, - } - - if self.data_future is not None: - if self.group: - future_mask = self.data_future.groupby(self.group).groups[group_id] - future_data = self.data_future.loc[future_mask] - else: - future_data = self.data_future - - combined_times = np.concatenate( - [data[self.time].values, future_data[self.time].values] - ) - combined_times = np.unique(combined_times) - combined_times.sort() - - num_timepoints = len(combined_times) - x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) - y_merged = np.full((num_timepoints, len(self.target)), np.nan) - - current_time_indices = {t: i for i, t in enumerate(combined_times)} - for i, t in enumerate(data[self.time].values): - idx = current_time_indices[t] - x_merged[idx] = data[self.feature_cols].values[i] - y_merged[idx] = data[self.target].values[i] - - for i, t in enumerate(future_data[self.time].values): - if t in current_time_indices: - idx = current_time_indices[t] - for j, col in enumerate(self.known): - if col in self.feature_cols: - feature_idx = self.feature_cols.index(col) - x_merged[idx, feature_idx] = future_data[col].values[i] - - result.update( - { - "t": combined_times, - "x": torch.tensor(x_merged, dtype=torch.float32), - "y": torch.tensor(y_merged, dtype=torch.float32), - } - ) - - if self.weight: - if self.data_future is not None and self.weight in self.data_future.columns: - weights_merged = np.full(num_timepoints, np.nan) - for i, t in enumerate(data[self.time].values): - idx = current_time_indices[t] - weights_merged[idx] = data[self.weight].values[i] - - for i, t in enumerate(future_data[self.time].values): - if t in current_time_indices and self.weight in future_data.columns: - idx = current_time_indices[t] - weights_merged[idx] = future_data[self.weight].values[i] - - result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) - else: - result["weights"] = torch.tensor( - data[self.weight].values, dtype=torch.float32 - ) - - return result - - def get_metadata(self) -> Dict: - """Return metadata about the dataset. - - Returns - ------- - Dict - Dictionary containing: - - cols: column names for y, x, and static features - - col_type: mapping of columns to their types (F/C) - - col_known: mapping of columns to their future known status (K/U) - """ - return self.metadata diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py new file mode 100644 index 000000000..53bf7228d --- /dev/null +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -0,0 +1,276 @@ +""" +Timeseries dataset - v2 prototype. + +Beta version, experimental - use for testing but not in production. +""" + +from typing import Dict, List, Optional, Union +import warnings + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset + +from pytorch_forecasting.data.timeseries._coerce import _coerce_to_list + + +####################################################################################### +# Disclaimer: This dataset class is still work in progress and experimental, please +# use with care. This class is a basic skeleton of how the data-handling pipeline may +# look like in the future. +# This is the D1 layer that is a "Raw Dataset Layer" mainly for raw data ingestion +# and turning the data to tensors. +# For now, this pipeline handles the simplest situation: The whole data can be loaded +# into the memory. +####################################################################################### + + +class TimeSeries(Dataset): + """PyTorch Dataset for time series data stored in pandas DataFrame. + + Parameters + ---------- + data : pd.DataFrame + data frame with sequence data. + Column names must all be str, and contain str as referred to below. + data_future : pd.DataFrame, optional, default=None + data frame with future data. + Column names must all be str, and contain str as referred to below. + May contain only columns that are in time, group, weight, known, or static. + time : str, optional, default = first col not in group_ids, weight, target, static. + integer typed column denoting the time index within ``data``. + This column is used to determine the sequence of samples. + If there are no missing observations, + the time index should increase by ``+1`` for each subsequent sample. + The first time_idx for each series does not necessarily + have to be ``0`` but any value is allowed. + target : str or List[str], optional, default = last column (at iloc -1) + column(s) in ``data`` denoting the forecasting target. + Can be categorical or numerical dtype. + group : List[str], optional, default = None + list of column names identifying a time series instance within ``data``. + This means that the ``group`` together uniquely identify an instance, + and ``group`` together with ``time`` uniquely identify a single observation + within a time series instance. + If ``None``, the dataset is assumed to be a single time series. + weight : str, optional, default=None + column name for weights. + If ``None``, it is assumed that there is no weight column. + num : list of str, optional, default = all columns with dtype in "fi" + list of numerical variables in ``data``, + list may also contain list of str, which are then grouped together. + cat : list of str, optional, default = all columns with dtype in "Obc" + list of categorical variables in ``data``, + list may also contain list of str, which are then grouped together + (e.g. useful for product categories). + known : list of str, optional, default = all variables + list of variables that change over time and are known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for special days or promotion categories). + unknown : list of str, optional, default = no variables + list of variables that are not known in the future, + list may also contain list of str, which are then grouped together + (e.g. useful for weather categories). + static : list of str, optional, default = all variables not in known, unknown + list of variables that do not change over time, + list may also contain list of str, which are then grouped together. + """ + + def __init__( + self, + data: pd.DataFrame, + data_future: Optional[pd.DataFrame] = None, + time: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group: Optional[List[str]] = None, + weight: Optional[str] = None, + num: Optional[List[Union[str, List[str]]]] = None, + cat: Optional[List[Union[str, List[str]]]] = None, + known: Optional[List[Union[str, List[str]]]] = None, + unknown: Optional[List[Union[str, List[str]]]] = None, + static: Optional[List[Union[str, List[str]]]] = None, + ): + + self.data = data + self.data_future = data_future + self.time = time + self.target = _coerce_to_list(target) + self.group = _coerce_to_list(group) + self.weight = weight + self.num = _coerce_to_list(num) + self.cat = _coerce_to_list(cat) + self.known = _coerce_to_list(known) + self.unknown = _coerce_to_list(unknown) + self.static = _coerce_to_list(static) + + self.feature_cols = [ + col + for col in data.columns + if col not in [self.time] + self.group + [self.weight] + self.target + ] + if self.group: + self._groups = self.data.groupby(self.group).groups + self._group_ids = list(self._groups.keys()) + else: + self._groups = {"_single_group": self.data.index} + self._group_ids = ["_single_group"] + + self._prepare_metadata() + + def _prepare_metadata(self): + """Prepare metadata for the dataset. + + The funcion returns metadata that contains: + + * ``cols``: dict { 'y': list[str], 'x': list[str], 'st': list[str] } + Names of columns for y, x, and static features. + List elements are in same order as column dimensions. + Columns not appearing are assumed to be named (x0, x1, etc.), + (y0, y1, etc.), (st0, st1, etc.). + * ``col_type``: dict[str, str] + maps column names to data types "F" (numerical) and "C" (categorical). + Column names not occurring are assumed "F". + * ``col_known``: dict[str, str] + maps column names to "K" (future known) or "U" (future unknown). + Column names not occurring are assumed "K". + """ + self.metadata = { + "cols": { + "y": self.target, + "x": self.feature_cols, + "st": self.static, + }, + "col_type": {}, + "col_known": {}, + } + + all_cols = self.target + self.feature_cols + self.static + for col in all_cols: + self.metadata["col_type"][col] = "C" if col in self.cat else "F" + + self.metadata["col_known"][col] = "K" if col in self.known else "U" + + def __len__(self) -> int: + """Return number of time series in the dataset.""" + return len(self._group_ids) + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get time series data for given index. + + Returns + ------- + t : numpy.ndarray of shape (n_timepoints,) + Time index for each time point in the past or present. Aligned with `y`, + and `x` not ending in `f`. + + y : torch.Tensor of shape (n_timepoints, n_targets) + Target values for each time point. Rows are time points, aligned with `t`. + + x : torch.Tensor of shape (n_timepoints, n_features) + Features for each time point. Rows are time points, aligned with `t`. + + group : torch.Tensor of shape (n_groups,) + Group identifiers for time series instances. + + st : torch.Tensor of shape (n_static_features,) + Static features. + + cutoff_time : float or numpy.float64 + Cutoff time for the time series instance. + + Other Returns + ------------- + weights : torch.Tensor of shape (n_timepoints,), optional + Only included if weights are not `None`. + """ + group_id = self._group_ids[index] + + if self.group: + mask = self._groups[group_id] + data = self.data.loc[mask] + else: + data = self.data + + cutoff_time = data[self.time].max() + + result = { + "t": data[self.time].values, + "y": torch.tensor(data[self.target].values), + "x": torch.tensor(data[self.feature_cols].values), + "group": torch.tensor([hash(str(group_id))]), + "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + "cutoff_time": cutoff_time, + } + + if self.data_future is not None: + if self.group: + future_mask = self.data_future.groupby(self.group).groups[group_id] + future_data = self.data_future.loc[future_mask] + else: + future_data = self.data_future + + combined_times = np.concatenate( + [data[self.time].values, future_data[self.time].values] + ) + combined_times = np.unique(combined_times) + combined_times.sort() + + num_timepoints = len(combined_times) + x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(self.target)), np.nan) + + current_time_indices = {t: i for i, t in enumerate(combined_times)} + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + x_merged[idx] = data[self.feature_cols].values[i] + y_merged[idx] = data[self.target].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices: + idx = current_time_indices[t] + for j, col in enumerate(self.known): + if col in self.feature_cols: + feature_idx = self.feature_cols.index(col) + x_merged[idx, feature_idx] = future_data[col].values[i] + + result.update( + { + "t": combined_times, + "x": torch.tensor(x_merged, dtype=torch.float32), + "y": torch.tensor(y_merged, dtype=torch.float32), + } + ) + + if self.weight: + if self.data_future is not None and self.weight in self.data_future.columns: + weights_merged = np.full(num_timepoints, np.nan) + for i, t in enumerate(data[self.time].values): + idx = current_time_indices[t] + weights_merged[idx] = data[self.weight].values[i] + + for i, t in enumerate(future_data[self.time].values): + if t in current_time_indices and self.weight in future_data.columns: + idx = current_time_indices[t] + weights_merged[idx] = future_data[self.weight].values[i] + + result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) + else: + result["weights"] = torch.tensor( + data[self.weight].values, dtype=torch.float32 + ) + + return result + + def get_metadata(self) -> Dict: + """Return metadata about the dataset. + + Returns + ------- + Dict + Dictionary containing: + - cols: column names for y, x, and static features + - col_type: mapping of columns to their types (F/C) + - col_known: mapping of columns to their future known status (K/U) + """ + return self.metadata From 043820dd3be3041a019fd9cd2cb1e681d25a79a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Wed, 30 Apr 2025 18:43:50 +0200 Subject: [PATCH 33/80] warning --- .../data/timeseries/_timeseries_v2.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 53bf7228d..1c91d2525 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -104,6 +104,18 @@ def __init__( self.unknown = _coerce_to_list(unknown) self.static = _coerce_to_list(static) + warnings.warn( + "TimeSeries is part of an experimental rework of the " + "pytorch-forecasting data layer, " + "scheduled for release with v2.0.0. " + "The API is not stable and may change without prior warning. " + "For beta testing, but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) + self.feature_cols = [ col for col in data.columns From 1720a15e9cff3e5c3ebcd0bf3ec03995d068e4b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 13:58:09 +0200 Subject: [PATCH 34/80] linting --- pytorch_forecasting/data/timeseries/__init__.py | 2 +- pytorch_forecasting/data/timeseries/_timeseries_v2.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py index 7734cccf2..85973267a 100644 --- a/pytorch_forecasting/data/timeseries/__init__.py +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -1,7 +1,7 @@ """Data loaders for time series data.""" -from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries from pytorch_forecasting.data.timeseries._timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries __all__ = [ "TimeSeriesDataSet", diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 1c91d2525..76972ab4d 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -14,7 +14,6 @@ from pytorch_forecasting.data.timeseries._coerce import _coerce_to_list - ####################################################################################### # Disclaimer: This dataset class is still work in progress and experimental, please # use with care. This class is a basic skeleton of how the data-handling pipeline may From af44474d16b3fcdf5e99acb4b9d1f7345119d8cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:21:58 +0200 Subject: [PATCH 35/80] move coercion to utils --- pytorch_forecasting/data/data_module.py | 6 ++---- pytorch_forecasting/{data/timeseries => utils}/_coerce.py | 0 2 files changed, 2 insertions(+), 4 deletions(-) rename pytorch_forecasting/{data/timeseries => utils}/_coerce.py (100%) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 1203e83ac..9d3ebbedb 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -19,10 +19,8 @@ NaNLabelEncoder, TorchNormalizer, ) -from pytorch_forecasting.data.timeseries import ( - TimeSeries, - _coerce_to_dict, -) +from pytorch_forecasting.data.timeseries import TimeSeries +from pytorch_forecasting.utils._coerce import _coerce_to_dict NORMALIZER = Union[TorchNormalizer, NaNLabelEncoder, EncoderNormalizer] diff --git a/pytorch_forecasting/data/timeseries/_coerce.py b/pytorch_forecasting/utils/_coerce.py similarity index 100% rename from pytorch_forecasting/data/timeseries/_coerce.py rename to pytorch_forecasting/utils/_coerce.py From a3cb8b736b0b134c8faa97f5ef2993deb28fb75b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:22:18 +0200 Subject: [PATCH 36/80] linting --- pytorch_forecasting/data/timeseries/_timeseries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries.py b/pytorch_forecasting/data/timeseries/_timeseries.py index 263e0ea3a..30fe9e0bb 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries.py +++ b/pytorch_forecasting/data/timeseries/_timeseries.py @@ -31,8 +31,8 @@ TorchNormalizer, ) from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler -from pytorch_forecasting.data.timeseries._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils import repr_class +from pytorch_forecasting.utils._coerce import _coerce_to_dict, _coerce_to_list from pytorch_forecasting.utils._dependencies import _check_matplotlib From 75d7fb54d8405ef493197c5a4d2fc86a5e9e9d5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:25:51 +0200 Subject: [PATCH 37/80] Update _timeseries_v2.py --- pytorch_forecasting/data/timeseries/_timeseries_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 76972ab4d..afa45725b 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -12,7 +12,7 @@ import torch from torch.utils.data import Dataset -from pytorch_forecasting.data.timeseries._coerce import _coerce_to_list +from pytorch_forecasting.utils._coerce import _coerce_to_list ####################################################################################### # Disclaimer: This dataset class is still work in progress and experimental, please From 1b946e699be9db2e201a2361779a695356a0460b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:30:13 +0200 Subject: [PATCH 38/80] Update __init__.py --- pytorch_forecasting/data/timeseries/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py index 85973267a..b359a0aa9 100644 --- a/pytorch_forecasting/data/timeseries/__init__.py +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -1,9 +1,15 @@ """Data loaders for time series data.""" -from pytorch_forecasting.data.timeseries._timeseries import TimeSeriesDataSet +from pytorch_forecasting.data.timeseries._timeseries import ( + _find_end_indices, + check_for_nonfinite, + TimeSeriesDataSet, +) from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries __all__ = [ + "_find_end_indices", + "check_for_nonfinite", "TimeSeriesDataSet", "TimeSeries", ] From 3edb08b7ea1b97d06b47b0ebcc83aaef9bec8083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Thu, 1 May 2025 14:33:17 +0200 Subject: [PATCH 39/80] Update __init__.py --- pytorch_forecasting/data/timeseries/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/__init__.py b/pytorch_forecasting/data/timeseries/__init__.py index b359a0aa9..788c08201 100644 --- a/pytorch_forecasting/data/timeseries/__init__.py +++ b/pytorch_forecasting/data/timeseries/__init__.py @@ -1,9 +1,9 @@ """Data loaders for time series data.""" from pytorch_forecasting.data.timeseries._timeseries import ( + TimeSeriesDataSet, _find_end_indices, check_for_nonfinite, - TimeSeriesDataSet, ) from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries From a6691341001f813ac4c5d12aafb173645087d679 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 4 May 2025 18:28:04 +0200 Subject: [PATCH 40/80] Update _lookup.py --- pytorch_forecasting/_registry/_lookup.py | 49 ++++++++++++++++-------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index 828c448b1..b4238f980 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -11,6 +11,7 @@ __author__ = ["fkiraly"] # all_objects is based on the sklearn utility all_estimators +from inspect import isclass from pathlib import Path from skbase.lookup import all_objects as _all_objects @@ -133,25 +134,39 @@ def all_objects( result = [] ROOT = str(Path(__file__).parent.parent) # package root directory - if isinstance(filter_tags, str): - filter_tags = {filter_tags: True} - filter_tags = filter_tags.copy() if filter_tags else None - - if object_types: - if filter_tags and "object_type" not in filter_tags.keys(): - object_tag_filter = {"object_type": object_types} - elif filter_tags: - filter_tags_filter = filter_tags.get("object_type", []) - if isinstance(object_types, str): - object_types = [object_types] - object_tag_update = {"object_type": object_types + filter_tags_filter} - filter_tags.update(object_tag_update) + def _coerce_to_str(obj): + if isinstance(obj, (list, tuple)): + return [_coerce_to_str(o) for o in obj] + if isclass(obj): + obj = obj.get_tag("object_type") + return obj + + def _coerce_to_list_of_str(obj): + obj = _coerce_to_str(obj) + if isinstance(obj, str): + return [obj] + return obj + + if object_types is not None: + object_types = _coerce_to_list_of_str(object_types) + object_types = list(set(object_types)) + + if object_types is not None: + if filter_tags is None: + filter_tags = {} + elif isinstance(filter_tags, str): + filter_tags = {filter_tags: True} else: - object_tag_filter = {"object_type": object_types} - if filter_tags: - filter_tags.update(object_tag_filter) + filter_tags = filter_tags.copy() + + if "object_type" in filter_tags: + obj_field = filter_tags["object_type"] + obj_field = _coerce_to_list_of_str(obj_field) + obj_field = obj_field + object_types else: - filter_tags = object_tag_filter + obj_field = object_types + + filter_tags["object_type"] = obj_field result = _all_objects( object_types=[_BaseObject], From d78bf5dc19cef1e659e8258552691c1713b2dd4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 4 May 2025 18:32:38 +0200 Subject: [PATCH 41/80] Update _lookup.py --- pytorch_forecasting/_registry/_lookup.py | 95 +++++++++++++++--------- 1 file changed, 60 insertions(+), 35 deletions(-) diff --git a/pytorch_forecasting/_registry/_lookup.py b/pytorch_forecasting/_registry/_lookup.py index b4238f980..0fb4c0c9d 100644 --- a/pytorch_forecasting/_registry/_lookup.py +++ b/pytorch_forecasting/_registry/_lookup.py @@ -40,44 +40,64 @@ def all_objects( ---------- object_types: str, list of str, optional (default=None) Which kind of objects should be returned. - if None, no filter is applied and all objects are returned. - if str or list of str, strings define scitypes specified in search - only objects that are of (at least) one of the scitypes are returned - possible str values are entries of registry.BASE_CLASS_REGISTER (first col) - for instance 'regrssor_proba', 'distribution, 'metric' - return_names: bool, optional (default=True) + * if None, no filter is applied and all objects are returned. + * if str or list of str, strings define scitypes specified in search + only objects that are of (at least) one of the scitypes are returned - if True, estimator class name is included in the ``all_objects`` - return in the order: name, estimator class, optional tags, either as - a tuple or as pandas.DataFrame columns + return_names: bool, optional (default=True) - if False, estimator class name is removed from the ``all_objects`` return. + * if True, estimator class name is included in the ``all_objects`` + return in the order: name, estimator class, optional tags, either as + a tuple or as pandas.DataFrame columns + * if False, estimator class name is removed from the ``all_objects`` return. - filter_tags: dict of (str or list of str), optional (default=None) + filter_tags: dict of (str or list of str or re.Pattern), optional (default=None) For a list of valid tag strings, use the registry.all_tags utility. - ``filter_tags`` subsets the returned estimators as follows: + ``filter_tags`` subsets the returned objects as follows: * each key/value pair is statement in "and"/conjunction * key is tag name to sub-set on * value str or list of string are tag values * condition is "key must be equal to value, or in set(value)" - exclude_estimators: str, list of str, optional (default=None) - Names of estimators to exclude. + In detail, he return will be filtered to keep exactly the classes + where tags satisfy all the filter conditions specified by ``filter_tags``. + Filter conditions are as follows, for ``tag_name: search_value`` pairs in + the ``filter_tags`` dict, applied to a class ``klass``: + + - If ``klass`` does not have a tag with name ``tag_name``, it is excluded. + Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``. + - If ``search_value`` is a string, and ``tag_value`` is a string, + the filter condition is that ``search_value`` must match the tag value. + - If ``search_value`` is a string, and ``tag_value`` is a list, + the filter condition is that ``search_value`` is contained in ``tag_value``. + - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string, + the filter condition is that ``search_value.fullmatch(tag_value)`` + is true, i.e., the regex matches the tag value. + - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list, + the filter condition is that at least one element of ``tag_value`` + matches the regex. + - If ``search_value`` is iterable, then the filter condition is that + at least one element of ``search_value`` satisfies the above conditions, + applied to ``tag_value``. + + Note: ``re.Pattern`` is supported only from ``scikit-base`` version 0.8.0. + + exclude_objects: str, list of str, optional (default=None) + Names of objects to exclude. as_dataframe: bool, optional (default=False) - True: ``all_objects`` will return a pandas.DataFrame with named - columns for all of the attributes being returned. - - False: ``all_objects`` will return a list (either a list of - estimators or a list of tuples, see Returns) + * True: ``all_objects`` will return a ``pandas.DataFrame`` with named + columns for all of the attributes being returned. + * False: ``all_objects`` will return a list (either a list of + objects or a list of tuples, see Returns) return_tags: str or list of str, optional (default=None) Names of tags to fetch and return each estimator's value of. - For a list of valid tag strings, use the registry.all_tags utility. + For a list of valid tag strings, use the ``registry.all_tags`` utility. if str or list of str, the tag values named in return_tags will be fetched for each estimator and will be appended as either columns or tuple entries. @@ -88,27 +108,32 @@ def all_objects( Returns ------- all_objects will return one of the following: - 1. list of objects, if return_names=False, and return_tags is None - 2. list of tuples (optional object name, class, ~optional object - tags), if return_names=True or return_tags is not None. - 3. pandas.DataFrame if as_dataframe = True + + 1. list of objects, if ``return_names=False``, and ``return_tags`` is None + + 2. list of tuples (optional estimator name, class, optional estimator + tags), if ``return_names=True`` or ``return_tags`` is not ``None``. + + 3. ``pandas.DataFrame`` if ``as_dataframe = True`` + if list of objects: entries are objects matching the query, - in alphabetical order of object name + in alphabetical order of estimator name + if list of tuples: - list of (optional object name, object, optional object - tags) matching the query, in alphabetical order of object name, + list of (optional estimator name, estimator, optional estimator + tags) matching the query, in alphabetical order of estimator name, where - ``name`` is the object name as string, and is an - optional return - ``object`` is the actual object - ``tags`` are the object's values for each tag in return_tags - and is an optional return. - if dataframe: - all_objects will return a pandas.DataFrame. + ``name`` is the estimator name as string, and is an + optional return + ``estimator`` is the actual estimator + ``tags`` are the estimator's values for each tag in return_tags + and is an optional return. + + if ``DataFrame``: column names represent the attributes contained in each column. "objects" will be the name of the column of objects, "names" - will be the name of the column of object class names and the string(s) + will be the name of the column of estimator class names and the string(s) passed in return_tags will serve as column names for all columns of tags that were optionally requested. From e350291c110f567e69946e0e113f2471b7472738 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 11 May 2025 22:10:01 +0530 Subject: [PATCH 42/80] update tests --- tests/test_data/test_data_module.py | 72 ++++++++++++++--------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index c14e3d8f4..4051b852c 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -9,7 +9,7 @@ @pytest.fixture def sample_timeseries_data(): """Create a sample time series dataset with only numerical values.""" - num_groups = 5 + num_groups = 10 seq_length = 100 groups = [] @@ -128,22 +128,22 @@ def test_metadata_property(data_module): assert metadata["decoder_cont"] == 1 # Only known_future marked as known -# def test_setup(data_module): -# """Test the setup method that prepares the datasets.""" -# data_module.setup(stage="fit") -# print(data_module._val_indices) -# assert hasattr(data_module, "train_dataset") -# assert hasattr(data_module, "val_dataset") -# assert len(data_module.train_windows) > 0 -# assert len(data_module.val_windows) > 0 -# -# data_module.setup(stage="test") -# assert hasattr(data_module, "test_dataset") -# assert len(data_module.test_windows) > 0 -# -# data_module.setup(stage="predict") -# assert hasattr(data_module, "predict_dataset") -# assert len(data_module.predict_windows) > 0 +def test_setup(data_module): + """Test the setup method that prepares the datasets.""" + data_module.setup(stage="fit") + print(data_module._val_indices) + assert hasattr(data_module, "train_dataset") + assert hasattr(data_module, "val_dataset") + assert len(data_module.train_windows) > 0 + assert len(data_module.val_windows) > 0 + + data_module.setup(stage="test") + assert hasattr(data_module, "test_dataset") + assert len(data_module.test_windows) > 0 + + data_module.setup(stage="predict") + assert hasattr(data_module, "predict_dataset") + assert len(data_module.predict_windows) > 0 def test_create_windows(data_module): @@ -407,25 +407,25 @@ def test_with_static_features(): assert "static_continuous_features" in x -# def test_different_train_val_test_split(sample_timeseries_data): -# """Test with different train/val/test split ratios.""" -# dm = EncoderDecoderTimeSeriesDataModule( -# time_series_dataset=sample_timeseries_data, -# max_encoder_length=24, -# max_prediction_length=12, -# batch_size=4, -# train_val_test_split=(0.8, 0.1, 0.1), -# ) -# -# dm.setup() -# -# total_series = len(sample_timeseries_data) -# expected_train = int(0.8 * total_series) -# expected_val = int(0.1 * total_series) -# -# assert len(dm._train_indices) == expected_train -# assert len(dm._val_indices) == expected_val -# assert len(dm._test_indices) == total_series - expected_train - expected_val +def test_different_train_val_test_split(sample_timeseries_data): + """Test with different train/val/test split ratios.""" + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=sample_timeseries_data, + max_encoder_length=24, + max_prediction_length=12, + batch_size=4, + train_val_test_split=(0.8, 0.1, 0.1), + ) + + dm.setup() + + total_series = len(sample_timeseries_data) + expected_train = int(0.8 * total_series) + expected_val = int(0.1 * total_series) + + assert len(dm._train_indices) == expected_train + assert len(dm._val_indices) == expected_val + assert len(dm._test_indices) == total_series - expected_train - expected_val def test_multivariate_target(): From 3099691d3cc792bd528f50ff3c51a0fa4a9ce28a Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 12 May 2025 00:22:27 +0530 Subject: [PATCH 43/80] update tft_v2 --- .../tft_version_two.py | 65 +++++++++++-------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py index 30f70f98e..2bfe407d7 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py @@ -36,6 +36,8 @@ def __init__( lr_scheduler=lr_scheduler, lr_scheduler_params=lr_scheduler_params, ) + self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"]) + self.hidden_size = hidden_size self.num_layers = num_layers self.attention_head_size = attention_head_size @@ -47,42 +49,51 @@ def __init__( self.max_prediction_length = self.metadata["max_prediction_length"] self.encoder_cont = self.metadata["encoder_cont"] self.encoder_cat = self.metadata["encoder_cat"] - self.static_categorical_features = self.metadata["static_categorical_features"] - self.static_continuous_features = self.metadata["static_continuous_features"] - - total_feature_size = self.encoder_cont + self.encoder_cat - total_static_size = ( - self.static_categorical_features + self.static_continuous_features - ) - - self.encoder_var_selection = nn.Sequential( - nn.Linear(total_feature_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, total_feature_size), - nn.Sigmoid(), - ) - - self.decoder_var_selection = nn.Sequential( - nn.Linear(total_feature_size, hidden_size), - nn.ReLU(), - nn.Linear(hidden_size, total_feature_size), - nn.Sigmoid(), - ) + self.encoder_input_dim = self.encoder_cont + self.encoder_cat + self.decoder_cont = self.metadata["decoder_cont"] + self.decoder_cat = self.metadata["decoder_cat"] + self.decoder_input_dim = self.decoder_cont + self.decoder_cat + self.static_cat_dim = self.metadata.get("static_categorical_features", 0) + self.static_cont_dim = self.metadata.get("static_continuous_features", 0) + self.static_input_dim = self.static_cat_dim + self.static_cont_dim + + if self.encoder_input_dim > 0: + self.encoder_var_selection = nn.Sequential( + nn.Linear(self.encoder_input_dim, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, self.encoder_input_dim), + nn.Sigmoid(), + ) + else: + self.encoder_var_selection = None + + if self.decoder_input_dim > 0: + self.decoder_var_selection = nn.Sequential( + nn.Linear(self.decoder_input_dim, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, self.decoder_input_dim), + nn.Sigmoid(), + ) + else: + self.decoder_var_selection = None - self.static_context_linear = ( - nn.Linear(total_static_size, hidden_size) if total_static_size > 0 else None - ) + if self.static_input_dim > 0: + self.static_context_linear = nn.Linear(self.static_input_dim, hidden_size) + else: + self.static_context_linear = None + _lstm_encoder_input_actual_dim = self.encoder_input_dim self.lstm_encoder = nn.LSTM( - input_size=total_feature_size, + input_size=max(1, _lstm_encoder_input_actual_dim), hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True, ) + _lstm_decoder_input_actual_dim = self.decoder_input_dim self.lstm_decoder = nn.LSTM( - input_size=total_feature_size, + input_size=max(1, _lstm_decoder_input_actual_dim), hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, @@ -97,7 +108,7 @@ def __init__( ) self.pre_output = nn.Linear(hidden_size, hidden_size) - self.output_layer = nn.Linear(hidden_size, output_size) + self.output_layer = nn.Linear(hidden_size, self.output_size) def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ From 77cb979808d83cbcfb4e7c3ed5ffd888c0828d31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:14:03 +0200 Subject: [PATCH 44/80] warnings and init attr handling --- pytorch_forecasting/data/data_module.py | 44 +++++++++---- .../data/timeseries/_timeseries_v2.py | 61 +++++++++++-------- 2 files changed, 67 insertions(+), 38 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 9d3ebbedb..690fb6057 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -8,6 +8,7 @@ ####################################################################################### from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn from lightning.pytorch import LightningDataModule from sklearn.preprocessing import RobustScaler, StandardScaler @@ -107,33 +108,50 @@ def __init__( num_workers: int = 0, train_val_test_split: tuple = (0.7, 0.15, 0.15), ): - super().__init__() - self.time_series_dataset = time_series_dataset - self.time_series_metadata = time_series_dataset.get_metadata() + self.time_series_dataset = time_series_dataset self.max_encoder_length = max_encoder_length - self.min_encoder_length = min_encoder_length or max_encoder_length + self.min_encoder_length = min_encoder_length self.max_prediction_length = max_prediction_length - self.min_prediction_length = min_prediction_length or max_prediction_length + self.min_prediction_length = min_prediction_length self.min_prediction_idx = min_prediction_idx - self.allow_missing_timesteps = allow_missing_timesteps self.add_relative_time_idx = add_relative_time_idx self.add_target_scales = add_target_scales self.add_encoder_length = add_encoder_length self.randomize_length = randomize_length - + self.target_normalizer = target_normalizer + self.categorical_encoders = categorical_encoders + self.scalers = scalers self.batch_size = batch_size self.num_workers = num_workers self.train_val_test_split = train_val_test_split + warn( + "TimeSeries is part of an experimental rework of the " + "pytorch-forecasting data layer, " + "scheduled for release with v2.0.0. " + "The API is not stable and may change without prior warning. " + "For beta testing, but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) + + super().__init__() + + # handle defaults and derived attributes if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto": - self.target_normalizer = RobustScaler() + self._target_normalizer = RobustScaler() else: - self.target_normalizer = target_normalizer + self._target_normalizer = target_normalizer - self.categorical_encoders = _coerce_to_dict(categorical_encoders) - self.scalers = _coerce_to_dict(scalers) + self.time_series_metadata = time_series_dataset.get_metadata() + self._min_prediction_length = min_prediction_length or max_prediction_length + self._min_encoder_length = min_encoder_length or max_encoder_length + self._categorical_encoders = _coerce_to_dict(categorical_encoders) + self._scalers = _coerce_to_dict(scalers) self.categorical_indices = [] self.continuous_indices = [] @@ -237,8 +255,8 @@ def _prepare_metadata(self): { "max_encoder_length": self.max_encoder_length, "max_prediction_length": self.max_prediction_length, - "min_encoder_length": self.min_encoder_length, - "min_prediction_length": self.min_prediction_length, + "min_encoder_length": self._min_encoder_length, + "min_prediction_length": self._min_prediction_length, } ) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index afa45725b..1f0ba6820 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -5,7 +5,7 @@ """ from typing import Dict, List, Optional, Union -import warnings +from warnings import warn import numpy as np import pandas as pd @@ -94,16 +94,16 @@ def __init__( self.data = data self.data_future = data_future self.time = time - self.target = _coerce_to_list(target) - self.group = _coerce_to_list(group) + self.target = target + self.group = group self.weight = weight - self.num = _coerce_to_list(num) - self.cat = _coerce_to_list(cat) - self.known = _coerce_to_list(known) - self.unknown = _coerce_to_list(unknown) - self.static = _coerce_to_list(static) + self.num = num + self.cat = cat + self.known = known + self.unknown = unknown + self.static = static - warnings.warn( + warn( "TimeSeries is part of an experimental rework of the " "pytorch-forecasting data layer, " "scheduled for release with v2.0.0. " @@ -115,13 +115,24 @@ def __init__( UserWarning, ) + super.__init__() + + # handle defaults, coercion, and derived attributes + self._target = _coerce_to_list(target) + self._group = _coerce_to_list(group) + self._num = _coerce_to_list(num) + self._cat = _coerce_to_list(cat) + self._known = _coerce_to_list(known) + self._unknown = _coerce_to_list(unknown) + self._static = _coerce_to_list(static) + self.feature_cols = [ col for col in data.columns - if col not in [self.time] + self.group + [self.weight] + self.target + if col not in [self.time] + self._group + [self.weight] + self._target ] - if self.group: - self._groups = self.data.groupby(self.group).groups + if self._group: + self._groups = self.data.groupby(self._group).groups self._group_ids = list(self._groups.keys()) else: self._groups = {"_single_group": self.data.index} @@ -148,19 +159,19 @@ def _prepare_metadata(self): """ self.metadata = { "cols": { - "y": self.target, + "y": self._target, "x": self.feature_cols, - "st": self.static, + "st": self._static, }, "col_type": {}, "col_known": {}, } - all_cols = self.target + self.feature_cols + self.static + all_cols = self._target + self.feature_cols + self._static for col in all_cols: - self.metadata["col_type"][col] = "C" if col in self.cat else "F" + self.metadata["col_type"][col] = "C" if col in self._cat else "F" - self.metadata["col_known"][col] = "K" if col in self.known else "U" + self.metadata["col_known"][col] = "K" if col in self._known else "U" def __len__(self) -> int: """Return number of time series in the dataset.""" @@ -197,7 +208,7 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: """ group_id = self._group_ids[index] - if self.group: + if self._group: mask = self._groups[group_id] data = self.data.loc[mask] else: @@ -207,16 +218,16 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: result = { "t": data[self.time].values, - "y": torch.tensor(data[self.target].values), + "y": torch.tensor(data[self._target].values), "x": torch.tensor(data[self.feature_cols].values), "group": torch.tensor([hash(str(group_id))]), - "st": torch.tensor(data[self.static].iloc[0].values if self.static else []), + "st": torch.tensor(data[self._static].iloc[0].values if self._static else []), "cutoff_time": cutoff_time, } if self.data_future is not None: - if self.group: - future_mask = self.data_future.groupby(self.group).groups[group_id] + if self._group: + future_mask = self.data_future.groupby(self._group).groups[group_id] future_data = self.data_future.loc[future_mask] else: future_data = self.data_future @@ -229,18 +240,18 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: num_timepoints = len(combined_times) x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) - y_merged = np.full((num_timepoints, len(self.target)), np.nan) + y_merged = np.full((num_timepoints, len(self._target)), np.nan) current_time_indices = {t: i for i, t in enumerate(combined_times)} for i, t in enumerate(data[self.time].values): idx = current_time_indices[t] x_merged[idx] = data[self.feature_cols].values[i] - y_merged[idx] = data[self.target].values[i] + y_merged[idx] = data[self._target].values[i] for i, t in enumerate(future_data[self.time].values): if t in current_time_indices: idx = current_time_indices[t] - for j, col in enumerate(self.known): + for j, col in enumerate(self._known): if col in self.feature_cols: feature_idx = self.feature_cols.index(col) x_merged[idx, feature_idx] = future_data[col].values[i] From f8c94e626010d165cf022e0fd3f0a22c994759c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:25:53 +0200 Subject: [PATCH 45/80] simplify TimeSeries.__getitem__ --- .../data/timeseries/_timeseries_v2.py | 73 +++++++++++-------- 1 file changed, 44 insertions(+), 29 deletions(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 1f0ba6820..5e24f6454 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -206,54 +206,69 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: weights : torch.Tensor of shape (n_timepoints,), optional Only included if weights are not `None`. """ - group_id = self._group_ids[index] - - if self._group: - mask = self._groups[group_id] + time = self.time + feature_cols = self.feature_cols + _target = self._target + _known = self._known + _static = self._static + _group = self._group + _groups = self._groups + _group_ids = self._group_ids + weight = self.weight + data_future = self.data_future + + group_id = _group_ids[index] + + if _group: + mask = _groups[group_id] data = self.data.loc[mask] else: data = self.data - cutoff_time = data[self.time].max() + cutoff_time = data[time].max() + + data_vals = data[time].values + data_tgt_vals = data[_target].values + data_feat_vals = data[feature_cols].values result = { - "t": data[self.time].values, - "y": torch.tensor(data[self._target].values), - "x": torch.tensor(data[self.feature_cols].values), + "t": data_vals, + "y": torch.tensor(data_tgt_vals), + "x": torch.tensor(data_feat_vals), "group": torch.tensor([hash(str(group_id))]), - "st": torch.tensor(data[self._static].iloc[0].values if self._static else []), + "st": torch.tensor(data[_static].iloc[0].values if _static else []), "cutoff_time": cutoff_time, } - if self.data_future is not None: - if self._group: - future_mask = self.data_future.groupby(self._group).groups[group_id] + if data_future is not None: + if _group: + future_mask = self.data_future.groupby(_group).groups[group_id] future_data = self.data_future.loc[future_mask] else: future_data = self.data_future - combined_times = np.concatenate( - [data[self.time].values, future_data[self.time].values] - ) + data_fut_vals = future_data[time].values + + combined_times = np.concatenate([data_vals, data_fut_vals]) combined_times = np.unique(combined_times) combined_times.sort() num_timepoints = len(combined_times) - x_merged = np.full((num_timepoints, len(self.feature_cols)), np.nan) - y_merged = np.full((num_timepoints, len(self._target)), np.nan) + x_merged = np.full((num_timepoints, len(feature_cols)), np.nan) + y_merged = np.full((num_timepoints, len(_target)), np.nan) current_time_indices = {t: i for i, t in enumerate(combined_times)} - for i, t in enumerate(data[self.time].values): + for i, t in enumerate(data_vals): idx = current_time_indices[t] - x_merged[idx] = data[self.feature_cols].values[i] - y_merged[idx] = data[self._target].values[i] + x_merged[idx] = data_feat_vals[i] + y_merged[idx] = data_tgt_vals[i] - for i, t in enumerate(future_data[self.time].values): + for i, t in enumerate(data_fut_vals): if t in current_time_indices: idx = current_time_indices[t] - for j, col in enumerate(self._known): - if col in self.feature_cols: - feature_idx = self.feature_cols.index(col) + for j, col in enumerate(_known): + if col in feature_cols: + feature_idx = feature_cols.index(col) x_merged[idx, feature_idx] = future_data[col].values[i] result.update( @@ -264,17 +279,17 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: } ) - if self.weight: + if weight: if self.data_future is not None and self.weight in self.data_future.columns: weights_merged = np.full(num_timepoints, np.nan) - for i, t in enumerate(data[self.time].values): + for i, t in enumerate(data_vals): idx = current_time_indices[t] - weights_merged[idx] = data[self.weight].values[i] + weights_merged[idx] = data[weight].values[i] - for i, t in enumerate(future_data[self.time].values): + for i, t in enumerate(data_fut_vals): if t in current_time_indices and self.weight in future_data.columns: idx = current_time_indices[t] - weights_merged[idx] = future_data[self.weight].values[i] + weights_merged[idx] = future_data[weight].values[i] result["weights"] = torch.tensor(weights_merged, dtype=torch.float32) else: From c289255286540b96ddcf5667851f06edf7af0c7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:36:17 +0200 Subject: [PATCH 46/80] Update _timeseries_v2.py --- pytorch_forecasting/data/timeseries/_timeseries_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 5e24f6454..178b273bc 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -115,7 +115,7 @@ def __init__( UserWarning, ) - super.__init__() + super().__init__() # handle defaults, coercion, and derived attributes self._target = _coerce_to_list(target) From 9467f387287f3ba4a56ef1a1a4673c2215deb355 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:44:38 +0200 Subject: [PATCH 47/80] Update data_module.py --- pytorch_forecasting/data/data_module.py | 65 ++++++++++++------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 690fb6057..7b0d45312 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -171,39 +171,38 @@ def _prepare_metadata(self): dict dictionary containing the following keys: - * ``encoder_cat``: Number of categorical variables in the encoder. - Computed as ``len(self.categorical_indices)``, which counts the - categorical feature indices. - * ``encoder_cont``: Number of continuous variables in the encoder. - Computed as ``len(self.continuous_indices)``, which counts the - continuous feature indices. - * ``decoder_cat``: Number of categorical variables in the decoder that - are known in advance. - Computed by filtering ``self.time_series_metadata["cols"]["x"]`` - where col_type == "C"(categorical) and col_known == "K" (known) - * ``decoder_cont``: Number of continuous variables in the decoder that - are known in advance. - Computed by filtering ``self.time_series_metadata["cols"]["x"]`` - where col_type == "F"(continuous) and col_known == "K"(known) - * ``target``: Number of target variables. - Computed as ``len(self.time_series_metadata["cols"]["y"])``, which - gives the number of output target columns.. - * ``static_categorical_features``: Number of static categorical features - Computed by filtering ``self.time_series_metadata["cols"]["st"]`` - (static features) where col_type == "C" (categorical). - * ``static_continuous_features``: Number of static continuous features - Computed as difference of - ``len(self.time_series_metadata["cols"]["st"])`` (static features) - and static_categorical_features that gives static continuous feature - * ``max_encoder_length``: maximum encoder length - Taken directly from `self.max_encoder_length`. - * ``max_prediction_length``: maximum prediction length - Taken directly from `self.max_prediction_length`. - * ``min_encoder_length``: minimum encoder length - Taken directly from `self.min_encoder_length`. - * ``min_prediction_length``: minimum prediction length - Taken directly from `self.min_prediction_length`. - + * ``encoder_cat``: Number of categorical variables in the encoder. + Computed as ``len(self.categorical_indices)``, which counts the + categorical feature indices. + * ``encoder_cont``: Number of continuous variables in the encoder. + Computed as ``len(self.continuous_indices)``, which counts the + continuous feature indices. + * ``decoder_cat``: Number of categorical variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "C"(categorical) and col_known == "K" (known) + * ``decoder_cont``: Number of continuous variables in the decoder that + are known in advance. + Computed by filtering ``self.time_series_metadata["cols"]["x"]`` + where col_type == "F"(continuous) and col_known == "K"(known) + * ``target``: Number of target variables. + Computed as ``len(self.time_series_metadata["cols"]["y"])``, which + gives the number of output target columns.. + * ``static_categorical_features``: Number of static categorical features + Computed by filtering ``self.time_series_metadata["cols"]["st"]`` + (static features) where col_type == "C" (categorical). + * ``static_continuous_features``: Number of static continuous features + Computed as difference of + ``len(self.time_series_metadata["cols"]["st"])`` (static features) + and static_categorical_features that gives static continuous feature + * ``max_encoder_length``: maximum encoder length + Taken directly from `self.max_encoder_length`. + * ``max_prediction_length``: maximum prediction length + Taken directly from `self.max_prediction_length`. + * ``min_encoder_length``: minimum encoder length + Taken directly from `self.min_encoder_length`. + * ``min_prediction_length``: minimum prediction length + Taken directly from `self.min_prediction_length`. """ encoder_cat_count = len(self.categorical_indices) encoder_cont_count = len(self.continuous_indices) From c3b40ad0f3298e84b70b12a050614da3909799e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Tue, 13 May 2025 08:50:43 +0200 Subject: [PATCH 48/80] backwards compat of private/public attrs --- pytorch_forecasting/data/data_module.py | 8 ++++++++ pytorch_forecasting/data/timeseries/_timeseries_v2.py | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/pytorch_forecasting/data/data_module.py b/pytorch_forecasting/data/data_module.py index 7b0d45312..c8252014d 100644 --- a/pytorch_forecasting/data/data_module.py +++ b/pytorch_forecasting/data/data_module.py @@ -163,6 +163,14 @@ def __init__( else: self.continuous_indices.append(idx) + # overwrite __init__ params for upwards compatibility with AS PRs + # todo: should we avoid this and ensure classes are dataclass-like? + self.min_prediction_length = self._min_prediction_length + self.min_encoder_length = self._min_encoder_length + self.categorical_encoders = self._categorical_encoders + self.scalers = self._scalers + self.target_normalizer = self._target_normalizer + def _prepare_metadata(self): """Prepare metadata for model initialisation. diff --git a/pytorch_forecasting/data/timeseries/_timeseries_v2.py b/pytorch_forecasting/data/timeseries/_timeseries_v2.py index 178b273bc..d5ecbcabb 100644 --- a/pytorch_forecasting/data/timeseries/_timeseries_v2.py +++ b/pytorch_forecasting/data/timeseries/_timeseries_v2.py @@ -140,6 +140,16 @@ def __init__( self._prepare_metadata() + # overwrite __init__ params for upwards compatibility with AS PRs + # todo: should we avoid this and ensure classes are dataclass-like? + self.group = self._group + self.target = self._target + self.num = self._num + self.cat = self._cat + self.known = self._known + self.unknown = self._unknown + self.static = self._static + def _prepare_metadata(self): """Prepare metadata for the dataset. From 38c28dc031ecebddca3385bb0f1c58b4423a1b35 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 18:51:05 +0530 Subject: [PATCH 49/80] add tests --- .../tft_version_two.py | 38 +- tests/test_models/test_tft_v2.py | 367 ++++++++++++++++++ 2 files changed, 398 insertions(+), 7 deletions(-) create mode 100644 tests/test_models/test_tft_v2.py diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py index 2bfe407d7..1a1634356 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py @@ -157,11 +157,11 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: if self.static_context_linear is not None: static_cat = x.get( "static_categorical_features", - torch.zeros(batch_size, 0, device=self.device), + torch.zeros(batch_size, 1, 0, device=self.device), ) static_cont = x.get( "static_continuous_features", - torch.zeros(batch_size, 0, device=self.device), + torch.zeros(batch_size, 1, 0, device=self.device), ) if static_cat.size(2) == 0 and static_cont.size(2) == 0: @@ -180,17 +180,41 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: static_context = static_context.view(batch_size, self.hidden_size) else: - static_input = torch.cat([static_cont, static_cat], dim=1).to( + static_input = torch.cat([static_cont, static_cat], dim=2).to( dtype=self.static_context_linear.weight.dtype ) static_context = self.static_context_linear(static_input) static_context = static_context.view(batch_size, self.hidden_size) - encoder_weights = self.encoder_var_selection(encoder_input) - encoder_input = encoder_input * encoder_weights + if self.encoder_var_selection is not None: + encoder_weights = self.encoder_var_selection(encoder_input) + encoder_input = encoder_input * encoder_weights + else: + if self.encoder_input_dim == 0: + encoder_input = torch.zeros( + batch_size, + self.max_encoder_length, + 1, + device=self.device, + dtype=encoder_input.dtype, + ) + else: + encoder_input = encoder_input - decoder_weights = self.decoder_var_selection(decoder_input) - decoder_input = decoder_input * decoder_weights + if self.decoder_var_selection is not None: + decoder_weights = self.decoder_var_selection(decoder_input) + decoder_input = decoder_input * decoder_weights + else: + if self.decoder_input_dim == 0: + decoder_input = torch.zeros( + batch_size, + self.max_prediction_length, + 1, + device=self.device, + dtype=decoder_input.dtype, + ) + else: + decoder_input = decoder_input if static_context is not None: encoder_static_context = static_context.unsqueeze(1).expand( diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py new file mode 100644 index 000000000..e69d3d06d --- /dev/null +++ b/tests/test_models/test_tft_v2.py @@ -0,0 +1,367 @@ +import numpy as np +import pandas as pd +import pytest +import torch +import torch.nn as nn + +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule +from pytorch_forecasting.data.timeseries import TimeSeries +from pytorch_forecasting.models.temporal_fusion_transformer.tft_version_two import TFT + +BATCH_SIZE_TEST = 2 +MAX_ENCODER_LENGTH_TEST = 10 +MAX_PREDICTION_LENGTH_TEST = 5 +HIDDEN_SIZE_TEST = 8 +OUTPUT_SIZE_TEST = 1 +ATTENTION_HEAD_SIZE_TEST = 2 +NUM_LAYERS_TEST = 1 +DROPOUT_TEST = 0.1 + + +def get_default_test_metadata( + enc_cont=2, + enc_cat=1, + dec_cont=1, + dec_cat=1, + static_cat=1, + static_cont=1, + output_size=OUTPUT_SIZE_TEST, +): + return { + "max_encoder_length": MAX_ENCODER_LENGTH_TEST, + "max_prediction_length": MAX_PREDICTION_LENGTH_TEST, + "encoder_cont": enc_cont, + "encoder_cat": enc_cat, + "decoder_cont": dec_cont, + "decoder_cat": dec_cat, + "static_categorical_features": static_cat, + "static_continuous_features": static_cont, + "target": output_size, + } + + +def create_tft_input_batch_for_test(metadata, batch_size=BATCH_SIZE_TEST, device="cpu"): + def _get_dim_val(key): + return metadata.get(key, 0) + + x = { + "encoder_cont": torch.randn( + batch_size, + metadata["max_encoder_length"], + _get_dim_val("encoder_cont"), + device=device, + ), + "encoder_cat": torch.randn( + batch_size, + metadata["max_encoder_length"], + _get_dim_val("encoder_cat"), + device=device, + ), + "decoder_cont": torch.randn( + batch_size, + metadata["max_prediction_length"], + _get_dim_val("decoder_cont"), + device=device, + ), + "decoder_cat": torch.randn( + batch_size, + metadata["max_prediction_length"], + _get_dim_val("decoder_cat"), + device=device, + ), + "static_categorical_features": torch.randn( + batch_size, 1, _get_dim_val("static_categorical_features"), device=device + ), + "static_continuous_features": torch.randn( + batch_size, 1, _get_dim_val("static_continuous_features"), device=device + ), + "encoder_lengths": torch.full( + (batch_size,), + metadata["max_encoder_length"], + dtype=torch.long, + device=device, + ), + "decoder_lengths": torch.full( + (batch_size,), + metadata["max_prediction_length"], + dtype=torch.long, + device=device, + ), + "groups": torch.arange(batch_size, device=device).unsqueeze(1), + "encoder_time_idx": torch.stack( + [torch.arange(metadata["max_encoder_length"], device=device)] * batch_size + ), + "decoder_time_idx": torch.stack( + [ + torch.arange( + metadata["max_encoder_length"], + metadata["max_encoder_length"] + metadata["max_prediction_length"], + device=device, + ) + ] + * batch_size + ), + "target_scale": torch.ones((batch_size, 1), device=device), + } + return x + + +dummy_loss_for_test = nn.MSELoss() + + +@pytest.fixture(scope="module") +def tft_model_params_fixture_func(): + return { + "loss": dummy_loss_for_test, + "hidden_size": HIDDEN_SIZE_TEST, + "num_layers": NUM_LAYERS_TEST, + "attention_head_size": ATTENTION_HEAD_SIZE_TEST, + "dropout": DROPOUT_TEST, + "output_size": OUTPUT_SIZE_TEST, + } + + +class TestTFTInitialization: + def test_basic_initialization(self, tft_model_params_fixture_func): + metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.hidden_size == HIDDEN_SIZE_TEST + assert model.num_layers == NUM_LAYERS_TEST + assert hasattr(model, "metadata") and model.metadata == metadata + assert ( + model.encoder_input_dim + == metadata["encoder_cont"] + metadata["encoder_cat"] + ) + assert ( + model.static_input_dim + == metadata["static_categorical_features"] + + metadata["static_continuous_features"] + ) + assert isinstance(model.lstm_encoder, nn.LSTM) + assert model.lstm_encoder.input_size == max(1, model.encoder_input_dim) + assert isinstance(model.self_attention, nn.MultiheadAttention) + if hasattr(model, "hparams") and model.hparams: + assert model.hparams.get("hidden_size") == HIDDEN_SIZE_TEST + assert model.output_size == OUTPUT_SIZE_TEST + + def test_initialization_no_time_varying_features( + self, tft_model_params_fixture_func + ): + metadata = get_default_test_metadata( + enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.encoder_input_dim == 0 + assert model.encoder_var_selection is None + assert model.lstm_encoder.input_size == 1 + assert model.decoder_input_dim == 0 + assert model.decoder_var_selection is None + assert model.lstm_decoder.input_size == 1 + + def test_initialization_no_static_features(self, tft_model_params_fixture_func): + metadata = get_default_test_metadata( + static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.static_input_dim == 0 + assert model.static_context_linear is None + + +class TestTFTForwardPass: + @pytest.mark.parametrize( + "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", + [ + (2, 1, 1, 1, 1, 1), + (2, 0, 1, 0, 0, 0), + (0, 0, 0, 0, 1, 1), + (0, 0, 0, 0, 0, 0), + (1, 0, 1, 0, 1, 0), + (1, 0, 1, 0, 0, 1), + ], + ) + def test_forward_pass_configs( + self, tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k + ): + current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] + metadata = get_default_test_metadata( + enc_cont=enc_c, + enc_cat=enc_k, + dec_cont=dec_c, + dec_cat=dec_k, + static_cat=stat_c, + static_cont=stat_k, + output_size=current_tft_actual_output_size, + ) + model_params = tft_model_params_fixture_func.copy() + model_params["output_size"] = current_tft_actual_output_size + model = TFT(**model_params, metadata=metadata) + model.eval() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + x = create_tft_input_batch_for_test( + metadata, batch_size=BATCH_SIZE_TEST, device=device + ) + output_dict = model(x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + BATCH_SIZE_TEST, + MAX_PREDICTION_LENGTH_TEST, + current_tft_actual_output_size, + ) + assert not torch.isnan(predictions).any(), "NaNs in prediction" + assert not torch.isinf(predictions).any(), "Infs in prediction" + + +@pytest.fixture +def sample_pandas_data_for_test(): + """Create sample data ensuring all feature columns are numeric (float32).""" + series_len = MAX_ENCODER_LENGTH_TEST + MAX_PREDICTION_LENGTH_TEST + 5 + num_groups = 6 + data = [] + + for i in range(num_groups): + static_cont_val = np.float32(i * 10.0) + static_cat_code = np.float32(i % 2) + + df_group = pd.DataFrame( + { + "time_idx": np.arange(series_len, dtype=np.int64), + "group_id_str": np.repeat(f"g{i}", series_len), + "target": np.random.rand(series_len).astype(np.float32) + i, + "enc_cont1": np.random.rand(series_len).astype(np.float32), + "enc_cat1_codes": np.random.randint(0, 3, series_len).astype( + np.float32 + ), + "dec_known_cont": np.sin(np.arange(series_len) / 5.0).astype( + np.float32 + ), + "dec_known_cat_codes": np.random.randint(0, 2, series_len).astype( + np.float32 + ), + "static_cat_feat_codes": np.full( + series_len, static_cat_code, dtype=np.float32 + ), + "static_cont_feat": np.full( + series_len, static_cont_val, dtype=np.float32 + ), + } + ) + data.append(df_group) + + df = pd.concat(data, ignore_index=True) + + df["group_id"] = df["group_id_str"].astype("category") + df.drop(columns=["group_id_str"], inplace=True) + + return df + + +@pytest.fixture +def timeseries_obj_for_test(sample_pandas_data_for_test): + df = sample_pandas_data_for_test + + return TimeSeries( + data=df, + time="time_idx", + target="target", + group=["group_id"], + num=[ + "enc_cont1", + "enc_cat1_codes", + "dec_known_cont", + "dec_known_cat_codes", + "static_cat_feat_codes", + "static_cont_feat", + ], + cat=[], + known=["dec_known_cont", "dec_known_cat_codes", "time_idx"], + static=["static_cat_feat_codes", "static_cont_feat"], + ) + + +@pytest.fixture +def data_module_for_test(timeseries_obj_for_test): + dm = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=timeseries_obj_for_test, + batch_size=BATCH_SIZE_TEST, + max_encoder_length=MAX_ENCODER_LENGTH_TEST, + max_prediction_length=MAX_PREDICTION_LENGTH_TEST, + train_val_test_split=(0.5, 0.25, 0.25), + num_workers=0, # Added for consistency + ) + dm.setup("fit") + dm.setup("test") + return dm + + +class TestTFTWithDataModule: + def test_model_with_datamodule_integration( + self, tft_model_params_fixture_func, data_module_for_test + ): + dm = data_module_for_test + model_metadata_from_dm = dm.metadata + + assert ( + model_metadata_from_dm["encoder_cont"] == 6 + ), f"Actual encoder_cont: {model_metadata_from_dm['encoder_cont']}" + assert ( + model_metadata_from_dm["encoder_cat"] == 0 + ), f"Actual encoder_cat: {model_metadata_from_dm['encoder_cat']}" + assert ( + model_metadata_from_dm["decoder_cont"] == 2 + ), f"Actual decoder_cont: {model_metadata_from_dm['decoder_cont']}" + assert ( + model_metadata_from_dm["decoder_cat"] == 0 + ), f"Actual decoder_cat: {model_metadata_from_dm['decoder_cat']}" + assert ( + model_metadata_from_dm["static_categorical_features"] == 0 + ), f"Actual static_cat: {model_metadata_from_dm['static_categorical_features']}" + assert ( + model_metadata_from_dm["static_continuous_features"] == 2 + ), f"Actual static_cont: {model_metadata_from_dm['static_continuous_features']}" + assert model_metadata_from_dm["target"] == 1 + + tft_init_args = tft_model_params_fixture_func.copy() + tft_init_args["output_size"] = model_metadata_from_dm["target"] + model = TFT(**tft_init_args, metadata=model_metadata_from_dm) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + + train_loader = dm.train_dataloader() + batch_x, batch_y = next(iter(train_loader)) + + actual_batch_size = batch_x["encoder_cont"].shape[0] + batch_x = {k: v.to(device) for k, v in batch_x.items()} + batch_y = batch_y.to(device) + + assert ( + batch_x["encoder_cont"].shape[2] == model_metadata_from_dm["encoder_cont"] + ) + assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] + assert ( + batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] + ) + assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] + # assert ( + # batch_x["static_categorical_features"].shape[2] + # == model_metadata_from_dm["static_categorical_features"] + # ) + # assert ( + # batch_x["static_continuous_features"].shape[2] + # == model_metadata_from_dm["static_continuous_features"] + # ) + + output_dict = model(batch_x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) + assert not torch.isnan(predictions).any() + assert batch_y.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) From 9d80eb822e47c92e3b542cd70fe98103e00bd829 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 19:10:57 +0530 Subject: [PATCH 50/80] add tests --- tests/test_models/test_tft_v2.py | 311 +++++++++++++++---------------- 1 file changed, 152 insertions(+), 159 deletions(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index e69d3d06d..0455ad818 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -121,95 +121,92 @@ def tft_model_params_fixture_func(): } -class TestTFTInitialization: - def test_basic_initialization(self, tft_model_params_fixture_func): - metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) - model = TFT(**tft_model_params_fixture_func, metadata=metadata) - assert model.hidden_size == HIDDEN_SIZE_TEST - assert model.num_layers == NUM_LAYERS_TEST - assert hasattr(model, "metadata") and model.metadata == metadata - assert ( - model.encoder_input_dim - == metadata["encoder_cont"] + metadata["encoder_cat"] - ) - assert ( - model.static_input_dim - == metadata["static_categorical_features"] - + metadata["static_continuous_features"] - ) - assert isinstance(model.lstm_encoder, nn.LSTM) - assert model.lstm_encoder.input_size == max(1, model.encoder_input_dim) - assert isinstance(model.self_attention, nn.MultiheadAttention) - if hasattr(model, "hparams") and model.hparams: - assert model.hparams.get("hidden_size") == HIDDEN_SIZE_TEST - assert model.output_size == OUTPUT_SIZE_TEST - - def test_initialization_no_time_varying_features( - self, tft_model_params_fixture_func - ): - metadata = get_default_test_metadata( - enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST - ) - model = TFT(**tft_model_params_fixture_func, metadata=metadata) - assert model.encoder_input_dim == 0 - assert model.encoder_var_selection is None - assert model.lstm_encoder.input_size == 1 - assert model.decoder_input_dim == 0 - assert model.decoder_var_selection is None - assert model.lstm_decoder.input_size == 1 - - def test_initialization_no_static_features(self, tft_model_params_fixture_func): - metadata = get_default_test_metadata( - static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST - ) - model = TFT(**tft_model_params_fixture_func, metadata=metadata) - assert model.static_input_dim == 0 - assert model.static_context_linear is None - - -class TestTFTForwardPass: - @pytest.mark.parametrize( - "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", - [ - (2, 1, 1, 1, 1, 1), - (2, 0, 1, 0, 0, 0), - (0, 0, 0, 0, 1, 1), - (0, 0, 0, 0, 0, 0), - (1, 0, 1, 0, 1, 0), - (1, 0, 1, 0, 0, 1), - ], +# Converted from TestTFTInitialization class +def test_basic_initialization(tft_model_params_fixture_func): + metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.hidden_size == HIDDEN_SIZE_TEST + assert model.num_layers == NUM_LAYERS_TEST + assert hasattr(model, "metadata") and model.metadata == metadata + assert model.encoder_input_dim == metadata["encoder_cont"] + metadata["encoder_cat"] + assert ( + model.static_input_dim + == metadata["static_categorical_features"] + + metadata["static_continuous_features"] ) - def test_forward_pass_configs( - self, tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k - ): - current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] - metadata = get_default_test_metadata( - enc_cont=enc_c, - enc_cat=enc_k, - dec_cont=dec_c, - dec_cat=dec_k, - static_cat=stat_c, - static_cont=stat_k, - output_size=current_tft_actual_output_size, - ) - model_params = tft_model_params_fixture_func.copy() - model_params["output_size"] = current_tft_actual_output_size - model = TFT(**model_params, metadata=metadata) - model.eval() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - x = create_tft_input_batch_for_test( - metadata, batch_size=BATCH_SIZE_TEST, device=device - ) - output_dict = model(x) - predictions = output_dict["prediction"] - assert predictions.shape == ( - BATCH_SIZE_TEST, - MAX_PREDICTION_LENGTH_TEST, - current_tft_actual_output_size, - ) - assert not torch.isnan(predictions).any(), "NaNs in prediction" - assert not torch.isinf(predictions).any(), "Infs in prediction" + assert isinstance(model.lstm_encoder, nn.LSTM) + assert model.lstm_encoder.input_size == max(1, model.encoder_input_dim) + assert isinstance(model.self_attention, nn.MultiheadAttention) + if hasattr(model, "hparams") and model.hparams: + assert model.hparams.get("hidden_size") == HIDDEN_SIZE_TEST + assert model.output_size == OUTPUT_SIZE_TEST + + +def test_initialization_no_time_varying_features(tft_model_params_fixture_func): + metadata = get_default_test_metadata( + enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.encoder_input_dim == 0 + assert model.encoder_var_selection is None + assert model.lstm_encoder.input_size == 1 + assert model.decoder_input_dim == 0 + assert model.decoder_var_selection is None + assert model.lstm_decoder.input_size == 1 + + +def test_initialization_no_static_features(tft_model_params_fixture_func): + metadata = get_default_test_metadata( + static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST + ) + model = TFT(**tft_model_params_fixture_func, metadata=metadata) + assert model.static_input_dim == 0 + assert model.static_context_linear is None + + +# Converted from TestTFTForwardPass class +@pytest.mark.parametrize( + "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", + [ + (2, 1, 1, 1, 1, 1), + (2, 0, 1, 0, 0, 0), + (0, 0, 0, 0, 1, 1), + (0, 0, 0, 0, 0, 0), + (1, 0, 1, 0, 1, 0), + (1, 0, 1, 0, 0, 1), + ], +) +def test_forward_pass_configs( + tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k +): + current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] + metadata = get_default_test_metadata( + enc_cont=enc_c, + enc_cat=enc_k, + dec_cont=dec_c, + dec_cat=dec_k, + static_cat=stat_c, + static_cont=stat_k, + output_size=current_tft_actual_output_size, + ) + model_params = tft_model_params_fixture_func.copy() + model_params["output_size"] = current_tft_actual_output_size + model = TFT(**model_params, metadata=metadata) + model.eval() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + x = create_tft_input_batch_for_test( + metadata, batch_size=BATCH_SIZE_TEST, device=device + ) + output_dict = model(x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + BATCH_SIZE_TEST, + MAX_PREDICTION_LENGTH_TEST, + current_tft_actual_output_size, + ) + assert not torch.isnan(predictions).any(), "NaNs in prediction" + assert not torch.isinf(predictions).any(), "Infs in prediction" @pytest.fixture @@ -294,74 +291,70 @@ def data_module_for_test(timeseries_obj_for_test): return dm -class TestTFTWithDataModule: - def test_model_with_datamodule_integration( - self, tft_model_params_fixture_func, data_module_for_test - ): - dm = data_module_for_test - model_metadata_from_dm = dm.metadata - - assert ( - model_metadata_from_dm["encoder_cont"] == 6 - ), f"Actual encoder_cont: {model_metadata_from_dm['encoder_cont']}" - assert ( - model_metadata_from_dm["encoder_cat"] == 0 - ), f"Actual encoder_cat: {model_metadata_from_dm['encoder_cat']}" - assert ( - model_metadata_from_dm["decoder_cont"] == 2 - ), f"Actual decoder_cont: {model_metadata_from_dm['decoder_cont']}" - assert ( - model_metadata_from_dm["decoder_cat"] == 0 - ), f"Actual decoder_cat: {model_metadata_from_dm['decoder_cat']}" - assert ( - model_metadata_from_dm["static_categorical_features"] == 0 - ), f"Actual static_cat: {model_metadata_from_dm['static_categorical_features']}" - assert ( - model_metadata_from_dm["static_continuous_features"] == 2 - ), f"Actual static_cont: {model_metadata_from_dm['static_continuous_features']}" - assert model_metadata_from_dm["target"] == 1 - - tft_init_args = tft_model_params_fixture_func.copy() - tft_init_args["output_size"] = model_metadata_from_dm["target"] - model = TFT(**tft_init_args, metadata=model_metadata_from_dm) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - model.eval() - - train_loader = dm.train_dataloader() - batch_x, batch_y = next(iter(train_loader)) - - actual_batch_size = batch_x["encoder_cont"].shape[0] - batch_x = {k: v.to(device) for k, v in batch_x.items()} - batch_y = batch_y.to(device) - - assert ( - batch_x["encoder_cont"].shape[2] == model_metadata_from_dm["encoder_cont"] - ) - assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] - assert ( - batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] - ) - assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] - # assert ( - # batch_x["static_categorical_features"].shape[2] - # == model_metadata_from_dm["static_categorical_features"] - # ) - # assert ( - # batch_x["static_continuous_features"].shape[2] - # == model_metadata_from_dm["static_continuous_features"] - # ) - - output_dict = model(batch_x) - predictions = output_dict["prediction"] - assert predictions.shape == ( - actual_batch_size, - MAX_PREDICTION_LENGTH_TEST, - model_metadata_from_dm["target"], - ) - assert not torch.isnan(predictions).any() - assert batch_y.shape == ( - actual_batch_size, - MAX_PREDICTION_LENGTH_TEST, - model_metadata_from_dm["target"], - ) +# Converted from TestTFTWithDataModule class +def test_model_with_datamodule_integration( + tft_model_params_fixture_func, data_module_for_test +): + dm = data_module_for_test + model_metadata_from_dm = dm.metadata + + assert ( + model_metadata_from_dm["encoder_cont"] == 6 + ), f"Actual encoder_cont: {model_metadata_from_dm['encoder_cont']}" + assert ( + model_metadata_from_dm["encoder_cat"] == 0 + ), f"Actual encoder_cat: {model_metadata_from_dm['encoder_cat']}" + assert ( + model_metadata_from_dm["decoder_cont"] == 2 + ), f"Actual decoder_cont: {model_metadata_from_dm['decoder_cont']}" + assert ( + model_metadata_from_dm["decoder_cat"] == 0 + ), f"Actual decoder_cat: {model_metadata_from_dm['decoder_cat']}" + assert ( + model_metadata_from_dm["static_categorical_features"] == 0 + ), f"Actual static_cat: {model_metadata_from_dm['static_categorical_features']}" + assert ( + model_metadata_from_dm["static_continuous_features"] == 2 + ), f"Actual static_cont: {model_metadata_from_dm['static_continuous_features']}" + assert model_metadata_from_dm["target"] == 1 + + tft_init_args = tft_model_params_fixture_func.copy() + tft_init_args["output_size"] = model_metadata_from_dm["target"] + model = TFT(**tft_init_args, metadata=model_metadata_from_dm) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + + train_loader = dm.train_dataloader() + batch_x, batch_y = next(iter(train_loader)) + + actual_batch_size = batch_x["encoder_cont"].shape[0] + batch_x = {k: v.to(device) for k, v in batch_x.items()} + batch_y = batch_y.to(device) + + assert batch_x["encoder_cont"].shape[2] == model_metadata_from_dm["encoder_cont"] + assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] + assert batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] + assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] + # assert ( + # batch_x["static_categorical_features"].shape[2] + # == model_metadata_from_dm["static_categorical_features"] + # ) + # assert ( + # batch_x["static_continuous_features"].shape[2] + # == model_metadata_from_dm["static_continuous_features"] + # ) + + output_dict = model(batch_x) + predictions = output_dict["prediction"] + assert predictions.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) + assert not torch.isnan(predictions).any() + assert batch_y.shape == ( + actual_batch_size, + MAX_PREDICTION_LENGTH_TEST, + model_metadata_from_dm["target"], + ) From a8ccfe36d383191ba6bd23902543aed40dbe0d39 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 19:17:36 +0530 Subject: [PATCH 51/80] add tests --- tests/test_models/test_tft_v2.py | 37 +++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index 0455ad818..ae74d59fc 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -123,6 +123,14 @@ def tft_model_params_fixture_func(): # Converted from TestTFTInitialization class def test_basic_initialization(tft_model_params_fixture_func): + """Test basic initialization of the TFT model with default metadata. + + Verifies: + - Model attributes match the provided metadata (e.g., hidden_size, num_layers). + - Proper construction of key model components (LSTM, attention, etc.). + - Correct dimensionality of input layers based on metadata. + - Model retains metadata and hyperparameters as expected. + """ metadata = get_default_test_metadata(output_size=OUTPUT_SIZE_TEST) model = TFT(**tft_model_params_fixture_func, metadata=metadata) assert model.hidden_size == HIDDEN_SIZE_TEST @@ -143,6 +151,13 @@ def test_basic_initialization(tft_model_params_fixture_func): def test_initialization_no_time_varying_features(tft_model_params_fixture_func): + """Test TFT initialization with no time-varying (encoder/decoder) features. + + Verifies: + - Model handles zero encoder/decoder input dimensions correctly. + - Skips creation of encoder/decoder variable selection networks. + - Defaults to input size 1 for LSTMs when no time-varying features exist. + """ metadata = get_default_test_metadata( enc_cont=0, enc_cat=0, dec_cont=0, dec_cat=0, output_size=OUTPUT_SIZE_TEST ) @@ -156,6 +171,12 @@ def test_initialization_no_time_varying_features(tft_model_params_fixture_func): def test_initialization_no_static_features(tft_model_params_fixture_func): + """Test TFT initialization with no static features. + + Verifies: + - Model static input dim is 0. + - Static context linear layer is not created. + """ metadata = get_default_test_metadata( static_cat=0, static_cont=0, output_size=OUTPUT_SIZE_TEST ) @@ -179,6 +200,13 @@ def test_initialization_no_static_features(tft_model_params_fixture_func): def test_forward_pass_configs( tft_model_params_fixture_func, enc_c, enc_k, dec_c, dec_k, stat_c, stat_k ): + """Test TFT forward pass across multiple feature configurations. + + Verifies: + - Model can forward pass without errors for varying combinations of input types. + - Output prediction tensor has expected shape. + - Output contains no NaNs or infinities. + """ current_tft_actual_output_size = tft_model_params_fixture_func["output_size"] metadata = get_default_test_metadata( enc_cont=enc_c, @@ -211,7 +239,6 @@ def test_forward_pass_configs( @pytest.fixture def sample_pandas_data_for_test(): - """Create sample data ensuring all feature columns are numeric (float32).""" series_len = MAX_ENCODER_LENGTH_TEST + MAX_PREDICTION_LENGTH_TEST + 5 num_groups = 6 data = [] @@ -295,6 +322,14 @@ def data_module_for_test(timeseries_obj_for_test): def test_model_with_datamodule_integration( tft_model_params_fixture_func, data_module_for_test ): + """Integration test to ensure TFT works correctly with data module. + + Verifies: + - Metadata inferred from data module matches expected input dimensions. + - Model processes real dataloader batches correctly. + - Output and target tensors from model and data module align in shape. + - No NaNs in predictions. + """ dm = data_module_for_test model_metadata_from_dm = dm.metadata From f900ba5e4d4912573e7dc79c398386e683d5e807 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 19:24:21 +0530 Subject: [PATCH 52/80] add more docstrings --- tests/test_models/test_tft_v2.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index ae74d59fc..d79eac874 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -27,6 +27,7 @@ def get_default_test_metadata( static_cont=1, output_size=OUTPUT_SIZE_TEST, ): + """Return a dict representing default metadata for TFT model initialization.""" return { "max_encoder_length": MAX_ENCODER_LENGTH_TEST, "max_prediction_length": MAX_PREDICTION_LENGTH_TEST, @@ -41,6 +42,8 @@ def get_default_test_metadata( def create_tft_input_batch_for_test(metadata, batch_size=BATCH_SIZE_TEST, device="cpu"): + """Create a synthetic input batch dictionary for testing TFT forward passes.""" + def _get_dim_val(key): return metadata.get(key, 0) @@ -111,6 +114,7 @@ def _get_dim_val(key): @pytest.fixture(scope="module") def tft_model_params_fixture_func(): + """Create a default set of model parameters for TFT.""" return { "loss": dummy_loss_for_test, "hidden_size": HIDDEN_SIZE_TEST, @@ -121,7 +125,6 @@ def tft_model_params_fixture_func(): } -# Converted from TestTFTInitialization class def test_basic_initialization(tft_model_params_fixture_func): """Test basic initialization of the TFT model with default metadata. @@ -239,6 +242,7 @@ def test_forward_pass_configs( @pytest.fixture def sample_pandas_data_for_test(): + """Create synthetic multivariate time series data as a pandas DataFrame.""" series_len = MAX_ENCODER_LENGTH_TEST + MAX_PREDICTION_LENGTH_TEST + 5 num_groups = 6 data = [] @@ -282,6 +286,7 @@ def sample_pandas_data_for_test(): @pytest.fixture def timeseries_obj_for_test(sample_pandas_data_for_test): + """Convert sample DataFrame into a TimeSeries object.""" df = sample_pandas_data_for_test return TimeSeries( @@ -305,6 +310,7 @@ def timeseries_obj_for_test(sample_pandas_data_for_test): @pytest.fixture def data_module_for_test(timeseries_obj_for_test): + """Initialize and sets up an EncoderDecoderTimeSeriesDataModule.""" dm = EncoderDecoderTimeSeriesDataModule( time_series_dataset=timeseries_obj_for_test, batch_size=BATCH_SIZE_TEST, @@ -318,7 +324,6 @@ def data_module_for_test(timeseries_obj_for_test): return dm -# Converted from TestTFTWithDataModule class def test_model_with_datamodule_integration( tft_model_params_fixture_func, data_module_for_test ): From ed1b79936df9c4cb18c29393f964228997001b98 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 14 May 2025 19:26:40 +0530 Subject: [PATCH 53/80] add note about the commented out tests --- tests/test_models/test_tft_v2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index d79eac874..57a50e75e 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -188,7 +188,6 @@ def test_initialization_no_static_features(tft_model_params_fixture_func): assert model.static_context_linear is None -# Converted from TestTFTForwardPass class @pytest.mark.parametrize( "enc_c, enc_k, dec_c, dec_k, stat_c, stat_k", [ @@ -334,6 +333,8 @@ def test_model_with_datamodule_integration( - Model processes real dataloader batches correctly. - Output and target tensors from model and data module align in shape. - No NaNs in predictions. + + Note: The commented out tests are to test a bug in data_module """ dm = data_module_for_test model_metadata_from_dm = dm.metadata From c0ceb8a16703573144e3d0bd3aa6ab978157a341 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 17 May 2025 02:08:06 +0530 Subject: [PATCH 54/80] add the commented out tests --- tests/test_models/test_tft_v2.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index 57a50e75e..f541082ce 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -316,7 +316,6 @@ def data_module_for_test(timeseries_obj_for_test): max_encoder_length=MAX_ENCODER_LENGTH_TEST, max_prediction_length=MAX_PREDICTION_LENGTH_TEST, train_val_test_split=(0.5, 0.25, 0.25), - num_workers=0, # Added for consistency ) dm.setup("fit") dm.setup("test") @@ -377,14 +376,14 @@ def test_model_with_datamodule_integration( assert batch_x["encoder_cat"].shape[2] == model_metadata_from_dm["encoder_cat"] assert batch_x["decoder_cont"].shape[2] == model_metadata_from_dm["decoder_cont"] assert batch_x["decoder_cat"].shape[2] == model_metadata_from_dm["decoder_cat"] - # assert ( - # batch_x["static_categorical_features"].shape[2] - # == model_metadata_from_dm["static_categorical_features"] - # ) - # assert ( - # batch_x["static_continuous_features"].shape[2] - # == model_metadata_from_dm["static_continuous_features"] - # ) + assert ( + batch_x["static_categorical_features"].shape[2] + == model_metadata_from_dm["static_categorical_features"] + ) + assert ( + batch_x["static_continuous_features"].shape[2] + == model_metadata_from_dm["static_continuous_features"] + ) output_dict = model(batch_x) predictions = output_dict["prediction"] From 3828c260d4b32ee7fcd9fc300776126c70f6a3b6 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 17 May 2025 02:09:16 +0530 Subject: [PATCH 55/80] remove note --- tests/test_models/test_tft_v2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/test_tft_v2.py index f541082ce..791ea10ef 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/test_tft_v2.py @@ -332,8 +332,6 @@ def test_model_with_datamodule_integration( - Model processes real dataloader batches correctly. - Output and target tensors from model and data module align in shape. - No NaNs in predictions. - - Note: The commented out tests are to test a bug in data_module """ dm = data_module_for_test model_metadata_from_dm = dm.metadata From 30b541b2910c461e3e488e19137c1242c0b0627b Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 21 May 2025 00:52:29 +0530 Subject: [PATCH 56/80] make the modules private --- .../{base_model_refactor.py => _base_model_v2.py} | 13 +++++++++++++ .../{tft_version_two.py => _tft_v2.py} | 2 +- .../test_models/{test_tft_v2.py => _test_tft_v2.py} | 2 +- 3 files changed, 15 insertions(+), 2 deletions(-) rename pytorch_forecasting/models/base/{base_model_refactor.py => _base_model_v2.py} (93%) rename pytorch_forecasting/models/temporal_fusion_transformer/{tft_version_two.py => _tft_v2.py} (99%) rename tests/test_models/{test_tft_v2.py => _test_tft_v2.py} (99%) diff --git a/pytorch_forecasting/models/base/base_model_refactor.py b/pytorch_forecasting/models/base/_base_model_v2.py similarity index 93% rename from pytorch_forecasting/models/base/base_model_refactor.py rename to pytorch_forecasting/models/base/_base_model_v2.py index ccd2c2600..ddefc29fb 100644 --- a/pytorch_forecasting/models/base/base_model_refactor.py +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -6,6 +6,7 @@ from typing import Dict, List, Optional, Tuple, Union +from warnings import warn from lightning.pytorch import LightningModule from lightning.pytorch.utilities.types import STEP_OUTPUT @@ -53,6 +54,18 @@ def __init__( self.lr_scheduler_params = ( lr_scheduler_params if lr_scheduler_params is not None else {} ) + self.model_name = self.__class__.__name__ + warn( + f"The Model '{self.model_name}' is part of an experimental rework" + "of the pytorch-forecasting model layer, scheduled for release with v2.0.0." + " The API is not stable and may change without prior warning. " + "This class is intended for beta testing and as a basic skeleton, " + "but not for stable production use. " + "Feedback and suggestions are very welcome in " + "pytorch-forecasting issue 1736, " + "https://github.com/sktime/pytorch-forecasting/issues/1736", + UserWarning, + ) def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py similarity index 99% rename from pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py rename to pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py index 1a1634356..a0cf7d39e 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_version_two.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py @@ -9,7 +9,7 @@ import torch.nn as nn from torch.optim import Optimizer -from pytorch_forecasting.models.base.base_model_refactor import BaseModel +from pytorch_forecasting.models.base._base_model_v2 import BaseModel class TFT(BaseModel): diff --git a/tests/test_models/test_tft_v2.py b/tests/test_models/_test_tft_v2.py similarity index 99% rename from tests/test_models/test_tft_v2.py rename to tests/test_models/_test_tft_v2.py index 791ea10ef..13d92d5db 100644 --- a/tests/test_models/test_tft_v2.py +++ b/tests/test_models/_test_tft_v2.py @@ -6,7 +6,7 @@ from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule from pytorch_forecasting.data.timeseries import TimeSeries -from pytorch_forecasting.models.temporal_fusion_transformer.tft_version_two import TFT +from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT BATCH_SIZE_TEST = 2 MAX_ENCODER_LENGTH_TEST = 10 From 5cc3ff1dd8be8f4b325d383abea14c4c06ace280 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 21 May 2025 01:05:44 +0530 Subject: [PATCH 57/80] initial commit --- .../_tft_v2_metadata.py | 0 .../tests/test_all_estimators_v2.py | 49 +++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py create mode 100644 pytorch_forecasting/tests/test_all_estimators_v2.py diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py new file mode 100644 index 000000000..e69de29bb diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py new file mode 100644 index 000000000..1dc7859ab --- /dev/null +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -0,0 +1,49 @@ +"""Automated tests based on the skbase test suite template.""" + +from inspect import isclass +import shutil + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +from lightning.pytorch.loggers import TensorBoardLogger + +from pytorch_forecasting.tests._conftest import make_dataloaders +from pytorch_forecasting.tests.test_all_estimators import ( + BaseFixtureGenerator, + PackageConfig, +) + +# whether to test only estimators from modules that are changed w.r.t. main +# default is False, can be set to True by pytest --only_changed_modules True flag +ONLY_CHANGED_MODULES = False + + +def _integration( + estimator_cls, + data_with_covariates, + tmp_path, + cell_type="LSTM", + data_loader_kwargs={}, + clip_target: bool = False, + trainer_kwargs=None, + **kwargs, +): + pass + + +class TestAllPtForecastersV2(PackageConfig, BaseFixtureGenerator): + """Generic tests for all objects in the mini package.""" + + def test_doctest_examples(self, object_class): + """Runs doctests for estimator class.""" + import doctest + + doctest.run_docstring_examples(object_class, globals()) + + def test_integration( + self, + object_metadata, + trainer_kwargs, + tmp_path, + ): + pass From f18e09d183f40608d9d92e5cb7ef50d02cbf4644 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 21 May 2025 01:16:44 +0530 Subject: [PATCH 58/80] add TFTMetadata class --- .../_tft_v2_metadata.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py index e69de29bb..5ea87c24b 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py @@ -0,0 +1,24 @@ +"""TFT metadata container.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecaster + + +class TFTMetadata(_BasePtForecaster): + """TFT metadata container.""" + + _tags = { + "info:name": "TFT", + "info:compute": 3, + "authors": ["jdb78"], + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": False, + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT + + return TFT From e1e360eb7e626ce590e8c8a1bf6a3467220cdbb2 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 21 May 2025 01:16:58 +0530 Subject: [PATCH 59/80] add TFTMetadata class --- .../models/temporal_fusion_transformer/_tft_v2_metadata.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py index 5ea87c24b..8f11d78f3 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py @@ -8,7 +8,6 @@ class TFTMetadata(_BasePtForecaster): _tags = { "info:name": "TFT", - "info:compute": 3, "authors": ["jdb78"], "capability:exogenous": True, "capability:multivariate": True, From 92c12bf1da172f9bf86b9e40667d6cd120fa3958 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 26 May 2025 01:32:35 +0530 Subject: [PATCH 60/80] add TFT tests --- .../models/base/_base_model_v2.py | 44 +- .../temporal_fusion_transformer/_tft_v2.py | 3 +- .../_tft_v2_metadata.py | 23 - .../temporal_fusion_transformer/_tft_ver2.py | 1130 +++++++++++++++++ .../tft_v2_metadata.py | 90 ++ pytorch_forecasting/tests/_conftest.py | 232 ++++ pytorch_forecasting/tests/_data_scenarios.py | 231 ++++ .../tests/test_all_estimators_v2.py | 81 +- 8 files changed, 1805 insertions(+), 29 deletions(-) delete mode 100644 pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py create mode 100644 pytorch_forecasting/models/temporal_fusion_transformer/_tft_ver2.py create mode 100644 pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py diff --git a/pytorch_forecasting/models/base/_base_model_v2.py b/pytorch_forecasting/models/base/_base_model_v2.py index ddefc29fb..15e9cb9d8 100644 --- a/pytorch_forecasting/models/base/_base_model_v2.py +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -14,11 +14,23 @@ import torch.nn as nn from torch.optim import Optimizer +from pytorch_forecasting.metrics import ( + MAE, + MASE, + SMAPE, + DistributionLoss, + Metric, + MultiHorizonMetric, + MultiLoss, + QuantileLoss, + convert_torchmetric_to_pytorch_forecasting_metric, +) + class BaseModel(LightningModule): def __init__( self, - loss: nn.Module, + loss: Metric = SMAPE(), logging_metrics: Optional[List[nn.Module]] = None, optimizer: Optional[Union[Optimizer, str]] = "adam", optimizer_params: Optional[Dict] = None, @@ -104,6 +116,7 @@ def training_step( x, y = batch y_hat_dict = self(x) y_hat = y_hat_dict["prediction"] + y_hat, y = self._align_prediction_target_shapes(y_hat, y) loss = self.loss(y_hat, y) self.log( "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True @@ -132,6 +145,7 @@ def validation_step( x, y = batch y_hat_dict = self(x) y_hat = y_hat_dict["prediction"] + y_hat, y = self._align_prediction_target_shapes(y_hat, y) loss = self.loss(y_hat, y) self.log( "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True @@ -160,6 +174,7 @@ def test_step( x, y = batch y_hat_dict = self(x) y_hat = y_hat_dict["prediction"] + y_hat, y = self._align_prediction_target_shapes(y_hat, y) loss = self.loss(y_hat, y) self.log( "test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True @@ -294,3 +309,30 @@ def log_metrics( prog_bar=True, logger=True, ) + + def _align_prediction_target_shapes( + self, y_hat: torch.Tensor, y: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Align prediction and target tensor shapes for loss/metric calculation. + + Returns + ------- + Tuple of aligned prediction and target tensors + """ + if y.dim() == 3 and y.shape[-1] == 1: + y = y.squeeze(-1) + if y_hat.dim() < y.dim(): + y_hat = y_hat.unsqueeze(-1) + elif y_hat.dim() > y.dim(): + if y_hat.shape[-1] == 1: + y_hat = y_hat.squeeze(-1) + if y_hat.shape != y.shape: + if y_hat.numel() == y.numel(): + y_hat = y_hat.view(y.shape) + else: + raise ValueError( + f"Cannot align shapes: y_hat {y_hat.shape} vs y {y.shape}" + ) + + return y_hat, y diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py index a0cf7d39e..fd41fe2a1 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py @@ -9,13 +9,14 @@ import torch.nn as nn from torch.optim import Optimizer +from pytorch_forecasting.metrics import Metric from pytorch_forecasting.models.base._base_model_v2 import BaseModel class TFT(BaseModel): def __init__( self, - loss: nn.Module, + loss: Metric, logging_metrics: Optional[List[nn.Module]] = None, optimizer: Optional[Union[Optimizer, str]] = "adam", optimizer_params: Optional[Dict] = None, diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py deleted file mode 100644 index 8f11d78f3..000000000 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2_metadata.py +++ /dev/null @@ -1,23 +0,0 @@ -"""TFT metadata container.""" - -from pytorch_forecasting.models.base._base_object import _BasePtForecaster - - -class TFTMetadata(_BasePtForecaster): - """TFT metadata container.""" - - _tags = { - "info:name": "TFT", - "authors": ["jdb78"], - "capability:exogenous": True, - "capability:multivariate": True, - "capability:pred_int": True, - "capability:flexible_history_length": False, - } - - @classmethod - def get_model_cls(cls): - """Get model class.""" - from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT - - return TFT diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_ver2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_ver2.py new file mode 100644 index 000000000..cff63385c --- /dev/null +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_ver2.py @@ -0,0 +1,1130 @@ +""" +The temporal fusion transformer is a powerful predictive model for forecasting timeseries +""" # noqa: E501 + +from copy import copy +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torchmetrics import Metric as LightningMetric + +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.metrics import ( + MAE, + MAPE, + RMSE, + SMAPE, + MultiHorizonMetric, + QuantileLoss, +) +from pytorch_forecasting.models.base import BaseModelWithCovariates +from pytorch_forecasting.models.nn import LSTM, MultiEmbedding +from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import ( + AddNorm, + GateAddNorm, + GatedLinearUnit, + GatedResidualNetwork, + InterpretableMultiHeadAttention, + VariableSelectionNetwork, +) +from pytorch_forecasting.utils import ( + create_mask, + detach, + integer_histogram, + masked_op, + padded_stack, + to_list, +) +from pytorch_forecasting.utils._dependencies import _check_matplotlib + + +class TemporalFusionTransformer(BaseModelWithCovariates): + def __init__( + self, + metadata: Dict[str, Any], + hidden_size: int = 16, + lstm_layers: int = 1, + dropout: float = 0.1, + output_size: Union[int, List[int]] = None, + loss: MultiHorizonMetric = None, + attention_head_size: int = 4, + categorical_groups: Optional[Union[Dict, List[str]]] = None, + hidden_continuous_size: int = 8, + hidden_continuous_sizes: Optional[Dict[str, int]] = None, + embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None, + embedding_paddings: Optional[List[str]] = None, + embedding_labels: Optional[Dict[str, np.ndarray]] = None, + learning_rate: float = 1e-3, + log_interval: Union[int, float] = -1, + log_val_interval: Union[int, float] = None, + log_gradient_flow: bool = False, + reduce_on_plateau_patience: int = 1000, + monotone_constaints: Optional[Dict[str, int]] = None, + share_single_variable_networks: bool = False, + causal_attention: bool = True, + logging_metrics: nn.ModuleList = None, + **kwargs, + ): + """ + Temporal Fusion Transformer for forecasting timeseries - use its :py:meth:`~from_dataset` method if possible. + + Implementation of the article + `Temporal Fusion Transformers for Interpretable Multi-horizon Time Series + Forecasting `_. The network outperforms DeepAR by Amazon by 36-69% + in benchmarks. + + Enhancements compared to the original implementation (apart from capabilities added through base model + such as monotone constraints): + + * static variables can be continuous + * multiple categorical variables can be summarized with an EmbeddingBag + * variable encoder and decoder length by sample + * categorical embeddings are not transformed by variable selection network (because it is a redundant operation) + * variable dimension in variable selection network are scaled up via linear interpolation to reduce + number of parameters + * non-linear variable processing in variable selection network can be shared among decoder and encoder + (not shared by default) + + Tune its hyperparameters with + :py:func:`~pytorch_forecasting.models.temporal_fusion_transformer.tuning.optimize_hyperparameters`. + + Args: + + hidden_size: hidden size of network which is its main hyperparameter and can range from 8 to 512 + lstm_layers: number of LSTM layers (2 is mostly optimal) + dropout: dropout rate + output_size: number of outputs (e.g. number of quantiles for QuantileLoss and one target or list + of output sizes). + loss: loss function taking prediction and targets + attention_head_size: number of attention heads (4 is a good default) + max_encoder_length: length to encode (can be far longer than the decoder length but does not have to be) + static_categoricals: names of static categorical variables + static_reals: names of static continuous variables + time_varying_categoricals_encoder: names of categorical variables for encoder + time_varying_categoricals_decoder: names of categorical variables for decoder + time_varying_reals_encoder: names of continuous variables for encoder + time_varying_reals_decoder: names of continuous variables for decoder + categorical_groups: dictionary where values + are list of categorical variables that are forming together a new categorical + variable which is the key in the dictionary + x_reals: order of continuous variables in tensor passed to forward function + x_categoricals: order of categorical variables in tensor passed to forward function + hidden_continuous_size: default for hidden size for processing continous variables (similar to categorical + embedding size) + hidden_continuous_sizes: dictionary mapping continuous input indices to sizes for variable selection + (fallback to hidden_continuous_size if index is not in dictionary) + embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and + embedding size + embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector + embedding_labels: dictionary mapping (string) indices to list of categorical labels + learning_rate: learning rate + log_interval: log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0 + , will log multiple entries per batch. Defaults to -1. + log_val_interval: frequency with which to log validation set metrics, defaults to log_interval + log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training + failures + reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 + monotone_constaints (Dict[str, int]): dictionary of monotonicity constraints for continuous decoder + variables mapping + position (e.g. ``"0"`` for first position) to constraint (``-1`` for negative and ``+1`` for positive, + larger numbers add more weight to the constraint vs. the loss but are usually not necessary). + This constraint significantly slows down training. Defaults to {}. + share_single_variable_networks (bool): if to share the single variable networks between the encoder and + decoder. Defaults to False. + causal_attention (bool): If to attend only at previous timesteps in the decoder or also include future + predictions. Defaults to True. + logging_metrics (nn.ModuleList[LightningMetric]): list of metrics that are logged during training. + Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]). + **kwargs: additional arguments to :py:class:`~BaseModel`. + """ # noqa: E501 + + max_encoder_length = metadata["max_encoder_length"] + if output_size is None: + output_size = metadata["target"] + static_categoricals = [ + f"static_cat_{i}" + for i in range(metadata.get("static_categorical_features", 0)) + ] + static_reals = [ + f"static_real_{i}" + for i in range(metadata.get("static_continuous_features", 0)) + ] + time_varying_categoricals_encoder = [ + f"encoder_cat_{i}" for i in range(metadata["encoder_cat"]) + ] + time_varying_reals_encoder = [ + f"encoder_real_{i}" for i in range(metadata["encoder_cont"]) + ] + time_varying_categoricals_decoder = [ + f"decoder_cat_{i}" for i in range(metadata["decoder_cat"]) + ] + time_varying_reals_decoder = [ + f"decoder_real_{i}" for i in range(metadata["decoder_cont"]) + ] + x_categoricals = ( + static_categoricals + + time_varying_categoricals_encoder + + time_varying_categoricals_decoder + ) + x_reals = static_reals + time_varying_reals_encoder + time_varying_reals_decoder + + if monotone_constaints is None: + monotone_constaints = {} + if embedding_labels is None: + embedding_labels = {} + if embedding_paddings is None: + embedding_paddings = [] + if embedding_sizes is None: + embedding_sizes = {} + if hidden_continuous_sizes is None: + hidden_continuous_sizes = {} + if categorical_groups is None: + categorical_groups = {} + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]) + if loss is None: + loss = QuantileLoss() + self.save_hyperparameters(ignore=["metadata"]) + + self.hparams.max_encoder_length = max_encoder_length + self.hparams.output_size = output_size + self.hparams.static_categoricals = static_categoricals + self.hparams.static_reals = static_reals + self.hparams.time_varying_categoricals_encoder = ( + time_varying_categoricals_encoder + ) + self.hparams.time_varying_categoricals_decoder = ( + time_varying_categoricals_decoder + ) + self.hparams.time_varying_reals_encoder = time_varying_reals_encoder + self.hparams.time_varying_reals_decoder = time_varying_reals_decoder + self.hparams.x_categoricals = x_categoricals + self.hparams.x_reals = x_reals + # store loss function separately as it is a module + assert isinstance( + loss, LightningMetric + ), "Loss has to be a PyTorch Lightning `Metric`" + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + # processing inputs + # embeddings + self.input_embeddings = MultiEmbedding( + embedding_sizes=self.hparams.embedding_sizes, + categorical_groups=self.hparams.categorical_groups, + embedding_paddings=self.hparams.embedding_paddings, + x_categoricals=self.hparams.x_categoricals, + max_embedding_size=self.hparams.hidden_size, + ) + + # continuous variable processing + self.prescalers = nn.ModuleDict( + { + name: nn.Linear( + 1, + self.hparams.hidden_continuous_sizes.get( + name, self.hparams.hidden_continuous_size + ), + ) + for name in self.reals + } + ) + + # variable selection + # variable selection for static variables + static_input_sizes = { + name: self.input_embeddings.output_size[name] + for name in self.hparams.static_categoricals + } + static_input_sizes.update( + { + name: self.hparams.hidden_continuous_sizes.get( + name, self.hparams.hidden_continuous_size + ) + for name in self.hparams.static_reals + } + ) + self.static_variable_selection = VariableSelectionNetwork( + input_sizes=static_input_sizes, + hidden_size=self.hparams.hidden_size, + input_embedding_flags={ + name: True for name in self.hparams.static_categoricals + }, + dropout=self.hparams.dropout, + prescalers=self.prescalers, + ) + + # variable selection for encoder and decoder + encoder_input_sizes = { + name: self.input_embeddings.output_size[name] + for name in self.hparams.time_varying_categoricals_encoder + } + encoder_input_sizes.update( + { + name: self.hparams.hidden_continuous_sizes.get( + name, self.hparams.hidden_continuous_size + ) + for name in self.hparams.time_varying_reals_encoder + } + ) + + decoder_input_sizes = { + name: self.input_embeddings.output_size[name] + for name in self.hparams.time_varying_categoricals_decoder + } + decoder_input_sizes.update( + { + name: self.hparams.hidden_continuous_sizes.get( + name, self.hparams.hidden_continuous_size + ) + for name in self.hparams.time_varying_reals_decoder + } + ) + + # create single variable grns that are shared across decoder and encoder + if self.hparams.share_single_variable_networks: + self.shared_single_variable_grns = nn.ModuleDict() + for name, input_size in encoder_input_sizes.items(): + self.shared_single_variable_grns[name] = GatedResidualNetwork( + input_size, + min(input_size, self.hparams.hidden_size), + self.hparams.hidden_size, + self.hparams.dropout, + ) + for name, input_size in decoder_input_sizes.items(): + if name not in self.shared_single_variable_grns: + self.shared_single_variable_grns[name] = GatedResidualNetwork( + input_size, + min(input_size, self.hparams.hidden_size), + self.hparams.hidden_size, + self.hparams.dropout, + ) + + self.encoder_variable_selection = VariableSelectionNetwork( + input_sizes=encoder_input_sizes, + hidden_size=self.hparams.hidden_size, + input_embedding_flags={ + name: True for name in self.hparams.time_varying_categoricals_encoder + }, + dropout=self.hparams.dropout, + context_size=self.hparams.hidden_size, + prescalers=self.prescalers, + single_variable_grns=( + {} + if not self.hparams.share_single_variable_networks + else self.shared_single_variable_grns + ), + ) + + self.decoder_variable_selection = VariableSelectionNetwork( + input_sizes=decoder_input_sizes, + hidden_size=self.hparams.hidden_size, + input_embedding_flags={ + name: True for name in self.hparams.time_varying_categoricals_decoder + }, + dropout=self.hparams.dropout, + context_size=self.hparams.hidden_size, + prescalers=self.prescalers, + single_variable_grns=( + {} + if not self.hparams.share_single_variable_networks + else self.shared_single_variable_grns + ), + ) + + # static encoders + # for variable selection + self.static_context_variable_selection = GatedResidualNetwork( + input_size=self.hparams.hidden_size, + hidden_size=self.hparams.hidden_size, + output_size=self.hparams.hidden_size, + dropout=self.hparams.dropout, + ) + + # for hidden state of the lstm + self.static_context_initial_hidden_lstm = GatedResidualNetwork( + input_size=self.hparams.hidden_size, + hidden_size=self.hparams.hidden_size, + output_size=self.hparams.hidden_size, + dropout=self.hparams.dropout, + ) + + # for cell state of the lstm + self.static_context_initial_cell_lstm = GatedResidualNetwork( + input_size=self.hparams.hidden_size, + hidden_size=self.hparams.hidden_size, + output_size=self.hparams.hidden_size, + dropout=self.hparams.dropout, + ) + + # for post lstm static enrichment + self.static_context_enrichment = GatedResidualNetwork( + self.hparams.hidden_size, + self.hparams.hidden_size, + self.hparams.hidden_size, + self.hparams.dropout, + ) + + # lstm encoder (history) and decoder (future) for local processing + self.lstm_encoder = LSTM( + input_size=self.hparams.hidden_size, + hidden_size=self.hparams.hidden_size, + num_layers=self.hparams.lstm_layers, + dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0, + batch_first=True, + ) + + self.lstm_decoder = LSTM( + input_size=self.hparams.hidden_size, + hidden_size=self.hparams.hidden_size, + num_layers=self.hparams.lstm_layers, + dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0, + batch_first=True, + ) + + # skip connection for lstm + self.post_lstm_gate_encoder = GatedLinearUnit( + self.hparams.hidden_size, dropout=self.hparams.dropout + ) + self.post_lstm_gate_decoder = self.post_lstm_gate_encoder + # self.post_lstm_gate_decoder = GatedLinearUnit( + # self.hparams.hidden_size, dropout=self.hparams.dropout) + self.post_lstm_add_norm_encoder = AddNorm( + self.hparams.hidden_size, trainable_add=False + ) + # self.post_lstm_add_norm_decoder = AddNorm( + # self.hparams.hidden_size, trainable_add=True) + self.post_lstm_add_norm_decoder = self.post_lstm_add_norm_encoder + + # static enrichment and processing past LSTM + self.static_enrichment = GatedResidualNetwork( + input_size=self.hparams.hidden_size, + hidden_size=self.hparams.hidden_size, + output_size=self.hparams.hidden_size, + dropout=self.hparams.dropout, + context_size=self.hparams.hidden_size, + ) + + # attention for long-range processing + self.multihead_attn = InterpretableMultiHeadAttention( + d_model=self.hparams.hidden_size, + n_head=self.hparams.attention_head_size, + dropout=self.hparams.dropout, + ) + self.post_attn_gate_norm = GateAddNorm( + self.hparams.hidden_size, dropout=self.hparams.dropout, trainable_add=False + ) + self.pos_wise_ff = GatedResidualNetwork( + self.hparams.hidden_size, + self.hparams.hidden_size, + self.hparams.hidden_size, + dropout=self.hparams.dropout, + ) + + # output processing -> no dropout at this late stage + self.pre_output_gate_norm = GateAddNorm( + self.hparams.hidden_size, dropout=None, trainable_add=False + ) + + if self.n_targets > 1: # if to run with multiple targets + self.output_layer = nn.ModuleList( + [ + nn.Linear(self.hparams.hidden_size, output_size) + for output_size in self.hparams.output_size + ] + ) + else: + self.output_layer = nn.Linear( + self.hparams.hidden_size, self.hparams.output_size + ) + + @classmethod + def from_dataset( + cls, + dataset: TimeSeriesDataSet, + allowed_encoder_known_variable_names: List[str] = None, + **kwargs, + ): + """ + Create model from dataset. + + Args: + dataset: timeseries dataset + allowed_encoder_known_variable_names: List of known variables that are allowed in encoder, defaults to all + **kwargs: additional arguments such as hyperparameters for model (see ``__init__()``) + + Returns: + TemporalFusionTransformer + """ # noqa: E501 + # add maximum encoder length + # update defaults + new_kwargs = copy(kwargs) + new_kwargs["max_encoder_length"] = dataset.max_encoder_length + new_kwargs.update( + cls.deduce_default_output_parameters(dataset, kwargs, QuantileLoss()) + ) + + # create class and return + return super().from_dataset( + dataset, + allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, + **new_kwargs, + ) + + def expand_static_context(self, context, timesteps): + """ + add time dimension to static context + """ + return context[:, None].expand(-1, timesteps, -1) + + def get_attention_mask( + self, encoder_lengths: torch.LongTensor, decoder_lengths: torch.LongTensor + ): + """ + Returns causal mask to apply for self-attention layer. + """ + decoder_length = decoder_lengths.max() + if self.hparams.causal_attention: + # indices to which is attended + attend_step = torch.arange(decoder_length, device=self.device) + # indices for which is predicted + predict_step = torch.arange(0, decoder_length, device=self.device)[:, None] + # do not attend to steps to self or after prediction + decoder_mask = ( + (attend_step >= predict_step) + .unsqueeze(0) + .expand(encoder_lengths.size(0), -1, -1) + ) + else: + # there is value in attending to future forecasts if + # they are made with knowledge currently available + # one possibility is here to use a second attention layer + # for future attention + # (assuming different effects matter in the future than the past) + # or alternatively using the same layer but + # allowing forward attention - i.e. only + # masking out non-available data and self + decoder_mask = ( + create_mask(decoder_length, decoder_lengths) + .unsqueeze(1) + .expand(-1, decoder_length, -1) + ) + # do not attend to steps where data is padded + encoder_mask = ( + create_mask(encoder_lengths.max(), encoder_lengths) + .unsqueeze(1) + .expand(-1, decoder_length, -1) + ) + # combine masks along attended time - first encoder and then decoder + mask = torch.cat( + ( + encoder_mask, + decoder_mask, + ), + dim=2, + ) + return mask + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + input dimensions: n_samples x time x variables + """ + encoder_lengths = x["encoder_lengths"] + decoder_lengths = x["decoder_lengths"] + x_cat = torch.cat( + [x["encoder_cat"], x["decoder_cat"]], dim=1 + ) # concatenate in time dimension + + ############different from _tft########################### + encoder_cont = x["encoder_cont"] + decoder_cont = x["decoder_cont"] + if encoder_cont.shape[-1] != decoder_cont.shape[-1]: + max_features = max(encoder_cont.shape[-1], decoder_cont.shape[-1]) + + if encoder_cont.shape[-1] < max_features: + encoder_padding = torch.zeros( + encoder_cont.shape[0], + encoder_cont.shape[1], + max_features - encoder_cont.shape[-1], + ).to(encoder_cont.device) + encoder_cont = torch.cat([encoder_cont, encoder_padding], dim=-1) + + if decoder_cont.shape[-1] < max_features: + decoder_padding = torch.zeros( + decoder_cont.shape[0], + decoder_cont.shape[1], + max_features - decoder_cont.shape[-1], + ).to(decoder_cont.device) + decoder_cont = torch.cat([decoder_cont, decoder_padding], dim=-1) + x_cont = torch.cat( + [encoder_cont, decoder_cont], dim=1 + ) # concatenate in time dimension + ########## + + timesteps = x_cont.size(1) # encode + decode length + max_encoder_length = int(encoder_lengths.max()) + input_vectors = self.input_embeddings(x_cat) + + ############different from _tft########################### + available_reals = [name for name in self.hparams.x_reals if name in self.reals] + max_features = x_cont.shape[-1] + + input_vectors.update( + { + name: x_cont[..., idx].unsqueeze(-1) + for idx, name in enumerate(available_reals) + if idx < max_features + } + ) + + all_expected_vars = set(self.encoder_variables + self.decoder_variables) + real_vars_in_input = set(input_vectors.keys()) + missing_vars = all_expected_vars - real_vars_in_input + + for var_name in missing_vars: + if var_name.startswith(("encoder_real_", "decoder_real_")): + # Create zero tensor with same shape as other real variables + zero_tensor = torch.zeros_like(x_cont[..., 0].unsqueeze(-1)) + input_vectors[var_name] = zero_tensor + ############ + + # Embedding and variable selection + if len(self.static_variables) > 0: + # static embeddings will be constant over entire batch + static_embedding = { + name: input_vectors[name][:, 0] for name in self.static_variables + } + static_embedding, static_variable_selection = ( + self.static_variable_selection(static_embedding) + ) + else: + static_embedding = torch.zeros( + (x_cont.size(0), self.hparams.hidden_size), + dtype=self.dtype, + device=self.device, + ) + static_variable_selection = torch.zeros( + (x_cont.size(0), 0), dtype=self.dtype, device=self.device + ) + + static_context_variable_selection = self.expand_static_context( + self.static_context_variable_selection(static_embedding), timesteps + ) + + embeddings_varying_encoder = { + name: input_vectors[name][:, :max_encoder_length] + for name in self.encoder_variables + } + embeddings_varying_encoder, encoder_sparse_weights = ( + self.encoder_variable_selection( + embeddings_varying_encoder, + static_context_variable_selection[:, :max_encoder_length], + ) + ) + + embeddings_varying_decoder = { + name: input_vectors[name][:, max_encoder_length:] + for name in self.decoder_variables # select decoder + } + embeddings_varying_decoder, decoder_sparse_weights = ( + self.decoder_variable_selection( + embeddings_varying_decoder, + static_context_variable_selection[:, max_encoder_length:], + ) + ) + + # LSTM + # calculate initial state + input_hidden = self.static_context_initial_hidden_lstm(static_embedding).expand( + self.hparams.lstm_layers, -1, -1 + ) + input_cell = self.static_context_initial_cell_lstm(static_embedding).expand( + self.hparams.lstm_layers, -1, -1 + ) + + # run local encoder + encoder_output, (hidden, cell) = self.lstm_encoder( + embeddings_varying_encoder, + (input_hidden, input_cell), + lengths=encoder_lengths, + enforce_sorted=False, + ) + + # run local decoder + decoder_output, _ = self.lstm_decoder( + embeddings_varying_decoder, + (hidden, cell), + lengths=decoder_lengths, + enforce_sorted=False, + ) + + # skip connection over lstm + lstm_output_encoder = self.post_lstm_gate_encoder(encoder_output) + lstm_output_encoder = self.post_lstm_add_norm_encoder( + lstm_output_encoder, embeddings_varying_encoder + ) + + lstm_output_decoder = self.post_lstm_gate_decoder(decoder_output) + lstm_output_decoder = self.post_lstm_add_norm_decoder( + lstm_output_decoder, embeddings_varying_decoder + ) + + lstm_output = torch.cat([lstm_output_encoder, lstm_output_decoder], dim=1) + + # static enrichment + static_context_enrichment = self.static_context_enrichment(static_embedding) + attn_input = self.static_enrichment( + lstm_output, + self.expand_static_context(static_context_enrichment, timesteps), + ) + + # Attention + attn_output, attn_output_weights = self.multihead_attn( + q=attn_input[:, max_encoder_length:], # query only for predictions + k=attn_input, + v=attn_input, + mask=self.get_attention_mask( + encoder_lengths=encoder_lengths, decoder_lengths=decoder_lengths + ), + ) + + # skip connection over attention + attn_output = self.post_attn_gate_norm( + attn_output, attn_input[:, max_encoder_length:] + ) + + output = self.pos_wise_ff(attn_output) + + # skip connection over temporal fusion decoder (not LSTM decoder + # despite the LSTM output contains + # a skip from the variable selection network) + output = self.pre_output_gate_norm(output, lstm_output[:, max_encoder_length:]) + if self.n_targets > 1: # if to use multi-target architecture + output = [output_layer(output) for output_layer in self.output_layer] + else: + output = self.output_layer(output) + + return self.to_network_output( + prediction=self.transform_output(output, target_scale=x["target_scale"]), + encoder_attention=attn_output_weights[..., :max_encoder_length], + decoder_attention=attn_output_weights[..., max_encoder_length:], + static_variables=static_variable_selection, + encoder_variables=encoder_sparse_weights, + decoder_variables=decoder_sparse_weights, + decoder_lengths=decoder_lengths, + encoder_lengths=encoder_lengths, + ) + + def on_fit_end(self): + if self.log_interval > 0: + self.log_embeddings() + + def create_log(self, x, y, out, batch_idx, **kwargs): + log = super().create_log(x, y, out, batch_idx, **kwargs) + if self.log_interval > 0: + log["interpretation"] = self._log_interpretation(out) + return log + + def _log_interpretation(self, out): + # calculate interpretations etc for latter logging + interpretation = self.interpret_output( + detach(out), + reduction="sum", + attention_prediction_horizon=0, # attention only for first prediction horizon # noqa: E501 + ) + return interpretation + + def on_epoch_end(self, outputs): + """ + run at epoch end for training or validation + """ + if self.log_interval > 0 and not self.training: + self.log_interpretation(outputs) + + def interpret_output( + self, + out: Dict[str, torch.Tensor], + reduction: str = "none", + attention_prediction_horizon: int = 0, + ) -> Dict[str, torch.Tensor]: + """ + interpret output of model + + Args: + out: output as produced by ``forward()`` + reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for + normalizing by encode lengths + attention_prediction_horizon: which prediction horizon to use for attention + + Returns: + interpretations that can be plotted with ``plot_interpretation()`` + """ # noqa: E501 + # take attention and concatenate if a list to proper attention object + batch_size = len(out["decoder_attention"]) + if isinstance(out["decoder_attention"], (list, tuple)): + # start with decoder attention + # assume issue is in last dimension, we need to find max + max_last_dimension = max(x.size(-1) for x in out["decoder_attention"]) + first_elm = out["decoder_attention"][0] + # create new attention tensor into which we will scatter + decoder_attention = torch.full( + (batch_size, *first_elm.shape[:-1], max_last_dimension), + float("nan"), + dtype=first_elm.dtype, + device=first_elm.device, + ) + # scatter into tensor + for idx, x in enumerate(out["decoder_attention"]): + decoder_length = out["decoder_lengths"][idx] + decoder_attention[idx, :, :, :decoder_length] = x[..., :decoder_length] + else: + decoder_attention = out["decoder_attention"].clone() + decoder_mask = create_mask( + out["decoder_attention"].size(1), out["decoder_lengths"] + ) + decoder_attention[ + decoder_mask[..., None, None].expand_as(decoder_attention) + ] = float("nan") + + if isinstance(out["encoder_attention"], (list, tuple)): + # same game for encoder attention + # create new attention tensor into which we will scatter + first_elm = out["encoder_attention"][0] + encoder_attention = torch.full( + (batch_size, *first_elm.shape[:-1], self.hparams.max_encoder_length), + float("nan"), + dtype=first_elm.dtype, + device=first_elm.device, + ) + # scatter into tensor + for idx, x in enumerate(out["encoder_attention"]): + encoder_length = out["encoder_lengths"][idx] + encoder_attention[ + idx, :, :, self.hparams.max_encoder_length - encoder_length : + ] = x[..., :encoder_length] + else: + # roll encoder attention (so start last encoder value is on the right) + encoder_attention = out["encoder_attention"].clone() + shifts = encoder_attention.size(3) - out["encoder_lengths"] + new_index = ( + torch.arange( + encoder_attention.size(3), device=encoder_attention.device + )[None, None, None].expand_as(encoder_attention) + - shifts[:, None, None, None] + ) % encoder_attention.size(3) + encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index) + # expand encoder_attention to full size + if encoder_attention.size(-1) < self.hparams.max_encoder_length: + encoder_attention = torch.concat( + [ + torch.full( + ( + *encoder_attention.shape[:-1], + self.hparams.max_encoder_length + - out["encoder_lengths"].max(), + ), + float("nan"), + dtype=encoder_attention.dtype, + device=encoder_attention.device, + ), + encoder_attention, + ], + dim=-1, + ) + + # combine attention vector + attention = torch.concat([encoder_attention, decoder_attention], dim=-1) + attention[attention < 1e-5] = float("nan") + + # histogram of decode and encode lengths + encoder_length_histogram = integer_histogram( + out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length + ) + decoder_length_histogram = integer_histogram( + out["decoder_lengths"], min=1, max=out["decoder_variables"].size(1) + ) + + # mask where decoder and encoder where not applied + # when averaging variable selection weights + encoder_variables = out["encoder_variables"].squeeze(-2).clone() + encode_mask = create_mask(encoder_variables.size(1), out["encoder_lengths"]) + encoder_variables = encoder_variables.masked_fill( + encode_mask.unsqueeze(-1), 0.0 + ).sum(dim=1) + encoder_variables /= ( + out["encoder_lengths"] + .where(out["encoder_lengths"] > 0, torch.ones_like(out["encoder_lengths"])) + .unsqueeze(-1) + ) + + decoder_variables = out["decoder_variables"].squeeze(-2).clone() + decode_mask = create_mask(decoder_variables.size(1), out["decoder_lengths"]) + decoder_variables = decoder_variables.masked_fill( + decode_mask.unsqueeze(-1), 0.0 + ).sum(dim=1) + decoder_variables /= out["decoder_lengths"].unsqueeze(-1) + + # static variables need no masking + static_variables = out["static_variables"].squeeze(1) + # attention is batch x time x heads x time_to_attend + # average over heads + only keep prediction attention and + # attention on observed timesteps + attention = masked_op( + attention[ + :, + attention_prediction_horizon, + :, + : self.hparams.max_encoder_length + attention_prediction_horizon, + ], + op="mean", + dim=1, + ) + + if reduction != "none": # if to average over batches + static_variables = static_variables.sum(dim=0) + encoder_variables = encoder_variables.sum(dim=0) + decoder_variables = decoder_variables.sum(dim=0) + + attention = masked_op(attention, dim=0, op=reduction) + else: + attention = attention / masked_op(attention, dim=1, op="sum").unsqueeze( + -1 + ) # renormalize + + interpretation = dict( + attention=attention.masked_fill(torch.isnan(attention), 0.0), + static_variables=static_variables, + encoder_variables=encoder_variables, + decoder_variables=decoder_variables, + encoder_length_histogram=encoder_length_histogram, + decoder_length_histogram=decoder_length_histogram, + ) + return interpretation + + def plot_prediction( + self, + x: Dict[str, torch.Tensor], + out: Dict[str, torch.Tensor], + idx: int, + plot_attention: bool = True, + add_loss_to_title: bool = False, + show_future_observed: bool = True, + ax=None, + **kwargs, + ): + """ + Plot actuals vs prediction and attention + + Args: + x (Dict[str, torch.Tensor]): network input + out (Dict[str, torch.Tensor]): network output + idx (int): sample index + plot_attention: if to plot attention on secondary axis + add_loss_to_title: if to add loss to title. Default to False. + show_future_observed: if to show actuals for future. Defaults to True. + ax: matplotlib axes to plot on + + Returns: + plt.Figure: matplotlib figure + """ + # plot prediction as normal + fig = super().plot_prediction( + x, + out, + idx=idx, + add_loss_to_title=add_loss_to_title, + show_future_observed=show_future_observed, + ax=ax, + **kwargs, + ) + + # add attention on secondary axis + if plot_attention: + interpretation = self.interpret_output(out.iget(slice(idx, idx + 1))) + for f in to_list(fig): + ax = f.axes[0] + ax2 = ax.twinx() + ax2.set_ylabel("Attention") + encoder_length = x["encoder_lengths"][0] + ax2.plot( + torch.arange(-encoder_length, 0), + interpretation["attention"][0, -encoder_length:].detach().cpu(), + alpha=0.2, + color="k", + ) + f.tight_layout() + return fig + + def plot_interpretation(self, interpretation: Dict[str, torch.Tensor]): + """ + Make figures that interpret model. + + * Attention + * Variable selection weights / importances + + Args: + interpretation: as obtained from ``interpret_output()`` + + Returns: + dictionary of matplotlib figures + """ + _check_matplotlib("plot_interpretation") + + import matplotlib.pyplot as plt + + figs = {} + + # attention + fig, ax = plt.subplots() + attention = interpretation["attention"].detach().cpu() + attention = attention / attention.sum(-1).unsqueeze(-1) + ax.plot( + np.arange( + -self.hparams.max_encoder_length, + attention.size(0) - self.hparams.max_encoder_length, + ), + attention, + ) + ax.set_xlabel("Time index") + ax.set_ylabel("Attention") + ax.set_title("Attention") + figs["attention"] = fig + + # variable selection + def make_selection_plot(title, values, labels): + fig, ax = plt.subplots(figsize=(7, len(values) * 0.25 + 2)) + order = np.argsort(values) + values = values / values.sum(-1).unsqueeze(-1) + ax.barh( + np.arange(len(values)), + values[order] * 100, + tick_label=np.asarray(labels)[order], + ) + ax.set_title(title) + ax.set_xlabel("Importance in %") + plt.tight_layout() + return fig + + figs["static_variables"] = make_selection_plot( + "Static variables importance", + interpretation["static_variables"].detach().cpu(), + self.static_variables, + ) + figs["encoder_variables"] = make_selection_plot( + "Encoder variables importance", + interpretation["encoder_variables"].detach().cpu(), + self.encoder_variables, + ) + figs["decoder_variables"] = make_selection_plot( + "Decoder variables importance", + interpretation["decoder_variables"].detach().cpu(), + self.decoder_variables, + ) + + return figs + + def log_interpretation(self, outputs): + """ + Log interpretation metrics to tensorboard. + """ + # extract interpretations + interpretation = { + # use padded_stack because decoder + # length histogram can be of different length + name: padded_stack( + [x["interpretation"][name].detach() for x in outputs], + side="right", + value=0, + ).sum(0) + for name in outputs[0]["interpretation"].keys() + } + # normalize attention with length histogram squared to account for: + # 1. zeros in attention and + # 2. higher attention due to less values + attention_occurances = ( + interpretation["encoder_length_histogram"][1:].flip(0).float().cumsum(0) + ) + attention_occurances = attention_occurances / attention_occurances.max() + attention_occurances = torch.cat( + [ + attention_occurances, + torch.ones( + interpretation["attention"].size(0) - attention_occurances.size(0), + dtype=attention_occurances.dtype, + device=attention_occurances.device, + ), + ], + dim=0, + ) + interpretation["attention"] = interpretation[ + "attention" + ] / attention_occurances.pow(2).clamp(1.0) + interpretation["attention"] = ( + interpretation["attention"] / interpretation["attention"].sum() + ) + + mpl_available = _check_matplotlib("log_interpretation", raise_error=False) + + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): + return None + + import matplotlib.pyplot as plt + + figs = self.plot_interpretation(interpretation) # make interpretation figures + label = self.current_stage + # log to tensorboard + for name, fig in figs.items(): + self.logger.experiment.add_figure( + f"{label.capitalize()} {name} importance", + fig, + global_step=self.global_step, + ) + + # log lengths of encoder/decoder + for type in ["encoder", "decoder"]: + fig, ax = plt.subplots() + lengths = ( + padded_stack( + [ + out["interpretation"][f"{type}_length_histogram"] + for out in outputs + ] + ) + .sum(0) + .detach() + .cpu() + ) + if type == "decoder": + start = 1 + else: + start = 0 + ax.plot(torch.arange(start, start + len(lengths)), lengths) + ax.set_xlabel(f"{type.capitalize()} length") + ax.set_ylabel("Number of samples") + ax.set_title(f"{type.capitalize()} length distribution in {label} epoch") + + self.logger.experiment.add_figure( + f"{label.capitalize()} {type} length distribution", + fig, + global_step=self.global_step, + ) + + def log_embeddings(self): + """ + Log embeddings to tensorboard + """ + + # Don't log embeddings if add_embedding is not available + if not self._logger_supports("add_embedding"): + return None + + for name, emb in self.input_embeddings.items(): + labels = self.hparams.embedding_labels[name] + self.logger.experiment.add_embedding( + emb.weight.data.detach().cpu(), + metadata=labels, + tag=name, + global_step=self.global_step, + ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py new file mode 100644 index 000000000..5ebea75b1 --- /dev/null +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py @@ -0,0 +1,90 @@ +"""TFT metadata container.""" + +from pytorch_forecasting.models.base._base_object import _BasePtForecaster + + +class TFTMetadata(_BasePtForecaster): + """TFT metadata container.""" + + _tags = { + "info:name": "TFT", + "object_type": "ptf-v2", + "authors": ["phoeenniixx"], + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": False, + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT + + return TFT + + @classmethod + def get_test_train_params(cls): + """Return testing parameter settings for the trainer. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + return [ + {}, + dict( + hidden_size=25, + attention_head_size=5, + ), + ] + + +class TemporalFusionTransformerMetadata(_BasePtForecaster): + """TFT metadata container.""" + + _tags = { + "info:name": "TemporalFusionTransformerM", + "object_type": "ptf-v2", + "authors": ["jdb78"], + "capability:exogenous": True, + "capability:multivariate": True, + "capability:pred_int": True, + "capability:flexible_history_length": False, + } + + @classmethod + def get_model_cls(cls): + """Get model class.""" + from pytorch_forecasting.models.temporal_fusion_transformer._tft_ver2 import ( + TemporalFusionTransformer, + ) + + return TemporalFusionTransformer + + @classmethod + def get_test_train_params(cls): + """Return testing parameter settings for the trainer. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + return [ + {}, + dict( + hidden_size=25, + attention_head_size=5, + data_loader_kwargs={ + "add_relative_time_idx": False, + }, + ), + ] diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index e276446a6..76d280f61 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -1,10 +1,15 @@ +from datetime import datetime, timedelta + import numpy as np +import pandas as pd import pytest import torch from pytorch_forecasting import TimeSeriesDataSet from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data +from pytorch_forecasting.data.timeseries import TimeSeries torch.manual_seed(23) @@ -88,6 +93,233 @@ def make_dataloaders(data_with_covariates, **kwargs): return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) +@pytest.fixture(scope="session") +def data_with_covariates_v2(): + """Create synthetic time series data with all numerical features.""" + + start_date = datetime(2015, 1, 1) + end_date = datetime(2017, 12, 31) + dates = pd.date_range(start_date, end_date, freq="M") + + agencies = [0, 1] + skus = [0, 1] + data_list = [] + + for agency in agencies: + for sku in skus: + for date in dates: + time_idx = (date.year - 2015) * 12 + date.month - 1 + + volume = ( + np.random.exponential(2) + + 0.1 * time_idx + + 0.5 * np.sin(date.month * np.pi / 6) + ) + volume = max(0.001, volume) + month = date.month + year = date.year + quarter = (date.month - 1) // 3 + 1 + + seasonal_1 = np.sin(2 * np.pi * date.month / 12) + seasonal_2 = np.cos(2 * np.pi * date.month / 12) + + agency_feature_1 = agency * 10 + np.random.normal(0, 0.1) + agency_feature_2 = agency * 5 + np.random.normal(0, 0.1) + + sku_feature_1 = sku * 8 + np.random.normal(0, 0.1) + sku_feature_2 = sku * 3 + np.random.normal(0, 0.1) + + trend = time_idx * 0.1 + noise = np.random.normal(0, 0.1) + + special_event_1 = 1 if date.month in [12, 1] else 0 + special_event_2 = 1 if date.month in [6, 7, 8] else 0 + + data_list.append( + { + "date": date, + "time_idx": time_idx, + "agency_encoded": agency, + "sku_encoded": sku, + "volume": volume, + "target": volume, + "weight": 1.0 + np.sqrt(volume), + "month": month, + "year": year, + "quarter": quarter, + "seasonal_1": seasonal_1, + "seasonal_2": seasonal_2, + "agency_feature_1": agency_feature_1, + "agency_feature_2": agency_feature_2, + "sku_feature_1": sku_feature_1, + "sku_feature_2": sku_feature_2, + "trend": trend, + "noise": noise, + "special_event_1": special_event_1, + "special_event_2": special_event_2, + "log_volume": np.log1p(volume), + } + ) + + data = pd.DataFrame(data_list) + + numeric_cols = [col for col in data.columns if col != "date"] + for col in numeric_cols: + data[col] = pd.to_numeric(data[col], errors="coerce") + data[numeric_cols] = data[numeric_cols].fillna(0) + + return data + + +def make_dataloaders_v2(data_with_covariates, **kwargs): + """Create dataloaders with consistent encoder/decoder features.""" + + training_cutoff = "2016-09-01" + max_encoder_length = 4 + max_prediction_length = 3 + + target_col = kwargs.get("target", "target") + group_cols = kwargs.get("group_ids", ["agency_encoded", "sku_encoded"]) + add_relative_time_idx = kwargs.get("add_relative_time_idx", True) + + known_features = [ + "month", + "year", + "quarter", + "seasonal_1", + "seasonal_2", + "special_event_1", + "special_event_2", + "trend", + ] + unknown_features = [ + "agency_feature_1", + "agency_feature_2", + "sku_feature_1", + "sku_feature_2", + "noise", + "log_volume", + ] + + numerical_features = known_features + unknown_features + categorical_features = [] + static_features = group_cols + + for col in numerical_features + categorical_features + group_cols + [target_col]: + if col in data_with_covariates.columns: + data_with_covariates[col] = pd.to_numeric( + data_with_covariates[col], errors="coerce" + ).fillna(0) + + for col in categorical_features + group_cols: + if col in data_with_covariates.columns: + data_with_covariates[col] = data_with_covariates[col].astype("int64") + + if "weight" in data_with_covariates.columns: + data_with_covariates["weight"] = pd.to_numeric( + data_with_covariates["weight"], errors="coerce" + ).fillna(1.0) + + training_data = data_with_covariates[ + data_with_covariates.date < training_cutoff + ].copy() + validation_data = data_with_covariates.copy() + + required_columns = ( + ["time_idx", target_col, "weight", "date"] + + group_cols + + numerical_features + + categorical_features + ) + + available_columns = [ + col for col in required_columns if col in data_with_covariates.columns + ] + + training_data_clean = training_data[available_columns].copy() + validation_data_clean = validation_data[available_columns].copy() + + if "date" in training_data_clean.columns: + training_data_clean = training_data_clean.drop("date", axis=1) + if "date" in validation_data_clean.columns: + validation_data_clean = validation_data_clean.drop("date", axis=1) + + training_dataset = TimeSeries( + data=training_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + validation_dataset = TimeSeries( + data=validation_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + training_max_time_idx = training_data["time_idx"].max() + 1 + + train_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=training_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + add_relative_time_idx=add_relative_time_idx, + batch_size=2, + num_workers=0, + train_val_test_split=(0.8, 0.2, 0.0), + ) + + val_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=2, + num_workers=0, + train_val_test_split=(0.0, 1.0, 0.0), + ) + + test_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=1, + num_workers=0, + train_val_test_split=(0.0, 0.0, 1.0), + ) + + train_datamodule.setup("fit") + val_datamodule.setup("fit") + test_datamodule.setup("test") + + train_dataloader = train_datamodule.train_dataloader() + val_dataloader = val_datamodule.val_dataloader() + test_dataloader = test_datamodule.test_dataloader() + + return { + "train": train_dataloader, + "val": val_dataloader, + "test": test_dataloader, + "data_module": train_datamodule, + } + + @pytest.fixture( params=[ dict(), diff --git a/pytorch_forecasting/tests/_data_scenarios.py b/pytorch_forecasting/tests/_data_scenarios.py index 062db97dd..b3037b3f6 100644 --- a/pytorch_forecasting/tests/_data_scenarios.py +++ b/pytorch_forecasting/tests/_data_scenarios.py @@ -1,10 +1,15 @@ +from datetime import datetime, timedelta + import numpy as np +import pandas as pd import pytest import torch from pytorch_forecasting import TimeSeriesDataSet from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data +from pytorch_forecasting.data.timeseries import TimeSeries torch.manual_seed(23) @@ -87,6 +92,232 @@ def make_dataloaders(data_with_covariates, **kwargs): return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) +def data_with_covariates_v2(): + """Create synthetic time series data with all numerical features.""" + + start_date = datetime(2015, 1, 1) + end_date = datetime(2017, 12, 31) + dates = pd.date_range(start_date, end_date, freq="M") + + agencies = [0, 1] + skus = [0, 1] + data_list = [] + + for agency in agencies: + for sku in skus: + for date in dates: + time_idx = (date.year - 2015) * 12 + date.month - 1 + + volume = ( + np.random.exponential(2) + + 0.1 * time_idx + + 0.5 * np.sin(date.month * np.pi / 6) + ) + volume = max(0.001, volume) + month = date.month + year = date.year + quarter = (date.month - 1) // 3 + 1 + + seasonal_1 = np.sin(2 * np.pi * date.month / 12) + seasonal_2 = np.cos(2 * np.pi * date.month / 12) + + agency_feature_1 = agency * 10 + np.random.normal(0, 0.1) + agency_feature_2 = agency * 5 + np.random.normal(0, 0.1) + + sku_feature_1 = sku * 8 + np.random.normal(0, 0.1) + sku_feature_2 = sku * 3 + np.random.normal(0, 0.1) + + trend = time_idx * 0.1 + noise = np.random.normal(0, 0.1) + + special_event_1 = 1 if date.month in [12, 1] else 0 + special_event_2 = 1 if date.month in [6, 7, 8] else 0 + + data_list.append( + { + "date": date, + "time_idx": time_idx, + "agency_encoded": agency, + "sku_encoded": sku, + "volume": volume, + "target": volume, + "weight": 1.0 + np.sqrt(volume), + "month": month, + "year": year, + "quarter": quarter, + "seasonal_1": seasonal_1, + "seasonal_2": seasonal_2, + "agency_feature_1": agency_feature_1, + "agency_feature_2": agency_feature_2, + "sku_feature_1": sku_feature_1, + "sku_feature_2": sku_feature_2, + "trend": trend, + "noise": noise, + "special_event_1": special_event_1, + "special_event_2": special_event_2, + "log_volume": np.log1p(volume), + } + ) + + data = pd.DataFrame(data_list) + + numeric_cols = [col for col in data.columns if col != "date"] + for col in numeric_cols: + data[col] = pd.to_numeric(data[col], errors="coerce") + data[numeric_cols] = data[numeric_cols].fillna(0) + + return data + + +def make_dataloaders_v2(data_with_covariates, **kwargs): + """Create dataloaders with consistent encoder/decoder features.""" + + training_cutoff = "2016-09-01" + max_encoder_length = 4 + max_prediction_length = 3 + + target_col = kwargs.get("target", "target") + group_cols = kwargs.get("group_ids", ["agency_encoded", "sku_encoded"]) + add_relative_time_idx = kwargs.get("add_relative_time_idx", True) + + known_features = [ + "month", + "year", + "quarter", + "seasonal_1", + "seasonal_2", + "special_event_1", + "special_event_2", + "trend", + ] + unknown_features = [ + "agency_feature_1", + "agency_feature_2", + "sku_feature_1", + "sku_feature_2", + "noise", + "log_volume", + ] + + numerical_features = known_features + unknown_features + categorical_features = [] + static_features = group_cols + + for col in numerical_features + categorical_features + group_cols + [target_col]: + if col in data_with_covariates.columns: + data_with_covariates[col] = pd.to_numeric( + data_with_covariates[col], errors="coerce" + ).fillna(0) + + for col in categorical_features + group_cols: + if col in data_with_covariates.columns: + data_with_covariates[col] = data_with_covariates[col].astype("int64") + + if "weight" in data_with_covariates.columns: + data_with_covariates["weight"] = pd.to_numeric( + data_with_covariates["weight"], errors="coerce" + ).fillna(1.0) + + training_data = data_with_covariates[ + data_with_covariates.date < training_cutoff + ].copy() + validation_data = data_with_covariates.copy() + + required_columns = ( + ["time_idx", target_col, "weight", "date"] + + group_cols + + numerical_features + + categorical_features + ) + + available_columns = [ + col for col in required_columns if col in data_with_covariates.columns + ] + + training_data_clean = training_data[available_columns].copy() + validation_data_clean = validation_data[available_columns].copy() + + if "date" in training_data_clean.columns: + training_data_clean = training_data_clean.drop("date", axis=1) + if "date" in validation_data_clean.columns: + validation_data_clean = validation_data_clean.drop("date", axis=1) + + training_dataset = TimeSeries( + data=training_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + validation_dataset = TimeSeries( + data=validation_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + training_max_time_idx = training_data["time_idx"].max() + 1 + + train_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=training_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + add_relative_time_idx=add_relative_time_idx, + batch_size=2, + num_workers=0, + train_val_test_split=(0.8, 0.2, 0.0), + ) + + val_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=2, + num_workers=0, + train_val_test_split=(0.0, 1.0, 0.0), + ) + + test_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=1, + num_workers=0, + train_val_test_split=(0.0, 0.0, 1.0), + ) + + train_datamodule.setup("fit") + val_datamodule.setup("fit") + test_datamodule.setup("test") + + train_dataloader = train_datamodule.train_dataloader() + val_dataloader = val_datamodule.val_dataloader() + test_dataloader = test_datamodule.test_dataloader() + + return { + "train": train_dataloader, + "val": val_dataloader, + "test": test_dataloader, + "data_module": train_datamodule, + } + + @pytest.fixture( params=[ dict(), diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index 1dc7859ab..efbe170f0 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -7,7 +7,8 @@ from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger -from pytorch_forecasting.tests._conftest import make_dataloaders +from pytorch_forecasting.metrics import SMAPE +from pytorch_forecasting.tests._conftest import make_dataloaders_v2 as make_dataloaders from pytorch_forecasting.tests.test_all_estimators import ( BaseFixtureGenerator, PackageConfig, @@ -22,18 +23,86 @@ def _integration( estimator_cls, data_with_covariates, tmp_path, - cell_type="LSTM", data_loader_kwargs={}, clip_target: bool = False, trainer_kwargs=None, **kwargs, ): - pass + data_with_covariates = data_with_covariates.copy() + if clip_target: + data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0) + else: + data_with_covariates["target"] = data_with_covariates["volume"] + + data_loader_default_kwargs = dict( + target="target", + group_ids=["agency_encoded", "sku_encoded"], + add_relative_time_idx=True, + ) + data_loader_default_kwargs.update(data_loader_kwargs) + + dataloaders_with_covariates = make_dataloaders( + data_with_covariates, **data_loader_default_kwargs + ) + + train_dataloader = dataloaders_with_covariates["train"] + val_dataloader = dataloaders_with_covariates["val"] + test_dataloader = dataloaders_with_covariates["test"] + + early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" + ) + + logger = TensorBoardLogger(tmp_path) + if trainer_kwargs is None: + trainer_kwargs = {} + trainer = pl.Trainer( + max_epochs=3, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + enable_checkpointing=True, + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + logger=logger, + **trainer_kwargs, + ) + training_data_module = dataloaders_with_covariates["data_module"] + metadata = training_data_module.metadata + + net = estimator_cls( + metadata=metadata, + loss=SMAPE(), + **kwargs, + ) + + try: + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + + # check loading + # net = estimator_cls.load_from_checkpoint( + # trainer.checkpoint_callback.best_model_path + # ) + # net.predict(val_dataloader) + + finally: + shutil.rmtree(tmp_path, ignore_errors=True) + + # net.predict(val_dataloader) class TestAllPtForecastersV2(PackageConfig, BaseFixtureGenerator): """Generic tests for all objects in the mini package.""" + object_type_filter = "ptf-v2" + def test_doctest_examples(self, object_class): """Runs doctests for estimator class.""" import doctest @@ -46,4 +115,8 @@ def test_integration( trainer_kwargs, tmp_path, ): - pass + from pytorch_forecasting.tests._data_scenarios import data_with_covariates_v2 + + data_with_covariates = data_with_covariates_v2() + object_class = object_metadata.get_model_cls() + _integration(object_class, data_with_covariates, tmp_path, **trainer_kwargs) From 1d478d5b13e68efe1b75f2e6bc17b2500bceaacb Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 28 May 2025 01:05:19 +0530 Subject: [PATCH 61/80] remove refactored TFT --- .../temporal_fusion_transformer/_tft_ver2.py | 1130 ----------------- .../tft_v2_metadata.py | 46 - .../tests/test_all_estimators.py | 15 + 3 files changed, 15 insertions(+), 1176 deletions(-) delete mode 100644 pytorch_forecasting/models/temporal_fusion_transformer/_tft_ver2.py diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_ver2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_ver2.py deleted file mode 100644 index cff63385c..000000000 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_ver2.py +++ /dev/null @@ -1,1130 +0,0 @@ -""" -The temporal fusion transformer is a powerful predictive model for forecasting timeseries -""" # noqa: E501 - -from copy import copy -from typing import Any, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -from torch import nn -from torchmetrics import Metric as LightningMetric - -from pytorch_forecasting.data import TimeSeriesDataSet -from pytorch_forecasting.metrics import ( - MAE, - MAPE, - RMSE, - SMAPE, - MultiHorizonMetric, - QuantileLoss, -) -from pytorch_forecasting.models.base import BaseModelWithCovariates -from pytorch_forecasting.models.nn import LSTM, MultiEmbedding -from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import ( - AddNorm, - GateAddNorm, - GatedLinearUnit, - GatedResidualNetwork, - InterpretableMultiHeadAttention, - VariableSelectionNetwork, -) -from pytorch_forecasting.utils import ( - create_mask, - detach, - integer_histogram, - masked_op, - padded_stack, - to_list, -) -from pytorch_forecasting.utils._dependencies import _check_matplotlib - - -class TemporalFusionTransformer(BaseModelWithCovariates): - def __init__( - self, - metadata: Dict[str, Any], - hidden_size: int = 16, - lstm_layers: int = 1, - dropout: float = 0.1, - output_size: Union[int, List[int]] = None, - loss: MultiHorizonMetric = None, - attention_head_size: int = 4, - categorical_groups: Optional[Union[Dict, List[str]]] = None, - hidden_continuous_size: int = 8, - hidden_continuous_sizes: Optional[Dict[str, int]] = None, - embedding_sizes: Optional[Dict[str, Tuple[int, int]]] = None, - embedding_paddings: Optional[List[str]] = None, - embedding_labels: Optional[Dict[str, np.ndarray]] = None, - learning_rate: float = 1e-3, - log_interval: Union[int, float] = -1, - log_val_interval: Union[int, float] = None, - log_gradient_flow: bool = False, - reduce_on_plateau_patience: int = 1000, - monotone_constaints: Optional[Dict[str, int]] = None, - share_single_variable_networks: bool = False, - causal_attention: bool = True, - logging_metrics: nn.ModuleList = None, - **kwargs, - ): - """ - Temporal Fusion Transformer for forecasting timeseries - use its :py:meth:`~from_dataset` method if possible. - - Implementation of the article - `Temporal Fusion Transformers for Interpretable Multi-horizon Time Series - Forecasting `_. The network outperforms DeepAR by Amazon by 36-69% - in benchmarks. - - Enhancements compared to the original implementation (apart from capabilities added through base model - such as monotone constraints): - - * static variables can be continuous - * multiple categorical variables can be summarized with an EmbeddingBag - * variable encoder and decoder length by sample - * categorical embeddings are not transformed by variable selection network (because it is a redundant operation) - * variable dimension in variable selection network are scaled up via linear interpolation to reduce - number of parameters - * non-linear variable processing in variable selection network can be shared among decoder and encoder - (not shared by default) - - Tune its hyperparameters with - :py:func:`~pytorch_forecasting.models.temporal_fusion_transformer.tuning.optimize_hyperparameters`. - - Args: - - hidden_size: hidden size of network which is its main hyperparameter and can range from 8 to 512 - lstm_layers: number of LSTM layers (2 is mostly optimal) - dropout: dropout rate - output_size: number of outputs (e.g. number of quantiles for QuantileLoss and one target or list - of output sizes). - loss: loss function taking prediction and targets - attention_head_size: number of attention heads (4 is a good default) - max_encoder_length: length to encode (can be far longer than the decoder length but does not have to be) - static_categoricals: names of static categorical variables - static_reals: names of static continuous variables - time_varying_categoricals_encoder: names of categorical variables for encoder - time_varying_categoricals_decoder: names of categorical variables for decoder - time_varying_reals_encoder: names of continuous variables for encoder - time_varying_reals_decoder: names of continuous variables for decoder - categorical_groups: dictionary where values - are list of categorical variables that are forming together a new categorical - variable which is the key in the dictionary - x_reals: order of continuous variables in tensor passed to forward function - x_categoricals: order of categorical variables in tensor passed to forward function - hidden_continuous_size: default for hidden size for processing continous variables (similar to categorical - embedding size) - hidden_continuous_sizes: dictionary mapping continuous input indices to sizes for variable selection - (fallback to hidden_continuous_size if index is not in dictionary) - embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and - embedding size - embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector - embedding_labels: dictionary mapping (string) indices to list of categorical labels - learning_rate: learning rate - log_interval: log predictions every x batches, do not log if 0 or less, log interpretation if > 0. If < 1.0 - , will log multiple entries per batch. Defaults to -1. - log_val_interval: frequency with which to log validation set metrics, defaults to log_interval - log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training - failures - reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 - monotone_constaints (Dict[str, int]): dictionary of monotonicity constraints for continuous decoder - variables mapping - position (e.g. ``"0"`` for first position) to constraint (``-1`` for negative and ``+1`` for positive, - larger numbers add more weight to the constraint vs. the loss but are usually not necessary). - This constraint significantly slows down training. Defaults to {}. - share_single_variable_networks (bool): if to share the single variable networks between the encoder and - decoder. Defaults to False. - causal_attention (bool): If to attend only at previous timesteps in the decoder or also include future - predictions. Defaults to True. - logging_metrics (nn.ModuleList[LightningMetric]): list of metrics that are logged during training. - Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]). - **kwargs: additional arguments to :py:class:`~BaseModel`. - """ # noqa: E501 - - max_encoder_length = metadata["max_encoder_length"] - if output_size is None: - output_size = metadata["target"] - static_categoricals = [ - f"static_cat_{i}" - for i in range(metadata.get("static_categorical_features", 0)) - ] - static_reals = [ - f"static_real_{i}" - for i in range(metadata.get("static_continuous_features", 0)) - ] - time_varying_categoricals_encoder = [ - f"encoder_cat_{i}" for i in range(metadata["encoder_cat"]) - ] - time_varying_reals_encoder = [ - f"encoder_real_{i}" for i in range(metadata["encoder_cont"]) - ] - time_varying_categoricals_decoder = [ - f"decoder_cat_{i}" for i in range(metadata["decoder_cat"]) - ] - time_varying_reals_decoder = [ - f"decoder_real_{i}" for i in range(metadata["decoder_cont"]) - ] - x_categoricals = ( - static_categoricals - + time_varying_categoricals_encoder - + time_varying_categoricals_decoder - ) - x_reals = static_reals + time_varying_reals_encoder + time_varying_reals_decoder - - if monotone_constaints is None: - monotone_constaints = {} - if embedding_labels is None: - embedding_labels = {} - if embedding_paddings is None: - embedding_paddings = [] - if embedding_sizes is None: - embedding_sizes = {} - if hidden_continuous_sizes is None: - hidden_continuous_sizes = {} - if categorical_groups is None: - categorical_groups = {} - if logging_metrics is None: - logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE()]) - if loss is None: - loss = QuantileLoss() - self.save_hyperparameters(ignore=["metadata"]) - - self.hparams.max_encoder_length = max_encoder_length - self.hparams.output_size = output_size - self.hparams.static_categoricals = static_categoricals - self.hparams.static_reals = static_reals - self.hparams.time_varying_categoricals_encoder = ( - time_varying_categoricals_encoder - ) - self.hparams.time_varying_categoricals_decoder = ( - time_varying_categoricals_decoder - ) - self.hparams.time_varying_reals_encoder = time_varying_reals_encoder - self.hparams.time_varying_reals_decoder = time_varying_reals_decoder - self.hparams.x_categoricals = x_categoricals - self.hparams.x_reals = x_reals - # store loss function separately as it is a module - assert isinstance( - loss, LightningMetric - ), "Loss has to be a PyTorch Lightning `Metric`" - super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) - - # processing inputs - # embeddings - self.input_embeddings = MultiEmbedding( - embedding_sizes=self.hparams.embedding_sizes, - categorical_groups=self.hparams.categorical_groups, - embedding_paddings=self.hparams.embedding_paddings, - x_categoricals=self.hparams.x_categoricals, - max_embedding_size=self.hparams.hidden_size, - ) - - # continuous variable processing - self.prescalers = nn.ModuleDict( - { - name: nn.Linear( - 1, - self.hparams.hidden_continuous_sizes.get( - name, self.hparams.hidden_continuous_size - ), - ) - for name in self.reals - } - ) - - # variable selection - # variable selection for static variables - static_input_sizes = { - name: self.input_embeddings.output_size[name] - for name in self.hparams.static_categoricals - } - static_input_sizes.update( - { - name: self.hparams.hidden_continuous_sizes.get( - name, self.hparams.hidden_continuous_size - ) - for name in self.hparams.static_reals - } - ) - self.static_variable_selection = VariableSelectionNetwork( - input_sizes=static_input_sizes, - hidden_size=self.hparams.hidden_size, - input_embedding_flags={ - name: True for name in self.hparams.static_categoricals - }, - dropout=self.hparams.dropout, - prescalers=self.prescalers, - ) - - # variable selection for encoder and decoder - encoder_input_sizes = { - name: self.input_embeddings.output_size[name] - for name in self.hparams.time_varying_categoricals_encoder - } - encoder_input_sizes.update( - { - name: self.hparams.hidden_continuous_sizes.get( - name, self.hparams.hidden_continuous_size - ) - for name in self.hparams.time_varying_reals_encoder - } - ) - - decoder_input_sizes = { - name: self.input_embeddings.output_size[name] - for name in self.hparams.time_varying_categoricals_decoder - } - decoder_input_sizes.update( - { - name: self.hparams.hidden_continuous_sizes.get( - name, self.hparams.hidden_continuous_size - ) - for name in self.hparams.time_varying_reals_decoder - } - ) - - # create single variable grns that are shared across decoder and encoder - if self.hparams.share_single_variable_networks: - self.shared_single_variable_grns = nn.ModuleDict() - for name, input_size in encoder_input_sizes.items(): - self.shared_single_variable_grns[name] = GatedResidualNetwork( - input_size, - min(input_size, self.hparams.hidden_size), - self.hparams.hidden_size, - self.hparams.dropout, - ) - for name, input_size in decoder_input_sizes.items(): - if name not in self.shared_single_variable_grns: - self.shared_single_variable_grns[name] = GatedResidualNetwork( - input_size, - min(input_size, self.hparams.hidden_size), - self.hparams.hidden_size, - self.hparams.dropout, - ) - - self.encoder_variable_selection = VariableSelectionNetwork( - input_sizes=encoder_input_sizes, - hidden_size=self.hparams.hidden_size, - input_embedding_flags={ - name: True for name in self.hparams.time_varying_categoricals_encoder - }, - dropout=self.hparams.dropout, - context_size=self.hparams.hidden_size, - prescalers=self.prescalers, - single_variable_grns=( - {} - if not self.hparams.share_single_variable_networks - else self.shared_single_variable_grns - ), - ) - - self.decoder_variable_selection = VariableSelectionNetwork( - input_sizes=decoder_input_sizes, - hidden_size=self.hparams.hidden_size, - input_embedding_flags={ - name: True for name in self.hparams.time_varying_categoricals_decoder - }, - dropout=self.hparams.dropout, - context_size=self.hparams.hidden_size, - prescalers=self.prescalers, - single_variable_grns=( - {} - if not self.hparams.share_single_variable_networks - else self.shared_single_variable_grns - ), - ) - - # static encoders - # for variable selection - self.static_context_variable_selection = GatedResidualNetwork( - input_size=self.hparams.hidden_size, - hidden_size=self.hparams.hidden_size, - output_size=self.hparams.hidden_size, - dropout=self.hparams.dropout, - ) - - # for hidden state of the lstm - self.static_context_initial_hidden_lstm = GatedResidualNetwork( - input_size=self.hparams.hidden_size, - hidden_size=self.hparams.hidden_size, - output_size=self.hparams.hidden_size, - dropout=self.hparams.dropout, - ) - - # for cell state of the lstm - self.static_context_initial_cell_lstm = GatedResidualNetwork( - input_size=self.hparams.hidden_size, - hidden_size=self.hparams.hidden_size, - output_size=self.hparams.hidden_size, - dropout=self.hparams.dropout, - ) - - # for post lstm static enrichment - self.static_context_enrichment = GatedResidualNetwork( - self.hparams.hidden_size, - self.hparams.hidden_size, - self.hparams.hidden_size, - self.hparams.dropout, - ) - - # lstm encoder (history) and decoder (future) for local processing - self.lstm_encoder = LSTM( - input_size=self.hparams.hidden_size, - hidden_size=self.hparams.hidden_size, - num_layers=self.hparams.lstm_layers, - dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0, - batch_first=True, - ) - - self.lstm_decoder = LSTM( - input_size=self.hparams.hidden_size, - hidden_size=self.hparams.hidden_size, - num_layers=self.hparams.lstm_layers, - dropout=self.hparams.dropout if self.hparams.lstm_layers > 1 else 0, - batch_first=True, - ) - - # skip connection for lstm - self.post_lstm_gate_encoder = GatedLinearUnit( - self.hparams.hidden_size, dropout=self.hparams.dropout - ) - self.post_lstm_gate_decoder = self.post_lstm_gate_encoder - # self.post_lstm_gate_decoder = GatedLinearUnit( - # self.hparams.hidden_size, dropout=self.hparams.dropout) - self.post_lstm_add_norm_encoder = AddNorm( - self.hparams.hidden_size, trainable_add=False - ) - # self.post_lstm_add_norm_decoder = AddNorm( - # self.hparams.hidden_size, trainable_add=True) - self.post_lstm_add_norm_decoder = self.post_lstm_add_norm_encoder - - # static enrichment and processing past LSTM - self.static_enrichment = GatedResidualNetwork( - input_size=self.hparams.hidden_size, - hidden_size=self.hparams.hidden_size, - output_size=self.hparams.hidden_size, - dropout=self.hparams.dropout, - context_size=self.hparams.hidden_size, - ) - - # attention for long-range processing - self.multihead_attn = InterpretableMultiHeadAttention( - d_model=self.hparams.hidden_size, - n_head=self.hparams.attention_head_size, - dropout=self.hparams.dropout, - ) - self.post_attn_gate_norm = GateAddNorm( - self.hparams.hidden_size, dropout=self.hparams.dropout, trainable_add=False - ) - self.pos_wise_ff = GatedResidualNetwork( - self.hparams.hidden_size, - self.hparams.hidden_size, - self.hparams.hidden_size, - dropout=self.hparams.dropout, - ) - - # output processing -> no dropout at this late stage - self.pre_output_gate_norm = GateAddNorm( - self.hparams.hidden_size, dropout=None, trainable_add=False - ) - - if self.n_targets > 1: # if to run with multiple targets - self.output_layer = nn.ModuleList( - [ - nn.Linear(self.hparams.hidden_size, output_size) - for output_size in self.hparams.output_size - ] - ) - else: - self.output_layer = nn.Linear( - self.hparams.hidden_size, self.hparams.output_size - ) - - @classmethod - def from_dataset( - cls, - dataset: TimeSeriesDataSet, - allowed_encoder_known_variable_names: List[str] = None, - **kwargs, - ): - """ - Create model from dataset. - - Args: - dataset: timeseries dataset - allowed_encoder_known_variable_names: List of known variables that are allowed in encoder, defaults to all - **kwargs: additional arguments such as hyperparameters for model (see ``__init__()``) - - Returns: - TemporalFusionTransformer - """ # noqa: E501 - # add maximum encoder length - # update defaults - new_kwargs = copy(kwargs) - new_kwargs["max_encoder_length"] = dataset.max_encoder_length - new_kwargs.update( - cls.deduce_default_output_parameters(dataset, kwargs, QuantileLoss()) - ) - - # create class and return - return super().from_dataset( - dataset, - allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, - **new_kwargs, - ) - - def expand_static_context(self, context, timesteps): - """ - add time dimension to static context - """ - return context[:, None].expand(-1, timesteps, -1) - - def get_attention_mask( - self, encoder_lengths: torch.LongTensor, decoder_lengths: torch.LongTensor - ): - """ - Returns causal mask to apply for self-attention layer. - """ - decoder_length = decoder_lengths.max() - if self.hparams.causal_attention: - # indices to which is attended - attend_step = torch.arange(decoder_length, device=self.device) - # indices for which is predicted - predict_step = torch.arange(0, decoder_length, device=self.device)[:, None] - # do not attend to steps to self or after prediction - decoder_mask = ( - (attend_step >= predict_step) - .unsqueeze(0) - .expand(encoder_lengths.size(0), -1, -1) - ) - else: - # there is value in attending to future forecasts if - # they are made with knowledge currently available - # one possibility is here to use a second attention layer - # for future attention - # (assuming different effects matter in the future than the past) - # or alternatively using the same layer but - # allowing forward attention - i.e. only - # masking out non-available data and self - decoder_mask = ( - create_mask(decoder_length, decoder_lengths) - .unsqueeze(1) - .expand(-1, decoder_length, -1) - ) - # do not attend to steps where data is padded - encoder_mask = ( - create_mask(encoder_lengths.max(), encoder_lengths) - .unsqueeze(1) - .expand(-1, decoder_length, -1) - ) - # combine masks along attended time - first encoder and then decoder - mask = torch.cat( - ( - encoder_mask, - decoder_mask, - ), - dim=2, - ) - return mask - - def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - input dimensions: n_samples x time x variables - """ - encoder_lengths = x["encoder_lengths"] - decoder_lengths = x["decoder_lengths"] - x_cat = torch.cat( - [x["encoder_cat"], x["decoder_cat"]], dim=1 - ) # concatenate in time dimension - - ############different from _tft########################### - encoder_cont = x["encoder_cont"] - decoder_cont = x["decoder_cont"] - if encoder_cont.shape[-1] != decoder_cont.shape[-1]: - max_features = max(encoder_cont.shape[-1], decoder_cont.shape[-1]) - - if encoder_cont.shape[-1] < max_features: - encoder_padding = torch.zeros( - encoder_cont.shape[0], - encoder_cont.shape[1], - max_features - encoder_cont.shape[-1], - ).to(encoder_cont.device) - encoder_cont = torch.cat([encoder_cont, encoder_padding], dim=-1) - - if decoder_cont.shape[-1] < max_features: - decoder_padding = torch.zeros( - decoder_cont.shape[0], - decoder_cont.shape[1], - max_features - decoder_cont.shape[-1], - ).to(decoder_cont.device) - decoder_cont = torch.cat([decoder_cont, decoder_padding], dim=-1) - x_cont = torch.cat( - [encoder_cont, decoder_cont], dim=1 - ) # concatenate in time dimension - ########## - - timesteps = x_cont.size(1) # encode + decode length - max_encoder_length = int(encoder_lengths.max()) - input_vectors = self.input_embeddings(x_cat) - - ############different from _tft########################### - available_reals = [name for name in self.hparams.x_reals if name in self.reals] - max_features = x_cont.shape[-1] - - input_vectors.update( - { - name: x_cont[..., idx].unsqueeze(-1) - for idx, name in enumerate(available_reals) - if idx < max_features - } - ) - - all_expected_vars = set(self.encoder_variables + self.decoder_variables) - real_vars_in_input = set(input_vectors.keys()) - missing_vars = all_expected_vars - real_vars_in_input - - for var_name in missing_vars: - if var_name.startswith(("encoder_real_", "decoder_real_")): - # Create zero tensor with same shape as other real variables - zero_tensor = torch.zeros_like(x_cont[..., 0].unsqueeze(-1)) - input_vectors[var_name] = zero_tensor - ############ - - # Embedding and variable selection - if len(self.static_variables) > 0: - # static embeddings will be constant over entire batch - static_embedding = { - name: input_vectors[name][:, 0] for name in self.static_variables - } - static_embedding, static_variable_selection = ( - self.static_variable_selection(static_embedding) - ) - else: - static_embedding = torch.zeros( - (x_cont.size(0), self.hparams.hidden_size), - dtype=self.dtype, - device=self.device, - ) - static_variable_selection = torch.zeros( - (x_cont.size(0), 0), dtype=self.dtype, device=self.device - ) - - static_context_variable_selection = self.expand_static_context( - self.static_context_variable_selection(static_embedding), timesteps - ) - - embeddings_varying_encoder = { - name: input_vectors[name][:, :max_encoder_length] - for name in self.encoder_variables - } - embeddings_varying_encoder, encoder_sparse_weights = ( - self.encoder_variable_selection( - embeddings_varying_encoder, - static_context_variable_selection[:, :max_encoder_length], - ) - ) - - embeddings_varying_decoder = { - name: input_vectors[name][:, max_encoder_length:] - for name in self.decoder_variables # select decoder - } - embeddings_varying_decoder, decoder_sparse_weights = ( - self.decoder_variable_selection( - embeddings_varying_decoder, - static_context_variable_selection[:, max_encoder_length:], - ) - ) - - # LSTM - # calculate initial state - input_hidden = self.static_context_initial_hidden_lstm(static_embedding).expand( - self.hparams.lstm_layers, -1, -1 - ) - input_cell = self.static_context_initial_cell_lstm(static_embedding).expand( - self.hparams.lstm_layers, -1, -1 - ) - - # run local encoder - encoder_output, (hidden, cell) = self.lstm_encoder( - embeddings_varying_encoder, - (input_hidden, input_cell), - lengths=encoder_lengths, - enforce_sorted=False, - ) - - # run local decoder - decoder_output, _ = self.lstm_decoder( - embeddings_varying_decoder, - (hidden, cell), - lengths=decoder_lengths, - enforce_sorted=False, - ) - - # skip connection over lstm - lstm_output_encoder = self.post_lstm_gate_encoder(encoder_output) - lstm_output_encoder = self.post_lstm_add_norm_encoder( - lstm_output_encoder, embeddings_varying_encoder - ) - - lstm_output_decoder = self.post_lstm_gate_decoder(decoder_output) - lstm_output_decoder = self.post_lstm_add_norm_decoder( - lstm_output_decoder, embeddings_varying_decoder - ) - - lstm_output = torch.cat([lstm_output_encoder, lstm_output_decoder], dim=1) - - # static enrichment - static_context_enrichment = self.static_context_enrichment(static_embedding) - attn_input = self.static_enrichment( - lstm_output, - self.expand_static_context(static_context_enrichment, timesteps), - ) - - # Attention - attn_output, attn_output_weights = self.multihead_attn( - q=attn_input[:, max_encoder_length:], # query only for predictions - k=attn_input, - v=attn_input, - mask=self.get_attention_mask( - encoder_lengths=encoder_lengths, decoder_lengths=decoder_lengths - ), - ) - - # skip connection over attention - attn_output = self.post_attn_gate_norm( - attn_output, attn_input[:, max_encoder_length:] - ) - - output = self.pos_wise_ff(attn_output) - - # skip connection over temporal fusion decoder (not LSTM decoder - # despite the LSTM output contains - # a skip from the variable selection network) - output = self.pre_output_gate_norm(output, lstm_output[:, max_encoder_length:]) - if self.n_targets > 1: # if to use multi-target architecture - output = [output_layer(output) for output_layer in self.output_layer] - else: - output = self.output_layer(output) - - return self.to_network_output( - prediction=self.transform_output(output, target_scale=x["target_scale"]), - encoder_attention=attn_output_weights[..., :max_encoder_length], - decoder_attention=attn_output_weights[..., max_encoder_length:], - static_variables=static_variable_selection, - encoder_variables=encoder_sparse_weights, - decoder_variables=decoder_sparse_weights, - decoder_lengths=decoder_lengths, - encoder_lengths=encoder_lengths, - ) - - def on_fit_end(self): - if self.log_interval > 0: - self.log_embeddings() - - def create_log(self, x, y, out, batch_idx, **kwargs): - log = super().create_log(x, y, out, batch_idx, **kwargs) - if self.log_interval > 0: - log["interpretation"] = self._log_interpretation(out) - return log - - def _log_interpretation(self, out): - # calculate interpretations etc for latter logging - interpretation = self.interpret_output( - detach(out), - reduction="sum", - attention_prediction_horizon=0, # attention only for first prediction horizon # noqa: E501 - ) - return interpretation - - def on_epoch_end(self, outputs): - """ - run at epoch end for training or validation - """ - if self.log_interval > 0 and not self.training: - self.log_interpretation(outputs) - - def interpret_output( - self, - out: Dict[str, torch.Tensor], - reduction: str = "none", - attention_prediction_horizon: int = 0, - ) -> Dict[str, torch.Tensor]: - """ - interpret output of model - - Args: - out: output as produced by ``forward()`` - reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for - normalizing by encode lengths - attention_prediction_horizon: which prediction horizon to use for attention - - Returns: - interpretations that can be plotted with ``plot_interpretation()`` - """ # noqa: E501 - # take attention and concatenate if a list to proper attention object - batch_size = len(out["decoder_attention"]) - if isinstance(out["decoder_attention"], (list, tuple)): - # start with decoder attention - # assume issue is in last dimension, we need to find max - max_last_dimension = max(x.size(-1) for x in out["decoder_attention"]) - first_elm = out["decoder_attention"][0] - # create new attention tensor into which we will scatter - decoder_attention = torch.full( - (batch_size, *first_elm.shape[:-1], max_last_dimension), - float("nan"), - dtype=first_elm.dtype, - device=first_elm.device, - ) - # scatter into tensor - for idx, x in enumerate(out["decoder_attention"]): - decoder_length = out["decoder_lengths"][idx] - decoder_attention[idx, :, :, :decoder_length] = x[..., :decoder_length] - else: - decoder_attention = out["decoder_attention"].clone() - decoder_mask = create_mask( - out["decoder_attention"].size(1), out["decoder_lengths"] - ) - decoder_attention[ - decoder_mask[..., None, None].expand_as(decoder_attention) - ] = float("nan") - - if isinstance(out["encoder_attention"], (list, tuple)): - # same game for encoder attention - # create new attention tensor into which we will scatter - first_elm = out["encoder_attention"][0] - encoder_attention = torch.full( - (batch_size, *first_elm.shape[:-1], self.hparams.max_encoder_length), - float("nan"), - dtype=first_elm.dtype, - device=first_elm.device, - ) - # scatter into tensor - for idx, x in enumerate(out["encoder_attention"]): - encoder_length = out["encoder_lengths"][idx] - encoder_attention[ - idx, :, :, self.hparams.max_encoder_length - encoder_length : - ] = x[..., :encoder_length] - else: - # roll encoder attention (so start last encoder value is on the right) - encoder_attention = out["encoder_attention"].clone() - shifts = encoder_attention.size(3) - out["encoder_lengths"] - new_index = ( - torch.arange( - encoder_attention.size(3), device=encoder_attention.device - )[None, None, None].expand_as(encoder_attention) - - shifts[:, None, None, None] - ) % encoder_attention.size(3) - encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index) - # expand encoder_attention to full size - if encoder_attention.size(-1) < self.hparams.max_encoder_length: - encoder_attention = torch.concat( - [ - torch.full( - ( - *encoder_attention.shape[:-1], - self.hparams.max_encoder_length - - out["encoder_lengths"].max(), - ), - float("nan"), - dtype=encoder_attention.dtype, - device=encoder_attention.device, - ), - encoder_attention, - ], - dim=-1, - ) - - # combine attention vector - attention = torch.concat([encoder_attention, decoder_attention], dim=-1) - attention[attention < 1e-5] = float("nan") - - # histogram of decode and encode lengths - encoder_length_histogram = integer_histogram( - out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length - ) - decoder_length_histogram = integer_histogram( - out["decoder_lengths"], min=1, max=out["decoder_variables"].size(1) - ) - - # mask where decoder and encoder where not applied - # when averaging variable selection weights - encoder_variables = out["encoder_variables"].squeeze(-2).clone() - encode_mask = create_mask(encoder_variables.size(1), out["encoder_lengths"]) - encoder_variables = encoder_variables.masked_fill( - encode_mask.unsqueeze(-1), 0.0 - ).sum(dim=1) - encoder_variables /= ( - out["encoder_lengths"] - .where(out["encoder_lengths"] > 0, torch.ones_like(out["encoder_lengths"])) - .unsqueeze(-1) - ) - - decoder_variables = out["decoder_variables"].squeeze(-2).clone() - decode_mask = create_mask(decoder_variables.size(1), out["decoder_lengths"]) - decoder_variables = decoder_variables.masked_fill( - decode_mask.unsqueeze(-1), 0.0 - ).sum(dim=1) - decoder_variables /= out["decoder_lengths"].unsqueeze(-1) - - # static variables need no masking - static_variables = out["static_variables"].squeeze(1) - # attention is batch x time x heads x time_to_attend - # average over heads + only keep prediction attention and - # attention on observed timesteps - attention = masked_op( - attention[ - :, - attention_prediction_horizon, - :, - : self.hparams.max_encoder_length + attention_prediction_horizon, - ], - op="mean", - dim=1, - ) - - if reduction != "none": # if to average over batches - static_variables = static_variables.sum(dim=0) - encoder_variables = encoder_variables.sum(dim=0) - decoder_variables = decoder_variables.sum(dim=0) - - attention = masked_op(attention, dim=0, op=reduction) - else: - attention = attention / masked_op(attention, dim=1, op="sum").unsqueeze( - -1 - ) # renormalize - - interpretation = dict( - attention=attention.masked_fill(torch.isnan(attention), 0.0), - static_variables=static_variables, - encoder_variables=encoder_variables, - decoder_variables=decoder_variables, - encoder_length_histogram=encoder_length_histogram, - decoder_length_histogram=decoder_length_histogram, - ) - return interpretation - - def plot_prediction( - self, - x: Dict[str, torch.Tensor], - out: Dict[str, torch.Tensor], - idx: int, - plot_attention: bool = True, - add_loss_to_title: bool = False, - show_future_observed: bool = True, - ax=None, - **kwargs, - ): - """ - Plot actuals vs prediction and attention - - Args: - x (Dict[str, torch.Tensor]): network input - out (Dict[str, torch.Tensor]): network output - idx (int): sample index - plot_attention: if to plot attention on secondary axis - add_loss_to_title: if to add loss to title. Default to False. - show_future_observed: if to show actuals for future. Defaults to True. - ax: matplotlib axes to plot on - - Returns: - plt.Figure: matplotlib figure - """ - # plot prediction as normal - fig = super().plot_prediction( - x, - out, - idx=idx, - add_loss_to_title=add_loss_to_title, - show_future_observed=show_future_observed, - ax=ax, - **kwargs, - ) - - # add attention on secondary axis - if plot_attention: - interpretation = self.interpret_output(out.iget(slice(idx, idx + 1))) - for f in to_list(fig): - ax = f.axes[0] - ax2 = ax.twinx() - ax2.set_ylabel("Attention") - encoder_length = x["encoder_lengths"][0] - ax2.plot( - torch.arange(-encoder_length, 0), - interpretation["attention"][0, -encoder_length:].detach().cpu(), - alpha=0.2, - color="k", - ) - f.tight_layout() - return fig - - def plot_interpretation(self, interpretation: Dict[str, torch.Tensor]): - """ - Make figures that interpret model. - - * Attention - * Variable selection weights / importances - - Args: - interpretation: as obtained from ``interpret_output()`` - - Returns: - dictionary of matplotlib figures - """ - _check_matplotlib("plot_interpretation") - - import matplotlib.pyplot as plt - - figs = {} - - # attention - fig, ax = plt.subplots() - attention = interpretation["attention"].detach().cpu() - attention = attention / attention.sum(-1).unsqueeze(-1) - ax.plot( - np.arange( - -self.hparams.max_encoder_length, - attention.size(0) - self.hparams.max_encoder_length, - ), - attention, - ) - ax.set_xlabel("Time index") - ax.set_ylabel("Attention") - ax.set_title("Attention") - figs["attention"] = fig - - # variable selection - def make_selection_plot(title, values, labels): - fig, ax = plt.subplots(figsize=(7, len(values) * 0.25 + 2)) - order = np.argsort(values) - values = values / values.sum(-1).unsqueeze(-1) - ax.barh( - np.arange(len(values)), - values[order] * 100, - tick_label=np.asarray(labels)[order], - ) - ax.set_title(title) - ax.set_xlabel("Importance in %") - plt.tight_layout() - return fig - - figs["static_variables"] = make_selection_plot( - "Static variables importance", - interpretation["static_variables"].detach().cpu(), - self.static_variables, - ) - figs["encoder_variables"] = make_selection_plot( - "Encoder variables importance", - interpretation["encoder_variables"].detach().cpu(), - self.encoder_variables, - ) - figs["decoder_variables"] = make_selection_plot( - "Decoder variables importance", - interpretation["decoder_variables"].detach().cpu(), - self.decoder_variables, - ) - - return figs - - def log_interpretation(self, outputs): - """ - Log interpretation metrics to tensorboard. - """ - # extract interpretations - interpretation = { - # use padded_stack because decoder - # length histogram can be of different length - name: padded_stack( - [x["interpretation"][name].detach() for x in outputs], - side="right", - value=0, - ).sum(0) - for name in outputs[0]["interpretation"].keys() - } - # normalize attention with length histogram squared to account for: - # 1. zeros in attention and - # 2. higher attention due to less values - attention_occurances = ( - interpretation["encoder_length_histogram"][1:].flip(0).float().cumsum(0) - ) - attention_occurances = attention_occurances / attention_occurances.max() - attention_occurances = torch.cat( - [ - attention_occurances, - torch.ones( - interpretation["attention"].size(0) - attention_occurances.size(0), - dtype=attention_occurances.dtype, - device=attention_occurances.device, - ), - ], - dim=0, - ) - interpretation["attention"] = interpretation[ - "attention" - ] / attention_occurances.pow(2).clamp(1.0) - interpretation["attention"] = ( - interpretation["attention"] / interpretation["attention"].sum() - ) - - mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - - # Don't log figures if matplotlib or add_figure is not available - if not mpl_available or not self._logger_supports("add_figure"): - return None - - import matplotlib.pyplot as plt - - figs = self.plot_interpretation(interpretation) # make interpretation figures - label = self.current_stage - # log to tensorboard - for name, fig in figs.items(): - self.logger.experiment.add_figure( - f"{label.capitalize()} {name} importance", - fig, - global_step=self.global_step, - ) - - # log lengths of encoder/decoder - for type in ["encoder", "decoder"]: - fig, ax = plt.subplots() - lengths = ( - padded_stack( - [ - out["interpretation"][f"{type}_length_histogram"] - for out in outputs - ] - ) - .sum(0) - .detach() - .cpu() - ) - if type == "decoder": - start = 1 - else: - start = 0 - ax.plot(torch.arange(start, start + len(lengths)), lengths) - ax.set_xlabel(f"{type.capitalize()} length") - ax.set_ylabel("Number of samples") - ax.set_title(f"{type.capitalize()} length distribution in {label} epoch") - - self.logger.experiment.add_figure( - f"{label.capitalize()} {type} length distribution", - fig, - global_step=self.global_step, - ) - - def log_embeddings(self): - """ - Log embeddings to tensorboard - """ - - # Don't log embeddings if add_embedding is not available - if not self._logger_supports("add_embedding"): - return None - - for name, emb in self.input_embeddings.items(): - labels = self.hparams.embedding_labels[name] - self.logger.experiment.add_embedding( - emb.weight.data.detach().cpu(), - metadata=labels, - tag=name, - global_step=self.global_step, - ) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py index 5ebea75b1..2a98767c2 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py @@ -42,49 +42,3 @@ def get_test_train_params(cls): attention_head_size=5, ), ] - - -class TemporalFusionTransformerMetadata(_BasePtForecaster): - """TFT metadata container.""" - - _tags = { - "info:name": "TemporalFusionTransformerM", - "object_type": "ptf-v2", - "authors": ["jdb78"], - "capability:exogenous": True, - "capability:multivariate": True, - "capability:pred_int": True, - "capability:flexible_history_length": False, - } - - @classmethod - def get_model_cls(cls): - """Get model class.""" - from pytorch_forecasting.models.temporal_fusion_transformer._tft_ver2 import ( - TemporalFusionTransformer, - ) - - return TemporalFusionTransformer - - @classmethod - def get_test_train_params(cls): - """Return testing parameter settings for the trainer. - - Returns - ------- - params : dict or list of dict, default = {} - Parameters to create testing instances of the class - Each dict are parameters to construct an "interesting" test instance, i.e., - `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. - `create_test_instance` uses the first (or only) dictionary in `params` - """ - return [ - {}, - dict( - hidden_size=25, - attention_head_size=5, - data_loader_kwargs={ - "add_relative_time_idx": False, - }, - ), - ] diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index b8a21cc6a..18199214b 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -265,6 +265,21 @@ def _integration( class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator): """Generic tests for all objects in the mini package.""" + def _all_objects(self): + """Retrieve list of all object classes, excluding ptf-v2 objects.""" + obj_list = super()._all_objects() + + filtered_obj_list = [] + for obj in obj_list: + if hasattr(obj, "get_class_tag"): + object_type = obj.get_class_tag("object_type", None) + if object_type != "ptf-v2": + filtered_obj_list.append(obj) + else: + filtered_obj_list.append(obj) + + return filtered_obj_list + def test_doctest_examples(self, object_class): """Runs doctests for estimator class.""" import doctest From d0490192b62ddaf8575318f55dcaf3582e0e8eef Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 28 May 2025 16:56:23 +0530 Subject: [PATCH 62/80] update test_all_estimators --- pytorch_forecasting/tests/_conftest.py | 234 +++++++++++++++++- pytorch_forecasting/tests/_data_scenarios.py | 233 ++++++++++++++++- .../tests/test_all_estimators.py | 17 +- .../tests/test_all_estimators_v2.py | 4 +- 4 files changed, 483 insertions(+), 5 deletions(-) diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index 36691e850..caa0f5600 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -1,10 +1,15 @@ +from datetime import datetime + import numpy as np +import pandas as pd import pytest import torch from pytorch_forecasting import TimeSeriesDataSet from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data +from pytorch_forecasting.data.timeseries import TimeSeries torch.manual_seed(23) @@ -155,6 +160,233 @@ def multiple_dataloaders_with_covariates(data_with_covariates, request): return make_dataloaders(data_with_covariates, **request.param) +@pytest.fixture(scope="session") +def data_with_covariates_v2(): + """Create synthetic time series data with all numerical features.""" + + start_date = datetime(2015, 1, 1) + end_date = datetime(2017, 12, 31) + dates = pd.date_range(start_date, end_date, freq="M") + + agencies = [0, 1] + skus = [0, 1] + data_list = [] + + for agency in agencies: + for sku in skus: + for date in dates: + time_idx = (date.year - 2015) * 12 + date.month - 1 + + volume = ( + np.random.exponential(2) + + 0.1 * time_idx + + 0.5 * np.sin(date.month * np.pi / 6) + ) + volume = max(0.001, volume) + month = date.month + year = date.year + quarter = (date.month - 1) // 3 + 1 + + seasonal_1 = np.sin(2 * np.pi * date.month / 12) + seasonal_2 = np.cos(2 * np.pi * date.month / 12) + + agency_feature_1 = agency * 10 + np.random.normal(0, 0.1) + agency_feature_2 = agency * 5 + np.random.normal(0, 0.1) + + sku_feature_1 = sku * 8 + np.random.normal(0, 0.1) + sku_feature_2 = sku * 3 + np.random.normal(0, 0.1) + + trend = time_idx * 0.1 + noise = np.random.normal(0, 0.1) + + special_event_1 = 1 if date.month in [12, 1] else 0 + special_event_2 = 1 if date.month in [6, 7, 8] else 0 + + data_list.append( + { + "date": date, + "time_idx": time_idx, + "agency_encoded": agency, + "sku_encoded": sku, + "volume": volume, + "target": volume, + "weight": 1.0 + np.sqrt(volume), + "month": month, + "year": year, + "quarter": quarter, + "seasonal_1": seasonal_1, + "seasonal_2": seasonal_2, + "agency_feature_1": agency_feature_1, + "agency_feature_2": agency_feature_2, + "sku_feature_1": sku_feature_1, + "sku_feature_2": sku_feature_2, + "trend": trend, + "noise": noise, + "special_event_1": special_event_1, + "special_event_2": special_event_2, + "log_volume": np.log1p(volume), + } + ) + + data = pd.DataFrame(data_list) + + numeric_cols = [col for col in data.columns if col != "date"] + for col in numeric_cols: + data[col] = pd.to_numeric(data[col], errors="coerce") + data[numeric_cols] = data[numeric_cols].fillna(0) + + return data + + +def make_dataloaders_v2(data_with_covariates, **kwargs): + """Create dataloaders with consistent encoder/decoder features.""" + + training_cutoff = "2016-09-01" + max_encoder_length = 4 + max_prediction_length = 3 + + target_col = kwargs.get("target", "target") + group_cols = kwargs.get("group_ids", ["agency_encoded", "sku_encoded"]) + add_relative_time_idx = kwargs.get("add_relative_time_idx", True) + + known_features = [ + "month", + "year", + "quarter", + "seasonal_1", + "seasonal_2", + "special_event_1", + "special_event_2", + "trend", + ] + unknown_features = [ + "agency_feature_1", + "agency_feature_2", + "sku_feature_1", + "sku_feature_2", + "noise", + "log_volume", + ] + + numerical_features = known_features + unknown_features + categorical_features = [] + static_features = group_cols + + for col in numerical_features + categorical_features + group_cols + [target_col]: + if col in data_with_covariates.columns: + data_with_covariates[col] = pd.to_numeric( + data_with_covariates[col], errors="coerce" + ).fillna(0) + + for col in categorical_features + group_cols: + if col in data_with_covariates.columns: + data_with_covariates[col] = data_with_covariates[col].astype("int64") + + if "weight" in data_with_covariates.columns: + data_with_covariates["weight"] = pd.to_numeric( + data_with_covariates["weight"], errors="coerce" + ).fillna(1.0) + + training_data = data_with_covariates[ + data_with_covariates.date < training_cutoff + ].copy() + validation_data = data_with_covariates.copy() + + required_columns = ( + ["time_idx", target_col, "weight", "date"] + + group_cols + + numerical_features + + categorical_features + ) + + available_columns = [ + col for col in required_columns if col in data_with_covariates.columns + ] + + training_data_clean = training_data[available_columns].copy() + validation_data_clean = validation_data[available_columns].copy() + + if "date" in training_data_clean.columns: + training_data_clean = training_data_clean.drop("date", axis=1) + if "date" in validation_data_clean.columns: + validation_data_clean = validation_data_clean.drop("date", axis=1) + + training_dataset = TimeSeries( + data=training_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + validation_dataset = TimeSeries( + data=validation_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + training_max_time_idx = training_data["time_idx"].max() + 1 + + train_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=training_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + add_relative_time_idx=add_relative_time_idx, + batch_size=2, + num_workers=0, + train_val_test_split=(0.8, 0.2, 0.0), + ) + + val_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=2, + num_workers=0, + train_val_test_split=(0.0, 1.0, 0.0), + ) + + test_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=1, + num_workers=0, + train_val_test_split=(0.0, 0.0, 1.0), + ) + + train_datamodule.setup("fit") + val_datamodule.setup("fit") + test_datamodule.setup("test") + + train_dataloader = train_datamodule.train_dataloader() + val_dataloader = val_datamodule.val_dataloader() + test_dataloader = test_datamodule.test_dataloader() + + return { + "train": train_dataloader, + "val": val_dataloader, + "test": test_dataloader, + "data_module": train_datamodule, + } + + @pytest.fixture(scope="session") def dataloaders_with_different_encoder_decoder_length(data_with_covariates): return make_dataloaders( @@ -259,4 +491,4 @@ def dataloaders_fixed_window_without_covariates(): train=False, batch_size=batch_size, num_workers=0 ) - return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) \ No newline at end of file + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) diff --git a/pytorch_forecasting/tests/_data_scenarios.py b/pytorch_forecasting/tests/_data_scenarios.py index fdf8e5e6d..d39f6d988 100644 --- a/pytorch_forecasting/tests/_data_scenarios.py +++ b/pytorch_forecasting/tests/_data_scenarios.py @@ -1,10 +1,15 @@ +from datetime import datetime + import numpy as np +import pandas as pd import pytest import torch from pytorch_forecasting import TimeSeriesDataSet from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder +from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data +from pytorch_forecasting.data.timeseries import TimeSeries torch.manual_seed(23) @@ -87,6 +92,232 @@ def make_dataloaders(data_with_covariates, **kwargs): return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) +def data_with_covariates_v2(): + """Create synthetic time series data with all numerical features.""" + + start_date = datetime(2015, 1, 1) + end_date = datetime(2017, 12, 31) + dates = pd.date_range(start_date, end_date, freq="M") + + agencies = [0, 1] + skus = [0, 1] + data_list = [] + + for agency in agencies: + for sku in skus: + for date in dates: + time_idx = (date.year - 2015) * 12 + date.month - 1 + + volume = ( + np.random.exponential(2) + + 0.1 * time_idx + + 0.5 * np.sin(date.month * np.pi / 6) + ) + volume = max(0.001, volume) + month = date.month + year = date.year + quarter = (date.month - 1) // 3 + 1 + + seasonal_1 = np.sin(2 * np.pi * date.month / 12) + seasonal_2 = np.cos(2 * np.pi * date.month / 12) + + agency_feature_1 = agency * 10 + np.random.normal(0, 0.1) + agency_feature_2 = agency * 5 + np.random.normal(0, 0.1) + + sku_feature_1 = sku * 8 + np.random.normal(0, 0.1) + sku_feature_2 = sku * 3 + np.random.normal(0, 0.1) + + trend = time_idx * 0.1 + noise = np.random.normal(0, 0.1) + + special_event_1 = 1 if date.month in [12, 1] else 0 + special_event_2 = 1 if date.month in [6, 7, 8] else 0 + + data_list.append( + { + "date": date, + "time_idx": time_idx, + "agency_encoded": agency, + "sku_encoded": sku, + "volume": volume, + "target": volume, + "weight": 1.0 + np.sqrt(volume), + "month": month, + "year": year, + "quarter": quarter, + "seasonal_1": seasonal_1, + "seasonal_2": seasonal_2, + "agency_feature_1": agency_feature_1, + "agency_feature_2": agency_feature_2, + "sku_feature_1": sku_feature_1, + "sku_feature_2": sku_feature_2, + "trend": trend, + "noise": noise, + "special_event_1": special_event_1, + "special_event_2": special_event_2, + "log_volume": np.log1p(volume), + } + ) + + data = pd.DataFrame(data_list) + + numeric_cols = [col for col in data.columns if col != "date"] + for col in numeric_cols: + data[col] = pd.to_numeric(data[col], errors="coerce") + data[numeric_cols] = data[numeric_cols].fillna(0) + + return data + + +def make_dataloaders_v2(data_with_covariates, **kwargs): + """Create dataloaders with consistent encoder/decoder features.""" + + training_cutoff = "2016-09-01" + max_encoder_length = 4 + max_prediction_length = 3 + + target_col = kwargs.get("target", "target") + group_cols = kwargs.get("group_ids", ["agency_encoded", "sku_encoded"]) + add_relative_time_idx = kwargs.get("add_relative_time_idx", True) + + known_features = [ + "month", + "year", + "quarter", + "seasonal_1", + "seasonal_2", + "special_event_1", + "special_event_2", + "trend", + ] + unknown_features = [ + "agency_feature_1", + "agency_feature_2", + "sku_feature_1", + "sku_feature_2", + "noise", + "log_volume", + ] + + numerical_features = known_features + unknown_features + categorical_features = [] + static_features = group_cols + + for col in numerical_features + categorical_features + group_cols + [target_col]: + if col in data_with_covariates.columns: + data_with_covariates[col] = pd.to_numeric( + data_with_covariates[col], errors="coerce" + ).fillna(0) + + for col in categorical_features + group_cols: + if col in data_with_covariates.columns: + data_with_covariates[col] = data_with_covariates[col].astype("int64") + + if "weight" in data_with_covariates.columns: + data_with_covariates["weight"] = pd.to_numeric( + data_with_covariates["weight"], errors="coerce" + ).fillna(1.0) + + training_data = data_with_covariates[ + data_with_covariates.date < training_cutoff + ].copy() + validation_data = data_with_covariates.copy() + + required_columns = ( + ["time_idx", target_col, "weight", "date"] + + group_cols + + numerical_features + + categorical_features + ) + + available_columns = [ + col for col in required_columns if col in data_with_covariates.columns + ] + + training_data_clean = training_data[available_columns].copy() + validation_data_clean = validation_data[available_columns].copy() + + if "date" in training_data_clean.columns: + training_data_clean = training_data_clean.drop("date", axis=1) + if "date" in validation_data_clean.columns: + validation_data_clean = validation_data_clean.drop("date", axis=1) + + training_dataset = TimeSeries( + data=training_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + validation_dataset = TimeSeries( + data=validation_data_clean, + time="time_idx", + target=[target_col], + group=group_cols, + weight="weight", + num=numerical_features, + cat=categorical_features if categorical_features else None, + known=known_features, + unknown=unknown_features, + static=static_features, + ) + + training_max_time_idx = training_data["time_idx"].max() + 1 + + train_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=training_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + add_relative_time_idx=add_relative_time_idx, + batch_size=2, + num_workers=0, + train_val_test_split=(0.8, 0.2, 0.0), + ) + + val_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=2, + num_workers=0, + train_val_test_split=(0.0, 1.0, 0.0), + ) + + test_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=1, + num_workers=0, + train_val_test_split=(0.0, 0.0, 1.0), + ) + + train_datamodule.setup("fit") + val_datamodule.setup("fit") + test_datamodule.setup("test") + + train_dataloader = train_datamodule.train_dataloader() + val_dataloader = val_datamodule.val_dataloader() + test_dataloader = test_datamodule.test_dataloader() + + return { + "train": train_dataloader, + "val": val_dataloader, + "test": test_dataloader, + "data_module": train_datamodule, + } + + @pytest.fixture( params=[ dict(), @@ -258,4 +489,4 @@ def dataloaders_fixed_window_without_covariates(): train=False, batch_size=batch_size, num_workers=0 ) - return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) \ No newline at end of file + return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index 8b273d449..dca0f30c3 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -262,6 +262,21 @@ def _integration( class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator): """Generic tests for all objects in the mini package.""" + def _all_objects(self): + """Retrieve list of all object classes, excluding ptf-v2 objects.""" + obj_list = super()._all_objects() + + filtered_obj_list = [] + for obj in obj_list: + if hasattr(obj, "get_class_tag"): + object_type = obj.get_class_tag("object_type", None) + if object_type != "ptf-v2": + filtered_obj_list.append(obj) + else: + filtered_obj_list.append(obj) + + return filtered_obj_list + def test_doctest_examples(self, object_class): """Runs doctests for estimator class.""" from skbase.utils.doctest_run import run_doctest @@ -288,4 +303,4 @@ def test_integration( data_with_covariates = data_with_covariates.assign( volume=lambda x: x.volume.round() ) - _integration(object_class, data_with_covariates, tmp_path, **trainer_kwargs) \ No newline at end of file + _integration(object_class, data_with_covariates, tmp_path, **trainer_kwargs) diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index efbe170f0..8a68d8e17 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -105,9 +105,9 @@ class TestAllPtForecastersV2(PackageConfig, BaseFixtureGenerator): def test_doctest_examples(self, object_class): """Runs doctests for estimator class.""" - import doctest + from skbase.utils.doctest_run import run_doctest - doctest.run_docstring_examples(object_class, globals()) + run_doctest(object_class, name=f"class {object_class.__name__}") def test_integration( self, From e72486b79e4b5984400822e8e81ceec665f75148 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 28 May 2025 16:59:50 +0530 Subject: [PATCH 63/80] linting --- pytorch_forecasting/models/deepar/_deepar_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index 59df96441..a9eb46a04 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -122,4 +122,4 @@ def get_test_train_params(cls): n_plotting_samples=100, trainer_kwargs=dict(accelerator="cpu"), ), - ] \ No newline at end of file + ] From a734f265970cccc4c7e2e944916ac2f2607afb0f Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Thu, 29 May 2025 22:58:11 +0530 Subject: [PATCH 64/80] refactor --- .../models/base/_base_model_v2.py | 42 ------ .../temporal_fusion_transformer/_tft_v2.py | 1 - pytorch_forecasting/tests/_conftest.py | 134 +++++++++--------- .../tests/test_all_estimators_v2.py | 4 +- 4 files changed, 69 insertions(+), 112 deletions(-) diff --git a/pytorch_forecasting/models/base/_base_model_v2.py b/pytorch_forecasting/models/base/_base_model_v2.py index a74f926b9..aceec0869 100644 --- a/pytorch_forecasting/models/base/_base_model_v2.py +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -14,18 +14,6 @@ import torch.nn as nn from torch.optim import Optimizer -from pytorch_forecasting.metrics import ( - MAE, - MASE, - SMAPE, - DistributionLoss, - Metric, - MultiHorizonMetric, - MultiLoss, - QuantileLoss, - convert_torchmetric_to_pytorch_forecasting_metric, -) - class BaseModel(LightningModule): def __init__( @@ -116,7 +104,6 @@ def training_step( x, y = batch y_hat_dict = self(x) y_hat = y_hat_dict["prediction"] - y_hat, y = self._align_prediction_target_shapes(y_hat, y) loss = self.loss(y_hat, y) self.log( "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True @@ -145,7 +132,6 @@ def validation_step( x, y = batch y_hat_dict = self(x) y_hat = y_hat_dict["prediction"] - y_hat, y = self._align_prediction_target_shapes(y_hat, y) loss = self.loss(y_hat, y) self.log( "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True @@ -174,7 +160,6 @@ def test_step( x, y = batch y_hat_dict = self(x) y_hat = y_hat_dict["prediction"] - y_hat, y = self._align_prediction_target_shapes(y_hat, y) loss = self.loss(y_hat, y) self.log( "test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True @@ -309,30 +294,3 @@ def log_metrics( prog_bar=True, logger=True, ) - - def _align_prediction_target_shapes( - self, y_hat: torch.Tensor, y: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Align prediction and target tensor shapes for loss/metric calculation. - - Returns - ------- - Tuple of aligned prediction and target tensors - """ - if y.dim() == 3 and y.shape[-1] == 1: - y = y.squeeze(-1) - if y_hat.dim() < y.dim(): - y_hat = y_hat.unsqueeze(-1) - elif y_hat.dim() > y.dim(): - if y_hat.shape[-1] == 1: - y_hat = y_hat.squeeze(-1) - if y_hat.shape != y.shape: - if y_hat.numel() == y.numel(): - y_hat = y_hat.view(y.shape) - else: - raise ValueError( - f"Cannot align shapes: y_hat {y_hat.shape} vs y {y.shape}" - ) - - return y_hat, y diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py index fa4f8f6c0..e74f7cf32 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_v2.py @@ -9,7 +9,6 @@ import torch.nn as nn from torch.optim import Optimizer -from pytorch_forecasting.metrics import Metric from pytorch_forecasting.models.base._base_model_v2 import BaseModel diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index caa0f5600..d175db0c6 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -93,73 +93,6 @@ def make_dataloaders(data_with_covariates, **kwargs): return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) -@pytest.fixture( - params=[ - dict(), - dict( - static_categoricals=["agency", "sku"], - static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], - time_varying_known_categoricals=["special_days", "month"], - variable_groups=dict( - special_days=[ - "easter_day", - "good_friday", - "new_year", - "christmas", - "labor_day", - "independence_day", - "revolution_day_memorial", - "regional_games", - "fifa_u_17_world_cup", - "football_gold_cup", - "beer_capital", - "music_fest", - ] - ), - time_varying_known_reals=[ - "time_idx", - "price_regular", - "price_actual", - "discount", - "discount_in_percent", - ], - time_varying_unknown_categoricals=[], - time_varying_unknown_reals=[ - "volume", - "log_volume", - "industry_volume", - "soda_volume", - "avg_max_temp", - ], - constant_fill_strategy={"volume": 0}, - categorical_encoders={"sku": NaNLabelEncoder(add_nan=True)}, - ), - dict(static_categoricals=["agency", "sku"]), - dict(randomize_length=True, min_encoder_length=2), - dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2), - dict(target_normalizer=GroupNormalizer(transformation="log1p")), - dict( - target_normalizer=GroupNormalizer( - groups=["agency", "sku"], transformation="softplus", center=False - ) - ), - dict(target="agency"), - # test multiple targets - dict(target=["industry_volume", "volume"]), - dict(target=["agency", "volume"]), - dict( - target=["agency", "volume"], min_encoder_length=1, min_prediction_length=1 - ), - dict(target=["agency", "volume"], weight="volume"), - # test weights - dict(target="volume", weight="volume"), - ], - scope="session", -) -def multiple_dataloaders_with_covariates(data_with_covariates, request): - return make_dataloaders(data_with_covariates, **request.param) - - @pytest.fixture(scope="session") def data_with_covariates_v2(): """Create synthetic time series data with all numerical features.""" @@ -387,6 +320,73 @@ def make_dataloaders_v2(data_with_covariates, **kwargs): } +@pytest.fixture( + params=[ + dict(), + dict( + static_categoricals=["agency", "sku"], + static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], + time_varying_known_categoricals=["special_days", "month"], + variable_groups=dict( + special_days=[ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", + ] + ), + time_varying_known_reals=[ + "time_idx", + "price_regular", + "price_actual", + "discount", + "discount_in_percent", + ], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + ], + constant_fill_strategy={"volume": 0}, + categorical_encoders={"sku": NaNLabelEncoder(add_nan=True)}, + ), + dict(static_categoricals=["agency", "sku"]), + dict(randomize_length=True, min_encoder_length=2), + dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2), + dict(target_normalizer=GroupNormalizer(transformation="log1p")), + dict( + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="softplus", center=False + ) + ), + dict(target="agency"), + # test multiple targets + dict(target=["industry_volume", "volume"]), + dict(target=["agency", "volume"]), + dict( + target=["agency", "volume"], min_encoder_length=1, min_prediction_length=1 + ), + dict(target=["agency", "volume"], weight="volume"), + # test weights + dict(target="volume", weight="volume"), + ], + scope="session", +) +def multiple_dataloaders_with_covariates(data_with_covariates, request): + return make_dataloaders(data_with_covariates, **request.param) + + @pytest.fixture(scope="session") def dataloaders_with_different_encoder_decoder_length(data_with_covariates): return make_dataloaders( diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index 8a68d8e17..29e6ab22a 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -6,8 +6,8 @@ import lightning.pytorch as pl from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.loggers import TensorBoardLogger +import torch.nn as nn -from pytorch_forecasting.metrics import SMAPE from pytorch_forecasting.tests._conftest import make_dataloaders_v2 as make_dataloaders from pytorch_forecasting.tests.test_all_estimators import ( BaseFixtureGenerator, @@ -73,7 +73,7 @@ def _integration( net = estimator_cls( metadata=metadata, - loss=SMAPE(), + loss=nn.MSELoss(), **kwargs, ) From 7f466b29a742d6af8fc5360ee02d1e39bab0a3c0 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 30 May 2025 00:08:56 +0530 Subject: [PATCH 65/80] Add more test_params --- .../tft_v2_metadata.py | 17 +++++++++++++++++ pytorch_forecasting/tests/_conftest.py | 4 ++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py index 2a98767c2..91e2440ed 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py @@ -41,4 +41,21 @@ def get_test_train_params(cls): hidden_size=25, attention_head_size=5, ), + dict( + data_loader_kwargs=dict(max_encoder_length=5, max_prediction_length=3) + ), + dict( + hidden_size=24, + attention_head_size=8, + data_loader_kwargs=dict( + max_encoder_length=5, + max_prediction_length=3, + add_relative_time_idx=False, + ), + ), + dict( + hidden_size=12, + data_loader_kwargs=dict(max_encoder_length=7, max_prediction_length=10), + ), + dict(attention_head_size=2), ] diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index d175db0c6..7bc19920a 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -175,8 +175,8 @@ def make_dataloaders_v2(data_with_covariates, **kwargs): """Create dataloaders with consistent encoder/decoder features.""" training_cutoff = "2016-09-01" - max_encoder_length = 4 - max_prediction_length = 3 + max_encoder_length = kwargs.get("max_encoder_length", 4) + max_prediction_length = kwargs.get("max_prediction_length", 3) target_col = kwargs.get("target", "target") group_cols = kwargs.get("group_ids", ["agency_encoded", "sku_encoded"]) From 0968452c7beaca969aeaf40728f61a8d4b165534 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 1 Jun 2025 00:07:15 +0530 Subject: [PATCH 66/80] Add metadata tests --- .../tests/test_all_estimators_v2.py | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index 29e6ab22a..857ea1f76 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -71,6 +71,38 @@ def _integration( training_data_module = dataloaders_with_covariates["data_module"] metadata = training_data_module.metadata + assert metadata["encoder_cont"] == 14 # 14 features (8 known + 6 unknown) + assert metadata["encoder_cat"] == 0 + assert metadata["decoder_cont"] == 8 # 8 (only known features) + assert metadata["decoder_cat"] == 0 + assert metadata["static_categorical_features"] == 0 + assert ( + metadata["static_continuous_features"] == 2 + ) # 2 (agency_encoded, sku_encoded) + assert metadata["target"] == 1 + + batch_x, batch_y = next(iter(train_dataloader)) + + assert batch_x["encoder_cont"].shape[2] == metadata["encoder_cont"] + assert batch_x["encoder_cat"].shape[2] == metadata["encoder_cat"] + + assert batch_x["decoder_cont"].shape[2] == metadata["decoder_cont"] + assert batch_x["decoder_cat"].shape[2] == metadata["decoder_cat"] + + if "static_categorical_features" in batch_x: + assert ( + batch_x["static_categorical_features"].shape[2] + == metadata["static_categorical_features"] + ) + + if "static_continuous_features" in batch_x: + assert ( + batch_x["static_continuous_features"].shape[2] + == metadata["static_continuous_features"] + ) + + assert batch_y.shape[2] == metadata["target"] + net = estimator_cls( metadata=metadata, loss=nn.MSELoss(), @@ -85,18 +117,9 @@ def _integration( ) test_outputs = trainer.test(net, dataloaders=test_dataloader) assert len(test_outputs) > 0 - - # check loading - # net = estimator_cls.load_from_checkpoint( - # trainer.checkpoint_callback.best_model_path - # ) - # net.predict(val_dataloader) - finally: shutil.rmtree(tmp_path, ignore_errors=True) - # net.predict(val_dataloader) - class TestAllPtForecastersV2(PackageConfig, BaseFixtureGenerator): """Generic tests for all objects in the mini package.""" From 4e8f86343a888e3bf820c132fa1713948ea5838f Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 1 Jun 2025 19:14:17 +0530 Subject: [PATCH 67/80] add object-filter to ptf-v1 --- .../models/deepar/_deepar_metadata.py | 1 + .../models/nbeats/_nbeats_metadata.py | 1 + pytorch_forecasting/models/tide/_tide_metadata.py | 1 + pytorch_forecasting/tests/test_all_estimators.py | 15 +-------------- 4 files changed, 4 insertions(+), 14 deletions(-) diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index ad0e210a5..28754298a 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -9,6 +9,7 @@ class DeepARMetadata(_BasePtForecaster): _tags = { "info:name": "DeepAR", "info:compute": 3, + "object_type": "ptf-v1", "authors": ["jdb78"], "capability:exogenous": True, "capability:multivariate": True, diff --git a/pytorch_forecasting/models/nbeats/_nbeats_metadata.py b/pytorch_forecasting/models/nbeats/_nbeats_metadata.py index 9910a0ba1..f644b378a 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_metadata.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_metadata.py @@ -9,6 +9,7 @@ class NBeatsMetadata(_BasePtForecaster): _tags = { "info:name": "NBeats", "info:compute": 1, + "object_type": "ptf-v1", "authors": ["jdb78"], "capability:exogenous": False, "capability:multivariate": False, diff --git a/pytorch_forecasting/models/tide/_tide_metadata.py b/pytorch_forecasting/models/tide/_tide_metadata.py index 502229b71..49a2acc67 100644 --- a/pytorch_forecasting/models/tide/_tide_metadata.py +++ b/pytorch_forecasting/models/tide/_tide_metadata.py @@ -9,6 +9,7 @@ class TiDEModelMetadata(_BasePtForecaster): _tags = { "info:name": "TiDEModel", "info:compute": 3, + "object_type": "ptf-v1", "authors": ["Sohaib-Ahmed21"], "capability:exogenous": True, "capability:multivariate": True, diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index efc944937..add3ef7ba 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -245,20 +245,7 @@ def _integration( class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator): """Generic tests for all objects in the mini package.""" - def _all_objects(self): - """Retrieve list of all object classes, excluding ptf-v2 objects.""" - obj_list = super()._all_objects() - - filtered_obj_list = [] - for obj in obj_list: - if hasattr(obj, "get_class_tag"): - object_type = obj.get_class_tag("object_type", None) - if object_type != "ptf-v2": - filtered_obj_list.append(obj) - else: - filtered_obj_list.append(obj) - - return filtered_obj_list + object_type_filter = "ptf-v1" def test_doctest_examples(self, object_class): """Runs doctests for estimator class.""" From 2c518ee554afee1e45d5a6dfa69782c35e75498b Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 7 Jun 2025 01:08:53 +0530 Subject: [PATCH 68/80] add new base classes --- .../models/base/_base_object.py | 20 +++++++++++++++---- .../models/deepar/_deepar_metadata.py | 5 ++--- .../models/mlp/_decodermlp_metadata.py | 4 ++-- .../models/nbeats/_nbeats_metadata.py | 5 ++--- .../tft_v2_metadata.py | 5 ++--- .../models/tide/_tide_metadata.py | 5 ++--- .../tests/test_all_estimators.py | 2 +- .../tests/test_all_estimators_v2.py | 2 +- 8 files changed, 28 insertions(+), 20 deletions(-) diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 0106b7afa..a7cccfae5 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -17,10 +17,6 @@ class _BasePtForecaster(_BaseObject): This class points to model objects and contains metadata as tags. """ - _tags = { - "object_type": "forecaster_pytorch", - } - @classmethod def get_model_cls(cls): """Get model class.""" @@ -112,3 +108,19 @@ def create_test_instances_and_names(cls, parameter_set="default"): names = [cls.__name__] return objs, names + + +class _BasePtForecasterV1(_BasePtForecaster): + """Base class for PyTorch Forecasting v1 forecasters.""" + + _tags = { + "object_type": "forecaster_pytorch_v1", + } + + +class _BasePtForecasterV2(_BasePtForecaster): + """Base class for PyTorch Forecasting v2 forecasters.""" + + _tags = { + "object_type": "forecaster_pytorch_v2", + } diff --git a/pytorch_forecasting/models/deepar/_deepar_metadata.py b/pytorch_forecasting/models/deepar/_deepar_metadata.py index 28754298a..27e5af168 100644 --- a/pytorch_forecasting/models/deepar/_deepar_metadata.py +++ b/pytorch_forecasting/models/deepar/_deepar_metadata.py @@ -1,15 +1,14 @@ """DeepAR metadata container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecaster +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 -class DeepARMetadata(_BasePtForecaster): +class DeepARMetadata(_BasePtForecasterV1): """DeepAR metadata container.""" _tags = { "info:name": "DeepAR", "info:compute": 3, - "object_type": "ptf-v1", "authors": ["jdb78"], "capability:exogenous": True, "capability:multivariate": True, diff --git a/pytorch_forecasting/models/mlp/_decodermlp_metadata.py b/pytorch_forecasting/models/mlp/_decodermlp_metadata.py index c7abead33..c3fd71dce 100644 --- a/pytorch_forecasting/models/mlp/_decodermlp_metadata.py +++ b/pytorch_forecasting/models/mlp/_decodermlp_metadata.py @@ -1,9 +1,9 @@ """DecoderMLP metadata container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecaster +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 -class DecoderMLPMetadata(_BasePtForecaster): +class DecoderMLPMetadata(_BasePtForecasterV1): """DecoderMLP metadata container.""" _tags = { diff --git a/pytorch_forecasting/models/nbeats/_nbeats_metadata.py b/pytorch_forecasting/models/nbeats/_nbeats_metadata.py index f644b378a..9fd3132f1 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_metadata.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_metadata.py @@ -1,15 +1,14 @@ """NBeats metadata container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecaster +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 -class NBeatsMetadata(_BasePtForecaster): +class NBeatsMetadata(_BasePtForecasterV1): """NBeats metadata container.""" _tags = { "info:name": "NBeats", "info:compute": 1, - "object_type": "ptf-v1", "authors": ["jdb78"], "capability:exogenous": False, "capability:multivariate": False, diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py index 91e2440ed..41d2df27b 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py @@ -1,14 +1,13 @@ """TFT metadata container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecaster +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 -class TFTMetadata(_BasePtForecaster): +class TFTMetadata(_BasePtForecasterV2): """TFT metadata container.""" _tags = { "info:name": "TFT", - "object_type": "ptf-v2", "authors": ["phoeenniixx"], "capability:exogenous": True, "capability:multivariate": True, diff --git a/pytorch_forecasting/models/tide/_tide_metadata.py b/pytorch_forecasting/models/tide/_tide_metadata.py index 49a2acc67..e8866b38e 100644 --- a/pytorch_forecasting/models/tide/_tide_metadata.py +++ b/pytorch_forecasting/models/tide/_tide_metadata.py @@ -1,15 +1,14 @@ """TiDE metadata container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecaster +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 -class TiDEModelMetadata(_BasePtForecaster): +class TiDEModelMetadata(_BasePtForecasterV1): """Metadata container for TiDE Model.""" _tags = { "info:name": "TiDEModel", "info:compute": 3, - "object_type": "ptf-v1", "authors": ["Sohaib-Ahmed21"], "capability:exogenous": True, "capability:multivariate": True, diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index add3ef7ba..c690a5460 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -245,7 +245,7 @@ def _integration( class TestAllPtForecasters(PackageConfig, BaseFixtureGenerator): """Generic tests for all objects in the mini package.""" - object_type_filter = "ptf-v1" + object_type_filter = "forecaster_pytorch_v1" def test_doctest_examples(self, object_class): """Runs doctests for estimator class.""" diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index 857ea1f76..85ebe1cd0 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -124,7 +124,7 @@ def _integration( class TestAllPtForecastersV2(PackageConfig, BaseFixtureGenerator): """Generic tests for all objects in the mini package.""" - object_type_filter = "ptf-v2" + object_type_filter = "forecaster_pytorch_v2" def test_doctest_examples(self, object_class): """Runs doctests for estimator class.""" From 7a5c58fd37bccff3eeb8c065dca252c7c4a872c5 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Mon, 9 Jun 2025 02:07:31 +0530 Subject: [PATCH 69/80] remove try block --- .../tests/test_all_estimators_v2.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index 85ebe1cd0..7e5720a4b 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -109,16 +109,15 @@ def _integration( **kwargs, ) - try: - trainer.fit( - net, - train_dataloaders=train_dataloader, - val_dataloaders=val_dataloader, - ) - test_outputs = trainer.test(net, dataloaders=test_dataloader) - assert len(test_outputs) > 0 - finally: - shutil.rmtree(tmp_path, ignore_errors=True) + trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, + ) + test_outputs = trainer.test(net, dataloaders=test_dataloader) + assert len(test_outputs) > 0 + + shutil.rmtree(tmp_path, ignore_errors=True) class TestAllPtForecastersV2(PackageConfig, BaseFixtureGenerator): From 3b9de6d9f2be70b7f5833e60555d21b369a73030 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Tue, 10 Jun 2025 01:17:55 +0530 Subject: [PATCH 70/80] add support for multiple datamodules --- .../tft_v2_metadata.py | 77 +++++++++++++++++++ pytorch_forecasting/tests/_conftest.py | 56 ++------------ pytorch_forecasting/tests/_data_scenarios.py | 56 ++------------ .../tests/test_all_estimators_v2.py | 69 +++-------------- 4 files changed, 98 insertions(+), 160 deletions(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py index 41d2df27b..c1277fcb4 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py @@ -22,6 +22,83 @@ def get_model_cls(cls): return TFT + @classmethod + def _get_test_dataloaders_from(cls, trainer_kwargs): + """Create test dataloaders from trainer_kwargs - following v1 pattern.""" + from pytorch_forecasting.data.data_module import ( + EncoderDecoderTimeSeriesDataModule, + ) + from pytorch_forecasting.tests._conftest import make_datasets_v2 + from pytorch_forecasting.tests._data_scenarios import data_with_covariates_v2 + + data_with_covariates = data_with_covariates_v2() + + data_loader_default_kwargs = dict( + target="target", + group_ids=["agency_encoded", "sku_encoded"], + add_relative_time_idx=True, + ) + + data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {}) + data_loader_default_kwargs.update(data_loader_kwargs) + + datasets_info = make_datasets_v2( + data_with_covariates, **data_loader_default_kwargs + ) + + training_dataset = datasets_info["training_dataset"] + validation_dataset = datasets_info["validation_dataset"] + training_max_time_idx = datasets_info["training_max_time_idx"] + + max_encoder_length = data_loader_kwargs.get("max_encoder_length", 4) + max_prediction_length = data_loader_kwargs.get("max_prediction_length", 3) + add_relative_time_idx = data_loader_kwargs.get("add_relative_time_idx", True) + batch_size = data_loader_kwargs.get("batch_size", 2) + + train_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=training_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + add_relative_time_idx=add_relative_time_idx, + batch_size=batch_size, + train_val_test_split=(0.8, 0.2, 0.0), + ) + + val_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=batch_size, + train_val_test_split=(0.0, 1.0, 0.0), + ) + + test_datamodule = EncoderDecoderTimeSeriesDataModule( + time_series_dataset=validation_dataset, + max_encoder_length=max_encoder_length, + max_prediction_length=max_prediction_length, + min_prediction_idx=training_max_time_idx, + add_relative_time_idx=add_relative_time_idx, + batch_size=1, + train_val_test_split=(0.0, 0.0, 1.0), + ) + + train_datamodule.setup("fit") + val_datamodule.setup("fit") + test_datamodule.setup("test") + + train_dataloader = train_datamodule.train_dataloader() + val_dataloader = val_datamodule.val_dataloader() + test_dataloader = test_datamodule.test_dataloader() + + return { + "train": train_dataloader, + "val": val_dataloader, + "test": test_dataloader, + "data_module": train_datamodule, + } + @classmethod def get_test_train_params(cls): """Return testing parameter settings for the trainer. diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index a3b2bba5d..9f2806cfb 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -7,7 +7,6 @@ from pytorch_forecasting import TimeSeriesDataSet from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder -from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data from pytorch_forecasting.data.timeseries import TimeSeries @@ -175,16 +174,12 @@ def data_with_covariates_v2(): return data -def make_dataloaders_v2(data_with_covariates, **kwargs): - """Create dataloaders with consistent encoder/decoder features.""" +def make_datasets_v2(data_with_covariates, **kwargs): + """Create datasets with consistent encoder/decoder features.""" training_cutoff = "2016-09-01" - max_encoder_length = kwargs.get("max_encoder_length", 4) - max_prediction_length = kwargs.get("max_prediction_length", 3) - target_col = kwargs.get("target", "target") group_cols = kwargs.get("group_ids", ["agency_encoded", "sku_encoded"]) - add_relative_time_idx = kwargs.get("add_relative_time_idx", True) known_features = [ "month", @@ -276,51 +271,10 @@ def make_dataloaders_v2(data_with_covariates, **kwargs): training_max_time_idx = training_data["time_idx"].max() + 1 - train_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=training_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - add_relative_time_idx=add_relative_time_idx, - batch_size=2, - num_workers=0, - train_val_test_split=(0.8, 0.2, 0.0), - ) - - val_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=2, - num_workers=0, - train_val_test_split=(0.0, 1.0, 0.0), - ) - - test_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=1, - num_workers=0, - train_val_test_split=(0.0, 0.0, 1.0), - ) - - train_datamodule.setup("fit") - val_datamodule.setup("fit") - test_datamodule.setup("test") - - train_dataloader = train_datamodule.train_dataloader() - val_dataloader = val_datamodule.val_dataloader() - test_dataloader = test_datamodule.test_dataloader() - return { - "train": train_dataloader, - "val": val_dataloader, - "test": test_dataloader, - "data_module": train_datamodule, + "training_dataset": training_dataset, + "validation_dataset": validation_dataset, + "training_max_time_idx": training_max_time_idx, } diff --git a/pytorch_forecasting/tests/_data_scenarios.py b/pytorch_forecasting/tests/_data_scenarios.py index d39f6d988..40c6fa9d9 100644 --- a/pytorch_forecasting/tests/_data_scenarios.py +++ b/pytorch_forecasting/tests/_data_scenarios.py @@ -7,7 +7,6 @@ from pytorch_forecasting import TimeSeriesDataSet from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder -from pytorch_forecasting.data.data_module import EncoderDecoderTimeSeriesDataModule from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data from pytorch_forecasting.data.timeseries import TimeSeries @@ -169,16 +168,12 @@ def data_with_covariates_v2(): return data -def make_dataloaders_v2(data_with_covariates, **kwargs): - """Create dataloaders with consistent encoder/decoder features.""" +def make_datasets_v2(data_with_covariates, **kwargs): + """Create datasets with consistent encoder/decoder features.""" training_cutoff = "2016-09-01" - max_encoder_length = 4 - max_prediction_length = 3 - target_col = kwargs.get("target", "target") group_cols = kwargs.get("group_ids", ["agency_encoded", "sku_encoded"]) - add_relative_time_idx = kwargs.get("add_relative_time_idx", True) known_features = [ "month", @@ -270,51 +265,10 @@ def make_dataloaders_v2(data_with_covariates, **kwargs): training_max_time_idx = training_data["time_idx"].max() + 1 - train_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=training_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - add_relative_time_idx=add_relative_time_idx, - batch_size=2, - num_workers=0, - train_val_test_split=(0.8, 0.2, 0.0), - ) - - val_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=2, - num_workers=0, - train_val_test_split=(0.0, 1.0, 0.0), - ) - - test_datamodule = EncoderDecoderTimeSeriesDataModule( - time_series_dataset=validation_dataset, - max_encoder_length=max_encoder_length, - max_prediction_length=max_prediction_length, - min_prediction_idx=training_max_time_idx, - add_relative_time_idx=add_relative_time_idx, - batch_size=1, - num_workers=0, - train_val_test_split=(0.0, 0.0, 1.0), - ) - - train_datamodule.setup("fit") - val_datamodule.setup("fit") - test_datamodule.setup("test") - - train_dataloader = train_datamodule.train_dataloader() - val_dataloader = val_datamodule.val_dataloader() - test_dataloader = test_datamodule.test_dataloader() - return { - "train": train_dataloader, - "val": val_dataloader, - "test": test_dataloader, - "data_module": train_datamodule, + "training_dataset": training_dataset, + "validation_dataset": validation_dataset, + "training_max_time_idx": training_max_time_idx, } diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index 7e5720a4b..7cb283454 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -8,7 +8,6 @@ from lightning.pytorch.loggers import TensorBoardLogger import torch.nn as nn -from pytorch_forecasting.tests._conftest import make_dataloaders_v2 as make_dataloaders from pytorch_forecasting.tests.test_all_estimators import ( BaseFixtureGenerator, PackageConfig, @@ -21,33 +20,16 @@ def _integration( estimator_cls, - data_with_covariates, + dataloaders, tmp_path, data_loader_kwargs={}, clip_target: bool = False, trainer_kwargs=None, **kwargs, ): - data_with_covariates = data_with_covariates.copy() - if clip_target: - data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0) - else: - data_with_covariates["target"] = data_with_covariates["volume"] - - data_loader_default_kwargs = dict( - target="target", - group_ids=["agency_encoded", "sku_encoded"], - add_relative_time_idx=True, - ) - data_loader_default_kwargs.update(data_loader_kwargs) - - dataloaders_with_covariates = make_dataloaders( - data_with_covariates, **data_loader_default_kwargs - ) - - train_dataloader = dataloaders_with_covariates["train"] - val_dataloader = dataloaders_with_covariates["val"] - test_dataloader = dataloaders_with_covariates["test"] + train_dataloader = dataloaders["train"] + val_dataloader = dataloaders["val"] + test_dataloader = dataloaders["test"] early_stop_callback = EarlyStopping( monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" @@ -68,40 +50,12 @@ def _integration( logger=logger, **trainer_kwargs, ) - training_data_module = dataloaders_with_covariates["data_module"] + training_data_module = dataloaders.get("data_module") metadata = training_data_module.metadata - assert metadata["encoder_cont"] == 14 # 14 features (8 known + 6 unknown) - assert metadata["encoder_cat"] == 0 - assert metadata["decoder_cont"] == 8 # 8 (only known features) - assert metadata["decoder_cat"] == 0 - assert metadata["static_categorical_features"] == 0 - assert ( - metadata["static_continuous_features"] == 2 - ) # 2 (agency_encoded, sku_encoded) - assert metadata["target"] == 1 - - batch_x, batch_y = next(iter(train_dataloader)) - - assert batch_x["encoder_cont"].shape[2] == metadata["encoder_cont"] - assert batch_x["encoder_cat"].shape[2] == metadata["encoder_cat"] - - assert batch_x["decoder_cont"].shape[2] == metadata["decoder_cont"] - assert batch_x["decoder_cat"].shape[2] == metadata["decoder_cat"] - - if "static_categorical_features" in batch_x: - assert ( - batch_x["static_categorical_features"].shape[2] - == metadata["static_categorical_features"] - ) - - if "static_continuous_features" in batch_x: - assert ( - batch_x["static_continuous_features"].shape[2] - == metadata["static_continuous_features"] - ) - - assert batch_y.shape[2] == metadata["target"] + assert isinstance( + metadata, dict + ), f"Expected metadata to be dict, got {type(metadata)}" net = estimator_cls( metadata=metadata, @@ -137,8 +91,7 @@ def test_integration( trainer_kwargs, tmp_path, ): - from pytorch_forecasting.tests._data_scenarios import data_with_covariates_v2 - - data_with_covariates = data_with_covariates_v2() object_class = object_metadata.get_model_cls() - _integration(object_class, data_with_covariates, tmp_path, **trainer_kwargs) + dataloaders = object_metadata._get_test_dataloaders_from(trainer_kwargs) + + _integration(object_class, dataloaders, tmp_path, **trainer_kwargs) From 032a7b0a202b995cf0d522ba6c3206a05f0726a3 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Tue, 10 Jun 2025 01:23:30 +0530 Subject: [PATCH 71/80] typo --- .../models/temporal_fusion_transformer/tft_v2_metadata.py | 2 +- pytorch_forecasting/tests/test_all_estimators_v2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py index c1277fcb4..f17f514f2 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py @@ -23,7 +23,7 @@ def get_model_cls(cls): return TFT @classmethod - def _get_test_dataloaders_from(cls, trainer_kwargs): + def _get_test_datamodule_from(cls, trainer_kwargs): """Create test dataloaders from trainer_kwargs - following v1 pattern.""" from pytorch_forecasting.data.data_module import ( EncoderDecoderTimeSeriesDataModule, diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index 7cb283454..2bde90505 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -92,6 +92,6 @@ def test_integration( tmp_path, ): object_class = object_metadata.get_model_cls() - dataloaders = object_metadata._get_test_dataloaders_from(trainer_kwargs) + dataloaders = object_metadata._get_test_datamodule_from(trainer_kwargs) _integration(object_class, dataloaders, tmp_path, **trainer_kwargs) From 8b0087eec5b713d4e7b414790f1d6fa8b6d7c135 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 13 Jun 2025 02:38:40 +0530 Subject: [PATCH 72/80] linting --- pytorch_forecasting/tests/_data_scenarios.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_forecasting/tests/_data_scenarios.py b/pytorch_forecasting/tests/_data_scenarios.py index de20a2eae..c13ff0ae5 100644 --- a/pytorch_forecasting/tests/_data_scenarios.py +++ b/pytorch_forecasting/tests/_data_scenarios.py @@ -263,7 +263,6 @@ def make_datasets_v2(data_with_covariates, **kwargs): } - def dataloaders_with_different_encoder_decoder_length(): return make_dataloaders( data_with_covariates(), From 57d635b09147cf060555fc6c5c62f8c0818cbfa8 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 14 Jun 2025 00:14:12 +0530 Subject: [PATCH 73/80] add pkg name to v2 --- .../models/deepar/_deepar_pkg.py | 4 +-- .../models/mlp/_decodermlp_pkg.py | 4 +-- .../models/nbeats/_nbeats_pkg.py | 4 +-- .../tft_v2_metadata.py | 6 ++--- pytorch_forecasting/models/tide/_tide_pkg.py | 4 +-- .../models/timexer/_timexer_pkg.py | 4 +-- .../tests/test_all_estimators_v2.py | 26 ++++++++++++++++--- 7 files changed, 36 insertions(+), 16 deletions(-) diff --git a/pytorch_forecasting/models/deepar/_deepar_pkg.py b/pytorch_forecasting/models/deepar/_deepar_pkg.py index eb30d639c..ffee0d328 100644 --- a/pytorch_forecasting/models/deepar/_deepar_pkg.py +++ b/pytorch_forecasting/models/deepar/_deepar_pkg.py @@ -1,9 +1,9 @@ """DeepAR package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecaster +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 -class DeepAR_pkg(_BasePtForecaster): +class DeepAR_pkg(_BasePtForecasterV1): """DeepAR package container.""" _tags = { diff --git a/pytorch_forecasting/models/mlp/_decodermlp_pkg.py b/pytorch_forecasting/models/mlp/_decodermlp_pkg.py index 917d99fb2..df5060088 100644 --- a/pytorch_forecasting/models/mlp/_decodermlp_pkg.py +++ b/pytorch_forecasting/models/mlp/_decodermlp_pkg.py @@ -1,9 +1,9 @@ """DecoderMLP package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecaster +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 -class DecoderMLP_pkg(_BasePtForecaster): +class DecoderMLP_pkg(_BasePtForecasterV1): """DecoderMLP package container.""" _tags = { diff --git a/pytorch_forecasting/models/nbeats/_nbeats_pkg.py b/pytorch_forecasting/models/nbeats/_nbeats_pkg.py index 7650749c2..837e709ce 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_pkg.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_pkg.py @@ -1,9 +1,9 @@ """NBeats package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecaster +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 -class NBeats_pkg(_BasePtForecaster): +class NBeats_pkg(_BasePtForecasterV1): """NBeats package container.""" _tags = { diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py index f17f514f2..83f9a5b34 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py @@ -1,10 +1,10 @@ -"""TFT metadata container.""" +"""TFT package container.""" from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 -class TFTMetadata(_BasePtForecasterV2): - """TFT metadata container.""" +class TFT_pkg(_BasePtForecasterV2): + """TFT package container.""" _tags = { "info:name": "TFT", diff --git a/pytorch_forecasting/models/tide/_tide_pkg.py b/pytorch_forecasting/models/tide/_tide_pkg.py index 67fdcf154..373a95525 100644 --- a/pytorch_forecasting/models/tide/_tide_pkg.py +++ b/pytorch_forecasting/models/tide/_tide_pkg.py @@ -1,9 +1,9 @@ """TiDE package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecaster +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 -class TiDEModel_pkg(_BasePtForecaster): +class TiDEModel_pkg(_BasePtForecasterV1): """Package container for TiDE Model.""" _tags = { diff --git a/pytorch_forecasting/models/timexer/_timexer_pkg.py b/pytorch_forecasting/models/timexer/_timexer_pkg.py index d1febcdb1..f22c2db0f 100644 --- a/pytorch_forecasting/models/timexer/_timexer_pkg.py +++ b/pytorch_forecasting/models/timexer/_timexer_pkg.py @@ -1,9 +1,9 @@ """TimeXer package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecaster +from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 -class TimeXer_pkg(_BasePtForecaster): +class TimeXer_pkg(_BasePtForecasterV1): """TimeXer package container.""" _tags = { diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index 2bde90505..9e4ba316b 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -87,11 +87,31 @@ def test_doctest_examples(self, object_class): def test_integration( self, - object_metadata, + object_pkg, trainer_kwargs, tmp_path, ): - object_class = object_metadata.get_model_cls() - dataloaders = object_metadata._get_test_datamodule_from(trainer_kwargs) + object_class = object_pkg.get_model_cls() + dataloaders = object_pkg._get_test_datamodule_from(trainer_kwargs) _integration(object_class, dataloaders, tmp_path, **trainer_kwargs) + + def test_pkg_linkage(self, object_pkg, object_class): + """Test that the package is linked correctly.""" + # check name method + msg = ( + f"Package {object_pkg}.name() does not match class " + f"name {object_class.__name__}. " + "The expected package name is " + f"{object_class.__name__}_pkg." + ) + assert object_pkg.name() == object_class.__name__, msg + + # check naming convention + msg = ( + f"Package {object_pkg.__name__} does not match class " + f"name {object_class.__name__}. " + "The expected package name is " + f"{object_class.__name__}_pkg." + ) + assert object_pkg.__name__ == object_class.__name__ + "_pkg", msg From c4d5628694015ca325d3c3b906a00fdb95952ccc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 15 Jun 2025 21:27:00 +0200 Subject: [PATCH 74/80] revert changes to conftest --- pytorch_forecasting/tests/_conftest.py | 186 ------------------------- 1 file changed, 186 deletions(-) diff --git a/pytorch_forecasting/tests/_conftest.py b/pytorch_forecasting/tests/_conftest.py index dd8d1d8d9..8def0bfe2 100644 --- a/pytorch_forecasting/tests/_conftest.py +++ b/pytorch_forecasting/tests/_conftest.py @@ -1,14 +1,10 @@ -from datetime import datetime - import numpy as np -import pandas as pd import pytest import torch from pytorch_forecasting import TimeSeriesDataSet from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data -from pytorch_forecasting.data.timeseries import TimeSeries torch.manual_seed(23) @@ -60,188 +56,6 @@ def make_dataloaders(data_with_covariates, **kwargs): return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader) -@pytest.fixture(scope="session") -def data_with_covariates_v2(): - """Create synthetic time series data with all numerical features.""" - - start_date = datetime(2015, 1, 1) - end_date = datetime(2017, 12, 31) - dates = pd.date_range(start_date, end_date, freq="M") - - agencies = [0, 1] - skus = [0, 1] - data_list = [] - - for agency in agencies: - for sku in skus: - for date in dates: - time_idx = (date.year - 2015) * 12 + date.month - 1 - - volume = ( - np.random.exponential(2) - + 0.1 * time_idx - + 0.5 * np.sin(date.month * np.pi / 6) - ) - volume = max(0.001, volume) - month = date.month - year = date.year - quarter = (date.month - 1) // 3 + 1 - - seasonal_1 = np.sin(2 * np.pi * date.month / 12) - seasonal_2 = np.cos(2 * np.pi * date.month / 12) - - agency_feature_1 = agency * 10 + np.random.normal(0, 0.1) - agency_feature_2 = agency * 5 + np.random.normal(0, 0.1) - - sku_feature_1 = sku * 8 + np.random.normal(0, 0.1) - sku_feature_2 = sku * 3 + np.random.normal(0, 0.1) - - trend = time_idx * 0.1 - noise = np.random.normal(0, 0.1) - - special_event_1 = 1 if date.month in [12, 1] else 0 - special_event_2 = 1 if date.month in [6, 7, 8] else 0 - - data_list.append( - { - "date": date, - "time_idx": time_idx, - "agency_encoded": agency, - "sku_encoded": sku, - "volume": volume, - "target": volume, - "weight": 1.0 + np.sqrt(volume), - "month": month, - "year": year, - "quarter": quarter, - "seasonal_1": seasonal_1, - "seasonal_2": seasonal_2, - "agency_feature_1": agency_feature_1, - "agency_feature_2": agency_feature_2, - "sku_feature_1": sku_feature_1, - "sku_feature_2": sku_feature_2, - "trend": trend, - "noise": noise, - "special_event_1": special_event_1, - "special_event_2": special_event_2, - "log_volume": np.log1p(volume), - } - ) - - data = pd.DataFrame(data_list) - - numeric_cols = [col for col in data.columns if col != "date"] - for col in numeric_cols: - data[col] = pd.to_numeric(data[col], errors="coerce") - data[numeric_cols] = data[numeric_cols].fillna(0) - - return data - - -def make_datasets_v2(data_with_covariates, **kwargs): - """Create datasets with consistent encoder/decoder features.""" - - training_cutoff = "2016-09-01" - target_col = kwargs.get("target", "target") - group_cols = kwargs.get("group_ids", ["agency_encoded", "sku_encoded"]) - - known_features = [ - "month", - "year", - "quarter", - "seasonal_1", - "seasonal_2", - "special_event_1", - "special_event_2", - "trend", - ] - unknown_features = [ - "agency_feature_1", - "agency_feature_2", - "sku_feature_1", - "sku_feature_2", - "noise", - "log_volume", - ] - - numerical_features = known_features + unknown_features - categorical_features = [] - static_features = group_cols - - for col in numerical_features + categorical_features + group_cols + [target_col]: - if col in data_with_covariates.columns: - data_with_covariates[col] = pd.to_numeric( - data_with_covariates[col], errors="coerce" - ).fillna(0) - - for col in categorical_features + group_cols: - if col in data_with_covariates.columns: - data_with_covariates[col] = data_with_covariates[col].astype("int64") - - if "weight" in data_with_covariates.columns: - data_with_covariates["weight"] = pd.to_numeric( - data_with_covariates["weight"], errors="coerce" - ).fillna(1.0) - - training_data = data_with_covariates[ - data_with_covariates.date < training_cutoff - ].copy() - validation_data = data_with_covariates.copy() - - required_columns = ( - ["time_idx", target_col, "weight", "date"] - + group_cols - + numerical_features - + categorical_features - ) - - available_columns = [ - col for col in required_columns if col in data_with_covariates.columns - ] - - training_data_clean = training_data[available_columns].copy() - validation_data_clean = validation_data[available_columns].copy() - - if "date" in training_data_clean.columns: - training_data_clean = training_data_clean.drop("date", axis=1) - if "date" in validation_data_clean.columns: - validation_data_clean = validation_data_clean.drop("date", axis=1) - - training_dataset = TimeSeries( - data=training_data_clean, - time="time_idx", - target=[target_col], - group=group_cols, - weight="weight", - num=numerical_features, - cat=categorical_features if categorical_features else None, - known=known_features, - unknown=unknown_features, - static=static_features, - ) - - validation_dataset = TimeSeries( - data=validation_data_clean, - time="time_idx", - target=[target_col], - group=group_cols, - weight="weight", - num=numerical_features, - cat=categorical_features if categorical_features else None, - known=known_features, - unknown=unknown_features, - static=static_features, - ) - - training_max_time_idx = training_data["time_idx"].max() + 1 - - return { - "training_dataset": training_dataset, - "validation_dataset": validation_dataset, - "training_max_time_idx": training_max_time_idx, - } - - @pytest.fixture( params=[ dict(), From 6129d333bc40e918d6f682a269401b691216964e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 15 Jun 2025 21:30:51 +0200 Subject: [PATCH 75/80] reverts and fixes --- pytorch_forecasting/models/base/__init__.py | 2 ++ pytorch_forecasting/models/base/_base_object.py | 8 ++++---- pytorch_forecasting/models/deepar/_deepar_pkg.py | 4 ++-- pytorch_forecasting/models/mlp/_decodermlp_pkg.py | 4 ++-- pytorch_forecasting/models/nbeats/_nbeats_pkg.py | 4 ++-- .../models/temporal_fusion_transformer/__init__.py | 4 ++++ .../{tft_v2_metadata.py => _tft_pkg_v2.py} | 2 +- pytorch_forecasting/models/tide/_tide_pkg.py | 4 ++-- pytorch_forecasting/models/timexer/_timexer_pkg.py | 4 ++-- 9 files changed, 21 insertions(+), 15 deletions(-) rename pytorch_forecasting/models/temporal_fusion_transformer/{tft_v2_metadata.py => _tft_pkg_v2.py} (98%) diff --git a/pytorch_forecasting/models/base/__init__.py b/pytorch_forecasting/models/base/__init__.py index 7b69ec246..a0ac824e1 100644 --- a/pytorch_forecasting/models/base/__init__.py +++ b/pytorch_forecasting/models/base/__init__.py @@ -10,11 +10,13 @@ from pytorch_forecasting.models.base._base_object import ( _BaseObject, _BasePtForecaster, + _BasePtForecasterV2, ) __all__ = [ "_BaseObject", "_BasePtForecaster", + "_BasePtForecasterV2", "AutoRegressiveBaseModel", "AutoRegressiveBaseModelWithCovariates", "BaseModel", diff --git a/pytorch_forecasting/models/base/_base_object.py b/pytorch_forecasting/models/base/_base_object.py index 81c91478b..12e37fcc5 100644 --- a/pytorch_forecasting/models/base/_base_object.py +++ b/pytorch_forecasting/models/base/_base_object.py @@ -11,7 +11,7 @@ class _BaseObject(_SkbaseBaseObject): pass -class _BasePtForecaster(_BaseObject): +class _BasePtForecaster_Common(_BaseObject): """Base class for all PyTorch Forecasting forecaster packages. This class points to model objects and contains metadata as tags. @@ -110,15 +110,15 @@ def create_test_instances_and_names(cls, parameter_set="default"): return objs, names -class _BasePtForecasterV1(_BasePtForecaster): +class _BasePtForecaster(_BasePtForecaster_Common): """Base class for PyTorch Forecasting v1 forecasters.""" _tags = { - "object_type": "forecaster_pytorch_v1", + "object_type": ["forecaster_pytorch", "forecaster_pytorch_v1"], } -class _BasePtForecasterV2(_BasePtForecaster): +class _BasePtForecasterV2(_BasePtForecaster_Common): """Base class for PyTorch Forecasting v2 forecasters.""" _tags = { diff --git a/pytorch_forecasting/models/deepar/_deepar_pkg.py b/pytorch_forecasting/models/deepar/_deepar_pkg.py index ffee0d328..eb30d639c 100644 --- a/pytorch_forecasting/models/deepar/_deepar_pkg.py +++ b/pytorch_forecasting/models/deepar/_deepar_pkg.py @@ -1,9 +1,9 @@ """DeepAR package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 +from pytorch_forecasting.models.base._base_object import _BasePtForecaster -class DeepAR_pkg(_BasePtForecasterV1): +class DeepAR_pkg(_BasePtForecaster): """DeepAR package container.""" _tags = { diff --git a/pytorch_forecasting/models/mlp/_decodermlp_pkg.py b/pytorch_forecasting/models/mlp/_decodermlp_pkg.py index df5060088..917d99fb2 100644 --- a/pytorch_forecasting/models/mlp/_decodermlp_pkg.py +++ b/pytorch_forecasting/models/mlp/_decodermlp_pkg.py @@ -1,9 +1,9 @@ """DecoderMLP package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 +from pytorch_forecasting.models.base._base_object import _BasePtForecaster -class DecoderMLP_pkg(_BasePtForecasterV1): +class DecoderMLP_pkg(_BasePtForecaster): """DecoderMLP package container.""" _tags = { diff --git a/pytorch_forecasting/models/nbeats/_nbeats_pkg.py b/pytorch_forecasting/models/nbeats/_nbeats_pkg.py index 837e709ce..7650749c2 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats_pkg.py +++ b/pytorch_forecasting/models/nbeats/_nbeats_pkg.py @@ -1,9 +1,9 @@ """NBeats package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 +from pytorch_forecasting.models.base._base_object import _BasePtForecaster -class NBeats_pkg(_BasePtForecasterV1): +class NBeats_pkg(_BasePtForecaster): """NBeats package container.""" _tags = { diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index c823d6229..c6a45a95f 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -11,6 +11,9 @@ InterpretableMultiHeadAttention, VariableSelectionNetwork, ) +from pytorch_forecasting.models.temporal_fusion_transformer._tft_pkg_v2 import ( + TFT_pkg, +) __all__ = [ "TemporalFusionTransformer", @@ -19,5 +22,6 @@ "GatedLinearUnit", "GatedResidualNetwork", "InterpretableMultiHeadAttention", + "TFT_pkg", "VariableSelectionNetwork", ] diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py similarity index 98% rename from pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py rename to pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py index 83f9a5b34..4799aa8f1 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/tft_v2_metadata.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py @@ -1,6 +1,6 @@ """TFT package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 +from pytorch_forecasting.models.base import _BasePtForecasterV2 class TFT_pkg(_BasePtForecasterV2): diff --git a/pytorch_forecasting/models/tide/_tide_pkg.py b/pytorch_forecasting/models/tide/_tide_pkg.py index 373a95525..67fdcf154 100644 --- a/pytorch_forecasting/models/tide/_tide_pkg.py +++ b/pytorch_forecasting/models/tide/_tide_pkg.py @@ -1,9 +1,9 @@ """TiDE package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 +from pytorch_forecasting.models.base._base_object import _BasePtForecaster -class TiDEModel_pkg(_BasePtForecasterV1): +class TiDEModel_pkg(_BasePtForecaster): """Package container for TiDE Model.""" _tags = { diff --git a/pytorch_forecasting/models/timexer/_timexer_pkg.py b/pytorch_forecasting/models/timexer/_timexer_pkg.py index f22c2db0f..d1febcdb1 100644 --- a/pytorch_forecasting/models/timexer/_timexer_pkg.py +++ b/pytorch_forecasting/models/timexer/_timexer_pkg.py @@ -1,9 +1,9 @@ """TimeXer package container.""" -from pytorch_forecasting.models.base._base_object import _BasePtForecasterV1 +from pytorch_forecasting.models.base._base_object import _BasePtForecaster -class TimeXer_pkg(_BasePtForecasterV1): +class TimeXer_pkg(_BasePtForecaster): """TimeXer package container.""" _tags = { From 32ef57e763454ca12d9bd49846431ca24b057d38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 15 Jun 2025 21:31:24 +0200 Subject: [PATCH 76/80] v2 --- .../models/temporal_fusion_transformer/__init__.py | 2 +- .../models/temporal_fusion_transformer/_tft_pkg_v2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index c6a45a95f..8d700b641 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -22,6 +22,6 @@ "GatedLinearUnit", "GatedResidualNetwork", "InterpretableMultiHeadAttention", - "TFT_pkg", + "TFT_pkg_v2", "VariableSelectionNetwork", ] diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py index 4799aa8f1..b2a5f357f 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py @@ -3,7 +3,7 @@ from pytorch_forecasting.models.base import _BasePtForecasterV2 -class TFT_pkg(_BasePtForecasterV2): +class TFT_pkg_v2(_BasePtForecasterV2): """TFT package container.""" _tags = { From 93ea865f4c02bfc0d810707cd1ccadf7c9805781 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 15 Jun 2025 21:55:57 +0200 Subject: [PATCH 77/80] Update __init__.py --- .../models/temporal_fusion_transformer/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index 8d700b641..05fc8958a 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -1,5 +1,8 @@ """Temporal fusion transformer for forecasting timeseries.""" +from pytorch_forecasting.models.temporal_fusion_transformer._tft_pkg_v2 import ( + TFT_pkg_v2, +) from pytorch_forecasting.models.temporal_fusion_transformer._tft import ( TemporalFusionTransformer, ) @@ -11,9 +14,6 @@ InterpretableMultiHeadAttention, VariableSelectionNetwork, ) -from pytorch_forecasting.models.temporal_fusion_transformer._tft_pkg_v2 import ( - TFT_pkg, -) __all__ = [ "TemporalFusionTransformer", From 8e95e6e453f92a2ae87202a3f53a70c32bf8a39a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 15 Jun 2025 21:57:13 +0200 Subject: [PATCH 78/80] Update __init__.py --- .../models/temporal_fusion_transformer/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index 05fc8958a..66ba9b58a 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -1,11 +1,11 @@ """Temporal fusion transformer for forecasting timeseries.""" -from pytorch_forecasting.models.temporal_fusion_transformer._tft_pkg_v2 import ( - TFT_pkg_v2, -) from pytorch_forecasting.models.temporal_fusion_transformer._tft import ( TemporalFusionTransformer, ) +from pytorch_forecasting.models.temporal_fusion_transformer._tft_pkg_v2 import ( + TFT_pkg_v2, +) from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import ( AddNorm, GateAddNorm, From f990c8a4611792cd71881a97093d763d094ff1ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 15 Jun 2025 22:35:43 +0200 Subject: [PATCH 79/80] Update test_all_estimators_v2.py --- pytorch_forecasting/tests/test_all_estimators_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_forecasting/tests/test_all_estimators_v2.py b/pytorch_forecasting/tests/test_all_estimators_v2.py index 9e4ba316b..4b063ed44 100644 --- a/pytorch_forecasting/tests/test_all_estimators_v2.py +++ b/pytorch_forecasting/tests/test_all_estimators_v2.py @@ -114,4 +114,4 @@ def test_pkg_linkage(self, object_pkg, object_class): "The expected package name is " f"{object_class.__name__}_pkg." ) - assert object_pkg.__name__ == object_class.__name__ + "_pkg", msg + assert object_pkg.__name__ == object_class.__name__ + "_pkg_v2", msg From 53747e0ebdf369ea14210064966868ac25cc9047 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sun, 15 Jun 2025 22:37:36 +0200 Subject: [PATCH 80/80] Update _tft_pkg_v2.py --- .../models/temporal_fusion_transformer/_tft_pkg_v2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py index b2a5f357f..5b9bfe6c7 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/_tft_pkg_v2.py @@ -28,8 +28,10 @@ def _get_test_datamodule_from(cls, trainer_kwargs): from pytorch_forecasting.data.data_module import ( EncoderDecoderTimeSeriesDataModule, ) - from pytorch_forecasting.tests._conftest import make_datasets_v2 - from pytorch_forecasting.tests._data_scenarios import data_with_covariates_v2 + from pytorch_forecasting.tests._data_scenarios import ( + data_with_covariates_v2, + make_datasets_v2, + ) data_with_covariates = data_with_covariates_v2()