You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Changes to enable benchmarking with PFNs within Ax (#2915)
Summary:
X-link: facebook/Ax#4003
Pull Request resolved: #2915
These are all the basic changes needed for the PFNs to work within Ax.
While this removes a bunch of problems encountered it also fixes a particular bug that led to worse optimization performance: an issue in the way EI is computed. This also adds tests to see that the acquisition functions compute the right thing when approximating a normal distribution.
### Discussion (cc saitcakmak)
Should we get rid of the ag_integrate logic? I overtook it from whoever wrote this before, but have to say that just implementing acquisition functions using the raw logits seems easier to me.
I believe it was implemented such that we can define acquisition functions based on a posterior in an elegant way. I would propose to do it slightly less elegantly and use a function like the one below, where we access the logits and the borders from the posterior (`posterior.borders`) instead of the posterior providing the integrate function.
This is how you implement EI, which I believe to be simpler than our current EI implementation (which already had a bug twice). It even has only the same amount of lines, but does not require you to understand the concept of splitting up an integration into a product that is worked on separately.
```
def ei(
self,
logits: torch.Tensor,
best_f: float | torch.Tensor,
*,
maximize: bool = True,
) -> torch.Tensor: # logits: evaluation_points x batch x feature_dim
bucket_diffs = self.borders[1:] - self.borders[:-1]
assert maximize
if not torch.is_tensor(best_f) or not len(best_f.shape): # type: ignore
best_f = torch.full(
logits[..., 0].shape, best_f, device=logits.device
) # type: ignore
best_f = best_f[..., None].repeat(
*[1] * len(best_f.shape), logits.shape[-1]
) # type: ignore
clamped_best_f = best_f.clamp(self.borders[:-1], self.borders[1:])
# > bucket_contributions =
# > (best_f[...,None] < self.borders[:-1]).float() * bucket_means
# true bucket contributions
bucket_contributions = (
(self.borders[1:] ** 2 - clamped_best_f**2) / 2
- best_f * (self.borders[1:] - clamped_best_f)
) / bucket_diffs
p = torch.softmax(logits, -1)
return torch.einsum("...b,...b->...", p, bucket_contributions)
```
Reviewed By: saitcakmak
Differential Revision: D77884839
fbshipit-source-id: b138037cd5252733a0bf13db7f42631f0b6959a0
0 commit comments