Skip to content

Commit 622d979

Browse files
committed
update
1 parent d034e0b commit 622d979

File tree

3 files changed

+88
-48
lines changed

3 files changed

+88
-48
lines changed

pytorch_forecasting/models/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def transform_output(
564564
prediction: Union[torch.Tensor, List[torch.Tensor]],
565565
target_scale: Union[torch.Tensor, List[torch.Tensor]],
566566
loss: Optional[Metric] = None,
567-
) -> torch.Tensor:
567+
) -> Union[torch.Tensor, List[torch.Tensor]]:
568568
"""
569569
Extract prediction from network output and rescale it to real space / de-normalize it.
570570

pytorch_forecasting/models/lstm.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__all__ = ["LSTMModel"]
22

3-
from typing import Any, Dict, List, Sequence, Tuple, Union
3+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
44

55
from loguru import logger
66
import torch
@@ -13,35 +13,68 @@
1313

1414

1515
class LSTMModel(AutoRegressiveBaseModel):
16-
"""Simple LSTM model."""
16+
"""Simple LSTM model.
17+
18+
Args:
19+
target (Union[str, Sequence[str]]):
20+
Name (or list of names) of target variable(s).
21+
22+
target_lags (Dict[str, Dict[str, int]]): _description_
23+
24+
n_layers (int):
25+
Number of LSTM layers.
26+
27+
hidden_size (int):
28+
Hidden size for LSTM model.
29+
30+
dropout (float, optional):
31+
Droput probability (<1). Defaults to 0.1.
32+
33+
input_size (int, optional):
34+
Input size. Defaults to: inferred from `target`.
35+
36+
loss (Metric):
37+
Loss criterion. Can be different for each target in multi-target setting thanks to
38+
`MultiLoss`. Defaults to `MAE`.
39+
40+
**kwargs:
41+
See :class:`pytorch_forecasting.models.base_model.AutoRegressiveBaseModel`.
42+
"""
1743

1844
def __init__(
1945
self,
2046
target: Union[str, Sequence[str]],
21-
target_lags: Dict[str, Dict[str, int]],
47+
target_lags: Dict[str, Dict[str, int]], # pylint: disable=unused-argument
2248
n_layers: int,
2349
hidden_size: int,
2450
dropout: float = 0.1,
25-
input_size: int = None,
26-
loss: Metric = None,
51+
input_size: Optional[int] = None,
52+
loss: Optional[Metric] = None,
2753
**kwargs: Any,
2854
):
2955
"""Prefer using the `LSTMModel.from_dataset()` method rather than this constructor.
56+
3057
Args:
3158
target (Union[str, Sequence[str]]):
3259
Name (or list of names) of target variable(s).
3360
target_lags (Dict[str, Dict[str, int]]): _description_
61+
3462
n_layers (int):
3563
Number of LSTM layers.
64+
3665
hidden_size (int):
3766
Hidden size for LSTM model.
67+
3868
dropout (float, optional):
3969
Droput probability (<1). Defaults to 0.1.
70+
4071
input_size (int, optional):
4172
Input size. Defaults to: inferred from `target`.
73+
4274
loss (Metric):
4375
Loss criterion. Can be different for each target in multi-target setting thanks to
4476
`MultiLoss`. Defaults to `MAE`.
77+
4578
**kwargs:
4679
See :class:`pytorch_forecasting.models.base_model.AutoRegressiveBaseModel`.
4780
"""
@@ -55,9 +88,9 @@ def __init__(
5588
self.save_hyperparameters()
5689
# loss
5790
if loss is None:
58-
loss = MultiLoss([MAE() for _ in range(n_targets)]) if n_targets > 1 else MAE()
91+
loss = MultiLoss([MAE() for _ in range(n_targets)]) if n_targets > 1 else MAE() # type: ignore
5992
# pass additional arguments to BaseModel.__init__, mandatory call - do not skip this
60-
super().__init__(loss=loss, **kwargs)
93+
super().__init__(loss=loss, **kwargs) # type: ignore
6194
# use version of LSTM that can handle zero-length sequences
6295
self.lstm = LSTM(
6396
hidden_size=hidden_size,
@@ -168,15 +201,16 @@ def decode(
168201
return output
169202

170203
def forward(self, x: Dict[str, Tensor]) -> Dict[str, Union[Tensor, List[Tensor]]]:
171-
"""_summary_
204+
"""
172205
Args:
173-
x (Dict[str, torch.Tensor]): _description_
206+
x (Dict[str, torch.Tensor]): Input dict.
207+
174208
Returns:
175-
Dict[str, torch.Tensor]: _description_
209+
Dict[str, torch.Tensor]: Output dict.
176210
"""
177211
hidden_state = self.encode(x) # encode to hidden state
178212
output = self.decode(x, hidden_state) # decode leveraging hidden state
179-
out: Dict[str, torch.Tensor] = self.to_network_output(prediction=output)
213+
out = self.to_network_output(prediction=output)
180214
return out
181215

