50
50
MAX_ITER_INIT = 100
51
51
CONVERGENCE_TOL = 1e-8 # Optimizer convergence tolerance.
52
52
DUPLICATE_TOL = 1e-6 # Tolerance for deduplicating initial candidates.
53
+ STOP_AFTER_SHARE_CONVERGED = 1.0 # We optimize multiple configurations at once
54
+ # in `optimize_acqf_mixed_alternating`. This option controls, whether to stop
55
+ # optimizing after the given share has converged.
56
+ # Convergence is defined as the improvements of one discrete, followed by a scalar
57
+ # optimization yield less than `CONVERGENCE_TOL` improvements.
53
58
54
59
SUPPORTED_OPTIONS = {
55
60
"initialization_strategy" ,
@@ -564,63 +569,113 @@ def discrete_step(
564
569
discrete_dims: A tensor of indices corresponding to binary and
565
570
integer parameters.
566
571
cat_dims: A tensor of indices corresponding to categorical parameters.
567
- current_x: Starting point . A tensor of shape `d`.
572
+ current_x: Batch of starting points . A tensor of shape `b x d`.
568
573
569
574
Returns:
570
- A tuple of two tensors: a (d)-dim tensor of optimized point
575
+ A tuple of two tensors: a (b, d)-dim tensor of optimized point
571
576
and a scalar tensor of correspondins acquisition value.
572
577
"""
573
578
with torch .no_grad ():
574
- current_acqval = opt_inputs .acq_function (current_x .unsqueeze (0 ))
579
+ current_acqvals = opt_inputs .acq_function (current_x .unsqueeze (1 ))
575
580
options = opt_inputs .options or {}
576
- for _ in range (
577
- assert_is_instance (options .get ("maxiter_discrete" , MAX_ITER_DISCRETE ), int )
578
- ):
579
- neighbors = []
580
- if discrete_dims .numel ():
581
- x_neighbors_discrete = get_nearest_neighbors (
582
- current_x = current_x .detach (),
583
- bounds = opt_inputs .bounds ,
584
- discrete_dims = discrete_dims ,
585
- )
586
- x_neighbors_discrete = _filter_infeasible (
587
- X = x_neighbors_discrete ,
588
- inequality_constraints = opt_inputs .inequality_constraints ,
589
- )
590
- neighbors .append (x_neighbors_discrete )
581
+ maxiter_discrete = options .get ("maxiter_discrete" , MAX_ITER_DISCRETE )
582
+ done = torch .zeros (len (current_x ), dtype = torch .bool )
583
+ for _ in range (assert_is_instance (maxiter_discrete , int )):
584
+ # we don't batch this, as the number of x_neighbors can be different
585
+ # for each entry (as duplicates are removed), and the most expensive
586
+ # op is the acq_function, which is batched
587
+ # TODO finding the set of neighbors currently is done sequentially
588
+ # for one item in the batch after the other
589
+ x_neighbors_list = [None for _ in done ]
590
+ for i in range (len (done )):
591
+ # don't do anything if we are already done
592
+ if done [i ]:
593
+ continue
594
+
595
+ neighbors = []
596
+
597
+ # if we have discrete_dims look for neighbors by stepping +1/-1
598
+ if discrete_dims .numel ():
599
+ x_neighbors_discrete = get_nearest_neighbors (
600
+ current_x = current_x [i ].detach (),
601
+ bounds = opt_inputs .bounds ,
602
+ discrete_dims = discrete_dims ,
603
+ )
604
+ x_neighbors_discrete = _filter_infeasible (
605
+ X = x_neighbors_discrete ,
606
+ inequality_constraints = opt_inputs .inequality_constraints ,
607
+ )
608
+ neighbors .append (x_neighbors_discrete )
609
+
610
+ # if we have cat_dims look for neighbors by changing the cat's
611
+ if cat_dims .numel ():
612
+ x_neighbors_cat = get_categorical_neighbors (
613
+ current_x = current_x [i ].detach (),
614
+ bounds = opt_inputs .bounds ,
615
+ cat_dims = cat_dims ,
616
+ )
617
+ x_neighbors_cat = _filter_infeasible (
618
+ X = x_neighbors_cat ,
619
+ inequality_constraints = opt_inputs .inequality_constraints ,
620
+ )
621
+ neighbors .append (x_neighbors_cat )
591
622
592
- if cat_dims .numel ():
593
- x_neighbors_cat = get_categorical_neighbors (
594
- current_x = current_x .detach (),
595
- bounds = opt_inputs .bounds ,
596
- cat_dims = cat_dims ,
597
- )
598
- x_neighbors_cat = _filter_infeasible (
599
- X = x_neighbors_cat ,
600
- inequality_constraints = opt_inputs .inequality_constraints ,
601
- )
602
- neighbors .append (x_neighbors_cat )
623
+ x_neighbors = torch .cat (neighbors , dim = 0 )
624
+ if x_neighbors .numel () == 0 :
625
+ # If the i'th point has no neighbors, we mark it as done
626
+ done [i ] = True
627
+ x_neighbors_list [i ] = x_neighbors
603
628
604
- x_neighbors = torch .cat (neighbors , dim = 0 )
605
- if x_neighbors .numel () == 0 :
606
- # Exit gracefully with last point if there are no feasible neighbors.
629
+ # Exit if all batches converged or have no valid neighbors left.
630
+ if done .all ():
607
631
break
632
+
633
+ all_x_neighbors = torch .cat (
634
+ [
635
+ x_neighbors
636
+ for x_neighbors in x_neighbors_list
637
+ if x_neighbors is not None
638
+ ],
639
+ dim = 0 ,
640
+ ) # shape: (sum(#neihbors of the items in the batch), d)
608
641
with torch .no_grad ():
642
+ # This is the most expensive call in this function.
643
+ # The reason that `discrete_step` uses a batched x
644
+ # rather than looping over each batch within x is so that
645
+ # we can batch this call. This leads to an overall speedup
646
+ # even though the above and below for loops cannot
647
+ # be sped up by batching.
609
648
acq_vals = torch .cat (
610
649
[
611
650
opt_inputs .acq_function (X_ .unsqueeze (- 2 ))
612
- for X_ in x_neighbors .split (
651
+ for X_ in all_x_neighbors .split (
613
652
options .get ("init_batch_limit" , MAX_BATCH_SIZE )
614
653
)
615
654
]
616
655
)
617
- argmax = acq_vals .argmax ()
618
- improvement = acq_vals [argmax ] - current_acqval
619
- if improvement > 0 :
620
- current_acqval , current_x = acq_vals [argmax ], x_neighbors [argmax ]
621
- if improvement <= options .get ("tol" , CONVERGENCE_TOL ):
622
- break
623
- return current_x , current_acqval
656
+ offset = 0
657
+ for i in range (len (done )):
658
+ if done [i ]:
659
+ continue
660
+
661
+ # We index into all_x_neighbors in the following convoluted way,
662
+ # as it is a flattened version of x_neighbors_list with the
663
+ # None entries removed. That is why we do not increase offset if done[i].
664
+ width = len (x_neighbors_list [i ])
665
+ x_neighbors = all_x_neighbors [offset : offset + width ]
666
+ max_acq , argmax = acq_vals [offset : offset + width ].max (dim = 0 )
667
+ improvement = max_acq - current_acqvals [i ]
668
+ if improvement > 0 :
669
+ current_acqvals [i ], current_x [i ] = (
670
+ max_acq ,
671
+ x_neighbors [argmax ],
672
+ )
673
+ if improvement <= options .get ("tol" , CONVERGENCE_TOL ):
674
+ done [i ] = True
675
+
676
+ offset += width
677
+
678
+ return current_x , current_acqvals
624
679
625
680
626
681
def continuous_step (
@@ -635,39 +690,49 @@ def continuous_step(
635
690
opt_inputs: Common set of arguments for acquisition optimization.
636
691
This function utilizes `acq_function`, `bounds`, `options`,
637
692
`fixed_features` and constraints from `opt_inputs`.
693
+ `opt_inputs.return_best_only` should be `False`.
638
694
discrete_dims: A tensor of indices corresponding to binary and
639
695
integer parameters.
640
696
cat_dims: A tensor of indices corresponding to categorical parameters.
641
- current_x: Starting point. A tensor of shape `d`.
697
+ current_x: Starting point. A tensor of shape `b x d`.
642
698
643
699
Returns:
644
- A tuple of two tensors: a (1 x d)-dim tensor of optimized points
645
- and a (1 )-dim tensor of acquisition values.
700
+ A tuple of two tensors: a (b x d)-dim tensor of optimized points
701
+ and a (b )-dim tensor of acquisition values.
646
702
"""
703
+
704
+ if opt_inputs .return_best_only :
705
+ raise UnsupportedError (
706
+ "`continuous_step` does not support `return_best_only=True`."
707
+ )
708
+
709
+ d = current_x .shape [- 1 ]
647
710
options = opt_inputs .options or {}
648
711
non_cont_dims = torch .cat ((discrete_dims , cat_dims ), dim = 0 )
649
712
650
- if len (non_cont_dims ) == len ( current_x ) : # nothing continuous to optimize
713
+ if len (non_cont_dims ) == d : # nothing continuous to optimize
651
714
with torch .no_grad ():
652
- return current_x , opt_inputs .acq_function (current_x .unsqueeze (0 ))
715
+ return current_x , opt_inputs .acq_function (current_x .unsqueeze (1 ))
653
716
654
717
updated_opt_inputs = dataclasses .replace (
655
718
opt_inputs ,
656
719
q = 1 ,
657
- num_restarts = 1 ,
658
720
raw_samples = None ,
659
- batch_initial_conditions = current_x .unsqueeze (0 ),
721
+ # unsqueeze to add the q dimension
722
+ batch_initial_conditions = current_x .unsqueeze (1 ),
660
723
fixed_features = {
661
- ** dict ( zip ( non_cont_dims .tolist (), current_x [ non_cont_dims ])) ,
724
+ ** { d : current_x [:, d ] for d in non_cont_dims .tolist ()} ,
662
725
** (opt_inputs .fixed_features or {}),
663
726
},
664
727
options = {
665
728
"maxiter" : options .get ("maxiter_continuous" , MAX_ITER_CONT ),
666
729
"tol" : options .get ("tol" , CONVERGENCE_TOL ),
667
730
"batch_limit" : options .get ("batch_limit" , MAX_BATCH_SIZE ),
731
+ "max_optimization_problem_aggregation_size" : 1 ,
668
732
},
669
733
)
670
- return _optimize_acqf (opt_inputs = updated_opt_inputs )
734
+ best_X , best_acq_values = _optimize_acqf (opt_inputs = updated_opt_inputs )
735
+ return best_X .view_as (current_x ), best_acq_values
671
736
672
737
673
738
def optimize_acqf_mixed_alternating (
@@ -761,6 +826,12 @@ def optimize_acqf_mixed_alternating(
761
826
762
827
fixed_features = fixed_features or {}
763
828
options = options or {}
829
+ if options .get ("max_optimization_problem_aggregation_size" , 1 ) != 1 :
830
+ raise UnsupportedError (
831
+ "optimize_acqf_mixed_alternating does not support "
832
+ "max_optimization_problem_aggregation_size != 1. "
833
+ "You might leave this option empty, though."
834
+ )
764
835
options .setdefault ("batch_limit" , MAX_BATCH_SIZE )
765
836
options .setdefault ("init_batch_limit" , options ["batch_limit" ])
766
837
if not (keys := set (options .keys ())).issubset (SUPPORTED_OPTIONS ):
@@ -793,11 +864,18 @@ def optimize_acqf_mixed_alternating(
793
864
fixed_features = fixed_features ,
794
865
post_processing_func = post_processing_func ,
795
866
batch_initial_conditions = None ,
796
- return_best_only = True ,
867
+ return_best_only = False , # We don't want to perform the cont. optimization
868
+ # step and only return best, but this function itself only returns best
797
869
gen_candidates = gen_candidates_scipy ,
798
- sequential = sequential ,
870
+ sequential = sequential , # only relevant if all dims are cont.
799
871
)
800
- _validate_sequential_inputs (opt_inputs = opt_inputs )
872
+ if sequential :
873
+ # Sequential optimization requires return_best_only to be True
874
+ # But we turn it off here, as we "manually" perform the seq.
875
+ # conditioning in the loop below
876
+ _validate_sequential_inputs (
877
+ opt_inputs = dataclasses .replace (opt_inputs , return_best_only = True )
878
+ )
801
879
802
880
base_X_pending = acq_function .X_pending if q > 1 else None
803
881
dim = bounds .shape [- 1 ]
@@ -808,7 +886,12 @@ def optimize_acqf_mixed_alternating(
808
886
non_cont_dims = [* discrete_dims , * cat_dims ]
809
887
if len (non_cont_dims ) == 0 :
810
888
# If the problem is fully continuous, fall back to standard optimization.
811
- return _optimize_acqf (opt_inputs = opt_inputs )
889
+ return _optimize_acqf (
890
+ opt_inputs = dataclasses .replace (
891
+ opt_inputs ,
892
+ return_best_only = True ,
893
+ )
894
+ )
812
895
if not (
813
896
isinstance (non_cont_dims , list )
814
897
and len (set (non_cont_dims )) == len (non_cont_dims )
@@ -842,26 +925,28 @@ def optimize_acqf_mixed_alternating(
842
925
cont_dims = cont_dims ,
843
926
)
844
927
845
- # TODO: Eliminate this for loop. Tensors being unequal sizes could potentially
846
- # be handled by concatenating them rather than stacking, and keeping a list
847
- # of indices.
848
- for i in range (num_restarts ):
849
- alternate_steps = 0
850
- while alternate_steps < options .get ("maxiter_alternating" , MAX_ITER_ALTER ):
851
- starting_acq_val = best_acq_val [i ].clone ()
852
- alternate_steps += 1
853
- for step in (discrete_step , continuous_step ):
854
- best_X [i ], best_acq_val [i ] = step (
855
- opt_inputs = opt_inputs ,
856
- discrete_dims = discrete_dims_t ,
857
- cat_dims = cat_dims_t ,
858
- current_x = best_X [i ],
859
- )
928
+ done = torch .zeros (len (best_X ), dtype = torch .bool , device = tkwargs ["device" ])
929
+ for _step in range (options .get ("maxiter_alternating" , MAX_ITER_ALTER )):
930
+ starting_acq_val = best_acq_val .clone ()
931
+ best_X [~ done ], best_acq_val [~ done ] = discrete_step (
932
+ opt_inputs = opt_inputs ,
933
+ discrete_dims = discrete_dims_t ,
934
+ cat_dims = cat_dims_t ,
935
+ current_x = best_X [~ done ],
936
+ )
937
+
938
+ best_X [~ done ], best_acq_val [~ done ] = continuous_step (
939
+ opt_inputs = opt_inputs ,
940
+ discrete_dims = discrete_dims_t ,
941
+ cat_dims = cat_dims_t ,
942
+ current_x = best_X [~ done ],
943
+ )
860
944
861
- improvement = best_acq_val [i ] - starting_acq_val
862
- if improvement < options .get ("tol" , CONVERGENCE_TOL ):
863
- # Check for convergence
864
- break
945
+ improvement = best_acq_val - starting_acq_val
946
+ done_now = improvement < options .get ("tol" , CONVERGENCE_TOL )
947
+ done = done | done_now
948
+ if done .float ().mean () >= STOP_AFTER_SHARE_CONVERGED :
949
+ break
865
950
866
951
new_candidate = best_X [torch .argmax (best_acq_val )].unsqueeze (0 )
867
952
candidates = torch .cat ([candidates , new_candidate ], dim = - 2 )
0 commit comments