-
Notifications
You must be signed in to change notification settings - Fork 726
[ENH] Add predict to v2 models
#1984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
phoeenniixx
wants to merge
11
commits into
sktime:main
Choose a base branch
from
phoeenniixx:predict-v2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
cff05b9
initial design
phoeenniixx 1371c68
Merge branch 'main' into predict-v2
phoeenniixx 13c12b9
preliminary design
phoeenniixx 7cfb26c
Merge branch 'main' into predict-v2
phoeenniixx 35e2447
add predict to all v2 models
phoeenniixx 882caba
update base model
phoeenniixx c53f881
update base model
phoeenniixx f8d06fe
update test_integration
phoeenniixx a1aad82
Merge branch 'main' into predict-v2
phoeenniixx e7fcc77
Merge branch 'main' into predict-v2
phoeenniixx 5024fcb
hadnle kwargs
phoeenniixx File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,241 @@ | ||
| from pathlib import Path | ||
| import pickle | ||
| from typing import Any, Optional, Union | ||
|
|
||
| from lightning import Trainer | ||
| from lightning.pytorch.callbacks import ModelCheckpoint | ||
| from lightning.pytorch.core.datamodule import LightningDataModule | ||
| import torch | ||
| from torch.utils.data import DataLoader | ||
|
|
||
| from pytorch_forecasting.data import TimeSeries | ||
| from pytorch_forecasting.models.base._base_object import _BasePtForecasterV2 | ||
|
|
||
|
|
||
| class Base_pkg(_BasePtForecasterV2): | ||
| """ | ||
| Base model package class acting as a high-level wrapper for the Lightning workflow. | ||
|
|
||
| This class simplifies the user experience by managing model, datamodule, and trainer | ||
| configurations, and providing streamlined ``fit`` and ``predict`` methods. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model_cfg : dict, optional | ||
| Model configs for the initialisation of the model. Required if not loading | ||
| from a checkpoint. Defaults to ``{}``. | ||
| trainer_cfg : dict, optional | ||
| Configs to initialise ``lightning.Trainer``. Defaults to {}. | ||
| datamodule_cfg : Union[dict, str, Path], optional | ||
| Configs to initialise a ``LightningDataModule``. | ||
| - If dict, the keys and values are used as configuration parameters. | ||
| - If str or Path, it should be a path to a ``.pkl`` file containing | ||
| the serialized configuration dictionary. Required for reproducibility | ||
| when loading a model for inference. Defaults to {}. | ||
| ckpt_path : Union[str, Path], optional | ||
| Path to the checkpoint from which to load the model. If provided, `model_cfg` | ||
| is ignored. Defaults to None. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_cfg: Optional[dict[str, Any]] = None, | ||
| trainer_cfg: Optional[dict[str, Any]] = None, | ||
| datamodule_cfg: Optional[Union[dict[str, Any], str, Path]] = None, | ||
| ckpt_path: Optional[Union[str, Path]] = None, | ||
| ): | ||
| self.model_cfg = model_cfg or {} | ||
| self.trainer_cfg = trainer_cfg or {} | ||
| self.ckpt_path = Path(ckpt_path) if ckpt_path else None | ||
|
|
||
| if isinstance(datamodule_cfg, (str, Path)): | ||
| with open(datamodule_cfg, "rb") as f: | ||
| self.datamodule_cfg = pickle.load(f) # noqa : S301 | ||
| else: | ||
| self.datamodule_cfg = datamodule_cfg or {} | ||
|
|
||
| self.model = None | ||
| self.trainer = None | ||
| self.datamodule = None | ||
|
|
||
| @classmethod | ||
| def get_cls(cls): | ||
| """Get the underlying model class.""" | ||
| raise NotImplementedError("Subclasses must implement `get_cls`.") | ||
|
|
||
| @classmethod | ||
| def get_datamodule_cls(cls): | ||
| """Get the underlying DataModule class.""" | ||
| raise NotImplementedError("Subclasses must implement `get_datamodule_cls`.") | ||
|
|
||
| @classmethod | ||
| def get_test_dataset_from(cls, **kwargs): | ||
| """ | ||
| Creates and returns D1 TimeSeries dataSet objects for testing. | ||
| """ | ||
| from pytorch_forecasting.tests._data_scenarios import ( | ||
| data_with_covariates_v2, | ||
| make_datasets_v2, | ||
| ) | ||
|
|
||
| raw_data = data_with_covariates_v2() | ||
|
|
||
| datasets_info = make_datasets_v2(raw_data, **kwargs) | ||
|
|
||
| return { | ||
| "train": datasets_info["training_dataset"], | ||
| "predict": datasets_info["validation_dataset"], | ||
| } | ||
|
|
||
| def _build_model(self, metadata: dict): | ||
| """Instantiates the model, either from a checkpoint or from config.""" | ||
| model_cls = self.get_cls() | ||
| if self.ckpt_path: | ||
| self.model = model_cls.load_from_checkpoint(self.ckpt_path) | ||
| elif self.model_cfg: | ||
| self.model = model_cls(**self.model_cfg, metadata=metadata) | ||
| else: | ||
| self.model = None | ||
|
|
||
| def _build_datamodule(self, data: TimeSeries) -> LightningDataModule: | ||
| """Constructs a DataModule from a D1 layer object.""" | ||
| if not self.datamodule_cfg: | ||
| raise ValueError("`datamodule_cfg` must be provided to build a datamodule.") | ||
| datamodule_cls = self.get_datamodule_cls() | ||
| return datamodule_cls(data, **self.datamodule_cfg) | ||
|
|
||
| def _load_dataloader( | ||
| self, data: Union[TimeSeries, LightningDataModule, DataLoader] | ||
| ) -> DataLoader: | ||
| """Converts various data input types into a DataLoader for prediction.""" | ||
| if isinstance(data, TimeSeries): # D1 Layer | ||
| dm = self._build_datamodule(data) | ||
| dm.setup(stage="predict") | ||
| return dm.predict_dataloader() | ||
| elif isinstance(data, LightningDataModule): # D2 Layer | ||
| data.setup(stage="predict") | ||
| return data.predict_dataloader() | ||
| elif isinstance(data, DataLoader): | ||
| return data | ||
| else: | ||
| raise TypeError( | ||
| f"Unsupported data type for prediction: {type(data).__name__}. " | ||
| "Expected TimeSeriesDataSet, LightningDataModule, or DataLoader." | ||
| ) | ||
|
|
||
| def fit( | ||
| self, | ||
| data: Union[TimeSeries, LightningDataModule], | ||
| # todo: we should create a base data_module for different data_modules | ||
| save_ckpt: bool = True, | ||
| ckpt_dir: Union[str, Path] = "checkpoints", | ||
| **trainer_fit_kwargs, | ||
| ): | ||
| """ | ||
| Fit the model to the training data. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : Union[TimeSeries, LightningDataModule] | ||
| The data to fit on (D1 or D2 layer). This object is responsible | ||
| for providing both training and validation data. | ||
| save_ckpt : bool, default=True | ||
| If True, save the best model checkpoint and the `datamodule_cfg`. | ||
| ckpt_dir : Union[str, Path], default="checkpoints" | ||
| Directory to save artifacts. | ||
| **trainer_fit_kwargs : | ||
| Additional keyword arguments passed to `trainer.fit()`. | ||
|
|
||
| Returns | ||
| ------- | ||
| Optional[Path] | ||
| The path to the best model checkpoint if `save_ckpt=True`, else None. | ||
| """ | ||
| if isinstance(data, TimeSeries): | ||
| self.datamodule = self._build_datamodule(data) | ||
| else: | ||
| self.datamodule = data | ||
| self.datamodule.setup(stage="fit") | ||
|
|
||
| if self.model is None: | ||
| if not self.model_cfg: | ||
| raise RuntimeError( | ||
| "`model_cfg` must be provided to train from scratch." | ||
| ) | ||
| metadata = self.datamodule.metadata | ||
| self._build_model(metadata) | ||
|
|
||
| callbacks = self.trainer_cfg.get("callbacks", []).copy() | ||
| checkpoint_cb = None | ||
| if save_ckpt: | ||
| ckpt_dir = Path(ckpt_dir) | ||
| ckpt_dir.mkdir(parents=True, exist_ok=True) | ||
| checkpoint_cb = ModelCheckpoint( | ||
| dirpath=ckpt_dir, | ||
| filename="best-{epoch}-{val_loss:.2f}", | ||
| save_top_k=1, | ||
| monitor="val_loss", | ||
| mode="min", | ||
| ) | ||
| callbacks.append(checkpoint_cb) | ||
| trainer_init_cfg = self.trainer_cfg.copy() | ||
| trainer_init_cfg.pop("callbacks", None) | ||
|
|
||
| self.trainer = Trainer(**trainer_init_cfg, callbacks=callbacks) | ||
|
|
||
| self.trainer.fit(self.model, datamodule=self.datamodule, **trainer_fit_kwargs) | ||
| if save_ckpt and checkpoint_cb: | ||
| best_model_path = Path(checkpoint_cb.best_model_path) | ||
| dm_cfg_path = best_model_path.parent / "datamodule_cfg.pkl" | ||
| with open(dm_cfg_path, "wb") as f: | ||
| pickle.dump(self.datamodule_cfg, f) | ||
| print(f"Best model saved to: {best_model_path}") | ||
| print(f"DataModule config saved to: {dm_cfg_path}") | ||
| return best_model_path | ||
| return None | ||
|
|
||
| def predict( | ||
| self, | ||
| data: Union[TimeSeries, LightningDataModule, DataLoader], | ||
| output_dir: Optional[Union[str, Path]] = None, | ||
| **kwargs, | ||
| ) -> Union[dict[str, torch.Tensor], None]: | ||
| """ | ||
| Generate predictions by wrapping the model's predict method. | ||
|
|
||
| This method prepares the data by resolving it into a DataLoader and then | ||
| delegates the prediction task to the underlying model's ``.predict()`` method. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : Union[TimeSeries, LightningDataModule, DataLoader] | ||
| The data to predict on (D1, D2, or DataLoader). | ||
| **kwargs : | ||
| Additional keyword arguments passed directly to the model's ``.predict()`` | ||
| method. This includes `mode`, `return_info`, `output_dir`, and any | ||
| `trainer_kwargs`. | ||
|
|
||
| Returns | ||
| ------- | ||
| Union[Dict[str, torch.Tensor], None] | ||
| A dictionary of prediction tensors, or `None` if `output_dir` is specified | ||
| in `**kwargs`. | ||
| """ | ||
| if self.model is None: | ||
| raise RuntimeError( | ||
| "Model is not initialized. Provide `model_cfg` or `ckpt_path`." | ||
| ) | ||
|
|
||
| dataloader = self._load_dataloader(data) | ||
| predictions = self.model.predict(dataloader, **kwargs) | ||
|
|
||
| if output_dir: | ||
| output_path = Path(output_dir) | ||
| output_path.mkdir(parents=True, exist_ok=True) | ||
| output_file = output_path / "predictions.pkl" | ||
| with open(output_file, "wb") as f: | ||
| pickle.dump(predictions, f) | ||
| print(f"Predictions saved to {output_file}") | ||
| return None | ||
|
|
||
| return predictions | ||
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| from typing import Any, Optional | ||
| from warnings import warn | ||
|
|
||
| from lightning import Trainer | ||
| from lightning.pytorch import LightningModule | ||
| from lightning.pytorch.callbacks import BasePredictionWriter | ||
| import torch | ||
|
|
||
|
|
||
| class PredictCallback(BasePredictionWriter): | ||
| """ | ||
| Callback to capture predictions and related information internally. | ||
|
|
||
| This callback is used by ``BaseModel.predict()`` to process raw model outputs | ||
| into the desired format (``prediction``, ``quantiles``, or ``raw``) and collect | ||
| any additional requested info (``x``, ``y``, ``index``, etc.). The results are | ||
| collated and stored in memory, accessible via the ``.result`` property. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| mode : str | ||
| The prediction mode ("prediction", "quantiles", or "raw"). | ||
| return_info : list[str], optional | ||
| Additional information to return. | ||
| **kwargs : | ||
| Additional keyword arguments for `to_prediction` or `to_quantiles`. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| mode: str = "prediction", | ||
| return_info: Optional[list[str]] = None, | ||
| mode_kwargs: dict[str, Any] = None, | ||
| ): | ||
| super().__init__(write_interval="epoch") | ||
| self.mode = mode | ||
| self.return_info = return_info or [] | ||
| self.mode_kwargs = mode_kwargs or {} | ||
| self._reset_data() | ||
|
|
||
| def _reset_data(self, result: bool = True): | ||
| """Clear collected data for a new prediction run.""" | ||
| self.predictions = [] | ||
| self.info = {key: [] for key in self.return_info} | ||
| if result: | ||
| self._result = None | ||
|
|
||
| def on_predict_batch_end( | ||
| self, | ||
| trainer: Trainer, | ||
| pl_module: LightningModule, | ||
| outputs: Any, | ||
| batch: Any, | ||
| batch_idx: int, | ||
| dataloader_idx: int = 0, | ||
| ): | ||
| """Process and store predictions for a single batch.""" | ||
| x, y = batch | ||
|
|
||
| if self.mode == "raw": | ||
| processed_output = outputs | ||
| elif self.mode == "prediction": | ||
| processed_output = pl_module.to_prediction(outputs, **self.mode_kwargs) | ||
| elif self.mode == "quantiles": | ||
| processed_output = pl_module.to_quantiles(outputs, **self.mode_kwargs) | ||
| else: | ||
| raise ValueError(f"Invalid prediction mode: {self.mode}") | ||
|
|
||
| self.predictions.append(processed_output) | ||
|
|
||
| for key in self.return_info: | ||
| if key == "x": | ||
| self.info[key].append(x) | ||
| elif key == "y": | ||
| self.info[key].append(y[0]) | ||
| elif key == "index": | ||
| self.info[key].append(y[1]) | ||
| elif key == "decoder_lengths": | ||
| self.info[key].append(x["decoder_lengths"]) | ||
| else: | ||
| warn(f"Unknown return_info key: {key}") | ||
|
|
||
| def on_predict_epoch_end(self, trainer: Trainer, pl_module: LightningModule): | ||
| """Collate all batch results into final tensors.""" | ||
| if self.mode == "raw" and isinstance(self.predictions[0], dict): | ||
| keys = self.predictions[0].keys() | ||
| collated_preds = { | ||
| key: torch.cat([p[key] for p in self.predictions]) for key in keys | ||
| } | ||
| else: | ||
| collated_preds = {"prediction": torch.cat(self.predictions)} | ||
|
|
||
| final_result = collated_preds | ||
|
|
||
| for key, data_list in self.info.items(): | ||
| if isinstance(data_list[0], dict): | ||
| collated_info = { | ||
| k: torch.cat([d[k] for d in data_list]) for k in data_list[0].keys() | ||
| } | ||
| else: | ||
| collated_info = torch.cat(data_list) | ||
| final_result[key] = collated_info | ||
|
|
||
| self._result = final_result | ||
| self._reset_data(result=False) | ||
|
|
||
| @property | ||
| def result(self) -> dict[str, torch.Tensor]: | ||
| if self._result is None: | ||
| raise RuntimeError("Prediction results are not yet available.") | ||
| return self._result |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor formatting issue: please have newlines around bullet point lists
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for poiniting it out. I will make the changes to the PR soon.