From 3bab673c3a39e27223bac9f150df8f673f7809f2 Mon Sep 17 00:00:00 2001 From: SohaibAhmed121 Date: Mon, 13 Jan 2025 11:39:18 +0500 Subject: [PATCH 1/8] Refactored NBeats and added comments for KAN block and NBeats. --- pytorch_forecasting/models/nbeats/__init__.py | 246 +++++++- .../models/nbeats/kan_layer.py | 532 ++++++++++++++++++ .../models/nbeats/sub_modules.py | 224 +++++++- 3 files changed, 987 insertions(+), 15 deletions(-) create mode 100644 pytorch_forecasting/models/nbeats/kan_layer.py diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index 8d00392cc..114e4efac 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -8,7 +8,8 @@ from torch import nn from pytorch_forecasting.data import TimeSeriesDataSet -from pytorch_forecasting.data.encoders import NaNLabelEncoder + +# from pytorch_forecasting.data.encoders import NaNLabelEncoder from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric from pytorch_forecasting.models.base_model import BaseModel from pytorch_forecasting.models.nbeats.sub_modules import NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock @@ -26,6 +27,19 @@ def __init__( expansion_coefficient_lengths: Optional[List[int]] = None, prediction_length: int = 1, context_length: int = 1, + use_kan: bool = False, + num_grids: int = 5, + k: int = 3, + noise_scale: float = 0.5, + scale_base_mu: float = 0.0, + scale_base_sigma: float = 1.0, + scale_sp: float = 1.0, + base_fun: callable = torch.nn.SiLU(), + grid_eps: float = 0.02, + grid_range: List[int] = [-1, 1], + sp_trainable: bool = True, + sb_trainable: bool = True, + sparse_init: bool = False, dropout: float = 0.1, learning_rate: float = 1e-2, log_interval: int = -1, @@ -76,6 +90,24 @@ def __init__( prediction_length: Length of the prediction. Also known as 'horizon'. context_length: Number of time units that condition the predictions. Also known as 'lookback period'. Should be between 1-10 times the prediction length. + num_grids : Parameter for KAN layer. the number of grid intervals = G. Default: 5. + k : Parameter for KAN layer. the order of piecewise polynomial. Default: 3. + noise_scale : Parameter for KAN layer. the scale of noise injected at initialization. Default: 0.1. + scale_base_mu : Parameter for KAN layer. the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). Deafult: 0.0 + scale_base_sigma : Parameter for KAN layer. the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). Deafult: 1.0 + scale_sp : Parameter for KAN layer. the scale of the base function spline(x). Deafult: 1.0 + base_fun : Parameter for KAN layer. residual function b(x). Default: torch.nn.SiLU() + grid_eps : Parameter for KAN layer. When grid_eps = 1, the grid is uniform; when grid_eps = 0, + the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the + two extremes. Deafult: 0.02 + grid_range : Parameter for KAN layer. list/np.array of shape (2,). setting the range of grids. + Default: [-1,1]. + sp_trainable : Parameter for KAN layer. If true, scale_sp is trainable. Default: True. + sb_trainable : Parameter for KAN layer. If true, scale_base is trainable. Default: True. + sparse_init : Parameter for KAN layer. if sparse_init = True, sparse initialization is applied. + Default: False. backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss. A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and forecast lengths). Defaults to 0.0, i.e. no weight. @@ -103,6 +135,23 @@ def __init__( logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) if loss is None: loss = MASE() + # Bundle KAN parameters into a dictionary + self.kan_params = { + "use_kan": use_kan, + "num_grids": num_grids, + "k": k, + "noise_scale": noise_scale, + "scale_base_mu": scale_base_mu, + "scale_base_sigma": scale_base_sigma, + "scale_sp": scale_sp, + "base_fun": base_fun, + "grid_eps": grid_eps, + "grid_range": grid_range, + "sp_trainable": sp_trainable, + "sb_trainable": sb_trainable, + "sparse_init": sparse_init, + } + self.save_hyperparameters() super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) @@ -118,6 +167,7 @@ def __init__( backcast_length=context_length, forecast_length=prediction_length, dropout=self.hparams.dropout, + kan_params=self.hparams.kan_params, ) elif stack_type == "seasonality": net_block = NBEATSSeasonalBlock( @@ -127,6 +177,7 @@ def __init__( forecast_length=prediction_length, min_period=self.hparams.expansion_coefficient_lengths[stack_id], dropout=self.hparams.dropout, + kan_params=self.hparams.kan_params, ) elif stack_type == "trend": net_block = NBEATSTrendBlock( @@ -136,6 +187,7 @@ def __init__( backcast_length=context_length, forecast_length=prediction_length, dropout=self.hparams.dropout, + kan_params=self.hparams.kan_params, ) else: raise ValueError(f"Unknown stack type {stack_type}") @@ -374,3 +426,195 @@ def plot_interpretation( fig.legend() return fig + + +# from sktime.datasets import load_airline +# import pandas as pd +# from pytorch_forecasting.data import TimeSeriesDataSet +# import lightning.pytorch as pl +# from lightning.pytorch.callbacks import EarlyStopping + +# # Load the airline dataset +# y = load_airline() + +# # Convert to DataFrame and reset index for clarity +# df = y.reset_index() + +# # Add a 'time_idx' column with values same as the index of rows +# df["time_idx"] = df.index + +# # Display the DataFrame +# data = df.drop(columns=["Period"]) +# data["series"] = 0 +# # data["value"] = data["Number of airline passengers"]+20 + + +# # create dataset and dataloaders +# max_encoder_length = 60 +# max_prediction_length = 20 + +# training_cutoff = data["time_idx"].max() - max_prediction_length + +# context_length = max_encoder_length +# prediction_length = max_prediction_length + +# training = TimeSeriesDataSet( +# data[lambda x: x.time_idx <= training_cutoff], +# time_idx="time_idx", +# target="Number of airline passengers", +# categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, +# group_ids=["series"], +# # only unknown variable is "value" - and N-Beats can also not take any additional variables +# time_varying_unknown_reals=["Number of airline passengers"], +# max_encoder_length=context_length, +# max_prediction_length=prediction_length, +# ) +# print("hazrat") +# validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training_cutoff + 1) +# batch_size = 2 +# train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0) +# val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0) + +# pl.seed_everything(42) +# trainer = pl.Trainer(accelerator="auto", gradient_clip_val=0.01) +# net = NBeats.from_dataset( +# training, +# learning_rate=1e-3, +# log_interval=10, +# log_val_interval=1, +# weight_decay=1e-2, +# widths=[32, 512], +# backcast_loss_ratio=1.0, +# ) + +# early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min") +# trainer = pl.Trainer( +# max_epochs=2, +# accelerator="auto", +# enable_model_summary=True, +# gradient_clip_val=0.1, +# callbacks=[early_stop_callback], +# limit_train_batches=150, +# ) + +# trainer.fit( +# net, +# train_dataloaders=train_dataloader, +# val_dataloaders=val_dataloader, +# ) + +# best_model_path = trainer.checkpoint_callback.best_model_path +# best_model = NBeats.load_from_checkpoint(best_model_path) + +# # for x, y in iter(val_dataloader): +# # for y in y: +# # print(y,type(y)) +# # actuals = torch.cat([y for x, y in iter(val_dataloader)]).to("cpu") +# # actuals = [y_tensors[0] for _, y_tensors in iter(val_dataloader)][0] + +# # print(actuals) + +# # predictions = best_model.predict(val_dataloader, trainer_kwargs=dict(accelerator="cpu")) +# # print(predictions) +# # predictions_tensor = torch.cat(predictions) +# # actuals_tensor = torch.cat(actuals) + +# # # Calculate the absolute error and mean +# # error = (actuals_tensor - predictions_tensor).abs().mean() + +# # print(f"Mean absolute error: {error}") +# import matplotlib.pyplot as plt + +# raw_predictions = best_model.predict(val_dataloader, mode="raw", return_x=True) + +# for idx in range(10): # plot 10 examples +# figure = best_model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True) +# plt.show() + + +import warnings + +warnings.filterwarnings("ignore") +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +import pandas as pd +import torch + +from pytorch_forecasting import TimeSeriesDataSet +from pytorch_forecasting.data import NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data + + +data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42) +data["static"] = 2 +data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") +data.head() + +# create dataset and dataloaders +max_encoder_length = 60 +max_prediction_length = 20 + +training_cutoff = data["time_idx"].max() - max_prediction_length + +context_length = max_encoder_length +prediction_length = max_prediction_length + +training = TimeSeriesDataSet( + data[lambda x: x.time_idx <= training_cutoff], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + # only unknown variable is "value" - and N-Beats can also not take any additional variables + time_varying_unknown_reals=["value"], + max_encoder_length=context_length, + max_prediction_length=prediction_length, +) + +validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training_cutoff + 1) +batch_size = 128 +train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0) +val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0) + +pl.seed_everything(42) +trainer = pl.Trainer(accelerator="auto", gradient_clip_val=0.01) +# net = NBeats.from_dataset(training, learning_rate=3e-2, weight_decay=1e-2, widths=[32, 512], backcast_loss_ratio=0.1) +net = NBeats.from_dataset( + training, + learning_rate=1e-3, + log_interval=10, + log_val_interval=1, + weight_decay=1e-2, + widths=[32, 512], + backcast_loss_ratio=1.0, + num_block_layers=[3, 3], +) + +early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min") +trainer = pl.Trainer( + max_epochs=1, + accelerator="auto", + enable_model_summary=True, + gradient_clip_val=0.1, + callbacks=[early_stop_callback], + limit_train_batches=150, +) + +trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, +) + +best_model_path = trainer.checkpoint_callback.best_model_path +best_model = NBeats.load_from_checkpoint(best_model_path) + +raw_predictions = best_model.predict(val_dataloader, mode="raw", return_x=True) +print(best_model) +import matplotlib.pyplot as plt + +raw_predictions = best_model.predict(val_dataloader, mode="raw", return_x=True) + +for idx in range(10): # plot 10 examples + figure = best_model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True) + plt.show() diff --git a/pytorch_forecasting/models/nbeats/kan_layer.py b/pytorch_forecasting/models/nbeats/kan_layer.py new file mode 100644 index 000000000..13aa77802 --- /dev/null +++ b/pytorch_forecasting/models/nbeats/kan_layer.py @@ -0,0 +1,532 @@ +import torch +import torch.nn as nn +import numpy as np + + +def B_batch(x, grid, k=0, extend=True, device="cpu"): + """ + evaludate x on B-spline bases + + Args: + ----- + x : 2D torch.tensor + inputs, shape (number of splines, number of samples) + grid : 2D torch.tensor + grids, shape (number of splines, number of grid points) + k : int + the piecewise polynomial order of splines. + extend : bool + If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True + device : str + devicde + + Returns: + -------- + spline values : 3D torch.tensor + shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order. + + Example + ------- + >>> from kan.spline import B_batch + >>> x = torch.rand(100,2) + >>> grid = torch.linspace(-1,1,steps=11)[None, :].expand(2, 11) + >>> B_batch(x, grid, k=3).shape + """ + + x = x.unsqueeze(dim=2) + grid = grid.unsqueeze(dim=0) + + if k == 0: + value = (x >= grid[:, :, :-1]) * (x < grid[:, :, 1:]) + else: + B_km1 = B_batch(x[:, :, 0], grid=grid[0], k=k - 1) + + value = (x - grid[:, :, : -(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, : -(k + 1)]) * B_km1[:, :, :-1] + ( + grid[:, :, k + 1 :] - x + ) / (grid[:, :, k + 1 :] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:] + + # in case grid is degenerate + value = torch.nan_to_num(value) + return value + + +def coef2curve(x_eval, grid, coef, k, device="cpu"): + """ + converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves + (summing up B_batch results over B-spline basis). + + Args: + ----- + x_eval : 2D torch.tensor + shape (batch, in_dim) + grid : 2D torch.tensor + shape (in_dim, G+2k). G: the number of grid intervals; k: spline order. + coef : 3D torch.tensor + shape (in_dim, out_dim, G+k) + k : int + the piecewise polynomial order of splines. + device : str + devicde + + Returns: + -------- + y_eval : 3D torch.tensor + shape (batch, in_dim, out_dim) + + """ + + b_splines = B_batch(x_eval, grid, k=k) + y_eval = torch.einsum("ijk,jlk->ijl", b_splines, coef.to(b_splines.device)) + + return y_eval + + +def curve2coef(x_eval, y_eval, grid, k): + """ + converting B-spline curves to B-spline coefficients using least squares. + + Args: + ----- + x_eval : 2D torch.tensor + shape (batch, in_dim) + y_eval : 3D torch.tensor + shape (batch, in_dim, out_dim) + grid : 2D torch.tensor + shape (in_dim, grid+2*k) + k : int + spline order + lamb : float + regularized least square lambda + + Returns: + -------- + coef : 3D torch.tensor + shape (in_dim, out_dim, G+k) + """ + # print('haha', x_eval.shape, y_eval.shape, grid.shape) + batch = x_eval.shape[0] + in_dim = x_eval.shape[1] + out_dim = y_eval.shape[2] + n_coef = grid.shape[1] - k - 1 + mat = B_batch(x_eval, grid, k) + mat = mat.permute(1, 0, 2)[:, None, :, :].expand(in_dim, out_dim, batch, n_coef) + y_eval = y_eval.permute(1, 2, 0).unsqueeze(dim=3) + try: + coef = torch.linalg.lstsq(mat, y_eval).solution[:, :, :, 0] + except Exception as e: + print(f"lstsq failed with error: {e}") + + # manual psuedo-inverse + """lamb=1e-8 + XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat) + Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval) + n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2] + identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device) + A = XtX + lamb * identity + B = Xty + coef = (A.pinverse() @ B)[:,:,:,0]""" + + return coef + + +def extend_grid(grid, k_extend=0): + """ + extend grid + """ + h = (grid[:, [-1]] - grid[:, [0]]) / (grid.shape[1] - 1) + + for i in range(k_extend): + grid = torch.cat([grid[:, [0]] - h, grid], dim=1) + grid = torch.cat([grid, grid[:, [-1]] + h], dim=1) + + return grid + + +def sparse_mask(in_dim, out_dim): + """ + get sparse mask + """ + in_coord = torch.arange(in_dim) * 1 / in_dim + 1 / (2 * in_dim) + out_coord = torch.arange(out_dim) * 1 / out_dim + 1 / (2 * out_dim) + + dist_mat = torch.abs(out_coord[:, None] - in_coord[None, :]) + in_nearest = torch.argmin(dist_mat, dim=0) + in_connection = torch.stack([torch.arange(in_dim), in_nearest]).permute(1, 0) + out_nearest = torch.argmin(dist_mat, dim=1) + out_connection = torch.stack([out_nearest, torch.arange(out_dim)]).permute(1, 0) + all_connection = torch.cat([in_connection, out_connection], dim=0) + mask = torch.zeros(in_dim, out_dim) + mask[all_connection[:, 0], all_connection[:, 1]] = 1.0 + + return mask + + +class KANLayer(nn.Module): + """ + KANLayer class + + + Attributes: + ----------- + in_dim: int + input dimension + out_dim: int + output dimension + num: int + the number of grid intervals + k: int + the piecewise polynomial order of splines + noise_scale: float + spline scale at initialization + coef: 2D torch.tensor + coefficients of B-spline bases + scale_base_mu: float + magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu + scale_base_sigma: float + magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma + scale_sp: float + mangitude of the spline function spline(x) + base_fun: fun + residual function b(x) + mask: 1D torch.float + mask of spline functions. setting some element of the mask to zero means setting the + corresponding activation to zero function. + grid_eps: float in [0,1] + a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; + when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 + interpolates between the two extremes. + the id of activation functions that are locked + """ + + def __init__( + self, + in_dim=3, + out_dim=2, + num=5, + k=3, + noise_scale=0.5, + scale_base_mu=0.0, + scale_base_sigma=1.0, + scale_sp=1.0, + base_fun=torch.nn.SiLU(), + grid_eps=0.02, + grid_range=[-1, 1], + sp_trainable=True, + sb_trainable=True, + device="cpu", + sparse_init=False, + ): + """' + initialize a KANLayer + + Args: + ----- + in_dim : int + input dimension. Default: 2. + out_dim : int + output dimension. Default: 3. + num : int + the number of grid intervals = G. Default: 5. + k : int + the order of piecewise polynomial. Default: 3. + noise_scale : float + the scale of noise injected at initialization. Default: 0.1. + scale_base_mu : float + the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma : float + the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + scale_sp : float + the scale of the base function spline(x). + base_fun : function + residual function b(x). Default: torch.nn.SiLU() + grid_eps : float + When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using + percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. + grid_range : list/np.array of shape (2,) + setting the range of grids. Default: [-1,1]. + sp_trainable : bool + If true, scale_sp is trainable + sb_trainable : bool + If true, scale_base is trainable + sparse_init : bool + if sparse_init = True, sparse initialization is applied. + + Returns: + -------- + self + + Example + ------- + >>> from kan.KANLayer import * + >>> model = KANLayer(in_dim=3, out_dim=5) + >>> (model.in_dim, model.out_dim) + """ + super(KANLayer, self).__init__() + # size + self.out_dim = out_dim + self.in_dim = in_dim + self.num = num + self.k = k + + grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None, :].expand(self.in_dim, num + 1) + grid = extend_grid(grid, k_extend=k) + self.grid = torch.nn.Parameter(grid).requires_grad_(False) + noises = (torch.rand(self.num + 1, self.in_dim, self.out_dim) - 1 / 2) * noise_scale / num + + self.coef = torch.nn.Parameter(curve2coef(self.grid[:, k:-k].permute(1, 0), noises, self.grid, k)) + + if sparse_init: + self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_(False) + else: + self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_(False) + + self.scale_base = torch.nn.Parameter( + scale_base_mu * 1 / np.sqrt(in_dim) + + scale_base_sigma * (torch.rand(in_dim, out_dim) * 2 - 1) * 1 / np.sqrt(in_dim) + ).requires_grad_(sb_trainable) + self.scale_sp = torch.nn.Parameter( + torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask + ).requires_grad_( + sp_trainable + ) # make scale trainable + self.base_fun = base_fun + + self.grid_eps = grid_eps + + def forward(self, x): + """ + KANLayer forward given input x + + Args: + ----- + x : 2D torch.float + inputs, shape (number of samples, input dimension) + + Returns: + -------- + y : 2D torch.float + outputs, shape (number of samples, output dimension) + preacts : 3D torch.float + fan out x into activations, shape (number of sampels, output dimension, input dimension) + postacts : 3D torch.float + the outputs of activation functions with preacts as inputs + postspline : 3D torch.float + the outputs of spline functions with preacts as inputs + + Example + ------- + >>> from kan.KANLayer import * + >>> model = KANLayer(in_dim=3, out_dim=5) + >>> x = torch.normal(0,1,size=(100,3)) + >>> y, preacts, postacts, postspline = model(x) + >>> y.shape, preacts.shape, postacts.shape, postspline.shape + """ + + base = self.base_fun(x) # (batch, in_dim) + y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k) + y = self.scale_base[None, :, :] * base[:, :, None] + self.scale_sp[None, :, :] * y + y = self.mask[None, :, :] * y + y = torch.sum(y, dim=1) + return y + + def update_grid_from_samples(self, x, mode="sample"): + """ + update grid from samples + + Args: + ----- + x : 2D torch.float + inputs, shape (number of samples, input dimension) + + Returns: + -------- + None + + Example + ------- + >>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) + >>> print(model.grid.data) + >>> x = torch.linspace(-3,3,steps=100)[:,None] + >>> model.update_grid_from_samples(x) + >>> print(model.grid.data) + """ + + batch = x.shape[0] + # x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size) + # .permute(1, 0) + x_pos = torch.sort(x, dim=0)[0] + y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) + num_interval = self.grid.shape[1] - 1 - 2 * self.k + + def get_grid(num_interval): + ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] + grid_adaptive = x_pos[ids, :].permute(1, 0) + margin = 0.00 + h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) / num_interval + grid_uniform = ( + grid_adaptive[:, [0]] + - margin + + h + * torch.arange( + num_interval + 1, + )[ + None, : + ].to(x.device) + ) + grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive + return grid + + grid = get_grid(num_interval) + + if mode == "grid": + sample_grid = get_grid(2 * num_interval) + x_pos = sample_grid.permute(1, 0) + y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) + + self.grid.data = extend_grid(grid, k_extend=self.k) + # print('x_pos 2', x_pos.shape) + # print('y_eval 2', y_eval.shape) + self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) + + def initialize_grid_from_parent(self, parent, x, mode="sample"): + """ + update grid from a parent KANLayer & samples + + Args: + ----- + parent : KANLayer + a parent KANLayer (whose grid is usually coarser than the current model) + x : 2D torch.float + inputs, shape (number of samples, input dimension) + + Returns: + -------- + None + + Example + ------- + >>> batch = 100 + >>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) + >>> print(parent_model.grid.data) + >>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3) + >>> x = torch.normal(0,1,size=(batch, 1)) + >>> model.initialize_grid_from_parent(parent_model, x) + >>> print(model.grid.data) + """ + # shrink grid + x_pos = torch.sort(x, dim=0)[0] + y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) + num_interval = self.grid.shape[1] - 1 - 2 * self.k + + """ + # based on samples + def get_grid(num_interval): + ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] + grid_adaptive = x_pos[ids, :].permute(1,0) + h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval + grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device) + grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive + return grid""" + + # based on interpolating parent grid + def get_grid(num_interval): + x_pos = parent.grid[:, parent.k : -parent.k] + # print('x_pos', x_pos) + sp2 = KANLayer( + in_dim=1, out_dim=self.in_dim, k=1, num=x_pos.shape[1] - 1, scale_base_mu=0.0, scale_base_sigma=0.0 + ).to(x.device) + + # print('sp2_grid', sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim)) + # print('sp2_coef_shape', sp2.coef.shape) + sp2_coef = curve2coef( + sp2.grid[:, sp2.k : -sp2.k].permute(1, 0).expand(-1, self.in_dim), + x_pos.permute(1, 0).unsqueeze(dim=2), + sp2.grid[:, :], + k=1, + ).permute(1, 0, 2) + sp2.coef.data = sp2_coef + percentile = torch.linspace(-1, 1, self.num + 1).to(self.device) + grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0) + return grid + + grid = get_grid(num_interval) + + if mode == "grid": + sample_grid = get_grid(2 * num_interval) + x_pos = sample_grid.permute(1, 0) + y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) + + grid = extend_grid(grid, k_extend=self.k) + self.grid.data = grid + self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) + + def get_subset(self, in_id, out_id): + """ + get a smaller KANLayer from a larger KANLayer (used for pruning) + + Args: + ----- + in_id : list + id of selected input neurons + out_id : list + id of selected output neurons + + Returns: + -------- + spb : KANLayer + + Example + ------- + >>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3) + >>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3]) + >>> kanlayer_small.in_dim, kanlayer_small.out_dim + (2, 3) + """ + spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun) + spb.grid.data = self.grid[in_id] + spb.coef.data = self.coef[in_id][:, out_id] + spb.scale_base.data = self.scale_base[in_id][:, out_id] + spb.scale_sp.data = self.scale_sp[in_id][:, out_id] + spb.mask.data = self.mask[in_id][:, out_id] + + spb.in_dim = len(in_id) + spb.out_dim = len(out_id) + return spb + + def swap(self, i1, i2, mode="in"): + """ + swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out') + + Args: + ----- + i1 : int + i2 : int + mode : str + mode = 'in' or 'out' + + Returns: + -------- + None + + Example + ------- + >>> from kan.KANLayer import * + >>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3) + >>> print(model.coef) + >>> model.swap(0,1,mode='in') + >>> print(model.coef) + """ + with torch.no_grad(): + + def swap_(data, i1, i2, mode="in"): + if mode == "in": + data[i1], data[i2] = data[i2].clone(), data[i1].clone() + elif mode == "out": + data[:, i1], data[:, i2] = data[:, i2].clone(), data[:, i1].clone() + + if mode == "in": + swap_(self.grid.data, i1, i2, mode="in") + swap_(self.coef.data, i1, i2, mode=mode) + swap_(self.scale_base.data, i1, i2, mode=mode) + swap_(self.scale_sp.data, i1, i2, mode=mode) + swap_(self.mask.data, i1, i2, mode=mode) diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index e300d452f..f78ba1ac4 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from kan_layer import KANLayer def linear(input_size, output_size, bias=True, dropout: int = None): @@ -33,6 +34,65 @@ def linspace(backcast_length: int, forecast_length: int, centered: bool = False) return b_ls, f_ls +# class NBEATSBlock(nn.Module): +# def __init__( +# self, +# units, +# thetas_dim, +# num_block_layers=4, +# backcast_length=10, +# forecast_length=5, +# share_thetas=False, +# num_grid_intervals=5, +# k_order=3, +# dropout=0.1, +# ): +# super().__init__() +# self.units = units +# self.thetas_dim = thetas_dim +# self.backcast_length = backcast_length +# self.forecast_length = forecast_length +# self.share_thetas = share_thetas +# # First KANLayer +# layers = [ +# KANLayer( +# in_dim=backcast_length, +# out_dim=units, +# num=num_grid_intervals, +# k=k_order, +# device="cpu", +# ) +# ] +# # Additional KANLayers for deeper structure +# for _ in range(num_block_layers - 1): +# layers.extend( +# [ +# KANLayer( +# in_dim=units, +# out_dim=units, +# num=num_grid_intervals, +# k=k_order, +# device="cpu", +# ) +# ] +# ) +# self.fc = nn.Sequential(*layers) +# # print(self.fc) +# # Theta layers +# if share_thetas: +# self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) +# else: +# self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) +# self.theta_f_fc = nn.Linear(units, thetas_dim, bias=False) + +# def forward(self, x): +# # x = x.unsqueeze(0) +# # print(x.shape,"here") +# y = self.fc(x) +# # print("bhen") +# return y + + class NBEATSBlock(nn.Module): def __init__( self, @@ -43,6 +103,7 @@ def __init__( forecast_length=5, share_thetas=False, dropout=0.1, + kan_params={}, ): super().__init__() self.units = units @@ -50,6 +111,52 @@ def __init__( self.backcast_length = backcast_length self.forecast_length = forecast_length self.share_thetas = share_thetas + self.kan_params = kan_params + + if self.kan_params["use_kan"]: + layers = [ + KANLayer( + in_dim=backcast_length, + out_dim=units, + num=self.kan_params["num_grids"], + k=self.kan_params["k"], + noise_scale=self.kan_params["noise_scale"], + scale_base_mu=self.kan_params["scale_base_mu"], + scale_base_sigma=self.kan_params["scale_base_sigma"], + scale_sp=self.kan_params["scale_sp"], + base_fun=self.kan_params["base_fun"], + grid_eps=self.kan_params["grid_eps"], + grid_range=self.kan_params["grid_range"], + sp_trainable=self.kan_params["sp_trainable"], + sb_trainable=self.kan_params["sb_trainable"], + sparse_init=self.kan_params["sparse_init"], + ) + ] + # Additional KANLayers for deeper structure + for _ in range(num_block_layers - 1): + layers.extend( + [ + KANLayer( + in_dim=units, + out_dim=units, + num=self.kan_params["num_grids"], + k=self.kan_params["k"], + noise_scale=self.kan_params["noise_scale"], + scale_base_mu=self.kan_params["scale_base_mu"], + scale_base_sigma=self.kan_params["scale_base_sigma"], + scale_sp=self.kan_params["scale_sp"], + base_fun=self.kan_params["base_fun"], + grid_eps=self.kan_params["grid_eps"], + grid_range=self.kan_params["grid_range"], + sp_trainable=self.kan_params["sp_trainable"], + sb_trainable=self.kan_params["sb_trainable"], + sparse_init=self.kan_params["sparse_init"], + device="cpu", # Assuming you are using the "cpu" device + ) + ] + ) + + self.fc = nn.Sequential(*layers) fc_stack = [ nn.Linear(backcast_length, units), @@ -80,7 +187,41 @@ def __init__( nb_harmonics=None, min_period=1, dropout=0.1, + kan_params={}, ): + """ + Initialize NBeatsSeasonalBlock + + Args: + units: The number of units in the mlp/kan layers. Default: 256. + thetas_dim: The dimension of the parameterized output for the block. If None, it is inferred. Default: None. + num_block_layers: Number of fully connected mlp/kan layers. Default: 4. + backcast_length: The length of the backcast. Defines how many time units from the past are used to + predict the future. Default: 10. + forecast_length: The length of the forecast, i.e., the number of time steps ahead to predict. Default: 5. + nb_harmonics: The number of harmonics in the seasonal function (relevant for seasonal models). + Default: None (no seasonality). + min_period: The minimum period used for seasonal patterns. Default: 1. + dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. + kan_params (dict): Parameters specific to the KAN layer (used for modeling using KAN). + Default: empty dictionary. + Contains: + num_grids (int): The number of grid intervals for KAN. + k (int): The order of the piecewise polynomial for KAN. + noise_scale (float): The scale of noise injected at initialization. + scale_base_mu (float): The scale of the residual function initialized to + N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma (float): The scale of the residual function initialized to + N(scale_base_mu, scale_base_sigma^2). + scale_sp (float): The scale of the base function spline(x) in KAN. + base_fun (function): The residual function used by KAN (e.g., torch.nn.SiLU()). + grid_eps (float): Determines the partitioning of the grid. If 1, the grid is uniform; if 0, + grid is partitioned by percentiles. + grid_range (list or np.array): The range of the grid, given as a list of two values. + sp_trainable (bool): If True, the scale_sp is trainable. + sb_trainable (bool): If True, the scale_base is trainable. + sparse_init (bool): If True, applies sparse initialization. + """ if nb_harmonics: thetas_dim = nb_harmonics else: @@ -95,6 +236,7 @@ def __init__( forecast_length=forecast_length, share_thetas=True, dropout=dropout, + kan_params=kan_params, ) backcast_linspace, forecast_linspace = linspace(backcast_length, forecast_length, centered=False) @@ -117,6 +259,7 @@ def __init__( self.register_buffer("S_forecast", torch.cat([s1_f, s2_f])) def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes the backcast and forecast outputs for the given input tensor.""" x = super().forward(x) amplitudes_backward = self.theta_b_fc(x) backcast = amplitudes_backward.mm(self.S_backcast) @@ -126,19 +269,46 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: return backcast, forecast def get_frequencies(self, n): + """ + Generates frequency values based on the backcast and forecast lengths. + """ return np.linspace(0, (self.backcast_length + self.forecast_length) / self.min_period, n) class NBEATSTrendBlock(NBEATSBlock): def __init__( - self, - units, - thetas_dim, - num_block_layers=4, - backcast_length=10, - forecast_length=5, - dropout=0.1, + self, units, thetas_dim, num_block_layers=4, backcast_length=10, forecast_length=5, dropout=0.1, kan_params={} ): + """ + Initialize NBeatsSeasonalBlock + + Args: + units: The number of units in the mlp/kan layers. Default: 256. + thetas_dim: The dimension of the parameterized output for the block. If None, it is inferred. Default: None. + num_block_layers: Number of fully connected mlp/kan layers. Default: 4. + backcast_length: The length of the backcast. Defines how many time units from the past are used to + predict the future. Default: 10. + forecast_length: The length of the forecast, i.e., the number of time steps ahead to predict. Default: 5. + dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. + kan_params (dict): Parameters specific to the KAN layer (used for modeling using KAN). + Default: empty dictionary. + Contains: + num_grids (int): The number of grid intervals for KAN. + k (int): The order of the piecewise polynomial for KAN. + noise_scale (float): The scale of noise injected at initialization. + scale_base_mu (float): The scale of the residual function initialized to + N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma (float): The scale of the residual function initialized to + N(scale_base_mu, scale_base_sigma^2). + scale_sp (float): The scale of the base function spline(x) in KAN. + base_fun (function): The residual function used by KAN (e.g., torch.nn.SiLU()). + grid_eps (float): Determines the partitioning of the grid. If 1, the grid is uniform; if 0, + grid is partitioned by percentiles. + grid_range (list or np.array): The range of the grid, given as a list of two values. + sp_trainable (bool): If True, the scale_sp is trainable. + sb_trainable (bool): If True, the scale_base is trainable. + sparse_init (bool): If True, applies sparse initialization. + """ super().__init__( units=units, thetas_dim=thetas_dim, @@ -147,6 +317,7 @@ def __init__( forecast_length=forecast_length, share_thetas=True, dropout=dropout, + kan_params=kan_params, ) backcast_linspace, forecast_linspace = linspace(backcast_length, forecast_length, centered=True) @@ -167,14 +338,38 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: class NBEATSGenericBlock(NBEATSBlock): def __init__( - self, - units, - thetas_dim, - num_block_layers=4, - backcast_length=10, - forecast_length=5, - dropout=0.1, + self, units, thetas_dim, num_block_layers=4, backcast_length=10, forecast_length=5, dropout=0.1, kan_params={} ): + """ + Initialize NBeatsSeasonalBlock + + Args: + units: The number of units in the mlp/kan layers. Default: 256. + thetas_dim: The dimension of the parameterized output for the block. If None, it is inferred. Default: None. + num_block_layers: Number of fully connected mlp/kan layers. Default: 4. + backcast_length: The length of the backcast. Defines how many time units from the past are used to + predict the future. Default: 10. + forecast_length: The length of the forecast, i.e., the number of time steps ahead to predict. Default: 5. + dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. + kan_params (dict): Parameters specific to the KAN layer (used for modeling using KAN). + Default: empty dictionary. + Contains: + num_grids (int): The number of grid intervals for KAN. + k (int): The order of the piecewise polynomial for KAN. + noise_scale (float): The scale of noise injected at initialization. + scale_base_mu (float): The scale of the residual function initialized to + N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma (float): The scale of the residual function initialized to + N(scale_base_mu, scale_base_sigma^2). + scale_sp (float): The scale of the base function spline(x) in KAN. + base_fun (function): The residual function used by KAN (e.g., torch.nn.SiLU()). + grid_eps (float): Determines the partitioning of the grid. If 1, the grid is uniform; if 0, + grid is partitioned by percentiles. + grid_range (list or np.array): The range of the grid, given as a list of two values. + sp_trainable (bool): If True, the scale_sp is trainable. + sb_trainable (bool): If True, the scale_base is trainable. + sparse_init (bool): If True, applies sparse initialization. + """ super().__init__( units=units, thetas_dim=thetas_dim, @@ -182,6 +377,7 @@ def __init__( backcast_length=backcast_length, forecast_length=forecast_length, dropout=dropout, + kan_params=kan_params, ) self.backcast_fc = nn.Linear(thetas_dim, backcast_length) From 41d74039d9f1f1db5556ad05eb9cd128b0970a12 Mon Sep 17 00:00:00 2001 From: SohaibAhmed121 Date: Mon, 13 Jan 2025 15:17:16 +0500 Subject: [PATCH 2/8] End to end integrated Kolmogorov Arnold Networks in NBeats. Also refactored NBeats. --- pytorch_forecasting/models/nbeats/_nbeats.py | 8 +- .../models/nbeats/kan_layer.py | 146 ++++++++--------- .../models/nbeats/sub_modules.py | 152 ++++++++++++------ 3 files changed, 180 insertions(+), 126 deletions(-) diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index d390bc3b1..872e244b1 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -175,7 +175,7 @@ def __init__( "sparse_init": sparse_init, } - self.save_hyperparameters() + self.save_hyperparameters(ignore=["loss", "logging_metrics"]) super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) # setup stacks @@ -190,7 +190,7 @@ def __init__( backcast_length=context_length, forecast_length=prediction_length, dropout=self.hparams.dropout, - kan_params=self.hparams.kan_params, + kan_params=self.kan_params, ) elif stack_type == "seasonality": net_block = NBEATSSeasonalBlock( @@ -200,7 +200,7 @@ def __init__( forecast_length=prediction_length, min_period=self.hparams.expansion_coefficient_lengths[stack_id], dropout=self.hparams.dropout, - kan_params=self.hparams.kan_params, + kan_params=self.kan_params, ) elif stack_type == "trend": net_block = NBEATSTrendBlock( @@ -210,7 +210,7 @@ def __init__( backcast_length=context_length, forecast_length=prediction_length, dropout=self.hparams.dropout, - kan_params=self.hparams.kan_params, + kan_params=self.kan_params, ) else: raise ValueError(f"Unknown stack type {stack_type}") diff --git a/pytorch_forecasting/models/nbeats/kan_layer.py b/pytorch_forecasting/models/nbeats/kan_layer.py index 13aa77802..7b5d991df 100644 --- a/pytorch_forecasting/models/nbeats/kan_layer.py +++ b/pytorch_forecasting/models/nbeats/kan_layer.py @@ -1,9 +1,9 @@ +import numpy as np import torch import torch.nn as nn -import numpy as np -def B_batch(x, grid, k=0, extend=True, device="cpu"): +def B_batch(x, grid, k=0, extend=True): """ evaludate x on B-spline bases @@ -16,14 +16,14 @@ def B_batch(x, grid, k=0, extend=True, device="cpu"): k : int the piecewise polynomial order of splines. extend : bool - If True, k points are extended on both ends. If False, no extension (zero boundary condition). Default: True - device : str - devicde + If True, k points are extended on both ends. If False, no extension + (zero boundary condition). Default: True Returns: -------- spline values : 3D torch.tensor - shape (batch, in_dim, G+k). G: the number of grid intervals, k: spline order. + shape (batch, in_dim, G+k). G: the number of grid intervals, + k: spline order. Example ------- @@ -41,16 +41,20 @@ def B_batch(x, grid, k=0, extend=True, device="cpu"): else: B_km1 = B_batch(x[:, :, 0], grid=grid[0], k=k - 1) - value = (x - grid[:, :, : -(k + 1)]) / (grid[:, :, k:-1] - grid[:, :, : -(k + 1)]) * B_km1[:, :, :-1] + ( - grid[:, :, k + 1 :] - x - ) / (grid[:, :, k + 1 :] - grid[:, :, 1:(-k)]) * B_km1[:, :, 1:] + value = (x - grid[:, :, : -(k + 1)]) / ( + grid[:, :, k:-1] - grid[:, :, : -(k + 1)] + ) * B_km1[:, :, :-1] + (grid[:, :, k + 1 :] - x) / ( + grid[:, :, k + 1 :] - grid[:, :, 1:(-k)] + ) * B_km1[ + :, :, 1: + ] # in case grid is degenerate value = torch.nan_to_num(value) return value -def coef2curve(x_eval, grid, coef, k, device="cpu"): +def coef2curve(x_eval, grid, coef, k): """ converting B-spline coefficients to B-spline curves. Evaluate x on B-spline curves (summing up B_batch results over B-spline basis). @@ -65,8 +69,6 @@ def coef2curve(x_eval, grid, coef, k, device="cpu"): shape (in_dim, out_dim, G+k) k : int the piecewise polynomial order of splines. - device : str - devicde Returns: -------- @@ -76,7 +78,7 @@ def coef2curve(x_eval, grid, coef, k, device="cpu"): """ b_splines = B_batch(x_eval, grid, k=k) - y_eval = torch.einsum("ijk,jlk->ijl", b_splines, coef.to(b_splines.device)) + y_eval = torch.einsum("ijk,jlk->ijl", b_splines, coef.to(b_splines)) return y_eval @@ -121,7 +123,7 @@ def curve2coef(x_eval, y_eval, grid, k): XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat) Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval) n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2] - identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device) + identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n) A = XtX + lamb * identity B = Xty coef = (A.pinverse() @ B)[:,:,:,0]""" @@ -164,38 +166,6 @@ def sparse_mask(in_dim, out_dim): class KANLayer(nn.Module): """ KANLayer class - - - Attributes: - ----------- - in_dim: int - input dimension - out_dim: int - output dimension - num: int - the number of grid intervals - k: int - the piecewise polynomial order of splines - noise_scale: float - spline scale at initialization - coef: 2D torch.tensor - coefficients of B-spline bases - scale_base_mu: float - magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu - scale_base_sigma: float - magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma - scale_sp: float - mangitude of the spline function spline(x) - base_fun: fun - residual function b(x) - mask: 1D torch.float - mask of spline functions. setting some element of the mask to zero means setting the - corresponding activation to zero function. - grid_eps: float in [0,1] - a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; - when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 - interpolates between the two extremes. - the id of activation functions that are locked """ def __init__( @@ -213,7 +183,6 @@ def __init__( grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, - device="cpu", sparse_init=False, ): """' @@ -232,16 +201,19 @@ def __init__( noise_scale : float the scale of noise injected at initialization. Default: 0.1. scale_base_mu : float - the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). scale_base_sigma : float - the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + the scale of the residual function b(x) is intialized to be + N(scale_base_mu, scale_base_sigma^2). scale_sp : float the scale of the base function spline(x). base_fun : function residual function b(x). Default: torch.nn.SiLU() grid_eps : float - When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using - percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes. + When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is + partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates + between the two extremes. grid_range : list/np.array of shape (2,) setting the range of grids. Default: [-1,1]. sp_trainable : bool @@ -268,21 +240,36 @@ def __init__( self.num = num self.k = k - grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None, :].expand(self.in_dim, num + 1) + grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[ + None, : + ].expand(self.in_dim, num + 1) grid = extend_grid(grid, k_extend=k) self.grid = torch.nn.Parameter(grid).requires_grad_(False) - noises = (torch.rand(self.num + 1, self.in_dim, self.out_dim) - 1 / 2) * noise_scale / num + noises = ( + (torch.rand(self.num + 1, self.in_dim, self.out_dim) - 1 / 2) + * noise_scale + / num + ) - self.coef = torch.nn.Parameter(curve2coef(self.grid[:, k:-k].permute(1, 0), noises, self.grid, k)) + self.coef = torch.nn.Parameter( + curve2coef(self.grid[:, k:-k].permute(1, 0), noises, self.grid, k) + ) if sparse_init: - self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_(False) + self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_( + False + ) else: - self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_(False) + self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_( + False + ) self.scale_base = torch.nn.Parameter( scale_base_mu * 1 / np.sqrt(in_dim) - + scale_base_sigma * (torch.rand(in_dim, out_dim) * 2 - 1) * 1 / np.sqrt(in_dim) + + scale_base_sigma + * (torch.rand(in_dim, out_dim) * 2 - 1) + * 1 + / np.sqrt(in_dim) ).requires_grad_(sb_trainable) self.scale_sp = torch.nn.Parameter( torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask @@ -307,7 +294,8 @@ def forward(self, x): y : 2D torch.float outputs, shape (number of samples, output dimension) preacts : 3D torch.float - fan out x into activations, shape (number of sampels, output dimension, input dimension) + fan out x into activations, shape (number of sampels, + output dimension, input dimension) postacts : 3D torch.float the outputs of activation functions with preacts as inputs postspline : 3D torch.float @@ -324,7 +312,10 @@ def forward(self, x): base = self.base_fun(x) # (batch, in_dim) y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k) - y = self.scale_base[None, :, :] * base[:, :, None] + self.scale_sp[None, :, :] * y + y = ( + self.scale_base[None, :, :] * base[:, :, None] + + self.scale_sp[None, :, :] * y + ) y = self.mask[None, :, :] * y y = torch.sum(y, dim=1) return y @@ -352,8 +343,8 @@ def update_grid_from_samples(self, x, mode="sample"): """ batch = x.shape[0] - # x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size) - # .permute(1, 0) + # x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, )) + # .reshape(batch, self.size).permute(1, 0) x_pos = torch.sort(x, dim=0)[0] y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) num_interval = self.grid.shape[1] - 1 - 2 * self.k @@ -362,16 +353,16 @@ def get_grid(num_interval): ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] grid_adaptive = x_pos[ids, :].permute(1, 0) margin = 0.00 - h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin) / num_interval + h = ( + grid_adaptive[:, [-1]] - grid_adaptive[:, [0]] + 2 * margin + ) / num_interval grid_uniform = ( grid_adaptive[:, [0]] - margin + h * torch.arange( num_interval + 1, - )[ - None, : - ].to(x.device) + )[None, :] ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive return grid @@ -384,8 +375,6 @@ def get_grid(num_interval): y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) self.grid.data = extend_grid(grid, k_extend=self.k) - # print('x_pos 2', x_pos.shape) - # print('y_eval 2', y_eval.shape) self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) def initialize_grid_from_parent(self, parent, x, mode="sample"): @@ -424,7 +413,8 @@ def get_grid(num_interval): ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] grid_adaptive = x_pos[ids, :].permute(1,0) h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval - grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device) + grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,) + [None, :] grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive return grid""" @@ -433,11 +423,14 @@ def get_grid(num_interval): x_pos = parent.grid[:, parent.k : -parent.k] # print('x_pos', x_pos) sp2 = KANLayer( - in_dim=1, out_dim=self.in_dim, k=1, num=x_pos.shape[1] - 1, scale_base_mu=0.0, scale_base_sigma=0.0 - ).to(x.device) + in_dim=1, + out_dim=self.in_dim, + k=1, + num=x_pos.shape[1] - 1, + scale_base_mu=0.0, + scale_base_sigma=0.0, + ) - # print('sp2_grid', sp2.grid[:,sp2.k:-sp2.k].permute(1,0).expand(-1,self.in_dim)) - # print('sp2_coef_shape', sp2.coef.shape) sp2_coef = curve2coef( sp2.grid[:, sp2.k : -sp2.k].permute(1, 0).expand(-1, self.in_dim), x_pos.permute(1, 0).unsqueeze(dim=2), @@ -445,7 +438,7 @@ def get_grid(num_interval): k=1, ).permute(1, 0, 2) sp2.coef.data = sp2_coef - percentile = torch.linspace(-1, 1, self.num + 1).to(self.device) + percentile = torch.linspace(-1, 1, self.num + 1) grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0) return grid @@ -482,7 +475,9 @@ def get_subset(self, in_id, out_id): >>> kanlayer_small.in_dim, kanlayer_small.out_dim (2, 3) """ - spb = KANLayer(len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun) + spb = KANLayer( + len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun + ) spb.grid.data = self.grid[in_id] spb.coef.data = self.coef[in_id][:, out_id] spb.scale_base.data = self.scale_base[in_id][:, out_id] @@ -495,7 +490,8 @@ def get_subset(self, in_id, out_id): def swap(self, i1, i2, mode="in"): """ - swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out') + swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output + (if mode == 'out') Args: ----- diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index 3e6193d49..b80382c64 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -12,6 +12,9 @@ def linear(input_size, output_size, bias=True, dropout: int = None): + """ + Initialize linear layers for MLP block layers. + """ lin = nn.Linear(input_size, output_size, bias=bias) if dropout is not None: return nn.Sequential(nn.Dropout(dropout), lin) @@ -22,6 +25,9 @@ def linear(input_size, output_size, bias=True, dropout: int = None): def linspace( backcast_length: int, forecast_length: int, centered: bool = False ) -> Tuple[np.ndarray, np.ndarray]: + """ + Generate linear spaced values for backcast and forecast. + """ if centered: norm = max(backcast_length, forecast_length) start = -backcast_length @@ -46,16 +52,48 @@ def __init__( num_block_layers=4, backcast_length=10, forecast_length=5, - share_thetas=False, dropout=0.1, kan_params={}, ): + """ + Initialize NBeatsSeasonalBlock + + Args: + units: The number of units in the mlp/kan layers. + thetas_dim: The dimension of the parameterized output for the block. + num_block_layers: Number of fully connected mlp/kan layers. Default: 4. + backcast_length: The length of the backcast. Defines how many time units + from the past are used to predict the future. Default: 10. + forecast_length: The length of the forecast, i.e., the number of time steps + ahead to predict. Default: 5. + dropout: The dropout rate applied to the fully connected mlp layers to + prevent overfitting. Default: 0.1. + kan_params (dict): Parameters specific to the KAN layer + (used for modeling using KAN). Default: empty dictionary. + Contains: + num_grids (int): The number of grid intervals for KAN. + k (int): The order of the piecewise polynomial for KAN. + noise_scale (float): The scale of noise injected at initialization. + scale_base_mu (float): The scale of the residual function + initialized to N(scale_base_mu, scale_base_sigma^2). + scale_base_sigma (float): The scale of the residual function + initialized to N(scale_base_mu, scale_base_sigma^2). + scale_sp (float): The scale of the base function spline(x) in KAN. + base_fun (function): The residual function used by + KAN (e.g., torch.nn.SiLU()). + grid_eps (float): Determines the partitioning of the grid. If 1, + the grid is uniform; if 0, grid is partitioned by percentiles. + grid_range (list or np.array): The range of the grid, given as + a list of two values. + sp_trainable (bool): If True, the scale_sp is trainable. + sb_trainable (bool): If True, the scale_base is trainable. + sparse_init (bool): If True, applies sparse initialization. + """ super().__init__() self.units = units self.thetas_dim = thetas_dim self.backcast_length = backcast_length self.forecast_length = forecast_length - self.share_thetas = share_thetas self.kan_params = kan_params if self.kan_params["use_kan"]: @@ -77,47 +115,63 @@ def __init__( sparse_init=self.kan_params["sparse_init"], ) ] - # Additional KANLayers for deeper structure + + # Add additional layers for deeper structure for _ in range(num_block_layers - 1): - layers.extend( - [ - KANLayer( - in_dim=units, - out_dim=units, - num=self.kan_params["num_grids"], - k=self.kan_params["k"], - noise_scale=self.kan_params["noise_scale"], - scale_base_mu=self.kan_params["scale_base_mu"], - scale_base_sigma=self.kan_params["scale_base_sigma"], - scale_sp=self.kan_params["scale_sp"], - base_fun=self.kan_params["base_fun"], - grid_eps=self.kan_params["grid_eps"], - grid_range=self.kan_params["grid_range"], - sp_trainable=self.kan_params["sp_trainable"], - sb_trainable=self.kan_params["sb_trainable"], - sparse_init=self.kan_params["sparse_init"], - device="cpu", # Assuming you are using the "cpu" device - ) - ] + layers.append( + KANLayer( + in_dim=units, + out_dim=units, + num=self.kan_params["num_grids"], + k=self.kan_params["k"], + noise_scale=self.kan_params["noise_scale"], + scale_base_mu=self.kan_params["scale_base_mu"], + scale_base_sigma=self.kan_params["scale_base_sigma"], + scale_sp=self.kan_params["scale_sp"], + base_fun=self.kan_params["base_fun"], + grid_eps=self.kan_params["grid_eps"], + grid_range=self.kan_params["grid_range"], + sp_trainable=self.kan_params["sp_trainable"], + sb_trainable=self.kan_params["sb_trainable"], + sparse_init=self.kan_params["sparse_init"], + ) ) - self.fc = nn.Sequential(*layers) + # Define the fully connected layers + self.fc = nn.Sequential(*layers) + + # Define the theta layers + self.theta_f_fc = self.theta_b_fc = KANLayer( + in_dim=units, + out_dim=thetas_dim, + num=self.kan_params["num_grids"], + k=self.kan_params["k"], + noise_scale=self.kan_params["noise_scale"], + scale_base_mu=self.kan_params["scale_base_mu"], + scale_base_sigma=self.kan_params["scale_base_sigma"], + scale_sp=self.kan_params["scale_sp"], + base_fun=self.kan_params["base_fun"], + grid_eps=self.kan_params["grid_eps"], + grid_range=self.kan_params["grid_range"], + sp_trainable=self.kan_params["sp_trainable"], + sb_trainable=self.kan_params["sb_trainable"], + sparse_init=self.kan_params["sparse_init"], + ) - fc_stack = [ - nn.Linear(backcast_length, units), - nn.ReLU(), - ] - for _ in range(num_block_layers - 1): - fc_stack.extend([linear(units, units, dropout=dropout), nn.ReLU()]) - self.fc = nn.Sequential(*fc_stack) - - if share_thetas: - self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) else: - self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) - self.theta_f_fc = nn.Linear(units, thetas_dim, bias=False) + fc_stack = [ + nn.Linear(backcast_length, units), + nn.ReLU(), + ] + for _ in range(num_block_layers - 1): + fc_stack.extend([linear(units, units, dropout=dropout), nn.ReLU()]) + self.fc = nn.Sequential(*fc_stack) + self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) def forward(self, x): + """ + Pass through the fully connected mlp/kan layers and returns the output. + """ return self.fc(x) @@ -138,9 +192,9 @@ def __init__( Initialize NBeatsSeasonalBlock Args: - units: The number of units in the mlp/kan layers. Default: 256. + units: The number of units in the mlp/kan layers. thetas_dim: The dimension of the parameterized output for the block. - If None, it is inferred. Default: None. + If None, it is inferred. num_block_layers: Number of fully connected mlp/kan layers. Default: 4. backcast_length: The length of the backcast. Defines how many time units from the past are used to predict the future. Default: 10. @@ -184,7 +238,6 @@ def __init__( num_block_layers=num_block_layers, backcast_length=backcast_length, forecast_length=forecast_length, - share_thetas=True, dropout=dropout, kan_params=kan_params, ) @@ -219,7 +272,9 @@ def __init__( self.register_buffer("S_forecast", torch.cat([s1_f, s2_f])) def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: - """Computes the backcast and forecast outputs for the given input tensor.""" + """ + Computes the backcast and forecast outputs for the given input tensor. + """ x = super().forward(x) amplitudes_backward = self.theta_b_fc(x) backcast = amplitudes_backward.mm(self.S_backcast) @@ -252,9 +307,9 @@ def __init__( Initialize NBeatsSeasonalBlock Args: - units: The number of units in the mlp/kan layers. Default: 256. + units: The number of units in the mlp/kan layers. thetas_dim: The dimension of the parameterized output for the block. - If None, it is inferred. Default: None. + If None, it is inferred. num_block_layers: Number of fully connected mlp/kan layers. Default: 4. backcast_length: The length of the backcast. Defines how many time units from the past are used to predict the future. Default: 10. @@ -289,7 +344,6 @@ def __init__( num_block_layers=num_block_layers, backcast_length=backcast_length, forecast_length=forecast_length, - share_thetas=True, dropout=dropout, kan_params=kan_params, ) @@ -313,6 +367,9 @@ def __init__( self.register_buffer("T_forecast", coefficients * norm) def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes the backcast and forecast outputs for the given input tensor. + """ x = super().forward(x) backcast = self.theta_b_fc(x).mm(self.T_backcast) forecast = self.theta_f_fc(x).mm(self.T_forecast) @@ -334,9 +391,9 @@ def __init__( Initialize NBeatsSeasonalBlock Args: - units: The number of units in the mlp/kan layers. Default: 256. + units: The number of units in the mlp/kan layers. thetas_dim: The dimension of the parameterized output for the block. - If None, it is inferred. Default: None. + If None, it is inferred. num_block_layers: Number of fully connected mlp/kan layers. Default: 4. backcast_length: The length of the backcast. Defines how many time units from the past are used to predict the future. Default: 10. @@ -379,9 +436,10 @@ def __init__( self.forecast_fc = nn.Linear(thetas_dim, forecast_length) def forward(self, x): + """ + Computes the backcast and forecast outputs for the given input tensor. + """ x = super().forward(x) - theta_b = F.relu(self.theta_b_fc(x)) theta_f = F.relu(self.theta_f_fc(x)) - return self.backcast_fc(theta_b), self.forecast_fc(theta_f) From 594102d1d13409d5fb00c7686480b05a3cf77f00 Mon Sep 17 00:00:00 2001 From: SohaibAhmed121 Date: Mon, 13 Jan 2025 15:43:36 +0500 Subject: [PATCH 3/8] Resolved import error. --- pytorch_forecasting/models/nbeats/sub_modules.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index b80382c64..c236507a1 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -4,12 +4,13 @@ from typing import Tuple -from kan_layer import KANLayer import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from pytorch_forecasting.models.nbeats.kan_layer import KANLayer + def linear(input_size, output_size, bias=True, dropout: int = None): """ From c8ccfaf6a51e57af6cf810f5f822bc0752d9c1cd Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Thu, 23 Jan 2025 04:46:56 -0800 Subject: [PATCH 4/8] Refactored NBEATS and added support for grid updation during training while using KAN blocks in NBEATS. --- pytorch_forecasting/models/nbeats/_nbeats.py | 106 ++++++----- .../models/nbeats/grid_callback.py | 38 ++++ .../models/nbeats/kan_layer.py | 167 +----------------- .../models/nbeats/sub_modules.py | 89 +++++----- 4 files changed, 150 insertions(+), 250 deletions(-) create mode 100644 pytorch_forecasting/models/nbeats/grid_callback.py diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index beb66314e..77fffd332 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -30,19 +30,6 @@ def __init__( expansion_coefficient_lengths: Optional[List[int]] = None, prediction_length: int = 1, context_length: int = 1, - use_kan: bool = False, - num_grids: int = 5, - k: int = 3, - noise_scale: float = 0.5, - scale_base_mu: float = 0.0, - scale_base_sigma: float = 1.0, - scale_sp: float = 1.0, - base_fun: callable = torch.nn.SiLU(), - grid_eps: float = 0.02, - grid_range: List[int] = [-1, 1], - sp_trainable: bool = True, - sb_trainable: bool = True, - sparse_init: bool = False, dropout: float = 0.1, learning_rate: float = 1e-2, log_interval: int = -1, @@ -53,6 +40,19 @@ def __init__( reduce_on_plateau_patience: int = 1000, backcast_loss_ratio: float = 0.0, logging_metrics: nn.ModuleList = None, + use_kan: bool = False, + num: int = 5, + k: int = 3, + noise_scale: float = 0.5, + scale_base_mu: float = 0.0, + scale_base_sigma: float = 1.0, + scale_sp: float = 1.0, + base_fun: callable = None, + grid_eps: float = 0.02, + grid_range: List[int] = None, + sp_trainable: bool = True, + sb_trainable: bool = True, + sparse_init: bool = False, **kwargs, ): """ @@ -101,45 +101,49 @@ def __init__( context_length: Number of time units that condition the predictions. Also known as 'lookback period'. Should be between 1-10 times the prediction length. - num_grids : Parameter for KAN layer. the number of grid intervals = G. - Default: 5. + backcast_loss_ratio: weight of backcast in comparison to forecast when + calculating the loss. A weight of 1.0 means that forecast and + backcast loss is weighted the same (regardless of backcast and forecast + lengths). Defaults to 0.0, i.e. no weight. + loss: loss to optimize. Defaults to MASE(). + log_gradient_flow: if to log gradient flow, this takes time and should be + only done to diagnose training failures. + reduce_on_plateau_patience (int): patience after which learning rate is + reduced by a factor of 10 + logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that + are logged during training. Defaults to + nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. + num : Parameter for KAN layer. the number of grid intervals = G. + Default: 5. used when use_kan is True. k : Parameter for KAN layer. the order of piecewise polynomial. Default: 3. + used when use_kan is True. noise_scale : Parameter for KAN layer. the scale of noise injected at - initialization. Default: 0.1. + initialization. Default: 0.1. used when use_kan is True. scale_base_mu : Parameter for KAN layer. the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - Deafult: 0.0 + Deafult: 0.0. used when use_kan is True. scale_base_sigma : Parameter for KAN layer. the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - Deafult: 1.0 + Deafult: 1.0. used when use_kan is True. scale_sp : Parameter for KAN layer. the scale of the base function - spline(x). Deafult: 1.0 + spline(x). Deafult: 1.0. used when use_kan is True. base_fun : Parameter for KAN layer. residual function b(x). - Default: torch.nn.SiLU() + Default: None. used when use_kan is True. grid_eps : Parameter for KAN layer. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. - 0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02 + 0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02. + used when use_kan is True. grid_range : Parameter for KAN layer. list/np.array of shape (2,). setting - the range of grids. - Default: [-1,1]. + the range of grids. Default: None. used when use_kan is True. sp_trainable : Parameter for KAN layer. If true, scale_sp is trainable. - Default: True. + Default: True. used when use_kan is True. sb_trainable : Parameter for KAN layer. If true, scale_base is trainable. - Default: True. + Default: True. used when use_kan is True. sparse_init : Parameter for KAN layer. if sparse_init = True, sparse - initialization is applied. Default: False. - backcast_loss_ratio: weight of backcast in comparison to forecast when - calculating the loss. A weight of 1.0 means that forecast and - backcast loss is weighted the same (regardless of backcast and forecast - lengths). Defaults to 0.0, i.e. no weight. - loss: loss to optimize. Defaults to MASE(). - log_gradient_flow: if to log gradient flow, this takes time and should be - only done to diagnose training failures. - reduce_on_plateau_patience (int): patience after which learning rate is - reduced by a factor of 10 - logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that - are logged during training. Defaults to - nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + initialization is applied. Default: False. used when use_kan is True. **kwargs: additional arguments to :py:class:`~BaseModel`. """ # noqa: E501 if expansion_coefficient_lengths is None: @@ -154,14 +158,17 @@ def __init__( num_blocks = [3, 3] if stack_types is None: stack_types = ["trend", "seasonality"] + if base_fun is None: + base_fun = torch.nn.SiLU() + if grid_range is None: + grid_range = [-1, 1] if logging_metrics is None: logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) if loss is None: loss = MASE() # Bundle KAN parameters into a dictionary self.kan_params = { - "use_kan": use_kan, - "num_grids": num_grids, + "num": num, "k": k, "noise_scale": noise_scale, "scale_base_mu": scale_base_mu, @@ -174,6 +181,7 @@ def __init__( "sb_trainable": sb_trainable, "sparse_init": sparse_init, } + self.use_kan = use_kan self.save_hyperparameters(ignore=["loss", "logging_metrics"]) super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) @@ -191,6 +199,7 @@ def __init__( forecast_length=prediction_length, dropout=self.hparams.dropout, kan_params=self.kan_params, + use_kan=use_kan, ) elif stack_type == "seasonality": net_block = NBEATSSeasonalBlock( @@ -201,6 +210,7 @@ def __init__( min_period=self.hparams.expansion_coefficient_lengths[stack_id], dropout=self.hparams.dropout, kan_params=self.kan_params, + use_kan=use_kan, ) elif stack_type == "trend": net_block = NBEATSTrendBlock( @@ -211,6 +221,7 @@ def __init__( forecast_length=prediction_length, dropout=self.hparams.dropout, kan_params=self.kan_params, + use_kan=use_kan, ) else: raise ValueError(f"Unknown stack type {stack_type}") @@ -291,6 +302,21 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: ), ) + def update_kan_grid(self): + """ + Updates grid of KAN layers when using KAN layers in NBEATSBlock. + """ + if self.use_kan: + for block in self.net_blocks: + # updation logic taken from + # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 + for i, layer in enumerate(block.fc): + # update basis KAN layers' grid + layer.update_grid_from_samples(block.outputs[i]) + # update theta backward and theta forward KAN layers' grid + block.theta_b_fc.update_grid_from_samples(block.outputs[i + 1]) + block.theta_f_fc.update_grid_from_samples(block.outputs[i + 1]) + @classmethod def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): """ diff --git a/pytorch_forecasting/models/nbeats/grid_callback.py b/pytorch_forecasting/models/nbeats/grid_callback.py new file mode 100644 index 000000000..3c36b1ef4 --- /dev/null +++ b/pytorch_forecasting/models/nbeats/grid_callback.py @@ -0,0 +1,38 @@ +from lightning.pytorch.callbacks import Callback + + +class GridUpdateCallback(Callback): + """ + Custom callback to update the grid of the model during training at regular + intervals. + + Attributes: + update_interval (int): The frequency at which the grid is updated. + """ + + def __init__(self, update_interval): + """ + Initializes the callback with the given update interval. + + Args: + update_interval (int): The frequency at which the grid is updated. + """ + self.update_interval = update_interval + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + """ + Hook that is called at the end of each training batch. + Updates the grid of KAN layers if the current step is a multiple of the update + interval. + + Args: + trainer (Trainer): The PyTorch Lightning Trainer object. + pl_module (LightningModule): The model being trained (LightningModule). + outputs (Any): Outputs from the model for the current batch. + batch (Any): The current batch of data. + batch_idx (int): Index of the current batch. + """ + # Check if the current step is a multiple of the update interval + if (trainer.global_step + 1) % self.update_interval == 0: + # Call the model's update_kan_grid method + pl_module.update_kan_grid() diff --git a/pytorch_forecasting/models/nbeats/kan_layer.py b/pytorch_forecasting/models/nbeats/kan_layer.py index 7b5d991df..d9357bb46 100644 --- a/pytorch_forecasting/models/nbeats/kan_layer.py +++ b/pytorch_forecasting/models/nbeats/kan_layer.py @@ -1,3 +1,6 @@ +# The following implementation of KANLayer is inspired by the pykan library. +# Reference: https://github.com/KindXiaoming/pykan/blob/master/kan/KANLayer.py + import numpy as np import torch import torch.nn as nn @@ -105,7 +108,6 @@ def curve2coef(x_eval, y_eval, grid, k): coef : 3D torch.tensor shape (in_dim, out_dim, G+k) """ - # print('haha', x_eval.shape, y_eval.shape, grid.shape) batch = x_eval.shape[0] in_dim = x_eval.shape[1] out_dim = y_eval.shape[2] @@ -118,16 +120,6 @@ def curve2coef(x_eval, y_eval, grid, k): except Exception as e: print(f"lstsq failed with error: {e}") - # manual psuedo-inverse - """lamb=1e-8 - XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat) - Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval) - n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2] - identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n) - A = XtX + lamb * identity - B = Xty - coef = (A.pinverse() @ B)[:,:,:,0]""" - return coef @@ -343,8 +335,6 @@ def update_grid_from_samples(self, x, mode="sample"): """ batch = x.shape[0] - # x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, )) - # .reshape(batch, self.size).permute(1, 0) x_pos = torch.sort(x, dim=0)[0] y_eval = coef2curve(x_pos, self.grid, self.coef, self.k) num_interval = self.grid.shape[1] - 1 - 2 * self.k @@ -368,7 +358,6 @@ def get_grid(num_interval): return grid grid = get_grid(num_interval) - if mode == "grid": sample_grid = get_grid(2 * num_interval) x_pos = sample_grid.permute(1, 0) @@ -376,153 +365,3 @@ def get_grid(num_interval): self.grid.data = extend_grid(grid, k_extend=self.k) self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) - - def initialize_grid_from_parent(self, parent, x, mode="sample"): - """ - update grid from a parent KANLayer & samples - - Args: - ----- - parent : KANLayer - a parent KANLayer (whose grid is usually coarser than the current model) - x : 2D torch.float - inputs, shape (number of samples, input dimension) - - Returns: - -------- - None - - Example - ------- - >>> batch = 100 - >>> parent_model = KANLayer(in_dim=1, out_dim=1, num=5, k=3) - >>> print(parent_model.grid.data) - >>> model = KANLayer(in_dim=1, out_dim=1, num=10, k=3) - >>> x = torch.normal(0,1,size=(batch, 1)) - >>> model.initialize_grid_from_parent(parent_model, x) - >>> print(model.grid.data) - """ - # shrink grid - x_pos = torch.sort(x, dim=0)[0] - y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) - num_interval = self.grid.shape[1] - 1 - 2 * self.k - - """ - # based on samples - def get_grid(num_interval): - ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1] - grid_adaptive = x_pos[ids, :].permute(1,0) - h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval - grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,) - [None, :] - grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive - return grid""" - - # based on interpolating parent grid - def get_grid(num_interval): - x_pos = parent.grid[:, parent.k : -parent.k] - # print('x_pos', x_pos) - sp2 = KANLayer( - in_dim=1, - out_dim=self.in_dim, - k=1, - num=x_pos.shape[1] - 1, - scale_base_mu=0.0, - scale_base_sigma=0.0, - ) - - sp2_coef = curve2coef( - sp2.grid[:, sp2.k : -sp2.k].permute(1, 0).expand(-1, self.in_dim), - x_pos.permute(1, 0).unsqueeze(dim=2), - sp2.grid[:, :], - k=1, - ).permute(1, 0, 2) - sp2.coef.data = sp2_coef - percentile = torch.linspace(-1, 1, self.num + 1) - grid = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0) - return grid - - grid = get_grid(num_interval) - - if mode == "grid": - sample_grid = get_grid(2 * num_interval) - x_pos = sample_grid.permute(1, 0) - y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k) - - grid = extend_grid(grid, k_extend=self.k) - self.grid.data = grid - self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k) - - def get_subset(self, in_id, out_id): - """ - get a smaller KANLayer from a larger KANLayer (used for pruning) - - Args: - ----- - in_id : list - id of selected input neurons - out_id : list - id of selected output neurons - - Returns: - -------- - spb : KANLayer - - Example - ------- - >>> kanlayer_large = KANLayer(in_dim=10, out_dim=10, num=5, k=3) - >>> kanlayer_small = kanlayer_large.get_subset([0,9],[1,2,3]) - >>> kanlayer_small.in_dim, kanlayer_small.out_dim - (2, 3) - """ - spb = KANLayer( - len(in_id), len(out_id), self.num, self.k, base_fun=self.base_fun - ) - spb.grid.data = self.grid[in_id] - spb.coef.data = self.coef[in_id][:, out_id] - spb.scale_base.data = self.scale_base[in_id][:, out_id] - spb.scale_sp.data = self.scale_sp[in_id][:, out_id] - spb.mask.data = self.mask[in_id][:, out_id] - - spb.in_dim = len(in_id) - spb.out_dim = len(out_id) - return spb - - def swap(self, i1, i2, mode="in"): - """ - swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output - (if mode == 'out') - - Args: - ----- - i1 : int - i2 : int - mode : str - mode = 'in' or 'out' - - Returns: - -------- - None - - Example - ------- - >>> from kan.KANLayer import * - >>> model = KANLayer(in_dim=2, out_dim=2, num=5, k=3) - >>> print(model.coef) - >>> model.swap(0,1,mode='in') - >>> print(model.coef) - """ - with torch.no_grad(): - - def swap_(data, i1, i2, mode="in"): - if mode == "in": - data[i1], data[i2] = data[i2].clone(), data[i1].clone() - elif mode == "out": - data[:, i1], data[:, i2] = data[:, i2].clone(), data[:, i1].clone() - - if mode == "in": - swap_(self.grid.data, i1, i2, mode="in") - swap_(self.coef.data, i1, i2, mode=mode) - swap_(self.scale_base.data, i1, i2, mode=mode) - swap_(self.scale_sp.data, i1, i2, mode=mode) - swap_(self.mask.data, i1, i2, mode=mode) diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index c236507a1..cb0b32962 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -54,7 +54,8 @@ def __init__( backcast_length=10, forecast_length=5, dropout=0.1, - kan_params={}, + kan_params=None, + use_kan=False, ): """ Initialize NBeatsSeasonalBlock @@ -70,7 +71,7 @@ def __init__( dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. kan_params (dict): Parameters specific to the KAN layer - (used for modeling using KAN). Default: empty dictionary. + (used for modeling using KAN). Default: None. Contains: num_grids (int): The number of grid intervals for KAN. k (int): The order of the piecewise polynomial for KAN. @@ -89,6 +90,9 @@ def __init__( sp_trainable (bool): If True, the scale_sp is trainable. sb_trainable (bool): If True, the scale_base is trainable. sparse_init (bool): If True, applies sparse initialization. + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. """ super().__init__() self.units = units @@ -96,24 +100,14 @@ def __init__( self.backcast_length = backcast_length self.forecast_length = forecast_length self.kan_params = kan_params + self.use_kan = use_kan - if self.kan_params["use_kan"]: + if self.use_kan: layers = [ KANLayer( in_dim=backcast_length, out_dim=units, - num=self.kan_params["num_grids"], - k=self.kan_params["k"], - noise_scale=self.kan_params["noise_scale"], - scale_base_mu=self.kan_params["scale_base_mu"], - scale_base_sigma=self.kan_params["scale_base_sigma"], - scale_sp=self.kan_params["scale_sp"], - base_fun=self.kan_params["base_fun"], - grid_eps=self.kan_params["grid_eps"], - grid_range=self.kan_params["grid_range"], - sp_trainable=self.kan_params["sp_trainable"], - sb_trainable=self.kan_params["sb_trainable"], - sparse_init=self.kan_params["sparse_init"], + **self.kan_params, ) ] @@ -123,18 +117,7 @@ def __init__( KANLayer( in_dim=units, out_dim=units, - num=self.kan_params["num_grids"], - k=self.kan_params["k"], - noise_scale=self.kan_params["noise_scale"], - scale_base_mu=self.kan_params["scale_base_mu"], - scale_base_sigma=self.kan_params["scale_base_sigma"], - scale_sp=self.kan_params["scale_sp"], - base_fun=self.kan_params["base_fun"], - grid_eps=self.kan_params["grid_eps"], - grid_range=self.kan_params["grid_range"], - sp_trainable=self.kan_params["sp_trainable"], - sb_trainable=self.kan_params["sb_trainable"], - sparse_init=self.kan_params["sparse_init"], + **self.kan_params, ) ) @@ -145,18 +128,7 @@ def __init__( self.theta_f_fc = self.theta_b_fc = KANLayer( in_dim=units, out_dim=thetas_dim, - num=self.kan_params["num_grids"], - k=self.kan_params["k"], - noise_scale=self.kan_params["noise_scale"], - scale_base_mu=self.kan_params["scale_base_mu"], - scale_base_sigma=self.kan_params["scale_base_sigma"], - scale_sp=self.kan_params["scale_sp"], - base_fun=self.kan_params["base_fun"], - grid_eps=self.kan_params["grid_eps"], - grid_range=self.kan_params["grid_range"], - sp_trainable=self.kan_params["sp_trainable"], - sb_trainable=self.kan_params["sb_trainable"], - sparse_init=self.kan_params["sparse_init"], + **self.kan_params, ) else: @@ -173,7 +145,17 @@ def forward(self, x): """ Pass through the fully connected mlp/kan layers and returns the output. """ - return self.fc(x) + # outputs logic taken from + # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 + self.outputs = [] + self.outputs.append(x.clone().detach()) + for layer in self.fc: + x = layer(x) # Pass data through the current layer + # store outputs for updating grids of theta_fc when using KAN + self.outputs.append(x.clone().detach()) + # for updating grids of theta_b_fc and theta_f_fc when using KAN + self.outputs.append(x.clone().detach()) + return x # Return final output class NBEATSSeasonalBlock(NBEATSBlock): @@ -187,7 +169,8 @@ def __init__( nb_harmonics=None, min_period=1, dropout=0.1, - kan_params={}, + kan_params=None, + use_kan=False, ): """ Initialize NBeatsSeasonalBlock @@ -207,7 +190,7 @@ def __init__( dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. kan_params (dict): Parameters specific to the KAN layer - (used for modeling using KAN). Default: empty dictionary. + (used for modeling using KAN). Default: None. Contains: num_grids (int): The number of grid intervals for KAN. k (int): The order of the piecewise polynomial for KAN. @@ -226,6 +209,9 @@ def __init__( sp_trainable (bool): If True, the scale_sp is trainable. sb_trainable (bool): If True, the scale_base is trainable. sparse_init (bool): If True, applies sparse initialization. + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. """ if nb_harmonics: thetas_dim = nb_harmonics @@ -241,6 +227,7 @@ def __init__( forecast_length=forecast_length, dropout=dropout, kan_params=kan_params, + use_kan=use_kan, ) backcast_linspace, forecast_linspace = linspace( @@ -302,7 +289,8 @@ def __init__( backcast_length=10, forecast_length=5, dropout=0.1, - kan_params={}, + kan_params=None, + use_kan=False, ): """ Initialize NBeatsSeasonalBlock @@ -319,7 +307,7 @@ def __init__( dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. kan_params (dict): Parameters specific to the KAN layer - (used for modeling using KAN). Default: empty dictionary. + (used for modeling using KAN). Default: None. Contains: num_grids (int): The number of grid intervals for KAN. k (int): The order of the piecewise polynomial for KAN. @@ -338,6 +326,9 @@ def __init__( sp_trainable (bool): If True, the scale_sp is trainable. sb_trainable (bool): If True, the scale_base is trainable. sparse_init (bool): If True, applies sparse initialization. + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. """ super().__init__( units=units, @@ -347,6 +338,7 @@ def __init__( forecast_length=forecast_length, dropout=dropout, kan_params=kan_params, + use_kan=use_kan, ) backcast_linspace, forecast_linspace = linspace( @@ -386,7 +378,8 @@ def __init__( backcast_length=10, forecast_length=5, dropout=0.1, - kan_params={}, + kan_params=None, + use_kan=False, ): """ Initialize NBeatsSeasonalBlock @@ -403,7 +396,7 @@ def __init__( dropout: The dropout rate applied to the fully connected mlp layers to prevent overfitting. Default: 0.1. kan_params (dict): Parameters specific to the KAN layer - (used for modeling using KAN). Default: empty dictionary. + (used for modeling using KAN). Default: None. Contains: num_grids (int): The number of grid intervals for KAN. k (int): The order of the piecewise polynomial for KAN. @@ -422,6 +415,9 @@ def __init__( sp_trainable (bool): If True, the scale_sp is trainable. sb_trainable (bool): If True, the scale_base is trainable. sparse_init (bool): If True, applies sparse initialization. + use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, + kan layers are used in nbeats block else mlp layers are used. Default: + false. """ super().__init__( units=units, @@ -431,6 +427,7 @@ def __init__( forecast_length=forecast_length, dropout=dropout, kan_params=kan_params, + use_kan=use_kan, ) self.backcast_fc = nn.Linear(thetas_dim, backcast_length) From 348da97e3ed7f6b93fb011a9903c3965c4457600 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Thu, 23 Jan 2025 05:49:46 -0800 Subject: [PATCH 5/8] Refactored comments. --- pytorch_forecasting/models/nbeats/sub_modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index cb0b32962..7ddf17a20 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -151,9 +151,9 @@ def forward(self, x): self.outputs.append(x.clone().detach()) for layer in self.fc: x = layer(x) # Pass data through the current layer - # store outputs for updating grids of theta_fc when using KAN + # storing outputs for updating grids of self.fc when using KAN self.outputs.append(x.clone().detach()) - # for updating grids of theta_b_fc and theta_f_fc when using KAN + # storing for updating grids of theta_b_fc and theta_f_fc when using KAN self.outputs.append(x.clone().detach()) return x # Return final output From 1ab0da0dd14577e3d51d3f7756e9215e8bd59057 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Sat, 1 Feb 2025 11:56:57 -0800 Subject: [PATCH 6/8] Added example to use grid_update_callback and added correct device to tensors during training. --- examples/nbeats_with_kan.py | 106 ++++++++++++++++++ .../models/nbeats/grid_callback.py | 4 + .../models/nbeats/kan_layer.py | 5 +- 3 files changed, 111 insertions(+), 4 deletions(-) create mode 100644 examples/nbeats_with_kan.py diff --git a/examples/nbeats_with_kan.py b/examples/nbeats_with_kan.py new file mode 100644 index 000000000..925e8dcf0 --- /dev/null +++ b/examples/nbeats_with_kan.py @@ -0,0 +1,106 @@ +import sys + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +import pandas as pd + +from pytorch_forecasting import NBeats, TimeSeriesDataSet +from pytorch_forecasting.data import NaNLabelEncoder +from pytorch_forecasting.data.examples import generate_ar_data +from pytorch_forecasting.models.nbeats.grid_callback import GridUpdateCallback + +sys.path.append("..") + + +print("load data") +data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100) +data["static"] = 2 +data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") +validation = data.series.sample(20) + + +max_encoder_length = 150 +max_prediction_length = 20 + +training_cutoff = data["time_idx"].max() - max_prediction_length + +context_length = max_encoder_length +prediction_length = max_prediction_length + +training = TimeSeriesDataSet( + data[lambda x: x.time_idx < training_cutoff], + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + min_encoder_length=context_length, + max_encoder_length=context_length, + max_prediction_length=prediction_length, + min_prediction_length=prediction_length, + time_varying_unknown_reals=["value"], + randomize_length=None, + add_relative_time_idx=False, + add_target_scales=False, +) + +validation = TimeSeriesDataSet.from_dataset( + training, data, min_prediction_idx=training_cutoff +) +batch_size = 128 +train_dataloader = training.to_dataloader( + train=True, batch_size=batch_size, num_workers=0 +) +val_dataloader = validation.to_dataloader( + train=False, batch_size=batch_size, num_workers=0 +) + + +early_stop_callback = EarlyStopping( + monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min" +) +# updates KAN layers' grid after every 3 steps during training +grid_update_callback = GridUpdateCallback(update_interval=3) + +trainer = pl.Trainer( + max_epochs=1, + accelerator="auto", + gradient_clip_val=0.1, + callbacks=[early_stop_callback, grid_update_callback], + limit_train_batches=15, + # limit_val_batches=1, + # fast_dev_run=True, + # logger=logger, + # profiler=True, +) + + +net = NBeats.from_dataset( + training, + learning_rate=3e-2, + log_interval=10, + log_val_interval=1, + log_gradient_flow=False, + weight_decay=1e-2, + use_kan=True, +) +print(f"Number of parameters in network: {net.size() / 1e3:.1f}k") + +# # find optimal learning rate +# # remove logging and artificial epoch size +# net.hparams.log_interval = -1 +# net.hparams.log_val_interval = -1 +# trainer.limit_train_batches = 1.0 +# # run learning rate finder +# res = Tuner(trainer).lr_find( +# net, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5, max_lr=1e2 # noqa: E501 +# ) +# print(f"suggested learning rate: {res.suggestion()}") +# fig = res.plot(show=True, suggest=True) +# fig.show() +# net.hparams.learning_rate = res.suggestion() + +trainer.fit( + net, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, +) diff --git a/pytorch_forecasting/models/nbeats/grid_callback.py b/pytorch_forecasting/models/nbeats/grid_callback.py index 3c36b1ef4..d311cfe84 100644 --- a/pytorch_forecasting/models/nbeats/grid_callback.py +++ b/pytorch_forecasting/models/nbeats/grid_callback.py @@ -6,6 +6,10 @@ class GridUpdateCallback(Callback): Custom callback to update the grid of the model during training at regular intervals. + Example: + See the full example in: + `examples/nbeats_with_kan.py` + Attributes: update_interval (int): The frequency at which the grid is updated. """ diff --git a/pytorch_forecasting/models/nbeats/kan_layer.py b/pytorch_forecasting/models/nbeats/kan_layer.py index d9357bb46..1f7a18a1c 100644 --- a/pytorch_forecasting/models/nbeats/kan_layer.py +++ b/pytorch_forecasting/models/nbeats/kan_layer.py @@ -349,10 +349,7 @@ def get_grid(num_interval): grid_uniform = ( grid_adaptive[:, [0]] - margin - + h - * torch.arange( - num_interval + 1, - )[None, :] + + h * torch.arange(num_interval + 1, device=h.device)[None, :] ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive return grid From 05350c2d193b949ba12d206f181e605e4935b618 Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Thu, 20 Feb 2025 04:53:14 -0800 Subject: [PATCH 7/8] Refactored code for NBEATSKAN and introduced it as separate model/entity using adapter for common functionality. --- docs/source/models.rst | 1 + examples/nbeats_with_kan.py | 5 +- pytorch_forecasting/__init__.py | 2 + pytorch_forecasting/models/__init__.py | 3 +- pytorch_forecasting/models/nbeats/__init__.py | 9 +- pytorch_forecasting/models/nbeats/_nbeats.py | 406 +----------------- .../models/nbeats/_nbeatskan.py | 235 ++++++++++ .../models/nbeats/nbeats_adapter.py | 322 ++++++++++++++ .../models/nbeats/sub_modules.py | 23 +- 9 files changed, 599 insertions(+), 407 deletions(-) create mode 100644 pytorch_forecasting/models/nbeats/_nbeatskan.py create mode 100644 pytorch_forecasting/models/nbeats/nbeats_adapter.py diff --git a/docs/source/models.rst b/docs/source/models.rst index f8ac486af..71569c724 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -27,6 +27,7 @@ and you should take into account. Here is an overview over the pros and cons of :py:class:`~pytorch_forecasting.models.rnn.RecurrentNetwork`, "x", "x", "x", "", "", "", "", "x", "", 2 :py:class:`~pytorch_forecasting.models.mlp.DecoderMLP`, "x", "x", "x", "x", "", "x", "", "x", "x", 1 :py:class:`~pytorch_forecasting.models.nbeats.NBeats`, "", "", "x", "", "", "", "", "", "", 1 + :py:class:`~pytorch_forecasting.models.nbeats.NBeatsKAN`, "", "", "x", "", "", "", "", "", "", 1 :py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1 :py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "x [#deepvar]_ ", "x", "", 3 :py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4 diff --git a/examples/nbeats_with_kan.py b/examples/nbeats_with_kan.py index 925e8dcf0..952a2acce 100644 --- a/examples/nbeats_with_kan.py +++ b/examples/nbeats_with_kan.py @@ -4,7 +4,7 @@ from lightning.pytorch.callbacks import EarlyStopping import pandas as pd -from pytorch_forecasting import NBeats, TimeSeriesDataSet +from pytorch_forecasting import NBeatsKAN, TimeSeriesDataSet from pytorch_forecasting.data import NaNLabelEncoder from pytorch_forecasting.data.examples import generate_ar_data from pytorch_forecasting.models.nbeats.grid_callback import GridUpdateCallback @@ -74,14 +74,13 @@ ) -net = NBeats.from_dataset( +net = NBeatsKAN.from_dataset( training, learning_rate=3e-2, log_interval=10, log_val_interval=1, log_gradient_flow=False, weight_decay=1e-2, - use_kan=True, ) print(f"Number of parameters in network: {net.size() / 1e3:.1f}k") diff --git a/pytorch_forecasting/__init__.py b/pytorch_forecasting/__init__.py index eabfe481f..e1b150d51 100644 --- a/pytorch_forecasting/__init__.py +++ b/pytorch_forecasting/__init__.py @@ -43,6 +43,7 @@ DeepAR, MultiEmbedding, NBeats, + NBeatsKAN, NHiTS, RecurrentNetwork, TemporalFusionTransformer, @@ -71,6 +72,7 @@ "MultiNormalizer", "TemporalFusionTransformer", "NBeats", + "NBeatsKAN", "NHiTS", "Baseline", "DeepAR", diff --git a/pytorch_forecasting/models/__init__.py b/pytorch_forecasting/models/__init__.py index d4173f620..9b92ef30c 100644 --- a/pytorch_forecasting/models/__init__.py +++ b/pytorch_forecasting/models/__init__.py @@ -11,7 +11,7 @@ from pytorch_forecasting.models.baseline import Baseline from pytorch_forecasting.models.deepar import DeepAR from pytorch_forecasting.models.mlp import DecoderMLP -from pytorch_forecasting.models.nbeats import NBeats +from pytorch_forecasting.models.nbeats import NBeats, NBeatsKAN from pytorch_forecasting.models.nhits import NHiTS from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn from pytorch_forecasting.models.rnn import RecurrentNetwork @@ -21,6 +21,7 @@ __all__ = [ "NBeats", + "NBeatsKAN", "NHiTS", "TemporalFusionTransformer", "RecurrentNetwork", diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index b3264272d..87c1fe7fb 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -1,10 +1,17 @@ """N-Beats model for timeseries forecasting without covariates.""" from pytorch_forecasting.models.nbeats._nbeats import NBeats +from pytorch_forecasting.models.nbeats._nbeatskan import NBeatsKAN from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock, ) -__all__ = ["NBeats", "NBEATSGenericBlock", "NBEATSSeasonalBlock", "NBEATSTrendBlock"] +__all__ = [ + "NBeats", + "NBeatsKAN", + "NBEATSGenericBlock", + "NBEATSSeasonalBlock", + "NBEATSTrendBlock", +] diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index 77fffd332..f85067e22 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -2,24 +2,20 @@ N-Beats model for timeseries forecasting without covariates. """ -from typing import Dict, List, Optional +from typing import List, Optional -import torch from torch import nn -from pytorch_forecasting.data import TimeSeriesDataSet -from pytorch_forecasting.data.encoders import NaNLabelEncoder from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric -from pytorch_forecasting.models.base_model import BaseModel +from pytorch_forecasting.models.nbeats.nbeats_adapter import NBeatsAdapter from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock, ) -from pytorch_forecasting.utils._dependencies import _check_matplotlib -class NBeats(BaseModel): +class NBeats(NBeatsAdapter): def __init__( self, stack_types: Optional[List[str]] = None, @@ -40,19 +36,6 @@ def __init__( reduce_on_plateau_patience: int = 1000, backcast_loss_ratio: float = 0.0, logging_metrics: nn.ModuleList = None, - use_kan: bool = False, - num: int = 5, - k: int = 3, - noise_scale: float = 0.5, - scale_base_mu: float = 0.0, - scale_base_sigma: float = 1.0, - scale_sp: float = 1.0, - base_fun: callable = None, - grid_eps: float = 0.02, - grid_range: List[int] = None, - sp_trainable: bool = True, - sb_trainable: bool = True, - sparse_init: bool = False, **kwargs, ): """ @@ -70,23 +53,23 @@ def __init__( Args: stack_types: One of the following values: “generic”, “seasonality" or - “trend". A list of strings of length 1 or ‘num_stacks’. Default and + “trend". A list of strings of length 1 or 'num_stacks'. Default and recommended value for generic mode: [“generic”] Recommended value for interpretable mode: [“trend”,”seasonality”]. num_blocks: The number of blocks per stack. A list of ints of length 1 or - ‘num_stacks’. Default and recommended value for generic mode: [1] + 'num_stacks'. Default and recommended value for generic mode: [1] Recommended value for interpretable mode: [3] num_block_layers: Number of fully connected layers with ReLu activation per block. - A list of ints of length 1 or ‘num_stacks’. Default and recommended + A list of ints of length 1 or 'num_stacks'. Default and recommended value for generic mode: [4] Recommended value for interpretable mode: [4]. width: Widths of the fully connected layers with ReLu activation in the - blocks. A list of ints of length 1 or ‘num_stacks’. Default and + blocks. A list of ints of length 1 or 'num_stacks'. Default and recommended value for generic mode: [512]. Recommended value for interpretable mode: [256, 2048] sharing: Whether the weights are shared with the other blocks per stack. - A list of ints of length 1 or ‘num_stacks’. Default and recommended + A list of ints of length 1 or 'num_stacks'. Default and recommended value for generic mode: [False]. Recommended value for interpretable mode: [True]. expansion_coefficient_length: If the type is “G” (generic), then the length @@ -95,7 +78,7 @@ def __init__( polynomial. If the type is “S” (seasonal) then this is the minimum period allowed, e.g. 2 for changes every timestep. A list of ints of length 1 or - ‘num_stacks’. Default value for generic mode: [32] Recommended value for + 'num_stacks'. Default value for generic mode: [32] Recommended value for interpretable mode: [3] prediction_length: Length of the prediction. Also known as 'horizon'. context_length: Number of time units that condition the predictions. @@ -113,39 +96,9 @@ def __init__( logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training. Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) - use_kan: flag parameter to decide usage of KAN blocks in NBEATS. if true, - kan layers are used in nbeats block else mlp layers are used. Default: - false. - num : Parameter for KAN layer. the number of grid intervals = G. - Default: 5. used when use_kan is True. - k : Parameter for KAN layer. the order of piecewise polynomial. Default: 3. - used when use_kan is True. - noise_scale : Parameter for KAN layer. the scale of noise injected at - initialization. Default: 0.1. used when use_kan is True. - scale_base_mu : Parameter for KAN layer. the scale of the residual - function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - Deafult: 0.0. used when use_kan is True. - scale_base_sigma : Parameter for KAN layer. the scale of the residual - function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). - Deafult: 1.0. used when use_kan is True. - scale_sp : Parameter for KAN layer. the scale of the base function - spline(x). Deafult: 1.0. used when use_kan is True. - base_fun : Parameter for KAN layer. residual function b(x). - Default: None. used when use_kan is True. - grid_eps : Parameter for KAN layer. When grid_eps = 1, the grid is uniform; - when grid_eps = 0, the grid is partitioned using percentiles of samples. - 0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02. - used when use_kan is True. - grid_range : Parameter for KAN layer. list/np.array of shape (2,). setting - the range of grids. Default: None. used when use_kan is True. - sp_trainable : Parameter for KAN layer. If true, scale_sp is trainable. - Default: True. used when use_kan is True. - sb_trainable : Parameter for KAN layer. If true, scale_base is trainable. - Default: True. used when use_kan is True. - sparse_init : Parameter for KAN layer. if sparse_init = True, sparse - initialization is applied. Default: False. used when use_kan is True. **kwargs: additional arguments to :py:class:`~BaseModel`. """ # noqa: E501 + if expansion_coefficient_lengths is None: expansion_coefficient_lengths = [3, 7] if sharing is None: @@ -158,34 +111,13 @@ def __init__( num_blocks = [3, 3] if stack_types is None: stack_types = ["trend", "seasonality"] - if base_fun is None: - base_fun = torch.nn.SiLU() - if grid_range is None: - grid_range = [-1, 1] if logging_metrics is None: logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) if loss is None: loss = MASE() - # Bundle KAN parameters into a dictionary - self.kan_params = { - "num": num, - "k": k, - "noise_scale": noise_scale, - "scale_base_mu": scale_base_mu, - "scale_base_sigma": scale_base_sigma, - "scale_sp": scale_sp, - "base_fun": base_fun, - "grid_eps": grid_eps, - "grid_range": grid_range, - "sp_trainable": sp_trainable, - "sb_trainable": sb_trainable, - "sparse_init": sparse_init, - } - self.use_kan = use_kan self.save_hyperparameters(ignore=["loss", "logging_metrics"]) super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) - # setup stacks self.net_blocks = nn.ModuleList() for stack_id, stack_type in enumerate(stack_types): @@ -197,9 +129,7 @@ def __init__( num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, - dropout=self.hparams.dropout, - kan_params=self.kan_params, - use_kan=use_kan, + dropout=dropout, ) elif stack_type == "seasonality": net_block = NBEATSSeasonalBlock( @@ -207,10 +137,8 @@ def __init__( num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, - min_period=self.hparams.expansion_coefficient_lengths[stack_id], - dropout=self.hparams.dropout, - kan_params=self.kan_params, - use_kan=use_kan, + min_period=expansion_coefficient_lengths[stack_id], + dropout=dropout, ) elif stack_type == "trend": net_block = NBEATSTrendBlock( @@ -219,315 +147,9 @@ def __init__( num_block_layers=self.hparams.num_block_layers[stack_id], backcast_length=context_length, forecast_length=prediction_length, - dropout=self.hparams.dropout, - kan_params=self.kan_params, - use_kan=use_kan, + dropout=dropout, ) else: raise ValueError(f"Unknown stack type {stack_type}") self.net_blocks.append(net_block) - - def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - Pass forward of network. - - Args: - x (Dict[str, torch.Tensor]): input from dataloader generated from - :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - - Returns: - Dict[str, torch.Tensor]: output of model - """ - target = x["encoder_cont"][..., 0] - - timesteps = self.hparams.context_length + self.hparams.prediction_length - generic_forecast = [ - torch.zeros( - (target.size(0), timesteps), dtype=torch.float32, device=self.device - ) - ] - trend_forecast = [ - torch.zeros( - (target.size(0), timesteps), dtype=torch.float32, device=self.device - ) - ] - seasonal_forecast = [ - torch.zeros( - (target.size(0), timesteps), dtype=torch.float32, device=self.device - ) - ] - forecast = torch.zeros( - (target.size(0), self.hparams.prediction_length), - dtype=torch.float32, - device=self.device, - ) - - backcast = target # initialize backcast - for i, block in enumerate(self.net_blocks): - # evaluate block - backcast_block, forecast_block = block(backcast) - - # add for interpretation - full = torch.cat([backcast_block.detach(), forecast_block.detach()], dim=1) - if isinstance(block, NBEATSTrendBlock): - trend_forecast.append(full) - elif isinstance(block, NBEATSSeasonalBlock): - seasonal_forecast.append(full) - else: - generic_forecast.append(full) - - # update backcast and forecast - backcast = ( - backcast - backcast_block - ) # do not use backcast -= backcast_block as this signifies an inline operation # noqa : E501 - forecast = forecast + forecast_block - - return self.to_network_output( - prediction=self.transform_output(forecast, target_scale=x["target_scale"]), - backcast=self.transform_output( - prediction=target - backcast, target_scale=x["target_scale"] - ), - trend=self.transform_output( - torch.stack(trend_forecast, dim=0).sum(0), - target_scale=x["target_scale"], - ), - seasonality=self.transform_output( - torch.stack(seasonal_forecast, dim=0).sum(0), - target_scale=x["target_scale"], - ), - generic=self.transform_output( - torch.stack(generic_forecast, dim=0).sum(0), - target_scale=x["target_scale"], - ), - ) - - def update_kan_grid(self): - """ - Updates grid of KAN layers when using KAN layers in NBEATSBlock. - """ - if self.use_kan: - for block in self.net_blocks: - # updation logic taken from - # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 - for i, layer in enumerate(block.fc): - # update basis KAN layers' grid - layer.update_grid_from_samples(block.outputs[i]) - # update theta backward and theta forward KAN layers' grid - block.theta_b_fc.update_grid_from_samples(block.outputs[i + 1]) - block.theta_f_fc.update_grid_from_samples(block.outputs[i + 1]) - - @classmethod - def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): - """ - Convenience function to create network from :py:class - `~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. - - Args: - dataset (TimeSeriesDataSet): dataset where sole predictor is the target. - **kwargs: additional arguments to be passed to ``__init__`` method. - - Returns: - NBeats - """ # noqa: E501 - new_kwargs = { - "prediction_length": dataset.max_prediction_length, - "context_length": dataset.max_encoder_length, - } - new_kwargs.update(kwargs) - - # validate arguments - assert isinstance( - dataset.target, str - ), "only one target is allowed (passed as string to dataset)" - assert not isinstance( - dataset.target_normalizer, NaNLabelEncoder - ), "only regression tasks are supported - target must not be categorical" - assert dataset.min_encoder_length == dataset.max_encoder_length, ( - "only fixed encoder length is allowed," - " but min_encoder_length != max_encoder_length" - ) - - assert dataset.max_prediction_length == dataset.min_prediction_length, ( - "only fixed prediction length is allowed," - " but max_prediction_length != min_prediction_length" - ) - - assert ( - dataset.randomize_length is None - ), "length has to be fixed, but randomize_length is not None" - assert ( - not dataset.add_relative_time_idx - ), "add_relative_time_idx has to be False" - - assert ( - len(dataset.flat_categoricals) == 0 - and len(dataset.reals) == 1 - and len(dataset._time_varying_unknown_reals) == 1 - and dataset._time_varying_unknown_reals[0] == dataset.target - ), ( - "The only variable as input should be the" - " target which is part of time_varying_unknown_reals" - ) - - # initialize class - return super().from_dataset(dataset, **new_kwargs) - - def step(self, x, y, batch_idx) -> Dict[str, torch.Tensor]: - """ - Take training / validation step. - """ - log, out = super().step(x, y, batch_idx=batch_idx) - - if ( - self.hparams.backcast_loss_ratio > 0 and not self.predicting - ): # add loss from backcast - backcast = out["backcast"] - backcast_weight = ( - self.hparams.backcast_loss_ratio - * self.hparams.prediction_length - / self.hparams.context_length - ) - backcast_weight = backcast_weight / (backcast_weight + 1) # normalize - forecast_weight = 1 - backcast_weight - if isinstance(self.loss, MASE): - backcast_loss = ( - self.loss(backcast, x["encoder_target"], x["decoder_target"]) - * backcast_weight - ) - else: - backcast_loss = ( - self.loss(backcast, x["encoder_target"]) * backcast_weight - ) - label = ["val", "train"][self.training] - self.log( - f"{label}_backcast_loss", - backcast_loss, - on_epoch=True, - on_step=self.training, - batch_size=len(x["decoder_target"]), - ) - self.log( - f"{label}_forecast_loss", - log["loss"], - on_epoch=True, - on_step=self.training, - batch_size=len(x["decoder_target"]), - ) - log["loss"] = log["loss"] * forecast_weight + backcast_loss - - self.log_interpretation(x, out, batch_idx=batch_idx) - return log, out - - def log_interpretation(self, x, out, batch_idx): - """ - Log interpretation of network predictions in tensorboard. - """ - mpl_available = _check_matplotlib("log_interpretation", raise_error=False) - - # Don't log figures if matplotlib or add_figure is not available - if not mpl_available or not self._logger_supports("add_figure"): - return None - - label = ["val", "train"][self.training] - if self.log_interval > 0 and batch_idx % self.log_interval == 0: - fig = self.plot_interpretation(x, out, idx=0) - name = f"{label.capitalize()} interpretation of item 0 in " - if self.training: - name += f"step {self.global_step}" - else: - name += f"batch {batch_idx}" - self.logger.experiment.add_figure(name, fig, global_step=self.global_step) - - def plot_interpretation( - self, - x: Dict[str, torch.Tensor], - output: Dict[str, torch.Tensor], - idx: int, - ax=None, - plot_seasonality_and_generic_on_secondary_axis: bool = False, - ): - """ - Plot interpretation. - - Plot two pannels: prediction and backcast vs actuals and - decomposition of prediction into trend, seasonality and generic forecast. - - Args: - x (Dict[str, torch.Tensor]): network input - output (Dict[str, torch.Tensor]): network output - idx (int): index of sample for which to plot the interpretation. - ax (List[matplotlib axes], optional): list of two matplotlib axes onto which - to plot the interpretation. Defaults to None. - plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot - seasonality and generic forecast on secondary axis in second panel. - Defaults to False. - - Returns: - plt.Figure: matplotlib figure - """ # noqa: E501 - _check_matplotlib("plot_interpretation") - - import matplotlib.pyplot as plt - - if ax is None: - fig, ax = plt.subplots(2, 1, figsize=(6, 8)) - else: - fig = ax[0].get_figure() - - time = torch.arange( - -self.hparams.context_length, self.hparams.prediction_length - ) - - # plot target vs prediction - ax[0].plot( - time, - torch.cat([x["encoder_target"][idx], x["decoder_target"][idx]]) - .detach() - .cpu(), - label="target", - ) - ax[0].plot( - time, - torch.cat( - [ - output["backcast"][idx].detach(), - output["prediction"][idx].detach(), - ], - dim=0, - ).cpu(), - label="prediction", - ) - ax[0].set_xlabel("Time") - - # plot blocks - prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) - next(prop_cycle) # prediction - next(prop_cycle) # observations - if plot_seasonality_and_generic_on_secondary_axis: - ax2 = ax[1].twinx() - ax2.set_ylabel("Seasonality / Generic") - else: - ax2 = ax[1] - for title in ["trend", "seasonality", "generic"]: - if title not in self.hparams.stack_types: - continue - if title == "trend": - ax[1].plot( - time, - output[title][idx].detach().cpu(), - label=title.capitalize(), - c=next(prop_cycle)["color"], - ) - else: - ax2.plot( - time, - output[title][idx].detach().cpu(), - label=title.capitalize(), - c=next(prop_cycle)["color"], - ) - ax[1].set_xlabel("Time") - ax[1].set_ylabel("Decomposition") - - fig.legend() - return fig diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan.py b/pytorch_forecasting/models/nbeats/_nbeatskan.py new file mode 100644 index 000000000..ea9591645 --- /dev/null +++ b/pytorch_forecasting/models/nbeats/_nbeatskan.py @@ -0,0 +1,235 @@ +""" +N-Beats model with KAN blocks for timeseries forecasting without covariates. +""" + +from typing import List, Optional + +import torch +from torch import nn + +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric +from pytorch_forecasting.models.nbeats.nbeats_adapter import NBeatsAdapter +from pytorch_forecasting.models.nbeats.sub_modules import ( + NBEATSGenericBlock, + NBEATSSeasonalBlock, + NBEATSTrendBlock, +) + + +class NBeatsKAN(NBeatsAdapter): + def __init__( + self, + stack_types: Optional[List[str]] = None, + num_blocks: Optional[List[int]] = None, + num_block_layers: Optional[List[int]] = None, + widths: Optional[List[int]] = None, + sharing: Optional[List[bool]] = None, + expansion_coefficient_lengths: Optional[List[int]] = None, + prediction_length: int = 1, + context_length: int = 1, + dropout: float = 0.1, + learning_rate: float = 1e-2, + log_interval: int = -1, + log_gradient_flow: bool = False, + log_val_interval: int = None, + weight_decay: float = 1e-3, + loss: MultiHorizonMetric = None, + reduce_on_plateau_patience: int = 1000, + backcast_loss_ratio: float = 0.0, + logging_metrics: nn.ModuleList = None, + num: int = 5, + k: int = 3, + noise_scale: float = 0.5, + scale_base_mu: float = 0.0, + scale_base_sigma: float = 1.0, + scale_sp: float = 1.0, + base_fun: callable = None, + grid_eps: float = 0.02, + grid_range: List[int] = None, + sp_trainable: bool = True, + sb_trainable: bool = True, + sparse_init: bool = False, + **kwargs, + ): + """ + Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible. + + Based on the article + `N-BEATS: Neural basis expansion analysis for interpretable time series + forecasting `_. The network has (if + used as ensemble) outperformed all other methods including ensembles of + traditional statical methods in the M4 competition. The M4 competition is + arguably the most important benchmark for univariate time series forecasting. + + The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently + shown to consistently outperform N-BEATS. + + Args: + stack_types: One of the following values: “generic”, “seasonality" or + “trend". A list of strings of length 1 or 'num_stacks'. Default and + recommended value for generic mode: [“generic”] Recommended value for + interpretable mode: [“trend”,”seasonality”]. + num_blocks: The number of blocks per stack. A list of ints of length 1 or + 'num_stacks'. Default and recommended value for generic mode: [1] + Recommended value for interpretable mode: [3] + num_block_layers: Number of fully connected layers with ReLu activation per + block. + A list of ints of length 1 or 'num_stacks'. Default and recommended + value for generic mode: [4] Recommended value for interpretable mode: + [4]. + width: Widths of the fully connected layers with ReLu activation in the + blocks. A list of ints of length 1 or 'num_stacks'. Default and + recommended value for generic mode: [512]. Recommended value for + interpretable mode: [256, 2048] + sharing: Whether the weights are shared with the other blocks per stack. + A list of ints of length 1 or 'num_stacks'. Default and recommended + value for generic mode: [False]. Recommended value for interpretable + mode: [True]. + expansion_coefficient_length: If the type is “G” (generic), then the length + of the expansion coefficient. + If type is “T” (trend), then it corresponds to the degree of the + polynomial. + If the type is “S” (seasonal) then this is the minimum period allowed, + e.g. 2 for changes every timestep. A list of ints of length 1 or + 'num_stacks'. Default value for generic mode: [32] Recommended value for + interpretable mode: [3] + prediction_length: Length of the prediction. Also known as 'horizon'. + context_length: Number of time units that condition the predictions. + Also known as 'lookback period'. + Should be between 1-10 times the prediction length. + backcast_loss_ratio: weight of backcast in comparison to forecast when + calculating the loss. A weight of 1.0 means that forecast and + backcast loss is weighted the same (regardless of backcast and forecast + lengths). Defaults to 0.0, i.e. no weight. + loss: loss to optimize. Defaults to MASE(). + log_gradient_flow: if to log gradient flow, this takes time and should be + only done to diagnose training failures. + reduce_on_plateau_patience (int): patience after which learning rate is + reduced by a factor of 10 + logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that + are logged during training. Defaults to + nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + num : Parameter for KAN layer. the number of grid intervals = G. + Default: 5. + k : Parameter for KAN layer. the order of piecewise polynomial. Default: 3. + noise_scale : Parameter for KAN layer. the scale of noise injected at + initialization. Default: 0.1. + scale_base_mu : Parameter for KAN layer. the scale of the residual + function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + Deafult: 0.0. + scale_base_sigma : Parameter for KAN layer. the scale of the residual + function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2). + Deafult: 1.0. + scale_sp : Parameter for KAN layer. the scale of the base function + spline(x). Deafult: 1.0. + base_fun : Parameter for KAN layer. residual function b(x). + Default: None. + grid_eps : Parameter for KAN layer. When grid_eps = 1, the grid is uniform; + when grid_eps = 0, the grid is partitioned using percentiles of samples. + 0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02. + grid_range : Parameter for KAN layer. list/np.array of shape (2,). setting + the range of grids. Default: None. + sp_trainable : Parameter for KAN layer. If true, scale_sp is trainable. + Default: True. + sb_trainable : Parameter for KAN layer. If true, scale_base is trainable. + Default: True. + sparse_init : Parameter for KAN layer. if sparse_init = True, sparse + initialization is applied. Default: False. + **kwargs: additional arguments to :py:class:`~BaseModel`. + """ # noqa: E501 + + if base_fun is None: + base_fun = torch.nn.SiLU() + if grid_range is None: + grid_range = [-1, 1] + if expansion_coefficient_lengths is None: + expansion_coefficient_lengths = [3, 7] + if sharing is None: + sharing = [True, True] + if widths is None: + widths = [32, 512] + if num_block_layers is None: + num_block_layers = [3, 3] + if num_blocks is None: + num_blocks = [3, 3] + if stack_types is None: + stack_types = ["trend", "seasonality"] + if logging_metrics is None: + logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + if loss is None: + loss = MASE() + + self.save_hyperparameters(ignore=["loss", "logging_metrics"]) + super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) + + # Bundle KAN parameters into a dictionary + kan_params = { + "num": num, + "k": k, + "noise_scale": noise_scale, + "scale_base_mu": scale_base_mu, + "scale_base_sigma": scale_base_sigma, + "scale_sp": scale_sp, + "base_fun": base_fun, + "grid_eps": grid_eps, + "grid_range": grid_range, + "sp_trainable": sp_trainable, + "sb_trainable": sb_trainable, + "sparse_init": sparse_init, + } + self.kan_params = kan_params + # setup stacks + self.net_blocks = nn.ModuleList() + for stack_id, stack_type in enumerate(stack_types): + for _ in range(num_blocks[stack_id]): + if stack_type == "generic": + net_block = NBEATSGenericBlock( + units=self.hparams.widths[stack_id], + thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + dropout=dropout, + kan_params=self.kan_params, + use_kan=True, + ) + elif stack_type == "seasonality": + net_block = NBEATSSeasonalBlock( + units=self.hparams.widths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + min_period=expansion_coefficient_lengths[stack_id], + dropout=dropout, + kan_params=self.kan_params, + use_kan=True, + ) + elif stack_type == "trend": + net_block = NBEATSTrendBlock( + units=self.hparams.widths[stack_id], + thetas_dim=self.hparams.expansion_coefficient_lengths[stack_id], + num_block_layers=self.hparams.num_block_layers[stack_id], + backcast_length=context_length, + forecast_length=prediction_length, + dropout=dropout, + kan_params=self.kan_params, + use_kan=True, + ) + else: + raise ValueError(f"Unknown stack type {stack_type}") + + self.net_blocks.append(net_block) + + def update_kan_grid(self): + """ + Updates grid of KAN layers when using KAN layers in NBEATSBlock. + """ + for block in self.net_blocks: + # updation logic taken from + # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 + for i, layer in enumerate(block.fc): + # update basis KAN layers' grid + layer.update_grid_from_samples(block.outputs[i]) + # update theta backward and theta forward KAN layers' grid + block.theta_b_fc.update_grid_from_samples(block.outputs[i + 1]) + block.theta_f_fc.update_grid_from_samples(block.outputs[i + 1]) diff --git a/pytorch_forecasting/models/nbeats/nbeats_adapter.py b/pytorch_forecasting/models/nbeats/nbeats_adapter.py new file mode 100644 index 000000000..d08d4c5ca --- /dev/null +++ b/pytorch_forecasting/models/nbeats/nbeats_adapter.py @@ -0,0 +1,322 @@ +""" +N-Beats model adapter for timeseries forecasting without covariates. +""" + +from typing import Dict, List, Optional + +import torch +from torch import nn + +from pytorch_forecasting.data import TimeSeriesDataSet +from pytorch_forecasting.data.encoders import NaNLabelEncoder +from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric +from pytorch_forecasting.models.base_model import BaseModel +from pytorch_forecasting.models.nbeats.sub_modules import ( + NBEATSGenericBlock, + NBEATSSeasonalBlock, + NBEATSTrendBlock, +) +from pytorch_forecasting.utils._dependencies import _check_matplotlib + + +class NBeatsAdapter(BaseModel): + def __init__( + self, + **kwargs, + ): + """ + Initialize NBeats Adapter. + + Args: + **kwargs: additional arguments to :py:class:`~BaseModel`. + """ # noqa: E501 + super().__init__(**kwargs) + + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Pass forward of network. + + Args: + x (Dict[str, torch.Tensor]): input from dataloader generated from + :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Returns: + Dict[str, torch.Tensor]: output of model + """ + target = x["encoder_cont"][..., 0] + + timesteps = self.hparams.context_length + self.hparams.prediction_length + generic_forecast = [ + torch.zeros( + (target.size(0), timesteps), dtype=torch.float32, device=self.device + ) + ] + trend_forecast = [ + torch.zeros( + (target.size(0), timesteps), dtype=torch.float32, device=self.device + ) + ] + seasonal_forecast = [ + torch.zeros( + (target.size(0), timesteps), dtype=torch.float32, device=self.device + ) + ] + forecast = torch.zeros( + (target.size(0), self.hparams.prediction_length), + dtype=torch.float32, + device=self.device, + ) + + backcast = target # initialize backcast + for i, block in enumerate(self.net_blocks): + # evaluate block + backcast_block, forecast_block = block(backcast) + + # add for interpretation + full = torch.cat([backcast_block.detach(), forecast_block.detach()], dim=1) + if isinstance(block, NBEATSTrendBlock): + trend_forecast.append(full) + elif isinstance(block, NBEATSSeasonalBlock): + seasonal_forecast.append(full) + else: + generic_forecast.append(full) + + # update backcast and forecast + backcast = ( + backcast - backcast_block + ) # do not use backcast -= backcast_block as this signifies an inline operation # noqa : E501 + forecast = forecast + forecast_block + + return self.to_network_output( + prediction=self.transform_output(forecast, target_scale=x["target_scale"]), + backcast=self.transform_output( + prediction=target - backcast, target_scale=x["target_scale"] + ), + trend=self.transform_output( + torch.stack(trend_forecast, dim=0).sum(0), + target_scale=x["target_scale"], + ), + seasonality=self.transform_output( + torch.stack(seasonal_forecast, dim=0).sum(0), + target_scale=x["target_scale"], + ), + generic=self.transform_output( + torch.stack(generic_forecast, dim=0).sum(0), + target_scale=x["target_scale"], + ), + ) + + @classmethod + def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): + """ + Convenience function to create network from :py:class + `~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + + Args: + dataset (TimeSeriesDataSet): dataset where sole predictor is the target. + **kwargs: additional arguments to be passed to ``__init__`` method. + + Returns: + NBeats + """ # noqa: E501 + new_kwargs = { + "prediction_length": dataset.max_prediction_length, + "context_length": dataset.max_encoder_length, + } + new_kwargs.update(kwargs) + + # validate arguments + assert isinstance( + dataset.target, str + ), "only one target is allowed (passed as string to dataset)" + assert not isinstance( + dataset.target_normalizer, NaNLabelEncoder + ), "only regression tasks are supported - target must not be categorical" + assert dataset.min_encoder_length == dataset.max_encoder_length, ( + "only fixed encoder length is allowed," + " but min_encoder_length != max_encoder_length" + ) + + assert dataset.max_prediction_length == dataset.min_prediction_length, ( + "only fixed prediction length is allowed," + " but max_prediction_length != min_prediction_length" + ) + + assert ( + dataset.randomize_length is None + ), "length has to be fixed, but randomize_length is not None" + assert ( + not dataset.add_relative_time_idx + ), "add_relative_time_idx has to be False" + + assert ( + len(dataset.flat_categoricals) == 0 + and len(dataset.reals) == 1 + and len(dataset._time_varying_unknown_reals) == 1 + and dataset._time_varying_unknown_reals[0] == dataset.target + ), ( + "The only variable as input should be the" + " target which is part of time_varying_unknown_reals" + ) + + # initialize class + return super().from_dataset(dataset, **new_kwargs) + + def step(self, x, y, batch_idx) -> Dict[str, torch.Tensor]: + """ + Take training / validation step. + """ + log, out = super().step(x, y, batch_idx=batch_idx) + + if ( + self.hparams.backcast_loss_ratio > 0 and not self.predicting + ): # add loss from backcast + backcast = out["backcast"] + backcast_weight = ( + self.hparams.backcast_loss_ratio + * self.hparams.prediction_length + / self.hparams.context_length + ) + backcast_weight = backcast_weight / (backcast_weight + 1) # normalize + forecast_weight = 1 - backcast_weight + if isinstance(self.loss, MASE): + backcast_loss = ( + self.loss(backcast, x["encoder_target"], x["decoder_target"]) + * backcast_weight + ) + else: + backcast_loss = ( + self.loss(backcast, x["encoder_target"]) * backcast_weight + ) + label = ["val", "train"][self.training] + self.log( + f"{label}_backcast_loss", + backcast_loss, + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + self.log( + f"{label}_forecast_loss", + log["loss"], + on_epoch=True, + on_step=self.training, + batch_size=len(x["decoder_target"]), + ) + log["loss"] = log["loss"] * forecast_weight + backcast_loss + + self.log_interpretation(x, out, batch_idx=batch_idx) + return log, out + + def log_interpretation(self, x, out, batch_idx): + """ + Log interpretation of network predictions in tensorboard. + """ + mpl_available = _check_matplotlib("log_interpretation", raise_error=False) + + # Don't log figures if matplotlib or add_figure is not available + if not mpl_available or not self._logger_supports("add_figure"): + return None + + label = ["val", "train"][self.training] + if self.log_interval > 0 and batch_idx % self.log_interval == 0: + fig = self.plot_interpretation(x, out, idx=0) + name = f"{label.capitalize()} interpretation of item 0 in " + if self.training: + name += f"step {self.global_step}" + else: + name += f"batch {batch_idx}" + self.logger.experiment.add_figure(name, fig, global_step=self.global_step) + + def plot_interpretation( + self, + x: Dict[str, torch.Tensor], + output: Dict[str, torch.Tensor], + idx: int, + ax=None, + plot_seasonality_and_generic_on_secondary_axis: bool = False, + ): + """ + Plot interpretation. + + Plot two pannels: prediction and backcast vs actuals and + decomposition of prediction into trend, seasonality and generic forecast. + + Args: + x (Dict[str, torch.Tensor]): network input + output (Dict[str, torch.Tensor]): network output + idx (int): index of sample for which to plot the interpretation. + ax (List[matplotlib axes], optional): list of two matplotlib axes onto which + to plot the interpretation. Defaults to None. + plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot + seasonality and generic forecast on secondary axis in second panel. + Defaults to False. + + Returns: + plt.Figure: matplotlib figure + """ # noqa: E501 + _check_matplotlib("plot_interpretation") + + import matplotlib.pyplot as plt + + if ax is None: + fig, ax = plt.subplots(2, 1, figsize=(6, 8)) + else: + fig = ax[0].get_figure() + + time = torch.arange( + -self.hparams.context_length, self.hparams.prediction_length + ) + + # plot target vs prediction + ax[0].plot( + time, + torch.cat([x["encoder_target"][idx], x["decoder_target"][idx]]) + .detach() + .cpu(), + label="target", + ) + ax[0].plot( + time, + torch.cat( + [ + output["backcast"][idx].detach(), + output["prediction"][idx].detach(), + ], + dim=0, + ).cpu(), + label="prediction", + ) + ax[0].set_xlabel("Time") + + # plot blocks + prop_cycle = iter(plt.rcParams["axes.prop_cycle"]) + next(prop_cycle) # prediction + next(prop_cycle) # observations + if plot_seasonality_and_generic_on_secondary_axis: + ax2 = ax[1].twinx() + ax2.set_ylabel("Seasonality / Generic") + else: + ax2 = ax[1] + for title in ["trend", "seasonality", "generic"]: + if title not in self.hparams.stack_types: + continue + if title == "trend": + ax[1].plot( + time, + output[title][idx].detach().cpu(), + label=title.capitalize(), + c=next(prop_cycle)["color"], + ) + else: + ax2.plot( + time, + output[title][idx].detach().cpu(), + label=title.capitalize(), + c=next(prop_cycle)["color"], + ) + ax[1].set_xlabel("Time") + ax[1].set_ylabel("Decomposition") + + fig.legend() + return fig diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index 7ddf17a20..492017e5b 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -145,17 +145,20 @@ def forward(self, x): """ Pass through the fully connected mlp/kan layers and returns the output. """ - # outputs logic taken from - # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 - self.outputs = [] - self.outputs.append(x.clone().detach()) - for layer in self.fc: - x = layer(x) # Pass data through the current layer - # storing outputs for updating grids of self.fc when using KAN + if self.use_kan: + # save outputs to be used in updating grid in kan layers during training + # outputs logic taken from + # https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L2682 + self.outputs = [] + self.outputs.append(x.clone().detach()) + for layer in self.fc: + x = layer(x) # Pass data through the current layer + # storing outputs for updating grids of self.fc when using KAN + self.outputs.append(x.clone().detach()) + # storing for updating grids of theta_b_fc and theta_f_fc when using KAN self.outputs.append(x.clone().detach()) - # storing for updating grids of theta_b_fc and theta_f_fc when using KAN - self.outputs.append(x.clone().detach()) - return x # Return final output + return x # Return final output + return self.fc(x) class NBEATSSeasonalBlock(NBEATSBlock): From 7070f8b429854ef040b82b5cf4659f489738dcaa Mon Sep 17 00:00:00 2001 From: Sohaib-Ahmed21 Date: Sat, 22 Feb 2025 23:13:35 -0800 Subject: [PATCH 8/8] Made modules private. --- examples/nbeats_with_kan.py | 2 +- pytorch_forecasting/models/nbeats/__init__.py | 4 ++++ .../models/nbeats/{grid_callback.py => _grid_callback.py} | 0 .../models/nbeats/{kan_layer.py => _kan_layer.py} | 0 pytorch_forecasting/models/nbeats/_nbeats.py | 2 +- .../models/nbeats/{nbeats_adapter.py => _nbeats_adapter.py} | 0 pytorch_forecasting/models/nbeats/_nbeatskan.py | 2 +- pytorch_forecasting/models/nbeats/sub_modules.py | 2 +- 8 files changed, 8 insertions(+), 4 deletions(-) rename pytorch_forecasting/models/nbeats/{grid_callback.py => _grid_callback.py} (100%) rename pytorch_forecasting/models/nbeats/{kan_layer.py => _kan_layer.py} (100%) rename pytorch_forecasting/models/nbeats/{nbeats_adapter.py => _nbeats_adapter.py} (100%) diff --git a/examples/nbeats_with_kan.py b/examples/nbeats_with_kan.py index 952a2acce..6a018ce5d 100644 --- a/examples/nbeats_with_kan.py +++ b/examples/nbeats_with_kan.py @@ -7,7 +7,7 @@ from pytorch_forecasting import NBeatsKAN, TimeSeriesDataSet from pytorch_forecasting.data import NaNLabelEncoder from pytorch_forecasting.data.examples import generate_ar_data -from pytorch_forecasting.models.nbeats.grid_callback import GridUpdateCallback +from pytorch_forecasting.models.nbeats import GridUpdateCallback sys.path.append("..") diff --git a/pytorch_forecasting/models/nbeats/__init__.py b/pytorch_forecasting/models/nbeats/__init__.py index 87c1fe7fb..b588093af 100644 --- a/pytorch_forecasting/models/nbeats/__init__.py +++ b/pytorch_forecasting/models/nbeats/__init__.py @@ -1,6 +1,8 @@ """N-Beats model for timeseries forecasting without covariates.""" +from pytorch_forecasting.models.nbeats._grid_callback import GridUpdateCallback from pytorch_forecasting.models.nbeats._nbeats import NBeats +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter from pytorch_forecasting.models.nbeats._nbeatskan import NBeatsKAN from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, @@ -14,4 +16,6 @@ "NBEATSGenericBlock", "NBEATSSeasonalBlock", "NBEATSTrendBlock", + "NBeatsAdapter", + "GridUpdateCallback", ] diff --git a/pytorch_forecasting/models/nbeats/grid_callback.py b/pytorch_forecasting/models/nbeats/_grid_callback.py similarity index 100% rename from pytorch_forecasting/models/nbeats/grid_callback.py rename to pytorch_forecasting/models/nbeats/_grid_callback.py diff --git a/pytorch_forecasting/models/nbeats/kan_layer.py b/pytorch_forecasting/models/nbeats/_kan_layer.py similarity index 100% rename from pytorch_forecasting/models/nbeats/kan_layer.py rename to pytorch_forecasting/models/nbeats/_kan_layer.py diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index f85067e22..3326bb5a9 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -7,7 +7,7 @@ from torch import nn from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric -from pytorch_forecasting.models.nbeats.nbeats_adapter import NBeatsAdapter +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, NBEATSSeasonalBlock, diff --git a/pytorch_forecasting/models/nbeats/nbeats_adapter.py b/pytorch_forecasting/models/nbeats/_nbeats_adapter.py similarity index 100% rename from pytorch_forecasting/models/nbeats/nbeats_adapter.py rename to pytorch_forecasting/models/nbeats/_nbeats_adapter.py diff --git a/pytorch_forecasting/models/nbeats/_nbeatskan.py b/pytorch_forecasting/models/nbeats/_nbeatskan.py index ea9591645..9df6b3d2e 100644 --- a/pytorch_forecasting/models/nbeats/_nbeatskan.py +++ b/pytorch_forecasting/models/nbeats/_nbeatskan.py @@ -8,7 +8,7 @@ from torch import nn from pytorch_forecasting.metrics import MAE, MAPE, MASE, RMSE, SMAPE, MultiHorizonMetric -from pytorch_forecasting.models.nbeats.nbeats_adapter import NBeatsAdapter +from pytorch_forecasting.models.nbeats._nbeats_adapter import NBeatsAdapter from pytorch_forecasting.models.nbeats.sub_modules import ( NBEATSGenericBlock, NBEATSSeasonalBlock, diff --git a/pytorch_forecasting/models/nbeats/sub_modules.py b/pytorch_forecasting/models/nbeats/sub_modules.py index 492017e5b..e1ea1288f 100644 --- a/pytorch_forecasting/models/nbeats/sub_modules.py +++ b/pytorch_forecasting/models/nbeats/sub_modules.py @@ -9,7 +9,7 @@ import torch.nn as nn import torch.nn.functional as F -from pytorch_forecasting.models.nbeats.kan_layer import KANLayer +from pytorch_forecasting.models.nbeats._kan_layer import KANLayer def linear(input_size, output_size, bias=True, dropout: int = None):