Skip to content

Commit 18c80ff

Browse files
committed
Add test verifying output dimension in degenerate case
1 parent 9a7c517 commit 18c80ff

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

test/optim/test_optimize.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
import numpy as np
1616
import torch
17+
1718
from botorch.acquisition.acquisition import (
1819
AcquisitionFunction,
1920
OneShotAcquisitionFunction,
2021
)
22+
from botorch.acquisition.analytic import LogExpectedImprovement
2123
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
2224
from botorch.acquisition.monte_carlo import qExpectedImprovement
2325
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
@@ -1147,6 +1149,23 @@ def nlc(x):
11471149
),
11481150
)
11491151

1152+
def test_optimize_acqf_all_fixed_features(self):
1153+
train_X = torch.rand(3, 2)
1154+
train_Y = torch.rand(3, 1)
1155+
gp = SingleTaskGP(train_X=train_X, train_Y=train_Y)
1156+
gp.eval()
1157+
logEI = LogExpectedImprovement(model=gp, best_f=train_Y.max())
1158+
bounds = torch.stack([torch.zeros(2), torch.ones(2)])
1159+
_, acqf_value = optimize_acqf(
1160+
logEI,
1161+
bounds,
1162+
q=1,
1163+
num_restarts=1,
1164+
raw_samples=1,
1165+
fixed_features={0: 0, 1: 0},
1166+
)
1167+
self.assertEqual(acqf_value.ndim, 0)
1168+
11501169
def test_constraint_caching(self):
11511170
def nlc(x):
11521171
return 4 - x.sum(dim=-1)

0 commit comments

Comments
 (0)