Description
Discussion thread for API re-design for pytorch.forecasting
next 1.X and towards 2.0. Comments appreciated from everyone!
Link to enhancemeng proposal: sktime/enhancement-proposals#39
Context and goals
High-level directions:
- the DSIPTS project (https://github.com/DSIP-FBK/DSIPTS) be used as an experimental branch for
pytorch-forecasting 2.0
. We will need to homogenize interfaces, consolidate design ideas, and ensure downwards compatibility. - further inspiration can be taken from the
thuml
project, also see [ENH] neural network libraries in thuml time-series-library sktime#7243. - @agobbifbk and a team at FBK with substantial time on this will drive this together with @fkiraly and developers at
sktime
.
High-level features for 2.0 with MoSCoW analysis:
- M: unified model API which is easily extensible and composable, similar to
sktime
and DSIPTS, but as closely to thepytorch
level as possible. The API need not cover forecasters in general, onlytorch
based forecasters.- M: unified monitoring and logging API, also see [API] redesign of logging and monitoring for 2.0 #1700
- M: extension templates need to be created
- S:
skbase
can be used to curate the forecasters as records, with tags, etc - S: model persistence
- C: third party extension patterns, so new models can "live" in other repositories or packages, for instance
thuml
- M: reworked and unified data input API
- M: support static variables and categoricals
- S: support for multiple data input locations and formats - pandas, polars, distributed solutions etc
- M: MLops and benchmarking features as in DSIPTS
- S: support for pre-training, model hubs, foundation models, but this could be post-2.0
Meeting notes
Summary of discussion on Dec 20, 2024 and prior
FYI @agobbifbk, @thawn, @sktime/core-developers.
High-level directions:
- the DSIPTS project (https://github.com/DSIP-FBK/DSIPTS) be used as an experimental branch for
pytorch-forecasting 2.0
. We will need to homogenize interfaces, consolidate design ideas, and ensure downwards compatibility. - further inspiration can be taken from the
thuml
project, also see [ENH] neural network libraries in thuml time-series-library sktime#7243. - @agobbifbk and a team at FBK with substantial time on this will drive this together with @fkiraly and developers at
sktime
.
High-level features for 2.0 with MoSCoW analysis:
- M: unified model API which is easily extensible and composable, similar to
sktime
and DSIPTS, but as closely to thepytorch
level as possible. The API need not cover forecasters in general, onlytorch
based forecasters.- M: unified monitoring and logging API, also see [API] redesign of logging and monitoring for 2.0 #1700
- M: extension templates need to be created
- S:
skbase
can be used to curate the forecasters as records, with tags, etc - S: model persistence
- C: third party extension patterns, so new models can "live" in other repositories or packages, for instance
thuml
- M: reworked and unified data input API
- M: support static variables and categoricals
- S: support for multiple data input locations and formats - pandas, polars, distributed solutions etc
- M: MLops and benchmarking features as in DSIPTS
- S: support for pre-training, model hubs, foundation models, but this could be post-2.0
Todos:
0. update documentation on dsipts to signpost the above. README etc.
- highest priority - consolidated API design for model and data layer.
- Depending on distance to current ptf and dsipts, use one or the other location for improvements (separate 2.0 -> dsipts, 1.X -> ptf as current).
- ptf = stable and downwards compatible; dsipts = "playground"
- first step for that: side-by-side comparisons of code, defined core workflows
- planning sessions & sprints from Jan 2025
Roadmap planning Jan 15, 2025
👋 Attendees
- [name=fill this in]
- [name=Andrea Gobbi]
- [name=Aryan Saini]
- [name=Franz Kiraly]
- [name=Benedikt Heidrich]
- [name=Sandeep Kumar]
- [name=Felix Hirwa Nshuti]
Prioritization
👍👍👍👍👍👍👍👍👍👍👍👍👍
✔️✔️✔️
💬 💬 💬
data layer - dataset, dataloader 👍👍👍👍👍👍👍 💬 ✔️✔️✔️
- dataset and dataloader API consolidation
model layer - base classes, configs, unified API 👍👍👍👍 ✔️
- A more refined base classes maybe and proper documentation of them
- tests: input-output shapes of the bathches
- operationalization of the models (inference must be a clear process/output)
- think about how model config are stored. e.g. versioning so that we know how to load model weights.
- easy interface for adding new architectures
foundation models, model hubs 👍 👍 👍
- To think how to handle pre-trained models in ptf and how to interface them?
- integration with model hubs (hf/kaggle/...)
- [IDEA] looking to Chronos-Bolt and integrate into the repo.
documentation👍✔️
- adding tutorials and examples for the new users
- proper documentation of base classes
benchmarking 👍 👍💬
- easy way to use benchmark datasets
- reproducibility, scalability, concept of experiment (same dataset, different models/parameters to be compared )
mlops and scaling (distributed, cluster etc)👍 👍
- operationalization of the models (inference)
- hooks for: slurm cluster, multiprocess, OPTUNA
- think about how model config are stored. e.g. versioning so that we know how to load model weights.
more learning tasks supported
- [IDEA] continuous learning / active learning
- [IDEA] easy to convert DL architectures from regression to classification
Tech meeting Jan 20, 2025
Attendees:
- Andrea
- Felix
- Franz
- Pranav B
- Sandeep
- Till
Agenda
- discussing agenda
- number of classes, dataset, dataloader, "bottleneck" idea
__getitem__
output convention__init__
input convention(s)- handling large data sets, pandas vs polars
References
Umbrella issue design
#1736
Notes
number of classes, dataset, dataloader, "bottleneck" idea
AG: should be making it as modular as possible
- avoid "god class" anti-pattern (like current ptf
TimeSeriesDataSet
that does everything) - so, more than one?
- thinks there are two different ways to implement dataloader, with respect of where we compute slices when re-sampling
- option 1: pre-compute in
__init__
- more memory intensive, clear distinction between train and inference; naive implementation needs to load everything in memory - option 2: during sample generation, e.g., compute at
__getitem__
time (dataset or dataloader) - feels this might be compute intensive, if we are recomputing and not caching etc - need a decision?
- option 1: pre-compute in
- thinks output of
__getitem__
should be as general as possible- but should keep the number of dictionary keys as small as possible, not overdesign or proliferate
- most architectures so far are more or less the same
- important input to this discussion: is input to the foundation models similar to classical models? Or do we need different dataloaders, even different datasets perhaps?
FHN:
- prefers lazy computation at
__getitem__
time, option 2.- need to ensure ptf can be used by people with large data!
- agrees with most points of AG otherwise
PB:
- concerned about difference in datasets, dataloaders for foundation models
- this question needs to be resolved before adopting an input/output API
S:
- need to discuss support for data files on the hard drive
- list of datasets seen:
- in sktime we are using timesfm library
- timesfm is by google
- maybe we can use or interface it
- we could have two kind of output, point based vs interval or range prediction
- in inferce mode
- we should ensure we can cover both
T:
- mostly agree with AG on important points
- additional idea on in-memory vs lazy loading
- two data classes, one in-memory, one on-disk, with same
__getitem__
protocol - can use on-disk for trianing, in-memory for inference
- ideally one class inherits from the other (or have a common ancestor)
- two data classes, one in-memory, one on-disk, with same
FK:
-
idea of "bottleneck" or "least common denominator" did not come up, surprised (came up before)
-
think we need at least one class, likely a
DataSet
for "raw time series" (collection of, with all metadata)- this can be used as a component of another dataset or data loader
- conceptual model would be "minimal information that is required to specify abstract data model", of collection of time series
- alternatively, minimal for all deep learning forecasters ever seen somewhere on GitHub (which could be less minimal)
-
Benedikt (not here today) also suggested this idea, and that
DataSet
-s could depend on each other -
current best guess for a structure:
- "minimal" layer
DataSet
-s, these inherit from a common base and handle pandas as well as hard drive data - "specialized"
DataSet
-s, these could add re-sampling on top, normalization etc - "specialized"
DataLoader
-s, these are specific to data sets and classes of neural networks
- "minimal" layer
-
alternative structure
DataSet
-s only have minimal representation of "time series"- common base class and children for pandas/in-memory vs hard drive
- everything else is done by "specialized"
DataLoader
-s that adapt data sets to neural networks
-
T: one of the "final layer" classes - or middle layer classes - could be adapter to 1.X API of pytorch-forecasting (current), ensuring downwards compatibility.
-
FK: big question for me is how many "layers" to have, e.g., two dataset layers and one data loader layer, or single dataset layer and one data loader layer (where data loaders do more).
- T: this will depend on the model, and whether one wants 1.X compatibility, here one would need two data sets
- FK: true for 1.X compatibility, but what about long term - do we need the second dataset?
- T: thinking - 1.X API will not understand dataset directly, so will need to write adapter code.
- FK: does not follow that we need two layers of dataset - because we could have dataloader that adapts the "lower layer" dataset directly to the loader consumed by the current ptf 1.X API
- so there could be a dataloader that consumes the "minimal" dataset, and produces the current dataloader interface
- FK: does not follow that we need two layers of dataset - because we could have dataloader that adapts the "lower layer" dataset directly to the loader consumed by the current ptf 1.X API
-
T: had assumed we will use standard pytorch dataloader - if that is the case, we will need two datasets for downwards compatibility.
- FK: should we or not rewrite data loaders?
- T: as long as we adhere to torch API, it is fine to write custom loaders; there is no a-priori reason to not write our own
- we also have the "option 1 vs option 2" problem (precompute vs lazy) that AG mentioned in this case
FHN: if we keep using vanilla torch dataloader, we need two data set layers
* is this a contradiciton to the dataset to be "minimal"?
* FK: thinks not a contradiction, since there are two layers of datasets
* lower layer is "minimal" as discussed
* 2nd layer is specialized and specific to neural network(s)
- T: would make 2nd layer optional, not all models will need it
- FK: agree, have the feeling that foundation models do not need it (or need another, different, possibly simpler layer)
FK: feels there is convergence but with two open questions:
- "precompute or lazy" for resampling ("option 1 and 2")
- custom data loaders, or not (with implications on lower layers)
- if custom, then only one dataset layer needed
- if vanilla, then two dataset layers needed
(__getitem__
format to be handled in next agenda point)
-
AG: there is one more complication - "stacked models", which are composites that use other models and their outputs to generate improved outputs
- possible in dsipts but not well-architected (yet)
- impacts dataset/dataloader because this may require holding data in memory - option 1 vs 2 discussion
- this is more of a general point, perhaps about models - implications on dataset/dataloaders are not clear, but perhaps relevant
-
FK - we could have both options with a flag or two classes, this is really about internals of the class and does not impact
- number of classes
- interfaces
- it is an orthogonal question
- T: prefers having different classes if we have "option 1 classes" vs "option 2 classes"
- in-memory class should be optional adapter, does ont change getitem output
-
T: commenting about "stacked models"
- sounds very useful if model output can easily act as input!
- FK: like in sktime, this makes composition easy (and possible at all!)
- specifically, it should be possible to plug inference outputs into dataloader input locations!
- sounds very useful if model output can easily act as input!
strong opinions on using vanilla dataloader vs two dataset layers, vs custom dataloader and one dataset
- T: community standard seems to be vanilla dataloader, would have slight preference due to that
- FHN: good to have custom dataloader, but need to sort out in-memory vs lazy issue, and because there is less dataset layers.
- prefers vanilla dataloader, because familiar with how toe use by default
- AG: weak preference for vanilla dataloader, same in dsipts (status quo)
- FHN: if we need caching during training, can do via custom dataloader
- S: thinks vanilla is much better, already fairly optimized by community
- extra features can use customization, e.g., dynamic batches
- FK: Benedikt prefers (moderately stronlgy) the "two dataset layer" design, I infer this from GitHub converstaion on the linked issue
__getitem__
output convention
-
FHN: unsure
-
T: "as simple as possible"
- suggest to use
dict
and arrays (tensors etc) inside- should not be much more complicated
- suggest to use
-
S: do we have a clear picture of what should be there?
- FK: we have decided on two dataset layers, so need to answer this for two layers
- lower "minimal" layer is probably key
- intermediate layer implied by NN architecture
- "minimal" layer ADT is? already unclear
- single time series?
- chunk of time series?
- out of a collection of?
- metadata?
- FK: we have decided on two dataset layers, so need to answer this for two layers
-
T: would prefer pure tensors
- just a single tensor!
- FK: what would this encode?
- batch/time/channels
- that's it?
- AG: thinks this is too restrictive for a set of models
- length between input and output can be different
- context of 20 in past, 10 in future
- FK: this is an argument for the middle layer, not the "minimal layer"
- AG: question, what does you mean with layers
- "minimal layer" means autoregressive?
- FK: "minimal layer" was thinking the most simple possible dataset that is a dataest for all DL models
- was thinking more conretely, corresponds to abstract data type that actual datasets have, e.g., collection of time series of potentially different length
- possibly with metadata
- was thinking more conretely, corresponds to abstract data type that actual datasets have, e.g., collection of time series of potentially different length
- main reasoning, eventually it has to be tensors! That is the minimal encoding
- FK: what about categoricals?
- T: categoricals would be channels
- T: which is which is metadata!
- FK: would imply metadata tracking "already encoded", and "not yet encoded"
- FK: what about categoricals?
Tech meeting Jan 24, 2025
Attendees:- Andrea
- Aryan
- Franz
- Jigyasu
- Pranav B
- Satvik
- Sandeep
- whatdoes12
Notes
- discussion of input
- need adaptation for distirbuted inputs or hard drive
- FK: can be dealt with by varying inputs for same output
Recap
-
need to define dataset/dataloader layers
- last time, strong opinion for dataloader = vanilla, two dataset layers
- two "sides" need to be discussed - "input API",
__init__
, and "output API",__getitem__
-
FK: suggest to focus on on output first
- input can vary - so we can adapt pandas and zarr input etc
__getitem__
designs based on last times
AG design suggestion
* two DataSet (__getitem__):
* simple: x_num: [length x channels]
*
Current DSIPTS str
class MyDataset(Dataset):
def __init__(self, data:dict,t:np.array,groups:np.array,idx_target:Union[np.array,None],idx_target_future:Union[np.array,None])->torch.utils.data.Dataset:
"""
Extension of Dataset class. While training the returned item is a batch containing the standard keys
Args:
data (dict): a dictionary. Each key is a np.array containing the data. The keys are:
y : the target variable(s)
x_num_past: the numerical past variables
x_num_future: the numerical future variables
x_cat_past: the categorical past variables
x_cat_future: the categorical future variables
idx_target: index of target features in the past array
t (np.array): the time array related to the target variables
idx_target (Union[np.array,None]): you can specify the index in the past data that represent the input features (for differntial analysis or detrending strategies)
idx_target_future (Union[np.array,None]): you can specify the index in the future data that represent the input features (for differntial analysis or detrending strategies)
Returns:
torch.utils.data.Dataset: a torch Dataset to be used in a Dataloader
"""
def __getitem__(self, idxs):
sample = {}
for k in self.data:
sample[k] = self.data[k][idxs]
if self.idx_target is not None:
sample['idx_target'] = self.idx_target
if self.idx_target_future is not None:
sample['idx_target_future'] = self.idx_target_future
return sample
"""
Sampling via ``__getitem__`` returns a dictionary,
which always has following str-keyed entries:
y : (n_timepoints_future, n_targets)
x_num_past : (n_timepoints_past, n_targets + n_past_covariates_numerical)
x_num_future : (n_timepoints_future, n_future_numerical)
x_cat_past : (n_timepoints_past, n_past_covariates_categorical)
x_cat_future: (n_timepoints_future, n_future_covariates_categorical)
idx_target : list containing the column indexes of x_num_past corresponding to y
dsipts neural networks currently do not use t, so it is not passed!
"""
""" The input to `__init__` expects a dictionary:
y : (n_samples, n_timepoints_future, n_targets)
x_num_past : (n_samples, n_timepoints_past, n_targets + n_past_covariates_numerical)
x_num_future : (n_samples, n_timepoints_future, n_future_numerical)
x_cat_past : (n_samples, n_timepoints_past, n_past_covariates_categorical)
x_cat_future: (n_samples, n_timepoints_future, n_future_covariates_categorical)
t :
idx_target : list containing the column indexes of x_num_past corresponding to y
FK comments:
this looks like the top layer. It is closer to the "raw" or "bottleneck" layer, but it already has the data resampled.
The "sample" index is the first index in the input to __init__
.
FK opinion: the resampling should be part of a pipeline to prepare a data loader.
So we have different artefacts
- A raw data - y and x have not been split, resampled, etc
- B resampled data - input to dsipts
DataSet
. Obtained from raw data via resampling/normalization utility - C
DataLoader
using the output of__getitem__
observation:
pytorch-forecasting covers A-C in single DataSet
DSIPTS covers B-C in single DataSet, and A-B in utilities
FK: last time, agreed we should have two layers DataSet
but none of current solutions has the "bottleneck" layer
-
ptf should take DataSet instead of DataFrame
- this would just convert df into the "bottleneck" format
-
DSIPTS has A-B outside torch idiomatic structures
- FK not sure how to address this best
- option: putting preproc into DataSet as current
- perhaps not a good idea as it adds features to input
- option: one more DataSet between point A and B?
- if we start at "bottleneck", we enable harddrive, zarr, etc, because that will have a conversion layer
-
alternatively, we could have a custom class handle conversions up to the dataloader format, or the input required for the dataset closest to the model
FK design suggestion
class TimeSeries(Dataset):
"""PyTorch Dataset for storing raw time series from a pandas DataFrame.
This dataset follows the base raw time series dataset API in pytorch-forecasting.
A single sample corresponds to the i-th time series instance in the dataset.
Sampling via ``__getitem__`` returns a dictionary,
which always has following str-keyed entries:
* t: tensor of shape (n_timepoints)
Time index for each time point in the past or present. Aligned with ``y``,
and ``x`` not ending in ``f``.
* y: tensor of shape (n_timepoints, n_targets)
Target values for each time point. Rows are time points, aligned with ``t``.
Columns are targets, aligned with ``col_t``.
* x: tensor of shape (n_timepoints, n_features)
Features for each time point. Rows are time points, aligned with ``t``.
* group: tensor of shape (n_groups)
Group ids for time series instance.
* st: tensor of shape (n_static_features)
Static features.
* y_cols: list of str of length (n_targets)
Names of columns of ``y``, in same order as columns in ``y``.
* x_cols: list of str of length (n_features)
Names of columns of ``x``, in same order as columns in ``x``.
* st_cols: list of str of length (n_static_features)
Names of entries of ``st``, in same order as entries in ``st``.
* y_types: list of str of length (n_targets)
Types of columns of ``y``, in same order as columns in ``y``.
Types can be "c" for categorical, "n" for numerical.
* x_types: list of str of length (n_features)
Types of columns of ``x``, in same order as columns in ``x``.
Types can be "c" for categorical, "n" for numerical.
* st_types: list of str of length (n_static_features)
Types of entries of ``st``, in same order as entries in ``st``.
* x_k: list of int of length (n_features)
Whether the feature is known in the future, encoded by 0 or 1,
in same order as columns in ``x``.
0 means the feature is not known in the future, 1 means it is known.
Optionally, the following str-keyed entries can be included:
* t_f: tensor of shape (n_timepoints_future)
Time index for each time point in the future.
Aligned with ``x_f``.
* x_f: tensor of shape (n_timepoints_future, n_features)
Known features for each time point in the future.
Rows are time points, aligned with ``t_f``.
* weight: tensor of shape (n_timepoints), only if weight is not None
* weight_f: tensor of shape (n_timepoints_future), only if weight is not None
Parameters
----------
data : pd.DataFrame
data frame with sequence data.
Column names must all be str, and contain str as referred to below.
data_future : pd.DataFrame, optional, default=None
data rame with future data.
Column names must all be str, and contain str as referred to below.
May contain only columns that are in time, group, weight, known, or static.
time : str, optional, default = first col not in group_ids, weight, target, static.
integer typed column denoting the time index within ``data``.
This columns is used to determine the sequence of samples.
If there are no missings observations,
the time index should increase by ``+1`` for each subsequent sample.
The first time_idx for each series does not necessarily
have to be ``0`` but any value is allowed.
target : str or List[str], optional, default = last column (at iloc -1)
column(s) in ``data`` denoting the forecasting target.
Can be categorical or numerical dtype.
group : List[str], optional, default = None
list of column names identifying a time series instance within ``data``.
This means that the ``group`` together uniquely identify an instance,
and ``group`` together with ``time`` uniquely identify a single observation
within a time series instance.
If ``None``, the dataset is assumed to be a single time series.
weight : str, optional, default=None
column name for weights.
If ``None``, it is assumed that there is no weight column.
num : list of str, optional, default = all columns with dtype in "fi"
list of numerical variables in ``data``,
list may also contain list of str, which are then grouped together.
cat : list of str, optional, default = all columns with dtype in "Obc"
list of categorical variables in ``data``,
list may also contain list of str, which are then grouped together
(e.g. useful for product categories).
known : list of str, optional, default = all variables
list of variables that change over time and are known in the future,
list may also contain list of str, which are then grouped together
(e.g. useful for special days or promotion categories).
unknown : list of str, optional, default = no variables
list of variables that are not known in the future,
list may also contain list of str, which are then grouped together
(e.g. useful for weather categories).
static : list of str, optional, default = all variables not in known, unknown
list of variables that do not change over time,
list may also contain list of str, which are then grouped together.
"""