Skip to content

Commit 6a74680

Browse files
committed
clean up train config setup
1 parent 63c935c commit 6a74680

File tree

3 files changed

+138
-88
lines changed

3 files changed

+138
-88
lines changed

neuralprophet/configure.py

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,19 @@ class Train:
105105
loss_func_name: str = field(init=False)
106106
lr_finder_args: dict = field(default_factory=dict)
107107
optimizer_state: dict = field(default_factory=dict)
108+
continue_training: bool = False
108109

109110
def __post_init__(self):
110111
# assert the uncertainty estimation params and then finalize the quantiles
111-
self.set_quantiles()
112+
# self.set_quantiles()
112113
assert self.newer_samples_weight >= 1.0
113114
assert self.newer_samples_start >= 0.0
114115
assert self.newer_samples_start < 1.0
115116
self.set_loss_func()
116-
self.set_optimizer()
117-
self.set_scheduler()
117+
118+
# called in TimeNet configure_optimizers:
119+
# self.set_optimizer()
120+
# self.set_scheduler()
118121

119122
def set_loss_func(self):
120123
if isinstance(self.loss_func, str):
@@ -139,22 +142,22 @@ def set_loss_func(self):
139142
if len(self.quantiles) > 1:
140143
self.loss_func = PinballLoss(loss_func=self.loss_func, quantiles=self.quantiles)
141144

142-
def set_quantiles(self):
143-
# convert quantiles to empty list [] if None
144-
if self.quantiles is None:
145-
self.quantiles = []
146-
# assert quantiles is a list type
147-
assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar."
148-
# check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index
149-
self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)]
150-
# check if quantiles are float values in (0, 1)
151-
assert all(
152-
0 < quantile < 1 for quantile in self.quantiles
153-
), "The quantiles specified need to be floats in-between (0, 1)."
154-
# sort the quantiles
155-
self.quantiles.sort()
156-
# 0 is the median quantile index
157-
self.quantiles.insert(0, 0.5)
145+
# def set_quantiles(self):
146+
# # convert quantiles to empty list [] if None
147+
# if self.quantiles is None:
148+
# self.quantiles = []
149+
# # assert quantiles is a list type
150+
# assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar."
151+
# # check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index
152+
# self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)]
153+
# # check if quantiles are float values in (0, 1)
154+
# assert all(
155+
# 0 < quantile < 1 for quantile in self.quantiles
156+
# ), "The quantiles specified need to be floats in-between (0, 1)."
157+
# # sort the quantiles
158+
# self.quantiles.sort()
159+
# # 0 is the median quantile index
160+
# self.quantiles.insert(0, 0.5)
158161

159162
def set_auto_batch_epoch(
160163
self,
@@ -183,16 +186,50 @@ def set_optimizer(self):
183186
"""
184187
Set the optimizer and optimizer args. If optimizer is a string, then it will be converted to the corresponding
185188
torch optimizer. The optimizer is not initialized yet as this is done in configure_optimizers in TimeNet.
189+
190+
Parameters
191+
----------
192+
optimizer_name : int
193+
Object provided to NeuralProphet as optimizer.
194+
optimizer_args : dict
195+
Arguments for the optimizer.
196+
186197
"""
187-
self.optimizer, self.optimizer_args = utils_torch.create_optimizer_from_config(
188-
self.optimizer, self.optimizer_args
189-
)
198+
if isinstance(self.optimizer, str):
199+
if self.optimizer.lower() == "adamw":
200+
# Tends to overfit, but reliable
201+
self.optimizer = torch.optim.AdamW
202+
self.optimizer_args["weight_decay"] = 1e-3
203+
elif self.optimizer.lower() == "sgd":
204+
# better validation performance, but diverges sometimes
205+
self.optimizer = torch.optim.SGD
206+
self.optimizer_args["momentum"] = 0.9
207+
self.optimizer_args["weight_decay"] = 1e-4
208+
else:
209+
raise ValueError(
210+
f"The optimizer name {self.optimizer} is not supported. Please pass the optimizer class."
211+
)
212+
elif not issubclass(self.optimizer, torch.optim.Optimizer):
213+
raise ValueError("The provided optimizer is not supported.")
190214

191215
def set_scheduler(self):
192216
"""
193217
Set the scheduler and scheduler arg depending on the user selection.
194218
The scheduler is not initialized yet as this is done in configure_optimizers in TimeNet.
195219
"""
220+
if self.continue_training:
221+
if (isinstance(self.scheduler, str) and self.scheduler.lower() == "onecyclelr") or isinstance(
222+
self.scheduler, torch.optim.lr_scheduler.OneCycleLR
223+
):
224+
log.warning(
225+
"OneCycleLR scheduler is not supported for continued training. Please set another scheduler. Falling back to ExponentialLR scheduler"
226+
)
227+
self.scheduler = "exponentiallr"
228+
229+
if self.scheduler is None:
230+
log.warning("No scheduler specified. Falling back to ExponentialLR scheduler.")
231+
self.scheduler = "exponentiallr"
232+
196233
if isinstance(self.scheduler, str):
197234
if self.scheduler.lower() == "onecyclelr":
198235
self.scheduler = torch.optim.lr_scheduler.OneCycleLR
@@ -226,12 +263,7 @@ def set_scheduler(self):
226263
if self.scheduler_args is not None:
227264
defaults.update(self.scheduler_args)
228265
self.scheduler_args = defaults
229-
elif self.scheduler is None:
230-
self.scheduler = torch.optim.lr_scheduler.ExponentialLR
231-
self.scheduler_args = {
232-
"gamma": 0.95,
233-
}
234-
else: # if scheduler is a class
266+
else:
235267
assert issubclass(
236268
self.scheduler, torch.optim.lr_scheduler.LRScheduler
237269
), "Scheduler must be a subclass of torch.optim.lr_scheduler.LRScheduler"

neuralprophet/forecaster.py

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import math
23
import os
34
import time
45
from collections import OrderedDict
@@ -518,20 +519,36 @@ def __init__(
518519
trend_local_reg=trend_local_reg,
519520
)
520521

522+
# Model
523+
self.quantiles = quantiles
524+
# convert quantiles to empty list [] if None
525+
if self.quantiles is None:
526+
self.quantiles = []
527+
# assert quantiles is a list type
528+
assert isinstance(self.quantiles, list), "Quantiles must be in a list format, not None or scalar."
529+
# check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index
530+
self.quantiles = [quantile for quantile in self.quantiles if not math.isclose(0.5, quantile)]
531+
# check if quantiles are float values in (0, 1)
532+
assert all(
533+
0 < quantile < 1 for quantile in self.quantiles
534+
), "The quantiles specified need to be floats in-between (0, 1)."
535+
# sort the quantiles
536+
self.quantiles.sort()
537+
# 0 is the median quantile index
538+
self.quantiles.insert(0, 0.5)
539+
521540
# Training
522-
self.config_train = configure.Train(
523-
quantiles=quantiles,
524-
learning_rate=learning_rate,
525-
scheduler=scheduler,
526-
scheduler_args=scheduler_args,
527-
epochs=epochs,
528-
batch_size=batch_size,
529-
loss_func=loss_func,
530-
optimizer=optimizer,
531-
newer_samples_weight=newer_samples_weight,
532-
newer_samples_start=newer_samples_start,
533-
trend_reg_threshold=self.config_trend.trend_reg_threshold,
534-
)
541+
self.learning_rate = learning_rate
542+
self.scheduler = scheduler
543+
self.scheduler_args = scheduler_args
544+
self.epochs = epochs
545+
self.batch_size = batch_size
546+
self.loss_func = loss_func
547+
self.optimizer = optimizer
548+
self.newer_samples_weight = newer_samples_weight
549+
self.newer_samples_start = newer_samples_start
550+
self.trend_reg_threshold = self.config_trend.trend_reg_threshold
551+
self.continue_training = False
535552

536553
# Seasonality
537554
self.config_seasonality = configure.ConfigSeasonality(
@@ -1013,25 +1030,29 @@ def fit(
10131030
if continue_training and self.metrics_logger.checkpoint_path is None:
10141031
log.error("Continued training requires checkpointing in model to continue from last epoch.")
10151032

1016-
# if scheduler is not None:
1017-
# log.warning(
1018-
# "Scheduler can only be set in fit when continuing training. Please set the scheduler when initializing the model."
1019-
# )
1033+
# Configuration
1034+
self.continue_training = continue_training
10201035

1021-
if scheduler is None:
1022-
log.warning(
1023-
"No scheduler specified for continued training. Using a fallback scheduler for continued training."
1024-
)
1025-
self.config_train.scheduler = None
1026-
self.config_train.scheduler_args = None
1027-
self.config_train.set_scheduler()
1036+
# Config
1037+
self.config_train = configure.Train(
1038+
quantiles=self.quantiles,
1039+
learning_rate=self.learning_rate,
1040+
scheduler=self.scheduler,
1041+
scheduler_args=self.scheduler_args,
1042+
epochs=self.epochs,
1043+
batch_size=self.batch_size,
1044+
loss_func=self.loss_func,
1045+
optimizer=self.optimizer,
1046+
newer_samples_weight=self.newer_samples_weight,
1047+
newer_samples_start=self.newer_samples_start,
1048+
trend_reg_threshold=self.config_trend.trend_reg_threshold,
1049+
continue_training=self.continue_training,
1050+
)
10281051

10291052
if scheduler is not None:
10301053
self.config_train.scheduler = scheduler
10311054
self.config_train.scheduler_args = scheduler_args
1032-
self.config_train.set_scheduler()
10331055

1034-
# Configuration
10351056
if epochs is not None:
10361057
self.config_train.epochs = epochs
10371058

@@ -1245,7 +1266,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a
12451266
dates=dates,
12461267
predicted=predicted,
12471268
n_forecasts=self.n_forecasts,
1248-
quantiles=self.config_train.quantiles,
1269+
quantiles=self.quantiles,
12491270
components=components,
12501271
)
12511272
if auto_extend and periods_added[df_name] > 0:
@@ -1260,7 +1281,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a
12601281
n_forecasts=self.n_forecasts,
12611282
max_lags=self.max_lags,
12621283
freq=self.data_freq,
1263-
quantiles=self.config_train.quantiles,
1284+
quantiles=self.quantiles,
12641285
config_lagged_regressors=self.config_lagged_regressors,
12651286
)
12661287
if auto_extend and periods_added[df_name] > 0:
@@ -1901,7 +1922,7 @@ def predict_trend(self, df: pd.DataFrame, quantile: float = 0.5):
19011922
else:
19021923
meta_name_tensor = None
19031924

1904-
quantile_index = self.config_train.quantiles.index(quantile)
1925+
quantile_index = self.quantiles.index(quantile)
19051926
trend = self.model.trend(t, meta_name_tensor).detach().numpy()[:, :, quantile_index].squeeze()
19061927

19071928
data_params = self.config_normalization.get_data_params(df_name)
@@ -1966,7 +1987,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5):
19661987

