Skip to content

Commit 300fd90

Browse files
authored
[ENH] Test framework for ptf-v2 (#1841)
### Description This PR solves #1838 Implements Test framework for v2
1 parent 1fc5542 commit 300fd90

File tree

7 files changed

+466
-5
lines changed

7 files changed

+466
-5
lines changed

pytorch_forecasting/models/base/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from pytorch_forecasting.models.base._base_object import (
1111
_BaseObject,
1212
_BasePtForecaster,
13+
_BasePtForecasterV2,
1314
)
1415

1516
__all__ = [
1617
"_BaseObject",
1718
"_BasePtForecaster",
19+
"_BasePtForecasterV2",
1820
"AutoRegressiveBaseModel",
1921
"AutoRegressiveBaseModelWithCovariates",
2022
"BaseModel",

pytorch_forecasting/models/base/_base_object.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,12 @@ class _BaseObject(_SkbaseBaseObject):
1111
pass
1212

1313

14-
class _BasePtForecaster(_BaseObject):
14+
class _BasePtForecaster_Common(_BaseObject):
1515
"""Base class for all PyTorch Forecasting forecaster packages.
1616
1717
This class points to model objects and contains metadata as tags.
1818
"""
1919

20-
_tags = {
21-
"object_type": "forecaster_pytorch",
22-
}
23-
2420
@classmethod
2521
def get_model_cls(cls):
2622
"""Get model class."""
@@ -112,3 +108,19 @@ def create_test_instances_and_names(cls, parameter_set="default"):
112108
names = [cls.__name__]
113109

114110
return objs, names
111+
112+
113+
class _BasePtForecaster(_BasePtForecaster_Common):
114+
"""Base class for PyTorch Forecasting v1 forecasters."""
115+
116+
_tags = {
117+
"object_type": ["forecaster_pytorch", "forecaster_pytorch_v1"],
118+
}
119+
120+
121+
class _BasePtForecasterV2(_BasePtForecaster_Common):
122+
"""Base class for PyTorch Forecasting v2 forecasters."""
123+
124+
_tags = {
125+
"object_type": "forecaster_pytorch_v2",
126+
}

pytorch_forecasting/models/temporal_fusion_transformer/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from pytorch_forecasting.models.temporal_fusion_transformer._tft import (
44
TemporalFusionTransformer,
55
)
6+
from pytorch_forecasting.models.temporal_fusion_transformer._tft_pkg_v2 import (
7+
TFT_pkg_v2,
8+
)
69
from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import (
710
AddNorm,
811
GateAddNorm,
@@ -19,5 +22,6 @@
1922
"GatedLinearUnit",
2023
"GatedResidualNetwork",
2124
"InterpretableMultiHeadAttention",
25+
"TFT_pkg_v2",
2226
"VariableSelectionNetwork",
2327
]
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""TFT package container."""
2+
3+
from pytorch_forecasting.models.base import _BasePtForecasterV2
4+
5+
6+
class TFT_pkg_v2(_BasePtForecasterV2):
7+
"""TFT package container."""
8+
9+
_tags = {
10+
"info:name": "TFT",
11+
"authors": ["phoeenniixx"],
12+
"capability:exogenous": True,
13+
"capability:multivariate": True,
14+
"capability:pred_int": True,
15+
"capability:flexible_history_length": False,
16+
}
17+
18+
@classmethod
19+
def get_model_cls(cls):
20+
"""Get model class."""
21+
from pytorch_forecasting.models.temporal_fusion_transformer._tft_v2 import TFT
22+
23+
return TFT
24+
25+
@classmethod
26+
def _get_test_datamodule_from(cls, trainer_kwargs):
27+
"""Create test dataloaders from trainer_kwargs - following v1 pattern."""
28+
from pytorch_forecasting.data.data_module import (
29+
EncoderDecoderTimeSeriesDataModule,
30+
)
31+
from pytorch_forecasting.tests._data_scenarios import (
32+
data_with_covariates_v2,
33+
make_datasets_v2,
34+
)
35+
36+
data_with_covariates = data_with_covariates_v2()
37+
38+
data_loader_default_kwargs = dict(
39+
target="target",
40+
group_ids=["agency_encoded", "sku_encoded"],
41+
add_relative_time_idx=True,
42+
)
43+
44+
data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {})
45+
data_loader_default_kwargs.update(data_loader_kwargs)
46+
47+
datasets_info = make_datasets_v2(
48+
data_with_covariates, **data_loader_default_kwargs
49+
)
50+
51+
training_dataset = datasets_info["training_dataset"]
52+
validation_dataset = datasets_info["validation_dataset"]
53+
training_max_time_idx = datasets_info["training_max_time_idx"]
54+
55+
max_encoder_length = data_loader_kwargs.get("max_encoder_length", 4)
56+
max_prediction_length = data_loader_kwargs.get("max_prediction_length", 3)
57+
add_relative_time_idx = data_loader_kwargs.get("add_relative_time_idx", True)
58+
batch_size = data_loader_kwargs.get("batch_size", 2)
59+
60+
train_datamodule = EncoderDecoderTimeSeriesDataModule(
61+
time_series_dataset=training_dataset,
62+
max_encoder_length=max_encoder_length,
63+
max_prediction_length=max_prediction_length,
64+
add_relative_time_idx=add_relative_time_idx,
65+
batch_size=batch_size,
66+
train_val_test_split=(0.8, 0.2, 0.0),
67+
)
68+
69+
val_datamodule = EncoderDecoderTimeSeriesDataModule(
70+
time_series_dataset=validation_dataset,
71+
max_encoder_length=max_encoder_length,
72+
max_prediction_length=max_prediction_length,
73+
min_prediction_idx=training_max_time_idx,
74+
add_relative_time_idx=add_relative_time_idx,
75+
batch_size=batch_size,
76+
train_val_test_split=(0.0, 1.0, 0.0),
77+
)
78+
79+
test_datamodule = EncoderDecoderTimeSeriesDataModule(
80+
time_series_dataset=validation_dataset,
81+
max_encoder_length=max_encoder_length,
82+
max_prediction_length=max_prediction_length,
83+
min_prediction_idx=training_max_time_idx,
84+
add_relative_time_idx=add_relative_time_idx,
85+
batch_size=1,
86+
train_val_test_split=(0.0, 0.0, 1.0),
87+
)
88+
89+
train_datamodule.setup("fit")
90+
val_datamodule.setup("fit")
91+
test_datamodule.setup("test")
92+
93+
train_dataloader = train_datamodule.train_dataloader()
94+
val_dataloader = val_datamodule.val_dataloader()
95+
test_dataloader = test_datamodule.test_dataloader()
96+
97+
return {
98+
"train": train_dataloader,
99+
"val": val_dataloader,
100+
"test": test_dataloader,
101+
"data_module": train_datamodule,
102+
}
103+
104+
@classmethod
105+
def get_test_train_params(cls):
106+
"""Return testing parameter settings for the trainer.
107+
108+
Returns
109+
-------
110+
params : dict or list of dict, default = {}
111+
Parameters to create testing instances of the class
112+
Each dict are parameters to construct an "interesting" test instance, i.e.,
113+
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
114+
`create_test_instance` uses the first (or only) dictionary in `params`
115+
"""
116+
return [
117+
{},
118+
dict(
119+
hidden_size=25,
120+
attention_head_size=5,
121+
),
122+
dict(
123+
data_loader_kwargs=dict(max_encoder_length=5, max_prediction_length=3)
124+
),
125+
dict(
126+
hidden_size=24,
127+
attention_head_size=8,
128+
data_loader_kwargs=dict(
129+
max_encoder_length=5,
130+
max_prediction_length=3,
131+
add_relative_time_idx=False,
132+
),
133+
),
134+
dict(
135+
hidden_size=12,
136+
data_loader_kwargs=dict(max_encoder_length=7, max_prediction_length=10),
137+
),
138+
dict(attention_head_size=2),
139+
]

0 commit comments

Comments
 (0)