Skip to content

[ENH] [WIP] Standardize model output to 4d tensor #1895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 225 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 214 commits
Commits
Show all changes
225 commits
Select commit Hold shift + click to select a range
b3644a6
test suite
fkiraly Feb 22, 2025
a1d64c6
Merge branch 'main' into test-suite
fkiraly Feb 22, 2025
4b2486e
skeleton
fkiraly Feb 22, 2025
02b0ce6
skeleton
fkiraly Feb 22, 2025
41cbf66
Update test_all_estimators.py
fkiraly Feb 23, 2025
cef62d3
Update _base_object.py
fkiraly Feb 23, 2025
bc2e93b
Update _lookup.py
fkiraly Feb 23, 2025
eee1c86
Update _lookup.py
fkiraly Feb 23, 2025
164fe0d
base metadatda
fkiraly Feb 23, 2025
20e88d0
registry
fkiraly Feb 23, 2025
318c1fb
fix private name
fkiraly Feb 23, 2025
012ab3d
Update _base_object.py
fkiraly Feb 23, 2025
86365a0
test failure
fkiraly Feb 23, 2025
f6dee46
Update test_all_estimators.py
fkiraly Feb 23, 2025
9b0e4ec
Update test_all_estimators.py
fkiraly Feb 23, 2025
7de5285
Update test_all_estimators.py
fkiraly Feb 23, 2025
57dfe3a
test folders
fkiraly Feb 23, 2025
c9f12db
Update test.yml
fkiraly Feb 23, 2025
fa8144e
test integration
fkiraly Feb 23, 2025
232a510
fixes
fkiraly Feb 23, 2025
1c8d4b5
Update _conftest.py
fkiraly Feb 23, 2025
f632e32
try scenarios
fkiraly Feb 23, 2025
252598d
D1, D2 layer commit
phoeenniixx Apr 6, 2025
d0d1c3e
remove one comment
phoeenniixx Apr 6, 2025
80e64d2
model layer commit
phoeenniixx Apr 6, 2025
6364780
update docstring
phoeenniixx Apr 6, 2025
82b3dc7
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 6, 2025
257183c
update data_module.py
phoeenniixx Apr 10, 2025
9cdcb19
update data_module.py
phoeenniixx Apr 10, 2025
a83bf32
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
ac56d4f
Add disclaimer
phoeenniixx Apr 10, 2025
0e7e36f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
4bfff21
update docstring
phoeenniixx Apr 11, 2025
ef98273
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 11, 2025
8a53ed6
Add tests for D1,D2 layer
phoeenniixx Apr 19, 2025
9f9df31
Merge branch 'main' into refactor-d1-d2
phoeenniixx Apr 19, 2025
cdecb77
Code quality
phoeenniixx Apr 19, 2025
86360fd
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 19, 2025
20aafb7
refactor file
fkiraly Apr 30, 2025
043820d
warning
fkiraly Apr 30, 2025
1720a15
linting
fkiraly May 1, 2025
af44474
move coercion to utils
fkiraly May 1, 2025
a3cb8b7
linting
fkiraly May 1, 2025
75d7fb5
Update _timeseries_v2.py
fkiraly May 1, 2025
1b946e6
Update __init__.py
fkiraly May 1, 2025
3edb08b
Update __init__.py
fkiraly May 1, 2025
a4bc9d8
Merge branch 'main' into pr/1811
fkiraly May 1, 2025
4c0d570
Merge branch 'pr/1811' into pr/1812
fkiraly May 1, 2025
ef37f55
Merge branch 'main' into test-suite
fkiraly May 1, 2025
a669134
Update _lookup.py
fkiraly May 4, 2025
d78bf5d
Update _lookup.py
fkiraly May 4, 2025
e350291
update tests
phoeenniixx May 11, 2025
f90c94f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx May 11, 2025
3099691
update tft_v2
phoeenniixx May 11, 2025
77cb979
warnings and init attr handling
fkiraly May 13, 2025
28df3c3
Merge branch 'refactor-d1-d2' of https://github.com/phoeenniixx/pytor…
fkiraly May 13, 2025
f8c94e6
simplify TimeSeries.__getitem__
fkiraly May 13, 2025
c289255
Update _timeseries_v2.py
fkiraly May 13, 2025
9467f38
Update data_module.py
fkiraly May 13, 2025
c3b40ad
backwards compat of private/public attrs
fkiraly May 13, 2025
c007310
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx May 13, 2025
2e25052
Merge branch 'main' into refactor-model
phoeenniixx May 13, 2025
38c28dc
add tests
phoeenniixx May 14, 2025
9d80eb8
add tests
phoeenniixx May 14, 2025
a8ccfe3
add tests
phoeenniixx May 14, 2025
f900ba5
add more docstrings
phoeenniixx May 14, 2025
ed1b799
add note about the commented out tests
phoeenniixx May 14, 2025
fd59bac
initial commit - design decisions pending
May 15, 2025
c947910
Merge branch 'main' into refactor-model
phoeenniixx May 16, 2025
c0ceb8a
add the commented out tests
phoeenniixx May 16, 2025
3828c26
remove note
phoeenniixx May 16, 2025
0cb7df6
add _prepare_metadata function to d2 class
May 18, 2025
6d6d18e
Merge branch 'main' into refactor-model
phoeenniixx May 18, 2025
b32ed0e
add data processing and window mechanism for data module
May 18, 2025
cb55578
implement scaling, normalization in preprocessing, and windowing
May 18, 2025
2218336
create new tslib dataset, dataloader setup and collate_fn
PranavBhatP May 20, 2025
3144865
Merge branch 'test-suite' of https://github.com/sktime/pytorch-foreca…
phoeenniixx May 20, 2025
89231d9
Merge branch 'main' into tslib-d2-refactor
PranavBhatP May 20, 2025
30b541b
make the modules private
phoeenniixx May 20, 2025
3f1e11f
Merge remote-tracking branch 'origin/refactor-model' into refactor-model
phoeenniixx May 20, 2025
5cc3ff1
initial commit
phoeenniixx May 20, 2025
1bcf181
Merge branch 'refactor-model' into test-framework
phoeenniixx May 20, 2025
f18e09d
add TFTMetadata class
phoeenniixx May 20, 2025
e1e360e
add TFTMetadata class
phoeenniixx May 20, 2025
1fd0594
Merge remote-tracking branch 'phoeenniixx/refactor-model' into base-m…
PranavBhatP May 22, 2025
168e16a
Merge branch 'main' into test-framework
phoeenniixx May 22, 2025
b25ced9
implement base model for tslib
PranavBhatP May 22, 2025
f4afe90
add layer directory to store different architecture layer implementat…
PranavBhatP May 22, 2025
ae12c69
Merge branch 'main' into tslib-d2-refactor
PranavBhatP May 22, 2025
04e1a45
add model layers and timexer sample model code
PranavBhatP May 22, 2025
a27905d
format timexer file with proper linting
PranavBhatP May 22, 2025
e87c25b
complete the pipeline for tslib data module
PranavBhatP May 23, 2025
92c12bf
add TFT tests
phoeenniixx May 25, 2025
1831bcb
add example notebook and fix buy in _timexer.py
PranavBhatP May 25, 2025
a896b3f
clear cell outputs on trainer.fit()
PranavBhatP May 25, 2025
420de37
remove unnecessary squeeze method and add prediction demo in example …
PranavBhatP May 26, 2025
3b07263
restructure layers directory
PranavBhatP May 27, 2025
010298e
delete empty modules.py
PranavBhatP May 27, 2025
d0aa444
add warning and move tslib base model to a new file
PranavBhatP May 27, 2025
fef4113
fix wrong import statement in _timexer.py
PranavBhatP May 27, 2025
1d478d5
remove refactored TFT
phoeenniixx May 27, 2025
f9992f2
Merge branch 'main' into test-framework
phoeenniixx May 28, 2025
d049019
update test_all_estimators
phoeenniixx May 28, 2025
e72486b
linting
phoeenniixx May 28, 2025
8daeb95
Merge branch 'main' into tslib-d2-refactor
PranavBhatP May 28, 2025
5142d52
fix circular dependency error in en_embedding.py
PranavBhatP May 28, 2025
efbbc09
Merge branch 'main' into tslib-d2-refactor
PranavBhatP May 28, 2025
7443b0b
Merge branch 'main' into test-framework
phoeenniixx May 29, 2025
a734f26
refactor
phoeenniixx May 29, 2025
7f466b2
Add more test_params
phoeenniixx May 29, 2025
8a680df
add prelimnary tests for tslib d2
PranavBhatP May 29, 2025
0ccb078
add collate, setup and dataset tests
PranavBhatP May 29, 2025
826ac31
fix failing setup and tslib_dataset tests
PranavBhatP May 30, 2025
d70b07c
fix incorrect metadata handling in tslib dataset and fix tests for co…
PranavBhatP May 30, 2025
5ce4553
Merge branch 'main' into tslib-d2-refactor
PranavBhatP May 30, 2025
7b41140
fix code to comply with new linting syntax rules
PranavBhatP May 30, 2025
e3e5bb8
add tests for checking custom train test split
PranavBhatP May 30, 2025
d67ccae
add tests for multitarget dataset and fix handling of incosistent dty…
PranavBhatP May 30, 2025
0968452
Add metadata tests
phoeenniixx May 31, 2025
525bbb9
Merge branch 'main' into test-framework
phoeenniixx Jun 1, 2025
4267da6
Merge branch 'main' into test-framework
phoeenniixx Jun 1, 2025
4e8f863
add object-filter to ptf-v1
phoeenniixx Jun 1, 2025
5f79e25
add warnings for hyperparams and refactor d_model param to hidden_siz…
PranavBhatP Jun 2, 2025
7fb048b
Merge branch 'main' into tslib-d2-refactor
PranavBhatP Jun 2, 2025
4bfec1b
dummy commit to trigger code-quality checks
PranavBhatP Jun 2, 2025
7510509
refactor logic for handling shuffling at the series level
PranavBhatP Jun 4, 2025
1a52579
add assert statement for validation of time series indices
PranavBhatP Jun 4, 2025
693fbd2
add handling for time series datasets of small size.
PranavBhatP Jun 4, 2025
10b1e4a
change layers module files to private
PranavBhatP Jun 4, 2025
0fab57a
refactor notebook and change _timexer.py to fix broken import blocks
PranavBhatP Jun 4, 2025
668c901
remove features parameter and precompute the feature mode (S/MS/M) in…
PranavBhatP Jun 4, 2025
7c2855c
Merge branch 'main' into tslib-d2-refactor
PranavBhatP Jun 4, 2025
9d62ff0
Merge branch 'main' into tslib-d2-refactor
PranavBhatP Jun 5, 2025
8cb1484
initial setup for dlinear
PranavBhatP Jun 5, 2025
c117092
Merge branch 'main' into test-framework
phoeenniixx Jun 5, 2025
4845c9b
complete functional dlinear pipeline
PranavBhatP Jun 5, 2025
7f0495d
Merge branch 'main' into tslib-dlinear-model
PranavBhatP Jun 5, 2025
943151b
Merge branch 'main' into tslib-d2-refactor
PranavBhatP Jun 5, 2025
228ebed
Merge branch 'main' into tslib-d2-refactor
PranavBhatP Jun 5, 2025
b101e2e
Merge branch 'main' into tslib-dlinear-model
PranavBhatP Jun 5, 2025
913c418
remove unused import in _base_model_v2
PranavBhatP Jun 6, 2025
8990e8b
add function to validate presence of continous and categorical indices
PranavBhatP Jun 6, 2025
f6d39fe
Merge branch 'main' into test-framework
phoeenniixx Jun 6, 2025
2c518ee
add new base classes
phoeenniixx Jun 6, 2025
7a5c58f
remove try block
phoeenniixx Jun 8, 2025
cb3e944
Merge branch 'main' into test-framework
phoeenniixx Jun 8, 2025
d0009ff
fix bug in quantile predictions for timexer
PranavBhatP Jun 9, 2025
c9d3c26
refactor tensor handling to return pure tensor outputs with timexer a…
PranavBhatP Jun 9, 2025
33a99d1
Merge branch 'main' into tslib-d2-refactor
PranavBhatP Jun 9, 2025
927ee49
Merge branch 'tslib-d2-refactor' into tslib-dlinear-model
PranavBhatP Jun 9, 2025
1829de5
restore v1 version of _timexer
PranavBhatP Jun 9, 2025
8fc1865
change import statement for new location of timexer v2 in notebook
PranavBhatP Jun 9, 2025
ebe8d22
changed tslib_data_module to private
PranavBhatP Jun 9, 2025
01b2d78
change import statement for new location of tslib data module in note…
PranavBhatP Jun 9, 2025
8fd608b
Merge branch 'tslib-d2-refactor' into tslib-dlinear-model
PranavBhatP Jun 9, 2025
3b9de6d
add support for multiple datamodules
phoeenniixx Jun 9, 2025
032a7b0
typo
phoeenniixx Jun 9, 2025
4d9a19a
Merge branch 'main' into test-framework
phoeenniixx Jun 9, 2025
cef0292
implement quantile prediction support
PranavBhatP Jun 10, 2025
35c6973
fix quantile predictions bug and improve code readability
PranavBhatP Jun 11, 2025
1749cd2
fix feature mode handling in tslib data module and add error handling…
PranavBhatP Jun 11, 2025
2d86134
Merge branch 'tslib-d2-refactor' into tslib-dlinear-model
PranavBhatP Jun 11, 2025
cd477e6
handle single target variable casE
PranavBhatP Jun 11, 2025
7f8fca8
fix validation of empty cont and cat indices to allow univariate fore…
PranavBhatP Jun 11, 2025
953214e
Merge branch 'tslib-d2-refactor' into tslib-dlinear-model
PranavBhatP Jun 11, 2025
bdaecc2
Merge branch 'test-framework' of https://github.com/phoeenniixx/pytor…
PranavBhatP Jun 11, 2025
8e5864c
add metadata class for timexer but tests not running
PranavBhatP Jun 11, 2025
327919c
implement metadata container for v2 of timexer
PranavBhatP Jun 11, 2025
3809ad5
fix minor bug by simplifying the window creation and removing redunda…
PranavBhatP Jun 11, 2025
10c9290
Merge branch 'tslib-d2-refactor' into tslib-dlinear-model
PranavBhatP Jun 11, 2025
212e01d
improve notebook content
PranavBhatP Jun 11, 2025
12a79a8
remove 'gpu' param from trainer creation step
PranavBhatP Jun 11, 2025
75220be
add basic fixtures for timexer v2 tests
PranavBhatP Jun 12, 2025
03c06e8
Merge branch 'main' into test-framework
phoeenniixx Jun 12, 2025
8b0087e
linting
phoeenniixx Jun 12, 2025
d328fae
Merge branch 'main' into test-framework
phoeenniixx Jun 13, 2025
6e4e692
add tests for timexer in v2
PranavBhatP Jun 13, 2025
0003c54
minor changes to forecast and forecast multi to deal with explicit en…
PranavBhatP Jun 13, 2025
79b5682
Merge branch 'main' into tslib-d2-refactor
PranavBhatP Jun 13, 2025
b422af6
Merge branch 'tslib-d2-refactor' into tslib-dlinear-model
PranavBhatP Jun 13, 2025
2546410
add endogenous and exogenous feature handling for forecast_multi
PranavBhatP Jun 13, 2025
55e1869
Merge branch 'test-framework' of https://github.com/phoeenniixx/pytor…
PranavBhatP Jun 13, 2025
baf9a61
Merge branch 'main' into tslib-d2-refactor
PranavBhatP Jun 15, 2025
0352465
Merge branch 'main' into tslib-dlinear-model
PranavBhatP Jun 15, 2025
7c6385c
rename
fkiraly Jun 15, 2025
66a006c
renames, exports
fkiraly Jun 15, 2025
5871bfa
Merge branch 'main' into pr/1836
fkiraly Jun 15, 2025
eb0dfa5
reverts
fkiraly Jun 15, 2025
6fac509
revert
fkiraly Jun 15, 2025
ee1edf5
revert
fkiraly Jun 15, 2025
ff85e69
move
fkiraly Jun 15, 2025
377f416
Update test_tslib_data_module.py
fkiraly Jun 15, 2025
14db380
make layer folders all private
fkiraly Jun 15, 2025
7eaee38
Update _timexer_v2.py
fkiraly Jun 15, 2025
06ae102
Update _tslib_base_model_v2.py
fkiraly Jun 15, 2025
faf86ce
imports
fkiraly Jun 15, 2025
65e171e
revert
fkiraly Jun 15, 2025
3786a4c
Update _timexer_pkg_v2.py
fkiraly Jun 15, 2025
d3e6ce2
imports
fkiraly Jun 15, 2025
e8d72c1
Merge branch 'main' into tslib-dlinear-model
PranavBhatP Jun 16, 2025
56ea5af
revert BaseForecaster package parent class
PranavBhatP Jun 16, 2025
590beff
revert tide pkg base class
PranavBhatP Jun 16, 2025
081b840
Merge branch 'tslib-d2-refactor' into tslib-dlinear-model
PranavBhatP Jun 16, 2025
ffe6983
fix bug in moving average when kernel_size is even
PranavBhatP Jun 16, 2025
6882e23
add package container for dlinear model in v2
PranavBhatP Jun 16, 2025
87ceb10
remove redundant code for _forecast_multi
PranavBhatP Jun 16, 2025
fbb234a
Merge branch 'tslib-d2-refactor' into tslib-dlinear-model
PranavBhatP Jun 16, 2025
cc10dd5
remove redundant feaure mode
PranavBhatP Jun 16, 2025
369ed49
add test suite for dlinear model
PranavBhatP Jun 17, 2025
9d94548
add standardise output format function in BaseModel to enforce 4d ten…
PranavBhatP Jun 18, 2025
2d45ac6
refactor logic to handle inputs of arbitrary and standardise the outp…
PranavBhatP Jun 20, 2025
2b5b702
add contract for handling standard model output format in docstring
PranavBhatP Jun 20, 2025
ed71413
Merge branch 'main' into pr/1874
fkiraly Jun 20, 2025
b5aab41
Merge branch 'pr/1874' into pr/1895
fkiraly Jun 20, 2025
8a0e673
make name changes to DLinear
PranavBhatP Jun 22, 2025
9a2979c
Merge branch 'main' into tslib-dlinear-model
PranavBhatP Jun 22, 2025
fa8cc3b
Merge branch 'tslib-dlinear-model' of https://www.github.com/PranavBh…
PranavBhatP Jun 22, 2025
7e7688e
update DLinearModel -> DLinear for model name
PranavBhatP Jun 22, 2025
ea827a8
rename layers folder to private module
PranavBhatP Jun 22, 2025
06b898b
Merge branch 'main' into standardise-output-format
PranavBhatP Jun 22, 2025
5994af3
Merge branch 'tslib-dlinear-model' into standardise-output-format
PranavBhatP Jun 22, 2025
e82ab68
Merge branch 'standardise-output-format' of https://www.github.com/Pr…
PranavBhatP Jun 22, 2025
1f60883
changes to clarify multi-target forecasting in docstring
PranavBhatP Jun 22, 2025
ce4eeae
address code feedback and clarify docstring explanations
PranavBhatP Jun 22, 2025
a33b9ba
Merge branch 'main' into pr/1895
fkiraly Jun 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pytorch_forecasting/layers/decomposition/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Decomposition layers for PyTorch Forecasting.
"""

from pytorch_forecasting.layers.decomposition._series_decomp import SeriesDecomposition

__all__ = [
"SeriesDecomposition",
]
43 changes: 43 additions & 0 deletions pytorch_forecasting/layers/decomposition/_series_decomp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Series Decomposition Block for time series forecasting models.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from pytorch_forecasting.layers.filter._moving_avg_filter import MovingAvg


class SeriesDecomposition(nn.Module):
"""
Series decomposition block from Autoformer.

Decomposes time series into trend and seasonal components using
moving average filtering.

Args:
kernel_size (int):
Size of the moving average kernel for trend extraction.
"""

def __init__(self, kernel_size):
super().__init__()
self.moving_avg = MovingAvg(kernel_size, stride=1)

def forward(self, x):
"""
Forward pass for series decomposition.

Args:
x (torch.Tensor):
Input time series tensor of shape (batch_size, seq_len, features).

Returns:
tuple:
- trend (torch.Tensor): Trend component of the time series.
- seasonal (torch.Tensor): Seasonal component of the time series.
"""
trend = self.moving_avg(x)
seasonal = x - trend
return seasonal, trend
9 changes: 9 additions & 0 deletions pytorch_forecasting/layers/filter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Filtering layers for time series forecasting models.
"""

from pytorch_forecasting.layers.filter._moving_avg_filter import MovingAvg

__all__ = [
"MovingAvg",
]
48 changes: 48 additions & 0 deletions pytorch_forecasting/layers/filter/_moving_avg_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Moving Average Filter Block
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class MovingAvg(nn.Module):
"""
Moving Average block for smoothing and trend extraction from time series data.

A moving average is a smoothing technique that creates a series of average from
different subsets of a time series.

For example: Given a time series ``x = [x_1, x_2, ..., x_n]``, the moving average
with a kernel size of `k` calculates the average of each subset of `k` consecutive
elements, resulting in a new series of averages.

Args:
kernel_size (int):
Size of the moving average kernel.
stride (int):
Stride for the moving average operation, typically set to 1.
"""

