-
Notifications
You must be signed in to change notification settings - Fork 205
Description
🚀 Feature Request
Add support for multimodal x_o
SBI expects data x_o to be Tensor, which makes it difficult to work with multimodal data. In many real-world scientific applications data can contain different modalities, e.g. 2d images, 1d signals, scalar context data, which should be provided together as x_o. Currently, one would have to make some tricks like x_o = torch.cat([data.flatten(1) for data in data_list], 1), which is obviously not user-friendly. An ideal solution would be to support both Tensor and dict[str, Tensor].
Describe the solution you'd like
Add a dedicated Data class to wrap x_o:
- Support both
Tensoranddict[str, Tensor]in the constructor (or a list of these for iid). to_model_input()method that would return a Tensor or dict[str, Tensor] (so no changes in the current models are needed).- Provide sbi-specific methods, including
.batch_sizeand.is_iid()(instead of checkingx_o.shape[0]). - Add native concatenation for iid mode.
- Register as a pytree.
Describe alternatives you've considered
We could use some existing solutions like tensordict. It seems that tensordict could be integrated into pytorch in the future, but so far that would be an additional dependency and probably an overkill.
📌 Additional Context
Already discussed with @janfb briefly, but we could talk more about the details. It seems to me that we could make a surgical addition with minimal changes and I would be happy to do that.