19671988
for name in self.config_seasonality.periods:
19681989
features = inputs["seasonalities"][name]
1969-
quantile_index = self.config_train.quantiles.index(quantile)
1990+
quantile_index = self.quantiles.index(quantile)
19701991
y_season = torch.squeeze(
19711992
self.model.seasonality.compute_fourier(features=features, name=name, meta=meta_name_tensor)[
19721993
:, :, quantile_index
@@ -2098,7 +2119,7 @@ def plot(
20982119
log.info(f"Plotting data from ID {df_name}")
20992120
if forecast_in_focus is None:
21002121
forecast_in_focus = self.highlight_forecast_step_n
2101-
if len(self.config_train.quantiles) > 1:
2122+
if len(self.quantiles) > 1:
21022123
if (self.highlight_forecast_step_n) is None and (
21032124
self.n_forecasts > 1 or self.n_lags > 0
21042125
): # rather query if n_forecasts >1 than n_lags>1
@@ -2138,7 +2159,7 @@ def plot(
21382159
if plotting_backend.startswith("plotly"):
21392160
return plot_plotly(
21402161
fcst=fcst,
2141-
quantiles=self.config_train.quantiles,
2162+
quantiles=self.quantiles,
21422163
xlabel=xlabel,
21432164
ylabel=ylabel,
21442165
figsize=tuple(x * 70 for x in figsize),
@@ -2149,7 +2170,7 @@ def plot(
21492170
else:
21502171
return plot(
21512172
fcst=fcst,
2152-
quantiles=self.config_train.quantiles,
2173+
quantiles=self.quantiles,
21532174
ax=ax,
21542175
xlabel=xlabel,
21552176
ylabel=ylabel,
@@ -2217,9 +2238,7 @@ def get_latest_forecast(
22172238
fcst = fcst[-(include_previous_forecasts + self.n_forecasts) :]
22182239
elif include_history_data is True:
22192240
fcst = fcst
2220-
fcst = utils.fcst_df_to_latest_forecast(
2221-
fcst, self.config_train.quantiles, n_last=1 + include_previous_forecasts
2222-
)
2241+
fcst = utils.fcst_df_to_latest_forecast(fcst, self.quantiles, n_last=1 + include_previous_forecasts)
22232242
return fcst
22242243

22252244
def plot_latest_forecast(
@@ -2287,7 +2306,7 @@ def plot_latest_forecast(
22872306
else:
22882307
fcst = fcst[fcst["ID"] == df_name].copy(deep=True)
22892308
log.info(f"Plotting data from ID {df_name}")
2290-
if len(self.config_train.quantiles) > 1:
2309+
if len(self.quantiles) > 1:
22912310
log.warning(
22922311
"Plotting latest forecasts when uncertainty estimation enabled"
22932312
" plots only the median quantile forecasts."
@@ -2298,9 +2317,7 @@ def plot_latest_forecast(
22982317
fcst = fcst[-(include_previous_forecasts + self.n_forecasts) :]
22992318
elif plot_history_data is True:
23002319
fcst = fcst
2301-
fcst = utils.fcst_df_to_latest_forecast(
2302-
fcst, self.config_train.quantiles, n_last=1 + include_previous_forecasts
2303-
)
2320+
fcst = utils.fcst_df_to_latest_forecast(fcst, self.quantiles, n_last=1 + include_previous_forecasts)
23042321

23052322
# Check whether a local or global plotting backend is set.
23062323
plotting_backend = select_plotting_backend(model=self, plotting_backend=plotting_backend)
@@ -2309,7 +2326,7 @@ def plot_latest_forecast(
23092326
if plotting_backend.startswith("plotly"):
23102327
return plot_plotly(
23112328
fcst=fcst,
2312-
quantiles=self.config_train.quantiles,
2329+
quantiles=self.quantiles,
23132330
ylabel=ylabel,
23142331
xlabel=xlabel,
23152332
figsize=tuple(x * 70 for x in figsize),
@@ -2321,7 +2338,7 @@ def plot_latest_forecast(
23212338
else:
23222339
return plot(
23232340
fcst=fcst,
2324-
quantiles=self.config_train.quantiles,
2341+
quantiles=self.quantiles,
23252342
ax=ax,
23262343
ylabel=ylabel,
23272344
xlabel=xlabel,
@@ -2487,7 +2504,7 @@ def plot_components(
24872504
m=self,
24882505
fcst=fcst,
24892506
plot_configuration=valid_plot_configuration,
2490-
quantile=self.config_train.quantiles[0], # plot components only for median quantile
2507+
quantile=self.quantiles[0], # plot components only for median quantile
24912508
figsize=figsize,
24922509
df_name=df_name,
24932510
one_period_per_season=one_period_per_season,
@@ -2597,11 +2614,11 @@ def plot_parameters(
25972614
if not (0 < quantile < 1):
25982615
raise ValueError("The quantile selected needs to be a float in-between (0,1)")
25992616
# ValueError if selected quantile is out of range
2600-
if quantile not in self.config_train.quantiles:
2617+
if quantile not in self.quantiles:
26012618
raise ValueError("Selected quantile is not specified in the model configuration.")
26022619
else:
26032620
# plot parameters for median quantile if not specified
2604-
quantile = self.config_train.quantiles[0]
2621+
quantile = self.quantiles[0]
26052622

26062623
# Validate components to be plotted
26072624
valid_parameters_set = [
@@ -3148,7 +3165,7 @@ def conformal_predict(
31483165
alpha=alpha,
31493166
method=method,
31503167
n_forecasts=self.n_forecasts,
3151-
quantiles=self.config_train.quantiles,
3168+
quantiles=self.quantiles,
31523169
)
31533170

31543171
df_forecast = c.predict(df=df_test, df_cal=df_cal, show_all_PI=show_all_PI)

0 commit comments

Comments
 (0)