Skip to content

Commit 3ced166

Browse files
committed
Extract scalar only when necessary
1 parent 3d47b62 commit 3ced166

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

botorch/optim/optimize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,9 @@ def _optimize_acqf_all_features_fixed(
196196
X = X.expand(q, *X.shape)
197197
with torch.no_grad():
198198
acq_value = acq_function(X)
199-
return X, acq_value[0]
199+
if acq_value.ndim == 1:
200+
acq_value = acq_value[0]
201+
return X, acq_value
200202

201203

202204
def _validate_sequential_inputs(opt_inputs: OptimizeAcqfInputs) -> None:

0 commit comments

Comments
 (0)