182216
def decode_one(

pytorch_forecasting/models/tuning.py

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""
2-
Hyperparameters can be efficiently tuned with `optuna <https://optuna.readthedocs.io/>`_.
2+
Module for hyperparameter optimization.
3+
4+
Hyperparameters can be efficiently tuned with `optuna <https://optuna.readthedocs.io/>`.
35
"""
46

57
__all__ = ["optimize_hyperparameters"]
@@ -87,8 +89,9 @@ def optimize_hyperparameters(
8789
**kwargs: Any,
8890
) -> optuna.Study:
8991
"""
90-
Optimize hyperparameters. Run hyperparameter optimization. Learning rate for is determined with the
91-
PyTorch Lightning learning rate finder.
92+
Optimize hyperparameters. Run hyperparameter optimization.
93+
94+
Learning rate for is determined with the PyTorch Lightning learning rate finder.
9295
9396
Args:
9497
train_dataloaders (DataLoader):
@@ -98,65 +101,68 @@ def optimize_hyperparameters(
98101
model_path (str):
99102
Folder to which model checkpoints are saved.
100103
monitor (str):
101-
Metric to return. The hyper-parameter (HP) tuner trains a model for a certain HP config, and
102-
reads this metric to score configuration. By default, the lower the better.
104+
Metric to return. The hyper-parameter (HP) tuner trains a model for a certain HP config,
105+
and reads this metric to score configuration. By default, the lower the better.
103106
direction (str):
104-
By default, direction is "minimize", meaning that lower values of the specified `monitor` are
105-
better. You can change this, e.g. to "maximize".
107+
By default, direction is "minimize", meaning that lower values of the specified
108+
``monitor`` are better. You can change this, e.g. to "maximize".
106109
max_epochs (int, optional):
107110
Maximum number of epochs to run training. Defaults to 20.
108111
n_trials (int, optional):
109112
Number of hyperparameter trials to run. Defaults to 100.
110113
timeout (float, optional):
111-
Time in seconds after which training is stopped regardless of number of epochs or validation
112-
metric. Defaults to 3600*8.0.
114+
Time in seconds after which training is stopped regardless of number of epochs or
115+
validation metric. Defaults to 3600*8.0.
113116
input_params (dict, optional):
114-
A dictionary, where each `key` contains another dictionary with two keys: `"method"` and
115-
`"ranges"`. Example:
116-
>>> {"hidden_size": {
117-
>>> "method": "suggest_int",
118-
>>> "ranges": (16, 265),
119-
>>> }}
120-
The method key has to be a method of the `optuna.Trial` object. The ranges key are the input
121-
ranges for the specified method.
117+
A dictionary, where each ``key`` contains another dictionary with two keys: ``"method"``
118+
and ``"ranges"``. Example:
119+
>>> {"hidden_size": {
120+
"method": "suggest_int",
121+
"ranges": (16, 265),
122+
}}
123+
The method key has to be a method of the ``optuna.Trial`` object.
124+
The ranges key are the input ranges for the specified method.
122125
input_params_generator (Callable, optional):
123-
A function with the following signature: `fn(trial: optuna.Trial, **kwargs: Any) -> Dict[str, Any]
124-
`, returning the parameter values to set up your model for the current trial/run.
126+
A function with the following signature:
127+
`fn(trial: optuna.Trial, **kwargs: Any) -> Dict[str, Any]`,
128+
returning the parameter values to set up your model for the current trial/run.
125129
Example:
126-
>>> def fn(trial: optuna.Trial, param_ranges: Tuple[int, int] = (16, 265)) -> Dict[str, Any]:
127-
>>> param = trial.suggest_int("param", *param_ranges, log=True)
128-
>>> model_params = {"param": param}
129-
>>> return model_params
130-
Then, when your model is created (before training it and report the metrics for the current
131-
combination of hyperparameters), these dictionary is used as follows:
132-
>>> model = YourModelClass.from_dataset(
133-
>>> train_dataloaders.dataset,
134-
>>> log_interval=-1,
135-
>>> **model_params,
136-
>>> )
130+
>>> def fn(trial, param_ranges = (16, 265)) -> Dict[str, Any]:
131+
param = trial.suggest_int("param", *param_ranges, log=True)
132+
model_params = {"param": param}
133+
return model_params
134+
Then, when your model is created (before training it and report the metrics for
135+
the current combination of hyperparameters), these dictionary is used as follows:
136+
>>> model = YourModelClass.from_dataset(
137+
train_dataloaders.dataset,
138+
log_interval=-1,
139+
**model_params,
140+
)
137141
generator_params (dict, optional):
138-
The additional parameters to be passed to the `input_params_generator` function, if required.
142+
The additional parameters to be passed to the ``input_params_generator`` function,
143+
if required.
139144
learning_rate_range (Tuple[float, float], optional):
140145
Learning rate range. Defaults to (1e-5, 1.0).
141146
use_learning_rate_finder (bool):
142-
If to use learning rate finder or optimize as part of hyperparameters. Defaults to True.
147+
If to use learning rate finder or optimize as part of hyperparameters.
148+
Defaults to True.
143149
trainer_kwargs (Dict[str, Any], optional):
144150
Additional arguments to the
145-
`PyTorch Lightning trainer <https://pytorch-lightning.readthedocs.io/en/latest/trainer.html>`
146-
such as `limit_train_batches`. Defaults to {}.
151+
PyTorch Lightning trainer such as ``limit_train_batches``.
152+
Defaults to {}.
147153
log_dir (str, optional):
148154
Folder into which to log results for tensorboard. Defaults to "lightning_logs".
149155
study (optuna.Study, optional):
150156
Study to resume. Will create new study by default.
151157
verbose (Union[int, bool]):
152158
Level of verbosity.
153-
* None: no change in verbosity level (equivalent to verbose=1 by optuna-set default).
159+
* None: no change in verbosity level (equivalent to verbose=1).
154160
* 0 or False: log only warnings.
155161
* 1 or True: log pruning events.
156162
* 2: optuna logging level at debug level.
157163
Defaults to None.
158164
pruner (optuna.pruners.BasePruner, optional):
159-
The optuna pruner to use. Defaults to `optuna.pruners.SuccessiveHalvingPruner()`.
165+
The optuna pruner to use. Defaults to ``optuna.pruners.SuccessiveHalvingPruner()``.
160166
**kwargs:
161167
Additional arguments for your model's class.
162168

0 commit comments

Comments
 (0)