Skip to content

Commit 7da91b1

Browse files
committed
Extract scalar only when necessary
1 parent f23e90b commit 7da91b1

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

botorch/optim/optimize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ def _optimize_acqf_all_features_fixed(
192192
X = X.expand(q, *X.shape)
193193
with torch.no_grad():
194194
acq_value = acq_function(X)
195-
return X, acq_value[0]
195+
if acq_value.ndim == 1:
196+
acq_value = acq_value[0]
197+
return X, acq_value
196198

197199

198200
def _validate_sequential_inputs(opt_inputs: OptimizeAcqfInputs) -> None:

0 commit comments

Comments
 (0)