@@ -195,7 +195,7 @@ def __init__(
195
195
## set during prediction
196
196
self .future_periods = None
197
197
## later set by user (optional)
198
- self .forecast_in_focus = None
198
+ self .highlight_forecast_step_n = None
199
199
self .true_ar_weights = None
200
200
201
201
def _init_model (self ):
@@ -495,8 +495,8 @@ def _train(self, df, df_val=None):
495
495
loader = self ._init_train_loader (df )
496
496
val = df_val is not None
497
497
## Metrics
498
- if self .forecast_in_focus is not None :
499
- self .metrics .add_specific_target (target_pos = self .forecast_in_focus - 1 )
498
+ if self .highlight_forecast_step_n is not None :
499
+ self .metrics .add_specific_target (target_pos = self .highlight_forecast_step_n - 1 )
500
500
if self .normalize_y :
501
501
self .metrics .set_shift_scale ((self .data_params ['y' ].shift , self .data_params ['y' ].scale ))
502
502
if val :
@@ -526,12 +526,12 @@ def _train(self, df, df_val=None):
526
526
527
527
def _eval_true_ar (self , verbose = False ):
528
528
assert self .n_lags > 0
529
- if self .forecast_in_focus is None :
529
+ if self .highlight_forecast_step_n is None :
530
530
if self .n_lags > 1 :
531
531
raise ValueError ("Please define forecast_lag for sTPE computation" )
532
532
forecast_pos = 1
533
533
else :
534
- forecast_pos = self .forecast_in_focus
534
+ forecast_pos = self .highlight_forecast_step_n
535
535
weights = self .model .ar_weights .detach ().numpy ()
536
536
weights = weights [forecast_pos - 1 , :][::- 1 ]
537
537
sTPE = utils .symmetric_total_percentage_error (self .true_ar_weights , weights )
@@ -550,8 +550,8 @@ def _evaluate(self, loader, verbose=None):
550
550
if self .fitted is False : raise Exception ('Model object needs to be fit first.' )
551
551
if verbose is None : verbose = self .verbose
552
552
val_metrics = metrics .MetricsCollection ([m .new () for m in self .metrics .batch_metrics ])
553
- if self .forecast_in_focus is not None :
554
- val_metrics .add_specific_target (target_pos = self .forecast_in_focus - 1 )
553
+ if self .highlight_forecast_step_n is not None :
554
+ val_metrics .add_specific_target (target_pos = self .highlight_forecast_step_n - 1 )
555
555
## Run
556
556
val_metrics_dict = self ._evaluate_epoch (loader , val_metrics )
557
557
@@ -838,16 +838,16 @@ def set_true_ar_for_eval(self, true_ar_weights):
838
838
"""
839
839
self .true_ar_weights = true_ar_weights
840
840
841
- def set_forecast_in_focus (self , forecast_number = None ):
841
+ def highlight_nth_step_ahead_of_each_forecast (self , step_number = None ):
842
842
"""Set which forecast step to focus on for metrics evaluation and plotting.
843
843
844
844
Args:
845
- forecast_number (int): i-th step ahead forecast to use for performance statistics evaluation .
846
- Can also be None.
845
+ step_number (int): i-th step ahead forecast to use for statistics and plotting .
846
+ default: None.
847
847
"""
848
- if forecast_number is not None :
849
- assert forecast_number <= self .n_forecasts
850
- self .forecast_in_focus = forecast_number
848
+ if step_number is not None :
849
+ assert step_number <= self .n_forecasts
850
+ self .highlight_forecast_step_n = step_number
851
851
return self
852
852
853
853
def add_covariate (self , name , regularization = None , normalize = 'auto' , only_last_value = False ):
@@ -1027,7 +1027,7 @@ def plot(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10, 6)):
1027
1027
include_previous_forecasts = num_forecasts - 1 , plot_history_data = True )
1028
1028
return plotting .plot (
1029
1029
fcst = fcst , ax = ax , xlabel = xlabel , ylabel = ylabel , figsize = figsize ,
1030
- highlight_forecast = self .forecast_in_focus
1030
+ highlight_forecast = self .highlight_forecast_step_n
1031
1031
)
1032
1032
1033
1033
def plot_last_forecast (self , fcst , ax = None , xlabel = 'ds' , ylabel = 'y' , figsize = (10 , 6 ),
@@ -1056,7 +1056,7 @@ def plot_last_forecast(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10
1056
1056
fcst = utils .fcst_df_to_last_forecast (fcst , n_last = 1 + include_previous_forecasts )
1057
1057
return plotting .plot (
1058
1058
fcst = fcst , ax = ax , xlabel = xlabel , ylabel = ylabel , figsize = figsize ,
1059
- highlight_forecast = self .forecast_in_focus , line_per_origin = True ,
1059
+ highlight_forecast = self .highlight_forecast_step_n , line_per_origin = True ,
1060
1060
)
1061
1061
1062
1062
def plot_components (self , fcst , figsize = (10 , 6 )):
@@ -1074,7 +1074,7 @@ def plot_components(self, fcst, figsize=(10, 6)):
1074
1074
m = self ,
1075
1075
fcst = fcst ,
1076
1076
figsize = figsize ,
1077
- forecast_in_focus = self .forecast_in_focus ,
1077
+ forecast_in_focus = self .highlight_forecast_step_n ,
1078
1078
)
1079
1079
1080
1080
def plot_parameters (self , weekly_start = 0 , yearly_start = 0 , figsize = (10 , 6 )):
@@ -1091,7 +1091,7 @@ def plot_parameters(self, weekly_start=0, yearly_start=0, figsize=(10, 6)):
1091
1091
"""
1092
1092
return plotting .plot_parameters (
1093
1093
m = self ,
1094
- forecast_in_focus = self .forecast_in_focus ,
1094
+ forecast_in_focus = self .highlight_forecast_step_n ,
1095
1095
weekly_start = weekly_start ,
1096
1096
yearly_start = yearly_start ,
1097
1097
figsize = figsize ,
0 commit comments