def __init__(self, kernel_size, stride):
super().__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size, stride=stride, padding=0)

def forward(self, x):
if self.kernel_size % 2 == 0:
self.padding_left = self.kernel_size // 2 - 1
self.padding_right = self.kernel_size // 2
else:
self.padding_left = self.kernel_size // 2
self.padding_right = self.kernel_size // 2

front = x[:, 0:1, :].repeat(1, self.padding_left, 1)
end = x[:, -1:, :].repeat(1, self.padding_right, 1)

x_padded = torch.cat([front, x, end], dim=1)
x_transposed = x_padded.permute(0, 2, 1)
x_smoothed = self.avg(x_transposed)
x_out = x_smoothed.permute(0, 2, 1)
return x_out
172 changes: 172 additions & 0 deletions pytorch_forecasting/models/base/_base_model_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,175 @@
prog_bar=True,
logger=True,
)

def standardize_model_output(
self,
prediction: torch.Tensor,
expected_dims: tuple[int, Optional[int], Optional[int], Optional[int]] = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hard to read - why not simply tuple[int]?

None,
None,
None,
None,
), # noqa: E501
) -> torch.Tensor:
"""
Standardize model outputs to a 4-dimensional tensor, with shape
(batch_size, timesteps, num_features, last_dim).

Parameters
----------
prediction : torch.Tensor
The raw prediction tensor from the model.
- Must be a torch.Tensor (in the future, also accept a list of tensors for
multi-target forecasting).
- Supported dims: 2D, 3D or 4D tensors.
- if 2D: (batch_size, timesteps) - univariate forecasting
- if 3D:
a) (batch_size, timesteps, n_targets) - multivariate forecasting
b) (batch_size, timesteps, last_dim) - univariate forecasting with quantiles or distribution.
c) (batch_size, timesteps, n_targets * last_dim) - multivariate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we think of a better name than last_dim? I have no good alternative idea though. So this is not blocking for now (only appears in docstring anyway)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does forecast_param_dim sound good?

forecasting with quantiles, where features and quantiles are flattened in dim 2.
- if 4D: (batch_size, timesteps, n_targets, last_dim) - multivariate
forecasting with quantiles or distribution parameters.
- In the future, once multi-target forecasting is supported, this
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, not sure if I understand this. Is multi-target forcasting not just the case n_targets>1?

Copy link
Contributor Author

@PranavBhatP PranavBhatP Jun 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you are right. I think there is a slight confusion in the docstring. What I was actually referring to is multi-target forecasting MultiLoss where an individual target has its own loss function. I've made the changes in the docstring now. Currently, when a single loss function is used with multiple-targets, we can directly use n_targets > 1.

will also accept a list of tensors, where each tensor inside the list
is treated as above.
- If anything apart from the above dimensions is provided, an error is raised.

expected_dims : tuple[int, Optional[int], Optional[int], Optional[int]], default=(None, None, None, None)
A tuple specifying the dimensions: (n_targets, batch_size, timesteps, last_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very confusing since the dimensions are not in the same order as for the tensor. Can we change that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed it to the order of the output 4d tensor


n_targets : int
- Position 0: Number of target features
- Must be provided explicitly (cannot be None)
- Used for reshaping 2D and 3D tensors to 4D.

batch_size : Optional[int], default=None
- Position 1: Expected batch size
- When specified: Validates prediction.shape[0]
- When None: Uses actual tensor dimension

timesteps : Optional[int], default=None
- Position 2: Expected number of timesteps
- When specified: Validates prediction.shape[1]
- When None: Uses actual tensor dimension

last_dim : Optional[int], default=None
- Position 3: Size of the last dimension.
- Common use case - quantile, sample, distribution params.
- When it is specified, it is used to directly reshape.
- When None and model uses QuantileLoss: It is set to the number of quantiles
- When None and no quantile information is available: It defaults to 1.
- If required, this can be extended to handle other cases where the last_dim is None
but its value can be inferred from the loss function or model configuration (apart from
the existing QuantileLoss case, of course).
Returns
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor formatting remark:

  • there should be a newline before headers
  • there shuld be a newline before bullet point lists, and one after

(not consequential now, but this is an assumption when generating rst autodocs)

-------
torch.Tensor
The standardized prediction tensor with shape (batch_size, timesteps, n_targets, last_dim).
The prediction tensor is obtained by reshaping the input tensor. There are
several cases to consider:

- If the input tensor is 2D, it is reshaped to (batch_size, timesteps, 1, 1).
- If the input tensor is 3D, it is reshaped to (batch_size, timesteps, 1, 1) for a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is unclear how the reshaping happens for 3D, can you be more precise?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made changes, you can take a look now..

multivariate single-target forecast, or (batch_size, timesteps, 1, last_dim) for a univariate quantile forecast.
- If the input tensor is 4D, it is assumed to be in the shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for 4D, you need to be explicit about how it is getting reshaped.

(batch_size, timesteps, n_targets, last_dim) or (batch_size, timesteps, last_dim, n_targets).
Notes
-----
[1] The fourth dimension (last_dim) commonly represents:

* Quantiles: For quantile regression (e.g., 0.1, 0.5, 0.9)
* Distribution parameters: For parametric forecasts (e.g., mean, variance)
* Samples: For sample-based uncertainty estimates

The current implementation assumes the most common case of quantile forecasts
when automatically inferring this dimension from the loss function,
but any value can be explicitly provided. A fallback of 1 is used in case where
no information is available on ``last_dim``.

[2] This can currently handle situations where a single target is used
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain that and give an example? Can you give the simplest example that shows a 4D tensor is not possible in this case? Is this, for instance, that we use squared loss for variable 1 but parametric log-loss on mean/variance (of normal) on variable 2?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe using nestedtensor-s?
https://docs.pytorch.org/tutorials/prototype/nestedtensor.html
(is this a stable feature? Looks like it?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try looking at this..

either in a univariate or multivariate situation. In case of multi-target
forecasting, where each target has its own loss function, a list of tensors is
returned, where each tensor corresponds to a target. This requires some change
to the existing code.
""" # noqa: E501

