From 1307cdf839f32700b4f59b80631cce7f84eb1b2c Mon Sep 17 00:00:00 2001 From: Himanshu-Verma-ds <144148209+Himanshu-Verma-ds@users.noreply.github.com> Date: Tue, 9 Jan 2024 19:14:24 +0530 Subject: [PATCH] Updated base_model.py to account for importing error changed the importing of pytorch-forecasting in order to account for this error "`model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `TemporalFusionTransformer`" --- pytorch_forecasting/models/base_model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 50a4ad77f..b26a45f0e 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -10,11 +10,11 @@ from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union import warnings -import lightning.pytorch as pl -from lightning.pytorch import LightningModule, Trainer -from lightning.pytorch.callbacks import BasePredictionWriter, LearningRateFinder -from lightning.pytorch.trainer.states import RunningStage -from lightning.pytorch.utilities.parsing import AttributeDict, get_init_args +import pytorch_lightning as pl +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import BasePredictionWriter, LearningRateFinder +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.parsing import AttributeDict, get_init_args import matplotlib.pyplot as plt import numpy as np from numpy.lib.function_base import iterable