Skip to content

Commit d90aadf

Browse files
authored
Merge pull request #66 from ourownstory/events_regularization
regularization for events
2 parents 9b7854c + 8315ba5 commit d90aadf

File tree

6 files changed

+102
-51
lines changed

6 files changed

+102
-51
lines changed

neuralprophet/neural_prophet.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,11 @@ def _add_batch_regualarizations(self, loss, reg_lambda_ar):
460460
reg_loss += l_season * reg_season
461461
loss += l_season * reg_season
462462

463-
# Regularize holidays: sparsify holiday features coefficients
463+
# Regularize events: sparsify events features coefficients
464464
if self.events_config is not None or self.country_holidays_config is not None:
465-
pass
465+
reg_events_loss = utils.reg_func_events(self.events_config, self.country_holidays_config, self.model)
466+
reg_loss += reg_events_loss
467+
loss += reg_events_loss
466468

467469
return loss, reg_loss
468470

@@ -1059,12 +1061,13 @@ def plot_last_forecast(self, fcst, ax=None, xlabel='ds', ylabel='y', figsize=(10
10591061
highlight_forecast=self.forecast_in_focus, line_per_origin=True,
10601062
)
10611063

1062-
def plot_components(self, fcst, figsize=(10, 6)):
1064+
def plot_components(self, fcst, figsize=None):
10631065
"""Plot the NeuralProphet forecast components.
10641066
10651067
Args:
10661068
fcst (pd.DataFrame): output of self.predict
1067-
figsize (tuple): width, height in inches. default: (10, 6)
1069+
figsize (tuple): width, height in inches.
1070+
None (default): automatic (10, 3 * npanel)
10681071
crop_last_n (int): number of samples to plot (combined future and past)
10691072
None (default) includes entire history. ignored for seasonality.
10701073
Returns:
@@ -1077,15 +1080,16 @@ def plot_components(self, fcst, figsize=(10, 6)):
10771080
forecast_in_focus=self.forecast_in_focus,
10781081
)
10791082

1080-
def plot_parameters(self, weekly_start=0, yearly_start=0, figsize=(10, 6)):
1083+
def plot_parameters(self, weekly_start=0, yearly_start=0, figsize=None):
10811084
"""Plot the NeuralProphet forecast components.
10821085
10831086
Args:
10841087
weekly_start (int): specifying the start day of the weekly seasonality plot.
10851088
0 (default) starts the week on Sunday. 1 shifts by 1 day to Monday, and so on.
10861089
yearly_start (int): specifying the start day of the yearly seasonality plot.
10871090
0 (default) starts the year on Jan 1. 1 shifts by 1 day to Jan 2, and so on.
1088-
figsize (tuple): width, height in inches. default: (10, 6)
1091+
figsize (tuple): width, height in inches.
1092+
None (default): automatic (10, 3 * npanel)
10891093
Returns:
10901094
A matplotlib figure.
10911095
"""

neuralprophet/plotting_utils.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,15 @@ def plot(fcst, ax=None, xlabel='ds', ylabel='y', highlight_forecast=None, line_p
106106
return fig
107107

108108

109-
def plot_components(m, fcst, forecast_in_focus=None, figsize=(10, 6)):
109+
def plot_components(m, fcst, forecast_in_focus=None, figsize=None):
110110
"""Plot the NeuralProphet forecast components.
111111
112112
Args:
113113
m (NeuralProphet): fitted model.
114114
fcst (pd.DataFrame): output of m.predict.
115115
forecast_in_focus (int): n-th step ahead forecast AR-coefficients to plot
116116
figsize (tuple): width, height in inches.
117+
None (default): automatic (10, 3 * npanel)
117118
118119
Returns:
119120
A matplotlib figure.
@@ -177,7 +178,7 @@ def plot_components(m, fcst, forecast_in_focus=None, figsize=(10, 6)):
177178
'bar': True})
178179

