@@ -106,14 +106,15 @@ def plot(fcst, ax=None, xlabel='ds', ylabel='y', highlight_forecast=None, line_p
106
106
return fig
107
107
108
108
109
- def plot_components (m , fcst , forecast_in_focus = None , figsize = ( 10 , 6 ) ):
109
+ def plot_components (m , fcst , forecast_in_focus = None , figsize = None ):
110
110
"""Plot the NeuralProphet forecast components.
111
111
112
112
Args:
113
113
m (NeuralProphet): fitted model.
114
114
fcst (pd.DataFrame): output of m.predict.
115
115
forecast_in_focus (int): n-th step ahead forecast AR-coefficients to plot
116
116
figsize (tuple): width, height in inches.
117
+ None (default): automatic (10, 3 * npanel)
117
118
118
119
Returns:
119
120
A matplotlib figure.
@@ -177,7 +178,7 @@ def plot_components(m, fcst, forecast_in_focus=None, figsize=(10, 6)):
177
178
'bar' : True })
178
179
179
180
npanel = len (components )
180
- figsize = figsize if figsize else (9 , 3 * npanel )
181
+ figsize = figsize if figsize else (10 , 3 * npanel )
181
182
fig , axes = plt .subplots (npanel , 1 , facecolor = 'w' , figsize = figsize )
182
183
if npanel == 1 :
183
184
axes = [axes ]
@@ -217,7 +218,8 @@ def plot_forecast_component(fcst, comp_name, plot_name=None, ax=None, figsize=(1
217
218
comp_name (str): Name of the component to plot.
218
219
plot_name (str): Name of the plot Title.
219
220
ax (matplotlib axis): matplotlib Axes to plot on.
220
- figsize (tuple): width, height in inches. default: (10, 6)
221
+ figsize (tuple): width, height in inches. Ignored if ax is not None.
222
+ default: (10, 6)
221
223
multiplicative (bool): set y axis as percentage
222
224
bar (bool): make barplot
223
225
rolling (int): rolling average underplot
@@ -267,7 +269,8 @@ def plot_multiforecast_component(fcst, comp_name, plot_name=None, ax=None, figsi
267
269
comp_name (str): Name of the component to plot.
268
270
plot_name (str): Name of the plot Title.
269
271
ax (matplotlib axis): matplotlib Axes to plot on.
270
- figsize (tuple): width, height in inches. default: (10, 6)
272
+ figsize (tuple): width, height in inches. Ignored if ax is not None.
273
+ default: (10, 6)
271
274
multiplicative (bool): set y axis as percentage
272
275
bar (bool): make barplot
273
276
focus (int): forecast number to portray in detail.
@@ -315,7 +318,7 @@ def plot_multiforecast_component(fcst, comp_name, plot_name=None, ax=None, figsi
315
318
return artists
316
319
317
320
318
- def plot_parameters (m , forecast_in_focus = None , weekly_start = 0 , yearly_start = 0 , figsize = ( 10 , 6 ) ):
321
+ def plot_parameters (m , forecast_in_focus = None , weekly_start = 0 , yearly_start = 0 , figsize = None ):
319
322
"""Plot the parameters that the model is composed of, visually.
320
323
321
324
Args:
@@ -327,7 +330,8 @@ def plot_parameters(m, forecast_in_focus=None, weekly_start=0, yearly_start=0, f
327
330
yearly_start (int): specifying the start day of the yearly seasonality plot.
328
331
0 (default) starts the year on Jan 1.
329
332
1 shifts by 1 day to Jan 2, and so on.
330
- figsize (tuple): width, height in inches.default: (10, 6)
333
+ figsize (tuple): width, height in inches.
334
+ None (default): automatic (10, 3 * npanel)
331
335
332
336
Returns:
333
337
A matplotlib figure.
@@ -412,7 +416,7 @@ def plot_parameters(m, forecast_in_focus=None, weekly_start=0, yearly_start=0, f
412
416
components .append ({'plot_name' : 'Multiplicative event' })
413
417
414
418
npanel = len (components )
415
- figsize = figsize if figsize else (9 , 3 * npanel )
419
+ figsize = figsize if figsize else (10 , 3 * npanel )
416
420
fig , axes = plt .subplots (npanel , 1 , facecolor = 'w' , figsize = figsize )
417
421
if npanel == 1 :
418
422
axes = [axes ]
@@ -462,7 +466,8 @@ def plot_trend_change(m, ax=None, plot_name='Trend Change', figsize=(10, 6)):
462
466
ax (matplotlib axis): matplotlib Axes to plot on.
463
467
One will be created if this is not provided.
464
468
plot_name (str): Name of the plot Title.
465
- figsize (tuple): width, height in inches. default: (10, 6)
469
+ figsize (tuple): width, height in inches. Ignored if ax is not None.
470
+ default: (10, 6)
466
471
467
472
Returns:
468
473
a list of matplotlib artists
@@ -490,7 +495,8 @@ def plot_trend(m, ax=None, plot_name='Trend', figsize=(10, 6)):
490
495
ax (matplotlib axis): matplotlib Axes to plot on.
491
496
One will be created if this is not provided.
492
497
plot_name (str): Name of the plot Title.
493
- figsize (tuple): width, height in inches.
498
+ figsize (tuple): width, height in inches. Ignored if ax is not None.
499
+ default: (10, 6)
494
500
495
501
Returns:
496
502
a list of matplotlib artists
@@ -534,7 +540,8 @@ def plot_scalar_weights(weights, plot_name, focus=None, ax=None, figsize=(10, 6)
534
540
One will be created if this is not provided.
535
541
focus (int): if provided, show weights for this forecast
536
542
None (default) plot average
537
- figsize (tuple): width, height in inches. default: (10, 6)
543
+ figsize (tuple): width, height in inches. Ignored if ax is not None.
544
+ default: (10, 6)
538
545
Returns:
539
546
a list of matplotlib artists
540
547
"""
@@ -560,7 +567,10 @@ def plot_scalar_weights(weights, plot_name, focus=None, ax=None, figsize=(10, 6)
560
567
artists += ax .bar (names , values , width = 0.8 , color = '#0072B2' )
561
568
ax .grid (True , which = 'major' , c = 'gray' , ls = '-' , lw = 1 , alpha = 0.2 )
562
569
ax .set_xlabel (plot_name + " name" )
563
- plt .xticks (rotation = 90 )
570
+ # only rotates last subplot!
571
+ # TODO fix
572
+ if len ("_" .join (names )) > 100 :
573
+ plt .xticks (rotation = 45 )
564
574
if focus is None :
565
575
ax .set_ylabel (plot_name + ' weight (avg)' )
566
576
else :
@@ -578,7 +588,8 @@ def plot_lagged_weights(weights, comp_name, focus=None, ax=None, figsize=(10, 6)
578
588
None (default) sum over all forecasts and plot as relative percentage
579
589
ax (matplotlib axis): matplotlib Axes to plot on.
580
590
One will be created if this is not provided.
581
- figsize (tuple): width, height in inches. default: (10, 6)
591
+ figsize (tuple): width, height in inches. Ignored if ax is not None.
592
+ default: (10, 6)
582
593
Returns:
583
594
a list of matplotlib artists
584
595
"""
@@ -606,7 +617,7 @@ def plot_lagged_weights(weights, comp_name, focus=None, ax=None, figsize=(10, 6)
606
617
return artists
607
618
608
619
609
- def plot_custom_season (m , ax = None , comp_name = None ):
620
+ def plot_custom_season ():
610
621
raise NotImplementedError
611
622
612
623
@@ -620,7 +631,8 @@ def plot_yearly(m, ax=None, yearly_start=0, figsize=(10, 6), comp_name='yearly')
620
631
yearly_start (int): specifying the start day of the yearly seasonality plot.
621
632
0 (default) starts the year on Jan 1.
622
633
1 shifts by 1 day to Jan 2, and so on.
623
- figsize (tuple): width, height in inches. default: (10, 6)
634
+ figsize (tuple): width, height in inches. Ignored if ax is not None.
635
+ default: (10, 6)
624
636
comp_name (str): Name of seasonality component if previously changed from default 'yearly'.
625
637
626
638
Returns:
@@ -656,7 +668,8 @@ def plot_weekly(m, ax=None, weekly_start=0, figsize=(10, 6), comp_name='weekly')
656
668
weekly_start (int): specifying the start day of the weekly seasonality plot.
657
669
0 (default) starts the week on Sunday.
658
670
1 shifts by 1 day to Monday, and so on.
659
- figsize (tuple): width, height in inches. default: (10, 6)
671
+ figsize (tuple): width, height in inches. Ignored if ax is not None.
672
+ default: (10, 6)
660
673
comp_name (str): Name of seasonality component if previously changed from default 'weekly'.
661
674
662
675
Returns:
@@ -680,3 +693,7 @@ def plot_weekly(m, ax=None, weekly_start=0, figsize=(10, 6), comp_name='weekly')
680
693
if m .season_config .mode == 'multiplicative' :
681
694
ax = set_y_as_percent (ax )
682
695
return artists
696
+
697
+
698
+ def plot_daily ():
699
+ raise NotImplementedError
0 commit comments