Skip to content

Commit 90f6e6e

Browse files
authored
fixed matplotlib bug (#1395)
1 parent 89e3433 commit 90f6e6e

File tree

3 files changed

+155
-142
lines changed

3 files changed

+155
-142
lines changed

neuralprophet/plot_model_parameters_matplotlib.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,12 @@ def plot_trend(m, quantile, ax=None, plot_name="Trend", figsize=(10, 6), df_name
265265
trend_1 = trend_0
266266
else:
267267
if m.model.config_trend.trend_global_local == "local":
268-
trend_1 = trend_0 + m.model.trend.trend_k0[quantile_index, m.model.id_dict[df_name]].detach().numpy()
268+
trend_1 = (
269+
trend_0
270+
+ m.model.trend.trend_k0[quantile_index, m.model.id_dict[df_name]].detach().numpy().squeeze()
271+
)
269272
else:
270-
trend_1 = trend_0 + m.model.trend.trend_k0[quantile_index, 0].detach().numpy()
273+
trend_1 = trend_0 + m.model.trend.trend_k0[quantile_index, 0].detach().numpy().squeeze()
271274

272275
data_params = m.config_normalization.get_data_params(df_name)
273276
shift = data_params["y"].shift

0 commit comments

Comments
 (0)