Skip to content

Commit 9b7854c

Browse files
authored
Merge pull request #69 from ourownstory/bugfixes-italo
small bugfixes
2 parents b74ac19 + 26780f6 commit 9b7854c

File tree

4 files changed

+96
-77
lines changed

4 files changed

+96
-77
lines changed

neuralprophet/neural_prophet.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(
150150
## Trend
151151
self.n_changepoints = n_changepoints
152152
self.trend_smoothness = trend_smoothness
153-
# self.growth = "linear" # Prophet Trend related, only linear currently implemented
153+
# self.growth = "linear" # OG Prophet Trend related, only linear currently implemented
154154
# if self.growth != 'linear':
155155
# raise NotImplementedError
156156
if self.n_changepoints > 0 and self.trend_smoothness > 0:
@@ -617,22 +617,22 @@ def test(self, df):
617617
val_metrics_df = self._evaluate(loader)
618618
return val_metrics_df
619619

620-
def compose_prediction_df(self, df, events_df=None, future_periods=None, n_history=0):
621-
assert n_history >= 0
620+
def compose_prediction_df(self, df, events_df=None, future_periods=None, n_historic_predictions=0):
621+
assert n_historic_predictions >= 0
622622
if future_periods is not None:
623623
assert future_periods >= 0
624-
if future_periods == 0 and n_history == 0:
624+
if future_periods == 0 and n_historic_predictions == 0:
625625
raise ValueError("Set either history or future to contain more than zero values.")
626626

627627
n_lags = 0 if self.n_lags is None else self.n_lags
628628

629629
if len(df) < n_lags:
630630
raise ValueError("Insufficient data for a prediction")
631-
elif len(df) < n_lags + n_history:
631+
elif len(df) < n_lags + n_historic_predictions:
632632
print("Warning: insufficient data for {} historic forecasts, reduced to {}.".format(
633-
n_history, len(df) - n_lags))
634-
n_history = len(df) - n_lags
635-
df = df[-(n_lags + n_history):]
633+
n_historic_predictions, len(df) - n_lags))
634+
n_historic_predictions = len(df) - n_lags
635+
df = df[-(n_lags + n_historic_predictions):]
636636

637637
if len(df) > 0:
638638
if len(df.columns) == 1 and 'ds' in df:
@@ -1012,7 +1012,7 @@ def plot(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10, 6)):
10121012
ax (matplotlib axes): Optional, matplotlib axes on which to plot.
10131013
xlabel (string): label name on X-axis
10141014
ylabel (string): label name on Y-axis
1015-
figsize (tuple): width, height in inches.
1015+
figsize (tuple): width, height in inches. default: (10, 6)
10161016
10171017
Returns:
10181018
A matplotlib figure.
@@ -1024,7 +1024,7 @@ def plot(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10, 6)):
10241024
"Plotting a line per forecast origin instead.")
10251025
return self.plot_last_forecast(
10261026
fcst, ax=ax, xlabel=xlabel, ylabel=ylabel, figsize=figsize,
1027-
include_previous_forecasts=num_forecasts - 1)
1027+
include_previous_forecasts=num_forecasts - 1, plot_history_data=True)
10281028
return plotting.plot(
10291029
fcst=fcst, ax=ax, xlabel=xlabel, ylabel=ylabel, figsize=figsize,
10301030
highlight_forecast=self.forecast_in_focus
@@ -1039,7 +1039,7 @@ def plot_last_forecast(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10
10391039
ax (matplotlib axes): Optional, matplotlib axes on which to plot.
10401040
xlabel (string): label name on X-axis
10411041
ylabel (string): label name on Y-axis
1042-
figsize (tuple): width, height in inches.
1042+
figsize (tuple): width, height in inches. default: (10, 6)
10431043
include_previous_forecasts (int): number of previous forecasts to include in plot
10441044
plot_history_data
10451045
Returns:
@@ -1056,38 +1056,36 @@ def plot_last_forecast(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10
10561056
fcst = utils.fcst_df_to_last_forecast(fcst, n_last=1 + include_previous_forecasts)
10571057
return plotting.plot(
10581058
fcst=fcst, ax=ax, xlabel=xlabel, ylabel=ylabel, figsize=figsize,
1059-
highlight_forecast=1
1059+
highlight_forecast=self.forecast_in_focus, line_per_origin=True,
10601060
)
10611061

1062-
def plot_components(self, fcst, crop_last_n=None, figsize=None):
1063-
"""Plot the Prophet forecast components.
1062+
def plot_components(self, fcst, figsize=(10, 6)):
1063+
"""Plot the NeuralProphet forecast components.
10641064
10651065
Args:
10661066
fcst (pd.DataFrame): output of self.predict
1067-
figsize (tuple): width, height in inches.
1067+
figsize (tuple): width, height in inches. default: (10, 6)
10681068
crop_last_n (int): number of samples to plot (combined future and past)
10691069
None (default) includes entire history. ignored for seasonality.
10701070
Returns:
10711071
A matplotlib figure.
10721072
"""
1073-
if crop_last_n is not None:
1074-
fcst = fcst[-crop_last_n:]
10751073
return plotting.plot_components(
10761074
m=self,
10771075
fcst=fcst,
10781076
figsize=figsize,
10791077
forecast_in_focus=self.forecast_in_focus,
10801078
)
10811079

1082-
def plot_parameters(self, weekly_start=0, yearly_start=0, figsize=None,):
1083-
"""Plot the Prophet forecast components.
1080+
def plot_parameters(self, weekly_start=0, yearly_start=0, figsize=(10, 6)):
1081+
"""Plot the NeuralProphet forecast components.
10841082
10851083
Args:
10861084
weekly_start (int): specifying the start day of the weekly seasonality plot.
10871085
0 (default) starts the week on Sunday. 1 shifts by 1 day to Monday, and so on.
10881086
yearly_start (int): specifying the start day of the yearly seasonality plot.
10891087
0 (default) starts the year on Jan 1. 1 shifts by 1 day to Jan 2, and so on.
1090-
figsize (tuple): width, height in inches.
1088+
figsize (tuple): width, height in inches. default: (10, 6)
10911089
Returns:
10921090
A matplotlib figure.
10931091
"""

neuralprophet/plotting_utils.py

Lines changed: 62 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pandas as pd
3+
import warnings
34

45
try:
56
from matplotlib import pyplot as plt
@@ -26,13 +27,18 @@ def set_y_as_percent(ax):
2627
Returns:
2728
ax
2829
"""
29-
yticks = 100 * ax.get_yticks()
30-
yticklabels = ['{0:.4g}%'.format(y) for y in yticks]
31-
ax.set_yticklabels(yticklabels)
32-
return ax
33-
34-
35-
def plot(fcst, ax=None, xlabel='ds', ylabel='y', highlight_forecast=None, figsize=(10, 6)):
30+
warnings.filterwarnings("error")
31+
try:
32+
yticks = 100 * ax.get_yticks()
33+
yticklabels = ['{0:.4g}%'.format(y) for y in yticks]
34+
ax.set_yticklabels(yticklabels)
35+
except UserWarning:
36+
pass # workaround until there is clear direction how to handle this recent matplotlib bug
37+
finally:
38+
return ax
39+
40+
41+
def plot(fcst, ax=None, xlabel='ds', ylabel='y', highlight_forecast=None, line_per_origin=False, figsize=(10, 6)):
3642
"""Plot the NeuralProphet forecast
3743
3844
Args:
@@ -46,34 +52,45 @@ def plot(fcst, ax=None, xlabel='ds', ylabel='y', highlight_forecast=None, figsiz
4652
Returns:
4753
A matplotlib figure.
4854
"""
55+
fcst = fcst.fillna(value=np.nan)
4956
if ax is None:
5057
fig = plt.figure(facecolor='w', figsize=figsize)
5158
ax = fig.add_subplot(111)
5259
else:
5360
fig = ax.get_figure()
5461
ds = fcst['ds'].dt.to_pydatetime()
5562
yhat_col_names = [col_name for col_name in fcst.columns if 'yhat' in col_name]
56-
for i in range(len(yhat_col_names)):
57-
ax.plot(ds, fcst['yhat{}'.format(i + 1)], ls='-', c='#0072B2', alpha=0.2 + 2.0/(i+2.5))
58-
# Future Todo: use fill_between for all but highlight_forecast
59-
"""
60-
col1 = 'yhat{}'.format(i+1)
61-
col2 = 'yhat{}'.format(i+2)
62-
no_na1 = fcst.copy()[col1].notnull().values
63-
no_na2 = fcst.copy()[col2].notnull().values
64-
no_na = [x1 and x2 for x1, x2 in zip(no_na1, no_na2)]
65-
fcst_na = fcst.copy()[no_na]
66-
fcst_na_t = fcst_na['ds'].dt.to_pydatetime()
67-
ax.fill_between(
68-
fcst_na_t,
69-
fcst_na[col1],
70-
fcst_na[col2],
71-
color='#0072B2', alpha=1.0/(i+1)
72-
)
73-
"""
63+
64+
if highlight_forecast is None or line_per_origin:
65+
for i in range(len(yhat_col_names)):
66+
ax.plot(ds, fcst['yhat{}'.format(i + 1)], ls='-', c='#0072B2', alpha=0.2 + 2.0 / (i + 2.5))
67+
# Future Todo: use fill_between for all but highlight_forecast
68+
"""
69+
col1 = 'yhat{}'.format(i+1)
70+
col2 = 'yhat{}'.format(i+2)
71+
no_na1 = fcst.copy()[col1].notnull().values
72+
no_na2 = fcst.copy()[col2].notnull().values
73+
no_na = [x1 and x2 for x1, x2 in zip(no_na1, no_na2)]
74+
fcst_na = fcst.copy()[no_na]
75+
fcst_na_t = fcst_na['ds'].dt.to_pydatetime()
76+
ax.fill_between(
77+
fcst_na_t,
78+
fcst_na[col1],
79+
fcst_na[col2],
80+
color='#0072B2', alpha=1.0/(i+1)
81+
)
82+
"""
7483
if highlight_forecast is not None:
75-
ax.plot(ds, fcst['yhat{}'.format(highlight_forecast)], ls='-', c='b')
76-
ax.plot(ds, fcst['yhat{}'.format(highlight_forecast)], 'bx')
84+
if line_per_origin:
85+
num_forecast_steps = sum(fcst['yhat1'].notna())
86+
steps_from_last = num_forecast_steps - highlight_forecast
87+
for i in range(len(yhat_col_names)):
88+
x = ds[-(1 + i + steps_from_last)]
89+
y = fcst['yhat{}'.format(i + 1)].values[-(1 + i + steps_from_last)]
90+
ax.plot(x, y, 'bx')
91+
else:
92+
ax.plot(ds, fcst['yhat{}'.format(highlight_forecast)], ls='-', c='b')
93+
ax.plot(ds, fcst['yhat{}'.format(highlight_forecast)], 'bx')
7794

7895
ax.plot(ds, fcst['y'], 'k.')
7996

@@ -89,7 +106,7 @@ def plot(fcst, ax=None, xlabel='ds', ylabel='y', highlight_forecast=None, figsiz
89106
return fig
90107

91108

92-
def plot_components(m, fcst, forecast_in_focus=None, figsize=None):
109+
def plot_components(m, fcst, forecast_in_focus=None, figsize=(10, 6)):
93110
"""Plot the NeuralProphet forecast components.
94111
95112
Args:
@@ -101,6 +118,7 @@ def plot_components(m, fcst, forecast_in_focus=None, figsize=None):
101118
Returns:
102119
A matplotlib figure.
103120
"""
121+
fcst = fcst.fillna(value=np.nan)
104122
# Identify components to be plotted
105123
# as dict, minimum: {plot_name, comp_name}
106124
components = [{'plot_name': 'Trend',
@@ -123,8 +141,8 @@ def plot_components(m, fcst, forecast_in_focus=None, figsize=None):
123141
'bar': True})
124142
else:
125143
components.append({'plot_name': 'AR ({})-ahead'.format(forecast_in_focus),
126-
'comp_name': 'ar{}'.format(forecast_in_focus),
127-
'add_x': True})
144+
'comp_name': 'ar{}'.format(forecast_in_focus), })
145+
# 'add_x': True})
128146

129147
# Add Covariates
130148
if m.covar_config is not None:
@@ -136,8 +154,8 @@ def plot_components(m, fcst, forecast_in_focus=None, figsize=None):
136154
'bar': True})
137155
else:
138156
components.append({'plot_name': 'COV "{}" ({})-ahead'.format(name, forecast_in_focus),
139-
'comp_name': 'covar_{}{}'.format(name, forecast_in_focus),
140-
'add_x': True})
157+
'comp_name': 'covar_{}{}'.format(name, forecast_in_focus), })
158+
# 'add_x': True})
141159
# Add Events
142160
if 'events_additive' in fcst.columns:
143161
components.append({'plot_name': 'Additive Events',
@@ -156,7 +174,7 @@ def plot_components(m, fcst, forecast_in_focus=None, figsize=None):
156174
elif fcst['residual{}'.format(forecast_in_focus)].count() > 0:
157175
components.append({'plot_name': 'Residuals ({})-ahead'.format(forecast_in_focus),
158176
'comp_name': 'residual{}'.format(forecast_in_focus),
159-
'add_x': True})
177+
'bar': True})
160178

161179
npanel = len(components)
162180
figsize = figsize if figsize else (9, 3 * npanel)
@@ -199,14 +217,15 @@ def plot_forecast_component(fcst, comp_name, plot_name=None, ax=None, figsize=(1
199217
comp_name (str): Name of the component to plot.
200218
plot_name (str): Name of the plot Title.
201219
ax (matplotlib axis): matplotlib Axes to plot on.
202-
figsize (tuple): width, height in inches.
220+
figsize (tuple): width, height in inches. default: (10, 6)
203221
multiplicative (bool): set y axis as percentage
204222
bar (bool): make barplot
205223
rolling (int): rolling average underplot
206224
207225
Returns:
208226
a list of matplotlib artists
209227
"""
228+
fcst = fcst.fillna(value=np.nan)
210229
artists = []
211230
if not ax:
212231
fig = plt.figure(facecolor='w', figsize=figsize)
@@ -224,7 +243,7 @@ def plot_forecast_component(fcst, comp_name, plot_name=None, ax=None, figsize=(1
224243
artists += ax.bar(fcst_t, fcst[comp_name], width=1.00, color='#0072B2')
225244
else:
226245
artists += ax.plot(fcst_t, fcst[comp_name], ls='-', c='#0072B2')
227-
if add_x:
246+
if add_x or sum(fcst[comp_name].notna()) == 1:
228247
artists += ax.plot(fcst_t, fcst[comp_name], 'bx')
229248
# Specify formatting to workaround matplotlib issue #12925
230249
locator = AutoDateLocator(interval_multiples=False)
@@ -248,7 +267,7 @@ def plot_multiforecast_component(fcst, comp_name, plot_name=None, ax=None, figsi
248267
comp_name (str): Name of the component to plot.
249268
plot_name (str): Name of the plot Title.
250269
ax (matplotlib axis): matplotlib Axes to plot on.
251-
figsize (tuple): width, height in inches.
270+
figsize (tuple): width, height in inches. default: (10, 6)
252271
multiplicative (bool): set y axis as percentage
253272
bar (bool): make barplot
254273
focus (int): forecast number to portray in detail.
@@ -296,7 +315,7 @@ def plot_multiforecast_component(fcst, comp_name, plot_name=None, ax=None, figsi
296315
return artists
297316

298317

299-
def plot_parameters(m, forecast_in_focus=None, weekly_start=0, yearly_start=0, figsize=None,):
318+
def plot_parameters(m, forecast_in_focus=None, weekly_start=0, yearly_start=0, figsize=(10, 6)):
300319
"""Plot the parameters that the model is composed of, visually.
301320
302321
Args:
@@ -308,7 +327,7 @@ def plot_parameters(m, forecast_in_focus=None, weekly_start=0, yearly_start=0, f
308327
yearly_start (int): specifying the start day of the yearly seasonality plot.
309328
0 (default) starts the year on Jan 1.
310329
1 shifts by 1 day to Jan 2, and so on.
311-
figsize (tuple): width, height in inches.
330+
figsize (tuple): width, height in inches.default: (10, 6)
312331
313332
Returns:
314333
A matplotlib figure.
@@ -443,7 +462,7 @@ def plot_trend_change(m, ax=None, plot_name='Trend Change', figsize=(10, 6)):
443462
ax (matplotlib axis): matplotlib Axes to plot on.
444463
One will be created if this is not provided.
445464
plot_name (str): Name of the plot Title.
446-
figsize (tuple): width, height in inches.
465+
figsize (tuple): width, height in inches. default: (10, 6)
447466
448467
Returns:
449468
a list of matplotlib artists
@@ -515,7 +534,7 @@ def plot_scalar_weights(weights, plot_name, focus=None, ax=None, figsize=(10, 6)
515534
One will be created if this is not provided.
516535
focus (int): if provided, show weights for this forecast
517536
None (default) plot average
518-
figsize (tuple): width, height in inches.
537+
figsize (tuple): width, height in inches. default: (10, 6)
519538
Returns:
520539
a list of matplotlib artists
521540
"""
@@ -559,7 +578,7 @@ def plot_lagged_weights(weights, comp_name, focus=None, ax=None, figsize=(10, 6)
559578
None (default) sum over all forecasts and plot as relative percentage
560579
ax (matplotlib axis): matplotlib Axes to plot on.
561580
One will be created if this is not provided.
562-
figsize (tuple): width, height in inches.
581+
figsize (tuple): width, height in inches. default: (10, 6)
563582
Returns:
564583
a list of matplotlib artists
565584
"""
@@ -601,7 +620,7 @@ def plot_yearly(m, ax=None, yearly_start=0, figsize=(10, 6), comp_name='yearly')
601620
yearly_start (int): specifying the start day of the yearly seasonality plot.
602621
0 (default) starts the year on Jan 1.
603622
1 shifts by 1 day to Jan 2, and so on.
604-
figsize (tuple): width, height in inches.
623+
figsize (tuple): width, height in inches. default: (10, 6)
605624
comp_name (str): Name of seasonality component if previously changed from default 'yearly'.
606625
607626
Returns:
@@ -637,7 +656,7 @@ def plot_weekly(m, ax=None, weekly_start=0, figsize=(10, 6), comp_name='weekly')
637656
weekly_start (int): specifying the start day of the weekly seasonality plot.
638657
0 (default) starts the week on Sunday.
639658
1 shifts by 1 day to Monday, and so on.
640-
figsize (tuple): width, height in inches.
659+
figsize (tuple): width, height in inches. default: (10, 6)
641660
comp_name (str): Name of seasonality component if previously changed from default 'weekly'.
642661
643662
Returns:

0 commit comments

Comments
 (0)