diff --git a/docs/source/models.rst b/docs/source/models.rst index 67dd7c042..41afe8241 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -27,6 +27,7 @@ and you should take into account. Here is an overview over the pros and cons of :py:class:`~pytorch_forecasting.models.rnn.RecurrentNetwork`, "x", "x", "x", "", "", "", "", "x", "", 2 :py:class:`~pytorch_forecasting.models.mlp.DecoderMLP`, "x", "x", "x", "x", "", "x", "", "x", "x", 1 :py:class:`~pytorch_forecasting.models.nbeats.NBeats`, "", "", "x", "", "", "", "", "", "", 1 + :py:class:`~pytorch_forecasting.models.nbeats.NBeatsKAN`, "", "", "x", "", "", "", "", "", "", 1 :py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1 :py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "x [#deepvar]_ ", "x", "", 3 :py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4 diff --git a/examples/nbeats_with_kan.py b/examples/nbeats_with_kan.py new file mode 100644 index 000000000..6a018ce5d --- /dev/null +++ b/examples/nbeats_with_kan.py @@ -0,0 +1,105 @@ +import sys + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +import pandas as pd + +from pytorch_forecasting import NBeatsKAN, TimeSeriesDataSet +from pytorch_forecasting.data import NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data +from pytorch_forecasting.models.nbeats import GridUpdateCallback + +sys.path.append("..") + + +print("load data") +data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100) +data["static"] = 2 +data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") +validation = data.series.sample(20) + + +max_encoder_length = 150 +max_prediction_length = 20 + +training_cutoff = data["time_idx"].max() - max_prediction_length + +context_length = max_encoder_length +prediction_length = max_prediction_length + +training = TimeSeriesDataSet( + data[lambda x: x.time_idx < training_cutoff], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + min_encoder_length=context_length, + max_encoder_length=context_length, + max_prediction_length=prediction_length, + min_prediction_length=prediction_length, + time_varying_unknown_reals=["value"], + randomize_length=None, + add_relative_time_idx=False, + add_target_scales=False, +) + +validation = TimeSeriesDataSet.from_dataset( + training, data, min_prediction_idx=training_cutoff +) +batch_size = 128 +train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 +) +val_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 +) + + +early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min" +) +# updates KAN layers' grid after every 3 steps during training +grid_update_callback = GridUpdateCallback(update_interval=3) + +trainer = pl.Trainer( + max_epochs=1, + accelerator="auto", + gradient_clip_val=0.1, + callbacks=[early_stop_callback, grid_update_callback], + limit_train_batches=15, + # limit_val_batches=1, + # fast_dev_run=True, + # logger=logger, + # profiler=True, +) + + +net = NBeatsKAN.from_dataset( + training, + learning_rate=3e-2, + log_interval=10, + log_val_interval=1, + log_gradient_flow=False, + weight_decay=1e-2, +) +print(f"Number of parameters in network: {net.size() / 1e3:.1f}k") + +# # find optimal learning rate +# # remove logging and artificial epoch size +# net.hparams.log_interval = -1 +# net.hparams.log_val_interval = -1 +# trainer.limit_train_batches = 1.0 +# # run learning rate finder +# res = Tuner(trainer).lr_find( +# net, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5, max_lr=1e2 # noqa: E501 +# ) +# print(f"suggested learning rate: {res.suggestion()}") +# fig = res.plot(show=True, suggest=True) +# fig.show() +# net.hparams.learning_rate = res.suggestion() + +trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, +) diff --git a/pytorch_forecasting/__init__.py b/pytorch_forecasting/__init__.py index dede44fd1..47f2be0f6 100644 --- a/pytorch_forecasting/__init__.py +++ b/pytorch_forecasting/__init__.py @@ -43,6 +43,7 @@ DeepAR, MultiEmbedding, NBeats, + NBeatsKAN, NHiTS, RecurrentNetwork, TemporalFusionTransformer, @@ -73,6 +74,7 @@ "TemporalFusionTransformer", "TiDEModel", "NBeats", + "NBeatsKAN", "NHiTS", "Baseline", "DeepAR", diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index 29aeb24f5..4a7aba3e0 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -11,7 +11,7 @@ from pytorch_forecasting.models.baseline import Baseline from pytorch_forecasting.models.deepar import DeepAR from pytorch_forecasting.models.mlp import DecoderMLP -from pytorch_forecasting.models.nbeats import NBeats +from pytorch_forecasting.models.nbeats import NBeats, NBeatsKAN from pytorch_forecasting.models.nhits import NHiTS from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn from pytorch_forecasting.models.rnn import RecurrentNetwork @@ -22,6 +22,7 @@ __all__ = [ "NBeats", + "NBeatsKAN", "NHiTS", "TemporalFusionTransformer", "RecurrentNetwork", diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index b3264272d..b588093af 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -1,10 +1,21 @@ """N-Beats model for timeseries forecasting without covariates.""" +from pytorch_forecasting.models.nbeats._grid_callback import GridUpdateCallback from pytorch_forecasting.models.nbeats._nbeats import NBeats +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter +from pytorch_forecasting.models.nbeats._nbeatskan import NBeatsKAN from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock, ) -__all__ = ["NBeats", "NBEATSGenericBlock", "NBEATSSeasonalBlock", "NBEATSTrendBlock"] +__all__ = [ + "NBeats", + "NBeatsKAN", + "NBEATSGenericBlock", + "NBEATSSeasonalBlock", + "NBEATSTrendBlock", + "NBeatsAdapter", + "GridUpdateCallback", +] diff --git a/pytorch_forecasting/models/nbeats/_grid_callback.py b/pytorch_forecasting/models/nbeats/_grid_callback.py new file mode 100644 index 000000000..d311cfe84 --- /dev/null +++ b/pytorch_forecasting/models/nbeats/_grid_callback.py @@ -0,0 +1,42 @@ +from lightning.pytorch.callbacks import Callback + + +class GridUpdateCallback(Callback): + """ + Custom callback to update the grid of the model during training at regular + intervals. + + Example: + See the full example in: + `examples/nbeats_with_kan.py` + + Attributes: + update_interval (int): The frequency at which the grid is updated. + """ + + def __init__(self, update_interval): + """ + Initializes the callback with the given update interval. + + Args: + update_interval (int): The frequency at which the grid is updated. + """ + self.update_interval = update_interval + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + """ + Hook that is called at the end of each training batch. + Updates the grid of KAN layers if the current step is a multiple of the update + interval. + + Args: + trainer (Trainer): The PyTorch Lightning Trainer object. + pl_module (LightningModule): The model being trained (LightningModule). + outputs (Any): Outputs from the model for the current batch. + batch (Any): The current batch of data. + batch_idx (int): Index of the current batch. + """ + # Check if the current step is a multiple of the update interval + if (trainer.global_step + 1) % self.update_interval == 0: + # Call the model's update_kan_grid method + pl_module.update_kan_grid() diff --git a/pytorch_forecasting/models/nbeats/_kan_layer.py b/pytorch_forecasting/models/nbeats/_kan_layer.py new file mode 100644 index 000000000..1f7a18a1c --- /dev/null +++ b/pytorch_forecasting/models/nbeats/_kan_layer.py @@ -0,0 +1,364 @@ +# The following implementation of KANLayer is inspired by the pykan library. +# Reference: https://github.com/KindXiaoming/pykan/blob/master/kan/KANLayer.py + +import numpy as np +import torch +import torch.nn as nn + + +def B_batch(x, grid, k=0, extend=True): + """ + evaludate x on B-spline bases + + Args: + ----- + x : 2D torch.tensor + inputs, shape (number of splines, number of samples) + grid : 2D torch.tensor + grids, shape (number of splines, number of grid points) + k : int + the piecewise polynomial order of splines. + extend : bool + If True, k points are extended on both ends. If False, no extension + (zero boundary condition). Default: True + + Returns: + -------- + spline values : 3D torch.tensor + shape (batch, in_dim, G+k). G: the number of grid intervals, + k: spline order. + + Example + ------- + >>> from kan.spline import B_batch + >>> x = torch.rand(100,2) + >>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11) + >>> B_batch(x, grid, k=3).shape + """ + + x = x.unsqueeze(dim=2) + grid = grid.unsqueeze(dim=0) + + if k == 0: + value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:]) + else: + B_km1 = B_batch(x[:, :, 0], grid=grid[0], k=k - 1) + + value = (x - grid[:, :, : -(k + 1)]) / ( + grid[:, :, k:-1] - grid[:, :, : -(k + 1)] + ) * B_km1[:, :, :-1] + (grid[:, :, k + 1 :] - x) / ( + grid[:, :, k + 1 :] - grid[:, :, 1:(-k)] + ) * B_km1[ + :, :, 1: + ] + + # in case grid is degenerate + value = torch.nan_to_num(value) + return value + + +def coef2curve(x_eval, grid, coef, k): + """ + converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves + (summing up B_batch results over B-spline basis). + + Args: + ----- + x_eval : 2D torch.tensor + shape (batch, in_dim) + grid : 2D torch.tensor + shape (in_dim, G+2k). G: the number of grid intervals; k: spline order. + coef : 3D torch.tensor + shape (in_dim, out_dim, G+k) + k : int + the piecewise polynomial order of splines. + + Returns: + -------- + y_eval : 3D torch.tensor + shape (batch, in_dim, out_dim) + + """ + + b_splines = B_batch(x_eval, grid, k=k) + y_eval = torch.einsum("ijk,jlk->ijl", b_splines, coef.to(b_splines)) + + return y_eval + + +def curve2coef(x_eval, y_eval, grid, k): + """ + converting B-spline curves to B-spline coefficients using least squares. + + Args: + ----- + x_eval : 2D torch.tensor + shape (batch, in_dim) + y_eval : 3D torch.tensor + shape (batch, in_dim, out_dim) + grid : 2D torch.tensor + shape (in_dim, grid+2*k) + k : int + spline order + lamb : float + regularized least square lambda + + Returns: + -------- + coef : 3D torch.tensor + shape (in_dim, out_dim, G+k) + """ + batch = x_eval.shape[0] + in_dim = x_eval.shape[1] + out_dim = y_eval.shape[2] + n_coef = grid.shape[1] - k - 1 + mat = B_batch(x_eval, grid, k) + mat = mat.permute(1, 0, 2)[:, None, :, :].expand(in_dim, out_dim, batch, n_coef) + y_eval = y_eval.permute(1, 2, 0).unsqueeze(dim=3) + try: + coef = torch.linalg.lstsq(mat, y_eval).solution[:, :, :, 0] + except Exception as e: + print(f"lstsq failed with error: {e}") + + return coef + + +def extend_grid(grid, k_extend=0): + """ + extend grid + """ + h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) + + for i in range(k_extend): + grid = torch.cat([grid[:, [0]] - h, grid], dim=1) + grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) + + return grid + + +def sparse_mask(in_dim, out_dim): + """ + get sparse mask + """ + in_coord = torch.arange(in_dim) * 1 / in_dim + 1 / (2 * in_dim) + out_coord = torch.arange(out_dim) * 1 / out_dim + 1 / (2 * out_dim) + + dist_mat = torch.abs(out_coord[:, None] - in_coord[None, :]) + in_nearest = torch.argmin(dist_mat, dim=0) + in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1, 0) + out_nearest = torch.argmin(dist_mat, dim=1) + out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1, 0) + all_connection = torch.cat([in_connection, out_connection], dim=0) + mask = torch.zeros(in_dim, out_dim) + mask[all_connection[:, 0], all_connection[:, 1]] = 1.0 + + return mask + + +class KANLayer(nn.Module): + """ + KANLayer class + """ + + def __init__( + self, + in_dim=3, + out_dim=2, + num=5, + k=3, + noise_scale=0.5, + scale_base_mu=0.0, + scale_base_sigma=1.0, + scale_sp=1.0, + base_fun=torch.nn.SiLU(), + grid_eps=0.02, + grid_range=[-1, 1], + sp_trainable=True, + sb_trainable=True, + sparse_init=False, + ): + """' + initialize a KANLayer + + Args: + ----- + in_dim : int + input dimension. Default: 2. + out_dim : int + output dimension. Default: 3. + num : int + the number of grid intervals = G. Default: 5. + k : int + the order of piecewise polynomial. Default: 3. + noise_scale : float + the scale of noise injected at initialization. Default: 0.1. + scale_base_mu : float + the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma : float + the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). + scale_sp : float + the scale of the base function spline(x). + base_fun : function + residual function b(x). Default: torch.nn.SiLU() + grid_eps : float + When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is + partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates + between the two extremes. + grid_range : list/np.array of shape (2,) + setting the range of grids. Default: [-1,1]. + sp_trainable : bool + If true, scale_sp is trainable + sb_trainable : bool + If true, scale_base is trainable + sparse_init : bool + if sparse_init = True, sparse initialization is applied. + + Returns: + -------- + self + + Example + ------- + >>> from kan.KANLayer import * + >>> model = KANLayer(in_dim=3, out_dim=5) + >>> (model.in_dim, model.out_dim) + """ + super(KANLayer, self).__init__() + # size + self.out_dim = out_dim + self.in_dim = in_dim + self.num = num + self.k = k + + grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[ + None, : + ].expand(self.in_dim, num + 1) + grid = extend_grid(grid, k_extend=k) + self.grid = torch.nn.Parameter(grid).requires_grad_(False) + noises = ( + (torch.rand(self.num + 1, self.in_dim, self.out_dim) - 1 / 2) + * noise_scale + / num + ) + + self.coef = torch.nn.Parameter( + curve2coef(self.grid[:, k:-k].permute(1, 0), noises, self.grid, k) + ) + + if sparse_init: + self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_( + False + ) + else: + self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_( + False + ) + + self.scale_base = torch.nn.Parameter( + scale_base_mu * 1 / np.sqrt(in_dim) + + scale_base_sigma + * (torch.rand(in_dim, out_dim) * 2 - 1) + * 1 + / np.sqrt(in_dim) + ).requires_grad_(sb_trainable) + self.scale_sp = torch.nn.Parameter( + torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask + ).requires_grad_( + sp_trainable + ) # make scale trainable + self.base_fun = base_fun + + self.grid_eps = grid_eps + + def forward(self, x): + """ + KANLayer forward given input x + + Args: + ----- + x : 2D torch.float + inputs, shape (number of samples, input dimension) + + Returns: + -------- + y : 2D torch.float + outputs, shape (number of samples, output dimension) + preacts : 3D torch.float + fan out x into activations, shape (number of sampels, + output dimension, input dimension) + postacts : 3D torch.float + the outputs of activation functions with preacts as inputs + postspline : 3D torch.float + the outputs of spline functions with preacts as inputs + + Example + ------- + >>> from kan.KANLayer import * + >>> model = KANLayer(in_dim=3, out_dim=5) + >>> x = torch.normal(0,1,size=(100,3)) + >>> y, preacts, postacts, postspline = model(x) + >>> y.shape, preacts.shape, postacts.shape, postspline.shape + """ + + base = self.base_fun(x) # (batch, in_dim) + y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k) + y = ( + self.scale_base[None, :, :] * base[:, :, None] + + self.scale_sp[None, :, :] * y + ) + y = self.mask[None, :, :] * y + y = torch.sum(y, dim=1) + return y + + def update_grid_from_samples(self, x, mode="sample"): + """ + update grid from samples + + Args: + ----- + x : 2D torch.float + inputs, shape (number of samples, input dimension) + + Returns: + -------- + None + + Example + ------- + >>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) + >>> print(model.grid.data) + >>> x = torch.linspace(-3,3,steps=100)[:,None] + >>> model.update_grid_from_samples(x) + >>> print(model.grid.data) + """ + + batch = x.shape[0] + x_pos = torch.sort(x, dim=0)[0] + y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) + num_interval = self.grid.shape[1] - 1 - 2 * self.k + + def get_grid(num_interval): + ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] + grid_adaptive = x_pos[ids, :].permute(1, 0) + margin = 0.00 + h = ( + grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin + ) / num_interval + grid_uniform = ( + grid_adaptive[:, [0]] + - margin + + h * torch.arange(num_interval + 1, device=h.device)[None, :] + ) + grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive + return grid + + grid = get_grid(num_interval) + if mode == "grid": + sample_grid = get_grid(2 * num_interval) + x_pos = sample_grid.permute(1, 0) + y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) + + self.grid.data = extend_grid(grid, k_extend=self.k) + self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index 3181d818c..3326bb5a9 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -2,24 +2,20 @@ N-Beats model for timeseries forecasting without covariates. """ -from typing import Dict, List, Optional +from typing import List, Optional -import torch from torch import nn -from pytorch_forecasting.data import TimeSeriesDataSet -from pytorch_forecasting.data.encoders import NaNLabelEncoder from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric -from pytorch_forecasting.models.base import BaseModel +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock, ) -from pytorch_forecasting.utils._dependencies import _check_matplotlib -class NBeats(BaseModel): +class NBeats(NBeatsAdapter): def __init__( self, stack_types: Optional[List[str]] = None, @@ -47,50 +43,62 @@ def __init__( Based on the article `N-BEATS: Neural basis expansion analysis for interpretable time series - forecasting `_. The network has (if used as ensemble) outperformed all - other methods - including ensembles of traditional statical methods in the M4 competition. The M4 competition is arguably - the most - important benchmark for univariate time series forecasting. + forecasting `_. The network has (if + used as ensemble) outperformed all other methods including ensembles of + traditional statical methods in the M4 competition. The M4 competition is + arguably the most important benchmark for univariate time series forecasting. - The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently shown to consistently outperform - N-BEATS. + The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently + shown to consistently outperform N-BEATS. Args: - stack_types: One of the following values: “generic”, “seasonality" or “trend". A list of strings - of length 1 or ‘num_stacks’. Default and recommended value - for generic mode: [“generic”] Recommended value for interpretable mode: [“trend”,”seasonality”] - num_blocks: The number of blocks per stack. A list of ints of length 1 or ‘num_stacks’. - Default and recommended value for generic mode: [1] Recommended value for interpretable mode: [3] - num_block_layers: Number of fully connected layers with ReLu activation per block. A list of ints of length - 1 or ‘num_stacks’. - Default and recommended value for generic mode: [4] Recommended value for interpretable mode: [4] - width: Widths of the fully connected layers with ReLu activation in the blocks. - A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [512] - Recommended value for interpretable mode: [256, 2048] + stack_types: One of the following values: “generic”, “seasonality" or + “trend". A list of strings of length 1 or 'num_stacks'. Default and + recommended value for generic mode: [“generic”] Recommended value for + interpretable mode: [“trend”,”seasonality”]. + num_blocks: The number of blocks per stack. A list of ints of length 1 or + 'num_stacks'. Default and recommended value for generic mode: [1] + Recommended value for interpretable mode: [3] + num_block_layers: Number of fully connected layers with ReLu activation per + block. + A list of ints of length 1 or 'num_stacks'. Default and recommended + value for generic mode: [4] Recommended value for interpretable mode: + [4]. + width: Widths of the fully connected layers with ReLu activation in the + blocks. A list of ints of length 1 or 'num_stacks'. Default and + recommended value for generic mode: [512]. Recommended value for + interpretable mode: [256, 2048] sharing: Whether the weights are shared with the other blocks per stack. - A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [False] - Recommended value for interpretable mode: [True] - expansion_coefficient_length: If the type is “G” (generic), then the length of the expansion - coefficient. - If type is “T” (trend), then it corresponds to the degree of the polynomial. If the type is “S” - (seasonal) then this is the minimum period allowed, e.g. 2 for changes every timestep. - A list of ints of length 1 or ‘num_stacks’. Default value for generic mode: [32] Recommended value for + A list of ints of length 1 or 'num_stacks'. Default and recommended + value for generic mode: [False]. Recommended value for interpretable + mode: [True]. + expansion_coefficient_length: If the type is “G” (generic), then the length + of the expansion coefficient. + If type is “T” (trend), then it corresponds to the degree of the + polynomial. + If the type is “S” (seasonal) then this is the minimum period allowed, + e.g. 2 for changes every timestep. A list of ints of length 1 or + 'num_stacks'. Default value for generic mode: [32] Recommended value for interpretable mode: [3] prediction_length: Length of the prediction. Also known as 'horizon'. - context_length: Number of time units that condition the predictions. Also known as 'lookback period'. + context_length: Number of time units that condition the predictions. + Also known as 'lookback period'. Should be between 1-10 times the prediction length. - backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss. - A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and - forecast lengths). Defaults to 0.0, i.e. no weight. + backcast_loss_ratio: weight of backcast in comparison to forecast when + calculating the loss. A weight of 1.0 means that forecast and + backcast loss is weighted the same (regardless of backcast and forecast + lengths). Defaults to 0.0, i.e. no weight. loss: loss to optimize. Defaults to MASE(). - log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training - failures - reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 - logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training. - Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + log_gradient_flow: if to log gradient flow, this takes time and should be + only done to diagnose training failures. + reduce_on_plateau_patience (int): patience after which learning rate is + reduced by a factor of 10 + logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that + are logged during training. Defaults to + nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) **kwargs: additional arguments to :py:class:`~BaseModel`. """ # noqa: E501 + if expansion_coefficient_lengths is None: expansion_coefficient_lengths = [3, 7] if sharing is None: @@ -107,9 +115,9 @@ def __init__( logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) if loss is None: loss = MASE() - self.save_hyperparameters() - super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + self.save_hyperparameters(ignore=["loss", "logging_metrics"]) + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) # setup stacks self.net_blocks = nn.ModuleList() for stack_id, stack_type in enumerate(stack_types): @@ -121,7 +129,7 @@ def __init__( num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, - dropout=self.hparams.dropout, + dropout=dropout, ) elif stack_type == "seasonality": net_block = NBEATSSeasonalBlock( @@ -129,8 +137,8 @@ def __init__( num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, - min_period=self.hparams.expansion_coefficient_lengths[stack_id], - dropout=self.hparams.dropout, + min_period=expansion_coefficient_lengths[stack_id], + dropout=dropout, ) elif stack_type == "trend": net_block = NBEATSTrendBlock( @@ -139,296 +147,9 @@ def __init__( num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, - dropout=self.hparams.dropout, + dropout=dropout, ) else: raise ValueError(f"Unknown stack type {stack_type}") self.net_blocks.append(net_block) - - def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - Pass forward of network. - - Args: - x (Dict[str, torch.Tensor]): input from dataloader generated from - :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - - Returns: - Dict[str, torch.Tensor]: output of model - """ - target = x["encoder_cont"][..., 0] - - timesteps = self.hparams.context_length + self.hparams.prediction_length - generic_forecast = [ - torch.zeros( - (target.size(0), timesteps), dtype=torch.float32, device=self.device - ) - ] - trend_forecast = [ - torch.zeros( - (target.size(0), timesteps), dtype=torch.float32, device=self.device - ) - ] - seasonal_forecast = [ - torch.zeros( - (target.size(0), timesteps), dtype=torch.float32, device=self.device - ) - ] - forecast = torch.zeros( - (target.size(0), self.hparams.prediction_length), - dtype=torch.float32, - device=self.device, - ) - - backcast = target # initialize backcast - for i, block in enumerate(self.net_blocks): - # evaluate block - backcast_block, forecast_block = block(backcast) - - # add for interpretation - full = torch.cat([backcast_block.detach(), forecast_block.detach()], dim=1) - if isinstance(block, NBEATSTrendBlock): - trend_forecast.append(full) - elif isinstance(block, NBEATSSeasonalBlock): - seasonal_forecast.append(full) - else: - generic_forecast.append(full) - - # update backcast and forecast - backcast = ( - backcast - backcast_block - ) # do not use backcast -= backcast_block as this signifies an inline operation # noqa : E501 - forecast = forecast + forecast_block - - return self.to_network_output( - prediction=self.transform_output(forecast, target_scale=x["target_scale"]), - backcast=self.transform_output( - prediction=target - backcast, target_scale=x["target_scale"] - ), - trend=self.transform_output( - torch.stack(trend_forecast, dim=0).sum(0), - target_scale=x["target_scale"], - ), - seasonality=self.transform_output( - torch.stack(seasonal_forecast, dim=0).sum(0), - target_scale=x["target_scale"], - ), - generic=self.transform_output( - torch.stack(generic_forecast, dim=0).sum(0), - target_scale=x["target_scale"], - ), - ) - - @classmethod - def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): - """ - Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - - Args: - dataset (TimeSeriesDataSet): dataset where sole predictor is the target. - **kwargs: additional arguments to be passed to ``__init__`` method. - - Returns: - NBeats - """ # noqa: E501 - new_kwargs = { - "prediction_length": dataset.max_prediction_length, - "context_length": dataset.max_encoder_length, - } - new_kwargs.update(kwargs) - - # validate arguments - assert isinstance( - dataset.target, str - ), "only one target is allowed (passed as string to dataset)" - assert not isinstance( - dataset.target_normalizer, NaNLabelEncoder - ), "only regression tasks are supported - target must not be categorical" - assert dataset.min_encoder_length == dataset.max_encoder_length, ( - "only fixed encoder length is allowed," - " but min_encoder_length != max_encoder_length" - ) - - assert dataset.max_prediction_length == dataset.min_prediction_length, ( - "only fixed prediction length is allowed," - " but max_prediction_length != min_prediction_length" - ) - - assert ( - dataset.randomize_length is None - ), "length has to be fixed, but randomize_length is not None" - assert ( - not dataset.add_relative_time_idx - ), "add_relative_time_idx has to be False" - - assert ( - len(dataset.flat_categoricals) == 0 - and len(dataset.reals) == 1 - and len(dataset._time_varying_unknown_reals) == 1 - and dataset._time_varying_unknown_reals[0] == dataset.target - ), ( - "The only variable as input should be the" - " target which is part of time_varying_unknown_reals" - ) - - # initialize class - return super().from_dataset(dataset, **new_kwargs) - - def step(self, x, y, batch_idx) -> Dict[str, torch.Tensor]: - """ - Take training / validation step. - """ - log, out = super().step(x, y, batch_idx=batch_idx) - - if ( - self.hparams.backcast_loss_ratio > 0 and not self.predicting - ): # add loss from backcast - backcast = out["backcast"] - backcast_weight = ( - self.hparams.backcast_loss_ratio - * self.hparams.prediction_length - / self.hparams.context_length - ) - backcast_weight = backcast_weight / (backcast_weight + 1) # normalize - forecast_weight = 1 - backcast_weight - if isinstance(self.loss, MASE): - backcast_loss = ( - self.loss(backcast, x["encoder_target"], x["decoder_target"]) - * backcast_weight - ) - else: - backcast_loss = ( - self.loss(backcast, x["encoder_target"]) * backcast_weight - ) - label = ["val", "train"][self.training] - self.log( - f"{label}_backcast_loss", - backcast_loss, - on_epoch=True, - on_step=self.training, - batch_size=len(x["decoder_target"]), - ) - self.log( - f"{label}_forecast_loss", - log["loss"], - on_epoch=True, - on_step=self.training, - batch_size=len(x["decoder_target"]), - ) - log["loss"] = log["loss"] * forecast_weight + backcast_loss - - self.log_interpretation(x, out, batch_idx=batch_idx) - return log, out - - def log_interpretation(self, x, out, batch_idx): - """ - Log interpretation of network predictions in tensorboard. - """ - mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - - # Don't log figures if matplotlib or add_figure is not available - if not mpl_available or not self._logger_supports("add_figure"): - return None - - label = ["val", "train"][self.training] - if self.log_interval > 0 and batch_idx % self.log_interval == 0: - fig = self.plot_interpretation(x, out, idx=0) - name = f"{label.capitalize()} interpretation of item 0 in " - if self.training: - name += f"step {self.global_step}" - else: - name += f"batch {batch_idx}" - self.logger.experiment.add_figure(name, fig, global_step=self.global_step) - - def plot_interpretation( - self, - x: Dict[str, torch.Tensor], - output: Dict[str, torch.Tensor], - idx: int, - ax=None, - plot_seasonality_and_generic_on_secondary_axis: bool = False, - ): - """ - Plot interpretation. - - Plot two pannels: prediction and backcast vs actuals and - decomposition of prediction into trend, seasonality and generic forecast. - - Args: - x (Dict[str, torch.Tensor]): network input - output (Dict[str, torch.Tensor]): network output - idx (int): index of sample for which to plot the interpretation. - ax (List[matplotlib axes], optional): list of two matplotlib axes onto which to plot the interpretation. - Defaults to None. - plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot seasonality and - generic forecast on secondary axis in second panel. Defaults to False. - - Returns: - plt.Figure: matplotlib figure - """ # noqa: E501 - _check_matplotlib("plot_interpretation") - - import matplotlib.pyplot as plt - - if ax is None: - fig, ax = plt.subplots(2, 1, figsize=(6, 8)) - else: - fig = ax[0].get_figure() - - time = torch.arange( - -self.hparams.context_length, self.hparams.prediction_length - ) - - # plot target vs prediction - ax[0].plot( - time, - torch.cat([x["encoder_target"][idx], x["decoder_target"][idx]]) - .detach() - .cpu(), - label="target", - ) - ax[0].plot( - time, - torch.cat( - [ - output["backcast"][idx].detach(), - output["prediction"][idx].detach(), - ], - dim=0, - ).cpu(), - label="prediction", - ) - ax[0].set_xlabel("Time") - - # plot blocks - prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) - next(prop_cycle) # prediction - next(prop_cycle) # observations - if plot_seasonality_and_generic_on_secondary_axis: - ax2 = ax[1].twinx() - ax2.set_ylabel("Seasonality / Generic") - else: - ax2 = ax[1] - for title in ["trend", "seasonality", "generic"]: - if title not in self.hparams.stack_types: - continue - if title == "trend": - ax[1].plot( - time, - output[title][idx].detach().cpu(), - label=title.capitalize(), - c=next(prop_cycle)["color"], - ) - else: - ax2.plot( - time, - output[title][idx].detach().cpu(), - label=title.capitalize(), - c=next(prop_cycle)["color"], - ) - ax[1].set_xlabel("Time") - ax[1].set_ylabel("Decomposition") - - fig.legend() - return fig diff --git a/pytorch_forecasting/models/nbeats/_nbeats_adapter.py b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py new file mode 100644 index 000000000..d08d4c5ca --- /dev/null +++ b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py @@ -0,0 +1,322 @@ +""" +N-Beats model adapter for timeseries forecasting without covariates. +""" + +from typing import Dict, List, Optional + +import torch +from torch import nn + +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.data.encoders import NaNLabelEncoder +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric +from pytorch_forecasting.models.base_model import BaseModel +from pytorch_forecasting.models.nbeats.sub_modules import ( + NBEATSGenericBlock, + NBEATSSeasonalBlock, + NBEATSTrendBlock, +) +from pytorch_forecasting.utils._dependencies import _check_matplotlib + + +class NBeatsAdapter(BaseModel): + def __init__( + self, + **kwargs, + ): + """ + Initialize NBeats Adapter. + + Args: + **kwargs: additional arguments to :py:class:`~BaseModel`. + """ # noqa: E501 + super().__init__(**kwargs) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Pass forward of network. + + Args: + x (Dict[str, torch.Tensor]): input from dataloader generated from + :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Returns: + Dict[str, torch.Tensor]: output of model + """ + target = x["encoder_cont"][..., 0] + + timesteps = self.hparams.context_length + self.hparams.prediction_length + generic_forecast = [ + torch.zeros( + (target.size(0), timesteps), dtype=torch.float32, device=self.device + ) + ] + trend_forecast = [ + torch.zeros( + (target.size(0), timesteps), dtype=torch.float32, device=self.device + ) + ] + seasonal_forecast = [ + torch.zeros( + (target.size(0), timesteps), dtype=torch.float32, device=self.device + ) + ] + forecast = torch.zeros( + (target.size(0), self.hparams.prediction_length), + dtype=torch.float32, + device=self.device, + ) + + backcast = target # initialize backcast + for i, block in enumerate(self.net_blocks): + # evaluate block + backcast_block, forecast_block = block(backcast) + + # add for interpretation + full = torch.cat([backcast_block.detach(), forecast_block.detach()], dim=1) + if isinstance(block, NBEATSTrendBlock): + trend_forecast.append(full) + elif isinstance(block, NBEATSSeasonalBlock): + seasonal_forecast.append(full) + else: + generic_forecast.append(full) + + # update backcast and forecast + backcast = ( + backcast - backcast_block + ) # do not use backcast -= backcast_block as this signifies an inline operation # noqa : E501 + forecast = forecast + forecast_block + + return self.to_network_output( + prediction=self.transform_output(forecast, target_scale=x["target_scale"]), + backcast=self.transform_output( + prediction=target - backcast, target_scale=x["target_scale"] + ), + trend=self.transform_output( + torch.stack(trend_forecast, dim=0).sum(0), + target_scale=x["target_scale"], + ), + seasonality=self.transform_output( + torch.stack(seasonal_forecast, dim=0).sum(0), + target_scale=x["target_scale"], + ), + generic=self.transform_output( + torch.stack(generic_forecast, dim=0).sum(0), + target_scale=x["target_scale"], + ), + ) + + @classmethod + def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): + """ + Convenience function to create network from :py:class + `~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Args: + dataset (TimeSeriesDataSet): dataset where sole predictor is the target. + **kwargs: additional arguments to be passed to ``__init__`` method. + + Returns: + NBeats + """ # noqa: E501 + new_kwargs = { + "prediction_length": dataset.max_prediction_length, + "context_length": dataset.max_encoder_length, + } + new_kwargs.update(kwargs) + + # validate arguments + assert isinstance( + dataset.target, str + ), "only one target is allowed (passed as string to dataset)" + assert not isinstance( + dataset.target_normalizer, NaNLabelEncoder + ), "only regression tasks are supported - target must not be categorical" + assert dataset.min_encoder_length == dataset.max_encoder_length, ( + "only fixed encoder length is allowed," + " but min_encoder_length != max_encoder_length" + ) + + assert dataset.max_prediction_length == dataset.min_prediction_length, ( + "only fixed prediction length is allowed," + " but max_prediction_length != min_prediction_length" + ) + + assert ( + dataset.randomize_length is None + ), "length has to be fixed, but randomize_length is not None" + assert ( + not dataset.add_relative_time_idx + ), "add_relative_time_idx has to be False" + + assert ( + len(dataset.flat_categoricals) == 0 + and len(dataset.reals) == 1 + and len(dataset._time_varying_unknown_reals) == 1 + and dataset._time_varying_unknown_reals[0] == dataset.target + ), ( + "The only variable as input should be the" + " target which is part of time_varying_unknown_reals" + ) + + # initialize class + return super().from_dataset(dataset, **new_kwargs) + + def step(self, x, y, batch_idx) -> Dict[str, torch.Tensor]: + """ + Take training / validation step. + """ + log, out = super().step(x, y, batch_idx=batch_idx) + + if ( + self.hparams.backcast_loss_ratio > 0 and not self.predicting + ): # add loss from backcast + backcast = out["backcast"] + backcast_weight = ( + self.hparams.backcast_loss_ratio + * self.hparams.prediction_length + / self.hparams.context_length + ) + backcast_weight = backcast_weight / (backcast_weight + 1) # normalize + forecast_weight = 1 - backcast_weight + if isinstance(self.loss, MASE): + backcast_loss = ( + self.loss(backcast, x["encoder_target"], x["decoder_target"]) + * backcast_weight + ) + else: + backcast_loss = ( + self.loss(backcast, x["encoder_target"]) * backcast_weight + ) + label = ["val", "train"][self.training] + self.log( + f"{label}_backcast_loss", + backcast_loss, + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + self.log( + f"{label}_forecast_loss", + log["loss"], + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + log["loss"] = log["loss"] * forecast_weight + backcast_loss + + self.log_interpretation(x, out, batch_idx=batch_idx) + return log, out + + def log_interpretation(self, x, out, batch_idx): + """ + Log interpretation of network predictions in tensorboard. + """ + mpl_available = _check_matplotlib("log_interpretation", raise_error=False) + + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): + return None + + label = ["val", "train"][self.training] + if self.log_interval > 0 and batch_idx % self.log_interval == 0: + fig = self.plot_interpretation(x, out, idx=0) + name = f"{label.capitalize()} interpretation of item 0 in " + if self.training: + name += f"step {self.global_step}" + else: + name += f"batch {batch_idx}" + self.logger.experiment.add_figure(name, fig, global_step=self.global_step) + + def plot_interpretation( + self, + x: Dict[str, torch.Tensor], + output: Dict[str, torch.Tensor], + idx: int, + ax=None, + plot_seasonality_and_generic_on_secondary_axis: bool = False, + ): + """ + Plot interpretation. + + Plot two pannels: prediction and backcast vs actuals and + decomposition of prediction into trend, seasonality and generic forecast. + + Args: + x (Dict[str, torch.Tensor]): network input + output (Dict[str, torch.Tensor]): network output + idx (int): index of sample for which to plot the interpretation. + ax (List[matplotlib axes], optional): list of two matplotlib axes onto which + to plot the interpretation. Defaults to None. + plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot + seasonality and generic forecast on secondary axis in second panel. + Defaults to False. + + Returns: + plt.Figure: matplotlib figure + """ # noqa: E501 + _check_matplotlib("plot_interpretation") + + import matplotlib.pyplot as plt + + if ax is None: + fig, ax = plt.subplots(2, 1, figsize=(6, 8)) + else: + fig = ax[0].get_figure() + + time = torch.arange( + -self.hparams.context_length, self.hparams.prediction_length + ) + + # plot target vs prediction + ax[0].plot( + time, + torch.cat([x["encoder_target"][idx], x["decoder_target"][idx]]) + .detach() + .cpu(), + label="target", + ) + ax[0].plot( + time, + torch.cat( + [ + output["backcast"][idx].detach(), + output["prediction"][idx].detach(), + ], + dim=0, + ).cpu(), + label="prediction", + ) + ax[0].set_xlabel("Time") + + # plot blocks + prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) + next(prop_cycle) # prediction + next(prop_cycle) # observations + if plot_seasonality_and_generic_on_secondary_axis: + ax2 = ax[1].twinx() + ax2.set_ylabel("Seasonality / Generic") + else: + ax2 = ax[1] + for title in ["trend", "seasonality", "generic"]: + if title not in self.hparams.stack_types: + continue + if title == "trend": + ax[1].plot( + time, + output[title][idx].detach().cpu(), + label=title.capitalize(), + c=next(prop_cycle)["color"], + ) + else: + ax2.plot( + time, + output[title][idx].detach().cpu(), + label=title.capitalize(), + c=next(prop_cycle)["color"], + ) + ax[1].set_xlabel("Time") + ax[1].set_ylabel("Decomposition") + + fig.legend() + return fig diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan.py b/pytorch_forecasting/models/nbeats/_nbeatskan.py new file mode 100644 index 000000000..9df6b3d2e --- /dev/null +++ b/pytorch_forecasting/models/nbeats/_nbeatskan.py @@ -0,0 +1,235 @@ +""" +N-Beats model with KAN blocks for timeseries forecasting without covariates. +""" + +from typing import List, Optional + +import torch +from torch import nn + +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter +from pytorch_forecasting.models.nbeats.sub_modules import ( + NBEATSGenericBlock, + NBEATSSeasonalBlock, + NBEATSTrendBlock, +) + + +class NBeatsKAN(NBeatsAdapter): + def __init__( + self, + stack_types: Optional[List[str]] = None, + num_blocks: Optional[List[int]] = None, + num_block_layers: Optional[List[int]] = None, + widths: Optional[List[int]] = None, + sharing: Optional[List[bool]] = None, + expansion_coefficient_lengths: Optional[List[int]] = None, + prediction_length: int = 1, + context_length: int = 1, + dropout: float = 0.1, + learning_rate: float = 1e-2, + log_interval: int = -1, + log_gradient_flow: bool = False, + log_val_interval: int = None, + weight_decay: float = 1e-3, + loss: MultiHorizonMetric = None, + reduce_on_plateau_patience: int = 1000, + backcast_loss_ratio: float = 0.0, + logging_metrics: nn.ModuleList = None, + num: int = 5, + k: int = 3, + noise_scale: float = 0.5, + scale_base_mu: float = 0.0, + scale_base_sigma: float = 1.0, + scale_sp: float = 1.0, + base_fun: callable = None, + grid_eps: float = 0.02, + grid_range: List[int] = None, + sp_trainable: bool = True, + sb_trainable: bool = True, + sparse_init: bool = False, + **kwargs, + ): + """ + Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible. + + Based on the article + `N-BEATS: Neural basis expansion analysis for interpretable time series + forecasting `_. The network has (if + used as ensemble) outperformed all other methods including ensembles of + traditional statical methods in the M4 competition. The M4 competition is + arguably the most important benchmark for univariate time series forecasting. + + The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently + shown to consistently outperform N-BEATS. + + Args: + stack_types: One of the following values: “generic”, “seasonality" or + “trend". A list of strings of length 1 or 'num_stacks'. Default and + recommended value for generic mode: [“generic”] Recommended value for + interpretable mode: [“trend”,”seasonality”]. + num_blocks: The number of blocks per stack. A list of ints of length 1 or + 'num_stacks'. Default and recommended value for generic mode: [1] + Recommended value for interpretable mode: [3] + num_block_layers: Number of fully connected layers with ReLu activation per + block. + A list of ints of length 1 or 'num_stacks'. Default and recommended + value for generic mode: [4] Recommended value for interpretable mode: + [4]. + width: Widths of the fully connected layers with ReLu activation in the + blocks. A list of ints of length 1 or 'num_stacks'. Default and + recommended value for generic mode: [512]. Recommended value for + interpretable mode: [256, 2048] + sharing: Whether the weights are shared with the other blocks per stack. + A list of ints of length 1 or 'num_stacks'. Default and recommended + value for generic mode: [False]. Recommended value for interpretable + mode: [True]. + expansion_coefficient_length: If the type is “G” (generic), then the length + of the expansion coefficient. + If type is “T” (trend), then it corresponds to the degree of the + polynomial. + If the type is “S” (seasonal) then this is the minimum period allowed, + e.g. 2 for changes every timestep. A list of ints of length 1 or + 'num_stacks'. Default value for generic mode: [32] Recommended value for + interpretable mode: [3] + prediction_length: Length of the prediction. Also known as 'horizon'. + context_length: Number of time units that condition the predictions. + Also known as 'lookback period'. + Should be between 1-10 times the prediction length. + backcast_loss_ratio: weight of backcast in comparison to forecast when + calculating the loss. A weight of 1.0 means that forecast and + backcast loss is weighted the same (regardless of backcast and forecast + lengths). Defaults to 0.0, i.e. no weight. + loss: loss to optimize. Defaults to MASE(). + log_gradient_flow: if to log gradient flow, this takes time and should be + only done to diagnose training failures. + reduce_on_plateau_patience (int): patience after which learning rate is + reduced by a factor of 10 + logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that + are logged during training. Defaults to + nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + num : Parameter for KAN layer. the number of grid intervals = G. + Default: 5. + k : Parameter for KAN layer. the order of piecewise polynomial. Default: 3. + noise_scale : Parameter for KAN layer. the scale of noise injected at + initialization. Default: 0.1. + scale_base_mu : Parameter for KAN layer. the scale of the residual + function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + Deafult: 0.0. + scale_base_sigma : Parameter for KAN layer. the scale of the residual + function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + Deafult: 1.0. + scale_sp : Parameter for KAN layer. the scale of the base function + spline(x). Deafult: 1.0. + base_fun : Parameter for KAN layer. residual function b(x). + Default: None. + grid_eps : Parameter for KAN layer. When grid_eps = 1, the grid is uniform; + when grid_eps = 0, the grid is partitioned using percentiles of samples. + 0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02. + grid_range : Parameter for KAN layer. list/np.array of shape (2,). setting + the range of grids. Default: None. + sp_trainable : Parameter for KAN layer. If true, scale_sp is trainable. + Default: True. + sb_trainable : Parameter for KAN layer. If true, scale_base is trainable. + Default: True. + sparse_init : Parameter for KAN layer. if sparse_init = True, sparse + initialization is applied. Default: False. + **kwargs: additional arguments to :py:class:`~BaseModel`. + """ # noqa: E501 + + if base_fun is None: + base_fun = torch.nn.SiLU() + if grid_range is None: + grid_range = [-1, 1] + if expansion_coefficient_lengths is None: + expansion_coefficient_lengths = [3, 7] + if sharing is None: + sharing = [True, True] + if widths is None: + widths = [32, 512] + if num_block_layers is None: + num_block_layers = [3, 3] + if num_blocks is None: + num_blocks = [3, 3] + if stack_types is None: + stack_types = ["trend", "seasonality"] + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + if loss is None: + loss = MASE() + + self.save_hyperparameters(ignore=["loss", "logging_metrics"]) + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + # Bundle KAN parameters into a dictionary + kan_params = { + "num": num, + "k": k, + "noise_scale": noise_scale, + "scale_base_mu": scale_base_mu, + "scale_base_sigma": scale_base_sigma, + "scale_sp": scale_sp, + "base_fun": base_fun, + "grid_eps": grid_eps, + "grid_range": grid_range, + "sp_trainable": sp_trainable, + "sb_trainable": sb_trainable, + "sparse_init": sparse_init, + } + self.kan_params = kan_params + # setup stacks + self.net_blocks = nn.ModuleList() + for stack_id, stack_type in enumerate(stack_types): + for _ in range(num_blocks[stack_id]): + if stack_type == "generic": + net_block = NBEATSGenericBlock( + units=self.hparams.widths[stack_id], + thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + dropout=dropout, + kan_params=self.kan_params, + use_kan=True, + ) + elif stack_type == "seasonality": + net_block = NBEATSSeasonalBlock( + units=self.hparams.widths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + min_period=expansion_coefficient_lengths[stack_id], + dropout=dropout, + kan_params=self.kan_params, + use_kan=True, + ) + elif stack_type == "trend": + net_block = NBEATSTrendBlock( + units=self.hparams.widths[stack_id], + thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + dropout=dropout, + kan_params=self.kan_params, + use_kan=True, + ) + else: + raise ValueError(f"Unknown stack type {stack_type}") + + self.net_blocks.append(net_block) + + def update_kan_grid(self): + """ + Updates grid of KAN layers when using KAN layers in NBEATSBlock. + """ + for block in self.net_blocks: + # updation logic taken from + # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 + for i, layer in enumerate(block.fc): + # update basis KAN layers' grid + layer.update_grid_from_samples(block.outputs[i]) + # update theta backward and theta forward KAN layers' grid + block.theta_b_fc.update_grid_from_samples(block.outputs[i + 1]) + block.theta_f_fc.update_grid_from_samples(block.outputs[i + 1]) diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index e815ecf40..e1ea1288f 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -9,8 +9,13 @@ import torch.nn as nn import torch.nn.functional as F +from pytorch_forecasting.models.nbeats._kan_layer import KANLayer + def linear(input_size, output_size, bias=True, dropout: int = None): + """ + Initialize linear layers for MLP block layers. + """ lin = nn.Linear(input_size, output_size, bias=bias) if dropout is not None: return nn.Sequential(nn.Dropout(dropout), lin) @@ -21,6 +26,9 @@ def linear(input_size, output_size, bias=True, dropout: int = None): def linspace( backcast_length: int, forecast_length: int, centered: bool = False ) -> Tuple[np.ndarray, np.ndarray]: + """ + Generate linear spaced values for backcast and forecast. + """ if centered: norm = max(backcast_length, forecast_length) start = -backcast_length @@ -45,31 +53,111 @@ def __init__( num_block_layers=4, backcast_length=10, forecast_length=5, - share_thetas=False, dropout=0.1, + kan_params=None, + use_kan=False, ): + """ + Initialize NBeatsSeasonalBlock + + Args: + units: The number of units in the mlp/kan layers. + thetas_dim: The dimension of the parameterized output for the block. + num_block_layers: Number of fully connected mlp/kan layers. Default: 4. + backcast_length: The length of the backcast. Defines how many time units + from the past are used to predict the future. Default: 10. + forecast_length: The length of the forecast, i.e., the number of time steps + ahead to predict. Default: 5. + dropout: The dropout rate applied to the fully connected mlp layers to + prevent overfitting. Default: 0.1. + kan_params (dict): Parameters specific to the KAN layer + (used for modeling using KAN). Default: None. + Contains: + num_grids (int): The number of grid intervals for KAN. + k (int): The order of the piecewise polynomial for KAN. + noise_scale (float): The scale of noise injected at initialization. + scale_base_mu (float): The scale of the residual function + initialized to N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma (float): The scale of the residual function + initialized to N(scale_base_mu, scale_base_sigma^2). + scale_sp (float): The scale of the base function spline(x) in KAN. + base_fun (function): The residual function used by + KAN (e.g., torch.nn.SiLU()). + grid_eps (float): Determines the partitioning of the grid. If 1, + the grid is uniform; if 0, grid is partitioned by percentiles. + grid_range (list or np.array): The range of the grid, given as + a list of two values. + sp_trainable (bool): If True, the scale_sp is trainable. + sb_trainable (bool): If True, the scale_base is trainable. + sparse_init (bool): If True, applies sparse initialization. + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. + """ super().__init__() self.units = units self.thetas_dim = thetas_dim self.backcast_length = backcast_length self.forecast_length = forecast_length - self.share_thetas = share_thetas + self.kan_params = kan_params + self.use_kan = use_kan - fc_stack = [ - nn.Linear(backcast_length, units), - nn.ReLU(), - ] - for _ in range(num_block_layers - 1): - fc_stack.extend([linear(units, units, dropout=dropout), nn.ReLU()]) - self.fc = nn.Sequential(*fc_stack) + if self.use_kan: + layers = [ + KANLayer( + in_dim=backcast_length, + out_dim=units, + **self.kan_params, + ) + ] + + # Add additional layers for deeper structure + for _ in range(num_block_layers - 1): + layers.append( + KANLayer( + in_dim=units, + out_dim=units, + **self.kan_params, + ) + ) + + # Define the fully connected layers + self.fc = nn.Sequential(*layers) + + # Define the theta layers + self.theta_f_fc = self.theta_b_fc = KANLayer( + in_dim=units, + out_dim=thetas_dim, + **self.kan_params, + ) - if share_thetas: - self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) else: - self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) - self.theta_f_fc = nn.Linear(units, thetas_dim, bias=False) + fc_stack = [ + nn.Linear(backcast_length, units), + nn.ReLU(), + ] + for _ in range(num_block_layers - 1): + fc_stack.extend([linear(units, units, dropout=dropout), nn.ReLU()]) + self.fc = nn.Sequential(*fc_stack) + self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) def forward(self, x): + """ + Pass through the fully connected mlp/kan layers and returns the output. + """ + if self.use_kan: + # save outputs to be used in updating grid in kan layers during training + # outputs logic taken from + # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 + self.outputs = [] + self.outputs.append(x.clone().detach()) + for layer in self.fc: + x = layer(x) # Pass data through the current layer + # storing outputs for updating grids of self.fc when using KAN + self.outputs.append(x.clone().detach()) + # storing for updating grids of theta_b_fc and theta_f_fc when using KAN + self.outputs.append(x.clone().detach()) + return x # Return final output return self.fc(x) @@ -84,7 +172,50 @@ def __init__( nb_harmonics=None, min_period=1, dropout=0.1, + kan_params=None, + use_kan=False, ): + """ + Initialize NBeatsSeasonalBlock + + Args: + units: The number of units in the mlp/kan layers. + thetas_dim: The dimension of the parameterized output for the block. + If None, it is inferred. + num_block_layers: Number of fully connected mlp/kan layers. Default: 4. + backcast_length: The length of the backcast. Defines how many time units + from the past are used to predict the future. Default: 10. + forecast_length: The length of the forecast, i.e., the number of time steps + ahead to predict. Default: 5. + nb_harmonics: The number of harmonics in the seasonal function (relevant for + seasonal models). Default: None (no seasonality). + min_period: The minimum period used for seasonal patterns. Default: 1. + dropout: The dropout rate applied to the fully connected mlp layers to + prevent overfitting. Default: 0.1. + kan_params (dict): Parameters specific to the KAN layer + (used for modeling using KAN). Default: None. + Contains: + num_grids (int): The number of grid intervals for KAN. + k (int): The order of the piecewise polynomial for KAN. + noise_scale (float): The scale of noise injected at initialization. + scale_base_mu (float): The scale of the residual function + initialized to N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma (float): The scale of the residual function + initialized to N(scale_base_mu, scale_base_sigma^2). + scale_sp (float): The scale of the base function spline(x) in KAN. + base_fun (function): The residual function used by + KAN (e.g., torch.nn.SiLU()). + grid_eps (float): Determines the partitioning of the grid. If 1, + the grid is uniform; if 0, grid is partitioned by percentiles. + grid_range (list or np.array): The range of the grid, given as + a list of two values. + sp_trainable (bool): If True, the scale_sp is trainable. + sb_trainable (bool): If True, the scale_base is trainable. + sparse_init (bool): If True, applies sparse initialization. + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. + """ if nb_harmonics: thetas_dim = nb_harmonics else: @@ -97,8 +228,9 @@ def __init__( num_block_layers=num_block_layers, backcast_length=backcast_length, forecast_length=forecast_length, - share_thetas=True, dropout=dropout, + kan_params=kan_params, + use_kan=use_kan, ) backcast_linspace, forecast_linspace = linspace( @@ -131,6 +263,9 @@ def __init__( self.register_buffer("S_forecast", torch.cat([s1_f, s2_f])) def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes the backcast and forecast outputs for the given input tensor. + """ x = super().forward(x) amplitudes_backward = self.theta_b_fc(x) backcast = amplitudes_backward.mm(self.S_backcast) @@ -140,6 +275,9 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: return backcast, forecast def get_frequencies(self, n): + """ + Generates frequency values based on the backcast and forecast lengths. + """ return np.linspace( 0, (self.backcast_length + self.forecast_length) / self.min_period, n ) @@ -154,15 +292,56 @@ def __init__( backcast_length=10, forecast_length=5, dropout=0.1, + kan_params=None, + use_kan=False, ): + """ + Initialize NBeatsSeasonalBlock + + Args: + units: The number of units in the mlp/kan layers. + thetas_dim: The dimension of the parameterized output for the block. + If None, it is inferred. + num_block_layers: Number of fully connected mlp/kan layers. Default: 4. + backcast_length: The length of the backcast. Defines how many time units + from the past are used to predict the future. Default: 10. + forecast_length: The length of the forecast, i.e., the number of time steps + ahead to predict. Default: 5. + dropout: The dropout rate applied to the fully connected mlp layers to + prevent overfitting. Default: 0.1. + kan_params (dict): Parameters specific to the KAN layer + (used for modeling using KAN). Default: None. + Contains: + num_grids (int): The number of grid intervals for KAN. + k (int): The order of the piecewise polynomial for KAN. + noise_scale (float): The scale of noise injected at initialization. + scale_base_mu (float): The scale of the residual function + initialized to N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma (float): The scale of the residual function + initialized to N(scale_base_mu, scale_base_sigma^2). + scale_sp (float): The scale of the base function spline(x) in KAN. + base_fun (function): The residual function used by + KAN (e.g., torch.nn.SiLU()). + grid_eps (float): Determines the partitioning of the grid. If 1, + the grid is uniform; if 0, grid is partitioned by percentiles. + grid_range (list or np.array): The range of the grid, given as + a list of two values. + sp_trainable (bool): If True, the scale_sp is trainable. + sb_trainable (bool): If True, the scale_base is trainable. + sparse_init (bool): If True, applies sparse initialization. + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. + """ super().__init__( units=units, thetas_dim=thetas_dim, num_block_layers=num_block_layers, backcast_length=backcast_length, forecast_length=forecast_length, - share_thetas=True, dropout=dropout, + kan_params=kan_params, + use_kan=use_kan, ) backcast_linspace, forecast_linspace = linspace( @@ -184,6 +363,9 @@ def __init__( self.register_buffer("T_forecast", coefficients * norm) def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes the backcast and forecast outputs for the given input tensor. + """ x = super().forward(x) backcast = self.theta_b_fc(x).mm(self.T_backcast) forecast = self.theta_f_fc(x).mm(self.T_forecast) @@ -199,7 +381,47 @@ def __init__( backcast_length=10, forecast_length=5, dropout=0.1, + kan_params=None, + use_kan=False, ): + """ + Initialize NBeatsSeasonalBlock + + Args: + units: The number of units in the mlp/kan layers. + thetas_dim: The dimension of the parameterized output for the block. + If None, it is inferred. + num_block_layers: Number of fully connected mlp/kan layers. Default: 4. + backcast_length: The length of the backcast. Defines how many time units + from the past are used to predict the future. Default: 10. + forecast_length: The length of the forecast, i.e., the number of time steps + ahead to predict. Default: 5. + dropout: The dropout rate applied to the fully connected mlp layers to + prevent overfitting. Default: 0.1. + kan_params (dict): Parameters specific to the KAN layer + (used for modeling using KAN). Default: None. + Contains: + num_grids (int): The number of grid intervals for KAN. + k (int): The order of the piecewise polynomial for KAN. + noise_scale (float): The scale of noise injected at initialization. + scale_base_mu (float): The scale of the residual function + initialized to N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma (float): The scale of the residual function + initialized to N(scale_base_mu, scale_base_sigma^2). + scale_sp (float): The scale of the base function spline(x) in KAN. + base_fun (function): The residual function used by + KAN (e.g., torch.nn.SiLU()). + grid_eps (float): Determines the partitioning of the grid. If 1, + the grid is uniform; if 0, grid is partitioned by percentiles. + grid_range (list or np.array): The range of the grid, given as + a list of two values. + sp_trainable (bool): If True, the scale_sp is trainable. + sb_trainable (bool): If True, the scale_base is trainable. + sparse_init (bool): If True, applies sparse initialization. + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. + """ super().__init__( units=units, thetas_dim=thetas_dim, @@ -207,15 +429,18 @@ def __init__( backcast_length=backcast_length, forecast_length=forecast_length, dropout=dropout, + kan_params=kan_params, + use_kan=use_kan, ) self.backcast_fc = nn.Linear(thetas_dim, backcast_length) self.forecast_fc = nn.Linear(thetas_dim, forecast_length) def forward(self, x): + """ + Computes the backcast and forecast outputs for the given input tensor. + """ x = super().forward(x) - theta_b = F.relu(self.theta_b_fc(x)) theta_f = F.relu(self.theta_f_fc(x)) - return self.backcast_fc(theta_b), self.forecast_fc(theta_f)