Skip to content

Add support for multimodal data #1672

@StarostinV

Description

@StarostinV

🚀 Feature Request

Add support for multimodal x_o

SBI expects data x_o to be Tensor, which makes it difficult to work with multimodal data. In many real-world scientific applications data can contain different modalities, e.g. 2d images, 1d signals, scalar context data, which should be provided together as x_o. Currently, one would have to make some tricks like x_o = torch.cat([data.flatten(1) for data in data_list], 1), which is obviously not user-friendly. An ideal solution would be to support both Tensor and dict[str, Tensor].

Describe the solution you'd like

Add a dedicated Data class to wrap x_o:

  • Support both Tensor and dict[str, Tensor] in the constructor (or a list of these for iid).
  • to_model_input() method that would return a Tensor or dict[str, Tensor] (so no changes in the current models are needed).
  • Provide sbi-specific methods, including .batch_size and .is_iid() (instead of checking x_o.shape[0]).
  • Add native concatenation for iid mode.
  • Register as a pytree.

Describe alternatives you've considered

We could use some existing solutions like tensordict. It seems that tensordict could be integrated into pytorch in the future, but so far that would be an additional dependency and probably an overkill.

📌 Additional Context

Already discussed with @janfb briefly, but we could talk more about the details. It seems to me that we could make a surgical addition with minimal changes and I would be happy to do that.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or requestfeatureadding new features to the toolbox

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions