Skip to content

Commit 1bfa633

Browse files
ourownstorySimonWittnerMaiBe-ctrl
authored
[Major] Dataloader: Just-In-Time tabularization (#1529)
* minimal pytest * move_func_getitem * slicing * predict_mode * typos * lr-finder * drop_missing * predict_v2 * predict_v3 * samples * lagged regressor n_lags * preliminary: events, holidays * adjustes pytests * selective forecasting * black * ruff * lagged_regressors * Note down df path to TimeDataset * complete notes on TimeDataset, move meta * Big rewrite with real and pseudocode * create_target_start_end_mask * boolean mask * combine masks into map * notes for nan check * bypass NAN filter * rework index to point at prediction origin, not first forecast. * tabularize: converted time and lags to single sample extraction * convert lagged regressors * consolidate seasonality computation in one script * finish Seasonlity conversion * update todos * complete targets and future regressors * convert events * finish events and holidays conversion * debug timedataset * debugging * make_country_specific_holidays_df * remove uses of df.loc[...].values * debug time * debugging types * debug timedata * debugging time_dataset variable shapes * address indexing and slicing issues, .loc * fix dimensions except nonstationary components * integrate torch formatting into tabularize * check shapes * AirPassengers test working! * fix dataset generator * fixed all performance tests but Energy due to nonstationary components * fixed nonstationary issue. all performance tests running * refactor tabularize function * fix bug * initial build of GlobalTimeDataset * refactor TimeDataset not to use kwargs passthrough * debugged seasonal components call of TimeDataset * fix numpy object type error * fix seasonality condition bugs * fix events and future regressor cases * fixing prediction frequency filter * performance_test_energy * debug events * convert new energytest to daily data * fix events util reference * fix test_get_country_holidays * fix test_timedataset_minima * fix selective forecasting * cleanup timedataset * refactor tabularize_univariate * daily_data * start nan check for smaple mask * working on time nan2 * fix tests * finish nan-check * fix dims * pass self.df to indexing * fix zero dim lagged regressors * close figures in tests * fix typings * black * ruff * linting * linting * modify logs * add benchmarking script for computational time * speed up uncertainty tests * fix unit test multiple country * reduce tests log level to ERROR * reduce log level to ERROR and fix adding multiple countries * bypass intentional glocal test error log * fix prev * benchmark dataloader time * remove hourly energy test * add debug notebook for energy hourly * set to log model performance INFO * address config_regressors.regressors * clean up create_nan_mask * clean up create_nan_mask params * clean TimeDataframe * update prediction frequency documentation * improve prediction frequency documentation * further improve prediction frequency documentation * fix test errors * fix df_names call * fix selective prediction assertion * normalize holiday naes * fix linting * fix tests * update to use new holiday functions in event_utils.py * fix seasonality_local_reg test * limit holidays to less than 1.0 * changed holidays * update lock * changed tests * adjsuted tests * fix reserved names * fixed ruff lintint * changed test * translate holidays to english is possible * exclude py3.13 * update lock * Merge all holidays related tests in one file * add deterministic flag * fixed ruff linting issues * fixed glocal test * fix lock file * update poetry * moved the deterministic flag to the train method * update lock file --------- Co-authored-by: Simon W <simon.wittner@gmx.net> Co-authored-by: MaiBe-ctrl <maiisabensalah@gmail.com> Co-authored-by: Maisa Ben Salah <76703998+MaiBe-ctrl@users.noreply.github.com>
1 parent e90bf5a commit 1bfa633

29 files changed

+4251
-881
lines changed

docs/source/code/forecaster.rst

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
1-
NeuralProphet Class
2-
-----------------------
1+
Core Module Documentation
2+
==========================
3+
4+
.. toctree::
5+
:hidden:
6+
:maxdepth: 1
7+
8+
configure.py <configure>
9+
df_utils.py <df_utils>
10+
event_utils.py <event_utils>
11+
plot_forecast_plotly.py <plot_forecast_plotly>
12+
plot_forecast_matplotlib.py <plot_forecast_matplotlib>
13+
plot_model_parameters_plotly.py <plot_model_parameters_plotly>
14+
plot_model_parameters_matplotlib.py <plot_model_parameters_matplotlib>
15+
time_dataset.py <time_dataset>
16+
time_net.py <time_net>
17+
utils.py <utils>
318

419
.. automodule:: neuralprophet.forecaster
520
:members:

docs/source/code/hdays_utils.rst

Lines changed: 0 additions & 5 deletions
This file was deleted.

docs/source/how-to-guides/feature-guides/mlflow.ipynb

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@
177177
"# Start a new MLflow run\n",
178178
"if local:\n",
179179
" with mlflow.start_run():\n",
180-
"\n",
181180
" # Create a new MLflow experiment\n",
182181
" mlflow.set_experiment(\"NP-MLflow Quickstart_v1\")\n",
183182
"\n",
@@ -259,7 +258,6 @@
259258
"from mlflow.data.pandas_dataset import PandasDataset\n",
260259
"\n",
261260
"if local:\n",
262-
"\n",
263261
" mlflow.pytorch.autolog(\n",
264262
" log_every_n_epoch=1,\n",
265263
" log_every_n_step=None,\n",
@@ -279,7 +277,6 @@
279277
" model_name = \"NeuralProphet\"\n",
280278
"\n",
281279
" with mlflow.start_run() as run:\n",
282-
"\n",
283280
" dataset: PandasDataset = mlflow.data.from_pandas(df, source=\"AirPassengersDataset\")\n",
284281
"\n",
285282
" # Log the dataset to the MLflow Run. Specify the \"training\" context to indicate that the\n",

neuralprophet/components/future_regressors/linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ def scalar_features_effects(self, features, params, indices=None):
5252
if indices is not None:
5353
features = features[:, :, indices]
5454
params = params[:, indices]
55-
56-
return torch.sum(features.unsqueeze(dim=2) * params.unsqueeze(dim=0).unsqueeze(dim=0), dim=-1)
55+
# features dims: (batch, n_forecasts, n_features) -> (batch, n_forecasts, 1, n_features)
56+
# params dims: (n_quantiles, n_features) -> (batch, 1, n_quantiles, n_features)
57+
out = torch.sum(features.unsqueeze(dim=2) * params.unsqueeze(dim=0).unsqueeze(dim=0), dim=-1)
58+
return out # dims (batch, n_forecasts, n_quantiles)
5759

5860
def get_reg_weights(self, name):
5961
"""

neuralprophet/configure.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from neuralprophet import df_utils, np_types, utils_torch
1717
from neuralprophet.custom_loss_metrics import PinballLoss
18-
from neuralprophet.hdays_utils import get_holidays_from_country
18+
from neuralprophet.event_utils import get_holiday_names
1919

2020
log = logging.getLogger("NP.config")
2121

@@ -42,10 +42,9 @@ def init_data_params(
4242
config_events: Optional[ConfigEvents] = None,
4343
config_seasonality: Optional[ConfigSeasonality] = None,
4444
):
45-
if len(df["ID"].unique()) == 1:
46-
if not self.global_normalization:
47-
log.info("Setting normalization to global as only one dataframe provided for training.")
48-
self.global_normalization = True
45+
if len(df["ID"].unique()) == 1 and not self.global_normalization:
46+
log.info("Setting normalization to global as only one dataframe provided for training.")
47+
self.global_normalization = True
4948
self.local_data_params, self.global_data_params = df_utils.init_data_params(
5049
df=df,
5150
normalize=self.normalize,
@@ -508,7 +507,7 @@ class Holidays:
508507
holiday_names: set = field(init=False)
509508

510509
def init_holidays(self, df=None):
511-
self.holiday_names = get_holidays_from_country(self.country, df)
510+
self.holiday_names = get_holiday_names(self.country, df)
512511

513512

514513
ConfigCountryHolidays = Holidays

neuralprophet/data/process.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -333,18 +333,18 @@ def _validate_column_name(
333333
"""
334334
reserved_names = [
335335
"trend",
336-
"additive_terms",
337336
"daily",
338337
"weekly",
339338
"yearly",
340339
"events",
341340
"holidays",
342-
"zeros",
343-
"extra_regressors_additive",
344341
"yhat",
345-
"extra_regressors_multiplicative",
346-
"multiplicative_terms",
347342
"ID",
343+
"y_scaled",
344+
"ds",
345+
"t",
346+
"y",
347+
"index",
348348
]
349349
rn_l = [n + "_lower" for n in reserved_names]
350350
rn_u = [n + "_upper" for n in reserved_names]
@@ -434,14 +434,14 @@ def _check_dataframe(
434434

435435
def _handle_missing_data(
436436
df: pd.DataFrame,
437-
freq: Optional[str],
437+
freq: str,
438438
n_lags: int,
439439
n_forecasts: int,
440440
config_missing,
441-
config_regressors: Optional[ConfigFutureRegressors],
442-
config_lagged_regressors: Optional[ConfigLaggedRegressors],
443-
config_events: Optional[ConfigEvents],
444-
config_seasonality: Optional[ConfigSeasonality],
441+
config_regressors: Optional[ConfigFutureRegressors] = None,
442+
config_lagged_regressors: Optional[ConfigLaggedRegressors] = None,
443+
config_events: Optional[ConfigEvents] = None,
444+
config_seasonality: Optional[ConfigSeasonality] = None,
445445
predicting: bool = False,
446446
) -> pd.DataFrame:
447447
"""
@@ -618,12 +618,13 @@ def _create_dataset(model, df, predict_mode, prediction_frequency=None):
618618
predict_mode=predict_mode,
619619
n_lags=model.n_lags,
620620
n_forecasts=model.n_forecasts,
621+
prediction_frequency=prediction_frequency,
621622
predict_steps=model.predict_steps,
622623
config_seasonality=model.config_seasonality,
623624
config_events=model.config_events,
624625
config_country_holidays=model.config_country_holidays,
625-
config_lagged_regressors=model.config_lagged_regressors,
626626
config_regressors=model.config_regressors,
627+
config_lagged_regressors=model.config_lagged_regressors,
627628
config_missing=model.config_missing,
628-
prediction_frequency=prediction_frequency,
629+
# config_train=model.config_train, # no longer needed since JIT tabularization.
629630
)

neuralprophet/df_utils.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,28 +88,27 @@ def return_df_in_original_format(df, received_ID_col=False, received_single_time
8888
return new_df
8989

9090

91-
def get_max_num_lags(config_lagged_regressors: Optional[ConfigLaggedRegressors], n_lags: int) -> int:
91+
def get_max_num_lags(n_lags: int, config_lagged_regressors: Optional[ConfigLaggedRegressors]) -> int:
9292
"""Get the greatest number of lags between the autoregression lags and the covariates lags.
9393
9494
Parameters
9595
----------
96-
config_lagged_regressors : configure.ConfigLaggedRegressors
97-
Configurations for lagged regressors
9896
n_lags : int
9997
number of lagged values of series to include as model inputs
98+
config_lagged_regressors : configure.ConfigLaggedRegressors
99+
Configurations for lagged regressors
100100
101101
Returns
102102
-------
103103
int
104104
Maximum number of lags between the autoregression lags and the covariates lags.
105105
"""
106106
if config_lagged_regressors is not None:
107-
log.debug("config_lagged_regressors exists")
108-
max_n_lags = max([n_lags] + [val.n_lags for key, val in config_lagged_regressors.items()])
107+
# log.debug("config_lagged_regressors exists")
108+
return max([n_lags] + [val.n_lags for key, val in config_lagged_regressors.items()])
109109
else:
110-
log.debug("config_lagged_regressors does not exist")
111-
max_n_lags = n_lags
112-
return max_n_lags
110+
# log.debug("config_lagged_regressors does not exist")
111+
return n_lags
113112

114113

115114
def merge_dataframes(df: pd.DataFrame) -> pd.DataFrame:
@@ -508,14 +507,12 @@ def check_dataframe(
508507
for name in columns:
509508
if name not in df:
510509
raise ValueError(f"Column {name!r} missing from dataframe")
511-
if df.loc[df.loc[:, name].notnull()].shape[0] < 1:
510+
if sum(df.loc[:, name].notnull().values) < 1:
512511
raise ValueError(f"Dataframe column {name!r} only has NaN rows.")
513512
if not np.issubdtype(df[name].dtype, np.number):
514513
df[name] = pd.to_numeric(df[name])
515514
if np.isinf(df.loc[:, name].values).any():
516515
df.loc[:, name] = df[name].replace([np.inf, -np.inf], np.nan)
517-
if df.loc[df.loc[:, name].notnull()].shape[0] < 1:
518-
raise ValueError(f"Dataframe column {name!r} only has NaN rows.")
519516

520517
if future:
521518
return df, regressors_to_remove, lag_regressors_to_remove
@@ -1541,10 +1538,10 @@ def drop_missing_from_df(df, drop_missing, predict_steps, n_lags):
15411538
if all_nan_idx[i + 1] - all_nan_idx[i] > 1:
15421539
break
15431540
# drop NaN window
1544-
df = df.drop(df.index[window[0] : window[-1] + 1]).reset_index().drop("index", axis=1)
1541+
df = df.drop(df.index[window[0] : window[-1] + 1]).reset_index(drop=True)
15451542
# drop lagged values if window does not occur at the beginning of df
15461543
if window[0] - (n_lags - 1) >= 0:
1547-
df = df.drop(df.index[(window[0] - (n_lags - 1)) : window[0]]).reset_index().drop("index", axis=1)
1544+
df = df.drop(df.index[(window[0] - (n_lags - 1)) : window[0]]).reset_index(drop=True)
15481545
return df
15491546

15501547

neuralprophet/event_utils.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from collections import defaultdict
2+
from typing import Iterable, Union
3+
4+
import numpy as np
5+
import pandas as pd
6+
from holidays import country_holidays
7+
8+
9+
def get_holiday_names(country: Union[str, Iterable[str]], df=None):
10+
"""
11+
Return all possible holiday names for a list of countries over time period in df
12+
13+
Parameters
14+
----------
15+
country : str, list
16+
List of country names to retrieve country specific holidays
17+
df : pd.Dataframe
18+
Dataframe from which datestamps will be retrieved from
19+
20+
Returns
21+
-------
22+
set
23+
All possible holiday names of given country
24+
"""
25+
if df is None:
26+
years = np.arange(1995, 2045)
27+
else:
28+
dates = df["ds"].copy(deep=True)
29+
years = pd.unique(dates.apply(lambda x: x.year))
30+
# years = list({x.year for x in dates})
31+
# support multiple countries, convert to list if not already
32+
if isinstance(country, str):
33+
country = [country]
34+
35+
all_holidays = get_all_holidays(years=years, country=country)
36+
return set(all_holidays.keys())
37+
38+
39+
def get_all_holidays(years, country):
40+
"""
41+
Make dataframe of country specific holidays for given years and countries
42+
Parameters
43+
----------
44+
year_list : list
45+
List of years
46+
country : str, list, dict
47+
List of country names and optional subdivisions
48+
Returns
49+
-------
50+
pd.DataFrame
51+
Containing country specific holidays df with columns 'ds' and 'holiday'
52+
"""
53+
# convert to list if not already
54+
if isinstance(country, str):
55+
country = {country: None}
56+
elif isinstance(country, list):
57+
country = dict(zip(country, [None] * len(country)))
58+
59+
all_holidays = defaultdict(list)
60+
# iterate over countries and get holidays for each country
61+
for single_country, subdivision in country.items():
62+
# For compatibility with Turkey as "TU" cases.
63+
single_country = "TUR" if single_country == "TU" else single_country
64+
# get dict of dates and their holiday name
65+
single_country_specific_holidays = country_holidays(
66+
country=single_country, subdiv=subdivision, years=years, expand=True, observed=False, language="en"
67+
)
68+
# invert order - for given holiday, store list of dates
69+
for date, name in single_country_specific_holidays.items():
70+
all_holidays[name].append(pd.to_datetime(date))
71+
return all_holidays

0 commit comments

Comments
 (0)