Skip to content

Commit 7fdbddd

Browse files
authored
Merge pull request #80 from ourownstory/forecast-in-focus
update name of forecast_in_focus to highlight_forecast_step_n
2 parents d90aadf + 68da7d3 commit 7fdbddd

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

neuralprophet/neural_prophet.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def __init__(
195195
## set during prediction
196196
self.future_periods = None
197197
## later set by user (optional)
198-
self.forecast_in_focus = None
198+
self.highlight_forecast_step_n = None
199199
self.true_ar_weights = None
200200

201201
def _init_model(self):
@@ -497,8 +497,8 @@ def _train(self, df, df_val=None):
497497
loader = self._init_train_loader(df)
498498
val = df_val is not None
499499
## Metrics
500-
if self.forecast_in_focus is not None:
501-
self.metrics.add_specific_target(target_pos=self.forecast_in_focus - 1)
500+
if self.highlight_forecast_step_n is not None:
501+
self.metrics.add_specific_target(target_pos=self.highlight_forecast_step_n - 1)
502502
if self.normalize_y:
503503
self.metrics.set_shift_scale((self.data_params['y'].shift, self.data_params['y'].scale))
504504
if val:
@@ -528,12 +528,12 @@ def _train(self, df, df_val=None):
528528

529529
def _eval_true_ar(self, verbose=False):
530530
assert self.n_lags > 0
531-
if self.forecast_in_focus is None:
531+
if self.highlight_forecast_step_n is None:
532532
if self.n_lags > 1:
533533
raise ValueError("Please define forecast_lag for sTPE computation")
534534
forecast_pos = 1
535535
else:
536-
forecast_pos = self.forecast_in_focus
536+
forecast_pos = self.highlight_forecast_step_n
537537
weights = self.model.ar_weights.detach().numpy()
538538
weights = weights[forecast_pos - 1, :][::-1]
539539
sTPE = utils.symmetric_total_percentage_error(self.true_ar_weights, weights)
@@ -552,8 +552,8 @@ def _evaluate(self, loader, verbose=None):
552552
if self.fitted is False: raise Exception('Model object needs to be fit first.')
553553
if verbose is None: verbose = self.verbose
554554
val_metrics = metrics.MetricsCollection([m.new() for m in self.metrics.batch_metrics])
555-
if self.forecast_in_focus is not None:
556-
val_metrics.add_specific_target(target_pos=self.forecast_in_focus - 1)
555+
if self.highlight_forecast_step_n is not None:
556+
val_metrics.add_specific_target(target_pos=self.highlight_forecast_step_n - 1)
557557
## Run
558558
val_metrics_dict = self._evaluate_epoch(loader, val_metrics)
559559

@@ -840,16 +840,16 @@ def set_true_ar_for_eval(self, true_ar_weights):
840840
"""
841841
self.true_ar_weights = true_ar_weights
842842

843-
def set_forecast_in_focus(self, forecast_number=None):
843+
def highlight_nth_step_ahead_of_each_forecast(self, step_number=None):
844844
"""Set which forecast step to focus on for metrics evaluation and plotting.
845845
846846
Args:
847-
forecast_number (int): i-th step ahead forecast to use for performance statistics evaluation.
848-
Can also be None.
847+
step_number (int): i-th step ahead forecast to use for statistics and plotting.
848+
default: None.
849849
"""
850-
if forecast_number is not None:
851-
assert forecast_number <= self.n_forecasts
852-
self.forecast_in_focus = forecast_number
850+
if step_number is not None:
851+
assert step_number <= self.n_forecasts
852+
self.highlight_forecast_step_n = step_number
853853
return self
854854

855855
def add_covariate(self, name, regularization=None, normalize='auto', only_last_value=False):
@@ -1029,7 +1029,7 @@ def plot(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10, 6)):
10291029
include_previous_forecasts=num_forecasts - 1, plot_history_data=True)
10301030
return plotting.plot(
10311031
fcst=fcst, ax=ax, xlabel=xlabel, ylabel=ylabel, figsize=figsize,
1032-
highlight_forecast=self.forecast_in_focus
1032+
highlight_forecast=self.highlight_forecast_step_n
10331033
)
10341034

10351035
def plot_last_forecast(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10, 6),
@@ -1058,7 +1058,7 @@ def plot_last_forecast(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10
10581058
fcst = utils.fcst_df_to_last_forecast(fcst, n_last=1 + include_previous_forecasts)
10591059
return plotting.plot(
10601060
fcst=fcst, ax=ax, xlabel=xlabel, ylabel=ylabel, figsize=figsize,
1061-
highlight_forecast=self.forecast_in_focus, line_per_origin=True,
1061+
highlight_forecast=self.highlight_forecast_step_n, line_per_origin=True,
10621062
)
10631063

10641064
def plot_components(self, fcst, figsize=None):
@@ -1077,7 +1077,7 @@ def plot_components(self, fcst, figsize=None):
10771077
m=self,
10781078
fcst=fcst,
10791079
figsize=figsize,
1080-
forecast_in_focus=self.forecast_in_focus,
1080+
forecast_in_focus=self.highlight_forecast_step_n,
10811081
)
10821082

10831083
def plot_parameters(self, weekly_start=0, yearly_start=0, figsize=None):
@@ -1095,7 +1095,7 @@ def plot_parameters(self, weekly_start=0, yearly_start=0, figsize=None):
10951095
"""
10961096
return plotting.plot_parameters(
10971097
m=self,
1098-
forecast_in_focus=self.forecast_in_focus,
1098+
forecast_in_focus=self.highlight_forecast_step_n,
10991099
weekly_start=weekly_start,
11001100
yearly_start=yearly_start,
11011101
figsize=figsize,

neuralprophet/test_debug.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_ar_net(verbose=True):
6161
weekly_seasonality=False,
6262
daily_seasonality=False,
6363
)
64-
m.set_forecast_in_focus(m.n_forecasts)
64+
m.highlight_nth_step_ahead_of_each_forecast(m.n_forecasts)
6565
m.fit(df, validate_each_epoch=True)
6666
future = m.compose_prediction_df(df, n_historic_predictions=len(df))
6767
forecast = m.predict(df=future)
@@ -123,7 +123,7 @@ def test_lag_reg(verbose=True):
123123
m = m.add_covariate(name='A')
124124
m = m.add_regressor(name='B')
125125
m = m.add_regressor(name='C')
126-
# m.set_forecast_in_focus(m.n_forecasts)
126+
# m.highlight_nth_step_ahead_of_each_forecast(m.n_forecasts)
127127
m.fit(df, validate_each_epoch=True)
128128
future = m.compose_prediction_df(df, n_historic_predictions=365)
129129
forecast = m.predict(future)
@@ -214,7 +214,7 @@ def test_plot(verbose=True):
214214
# daily_seasonality=False,
215215
)
216216
m.fit(df)
217-
m.set_forecast_in_focus(7)
217+
m.highlight_nth_step_ahead_of_each_forecast(7)
218218
future = m.compose_prediction_df(df, n_historic_predictions=10)
219219
forecast = m.predict(future)
220220
# print(future.to_string())

0 commit comments

Comments
 (0)