Skip to content

[Bug]: botorch_community VBLLModel posterior doesn't work with singel value tensor #2899

@CarloMittelbach

Description

@CarloMittelbach

What happened?

VBBLs posterior method seems to convert mean & variance to scalar value, when input tensor is single value.
This isn't compatible with torch.diag_embed() also used in the posterior function.

Please provide a minimal, reproducible example of the unexpected behavior.

import torch
import traceback
from botorch_community.models.vblls import VBLLModel


def main():
    torch.set_default_dtype(torch.float64)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = VBLLModel(in_features=1, out_features=1).to(device)
    test_X_single = torch.tensor([[0.5]], device=device)
    try:
        posterior = model.posterior(test_X_single)
    except IndexError as e:
        traceback.print_exc()
        
    test_X_multi = torch.tensor([[0.25], [0.75]], device=device)
    try:
        posterior_multi = model.posterior(test_X_multi)
        print(f"Worked as expected!")
    except Exception as e:
        print(f"You won't get here: {e}")


if __name__ == "__main__":
    main()

Please paste any relevant traceback/logs produced by the example provided.

Ouput:

Traceback (most recent call last):
  File "/home/carlo/uni/SoSe25/automahfic/automahfic/bo/surrogates/vbbl/min_example_failure.py", line 13, in main
    posterior = model.posterior(test_X_single)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/carlo/uni/SoSe25/automahfic/venv/lib/python3.12/site-packages/botorch_community/models/vblls.py", line 492, in posterior
    cov = torch.diag_embed(variance)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)

Multi worked as expected!

BoTorch Version

0.14.0

Python Version

3.12.3

Operating System

Linux Ubuntu 24.04

(Optional) Describe any potential fixes you've considered to the issue outlined above.

in vbbls.py VBBLModel

    def posterior(
        self,
        X: Tensor,
        output_indices=None,
        observation_noise=None,
        posterior_transform=None,
    ) -> Posterior:
        if X.dim() > 3:
            raise ValueError(f"Input must have at most 3 dimensions, got {X.dim()}.")

        # Determine if the input is batched
        batched = X.dim() == 3

        if not batched:
            N, D = X.shape
            B = 1
        else:
            B, N, D = X.shape
            X = X.reshape(B * N, D)

        posterior = self.model(X).predictive
        # Extract mean and variance
        mean = posterior.mean.squeeze()
        variance = posterior.variance.squeeze(-1) # make sure variance isn't converted to scalar for single value input
        cov = torch.diag_embed(variance)

        K = self.num_outputs
        mean = mean.reshape(B, N * K)

        # Cov must be `(B, N*K, N*K)`
        cov = cov.reshape(B, N, K, B, N, K)
        cov = torch.einsum("bnkbrl->bnkrl", cov)  # (B, N, K, N, K)
        cov = cov.reshape(B, N * K, N * K)

        # Remove fake batch dimension if not batched
        if not batched:
            mean = mean.squeeze(0)
            cov = cov.squeeze(0)

        # pass as MultivariateNormal to GPyTorchPosterior
        mvn_dist = MultivariateNormal(mean, cov)
        post_pred = GPyTorchPosterior(mvn_dist)
        return BLLPosterior(
            posterior=post_pred, model=self, X=X, output_dim=self.num_outputs
        )

making sure, that variance stays a tensor seems to fix the issue

Pull Request

Yes

Code of Conduct

  • I agree to follow BoTorch's Code of Conduct

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions