Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions pytorch_forecasting/base/_base_pkg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from pathlib import Path
import pickle
from typing import Any, Optional, Union

from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.core.datamodule import LightningDataModule
import torch
from torch.utils.data import DataLoader

from pytorch_forecasting.data import TimeSeries
from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2


class Base_pkg(_BasePtForecasterV2):
"""
Base model package class acting as a high-level wrapper for the Lightning workflow.

This class simplifies the user experience by managing model, datamodule, and trainer
configurations, and providing streamlined ``fit`` and ``predict`` methods.

Parameters
----------
model_cfg : dict, optional
Model configs for the initialisation of the model. Required if not loading
from a checkpoint. Defaults to ``{}``.
trainer_cfg : dict, optional
Configs to initialise ``lightning.Trainer``. Defaults to {}.
datamodule_cfg : Union[dict, str, Path], optional
Configs to initialise a ``LightningDataModule``.
- If dict, the keys and values are used as configuration parameters.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor formatting issue: please have newlines around bullet point lists

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for poiniting it out. I will make the changes to the PR soon.

- If str or Path, it should be a path to a ``.pkl`` file containing
the serialized configuration dictionary. Required for reproducibility
when loading a model for inference. Defaults to {}.
ckpt_path : Union[str, Path], optional
Path to the checkpoint from which to load the model. If provided, `model_cfg`
is ignored. Defaults to None.
"""

def __init__(
self,
model_cfg: Optional[dict[str, Any]] = None,
trainer_cfg: Optional[dict[str, Any]] = None,
datamodule_cfg: Optional[Union[dict[str, Any], str, Path]] = None,
ckpt_path: Optional[Union[str, Path]] = None,
):
self.model_cfg = model_cfg or {}
self.trainer_cfg = trainer_cfg or {}
self.ckpt_path = Path(ckpt_path) if ckpt_path else None

if isinstance(datamodule_cfg, (str, Path)):
with open(datamodule_cfg, "rb") as f:
self.datamodule_cfg = pickle.load(f) # noqa : S301
else:
self.datamodule_cfg = datamodule_cfg or {}

self.model = None
self.trainer = None
self.datamodule = None

@classmethod
def get_cls(cls):
"""Get the underlying model class."""
raise NotImplementedError("Subclasses must implement `get_cls`.")

@classmethod
def get_datamodule_cls(cls):
"""Get the underlying DataModule class."""
raise NotImplementedError("Subclasses must implement `get_datamodule_cls`.")

@classmethod
def get_test_dataset_from(cls, **kwargs):
"""
Creates and returns D1 TimeSeries dataSet objects for testing.
"""
from pytorch_forecasting.tests._data_scenarios import (
data_with_covariates_v2,
make_datasets_v2,
)

raw_data = data_with_covariates_v2()

datasets_info = make_datasets_v2(raw_data, **kwargs)

return {
"train": datasets_info["training_dataset"],
"predict": datasets_info["validation_dataset"],
}

def _build_model(self, metadata: dict):
"""Instantiates the model, either from a checkpoint or from config."""
model_cls = self.get_cls()
if self.ckpt_path:
self.model = model_cls.load_from_checkpoint(self.ckpt_path)
elif self.model_cfg:
self.model = model_cls(**self.model_cfg, metadata=metadata)
else:
self.model = None

def _build_datamodule(self, data: TimeSeries) -> LightningDataModule:
"""Constructs a DataModule from a D1 layer object."""
if not self.datamodule_cfg:
raise ValueError("`datamodule_cfg` must be provided to build a datamodule.")
datamodule_cls = self.get_datamodule_cls()
return datamodule_cls(data, **self.datamodule_cfg)

def _load_dataloader(
self, data: Union[TimeSeries, LightningDataModule, DataLoader]
) -> DataLoader:
"""Converts various data input types into a DataLoader for prediction."""
if isinstance(data, TimeSeries): # D1 Layer
dm = self._build_datamodule(data)
dm.setup(stage="predict")
return dm.predict_dataloader()
elif isinstance(data, LightningDataModule): # D2 Layer
data.setup(stage="predict")
return data.predict_dataloader()
elif isinstance(data, DataLoader):
return data
else:
raise TypeError(
f"Unsupported data type for prediction: {type(data).__name__}. "
"Expected TimeSeriesDataSet, LightningDataModule, or DataLoader."
)

def fit(
self,
data: Union[TimeSeries, LightningDataModule],
# todo: we should create a base data_module for different data_modules
save_ckpt: bool = True,
ckpt_dir: Union[str, Path] = "checkpoints",
**trainer_fit_kwargs,
):
"""
Fit the model to the training data.

Parameters
----------
data : Union[TimeSeries, LightningDataModule]
The data to fit on (D1 or D2 layer). This object is responsible
for providing both training and validation data.
save_ckpt : bool, default=True
If True, save the best model checkpoint and the `datamodule_cfg`.
ckpt_dir : Union[str, Path], default="checkpoints"
Directory to save artifacts.
**trainer_fit_kwargs :
Additional keyword arguments passed to `trainer.fit()`.

Returns
-------
Optional[Path]
The path to the best model checkpoint if `save_ckpt=True`, else None.
"""
if isinstance(data, TimeSeries):
self.datamodule = self._build_datamodule(data)
else:
self.datamodule = data
self.datamodule.setup(stage="fit")

if self.model is None:
if not self.model_cfg:
raise RuntimeError(
"`model_cfg` must be provided to train from scratch."
)
metadata = self.datamodule.metadata
self._build_model(metadata)

