Skip to content

Commit 68da7d3

Browse files
author
Oskar Triebe
committed
changed name
1 parent 9b7854c commit 68da7d3

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

neuralprophet/neural_prophet.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def __init__(
195195
## set during prediction
196196
self.future_periods = None
197197
## later set by user (optional)
198-
self.forecast_in_focus = None
198+
self.highlight_forecast_step_n = None
199199
self.true_ar_weights = None
200200

201201
def _init_model(self):
@@ -495,8 +495,8 @@ def _train(self, df, df_val=None):
495495
loader = self._init_train_loader(df)
496496
val = df_val is not None
497497
## Metrics
498-
if self.forecast_in_focus is not None:
499-
self.metrics.add_specific_target(target_pos=self.forecast_in_focus - 1)
498+
if self.highlight_forecast_step_n is not None:
499+
self.metrics.add_specific_target(target_pos=self.highlight_forecast_step_n - 1)
500500
if self.normalize_y:
501501
self.metrics.set_shift_scale((self.data_params['y'].shift, self.data_params['y'].scale))
502502
if val:
@@ -526,12 +526,12 @@ def _train(self, df, df_val=None):
526526

527527
def _eval_true_ar(self, verbose=False):
528528
assert self.n_lags > 0
529-
if self.forecast_in_focus is None:
529+
if self.highlight_forecast_step_n is None:
530530
if self.n_lags > 1:
531531
raise ValueError("Please define forecast_lag for sTPE computation")
532532
forecast_pos = 1
533533
else:
534-
forecast_pos = self.forecast_in_focus
534+
forecast_pos = self.highlight_forecast_step_n
535535
weights = self.model.ar_weights.detach().numpy()
536536
weights = weights[forecast_pos - 1, :][::-1]
537537
sTPE = utils.symmetric_total_percentage_error(self.true_ar_weights, weights)
@@ -550,8 +550,8 @@ def _evaluate(self, loader, verbose=None):
550550
if self.fitted is False: raise Exception('Model object needs to be fit first.')
551551
if verbose is None: verbose = self.verbose
552552
val_metrics = metrics.MetricsCollection([m.new() for m in self.metrics.batch_metrics])
553-
if self.forecast_in_focus is not None:
554-
val_metrics.add_specific_target(target_pos=self.forecast_in_focus - 1)
553+
if self.highlight_forecast_step_n is not None:
554+
val_metrics.add_specific_target(target_pos=self.highlight_forecast_step_n - 1)
555555
## Run
556556
val_metrics_dict = self._evaluate_epoch(loader, val_metrics)
557557

@@ -838,16 +838,16 @@ def set_true_ar_for_eval(self, true_ar_weights):
838838
"""
839839
self.true_ar_weights = true_ar_weights
840840

841-
def set_forecast_in_focus(self, forecast_number=None):
841+
def highlight_nth_step_ahead_of_each_forecast(self, step_number=None):
842842
"""Set which forecast step to focus on for metrics evaluation and plotting.
843843
844844
Args:
845-
forecast_number (int): i-th step ahead forecast to use for performance statistics evaluation.
846-
Can also be None.
845+
step_number (int): i-th step ahead forecast to use for statistics and plotting.
846+
default: None.
847847
"""
848-
if forecast_number is not None:
849-
assert forecast_number <= self.n_forecasts
850-
self.forecast_in_focus = forecast_number
848+
if step_number is not None:
849+
assert step_number <= self.n_forecasts
850+
self.highlight_forecast_step_n = step_number
851851
return self
852852

853853
def add_covariate(self, name, regularization=None, normalize='auto', only_last_value=False):
@@ -1027,7 +1027,7 @@ def plot(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10, 6)):
10271027
include_previous_forecasts=num_forecasts - 1, plot_history_data=True)
10281028
return plotting.plot(
10291029
fcst=fcst, ax=ax, xlabel=xlabel, ylabel=ylabel, figsize=figsize,
1030-
highlight_forecast=self.forecast_in_focus
1030+
highlight_forecast=self.highlight_forecast_step_n
10311031
)
10321032

10331033
def plot_last_forecast(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10, 6),
@@ -1056,7 +1056,7 @@ def plot_last_forecast(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10
10561056
fcst = utils.fcst_df_to_last_forecast(fcst, n_last=1 + include_previous_forecasts)
10571057
return plotting.plot(
10581058
fcst=fcst, ax=ax, xlabel=xlabel, ylabel=ylabel, figsize=figsize,
1059-
highlight_forecast=self.forecast_in_focus, line_per_origin=True,
1059+
highlight_forecast=self.highlight_forecast_step_n, line_per_origin=True,
10601060
)
10611061

10621062
def plot_components(self, fcst, figsize=(10, 6)):
@@ -1074,7 +1074,7 @@ def plot_components(self, fcst, figsize=(10, 6)):
10741074
m=self,
10751075
fcst=fcst,
10761076
figsize=figsize,
1077-
forecast_in_focus=self.forecast_in_focus,
1077+
forecast_in_focus=self.highlight_forecast_step_n,
10781078
)
10791079

10801080
def plot_parameters(self, weekly_start=0, yearly_start=0, figsize=(10, 6)):
@@ -1091,7 +1091,7 @@ def plot_parameters(self, weekly_start=0, yearly_start=0, figsize=(10, 6)):
10911091
"""
10921092
return plotting.plot_parameters(
10931093
m=self,
1094-
forecast_in_focus=self.forecast_in_focus,
1094+
forecast_in_focus=self.highlight_forecast_step_n,
10951095
weekly_start=weekly_start,
10961096
yearly_start=yearly_start,
10971097
figsize=figsize,

neuralprophet/test_debug.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_ar_net(verbose=True):
6161
weekly_seasonality=False,
6262
daily_seasonality=False,
6363
)
64-
m.set_forecast_in_focus(m.n_forecasts)
64+
m.highlight_nth_step_ahead_of_each_forecast(m.n_forecasts)
6565
m.fit(df, validate_each_epoch=True)
6666
future = m.compose_prediction_df(df, n_historic_predictions=len(df))
6767
forecast = m.predict(df=future)
@@ -123,7 +123,7 @@ def test_lag_reg(verbose=True):
123123
m = m.add_covariate(name='A')
124124
m = m.add_regressor(name='B')
125125
m = m.add_regressor(name='C')
126-
# m.set_forecast_in_focus(m.n_forecasts)
126+
# m.highlight_nth_step_ahead_of_each_forecast(m.n_forecasts)
127127
m.fit(df, validate_each_epoch=True)
128128
future = m.compose_prediction_df(df, n_historic_predictions=365)
129129
forecast = m.predict(future)
@@ -214,7 +214,7 @@ def test_plot(verbose=True):
214214
# daily_seasonality=False,
215215
)
216216
m.fit(df)
217-
m.set_forecast_in_focus(7)
217+
m.highlight_nth_step_ahead_of_each_forecast(7)
218218
future = m.compose_prediction_df(df, n_historic_predictions=10)
219219
forecast = m.predict(future)
220220
# print(future.to_string())

0 commit comments

Comments
 (0)