Skip to content

Commit ab28cc4

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Introduce ModelDict and MultiModelAcquisitionFunction (#1584)
Summary: Pull Request resolved: #1584 Introduces a lightweight `ModelDict` container, which is simply a `ModuleDict[str, Model]`, and an abstract `MultiModelAcquisitionFunction` class that accepts a `ModelDict` rather than a `Model`. The goal here is to help shape the MBM Surrogate refactor by having a concrete example of how the multiple surrogates would be consumed in BoTorch. Reviewed By: lena-kashtelyan Differential Revision: D41564744 fbshipit-source-id: d18632b94041d3529cc0b98560d7791e7d83d59d
1 parent d7edf20 commit ab28cc4

File tree

4 files changed

+85
-4
lines changed

4 files changed

+85
-4
lines changed

botorch/acquisition/acquisition.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import torch
1616
from botorch.exceptions import BotorchWarning, UnsupportedError
17-
from botorch.models.model import Model
17+
from botorch.models.model import Model, ModelDict
1818
from botorch.posteriors.posterior import Posterior
1919
from botorch.sampling.base import MCSampler
2020
from botorch.sampling.get_sampler import get_sampler
@@ -168,3 +168,34 @@ def get_posterior_samples(self, posterior: Posterior) -> Tensor:
168168
posterior=posterior, sample_shape=self._default_sample_shape
169169
)
170170
return self.sampler(posterior=posterior)
171+
172+
173+
class MultiModelAcquisitionFunction(AcquisitionFunction, ABC):
174+
r"""Abstract base class for acquisition functions that require
175+
multiple types of models.
176+
177+
The intended use case for these acquisition functions are those
178+
where we have multiple models, each serving a distinct purpose.
179+
As an example, we can have a "regression" model that predicts
180+
one or more outcomes, and a "classification" model that predicts
181+
the probabilty that a given parameterization is feasible. The
182+
multi-model acquisition function can then weight the acquisition
183+
value computed with the "regression" model with the feasibility
184+
value predicted by the "classification" model to produce the
185+
composite acquisition value.
186+
187+
This is currently only a placeholder to help with some development
188+
in Ax. We plan to add some acquisition functions utilizing multiple
189+
models in the future.
190+
191+
:meta private:
192+
"""
193+
194+
def __init__(self, model_dict: ModelDict) -> None:
195+
r"""Constructor for the MultiModelAcquisitionFunction base class.
196+
197+
Args:
198+
model_dict: A ModelDict mapping labels to models.
199+
"""
200+
super(AcquisitionFunction, self).__init__()
201+
self.model_dict: ModelDict = model_dict

botorch/models/model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,14 @@
3232
import numpy as np
3333
import torch
3434
from botorch import settings
35+
from botorch.exceptions.errors import InputDataError
3536
from botorch.models.utils.assorted import fantasize as fantasize_flag
3637
from botorch.posteriors import Posterior, PosteriorList
3738
from botorch.sampling.base import MCSampler
3839
from botorch.utils.datasets import BotorchDataset
3940
from botorch.utils.transforms import is_fully_bayesian
4041
from torch import Tensor
41-
from torch.nn import Module, ModuleList
42+
from torch.nn import Module, ModuleDict, ModuleList
4243

4344
if TYPE_CHECKING:
4445
from botorch.acquisition.objective import PosteriorTransform # pragma: no cover
@@ -514,3 +515,20 @@ def load_state_dict(
514515
}
515516
m.load_state_dict(filtered_dict)
516517
super().load_state_dict(state_dict=state_dict, strict=strict)
518+
519+
520+
class ModelDict(ModuleDict):
521+
r"""A lightweight container mapping model names to models."""
522+
523+
def __init__(self, **models: Model) -> None:
524+
r"""Initialize a `ModelDict`.
525+
526+
Args:
527+
models: An arbitrary number of models. Each model can be any type
528+
of BoTorch `Model`, including multi-output models and `ModelList`.
529+
"""
530+
if any(not isinstance(m, Model) for m in models.values()):
531+
raise InputDataError(
532+
f"Expected all models to be a BoTorch `Model`. Got {models}."
533+
)
534+
super().__init__(modules=models)

test/acquisition/test_acquisition.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from botorch.acquisition.acquisition import (
99
AcquisitionFunction,
1010
MCSamplerMixin,
11+
MultiModelAcquisitionFunction,
1112
OneShotAcquisitionFunction,
1213
)
14+
from botorch.models.model import ModelDict
1315
from botorch.sampling.normal import IIDNormalSampler
1416
from botorch.sampling.stochastic_samplers import StochasticSampler
1517
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
@@ -25,6 +27,11 @@ def forward(self, X):
2527
raise NotImplementedError
2628

2729

30+
class DummyMultiModelAcqf(MultiModelAcquisitionFunction):
31+
def forward(self, X):
32+
raise NotImplementedError
33+
34+
2835
class TestAcquisitionFunction(BotorchTestCase):
2936
def test_abstract_raises(self):
3037
with self.assertRaises(TypeError):
@@ -48,3 +55,15 @@ def test_mc_sampler_mixin(self):
4855
sampler = IIDNormalSampler(sample_shape=torch.Size([2]))
4956
acqf.sampler = sampler
5057
self.assertIs(acqf.sampler, sampler)
58+
59+
60+
class TestMultiModelAcquisitionFunction(BotorchTestCase):
61+
def test_multi_model_acquisition_function(self):
62+
model_dict = ModelDict(
63+
m1=MockModel(MockPosterior()),
64+
m2=MockModel(MockPosterior()),
65+
)
66+
with self.assertRaises(TypeError):
67+
MultiModelAcquisitionFunction(model_dict=model_dict)
68+
acqf = DummyMultiModelAcqf(model_dict=model_dict)
69+
self.assertIs(acqf.model_dict, model_dict)

test/models/test_model.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88

99
import torch
1010
from botorch.acquisition.objective import PosteriorTransform
11+
from botorch.exceptions.errors import InputDataError
1112
from botorch.models.deterministic import GenericDeterministicModel
12-
from botorch.models.model import Model, ModelList
13+
from botorch.models.model import Model, ModelDict, ModelList
1314
from botorch.models.utils import parse_training_data
1415
from botorch.posteriors.deterministic import DeterministicPosterior
1516
from botorch.posteriors.posterior_list import PosteriorList
16-
from botorch.utils.testing import BotorchTestCase
17+
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
1718

1819

1920
class NotSoAbstractBaseModel(Model):
@@ -117,3 +118,15 @@ def test_posterior_transform(self):
117118
posterior_tf.mean, torch.cat((2 * m1(X) + 1, 2 * m2(X) + 1), dim=-1)
118119
)
119120
)
121+
122+
123+
class TestModelDict(BotorchTestCase):
124+
def test_model_dict(self):
125+
models = {"m1": MockModel(MockPosterior()), "m2": MockModel(MockPosterior())}
126+
model_dict = ModelDict(**models)
127+
self.assertIs(model_dict["m1"], models["m1"])
128+
self.assertIs(model_dict["m2"], models["m2"])
129+
with self.assertRaisesRegex(
130+
InputDataError, "Expected all models to be a BoTorch `Model`."
131+
):
132+
ModelDict(m=MockPosterior())

0 commit comments

Comments
 (0)