n_targets, batch_size, timesteps, last_dim = expected_dims

Check warning on line 391 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L391

Added line #L391 was not covered by tests

if not isinstance(prediction, torch.Tensor):
raise TypeError(

Check warning on line 394 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L393-L394

Added lines #L393 - L394 were not covered by tests
f"Expected prediction to be a torch.Tensor, but got {type(prediction)}"
)

if n_targets is None:
raise ValueError(

Check warning on line 399 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L398-L399

Added lines #L398 - L399 were not covered by tests
"Expected n_targets to be a positive integer, but got `None`."
)

if last_dim is None:
if hasattr(self.loss, "quantiles") and self.loss.quantiles is not None:
last_dim = len(self.loss.quantiles) # Quantile regression case

Check warning on line 405 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L403-L405

Added lines #L403 - L405 were not covered by tests
# we can add more cases here in the future, where we refer to the specific
# loss function to determine the last dimension. For now we are sticking
# to the quantile regression case.
else:
last_dim = 1

Check warning on line 410 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L410

Added line #L410 was not covered by tests

if batch_size is not None:
if prediction.shape[0] != batch_size:
raise ValueError(

Check warning on line 414 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L412-L414

Added lines #L412 - L414 were not covered by tests
f"Expected batch size {batch_size}, but got {prediction.shape[0]}."
)

