Skip to content

[ENH] Test framework for ptf-v2 #1841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 112 commits into from
Jun 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
112 commits
Select commit Hold shift + click to select a range
b3644a6
test suite
fkiraly Feb 22, 2025
a1d64c6
Merge branch 'main' into test-suite
fkiraly Feb 22, 2025
4b2486e
skeleton
fkiraly Feb 22, 2025
02b0ce6
skeleton
fkiraly Feb 22, 2025
41cbf66
Update test_all_estimators.py
fkiraly Feb 23, 2025
cef62d3
Update _base_object.py
fkiraly Feb 23, 2025
bc2e93b
Update _lookup.py
fkiraly Feb 23, 2025
eee1c86
Update _lookup.py
fkiraly Feb 23, 2025
164fe0d
base metadatda
fkiraly Feb 23, 2025
20e88d0
registry
fkiraly Feb 23, 2025
318c1fb
fix private name
fkiraly Feb 23, 2025
012ab3d
Update _base_object.py
fkiraly Feb 23, 2025
86365a0
test failure
fkiraly Feb 23, 2025
f6dee46
Update test_all_estimators.py
fkiraly Feb 23, 2025
9b0e4ec
Update test_all_estimators.py
fkiraly Feb 23, 2025
7de5285
Update test_all_estimators.py
fkiraly Feb 23, 2025
57dfe3a
test folders
fkiraly Feb 23, 2025
c9f12db
Update test.yml
fkiraly Feb 23, 2025
fa8144e
test integration
fkiraly Feb 23, 2025
232a510
fixes
fkiraly Feb 23, 2025
1c8d4b5
Update _conftest.py
fkiraly Feb 23, 2025
f632e32
try scenarios
fkiraly Feb 23, 2025
252598d
D1, D2 layer commit
phoeenniixx Apr 6, 2025
d0d1c3e
remove one comment
phoeenniixx Apr 6, 2025
80e64d2
model layer commit
phoeenniixx Apr 6, 2025
6364780
update docstring
phoeenniixx Apr 6, 2025
82b3dc7
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 6, 2025
257183c
update data_module.py
phoeenniixx Apr 10, 2025
9cdcb19
update data_module.py
phoeenniixx Apr 10, 2025
a83bf32
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
ac56d4f
Add disclaimer
phoeenniixx Apr 10, 2025
0e7e36f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 10, 2025
4bfff21
update docstring
phoeenniixx Apr 11, 2025
ef98273
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 11, 2025
8a53ed6
Add tests for D1,D2 layer
phoeenniixx Apr 19, 2025
9f9df31
Merge branch 'main' into refactor-d1-d2
phoeenniixx Apr 19, 2025
cdecb77
Code quality
phoeenniixx Apr 19, 2025
86360fd
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx Apr 19, 2025
20aafb7
refactor file
fkiraly Apr 30, 2025
043820d
warning
fkiraly Apr 30, 2025
1720a15
linting
fkiraly May 1, 2025
af44474
move coercion to utils
fkiraly May 1, 2025
a3cb8b7
linting
fkiraly May 1, 2025
75d7fb5
Update _timeseries_v2.py
fkiraly May 1, 2025
1b946e6
Update __init__.py
fkiraly May 1, 2025
3edb08b
Update __init__.py
fkiraly May 1, 2025
a4bc9d8
Merge branch 'main' into pr/1811
fkiraly May 1, 2025
4c0d570
Merge branch 'pr/1811' into pr/1812
fkiraly May 1, 2025
ef37f55
Merge branch 'main' into test-suite
fkiraly May 1, 2025
a669134
Update _lookup.py
fkiraly May 4, 2025
d78bf5d
Update _lookup.py
fkiraly May 4, 2025
e350291
update tests
phoeenniixx May 11, 2025
f90c94f
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx May 11, 2025
3099691
update tft_v2
phoeenniixx May 11, 2025
77cb979
warnings and init attr handling
fkiraly May 13, 2025
28df3c3
Merge branch 'refactor-d1-d2' of https://github.com/phoeenniixx/pytor…
fkiraly May 13, 2025
f8c94e6
simplify TimeSeries.__getitem__
fkiraly May 13, 2025
c289255
Update _timeseries_v2.py
fkiraly May 13, 2025
9467f38
Update data_module.py
fkiraly May 13, 2025
c3b40ad
backwards compat of private/public attrs
fkiraly May 13, 2025
c007310
Merge branch 'refactor-d1-d2' into refactor-model
phoeenniixx May 13, 2025
2e25052
Merge branch 'main' into refactor-model
phoeenniixx May 13, 2025
38c28dc
add tests
phoeenniixx May 14, 2025
9d80eb8
add tests
phoeenniixx May 14, 2025
a8ccfe3
add tests
phoeenniixx May 14, 2025
f900ba5
add more docstrings
phoeenniixx May 14, 2025
ed1b799
add note about the commented out tests
phoeenniixx May 14, 2025
c947910
Merge branch 'main' into refactor-model
phoeenniixx May 16, 2025
c0ceb8a
add the commented out tests
phoeenniixx May 16, 2025
3828c26
remove note
phoeenniixx May 16, 2025
6d6d18e
Merge branch 'main' into refactor-model
phoeenniixx May 18, 2025
3144865
Merge branch 'test-suite' of https://github.com/sktime/pytorch-foreca…
phoeenniixx May 20, 2025
30b541b
make the modules private
phoeenniixx May 20, 2025
3f1e11f
Merge remote-tracking branch 'origin/refactor-model' into refactor-model
phoeenniixx May 20, 2025
5cc3ff1
initial commit
phoeenniixx May 20, 2025
1bcf181
Merge branch 'refactor-model' into test-framework
phoeenniixx May 20, 2025
f18e09d
add TFTMetadata class
phoeenniixx May 20, 2025
e1e360e
add TFTMetadata class
phoeenniixx May 20, 2025
168e16a
Merge branch 'main' into test-framework
phoeenniixx May 22, 2025
92c12bf
add TFT tests
phoeenniixx May 25, 2025
1d478d5
remove refactored TFT
phoeenniixx May 27, 2025
f9992f2
Merge branch 'main' into test-framework
phoeenniixx May 28, 2025
d049019
update test_all_estimators
phoeenniixx May 28, 2025
e72486b
linting
phoeenniixx May 28, 2025
7443b0b
Merge branch 'main' into test-framework
phoeenniixx May 29, 2025
a734f26
refactor
phoeenniixx May 29, 2025
7f466b2
Add more test_params
phoeenniixx May 29, 2025
0968452
Add metadata tests
phoeenniixx May 31, 2025
525bbb9
Merge branch 'main' into test-framework
phoeenniixx Jun 1, 2025
4267da6
Merge branch 'main' into test-framework
phoeenniixx Jun 1, 2025
4e8f863
add object-filter to ptf-v1
phoeenniixx Jun 1, 2025
c117092
Merge branch 'main' into test-framework
phoeenniixx Jun 5, 2025
f6d39fe
Merge branch 'main' into test-framework
phoeenniixx Jun 6, 2025
2c518ee
add new base classes
phoeenniixx Jun 6, 2025
7a5c58f
remove try block
phoeenniixx Jun 8, 2025
cb3e944
Merge branch 'main' into test-framework
phoeenniixx Jun 8, 2025
3b9de6d
add support for multiple datamodules
phoeenniixx Jun 9, 2025
032a7b0
typo
phoeenniixx Jun 9, 2025
4d9a19a
Merge branch 'main' into test-framework
phoeenniixx Jun 9, 2025
03c06e8
Merge branch 'main' into test-framework
phoeenniixx Jun 12, 2025
8b0087e
linting
phoeenniixx Jun 12, 2025
d328fae
Merge branch 'main' into test-framework
phoeenniixx Jun 13, 2025
68df4b6
merge main
phoeenniixx Jun 13, 2025
57d635b
add pkg name to v2
phoeenniixx Jun 13, 2025
e35a4ff
Merge branch 'main' into test-framework
phoeenniixx Jun 14, 2025
c4d5628
revert changes to conftest
fkiraly Jun 15, 2025
6129d33
reverts and fixes
fkiraly Jun 15, 2025
32ef57e
v2
fkiraly Jun 15, 2025
93ea865
Update __init__.py
fkiraly Jun 15, 2025
8e95e6e
Update __init__.py
fkiraly Jun 15, 2025
f990c8a
Update test_all_estimators_v2.py
fkiraly Jun 15, 2025
53747e0
Update _tft_pkg_v2.py
fkiraly Jun 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytorch_forecasting/models/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from pytorch_forecasting.models.base._base_object import (
_BaseObject,
_BasePtForecaster,
_BasePtForecasterV2,
)

