Skip to content

Commit 5066540

Browse files
esantorellafacebook-github-bot
authored andcommitted
Support for outcome transforms that return a TransformedPosterior in ModelListGP (#1563)
Summary: Pull Request resolved: #1563 Replaces D41860896, but adds support for the outcome transforms that weren't working for `ModelListGP.posterior` by calling `ModelList.posterior`. See #1519 for more context on the issue that this is fixing. Reviewed By: saitcakmak Differential Revision: D42019721 fbshipit-source-id: f2f566c53f327a02e26008428a187a6b9abf0c90
1 parent 14199e6 commit 5066540

File tree

6 files changed

+169
-38
lines changed

6 files changed

+169
-38
lines changed

botorch/models/gpytorch.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from torch import Tensor
4040

4141
if TYPE_CHECKING:
42+
from botorch.posteriors.posterior_list import PosteriorList # pragma: no cover
4243
from botorch.posteriors.transformed import TransformedPosterior # pragma: no cover
4344
from gpytorch.likelihoods import Likelihood # pragma: no cover
4445

@@ -555,14 +556,15 @@ def batch_shape(self) -> torch.Size:
555556
raise NotImplementedError(msg + " that are not broadcastble.")
556557
return next(iter(batch_shapes))
557558

559+
# pyre-fixme[15]: Inconsistent override in return types
558560
def posterior(
559561
self,
560562
X: Tensor,
561563
output_indices: Optional[List[int]] = None,
562564
observation_noise: Union[bool, Tensor] = False,
563565
posterior_transform: Optional[PosteriorTransform] = None,
564566
**kwargs: Any,
565-
) -> GPyTorchPosterior:
567+
) -> Union[GPyTorchPosterior, PosteriorList]:
566568
r"""Computes the posterior over model outputs at the provided points.
567569
568570
Args:
@@ -582,11 +584,38 @@ def posterior(
582584
posterior_transform: An optional PosteriorTransform.
583585
584586
Returns:
585-
A `GPyTorchPosterior` or `FullyBayesianPosterior` object, representing
586-
`batch_shape` joint distributions over `q` points and the outputs selected
587-
by `output_indices` each. Includes measurement noise if
588-
`observation_noise` is specified.
587+
- If no `posterior_transform` is provided and the component models have no
588+
`outcome_transform`, or if the component models only use linear outcome
589+
transforms like `Standardize` (i.e. not `Log`), returns a
590+
`GPyTorchPosterior` or `FullyBayesianPosterior` object,
591+
representing `batch_shape` joint distributions over `q` points
592+
and the outputs selected by `output_indices` each. Includes
593+
measurement noise if `observation_noise` is specified.
594+
- If no `posterior_transform` is provided and component models have
595+
nonlinear transforms like `Log`, returns a `PosteriorList` with
596+
sub-posteriors of type `TransformedPosterior`
597+
- If `posterior_transform` is provided, that posterior transform will be
598+
applied and will determine the return type. This could potentially be
599+
any subclass of `Posterior`, but common choices give a
600+
`GPyTorchPosterior`.
589601
"""
602+
603+
# Nonlinear transforms untransform to a `TransformedPosterior`,
604+
# which can't be made into a `GPyTorchPosterior`
605+
returns_untransformed = any(
606+
hasattr(mod, "outcome_transform") and (not mod.outcome_transform._is_linear)
607+
for mod in self.models
608+
)
609+
if returns_untransformed:
610+
return ModelList.posterior(
611+
self,
612+
X,
613+
output_indices,
614+
observation_noise,
615+
posterior_transform,
616+
**kwargs,
617+
)
618+
590619
self.eval() # make sure model is in eval mode
591620
# input transforms are applied at `posterior` in `eval` mode, and at
592621
# `model.forward()` at the training time
@@ -628,10 +657,10 @@ def posterior(
628657
# apply output transforms of individual models if present
629658
mvns = []
630659
for i, mvn in mvn_gen:
631-
try:
660+
if hasattr(self.models[i], "outcome_transform"):
632661
oct = self.models[i].outcome_transform
633662
tf_mvn = oct.untransform_posterior(GPyTorchPosterior(mvn)).distribution
634-
except AttributeError:
663+
else:
635664
tf_mvn = mvn
636665
mvns.append(tf_mvn)
637666
# return result as a GPyTorchPosteriors/FullyBayesianPosterior

botorch/models/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def posterior(
391391
X: Tensor,
392392
output_indices: Optional[List[int]] = None,
393393
observation_noise: bool = False,
394-
posterior_transform: Optional[Callable[[Posterior], Posterior]] = None,
394+
posterior_transform: Optional[Callable[[PosteriorList], Posterior]] = None,
395395
**kwargs: Any,
396396
) -> Posterior:
397397
r"""Computes the posterior over model outputs at the provided points.

botorch/models/transforms/outcome.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,22 @@ def untransform(
101101
f"{self.__class__.__name__} does not implement the `untransform` method"
102102
)
103103

104+
@property
105+
def _is_linear(self) -> bool:
106+
"""
107+
True for transformations such as `Standardize`; these should be able to apply
108+
`untransform_posterior` to a GPyTorchPosterior and return a GPyTorchPosterior,
109+
because a multivariate normal distribution should remain multivariate normal
110+
after applying the transform.
111+
"""
112+
return False
113+
104114
def untransform_posterior(self, posterior: Posterior) -> Posterior:
105-
r"""Un-transform a posterior
115+
r"""Un-transform a posterior.
116+
117+
Posteriors with `_is_linear=True` should return a `GPyTorchPosterior` when
118+
`posterior` is a `GPyTorchPosterior`. Posteriors with `_is_linear=False`
119+
likely return a `TransformedPosterior` instead.
106120
107121
Args:
108122
posterior: A posterior in the transformed space.
@@ -182,6 +196,14 @@ def untransform(
182196
Y, Yvar = tf.untransform(Y, Yvar)
183197
return Y, Yvar
184198

199+
@property
200+
def _is_linear(self) -> bool:
201+
"""
202+
A `ChainedOutcomeTransform` is linear only if all of the component transforms
203+
are linear.
204+
"""
205+
return all((octf._is_linear for octf in self.values()))
206+
185207
def untransform_posterior(self, posterior: Posterior) -> Posterior:
186208
r"""Un-transform a posterior
187209
@@ -255,7 +277,10 @@ def forward(
255277
if Y.shape[:-2] != self._batch_shape:
256278
raise RuntimeError("wrong batch shape")
257279
if Y.size(-1) != self._m:
258-
raise RuntimeError("wrong output dimension")
280+
raise RuntimeError(
281+
f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
282+
f"{self._m}."
283+
)
259284
stdvs = Y.std(dim=-2, keepdim=True)
260285
stdvs = stdvs.where(stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0))
261286
means = Y.mean(dim=-2, keepdim=True)
@@ -331,6 +356,10 @@ def untransform(
331356
Yvar_utf = self._stdvs_sq * Yvar if Yvar is not None else None
332357
return Y_utf, Yvar_utf
333358

359+
@property
360+
def _is_linear(self) -> bool:
361+
return True
362+
334363
def untransform_posterior(
335364
self, posterior: Posterior
336365
) -> Union[GPyTorchPosterior, TransformedPosterior]:

botorch/posteriors/gpytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
class GPyTorchPosterior(TorchPosterior):
3737
r"""A posterior based on GPyTorch's multi-variate Normal distributions."""
38+
distribution: MultivariateNormal
3839

3940
def __init__(
4041
self,

test/models/test_model_list_gp_regression.py

Lines changed: 75 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from botorch.fit import fit_gpytorch_mll
1515
from botorch.models import ModelListGP
1616
from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP
17-
from botorch.models.transforms import Standardize
1817
from botorch.models.transforms.input import Normalize
19-
from botorch.posteriors import GPyTorchPosterior
18+
from botorch.models.transforms.outcome import ChainedOutcomeTransform, Log, Standardize
19+
from botorch.posteriors import GPyTorchPosterior, PosteriorList, TransformedPosterior
2020
from botorch.sampling.normal import IIDNormalSampler
2121
from botorch.utils.testing import _get_random_data, BotorchTestCase
2222
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
@@ -28,14 +28,34 @@
2828
from gpytorch.priors import GammaPrior
2929

3030

31-
def _get_model(fixed_noise=False, use_octf=False, use_intf=False, **tkwargs):
31+
def _get_model(
32+
fixed_noise=False, outcome_transform: str = "None", use_intf=False, **tkwargs
33+
) -> ModelListGP:
3234
train_x1, train_y1 = _get_random_data(
3335
batch_shape=torch.Size(), m=1, n=10, **tkwargs
3436
)
37+
train_y1 = torch.exp(train_y1)
3538
train_x2, train_y2 = _get_random_data(
3639
batch_shape=torch.Size(), m=1, n=11, **tkwargs
3740
)
38-
octfs = [Standardize(m=1), Standardize(m=1)] if use_octf else [None, None]
41+
if outcome_transform == "Standardize":
42+
octfs = [Standardize(m=1), Standardize(m=1)]
43+
elif outcome_transform == "Log":
44+
octfs = [Log(), Standardize(m=1)]
45+
elif outcome_transform == "Chained":
46+
octfs = [
47+
ChainedOutcomeTransform(
48+
chained=ChainedOutcomeTransform(log=Log(), standardize=Standardize(m=1))
49+
),
50+
Standardize(m=1),
51+
]
52+
elif outcome_transform == "None":
53+
octfs = [None, None]
54+
else:
55+
raise KeyError( # pragma: no cover
56+
"outcome_transform must be one of 'Standardize', 'Log', 'Chained', or "
57+
"'None'."
58+
)
3959
intfs = [Normalize(d=1), Normalize(d=1)] if use_intf else [None, None]
4060
if fixed_noise:
4161
train_y1_var = 0.1 + 0.1 * torch.rand_like(train_y1, **tkwargs)
@@ -73,10 +93,12 @@ def _get_model(fixed_noise=False, use_octf=False, use_intf=False, **tkwargs):
7393

7494
class TestModelListGP(BotorchTestCase):
7595
def _base_test_ModelListGP(
76-
self, fixed_noise: bool, dtype, use_octf: bool
96+
self, fixed_noise: bool, dtype, outcome_transform: str
7797
) -> ModelListGP:
7898
tkwargs = {"device": self.device, "dtype": dtype}
79-
model = _get_model(fixed_noise=fixed_noise, use_octf=use_octf, **tkwargs)
99+
model = _get_model(
100+
fixed_noise=fixed_noise, outcome_transform=outcome_transform, **tkwargs
101+
)
80102
self.assertIsInstance(model, ModelListGP)
81103
self.assertIsInstance(model.likelihood, LikelihoodList)
82104
for m in model.models:
@@ -85,8 +107,12 @@ def _base_test_ModelListGP(
85107
matern_kernel = m.covar_module.base_kernel
86108
self.assertIsInstance(matern_kernel, MaternKernel)
87109
self.assertIsInstance(matern_kernel.lengthscale_prior, GammaPrior)
88-
if use_octf:
89-
self.assertIsInstance(m.outcome_transform, Standardize)
110+
if outcome_transform != "None":
111+
self.assertIsInstance(
112+
m.outcome_transform, (Log, Standardize, ChainedOutcomeTransform)
113+
)
114+
else:
115+
assert not hasattr(m, "outcome_transform")
90116

91117
# test constructing likelihood wrapper
92118
mll = SumMarginalLogLikelihood(model.likelihood, model)
@@ -121,9 +147,19 @@ def _base_test_ModelListGP(
121147
# test posterior
122148
test_x = torch.tensor([[0.25], [0.75]], **tkwargs)
123149
posterior = model.posterior(test_x)
124-
self.assertIsInstance(posterior, GPyTorchPosterior)
125-
self.assertIsInstance(posterior.distribution, MultitaskMultivariateNormal)
126-
if use_octf:
150+
gpytorch_posterior_expected = outcome_transform in ("None", "Standardize")
151+
expected_type = (
152+
GPyTorchPosterior if gpytorch_posterior_expected else PosteriorList
153+
)
154+
self.assertIsInstance(posterior, expected_type)
155+
submodel = model.models[0]
156+
p0 = submodel.posterior(test_x)
157+
self.assertTrue(torch.allclose(posterior.mean[:, [0]], p0.mean))
158+
self.assertTrue(torch.allclose(posterior.variance[:, [0]], p0.variance))
159+
160+
if gpytorch_posterior_expected:
161+
self.assertIsInstance(posterior.distribution, MultitaskMultivariateNormal)
162+
if outcome_transform != "None":
127163
# ensure un-transformation is applied
128164
submodel = model.models[0]
129165
p0 = submodel.posterior(test_x)
@@ -136,8 +172,9 @@ def _base_test_ModelListGP(
136172

137173
# test output_indices
138174
posterior = model.posterior(test_x, output_indices=[0], observation_noise=True)
139-
self.assertIsInstance(posterior, GPyTorchPosterior)
140-
self.assertIsInstance(posterior.distribution, MultivariateNormal)
175+
self.assertIsInstance(posterior, expected_type)
176+
if gpytorch_posterior_expected:
177+
self.assertIsInstance(posterior.distribution, MultivariateNormal)
141178

142179
# test condition_on_observations
143180
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]
@@ -176,39 +213,50 @@ def _base_test_ModelListGP(
176213
X = torch.rand(3, 1, **tkwargs)
177214
weights = torch.tensor([1, 2], **tkwargs)
178215
post_tf = ScalarizedPosteriorTransform(weights=weights)
179-
posterior_tf = model.posterior(X, posterior_transform=post_tf)
180-
self.assertTrue(
181-
torch.allclose(
182-
posterior_tf.mean,
183-
model.posterior(X).mean @ weights.unsqueeze(-1),
216+
if gpytorch_posterior_expected:
217+
posterior_tf = model.posterior(X, posterior_transform=post_tf)
218+
self.assertTrue(
219+
torch.allclose(
220+
posterior_tf.mean,
221+
model.posterior(X).mean @ weights.unsqueeze(-1),
222+
)
184223
)
185-
)
186224

187225
return model
188226

189227
def test_ModelListGP(self) -> None:
190-
for dtype, use_octf in itertools.product(
191-
(torch.float, torch.double), (False, True)
228+
for dtype, outcome_transform in itertools.product(
229+
(torch.float, torch.double), ("None", "Standardize", "Log", "Chained")
192230
):
193231

194232
model = self._base_test_ModelListGP(
195-
fixed_noise=False, dtype=dtype, use_octf=use_octf
233+
fixed_noise=False, dtype=dtype, outcome_transform=outcome_transform
196234
)
197235
tkwargs = {"device": self.device, "dtype": dtype}
198236

199237
# test observation_noise
200238
test_x = torch.tensor([[0.25], [0.75]], **tkwargs)
201239
posterior = model.posterior(test_x, observation_noise=True)
202-
self.assertIsInstance(posterior, GPyTorchPosterior)
203-
self.assertIsInstance(posterior.distribution, MultitaskMultivariateNormal)
240+
241+
gpytorch_posterior_expected = outcome_transform in ("None", "Standardize")
242+
expected_type = (
243+
GPyTorchPosterior if gpytorch_posterior_expected else PosteriorList
244+
)
245+
self.assertIsInstance(posterior, expected_type)
246+
if gpytorch_posterior_expected:
247+
self.assertIsInstance(
248+
posterior.distribution, MultitaskMultivariateNormal
249+
)
250+
else:
251+
self.assertIsInstance(posterior.posteriors[0], TransformedPosterior)
204252

205253
def test_ModelListGP_fixed_noise(self) -> None:
206254

207-
for dtype, use_octf in itertools.product(
208-
(torch.float, torch.double), (False, True)
255+
for dtype, outcome_transform in itertools.product(
256+
(torch.float, torch.double), ("None", "Standardize")
209257
):
210258
model = self._base_test_ModelListGP(
211-
fixed_noise=True, dtype=dtype, use_octf=use_octf
259+
fixed_noise=True, dtype=dtype, outcome_transform=outcome_transform
212260
)
213261
tkwargs = {"device": self.device, "dtype": dtype}
214262
f_x = [torch.rand(2, 1, **tkwargs) for _ in range(2)]

test/models/transforms/test_outcome.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,34 @@ def test_standardize_raises_when_mean_not_set(self) -> None:
9393
):
9494
transform.untransform(y)
9595

96+
def test_is_linear(self) -> None:
97+
posterior = _get_test_posterior(
98+
shape=torch.Size([1, 1]), device=self.device, dtype=torch.float64
99+
)
100+
y = torch.arange(2, dtype=torch.float64, device=self.device)[:, None]
101+
standardize_tf = Standardize(m=1)
102+
standardize_tf(y)
103+
104+
for transform in [
105+
standardize_tf,
106+
Power(power=0.5),
107+
Log(),
108+
ChainedOutcomeTransform(
109+
chained=ChainedOutcomeTransform(stand=standardize_tf)
110+
),
111+
ChainedOutcomeTransform(log=Log()),
112+
]:
113+
posterior_is_gpt = isinstance(
114+
transform.untransform_posterior(posterior), GPyTorchPosterior
115+
)
116+
self.assertEqual(posterior_is_gpt, transform._is_linear)
117+
96118
def test_standardize(self):
97119
# test error on incompatible dim
98120
tf = Standardize(m=1)
99-
with self.assertRaises(RuntimeError):
121+
with self.assertRaises(
122+
RuntimeError, msg="Wrong output dimension. Y.size(-1) is 2; expected 1."
123+
):
100124
tf(torch.zeros(3, 2, device=self.device), None)
101125
# test error on incompatible batch shape
102126
with self.assertRaises(RuntimeError):

0 commit comments

Comments
 (0)