callbacks = self.trainer_cfg.get("callbacks", []).copy()
checkpoint_cb = None
if save_ckpt:
ckpt_dir = Path(ckpt_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
checkpoint_cb = ModelCheckpoint(
dirpath=ckpt_dir,
filename="best-{epoch}-{val_loss:.2f}",
save_top_k=1,
monitor="val_loss",
mode="min",
)
callbacks.append(checkpoint_cb)
trainer_init_cfg = self.trainer_cfg.copy()
trainer_init_cfg.pop("callbacks", None)

self.trainer = Trainer(**trainer_init_cfg, callbacks=callbacks)

self.trainer.fit(self.model, datamodule=self.datamodule, **trainer_fit_kwargs)
if save_ckpt and checkpoint_cb:
best_model_path = Path(checkpoint_cb.best_model_path)
dm_cfg_path = best_model_path.parent / "datamodule_cfg.pkl"
with open(dm_cfg_path, "wb") as f:
pickle.dump(self.datamodule_cfg, f)
print(f"Best model saved to: {best_model_path}")
print(f"DataModule config saved to: {dm_cfg_path}")
return best_model_path
return None

def predict(
self,
data: Union[TimeSeries, LightningDataModule, DataLoader],
output_dir: Optional[Union[str, Path]] = None,
**kwargs,
) -> Union[dict[str, torch.Tensor], None]:
"""
Generate predictions by wrapping the model's predict method.

This method prepares the data by resolving it into a DataLoader and then
delegates the prediction task to the underlying model's ``.predict()`` method.

Parameters
----------
data : Union[TimeSeries, LightningDataModule, DataLoader]
The data to predict on (D1, D2, or DataLoader).
**kwargs :
Additional keyword arguments passed directly to the model's ``.predict()``
method. This includes `mode`, `return_info`, `output_dir`, and any
`trainer_kwargs`.

Returns
-------
Union[Dict[str, torch.Tensor], None]
A dictionary of prediction tensors, or `None` if `output_dir` is specified
in `**kwargs`.
"""
if self.model is None:
raise RuntimeError(
"Model is not initialized. Provide `model_cfg` or `ckpt_path`."
)

dataloader = self._load_dataloader(data)
predictions = self.model.predict(dataloader, **kwargs)

if output_dir:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
output_file = output_path / "predictions.pkl"
with open(output_file, "wb") as f:
pickle.dump(predictions, f)
print(f"Predictions saved to {output_file}")
return None

return predictions
Empty file.
111 changes: 111 additions & 0 deletions pytorch_forecasting/callbacks/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Any, Optional
from warnings import warn

from lightning import Trainer
from lightning.pytorch import LightningModule
from lightning.pytorch.callbacks import BasePredictionWriter
import torch


class PredictCallback(BasePredictionWriter):
"""
Callback to capture predictions and related information internally.

This callback is used by ``BaseModel.predict()`` to process raw model outputs
into the desired format (``prediction``, ``quantiles``, or ``raw``) and collect
any additional requested info (``x``, ``y``, ``index``, etc.). The results are
collated and stored in memory, accessible via the ``.result`` property.

Parameters
----------
mode : str
The prediction mode ("prediction", "quantiles", or "raw").
return_info : list[str], optional
Additional information to return.
**kwargs :
Additional keyword arguments for `to_prediction` or `to_quantiles`.
"""

def __init__(
self,
mode: str = "prediction",
return_info: Optional[list[str]] = None,
mode_kwargs: dict[str, Any] = None,
):
super().__init__(write_interval="epoch")
self.mode = mode
self.return_info = return_info or []
self.mode_kwargs = mode_kwargs or {}
self._reset_data()

def _reset_data(self, result: bool = True):
"""Clear collected data for a new prediction run."""
self.predictions = []
self.info = {key: [] for key in self.return_info}
if result:
self._result = None

def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: Any,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
):
"""Process and store predictions for a single batch."""
x, y = batch

if self.mode == "raw":
processed_output = outputs
elif self.mode == "prediction":
processed_output = pl_module.to_prediction(outputs, **self.mode_kwargs)
elif self.mode == "quantiles":
processed_output = pl_module.to_quantiles(outputs, **self.mode_kwargs)
else:
raise ValueError(f"Invalid prediction mode: {self.mode}")

self.predictions.append(processed_output)

for key in self.return_info:
if key == "x":
self.info[key].append(x)
elif key == "y":
self.info[key].append(y[0])
elif key == "index":
self.info[key].append(y[1])
elif key == "decoder_lengths":
self.info[key].append(x["decoder_lengths"])
else:
warn(f"Unknown return_info key: {key}")

def on_predict_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
"""Collate all batch results into final tensors."""
if self.mode == "raw" and isinstance(self.predictions[0], dict):
keys = self.predictions[0].keys()
collated_preds = {
key: torch.cat([p[key] for p in self.predictions]) for key in keys
}
else:
collated_preds = {"prediction": torch.cat(self.predictions)}

final_result = collated_preds

for key, data_list in self.info.items():
if isinstance(data_list[0], dict):
collated_info = {
k: torch.cat([d[k] for d in data_list]) for k in data_list[0].keys()
}
else:
collated_info = torch.cat(data_list)
final_result[key] = collated_info

self._result = final_result
self._reset_data(result=False)

@property
def result(self) -> dict[str, torch.Tensor]:
if self._result is None:
raise RuntimeError("Prediction results are not yet available.")
return self._result
Loading
Loading