__all__ = [
"_BaseObject",
"_BasePtForecaster",
"_BasePtForecasterV2",
"AutoRegressiveBaseModel",
"AutoRegressiveBaseModelWithCovariates",
"BaseModel",
Expand Down
22 changes: 17 additions & 5 deletions pytorch_forecasting/models/base/_base_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,12 @@ class _BaseObject(_SkbaseBaseObject):
pass


class _BasePtForecaster(_BaseObject):
class _BasePtForecaster_Common(_BaseObject):
"""Base class for all PyTorch Forecasting forecaster packages.

This class points to model objects and contains metadata as tags.
"""

_tags = {
"object_type": "forecaster_pytorch",
}

@classmethod
def get_model_cls(cls):
"""Get model class."""
Expand Down Expand Up @@ -112,3 +108,19 @@ def create_test_instances_and_names(cls, parameter_set="default"):
names = [cls.__name__]

return objs, names


class _BasePtForecaster(_BasePtForecaster_Common):
"""Base class for PyTorch Forecasting v1 forecasters."""

_tags = {
"object_type": ["forecaster_pytorch", "forecaster_pytorch_v1"],
}


class _BasePtForecasterV2(_BasePtForecaster_Common):
"""Base class for PyTorch Forecasting v2 forecasters."""

_tags = {
"object_type": "forecaster_pytorch_v2",
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from pytorch_forecasting.models.temporal_fusion_transformer._tft import (
TemporalFusionTransformer,
)
from pytorch_forecasting.models.temporal_fusion_transformer._tft_pkg_v2 import (
TFT_pkg_v2,
)
from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import (
AddNorm,
GateAddNorm,
Expand All @@ -19,5 +22,6 @@
"GatedLinearUnit",
"GatedResidualNetwork",
"InterpretableMultiHeadAttention",
"TFT_pkg_v2",
"VariableSelectionNetwork",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""TFT package container."""

from pytorch_forecasting.models.base import _BasePtForecasterV2


class TFT_pkg_v2(_BasePtForecasterV2):
"""TFT package container."""

_tags = {
"info:name": "TFT",
"authors": ["phoeenniixx"],
"capability:exogenous": True,
"capability:multivariate": True,
"capability:pred_int": True,
"capability:flexible_history_length": False,
}

@classmethod
def get_model_cls(cls):
"""Get model class."""
from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT

return TFT

@classmethod
def _get_test_datamodule_from(cls, trainer_kwargs):
"""Create test dataloaders from trainer_kwargs - following v1 pattern."""
from pytorch_forecasting.data.data_module import (
EncoderDecoderTimeSeriesDataModule,
)
from pytorch_forecasting.tests._data_scenarios import (
data_with_covariates_v2,
make_datasets_v2,
)

data_with_covariates = data_with_covariates_v2()

data_loader_default_kwargs = dict(
target="target",
group_ids=["agency_encoded", "sku_encoded"],
add_relative_time_idx=True,
)

data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {})
data_loader_default_kwargs.update(data_loader_kwargs)

datasets_info = make_datasets_v2(
data_with_covariates, **data_loader_default_kwargs
)

training_dataset = datasets_info["training_dataset"]
validation_dataset = datasets_info["validation_dataset"]
training_max_time_idx = datasets_info["training_max_time_idx"]

max_encoder_length = data_loader_kwargs.get("max_encoder_length", 4)
max_prediction_length = data_loader_kwargs.get("max_prediction_length", 3)
add_relative_time_idx = data_loader_kwargs.get("add_relative_time_idx", True)
batch_size = data_loader_kwargs.get("batch_size", 2)

train_datamodule = EncoderDecoderTimeSeriesDataModule(
time_series_dataset=training_dataset,
max_encoder_length=max_encoder_length,
max_prediction_length=max_prediction_length,
add_relative_time_idx=add_relative_time_idx,
batch_size=batch_size,
train_val_test_split=(0.8, 0.2, 0.0),
)

val_datamodule = EncoderDecoderTimeSeriesDataModule(
time_series_dataset=validation_dataset,
max_encoder_length=max_encoder_length,
max_prediction_length=max_prediction_length,
min_prediction_idx=training_max_time_idx,
add_relative_time_idx=add_relative_time_idx,
batch_size=batch_size,
train_val_test_split=(0.0, 1.0, 0.0),
)

test_datamodule = EncoderDecoderTimeSeriesDataModule(
time_series_dataset=validation_dataset,
max_encoder_length=max_encoder_length,
max_prediction_length=max_prediction_length,
min_prediction_idx=training_max_time_idx,
add_relative_time_idx=add_relative_time_idx,
batch_size=1,
train_val_test_split=(0.0, 0.0, 1.0),
)

train_datamodule.setup("fit")
val_datamodule.setup("fit")
test_datamodule.setup("test")

train_dataloader = train_datamodule.train_dataloader()
val_dataloader = val_datamodule.val_dataloader()
test_dataloader = test_datamodule.test_dataloader()

return {
"train": train_dataloader,
"val": val_dataloader,
"test": test_dataloader,
"data_module": train_datamodule,
}

@classmethod
def get_test_train_params(cls):
"""Return testing parameter settings for the trainer.

Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
return [
{},
dict(
hidden_size=25,
attention_head_size=5,
),
dict(
data_loader_kwargs=dict(max_encoder_length=5, max_prediction_length=3)
),
dict(
hidden_size=24,
attention_head_size=8,
data_loader_kwargs=dict(
max_encoder_length=5,
max_prediction_length=3,
add_relative_time_idx=False,
),
),
dict(
hidden_size=12,
data_loader_kwargs=dict(max_encoder_length=7, max_prediction_length=10),
),
dict(attention_head_size=2),
]
Loading
Loading