Skip to content

[ENH] standardize model output format to (batch, timesteps, features, quantiles) in v2 #1894

Open
@PranavBhatP

Description

@PranavBhatP

Is your feature request related to a problem? Please describe.
In its current state, each model in PTF has its own format for the output tensor returned by the forward method. The most common format is (batch_size, timesteps, n_features) for models which support multivariate forecasting and (batch_size, timesteps) for univariate forecasts. This inconsistency is something we can rework with v2.

Mainly, the issues are:

  • loss functions need to handle different input formats for every type of forecasting scenario and this would require conditional checks within the model to transform the model outputs (i.e loss function inputs) to the correct format. Imo, this is something that would be tedious to the developer and it would require a comprehensive read-through of the docs/code to understand the require format of the loss function for a simple training step.
  • model handle quantiles in different ways, some handle quantiles as a separate dimension (4d tensor), some models scale the feature dimension to handle quantiles (n_features * quantiles).

Describe the solution you'd like
After discussion with @agobbifbk, it has been decided to keep the output tensor format as a standard - (batch, timesteps, n_features, quantiles). This would be beneficial when adding new models and need not worry about transforming (squeeze/unsqueeze) tensors at every step.

I've gone through some of the parts of the codebase which might require changes to accommodate a standard output format. Here's a workflow, at a high level.

  • Add a standardize_output function in the BaseModel (v2) to forcefully convert the model output into the standard 4D tensor.
  • At the loss function stage, we have two options.
    • For backward compatibility, create a separate Metric module for v2 to take in this standard output format and compute loss only on the required dimensions. This can include some basic loss functions in v2, and allow us to start from scratch.
    • Or make changes to MultiHorizonMetric in the existing v1 to handle this standard output. This change might be confined to to_prediction and
    • to_quantiles method. Not sure about this fully, imho.

Additional context
Refer discussion in PR #1874 and issue #1883

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions