Skip to content

Commit eb6f2ec

Browse files
nataliemausfacebook-github-bot
authored andcommitted
Scalable Constrained Bayesian Optimization (#1257)
Summary: <!-- Thank you for sending the PR! We appreciate you spending the time to make BoTorch better. Help us understand your motivation by explaining why you decided to make this change. You can learn more about contributing to BoTorch here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md --> ## Motivation I created this PR to add support for Scalable Constrained Bayesian Optimization (SCBO) as described in [1] to BoTorch. Since BoTorch already supports TuRBO, relatively few modifications were needed to add support for SCBO. I have found the SCBO method to be very useful in my research, and I hope that this addition of SCBO to BoTorch will make it easier for others to use SCBO in the future. [1] David Eriksson and Matthias Poloczek. Scalable constrained Bayesian optimization. In International Conference on Artificial Intelligence and Statistics, pages 730–738. PMLR, 2021. (https://doi.org/10.48550/arxiv.2002.08526) To implement SCBO, this PR includes two additions to BoTorch 1. A new sampling class called ConstrainedMaxPosteriorSampling that we use to implement the constrained Thompson sampling described in [1]. 2. A tutorial Jupiter notebook that walks the user through a simple example, constrained optimization problem with SCBO. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #1257 Test Plan: I tested the added code in botorch.generation.sampling.ConstrainedMaxPosteriorSampling by running the tutorial notebook (tutorials/scalable_constrained_bo.ipynb) which uses ConstrainedMaxPosteriorSampling. ## Related PRs (If this PR adds or changes functionality, please take some time to update the docs at https://github.com/pytorch/botorch, and link to your PR here.) See the SCBO tutorial included in this PR (tutorials/scalable_constrained_bo.ipynb) for a description of the added implementation. Additionally, I am more than happy to add more formal documentation, specifically for the ConstrainedMaxPosteriorSampling class, I just wasn't sure where this documentation should go in within botorch/docs? jacobrgardner Reviewed By: Balandat Differential Revision: D37669540 Pulled By: dme65 fbshipit-source-id: c2c602e9cd2af2730c0df5553246e793499d2f18
1 parent 21ed85b commit eb6f2ec

File tree

4 files changed

+930
-2
lines changed

4 files changed

+930
-2
lines changed

botorch/generation/sampling.py

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from __future__ import annotations
1818

1919
from abc import ABC, abstractmethod
20-
from typing import Any, Optional
20+
from typing import Any, Optional, Union
2121

2222
import torch
2323
from botorch.acquisition.acquisition import AcquisitionFunction
@@ -29,6 +29,9 @@
2929
)
3030
from botorch.generation.utils import _flip_sub_unique
3131
from botorch.models.model import Model
32+
33+
from botorch.models.model_list_gp_regression import ModelListGP
34+
from botorch.models.multitask import MultiTaskGP
3235
from botorch.utils.sampling import batched_multinomial
3336
from botorch.utils.transforms import standardize
3437
from torch import Tensor
@@ -122,9 +125,11 @@ def forward(
122125
observation_noise=observation_noise,
123126
posterior_transform=self.posterior_transform,
124127
)
125-
126128
# num_samples x batch_shape x N x m
127129
samples = posterior.rsample(sample_shape=torch.Size([num_samples]))
130+
return self.maximize_samples(X, samples, num_samples)
131+
132+
def maximize_samples(self, X: Tensor, samples: Tensor, num_samples: int = 1):
128133
obj = self.objective(samples, X=X) # num_samples x batch_shape x N
129134
if self.replacement:
130135
# if we allow replacement then things are simple(r)
@@ -232,3 +237,111 @@ def forward(self, X: Tensor, num_samples: int = 1) -> Tensor:
232237
)
233238
# now do some gathering acrobatics to select the right elements from X
234239
return torch.gather(X, -2, idcs.unsqueeze(-1).expand(*idcs.shape, X.size(-1)))
240+
241+
242+
class ConstrainedMaxPosteriorSampling(MaxPosteriorSampling):
243+
r"""Sample from a set of points according to
244+
their max posterior value,
245+
which also likely meet a set of constraints
246+
c1(x) <= 0, c2(x) <= 0, ..., cm(x) <= 0
247+
c1, c2, ..., cm are black-box constraint functions
248+
Each constraint function is modeled by a seperate
249+
surrogate GP constraint model
250+
We sample points for which the posterior value
251+
for each constraint model <= 0,
252+
as described in https://doi.org/10.48550/arxiv.2002.08526
253+
254+
Example:
255+
>>> CMPS = ConstrainedMaxPosteriorSampling(model,
256+
constraint_model=ModelListGP(cmodel1, cmodel2,
257+
..., cmodelm) # models w/ feature dim d=3
258+
>>> X = torch.rand(2, 100, 3)
259+
>>> sampled_X = CMPS(X, num_samples=5)
260+
"""
261+
262+
def __init__(
263+
self,
264+
model: Model,
265+
constraint_model: Union[ModelListGP, MultiTaskGP],
266+
objective: Optional[MCAcquisitionObjective] = None,
267+
posterior_transform: Optional[PosteriorTransform] = None,
268+
replacement: bool = True,
269+
minimize_constraints_only: bool = False,
270+
) -> None:
271+
r"""Constructor for the SamplingStrategy base class.
272+
273+
Args:
274+
model: A fitted model.
275+
objective: The MCAcquisitionObjective under
276+
which the samples are evaluated.
277+
Defaults to `IdentityMCObjective()`.
278+
posterior_transform: An optional PosteriorTransform.
279+
replacement: If True, sample with replacement.
280+
constraint_model: either a ModelListGP where each submodel
281+
is a GP model for one constraint function,
282+
or a MultiTaskGP model where each task is one
283+
constraint function
284+
All constraints are of the form c(x) <= 0.
285+
In the case when the constraint model predicts
286+
that all candidates violate constraints,
287+
we pick the candidates with minimum violation.
288+
minimize_constraints_only: False by default, if true,
289+
we will automatically return the candidates
290+
with minimum posterior constraint values,
291+
(minimum predicted c(x) summed over all constraints)
292+
reguardless of predicted objective values.
293+
"""
294+
super().__init__(
295+
model=model,
296+
objective=objective,
297+
posterior_transform=posterior_transform,
298+
replacement=replacement,
299+
)
300+
self.constraint_model = constraint_model
301+
self.minimize_constraints_only = minimize_constraints_only
302+
303+
def forward(
304+
self, X: Tensor, num_samples: int = 1, observation_noise: bool = False
305+
) -> Tensor:
306+
r"""Sample from the model posterior.
307+
308+
Args:
309+
X: A `batch_shape x N x d`-dim Tensor
310+
from which to sample (in the `N`
311+
dimension) according to the maximum
312+
posterior value under the objective.
313+
num_samples: The number of samples to draw.
314+
observation_noise: If True, sample with observation noise.
315+
316+
Returns:
317+
A `batch_shape x num_samples x d`-dim
318+
Tensor of samples from `X`, where
319+
`X[..., i, :]` is the `i`-th sample.
320+
"""
321+
posterior = self.model.posterior(X, observation_noise=observation_noise)
322+
samples = posterior.rsample(sample_shape=torch.Size([num_samples]))
323+
324+
c_posterior = self.constraint_model.posterior(
325+
X, observation_noise=observation_noise
326+
)
327+
constraint_samples = c_posterior.rsample(sample_shape=torch.Size([num_samples]))
328+
valid_samples = constraint_samples <= 0
329+
if valid_samples.shape[-1] > 1: # if more than one constraint
330+
valid_samples = torch.all(valid_samples, dim=-1).unsqueeze(-1)
331+
if (valid_samples.sum() == 0) or self.minimize_constraints_only:
332+
# if none of the samples meet the constraints
333+
# we pick the one that minimizes total violation
334+
constraint_samples = constraint_samples.sum(dim=-1)
335+
idcs = torch.argmin(constraint_samples, dim=-1)
336+
if idcs.ndim > 1:
337+
idcs = idcs.permute(*range(1, idcs.ndim), 0)
338+
idcs = idcs.unsqueeze(-1).expand(*idcs.shape, X.size(-1))
339+
Xe = X.expand(*constraint_samples.shape[1:], X.size(-1))
340+
return torch.gather(Xe, -2, idcs)
341+
# replace all violators with -infinty so it will never choose them
342+
replacement_infs = -torch.inf * torch.ones(samples.shape).to(X.device).to(
343+
X.dtype
344+
)
345+
samples = torch.where(valid_samples, samples, replacement_infs)
346+
347+
return self.maximize_samples(X, samples, num_samples)

