Skip to content

Commit 63c935c

Browse files
committed
robustify scheduler config
1 parent df74dc3 commit 63c935c

File tree

2 files changed

+59
-49
lines changed

2 files changed

+59
-49
lines changed

neuralprophet/configure.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class Train:
9494
optimizer: Union[str, Type[torch.optim.Optimizer]]
9595
quantiles: List[float] = field(default_factory=list)
9696
optimizer_args: dict = field(default_factory=dict)
97-
scheduler: Optional[Type[torch.optim.lr_scheduler._LRScheduler]] = None
97+
scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None
9898
scheduler_args: dict = field(default_factory=dict)
9999
newer_samples_weight: float = 1.0
100100
newer_samples_start: float = 0.0
@@ -193,50 +193,48 @@ def set_scheduler(self):
193193
Set the scheduler and scheduler arg depending on the user selection.
194194
The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet.
195195
"""
196-
self.scheduler_args.clear()
197196
if isinstance(self.scheduler, str):
198197
if self.scheduler.lower() == "onecyclelr":
199198
self.scheduler = torch.optim.lr_scheduler.OneCycleLR
200-
self.scheduler_args.update(
201-
{
202-
"pct_start": 0.3,
203-
"anneal_strategy": "cos",
204-
"div_factor": 10.0,
205-
"final_div_factor": 10.0,
206-
"three_phase": True,
207-
}
208-
)
199+
defaults = {
200+
"pct_start": 0.3,
201+
"anneal_strategy": "cos",
202+
"div_factor": 10.0,
203+
"final_div_factor": 10.0,
204+
"three_phase": True,
205+
}
209206
elif self.scheduler.lower() == "steplr":
210207
self.scheduler = torch.optim.lr_scheduler.StepLR
211-
self.scheduler_args.update(
212-
{
213-
"step_size": 10,
214-
"gamma": 0.1,
215-
}
216-
)
208+
defaults = {
209+
"step_size": 10,
210+
"gamma": 0.1,
211+
}
217212
elif self.scheduler.lower() == "exponentiallr":
218213
self.scheduler = torch.optim.lr_scheduler.ExponentialLR
219-
self.scheduler_args.update(
220-
{
221-
"gamma": 0.95,
222-
}
223-
)
214+
defaults = {
215+
"gamma": 0.95,
216+
}
224217
elif self.scheduler.lower() == "cosineannealinglr":
225218
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR
226-
self.scheduler_args.update(
227-
{
228-
"T_max": 50,
229-
}
230-
)
219+
defaults = {
220+
"T_max": 50,
221+
}
231222
else:
232-
raise NotImplementedError(f"Scheduler {self.scheduler} is not supported.")
223+
raise NotImplementedError(
224+
f"Scheduler {self.scheduler} is not supported from string. Please pass the scheduler class."
225+
)
226+
if self.scheduler_args is not None:
227+
defaults.update(self.scheduler_args)
228+
self.scheduler_args = defaults
233229
elif self.scheduler is None:
234230
self.scheduler = torch.optim.lr_scheduler.ExponentialLR
235-
self.scheduler_args.update(
236-
{
237-
"gamma": 0.95,
238-
}
239-
)
231+
self.scheduler_args = {
232+
"gamma": 0.95,
233+
}
234+
else: # if scheduler is a class
235+
assert issubclass(
236+
self.scheduler, torch.optim.lr_scheduler.LRScheduler
237+
), "Scheduler must be a subclass of torch.optim.lr_scheduler.LRScheduler"
240238

241239
def set_lr_finder_args(self, dataset_size, num_batches):
242240
"""

