-
Notifications
You must be signed in to change notification settings - Fork 689
[ENH] Refactor metrics to be polymorphic #1897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
""" | ||
|
||
import inspect | ||
from typing import Any, Callable, Optional | ||
from typing import Any, Callable, Optional, Union | ||
import warnings | ||
|
||
from sklearn.base import BaseEstimator | ||
|
@@ -87,7 +87,7 @@ def rescale_parameters( | |
""" # noqa: E501 | ||
return encoder(dict(prediction=parameters, target_scale=target_scale)) | ||
|
||
def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor: | ||
def _to_prediction_3d(self, y_pred: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Convert network prediction into a point prediction. | ||
|
||
|
@@ -107,7 +107,28 @@ def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor: | |
y_pred = y_pred.mean(-1) | ||
return y_pred | ||
|
||
def to_quantiles( | ||
def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Convert network prediction into a point prediction. | ||
|
||
Args: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this docstring should be clarified - assumed input type, guaranteed output type |
||
y_pred: prediction output of network. | ||
|
||
Returns: | ||
torch.Tensor: point prediction. | ||
""" | ||
if y_pred.ndim == 4: | ||
predictions = [ | ||
self._to_prediction_3d(y_pred[:, :, i, :]) | ||
for i in range(y_pred.shape[2]) | ||
] | ||
return torch.stack(predictions, dim=2) | ||
elif y_pred.ndim == 3: | ||
return self._to_prediction_3d(y_pred) | ||
else: | ||
return y_pred | ||
|
||
def _to_quantiles_3d( | ||
self, y_pred: torch.Tensor, quantiles: list[float] = None | ||
) -> torch.Tensor: | ||
""" | ||
|
@@ -138,6 +159,35 @@ def to_quantiles( | |
f"prediction has 1 or more than 3 dimensions: {y_pred.ndim}" | ||
) | ||
|
||
def to_quantiles( | ||
self, y_pred: torch.Tensor, quantiles: list[float] = None | ||
) -> torch.Tensor: | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this docstring should be clarified - assumed input type, guaranteed output type |
||
Convert network prediction into a quantile prediction. | ||
|
||
Args: | ||
y_pred: prediction output of network | ||
quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as | ||
as defined in the class initialization. | ||
|
||
Returns: | ||
torch.Tensor: prediction quantiles. | ||
""" # noqa: E501 | ||
if y_pred.ndim == 4: | ||
quantile_preds = [ | ||
self._to_quantiles_3d(y_pred[:, :, i, :], quantiles=quantiles) | ||
for i in range(y_pred.shape[2]) | ||
] | ||
return torch.stack(quantile_preds, dim=2) | ||
elif y_pred.ndim == 3: | ||
return self._to_quantiles_3d(y_pred, quantiles=quantiles) | ||
elif y_pred.ndim == 2: | ||
return self._to_quantiles_3d(y_pred, quantiles=quantiles) | ||
else: | ||
raise ValueError( | ||
f"prediction has unsupported number of dimensions: {y_pred.ndim}" | ||
) | ||
|
||
def __add__(self, metric: LightningMetric): | ||
composite_metric = CompositeMetric(metrics=[self]) | ||
new_metric = composite_metric + metric | ||
|
@@ -320,6 +370,35 @@ def __len__(self) -> int: | |
""" | ||
return len(self.metrics) | ||
|
||
def _prepare_y_pred( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this docstring should be clarified - assumed input type, guaranteed output type |
||
self, y_pred: Union[torch.Tensor, list[torch.Tensor]] | ||
) -> list[torch.Tensor]: | ||
""" | ||
Ensure y_pred is a list of tensors. | ||
""" | ||
if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 4: | ||
if y_pred.shape[2] != len(self.metrics): | ||
raise ValueError( | ||
f"The number of targets in the 4D prediction " | ||
f"tensor ({y_pred.shape[2]}) " | ||
f"does not match the number of metrics in " | ||
f"MultiLoss ({len(self.metrics)})." | ||
) | ||
return list(torch.unbind(y_pred, dim=2)) | ||
elif isinstance(y_pred, (list, tuple)): | ||
if len(y_pred) != len(self.metrics): | ||
raise ValueError( | ||
f"The number of predictions in the list ({len(y_pred)}) " | ||
f"does not match the number of metrics in " | ||
f"MultiLoss ({len(self.metrics)})." | ||
) | ||
return y_pred | ||
else: | ||
raise TypeError( | ||
"y_pred for MultiLoss must be a list of tensors or a single 4D tensor, " | ||
f"but got {type(y_pred)}" | ||
) | ||
|
||
def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs) -> None: | ||
""" | ||
Update composite metric | ||
|
@@ -329,6 +408,7 @@ def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs) -> None | |
y_actual: actual values | ||
**kwargs: arguments to update function | ||
""" | ||
y_pred = self._prepare_y_pred(y_pred) | ||
for idx, metric in enumerate(self.metrics): | ||
try: | ||
metric.update( | ||
|
@@ -372,6 +452,7 @@ def forward(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs): | |
Returns: | ||
torch.Tensor: metric value on which backpropagation can be applied | ||
""" | ||
y_pred = self._prepare_y_pred(y_pred) | ||
results = [] | ||
for idx, metric in enumerate(self.metrics): | ||
try: | ||
|
@@ -882,8 +963,39 @@ def update(self, y_pred, target): | |
dtype=torch.long, | ||
device=target.device, | ||
) | ||
if target.ndim == 3: | ||
num_targets = target.shape[2] | ||
if y_pred.ndim == 4: | ||
if y_pred.shape[2] != num_targets: | ||
raise ValueError( | ||
f"Target and 4D prediction have inconsistent number of targets:" | ||
f" {num_targets} vs {y_pred.shape[2]}" | ||
) | ||
y_pred_slices = [y_pred[:, :, i, :] for i in range(num_targets)] | ||
elif y_pred.ndim == 3: | ||
y_pred_slices = [y_pred] * num_targets | ||
else: | ||
raise ValueError( | ||
f"Unsupported prediction dimensionality {y_pred.ndim} for " | ||
f"multi-target case." | ||
) | ||
target_losses = [] | ||
for i in range(num_targets): | ||
pred_slice = y_pred_slices[i] | ||
target_slice = target[:, :, i] | ||
target_losses.append(self.loss(pred_slice, target_slice)) | ||
losses = torch.stack(target_losses, dim=-1).sum(dim=-1) | ||
|
||
losses = self.loss(y_pred, target) | ||
else: | ||
if y_pred.ndim == 4: | ||
if y_pred.shape[2] == 1: | ||
y_pred = y_pred.squeeze(2) | ||
else: | ||
raise ValueError( | ||
f"4D prediction ({y_pred.shape}) cannot be used with a single " | ||
f"2D target ({target.shape})." | ||
) | ||
losses = self.loss(y_pred, target) | ||
# weight samples | ||
if weight is not None: | ||
losses = losses * unsqueeze_like(weight, losses) | ||
|
@@ -1046,7 +1158,9 @@ def loss(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor: | |
loss = -distribution.log_prob(y_actual) | ||
return loss | ||
|
||
def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Tensor: | ||
def _to_prediction_3d( | ||
self, y_pred: torch.Tensor, n_samples: int = 100 | ||
) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this docstring should be clarified - assumed input type, guaranteed output type |
||
""" | ||
Convert network prediction into a point prediction. | ||
|
||
|
@@ -1062,6 +1176,25 @@ def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Ten | |
except NotImplementedError: | ||
return self.sample(y_pred, n_samples=n_samples).mean(-1) | ||
|
||
def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Tensor: | ||
""" | ||
Convert network prediction into a point prediction. | ||
|
||
Args: | ||
y_pred: prediction output of network | ||
n_samples (int): number of samples to draw | ||
Returns: | ||
torch.Tensor: mean prediction | ||
""" | ||
if y_pred.ndim == 4: | ||
predictions = [ | ||
self._to_prediction_3d(y_pred[:, :, i, :], n_samples=n_samples) | ||
for i in range(y_pred.shape[2]) | ||
] | ||
return torch.stack(predictions, dim=2) | ||
else: | ||
return self._to_prediction_3d(y_pred, n_samples=n_samples) | ||
|
||
def sample(self, y_pred, n_samples: int) -> torch.Tensor: | ||
""" | ||
Sample from distribution. | ||
|
@@ -1075,13 +1208,13 @@ def sample(self, y_pred, n_samples: int) -> torch.Tensor: | |
""" # noqa: E501 | ||
dist = self.map_x_to_distribution(y_pred) | ||
samples = dist.sample((n_samples,)) | ||
if samples.ndim == 3: | ||
samples = samples.permute(1, 2, 0) | ||
if samples.ndim > 2: | ||
samples = samples.permute(1, 2, 0, *range(3, samples.ndim)) | ||
elif samples.ndim == 2: | ||
samples = samples.transpose(0, 1) | ||
return samples | ||
|
||
def to_quantiles( | ||
def _to_quantiles_3d( | ||
self, y_pred: torch.Tensor, quantiles: list[float] = None, n_samples: int = 100 | ||
) -> torch.Tensor: | ||
""" | ||
|
@@ -1110,6 +1243,34 @@ def to_quantiles( | |
).permute(1, 2, 0) | ||
return quantiles | ||
|
||
def to_quantiles( | ||
self, y_pred: torch.Tensor, quantiles: list[float] = None, n_samples: int = 100 | ||
) -> torch.Tensor: | ||
""" | ||
Convert network prediction into a quantile prediction. | ||
|
||
Args: | ||
y_pred: prediction output of network | ||
quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as | ||
as defined in the class initialization. | ||
n_samples (int): number of samples to draw for quantiles. Defaults to 100. | ||
|
||
Returns: | ||
torch.Tensor: prediction quantiles (last dimension) | ||
""" # noqa: E501 | ||
if y_pred.ndim == 4: | ||
quantile_preds = [ | ||
self._to_quantiles_3d( | ||
y_pred[:, :, i, :], quantiles=quantiles, n_samples=n_samples | ||
) | ||
for i in range(y_pred.shape[2]) | ||
] | ||
return torch.stack(quantile_preds, dim=2) | ||
else: | ||
return self._to_quantiles_3d( | ||
y_pred, quantiles=quantiles, n_samples=n_samples | ||
) | ||
|
||
|
||
class MultivariateDistributionLoss(DistributionLoss): | ||
"""Base class for multivariate distribution losses. | ||
|
Uh oh!
There was an error while loading. Please reload this page.