From 4de49a21f4a590796693f6d2943498639a50e095 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Thu, 19 Jun 2025 12:59:24 +0530 Subject: [PATCH 1/2] refactor metrics --- pytorch_forecasting/metrics/base_metrics.py | 177 +++++++++++++++++++- tests/test_metrics.py | 2 +- 2 files changed, 170 insertions(+), 9 deletions(-) diff --git a/pytorch_forecasting/metrics/base_metrics.py b/pytorch_forecasting/metrics/base_metrics.py index 9396a897c..e81da4886 100644 --- a/pytorch_forecasting/metrics/base_metrics.py +++ b/pytorch_forecasting/metrics/base_metrics.py @@ -3,7 +3,7 @@ """ import inspect -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import warnings from sklearn.base import BaseEstimator @@ -87,7 +87,7 @@ def rescale_parameters( """ # noqa: E501 return encoder(dict(prediction=parameters, target_scale=target_scale)) - def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor: + def _to_prediction_3d(self, y_pred: torch.Tensor) -> torch.Tensor: """ Convert network prediction into a point prediction. @@ -107,7 +107,28 @@ def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor: y_pred = y_pred.mean(-1) return y_pred - def to_quantiles( + def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor: + """ + Convert network prediction into a point prediction. + + Args: + y_pred: prediction output of network. + + Returns: + torch.Tensor: point prediction. + """ + if y_pred.ndim == 4: + predictions = [ + self._to_prediction_3d(y_pred[:, :, i, :]) + for i in range(y_pred.shape[2]) + ] + return torch.stack(predictions, dim=2) + elif y_pred.ndim == 3: + return self._to_prediction_3d(y_pred) + else: + return y_pred + + def _to_quantiles_3d( self, y_pred: torch.Tensor, quantiles: list[float] = None ) -> torch.Tensor: """ @@ -138,6 +159,35 @@ def to_quantiles( f"prediction has 1 or more than 3 dimensions: {y_pred.ndim}" ) + def to_quantiles( + self, y_pred: torch.Tensor, quantiles: list[float] = None + ) -> torch.Tensor: + """ + Convert network prediction into a quantile prediction. + + Args: + y_pred: prediction output of network + quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as + as defined in the class initialization. + + Returns: + torch.Tensor: prediction quantiles. + """ # noqa: E501 + if y_pred.ndim == 4: + quantile_preds = [ + self._to_quantiles_3d(y_pred[:, :, i, :], quantiles=quantiles) + for i in range(y_pred.shape[2]) + ] + return torch.stack(quantile_preds, dim=2) + elif y_pred.ndim == 3: + return self._to_quantiles_3d(y_pred, quantiles=quantiles) + elif y_pred.ndim == 2: + return self._to_quantiles_3d(y_pred, quantiles=quantiles) + else: + raise ValueError( + f"prediction has unsupported number of dimensions: {y_pred.ndim}" + ) + def __add__(self, metric: LightningMetric): composite_metric = CompositeMetric(metrics=[self]) new_metric = composite_metric + metric @@ -320,6 +370,35 @@ def __len__(self) -> int: """ return len(self.metrics) + def _prepare_y_pred( + self, y_pred: Union[torch.Tensor, list[torch.Tensor]] + ) -> list[torch.Tensor]: + """ + Ensure y_pred is a list of tensors. + """ + if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 4: + if y_pred.shape[2] != len(self.metrics): + raise ValueError( + f"The number of targets in the 4D prediction " + f"tensor ({y_pred.shape[2]}) " + f"does not match the number of metrics in " + f"MultiLoss ({len(self.metrics)})." + ) + return list(torch.unbind(y_pred, dim=2)) + elif isinstance(y_pred, (list, tuple)): + if len(y_pred) != len(self.metrics): + raise ValueError( + f"The number of predictions in the list ({len(y_pred)}) " + f"does not match the number of metrics in " + f"MultiLoss ({len(self.metrics)})." + ) + return y_pred + else: + raise TypeError( + "y_pred for MultiLoss must be a list of tensors or a single 4D tensor, " + f"but got {type(y_pred)}" + ) + def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs) -> None: """ Update composite metric @@ -329,6 +408,7 @@ def update(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs) -> None y_actual: actual values **kwargs: arguments to update function """ + y_pred = self._prepare_y_pred(y_pred) for idx, metric in enumerate(self.metrics): try: metric.update( @@ -372,6 +452,7 @@ def forward(self, y_pred: torch.Tensor, y_actual: torch.Tensor, **kwargs): Returns: torch.Tensor: metric value on which backpropagation can be applied """ + y_pred = self._prepare_y_pred(y_pred) results = [] for idx, metric in enumerate(self.metrics): try: @@ -882,8 +963,39 @@ def update(self, y_pred, target): dtype=torch.long, device=target.device, ) + if target.ndim == 3: + num_targets = target.shape[2] + if y_pred.ndim == 4: + if y_pred.shape[2] != num_targets: + raise ValueError( + f"Target and 4D prediction have inconsistent number of targets:" + f" {num_targets} vs {y_pred.shape[2]}" + ) + y_pred_slices = [y_pred[:, :, i, :] for i in range(num_targets)] + elif y_pred.ndim == 3: + y_pred_slices = [y_pred] * num_targets + else: + raise ValueError( + f"Unsupported prediction dimensionality {y_pred.ndim} for " + f"multi-target case." + ) + target_losses = [] + for i in range(num_targets): + pred_slice = y_pred_slices[i] + target_slice = target[:, :, i] + target_losses.append(self.loss(pred_slice, target_slice)) + losses = torch.stack(target_losses, dim=-1).sum(dim=-1) - losses = self.loss(y_pred, target) + else: + if y_pred.ndim == 4: + if y_pred.shape[2] == 1: + y_pred = y_pred.squeeze(2) + else: + raise ValueError( + f"4D prediction ({y_pred.shape}) cannot be used with a single " + f"2D target ({target.shape})." + ) + losses = self.loss(y_pred, target) # weight samples if weight is not None: losses = losses * unsqueeze_like(weight, losses) @@ -1046,7 +1158,9 @@ def loss(self, y_pred: torch.Tensor, y_actual: torch.Tensor) -> torch.Tensor: loss = -distribution.log_prob(y_actual) return loss - def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Tensor: + def _to_prediction_3d( + self, y_pred: torch.Tensor, n_samples: int = 100 + ) -> torch.Tensor: """ Convert network prediction into a point prediction. @@ -1062,6 +1176,25 @@ def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Ten except NotImplementedError: return self.sample(y_pred, n_samples=n_samples).mean(-1) + def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Tensor: + """ + Convert network prediction into a point prediction. + + Args: + y_pred: prediction output of network + n_samples (int): number of samples to draw + Returns: + torch.Tensor: mean prediction + """ + if y_pred.ndim == 4: + predictions = [ + self._to_prediction_3d(y_pred[:, :, i, :], n_samples=n_samples) + for i in range(y_pred.shape[2]) + ] + return torch.stack(predictions, dim=2) + else: + return self._to_prediction_3d(y_pred, n_samples=n_samples) + def sample(self, y_pred, n_samples: int) -> torch.Tensor: """ Sample from distribution. @@ -1075,13 +1208,13 @@ def sample(self, y_pred, n_samples: int) -> torch.Tensor: """ # noqa: E501 dist = self.map_x_to_distribution(y_pred) samples = dist.sample((n_samples,)) - if samples.ndim == 3: - samples = samples.permute(1, 2, 0) + if samples.ndim > 2: + samples = samples.permute(1, 2, 0, *range(3, samples.ndim)) elif samples.ndim == 2: samples = samples.transpose(0, 1) return samples - def to_quantiles( + def _to_quantiles_3d( self, y_pred: torch.Tensor, quantiles: list[float] = None, n_samples: int = 100 ) -> torch.Tensor: """ @@ -1110,6 +1243,34 @@ def to_quantiles( ).permute(1, 2, 0) return quantiles + def to_quantiles( + self, y_pred: torch.Tensor, quantiles: list[float] = None, n_samples: int = 100 + ) -> torch.Tensor: + """ + Convert network prediction into a quantile prediction. + + Args: + y_pred: prediction output of network + quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as + as defined in the class initialization. + n_samples (int): number of samples to draw for quantiles. Defaults to 100. + + Returns: + torch.Tensor: prediction quantiles (last dimension) + """ # noqa: E501 + if y_pred.ndim == 4: + quantile_preds = [ + self._to_quantiles_3d( + y_pred[:, :, i, :], quantiles=quantiles, n_samples=n_samples + ) + for i in range(y_pred.shape[2]) + ] + return torch.stack(quantile_preds, dim=2) + else: + return self._to_quantiles_3d( + y_pred, quantiles=quantiles, n_samples=n_samples + ) + class MultivariateDistributionLoss(DistributionLoss): """Base class for multivariate distribution losses. diff --git a/tests/test_metrics.py b/tests/test_metrics.py index ac9fc27c9..7a27db75f 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -59,7 +59,7 @@ def test_composite_metric(): ], ) def test_aggregation_metric(decoder_lengths, y): - y_pred = torch.tensor([[0.0, 2.0], [4.0, 3.0]]) + y_pred = torch.tensor([[[0.0], [2.0]], [[4.0], [3.0]]]) if (decoder_lengths != y_pred.size(-1)).any(): y_packed = rnn.pack_padded_sequence( y, lengths=decoder_lengths, batch_first=True, enforce_sorted=False From 312ca7bf46be3947eae0da54feaaca4dcbe8465d Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sun, 22 Jun 2025 15:01:07 +0530 Subject: [PATCH 2/2] add docstrings and debug --- pytorch_forecasting/metrics/base_metrics.py | 320 +++++++++++++++---- pytorch_forecasting/metrics/distributions.py | 8 +- 2 files changed, 256 insertions(+), 72 deletions(-) diff --git a/pytorch_forecasting/metrics/base_metrics.py b/pytorch_forecasting/metrics/base_metrics.py index e81da4886..cbdd4ab72 100644 --- a/pytorch_forecasting/metrics/base_metrics.py +++ b/pytorch_forecasting/metrics/base_metrics.py @@ -88,14 +88,29 @@ def rescale_parameters( return encoder(dict(prediction=parameters, target_scale=target_scale)) def _to_prediction_3d(self, y_pred: torch.Tensor) -> torch.Tensor: - """ - Convert network prediction into a point prediction. + """Convert network prediction into a point prediction. - Args: - y_pred: prediction output of network + This is an internal helper method. - Returns: - torch.Tensor: point prediction + Parameters + ---------- + y_pred: prediction output of network + it can either be 2D or 3D: + + - if 2D [batch, time]: returns `y_pred` as is + - if 3D [batch, time, params]: + + - if `self.quantiles` is None: + it assumes the last dimension is 1 and removes it, returning a + 2D tensor [batch, time]. + - if `self.quantiles` is not None: + it takes the mean along the last dimension, returning a + 2D tensor [batch, time]. + + Returns + ------- + torch.Tensor: point prediction + 2D point prediction tensor [batch, time]. """ if y_pred.ndim == 3: if self.quantiles is None: @@ -108,14 +123,28 @@ def _to_prediction_3d(self, y_pred: torch.Tensor) -> torch.Tensor: return y_pred def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor: - """ - Convert network prediction into a point prediction. + """Convert network prediction into a point prediction. - Args: - y_pred: prediction output of network. + Parameters + ---------- + y_pred: prediction output of network + it can be 2D, 3D or 4D: - Returns: - torch.Tensor: point prediction. + - if 4D [batch, time, num_targets (or output_channels), params]: + The method iterates over the `num_targets` dimension. + For each target, it passes a 3D slice [batch, time, params] to + `self._to_prediction_3d`. + - if 3D [batch, time, params]: + directly pass the `y_pred` to `self._to_prediction_3d`. + - if 2D [batch, time]: + return `y_pred` as is + + Returns + ------- + torch.Tensor: point prediction + + - For a 4D input, the output is a 3D tensor [batch, time, num_targets]. + - For a 3D or 2D input, the output is a 2D tensor [batch, time]. """ if y_pred.ndim == 4: predictions = [ @@ -131,16 +160,31 @@ def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor: def _to_quantiles_3d( self, y_pred: torch.Tensor, quantiles: list[float] = None ) -> torch.Tensor: - """ - Convert network prediction into a quantile prediction. + """Convert network prediction into a quantile prediction. - Args: - y_pred: prediction output of network - quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as - as defined in the class initialization. + This is an internal helper method. - Returns: - torch.Tensor: prediction quantiles + Parameters + ---------- + y_pred: prediction output of network + it can either be 2D or 3D: + + - if 2D [batch, time]: returns `y_pred` after it is unsqueezed to + [batch, time, 1]. + - if 3D [batch, time, params]: + + - If `params > 1`, it assumes `params` are samples and calculates + the specified `quantiles` along this dimension. + - If `params == 1`, it is treated as a single quantile forecast and + returned as is. + + quantiles (List[float], optional): quantiles for probability range. + Defaults to `self.quantiles`. + + Returns + ------- + torch.Tensor: prediction quantiles + 3D prediction quantiles tensor [batch, time, n_quantiles]. """ # noqa: E501 if quantiles is None: quantiles = self.quantiles @@ -162,16 +206,30 @@ def _to_quantiles_3d( def to_quantiles( self, y_pred: torch.Tensor, quantiles: list[float] = None ) -> torch.Tensor: - """ - Convert network prediction into a quantile prediction. - - Args: - y_pred: prediction output of network - quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as - as defined in the class initialization. - - Returns: - torch.Tensor: prediction quantiles. + """Convert network prediction into a quantile prediction. + + Parameters + ---------- + y_pred: prediction output of network + it can be 2D, 3D or 4D: + + - if 4D [batch, time, num_targets (or output_channels), params]: The method + iterates over the `num_targets` dimension. For each target, it passes a + 3D slice [batch, time, params] to `self._to_quantiles_3d`. + - if 3D [batch, time, params] or 2D [batch, time]: The tensor is + passed directly to `self._to_quantiles_3d`. + + quantiles (List[float], optional): quantiles for probability range. + Defaults to `self.quantiles`. + + Returns + ------- + torch.Tensor: prediction quantiles + + - For a 4D input, the output is a 4D tensor + [batch, time, num_targets, n_quantiles]. + - For a 3D or 2D input, the output is a 3D tensor + [batch, time, n_quantiles]. """ # noqa: E501 if y_pred.ndim == 4: quantile_preds = [ @@ -179,9 +237,7 @@ def to_quantiles( for i in range(y_pred.shape[2]) ] return torch.stack(quantile_preds, dim=2) - elif y_pred.ndim == 3: - return self._to_quantiles_3d(y_pred, quantiles=quantiles) - elif y_pred.ndim == 2: + elif y_pred.ndim == 3 or y_pred.ndim == 2: return self._to_quantiles_3d(y_pred, quantiles=quantiles) else: raise ValueError( @@ -375,6 +431,29 @@ def _prepare_y_pred( ) -> list[torch.Tensor]: """ Ensure y_pred is a list of tensors. + + Parameters + ---------- + y_pred: prediction output of network + it can be either 4D or a list/tuple of multiple predictions: + + - if 4D [batch, time, num_target, params]: + here num_target >1 + checks if num_target = number of metrics provided MultiLoss instance: + + - if true: + returns a list with each element of 3D shape tensor + [batch, time, params] and size of list is num_target + else: ValueError + + - if y_pred is list: + This means there are predicitions for each metric, so return the y_pred + as it is + + Returns + ------- + list of predictions for each metric in MultiLoss + """ if isinstance(y_pred, torch.Tensor) and y_pred.ndim == 4: if y_pred.shape[2] != len(self.metrics): @@ -938,12 +1017,50 @@ def update(self, y_pred, target): Do not override this method but :py:meth:`~loss` instead - Args: - y_pred (Dict[str, torch.Tensor]): network output - target (Union[torch.Tensor, rnn.PackedSequence]): actual values + Parameters + ---------- + y_pred : torch.Tensor + The prediction tensor from the model. It can be 3D or 4D: - Returns: - torch.Tensor: loss as a single number for backpropagation + - if 3D [batch, time, params]: + + - If `target` is 2D, this is treated as a standard single-target + forecast. The `loss()` method is called once with the full tensors. + - If `target` is 3D, this is treated as a broadcasted multi-target + forecast. This single 3D prediction is evaluated against each + of the targets defined in the `target` tensor. + + - 4D [batch, time, targets, params]: + + - If `target` is 3D, this is the multi-target forecast. + The method iterates through the `targets` dimension, calling `loss()` + for each `[batch, time, params]` slice. + - If `target` is 2D, this will raise a `ValueError` unless the + `targets` dimension has a size of 1, in which case it's squeezed + and treated as a single-target forecast. + + target : Union[torch.Tensor, rnn.PackedSequence, tuple] + actual values. It can: + + - A tensor: + + - 2D tensor [batch, time]: Defines a single-target problem. + The metric will call `self.loss()` once. + - 3D tensor [batch, time, targets]: Defines a multi-target + problem. The metric will loop over the last dimension and call + `self.loss()` for each target. + + - A `torch.nn.utils.rnn.PackedSequence`: unpack this object to get the + required tensors + + - A tuple `(target_tensor, weight_tensor)`: `target_tensor` is one of the + tensor formats above, and `weight_tensor` is a `[batch]` or + `[batch, time]` tensor whose values are multiplied with the calculated + loss for each sample. + + Returns + ------- + torch.Tensor: loss as a single number for backpropagation """ # unpack weight if isinstance(target, (list, tuple)) and not isinstance( @@ -1164,11 +1281,25 @@ def _to_prediction_3d( """ Convert network prediction into a point prediction. - Args: - y_pred: prediction output of network - n_samples (int): number of samples to draw - Returns: - torch.Tensor: mean prediction + This is an internal helper method. + The method first attempts to create an explicit probability distribution + via `self.map_x_to_distribution()`. + + - If successful, it returns the mean (`distribution.mean`) of + that distribution. + - If `map_x_to_distribution()` is not implemented, + it falls back to empirical estimation by drawing `n_samples` and + calculating their mean. + + Parameters + ---------- + y_pred: prediction output of network + A 3D prediction tensor of shape `[batch_size, time_steps, n_params]` + n_samples (int): number of samples to draw + Returns + ------- + torch.Tensor: mean prediction + 2D tensor [batch, time] """ distribution = self.map_x_to_distribution(y_pred) try: @@ -1180,11 +1311,26 @@ def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Ten """ Convert network prediction into a point prediction. - Args: - y_pred: prediction output of network - n_samples (int): number of samples to draw - Returns: - torch.Tensor: mean prediction + Parameters + ---------- + y_pred: prediction output of network + it can be 3D or 4D: + + - if 4D [batch, time, num_targets (or output_channels), params]: + The method iterates over the `num_targets` dimension. + For each target, it passes a 3D slice [batch, time, params] to + `self._to_prediction_3d`. + - if 3D [batch, time, params]: + directly pass the `y_pred` to `self._to_prediction_3d`. + + n_samples (int): number of samples to draw + Returns + ------- + torch.Tensor: mean prediction + + - For a 4D input, the output is a 3D tensor [batch, time, num_targets]. + - For a 3D input, the output is a 2D tensor [batch, time]. + """ if y_pred.ndim == 4: predictions = [ @@ -1199,12 +1345,20 @@ def sample(self, y_pred, n_samples: int) -> torch.Tensor: """ Sample from distribution. - Args: - y_pred: prediction output of network (shape batch_size x n_timesteps x n_paramters) - n_samples (int): number of samples to draw - - Returns: - torch.Tensor: tensor with samples (shape batch_size x n_timesteps x n_samples) + Parameters + ---------- + y_pred : torch.Tensor + A 3D prediction tensor of shape [batch_size, time_steps, n_params] + containing the parameters of the forecast distribution. + n_samples : int + The number of random samples to draw from the distribution for each + point in the [batch_size, time_steps] grid. + + Returns + ------- + torch.Tensor + A 3D tensor of shape [batch_size, time_steps, n_samples] containing + the drawn samples. """ # noqa: E501 dist = self.map_x_to_distribution(y_pred) samples = dist.sample((n_samples,)) @@ -1220,14 +1374,29 @@ def _to_quantiles_3d( """ Convert network prediction into a quantile prediction. - Args: - y_pred: prediction output of network - quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as - as defined in the class initialization. - n_samples (int): number of samples to draw for quantiles. Defaults to 100. - - Returns: - torch.Tensor: prediction quantiles (last dimension) + The method first attempts to use the analytical inverse CDF (`icdf`) of + the distribution defined by `self.map_x_to_distribution()`. + + - If implemented, it calculates the quantiles directly. + - If `map_x_to_distribution()` or `icdf()` is not implemented, it falls back to + an empirical estimation. + + Parameters + ---------- + y_pred : torch.Tensor + A 3D prediction tensor of shape [batch_size, time_steps, n_params] + for a single target. + quantiles : list[float], optional + quantiles for probability range. Defaults to quantiles as + as defined in the class initialization. + n_samples : int, default=100 + number of samples to draw for quantiles. + + Returns + ------- + torch.Tensor + A 3D tensor of shape [batch_size, time_steps, n_quantiles] containing + the predicted values for each requested quantile. """ # noqa: E501 if quantiles is None: quantiles = self.quantiles @@ -1249,14 +1418,27 @@ def to_quantiles( """ Convert network prediction into a quantile prediction. - Args: - y_pred: prediction output of network - quantiles (List[float], optional): quantiles for probability range. Defaults to quantiles as - as defined in the class initialization. - n_samples (int): number of samples to draw for quantiles. Defaults to 100. + Parameters + ---------- + y_pred: prediction output of network + it can be 3D or 4D: - Returns: - torch.Tensor: prediction quantiles (last dimension) + - if 4D [batch, time, num_targets (or output_channels), params]: + The method iterates over the `num_targets` dimension. + For each target, it passes a 3D slice [batch, time, params] to + `self._to_quantiles_3d`. + - if 3D [batch, time, params]: + directly pass the `y_pred` to `self._to_quantiles_3d`. + + n_samples (int): number of samples to draw + + Returns + ------- + torch.Tensor: mean prediction + + - For a 4D input, the output is a 4D tensor + [batch, time, targets, n_quantiles]. + - For a 3D input, the output is a 3D tensor [batch, time, n_quantiles]. """ # noqa: E501 if y_pred.ndim == 4: quantile_preds = [ diff --git a/pytorch_forecasting/metrics/distributions.py b/pytorch_forecasting/metrics/distributions.py index 8db52651a..e46e3de5a 100644 --- a/pytorch_forecasting/metrics/distributions.py +++ b/pytorch_forecasting/metrics/distributions.py @@ -24,7 +24,7 @@ class NormalDistributionLoss(DistributionLoss): distribution_arguments = ["loc", "scale"] def map_x_to_distribution(self, x: torch.Tensor) -> distributions.Normal: - distr = self.distribution_class(loc=x[..., 2], scale=x[..., 3]) + distr = self.distribution_class(loc=x[..., 2], scale=F.softplus(x[..., 3])) scaler = distributions.AffineTransform(loc=x[..., 0], scale=x[..., 1]) if self._transformation is None: return distributions.TransformedDistribution(distr, [scaler]) @@ -47,7 +47,7 @@ def rescale_parameters( ) -> torch.Tensor: self._transformation = encoder.transformation loc = parameters[..., 0] - scale = F.softplus(parameters[..., 1]) + scale = parameters[..., 1] return torch.concat( [ target_scale.unsqueeze(1).expand(-1, loc.size(1), -1), @@ -629,7 +629,9 @@ def rescale_parameters( dim=-1, ) - def to_prediction(self, y_pred: torch.Tensor, n_samples: int = 100) -> torch.Tensor: + def _to_prediction_3d( + self, y_pred: torch.Tensor, n_samples: int = 100 + ) -> torch.Tensor: if n_samples is None: return self.to_quantiles(y_pred, quantiles=[0.5]).squeeze(-1) else: