Skip to content

[ENH] Refactor metrics to be polymorphic #1897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 169 additions & 8 deletions pytorch_forecasting/metrics/base_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import inspect
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union
import warnings

from sklearn.base import BaseEstimator
Expand Down Expand Up @@ -87,7 +87,7 @@ def rescale_parameters(
""" # noqa: E501
return encoder(dict(prediction=parameters, target_scale=target_scale))

def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
def _to_prediction_3d(self, y_pred: torch.Tensor) -> torch.Tensor:
"""
Convert network prediction into a point prediction.

Expand All @@ -107,7 +107,28 @@ def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
y_pred = y_pred.mean(-1)
return y_pred

def to_quantiles(
def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
"""
Convert network prediction into a point prediction.

Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this docstring should be clarified - assumed input type, guaranteed output type

y_pred: prediction output of network.

Returns:
torch.Tensor: point prediction.
"""
if y_pred.ndim == 4:
predictions = [
self._to_prediction_3d(y_pred[:, :, i, :])
for i in range(y_pred.shape[2])
]
return torch.stack(predictions, dim=2)
elif y_pred.ndim == 3:
return self._to_prediction_3d(y_pred)
else:
return y_pred

def _to_quantiles_3d(
self, y_pred: torch.Tensor, quantiles: list[float] = None
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -138,6 +159,35 @@ def to_quantiles(
f"prediction has 1 or more than 3 dimensions: {y_pred.ndim}"
)

def to_quantiles(
self, y_pred: torch.Tensor, quantiles: list[float] = None
) -> torch.Tensor:
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this docstring should be clarified - assumed input type, guaranteed output type

Convert network prediction into a quantile prediction.

Args:
y_pred: prediction output of network
quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as
as defined in the class initialization.

Returns:
torch.Tensor: prediction quantiles.
""" # noqa: E501
if y_pred.ndim == 4:
quantile_preds = [
self._to_quantiles_3d(y_pred[:, :, i, :], quantiles=quantiles)
for i in range(y_pred.shape[2])
]
return torch.stack(quantile_preds, dim=2)
elif y_pred.ndim == 3:
return self._to_quantiles_3d(y_pred, quantiles=quantiles)
elif y_pred.ndim == 2:
return self._to_quantiles_3d(y_pred, quantiles=quantiles)
else:
raise ValueError(
f"prediction has unsupported number of dimensions: {y_pred.ndim}"
)

def __add__(self, metric: LightningMetric):
composite_metric = CompositeMetric(metrics=[self])
new_metric = composite_metric + metric
Expand Down Expand Up @@ -320,6 +370,35 @@ def __len__(self) -> int:
"""
return len(self.metrics)

def _prepare_y_pred(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this docstring should be clarified - assumed input type, guaranteed output type

self, y_pred: Union[torch.Tensor, list[torch.Tensor]]
) -> list[torch.Tensor]:
"""
Ensure y_pred is a list of tensors.
"""
if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 4:
if y_pred.shape[2] != len(self.metrics):
raise ValueError(
f"The number of targets in the 4D prediction "
f"tensor ({y_pred.shape[2]}) "
f"does not match the number of metrics in "
f"MultiLoss ({len(self.metrics)})."
)
return list(torch.unbind(y_pred, dim=2))
elif isinstance(y_pred, (list, tuple)):
if len(y_pred) != len(self.metrics):
raise ValueError(
f"The number of predictions in the list ({len(y_pred)}) "
f"does not match the number of metrics in "
f"MultiLoss ({len(self.metrics)})."
)
return y_pred
else:
raise TypeError(
"y_pred for MultiLoss must be a list of tensors or a single 4D tensor, "
f"but got {type(y_pred)}"
)

def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs) -> None:
"""
Update composite metric
Expand All @@ -329,6 +408,7 @@ def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs) -> None
y_actual: actual values
**kwargs: arguments to update function
"""
y_pred = self._prepare_y_pred(y_pred)
for idx, metric in enumerate(self.metrics):
try:
metric.update(
Expand Down Expand Up @@ -372,6 +452,7 @@ def forward(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs):
Returns:
torch.Tensor: metric value on which backpropagation can be applied
"""
y_pred = self._prepare_y_pred(y_pred)
results = []
for idx, metric in enumerate(self.metrics):
try:
Expand Down Expand Up @@ -882,8 +963,39 @@ def update(self, y_pred, target):
dtype=torch.long,
device=target.device,
)
if target.ndim == 3:
num_targets = target.shape[2]
if y_pred.ndim == 4:
if y_pred.shape[2] != num_targets:
raise ValueError(
f"Target and 4D prediction have inconsistent number of targets:"
f" {num_targets} vs {y_pred.shape[2]}"
)
y_pred_slices = [y_pred[:, :, i, :] for i in range(num_targets)]
elif y_pred.ndim == 3:
y_pred_slices = [y_pred] * num_targets
else:
raise ValueError(
f"Unsupported prediction dimensionality {y_pred.ndim} for "
f"multi-target case."
)
target_losses = []
for i in range(num_targets):
pred_slice = y_pred_slices[i]
target_slice = target[:, :, i]
target_losses.append(self.loss(pred_slice, target_slice))
losses = torch.stack(target_losses, dim=-1).sum(dim=-1)

losses = self.loss(y_pred, target)
else:
if y_pred.ndim == 4:
if y_pred.shape[2] == 1:
y_pred = y_pred.squeeze(2)
else:
raise ValueError(
f"4D prediction ({y_pred.shape}) cannot be used with a single "
f"2D target ({target.shape})."
)
losses = self.loss(y_pred, target)
# weight samples
if weight is not None:
losses = losses * unsqueeze_like(weight, losses)
Expand Down Expand Up @@ -1046,7 +1158,9 @@ def loss(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor:
loss = -distribution.log_prob(y_actual)
return loss

def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Tensor:
def _to_prediction_3d(
self, y_pred: torch.Tensor, n_samples: int = 100
) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this docstring should be clarified - assumed input type, guaranteed output type

"""
Convert network prediction into a point prediction.

Expand All @@ -1062,6 +1176,25 @@ def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Ten
except NotImplementedError:
return self.sample(y_pred, n_samples=n_samples).mean(-1)

def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Tensor:
"""
Convert network prediction into a point prediction.

Args:
y_pred: prediction output of network
n_samples (int): number of samples to draw
Returns:
torch.Tensor: mean prediction
"""
if y_pred.ndim == 4:
predictions = [
self._to_prediction_3d(y_pred[:, :, i, :], n_samples=n_samples)
for i in range(y_pred.shape[2])
]
return torch.stack(predictions, dim=2)
else:
return self._to_prediction_3d(y_pred, n_samples=n_samples)

def sample(self, y_pred, n_samples: int) -> torch.Tensor:
"""
Sample from distribution.
Expand All @@ -1075,13 +1208,13 @@ def sample(self, y_pred, n_samples: int) -> torch.Tensor:
""" # noqa: E501
dist = self.map_x_to_distribution(y_pred)
samples = dist.sample((n_samples,))
if samples.ndim == 3:
samples = samples.permute(1, 2, 0)
if samples.ndim > 2:
samples = samples.permute(1, 2, 0, *range(3, samples.ndim))
elif samples.ndim == 2:
samples = samples.transpose(0, 1)
return samples

def to_quantiles(
def _to_quantiles_3d(
self, y_pred: torch.Tensor, quantiles: list[float] = None, n_samples: int = 100
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -1110,6 +1243,34 @@ def to_quantiles(
).permute(1, 2, 0)
return quantiles

def to_quantiles(
self, y_pred: torch.Tensor, quantiles: list[float] = None, n_samples: int = 100
) -> torch.Tensor:
"""
Convert network prediction into a quantile prediction.

Args:
y_pred: prediction output of network
quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as
as defined in the class initialization.
n_samples (int): number of samples to draw for quantiles. Defaults to 100.

Returns:
torch.Tensor: prediction quantiles (last dimension)
""" # noqa: E501
if y_pred.ndim == 4:
quantile_preds = [
self._to_quantiles_3d(
y_pred[:, :, i, :], quantiles=quantiles, n_samples=n_samples
)
for i in range(y_pred.shape[2])
]
return torch.stack(quantile_preds, dim=2)
else:
return self._to_quantiles_3d(
y_pred, quantiles=quantiles, n_samples=n_samples
)


class MultivariateDistributionLoss(DistributionLoss):
"""Base class for multivariate distribution losses.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_composite_metric():
],
)
def test_aggregation_metric(decoder_lengths, y):
y_pred = torch.tensor([[0.0, 2.0], [4.0, 3.0]])
y_pred = torch.tensor([[[0.0], [2.0]], [[4.0], [3.0]]])
if (decoder_lengths != y_pred.size(-1)).any():
y_packed = rnn.pack_padded_sequence(
y, lengths=decoder_lengths, batch_first=True, enforce_sorted=False
Expand Down
Loading