14
14
from botorch .fit import fit_gpytorch_mll
15
15
from botorch .models import ModelListGP
16
16
from botorch .models .gp_regression import FixedNoiseGP , SingleTaskGP
17
- from botorch .models .transforms import Standardize
18
17
from botorch .models .transforms .input import Normalize
19
- from botorch .posteriors import GPyTorchPosterior
18
+ from botorch .models .transforms .outcome import ChainedOutcomeTransform , Log , Standardize
19
+ from botorch .posteriors import GPyTorchPosterior , PosteriorList , TransformedPosterior
20
20
from botorch .sampling .normal import IIDNormalSampler
21
21
from botorch .utils .testing import _get_random_data , BotorchTestCase
22
22
from gpytorch .distributions import MultitaskMultivariateNormal , MultivariateNormal
28
28
from gpytorch .priors import GammaPrior
29
29
30
30
31
- def _get_model (fixed_noise = False , use_octf = False , use_intf = False , ** tkwargs ):
31
+ def _get_model (
32
+ fixed_noise = False , outcome_transform : str = "None" , use_intf = False , ** tkwargs
33
+ ) -> ModelListGP :
32
34
train_x1 , train_y1 = _get_random_data (
33
35
batch_shape = torch .Size (), m = 1 , n = 10 , ** tkwargs
34
36
)
37
+ train_y1 = torch .exp (train_y1 )
35
38
train_x2 , train_y2 = _get_random_data (
36
39
batch_shape = torch .Size (), m = 1 , n = 11 , ** tkwargs
37
40
)
38
- octfs = [Standardize (m = 1 ), Standardize (m = 1 )] if use_octf else [None , None ]
41
+ if outcome_transform == "Standardize" :
42
+ octfs = [Standardize (m = 1 ), Standardize (m = 1 )]
43
+ elif outcome_transform == "Log" :
44
+ octfs = [Log (), Standardize (m = 1 )]
45
+ elif outcome_transform == "Chained" :
46
+ octfs = [
47
+ ChainedOutcomeTransform (
48
+ chained = ChainedOutcomeTransform (log = Log (), standardize = Standardize (m = 1 ))
49
+ ),
50
+ Standardize (m = 1 ),
51
+ ]
52
+ elif outcome_transform == "None" :
53
+ octfs = [None , None ]
54
+ else :
55
+ raise KeyError ( # pragma: no cover
56
+ "outcome_transform must be one of 'Standardize', 'Log', 'Chained', or "
57
+ "'None'."
58
+ )
39
59
intfs = [Normalize (d = 1 ), Normalize (d = 1 )] if use_intf else [None , None ]
40
60
if fixed_noise :
41
61
train_y1_var = 0.1 + 0.1 * torch .rand_like (train_y1 , ** tkwargs )
@@ -73,10 +93,12 @@ def _get_model(fixed_noise=False, use_octf=False, use_intf=False, **tkwargs):
73
93
74
94
class TestModelListGP (BotorchTestCase ):
75
95
def _base_test_ModelListGP (
76
- self , fixed_noise : bool , dtype , use_octf : bool
96
+ self , fixed_noise : bool , dtype , outcome_transform : str
77
97
) -> ModelListGP :
78
98
tkwargs = {"device" : self .device , "dtype" : dtype }
79
- model = _get_model (fixed_noise = fixed_noise , use_octf = use_octf , ** tkwargs )
99
+ model = _get_model (
100
+ fixed_noise = fixed_noise , outcome_transform = outcome_transform , ** tkwargs
101
+ )
80
102
self .assertIsInstance (model , ModelListGP )
81
103
self .assertIsInstance (model .likelihood , LikelihoodList )
82
104
for m in model .models :
@@ -85,8 +107,12 @@ def _base_test_ModelListGP(
85
107
matern_kernel = m .covar_module .base_kernel
86
108
self .assertIsInstance (matern_kernel , MaternKernel )
87
109
self .assertIsInstance (matern_kernel .lengthscale_prior , GammaPrior )
88
- if use_octf :
89
- self .assertIsInstance (m .outcome_transform , Standardize )
110
+ if outcome_transform != "None" :
111
+ self .assertIsInstance (
112
+ m .outcome_transform , (Log , Standardize , ChainedOutcomeTransform )
113
+ )
114
+ else :
115
+ assert not hasattr (m , "outcome_transform" )
90
116
91
117
# test constructing likelihood wrapper
92
118
mll = SumMarginalLogLikelihood (model .likelihood , model )
@@ -121,9 +147,19 @@ def _base_test_ModelListGP(
121
147
# test posterior
122
148
test_x = torch .tensor ([[0.25 ], [0.75 ]], ** tkwargs )
123
149
posterior = model .posterior (test_x )
124
- self .assertIsInstance (posterior , GPyTorchPosterior )
125
- self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
126
- if use_octf :
150
+ gpytorch_posterior_expected = outcome_transform in ("None" , "Standardize" )
151
+ expected_type = (
152
+ GPyTorchPosterior if gpytorch_posterior_expected else PosteriorList
153
+ )
154
+ self .assertIsInstance (posterior , expected_type )
155
+ submodel = model .models [0 ]
156
+ p0 = submodel .posterior (test_x )
157
+ self .assertTrue (torch .allclose (posterior .mean [:, [0 ]], p0 .mean ))
158
+ self .assertTrue (torch .allclose (posterior .variance [:, [0 ]], p0 .variance ))
159
+
160
+ if gpytorch_posterior_expected :
161
+ self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
162
+ if outcome_transform != "None" :
127
163
# ensure un-transformation is applied
128
164
submodel = model .models [0 ]
129
165
p0 = submodel .posterior (test_x )
@@ -136,8 +172,9 @@ def _base_test_ModelListGP(
136
172
137
173
# test output_indices
138
174
posterior = model .posterior (test_x , output_indices = [0 ], observation_noise = True )
139
- self .assertIsInstance (posterior , GPyTorchPosterior )
140
- self .assertIsInstance (posterior .distribution , MultivariateNormal )
175
+ self .assertIsInstance (posterior , expected_type )
176
+ if gpytorch_posterior_expected :
177
+ self .assertIsInstance (posterior .distribution , MultivariateNormal )
141
178
142
179
# test condition_on_observations
143
180
f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
@@ -176,39 +213,50 @@ def _base_test_ModelListGP(
176
213
X = torch .rand (3 , 1 , ** tkwargs )
177
214
weights = torch .tensor ([1 , 2 ], ** tkwargs )
178
215
post_tf = ScalarizedPosteriorTransform (weights = weights )
179
- posterior_tf = model .posterior (X , posterior_transform = post_tf )
180
- self .assertTrue (
181
- torch .allclose (
182
- posterior_tf .mean ,
183
- model .posterior (X ).mean @ weights .unsqueeze (- 1 ),
216
+ if gpytorch_posterior_expected :
217
+ posterior_tf = model .posterior (X , posterior_transform = post_tf )
218
+ self .assertTrue (
219
+ torch .allclose (
220
+ posterior_tf .mean ,
221
+ model .posterior (X ).mean @ weights .unsqueeze (- 1 ),
222
+ )
184
223
)
185
- )
186
224
187
225
return model
188
226
189
227
def test_ModelListGP (self ) -> None :
190
- for dtype , use_octf in itertools .product (
191
- (torch .float , torch .double ), (False , True )
228
+ for dtype , outcome_transform in itertools .product (
229
+ (torch .float , torch .double ), ("None" , "Standardize" , "Log" , "Chained" )
192
230
):
193
231
194
232
model = self ._base_test_ModelListGP (
195
- fixed_noise = False , dtype = dtype , use_octf = use_octf
233
+ fixed_noise = False , dtype = dtype , outcome_transform = outcome_transform
196
234
)
197
235
tkwargs = {"device" : self .device , "dtype" : dtype }
198
236
199
237
# test observation_noise
200
238
test_x = torch .tensor ([[0.25 ], [0.75 ]], ** tkwargs )
201
239
posterior = model .posterior (test_x , observation_noise = True )
202
- self .assertIsInstance (posterior , GPyTorchPosterior )
203
- self .assertIsInstance (posterior .distribution , MultitaskMultivariateNormal )
240
+
241
+ gpytorch_posterior_expected = outcome_transform in ("None" , "Standardize" )
242
+ expected_type = (
243
+ GPyTorchPosterior if gpytorch_posterior_expected else PosteriorList
244
+ )
245
+ self .assertIsInstance (posterior , expected_type )
246
+ if gpytorch_posterior_expected :
247
+ self .assertIsInstance (
248
+ posterior .distribution , MultitaskMultivariateNormal
249
+ )
250
+ else :
251
+ self .assertIsInstance (posterior .posteriors [0 ], TransformedPosterior )
204
252
205
253
def test_ModelListGP_fixed_noise (self ) -> None :
206
254
207
- for dtype , use_octf in itertools .product (
208
- (torch .float , torch .double ), (False , True )
255
+ for dtype , outcome_transform in itertools .product (
256
+ (torch .float , torch .double ), ("None" , "Standardize" )
209
257
):
210
258
model = self ._base_test_ModelListGP (
211
- fixed_noise = True , dtype = dtype , use_octf = use_octf
259
+ fixed_noise = True , dtype = dtype , outcome_transform = outcome_transform
212
260
)
213
261
tkwargs = {"device" : self .device , "dtype" : dtype }
214
262
f_x = [torch .rand (2 , 1 , ** tkwargs ) for _ in range (2 )]
0 commit comments