179180
npanel = len(components)
180-
figsize = figsize if figsize else (9, 3 * npanel)
181+
figsize = figsize if figsize else (10, 3 * npanel)
181182
fig, axes = plt.subplots(npanel, 1, facecolor='w', figsize=figsize)
182183
if npanel == 1:
183184
axes = [axes]
@@ -217,7 +218,8 @@ def plot_forecast_component(fcst, comp_name, plot_name=None, ax=None, figsize=(1
217218
comp_name (str): Name of the component to plot.
218219
plot_name (str): Name of the plot Title.
219220
ax (matplotlib axis): matplotlib Axes to plot on.
220-
figsize (tuple): width, height in inches. default: (10, 6)
221+
figsize (tuple): width, height in inches. Ignored if ax is not None.
222+
default: (10, 6)
221223
multiplicative (bool): set y axis as percentage
222224
bar (bool): make barplot
223225
rolling (int): rolling average underplot
@@ -267,7 +269,8 @@ def plot_multiforecast_component(fcst, comp_name, plot_name=None, ax=None, figsi
267269
comp_name (str): Name of the component to plot.
268270
plot_name (str): Name of the plot Title.
269271
ax (matplotlib axis): matplotlib Axes to plot on.
270-
figsize (tuple): width, height in inches. default: (10, 6)
272+
figsize (tuple): width, height in inches. Ignored if ax is not None.
273+
default: (10, 6)
271274
multiplicative (bool): set y axis as percentage
272275
bar (bool): make barplot
273276
focus (int): forecast number to portray in detail.
@@ -315,7 +318,7 @@ def plot_multiforecast_component(fcst, comp_name, plot_name=None, ax=None, figsi
315318
return artists
316319

317320

318-
def plot_parameters(m, forecast_in_focus=None, weekly_start=0, yearly_start=0, figsize=(10, 6)):
321+
def plot_parameters(m, forecast_in_focus=None, weekly_start=0, yearly_start=0, figsize=None):
319322
"""Plot the parameters that the model is composed of, visually.
320323
321324
Args:
@@ -327,7 +330,8 @@ def plot_parameters(m, forecast_in_focus=None, weekly_start=0, yearly_start=0, f
327330
yearly_start (int): specifying the start day of the yearly seasonality plot.
328331
0 (default) starts the year on Jan 1.
329332
1 shifts by 1 day to Jan 2, and so on.
330-
figsize (tuple): width, height in inches.default: (10, 6)
333+
figsize (tuple): width, height in inches.
334+
None (default): automatic (10, 3 * npanel)
331335
332336
Returns:
333337
A matplotlib figure.
@@ -412,7 +416,7 @@ def plot_parameters(m, forecast_in_focus=None, weekly_start=0, yearly_start=0, f
412416
components.append({'plot_name': 'Multiplicative event'})
413417

414418
npanel = len(components)
415-
figsize = figsize if figsize else (9, 3 * npanel)
419+
figsize = figsize if figsize else (10, 3 * npanel)
416420
fig, axes = plt.subplots(npanel, 1, facecolor='w', figsize=figsize)
417421
if npanel == 1:
418422
axes = [axes]
@@ -462,7 +466,8 @@ def plot_trend_change(m, ax=None, plot_name='Trend Change', figsize=(10, 6)):
462466
ax (matplotlib axis): matplotlib Axes to plot on.
463467
One will be created if this is not provided.
464468
plot_name (str): Name of the plot Title.
465-
figsize (tuple): width, height in inches. default: (10, 6)
469+
figsize (tuple): width, height in inches. Ignored if ax is not None.
470+
default: (10, 6)
466471
467472
Returns:
468473
a list of matplotlib artists
@@ -490,7 +495,8 @@ def plot_trend(m, ax=None, plot_name='Trend', figsize=(10, 6)):
490495
ax (matplotlib axis): matplotlib Axes to plot on.
491496
One will be created if this is not provided.
492497
plot_name (str): Name of the plot Title.
493-
figsize (tuple): width, height in inches.
498+
figsize (tuple): width, height in inches. Ignored if ax is not None.
499+
default: (10, 6)
494500
495501
Returns:
496502
a list of matplotlib artists
@@ -534,7 +540,8 @@ def plot_scalar_weights(weights, plot_name, focus=None, ax=None, figsize=(10, 6)
534540
One will be created if this is not provided.
535541
focus (int): if provided, show weights for this forecast
536542
None (default) plot average
537-
figsize (tuple): width, height in inches. default: (10, 6)
543+
figsize (tuple): width, height in inches. Ignored if ax is not None.
544+
default: (10, 6)
538545
Returns:
539546
a list of matplotlib artists
540547
"""
@@ -560,7 +567,10 @@ def plot_scalar_weights(weights, plot_name, focus=None, ax=None, figsize=(10, 6)
560567
artists += ax.bar(names, values, width=0.8, color='#0072B2')
561568
ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
562569
ax.set_xlabel(plot_name + " name")
563-
plt.xticks(rotation=90)
570+
# only rotates last subplot!
571+
# TODO fix
572+
if len("_".join(names)) > 100:
573+
plt.xticks(rotation=45)
564574
if focus is None:
565575
ax.set_ylabel(plot_name + ' weight (avg)')
566576
else:
@@ -578,7 +588,8 @@ def plot_lagged_weights(weights, comp_name, focus=None, ax=None, figsize=(10, 6)
578588
None (default) sum over all forecasts and plot as relative percentage
579589
ax (matplotlib axis): matplotlib Axes to plot on.
580590
One will be created if this is not provided.
581-
figsize (tuple): width, height in inches. default: (10, 6)
591+
figsize (tuple): width, height in inches. Ignored if ax is not None.
592+
default: (10, 6)
582593
Returns:
583594
a list of matplotlib artists
584595
"""
@@ -606,7 +617,7 @@ def plot_lagged_weights(weights, comp_name, focus=None, ax=None, figsize=(10, 6)
606617
return artists
607618

608619

609-
def plot_custom_season(m, ax=None, comp_name=None):
620+
def plot_custom_season():
610621
raise NotImplementedError
611622

612623

@@ -620,7 +631,8 @@ def plot_yearly(m, ax=None, yearly_start=0, figsize=(10, 6), comp_name='yearly')
620631
yearly_start (int): specifying the start day of the yearly seasonality plot.
621632
0 (default) starts the year on Jan 1.
622633
1 shifts by 1 day to Jan 2, and so on.
623-
figsize (tuple): width, height in inches. default: (10, 6)
634+
figsize (tuple): width, height in inches. Ignored if ax is not None.
635+
default: (10, 6)
624636
comp_name (str): Name of seasonality component if previously changed from default 'yearly'.
625637
626638
Returns:
@@ -656,7 +668,8 @@ def plot_weekly(m, ax=None, weekly_start=0, figsize=(10, 6), comp_name='weekly')
656668
weekly_start (int): specifying the start day of the weekly seasonality plot.
657669
0 (default) starts the week on Sunday.
658670
1 shifts by 1 day to Monday, and so on.
659-
figsize (tuple): width, height in inches. default: (10, 6)
671+
figsize (tuple): width, height in inches. Ignored if ax is not None.
672+
default: (10, 6)
660673
comp_name (str): Name of seasonality component if previously changed from default 'weekly'.
661674
662675
Returns:
@@ -680,3 +693,7 @@ def plot_weekly(m, ax=None, weekly_start=0, figsize=(10, 6), comp_name='weekly')
680693
if m.season_config.mode == 'multiplicative':
681694
ax = set_y_as_percent(ax)
682695
return artists
696+
697+
698+
def plot_daily():
699+
raise NotImplementedError

neuralprophet/test_debug.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def test_lag_reg(verbose=True):
137137
plt.show()
138138

139139

140-
def test_holidays(verbose=True):
140+
def test_events(verbose=True):
141141
df = pd.read_csv('../data/example_wp_log_peyton_manning.csv')
142142
playoffs = pd.DataFrame({
143143
'event': 'playoff',
@@ -162,10 +162,10 @@ def test_holidays(verbose=True):
162162
daily_seasonality=False
163163
)
164164
# set event windows
165-
m = m.add_events(["superbowl", "playoff"], lower_window=-1, upper_window=1, mode="additive")
165+
m = m.add_events(["superbowl", "playoff"], lower_window=-1, upper_window=1, mode="multiplicative", regularization=0.5)
166166

167167
# add the country specific holidays
168-
m = m.add_country_holidays("US", mode="multiplicative")
168+
m = m.add_country_holidays("US", mode="additive", regularization=0.5)
169169

170170
history_df = m.create_df_with_events(df, events_df)
171171
m.fit(history_df)
@@ -234,7 +234,7 @@ def test_all(verbose=False):
234234
test_ar_net(verbose)
235235
test_seasons(verbose)
236236
test_lag_reg(verbose)
237-
test_holidays(verbose)
237+
test_events(verbose)
238238
test_predict(verbose)
239239

240240

@@ -251,9 +251,9 @@ def test_all(verbose=False):
251251
# test_ar_net()
252252
# test_seasons()
253253
# test_lag_reg()
254-
# test_holidays()
254+
test_events()
255255
# test_predict()
256-
test_plot()
256+
# test_plot()
257257

258258
# test cases: predict (on fitting data, on future data, on completely new data), train_eval, test function, get_last_forecasts, plotting
259259

neuralprophet/time_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,17 +194,17 @@ def _stride_lagged_features(df_col_name, feature_dims):
194194
events = OrderedDict({})
195195
if n_lags == 0:
196196
if additive_events is not None:
197-
events["additive_events"] = np.expand_dims(additive_events, axis=1)
197+
events["additive"] = np.expand_dims(additive_events, axis=1)
198198
if multiplicative_events is not None:
199-
events["multiplicative_events"] = np.expand_dims(multiplicative_events, axis=1)
199+
events["multiplicative"] = np.expand_dims(multiplicative_events, axis=1)
200200
else:
201201
if additive_events is not None:
202202
additive_event_feature_windows = []
203203
for i in range(0, additive_events.shape[1]):
204204
# stride into num_forecast at dim=1 for each sample, just like we did with time
205205
additive_event_feature_windows.append(_stride_time_features_for_forecasts(additive_events[:, i]))
206206
additive_events = np.dstack(additive_event_feature_windows)
207-
events["additive_events"] = additive_events
207+
events["additive"] = additive_events
208208

209209
if multiplicative_events is not None:
210210
multiplicative_event_feature_windows = []
@@ -213,7 +213,7 @@ def _stride_lagged_features(df_col_name, feature_dims):
213213
multiplicative_event_feature_windows.append(
214214
_stride_time_features_for_forecasts(multiplicative_events[:, i]))
215215
multiplicative_events = np.dstack(multiplicative_event_feature_windows)
216-
events["multiplicative_events"] = multiplicative_events
216+
events["multiplicative"] = multiplicative_events
217217

218218
inputs["events"] = events
219219

neuralprophet/time_net.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def __init__(self,
118118
else:
119119
n_multiplicative_event_params += len(configs['event_indices'])
120120

121-
self.event_params["additive_event_params"] = new_param(dims=[n_additive_event_params])
122-
self.event_params["multiplicative_event_params"] = new_param(dims=[n_multiplicative_event_params])
121+
self.event_params["additive"] = new_param(dims=[n_additive_event_params])
122+
self.event_params["multiplicative"] = new_param(dims=[n_multiplicative_event_params])
123123
else:
124124
self.event_params = None
125125

@@ -203,9 +203,9 @@ def get_event_weights(self, name):
203203
mode = event_dims["mode"]
204204

205205
if mode == "additive":
206-
event_params = self.event_params["additive_event_params"]
206+
event_params = self.event_params["additive"]
207207
if mode == "multiplicative":
208-
event_params = self.event_params["multiplicative_event_params"]
208+
event_params = self.event_params["multiplicative"]
209209

210210
event_param_dict = OrderedDict({})
211211
for event_delim, indices in zip(event_dims["event_delim"], event_dims["event_indices"]):
@@ -408,12 +408,12 @@ def forward(self, inputs):
408408
# else: assert self.season_dims is None
409409

410410
if 'events' in inputs:
411-
if "additive_events" in inputs["events"].keys():
411+
if "additive" in inputs["events"].keys():
412412
additive_components += self.event_effects(
413-
inputs["events"]["additive_events"], self.event_params["additive_event_params"])
414-
if "multiplicative_events" in inputs["events"].keys():
413+
inputs["events"]["additive"], self.event_params["additive"])
414+
if "multiplicative" in inputs["events"].keys():
415415
multiplicative_components += self.event_effects(
416-
inputs["events"]["multiplicative_events"], self.event_params["multiplicative_event_params"])
416+
inputs["events"]["multiplicative"], self.event_params["multiplicative"])
417417

418418
out = trend + trend * multiplicative_components + additive_components
419419

@@ -452,21 +452,21 @@ def compute_components(self, inputs):
452452
for name, lags in inputs['covariates'].items():
453453
components['covar_{}'.format(name)] = self.covariate(lags=lags, name=name)
454454
if "events" in inputs:
455-
if 'additive_events' in inputs["events"].keys():
456-
components['events_additive'] = self.event_effects(features=inputs["events"]["additive_events"],
457-
params=self.event_params["additive_event_params"])
458-
if 'multiplicative_events' in inputs["events"].keys():
459-
components['events_multiplicative'] = self.event_effects(features=inputs["events"]["multiplicative_events"],
460-
params=self.event_params["multiplicative_event_params"])
455+
if 'additive' in inputs["events"].keys():
456+
components['events_additive'] = self.event_effects(features=inputs["events"]["additive"],
457+
params=self.event_params["additive"])
458+
if 'multiplicative' in inputs["events"].keys():
459+
components['events_multiplicative'] = self.event_effects(features=inputs["events"]["multiplicative"],
460+
params=self.event_params["multiplicative"])
461461
for event, configs in self.events_dims.items():
462462
mode = configs["mode"]
463463
indices = configs["event_indices"]
464464
if mode == "additive":
465-
features = inputs["events"]["additive_events"]
466-
params = self.event_params["additive_event_params"]
465+
features = inputs["events"]["additive"]
466+
params = self.event_params["additive"]
467467
else:
468-
features = inputs["events"]["multiplicative_events"]
469-
params = self.event_params["multiplicative_event_params"]
468+
features = inputs["events"]["multiplicative"]
469+
params = self.event_params["multiplicative"]
470470
components['event_{}'.format(event)] = self.event_effects(features=features, params=params, indices=indices)
471471
return components
472472

neuralprophet/utils.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,38 @@ def reg_func_season(weights):
7373
return reg_func_abs(weights)
7474

7575

76-
def reg_func_holidays(weights):
77-
return reg_func_abs(weights)
76+
def reg_func_events(events_config, country_holidays_config, model):
77+
"""
78+
Regularization of events coefficients to induce sparcity
79+
80+
Args:
81+
events_config (OrderedDict): Configurations (upper, lower windows, regularization) for user specified events
82+
country_holidays_config (OrderedDict): Configurations (holiday_names, upper, lower windows, regularization)
83+
for country specific holidays
84+
model (TimeNet): The TimeNet model object
85+
86+
Returns:
87+
regularization loss, scalar
88+
"""
89+
reg_events_loss = 0.0
90+
if events_config is not None:
91+
for event, configs in events_config.items():
92+
reg_lambda = configs["reg_lambda"]
93+
if reg_lambda is not None:
94+
weights = model.get_event_weights(event)
95+
for offset in weights.keys():
96+
reg_events_loss += reg_lambda * reg_func_abs(weights[offset])
97+
98+
if country_holidays_config is not None:
99+
reg_lambda = country_holidays_config["reg_lambda"]
100+
if reg_lambda is not None:
101+
for holiday in country_holidays_config["holiday_names"]:
102+
weights = model.get_event_weights(holiday)
103+
for offset in weights.keys():
104+
reg_events_loss += reg_lambda * reg_func_abs(weights[offset])
105+
106+
return reg_events_loss
107+
78108

79109

80110
def symmetric_total_percentage_error(values, estimates):

0 commit comments

Comments
 (0)