Skip to content

[ENH] Add warning for multi-target forecasting support in BaseModel #1880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pytorch_forecasting/models/base/_base_model_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
########################################################################################


import inspect
from typing import Optional, Union
from warnings import warn

Expand Down Expand Up @@ -46,6 +47,17 @@
Parameters for the learning rate scheduler.
"""
super().__init__()

# simple check for MultiLoss usage.
if inspect.isclass(loss) and loss.__class__.__name__ == "MultiLoss":
warn(

Check warning on line 53 in pytorch_forecasting/models/base/_base_model_v2.py

View check run for this annotation

Codecov / codecov/patch

pytorch_forecasting/models/base/_base_model_v2.py#L53

Added line #L53 was not covered by tests
"\nIMPORTANT: Multi-target forecasting (MultiLoss) is NOT supported "
"in v2 base models. For multi-target forecasting, please use "
"pytorch_forecasting.models.base.BaseModel (v1) instead. "
"Attempting to use MultiLoss with v2 models will result in runtime errors.", # noqa: E501
UserWarning,
stacklevel=2,
)
self.loss = loss
self.logging_metrics = logging_metrics if logging_metrics is not None else []
self.optimizer = optimizer
Expand Down
Loading