Skip to content

Commit e7fef3b

Browse files
SamuelGabrielfacebook-github-bot
authored andcommitted
Changes to enable benchmarking with PFNs within Ax (#2915)
Summary: X-link: facebook/Ax#4003 Pull Request resolved: #2915 These are all the basic changes needed for the PFNs to work within Ax. While this removes a bunch of problems encountered it also fixes a particular bug that led to worse optimization performance: an issue in the way EI is computed. This also adds tests to see that the acquisition functions compute the right thing when approximating a normal distribution. ### Discussion (cc saitcakmak) Should we get rid of the ag_integrate logic? I overtook it from whoever wrote this before, but have to say that just implementing acquisition functions using the raw logits seems easier to me. I believe it was implemented such that we can define acquisition functions based on a posterior in an elegant way. I would propose to do it slightly less elegantly and use a function like the one below, where we access the logits and the borders from the posterior (`posterior.borders`) instead of the posterior providing the integrate function. This is how you implement EI, which I believe to be simpler than our current EI implementation (which already had a bug twice). It even has only the same amount of lines, but does not require you to understand the concept of splitting up an integration into a product that is worked on separately. ``` def ei( self, logits: torch.Tensor, best_f: float | torch.Tensor, *, maximize: bool = True, ) -> torch.Tensor: # logits: evaluation_points x batch x feature_dim bucket_diffs = self.borders[1:] - self.borders[:-1] assert maximize if not torch.is_tensor(best_f) or not len(best_f.shape): # type: ignore best_f = torch.full( logits[..., 0].shape, best_f, device=logits.device ) # type: ignore best_f = best_f[..., None].repeat( *[1] * len(best_f.shape), logits.shape[-1] ) # type: ignore clamped_best_f = best_f.clamp(self.borders[:-1], self.borders[1:]) # > bucket_contributions = # > (best_f[...,None] < self.borders[:-1]).float() * bucket_means # true bucket contributions bucket_contributions = ( (self.borders[1:] ** 2 - clamped_best_f**2) / 2 - best_f * (self.borders[1:] - clamped_best_f) ) / bucket_diffs p = torch.softmax(logits, -1) return torch.einsum("...b,...b->...", p, bucket_contributions) ``` Reviewed By: saitcakmak Differential Revision: D77884839 fbshipit-source-id: b138037cd5252733a0bf13db7f42631f0b6959a0
1 parent 07ce376 commit e7fef3b

File tree

8 files changed

+512
-68
lines changed

8 files changed

+512
-68
lines changed

botorch_community/acquisition/discretized.py

Lines changed: 80 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,13 @@
1111
from abc import ABC, abstractmethod
1212

1313
import torch
14-
1514
from botorch.acquisition import AcquisitionFunction
15+
from botorch.acquisition.objective import (
16+
PosteriorTransform,
17+
ScalarizedPosteriorTransform,
18+
)
19+
20+
from botorch.exceptions.errors import UnsupportedError
1621
from botorch.models.model import Model
1722
from botorch.utils.transforms import (
1823
average_over_ensemble_models,
@@ -34,17 +39,33 @@ class DiscretizedAcquistionFunction(AcquisitionFunction, ABC):
3439
be implemented by subclasses to define the specific acquisition functions.
3540
"""
3641

37-
def __init__(self, model: Model) -> None:
42+
def __init__(self, model: Model, posterior_transform: PosteriorTransform) -> None:
3843
r"""
3944
Initialize the DiscretizedAcquistionFunction
4045
4146
Args:
4247
model: A fitted model that is used to compute the posterior
4348
distribution over the outcomes of interest.
4449
The model should be a `PFNModel`.
50+
posterior_transform: A ScalarizedPosteriorTransform that can only
51+
indicate minimization or maximization of the objective.
4552
"""
46-
4753
super().__init__(model=model)
54+
self.maximize = True
55+
if posterior_transform is not None:
56+
unsupported_error_message = (
57+
"Only scalarized posterior transforms with a"
58+
"single objective and 0.0 offset are supported."
59+
)
60+
if (
61+
not isinstance(posterior_transform, ScalarizedPosteriorTransform)
62+
or (posterior_transform.offset != 0.0)
63+
or len(posterior_transform.weights) != 1
64+
or posterior_transform.weights[0] not in [-1.0, 1.0]
65+
):
66+
raise UnsupportedError(unsupported_error_message)
67+
68+
self.maximize = posterior_transform.weights[0] == 1.0
4869

4970
@t_batch_mode_transform(expected_q=1)
5071
@average_over_ensemble_models
@@ -59,9 +80,13 @@ def forward(self, X: Tensor) -> Tensor:
5980
A `(b)`-dim Tensor of the acquisition function at the given
6081
design points `X`.
6182
"""
62-
self.to(device=X.device)
63-
6483
discrete_posterior = self.model.posterior(X)
84+
if not self.maximize:
85+
discrete_posterior.borders = -torch.flip(discrete_posterior.borders, [0])
86+
discrete_posterior.probabilities = torch.flip(
87+
discrete_posterior.probabilities, [-1]
88+
)
89+
6590
result = discrete_posterior.integrate(self.ag_integrate)
6691
# remove q dimension
6792
return result.squeeze(-1)
@@ -87,18 +112,19 @@ def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor:
87112
"""
88113
pass # pragma: no cover
89114

90-
r"""DiscretizedExpectedImprovement is an acquisition function that computes
91-
the expected improvement over the current best observed value for a Riemann
92-
distribution."""
93-
94115

95116
class DiscretizedExpectedImprovement(DiscretizedAcquistionFunction):
96117
r"""DiscretizedExpectedImprovement is an acquisition function that
97118
computes the expected improvement over the current best observed value
98119
for a Riemann distribution.
99120
"""
100121

101-
def __init__(self, model: Model, best_f: Tensor) -> None:
122+
def __init__(
123+
self,
124+
model: Model,
125+
best_f: Tensor,
126+
posterior_transform: PosteriorTransform | None = None,
127+
) -> None:
102128
r"""
103129
Initialize the DiscretizedExpectedImprovement
104130
@@ -108,7 +134,7 @@ def __init__(self, model: Model, best_f: Tensor) -> None:
108134
The model should be a `PFNModel`.
109135
best_f: A tensor representing the current best observed value.
110136
"""
111-
super().__init__(model)
137+
super().__init__(model=model, posterior_transform=posterior_transform)
112138
self.register_buffer("best_f", torch.as_tensor(best_f))
113139

114140
def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor:
@@ -127,11 +153,38 @@ def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor:
127153
A `(b)`-dim Tensor of acquisition function derivatives at the given
128154
design points `X`.
129155
"""
130-
max_lower_bound_and_f = torch.max(self.best_f, lower_bound)
131-
bucket_average = (upper_bound + max_lower_bound_and_f) / 2
132-
improvement = bucket_average - self.best_f
156+
best_f = self.best_f.to(lower_bound)
157+
158+
# Case 1: best_f >= upper_bound, entire interval gives 0 improvement
159+
case1_mask = best_f >= upper_bound
160+
161+
# Case 2: best_f <= lower_bound, entire interval gives improvement
162+
case2_mask = best_f <= lower_bound
163+
164+
# Case 3: lower_bound < best_f < upper_bound, partial improvement
165+
case3_mask = ~(case1_mask | case2_mask)
166+
167+
# Initialize result tensor
168+
result = torch.zeros_like(lower_bound)
169+
170+
# Case 1: result is already 0
171+
172+
# Case 2: integral = (
173+
# ((upper_bound + lower_bound)/2 - best_f)
174+
# * (upper_bound - lower_bound)
175+
# )
176+
if case2_mask.any():
177+
bucket_width = upper_bound - lower_bound
178+
bucket_center = (upper_bound + lower_bound) / 2
179+
result = torch.where(
180+
case2_mask, (bucket_center - best_f) * bucket_width, result
181+
)
182+
183+
# Case 3: integral = (upper_bound - best_f)²/2
184+
if case3_mask.any():
185+
result = torch.where(case3_mask, (upper_bound - best_f).pow(2) / 2, result)
133186

134-
return improvement.clamp_min(0)
187+
return result.clamp_min(0)
135188

136189

137190
class DiscretizedProbabilityOfImprovement(DiscretizedAcquistionFunction):
@@ -140,7 +193,12 @@ class DiscretizedProbabilityOfImprovement(DiscretizedAcquistionFunction):
140193
for a Riemann distribution.
141194
"""
142195

143-
def __init__(self, model: Model, best_f: Tensor) -> None:
196+
def __init__(
197+
self,
198+
model: Model,
199+
best_f: Tensor,
200+
posterior_transform: PosteriorTransform | None = None,
201+
) -> None:
144202
r"""
145203
Initialize the DiscretizedProbabilityOfImprovement
146204
@@ -151,7 +209,7 @@ def __init__(self, model: Model, best_f: Tensor) -> None:
151209
best_f: A tensor representing the current best observed value.
152210
"""
153211

154-
super().__init__(model)
212+
super().__init__(model, posterior_transform)
155213
self.register_buffer("best_f", torch.as_tensor(best_f))
156214

157215
def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor:
@@ -174,5 +232,8 @@ def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor:
174232
A `(b)`-dim Tensor of acquisition function derivatives at the given
175233
design points `X`.
176234
"""
177-
proportion = (upper_bound - self.best_f) / (upper_bound - lower_bound)
178-
return proportion.clamp(0, 1)
235+
best_f = self.best_f.to(lower_bound)
236+
# two separate clamps needed below, as one is a tensor and one a scalar
237+
return (
238+
(upper_bound - best_f).clamp(min=0.0).clamp(max=upper_bound - lower_bound)
239+
)

botorch_community/acquisition/input_constructors.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,71 @@
1313

1414
from __future__ import annotations
1515

16-
from typing import List, Optional, Tuple
16+
from typing import Any, Hashable, List, Optional, Tuple
1717

1818
import torch
19-
from botorch.acquisition.input_constructors import acqf_input_constructor
20-
from botorch.acquisition.objective import ScalarizedPosteriorTransform
19+
20+
from botorch.acquisition.input_constructors import (
21+
acqf_input_constructor,
22+
get_best_f_analytic,
23+
)
24+
from botorch.acquisition.objective import (
25+
PosteriorTransform,
26+
ScalarizedPosteriorTransform,
27+
)
2128
from botorch.acquisition.utils import get_optimal_samples
2229
from botorch.models.model import Model
30+
31+
from botorch.utils.datasets import SupervisedDataset
2332
from botorch_community.acquisition.bayesian_active_learning import (
2433
qBayesianQueryByComittee,
2534
qBayesianVarianceReduction,
2635
qStatisticalDistanceActiveLearning,
2736
)
37+
38+
from botorch_community.acquisition.discretized import (
39+
DiscretizedExpectedImprovement,
40+
DiscretizedProbabilityOfImprovement,
41+
)
2842
from botorch_community.acquisition.scorebo import qSelfCorrectingBayesianOptimization
2943
from torch import Tensor
3044

3145

46+
@acqf_input_constructor(
47+
DiscretizedExpectedImprovement, DiscretizedProbabilityOfImprovement
48+
)
49+
def construct_inputs_best_f(
50+
model: Model,
51+
training_data: SupervisedDataset | dict[Hashable, SupervisedDataset],
52+
posterior_transform: PosteriorTransform | None = None,
53+
best_f: float | Tensor | None = None,
54+
) -> dict[str, Any]:
55+
r"""Construct kwargs for the acquisition functions requiring `best_f`.
56+
57+
Args:
58+
model: The model to be used in the acquisition function.
59+
training_data: Dataset(s) used to train the model.
60+
Used to determine default value for `best_f`.
61+
best_f: Threshold above (or below) which improvement is defined.
62+
posterior_transform: The posterior transform to be used in the
63+
acquisition function.
64+
65+
Returns:
66+
A dict mapping kwarg names of the constructor to values.
67+
"""
68+
if best_f is None:
69+
best_f = get_best_f_analytic(
70+
training_data=training_data,
71+
posterior_transform=posterior_transform,
72+
)
73+
74+
return {
75+
"model": model,
76+
"posterior_transform": posterior_transform,
77+
"best_f": best_f,
78+
}
79+
80+
3281
@acqf_input_constructor(
3382
qBayesianQueryByComittee,
3483
qBayesianVarianceReduction,

botorch_community/models/prior_fitted_network.py

Lines changed: 98 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,61 @@ def __init__(
3131
train_X: Tensor,
3232
train_Y: Tensor,
3333
model: nn.Module,
34+
train_Yvar: Tensor | None = None,
35+
batch_first: bool = False,
36+
constant_model_kwargs: dict | None = None,
3437
) -> None:
3538
"""Initialize a PFNModel.
3639
3740
Args:
38-
train_X: A `batch_shape x n x d` tensor of training features.
39-
train_Y: A `batch_shape x n x m` tensor of training observations.
41+
train_X: A `n x d` tensor of training features.
42+
train_Y: A `n x m` tensor of training observations.
4043
model: A pre-trained PFN model with the following
4144
forward(train_X, train_Y, X) -> logit predictions of shape
4245
`n x b x c` where c is the number of discrete buckets
4346
borders: A `c+1`-dim tensor of bucket borders
47+
train_Yvar: Not yet supported.
48+
batch_first: Whether the batch dimension is the first dimension of
49+
the input tensors. This is needed to support different PFN
50+
models. For batch-first x has shape `batch x seq_len x features`
51+
and for non-batch-first it has shape `seq_len x batch x features`.
52+
constant_model_kwargs: A dictionary of model kwargs that
53+
will be passed to the model in each forward pass.
4454
"""
4555
super().__init__()
46-
self.train_X = train_X
47-
self.train_Y = train_Y
48-
self.pfn = model.to(train_X)
56+
57+
if train_Yvar is not None:
58+
raise UnsupportedError("train_Yvar is not supported for PFNModel.")
59+
60+
if not (1 <= train_Y.dim() <= 3):
61+
raise UnsupportedError("train_Y must be 1- to 3-dimensional.")
62+
63+
if not (2 <= train_X.dim() <= 3):
64+
raise UnsupportedError("train_X must be 2- to 3-dimensional.")
65+
66+
if train_Y.dim() == train_X.dim():
67+
if train_Y.shape[-1] > 1:
68+
raise UnsupportedError("Only 1 target allowed for PFNModel.")
69+
train_Y = train_Y.squeeze(-1)
70+
71+
if (len(train_X.shape) != len(train_Y.shape) + 1) or (
72+
train_Y.shape != train_X.shape[:-1]
73+
):
74+
raise UnsupportedError(
75+
"train_X and train_Y must have the same shape except "
76+
"for the last dimension."
77+
)
78+
79+
if len(train_X.shape) == 2:
80+
# adding batch dimension
81+
train_X = train_X.unsqueeze(0)
82+
train_Y = train_Y.unsqueeze(0)
83+
84+
self.train_X = train_X # shape: `b x n x d`
85+
self.train_Y = train_Y # shape: `b x n`
86+
self.pfn = model
87+
self.batch_first = batch_first
88+
self.constant_model_kwargs = constant_model_kwargs
4989

5090
def posterior(
5191
self,
@@ -61,7 +101,7 @@ def posterior(
61101
any `model.forward` or `model.likelihood` calls.
62102
63103
Args:
64-
X: A `b x q x d`-dim Tensor, where `d` is the dimension of the
104+
X: A `b'? x b? x q x d`-dim Tensor, where `d` is the dimension of the
65105
feature space, `q` is the number of points considered jointly,
66106
and `b` is the batch dimension.
67107
We only allow `q=1` for PFNModel, so q can also be omitted, i.e.
@@ -86,11 +126,59 @@ def posterior(
86126
if posterior_transform is not None:
87127
raise UnsupportedError("posterior_transform is not supported for PFNModel.")
88128

89-
if len(X.shape) > 2 and X.shape[-2] > 1:
90-
raise NotImplementedError("q must be 1 for PFNModel.") # add support later
129+
if not (1 <= len(X.shape) <= 4):
130+
raise UnsupportedError("X must be 1- to 4-dimensional.")
131+
132+
# X has shape b'? x b? x q? x d
133+
134+
orig_X_shape = X.shape
135+
q_in_orig_X_shape = len(X.shape) > 2
136+
137+
if len(X.shape) == 1:
138+
X = X.unsqueeze(0).unsqueeze(0).unsqueeze(0) # shape `b'=1 x b=1 x q=1 x d`
139+
elif len(X.shape) == 2:
140+
X = X.unsqueeze(1).unsqueeze(1) # shape `b' x b=1 x q=1 x d`
141+
elif len(X.shape) == 3:
142+
if self.train_X.shape[0] == 1:
143+
X = X.unsqueeze(1) # shape `b' x b=1 x q x d`
144+
else:
145+
X = X.unsqueeze(0) # shape `b'=1 x b x q x d`
146+
147+
# X has shape `b' x b x q x d`
148+
149+
if X.shape[2] != 1:
150+
raise UnsupportedError("Only q=1 is supported for PFNModel.")
151+
152+
# X has shape `b' x b x q=1 x d`
153+
154+
train_X = self.train_X # shape `b x n x d`
155+
train_Y = self.train_Y # shape `b x n`
156+
folded_X = X.transpose(0, 2).squeeze(0) # shape `b x b' x d
157+
158+
constant_model_kwargs = self.constant_model_kwargs or {}
159+
160+
if self.batch_first:
161+
logits = self.pfn(
162+
train_X.float(),
163+
train_X.float(),
164+
folded_X.float(),
165+
**constant_model_kwargs,
166+
).transpose(0, 1)
167+
else:
168+
logits = self.pfn(
169+
train_X.float().transpose(0, 1),
170+
train_Y.float().transpose(0, 1),
171+
folded_X.float().transpose(0, 1),
172+
**constant_model_kwargs,
173+
)
174+
175+
# logits shape `b' x b x logits_dim`
91176

92-
# flatten batch dimensions for PFN input
93-
logits = self.pfn(self.train_X, self.train_Y, X)
177+
logits = logits.view(
178+
*orig_X_shape[:-1], -1
179+
) # orig shape w/o q but logits_dim at end: `b'? x b? x q? x logits_dim`
180+
if q_in_orig_X_shape:
181+
logits = logits.squeeze(-2) # shape `b'? x b? x logits_dim`
94182

95183
probabilities = logits.softmax(dim=-1)
96184

0 commit comments

Comments
 (0)