test/generation/test_sampling.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@
1919
)
2020
from botorch.generation.sampling import (
2121
BoltzmannSampling,
22+
ConstrainedMaxPosteriorSampling,
2223
MaxPosteriorSampling,
2324
SamplingStrategy,
2425
)
26+
from botorch.models import SingleTaskGP
27+
from botorch.models.model_list_gp_regression import ModelListGP
2528
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
2629

2730

@@ -190,3 +193,130 @@ def test_boltzmann_sampling(self):
190193
BS = BoltzmannSampling(acqf, eta=10.0)
191194
samples = BS(X, num_samples=1)
192195
self.assertTrue(torch.equal(samples, X[max_idx, :]))
196+
197+
198+
class TestConstrainedMaxPosteriorSampling(BotorchTestCase):
199+
def test_init(self):
200+
mm = MockModel(MockPosterior(mean=None))
201+
cmms = MockModel(MockPosterior(mean=None))
202+
MPS = ConstrainedMaxPosteriorSampling(mm, cmms)
203+
self.assertEqual(MPS.model, mm)
204+
self.assertTrue(MPS.replacement)
205+
self.assertIsInstance(MPS.objective, IdentityMCObjective)
206+
obj = LinearMCObjective(torch.rand(2))
207+
MPS = ConstrainedMaxPosteriorSampling(
208+
mm, cmms, objective=obj, replacement=False
209+
)
210+
self.assertEqual(MPS.objective, obj)
211+
self.assertFalse(MPS.replacement)
212+
213+
def test_constrained_max_posterior_sampling(self):
214+
batch_shapes = (torch.Size(), torch.Size([3]), torch.Size([3, 2]))
215+
dtypes = (torch.float, torch.double)
216+
for batch_shape, dtype, N, num_samples, d in itertools.product(
217+
batch_shapes, dtypes, (5, 6), (1, 2), (1, 2)
218+
):
219+
tkwargs = {"device": self.device, "dtype": dtype}
220+
# X is `batch_shape x N x d` = batch_shape x N x 1.
221+
X = torch.randn(*batch_shape, N, d, **tkwargs)
222+
# the event shape is `num_samples x batch_shape x N x m`
223+
psamples = torch.zeros(num_samples, *batch_shape, N, 1, **tkwargs)
224+
psamples[..., 0, :] = 1.0
225+
226+
# IdentityMCObjective, with replacement
227+
with mock.patch.object(MockPosterior, "rsample", return_value=psamples):
228+
mp = MockPosterior(None)
229+
with mock.patch.object(MockModel, "posterior", return_value=mp):
230+
mm = MockModel(None)
231+
c_model1 = SingleTaskGP(X, torch.randn(X.shape[0:-1]).unsqueeze(-1))
232+
c_model2 = SingleTaskGP(X, torch.randn(X.shape[0:-1]).unsqueeze(-1))
233+
c_model3 = SingleTaskGP(X, torch.randn(X.shape[0:-1]).unsqueeze(-1))
234+
cmms1 = MockModel(MockPosterior(mean=None))
235+
cmms2 = ModelListGP(c_model1, c_model2)
236+
cmms3 = ModelListGP(c_model1, c_model2, c_model3)
237+
for cmms in [cmms1, cmms2, cmms3]:
238+
MPS = ConstrainedMaxPosteriorSampling(mm, cmms)
239+
s1 = MPS(X, num_samples=num_samples)
240+
# run again with minimize_constraints_only
241+
MPS = ConstrainedMaxPosteriorSampling(
242+
mm, cmms, minimize_constraints_only=True
243+
)
244+
s2 = MPS(X, num_samples=num_samples)
245+
assert s1.shape == s2.shape
246+
247+
# ScalarizedObjective, with replacement
248+
with mock.patch.object(MockPosterior, "rsample", return_value=psamples):
249+
mp = MockPosterior(None)
250+
with mock.patch.object(MockModel, "posterior", return_value=mp):
251+
mm = MockModel(None)
252+
cmms = MockModel(None)
253+
with mock.patch.object(
254+
ScalarizedObjective, "forward", return_value=mp
255+
):
256+
obj = ScalarizedObjective(torch.rand(2, **tkwargs))
257+
MPS = ConstrainedMaxPosteriorSampling(mm, cmms, objective=obj)
258+
s = MPS(X, num_samples=num_samples)
259+
self.assertTrue(s.shape[-2] == num_samples)
260+
261+
# ScalarizedPosteriorTransform w/ replacement
262+
with mock.patch.object(MockPosterior, "rsample", return_value=psamples):
263+
mp = MockPosterior(None)
264+
with mock.patch.object(MockModel, "posterior", return_value=mp):
265+
mm = MockModel(None)
266+
cmms = MockModel(None)
267+
with mock.patch.object(
268+
ScalarizedPosteriorTransform, "forward", return_value=mp
269+
):
270+
post_tf = ScalarizedPosteriorTransform(torch.rand(2, **tkwargs))
271+
MPS = ConstrainedMaxPosteriorSampling(
272+
mm, cmms, posterior_transform=post_tf
273+
)
274+
s = MPS(X, num_samples=num_samples)
275+
self.assertTrue(s.shape[-2] == num_samples)
276+
277+
# ScalarizedPosteriorTransform and Scalarized obj
278+
mp = MockPosterior(None)
279+
mm = MockModel(posterior=mp)
280+
mp = MockPosterior(None)
281+
cmms = MockModel(posterior=mp)
282+
obj = ScalarizedObjective(torch.rand(2, **tkwargs))
283+
post_tf = ScalarizedPosteriorTransform(torch.rand(2, **tkwargs))
284+
with self.assertRaises(RuntimeError):
285+
ConstrainedMaxPosteriorSampling(
286+
mm, cmms, posterior_transform=post_tf, objective=obj
287+
)
288+
289+
# without replacement
290+
psamples[..., 1, 0] = 1e-6
291+
with mock.patch.object(MockPosterior, "rsample", return_value=psamples):
292+
mp = MockPosterior(None)
293+
with mock.patch.object(MockModel, "posterior", return_value=mp):
294+
mm = MockModel(None)
295+
cmms = MockModel(None)
296+
MPS = ConstrainedMaxPosteriorSampling(mm, cmms, replacement=False)
297+
if len(batch_shape) > 1:
298+
with self.assertRaises(NotImplementedError):
299+
MPS(X, num_samples=num_samples)
300+
else:
301+
s = MPS(X, num_samples=num_samples)
302+
self.assertTrue(s.shape[-2] == num_samples)
303+
304+
# ScalarizedMCObjective, without replacement
305+
with mock.patch.object(MockPosterior, "rsample", return_value=psamples):
306+
mp = MockPosterior(None)
307+
with mock.patch.object(MockModel, "posterior", return_value=mp):
308+
mm = MockModel(None)
309+
cmms = MockModel(None)
310+
with mock.patch.object(
311+
ScalarizedObjective, "forward", return_value=mp
312+
):
313+
obj = ScalarizedObjective(torch.rand(2, **tkwargs))
314+
MPS = ConstrainedMaxPosteriorSampling(
315+
mm, cmms, objective=obj, replacement=False
316+
)
317+
if len(batch_shape) > 1:
318+
with self.assertRaises(NotImplementedError):
319+
MPS(X, num_samples=num_samples)
320+
else:
321+
s = MPS(X, num_samples=num_samples)
322+
self.assertTrue(s.shape[-2] == num_samples)

0 commit comments

Comments
 (0)