if timesteps is not None:
if prediction.shape[1] != timesteps:
raise ValueError(

Check warning on line 420 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L418-L420

Added lines #L418 - L420 were not covered by tests
f"Expected timesteps {timesteps}, but got {prediction.shape[1]}."
)

if prediction.ndim == 2:

Check warning on line 424 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L424

Added line #L424 was not covered by tests
# reshape to (batch_size, timsteps, 1, 1)
prediction = prediction.unsqueeze(-1).unsqueeze(-1)

Check warning on line 426 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L426

Added line #L426 was not covered by tests

elif prediction.ndim == 3:
if prediction.shape[2] == n_targets:

Check warning on line 429 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L428-L429

Added lines #L428 - L429 were not covered by tests
# reshape to (batch_size, timesteps, n_targets, 1)
prediction = prediction.unsqueeze(-1)
elif prediction.shape[2] == last_dim:

Check warning on line 432 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L431-L432

Added lines #L431 - L432 were not covered by tests
# reshape to (batch_size, timesteps, 1, last_dim)
prediction = prediction.unsqueeze(2)
elif prediction.shape[2] == n_targets * last_dim:

Check warning on line 435 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L434-L435

Added lines #L434 - L435 were not covered by tests
# multivariate forecast with quantiles
# where features and quantiles are flattened in dim 2.
# reshape to (batch_size, timesteps, n_targets, last_dim)
prediction = prediction.reshape(

Check warning on line 439 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L439

Added line #L439 was not covered by tests
prediction.shape[0], prediction.shape[1], n_targets, last_dim
)
else:
# reshape to (batch_size, timesteps, n_targets, last_dim)
prediction = prediction.unsqueeze(-1)

