diff --git a/pytorch_forecasting/models/base/_base_model_v2.py b/pytorch_forecasting/models/base/_base_model_v2.py index aceec0869..0563b88d4 100644 --- a/pytorch_forecasting/models/base/_base_model_v2.py +++ b/pytorch_forecasting/models/base/_base_model_v2.py @@ -294,3 +294,184 @@ def log_metrics( prog_bar=True, logger=True, ) + + def standardize_model_output( + self, + prediction: torch.Tensor, + expected_dims: tuple[int] = None, # noqa: E501 + ) -> torch.Tensor: + """ + Standardize model outputs to a 4-dimensional tensor, with shape + (batch_size, timesteps, num_features, last_dim). + + Parameters + ---------- + prediction : torch.Tensor + The raw prediction tensor from the model. + + - Must be a torch.Tensor (in the future, also accept a list of tensors for + multi-target forecasting). + - Supported dims: 2D, 3D or 4D tensors. + - if 2D: (batch_size, timesteps) - univariate forecasting + - if 3D: + + a) (batch_size, timesteps, n_targets) - multivariate forecasting + b) (batch_size, timesteps, last_dim) - univariate forecasting with quantiles or distribution. + c) (batch_size, timesteps, n_targets * last_dim) - multivariate + forecasting with quantiles, where features and quantiles are flattened in dim 2. + + - if 4D: (batch_size, timesteps, n_targets, last_dim) - multivariate + forecasting with quantiles or distribution parameters. + - In the future, once multi-target forecasting with ``MultiLoss`` is supported, this + will also accept a list of tensors, where each tensor inside the list + is treated as above. Note: In this case, each tensor in the list + will have n_targets = 1, as each tensor corresponds to a single target. + - If anything apart from the above dimensions is provided, an error is raised. + + expected_dims : tuple[int], default= None + A tuple specifying the dimensions: (batch_size, timesteps, n_targets, last_dim). + + batch_size : Optional[int], default=None + + - Position 1: Expected batch size + - When specified: Validates prediction.shape[0] + - When None: Uses actual tensor dimension + + timesteps : Optional[int], default=None + + - Position 2: Expected number of timesteps + - When specified: Validates prediction.shape[1] + - When None: Uses actual tensor dimension + + n_targets : int + + - Position 0: Number of target features + - Must be provided explicitly (cannot be None) + - Used for reshaping 2D and 3D tensors to 4D. + + last_dim : Optional[int], default=None + + - Position 3: Size of the last dimension. + - Common use case - quantile, sample, distribution params. + - When it is specified, it is used to directly reshape. + - When None and model uses QuantileLoss: It is set to the number of quantiles + - When None and no quantile information is available: It defaults to 1. + - If required, this can be extended to handle other cases where the last_dim is None + but its value can be inferred from the loss function or model configuration (apart from + the existing QuantileLoss case, of course). + + Returns + ------- + torch.Tensor + The standardized prediction tensor with shape (batch_size, timesteps, n_targets, last_dim). + The prediction tensor is obtained by reshaping the input tensor. There are + several cases to consider: + + - If the input tensor is 2D, it is reshaped to (batch_size, timesteps, n_targets, 1). + - If the input tensor is 3D, it is reshaped to (batch_size, timesteps, n_targets, 1) for a + non-quantile forecast, or to (batch_size, timesteps, n_targets, last_dim) in case of quantile/distribution. + - If the input tensor is 4D, it is assumed to be in the shape + (batch_size, timesteps, n_targets, last_dim) or (batch_size, timesteps, last_dim, n_targets). + and is reshaped to (batch_size, timesteps, n_targets, last_dim) if needed + by permuting the last two dimensions. + + Notes + ----- + [1] The fourth dimension (last_dim) commonly represents: + + * Quantiles: For quantile regression (e.g., 0.1, 0.5, 0.9) + * Distribution parameters: For parametric forecasts (e.g., mean, variance) + * Samples: For sample-based uncertainty estimates + + The current implementation assumes the most common case of quantile forecasts + when automatically inferring this dimension from the loss function, + but any value can be explicitly provided. A value of 1 is used in case where + no information is available on ``last_dim``. + + [2] This can currently handle situations where a single target is used + either in a univariate or multivariate situation and multiple-targets using the + same loss function. + + In case of multi-target forecasting with separate loss functions for each target, + the input tensor is expected to be a list of tensors. This is not yet supported + in this function, but it is planned for the future. + """ # noqa: E501 + + n_targets, batch_size, timesteps, last_dim = expected_dims + + if not isinstance(prediction, torch.Tensor): + raise TypeError( + f"Expected prediction to be a torch.Tensor, but got {type(prediction)}" + ) + + if n_targets is None: + raise ValueError( + "Expected n_targets to be a positive integer, but got `None`." + ) + + if last_dim is None: + if hasattr(self.loss, "quantiles") and self.loss.quantiles is not None: + last_dim = len(self.loss.quantiles) # Quantile regression case + # we can add more cases here in the future, where we refer to the specific + # loss function to determine the last dimension. For now we are sticking + # to the quantile regression case. + else: + last_dim = 1 + + if batch_size is not None: + if prediction.shape[0] != batch_size: + raise ValueError( + f"Expected batch size {batch_size}, but got {prediction.shape[0]}." + ) + + if timesteps is not None: + if prediction.shape[1] != timesteps: + raise ValueError( + f"Expected timesteps {timesteps}, but got {prediction.shape[1]}." + ) + + if prediction.ndim == 2: + # reshape to (batch_size, timsteps, 1, 1) + prediction = prediction.unsqueeze(-1).unsqueeze(-1) + + elif prediction.ndim == 3: + if prediction.shape[2] == n_targets: + # reshape to (batch_size, timesteps, n_targets, 1) + prediction = prediction.unsqueeze(-1) + elif prediction.shape[2] == last_dim: + # reshape to (batch_size, timesteps, 1, last_dim) + prediction = prediction.unsqueeze(2) + elif prediction.shape[2] == n_targets * last_dim: + # multivariate forecast with quantiles + # where features and quantiles are flattened in dim 2. + # reshape to (batch_size, timesteps, n_targets, last_dim) + prediction = prediction.reshape( + prediction.shape[0], prediction.shape[1], n_targets, last_dim + ) + else: + # reshape to (batch_size, timesteps, n_targets, last_dim) + prediction = prediction.unsqueeze(-1) + + elif prediction.ndim == 4: + # assuming only a single case where n_targets and last_dim are swapped. + if prediction.shape[2] == last_dim and prediction.shape[3] == n_targets: + # reshape to (batch_size, timesteps, n_targets, last_dim) + warn( + "Prediction tensor has shape (batch_size, timesteps, last_dim, n_targets). " # noqa: E501 + "This is not the expected shape. Transposing the last two dimensions." # noqa: E501 + ) + prediction = prediction.permute(0, 1, 3, 2) + + else: + raise ValueError( + f"Expected prediction tensor to have 2, 3, or 4 dimensions, " + f"but got {prediction.ndim} dimensions." + ) + + # final check to ensure the output is 4D + if prediction.ndim != 4: + raise ValueError( + f"Failed to standardize output to 4D tensor. Current shape: {prediction.shape}" # noqa: E501 + ) + + return prediction