1
1
import numpy as np
2
2
import pandas as pd
3
+ import warnings
3
4
4
5
try :
5
6
from matplotlib import pyplot as plt
@@ -26,13 +27,18 @@ def set_y_as_percent(ax):
26
27
Returns:
27
28
ax
28
29
"""
29
- yticks = 100 * ax .get_yticks ()
30
- yticklabels = ['{0:.4g}%' .format (y ) for y in yticks ]
31
- ax .set_yticklabels (yticklabels )
32
- return ax
33
-
34
-
35
- def plot (fcst , ax = None , xlabel = 'ds' , ylabel = 'y' , highlight_forecast = None , figsize = (10 , 6 )):
30
+ warnings .filterwarnings ("error" )
31
+ try :
32
+ yticks = 100 * ax .get_yticks ()
33
+ yticklabels = ['{0:.4g}%' .format (y ) for y in yticks ]
34
+ ax .set_yticklabels (yticklabels )
35
+ except UserWarning :
36
+ pass # workaround until there is clear direction how to handle this recent matplotlib bug
37
+ finally :
38
+ return ax
39
+
40
+
41
+ def plot (fcst , ax = None , xlabel = 'ds' , ylabel = 'y' , highlight_forecast = None , line_per_origin = False , figsize = (10 , 6 )):
36
42
"""Plot the NeuralProphet forecast
37
43
38
44
Args:
@@ -46,34 +52,45 @@ def plot(fcst, ax=None, xlabel='ds', ylabel='y', highlight_forecast=None, figsiz
46
52
Returns:
47
53
A matplotlib figure.
48
54
"""
55
+ fcst = fcst .fillna (value = np .nan )
49
56
if ax is None :
50
57
fig = plt .figure (facecolor = 'w' , figsize = figsize )
51
58
ax = fig .add_subplot (111 )
52
59
else :
53
60
fig = ax .get_figure ()
54
61
ds = fcst ['ds' ].dt .to_pydatetime ()
55
62
yhat_col_names = [col_name for col_name in fcst .columns if 'yhat' in col_name ]
56
- for i in range (len (yhat_col_names )):
57
- ax .plot (ds , fcst ['yhat{}' .format (i + 1 )], ls = '-' , c = '#0072B2' , alpha = 0.2 + 2.0 / (i + 2.5 ))
58
- # Future Todo: use fill_between for all but highlight_forecast
59
- """
60
- col1 = 'yhat{}'.format(i+1)
61
- col2 = 'yhat{}'.format(i+2)
62
- no_na1 = fcst.copy()[col1].notnull().values
63
- no_na2 = fcst.copy()[col2].notnull().values
64
- no_na = [x1 and x2 for x1, x2 in zip(no_na1, no_na2)]
65
- fcst_na = fcst.copy()[no_na]
66
- fcst_na_t = fcst_na['ds'].dt.to_pydatetime()
67
- ax.fill_between(
68
- fcst_na_t,
69
- fcst_na[col1],
70
- fcst_na[col2],
71
- color='#0072B2', alpha=1.0/(i+1)
72
- )
73
- """
63
+
64
+ if highlight_forecast is None or line_per_origin :
65
+ for i in range (len (yhat_col_names )):
66
+ ax .plot (ds , fcst ['yhat{}' .format (i + 1 )], ls = '-' , c = '#0072B2' , alpha = 0.2 + 2.0 / (i + 2.5 ))
67
+ # Future Todo: use fill_between for all but highlight_forecast
68
+ """
69
+ col1 = 'yhat{}'.format(i+1)
70
+ col2 = 'yhat{}'.format(i+2)
71
+ no_na1 = fcst.copy()[col1].notnull().values
72
+ no_na2 = fcst.copy()[col2].notnull().values
73
+ no_na = [x1 and x2 for x1, x2 in zip(no_na1, no_na2)]
74
+ fcst_na = fcst.copy()[no_na]
75
+ fcst_na_t = fcst_na['ds'].dt.to_pydatetime()
76
+ ax.fill_between(
77
+ fcst_na_t,
78
+ fcst_na[col1],
79
+ fcst_na[col2],
80
+ color='#0072B2', alpha=1.0/(i+1)
81
+ )
82
+ """
74
83
if highlight_forecast is not None :
75
- ax .plot (ds , fcst ['yhat{}' .format (highlight_forecast )], ls = '-' , c = 'b' )
76
- ax .plot (ds , fcst ['yhat{}' .format (highlight_forecast )], 'bx' )
84
+ if line_per_origin :
85
+ num_forecast_steps = sum (fcst ['yhat1' ].notna ())
86
+ steps_from_last = num_forecast_steps - highlight_forecast
87
+ for i in range (len (yhat_col_names )):
88
+ x = ds [- (1 + i + steps_from_last )]
89
+ y = fcst ['yhat{}' .format (i + 1 )].values [- (1 + i + steps_from_last )]
90
+ ax .plot (x , y , 'bx' )
91
+ else :
92
+ ax .plot (ds , fcst ['yhat{}' .format (highlight_forecast )], ls = '-' , c = 'b' )
93
+ ax .plot (ds , fcst ['yhat{}' .format (highlight_forecast )], 'bx' )
77
94
78
95
ax .plot (ds , fcst ['y' ], 'k.' )
79
96
@@ -89,7 +106,7 @@ def plot(fcst, ax=None, xlabel='ds', ylabel='y', highlight_forecast=None, figsiz
89
106
return fig
90
107
91
108
92
- def plot_components (m , fcst , forecast_in_focus = None , figsize = None ):
109
+ def plot_components (m , fcst , forecast_in_focus = None , figsize = ( 10 , 6 ) ):
93
110
"""Plot the NeuralProphet forecast components.
94
111
95
112
Args:
@@ -101,6 +118,7 @@ def plot_components(m, fcst, forecast_in_focus=None, figsize=None):
101
118
Returns:
102
119
A matplotlib figure.
103
120
"""
121
+ fcst = fcst .fillna (value = np .nan )
104
122
# Identify components to be plotted
105
123
# as dict, minimum: {plot_name, comp_name}
106
124
components = [{'plot_name' : 'Trend' ,
@@ -123,8 +141,8 @@ def plot_components(m, fcst, forecast_in_focus=None, figsize=None):
123
141
'bar' : True })
124
142
else :
125
143
components .append ({'plot_name' : 'AR ({})-ahead' .format (forecast_in_focus ),
126
- 'comp_name' : 'ar{}' .format (forecast_in_focus ),
127
- 'add_x' : True })
144
+ 'comp_name' : 'ar{}' .format (forecast_in_focus ), })
145
+ # 'add_x': True})
128
146
129
147
# Add Covariates
130
148
if m .covar_config is not None :
@@ -136,8 +154,8 @@ def plot_components(m, fcst, forecast_in_focus=None, figsize=None):
136
154
'bar' : True })
137
155
else :
138
156
components .append ({'plot_name' : 'COV "{}" ({})-ahead' .format (name , forecast_in_focus ),
139
- 'comp_name' : 'covar_{}{}' .format (name , forecast_in_focus ),
140
- 'add_x' : True })
157
+ 'comp_name' : 'covar_{}{}' .format (name , forecast_in_focus ), })
158
+ # 'add_x': True})
141
159
# Add Events
142
160
if 'events_additive' in fcst .columns :
143
161
components .append ({'plot_name' : 'Additive Events' ,
@@ -156,7 +174,7 @@ def plot_components(m, fcst, forecast_in_focus=None, figsize=None):
156
174
elif fcst ['residual{}' .format (forecast_in_focus )].count () > 0 :
157
175
components .append ({'plot_name' : 'Residuals ({})-ahead' .format (forecast_in_focus ),
158
176
'comp_name' : 'residual{}' .format (forecast_in_focus ),
159
- 'add_x ' : True })
177
+ 'bar ' : True })
160
178
161
179
npanel = len (components )
162
180
figsize = figsize if figsize else (9 , 3 * npanel )
@@ -199,14 +217,15 @@ def plot_forecast_component(fcst, comp_name, plot_name=None, ax=None, figsize=(1
199
217
comp_name (str): Name of the component to plot.
200
218
plot_name (str): Name of the plot Title.
201
219
ax (matplotlib axis): matplotlib Axes to plot on.
202
- figsize (tuple): width, height in inches.
220
+ figsize (tuple): width, height in inches. default: (10, 6)
203
221
multiplicative (bool): set y axis as percentage
204
222
bar (bool): make barplot
205
223
rolling (int): rolling average underplot
206
224
207
225
Returns:
208
226
a list of matplotlib artists
209
227
"""
228
+ fcst = fcst .fillna (value = np .nan )
210
229
artists = []
211
230
if not ax :
212
231
fig = plt .figure (facecolor = 'w' , figsize = figsize )
@@ -224,7 +243,7 @@ def plot_forecast_component(fcst, comp_name, plot_name=None, ax=None, figsize=(1
224
243
artists += ax .bar (fcst_t , fcst [comp_name ], width = 1.00 , color = '#0072B2' )
225
244
else :
226
245
artists += ax .plot (fcst_t , fcst [comp_name ], ls = '-' , c = '#0072B2' )
227
- if add_x :
246
+ if add_x or sum ( fcst [ comp_name ]. notna ()) == 1 :
228
247
artists += ax .plot (fcst_t , fcst [comp_name ], 'bx' )
229
248
# Specify formatting to workaround matplotlib issue #12925
230
249
locator = AutoDateLocator (interval_multiples = False )
@@ -248,7 +267,7 @@ def plot_multiforecast_component(fcst, comp_name, plot_name=None, ax=None, figsi
248
267
comp_name (str): Name of the component to plot.
249
268
plot_name (str): Name of the plot Title.
250
269
ax (matplotlib axis): matplotlib Axes to plot on.
251
- figsize (tuple): width, height in inches.
270
+ figsize (tuple): width, height in inches. default: (10, 6)
252
271
multiplicative (bool): set y axis as percentage
253
272
bar (bool): make barplot
254
273
focus (int): forecast number to portray in detail.
@@ -296,7 +315,7 @@ def plot_multiforecast_component(fcst, comp_name, plot_name=None, ax=None, figsi
296
315
return artists
297
316
298
317
299
- def plot_parameters (m , forecast_in_focus = None , weekly_start = 0 , yearly_start = 0 , figsize = None , ):
318
+ def plot_parameters (m , forecast_in_focus = None , weekly_start = 0 , yearly_start = 0 , figsize = ( 10 , 6 ) ):
300
319
"""Plot the parameters that the model is composed of, visually.
301
320
302
321
Args:
@@ -308,7 +327,7 @@ def plot_parameters(m, forecast_in_focus=None, weekly_start=0, yearly_start=0, f
308
327
yearly_start (int): specifying the start day of the yearly seasonality plot.
309
328
0 (default) starts the year on Jan 1.
310
329
1 shifts by 1 day to Jan 2, and so on.
311
- figsize (tuple): width, height in inches.
330
+ figsize (tuple): width, height in inches.default: (10, 6)
312
331
313
332
Returns:
314
333
A matplotlib figure.
@@ -443,7 +462,7 @@ def plot_trend_change(m, ax=None, plot_name='Trend Change', figsize=(10, 6)):
443
462
ax (matplotlib axis): matplotlib Axes to plot on.
444
463
One will be created if this is not provided.
445
464
plot_name (str): Name of the plot Title.
446
- figsize (tuple): width, height in inches.
465
+ figsize (tuple): width, height in inches. default: (10, 6)
447
466
448
467
Returns:
449
468
a list of matplotlib artists
@@ -515,7 +534,7 @@ def plot_scalar_weights(weights, plot_name, focus=None, ax=None, figsize=(10, 6)
515
534
One will be created if this is not provided.
516
535
focus (int): if provided, show weights for this forecast
517
536
None (default) plot average
518
- figsize (tuple): width, height in inches.
537
+ figsize (tuple): width, height in inches. default: (10, 6)
519
538
Returns:
520
539
a list of matplotlib artists
521
540
"""
@@ -559,7 +578,7 @@ def plot_lagged_weights(weights, comp_name, focus=None, ax=None, figsize=(10, 6)
559
578
None (default) sum over all forecasts and plot as relative percentage
560
579
ax (matplotlib axis): matplotlib Axes to plot on.
561
580
One will be created if this is not provided.
562
- figsize (tuple): width, height in inches.
581
+ figsize (tuple): width, height in inches. default: (10, 6)
563
582
Returns:
564
583
a list of matplotlib artists
565
584
"""
@@ -601,7 +620,7 @@ def plot_yearly(m, ax=None, yearly_start=0, figsize=(10, 6), comp_name='yearly')
601
620
yearly_start (int): specifying the start day of the yearly seasonality plot.
602
621
0 (default) starts the year on Jan 1.
603
622
1 shifts by 1 day to Jan 2, and so on.
604
- figsize (tuple): width, height in inches.
623
+ figsize (tuple): width, height in inches. default: (10, 6)
605
624
comp_name (str): Name of seasonality component if previously changed from default 'yearly'.
606
625
607
626
Returns:
@@ -637,7 +656,7 @@ def plot_weekly(m, ax=None, weekly_start=0, figsize=(10, 6), comp_name='weekly')
637
656
weekly_start (int): specifying the start day of the weekly seasonality plot.
638
657
0 (default) starts the week on Sunday.
639
658
1 shifts by 1 day to Monday, and so on.
640
- figsize (tuple): width, height in inches.
659
+ figsize (tuple): width, height in inches. default: (10, 6)
641
660
comp_name (str): Name of seasonality component if previously changed from default 'weekly'.
642
661
643
662
Returns:
0 commit comments