neuralprophet/forecaster.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ class NeuralProphet:
298298
>>> m = NeuralProphet(collect_metrics=["MSE", "MAE", "RMSE"])
299299
>>> # use custorm torchmetrics names
300300
>>> m = NeuralProphet(collect_metrics={"MAPE": "MeanAbsolutePercentageError", "MSLE": "MeanSquaredLogError",
301+
301302
scheduler : str, torch.optim.lr_scheduler._LRScheduler
302303
Type of learning rate scheduler to use.
303304
@@ -446,7 +447,8 @@ def __init__(
446447
batch_size: Optional[int] = None,
447448
loss_func: Union[str, torch.nn.modules.loss._Loss, Callable] = "SmoothL1Loss",
448449
optimizer: Union[str, Type[torch.optim.Optimizer]] = "AdamW",
449-
scheduler: Optional[str] = "onecyclelr",
450+
scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = "onecyclelr",
451+
scheduler_args: Optional[dict] = None,
450452
newer_samples_weight: float = 2,
451453
newer_samples_start: float = 0.0,
452454
quantiles: List[float] = [],
@@ -521,6 +523,7 @@ def __init__(
521523
quantiles=quantiles,
522524
learning_rate=learning_rate,
523525
scheduler=scheduler,
526+
scheduler_args=scheduler_args,
524527
epochs=epochs,
525528
batch_size=batch_size,
526529
loss_func=loss_func,
@@ -932,7 +935,8 @@ def fit(
932935
continue_training: bool = False,
933936
num_workers: int = 0,
934937
deterministic: bool = False,
935-
scheduler: Optional[str] = None,
938+
scheduler: Optional[Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]] = None,
939+
scheduler_args: Optional[dict] = None,
936940
):
937941
"""Train, and potentially evaluate model.
938942
@@ -1002,20 +1006,30 @@ def fit(
10021006
"Model has been fitted already. If you want to continue training please set the flag continue_training."
10031007
)
10041008

1005-
if continue_training and epochs is None:
1006-
raise ValueError("Continued training requires setting the number of epochs to train for.")
1007-
10081009
if continue_training:
1009-
if scheduler is not None:
1010-
self.config_train.scheduler = scheduler
1011-
else:
1010+
if epochs is None:
1011+
raise ValueError("Continued training requires setting the number of epochs to train for.")
1012+
1013+
if continue_training and self.metrics_logger.checkpoint_path is None:
1014+
log.error("Continued training requires checkpointing in model to continue from last epoch.")
1015+
1016+
# if scheduler is not None:
1017+
# log.warning(
1018+
# "Scheduler can only be set in fit when continuing training. Please set the scheduler when initializing the model."
1019+
# )
1020+
1021+
if scheduler is None:
1022+
log.warning(
1023+
"No scheduler specified for continued training. Using a fallback scheduler for continued training."
1024+
)
10121025
self.config_train.scheduler = None
1013-
self.config_train.set_scheduler()
1026+
self.config_train.scheduler_args = None
1027+
self.config_train.set_scheduler()
10141028

1015-
if scheduler is not None and not continue_training:
1016-
log.warning(
1017-
"Scheduler can only be set in fit when continuing training. Please set the scheduler when initializing the model."
1018-
)
1029+
if scheduler is not None:
1030+
self.config_train.scheduler = scheduler
1031+
self.config_train.scheduler_args = scheduler_args
1032+
self.config_train.set_scheduler()
10191033

10201034
# Configuration
10211035
if epochs is not None:
@@ -1061,6 +1075,7 @@ def fit(
10611075
log.info("When Global modeling with local normalization, metrics are displayed in normalized scale.")
10621076

10631077
if minimal:
1078+
# overrides these settings:
10641079
checkpointing = False
10651080
self.metrics = False
10661081
progress = None
@@ -1101,9 +1116,6 @@ def fit(
11011116
or any(value != 1 for value in self.num_seasonalities_modelled_dict.values())
11021117
)
11031118

1104-
if continue_training and self.metrics_logger.checkpoint_path is None:
1105-
log.error("Continued training requires checkpointing in model to continue from last epoch.")
1106-
11071119
self.max_lags = df_utils.get_max_num_lags(
11081120
n_lags=self.n_lags, config_lagged_regressors=self.config_lagged_regressors
11091121
)

0 commit comments

Comments
 (0)