@@ -195,7 +195,7 @@ def __init__(
195
195
## set during prediction
196
196
self .future_periods = None
197
197
## later set by user (optional)
198
- self .forecast_in_focus = None
198
+ self .highlight_forecast_step_n = None
199
199
self .true_ar_weights = None
200
200
201
201
def _init_model (self ):
@@ -497,8 +497,8 @@ def _train(self, df, df_val=None):
497
497
loader = self ._init_train_loader (df )
498
498
val = df_val is not None
499
499
## Metrics
500
- if self .forecast_in_focus is not None :
501
- self .metrics .add_specific_target (target_pos = self .forecast_in_focus - 1 )
500
+ if self .highlight_forecast_step_n is not None :
501
+ self .metrics .add_specific_target (target_pos = self .highlight_forecast_step_n - 1 )
502
502
if self .normalize_y :
503
503
self .metrics .set_shift_scale ((self .data_params ['y' ].shift , self .data_params ['y' ].scale ))
504
504
if val :
@@ -528,12 +528,12 @@ def _train(self, df, df_val=None):
528
528
529
529
def _eval_true_ar (self , verbose = False ):
530
530
assert self .n_lags > 0
531
- if self .forecast_in_focus is None :
531
+ if self .highlight_forecast_step_n is None :
532
532
if self .n_lags > 1 :
533
533
raise ValueError ("Please define forecast_lag for sTPE computation" )
534
534
forecast_pos = 1
535
535
else :
536
- forecast_pos = self .forecast_in_focus
536
+ forecast_pos = self .highlight_forecast_step_n
537
537
weights = self .model .ar_weights .detach ().numpy ()
538
538
weights = weights [forecast_pos - 1 , :][::- 1 ]
539
539
sTPE = utils .symmetric_total_percentage_error (self .true_ar_weights , weights )
@@ -552,8 +552,8 @@ def _evaluate(self, loader, verbose=None):
552
552
if self .fitted is False : raise Exception ('Model object needs to be fit first.' )
553
553
if verbose is None : verbose = self .verbose
554
554
val_metrics = metrics .MetricsCollection ([m .new () for m in self .metrics .batch_metrics ])
555
- if self .forecast_in_focus is not None :
556
- val_metrics .add_specific_target (target_pos = self .forecast_in_focus - 1 )
555
+ if self .highlight_forecast_step_n is not None :
556
+ val_metrics .add_specific_target (target_pos = self .highlight_forecast_step_n - 1 )
557
557
## Run
558
558
val_metrics_dict = self ._evaluate_epoch (loader , val_metrics )
559
559
@@ -840,16 +840,16 @@ def set_true_ar_for_eval(self, true_ar_weights):
840
840
"""
841
841
self .true_ar_weights = true_ar_weights
842
842
843
- def set_forecast_in_focus (self , forecast_number = None ):
843
+ def highlight_nth_step_ahead_of_each_forecast (self , step_number = None ):
844
844
"""Set which forecast step to focus on for metrics evaluation and plotting.
845
845
846
846
Args:
847
- forecast_number (int): i-th step ahead forecast to use for performance statistics evaluation .
848
- Can also be None.
847
+ step_number (int): i-th step ahead forecast to use for statistics and plotting .
848
+ default: None.
849
849
"""
850
- if forecast_number is not None :
851
- assert forecast_number <= self .n_forecasts
852
- self .forecast_in_focus = forecast_number
850
+ if step_number is not None :
851
+ assert step_number <= self .n_forecasts
852
+ self .highlight_forecast_step_n = step_number
853
853
return self
854
854
855
855
def add_covariate (self , name , regularization = None , normalize = 'auto' , only_last_value = False ):
@@ -1029,7 +1029,7 @@ def plot(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10, 6)):
1029
1029
include_previous_forecasts = num_forecasts - 1 , plot_history_data = True )
1030
1030
return plotting .plot (
1031
1031
fcst = fcst , ax = ax , xlabel = xlabel , ylabel = ylabel , figsize = figsize ,
1032
- highlight_forecast = self .forecast_in_focus
1032
+ highlight_forecast = self .highlight_forecast_step_n
1033
1033
)
1034
1034
1035
1035
def plot_last_forecast (self , fcst , ax = None , xlabel = 'ds' , ylabel = 'y' , figsize = (10 , 6 ),
@@ -1058,7 +1058,7 @@ def plot_last_forecast(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10
1058
1058
fcst = utils .fcst_df_to_last_forecast (fcst , n_last = 1 + include_previous_forecasts )
1059
1059
return plotting .plot (
1060
1060
fcst = fcst , ax = ax , xlabel = xlabel , ylabel = ylabel , figsize = figsize ,
1061
- highlight_forecast = self .forecast_in_focus , line_per_origin = True ,
1061
+ highlight_forecast = self .highlight_forecast_step_n , line_per_origin = True ,
1062
1062
)
1063
1063
1064
1064
def plot_components (self , fcst , figsize = None ):
@@ -1077,7 +1077,7 @@ def plot_components(self, fcst, figsize=None):
1077
1077
m = self ,
1078
1078
fcst = fcst ,
1079
1079
figsize = figsize ,
1080
- forecast_in_focus = self .forecast_in_focus ,
1080
+ forecast_in_focus = self .highlight_forecast_step_n ,
1081
1081
)
1082
1082
1083
1083
def plot_parameters (self , weekly_start = 0 , yearly_start = 0 , figsize = None ):
@@ -1095,7 +1095,7 @@ def plot_parameters(self, weekly_start=0, yearly_start=0, figsize=None):
1095
1095
"""
1096
1096
return plotting .plot_parameters (
1097
1097
m = self ,
1098
- forecast_in_focus = self .forecast_in_focus ,
1098
+ forecast_in_focus = self .highlight_forecast_step_n ,
1099
1099
weekly_start = weekly_start ,
1100
1100
yearly_start = yearly_start ,
1101
1101
figsize = figsize ,
0 commit comments