Check warning on line 444 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L444

Added line #L444 was not covered by tests

elif prediction.ndim == 4:

Check warning on line 446 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L446

Added line #L446 was not covered by tests
# assuming only a single case where n_targets and last_dim are swapped.
if prediction.shape[2] == last_dim and prediction.shape[3] == n_targets:

Check warning on line 448 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L448

Added line #L448 was not covered by tests
# reshape to (batch_size, timesteps, n_targets, last_dim)
warn(

Check warning on line 450 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L450

Added line #L450 was not covered by tests
"Prediction tensor has shape (batch_size, timesteps, last_dim, n_targets). " # noqa: E501
"This is not the expected shape. Transposing the last two dimensions." # noqa: E501
)
prediction = prediction.permute(0, 1, 3, 2)

Check warning on line 454 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L454

Added line #L454 was not covered by tests

else:
raise ValueError(

Check warning on line 457 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L457

Added line #L457 was not covered by tests
f"Expected prediction tensor to have 2, 3, or 4 dimensions, "
f"but got {prediction.ndim} dimensions."
)

# final check to ensure the output is 4D
if prediction.ndim != 4:
raise ValueError(

Check warning on line 464 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L463-L464

Added lines #L463 - L464 were not covered by tests
f"Failed to standardize output to 4D tensor. Current shape: {prediction.shape}" # noqa: E501
)

return prediction

Check warning on line 468 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L468

Added line #L468 was not covered by tests
10 changes: 10 additions & 0 deletions pytorch_forecasting/models/dlinear/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
Decomposition-Linear model for time series forecasting.
"""

from pytorch_forecasting.models.dlinear._dlinear_pkg_v2 import DLinearModel_pkg_v2
from pytorch_forecasting.models.dlinear._dlinear_v2 import DLinearModel

__all__ = [
"DLinearModel" "DLinearModel_pkg_v2",
]
Loading
Loading