1
1
import logging
2
+ import math
2
3
import os
3
4
import time
4
5
from collections import OrderedDict
@@ -518,20 +519,36 @@ def __init__(
518
519
trend_local_reg = trend_local_reg ,
519
520
)
520
521
522
+ # Model
523
+ self .quantiles = quantiles
524
+ # convert quantiles to empty list [] if None
525
+ if self .quantiles is None :
526
+ self .quantiles = []
527
+ # assert quantiles is a list type
528
+ assert isinstance (self .quantiles , list ), "Quantiles must be in a list format, not None or scalar."
529
+ # check if quantiles contain 0.5 or close to 0.5, remove if so as 0.5 will be inserted again as first index
530
+ self .quantiles = [quantile for quantile in self .quantiles if not math .isclose (0.5 , quantile )]
531
+ # check if quantiles are float values in (0, 1)
532
+ assert all (
533
+ 0 < quantile < 1 for quantile in self .quantiles
534
+ ), "The quantiles specified need to be floats in-between (0, 1)."
535
+ # sort the quantiles
536
+ self .quantiles .sort ()
537
+ # 0 is the median quantile index
538
+ self .quantiles .insert (0 , 0.5 )
539
+
521
540
# Training
522
- self .config_train = configure .Train (
523
- quantiles = quantiles ,
524
- learning_rate = learning_rate ,
525
- scheduler = scheduler ,
526
- scheduler_args = scheduler_args ,
527
- epochs = epochs ,
528
- batch_size = batch_size ,
529
- loss_func = loss_func ,
530
- optimizer = optimizer ,
531
- newer_samples_weight = newer_samples_weight ,
532
- newer_samples_start = newer_samples_start ,
533
- trend_reg_threshold = self .config_trend .trend_reg_threshold ,
534
- )
541
+ self .learning_rate = learning_rate
542
+ self .scheduler = scheduler
543
+ self .scheduler_args = scheduler_args
544
+ self .epochs = epochs
545
+ self .batch_size = batch_size
546
+ self .loss_func = loss_func
547
+ self .optimizer = optimizer
548
+ self .newer_samples_weight = newer_samples_weight
549
+ self .newer_samples_start = newer_samples_start
550
+ self .trend_reg_threshold = self .config_trend .trend_reg_threshold
551
+ self .continue_training = False
535
552
536
553
# Seasonality
537
554
self .config_seasonality = configure .ConfigSeasonality (
@@ -1013,25 +1030,29 @@ def fit(
1013
1030
if continue_training and self .metrics_logger .checkpoint_path is None :
1014
1031
log .error ("Continued training requires checkpointing in model to continue from last epoch." )
1015
1032
1016
- # if scheduler is not None:
1017
- # log.warning(
1018
- # "Scheduler can only be set in fit when continuing training. Please set the scheduler when initializing the model."
1019
- # )
1033
+ # Configuration
1034
+ self .continue_training = continue_training
1020
1035
1021
- if scheduler is None :
1022
- log .warning (
1023
- "No scheduler specified for continued training. Using a fallback scheduler for continued training."
1024
- )
1025
- self .config_train .scheduler = None
1026
- self .config_train .scheduler_args = None
1027
- self .config_train .set_scheduler ()
1036
+ # Config
1037
+ self .config_train = configure .Train (
1038
+ quantiles = self .quantiles ,
1039
+ learning_rate = self .learning_rate ,
1040
+ scheduler = self .scheduler ,
1041
+ scheduler_args = self .scheduler_args ,
1042
+ epochs = self .epochs ,
1043
+ batch_size = self .batch_size ,
1044
+ loss_func = self .loss_func ,
1045
+ optimizer = self .optimizer ,
1046
+ newer_samples_weight = self .newer_samples_weight ,
1047
+ newer_samples_start = self .newer_samples_start ,
1048
+ trend_reg_threshold = self .config_trend .trend_reg_threshold ,
1049
+ continue_training = self .continue_training ,
1050
+ )
1028
1051
1029
1052
if scheduler is not None :
1030
1053
self .config_train .scheduler = scheduler
1031
1054
self .config_train .scheduler_args = scheduler_args
1032
- self .config_train .set_scheduler ()
1033
1055
1034
- # Configuration
1035
1056
if epochs is not None :
1036
1057
self .config_train .epochs = epochs
1037
1058
@@ -1245,7 +1266,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a
1245
1266
dates = dates ,
1246
1267
predicted = predicted ,
1247
1268
n_forecasts = self .n_forecasts ,
1248
- quantiles = self .config_train . quantiles ,
1269
+ quantiles = self .quantiles ,
1249
1270
components = components ,
1250
1271
)
1251
1272
if auto_extend and periods_added [df_name ] > 0 :
@@ -1260,7 +1281,7 @@ def predict(self, df: pd.DataFrame, decompose: bool = True, raw: bool = False, a
1260
1281
n_forecasts = self .n_forecasts ,
1261
1282
max_lags = self .max_lags ,
1262
1283
freq = self .data_freq ,
1263
- quantiles = self .config_train . quantiles ,
1284
+ quantiles = self .quantiles ,
1264
1285
config_lagged_regressors = self .config_lagged_regressors ,
1265
1286
)
1266
1287
if auto_extend and periods_added [df_name ] > 0 :
@@ -1901,7 +1922,7 @@ def predict_trend(self, df: pd.DataFrame, quantile: float = 0.5):
1901
1922
else :
1902
1923
meta_name_tensor = None
1903
1924
1904
- quantile_index = self .config_train . quantiles .index (quantile )
1925
+ quantile_index = self .quantiles .index (quantile )
1905
1926
trend = self .model .trend (t , meta_name_tensor ).detach ().numpy ()[:, :, quantile_index ].squeeze ()
1906
1927
1907
1928
data_params = self .config_normalization .get_data_params (df_name )
@@ -1966,7 +1987,7 @@ def predict_seasonal_components(self, df: pd.DataFrame, quantile: float = 0.5):
1966
1987
1967
1988
for name in self .config_seasonality .periods :
1968
1989
features = inputs ["seasonalities" ][name ]
1969
- quantile_index = self .config_train . quantiles .index (quantile )
1990
+ quantile_index = self .quantiles .index (quantile )
1970
1991
y_season = torch .squeeze (
1971
1992
self .model .seasonality .compute_fourier (features = features , name = name , meta = meta_name_tensor )[
1972
1993
:, :, quantile_index
@@ -2098,7 +2119,7 @@ def plot(
2098
2119
log .info (f"Plotting data from ID { df_name } " )
2099
2120
if forecast_in_focus is None :
2100
2121
forecast_in_focus = self .highlight_forecast_step_n
2101
- if len (self .config_train . quantiles ) > 1 :
2122
+ if len (self .quantiles ) > 1 :
2102
2123
if (self .highlight_forecast_step_n ) is None and (
2103
2124
self .n_forecasts > 1 or self .n_lags > 0
2104
2125
): # rather query if n_forecasts >1 than n_lags>1
@@ -2138,7 +2159,7 @@ def plot(
2138
2159
if plotting_backend .startswith ("plotly" ):
2139
2160
return plot_plotly (
2140
2161
fcst = fcst ,
2141
- quantiles = self .config_train . quantiles ,
2162
+ quantiles = self .quantiles ,
2142
2163
xlabel = xlabel ,
2143
2164
ylabel = ylabel ,
2144
2165
figsize = tuple (x * 70 for x in figsize ),
@@ -2149,7 +2170,7 @@ def plot(
2149
2170
else :
2150
2171
return plot (
2151
2172
fcst = fcst ,
2152
- quantiles = self .config_train . quantiles ,
2173
+ quantiles = self .quantiles ,
2153
2174
ax = ax ,
2154
2175
xlabel = xlabel ,
2155
2176
ylabel = ylabel ,
@@ -2217,9 +2238,7 @@ def get_latest_forecast(
2217
2238
fcst = fcst [- (include_previous_forecasts + self .n_forecasts ) :]
2218
2239
elif include_history_data is True :
2219
2240
fcst = fcst
2220
- fcst = utils .fcst_df_to_latest_forecast (
2221
- fcst , self .config_train .quantiles , n_last = 1 + include_previous_forecasts
2222
- )
2241
+ fcst = utils .fcst_df_to_latest_forecast (fcst , self .quantiles , n_last = 1 + include_previous_forecasts )
2223
2242
return fcst
2224
2243
2225
2244
def plot_latest_forecast (
@@ -2287,7 +2306,7 @@ def plot_latest_forecast(
2287
2306
else :
2288
2307
fcst = fcst [fcst ["ID" ] == df_name ].copy (deep = True )
2289
2308
log .info (f"Plotting data from ID { df_name } " )
2290
- if len (self .config_train . quantiles ) > 1 :
2309
+ if len (self .quantiles ) > 1 :
2291
2310
log .warning (
2292
2311
"Plotting latest forecasts when uncertainty estimation enabled"
2293
2312
" plots only the median quantile forecasts."
@@ -2298,9 +2317,7 @@ def plot_latest_forecast(
2298
2317
fcst = fcst [- (include_previous_forecasts + self .n_forecasts ) :]
2299
2318
elif plot_history_data is True :
2300
2319
fcst = fcst
2301
- fcst = utils .fcst_df_to_latest_forecast (
2302
- fcst , self .config_train .quantiles , n_last = 1 + include_previous_forecasts
2303
- )
2320
+ fcst = utils .fcst_df_to_latest_forecast (fcst , self .quantiles , n_last = 1 + include_previous_forecasts )
2304
2321
2305
2322
# Check whether a local or global plotting backend is set.
2306
2323
plotting_backend = select_plotting_backend (model = self , plotting_backend = plotting_backend )
@@ -2309,7 +2326,7 @@ def plot_latest_forecast(
2309
2326
if plotting_backend .startswith ("plotly" ):
2310
2327
return plot_plotly (
2311
2328
fcst = fcst ,
2312
- quantiles = self .config_train . quantiles ,
2329
+ quantiles = self .quantiles ,
2313
2330
ylabel = ylabel ,
2314
2331
xlabel = xlabel ,
2315
2332
figsize = tuple (x * 70 for x in figsize ),
@@ -2321,7 +2338,7 @@ def plot_latest_forecast(
2321
2338
else :
2322
2339
return plot (
2323
2340
fcst = fcst ,
2324
- quantiles = self .config_train . quantiles ,
2341
+ quantiles = self .quantiles ,
2325
2342
ax = ax ,
2326
2343
ylabel = ylabel ,
2327
2344
xlabel = xlabel ,
@@ -2487,7 +2504,7 @@ def plot_components(
2487
2504
m = self ,
2488
2505
fcst = fcst ,
2489
2506
plot_configuration = valid_plot_configuration ,
2490
- quantile = self .config_train . quantiles [0 ], # plot components only for median quantile
2507
+ quantile = self .quantiles [0 ], # plot components only for median quantile
2491
2508
figsize = figsize ,
2492
2509
df_name = df_name ,
2493
2510
one_period_per_season = one_period_per_season ,
@@ -2597,11 +2614,11 @@ def plot_parameters(
2597
2614
if not (0 < quantile < 1 ):
2598
2615
raise ValueError ("The quantile selected needs to be a float in-between (0,1)" )
2599
2616
# ValueError if selected quantile is out of range
2600
- if quantile not in self .config_train . quantiles :
2617
+ if quantile not in self .quantiles :
2601
2618
raise ValueError ("Selected quantile is not specified in the model configuration." )
2602
2619
else :
2603
2620
# plot parameters for median quantile if not specified
2604
- quantile = self .config_train . quantiles [0 ]
2621
+ quantile = self .quantiles [0 ]
2605
2622
2606
2623
# Validate components to be plotted
2607
2624
valid_parameters_set = [
@@ -3148,7 +3165,7 @@ def conformal_predict(
3148
3165
alpha = alpha ,
3149
3166
method = method ,
3150
3167
n_forecasts = self .n_forecasts ,
3151
- quantiles = self .config_train . quantiles ,
3168
+ quantiles = self .quantiles ,
3152
3169
)
3153
3170
3154
3171
df_forecast = c .predict (df = df_test , df_cal = df_cal , show_all_PI = show_all_PI )
0 commit comments