From 8b53bd65dcb8d915cfd40d168e343ca620d4b412 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Tue, 14 Feb 2023 16:31:42 -0800 Subject: [PATCH 01/17] acquisition function wrapper (#1532) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1532 Add a wrapper for modifying inputs/outputs. This is useful for not only probabilistic reparameterization, but will also simplify other integrated AFs (e.g. MCMC) as well as fixed feature AFs and things like prior-guided AFs Differential Revision: https://internalfb.com/D41629186 fbshipit-source-id: c2d3b339edf44a3167804b095d213b3ba98b5e13 --- botorch/acquisition/fixed_feature.py | 26 ++++-------- botorch/acquisition/penalized.py | 24 +++-------- botorch/acquisition/proximal.py | 15 ++++--- botorch/acquisition/wrapper.py | 55 ++++++++++++++++++++++++++ sphinx/source/acquisition.rst | 9 ++++- test/acquisition/test_fixed_feature.py | 2 +- test/acquisition/test_proximal.py | 8 +++- test/acquisition/test_wrapper.py | 52 ++++++++++++++++++++++++ 8 files changed, 144 insertions(+), 47 deletions(-) create mode 100644 botorch/acquisition/wrapper.py create mode 100644 test/acquisition/test_wrapper.py diff --git a/botorch/acquisition/fixed_feature.py b/botorch/acquisition/fixed_feature.py index 0f3b85faa7..763226799e 100644 --- a/botorch/acquisition/fixed_feature.py +++ b/botorch/acquisition/fixed_feature.py @@ -16,11 +16,11 @@ import torch from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from torch import Tensor -from torch.nn import Module -class FixedFeatureAcquisitionFunction(AcquisitionFunction): +class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper): """A wrapper around AquisitionFunctions to fix a subset of features. Example: @@ -56,8 +56,7 @@ def __init__( combination of `Tensor`s and numbers which can be broadcasted to form a tensor with trailing dimension size of `d_f`. """ - Module.__init__(self) - self.acq_func = acq_function + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function) dtype = torch.float device = torch.device("cpu") self.d = d @@ -126,24 +125,13 @@ def forward(self, X: Tensor): X_full = self._construct_X_full(X) return self.acq_func(X_full) - @property - def X_pending(self): - r"""Return the `X_pending` of the base acquisition function.""" - try: - return self.acq_func.X_pending - except (ValueError, AttributeError): - raise ValueError( - f"Base acquisition function {type(self.acq_func).__name__} " - "does not have an `X_pending` attribute." - ) - - @X_pending.setter - def X_pending(self, X_pending: Optional[Tensor]): + def set_X_pending(self, X_pending: Optional[Tensor]): r"""Sets the `X_pending` of the base acquisition function.""" if X_pending is not None: - self.acq_func.X_pending = self._construct_X_full(X_pending) + full_X_pending = self._construct_X_full(X_pending) else: - self.acq_func.X_pending = X_pending + full_X_pending = None + self.acq_func.set_X_pending(full_X_pending) def _construct_X_full(self, X: Tensor) -> Tensor: r"""Constructs the full input for the base acquisition function. diff --git a/botorch/acquisition/penalized.py b/botorch/acquisition/penalized.py index b114362ea9..9ee8f1fee5 100644 --- a/botorch/acquisition/penalized.py +++ b/botorch/acquisition/penalized.py @@ -15,9 +15,8 @@ import torch from botorch.acquisition.acquisition import AcquisitionFunction -from botorch.acquisition.analytic import AnalyticAcquisitionFunction from botorch.acquisition.objective import GenericMCObjective -from botorch.exceptions import UnsupportedError +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from torch import Tensor @@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor: return regularization_term -class PenalizedAcquisitionFunction(AcquisitionFunction): +class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper): r"""Single-outcome acquisition function regularized by the given penalty. The usage is similar to: @@ -161,29 +160,16 @@ def __init__( penalty_func: The regularization function. regularization_parameter: Regularization parameter used in optimization. """ - super().__init__(model=raw_acqf.model) - self.raw_acqf = raw_acqf + AcquisitionFunction.__init__(self, model=raw_acqf.model) + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf) self.penalty_func = penalty_func self.regularization_parameter = regularization_parameter def forward(self, X: Tensor) -> Tensor: - raw_value = self.raw_acqf(X=X) + raw_value = self.acq_func(X=X) penalty_term = self.penalty_func(X) return raw_value - self.regularization_parameter * penalty_term - @property - def X_pending(self) -> Optional[Tensor]: - return self.raw_acqf.X_pending - - def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None: - if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction): - self.raw_acqf.set_X_pending(X_pending=X_pending) - else: - raise UnsupportedError( - "The raw acquisition function is Analytic and does not account " - "for X_pending yet." - ) - def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor: r"""Computes the group lasso regularization function for the given point. diff --git a/botorch/acquisition/proximal.py b/botorch/acquisition/proximal.py index 9cd4aed7ad..b1d68edef1 100644 --- a/botorch/acquisition/proximal.py +++ b/botorch/acquisition/proximal.py @@ -15,6 +15,8 @@ import torch from botorch.acquisition import AcquisitionFunction + +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from botorch.exceptions.errors import UnsupportedError from botorch.models import ModelListGP from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel @@ -25,7 +27,7 @@ from torch.nn import Module -class ProximalAcquisitionFunction(AcquisitionFunction): +class ProximalAcquisitionFunction(AbstractAcquisitionFunctionWrapper): """A wrapper around AcquisitionFunctions to add proximal weighting of the acquisition function. The acquisition function is weighted via a squared exponential centered at the last training point, @@ -70,9 +72,7 @@ def __init__( beta: If not None, apply a softplus transform to the base acquisition function, allows negative base acquisition function values. """ - Module.__init__(self) - - self.acq_func = acq_function + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function) model = self.acq_func.model if hasattr(acq_function, "X_pending"): @@ -80,7 +80,6 @@ def __init__( raise UnsupportedError( "Proximal acquisition function requires `X_pending` to be None." ) - self.X_pending = acq_function.X_pending self.register_buffer("proximal_weights", proximal_weights) self.register_buffer( @@ -91,6 +90,12 @@ def __init__( _validate_model(model, proximal_weights) + def set_X_pending(self, X_pending: Optional[Tensor]) -> None: + r"""Sets the `X_pending` of the base acquisition function.""" + raise UnsupportedError( + "Proximal acquisition function does not support `X_pending`." + ) + @t_batch_mode_transform(expected_q=1, assert_output_shape=False) def forward(self, X: Tensor) -> Tensor: r"""Evaluate base acquisition function with proximal weighting. diff --git a/botorch/acquisition/wrapper.py b/botorch/acquisition/wrapper.py new file mode 100644 index 0000000000..08dfbd2849 --- /dev/null +++ b/botorch/acquisition/wrapper.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +A wrapper classes around AcquisitionFunctions to modify inputs and outputs. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional + +from botorch.acquisition.acquisition import AcquisitionFunction +from torch import Tensor +from torch.nn import Module + + +class AbstractAcquisitionFunctionWrapper(AcquisitionFunction, ABC): + r"""Abstract acquisition wrapper.""" + + def __init__(self, acq_function: AcquisitionFunction) -> None: + Module.__init__(self) + self.acq_func = acq_function + + @property + def X_pending(self) -> Optional[Tensor]: + r"""Return the `X_pending` of the base acquisition function.""" + try: + return self.acq_func.X_pending + except (ValueError, AttributeError): + raise ValueError( + f"Base acquisition function {type(self.acq_func).__name__} " + "does not have an `X_pending` attribute." + ) + + def set_X_pending(self, X_pending: Optional[Tensor]) -> None: + r"""Sets the `X_pending` of the base acquisition function.""" + self.acq_func.set_X_pending(X_pending) + + @abstractmethod + def forward(self, X: Tensor) -> Tensor: + r"""Evaluate the wrapped acquisition function on the candidate set X. + + Args: + X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim + design points each. + + Returns: + A `(b)`-dim Tensor of acquisition function values at the given + design points `X`. + """ + pass # pragma: no cover diff --git a/sphinx/source/acquisition.rst b/sphinx/source/acquisition.rst index 79f529826a..a3c5eaeb5a 100644 --- a/sphinx/source/acquisition.rst +++ b/sphinx/source/acquisition.rst @@ -21,6 +21,11 @@ Analytic Acquisition Function API .. autoclass:: AnalyticAcquisitionFunction :members: +Acquisition Function Wrapper API +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.acquisition.wrapper + :members: + Cached Cholesky Acquisition Function API ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.cached_cholesky @@ -65,7 +70,7 @@ Multi-Objective Analytic Acquisition Functions .. automodule:: botorch.acquisition.multi_objective.analytic :members: :exclude-members: MultiObjectiveAnalyticAcquisitionFunction - + Multi-Objective Joint Entropy Search Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.joint_entropy_search @@ -86,7 +91,7 @@ Multi-Objective Multi-Fidelity Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.multi_fidelity :members: - + Multi-Objective Predictive Entropy Search Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search diff --git a/test/acquisition/test_fixed_feature.py b/test/acquisition/test_fixed_feature.py index 8dcc02f1df..b8f570e7e1 100644 --- a/test/acquisition/test_fixed_feature.py +++ b/test/acquisition/test_fixed_feature.py @@ -87,7 +87,7 @@ def test_fixed_features(self): qEI_ff.set_X_pending(X_pending[..., :-1]) self.assertAllClose(qEI.X_pending, X_pending) # test setting to None - qEI_ff.X_pending = None + qEI_ff.set_X_pending(None) self.assertIsNone(qEI_ff.X_pending) # test gradient diff --git a/test/acquisition/test_proximal.py b/test/acquisition/test_proximal.py index 795daa1b34..e17536ddd0 100644 --- a/test/acquisition/test_proximal.py +++ b/test/acquisition/test_proximal.py @@ -209,9 +209,15 @@ def test_proximal(self): # test for x_pending points pending_acq = DummyAcquisitionFunction(model) - pending_acq.set_X_pending(torch.rand(3, 3, device=self.device, dtype=dtype)) + X_pending = torch.rand(3, 3, device=self.device, dtype=dtype) + pending_acq.set_X_pending(X_pending) with self.assertRaises(UnsupportedError): ProximalAcquisitionFunction(pending_acq, proximal_weights) + # test setting pending points + pending_acq.set_X_pending(None) + af = ProximalAcquisitionFunction(pending_acq, proximal_weights) + with self.assertRaises(UnsupportedError): + af.set_X_pending(X_pending) # test model with multi-batch training inputs train_X = torch.rand(5, 2, 3, device=self.device, dtype=dtype) diff --git a/test/acquisition/test_wrapper.py b/test/acquisition/test_wrapper.py new file mode 100644 index 0000000000..e35175fb9b --- /dev/null +++ b/test/acquisition/test_wrapper.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from botorch.acquisition.analytic import ExpectedImprovement +from botorch.acquisition.monte_carlo import qExpectedImprovement +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper +from botorch.exceptions.errors import UnsupportedError +from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior + + +class DummyWrapper(AbstractAcquisitionFunctionWrapper): + def forward(self, X): + return self.acq_func(X) + + +class TestAbstractAcquisitionFunctionWrapper(BotorchTestCase): + def test_abstract_acquisition_function_wrapper(self): + for dtype in (torch.float, torch.double): + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, dtype=dtype, device=self.device), + variance=torch.ones(1, 1, dtype=dtype, device=self.device), + ) + ) + acq_func = ExpectedImprovement(model=mm, best_f=-1.0) + wrapped_af = DummyWrapper(acq_function=acq_func) + self.assertIs(wrapped_af.acq_func, acq_func) + # test forward + X = torch.rand(1, 1, dtype=dtype, device=self.device) + with torch.no_grad(): + wrapped_val = wrapped_af(X) + af_val = acq_func(X) + self.assertEqual(wrapped_val.item(), af_val.item()) + + # test X_pending + with self.assertRaises(ValueError): + self.assertIsNone(wrapped_af.X_pending) + with self.assertRaises(UnsupportedError): + wrapped_af.set_X_pending(X) + acq_func = qExpectedImprovement(model=mm, best_f=-1.0) + wrapped_af = DummyWrapper(acq_function=acq_func) + self.assertIsNone(wrapped_af.X_pending) + wrapped_af.set_X_pending(X) + self.assertTrue(torch.equal(X, wrapped_af.X_pending)) + self.assertTrue(torch.equal(X, acq_func.X_pending)) + wrapped_af.set_X_pending(None) + self.assertIsNone(wrapped_af.X_pending) + self.assertIsNone(acq_func.X_pending) From 8b49e5c210a5a02ffec417f459fcaa3018740535 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Tue, 14 Feb 2023 16:31:42 -0800 Subject: [PATCH 02/17] Add isinstance_af Summary: Creates a new helper method for checking both if a given AF is an instance of a class or if the given AF wraps a base AF that is an instance of a class Differential Revision: D43127722 fbshipit-source-id: 9f5f31b991f15f2b32931f1b9625422c7907495d --- botorch/acquisition/utils.py | 17 ++++++++-- test/acquisition/test_utils.py | 61 +++++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 486fdd0cff..ccbbf471b2 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -11,7 +11,7 @@ from __future__ import annotations import math -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from botorch.acquisition import analytic, monte_carlo, multi_objective # noqa F401 @@ -22,6 +22,7 @@ MCAcquisitionObjective, PosteriorTransform, ) +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from botorch.exceptions.errors import UnsupportedError from botorch.models.fully_bayesian import MCMC_DIM from botorch.models.model import Model @@ -253,6 +254,18 @@ def objective(Y: Tensor, X: Optional[Tensor] = None): return -(lb.clamp_max(0.0)) +def isinstance_af( + __obj: object, + __class_or_tuple: Union[type, tuple[Union[type, tuple[Any, ...]], ...]], +) -> bool: + r"""A variant of isinstance first checks for the acq_func attribute on wrapped acquisition functions.""" + if isinstance(__obj, AbstractAcquisitionFunctionWrapper): + isinstance_base_af = isinstance(__obj.acq_func, __class_or_tuple) + else: + isinstance_base_af = False + return isinstance_base_af or isinstance(__obj, __class_or_tuple) + + def is_nonnegative(acq_function: AcquisitionFunction) -> bool: r"""Determine whether a given acquisition function is non-negative. @@ -267,7 +280,7 @@ def is_nonnegative(acq_function: AcquisitionFunction) -> bool: >>> qEI = qExpectedImprovement(model, best_f=0.1) >>> is_nonnegative(qEI) # returns True """ - return isinstance( + return isinstance_af( acq_function, ( analytic.ExpectedImprovement, diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index d12b5f6da4..39b8017ea2 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -8,7 +8,8 @@ from unittest import mock import torch -from botorch.acquisition import monte_carlo +from botorch.acquisition import analytic, monte_carlo, multi_objective +from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction from botorch.acquisition.multi_objective import ( MCMultiOutputObjective, monte_carlo as moo_monte_carlo, @@ -18,10 +19,13 @@ MCAcquisitionObjective, ScalarizedPosteriorTransform, ) +from botorch.acquisition.proximal import ProximalAcquisitionFunction from botorch.acquisition.utils import ( expand_trace_observations, get_acquisition_function, get_infeasible_cost, + is_nonnegative, + isinstance_af, project_to_sample_points, project_to_target_fidelity, prune_inferior_points, @@ -606,6 +610,61 @@ def test_get_infeasible_cost(self): self.assertAllClose(M4, torch.tensor([1.0], **tkwargs)) +class TestIsNonnegative(BotorchTestCase): + def test_is_nonnegative(self): + nonneg_afs = ( + analytic.ExpectedImprovement, + analytic.ConstrainedExpectedImprovement, + analytic.ProbabilityOfImprovement, + analytic.NoisyExpectedImprovement, + monte_carlo.qExpectedImprovement, + monte_carlo.qNoisyExpectedImprovement, + monte_carlo.qProbabilityOfImprovement, + multi_objective.analytic.ExpectedHypervolumeImprovement, + multi_objective.monte_carlo.qExpectedHypervolumeImprovement, + multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement, + ) + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, device=self.device), + variance=torch.ones(1, 1, device=self.device), + ) + ) + acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) + with mock.patch( + "botorch.acquisition.utils.isinstance_af", return_value=True + ) as mock_isinstance_af: + self.assertTrue(is_nonnegative(acq_function=acq_func)) + mock_isinstance_af.assert_called_once() + cargs, _ = mock_isinstance_af.call_args + self.assertIs(cargs[0], acq_func) + self.assertEqual(cargs[1], nonneg_afs) + acq_func = analytic.UpperConfidenceBound(model=mm, beta=2.0) + self.assertFalse(is_nonnegative(acq_function=acq_func)) + + +class TestIsinstanceAf(BotorchTestCase): + def test_isinstance_af(self): + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, device=self.device), + variance=torch.ones(1, 1, device=self.device), + ) + ) + acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) + self.assertTrue(isinstance_af(acq_func, analytic.ExpectedImprovement)) + self.assertFalse(isinstance_af(acq_func, analytic.UpperConfidenceBound)) + wrapped_af = FixedFeatureAcquisitionFunction( + acq_function=acq_func, d=2, columns=[1], values=[0.0] + ) + # test base af class + self.assertTrue(isinstance_af(wrapped_af, analytic.ExpectedImprovement)) + self.assertFalse(isinstance_af(wrapped_af, analytic.UpperConfidenceBound)) + # test wrapper class + self.assertTrue(isinstance_af(wrapped_af, FixedFeatureAcquisitionFunction)) + self.assertFalse(isinstance_af(wrapped_af, ProximalAcquisitionFunction)) + + class TestPruneInferiorPoints(BotorchTestCase): def test_prune_inferior_points(self): for dtype in (torch.float, torch.double): From 7ce1389269f4c33b7e5a675fc6756c108655c3a8 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Tue, 14 Feb 2023 16:32:02 -0800 Subject: [PATCH 03/17] probabilistic reparameterization (#1533) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1533 Probabilistic reparameterization Differential Revision: D41629217 fbshipit-source-id: a6067c73ce534daf6f6a180fc49720f305827d58 --- .../probabilistic_reparameterization.py | 541 +++++++++++++++++ botorch/models/transforms/factory.py | 82 +++ botorch/models/transforms/input.py | 572 ++++++++++++++++++ sphinx/source/acquisition.rst | 5 + 4 files changed, 1200 insertions(+) create mode 100644 botorch/acquisition/probabilistic_reparameterization.py diff --git a/botorch/acquisition/probabilistic_reparameterization.py b/botorch/acquisition/probabilistic_reparameterization.py new file mode 100644 index 0000000000..5c6428985e --- /dev/null +++ b/botorch/acquisition/probabilistic_reparameterization.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Probabilistic Reparameterization (with gradients) using Monte Carlo estimators. + +See [Daulton2022bopr]_ for details. +""" + +from contextlib import ExitStack +from typing import Dict, List, Optional + +import torch +from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper +from botorch.models.transforms.factory import ( + get_probabilistic_reparameterization_input_transform, +) + +from botorch.models.transforms.input import ( + ChainedInputTransform, + InputTransform, + OneHotToNumeric, +) +from torch import Tensor +from torch.autograd import Function +from torch.nn.functional import one_hot + + +class _MCProbabilisticReparameterization(Function): + r"""Evaluate the acquisition function via probabistic reparameterization. + + This uses a score function gradient estimator. See [Daulton2022bopr]_ for details. + """ + + @staticmethod + def forward( + ctx, + X: Tensor, + acq_function: AcquisitionFunction, + input_tf: InputTransform, + batch_limit: Optional[int], + integer_indices: Tensor, + cont_indices: Tensor, + categorical_indices: Tensor, + use_ma_baseline: bool, + one_hot_to_numeric: Optional[OneHotToNumeric], + ma_counter: Optional[Tensor], + ma_hidden: Optional[Tensor], + ma_decay: Optional[float], + ): + """Evaluate the expectation of the acquisition function under + probabilistic reparameterization. Compute this in chunks of size + batch_limit to enable scaling to large numbers of samples from the + proposal distribution. + """ + with ExitStack() as es: + if ctx.needs_input_grad[0]: + es.enter_context(torch.enable_grad()) + if cont_indices.shape[0] > 0: + # only require gradient for continuous parameters + ctx.cont_X = X[..., cont_indices].detach().requires_grad_(True) + cont_idx = 0 + cols = [] + for col in range(X.shape[-1]): + # cont_indices is sorted in ascending order + if ( + cont_idx < cont_indices.shape[0] + and col == cont_indices[cont_idx] + ): + cols.append(ctx.cont_X[..., cont_idx]) + cont_idx += 1 + else: + cols.append(X[..., col]) + X = torch.stack(cols, dim=-1) + else: + ctx.cont_X = None + ctx.discrete_indices = input_tf["round"].discrete_indices + ctx.cont_indices = cont_indices + ctx.categorical_indices = categorical_indices + ctx.ma_counter = ma_counter + ctx.ma_hidden = ma_hidden + ctx.X_shape = X.shape + tilde_x_samples = input_tf(X.unsqueeze(-3)) + # save the rounding component + + rounding_component = tilde_x_samples.clone() + if integer_indices.shape[0] > 0: + X_integer_params = X[..., integer_indices].unsqueeze(-3) + rounding_component[..., integer_indices] = ( + (tilde_x_samples[..., integer_indices] - X_integer_params > 0) + | (X_integer_params == 1) + ).to(tilde_x_samples) + if categorical_indices.shape[0] > 0: + rounding_component[..., categorical_indices] = tilde_x_samples[ + ..., categorical_indices + ] + ctx.rounding_component = rounding_component[..., ctx.discrete_indices] + ctx.tau = input_tf["round"].tau + if hasattr(input_tf["round"], "base_samples"): + ctx.base_samples = input_tf["round"].base_samples.detach() + # save the probabilities + if "unnormalize" in input_tf: + unnormalized_X = input_tf["unnormalize"](X) + else: + unnormalized_X = X + # this is only for the integer parameters + ctx.prob = input_tf["round"].get_rounding_prob(unnormalized_X) + + if categorical_indices.shape[0] > 0: + ctx.base_samples_categorical = input_tf[ + "round" + ].base_samples_categorical.clone() + # compute the acquisition function where inputs are rounded according to base_samples < prob + ctx.tilde_x_samples = tilde_x_samples + ctx.use_ma_baseline = use_ma_baseline + acq_values_list = [] + start_idx = 0 + if one_hot_to_numeric is not None: + tilde_x_samples = one_hot_to_numeric(tilde_x_samples) + + while start_idx < tilde_x_samples.shape[-3]: + end_idx = min(start_idx + batch_limit, tilde_x_samples.shape[-3]) + acq_values = acq_function(tilde_x_samples[..., start_idx:end_idx, :, :]) + acq_values_list.append(acq_values) + start_idx += batch_limit + acq_values = torch.cat(acq_values_list, dim=-1) + ctx.mean_acq_values = acq_values.mean( + dim=-1 + ) # average over samples from proposal distribution + ctx.acq_values = acq_values + # update moving average baseline + ctx.ma_hidden = ma_hidden.clone() + ctx.ma_counter = ctx.ma_counter.clone() + ctx.ma_decay = ma_decay + # update in place + ma_counter.add_(1) + ma_hidden.sub_((ma_hidden - acq_values.detach().mean()) * (1 - ma_decay)) + return ctx.mean_acq_values.detach() + + @staticmethod + def backward(ctx, grad_output): + """ + Compute the gradient of the expectation of the acquisition function + with respect to the parameters of the proposal distribution using + Monte Carlo. + """ + # this is overwriting the entire gradient w.r.t. x' + # x' has shape batch_shape x q x d + if ctx.needs_input_grad[0]: + acq_values = ctx.acq_values + mean_acq_values = ctx.mean_acq_values + cont_indices = ctx.cont_indices + discrete_indices = ctx.discrete_indices + rounding_component = ctx.rounding_component + # retrieve only the ordinal parameters + expanded_acq_values = acq_values.view(*acq_values.shape, 1, 1).expand( + acq_values.shape + rounding_component.shape[-2:] + ) + prob = ctx.prob.unsqueeze(-3) + if not ctx.use_ma_baseline: + sample_level = expanded_acq_values * (rounding_component - prob) + else: + # use reinforce with the moving average baseline + if ctx.ma_counter == 0: + baseline = 0.0 + else: + baseline = ctx.ma_hidden / ( + 1.0 - torch.pow(ctx.ma_decay, ctx.ma_counter) + ) + sample_level = (expanded_acq_values - baseline) * ( + rounding_component - prob + ) + + grads = (sample_level / ctx.tau).mean(dim=-3) + + new_grads = ( + grad_output.view( + *grad_output.shape, + *[1 for _ in range(grads.ndim - grad_output.ndim)], + ) + .expand(*grad_output.shape, *ctx.X_shape[-2:]) + .clone() + ) + # multiply upstream grad_output by new gradients + new_grads[..., discrete_indices] *= grads + # use autograd for gradients w.r.t. the continuous parameters + if ctx.cont_X is not None: + auto_grad = torch.autograd.grad( + # note: this multiplies the gradient of mean_acq_values w.r.t to input + # by grad_output + mean_acq_values, + ctx.cont_X, + grad_outputs=grad_output, + )[0] + # overwrite grad_output since the previous step already applied the chain rule + new_grads[..., cont_indices] = auto_grad + return ( + new_grads, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + return None, None, None, None, None, None, None, None, None, None, None, None + + +class AbstractProbabilisticReparameterization(AbstractAcquisitionFunctionWrapper): + r"""Acquisition Function Wrapper that leverages probabilistic reparameterization. + + The forward method is abstract and must be implemented. + + See [Daulton2022bopr]_ for details. + """ + + input_transform: ChainedInputTransform + + def __init__( + self, + acq_function: AcquisitionFunction, + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + batch_limit: int = 32, + apply_numeric: bool = False, + **kwargs, + ) -> None: + r"""Initialize probabilistic reparameterization (PR). + + Args: + acq_function: The acquisition function. + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + batch_limit: The chunk size used in evaluating PR to limit memory + overhead. + apply_numeric: A boolean indicated if categoricals should be supplied + to the underlying acquisition function in numeric representation. + """ + if categorical_features is None and integer_indices is None: + raise NotImplementedError( + "categorical_features or integer indices must be provided." + ) + super().__init__(acq_function=acq_function) + self.batch_limit = batch_limit + + if apply_numeric: + self.one_hot_to_numeric = OneHotToNumeric( + categorical_features=categorical_features, + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=False, + ) + self.one_hot_to_numeric.eval() + else: + self.one_hot_to_numeric = None + discrete_indices = [] + if integer_indices is not None: + self.register_buffer( + "integer_indices", + torch.tensor( + integer_indices, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer("integer_bounds", one_hot_bounds[:, integer_indices]) + discrete_indices.extend(integer_indices) + else: + self.register_buffer( + "integer_indices", + torch.tensor([], dtype=torch.long, device=one_hot_bounds.device), + ) + self.register_buffer( + "integer_bounds", + torch.tensor( + [], dtype=one_hot_bounds.dtype, device=one_hot_bounds.device + ), + ) + dim = one_hot_bounds.shape[1] + if categorical_features is not None and len(categorical_features) > 0: + categorical_indices = list(range(min(categorical_features.keys()), dim)) + discrete_indices.extend(categorical_indices) + self.register_buffer( + "categorical_indices", + torch.tensor( + categorical_indices, + dtype=torch.long, + device=one_hot_bounds.device, + ), + ) + self.categorical_features = categorical_features + else: + self.register_buffer( + "categorical_indices", + torch.tensor( + [], + dtype=torch.long, + device=one_hot_bounds.device, + ), + ) + + self.register_buffer( + "cont_indices", + torch.tensor( + sorted(set(range(dim)) - set(discrete_indices)), + dtype=torch.long, + device=one_hot_bounds.device, + ), + ) + self.model = acq_function.model # for sample_around_best heuristic + # moving average baseline + self.register_buffer( + "ma_counter", + torch.zeros(1, dtype=one_hot_bounds.dtype, device=one_hot_bounds.device), + ) + self.register_buffer( + "ma_hidden", + torch.zeros(1, dtype=one_hot_bounds.dtype, device=one_hot_bounds.device), + ) + self.register_buffer( + "ma_baseline", + torch.zeros(1, dtype=one_hot_bounds.dtype, device=one_hot_bounds.device), + ) + + def sample_candidates(self, X: Tensor) -> Tensor: + if "unnormalize" in self.input_transform: + unnormalized_X = self.input_transform["unnormalize"](X) + else: + unnormalized_X = X.clone() + prob = self.input_transform["round"].get_rounding_prob(X=unnormalized_X) + discrete_idx = 0 + for i in self.integer_indices: + p = prob[..., discrete_idx] + rounding_component = torch.distributions.Bernoulli(probs=p).sample() + unnormalized_X[..., i] = unnormalized_X[..., i].floor() + rounding_component + discrete_idx += 1 + if len(self.integer_indices) > 0: + unnormalized_X[..., self.integer_indices] = torch.minimum( + torch.maximum( + unnormalized_X[..., self.integer_indices], self.integer_bounds[0] + ), + self.integer_bounds[1], + ) + # this is the starting index for the categoricals in unnormalized_X + raw_idx = self.cont_indices.shape[0] + discrete_idx + if self.categorical_indices.shape[0] > 0: + for cardinality in self.categorical_features.values(): + discrete_end = discrete_idx + cardinality + p = prob[..., discrete_idx:discrete_end] + z = one_hot( + torch.distributions.Categorical(probs=p).sample(), + num_classes=cardinality, + ) + raw_end = raw_idx + cardinality + unnormalized_X[..., raw_idx:raw_end] = z + discrete_idx = discrete_end + raw_idx = raw_end + # normalize X + if "normalize" in self.input_transform: + return self.input_transform["normalize"](unnormalized_X) + return unnormalized_X + + +class AnalyticProbabilisticReparameterization(AbstractProbabilisticReparameterization): + """Analytic probabilistic reparameterization. + + Note: this is only reasonable from a computation perspective for relatively + small numbers of discrete options (probably less than a few thousand). + """ + + def __init__( + self, + acq_function: AcquisitionFunction, + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + batch_limit: int = 32, + apply_numeric: bool = False, + tau: float = 0.1, + ) -> None: + """Initialize probabilistic reparameterization (PR). + + Args: + acq_function: The acquisition function. + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + batch_limit: The chunk size used in evaluating PR to limit memory + overhead. + apply_numeric: A boolean indicated if categoricals should be supplied + to the underlying acquisition function in numeric representation. + tau: The temperature parameter used to determine the probabilities. + + """ + super().__init__( + acq_function=acq_function, + integer_indices=integer_indices, + one_hot_bounds=one_hot_bounds, + categorical_features=categorical_features, + batch_limit=batch_limit, + apply_numeric=apply_numeric, + ) + # create input transform + # need to compute cross product of discrete options and weights + self.input_transform = get_probabilistic_reparameterization_input_transform( + one_hot_bounds=one_hot_bounds, + use_analytic=True, + integer_indices=integer_indices, + categorical_features=categorical_features, + tau=tau, + ) + + def forward(self, X: Tensor) -> Tensor: + r"""Evaluate PR.""" + X_discrete_all = self.input_transform(X.unsqueeze(-3)) + acq_values_list = [] + start_idx = 0 + if self.one_hot_to_numeric is not None: + X_discrete_all = self.one_hot_to_numeric(X_discrete_all) + if X.shape[-2] != 1: + raise NotImplementedError + + # save the probabilities + if "unnormalize" in self.input_transform: + unnormalized_X = self.input_transform["unnormalize"](X) + else: + unnormalized_X = X + # this is batch_shape x n_discrete (after squeezing) + probs = self.input_transform["round"].get_probs(X=unnormalized_X).squeeze(-1) + # TODO: filter discrete configs with zero probability + # this would require padding because there may be a different number in each batch. + while start_idx < X_discrete_all.shape[-3]: + end_idx = min(start_idx + self.batch_limit, X_discrete_all.shape[-3]) + acq_values = self.acq_func(X_discrete_all[..., start_idx:end_idx, :, :]) + acq_values_list.append(acq_values) + start_idx += self.batch_limit + # this is batch_shape x n_discrete + acq_values = torch.cat(acq_values_list, dim=-1) + # now weight the acquisition values by probabilities + return (acq_values * probs).sum(dim=-1) + + +class MCProbabilisticReparameterization(AbstractProbabilisticReparameterization): + r"""MC-based probabilistic reparameterization. + + See [Daulton2022bopr]_ for details. + """ + + def __init__( + self, + acq_function: AcquisitionFunction, + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + batch_limit: int = 32, + apply_numeric: bool = False, + mc_samples: int = 128, + use_ma_baseline: bool = True, + tau: float = 0.1, + ma_decay: float = 0.7, + resample: bool = True, + ) -> None: + """Initialize probabilistic reparameterization (PR). + + Args: + acq_function: The acquisition function. + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + batch_limit: The chunk size used in evaluating PR to limit memory + overhead. + apply_numeric: A boolean indicated if categoricals should be supplied + to the underlying acquisition function in numeric representation. + mc_samples: The number of MC samples for MC probabilistic + reparameterization. + use_ma_baseline: A boolean indicating whether to use a moving average + baseline for variance reduction. + tau: The temperature parameter used to determine the probabilities. + ma_decay: The decay parameter in the moving average baseline. + Default: 0.7 + resample: A boolean indicating whether to resample with MC + probabilistic reparameterization on each forward pass. + + """ + super().__init__( + acq_function=acq_function, + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + batch_limit=batch_limit, + apply_numeric=apply_numeric, + ) + if self.batch_limit is None: + self.batch_limit = mc_samples + self.use_ma_baseline = use_ma_baseline + self._pr_acq_function = _MCProbabilisticReparameterization() + # create input transform + self.input_transform = get_probabilistic_reparameterization_input_transform( + integer_indices=integer_indices, + one_hot_bounds=one_hot_bounds, + categorical_features=categorical_features, + mc_samples=mc_samples, + tau=tau, + resample=resample, + ) + self.ma_decay = ma_decay + + def forward(self, X: Tensor) -> Tensor: + r"""Evaluate MC probabilistic reparameterization.""" + return self._pr_acq_function.apply( + X, + self.acq_func, + self.input_transform, + self.batch_limit, + self.integer_indices, + self.cont_indices, + self.categorical_indices, + self.use_ma_baseline, + self.one_hot_to_numeric, + self.ma_counter, + self.ma_hidden, + self.ma_decay, + ) diff --git a/botorch/models/transforms/factory.py b/botorch/models/transforms/factory.py index 847fdf1b7c..486dbc3125 100644 --- a/botorch/models/transforms/factory.py +++ b/botorch/models/transforms/factory.py @@ -10,7 +10,9 @@ from typing import Dict, List, Optional from botorch.models.transforms.input import ( + AnalyticProbabilisticReparameterizationInputTransform, ChainedInputTransform, + MCProbabilisticReparameterizationInputTransform, Normalize, OneHotToNumeric, Round, @@ -123,3 +125,83 @@ def get_rounding_input_transform( tf.to(dtype=one_hot_bounds.dtype, device=one_hot_bounds.device) tf.eval() return tf + + +def get_probabilistic_reparameterization_input_transform( + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + use_analytic: bool = False, + mc_samples: int = 128, + resample: bool = False, + tau: float = 0.1, +) -> ChainedInputTransform: + r"""Construct InputTransform for Probabilistic Reparameterization. + + Note: this is intended to be used only for acquisition optimization + in via the AnalyticProbabilisticReparameterization and + MCProbabilisticReparameterization classes. This is not intended to be + attached to a botorch Model. + + See [Daulton2022bopr]_ for details. + + Args: + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer parameters + categorical_features: A dictionary mapping indices to cardinalities + for the categorical features. + use_analytic: A boolean indicating whether to use analytic + probabilistic reparameterization. + mc_samples: The number of MC samples for MC probabilistic + reparameterization. + resample: A boolean indicating whether to resample with MC + probabilistic reparameterization on each forward pass. + tau: The temperature parameter used to determine the probabilities. + + Returns: + The probabilistic reparameterization input transformation. + """ + tfs = OrderedDict() + if integer_indices is not None and len(integer_indices) > 0: + # unnormalize to integer space + tfs["unnormalize"] = Normalize( + d=one_hot_bounds.shape[1], + bounds=one_hot_bounds, + indices=integer_indices, + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=False, + reverse=True, + ) + if use_analytic: + tfs["round"] = AnalyticProbabilisticReparameterizationInputTransform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + tau=tau, + ) + else: + tfs["round"] = MCProbabilisticReparameterizationInputTransform( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + resample=resample, + mc_samples=mc_samples, + tau=tau, + ) + if integer_indices is not None and len(integer_indices) > 0: + # normalize to unit cube + tfs["normalize"] = Normalize( + d=one_hot_bounds.shape[1], + bounds=one_hot_bounds, + indices=integer_indices, + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=False, + reverse=False, + ) + tf = ChainedInputTransform(**tfs) + tf.eval() + return tf diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 09310163b5..0bc649dedf 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -25,6 +25,7 @@ from botorch.models.transforms.utils import subset_transform from botorch.models.utils import fantasize from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE +from botorch.utils.sampling import draw_sobol_samples from gpytorch import Module as GPyTorchModule from gpytorch.constraints import GreaterThan from gpytorch.priors import Prior @@ -1503,3 +1504,574 @@ def equals(self, other: InputTransform) -> bool: and (self.transform_on_fantasize == other.transform_on_fantasize) and self.categorical_features == other.categorical_features ) + + +class AnalyticProbabilisticReparameterizationInputTransform(InputTransform, Module): + r"""An input transform to prepare inputs for analytic PR. + + See [Daulton2022bopr]_ for details. + + This will typically be used in conjunction with normalization as + follows: + + In eval() mode (i.e. after training), the inputs pass + would typically be normalized to the unit cube (e.g. during candidate + optimization). + 1. These are unnormalized back to the raw input space. + 2. The discrete values are created. + 3. All values are normalized to the unitcube. + + TODO: consolidate this with MCProbabilisticReparameterizationInputTransform. + + """ + + def __init__( + self, + one_hot_bounds: Tensor = None, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + transform_on_train: bool = False, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + tau: float = 0.1, + ) -> None: + r"""Initialize transform. + + Args: + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer inputs. + categorical_features: The indices and cardinality of + each categorical feature. The features are assumed + to be one-hot encoded. TODO: generalize to support + alternative representations. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: A boolean indicating whether to apply the + transform when called from within a `fantasize` call. Default: True. + mc_samples: The number of MC samples. + resample: A boolean indicating whether to resample base samples + at each forward pass. + tau: The temperature parameter. + """ + super().__init__() + if integer_indices is None and categorical_features is None: + raise ValueError( + "integer_indices and/or categorical_features must be provided." + ) + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + discrete_indices = [] + if integer_indices is not None and len(integer_indices) > 0: + self.register_buffer( + "integer_indices", + torch.tensor( + integer_indices, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer("integer_bounds", one_hot_bounds[:, integer_indices]) + discrete_indices += integer_indices + else: + self.integer_indices = None + self.categorical_features = categorical_features + if self.categorical_features is not None: + self.categorical_start_idx = min(self.categorical_features.keys()) + # check that the trailing dimensions are categoricals + end = self.categorical_start_idx + err_msg = ( + f"{self.__class__.__name__} requires that the categorical " + "parameters are the rightmost elements." + ) + for start, card in self.categorical_features.items(): + # the end of one one-hot representation should be followed + # by the start of the next + if end != start: + raise ValueError(err_msg) + end = start + card + if end != one_hot_bounds.shape[1]: + # check end + raise ValueError(err_msg) + categorical_starts = [] + categorical_ends = [] + if self.categorical_features is not None: + start = None + for i, n_categories in categorical_features.items(): + if start is None: + start = i + end = start + n_categories + categorical_starts.append(start) + categorical_ends.append(end) + discrete_indices += list(range(start, end)) + start = end + self.register_buffer( + "discrete_indices", + torch.tensor( + discrete_indices, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer( + "categorical_starts", + torch.tensor( + categorical_starts, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer( + "categorical_ends", + torch.tensor( + categorical_ends, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.tau = tau + # create cartesian product of discrete options + discrete_options = [] + dim = one_hot_bounds.shape[1] + # get number of discrete parameters + num_discrete_params = 0 + if self.integer_indices is not None: + num_discrete_params += self.integer_indices.shape[0] + if self.categorical_features is not None: + num_discrete_params += len(self.categorical_features) + # add zeros for continuous params to simplify code + for _ in range(dim - len(discrete_indices)): + discrete_options.append( + torch.zeros( + 1, + dtype=torch.long, + device=one_hot_bounds.device, + ) + ) + if integer_indices is not None: + for i in range(self.integer_bounds.shape[-1]): + discrete_options.append( + torch.arange( + self.integer_bounds[0, i], + self.integer_bounds[1, i] + 1, + dtype=torch.long, + device=one_hot_bounds.device, + ) + ) + categorical_start_idx = len(discrete_options) + if categorical_features is not None: + for idx in sorted(categorical_features.keys()): + cardinality = categorical_features[idx] + discrete_options.append( + torch.arange( + cardinality, dtype=torch.long, device=one_hot_bounds.device + ) + ) + # categoricals are in numeric representation + all_discrete_options = torch.cartesian_prod(*discrete_options) + # one-hot encode the categoricals + if categorical_features is not None and len(categorical_features) > 0: + X_categ = torch.empty( + *all_discrete_options.shape[:-1], sum(categorical_features.values()) + ) + start = 0 + for i, (idx, cardinality) in enumerate( + sorted(categorical_features.items(), key=lambda kv: kv[0]) + ): + start = idx - categorical_start_idx + X_categ[..., start : start + cardinality] = one_hot( + all_discrete_options[..., i], + num_classes=cardinality, + ).to(X_categ) + all_discrete_options = torch.cat( + [all_discrete_options[..., : -len(categorical_features)], X_categ], + dim=-1, + ) + self.register_buffer("all_discrete_options", all_discrete_options) + + def get_rounding_prob(self, X: Tensor) -> Tensor: + # todo consolidate this the MCProbabilisticReparameterizationInputTransform + X_prob = X.detach().clone() + if self.integer_indices is not None: + # compute probabilities for integers + X_int = X_prob[..., self.integer_indices] + X_int_abs = X_int.abs() + offset = X_int_abs.floor() + if self.tau is not None: + X_prob[..., self.integer_indices] = torch.sigmoid( + (X_int_abs - offset - 0.5) / self.tau + ) + else: + X_prob[..., self.integer_indices] = X_int_abs - offset + # compute probabilities for categoricals + for start, end in zip(self.categorical_starts, self.categorical_ends): + X_categ = X_prob[..., start:end] + if self.tau is not None: + X_prob[..., start:end] = torch.softmax( + (X_categ - 0.5) / self.tau, dim=-1 + ) + else: + X_prob[..., start:end] = X_categ / X_categ.sum(dim=-1) + return X_prob[..., self.discrete_indices] + + def get_probs(self, X: Tensor) -> Tensor: + """ + Args: + X: a `batch_shape x n x d`-dim tensor + + Returns: + A `batch_shape x n_discrete x n`-dim tensors of probabilities of each discrete config under X. + """ + # note this method should be differentiable + X_prob = torch.ones( + *X.shape[:-2], + self.all_discrete_options.shape[0], + X.shape[-2], + dtype=X.dtype, + device=X.device, + ) + # n_discrete x batch_shape x n x d + all_discrete_options = self.all_discrete_options.view( + *([1] * (X.ndim - 2)), self.all_discrete_options.shape[0], *X.shape[-2:] + ).expand(*X.shape[:-2], self.all_discrete_options.shape[0], *X.shape[-2:]) + X = X.unsqueeze(-3) + if self.integer_indices is not None: + # compute probabilities for integers + X_int = X[..., self.integer_indices] + X_int_abs = X_int.abs() + offset = X_int_abs.floor() + # note we don't actually need the sigmoid here + X_prob_int = torch.sigmoid((X_int_abs - offset - 0.5) / self.tau) + # X_prob_int = X_int_abs - offset + for int_idx, idx in enumerate(self.integer_indices): + offset_i = offset[..., int_idx] + all_discrete_i = all_discrete_options[..., idx] + diff = (offset_i + 1) - all_discrete_i + round_up_mask = diff == 0 + round_down_mask = diff == 1 + neither_mask = ~(round_up_mask | round_down_mask) + prob = X_prob_int[..., int_idx].expand(round_up_mask.shape) + # need to be careful with in-place ops here for autograd + X_prob[round_up_mask] = X_prob[round_up_mask] * prob[round_up_mask] + X_prob[round_down_mask] = X_prob[round_down_mask] * ( + 1 - prob[round_down_mask] + ) + X_prob[neither_mask] = X_prob[neither_mask] * 0 + + # compute probabilities for categoricals + for start, end in zip(self.categorical_starts, self.categorical_ends): + X_categ = X[..., start:end] + X_prob_c = torch.softmax((X_categ - 0.5) / self.tau, dim=-1).expand( + *X_categ.shape[:-3], all_discrete_options.shape[-3], *X_categ.shape[-2:] + ) + for i in range(X_prob_c.shape[-1]): + mask = all_discrete_options[..., start + i] == 1 + X_prob[mask] = X_prob[mask] * X_prob_c[..., i][mask] + + return X_prob + + def transform(self, X: Tensor) -> Tensor: + r"""Round the inputs. + + This is not sample-path differentiable. + + Args: + X: A `batch_shape x 1 x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n_discrete x n x d`-dim tensor of rounded inputs. + """ + n_discrete = self.discrete_indices.shape[0] + all_discrete_options = self.all_discrete_options.view( + *([1] * (X.ndim - 3)), self.all_discrete_options.shape[0], *X.shape[-2:] + ).expand(*X.shape[:-3], self.all_discrete_options.shape[0], *X.shape[-2:]) + if X.shape[-1] > n_discrete: + X = X.expand( + *X.shape[:-3], self.all_discrete_options.shape[0], *X.shape[-2:] + ) + return torch.cat( + [X[..., :-n_discrete], all_discrete_options[..., -n_discrete:]], dim=-1 + ) + return all_discrete_options + + def equals(self, other: InputTransform) -> bool: + r"""Check if another input transform is equivalent. + + Args: + other: Another input transform. + + Returns: + A boolean indicating if the other transform is equivalent. + """ + # TODO: update this + return super().equals(other=other) and torch.equal( + self.integer_indices, other.integer_indices + ) + + +class MCProbabilisticReparameterizationInputTransform(InputTransform, Module): + r"""An input transform to prepare inputs for analytic PR. + + See [Daulton2022bopr]_ for details. + + This will typically be used in conjunction with normalization as + follows: + + In eval() mode (i.e. after training), the inputs pass + would typically be normalized to the unit cube (e.g. during candidate + optimization). + 1. These are unnormalized back to the raw input space. + 2. The discrete ordinal valeus are sampled. + 3. All values are normalized to the unitcube. + """ + + def __init__( + self, + one_hot_bounds: Tensor, + integer_indices: Optional[List[int]] = None, + categorical_features: Optional[Dict[int, int]] = None, + transform_on_train: bool = False, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + mc_samples: int = 128, + resample: bool = False, + tau: float = 0.1, + ) -> None: + r"""Initialize transform. + + Args: + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer inputs. + categorical_features: The indices and cardinality of + each categorical feature. The features are assumed + to be one-hot encoded. TODO: generalize to support + alternative representations. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: True. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: A boolean indicating whether to apply the + transform when called from within a `fantasize` call. Default: True. + mc_samples: The number of MC samples. + resample: A boolean indicating whether to resample base samples + at each forward pass. + tau: The temperature parameter. + """ + super().__init__() + if integer_indices is None and categorical_features is None: + raise ValueError( + "integer_indices and/or categorical_features must be provided." + ) + self.transform_on_train = transform_on_train + self.transform_on_eval = transform_on_eval + self.transform_on_fantasize = transform_on_fantasize + discrete_indices = [] + if integer_indices is not None and len(integer_indices) > 0: + self.register_buffer( + "integer_indices", torch.tensor(integer_indices, dtype=torch.long) + ) + discrete_indices += integer_indices + else: + self.integer_indices = None + self.categorical_features = categorical_features + if self.categorical_features is not None: + self.categorical_start_idx = min(self.categorical_features.keys()) + # check that the trailing dimensions are categoricals + end = self.categorical_start_idx + err_msg = ( + f"{self.__class__.__name__} requires that the categorical " + "parameters are the rightmost elements." + ) + for start, card in self.categorical_features.items(): + # the end of one one-hot representation should be followed + # by the start of the next + if end != start: + raise ValueError(err_msg) + end = start + card + if end != one_hot_bounds.shape[1]: + # check end + raise ValueError(err_msg) + categorical_starts = [] + categorical_ends = [] + if self.categorical_features is not None: + start = None + for i, n_categories in categorical_features.items(): + if start is None: + start = i + end = start + n_categories + categorical_starts.append(start) + categorical_ends.append(end) + discrete_indices += list(range(start, end)) + start = end + self.register_buffer( + "discrete_indices", + torch.tensor( + discrete_indices, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer( + "categorical_starts", + torch.tensor( + categorical_starts, dtype=torch.long, device=one_hot_bounds.device + ), + ) + self.register_buffer( + "categorical_ends", + torch.tensor( + categorical_ends, dtype=torch.long, device=one_hot_bounds.device + ), + ) + if integer_indices is None: + self.register_buffer( + "integer_bounds", + torch.tensor([], dtype=torch.long, device=one_hot_bounds.device), + ) + else: + self.register_buffer("integer_bounds", one_hot_bounds[:, integer_indices]) + self.mc_samples = mc_samples + self.resample = resample + self.tau = tau + + def get_rounding_prob(self, X: Tensor) -> Tensor: + X_prob = X.detach().clone() + if self.integer_indices is not None: + # compute probabilities for integers + X_int = X_prob[..., self.integer_indices] + X_int_abs = X_int.abs() + offset = X_int_abs.floor() + if self.tau is not None: + X_prob[..., self.integer_indices] = torch.sigmoid( + (X_int_abs - offset - 0.5) / self.tau + ) + else: + X_prob[..., self.integer_indices] = X_int_abs - offset + # compute probabilities for categoricals + for start, end in zip(self.categorical_starts, self.categorical_ends): + X_categ = X_prob[..., start:end] + if self.tau is not None: + X_prob[..., start:end] = torch.softmax( + (X_categ - 0.5) / self.tau, dim=-1 + ) + else: + X_prob[..., start:end] = X_categ / X_categ.sum(dim=-1) + return X_prob[..., self.discrete_indices] + + def transform(self, X: Tensor) -> Tensor: + r"""Round the inputs. + + This is not sample-path differentiable. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of rounded inputs. + """ + X_expanded = X.expand(*X.shape[:-3], self.mc_samples, *X.shape[-2:]).clone() + X_prob = self.get_rounding_prob(X=X) + if self.integer_indices is not None: + X_int = X[..., self.integer_indices].detach() + assert X.ndim > 1 + if X.ndim == 2: + X.unsqueeze(-1) + if ( + not hasattr(self, "base_samples") + or self.base_samples.shape[-2:] != X_int.shape[-2:] + or self.resample + ): + # construct sobol base samples + bounds = torch.zeros( + 2, X_int.shape[-1], dtype=X_int.dtype, device=X_int.device + ) + bounds[1] = 1 + self.register_buffer( + "base_samples", + draw_sobol_samples( + bounds=bounds, + n=self.mc_samples, + q=X_int.shape[-2], + seed=torch.randint(0, 100000, (1,)).item(), + ), + ) + X_int_abs = X_int.abs() + # perform exact rounding + is_negative = X_int < 0 + offset = X_int_abs.floor() + prob = X_prob[..., : self.integer_indices.shape[0]] + rounding_component = (prob >= self.base_samples).to( + dtype=X.dtype, + ) + X_abs_rounded = offset + rounding_component + X_int_new = (-1) ** is_negative.to(offset) * X_abs_rounded + # clamp to bounds + X_expanded[..., self.integer_indices] = torch.minimum( + torch.maximum(X_int_new, self.integer_bounds[0]), self.integer_bounds[1] + ) + + # sample for categoricals + if self.categorical_features is not None and len(self.categorical_features) > 0: + if ( + not hasattr(self, "base_samples_categorical") + or self.base_samples_categorical.shape[-2] != X.shape[-2] + or self.resample + ): + bounds = torch.zeros( + 2, len(self.categorical_features), dtype=X.dtype, device=X.device + ) + bounds[1] = 1 + self.register_buffer( + "base_samples_categorical", + draw_sobol_samples( + bounds=bounds, + n=self.mc_samples, + q=X.shape[-2], + seed=torch.randint(0, 100000, (1,)).item(), + ), + ) + + # sample from multinomial as argmin_c [sample_c * exp(-x_c)] + sample_d_start_idx = 0 + X_categ_prob = X_prob + if self.integer_indices is not None: + n_ints = self.integer_indices.shape[0] + if n_ints > 0: + X_categ_prob = X_prob[..., n_ints:] + + for i, cardinality in enumerate(self.categorical_features.values()): + sample_d_end_idx = sample_d_start_idx + cardinality + start = self.categorical_starts[i] + end = self.categorical_ends[i] + cum_prob = X_categ_prob[ + ..., sample_d_start_idx:sample_d_end_idx + ].cumsum(dim=-1) + categories = ( + ( + (cum_prob > self.base_samples_categorical[..., i : i + 1]) + .long() + .cumsum(dim=-1) + == 1 + ) + .long() + .argmax(dim=-1) + ) + # one-hot encode + X_expanded[..., start:end] = one_hot( + categories, num_classes=cardinality + ).to(X) + sample_d_start_idx = sample_d_end_idx + + return X_expanded + + def equals(self, other: InputTransform) -> bool: + r"""Check if another input transform is equivalent. + + Args: + other: Another input transform. + + Returns: + A boolean indicating if the other transform is equivalent. + """ + return ( + super().equals(other=other) + and (self.resample == other.resample) + and torch.equal(self.base_samples, other.base_samples) + and torch.equal(self.integer_indices, other.integer_indices) + ) diff --git a/sphinx/source/acquisition.rst b/sphinx/source/acquisition.rst index a3c5eaeb5a..a5a429e46b 100644 --- a/sphinx/source/acquisition.rst +++ b/sphinx/source/acquisition.rst @@ -180,6 +180,11 @@ Penalized Acquisition Function Wrapper .. automodule:: botorch.acquisition.penalized :members: +Probabilistic Reparameterization +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.acquisition.probabilistic_reparameterization + :members: + Proximal Acquisition Function Wrapper ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.proximal From f635524ce0d85b9c5b6cd176cd7afc3d57d1c73b Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Fri, 6 Jun 2025 14:40:00 +0100 Subject: [PATCH 04/17] Move is_nonnegative to optim.initializers again --- botorch/acquisition/utils.py | 33 ------------------------------ test/acquisition/test_utils.py | 36 +-------------------------------- test/optim/test_initializers.py | 35 ++++++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 68 deletions(-) diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 70abc1628a..ef1d60c7b9 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -16,8 +16,6 @@ import torch -from botorch.acquisition import analytic, monte_carlo, multi_objective # noqa F401 -from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.objective import ( MCAcquisitionObjective, PosteriorTransform, @@ -262,37 +260,6 @@ def isinstance_af( return isinstance_base_af or isinstance(__obj, __class_or_tuple) -def is_nonnegative(acq_function: AcquisitionFunction) -> bool: - r"""Determine whether a given acquisition function is non-negative. - - Args: - acq_function: The `AcquisitionFunction` instance. - - Returns: - True if `acq_function` is non-negative, False if not, or if the behavior - is unknown (for custom acquisition functions). - - Example: - >>> qEI = qExpectedImprovement(model, best_f=0.1) - >>> is_nonnegative(qEI) # returns True - """ - return isinstance_af( - acq_function, - ( - analytic.ExpectedImprovement, - analytic.ConstrainedExpectedImprovement, - analytic.ProbabilityOfImprovement, - analytic.NoisyExpectedImprovement, - monte_carlo.qExpectedImprovement, - monte_carlo.qNoisyExpectedImprovement, - monte_carlo.qProbabilityOfImprovement, - multi_objective.analytic.ExpectedHypervolumeImprovement, - multi_objective.monte_carlo.qExpectedHypervolumeImprovement, - multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement, - ), - ) - - def _prune_inferior_shared_processing( model: Model, X: Tensor, diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index 15565af85d..fd2a50ce25 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -9,7 +9,7 @@ from unittest.mock import patch import torch -from botorch.acquisition import analytic, monte_carlo, multi_objective +from botorch.acquisition import analytic from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction from botorch.acquisition.objective import ( ExpectationPosteriorTransform, @@ -25,7 +25,6 @@ get_acquisition_function, get_infeasible_cost, get_optimal_samples, - is_nonnegative, isinstance_af, project_to_sample_points, project_to_target_fidelity, @@ -228,39 +227,6 @@ def test_get_infeasible_cost(self): self.assertAllClose(M4, torch.tensor([1.0], **tkwargs)) -class TestIsNonnegative(BotorchTestCase): - def test_is_nonnegative(self): - nonneg_afs = ( - analytic.ExpectedImprovement, - analytic.ConstrainedExpectedImprovement, - analytic.ProbabilityOfImprovement, - analytic.NoisyExpectedImprovement, - monte_carlo.qExpectedImprovement, - monte_carlo.qNoisyExpectedImprovement, - monte_carlo.qProbabilityOfImprovement, - multi_objective.analytic.ExpectedHypervolumeImprovement, - multi_objective.monte_carlo.qExpectedHypervolumeImprovement, - multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement, - ) - mm = MockModel( - MockPosterior( - mean=torch.rand(1, 1, device=self.device), - variance=torch.ones(1, 1, device=self.device), - ) - ) - acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) - with mock.patch( - "botorch.acquisition.utils.isinstance_af", return_value=True - ) as mock_isinstance_af: - self.assertTrue(is_nonnegative(acq_function=acq_func)) - mock_isinstance_af.assert_called_once() - cargs, _ = mock_isinstance_af.call_args - self.assertIs(cargs[0], acq_func) - self.assertEqual(cargs[1], nonneg_afs) - acq_func = analytic.UpperConfidenceBound(model=mm, beta=2.0) - self.assertFalse(is_nonnegative(acq_function=acq_func)) - - class TestIsinstanceAf(BotorchTestCase): def test_isinstance_af(self): mm = MockModel( diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 83d75cd27b..80f2952843 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -11,6 +11,7 @@ from unittest import mock import torch +from botorch.acquisition import analytic, monte_carlo, multi_objective from botorch.acquisition.analytic import PosteriorMean from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction from botorch.acquisition.knowledge_gradient import qKnowledgeGradient @@ -38,6 +39,7 @@ initialize_q_batch, initialize_q_batch_nonneg, initialize_q_batch_topn, + is_nonnegative, sample_perturbed_subset_dims, sample_points_around_best, sample_q_batches_from_polytope, @@ -84,6 +86,39 @@ def test_constraint_check(self) -> None: self.assertAlmostEqual(result, 0.0, delta=1e-6) +class TestIsNonnegative(BotorchTestCase): + def test_is_nonnegative(self): + nonneg_afs = ( + analytic.ExpectedImprovement, + analytic.ConstrainedExpectedImprovement, + analytic.ProbabilityOfImprovement, + analytic.NoisyExpectedImprovement, + monte_carlo.qExpectedImprovement, + monte_carlo.qNoisyExpectedImprovement, + monte_carlo.qProbabilityOfImprovement, + multi_objective.analytic.ExpectedHypervolumeImprovement, + multi_objective.monte_carlo.qExpectedHypervolumeImprovement, + multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement, + ) + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, device=self.device), + variance=torch.ones(1, 1, device=self.device), + ) + ) + acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) + with mock.patch( + "botorch.acquisition.utils.isinstance_af", return_value=True + ) as mock_isinstance_af: + self.assertTrue(is_nonnegative(acq_function=acq_func)) + mock_isinstance_af.assert_called_once() + cargs, _ = mock_isinstance_af.call_args + self.assertIs(cargs[0], acq_func) + self.assertEqual(cargs[1], nonneg_afs) + acq_func = analytic.UpperConfidenceBound(model=mm, beta=2.0) + self.assertFalse(is_nonnegative(acq_function=acq_func)) + + class TestInitializeQBatch(BotorchTestCase): def test_initialize_q_batch_nonneg(self): for dtype in (torch.float, torch.double): From 88c423067b29b40c11ab1679448f2cafc308f603 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Fri, 6 Jun 2025 14:50:40 +0100 Subject: [PATCH 05/17] Fix `FixedFeature` feature --- botorch/acquisition/fixed_feature.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/botorch/acquisition/fixed_feature.py b/botorch/acquisition/fixed_feature.py index 626db57c65..6f789d8934 100644 --- a/botorch/acquisition/fixed_feature.py +++ b/botorch/acquisition/fixed_feature.py @@ -154,19 +154,7 @@ def forward(self, X: Tensor): X_full = self._construct_X_full(X) return self.acq_func(X_full) - @property - def X_pending(self): - r"""Return the `X_pending` of the base acquisition function.""" - try: - return self.acq_func.X_pending - except (ValueError, AttributeError): - raise ValueError( - f"Base acquisition function {type(self.acq_func).__name__} " - "does not have an `X_pending` attribute." - ) - - @X_pending.setter - def X_pending(self, X_pending: Tensor | None): + def set_X_pending(self, X_pending: Tensor | None): r"""Sets the `X_pending` of the base acquisition function.""" if X_pending is not None: full_X_pending = self._construct_X_full(X_pending) From 49b8edeab69761a3eabc9be671a51598a40bd9b8 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Fri, 6 Jun 2025 14:54:22 +0100 Subject: [PATCH 06/17] Fix `PenalizedAcquisition` --- botorch/acquisition/penalized.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/botorch/acquisition/penalized.py b/botorch/acquisition/penalized.py index d981514c10..6fe4c68c45 100644 --- a/botorch/acquisition/penalized.py +++ b/botorch/acquisition/penalized.py @@ -16,10 +16,8 @@ import torch from botorch.acquisition.acquisition import AcquisitionFunction -from botorch.acquisition.analytic import AnalyticAcquisitionFunction from botorch.acquisition.objective import GenericMCObjective from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper -from botorch.exceptions.errors import UnsupportedError from torch import Tensor @@ -234,19 +232,6 @@ def forward(self, X: Tensor) -> Tensor: penalty_term = self.penalty_func(X) return raw_value - self.regularization_parameter * penalty_term - @property - def X_pending(self) -> Tensor | None: - return self.raw_acqf.X_pending - - def set_X_pending(self, X_pending: Tensor | None = None) -> None: - if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction): - self.raw_acqf.set_X_pending(X_pending=X_pending) - else: - raise UnsupportedError( - "The raw acquisition function is Analytic and does not account " - "for X_pending yet." - ) - def group_lasso_regularizer(X: Tensor, groups: list[list[int]]) -> Tensor: r"""Computes the group lasso regularization function for the given point. From a99ac629c1f0f3043d036d5c1b064b7b2cb9b901 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Fri, 6 Jun 2025 15:09:39 +0100 Subject: [PATCH 07/17] Fix patching of isinstance_af in test_initializers --- botorch/optim/initializers.py | 3 ++- test/optim/test_initializers.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index a099ec5e0f..23ed4b6c82 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -32,6 +32,7 @@ qHypervolumeKnowledgeGradient, qMultiFidelityHypervolumeKnowledgeGradient, ) +from botorch.acquisition.utils import isinstance_af from botorch.exceptions.errors import BotorchTensorDimensionError, UnsupportedError from botorch.exceptions.warnings import ( BadInitialCandidatesWarning, @@ -1270,7 +1271,7 @@ def is_nonnegative(acq_function: AcquisitionFunction) -> bool: >>> qEI = qExpectedImprovement(model, best_f=0.1) >>> is_nonnegative(qEI) # returns True """ - return isinstance( + return isinstance_af( acq_function, ( analytic.ExpectedImprovement, diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 80f2952843..341209b5ea 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -108,7 +108,7 @@ def test_is_nonnegative(self): ) acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) with mock.patch( - "botorch.acquisition.utils.isinstance_af", return_value=True + "botorch.optim.initializers.isinstance_af", return_value=True ) as mock_isinstance_af: self.assertTrue(is_nonnegative(acq_function=acq_func)) mock_isinstance_af.assert_called_once() From c103033e2960d6a2967eecefeda709ba19b6ff81 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Fri, 6 Jun 2025 15:46:06 +0100 Subject: [PATCH 08/17] Update types to PEP 604; fix flake8 line length errors --- botorch/acquisition/__init__.py | 6 +++ .../probabilistic_reparameterization.py | 37 ++++++++++--------- botorch/acquisition/wrapper.py | 10 +++-- 3 files changed, 32 insertions(+), 21 deletions(-) diff --git a/botorch/acquisition/__init__.py b/botorch/acquisition/__init__.py index 862897fe11..8d5a0681dd 100644 --- a/botorch/acquisition/__init__.py +++ b/botorch/acquisition/__init__.py @@ -79,12 +79,17 @@ qExpectedUtilityOfBestOption, ) from botorch.acquisition.prior_guided import PriorGuidedAcquisitionFunction +from botorch.acquisition.probabilistic_reparameterization import ( + AnalyticProbabilisticReparameterization, + MCProbabilisticReparameterization, +) from botorch.acquisition.proximal import ProximalAcquisitionFunction __all__ = [ "AcquisitionFunction", "AnalyticAcquisitionFunction", "AnalyticExpectedUtilityOfBestOption", + "AnalyticProbabilisticReparameterization", "ConstrainedExpectedImprovement", "DecoupledAcquisitionFunction", "ExpectedImprovement", @@ -93,6 +98,7 @@ "FixedFeatureAcquisitionFunction", "GenericCostAwareUtility", "InverseCostWeightedUtility", + "MCProbabilisticReparameterization", "NoisyExpectedImprovement", "OneShotAcquisitionFunction", "PairwiseBayesianActiveLearningByDisagreement", diff --git a/botorch/acquisition/probabilistic_reparameterization.py b/botorch/acquisition/probabilistic_reparameterization.py index 5c6428985e..dcc042c67a 100644 --- a/botorch/acquisition/probabilistic_reparameterization.py +++ b/botorch/acquisition/probabilistic_reparameterization.py @@ -11,7 +11,6 @@ """ from contextlib import ExitStack -from typing import Dict, List, Optional import torch from botorch.acquisition.acquisition import AcquisitionFunction @@ -42,15 +41,15 @@ def forward( X: Tensor, acq_function: AcquisitionFunction, input_tf: InputTransform, - batch_limit: Optional[int], + batch_limit: int | None, integer_indices: Tensor, cont_indices: Tensor, categorical_indices: Tensor, use_ma_baseline: bool, - one_hot_to_numeric: Optional[OneHotToNumeric], - ma_counter: Optional[Tensor], - ma_hidden: Optional[Tensor], - ma_decay: Optional[float], + one_hot_to_numeric: OneHotToNumeric | None, + ma_counter: Tensor | None, + ma_hidden: Tensor | None, + ma_decay: float | None, ): """Evaluate the expectation of the acquisition function under probabilistic reparameterization. Compute this in chunks of size @@ -114,7 +113,8 @@ def forward( ctx.base_samples_categorical = input_tf[ "round" ].base_samples_categorical.clone() - # compute the acquisition function where inputs are rounded according to base_samples < prob + # compute the acquisition function where inputs are rounded according + # to base_samples < prob ctx.tilde_x_samples = tilde_x_samples ctx.use_ma_baseline = use_ma_baseline acq_values_list = [] @@ -190,13 +190,14 @@ def backward(ctx, grad_output): # use autograd for gradients w.r.t. the continuous parameters if ctx.cont_X is not None: auto_grad = torch.autograd.grad( - # note: this multiplies the gradient of mean_acq_values w.r.t to input - # by grad_output + # note: this multiplies the gradient of mean_acq_values + # w.r.t to input by grad_output mean_acq_values, ctx.cont_X, grad_outputs=grad_output, )[0] - # overwrite grad_output since the previous step already applied the chain rule + # overwrite grad_output since the previous step already + # applied the chain rule new_grads[..., cont_indices] = auto_grad return ( new_grads, @@ -229,8 +230,8 @@ def __init__( self, acq_function: AcquisitionFunction, one_hot_bounds: Tensor, - integer_indices: Optional[List[int]] = None, - categorical_features: Optional[Dict[int, int]] = None, + integer_indices: list[int] | None = None, + categorical_features: dict[int, int] | None = None, batch_limit: int = 32, apply_numeric: bool = False, **kwargs, @@ -384,8 +385,8 @@ def __init__( self, acq_function: AcquisitionFunction, one_hot_bounds: Tensor, - integer_indices: Optional[List[int]] = None, - categorical_features: Optional[Dict[int, int]] = None, + integer_indices: list[int] | None = None, + categorical_features: dict[int, int] | None = None, batch_limit: int = 32, apply_numeric: bool = False, tau: float = 0.1, @@ -442,8 +443,8 @@ def forward(self, X: Tensor) -> Tensor: unnormalized_X = X # this is batch_shape x n_discrete (after squeezing) probs = self.input_transform["round"].get_probs(X=unnormalized_X).squeeze(-1) - # TODO: filter discrete configs with zero probability - # this would require padding because there may be a different number in each batch. + # TODO: filter discrete configs with zero probability. This would require + # padding because there may be a different number in each batch. while start_idx < X_discrete_all.shape[-3]: end_idx = min(start_idx + self.batch_limit, X_discrete_all.shape[-3]) acq_values = self.acq_func(X_discrete_all[..., start_idx:end_idx, :, :]) @@ -465,8 +466,8 @@ def __init__( self, acq_function: AcquisitionFunction, one_hot_bounds: Tensor, - integer_indices: Optional[List[int]] = None, - categorical_features: Optional[Dict[int, int]] = None, + integer_indices: list[int] | None = None, + categorical_features: dict[int, int] | None = None, batch_limit: int = 32, apply_numeric: bool = False, mc_samples: int = 128, diff --git a/botorch/acquisition/wrapper.py b/botorch/acquisition/wrapper.py index 08dfbd2849..8655790153 100644 --- a/botorch/acquisition/wrapper.py +++ b/botorch/acquisition/wrapper.py @@ -11,7 +11,6 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional from botorch.acquisition.acquisition import AcquisitionFunction from torch import Tensor @@ -22,11 +21,16 @@ class AbstractAcquisitionFunctionWrapper(AcquisitionFunction, ABC): r"""Abstract acquisition wrapper.""" def __init__(self, acq_function: AcquisitionFunction) -> None: + r"""Initialize the acquisition function wrapper. + + Args: + acq_function: The inner acquisition function to wrap. + """ Module.__init__(self) self.acq_func = acq_function @property - def X_pending(self) -> Optional[Tensor]: + def X_pending(self) -> Tensor | None: r"""Return the `X_pending` of the base acquisition function.""" try: return self.acq_func.X_pending @@ -36,7 +40,7 @@ def X_pending(self) -> Optional[Tensor]: "does not have an `X_pending` attribute." ) - def set_X_pending(self, X_pending: Optional[Tensor]) -> None: + def set_X_pending(self, X_pending: Tensor | None) -> None: r"""Sets the `X_pending` of the base acquisition function.""" self.acq_func.set_X_pending(X_pending) From 7dcf18623e33e53fd534d564bbb3ded85c6e265c Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Wed, 11 Jun 2025 16:57:41 +0100 Subject: [PATCH 09/17] Add test for PR with binary search space --- .../test_probabilistic_reparameterization.py | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 test/acquisition/test_probabilistic_reparameterization.py diff --git a/test/acquisition/test_probabilistic_reparameterization.py b/test/acquisition/test_probabilistic_reparameterization.py new file mode 100644 index 0000000000..b96d47f4cf --- /dev/null +++ b/test/acquisition/test_probabilistic_reparameterization.py @@ -0,0 +1,69 @@ +import torch +from botorch.acquisition import LogExpectedImprovement, qLogExpectedImprovement +from botorch.acquisition.probabilistic_reparameterization import ( + AnalyticProbabilisticReparameterization, + MCProbabilisticReparameterization, +) +from botorch.optim import optimize_acqf +from botorch.test_functions.synthetic import AckleyMixed +from botorch.utils.test_helpers import get_model +from botorch.utils.testing import BotorchTestCase + + +class TestProbabilisticReparameterizationInputTransform(BotorchTestCase): + def test_probabilistic_reparameterization_input_transform(self): + pass + + +class TestProbabilisticReparameterization(BotorchTestCase): + def test_probabilistic_reparameterization_binary( + self, + pr_acq_func_cls=AnalyticProbabilisticReparameterization, + base_acq_func_cls=LogExpectedImprovement, + ): + torch.manual_seed(0) + f = AckleyMixed(dim=13, randomize_optimum=True) + train_X = torch.rand((10, f.dim), dtype=torch.float64) + train_X[:, f.discrete_inds] = train_X[:, f.discrete_inds].round() + train_Y = f(train_X).unsqueeze(-1) + model = get_model(train_X, train_Y) + base_acq_func = base_acq_func_cls(model, best_f=train_Y.max()) + + pr_acq_func = pr_acq_func_cls( + acq_function=base_acq_func, + one_hot_bounds=f.bounds, + integer_indices=f.discrete_inds, + batch_limit=32, + ) + + candidate, _ = optimize_acqf( + acq_function=pr_acq_func, + bounds=f.bounds, + q=1, + num_restarts=10, + raw_samples=20, + options={"maxiter": 5}, + ) + + self.assertTrue(candidate.shape == (1, f.dim)) + + def test_probabilistic_reparameterization_binary_analytic_qLogEI(self): + self.test_probabilistic_reparameterization_binary( + pr_acq_func_cls=AnalyticProbabilisticReparameterization, + base_acq_func_cls=qLogExpectedImprovement, + ) + + def test_probabilistic_reparameterization_binary_MC_LogEI(self): + self.test_probabilistic_reparameterization_binary( + pr_acq_func_cls=MCProbabilisticReparameterization, + base_acq_func_cls=LogExpectedImprovement, + ) + + def test_probabilistic_reparameterization_binary_MC_qLogEI(self): + self.test_probabilistic_reparameterization_binary( + pr_acq_func_cls=MCProbabilisticReparameterization, + base_acq_func_cls=qLogExpectedImprovement, + ) + + def test_probabilistic_reparameterization_categorical(self): + pass From 591b7eacf9dc989ac0966ceac9a88f450dd30e40 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Wed, 11 Jun 2025 19:26:21 +0100 Subject: [PATCH 10/17] Add test for PR with categorical search space --- .../test_probabilistic_reparameterization.py | 120 +++++++++++++++++- 1 file changed, 115 insertions(+), 5 deletions(-) diff --git a/test/acquisition/test_probabilistic_reparameterization.py b/test/acquisition/test_probabilistic_reparameterization.py index b96d47f4cf..f1c45f1aef 100644 --- a/test/acquisition/test_probabilistic_reparameterization.py +++ b/test/acquisition/test_probabilistic_reparameterization.py @@ -1,21 +1,57 @@ +from typing import Any + import torch from botorch.acquisition import LogExpectedImprovement, qLogExpectedImprovement from botorch.acquisition.probabilistic_reparameterization import ( AnalyticProbabilisticReparameterization, MCProbabilisticReparameterization, ) +from botorch.models.transforms.factory import ( + get_probabilistic_reparameterization_input_transform, + get_rounding_input_transform, +) +from botorch.models.transforms.input import OneHotToNumeric from botorch.optim import optimize_acqf -from botorch.test_functions.synthetic import AckleyMixed +from botorch.test_functions.synthetic import Ackley, AckleyMixed +from botorch.utils.sampling import draw_sobol_samples from botorch.utils.test_helpers import get_model from botorch.utils.testing import BotorchTestCase +def get_categorical_features_dict(feature_to_num_categories: dict[int, int]): + r"""Get the mapping of starting index in one-hot space to cardinality. + + This mapping is used to construct the OneHotToNumeric transform. This + requires that all of the categorical parameters are the rightmost elements. + + Args: + feature_to_num_categories: Mapping of feature index to cardinality in the + untransformed space. + + """ + start = None + categorical_features = {} + for idx, cardinality in sorted( + feature_to_num_categories.items(), key=lambda kv: kv[0] + ): + if start is None: + start = idx + categorical_features[start] = cardinality + # add cardinality to start + start += cardinality + return categorical_features + + class TestProbabilisticReparameterizationInputTransform(BotorchTestCase): def test_probabilistic_reparameterization_input_transform(self): - pass + _ = get_probabilistic_reparameterization_input_transform() class TestProbabilisticReparameterization(BotorchTestCase): + def setUp(self): + super().setUp() + self.tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double} + def test_probabilistic_reparameterization_binary( self, pr_acq_func_cls=AnalyticProbabilisticReparameterization, @@ -23,7 +59,7 @@ def test_probabilistic_reparameterization_binary( ): torch.manual_seed(0) f = AckleyMixed(dim=13, randomize_optimum=True) - train_X = torch.rand((10, f.dim), dtype=torch.float64) + train_X = torch.rand((10, f.dim), **self.tkwargs) train_X[:, f.discrete_inds] = train_X[:, f.discrete_inds].round() train_Y = f(train_X).unsqueeze(-1) model = get_model(train_X, train_Y) @@ -65,5 +101,79 @@ def test_probabilistic_reparameterization_binary_MC_qLogEI(self): base_acq_func_cls=qLogExpectedImprovement, ) - def test_probabilistic_reparameterization_categorical(self): - pass + def test_probabilistic_reparameterization_categorical( + self, + pr_acq_func_cls=AnalyticProbabilisticReparameterization, + base_acq_func_cls=LogExpectedImprovement, + ): + torch.manual_seed(0) + # we use Ackley here to ensure the categorical features are the + # rightmost elements + dim = 5 + bounds = [(0.0, 1.0)] * 5 + f = Ackley(dim=dim, bounds=bounds) + # convert the continuous features into categorical features + feature_to_num_categories = {3: 3, 4: 5} + for feature_idx, num_categories in feature_to_num_categories.items(): + f.bounds[1, feature_idx] = num_categories - 1 + + categorical_features = get_categorical_features_dict(feature_to_num_categories) + one_hot_bounds = torch.zeros( + 2, 3 + sum(categorical_features.values()), **self.tkwargs + ) + one_hot_bounds[1, :] = 1.0 + init_exact_rounding_func = get_rounding_input_transform( + one_hot_bounds=one_hot_bounds, + categorical_features=categorical_features, + initialization=True, + ) + one_hot_to_numeric = OneHotToNumeric( + dim=one_hot_bounds.shape[1], categorical_features=categorical_features + ).to(**self.tkwargs) + + raw_X = ( + draw_sobol_samples(one_hot_bounds, n=10, q=1).squeeze(-2).to(**self.tkwargs) + ) + train_X = init_exact_rounding_func(raw_X) + train_Y = f(one_hot_to_numeric(train_X)).unsqueeze(-1) + model = get_model(train_X, train_Y) + base_acq_func = base_acq_func_cls(model, best_f=train_Y.max()) + + pr_acq_func = pr_acq_func_cls( + acq_function=base_acq_func, + one_hot_bounds=one_hot_bounds, + categorical_features=categorical_features, + integer_indices=None, + batch_limit=32, + ) + + raw_candidate, _ = optimize_acqf( + acq_function=pr_acq_func, + bounds=one_hot_bounds, + q=1, + num_restarts=10, + raw_samples=20, + options={"maxiter": 5}, + # gen_candidates=gen_candidates_scipy, + ) + # candidates are generated in the one-hot space + candidate = one_hot_to_numeric(raw_candidate) + self.assertTrue(candidate.shape == (1, f.dim)) + + def test_probabilistic_reparameterization_categorical_analytic_qLogEI(self): + self.test_probabilistic_reparameterization_categorical( + pr_acq_func_cls=AnalyticProbabilisticReparameterization, + base_acq_func_cls=qLogExpectedImprovement, + ) + + def test_probabilistic_reparameterization_categorical_MC_LogEI(self): + self.test_probabilistic_reparameterization_categorical( + pr_acq_func_cls=MCProbabilisticReparameterization, + base_acq_func_cls=LogExpectedImprovement, + ) + + def test_probabilistic_reparameterization_categorical_MC_qLogEI(self): + self.test_probabilistic_reparameterization_categorical( + pr_acq_func_cls=MCProbabilisticReparameterization, + base_acq_func_cls=qLogExpectedImprovement, + ) From 883fb8928bd7c8cf969295cd3ed0fa83c483f5f3 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Thu, 12 Jun 2025 22:26:12 +0100 Subject: [PATCH 11/17] Compare analytic vs MC PR in test --- .../test_probabilistic_reparameterization.py | 63 +++++++++++++------ 1 file changed, 43 insertions(+), 20 deletions(-) diff --git a/test/acquisition/test_probabilistic_reparameterization.py b/test/acquisition/test_probabilistic_reparameterization.py index f1c45f1aef..03739b8543 100644 --- a/test/acquisition/test_probabilistic_reparameterization.py +++ b/test/acquisition/test_probabilistic_reparameterization.py @@ -1,3 +1,4 @@ +import itertools from typing import Any import torch @@ -6,12 +7,13 @@ AnalyticProbabilisticReparameterization, MCProbabilisticReparameterization, ) +from botorch.generation.gen import gen_candidates_scipy, gen_candidates_torch from botorch.models.transforms.factory import ( get_probabilistic_reparameterization_input_transform, get_rounding_input_transform, ) from botorch.models.transforms.input import OneHotToNumeric -from botorch.optim import optimize_acqf +from botorch.optim import optimize_acqf, optimize_acqf_mixed from botorch.test_functions.synthetic import Ackley, AckleyMixed from botorch.utils.sampling import draw_sobol_samples from botorch.utils.test_helpers import get_model @@ -44,6 +46,8 @@ def get_categorical_features_dict(feature_to_num_categories: dict[int, int]): class TestProbabilisticReparameterizationInputTransform(BotorchTestCase): def test_probabilistic_reparameterization_input_transform(self): + # TODO: test this functionality in factory + # test the actual transform here. _ = get_probabilistic_reparameterization_input_transform() @@ -54,50 +58,69 @@ def setUp(self): def test_probabilistic_reparameterization_binary( self, - pr_acq_func_cls=AnalyticProbabilisticReparameterization, base_acq_func_cls=LogExpectedImprovement, ): torch.manual_seed(0) - f = AckleyMixed(dim=13, randomize_optimum=True) + f = AckleyMixed(dim=6, randomize_optimum=True) train_X = torch.rand((10, f.dim), **self.tkwargs) train_X[:, f.discrete_inds] = train_X[:, f.discrete_inds].round() train_Y = f(train_X).unsqueeze(-1) model = get_model(train_X, train_Y) base_acq_func = base_acq_func_cls(model, best_f=train_Y.max()) - pr_acq_func = pr_acq_func_cls( + pr_acq_func_params = dict( acq_function=base_acq_func, one_hot_bounds=f.bounds, integer_indices=f.discrete_inds, batch_limit=32, ) - - candidate, _ = optimize_acqf( - acq_function=pr_acq_func, + optimize_acqf_params = dict( bounds=f.bounds, q=1, num_restarts=10, - raw_samples=20, - options={"maxiter": 5}, + raw_samples=512, + options={ + "batch_limit": 5, + "maxiter": 200, + "rel_tol": float("-inf"), + }, ) - self.assertTrue(candidate.shape == (1, f.dim)) + pr_analytic_acq_func = AnalyticProbabilisticReparameterization( + **pr_acq_func_params + ) - def test_probabilistic_reparameterization_binary_analytic_qLogEI(self): - self.test_probabilistic_reparameterization_binary( - pr_acq_func_cls=AnalyticProbabilisticReparameterization, - base_acq_func_cls=qLogExpectedImprovement, + pr_mc_acq_func = MCProbabilisticReparameterization(**pr_acq_func_params) + + candidate_analytic, acq_values_analytic = optimize_acqf( + acq_function=pr_analytic_acq_func, + gen_candidates=gen_candidates_scipy, + **optimize_acqf_params, ) - def test_probabilistic_reparameterization_binary_MC_LogEI(self): - self.test_probabilistic_reparameterization_binary( - pr_acq_func_cls=MCProbabilisticReparameterization, - base_acq_func_cls=LogExpectedImprovement, + candidate_mc, acq_values_mc = optimize_acqf( + acq_function=pr_mc_acq_func, + gen_candidates=gen_candidates_torch, + **optimize_acqf_params, ) - def test_probabilistic_reparameterization_binary_MC_qLogEI(self): + fixed_features_list = [ + {feat_dim: val for feat_dim, val in enumerate(vals)} + for vals in itertools.product([0, 1], repeat=len(f.discrete_inds)) + ] + candidate_exhaustive, acq_values_exhaustive = optimize_acqf_mixed( + acq_function=base_acq_func, + fixed_features_list=fixed_features_list, + **optimize_acqf_params, + ) + + self.assertTrue(candidate_analytic.shape == (1, f.dim)) + self.assertTrue(candidate_mc.shape == (1, f.dim)) + self.assertAllClose(candidate_analytic, candidate_mc) + self.assertAllClose(acq_values_analytic, acq_values_mc) + + def test_probabilistic_reparameterization_binary_qLogEI(self): self.test_probabilistic_reparameterization_binary( - pr_acq_func_cls=MCProbabilisticReparameterization, base_acq_func_cls=qLogExpectedImprovement, ) From f321b1fc1904c0d8cb78399d46fb81e77deba362 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Wed, 18 Jun 2025 10:54:30 +0100 Subject: [PATCH 12/17] Fix indexing bug when enumerating all discrete options --- botorch/models/transforms/input.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 7802bfc673..83676b7a95 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -1796,7 +1796,7 @@ def __init__( to be one-hot encoded. TODO: generalize to support alternative representations. transform_on_train: A boolean indicating whether to apply the - transforms in train() mode. Default: True. + transforms in train() mode. Default: False. transform_on_eval: A boolean indicating whether to apply the transform in eval() mode. Default: True. transform_on_fantasize: A boolean indicating whether to apply the @@ -1925,7 +1925,7 @@ def __init__( ): start = idx - categorical_start_idx X_categ[..., start : start + cardinality] = one_hot( - all_discrete_options[..., i], + all_discrete_options[..., -len(categorical_features) + i], num_classes=cardinality, ).to(X_categ) all_discrete_options = torch.cat( @@ -2095,7 +2095,7 @@ def __init__( to be one-hot encoded. TODO: generalize to support alternative representations. transform_on_train: A boolean indicating whether to apply the - transforms in train() mode. Default: True. + transforms in train() mode. Default: False. transform_on_eval: A boolean indicating whether to apply the transform in eval() mode. Default: True. transform_on_fantasize: A boolean indicating whether to apply the From ca69474ba13e178c7caac26f88520383ff1fef3c Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Wed, 18 Jun 2025 11:30:35 +0100 Subject: [PATCH 13/17] Test constructing PR input transforms --- botorch/models/transforms/input.py | 3 + .../test_probabilistic_reparameterization.py | 125 ++++++++++++++++-- 2 files changed, 120 insertions(+), 8 deletions(-) diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 83676b7a95..dab0ec1f10 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -1894,6 +1894,9 @@ def __init__( ) ) if integer_indices is not None: + # FIXME: this assumes that the integer dimensions are after the continuous + # if we want to enforce this, we should test for it similarly to + # categoricals for i in range(self.integer_bounds.shape[-1]): discrete_options.append( torch.arange( diff --git a/test/acquisition/test_probabilistic_reparameterization.py b/test/acquisition/test_probabilistic_reparameterization.py index 03739b8543..f2a3e9b180 100644 --- a/test/acquisition/test_probabilistic_reparameterization.py +++ b/test/acquisition/test_probabilistic_reparameterization.py @@ -8,11 +8,12 @@ MCProbabilisticReparameterization, ) from botorch.generation.gen import gen_candidates_scipy, gen_candidates_torch -from botorch.models.transforms.factory import ( - get_probabilistic_reparameterization_input_transform, - get_rounding_input_transform, +from botorch.models.transforms.factory import get_rounding_input_transform +from botorch.models.transforms.input import ( + AnalyticProbabilisticReparameterizationInputTransform, + MCProbabilisticReparameterizationInputTransform, + OneHotToNumeric, ) -from botorch.models.transforms.input import OneHotToNumeric from botorch.optim import optimize_acqf, optimize_acqf_mixed from botorch.test_functions.synthetic import Ackley, AckleyMixed from botorch.utils.sampling import draw_sobol_samples @@ -45,10 +46,118 @@ def get_categorical_features_dict(feature_to_num_categories: dict[int, int]): class TestProbabilisticReparameterizationInputTransform(BotorchTestCase): - def test_probabilistic_reparameterization_input_transform(self): - # TODO: test this functionality in factory - # test the actual transform here. - _ = get_probabilistic_reparameterization_input_transform() + def setUp(self): + super().setUp() + self.tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double} + self.one_hot_bounds = torch.tensor( + [ + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 4.0, 5.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ], + **self.tkwargs, + ) + + self.analytic_params = dict( + transform_on_train=False, + transform_on_eval=True, + transform_on_fantasize=True, + tau=0.1, + ) + + self.mc_params = dict( + **self.analytic_params, + mc_samples=128, + resample=False, + ) + + def test_probabilistic_reparameterization_input_transform_construction(self): + bounds = self.one_hot_bounds + integer_indices = [2, 3] + categorical_features = {4: 2, 6: 3} + + # must provide either categorical or discrete features + with self.assertRaises(ValueError): + _ = AnalyticProbabilisticReparameterizationInputTransform( + one_hot_bounds=bounds, + **self.analytic_params, + ) + + with self.assertRaises(ValueError): + _ = MCProbabilisticReparameterizationInputTransform( + one_hot_bounds=bounds, + **self.mc_params, + ) + + # categorical features must be in the rightmost columns + with self.assertRaisesRegex(ValueError, "rightmost"): + _ = AnalyticProbabilisticReparameterizationInputTransform( + one_hot_bounds=bounds, + integer_indices=integer_indices, + categorical_features={0: 2}, + **self.analytic_params, + ) + with self.assertRaisesRegex(ValueError, "rightmost"): + _ = MCProbabilisticReparameterizationInputTransform( + one_hot_bounds=bounds, + integer_indices=integer_indices, + categorical_features={0: 2}, + **self.mc_params, + ) + + # correct construction passes without raising errors + _ = AnalyticProbabilisticReparameterizationInputTransform( + one_hot_bounds=bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + **self.analytic_params, + ) + _ = MCProbabilisticReparameterizationInputTransform( + one_hot_bounds=bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + **self.mc_params, + ) + + # analytic generates all discrete options correctly + # use subset of features so that we can manually generate all options + sub_bounds = bounds[:, [0, 2, 6, 7, 8]] + sub_integer_indices = [1] + sub_categorical_features = {2: 3} + tf_analytic = AnalyticProbabilisticReparameterizationInputTransform( + one_hot_bounds=sub_bounds, + integer_indices=sub_integer_indices, + categorical_features=sub_categorical_features, + **self.analytic_params, + ) + + num_discrete_options = 5 * 3 + expected_all_discrete_options = torch.zeros( + (num_discrete_options, sub_bounds.shape[-1]) + ) + expected_all_discrete_options[:, 1] = torch.repeat_interleave( + torch.arange(5), 3 + ) + expected_all_discrete_options[:, 2:] = torch.eye(3).repeat([5, 1]) + + self.assertAllClose( + expected_all_discrete_options, tf_analytic.all_discrete_options + ) + + def test_probabilistic_reparameterization_input_transform_forward(self): + bounds = self.one_hot_bounds + integer_indices = [2, 3] + categorical_features = {4: 2, 6: 3} + + tf_analytic = AnalyticProbabilisticReparameterizationInputTransform( + one_hot_bounds=bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + **self.analytic_params, + ) + + X = torch.tensor([[[0.2, 0.8, 3.2, 1.5, 1.0, 0.0, 0.0, 0.0, 1.0]]]) + X_transformed = tf_analytic.transform(X) + print(X_transformed.shape) class TestProbabilisticReparameterization(BotorchTestCase): From ebfc9d7a8db9b51146e5fde447e095750302bd71 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Wed, 18 Jun 2025 14:33:41 +0100 Subject: [PATCH 14/17] Test forward pass of PR input transform --- botorch/models/transforms/input.py | 4 +- .../test_probabilistic_reparameterization.py | 47 +++++++++++++++++-- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index dab0ec1f10..04090fdc76 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -2213,10 +2213,10 @@ def transform(self, X: Tensor) -> Tensor: This is not sample-path differentiable. Args: - X: A `batch_shape x n x d`-dim tensor of inputs. + X: A `batch_shape x 1 x n x d`-dim tensor of inputs. Returns: - A `batch_shape x n x d`-dim tensor of rounded inputs. + A `batch_shape x mc_samples x n x d`-dim tensor of rounded inputs. """ X_expanded = X.expand(*X.shape[:-3], self.mc_samples, *X.shape[-2:]).clone() X_prob = self.get_rounding_prob(X=X) diff --git a/test/acquisition/test_probabilistic_reparameterization.py b/test/acquisition/test_probabilistic_reparameterization.py index f2a3e9b180..4b162e644d 100644 --- a/test/acquisition/test_probabilistic_reparameterization.py +++ b/test/acquisition/test_probabilistic_reparameterization.py @@ -155,9 +155,50 @@ def test_probabilistic_reparameterization_input_transform_forward(self): **self.analytic_params, ) - X = torch.tensor([[[0.2, 0.8, 3.2, 1.5, 1.0, 0.0, 0.0, 0.0, 1.0]]]) - X_transformed = tf_analytic.transform(X) - print(X_transformed.shape) + X = torch.tensor( + [[[0.2, 0.8, 3.2, 1.5, 0.9, 0.05, 0.05, 0.05, 0.95]]], **self.tkwargs + ) + X_transformed_analytic = tf_analytic.transform(X) + + expected_shape = [5 * 6 * 2 * 3, 1, bounds.shape[-1]] + self.assertEqual(X_transformed_analytic.shape, torch.Size(expected_shape)) + + tf_mc = MCProbabilisticReparameterizationInputTransform( + one_hot_bounds=bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + **self.mc_params, + ) + + X_transformed_mc = tf_mc.transform(X) + + expected_shape = [tf_mc.mc_samples, 1, bounds.shape[-1]] + self.assertEqual(X_transformed_mc.shape, torch.Size(expected_shape)) + + continuous_indices = [0, 1] + discrete_indices = [ + d for d in range(bounds.shape[-1]) if d not in continuous_indices + ] + for X_transformed in [X_transformed_analytic, X_transformed_mc]: + self.assertAllClose( + X[..., continuous_indices].repeat([X_transformed.shape[0], 1, 1]), + X_transformed[..., continuous_indices], + ) + + # all discrete indices have been rounded + self.assertAllClose( + X_transformed[..., discrete_indices] % 1, + torch.zeros_like(X_transformed[..., discrete_indices]), + ) + + # for MC, all integer indices should be within [floor(X), ceil(X)] + # categoricals should be approximately proportional to their probability + self.assertTrue( + ((X.floor() <= X_transformed_mc) & (X_transformed_mc <= X.ceil()))[ + ..., integer_indices + ].all() + ) + self.assertAllClose(X_transformed_mc[..., -1].mean().item(), 0.95, atol=0.10) class TestProbabilisticReparameterization(BotorchTestCase): From 6857ddfc6ddfee0945c4e631dd9416607a0f1732 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Wed, 18 Jun 2025 15:08:40 +0100 Subject: [PATCH 15/17] Consolidate `*PRInputTransform`s --- .../probabilistic_reparameterization.py | 2 +- botorch/models/transforms/input.py | 275 ++++++++---------- 2 files changed, 127 insertions(+), 150 deletions(-) diff --git a/botorch/acquisition/probabilistic_reparameterization.py b/botorch/acquisition/probabilistic_reparameterization.py index dcc042c67a..e2cb12f727 100644 --- a/botorch/acquisition/probabilistic_reparameterization.py +++ b/botorch/acquisition/probabilistic_reparameterization.py @@ -512,7 +512,7 @@ def __init__( if self.batch_limit is None: self.batch_limit = mc_samples self.use_ma_baseline = use_ma_baseline - self._pr_acq_function = _MCProbabilisticReparameterization() + self._pr_acq_function = _MCProbabilisticReparameterization # create input transform self.input_transform = get_probabilistic_reparameterization_input_transform( integer_indices=integer_indices, diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 04090fdc76..70e5057847 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -1755,8 +1755,8 @@ def equals(self, other: InputTransform) -> bool: ) -class AnalyticProbabilisticReparameterizationInputTransform(InputTransform, Module): - r"""An input transform to prepare inputs for analytic PR. +class AbstractProbabilisticReparameterizationInputTransform(InputTransform, ABC): + r"""An abstract input transform to prepare inputs for PR. See [Daulton2022bopr]_ for details. @@ -1769,19 +1769,18 @@ class AnalyticProbabilisticReparameterizationInputTransform(InputTransform, Modu 1. These are unnormalized back to the raw input space. 2. The discrete values are created. 3. All values are normalized to the unitcube. - - TODO: consolidate this with MCProbabilisticReparameterizationInputTransform. - """ def __init__( self, - one_hot_bounds: Tensor = None, + one_hot_bounds: Tensor, integer_indices: list[int] | None = None, categorical_features: dict[int, int] | None = None, transform_on_train: bool = False, transform_on_eval: bool = True, transform_on_fantasize: bool = True, + mc_samples: int = 128, + resample: bool = False, tau: float = 0.1, ) -> None: r"""Initialize transform. @@ -1817,12 +1816,8 @@ def __init__( discrete_indices = [] if integer_indices is not None and len(integer_indices) > 0: self.register_buffer( - "integer_indices", - torch.tensor( - integer_indices, dtype=torch.long, device=one_hot_bounds.device - ), + "integer_indices", torch.tensor(integer_indices, dtype=torch.long) ) - self.register_buffer("integer_bounds", one_hot_bounds[:, integer_indices]) discrete_indices += integer_indices else: self.integer_indices = None @@ -1874,7 +1869,114 @@ def __init__( categorical_ends, dtype=torch.long, device=one_hot_bounds.device ), ) + if integer_indices is None: + self.register_buffer( + "integer_bounds", + torch.tensor([], dtype=torch.long, device=one_hot_bounds.device), + ) + else: + self.register_buffer("integer_bounds", one_hot_bounds[:, integer_indices]) self.tau = tau + + def get_rounding_prob(self, X: Tensor) -> Tensor: + X_prob = X.detach().clone() + if self.integer_indices is not None: + # compute probabilities for integers + X_int = X_prob[..., self.integer_indices] + X_int_abs = X_int.abs() + offset = X_int_abs.floor() + if self.tau is not None: + X_prob[..., self.integer_indices] = torch.sigmoid( + (X_int_abs - offset - 0.5) / self.tau + ) + else: + X_prob[..., self.integer_indices] = X_int_abs - offset + # compute probabilities for categoricals + for start, end in zip(self.categorical_starts, self.categorical_ends): + X_categ = X_prob[..., start:end] + if self.tau is not None: + X_prob[..., start:end] = torch.softmax( + (X_categ - 0.5) / self.tau, dim=-1 + ) + else: + X_prob[..., start:end] = X_categ / X_categ.sum(dim=-1) + return X_prob[..., self.discrete_indices] + + def equals(self, other: InputTransform) -> bool: + r"""Check if another input transform is equivalent. + + Args: + other: Another input transform. + + Returns: + A boolean indicating if the other transform is equivalent. + """ + return ( + super().equals(other=other) + and torch.equal(self.integer_indices, other.integer_indices) + and self.tau == other.tau + ) + + +class AnalyticProbabilisticReparameterizationInputTransform( + AbstractProbabilisticReparameterizationInputTransform +): + r"""An input transform to prepare inputs for analytic PR. + + See [Daulton2022bopr]_ for details. + + This will typically be used in conjunction with normalization as + follows: + + In eval() mode (i.e. after training), the inputs pass + would typically be normalized to the unit cube (e.g. during candidate + optimization). + 1. These are unnormalized back to the raw input space. + 2. The discrete values are created. + 3. All values are normalized to the unitcube. + """ + + def __init__( + self, + one_hot_bounds: Tensor = None, + integer_indices: list[int] | None = None, + categorical_features: dict[int, int] | None = None, + transform_on_train: bool = False, + transform_on_eval: bool = True, + transform_on_fantasize: bool = True, + tau: float = 0.1, + ) -> None: + r"""Initialize transform. + + Args: + one_hot_bounds: The raw search space bounds where categoricals are + encoded in one-hot representation and the integer parameters + are not normalized. + integer_indices: The indices of the integer inputs. + categorical_features: The indices and cardinality of + each categorical feature. The features are assumed + to be one-hot encoded. TODO: generalize to support + alternative representations. + transform_on_train: A boolean indicating whether to apply the + transforms in train() mode. Default: False. + transform_on_eval: A boolean indicating whether to apply the + transform in eval() mode. Default: True. + transform_on_fantasize: A boolean indicating whether to apply the + transform when called from within a `fantasize` call. Default: True. + mc_samples: The number of MC samples. + resample: A boolean indicating whether to resample base samples + at each forward pass. + tau: The temperature parameter. + """ + super().__init__( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + transform_on_train=transform_on_train, + transform_on_eval=transform_on_eval, + transform_on_fantasize=transform_on_fantasize, + tau=tau, + ) # create cartesian product of discrete options discrete_options = [] dim = one_hot_bounds.shape[1] @@ -1885,7 +1987,7 @@ def __init__( if self.categorical_features is not None: num_discrete_params += len(self.categorical_features) # add zeros for continuous params to simplify code - for _ in range(dim - len(discrete_indices)): + for _ in range(dim - len(self.discrete_indices)): discrete_options.append( torch.zeros( 1, @@ -1937,31 +2039,6 @@ def __init__( ) self.register_buffer("all_discrete_options", all_discrete_options) - def get_rounding_prob(self, X: Tensor) -> Tensor: - # todo consolidate this the MCProbabilisticReparameterizationInputTransform - X_prob = X.detach().clone() - if self.integer_indices is not None: - # compute probabilities for integers - X_int = X_prob[..., self.integer_indices] - X_int_abs = X_int.abs() - offset = X_int_abs.floor() - if self.tau is not None: - X_prob[..., self.integer_indices] = torch.sigmoid( - (X_int_abs - offset - 0.5) / self.tau - ) - else: - X_prob[..., self.integer_indices] = X_int_abs - offset - # compute probabilities for categoricals - for start, end in zip(self.categorical_starts, self.categorical_ends): - X_categ = X_prob[..., start:end] - if self.tau is not None: - X_prob[..., start:end] = torch.softmax( - (X_categ - 0.5) / self.tau, dim=-1 - ) - else: - X_prob[..., start:end] = X_categ / X_categ.sum(dim=-1) - return X_prob[..., self.discrete_indices] - def get_probs(self, X: Tensor) -> Tensor: """ Args: @@ -2043,23 +2120,11 @@ def transform(self, X: Tensor) -> Tensor: ) return all_discrete_options - def equals(self, other: InputTransform) -> bool: - r"""Check if another input transform is equivalent. - Args: - other: Another input transform. - - Returns: - A boolean indicating if the other transform is equivalent. - """ - # TODO: update this - return super().equals(other=other) and torch.equal( - self.integer_indices, other.integer_indices - ) - - -class MCProbabilisticReparameterizationInputTransform(InputTransform, Module): - r"""An input transform to prepare inputs for analytic PR. +class MCProbabilisticReparameterizationInputTransform( + AbstractProbabilisticReparameterizationInputTransform +): + r"""An input transform to prepare inputs for Monte Carlo PR. See [Daulton2022bopr]_ for details. @@ -2108,104 +2173,17 @@ def __init__( at each forward pass. tau: The temperature parameter. """ - super().__init__() - if integer_indices is None and categorical_features is None: - raise ValueError( - "integer_indices and/or categorical_features must be provided." - ) - self.transform_on_train = transform_on_train - self.transform_on_eval = transform_on_eval - self.transform_on_fantasize = transform_on_fantasize - discrete_indices = [] - if integer_indices is not None and len(integer_indices) > 0: - self.register_buffer( - "integer_indices", torch.tensor(integer_indices, dtype=torch.long) - ) - discrete_indices += integer_indices - else: - self.integer_indices = None - self.categorical_features = categorical_features - if self.categorical_features is not None: - self.categorical_start_idx = min(self.categorical_features.keys()) - # check that the trailing dimensions are categoricals - end = self.categorical_start_idx - err_msg = ( - f"{self.__class__.__name__} requires that the categorical " - "parameters are the rightmost elements." - ) - for start, card in self.categorical_features.items(): - # the end of one one-hot representation should be followed - # by the start of the next - if end != start: - raise ValueError(err_msg) - end = start + card - if end != one_hot_bounds.shape[1]: - # check end - raise ValueError(err_msg) - categorical_starts = [] - categorical_ends = [] - if self.categorical_features is not None: - start = None - for i, n_categories in categorical_features.items(): - if start is None: - start = i - end = start + n_categories - categorical_starts.append(start) - categorical_ends.append(end) - discrete_indices += list(range(start, end)) - start = end - self.register_buffer( - "discrete_indices", - torch.tensor( - discrete_indices, dtype=torch.long, device=one_hot_bounds.device - ), - ) - self.register_buffer( - "categorical_starts", - torch.tensor( - categorical_starts, dtype=torch.long, device=one_hot_bounds.device - ), - ) - self.register_buffer( - "categorical_ends", - torch.tensor( - categorical_ends, dtype=torch.long, device=one_hot_bounds.device - ), + super().__init__( + one_hot_bounds=one_hot_bounds, + integer_indices=integer_indices, + categorical_features=categorical_features, + transform_on_train=transform_on_train, + transform_on_eval=transform_on_eval, + transform_on_fantasize=transform_on_fantasize, + tau=tau, ) - if integer_indices is None: - self.register_buffer( - "integer_bounds", - torch.tensor([], dtype=torch.long, device=one_hot_bounds.device), - ) - else: - self.register_buffer("integer_bounds", one_hot_bounds[:, integer_indices]) self.mc_samples = mc_samples self.resample = resample - self.tau = tau - - def get_rounding_prob(self, X: Tensor) -> Tensor: - X_prob = X.detach().clone() - if self.integer_indices is not None: - # compute probabilities for integers - X_int = X_prob[..., self.integer_indices] - X_int_abs = X_int.abs() - offset = X_int_abs.floor() - if self.tau is not None: - X_prob[..., self.integer_indices] = torch.sigmoid( - (X_int_abs - offset - 0.5) / self.tau - ) - else: - X_prob[..., self.integer_indices] = X_int_abs - offset - # compute probabilities for categoricals - for start, end in zip(self.categorical_starts, self.categorical_ends): - X_categ = X_prob[..., start:end] - if self.tau is not None: - X_prob[..., start:end] = torch.softmax( - (X_categ - 0.5) / self.tau, dim=-1 - ) - else: - X_prob[..., start:end] = X_categ / X_categ.sum(dim=-1) - return X_prob[..., self.discrete_indices] def transform(self, X: Tensor) -> Tensor: r"""Round the inputs. @@ -2326,5 +2304,4 @@ def equals(self, other: InputTransform) -> bool: super().equals(other=other) and (self.resample == other.resample) and torch.equal(self.base_samples, other.base_samples) - and torch.equal(self.integer_indices, other.integer_indices) ) From fab81f719a8750d8f40a9f0d3c8f0da15ae62f6a Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Fri, 20 Jun 2025 15:32:36 +0100 Subject: [PATCH 16/17] Change order of integer idxs in PR test --- .../test_probabilistic_reparameterization.py | 64 +++++++++++++------ 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/test/acquisition/test_probabilistic_reparameterization.py b/test/acquisition/test_probabilistic_reparameterization.py index 4b162e644d..378ecf42a1 100644 --- a/test/acquisition/test_probabilistic_reparameterization.py +++ b/test/acquisition/test_probabilistic_reparameterization.py @@ -206,12 +206,27 @@ def setUp(self): super().setUp() self.tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double} + self.acqf_params = dict( + batch_limit=32, + ) + + self.optimize_acqf_params = dict( + num_restarts=10, + raw_samples=512, + options={ + "batch_limit": 5, + "maxiter": 200, + "rel_tol": float("-inf"), + }, + ) + def test_probabilistic_reparameterization_binary( self, base_acq_func_cls=LogExpectedImprovement, ): torch.manual_seed(0) - f = AckleyMixed(dim=6, randomize_optimum=True) + f = AckleyMixed(dim=6, randomize_optimum=False) + f.discrete_inds = [3, 4, 5] train_X = torch.rand((10, f.dim), **self.tkwargs) train_X[:, f.discrete_inds] = train_X[:, f.discrete_inds].round() train_Y = f(train_X).unsqueeze(-1) @@ -222,18 +237,7 @@ def test_probabilistic_reparameterization_binary( acq_function=base_acq_func, one_hot_bounds=f.bounds, integer_indices=f.discrete_inds, - batch_limit=32, - ) - optimize_acqf_params = dict( - bounds=f.bounds, - q=1, - num_restarts=10, - raw_samples=512, - options={ - "batch_limit": 5, - "maxiter": 200, - "rel_tol": float("-inf"), - }, + **self.acqf_params, ) pr_analytic_acq_func = AnalyticProbabilisticReparameterization( @@ -242,32 +246,54 @@ def test_probabilistic_reparameterization_binary( pr_mc_acq_func = MCProbabilisticReparameterization(**pr_acq_func_params) + X = torch.tensor([[[0.3, 0.7, 0.8, 0.0, 0.5, 1.0]]], **self.tkwargs) + X_lb, X_ub = X.clone(), X.clone() + X_lb[..., 4] = 0.0 + X_ub[..., 4] = 1.0 + + acq_value_base_mean = (base_acq_func(X_lb) + base_acq_func(X_ub)) / 2 + acq_value_analytic = pr_analytic_acq_func(X) + acq_value_mc = pr_mc_acq_func(X) + + # this is not exact due to sigmoid transform in discrete probabilities + self.assertAllClose(acq_value_analytic, acq_value_base_mean, rtol=1e-2) + self.assertAllClose(acq_value_mc, acq_value_base_mean, rtol=1e-2) + candidate_analytic, acq_values_analytic = optimize_acqf( acq_function=pr_analytic_acq_func, + bounds=f.bounds, + q=1, gen_candidates=gen_candidates_scipy, - **optimize_acqf_params, + **self.optimize_acqf_params, ) candidate_mc, acq_values_mc = optimize_acqf( acq_function=pr_mc_acq_func, + bounds=f.bounds, + q=1, gen_candidates=gen_candidates_torch, - **optimize_acqf_params, + **self.optimize_acqf_params, ) fixed_features_list = [ - {feat_dim: val for feat_dim, val in enumerate(vals)} + {feat_dim + 3: val for feat_dim, val in enumerate(vals)} for vals in itertools.product([0, 1], repeat=len(f.discrete_inds)) ] candidate_exhaustive, acq_values_exhaustive = optimize_acqf_mixed( acq_function=base_acq_func, fixed_features_list=fixed_features_list, - **optimize_acqf_params, + bounds=f.bounds, + q=1, + **self.optimize_acqf_params, ) self.assertTrue(candidate_analytic.shape == (1, f.dim)) self.assertTrue(candidate_mc.shape == (1, f.dim)) - self.assertAllClose(candidate_analytic, candidate_mc) - self.assertAllClose(acq_values_analytic, acq_values_mc) + + self.assertAllClose(candidate_analytic, candidate_exhaustive, rtol=0.1) + self.assertAllClose(acq_values_analytic, acq_values_exhaustive, rtol=0.1) + self.assertAllClose(candidate_mc, candidate_exhaustive, rtol=0.1) + self.assertAllClose(acq_values_mc, acq_values_exhaustive, rtol=0.1) def test_probabilistic_reparameterization_binary_qLogEI(self): self.test_probabilistic_reparameterization_binary( From 5a1b0fc9ce9da2b5e95fff4079f48f2e06dfc128 Mon Sep 17 00:00:00 2001 From: Toby Boyne Date: Sun, 22 Jun 2025 14:13:51 +0100 Subject: [PATCH 17/17] Create test for categorical PR --- .../test_probabilistic_reparameterization.py | 101 +++++++++++++----- 1 file changed, 73 insertions(+), 28 deletions(-) diff --git a/test/acquisition/test_probabilistic_reparameterization.py b/test/acquisition/test_probabilistic_reparameterization.py index 378ecf42a1..1e0c87db27 100644 --- a/test/acquisition/test_probabilistic_reparameterization.py +++ b/test/acquisition/test_probabilistic_reparameterization.py @@ -8,6 +8,7 @@ MCProbabilisticReparameterization, ) from botorch.generation.gen import gen_candidates_scipy, gen_candidates_torch +from botorch.models import MixedSingleTaskGP from botorch.models.transforms.factory import get_rounding_input_transform from botorch.models.transforms.input import ( AnalyticProbabilisticReparameterizationInputTransform, @@ -256,8 +257,8 @@ def test_probabilistic_reparameterization_binary( acq_value_mc = pr_mc_acq_func(X) # this is not exact due to sigmoid transform in discrete probabilities - self.assertAllClose(acq_value_analytic, acq_value_base_mean, rtol=1e-2) - self.assertAllClose(acq_value_mc, acq_value_base_mean, rtol=1e-2) + self.assertAllClose(acq_value_analytic, acq_value_base_mean, rtol=0.1) + self.assertAllClose(acq_value_mc, acq_value_base_mean, rtol=0.1) candidate_analytic, acq_values_analytic = optimize_acqf( acq_function=pr_analytic_acq_func, @@ -302,7 +303,6 @@ def test_probabilistic_reparameterization_binary_qLogEI(self): def test_probabilistic_reparameterization_categorical( self, - pr_acq_func_cls=AnalyticProbabilisticReparameterization, base_acq_func_cls=LogExpectedImprovement, ): torch.manual_seed(0) @@ -327,52 +327,97 @@ def test_probabilistic_reparameterization_categorical( initialization=True, ) one_hot_to_numeric = OneHotToNumeric( - dim=one_hot_bounds.shape[1], categorical_features=categorical_features + dim=one_hot_bounds.shape[1], + categorical_features=categorical_features, + transform_on_train=False, ).to(**self.tkwargs) raw_X = ( draw_sobol_samples(one_hot_bounds, n=10, q=1).squeeze(-2).to(**self.tkwargs) ) train_X = init_exact_rounding_func(raw_X) - train_Y = f(one_hot_to_numeric(train_X)).unsqueeze(-1) - model = get_model(train_X, train_Y) + train_Y = f(one_hot_to_numeric.transform(train_X)).unsqueeze(-1) + model = MixedSingleTaskGP( + train_X=one_hot_to_numeric.transform(train_X), + train_Y=train_Y, + cat_dims=list(feature_to_num_categories.keys()), + input_transform=one_hot_to_numeric, + ) base_acq_func = base_acq_func_cls(model, best_f=train_Y.max()) - pr_acq_func = pr_acq_func_cls( + pr_acq_func_params = dict( acq_function=base_acq_func, one_hot_bounds=one_hot_bounds, categorical_features=categorical_features, - integer_indices=None, - batch_limit=32, + **self.acqf_params, ) - raw_candidate, _ = optimize_acqf( - acq_function=pr_acq_func, + pr_analytic_acq_func = AnalyticProbabilisticReparameterization( + **pr_acq_func_params + ) + + pr_mc_acq_func = MCProbabilisticReparameterization(**pr_acq_func_params) + + X = one_hot_bounds[:1, :].clone().unsqueeze(0) + X[..., -1] = 1.0 + X_lb, X_ub = X.clone(), X.clone() + X[..., 3:5] = 0.5 + X_lb[..., 3] = 1.0 + X_ub[..., 4] = 1.0 + + acq_value_base_mean = (base_acq_func(X_lb) + base_acq_func(X_ub)) / 2 + acq_value_analytic = pr_analytic_acq_func(X) + acq_value_mc = pr_mc_acq_func(X) + + # this is not exact due to sigmoid transform in discrete probabilities + self.assertAllClose(acq_value_analytic, acq_value_base_mean, rtol=0.1) + self.assertAllClose(acq_value_mc, acq_value_base_mean, rtol=0.1) + + candidate_analytic, acq_values_analytic = optimize_acqf( + acq_function=pr_analytic_acq_func, bounds=one_hot_bounds, q=1, - num_restarts=10, - raw_samples=20, - options={"maxiter": 5}, - # gen_candidates=gen_candidates_scipy, + gen_candidates=gen_candidates_scipy, + **self.optimize_acqf_params, ) - # candidates are generated in the one-hot space - candidate = one_hot_to_numeric(raw_candidate) - self.assertTrue(candidate.shape == (1, f.dim)) - def test_probabilistic_reparameterization_categorical_analytic_qLogEI(self): - self.test_probabilistic_reparameterization_categorical( - pr_acq_func_cls=AnalyticProbabilisticReparameterization, - base_acq_func_cls=qLogExpectedImprovement, + candidate_mc, acq_values_mc = optimize_acqf( + acq_function=pr_mc_acq_func, + bounds=one_hot_bounds, + q=1, + gen_candidates=gen_candidates_torch, + **self.optimize_acqf_params, ) - def test_probabilistic_reparameterization_categorical_MC_LogEI(self): - self.test_probabilistic_reparameterization_categorical( - pr_acq_func_cls=MCProbabilisticReparameterization, - base_acq_func_cls=LogExpectedImprovement, + fixed_features_list = [ + { + start_dim + i: float(val == i) + for (start_dim, num_cat), val in zip(categorical_features.items(), vals) + for i in range(num_cat) + } + for vals in itertools.product(*map(range, categorical_features.values())) + ] + candidate_exhaustive, acq_values_exhaustive = optimize_acqf_mixed( + acq_function=base_acq_func, + fixed_features_list=fixed_features_list, + bounds=one_hot_bounds, + q=1, + **self.optimize_acqf_params, ) - def test_probabilistic_reparameterization_categorical_MC_qLogEI(self): + self.assertTrue(candidate_analytic.shape == (1, one_hot_bounds.shape[-1])) + self.assertTrue(candidate_mc.shape == (1, one_hot_bounds.shape[-1])) + self.assertTrue(one_hot_to_numeric(candidate_analytic).shape == (1, f.dim)) + + # round the mc candidate to allow for comparison + candidate_mc = init_exact_rounding_func(candidate_mc) + + self.assertAllClose(candidate_analytic, candidate_exhaustive, rtol=0.1) + self.assertAllClose(acq_values_analytic, acq_values_exhaustive, rtol=0.1) + self.assertAllClose(candidate_mc, candidate_exhaustive, rtol=0.1) + self.assertAllClose(acq_values_mc, acq_values_exhaustive, rtol=0.1) + + def test_probabilistic_reparameterization_categorical_qLogEI(self): self.test_probabilistic_reparameterization_categorical( - pr_acq_func_cls=MCProbabilisticReparameterization, base_acq_func_cls=qLogExpectedImprovement, )