diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 50a4ad77f..b26a45f0e 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -10,11 +10,11 @@ from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union import warnings -import lightning.pytorch as pl -from lightning.pytorch import LightningModule, Trainer -from lightning.pytorch.callbacks import BasePredictionWriter, LearningRateFinder -from lightning.pytorch.trainer.states import RunningStage -from lightning.pytorch.utilities.parsing import AttributeDict, get_init_args +import pytorch_lightning as pl +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import BasePredictionWriter, LearningRateFinder +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.parsing import AttributeDict, get_init_args import matplotlib.pyplot as plt import numpy as np from numpy.lib.function_base import iterable