@@ -298,6 +298,7 @@ class NeuralProphet:
298
298
>>> m = NeuralProphet(collect_metrics=["MSE", "MAE", "RMSE"])
299
299
>>> # use custorm torchmetrics names
300
300
>>> m = NeuralProphet(collect_metrics={"MAPE": "MeanAbsolutePercentageError", "MSLE": "MeanSquaredLogError",
301
+
301
302
scheduler : str, torch.optim.lr_scheduler._LRScheduler
302
303
Type of learning rate scheduler to use.
303
304
@@ -446,7 +447,8 @@ def __init__(
446
447
batch_size : Optional [int ] = None ,
447
448
loss_func : Union [str , torch .nn .modules .loss ._Loss , Callable ] = "SmoothL1Loss" ,
448
449
optimizer : Union [str , Type [torch .optim .Optimizer ]] = "AdamW" ,
449
- scheduler : Optional [str ] = "onecyclelr" ,
450
+ scheduler : Optional [Union [str , Type [torch .optim .lr_scheduler .LRScheduler ]]] = "onecyclelr" ,
451
+ scheduler_args : Optional [dict ] = None ,
450
452
newer_samples_weight : float = 2 ,
451
453
newer_samples_start : float = 0.0 ,
452
454
quantiles : List [float ] = [],
@@ -521,6 +523,7 @@ def __init__(
521
523
quantiles = quantiles ,
522
524
learning_rate = learning_rate ,
523
525
scheduler = scheduler ,
526
+ scheduler_args = scheduler_args ,
524
527
epochs = epochs ,
525
528
batch_size = batch_size ,
526
529
loss_func = loss_func ,
@@ -932,7 +935,8 @@ def fit(
932
935
continue_training : bool = False ,
933
936
num_workers : int = 0 ,
934
937
deterministic : bool = False ,
935
- scheduler : Optional [str ] = None ,
938
+ scheduler : Optional [Union [str , Type [torch .optim .lr_scheduler .LRScheduler ]]] = None ,
939
+ scheduler_args : Optional [dict ] = None ,
936
940
):
937
941
"""Train, and potentially evaluate model.
938
942
@@ -1002,20 +1006,30 @@ def fit(
1002
1006
"Model has been fitted already. If you want to continue training please set the flag continue_training."
1003
1007
)
1004
1008
1005
- if continue_training and epochs is None :
1006
- raise ValueError ("Continued training requires setting the number of epochs to train for." )
1007
-
1008
1009
if continue_training :
1009
- if scheduler is not None :
1010
- self .config_train .scheduler = scheduler
1011
- else :
1010
+ if epochs is None :
1011
+ raise ValueError ("Continued training requires setting the number of epochs to train for." )
1012
+
1013
+ if continue_training and self .metrics_logger .checkpoint_path is None :
1014
+ log .error ("Continued training requires checkpointing in model to continue from last epoch." )
1015
+
1016
+ # if scheduler is not None:
1017
+ # log.warning(
1018
+ # "Scheduler can only be set in fit when continuing training. Please set the scheduler when initializing the model."
1019
+ # )
1020
+
1021
+ if scheduler is None :
1022
+ log .warning (
1023
+ "No scheduler specified for continued training. Using a fallback scheduler for continued training."
1024
+ )
1012
1025
self .config_train .scheduler = None
1013
- self .config_train .set_scheduler ()
1026
+ self .config_train .scheduler_args = None
1027
+ self .config_train .set_scheduler ()
1014
1028
1015
- if scheduler is not None and not continue_training :
1016
- log . warning (
1017
- "Scheduler can only be set in fit when continuing training. Please set the scheduler when initializing the model."
1018
- )
1029
+ if scheduler is not None :
1030
+ self . config_train . scheduler = scheduler
1031
+ self . config_train . scheduler_args = scheduler_args
1032
+ self . config_train . set_scheduler ( )
1019
1033
1020
1034
# Configuration
1021
1035
if epochs is not None :
@@ -1061,6 +1075,7 @@ def fit(
1061
1075
log .info ("When Global modeling with local normalization, metrics are displayed in normalized scale." )
1062
1076
1063
1077
if minimal :
1078
+ # overrides these settings:
1064
1079
checkpointing = False
1065
1080
self .metrics = False
1066
1081
progress = None
@@ -1101,9 +1116,6 @@ def fit(
1101
1116
or any (value != 1 for value in self .num_seasonalities_modelled_dict .values ())
1102
1117
)
1103
1118
1104
- if continue_training and self .metrics_logger .checkpoint_path is None :
1105
- log .error ("Continued training requires checkpointing in model to continue from last epoch." )
1106
-
1107
1119
self .max_lags = df_utils .get_max_num_lags (
1108
1120
n_lags = self .n_lags , config_lagged_regressors = self .config_lagged_regressors
1109
1121
)
0 commit comments