Skip to content

Commit 0eea0b7

Browse files
SamuelGabrielfacebook-github-bot
authored andcommitted
Batching Mixed Optimization (#2895)
Summary: Pull Request resolved: #2895 So far, our optimization in mixed search spaces work on each restart separately and sequentially instead of batching them. Here, we change this to batch the restarts, based on the new l-bfgs-b implementation that supports this. This speeds up mixed search spaces a lot (depending on the problem around 3-4x speedups). Reviewed By: esantorella Differential Revision: D76517454 fbshipit-source-id: c2b3ec3bb2bbe8a18212e8247b67391ddba2941b
1 parent 78de86d commit 0eea0b7

File tree

2 files changed

+373
-114
lines changed

2 files changed

+373
-114
lines changed

botorch/optim/optimize_mixed.py

Lines changed: 157 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@
5050
MAX_ITER_INIT = 100
5151
CONVERGENCE_TOL = 1e-8 # Optimizer convergence tolerance.
5252
DUPLICATE_TOL = 1e-6 # Tolerance for deduplicating initial candidates.
53+
STOP_AFTER_SHARE_CONVERGED = 1.0 # We optimize multiple configurations at once
54+
# in `optimize_acqf_mixed_alternating`. This option controls, whether to stop
55+
# optimizing after the given share has converged.
56+
# Convergence is defined as the improvements of one discrete, followed by a scalar
57+
# optimization yield less than `CONVERGENCE_TOL` improvements.
5358

5459
SUPPORTED_OPTIONS = {
5560
"initialization_strategy",
@@ -564,63 +569,113 @@ def discrete_step(
564569
discrete_dims: A tensor of indices corresponding to binary and
565570
integer parameters.
566571
cat_dims: A tensor of indices corresponding to categorical parameters.
567-
current_x: Starting point. A tensor of shape `d`.
572+
current_x: Batch of starting points. A tensor of shape `b x d`.
568573
569574
Returns:
570-
A tuple of two tensors: a (d)-dim tensor of optimized point
575+
A tuple of two tensors: a (b, d)-dim tensor of optimized point
571576
and a scalar tensor of correspondins acquisition value.
572577
"""
573578
with torch.no_grad():
574-
current_acqval = opt_inputs.acq_function(current_x.unsqueeze(0))
579+
current_acqvals = opt_inputs.acq_function(current_x.unsqueeze(1))
575580
options = opt_inputs.options or {}
576-
for _ in range(
577-
assert_is_instance(options.get("maxiter_discrete", MAX_ITER_DISCRETE), int)
578-
):
579-
neighbors = []
580-
if discrete_dims.numel():
581-
x_neighbors_discrete = get_nearest_neighbors(
582-
current_x=current_x.detach(),
583-
bounds=opt_inputs.bounds,
584-
discrete_dims=discrete_dims,
585-
)
586-
x_neighbors_discrete = _filter_infeasible(
587-
X=x_neighbors_discrete,
588-
inequality_constraints=opt_inputs.inequality_constraints,
589-
)
590-
neighbors.append(x_neighbors_discrete)
581+
maxiter_discrete = options.get("maxiter_discrete", MAX_ITER_DISCRETE)
582+
done = torch.zeros(len(current_x), dtype=torch.bool)
583+
for _ in range(assert_is_instance(maxiter_discrete, int)):
584+
# we don't batch this, as the number of x_neighbors can be different
585+
# for each entry (as duplicates are removed), and the most expensive
586+
# op is the acq_function, which is batched
587+
# TODO finding the set of neighbors currently is done sequentially
588+
# for one item in the batch after the other
589+
x_neighbors_list = [None for _ in done]
590+
for i in range(len(done)):
591+
# don't do anything if we are already done
592+
if done[i]:
593+
continue
594+
595+
neighbors = []
596+
597+
# if we have discrete_dims look for neighbors by stepping +1/-1
598+
if discrete_dims.numel():
599+
x_neighbors_discrete = get_nearest_neighbors(
600+
current_x=current_x[i].detach(),
601+
bounds=opt_inputs.bounds,
602+
discrete_dims=discrete_dims,
603+
)
604+
x_neighbors_discrete = _filter_infeasible(
605+
X=x_neighbors_discrete,
606+
inequality_constraints=opt_inputs.inequality_constraints,
607+
)
608+
neighbors.append(x_neighbors_discrete)
609+
610+
# if we have cat_dims look for neighbors by changing the cat's
611+
if cat_dims.numel():
612+
x_neighbors_cat = get_categorical_neighbors(
613+
current_x=current_x[i].detach(),
614+
bounds=opt_inputs.bounds,
615+
cat_dims=cat_dims,
616+
)
617+
x_neighbors_cat = _filter_infeasible(
618+
X=x_neighbors_cat,
619+
inequality_constraints=opt_inputs.inequality_constraints,
620+
)
621+
neighbors.append(x_neighbors_cat)
591622

592-
if cat_dims.numel():
593-
x_neighbors_cat = get_categorical_neighbors(
594-
current_x=current_x.detach(),
595-
bounds=opt_inputs.bounds,
596-
cat_dims=cat_dims,
597-
)
598-
x_neighbors_cat = _filter_infeasible(
599-
X=x_neighbors_cat,
600-
inequality_constraints=opt_inputs.inequality_constraints,
601-
)
602-
neighbors.append(x_neighbors_cat)
623+
x_neighbors = torch.cat(neighbors, dim=0)
624+
if x_neighbors.numel() == 0:
625+
# If the i'th point has no neighbors, we mark it as done
626+
done[i] = True
627+
x_neighbors_list[i] = x_neighbors
603628

604-
x_neighbors = torch.cat(neighbors, dim=0)
605-
if x_neighbors.numel() == 0:
606-
# Exit gracefully with last point if there are no feasible neighbors.
629+
# Exit if all batches converged or have no valid neighbors left.
630+
if done.all():
607631
break
632+
633+
all_x_neighbors = torch.cat(
634+
[
635+
x_neighbors
636+
for x_neighbors in x_neighbors_list
637+
if x_neighbors is not None
638+
],
639+
dim=0,
640+
) # shape: (sum(#neihbors of the items in the batch), d)
608641
with torch.no_grad():
642+
# This is the most expensive call in this function.
643+
# The reason that `discrete_step` uses a batched x
644+
# rather than looping over each batch within x is so that
645+
# we can batch this call. This leads to an overall speedup
646+
# even though the above and below for loops cannot
647+
# be sped up by batching.
609648
acq_vals = torch.cat(
610649
[
611650
opt_inputs.acq_function(X_.unsqueeze(-2))
612-
for X_ in x_neighbors.split(
651+
for X_ in all_x_neighbors.split(
613652
options.get("init_batch_limit", MAX_BATCH_SIZE)
614653
)
615654
]
616655
)
617-
argmax = acq_vals.argmax()
618-
improvement = acq_vals[argmax] - current_acqval
619-
if improvement > 0:
620-
current_acqval, current_x = acq_vals[argmax], x_neighbors[argmax]
621-
if improvement <= options.get("tol", CONVERGENCE_TOL):
622-
break
623-
return current_x, current_acqval
656+
offset = 0
657+
for i in range(len(done)):
658+
if done[i]:
659+
continue
660+
661+
# We index into all_x_neighbors in the following convoluted way,
662+
# as it is a flattened version of x_neighbors_list with the
663+
# None entries removed. That is why we do not increase offset if done[i].
664+
width = len(x_neighbors_list[i])
665+
x_neighbors = all_x_neighbors[offset : offset + width]
666+
max_acq, argmax = acq_vals[offset : offset + width].max(dim=0)
667+
improvement = max_acq - current_acqvals[i]
668+
if improvement > 0:
669+
current_acqvals[i], current_x[i] = (
670+
max_acq,
671+
x_neighbors[argmax],
672+
)
673+
if improvement <= options.get("tol", CONVERGENCE_TOL):
674+
done[i] = True
675+
676+
offset += width
677+
678+
return current_x, current_acqvals
624679

625680

626681
def continuous_step(
@@ -635,39 +690,49 @@ def continuous_step(
635690
opt_inputs: Common set of arguments for acquisition optimization.
636691
This function utilizes `acq_function`, `bounds`, `options`,
637692
`fixed_features` and constraints from `opt_inputs`.
693+
`opt_inputs.return_best_only` should be `False`.
638694
discrete_dims: A tensor of indices corresponding to binary and
639695
integer parameters.
640696
cat_dims: A tensor of indices corresponding to categorical parameters.
641-
current_x: Starting point. A tensor of shape `d`.
697+
current_x: Starting point. A tensor of shape `b x d`.
642698
643699
Returns:
644-
A tuple of two tensors: a (1 x d)-dim tensor of optimized points
645-
and a (1)-dim tensor of acquisition values.
700+
A tuple of two tensors: a (b x d)-dim tensor of optimized points
701+
and a (b)-dim tensor of acquisition values.
646702
"""
703+
704+
if opt_inputs.return_best_only:
705+
raise UnsupportedError(
706+
"`continuous_step` does not support `return_best_only=True`."
707+
)
708+
709+
d = current_x.shape[-1]
647710
options = opt_inputs.options or {}
648711
non_cont_dims = torch.cat((discrete_dims, cat_dims), dim=0)
649712

650-
if len(non_cont_dims) == len(current_x): # nothing continuous to optimize
713+
if len(non_cont_dims) == d: # nothing continuous to optimize
651714
with torch.no_grad():
652-
return current_x, opt_inputs.acq_function(current_x.unsqueeze(0))
715+
return current_x, opt_inputs.acq_function(current_x.unsqueeze(1))
653716

654717
updated_opt_inputs = dataclasses.replace(
655718
opt_inputs,
656719
q=1,
657-
num_restarts=1,
658720
raw_samples=None,
659-
batch_initial_conditions=current_x.unsqueeze(0),
721+
# unsqueeze to add the q dimension
722+
batch_initial_conditions=current_x.unsqueeze(1),
660723
fixed_features={
661-
**dict(zip(non_cont_dims.tolist(), current_x[non_cont_dims])),
724+
**{d: current_x[:, d] for d in non_cont_dims.tolist()},
662725
**(opt_inputs.fixed_features or {}),
663726
},
664727
options={
665728
"maxiter": options.get("maxiter_continuous", MAX_ITER_CONT),
666729
"tol": options.get("tol", CONVERGENCE_TOL),
667730
"batch_limit": options.get("batch_limit", MAX_BATCH_SIZE),
731+
"max_optimization_problem_aggregation_size": 1,
668732
},
669733
)
670-
return _optimize_acqf(opt_inputs=updated_opt_inputs)
734+
best_X, best_acq_values = _optimize_acqf(opt_inputs=updated_opt_inputs)
735+
return best_X.view_as(current_x), best_acq_values
671736

672737

673738
def optimize_acqf_mixed_alternating(
@@ -761,6 +826,12 @@ def optimize_acqf_mixed_alternating(
761826

762827
fixed_features = fixed_features or {}
763828
options = options or {}
829+
if options.get("max_optimization_problem_aggregation_size", 1) != 1:
830+
raise UnsupportedError(
831+
"optimize_acqf_mixed_alternating does not support "
832+
"max_optimization_problem_aggregation_size != 1. "
833+
"You might leave this option empty, though."
834+
)
764835
options.setdefault("batch_limit", MAX_BATCH_SIZE)
765836
options.setdefault("init_batch_limit", options["batch_limit"])
766837
if not (keys := set(options.keys())).issubset(SUPPORTED_OPTIONS):
@@ -793,11 +864,18 @@ def optimize_acqf_mixed_alternating(
793864
fixed_features=fixed_features,
794865
post_processing_func=post_processing_func,
795866
batch_initial_conditions=None,
796-
return_best_only=True,
867+
return_best_only=False, # We don't want to perform the cont. optimization
868+
# step and only return best, but this function itself only returns best
797869
gen_candidates=gen_candidates_scipy,
798-
sequential=sequential,
870+
sequential=sequential, # only relevant if all dims are cont.
799871
)
800-
_validate_sequential_inputs(opt_inputs=opt_inputs)
872+
if sequential:
873+
# Sequential optimization requires return_best_only to be True
874+
# But we turn it off here, as we "manually" perform the seq.
875+
# conditioning in the loop below
876+
_validate_sequential_inputs(
877+
opt_inputs=dataclasses.replace(opt_inputs, return_best_only=True)
878+
)
801879

802880
base_X_pending = acq_function.X_pending if q > 1 else None
803881
dim = bounds.shape[-1]
@@ -808,7 +886,12 @@ def optimize_acqf_mixed_alternating(
808886
non_cont_dims = [*discrete_dims, *cat_dims]
809887
if len(non_cont_dims) == 0:
810888
# If the problem is fully continuous, fall back to standard optimization.
811-
return _optimize_acqf(opt_inputs=opt_inputs)
889+
return _optimize_acqf(
890+
opt_inputs=dataclasses.replace(
891+
opt_inputs,
892+
return_best_only=True,
893+
)
894+
)
812895
if not (
813896
isinstance(non_cont_dims, list)
814897
and len(set(non_cont_dims)) == len(non_cont_dims)
@@ -842,26 +925,28 @@ def optimize_acqf_mixed_alternating(
842925
cont_dims=cont_dims,
843926
)
844927

845-
# TODO: Eliminate this for loop. Tensors being unequal sizes could potentially
846-
# be handled by concatenating them rather than stacking, and keeping a list
847-
# of indices.
848-
for i in range(num_restarts):
849-
alternate_steps = 0
850-
while alternate_steps < options.get("maxiter_alternating", MAX_ITER_ALTER):
851-
starting_acq_val = best_acq_val[i].clone()
852-
alternate_steps += 1
853-
for step in (discrete_step, continuous_step):
854-
best_X[i], best_acq_val[i] = step(
855-
opt_inputs=opt_inputs,
856-
discrete_dims=discrete_dims_t,
857-
cat_dims=cat_dims_t,
858-
current_x=best_X[i],
859-
)
928+
done = torch.zeros(len(best_X), dtype=torch.bool, device=tkwargs["device"])
929+
for _step in range(options.get("maxiter_alternating", MAX_ITER_ALTER)):
930+
starting_acq_val = best_acq_val.clone()
931+
best_X[~done], best_acq_val[~done] = discrete_step(
932+
opt_inputs=opt_inputs,
933+
discrete_dims=discrete_dims_t,
934+
cat_dims=cat_dims_t,
935+
current_x=best_X[~done],
936+
)
937+
938+
best_X[~done], best_acq_val[~done] = continuous_step(
939+
opt_inputs=opt_inputs,
940+
discrete_dims=discrete_dims_t,
941+
cat_dims=cat_dims_t,
942+
current_x=best_X[~done],
943+
)
860944

861-
improvement = best_acq_val[i] - starting_acq_val
862-
if improvement < options.get("tol", CONVERGENCE_TOL):
863-
# Check for convergence
864-
break
945+
improvement = best_acq_val - starting_acq_val
946+
done_now = improvement < options.get("tol", CONVERGENCE_TOL)
947+
done = done | done_now
948+
if done.float().mean() >= STOP_AFTER_SHARE_CONVERGED:
949+
break
865950

866951
new_candidate = best_X[torch.argmax(best_acq_val)].unsqueeze(0)
867952
candidates = torch.cat([candidates, new_candidate], dim=-2)

0 commit comments

Comments
 (0)