diff --git a/docs/source/tutorials/tslib_v2_example.ipynb b/docs/source/tutorials/tslib_v2_example.ipynb
new file mode 100644
index 000000000..8af751de9
--- /dev/null
+++ b/docs/source/tutorials/tslib_v2_example.ipynb
@@ -0,0 +1,1431 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "b5d44943",
+ "metadata": {},
+ "source": [
+ "# TSLib for v2 - Example notebook for full pipeline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b7d27b55",
+ "metadata": {},
+ "source": [
+ "## Basic imports for getting started\n",
+ "\n",
+ "This notebook is a basic vignette for the usage of the `tslib` data module on the `TimeXer` model for the v2 of PyTorch Forecasting. This is an experimental version and is an unstable version of the API.\n",
+ "\n",
+ "Feedback and suggestions on this pipeline - PR [#1836](https://github.com/sktime/pytorch-forecasting/pull/1836)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "550a3fbf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from typing import Any, Optional, Union\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "from sklearn.preprocessing import RobustScaler, StandardScaler\n",
+ "import torch\n",
+ "from torch.optim import Optimizer\n",
+ "from torch.utils.data import Dataset\n",
+ "\n",
+ "from pytorch_forecasting.data._tslib_data_module import TslibDataModule\n",
+ "from pytorch_forecasting.data.encoders import (\n",
+ " EncoderNormalizer,\n",
+ " NaNLabelEncoder,\n",
+ " TorchNormalizer,\n",
+ ")\n",
+ "from pytorch_forecasting.data.timeseries import TimeSeries\n",
+ "from pytorch_forecasting.models.timexer._timexer_v2 import TimeXer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2625ed3d",
+ "metadata": {},
+ "source": [
+ "## Construct a time series dataset\n",
+ "\n",
+ "This step requires us to build a `TimeSeries` object for creating a time series dataset, which identifies the features from a raw time series dataset. As you can see below, we are initialising a sample time series dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "a0058487",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.microsoft.datawrangler.viewer.v0+json": {
+ "columns": [
+ {
+ "name": "index",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "series_id",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "time_idx",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "x",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "y",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "category",
+ "rawType": "int64",
+ "type": "integer"
+ },
+ {
+ "name": "future_known_feature",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "static_feature",
+ "rawType": "float64",
+ "type": "float"
+ },
+ {
+ "name": "static_feature_cat",
+ "rawType": "int64",
+ "type": "integer"
+ }
+ ],
+ "ref": "9a040c8c-9b72-4d64-ad12-c5d4702ced35",
+ "rows": [
+ [
+ "0",
+ "0",
+ "0",
+ "-0.03319064433379144",
+ "0.22982012179859285",
+ "0",
+ "1.0",
+ "0.4945926741169627",
+ "0"
+ ],
+ [
+ "1",
+ "0",
+ "1",
+ "0.22982012179859285",
+ "0.4612874019620733",
+ "0",
+ "0.9950041652780258",
+ "0.4945926741169627",
+ "0"
+ ],
+ [
+ "2",
+ "0",
+ "2",
+ "0.4612874019620733",
+ "0.5387362265604877",
+ "0",
+ "0.9800665778412416",
+ "0.4945926741169627",
+ "0"
+ ],
+ [
+ "3",
+ "0",
+ "3",
+ "0.5387362265604877",
+ "0.8368343109148751",
+ "0",
+ "0.955336489125606",
+ "0.4945926741169627",
+ "0"
+ ],
+ [
+ "4",
+ "0",
+ "4",
+ "0.8368343109148751",
+ "0.7705107068068119",
+ "0",
+ "0.9210609940028851",
+ "0.4945926741169627",
+ "0"
+ ]
+ ],
+ "shape": {
+ "columns": 8,
+ "rows": 5
+ }
+ },
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " series_id | \n",
+ " time_idx | \n",
+ " x | \n",
+ " y | \n",
+ " category | \n",
+ " future_known_feature | \n",
+ " static_feature | \n",
+ " static_feature_cat | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0 | \n",
+ " 0 | \n",
+ " -0.033191 | \n",
+ " 0.229820 | \n",
+ " 0 | \n",
+ " 1.000000 | \n",
+ " 0.494593 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 0.229820 | \n",
+ " 0.461287 | \n",
+ " 0 | \n",
+ " 0.995004 | \n",
+ " 0.494593 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0 | \n",
+ " 2 | \n",
+ " 0.461287 | \n",
+ " 0.538736 | \n",
+ " 0 | \n",
+ " 0.980067 | \n",
+ " 0.494593 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0 | \n",
+ " 3 | \n",
+ " 0.538736 | \n",
+ " 0.836834 | \n",
+ " 0 | \n",
+ " 0.955336 | \n",
+ " 0.494593 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0 | \n",
+ " 4 | \n",
+ " 0.836834 | \n",
+ " 0.770511 | \n",
+ " 0 | \n",
+ " 0.921061 | \n",
+ " 0.494593 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " series_id time_idx x y category future_known_feature \\\n",
+ "0 0 0 -0.033191 0.229820 0 1.000000 \n",
+ "1 0 1 0.229820 0.461287 0 0.995004 \n",
+ "2 0 2 0.461287 0.538736 0 0.980067 \n",
+ "3 0 3 0.538736 0.836834 0 0.955336 \n",
+ "4 0 4 0.836834 0.770511 0 0.921061 \n",
+ "\n",
+ " static_feature static_feature_cat \n",
+ "0 0.494593 0 \n",
+ "1 0.494593 0 \n",
+ "2 0.494593 0 \n",
+ "3 0.494593 0 \n",
+ "4 0.494593 0 "
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "num_series = 100\n",
+ "seq_length = 50\n",
+ "data_list = []\n",
+ "for i in range(num_series):\n",
+ " x = np.arange(seq_length)\n",
+ " y = np.sin(x / 5.0) + np.random.normal(scale=0.1, size=seq_length)\n",
+ " category = i % 5\n",
+ " static_value = np.random.rand()\n",
+ " for t in range(seq_length - 1):\n",
+ " data_list.append(\n",
+ " {\n",
+ " \"series_id\": i,\n",
+ " \"time_idx\": t,\n",
+ " \"x\": y[t],\n",
+ " \"y\": y[t + 1],\n",
+ " \"category\": category,\n",
+ " \"future_known_feature\": np.cos(t / 10),\n",
+ " \"static_feature\": static_value,\n",
+ " \"static_feature_cat\": i % 3,\n",
+ " }\n",
+ " )\n",
+ "data_df = pd.DataFrame(data_list)\n",
+ "data_df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c7c04ff5",
+ "metadata": {},
+ "source": [
+ "## Feature Categories and Definitions\n",
+ "\n",
+ "### **`time_idx`**\n",
+ "- **Definition**: The temporal index column that orders observations chronologically\n",
+ "- **Example**: Sequential time steps (0, 1, 2, ...) or timestamps\n",
+ "- **Usage**: Identifies the temporal ordering of data points within each time series\n",
+ "\n",
+ "### **`target`** \n",
+ "- **Definition**: The variable you want to predict/forecast\n",
+ "- **Example**: Sales volume, stock price, temperature readings\n",
+ "- **Usage**: The dependent variable that the model learns to forecast\n",
+ "\n",
+ "### **`group`**\n",
+ "- **Definition**: Categorical variables that identify different time series entities\n",
+ "- **Example**: `series_id`, `store_id`, `product_id`, `customer_id`\n",
+ "- **Usage**: Distinguishes between multiple time series in the dataset\n",
+ "\n",
+ "### **`num`**\n",
+ "- **Definition**: Numerical/continuous features used as model inputs\n",
+ "- **Example**: Price, quantity, weather data, economic indicators \n",
+ "- **Usage**: Continuous variables that provide numerical context for predictions\n",
+ "\n",
+ "### **`cat`**\n",
+ "- **Definition**: Categorical features that represent discrete classes or labels\n",
+ "- **Example**: Product category, day of week, seasonal indicators, region\n",
+ "- **Usage**: Discrete variables that provide categorical context for predictions\n",
+ "\n",
+ "### **`known`**\n",
+ "- **Definition**: Future values that are known at prediction time (exogenous variables)\n",
+ "- **Example**: Holidays, planned promotions, scheduled events, calendar features\n",
+ "- **Usage**: Information available for both historical and future periods\n",
+ "\n",
+ "### **`unknown`**\n",
+ "- **Definition**: Variables only available during training/historical periods\n",
+ "- **Example**: Past weather conditions, historical prices, competitor actions\n",
+ "- **Usage**: Features that help with training but aren't available for future predictions\n",
+ "\n",
+ "### **`static`**\n",
+ "- **Definition**: Time-invariant features that remain constant for each time series\n",
+ "- **Example**: Store size, product attributes, geographic location, customer demographics\n",
+ "- **Usage**: Entity-specific characteristics that don't change over time"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "89a5adbe",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "C:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\pytorch_forecasting\\data\\timeseries\\_timeseries_v2.py:105: UserWarning: TimeSeries is part of an experimental rework of the pytorch-forecasting data layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. For beta testing, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n",
+ " warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset = TimeSeries(\n",
+ " data=data_df,\n",
+ " time=\"time_idx\",\n",
+ " target=\"y\",\n",
+ " group=[\"series_id\"],\n",
+ " num=[\"x\", \"future_know_feature\", \"static_feature\"],\n",
+ " cat=[\"category\", \"static_feature_cat\"],\n",
+ " known=[\"future_known_feature\"],\n",
+ " unknown=[\"x\", \"category\"],\n",
+ " static=[\"static_feature\", \"static_feature_cat\"],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f8753a6a",
+ "metadata": {},
+ "source": [
+ "## Initialise the `TslibDataModule` using the dataset\n",
+ "\n",
+ "This steps initialises a basic data module built specially for `tslib` modules and provides all the metadata required to train and implement the `tslib` of your choice!\n",
+ "You can refer the implementation for `TslibDataModule` for more information."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "5eae9035",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "C:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\pytorch_forecasting\\data\\_tslib_data_module.py:271: UserWarning: TslibDataModule is experimental and subject to change. The API is not stable and may change without prior warning.\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "data_module = TslibDataModule(\n",
+ " time_series_dataset=dataset,\n",
+ " context_length=30,\n",
+ " prediction_length=1,\n",
+ " add_relative_time_idx=True,\n",
+ " target_normalizer=TorchNormalizer(),\n",
+ " categorical_encoders={\n",
+ " \"category\": NaNLabelEncoder(add_nan=True),\n",
+ " \"static_feature_cat\": NaNLabelEncoder(add_nan=True),\n",
+ " },\n",
+ " scalers={\n",
+ " \"x\": StandardScaler(),\n",
+ " \"future_known_feature\": StandardScaler(),\n",
+ " \"static_feature\": StandardScaler(),\n",
+ " },\n",
+ " batch_size=32,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "b1843233",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'feature_names': {'categorical': ['category', 'static_feature_cat'],\n",
+ " 'continuous': ['x', 'future_known_feature', 'static_feature'],\n",
+ " 'static': ['static_feature', 'static_feature_cat'],\n",
+ " 'known': ['future_known_feature'],\n",
+ " 'unknown': ['x', 'category', 'static_feature', 'static_feature_cat'],\n",
+ " 'target': ['y'],\n",
+ " 'all': ['x',\n",
+ " 'category',\n",
+ " 'future_known_feature',\n",
+ " 'static_feature',\n",
+ " 'static_feature_cat'],\n",
+ " 'static_categorical': ['static_feature_cat'],\n",
+ " 'static_continuous': ['static_feature']},\n",
+ " 'feature_indices': {'categorical': [1, 4],\n",
+ " 'continuous': [0, 2, 3],\n",
+ " 'static': [],\n",
+ " 'known': [2],\n",
+ " 'unknown': [0, 1, 3, 4],\n",
+ " 'target': [0]},\n",
+ " 'n_features': {'categorical': 2,\n",
+ " 'continuous': 3,\n",
+ " 'static': 2,\n",
+ " 'known': 1,\n",
+ " 'unknown': 4,\n",
+ " 'target': 1,\n",
+ " 'all': 5,\n",
+ " 'static_categorical': 1,\n",
+ " 'static_continuous': 1},\n",
+ " 'context_length': 30,\n",
+ " 'prediction_length': 1,\n",
+ " 'freq': 'h',\n",
+ " 'features': 'MS'}"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data_module.metadata"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "dd9451ee",
+ "metadata": {},
+ "source": [
+ "## Initialise the model\n",
+ "\n",
+ "We shall try out two versions of this model, one using `MAE()` and one with `QuantileLoss()`.\n",
+ "\n",
+ "Let us quickly import the required packages for the next steps."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "f6b568a5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch.nn as nn\n",
+ "\n",
+ "from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "429b5f15",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "C:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\pytorch_forecasting\\models\\base\\_base_model_v2.py:58: UserWarning: The Model 'TimeXer' is part of an experimental reworkof the pytorch-forecasting model layer, scheduled for release with v2.0.0. The API is not stable and may change without prior warning. This class is intended for beta testing and as a basic skeleton, but not for stable production use. Feedback and suggestions are very welcome in pytorch-forecasting issue 1736, https://github.com/sktime/pytorch-forecasting/issues/1736\n",
+ " warn(\n",
+ "C:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\pytorch_forecasting\\models\\base\\_tslib_base_model_v2.py:60: UserWarning: The Model 'TimeXer' is part of an experimental implementationof the pytorch-forecasting model layer for Time Series Library, scheduledfor release with v2.0.0. The API is not stableand may change without prior warning. This class is intended for betatesting, not for stable production use.\n",
+ " warn(\n",
+ "C:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\pytorch_forecasting\\models\\timexer\\_timexer_v2.py:133: UserWarning: TimeXer is an experimental model implemented on TslibBaseModelV2. It is an unstable version and maybe subject to unannouced changes.Please use with caution. Feedback on the design and implementation iswelcome. On the issue #1833 - https://github.com/sktime/pytorch-forecasting/issues/1833\n",
+ " warn.warn(\n",
+ "C:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\pytorch_forecasting\\models\\timexer\\_timexer_v2.py:180: UserWarning: Context length (30) is not divisible by patch length. This may lead to unexpected behavior, as sometime steps will not be used in the model.\n",
+ " warn.warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "model1 = TimeXer(\n",
+ " loss=nn.MSELoss(),\n",
+ " hidden_size=64,\n",
+ " nhead=4,\n",
+ " e_layers=2,\n",
+ " d_ff=256,\n",
+ " dropout=0.1,\n",
+ " patch_length=4,\n",
+ " logging_metrics=[MAE(), SMAPE()],\n",
+ " optimizer=\"adam\",\n",
+ " optimizer_params={\"lr\": 1e-3},\n",
+ " lr_scheduler=\"reduce_lr_on_plateau\",\n",
+ " lr_scheduler_params={\n",
+ " \"mode\": \"min\",\n",
+ " \"factor\": 0.5,\n",
+ " \"patience\": 5,\n",
+ " },\n",
+ " metadata=data_module.metadata,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "0aa21f48",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model2 = TimeXer(\n",
+ " loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]), # quantiles of 0.1, 0.5 and 0.9 used.\n",
+ " hidden_size=64,\n",
+ " nhead=4,\n",
+ " e_layers=2,\n",
+ " d_ff=256,\n",
+ " dropout=0.1,\n",
+ " patch_length=4,\n",
+ " logging_metrics=[MAE(), SMAPE()],\n",
+ " optimizer=\"adam\",\n",
+ " optimizer_params={\"lr\": 1e-3},\n",
+ " lr_scheduler=\"reduce_lr_on_plateau\",\n",
+ " lr_scheduler_params={\n",
+ " \"mode\": \"min\",\n",
+ " \"factor\": 0.5,\n",
+ " \"patience\": 5,\n",
+ " },\n",
+ " metadata=data_module.metadata,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "02605f9b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n",
+ "GPU available: True (cuda), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "HPU available: False, using: 0 HPUs\n",
+ "Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.\n",
+ "GPU available: True (cuda), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ }
+ ],
+ "source": [
+ "from lightning.pytorch import Trainer\n",
+ "\n",
+ "trainer1 = Trainer(\n",
+ " max_epochs=5,\n",
+ " accelerator=\"auto\",\n",
+ " devices=1,\n",
+ " enable_progress_bar=True,\n",
+ " enable_model_summary=True,\n",
+ ")\n",
+ "\n",
+ "trainer2 = Trainer(\n",
+ " max_epochs=4,\n",
+ " accelerator=\"auto\",\n",
+ " devices=1,\n",
+ " enable_progress_bar=True,\n",
+ " enable_model_summary=True,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e22756b2",
+ "metadata": {},
+ "source": [
+ "## Fit the trainer on the model and feed data using the data module"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "6e9117d2",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "You are using a CUDA device ('NVIDIA GeForce RTX 3050 6GB Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "\n",
+ " | Name | Type | Params | Mode \n",
+ "----------------------------------------------------------------\n",
+ "0 | loss | MSELoss | 0 | train\n",
+ "1 | en_embedding | EnEmbedding | 320 | train\n",
+ "2 | ex_embedding | DataEmbedding_inverted | 2.0 K | train\n",
+ "3 | encoder | Encoder | 133 K | train\n",
+ "4 | head | FlattenHead | 513 | train\n",
+ "----------------------------------------------------------------\n",
+ "136 K Trainable params\n",
+ "0 Non-trainable params\n",
+ "136 K Total params\n",
+ "0.546 Total estimated model params size (MB)\n",
+ "57 Modules in train mode\n",
+ "0 Modules in eval mode\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "ee7a04f4538241e9b735e9a48752f106",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Sanity Checking: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "c:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\.venv\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n",
+ "c:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\.venv\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n",
+ "c:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\.venv\\Lib\\site-packages\\lightning\\pytorch\\loops\\fit_loop.py:310: The number of training batches (42) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "319963d7730f4c0d8047009dbd9167ca",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Training: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "bf4329366a31411691efaf82f6ed16a5",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7f264696d9c1404cb2ffa81f4f7a95b1",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "4bc9933c01884586be7e44e7c58a53a8",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f3229660e28a41c4b0d884e08d99b565",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2545a9ecf028495297a4d3dcd918d618",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "`Trainer.fit` stopped: `max_epochs=5` reached.\n"
+ ]
+ }
+ ],
+ "source": [
+ "trainer1.fit(model1, data_module)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e8fb4f31",
+ "metadata": {},
+ "source": [
+ "Now let us train the model using `QuantileLoss`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "3c67d86f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "\n",
+ " | Name | Type | Params | Mode \n",
+ "----------------------------------------------------------------\n",
+ "0 | loss | QuantileLoss | 0 | train\n",
+ "1 | en_embedding | EnEmbedding | 320 | train\n",
+ "2 | ex_embedding | DataEmbedding_inverted | 2.0 K | train\n",
+ "3 | encoder | Encoder | 133 K | train\n",
+ "4 | head | FlattenHead | 1.5 K | train\n",
+ "----------------------------------------------------------------\n",
+ "137 K Trainable params\n",
+ "0 Non-trainable params\n",
+ "137 K Total params\n",
+ "0.550 Total estimated model params size (MB)\n",
+ "57 Modules in train mode\n",
+ "0 Modules in eval mode\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0c087d79f3584548ac84a2838066a731",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Sanity Checking: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "c:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\.venv\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n",
+ "c:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\.venv\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n",
+ "c:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\.venv\\Lib\\site-packages\\lightning\\pytorch\\loops\\fit_loop.py:310: The number of training batches (42) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "05fc57fe977b4cb0b8695857c5bdfb57",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Training: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "ecaf412ba30f43019c3e154d42a1671d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "6536856597a34b79ae3bb65b2ee6a0d4",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "89590ef6629c46449060c9bbc8747765",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "777c414a2b294d319aabc6500ff61566",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "`Trainer.fit` stopped: `max_epochs=4` reached.\n"
+ ]
+ }
+ ],
+ "source": [
+ "trainer2.fit(model2, data_module)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "16e2d445",
+ "metadata": {},
+ "source": [
+ "## Test the model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "dbf1ace6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "c:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\.venv\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "31779b6cf5cc4049aaedced8d2a2e956",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Testing: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
+ " Test metric DataLoader 0\n",
+ "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
+ " test_MAE 0.46785134077072144\n",
+ " test_SMAPE 1.0638009309768677\n",
+ " test_loss 0.014495044946670532\n",
+ "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n"
+ ]
+ }
+ ],
+ "source": [
+ "test_metrics = trainer1.test(model1, data_module)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "250b128a",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TimeXer(\n",
+ " (loss): MSELoss()\n",
+ " (en_embedding): EnEmbedding(\n",
+ " (value_embedding): Linear(in_features=4, out_features=64, bias=False)\n",
+ " (position_embedding): PositionalEmbedding()\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (ex_embedding): DataEmbedding_inverted(\n",
+ " (value_embedding): Linear(in_features=30, out_features=64, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): Encoder(\n",
+ " (layers): ModuleList(\n",
+ " (0-1): 2 x EncoderLayer(\n",
+ " (self_attention): AttentionLayer(\n",
+ " (inner_attention): FullAttention(\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (query_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " (key_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " (value_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " (out_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " )\n",
+ " (cross_attention): AttentionLayer(\n",
+ " (inner_attention): FullAttention(\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (query_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " (key_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " (value_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " (out_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " )\n",
+ " (conv1): Conv1d(64, 256, kernel_size=(1,), stride=(1,))\n",
+ " (conv2): Conv1d(256, 64, kernel_size=(1,), stride=(1,))\n",
+ " (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " (norm3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (head): FlattenHead(\n",
+ " (flatten): Flatten(start_dim=-2, end_dim=-1)\n",
+ " (linear): Linear(in_features=512, out_features=1, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model1.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "f730b49a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Prediction: tensor([[[-3.8579e-02]],\n",
+ "\n",
+ " [[ 1.3515e-01]],\n",
+ "\n",
+ " [[ 2.7090e-01]],\n",
+ "\n",
+ " [[ 4.3945e-01]],\n",
+ "\n",
+ " [[ 5.7105e-01]],\n",
+ "\n",
+ " [[ 7.0694e-01]],\n",
+ "\n",
+ " [[ 8.1090e-01]],\n",
+ "\n",
+ " [[ 8.7570e-01]],\n",
+ "\n",
+ " [[ 9.0934e-01]],\n",
+ "\n",
+ " [[ 9.0872e-01]],\n",
+ "\n",
+ " [[ 8.6581e-01]],\n",
+ "\n",
+ " [[ 7.9358e-01]],\n",
+ "\n",
+ " [[ 6.9972e-01]],\n",
+ "\n",
+ " [[ 5.8747e-01]],\n",
+ "\n",
+ " [[ 4.4550e-01]],\n",
+ "\n",
+ " [[ 2.9315e-01]],\n",
+ "\n",
+ " [[ 1.5351e-01]],\n",
+ "\n",
+ " [[-5.8678e-04]],\n",
+ "\n",
+ " [[-1.5129e-01]],\n",
+ "\n",
+ " [[ 1.4533e-02]],\n",
+ "\n",
+ " [[ 1.7025e-01]],\n",
+ "\n",
+ " [[ 3.5256e-01]],\n",
+ "\n",
+ " [[ 5.0771e-01]],\n",
+ "\n",
+ " [[ 6.4501e-01]],\n",
+ "\n",
+ " [[ 7.4584e-01]],\n",
+ "\n",
+ " [[ 8.4855e-01]],\n",
+ "\n",
+ " [[ 8.7391e-01]],\n",
+ "\n",
+ " [[ 9.2469e-01]],\n",
+ "\n",
+ " [[ 8.8924e-01]],\n",
+ "\n",
+ " [[ 8.6606e-01]],\n",
+ "\n",
+ " [[ 7.7753e-01]],\n",
+ "\n",
+ " [[ 6.8279e-01]]])\n"
+ ]
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " test_batch = next(iter(data_module.test_dataloader()))\n",
+ " x_test, y_test = test_batch\n",
+ " y_pred = model1(x_test)\n",
+ "\n",
+ " print(\"Prediction:\", y_pred[\"prediction\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "e316c047",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([32, 1, 1])"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_pred[\"prediction\"].shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a01927d4",
+ "metadata": {},
+ "source": [
+ "Let us do the same for `QuantileLoss` predictions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "22bd191f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "c:\\Users\\prana\\Desktop\\code\\pytorch-forecasting\\.venv\\Lib\\site-packages\\lightning\\pytorch\\trainer\\connectors\\data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "69e76b8fe0c841ca8e40f2569373d0f1",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Testing: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
+ " Test metric DataLoader 0\n",
+ "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
+ " test_MAE 14.947474479675293\n",
+ " test_SMAPE 32.57101821899414\n",
+ " test_loss 5.774611473083496\n",
+ "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n"
+ ]
+ }
+ ],
+ "source": [
+ "test_metrics = trainer2.test(model2, data_module)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "a1d857db",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TimeXer(\n",
+ " (loss): QuantileLoss(quantiles=[0.1, 0.5, 0.9])\n",
+ " (en_embedding): EnEmbedding(\n",
+ " (value_embedding): Linear(in_features=4, out_features=64, bias=False)\n",
+ " (position_embedding): PositionalEmbedding()\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (ex_embedding): DataEmbedding_inverted(\n",
+ " (value_embedding): Linear(in_features=30, out_features=64, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): Encoder(\n",
+ " (layers): ModuleList(\n",
+ " (0-1): 2 x EncoderLayer(\n",
+ " (self_attention): AttentionLayer(\n",
+ " (inner_attention): FullAttention(\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (query_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " (key_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " (value_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " (out_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " )\n",
+ " (cross_attention): AttentionLayer(\n",
+ " (inner_attention): FullAttention(\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (query_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " (key_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " (value_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " (out_projection): Linear(in_features=64, out_features=64, bias=True)\n",
+ " )\n",
+ " (conv1): Conv1d(64, 256, kernel_size=(1,), stride=(1,))\n",
+ " (conv2): Conv1d(256, 64, kernel_size=(1,), stride=(1,))\n",
+ " (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " (norm3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (head): FlattenHead(\n",
+ " (flatten): Flatten(start_dim=-2, end_dim=-1)\n",
+ " (linear): Linear(in_features=512, out_features=3, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model2.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "52e2a36a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Prediction: tensor([[[[-0.1741, -0.0312, 0.2449]]],\n",
+ "\n",
+ "\n",
+ " [[[-0.0194, 0.1198, 0.3921]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.1472, 0.2544, 0.5401]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.3183, 0.4101, 0.6707]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.4626, 0.5497, 0.8223]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.5880, 0.6819, 0.9794]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.7212, 0.7909, 1.0700]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.8104, 0.8627, 1.1342]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.8615, 0.9050, 1.1836]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.8919, 0.9103, 1.1939]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.8414, 0.8754, 1.1404]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.7774, 0.8125, 1.0497]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.6535, 0.7326, 0.9382]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.5000, 0.6076, 0.7917]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.3172, 0.4677, 0.6275]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.1383, 0.3008, 0.4571]]],\n",
+ "\n",
+ "\n",
+ " [[[-0.0549, 0.1177, 0.2809]]],\n",
+ "\n",
+ "\n",
+ " [[[-0.2488, -0.0911, 0.0679]]],\n",
+ "\n",
+ "\n",
+ " [[[-0.4082, -0.2451, -0.0699]]],\n",
+ "\n",
+ "\n",
+ " [[[-0.2056, -0.0571, 0.2309]]],\n",
+ "\n",
+ "\n",
+ " [[[-0.0128, 0.0945, 0.3519]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.1674, 0.2839, 0.5486]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.3257, 0.4233, 0.7065]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.4488, 0.5480, 0.8359]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.5644, 0.6576, 0.9174]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.6968, 0.7615, 1.0404]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.7988, 0.8487, 1.1348]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.8528, 0.8676, 1.1572]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.8239, 0.8639, 1.1581]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.8162, 0.8525, 1.1129]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.7303, 0.8025, 1.0176]]],\n",
+ "\n",
+ "\n",
+ " [[[ 0.6131, 0.7197, 0.9090]]]])\n"
+ ]
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " test_batch = next(iter(data_module.test_dataloader()))\n",
+ " x_test, y_test = test_batch\n",
+ " y_pred = model2(x_test)\n",
+ "\n",
+ " print(\"Prediction:\", y_pred[\"prediction\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "a4e6e4b1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([32, 1, 1, 3])"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_pred[\"prediction\"].shape"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/pytorch_forecasting/data/_tslib_data_module.py b/pytorch_forecasting/data/_tslib_data_module.py
new file mode 100644
index 000000000..a33e1e709
--- /dev/null
+++ b/pytorch_forecasting/data/_tslib_data_module.py
@@ -0,0 +1,820 @@
+"""
+Experimmental data module for integrating `tslib` time series deep learning library.
+"""
+
+from typing import Any, Optional, Union
+import warnings
+
+from lightning.pytorch import LightningDataModule
+import numpy as np
+import pandas as pd
+from sklearn.preprocessing import RobustScaler, StandardScaler
+import torch
+from torch.utils.data import DataLoader, Dataset
+
+from pytorch_forecasting.data.encoders import (
+ EncoderNormalizer,
+ NaNLabelEncoder,
+ TorchNormalizer,
+)
+from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries
+from pytorch_forecasting.utils._coerce import _coerce_to_dict
+
+NORMALIZER = Union[TorchNormalizer, EncoderNormalizer, NaNLabelEncoder]
+
+
+class _TslibDataset(Dataset):
+ """
+ Dataset class for `tslib` time series dataset.
+
+ Parameters
+ ----------
+ dataset : TimeSeries
+ The time series dataset to be used for training and validation.
+ data_module : TslibDataModule
+ The data module that contains the metadata and other configurations for the
+ dataset.
+ windows: list[tuple[int, int, int, int]]
+ A list of tuples where each tuple contains:
+ - series_idx: Index of time series in the dataset
+ - start_idx: Start index of the window
+ - context_length: Length of the context/encoder window
+ - prediction_length: Length of the prediction/decoder window
+ add_relative_time_idx: bool
+ Whether to add relative time index to the dataset.
+ """
+
+ def __init__(
+ self,
+ dataset: TimeSeries,
+ data_module: "TslibDataModule",
+ windows: list[tuple[int, int, int, int]],
+ add_relative_time_idx: bool = False,
+ ):
+ self.dataset = dataset
+ self.data_module = data_module
+ self.windows = windows
+ self.add_relative_time_idx = add_relative_time_idx
+
+ def __len__(self) -> int:
+ return len(self.windows)
+
+ def __getitem__(self, idx: int) -> dict[str, Any]:
+ """
+ Get the processed dataset item at the given index.
+
+ Parameters
+ ----------
+ idx : int
+ The index of the dataset item to be retrieved.
+
+ Returns
+ -------
+ x: dict[str, torch.Tensor]
+ A dictionary containing the processed data.
+ y: torch.Tensor
+ The target variable.
+ """
+
+ series_idx, start_idx, context_length, prediction_length = self.windows[idx]
+
+ processed_data = self.data_module._preprocess_data(series_idx)
+
+ continous_features = processed_data["features"]["continuous"]
+ categorical_features = processed_data["features"]["categorical"]
+
+ end_idx = start_idx + context_length + prediction_length
+ history_indices = slice(start_idx, start_idx + context_length)
+ future_indices = slice(start_idx + context_length, end_idx)
+
+ metadata = self.data_module.metadata
+
+ history_cont = continous_features[history_indices]
+ history_cat = categorical_features[history_indices]
+
+ future_cont = continous_features[future_indices]
+ future_cat = categorical_features[future_indices]
+
+ known_features = set(metadata["feature_names"]["known"])
+ continuous_feature_names = metadata["feature_names"]["continuous"]
+ categorical_feature_names = metadata["feature_names"]["categorical"]
+
+ # use masking to filter out known and unknow features.
+ cont_known_mask = torch.tensor(
+ [feat in known_features for feat in continuous_feature_names],
+ dtype=torch.bool,
+ )
+
+ cat_known_mask = torch.tensor(
+ [feat in known_features for feat in categorical_feature_names],
+ dtype=torch.bool,
+ )
+
+ future_cont = (
+ future_cont[:, cont_known_mask]
+ if len(cont_known_mask) > 0
+ else torch.zeros((future_cont.shape[0], 0))
+ ) # noqa: E501
+ future_cat = (
+ future_cat[:, cat_known_mask]
+ if len(cat_known_mask) > 0
+ else torch.zeros((future_cat.shape[0], 0))
+ ) # noqa: E501
+
+ history_mask = (
+ processed_data["time_mask"][history_indices]
+ if "time_mask" in processed_data
+ else torch.ones(context_length, dtype=torch.bool)
+ )
+
+ future_mask = (
+ processed_data["time_mask"][future_indices]
+ if "time_mask" in processed_data
+ else torch.ones(prediction_length, dtype=torch.bool)
+ )
+
+ history_target = processed_data["target"][history_indices]
+ future_target = processed_data["target"][future_indices]
+
+ # history_time_idx = processed_data["timestep"][history_indices]
+ # future_time_idx = processed_data["timestep"][future_indices]
+
+ x = {
+ "history_cont": history_cont,
+ "history_cat": history_cat,
+ "future_cont": future_cont,
+ "future_cat": future_cat,
+ "history_length": torch.tensor(context_length),
+ "future_length": torch.tensor(prediction_length),
+ "history_mask": history_mask,
+ "future_mask": future_mask,
+ "groups": processed_data["group"],
+ "history_time_idx": torch.arange(context_length),
+ "future_time_idx": torch.arange(
+ context_length, context_length + prediction_length
+ ),
+ "history_target": history_target,
+ "future_target": future_target,
+ "future_target_len": torch.tensor(prediction_length),
+ }
+
+ if self.add_relative_time_idx:
+ x["history_relative_time_idx"] = torch.arange(-context_length, 0)
+ x["future_relative_time_idx"] = torch.arange(0, prediction_length)
+
+ if processed_data["static"] is not None:
+ x["static_categorical_features"] = processed_data["static"].unsqueeze(0)
+ x["static_continuous_features"] = processed_data["static"].unsqueeze(0)
+
+ if "target_scale" in processed_data:
+ x["target_scale"] = processed_data["target_scale"]
+
+ y = processed_data["target"][future_indices]
+
+ return x, y
+
+
+class TslibDataModule(LightningDataModule):
+ """
+ Experimental data module for integrating `tslib` time series into
+ PyTorch Forecasting.
+
+ This module serves as the D2 layer for `tslib` models including transformer-based
+ architectures like Informer, AutoFormer, TimeXer and other model deep learning model
+ architectures.
+
+ Parameters
+ ----------
+ time_series_dataset: TimeSeries
+ The time series dataset to be used for training and validation. This is the
+ newly implemented D1 layer.
+ context_length: int
+ The length of the context window for the model. This is the number of time steps
+ used as input to the model.
+ prediction_length: int
+ The length of the prediction window for the model. This is the number of time
+ steps to be predicted by the model.
+ freq: str, default = "h"
+ The frequency of the time series data. This is used to determine the time steps
+ for the model.
+ features: str = "MS"
+ Feature combination mode:
+ - "S": Single variable forecasting (target only)
+ - "M": Multivariate forecasting, using all variables
+ - "MS": Multivariate to single, using all variables to predict target
+ add_relative_time_idx: bool = False
+ Whether to allow the relative time index to be used with the model.
+ add_target_scales: bool = False
+ Whether to add target scaling info.
+ target_normalizer :
+ Union[NORMALIZER, str, list[NORMALIZER], tuple[NORMALIZER], None],
+ default="auto"
+ Normalizer for the target variable. If "auto", uses `RobustScaler`.
+ scalers : Optional[dict[str, Union[StandardScaler, RobustScaler, TorchNormalizer]]], default=None #noqa: E501
+ Dictionary of feature scalers.
+ shuffle : bool, default=True
+ Whether to shuffle the data at every epoch.
+ window_stride : int, default=1
+ The stride for the sliding window. This is used to create overlapping windows
+ for the data.
+ batch_size : int, default=32
+ Batch size for dataloader.
+ num_workers : int, default=0
+ Number of workers for dataloader.
+ train_val_test_split : tuple, default=(0.7, 0.15, 0.15)
+ Proportions for train, validation, and test dataset splits.
+ collate_fn : Optional[callable], default=None
+ Custom collate function for the dataloader.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ time_series_dataset: TimeSeries,
+ context_length: int,
+ prediction_length: int,
+ freq: str = "h",
+ add_relative_time_idx: bool = False,
+ add_target_scales: bool = False,
+ target_normalizer: Union[
+ NORMALIZER, str, list[NORMALIZER], tuple[NORMALIZER], None
+ ] = "auto", # noqa: E501
+ scalers: Optional[
+ dict[
+ str,
+ Union[StandardScaler, RobustScaler, TorchNormalizer, EncoderNormalizer],
+ ]
+ ] = None, # noqa: E501
+ shuffle: bool = True,
+ window_stride: int = 1,
+ batch_size: int = 32,
+ num_workers: int = 0,
+ train_val_test_split: tuple[float, float, float] = (0.7, 0.15, 0.15),
+ collate_fn: Optional[callable] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ self.time_series_dataset = time_series_dataset
+ self.context_length = context_length
+ self.prediction_length = prediction_length
+ self.freq = freq
+ self.add_relative_time_idx = add_relative_time_idx
+ self.add_target_scales = add_target_scales
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.train_val_test_split = train_val_test_split
+ self.collate_fn = (
+ collate_fn if collate_fn is not None else self.__class__.collate_fn
+ ) # noqa: E501
+ self.kwargs = kwargs
+
+ warnings.warn(
+ "TslibDataModule is experimental and subject to change. "
+ "The API is not stable and may change without prior warning.",
+ UserWarning,
+ )
+
+ if isinstance(target_normalizer, str) and target_normalizer.lower() == "auto":
+ self._target_normalizer = RobustScaler()
+ else:
+ self._target_normalizer = target_normalizer
+
+ self._metadata = None
+
+ self.scalers = scalers or {}
+ self.shuffle = shuffle
+
+ self.continuous_indices = []
+ self.categorical_indices = []
+
+ self.train_dataset = None
+ self.val_dataset = None
+ self.test_dataset = None
+
+ self.window_stride = window_stride
+
+ self.time_series_metadata = time_series_dataset.get_metadata()
+
+ for idx, col in enumerate(self.time_series_metadata["cols"]["x"]):
+ if self.time_series_metadata["col_type"].get(col) == "C":
+ self.categorical_indices.append(idx)
+ else:
+ self.continuous_indices.append(idx)
+
+ self._validate_indices()
+
+ def _validate_indices(self):
+ """
+ Validate that we have meaningful features for training.
+ Raises warnings for missing features or indices.
+ """
+
+ has_continuous = self.continuous_indices and len(self.continuous_indices) > 0
+ has_categorical = self.categorical_indices and len(self.categorical_indices) > 0
+ has_targets = len(self.time_series_metadata.get("cols", {}).get("y", [])) > 0
+ if not has_targets:
+ raise ValueError(
+ "No target variables found in the dataset. "
+ "Cannot proceed with model training."
+ )
+
+ if not has_continuous and not has_categorical and has_targets:
+ warnings.warn(
+ "No continuous or categorical features found. "
+ "Proceeding with pure univariate forecasting "
+ "using target history only.",
+ UserWarning,
+ )
+ return
+
+ if not has_continuous:
+ warnings.warn(
+ "No continuous features found in the dataset. "
+ "Some models (TimeXer) requires continous features. "
+ "Consider adding continous featuresinto the dataset.",
+ UserWarning,
+ )
+
+ if not has_categorical:
+ warnings.warn(
+ "No categorical features found in the dataset. "
+ "This may limit the model capabilities and and restrict "
+ "the usage to continuous features only.",
+ UserWarning,
+ )
+
+ def _prepare_metadata(self) -> dict[str, Any]:
+ """
+ Prepare metadata for `tslib` time series data module.
+
+ Returns
+ -------
+ dict containing the following as keys:
+ - feature_names: dict[str, list[str]]
+ Dictionary of feature names for each feature type.
+ - feature_indices: dict[str, list[int]]
+ Dictionary of feature indices for each feature type.
+ - n_features: dict[str, int]
+ Dictionary of number of features for each feature type.
+ - context_length: int
+ Length of the context window for the model, as set in the data module.
+ - prediction_length: int
+ Length of the prediction window for the model, as set in the data
+ module.
+ - freq: str or None
+ - features: str
+ Feature combination mode.
+ """
+ # TODO: include handling for datasets without get_metadata()
+ ds_metadata = self.time_series_metadata
+
+ feature_names = {
+ "categorical": [],
+ "continuous": [],
+ "static": [],
+ "known": [],
+ "unknown": [],
+ "target": [],
+ "all": [],
+ }
+
+ feature_indices = {
+ "categorical": [],
+ "continuous": [],
+ "static": [],
+ "known": [],
+ "unknown": [],
+ "target": [],
+ }
+
+ cols = ds_metadata.get("cols", {})
+ col_type = ds_metadata.get("col_type", {})
+ col_known = ds_metadata.get("col_known", {})
+
+ all_features = cols.get("x", [])
+ static_features = cols.get("st", [])
+ target_features = cols.get("y", [])
+
+ if len(target_features) == 0:
+ raise ValueError(
+ "The time series dataset must have at least one target variable. "
+ "Please provide a dataset with a target variable."
+ )
+
+ feature_names["all"] = list(all_features)
+ feature_names["static"] = list(static_features)
+ feature_names["target"] = list(target_features)
+
+ for idx, col in enumerate(all_features):
+ if col_type.get(col, "F") == "C":
+ feature_names["categorical"].append(col)
+ feature_indices["categorical"].append(idx)
+ else:
+ feature_names["continuous"].append(col)
+ feature_indices["continuous"].append(idx)
+
+ if col_known.get(col, "U") == "K":
+ feature_names["known"].append(col)
+ feature_indices["known"].append(idx)
+ else:
+ feature_names["unknown"].append(col)
+ feature_indices["unknown"].append(idx)
+
+ static_cat_names, static_cont_names = [], []
+ for col in static_features:
+ if col_type.get(col, "F") == "C":
+ static_cat_names.append(col)
+ else:
+ static_cont_names.append(col)
+
+ feature_indices["target"] = list(range(len(target_features)))
+
+ feature_names["static_categorical"] = static_cat_names
+ feature_names["static_continuous"] = static_cont_names
+
+ n_features = {k: len(v) for k, v in feature_names.items()}
+
+ # detect the feature mode - S/MS/M
+
+ n_targets = n_features["target"]
+ n_cont = n_features["continuous"]
+ n_cat = n_features["categorical"]
+
+ if n_targets == 1 and (n_cont + n_cat) == 0:
+ self.features = "S"
+ elif n_targets == 1 and (n_cont + n_cat) >= 1:
+ self.features = "MS"
+ elif n_targets > 1 and (n_cont + n_cat) > 0:
+ self.features = "M"
+ else:
+ self.features = "M"
+
+ metadata = {
+ "feature_names": feature_names,
+ "feature_indices": feature_indices,
+ "n_features": n_features,
+ "context_length": self.context_length,
+ "prediction_length": self.prediction_length,
+ "freq": self.freq,
+ "features": self.features,
+ }
+
+ return metadata
+
+ @property
+ def metadata(self) -> dict[str, Any]:
+ """ "
+ Compute the metadata via the `_prepare_metadata` method.
+ This method is called when the `metadata` property is accessed for the first.
+ Returns
+ -------
+ dict
+ Metadata for the data module. Refer to the `_prepare_metadata` method for
+ the keys and values in the metadata dictionary.
+ """
+ if self._metadata is None:
+ self._metadata = self._prepare_metadata()
+ return self._metadata
+
+ def _preprocess_data(self, idx: torch.Tensor) -> list[dict[str, Any]]:
+ """
+ Process the the time series data at the given index, before feeding it
+ to the `_TslibDataset` class.
+
+ Parameters
+ ----------
+ idx : torch.Tensor
+ The index of the time series data to be processed.
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ A dictionary containing the processed data.
+
+ Notes
+ -----
+ - The target data `y` and features `x` are converted to torch.float32 tensors.
+ - The timepoints before the cutoff time are masked off.
+ - Splits data into categorical and continous features, which are grouped based on the indices.
+ """ # noqa: E501
+
+ series = self.time_series_dataset[idx]
+ if series is None:
+ raise ValueError(f"series at index {idx} is None. Check the dataset.")
+ target = series["y"]
+ features = series["x"]
+ timestep = series["t"]
+ cutoff_time = series["cutoff_time"]
+
+ mask_timestep = torch.tensor(timestep <= cutoff_time, dtype=torch.bool)
+
+ if isinstance(target, torch.Tensor):
+ target = target.detach().clone().float()
+ else:
+ target = torch.tensor(target, dtype=torch.float32)
+
+ if isinstance(features, torch.Tensor):
+ features = features.detach().clone().float()
+ else:
+ features = torch.tensor(features, dtype=torch.float32)
+
+ # scaling and normlization
+ target_scale = {}
+
+ categorical_features = (
+ features[:, self.categorical_indices]
+ if self.categorical_indices
+ else torch.zeros((features.shape[0], 0))
+ )
+
+ continuous_features = (
+ features[:, self.continuous_indices]
+ if self.continuous_indices
+ else torch.zeros((features.shape[0], 0))
+ )
+
+ res = {
+ "features": {
+ "categorical": categorical_features,
+ "continuous": continuous_features,
+ },
+ "target": target,
+ "static": series["st"],
+ "group": series.get("group", torch.tensor([0])),
+ "length": len(series),
+ "time_mask": mask_timestep,
+ "cutoff_time": cutoff_time,
+ "timestep": timestep,
+ }
+
+ if target_scale:
+ res["target_scale"] = target_scale
+
+ return res
+
+ def _create_windows(self, indices: torch.Tensor) -> list[tuple[int, int, int, int]]:
+ """
+ Create windows for the data in the given indices, for training, testing
+ and validation.
+
+ Parameters
+ ----------
+ indices : torch.Tensor
+ The indices of the time series data to be processed.
+
+ Returns
+ -------
+ list[tuple[int, int, int, int]]
+ A list of tuples where each tuple contains:
+ - series_idx: Index of time series in the dataset
+ - start_idx: Start index of the window
+ - context_length: Length of the context/encoder window
+ - prediction_length: Length of the prediction/decoder window
+ """
+
+ windows = []
+
+ min_seq_length = self.context_length + self.prediction_length
+
+ for idx in indices:
+ series_idx = idx.item() if isinstance(idx, torch.Tensor) else idx
+ sample = self.time_series_dataset[series_idx]
+ sequence_length = len(sample["t"])
+
+ if sequence_length < min_seq_length:
+ continue
+
+ effective_min_prediction_idx = self.context_length
+
+ max_prediction_idx = sequence_length - self.prediction_length + 1
+
+ if max_prediction_idx <= effective_min_prediction_idx:
+ continue
+
+ stride = self.window_stride
+
+ for start_idx in range(
+ 0, max_prediction_idx - effective_min_prediction_idx, stride
+ ): # noqa: E501
+ if start_idx + self.context_length + self.prediction_length <= (
+ sequence_length
+ ):
+ windows.append(
+ (
+ series_idx,
+ start_idx,
+ self.context_length,
+ self.prediction_length,
+ )
+ )
+
+ return windows
+
+ def setup(self, stage: Optional[str] = None) -> None:
+ """
+ Setup the data module by preparing the datasets for training,
+ testing and validation.
+
+ Parameters
+ ----------
+ stage: Optional[str]
+ The stage of the data module. This can be "fit", "test" or "predict".
+ If None, the data module will be setup for training.
+ """
+
+ # TODO: Add support for temporal/random/group splits.
+ # Currently, it only supports random splits.
+ # Handle the case where the dataset is empty.
+
+ total_series = len(self.time_series_dataset)
+
+ if total_series == 0:
+ raise ValueError(
+ "The time series dataset is empty. "
+ "Please provide a non-empty dataset."
+ )
+
+ # this is a very rudimentary way to handle the splits when
+ # the dataset is of size equal to 1 or 2.
+ self._indices = torch.randperm(total_series)
+ if total_series == 1:
+ self._train_indices = self._indices
+ self._val_indices = self._indices
+ self._test_indices = self._indices
+ elif total_series == 2:
+ self._train_indices = self._indices[0:1]
+ self._val_indices = self._indices[1:2]
+ self._test_indices = self._indices[1:2]
+ else:
+ self._train_size = int(self.train_val_test_split[0] * total_series)
+ self._val_size = int(self.train_val_test_split[1] * total_series)
+
+ self._train_indices = self._indices[: self._train_size]
+ self._val_indices = self._indices[
+ self._train_size : self._train_size + self._val_size
+ ]
+
+ self._test_indices = self._indices[
+ self._train_size + self._val_size : total_series
+ ]
+
+ if stage == "fit" or stage is None:
+ if not hasattr(self, "_train_dataset") or not hasattr(self, "_val_dataset"):
+ self._train_windows = self._create_windows(self._train_indices)
+ self._val_windows = self._create_windows(self._val_indices)
+
+ self.train_dataset = _TslibDataset(
+ dataset=self.time_series_dataset,
+ data_module=self,
+ windows=self._train_windows,
+ add_relative_time_idx=self.add_relative_time_idx,
+ )
+
+ self.val_dataset = _TslibDataset(
+ dataset=self.time_series_dataset,
+ data_module=self,
+ windows=self._val_windows,
+ add_relative_time_idx=self.add_relative_time_idx,
+ )
+ elif stage == "test":
+ if not hasattr(self, "_test_dataset"):
+ self._test_windows = self._create_windows(self._test_indices)
+
+ self.test_dataset = _TslibDataset(
+ dataset=self.time_series_dataset,
+ data_module=self,
+ windows=self._test_windows,
+ add_relative_time_idx=self.add_relative_time_idx,
+ )
+
+ elif stage == "predict":
+ predict_indices = torch.arange(len(self.time_series_dataset))
+ self._predict_windows = self._create_windows(predict_indices)
+
+ self.predict_dataset = _TslibDataset(
+ dataset=self.time_series_dataset,
+ data_module=self,
+ windows=self._predict_windows,
+ add_relative_time_idx=self.add_relative_time_idx,
+ )
+
+ def train_dataloader(self) -> DataLoader:
+ """
+ Create the train dataloader.
+
+ Returns
+ -------
+ DataLoader
+ The train dataloader.
+ """
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ collate_fn=self.collate_fn,
+ )
+
+ def val_dataloader(self) -> DataLoader:
+ """
+ Create the validation dataloader.
+ Returns
+ -------
+ DataLoader
+ The validation dataloader.
+ """
+ return DataLoader(
+ self.val_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ collate_fn=self.collate_fn,
+ )
+
+ def test_dataloader(self) -> DataLoader:
+ """
+ Create the test dataloader.
+
+ Returns
+ -------
+ DataLoader
+ The test dataloader.
+ """
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ collate_fn=self.collate_fn,
+ )
+
+ def predict_dataloader(self) -> DataLoader:
+ """
+ Create the prediction dataloader.
+
+ Returns
+ -------
+ DataLoader
+ The prediction dataloader.
+ """
+ return DataLoader(
+ self.predict_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ collate_fn=self.collate_fn,
+ )
+
+ @staticmethod
+ def collate_fn(batch):
+ """
+ Custom collate function for the dataloader.
+
+ Parameters
+ ----------
+ batch: list[tuple[dict[str, Any]]]
+ The batch of data to be collated.
+
+ Returns
+ -------
+ tuple[dict[str, torch.Tensor], torch.Tensor]
+ A tuple containing the collated data and the target variable.
+ """
+
+ x_batch = {
+ "history_cont": torch.stack([x["history_cont"] for x, _ in batch]),
+ "history_cat": torch.stack([x["history_cat"] for x, _ in batch]),
+ "future_cont": torch.stack([x["future_cont"] for x, _ in batch]),
+ "future_cat": torch.stack([x["future_cat"] for x, _ in batch]),
+ "history_length": torch.stack([x["history_length"] for x, _ in batch]),
+ "future_length": torch.stack([x["future_length"] for x, _ in batch]),
+ "history_mask": torch.stack([x["history_mask"] for x, _ in batch]),
+ "future_mask": torch.stack([x["future_mask"] for x, _ in batch]),
+ "groups": torch.stack([x["groups"] for x, _ in batch]),
+ "history_time_idx": torch.stack([x["history_time_idx"] for x, _ in batch]),
+ "future_time_idx": torch.stack([x["future_time_idx"] for x, _ in batch]),
+ "history_target": torch.stack([x["history_target"] for x, _ in batch]),
+ "future_target": torch.stack([x["future_target"] for x, _ in batch]),
+ "future_target_len": torch.stack(
+ [x["future_target_len"] for x, _ in batch]
+ ),
+ }
+
+ if "target_scale" in batch[0][0]:
+ x_batch["target_scale"] = torch.stack([x["target_scale"] for x, _ in batch])
+
+ if "history_relative_time_idx" in batch[0][0]:
+ x_batch["history_relative_time_idx"] = torch.stack(
+ [x["history_relative_time_idx"] for x, _ in batch]
+ )
+ x_batch["future_relative_time_idx"] = torch.stack(
+ [x["future_relative_time_idx"] for x, _ in batch]
+ )
+
+ if "static_categorical_features" in batch[0][0]:
+ x_batch["static_categorical_features"] = torch.stack(
+ [x["static_categorical_features"] for x, _ in batch]
+ )
+ x_batch["static_continuous_features"] = torch.stack(
+ [x["static_continuous_features"] for x, _ in batch]
+ )
+
+ y_batch = torch.stack([y for _, y in batch])
+ return x_batch, y_batch
diff --git a/pytorch_forecasting/data/tests/__init__.py b/pytorch_forecasting/data/tests/__init__.py
new file mode 100644
index 000000000..162591895
--- /dev/null
+++ b/pytorch_forecasting/data/tests/__init__.py
@@ -0,0 +1 @@
+"""Tests for data modules and dataloaders in pytorch_forecasting.data package."""
diff --git a/pytorch_forecasting/data/tests/test_tslib_data_module.py b/pytorch_forecasting/data/tests/test_tslib_data_module.py
new file mode 100644
index 000000000..6822a9fa3
--- /dev/null
+++ b/pytorch_forecasting/data/tests/test_tslib_data_module.py
@@ -0,0 +1,532 @@
+import numpy as np
+import pandas as pd
+import pytest
+import torch
+
+from pytorch_forecasting.data._tslib_data_module import TslibDataModule
+from pytorch_forecasting.data.timeseries import TimeSeries
+
+
+@pytest.fixture(scope="session")
+def sample_timeseries_data():
+ """Fixture to generate a sample TimeSeries."""
+
+ np.random.seed(42)
+ n_series = 20
+ n_timesteps = 50
+
+ data = []
+
+ for series_id in range(n_series):
+ for time_idx in range(n_timesteps):
+ # Generate a target variable with some noise
+ target = (
+ 10
+ + 0.1 * time_idx
+ + np.sin(2 * np.pi * time_idx / 12)
+ + np.random.randn() * 0.5
+ ) # noqa: E501
+
+ cat_a = np.random.choice([0, 1, 2])
+
+ feature_1 = np.random.randn() + time_idx * 0.01
+ feature_2 = target * 0.8 + np.random.randn() * 0.2
+ feature_3 = np.sin(time_idx / 5) + np.random.randn() * 0.1
+
+ static_feature = series_id * 2.5
+
+ data.append(
+ {
+ "series_id": series_id,
+ "time_idx": time_idx,
+ "target": target,
+ "cat_a": cat_a,
+ "feature_1": feature_1,
+ "feature_2": feature_2,
+ "feature_3": feature_3,
+ "static_feature": static_feature,
+ }
+ )
+
+ df = pd.DataFrame(data)
+
+ time_series = TimeSeries(
+ data=df,
+ time="time_idx",
+ target="target",
+ group=["series_id"],
+ num=["feature_1", "feature_2", "feature_3"],
+ cat=["cat_a"],
+ unknown=["feature_2", "target", "cat_a"],
+ static=["static_feature"],
+ known=["feature_1", "feature_3"],
+ )
+ return time_series
+
+
+@pytest.fixture
+def tslib_data_module(sample_timeseries_data):
+ """Fixture for TSLibDataModule."""
+ return TslibDataModule(
+ time_series_dataset=sample_timeseries_data,
+ context_length=8,
+ prediction_length=4,
+ batch_size=2, # Smaller batch size for faster testing
+ num_workers=0, # Avoid multiprocessing issues in tests
+ )
+
+
+def test_init(sample_timeseries_data):
+ """Test the initialization of the data module."""
+
+ tslib_dm = TslibDataModule(
+ time_series_dataset=sample_timeseries_data,
+ context_length=32,
+ prediction_length=16,
+ batch_size=8,
+ )
+
+ assert tslib_dm.time_series_dataset == sample_timeseries_data
+ assert tslib_dm.context_length == 32
+ assert tslib_dm.prediction_length == 16
+ assert tslib_dm.batch_size == 8
+ assert tslib_dm.train_val_test_split == (0.7, 0.15, 0.15)
+
+ assert isinstance(tslib_dm.time_series_metadata, dict)
+ assert "cols" in tslib_dm.time_series_metadata
+
+
+def test_prepare_metadata(tslib_data_module):
+ """Test the metadata preparation to ensure correct metadata extraction
+ and structure."""
+
+ metadata = tslib_data_module.metadata
+
+ assert isinstance(metadata, dict)
+
+ assert "feature_names" in metadata
+ assert "feature_indices" in metadata
+ assert "n_features" in metadata
+ assert "context_length" in metadata
+ assert "prediction_length" in metadata
+ assert "freq" in metadata
+ assert "features" in metadata
+
+ assert "categorical" in metadata["feature_names"]
+ assert "continuous" in metadata["feature_names"]
+ assert "static" in metadata["feature_names"]
+ assert "known" in metadata["feature_names"]
+ assert "unknown" in metadata["feature_names"]
+ assert "target" in metadata["feature_names"]
+ assert "all" in metadata["feature_names"]
+ assert "static_categorical" in metadata["feature_names"]
+ assert "static_continuous" in metadata["feature_names"]
+
+ assert "categorical" in metadata["feature_indices"]
+ assert "continuous" in metadata["feature_indices"]
+ assert "static" in metadata["feature_indices"]
+ assert "known" in metadata["feature_indices"]
+ assert "unknown" in metadata["feature_indices"]
+ assert "target" in metadata["feature_indices"]
+
+ for k in metadata["n_features"]:
+ assert k in metadata["n_features"]
+ assert metadata["n_features"][k] == len(metadata["feature_names"][k])
+
+ assert metadata["context_length"] == tslib_data_module.context_length
+ assert metadata["prediction_length"] == tslib_data_module.prediction_length
+
+
+def test_setup(tslib_data_module):
+ """Test the setup method to ensure datamodule is setup for training,
+ testing, and validation."""
+
+ tslib_data_module.setup(stage="fit")
+ assert hasattr(tslib_data_module, "train_dataset")
+ assert hasattr(tslib_data_module, "val_dataset")
+ assert len(tslib_data_module._train_windows) > 0
+ assert len(tslib_data_module._val_windows) > 0
+
+ tslib_data_module.setup(stage="test")
+ assert hasattr(tslib_data_module, "test_dataset")
+ assert len(tslib_data_module._test_windows) > 0
+
+ tslib_data_module.setup(stage="predict")
+ assert hasattr(tslib_data_module, "predict_dataset")
+ assert len(tslib_data_module._predict_windows) > 0
+
+
+def test_train_dataloader(tslib_data_module):
+ """Test the train dataloader to ensure it returns the batches of the data,
+ and all hyperparameters are correctly set."""
+
+ tslib_data_module.setup(stage="fit")
+ train_data_loader = tslib_data_module.train_dataloader()
+
+ assert hasattr(train_data_loader, "batch_size")
+ assert train_data_loader.batch_size == tslib_data_module.batch_size
+ assert train_data_loader.num_workers == tslib_data_module.num_workers
+
+ val_data_loader = tslib_data_module.val_dataloader()
+ assert hasattr(val_data_loader, "batch_size")
+
+
+def test_test_dataloader(tslib_data_module):
+ """Test the test dataloader to ensure it returns the batches of the data,
+ and all hyperparameters are correctly set."""
+
+ tslib_data_module.setup(stage="test")
+ test_data_loader = tslib_data_module.test_dataloader()
+
+ assert hasattr(test_data_loader, "batch_size")
+ assert test_data_loader.batch_size == tslib_data_module.batch_size
+ assert test_data_loader.num_workers == tslib_data_module.num_workers
+
+
+def test_predict_dataloader(tslib_data_module):
+ """Test the predict dataloader to ensure it returns the batches of the data,
+ and all hyperparameters are correctly set."""
+
+ tslib_data_module.setup(stage="predict")
+ predict_data_loader = tslib_data_module.predict_dataloader()
+
+ assert hasattr(predict_data_loader, "batch_size")
+ assert predict_data_loader.batch_size == tslib_data_module.batch_size
+ assert predict_data_loader.num_workers == tslib_data_module.num_workers
+
+
+def test_tslib_dataset(tslib_data_module):
+ """Test the _TslibDataset to ensure it is correctly initialized
+ and ensure correct outputs from __getitem__."""
+
+ tslib_data_module.setup(stage="fit")
+ assert hasattr(tslib_data_module, "train_dataset")
+ train_dataset = tslib_data_module.train_dataset
+
+ assert len(train_dataset) > 0, "The train dataset is empty!"
+
+ sample_x, sample_y = train_dataset[0]
+
+ assert isinstance(sample_x, dict), "Sample x should be a dictionary."
+ assert isinstance(sample_y, torch.Tensor), "Sample y should be a PyTorch tensor."
+
+ expected_keys = [
+ "history_cont",
+ "history_cat",
+ "future_cont",
+ "future_cat",
+ "history_length",
+ "future_length",
+ "history_mask",
+ "future_mask",
+ "groups",
+ "history_time_idx",
+ "future_time_idx",
+ "future_target",
+ "future_target_len",
+ ]
+
+ for key in expected_keys:
+ assert key in sample_x, f"Key '{key}' not found in sample_x."
+
+ context_length = tslib_data_module.context_length
+ prediction_length = tslib_data_module.prediction_length
+ metadata = tslib_data_module.metadata
+
+ assert sample_x["history_cont"].shape[0] == context_length
+ assert sample_x["history_cat"].shape[0] == context_length
+ assert sample_x["future_cont"].shape[0] == prediction_length
+ assert sample_x["future_cat"].shape[0] == prediction_length
+ assert sample_x["history_target"].shape[0] == context_length
+ assert sample_x["future_target"].shape[0] == prediction_length
+
+ known_cat_count = len(
+ [
+ name
+ for name in metadata["feature_names"]["known"]
+ if name in metadata["feature_names"]["categorical"]
+ ]
+ )
+ known_cont_count = len(
+ [
+ name
+ for name in metadata["feature_names"]["known"]
+ if name in metadata["feature_names"]["continuous"]
+ ]
+ )
+
+ print(sample_x["future_cont"].shape)
+
+ assert sample_x["future_cont"].shape[1] == known_cont_count
+ assert sample_x["future_cat"].shape[1] == known_cat_count
+
+ assert sample_y.shape[0] == prediction_length
+
+ assert sample_x["history_cont"].dtype == torch.float32
+ assert sample_x["future_cont"].dtype == torch.float32
+ assert sample_x["history_target"].dtype == torch.float32
+
+ assert sample_y.dtype == torch.float32
+
+
+def test_collate_fn(tslib_data_module):
+ """Test the collate function in the TslibDataModule to ensure it correctly
+ collates the data into batches and properly handles stacking of batches."""
+
+ tslib_data_module.setup(stage="fit")
+ batch_size = 2
+
+ batches = [tslib_data_module.train_dataset[i] for i in range(batch_size)]
+
+ x_batch, y_batch = tslib_data_module.collate_fn(batches)
+
+ for key in x_batch:
+ assert x_batch[key].shape[0] == batch_size
+
+ metadata = tslib_data_module.metadata
+
+ known_cat_count = len(
+ [
+ name
+ for name in metadata["feature_names"]["known"]
+ if name in metadata["feature_names"]["categorical"]
+ ]
+ )
+ known_cont_count = len(
+ [
+ name
+ for name in metadata["feature_names"]["known"]
+ if name in metadata["feature_names"]["continuous"]
+ ]
+ )
+
+ assert x_batch["future_cont"].shape[2] == known_cont_count
+ assert x_batch["future_cat"].shape[2] == known_cat_count
+ # print(x_batch["future_cont"].shape)
+ assert y_batch.shape[0] == batch_size
+ assert y_batch.shape[1] == tslib_data_module.prediction_length
+
+
+def test_create_windows(tslib_data_module):
+ """Test the _create_windows method to ensures correct creation
+ of windows for training, validation and testing."""
+
+ tslib_data_module.setup(stage="fit")
+ train_indices = tslib_data_module._train_indices
+ train_windows = tslib_data_module._create_windows(train_indices)
+
+ assert len(train_windows) > 0, "No training windows created!"
+
+ for windows in train_windows:
+ assert isinstance(windows, tuple), "Windows should be a tuple."
+
+ assert len(windows) == 4, "Each window should have 4 elements."
+
+ series_idx, start_idx, context_length, prediction_length = windows
+
+ assert isinstance(series_idx, int), "series_idx should be an integer."
+
+ assert isinstance(start_idx, int), "start_idx should be an integer."
+
+ assert (
+ context_length == tslib_data_module.context_length
+ ), "context_length should match the datamodule's context_length."
+
+ assert (
+ prediction_length == tslib_data_module.prediction_length
+ ), "prediction_length should match the datamodule's prediction_length."
+
+ assert (
+ 0 <= series_idx < len(tslib_data_module.time_series_dataset)
+ ), "series_idx should be within the range of the dataset length."
+
+ min_required_length = context_length + prediction_length
+
+ time_series_dataset = tslib_data_module.time_series_dataset
+ # print(type(time_series_dataset[series_idx]))
+ sample = time_series_dataset[series_idx]
+
+ if "t" in sample:
+ series_length = len(sample["t"])
+ elif "y" in sample:
+ series_length = len(sample["y"])
+ else:
+ series_length = len(sample)
+ assert (
+ start_idx + min_required_length <= series_length
+ ), "Window extended beyond series length."
+
+ all_indices = torch.arange(len(tslib_data_module.time_series_dataset))
+ all_windows = tslib_data_module._create_windows(all_indices)
+ assert len(all_windows) >= len(
+ train_windows
+ ), "Should have more windows than all indices."
+
+ empty_windows = tslib_data_module._create_windows(torch.tensor([]))
+
+ assert len(empty_windows) == 0, "Should return empty list for empty index."
+
+
+def test_dataloader_pipeline(tslib_data_module):
+ """Test for a single iteration of the dataloader pipeline to
+ perform batch retrival and ensure correct data shapes and types."""
+
+ tslib_data_module.setup(stage="fit")
+ train_dataloader = tslib_data_module.train_dataloader()
+
+ x_batch, y_batch = next(iter(train_dataloader))
+
+ assert isinstance(x_batch, dict), "x_batch should be a dictionary."
+ assert isinstance(y_batch, torch.Tensor), "y_batch should be a PyTorch tensor."
+
+ assert x_batch["history_cont"].shape[1] == tslib_data_module.context_length
+ assert x_batch["history_cat"].shape[1] == tslib_data_module.context_length
+
+ metadata = tslib_data_module.metadata
+
+ known_cat_count = len(
+ [
+ name
+ for name in metadata["feature_names"]["known"]
+ if name in metadata["feature_names"]["categorical"]
+ ]
+ )
+
+ known_cont_count = len(
+ [
+ name
+ for name in metadata["feature_names"]["known"]
+ if name in metadata["feature_names"]["continuous"]
+ ]
+ )
+
+ assert x_batch["future_cont"].shape[0] == tslib_data_module.batch_size
+ assert x_batch["future_cat"].shape[2] == known_cat_count
+ assert x_batch["future_cont"].shape[2] == known_cont_count
+ assert x_batch["future_cont"].shape[0] == tslib_data_module.batch_size
+
+ assert y_batch.shape[0] == tslib_data_module.batch_size
+ assert y_batch.shape[1] == tslib_data_module.prediction_length
+
+
+def test_different_split_ratios(sample_timeseries_data):
+ """Test the TslibDataModule with different train/val/test split ratios."""
+
+ custom_split = (0.6, 0.2, 0.2)
+ dm_custom = TslibDataModule(
+ time_series_dataset=sample_timeseries_data,
+ context_length=8,
+ prediction_length=4,
+ batch_size=2,
+ train_val_test_split=custom_split,
+ )
+
+ dm_custom.setup(stage="fit")
+
+ total_series = len(sample_timeseries_data)
+ expected_train = int(total_series * 0.6)
+ expected_val = int(total_series * 0.2)
+ expected_test = total_series - expected_train - expected_val
+
+ assert len(dm_custom._train_indices) == expected_train
+ assert len(dm_custom._val_indices) == expected_val
+ assert len(dm_custom._test_indices) == expected_test
+
+ assert dm_custom.train_val_test_split == custom_split
+
+ total_split = (
+ len(dm_custom._train_indices)
+ + len(dm_custom._val_indices)
+ + len(dm_custom._test_indices)
+ )
+ assert (
+ total_split == total_series
+ ), "Total split indices should match the dataset length."
+
+
+def test_preprocess_data(tslib_data_module, sample_timeseries_data):
+ """Test the preprocess_data method.
+ Ensures alignment and presence of all required features."""
+
+ if not hasattr(tslib_data_module, "_indices"):
+ tslib_data_module.setup()
+
+ sample_series_idx = tslib_data_module._train_indices[0]
+ processed = tslib_data_module._preprocess_data(sample_series_idx)
+
+ assert "features" in processed
+ assert "target" in processed
+ assert "static" in processed
+ assert "group" in processed
+ assert "time_mask" in processed
+ assert "continuous" in processed["features"]
+ assert "categorical" in processed["features"]
+ assert "length" in processed
+ assert "timestep" in processed
+
+ original_sample = sample_timeseries_data[sample_series_idx]
+
+ expected_length = len(original_sample["t"])
+
+ assert processed["features"]["categorical"].shape[0] == expected_length
+ assert processed["features"]["continuous"].shape[0] == expected_length
+ assert processed["target"].shape[0] == expected_length
+
+
+def test_static_features(tslib_data_module):
+ """Test with static features included.
+
+ Validates the static feature support in the TslibDataModule."""
+
+ tslib_data_module.setup(stage="fit")
+
+ metadata = tslib_data_module.metadata
+
+ assert metadata["n_features"]["static_continuous"] == 1
+
+ x, y = tslib_data_module.train_dataset[0]
+
+ assert "static_continuous_features" in x
+ assert (
+ x["static_continuous_features"].shape[1]
+ == metadata["n_features"]["static_continuous"]
+ )
+
+
+def test_multivariate_target():
+ """Test with multivariate target (multiple target columns).
+
+ Verifies correct handling of multivariate targets in data pipeline."""
+ df = pd.DataFrame(
+ {
+ "group": np.repeat([0, 1], 50),
+ "time": np.tile(pd.date_range("2020-01-01", periods=50), 2),
+ "target1": np.random.normal(0, 1, 100),
+ "target2": np.random.normal(5, 2, 100),
+ "feature1": np.random.normal(0, 1, 100),
+ "feature2": np.random.normal(0, 1, 100),
+ }
+ )
+
+ ts = TimeSeries(
+ data=df,
+ time="time",
+ target=["target1", "target2"],
+ group=["group"],
+ num=["feature1", "feature2"],
+ )
+
+ dm = TslibDataModule(
+ time_series_dataset=ts,
+ context_length=8,
+ prediction_length=4,
+ batch_size=2,
+ )
+
+ dm.setup(stage="fit")
+
+ x, y = dm.train_dataset[0]
+
+ assert (
+ y.shape[-1] == 2
+ ), "Target should have two dimensions for n_features for multivariate target."
diff --git a/pytorch_forecasting/layers/__init__.py b/pytorch_forecasting/layers/__init__.py
new file mode 100644
index 000000000..43a8db84c
--- /dev/null
+++ b/pytorch_forecasting/layers/__init__.py
@@ -0,0 +1,29 @@
+"""
+Architectural deep learning layers from `nn.Module`.
+"""
+
+from pytorch_forecasting.layers._attention import AttentionLayer, FullAttention
+from pytorch_forecasting.layers._embeddings import (
+ DataEmbedding_inverted,
+ EnEmbedding,
+ PositionalEmbedding,
+)
+from pytorch_forecasting.layers._encoders import (
+ Encoder,
+ EncoderLayer,
+)
+from pytorch_forecasting.layers._output._flatten_head import (
+ FlattenHead,
+)
+
+__all__ = [
+ "FullAttention",
+ "TriangularCausalMask",
+ "AttentionLayer",
+ "DataEmbedding_inverted",
+ "EnEmbedding",
+ "PositionalEmbedding",
+ "Encoder",
+ "EncoderLayer",
+ "FlattenHead",
+]
diff --git a/pytorch_forecasting/layers/_attention/__init__.py b/pytorch_forecasting/layers/_attention/__init__.py
new file mode 100644
index 000000000..cdfc6c3e2
--- /dev/null
+++ b/pytorch_forecasting/layers/_attention/__init__.py
@@ -0,0 +1,8 @@
+"""
+Attention Layers for pytorch-forecasting models.
+"""
+
+from pytorch_forecasting.layers._attention._attention_layer import AttentionLayer
+from pytorch_forecasting.layers._attention._full_attention import FullAttention
+
+__all__ = ["AttentionLayer", "FullAttention"]
diff --git a/pytorch_forecasting/layers/_attention/_attention_layer.py b/pytorch_forecasting/layers/_attention/_attention_layer.py
new file mode 100644
index 000000000..3f6072a93
--- /dev/null
+++ b/pytorch_forecasting/layers/_attention/_attention_layer.py
@@ -0,0 +1,58 @@
+"""
+Implementation of attention layers from `nn.Module`.
+"""
+
+from math import sqrt
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class AttentionLayer(nn.Module):
+ """
+ Attention layer that combines query, key, and value projections with an attention
+ mechanism.
+ Args:
+ attention (nn.Module): Attention mechanism to use.
+ d_model (int): Dimension of the model.
+ n_heads (int): Number of attention heads.
+ d_keys (int, optional): Dimension of the keys. Defaults to d_model // n_heads.
+ d_values (int, optional):
+ Dimension of the values. Defaults to d_model // n_heads.
+ """
+
+ def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
+ super().__init__()
+
+ d_keys = d_keys or (d_model // n_heads)
+ d_values = d_values or (d_model // n_heads)
+
+ self.inner_attention = attention
+ self.query_projection = nn.Linear(d_model, d_keys * n_heads)
+ self.key_projection = nn.Linear(d_model, d_keys * n_heads)
+ self.value_projection = nn.Linear(d_model, d_values * n_heads)
+ self.out_projection = nn.Linear(d_values * n_heads, d_model)
+ self.n_heads = n_heads
+
+ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
+ B, L, _ = queries.shape
+ _, S, _ = keys.shape
+ H = self.n_heads
+
+ if S == 0:
+ # skip the cross attention process since there is no exogenous variables
+ queries = self.query_projection(queries)
+ return self.out_projection(queries), None
+
+ queries = self.query_projection(queries).view(B, L, H, -1)
+ keys = self.key_projection(keys).view(B, S, H, -1)
+ values = self.value_projection(values).view(B, S, H, -1)
+
+ out, attn = self.inner_attention(
+ queries, keys, values, attn_mask, tau=tau, delta=delta
+ )
+ out = out.view(B, L, -1)
+
+ return self.out_projection(out), attn
diff --git a/pytorch_forecasting/layers/_attention/_full_attention.py b/pytorch_forecasting/layers/_attention/_full_attention.py
new file mode 100644
index 000000000..def9b5214
--- /dev/null
+++ b/pytorch_forecasting/layers/_attention/_full_attention.py
@@ -0,0 +1,71 @@
+"""
+Full Attention Layer.
+"""
+
+from math import sqrt
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class TriangularCausalMask:
+ """
+ Triangular causal mask for attention mechanism.
+ """
+
+ def __init__(self, B, L, device="cpu"):
+ mask_shape = [B, 1, L, L]
+ with torch.no_grad():
+ self._mask = torch.triu(
+ torch.ones(mask_shape, dtype=torch.bool), diagonal=1
+ ).to(device)
+
+ @property
+ def mask(self):
+ return self._mask
+
+
+class FullAttention(nn.Module):
+ """
+ Full attention mechanism with optional masking and dropout.
+ Args:
+ mask_flag (bool): Whether to apply masking.
+ factor (int): Factor for scaling the attention scores.
+ scale (float): Scaling factor for attention scores.
+ attention_dropout (float): Dropout rate for attention scores.
+ output_attention (bool): Whether to output attention weights."""
+
+ def __init__(
+ self,
+ mask_flag=True,
+ factor=5,
+ scale=None,
+ attention_dropout=0.1,
+ output_attention=False,
+ ):
+ super().__init__()
+ self.scale = scale
+ self.mask_flag = mask_flag
+ self.output_attention = output_attention
+ self.dropout = nn.Dropout(attention_dropout)
+
+ def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
+ B, L, H, E = queries.shape
+ _, S, _, D = values.shape
+ scale = self.scale or 1.0 / sqrt(E)
+
+ scores = torch.einsum("blhe,bshe->bhls", queries, keys)
+
+ if self.mask_flag:
+ if attn_mask is None:
+ attn_mask = TriangularCausalMask(B, L, device=queries.device)
+ scores.masked_fill_(attn_mask.mask, -np.abs)
+ A = self.dropout(torch.softmax(scale * scores, dim=-1))
+ V = torch.einsum("bhls,bshd->blhd", A, values)
+
+ if self.output_attention:
+ return V.contiguous(), A
+ else:
+ return V.contiguous(), None
diff --git a/pytorch_forecasting/layers/_embeddings/__init__.py b/pytorch_forecasting/layers/_embeddings/__init__.py
new file mode 100644
index 000000000..e18bd88ce
--- /dev/null
+++ b/pytorch_forecasting/layers/_embeddings/__init__.py
@@ -0,0 +1,13 @@
+"""
+Implementation of embedding layers for PTF models imported from `nn.Modules`
+"""
+
+from pytorch_forecasting.layers._embeddings._data_embedding import (
+ DataEmbedding_inverted,
+)
+from pytorch_forecasting.layers._embeddings._en_embedding import EnEmbedding
+from pytorch_forecasting.layers._embeddings._positional_embedding import (
+ PositionalEmbedding,
+)
+
+__all__ = ["PositionalEmbedding", "DataEmbedding_inverted", "EnEmbedding"]
diff --git a/pytorch_forecasting/layers/_embeddings/_data_embedding.py b/pytorch_forecasting/layers/_embeddings/_data_embedding.py
new file mode 100644
index 000000000..9e33e65f7
--- /dev/null
+++ b/pytorch_forecasting/layers/_embeddings/_data_embedding.py
@@ -0,0 +1,38 @@
+"""
+Data embedding layer for exogenous variables.
+"""
+
+import math
+from math import sqrt
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DataEmbedding_inverted(nn.Module):
+ """
+ Data embedding module for time series data.
+ Args:
+ c_in (int): Number of input features.
+ d_model (int): Dimension of the model.
+ embed_type (str): Type of embedding to use. Defaults to "fixed".
+ freq (str): Frequency of the time series data. Defaults to "h".
+ dropout (float): Dropout rate. Defaults to 0.1.
+ """
+
+ def __init__(self, c_in, d_model, dropout=0.1):
+ super().__init__()
+ self.value_embedding = nn.Linear(c_in, d_model)
+ self.dropout = nn.Dropout(p=dropout)
+
+ def forward(self, x, x_mark):
+ x = x.permute(0, 2, 1)
+ # x: [Batch Variate Time]
+ if x_mark is None:
+ x = self.value_embedding(x)
+ else:
+ x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
+ # x: [Batch Variate d_model]
+ return self.dropout(x)
diff --git a/pytorch_forecasting/layers/_embeddings/_en_embedding.py b/pytorch_forecasting/layers/_embeddings/_en_embedding.py
new file mode 100644
index 000000000..573f0571b
--- /dev/null
+++ b/pytorch_forecasting/layers/_embeddings/_en_embedding.py
@@ -0,0 +1,52 @@
+"""
+Implementation of endogenous embedding layers from `nn.Module`.
+"""
+
+import math
+from math import sqrt
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from pytorch_forecasting.layers._embeddings._positional_embedding import (
+ PositionalEmbedding,
+)
+
+
+class EnEmbedding(nn.Module):
+ """
+ Encoder embedding module for time series data. Handles endogenous feature
+ embeddings in this case.
+ Args:
+ n_vars (int): Number of input features.
+ d_model (int): Dimension of the model.
+ patch_len (int): Length of the patches.
+ dropout (float): Dropout rate. Defaults to 0.1.
+ """
+
+ def __init__(self, n_vars, d_model, patch_len, dropout):
+ super().__init__()
+
+ self.patch_len = patch_len
+
+ self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
+ self.glb_token = nn.Parameter(torch.randn(1, n_vars, 1, d_model))
+ self.position_embedding = PositionalEmbedding(d_model)
+
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 1)
+ n_vars = x.shape[1]
+ glb = self.glb_token.repeat((x.shape[0], 1, 1, 1))
+
+ x = x.unfold(dimension=-1, size=self.patch_len, step=self.patch_len)
+ x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
+ # Input encoding
+ x = self.value_embedding(x) + self.position_embedding(x)
+ x = torch.reshape(x, (-1, n_vars, x.shape[-2], x.shape[-1]))
+ x = torch.cat([x, glb], dim=2)
+ x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
+ return self.dropout(x), n_vars
diff --git a/pytorch_forecasting/layers/_embeddings/_positional_embedding.py b/pytorch_forecasting/layers/_embeddings/_positional_embedding.py
new file mode 100644
index 000000000..82b107315
--- /dev/null
+++ b/pytorch_forecasting/layers/_embeddings/_positional_embedding.py
@@ -0,0 +1,39 @@
+"""
+Positional Embedding Layer for PTF.
+"""
+
+import math
+from math import sqrt
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class PositionalEmbedding(nn.Module):
+ """
+ Positional embedding module for time series data.
+ Args:
+ d_model (int): Dimension of the model.
+ max_len (int): Maximum length of the input sequence. Defaults to 5000."""
+
+ def __init__(self, d_model, max_len=5000):
+ super().__init__()
+ # Compute the positional encodings once in log space.
+ pe = torch.zeros(max_len, d_model).float()
+ pe.require_grad = False
+
+ position = torch.arange(0, max_len).float().unsqueeze(1)
+ div_term = (
+ torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
+ ).exp()
+
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+
+ pe = pe.unsqueeze(0)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x):
+ return self.pe[:, : x.size(1)]
diff --git a/pytorch_forecasting/layers/_encoders/__init__.py b/pytorch_forecasting/layers/_encoders/__init__.py
new file mode 100644
index 000000000..da20c4f3f
--- /dev/null
+++ b/pytorch_forecasting/layers/_encoders/__init__.py
@@ -0,0 +1,8 @@
+"""
+Encoder layers for neural network models.
+"""
+
+from pytorch_forecasting.layers._encoders._encoder import Encoder
+from pytorch_forecasting.layers._encoders._encoder_layer import EncoderLayer
+
+__all__ = ["Encoder", "EncoderLayer"]
diff --git a/pytorch_forecasting/layers/_encoders/_encoder.py b/pytorch_forecasting/layers/_encoders/_encoder.py
new file mode 100644
index 000000000..3b54a0838
--- /dev/null
+++ b/pytorch_forecasting/layers/_encoders/_encoder.py
@@ -0,0 +1,40 @@
+"""
+Implementation of encoder layers from `nn.Module`.
+"""
+
+import math
+from math import sqrt
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Encoder(nn.Module):
+ """
+ Encoder module for the TimeXer model.
+ Args:
+ layers (list): List of encoder layers.
+ norm_layer (nn.Module, optional): Normalization layer. Defaults to None.
+ projection (nn.Module, optional): Projection layer. Defaults to None.
+ """
+
+ def __init__(self, layers, norm_layer=None, projection=None):
+ super().__init__()
+ self.layers = nn.ModuleList(layers)
+ self.norm = norm_layer
+ self.projection = projection
+
+ def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
+ for layer in self.layers:
+ x = layer(
+ x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta
+ )
+
+ if self.norm is not None:
+ x = self.norm(x)
+
+ if self.projection is not None:
+ x = self.projection(x)
+ return x
diff --git a/pytorch_forecasting/layers/_encoders/_encoder_layer.py b/pytorch_forecasting/layers/_encoders/_encoder_layer.py
new file mode 100644
index 000000000..a246edc91
--- /dev/null
+++ b/pytorch_forecasting/layers/_encoders/_encoder_layer.py
@@ -0,0 +1,73 @@
+"""
+Implementation of EncoderLayer for encoder-decoder architectures from `nn.Module`.
+"""
+
+import math
+from math import sqrt
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class EncoderLayer(nn.Module):
+ """
+ Encoder layer for the TimeXer model.
+ Args:
+ self_attention (nn.Module): Self-attention mechanism.
+ cross_attention (nn.Module): Cross-attention mechanism.
+ d_model (int): Dimension of the model.
+ d_ff (int, optional):
+ Dimension of the feedforward layer. Defaults to 4 * d_model.
+ dropout (float): Dropout rate. Defaults to 0.1.
+ activation (str): Activation function. Defaults to "relu".
+ """
+
+ def __init__(
+ self,
+ self_attention,
+ cross_attention,
+ d_model,
+ d_ff=None,
+ dropout=0.1,
+ activation="relu",
+ ):
+ super().__init__()
+ d_ff = d_ff or 4 * d_model
+ self.self_attention = self_attention
+ self.cross_attention = cross_attention
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
+ self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
+ self.norm1 = nn.LayerNorm(d_model)
+ self.norm2 = nn.LayerNorm(d_model)
+ self.norm3 = nn.LayerNorm(d_model)
+ self.dropout = nn.Dropout(dropout)
+ self.activation = F.relu if activation == "relu" else F.gelu
+
+ def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None):
+ B, L, D = cross.shape
+ x = x + self.dropout(
+ self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0]
+ )
+ x = self.norm1(x)
+
+ x_glb_ori = x[:, -1, :].unsqueeze(1)
+ x_glb = torch.reshape(x_glb_ori, (B, -1, D))
+ x_glb_attn = self.dropout(
+ self.cross_attention(
+ x_glb, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta
+ )[0]
+ )
+ x_glb_attn = torch.reshape(
+ x_glb_attn, (x_glb_attn.shape[0] * x_glb_attn.shape[1], x_glb_attn.shape[2])
+ ).unsqueeze(1)
+ x_glb = x_glb_ori + x_glb_attn
+ x_glb = self.norm2(x_glb)
+
+ y = x = torch.cat([x[:, :-1, :], x_glb], dim=1)
+
+ y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
+ y = self.dropout(self.conv2(y).transpose(-1, 1))
+
+ return self.norm3(x + y)
diff --git a/pytorch_forecasting/layers/_output/__init__.py b/pytorch_forecasting/layers/_output/__init__.py
new file mode 100644
index 000000000..eb3b686a3
--- /dev/null
+++ b/pytorch_forecasting/layers/_output/__init__.py
@@ -0,0 +1,7 @@
+"""
+Implementation of output layers for PyTorch Forecasting.
+"""
+
+from pytorch_forecasting.layers._output._flatten_head import FlattenHead
+
+__all__ = ["FlattenHead"]
diff --git a/pytorch_forecasting/layers/_output/_flatten_head.py b/pytorch_forecasting/layers/_output/_flatten_head.py
new file mode 100644
index 000000000..71823b162
--- /dev/null
+++ b/pytorch_forecasting/layers/_output/_flatten_head.py
@@ -0,0 +1,45 @@
+"""
+Implementation of output layers from `nn.Module` for TimeXer model.
+"""
+
+import math
+from math import sqrt
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class FlattenHead(nn.Module):
+ """
+ Flatten head for the output of the model.
+ Args:
+ n_vars (int): Number of input features.
+ nf (int): Number of features in the last layer.
+ target_window (int): Target window size.
+ head_dropout (float): Dropout rate for the head. Defaults to 0.
+ n_quantiles (int, optional): Number of quantiles. Defaults to None."""
+
+ def __init__(self, n_vars, nf, target_window, head_dropout=0, n_quantiles=None):
+ super().__init__()
+ self.n_vars = n_vars
+ self.flatten = nn.Flatten(start_dim=-2)
+ self.linear = nn.Linear(nf, target_window)
+ self.n_quantiles = n_quantiles
+
+ if self.n_quantiles is not None:
+ self.linear = nn.Linear(nf, target_window * n_quantiles)
+ else:
+ self.linear = nn.Linear(nf, target_window)
+ self.dropout = nn.Dropout(head_dropout)
+
+ def forward(self, x):
+ x = self.flatten(x)
+ x = self.linear(x)
+ x = self.dropout(x)
+
+ if self.n_quantiles is not None:
+ batch_size, n_vars = x.shape[0], x.shape[1]
+ x = x.reshape(batch_size, n_vars, -1, self.n_quantiles)
+ return x
diff --git a/pytorch_forecasting/models/base/_tslib_base_model_v2.py b/pytorch_forecasting/models/base/_tslib_base_model_v2.py
new file mode 100644
index 000000000..b502c46bf
--- /dev/null
+++ b/pytorch_forecasting/models/base/_tslib_base_model_v2.py
@@ -0,0 +1,190 @@
+"""
+Experimental implementation of a base class for `tslib` models.
+"""
+
+from typing import Optional, Union
+from warnings import warn
+
+import torch
+import torch.nn as nn
+from torch.optim import Optimizer
+
+from pytorch_forecasting.models.base._base_model_v2 import BaseModel
+
+
+class TslibBaseModel(BaseModel):
+ """
+ Base class for `tslib` models.
+
+ Parameters
+ ----------
+ loss : nn.Module
+ Loss function to use for training.
+ logging_metrics : Optional[list[nn.Module]], optional
+ list of metrics to log during training, validation, and testing.
+ optimizer : Optional[Union[Optimizer, str]], optional
+ Optimizer to use for training.
+ optimizer_params : Optional[dict], optional
+ Parameters for the optimizer.
+ lr_scheduler : Optional[str], optional
+ Learning rate scheduler to use.
+ lr_scheduler_params : Optional[dict], optional
+ Parameters for the learning rate scheduler.
+ metadata : Optional[dict], default=None
+ Metadata for the model from TslibDataModule.
+ """
+
+ def __init__(
+ self,
+ loss: nn.Module,
+ logging_metrics: Optional[list[nn.Module]] = None,
+ optimizer: Optional[Union[Optimizer, str]] = "adam",
+ optimizer_params: Optional[dict] = None,
+ lr_scheduler: Optional[str] = None,
+ lr_scheduler_params: Optional[dict] = None,
+ metadata: Optional[dict] = None,
+ ):
+ super().__init__(
+ loss=loss,
+ logging_metrics=logging_metrics,
+ optimizer=optimizer,
+ optimizer_params=optimizer_params,
+ lr_scheduler=lr_scheduler,
+ lr_scheduler_params=lr_scheduler_params,
+ )
+ self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"])
+ self.metadata = metadata or {}
+ self.model_name = self.__class__.__name__
+
+ warn(
+ f"The Model '{self.model_name}' is part of an experimental implementation"
+ "of the pytorch-forecasting model layer for Time Series Library, scheduled"
+ "for release with v2.0.0. The API is not stable"
+ "and may change without prior warning. This class is intended for beta"
+ "testing, not for stable production use.",
+ UserWarning,
+ )
+
+ self.context_length = self.metadata.get("context_length", 0)
+ self.prediction_length = self.metadata.get("prediction_length", 0)
+
+ feature_indices = metadata.get("feature_indices", {})
+ self.cont_indices = feature_indices.get("continuous", [])
+ self.cat_indices = feature_indices.get("categorical", [])
+ self.known_indices = feature_indices.get("known", [])
+ self.unknown_indices = feature_indices.get("unknown", [])
+ self.target_indices = feature_indices.get("target", [])
+
+ feature_dims = metadata.get("n_features", {})
+ self.cont_dim = feature_dims.get("continuous", 0)
+ self.cat_dim = feature_dims.get("categorical", 0)
+ self.static_cat_dim = feature_dims.get("static_categorical", 0)
+ self.static_cont_dim = feature_dims.get("static_continuous", 0)
+ self.target_dim = feature_dims.get("target", 1)
+
+ self.feature_names = metadata.get("feature_names", {})
+
+ # feature-mode
+ self.features = metadata.get("features", "MS")
+
+ def _init_network(self):
+ """
+ Initialize the network architecture.
+ This method should be implemented in subclasses to define the specific layers
+ and sub_modules of the model.
+ """
+ raise NotImplementedError("Subclasses must implement _init_network method.")
+
+ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Forward pass of the model.
+
+ Parameters
+ ----------
+ x: dict[str, torch.Tensor]
+ Dictionary containing input tensors.
+
+ Returns
+ -------
+ dict[str, torch.Tensor]
+ Dictionary containing output tensors. These can include
+ - predictions:
+ Prediction_output of shape (batch_size, prediction_length, target_dim)
+ - attention_weights: Optionally, output attention weights
+ """
+
+ raise NotImplementedError("Subclasses must implement forward method.")
+
+ def predict_step(
+ self,
+ batch: tuple[dict[str, torch.Tensor]],
+ batch_idx: int,
+ dataloader_idx: int = 0,
+ ) -> torch.Tensor:
+ """
+ Prediction step for the model.
+
+ Parameters
+ ----------
+ batch : tuple[dict[str, torch.Tensor]]
+ Batch of data containing input tensors.
+ batch_idx : int
+ Index of the batch.
+ dataloader_idx : int
+ Index of the dataloader.
+
+ Returns
+ -------
+ torch.Tensor
+ Predicted output tensor.
+ """
+ x, _ = batch
+ y_hat = self(x)
+
+ if "target" in x:
+ y_hat["target"] = x["target"]
+
+ return y_hat
+
+ def transform_output(
+ self,
+ y_hat: Union[
+ torch.Tensor, list[torch.Tensor]
+ ], # evidenced from TimeXer implementation - in PR #1797 # noqa: E501
+ target_scale: Optional[dict[str, torch.Tensor]],
+ ) -> Union[torch.Tensor, list[torch.Tensor]]:
+ """
+ Transform the output of the model to the original scale.
+
+ Parameters
+ ----------
+ y_hat : Union[torch.Tensor, list[torch.Tensor]]
+ Dictionary containing the model output.
+ target_scale : Optional[dict[str, torch.Tensor]]
+ Dictionary containing the target scale for inverse transformation.
+
+ Returns
+ -------
+ Union[torch.Tensor, list[torch.Tensor]]
+ Dictionary containing the transformed output.
+
+ Notes
+ -----
+ WARNING! : This is a temporary implementation and is meant to be replaced with
+ a more robust scaling and normalization module for v2 of PTF.
+ """
+
+ scale = None
+ center = None
+
+ if "scale" in target_scale and "center" in target_scale:
+ scale = target_scale["scale"]
+ center = target_scale["center"]
+ else:
+ raise ValueError("Cannot transform output without scale and center.")
+
+ while scale.dim() < y_hat.dim():
+ scale = scale.unsqueeze(0)
+ center = center.unsqueeze(0)
+
+ return y_hat * scale + center
diff --git a/pytorch_forecasting/models/timexer/__init__.py b/pytorch_forecasting/models/timexer/__init__.py
index 8ec9c7dd2..43703d33b 100644
--- a/pytorch_forecasting/models/timexer/__init__.py
+++ b/pytorch_forecasting/models/timexer/__init__.py
@@ -4,6 +4,7 @@
from pytorch_forecasting.models.timexer._timexer import TimeXer
from pytorch_forecasting.models.timexer._timexer_pkg import TimeXer_pkg
+from pytorch_forecasting.models.timexer._timexer_pkg_v2 import TimeXer_pkg_v2
from pytorch_forecasting.models.timexer.sub_modules import (
AttentionLayer,
DataEmbedding_inverted,
@@ -28,4 +29,5 @@
"Encoder",
"EncoderLayer",
"TimeXer_pkg",
+ "TimeXer_pkg_v2",
]
diff --git a/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py b/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py
new file mode 100644
index 000000000..22f7f83ba
--- /dev/null
+++ b/pytorch_forecasting/models/timexer/_timexer_pkg_v2.py
@@ -0,0 +1,162 @@
+"""
+Metadata container for TimeXer v2.
+"""
+
+from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2
+
+
+class TimeXer_pkg_v2(_BasePtForecasterV2):
+ """TimeXer metadata container."""
+
+ _tags = {
+ "info:name": "TimeXer",
+ "authors": ["PranavBhatP"],
+ "capability:exogenous": True,
+ "capability:multivariate": True,
+ "capability:pred_int": True,
+ "capability:flexible_history_length": False,
+ }
+
+ @classmethod
+ def get_model_cls(cls):
+ """Get model class."""
+ from pytorch_forecasting.models.timexer._timexer_v2 import TimeXer
+
+ return TimeXer
+
+ @classmethod
+ def _get_test_datamodule_from(cls, trainer_kwargs):
+ """Create test dataloaders from trainer_kwargs - following v1 pattern."""
+ from pytorch_forecasting.data._tslib_data_module import TslibDataModule
+ from pytorch_forecasting.tests._data_scenarios import (
+ data_with_covariates_v2,
+ make_datasets_v2,
+ )
+
+ data_with_covariates = data_with_covariates_v2()
+
+ data_loader_default_kwargs = dict(
+ target="target",
+ group_ids=["agency_encoded", "sku_encoded"],
+ add_relative_time_idx=True,
+ )
+
+ data_loader_kwargs = trainer_kwargs.get("data_loader_kwargs", {})
+ data_loader_default_kwargs.update(data_loader_kwargs)
+
+ datasets_info = make_datasets_v2(
+ data_with_covariates, **data_loader_default_kwargs
+ )
+
+ training_dataset = datasets_info["training_dataset"]
+ validation_dataset = datasets_info["validation_dataset"]
+
+ context_length = data_loader_kwargs.get("context_length", 12)
+ prediction_length = data_loader_kwargs.get("prediction_length", 4)
+ batch_size = data_loader_kwargs.get("batch_size", 2)
+
+ train_datamodule = TslibDataModule(
+ time_series_dataset=training_dataset,
+ context_length=context_length,
+ prediction_length=prediction_length,
+ add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True),
+ batch_size=batch_size,
+ train_val_test_split=(0.8, 0.2, 0.0),
+ )
+
+ val_datamodule = TslibDataModule(
+ time_series_dataset=validation_dataset,
+ context_length=context_length,
+ prediction_length=prediction_length,
+ add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True),
+ batch_size=batch_size,
+ train_val_test_split=(0.0, 1.0, 0.0),
+ )
+
+ test_datamodule = TslibDataModule(
+ time_series_dataset=validation_dataset,
+ context_length=context_length,
+ prediction_length=prediction_length,
+ add_relative_time_idx=data_loader_kwargs.get("add_relative_time_idx", True),
+ batch_size=1,
+ train_val_test_split=(0.0, 0.0, 1.0),
+ )
+
+ train_datamodule.setup("fit")
+ val_datamodule.setup("fit")
+ test_datamodule.setup("test")
+
+ train_dataloader = train_datamodule.train_dataloader()
+ val_dataloader = val_datamodule.val_dataloader()
+ test_dataloader = test_datamodule.test_dataloader()
+
+ return {
+ "train": train_dataloader,
+ "val": val_dataloader,
+ "test": test_dataloader,
+ "data_module": train_datamodule,
+ }
+
+ @classmethod
+ def get_test_train_params(cls):
+ """Return testing parameter settings for the trainer.
+
+ Returns
+ -------
+ params : dict or list of dict, default = {}
+ Parameters to create testing instances of the class
+ Each dict are parameters to construct an "interesting" test instance, i.e.,
+ `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
+ `create_test_instance` uses the first (or only) dictionary in `params`
+ """
+ return [
+ {},
+ dict(
+ hidden_size=64,
+ n_heads=4,
+ ),
+ dict(data_loader_kwargs=dict(context_length=12, prediction_length=3)),
+ dict(
+ hidden_size=32,
+ n_heads=2,
+ data_loader_kwargs=dict(
+ context_length=12,
+ prediction_length=3,
+ add_relative_time_idx=False,
+ ),
+ ),
+ dict(
+ hidden_size=128,
+ patch_length=12,
+ data_loader_kwargs=dict(context_length=16, prediction_length=4),
+ ),
+ dict(
+ n_heads=2,
+ e_layers=1,
+ patch_length=6,
+ ),
+ dict(
+ hidden_size=256,
+ n_heads=8,
+ e_layers=3,
+ d_ff=1024,
+ patch_length=8,
+ factor=3,
+ activation="gelu",
+ dropout=0.2,
+ ),
+ dict(
+ hidden_size=32,
+ n_heads=2,
+ e_layers=1,
+ d_ff=64,
+ patch_length=4,
+ factor=2,
+ activation="relu",
+ dropout=0.05,
+ data_loader_kwargs=dict(
+ context_length=16,
+ prediction_length=4,
+ ),
+ ),
+ ]
diff --git a/pytorch_forecasting/models/timexer/_timexer_v2.py b/pytorch_forecasting/models/timexer/_timexer_v2.py
new file mode 100644
index 000000000..671b2a383
--- /dev/null
+++ b/pytorch_forecasting/models/timexer/_timexer_v2.py
@@ -0,0 +1,333 @@
+"""
+Time Series Transformer with eXogenous variables (TimeXer)
+----------------------------------------------------------
+"""
+
+################################################################
+# NOTE: This implementation of TimeXer derives from PR #1797. #
+# It is experimental and seeks to clarify design decisions. #
+# IT IS STRICTLY A PART OF THE v2 design of PTF. It overrides #
+# the v1 version introduced in PTF by PR #1797 #
+################################################################
+
+from typing import Any, Optional, Union
+import warnings as warn
+
+import torch
+import torch.nn as nn
+from torch.optim import Optimizer
+
+from pytorch_forecasting.models.base._tslib_base_model_v2 import TslibBaseModel
+
+
+class TimeXer(TslibBaseModel):
+ """
+ An implementation of TimeXer model for v2 of pytorch-forecasting.
+
+ TimeXer empowers the canonical transformer with the ability to reconcile
+ endogenous and exogenous information without any architectural modifications
+ and achieves consistent state-of-the-art performance across twelve real-world
+ forecasting benchmarks.
+
+ TimeXer employs patch-level and variate-level representations respectively for
+ endogenous and exogenous variables, with an endogenous global token as a bridge
+ in-between. With this design, TimeXer can jointly capture intra-endogenous
+ temporal dependencies and exogenous-to-endogenous correlations.
+
+ Parameters
+ ----------
+ loss: nn.Module
+ Loss function to use for training.
+ enc_in: int, optional
+ Number of input features for the encoder. If not provided, it will be set to
+ the number of continuous features in the dataset.
+ hidden_size: int, default=512
+ Dimension of the model embeddings and hidden representations of features.
+ n_heads: int, default=8
+ Number of attention heads in the multi-head attention mechanism.\
+ e_layers: int, default=2
+ Number of encoder layers in the transformer architecture.
+ d_ff: int, default=2048
+ Dimension of the feed-forward network in the transformer architecture.
+ dropout: float, default=0.1
+ Dropout rate for regularization. This is used throughout the model to prevent overfitting.
+ patch_length: int, default=24
+ Length of each non-overlapping patch for endogenous variable tokenization.
+ factor: int, default=5
+ Factor for the attention mechanism, controlling the number of keys and values.
+ activation: str, default='relu'
+ Activation function to use in the feed-forward network. Common choices are 'relu', 'gelu', etc.
+ endogenous_vars: Optional[list[str]], default=None
+ List of endogenous variable names to be used in the model. If None, all historical values
+ for the target variable are used.
+ exogenous_vars: Optional[list[str]], default=None
+ List of exogenous variable names to be used in the model. If None, all historical values
+ for continous variables are used.
+ logging_metrics: Optional[list[nn.Module]], default=None
+ List of metrics to log during training, validation, and testing.
+ optimizer: Optional[Union[Optimizer, str]], default='adam'
+ Optimizer to use for training. Can be a string name or an instance of an optimizer.
+ optimizer_params: Optional[dict], default=None
+ Parameters for the optimizer. If None, default parameters for the optimizer will be used.
+ lr_scheduler: Optional[str], default=None
+ Learning rate scheduler to use. If None, no scheduler is used.
+ lr_scheduler_params: Optional[dict], default=None
+ Parameters for the learning rate scheduler. If None, default parameters for the scheduler will be used.
+ metadata: Optional[dict], default=None
+ Metadata for the model from TslibDataModule. This can include information about the dataset,
+ such as the number of time steps, number of features, etc. It is used to initialize the model
+ and ensure it is compatible with the data being used.
+
+ References
+ ----------
+ [1] https://arxiv.org/abs/2402.19072
+ [2] https://github.com/thuml/TimeXer
+
+ Notes
+ -----
+ [1] This implementation handles only continous variables in the context length. Categorical variables
+ support will be added in the future.
+ [2] The `TimeXer` model obtains many of its attributes from the `TslibBaseModel` class, which is a base class
+ where a lot of the boiler plate code for metadata handling and model initialization is implemented.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ loss: nn.Module,
+ enc_in: int = None,
+ hidden_size: int = 512,
+ n_heads: int = 8,
+ e_layers: int = 2,
+ d_ff: int = 2048,
+ dropout: float = 0.1,
+ patch_length: int = 4,
+ factor: int = 5,
+ activation: str = "relu",
+ endogenous_vars: Optional[list[str]] = None,
+ exogenous_vars: Optional[list[str]] = None,
+ logging_metrics: Optional[list[nn.Module]] = None,
+ optimizer: Optional[Union[Optimizer, str]] = "adam",
+ optimizer_params: Optional[dict] = None,
+ lr_scheduler: Optional[str] = None,
+ lr_scheduler_params: Optional[dict] = None,
+ metadata: Optional[dict] = None,
+ **kwargs: Any,
+ ):
+ super().__init__(
+ loss=loss,
+ logging_metrics=logging_metrics,
+ optimizer=optimizer,
+ optimizer_params=optimizer_params,
+ lr_scheduler=lr_scheduler,
+ lr_scheduler_params=lr_scheduler_params,
+ metadata=metadata,
+ )
+
+ warn.warn(
+ "TimeXer is an experimental model implemented on TslibBaseModelV2. "
+ "It is an unstable version and maybe subject to unannouced changes."
+ "Please use with caution. Feedback on the design and implementation is"
+ ""
+ "welcome. On the issue #1833 - https://github.com/sktime/pytorch-forecasting/issues/1833",
+ )
+
+ self.enc_in = enc_in
+ self.hidden_size = hidden_size
+ self.n_heads = n_heads
+ self.e_layers = e_layers
+ self.d_ff = d_ff
+ self.dropout = dropout
+ self.patch_length = patch_length
+ self.activation = activation
+ self.factor = factor
+ self.endogenous_vars = endogenous_vars
+ self.exogenous_vars = exogenous_vars
+ self.save_hyperparameters(ignore=["loss", "logging_metrics", "metadata"])
+
+ self._init_network()
+
+ def _init_network(self):
+ """
+ Initialize the network for TimeXer's architecture.
+ """
+
+ from pytorch_forecasting.layers import (
+ AttentionLayer,
+ DataEmbedding_inverted,
+ Encoder,
+ EncoderLayer,
+ EnEmbedding,
+ FlattenHead,
+ FullAttention,
+ )
+
+ if self.context_length <= self.patch_length:
+ raise ValueError(
+ f"Context length ({self.context_length}) must be greater than patch"
+ "length. Patches of ({self.patch_length}) will end up being longer than"
+ "the sequence length."
+ )
+
+ if self.context_length % self.patch_length != 0:
+ warn.warn(
+ f"Context length ({self.context_length}) is not divisible by"
+ " patch length. This may lead to unexpected behavior, as some"
+ "time steps will not be used in the model."
+ )
+
+ self.patch_num = max(1, int(self.context_length // self.patch_length))
+
+ if self.target_dim > 1 and self.features == "M":
+ self.n_target_vars = self.target_dim
+ else:
+ self.n_target_vars = 1
+
+ # currently enc_in is set only to cont_dim since
+ # the data module doesn't fully support categorical
+ # variables in the context length and modele expects
+ # float values.
+ self.enc_in = self.enc_in or self.cont_dim
+
+ self.n_quantiles = None
+
+ if hasattr(self.loss, "quantiles") and self.loss.quantiles is not None:
+ self.n_quantiles = len(self.loss.quantiles)
+
+ if self.hidden_size % self.n_heads != 0:
+ raise ValueError(
+ f"hidden_size ({self.hidden_size}) must be divisible by n_heads ({self.n_heads}) " # noqa: E501
+ f"for multi-head attention mechanism to work properly."
+ )
+
+ self.en_embedding = EnEmbedding(
+ self.n_target_vars, self.hidden_size, self.patch_length, self.dropout
+ )
+
+ self.ex_embedding = DataEmbedding_inverted(
+ self.context_length, self.hidden_size, self.dropout
+ )
+
+ encoder_layers = []
+
+ for _ in range(self.e_layers):
+ encoder_layers.append(
+ EncoderLayer(
+ AttentionLayer(
+ FullAttention(
+ False,
+ self.factor,
+ attention_dropout=self.dropout,
+ output_attention=False,
+ ),
+ self.hidden_size,
+ self.n_heads,
+ ),
+ AttentionLayer(
+ FullAttention(
+ False,
+ self.factor,
+ attention_dropout=self.dropout,
+ output_attention=False,
+ ),
+ self.hidden_size,
+ self.n_heads,
+ ),
+ self.hidden_size,
+ self.d_ff,
+ dropout=self.dropout,
+ activation=self.activation,
+ )
+ )
+
+ self.encoder = Encoder(
+ encoder_layers, norm_layer=torch.nn.LayerNorm(self.hidden_size)
+ )
+
+ # Initialize output head
+ self.head_nf = self.hidden_size * (self.patch_num + 1)
+ self.head = FlattenHead(
+ self.enc_in,
+ self.head_nf,
+ self.prediction_length,
+ head_dropout=self.dropout,
+ n_quantiles=self.n_quantiles,
+ )
+
+ def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Forward pass of the TimeXer model.
+ Args:
+ x (dict[str, torch.Tensor]): Input data.
+ Returns:
+ dict[str, torch.Tensor]: Model predictions.
+ """
+ batch_size = x["history_cont"].shape[0]
+ history_cont = x["history_cont"]
+ history_time_idx = x.get("history_time_idx", None)
+
+ history_target = x.get(
+ "history_target",
+ torch.zeros(batch_size, self.context_length, 1, device=self.device),
+ ) # noqa: E501
+
+ if history_time_idx is not None and history_time_idx.dim() == 2:
+ # change [batch_size, time_steps] to [batch_size, time_steps, features]
+ history_time_idx = history_time_idx.unsqueeze(-1)
+
+ # explicitly set endogenous and exogenous variables
+ endogenous_cont = history_target
+ if self.endogenous_vars:
+ endogenous_indices = [
+ self.feature_names["continuous"].index(var)
+ for var in self.endogenous_vars # noqa: E501
+ ]
+ endogenous_cont = history_cont[..., endogenous_indices]
+
+ exogenous_cont = history_cont
+ if self.exogenous_vars:
+ exogenous_indices = [
+ self.feature_names["continuous"].index(var)
+ for var in self.exogenous_vars # noqa: E501
+ ]
+ exogenous_cont = history_cont[..., exogenous_indices]
+
+ en_embed, n_vars = self.en_embedding(endogenous_cont)
+ ex_embed = self.ex_embedding(exogenous_cont, history_time_idx)
+
+ enc_out = self.encoder(en_embed, ex_embed)
+
+ enc_out = torch.reshape(
+ enc_out, (-1, n_vars, enc_out.shape[-2], enc_out.shape[-1])
+ )
+
+ enc_out = enc_out.permute(0, 1, 3, 2)
+
+ dec_out = self.head(enc_out)
+
+ if self.n_quantiles is not None:
+ dec_out = dec_out.permute(0, 2, 1, 3)
+ else:
+ dec_out = dec_out.permute(0, 2, 1)
+
+ return dec_out
+
+ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+ """
+ Forward pass of the TimeXer model.
+ Args:
+ x (dict[str, torch.Tensor]): Input data.
+ Returns:
+ dict[str, torch.Tensor]: Model predictions.
+ """
+
+ out = self._forecast(x)
+ prediction = out[:, : self.prediction_length, :]
+
+ # check to see if the output shape is equal to number of targets
+ if prediction.size(2) != self.target_dim:
+ prediction = prediction[:, :, : self.target_dim]
+
+ if "target_scale" in x:
+ prediction = self.transform_output(prediction, x["target_scale"])
+
+ return {"prediction": prediction}
diff --git a/tests/test_models/test_timexer_v2.py b/tests/test_models/test_timexer_v2.py
new file mode 100644
index 000000000..3285bd0e7
--- /dev/null
+++ b/tests/test_models/test_timexer_v2.py
@@ -0,0 +1,400 @@
+"""
+Basic test frameowrk for TimeXer v2 model.
+TODO:
+- Add tests for testing the scaling of features, once that is implemented in the D1/D2
+ level.
+- Add tests for the M mode (multiple series) once that is implemented.
+"""
+
+import numpy as np
+import pandas as pd
+import pytest
+import torch
+import torch.nn as nn
+
+from pytorch_forecasting.data import TimeSeries
+from pytorch_forecasting.data._tslib_data_module import TslibDataModule
+from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss
+from pytorch_forecasting.models.timexer._timexer_v2 import TimeXer
+
+
+@pytest.fixture
+def sample_multivariate_data():
+ """Sample multivariate data for testing."""
+
+ np.random.seed(42)
+
+ series_len = 30
+ num_groups = 3
+ data = []
+
+ for i in range(num_groups):
+ time_idx = np.arange(series_len, dtype=np.int64)
+
+ trend = 100 + i * 20 + 0.5 * time_idx
+ seasonal = 10 * np.sin(2 * np.pi * time_idx / 12)
+ noise = np.random.normal(0, 5, series_len)
+
+ target = trend + seasonal + noise
+
+ temperature = (
+ 20
+ + 15 * np.sin(2 * np.pi * time_idx / 365)
+ + np.random.normal(0, 3, series_len)
+ ) # noqa: E501
+ humidity = (
+ 30
+ + 20 * np.cos(2 * np.pi * time_idx / 7)
+ + np.random.normal(0, 5, series_len)
+ ) # noqa: E501
+ pressure = (
+ 1013
+ + 10 * np.sin(2 * np.pi * time_idx / 30)
+ + np.random.normal(0, 2, series_len)
+ ) # noqa: E501
+
+ static_cont_val = np.float32(i * 10.0)
+ static_cat_code = np.float32(i % 2)
+
+ df_group = pd.DataFrame(
+ {
+ "time_idx": time_idx,
+ "group_id": f"group_{i}",
+ "value": target.astype(np.float32),
+ "temperature": temperature.astype(np.float32),
+ "humidity": humidity.astype(np.float32),
+ "pressure": pressure.astype(np.float32),
+ "static_cont_feat": np.full(
+ series_len, static_cont_val, dtype=np.float32
+ ),
+ "static_cat_feat": np.full(
+ series_len, static_cat_code, dtype=np.float32
+ ),
+ }
+ )
+ data.append(df_group)
+
+ df = pd.concat(data, ignore_index=True)
+ df["group_id"] = df["group_id"].astype("category")
+
+ return df
+
+
+@pytest.fixture
+def sample_multivariate_multi_series_data():
+ """Create sample data for M mode (multiple series) testing."""
+ np.random.seed(123)
+
+ series_len = 30
+ num_groups = 5
+ data = []
+
+ for i in range(num_groups):
+ time_idx = np.arange(series_len, dtype=np.int64)
+ base_level = 50 + i * 15
+ trend_slope = 0.2 + i * 0.1
+ seasonal_amp = 5 + i * 2
+
+ # Target variables (multiple targets for M mode)
+ target1 = (
+ base_level
+ + trend_slope * time_idx
+ + seasonal_amp * np.sin(2 * np.pi * time_idx / 7)
+ + np.random.normal(0, 1, series_len)
+ )
+ target2 = (
+ base_level * 0.8
+ + trend_slope * 0.5 * time_idx
+ + seasonal_amp * 0.7 * np.cos(2 * np.pi * time_idx / 7)
+ + np.random.normal(0, 1.5, series_len)
+ ) # noqa: E501
+
+ # Exogenous variables
+ temperature = (
+ 18
+ + 12 * np.sin(2 * np.pi * time_idx / 365)
+ + np.random.normal(0, 2, series_len)
+ ) # noqa: E501
+ humidity = (
+ 45
+ + 25 * np.cos(2 * np.pi * time_idx / 7 + i * np.pi / 4)
+ + np.random.normal(0, 4, series_len)
+ ) # noqa: E501
+ pressure = (
+ 1010
+ + 8 * np.sin(2 * np.pi * time_idx / 30)
+ + np.random.normal(0, 1.5, series_len)
+ ) # noqa: E501
+ wind_speed = (
+ 5
+ + 3 * np.sin(2 * np.pi * time_idx / 14)
+ + np.random.normal(0, 1, series_len)
+ ) # noqa: E501
+
+ df_group = pd.DataFrame(
+ {
+ "time_idx": time_idx,
+ "group_id": f"series_{i}",
+ "target1": target1.astype(np.float32),
+ "target2": target2.astype(np.float32),
+ "temperature": temperature.astype(np.float32),
+ "humidity": humidity.astype(np.float32),
+ "pressure": pressure.astype(np.float32),
+ "wind_speed": wind_speed.astype(np.float32),
+ }
+ )
+ data.append(df_group)
+
+ df = pd.concat(data, ignore_index=True)
+ df["group_id"] = df["group_id"].astype("category")
+
+ return df
+
+
+@pytest.fixture
+def basic_timeseries_dataset(sample_multivariate_data):
+ """Create a basic TimeSeries dataset for testing."""
+ return TimeSeries(
+ data=sample_multivariate_data,
+ time="time_idx",
+ target="value",
+ group=["group_id"],
+ num=[
+ "value",
+ "temperature",
+ "humidity",
+ "pressure",
+ "static_cont_feat",
+ "static_cat_feat",
+ ],
+ cat=[],
+ known=["temperature", "humidity", "pressure", "time_idx"],
+ static=["static_cont_feat", "static_cat_feat"],
+ )
+
+
+@pytest.fixture
+def basic_tslib_data_module(basic_timeseries_dataset):
+ """Create a basic TslibDataModule for testing."""
+ return TslibDataModule(
+ time_series_dataset=basic_timeseries_dataset,
+ batch_size=2,
+ context_length=12,
+ prediction_length=8,
+ train_val_test_split=(0.7, 0.15, 0.15),
+ )
+
+
+@pytest.fixture
+def basic_metadata(basic_tslib_data_module):
+ """Basic metadata from data module for model initialization."""
+ basic_tslib_data_module.setup()
+
+ # Return the generated metadata
+ return basic_tslib_data_module.metadata
+
+
+@pytest.fixture
+def model(basic_metadata):
+ """Initialize a TimeXer model for testing."""
+ return TimeXer(
+ loss=MAE(),
+ hidden_size=64,
+ n_heads=8,
+ e_layers=2,
+ d_ff=256,
+ dropout=0.1,
+ patch_length=4,
+ logging_metrics=[SMAPE()],
+ optimizer="adam",
+ optimizer_params={"lr": 1e-3},
+ lr_scheduler="reduce_lr_on_plateau",
+ lr_scheduler_params={
+ "mode": "min",
+ "factor": 0.5,
+ "patience": 5,
+ },
+ metadata=basic_metadata,
+ )
+
+
+def test_basic_model_initialization(model, basic_metadata):
+ """Test the basic model initialization."""
+
+ assert isinstance(model, TimeXer)
+
+ assert model.hidden_size == 64
+ assert model.n_heads == 8
+ assert model.e_layers == 2
+ assert model.d_ff == 256
+ assert model.patch_length == 4
+ assert model.dropout == 0.1
+
+ assert model.patch_num == 3
+ assert model.n_target_vars == 1
+ assert model.head_nf == 64 * (3 + 1)
+
+ assert model.context_length == basic_metadata["context_length"]
+ assert model.prediction_length == basic_metadata["prediction_length"]
+ assert model.cont_dim == basic_metadata["n_features"]["continuous"]
+ assert model.cat_dim == basic_metadata["n_features"]["categorical"]
+ assert model.target_dim == basic_metadata["n_features"]["target"]
+ assert model.features == basic_metadata["features"]
+
+
+def test_multivariate_single_series(model, basic_tslib_data_module):
+ basic_tslib_data_module.setup()
+ train_dataloader = basic_tslib_data_module.train_dataloader()
+ batch = next(iter(train_dataloader))[0]
+
+ model.eval()
+ with torch.no_grad():
+ output = model(batch)
+
+ assert "prediction" in output
+ predictions = output["prediction"]
+
+ batch_size = batch["history_cont"].shape[0]
+ assert predictions.shape == (batch_size, model.prediction_length, model.target_dim)
+
+ assert not torch.isnan(predictions).any()
+ assert not torch.isinf(predictions).any()
+
+
+def test_quantile_predictions(basic_metadata):
+ """Test quantile predictions with TimeXer model."""
+
+ quantiles = [0.1, 0.5, 0.9]
+
+ model = TimeXer(
+ loss=QuantileLoss(quantiles=quantiles),
+ hidden_size=64,
+ n_heads=8,
+ e_layers=2,
+ d_ff=256,
+ dropout=0.1,
+ patch_length=4,
+ metadata=basic_metadata,
+ )
+
+ assert model.n_quantiles == 3
+
+ batch_size = 4
+
+ # sample input data as a substitute for x
+ sample_input_data = {
+ "history_cont": torch.randn(
+ batch_size, 12, basic_metadata["n_features"]["continuous"]
+ ),
+ "history_target": torch.randn(
+ batch_size, 12, basic_metadata["n_features"]["target"]
+ ),
+ "history_time_idx": torch.arange(12).unsqueeze(0).repeat(batch_size, 1),
+ }
+
+ model.eval()
+ with torch.no_grad():
+ output = model(sample_input_data)
+
+ predictions = output["prediction"]
+ assert predictions.shape == (batch_size, 8, 1, 3)
+
+
+def test_missing_history_target_handling(basic_metadata):
+ """Test handling of missing history_target in TimeXer model."""
+
+ model = TimeXer(
+ loss=MAE(),
+ hidden_size=64,
+ n_heads=8,
+ e_layers=2,
+ d_ff=256,
+ dropout=0.1,
+ patch_length=4,
+ metadata=basic_metadata,
+ )
+
+ batch_size = 4
+ sample_input = {
+ "history_cont": torch.randn(
+ batch_size, 12, basic_metadata["n_features"]["continuous"]
+ ), # noqa: E501
+ "history_time_idx": torch.arange(12).unsqueeze(0).repeat(batch_size, 1),
+ }
+
+ model.eval()
+ with torch.no_grad():
+ output = model(sample_input)
+
+ predictions = output["prediction"]
+ assert predictions.shape == (batch_size, 8, basic_metadata["n_features"]["target"])
+ assert not torch.isnan(predictions).any()
+
+
+def test_endogenous_exogenous_variable_selection(basic_metadata):
+ """Test explicit endogenous and exogenous variable selection in TimeXer model."""
+
+ endo_names = basic_metadata["feature_names"]["continuous"][0]
+ exog_names = basic_metadata["feature_names"]["continuous"][1]
+
+ model = TimeXer(
+ loss=MAE(),
+ hidden_size=64,
+ n_heads=8,
+ endogenous_vars=[endo_names],
+ exogenous_vars=[exog_names],
+ e_layers=2,
+ metadata=basic_metadata,
+ )
+
+ batch_size = 4
+ sample_input = {
+ "history_cont": torch.randn(
+ batch_size, 12, basic_metadata["n_features"]["continuous"]
+ ),
+ "history_target": torch.randn(
+ batch_size, 12, basic_metadata["n_features"]["target"]
+ ),
+ "history_time_idx": torch.arange(12).unsqueeze(0).repeat(batch_size, 1),
+ }
+
+ model.eval()
+ with torch.no_grad():
+ output = model(sample_input)
+
+ predictions = output["prediction"]
+ assert predictions.shape == (batch_size, 8, 1)
+ assert not torch.isnan(predictions).any()
+
+
+def test_integration_with_datamodule(model, basic_tslib_data_module):
+ """Test integration of TimeXer model with TslibDataModule."""
+
+ basic_tslib_data_module.setup(stage="fit")
+ basic_tslib_data_module.setup(stage="test")
+
+ train_loader = basic_tslib_data_module.train_dataloader()
+ test_loader = basic_tslib_data_module.test_dataloader()
+ val_loader = basic_tslib_data_module.val_dataloader()
+
+ model.eval()
+ with torch.no_grad():
+ train_batch = next(iter(train_loader))[0]
+ train_output = model(train_batch)
+ assert train_output["prediction"].shape[1] == model.prediction_length
+
+ # Check if validation and test sets are not empty
+ # If they are empty, skip the validation and test checks
+ try:
+ val_batch = next(iter(val_loader))[0]
+ val_output = model(val_batch)
+ assert val_output["prediction"].shape[1] == model.prediction_length
+ except StopIteration:
+ print("Validation set is empty, skipping validation testing")
+
+ try:
+ test_batch = next(iter(test_loader))[0]
+ test_output = model(test_batch)
+ assert test_output["prediction"].shape[1] == model.prediction_length
+ except StopIteration:
+ print("Test set is empty, skipping test testing")