Skip to content

Commit ffcad4a

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Adding logmeanexp and logdiffexp numerical utilities (#1657)
Summary: Pull Request resolved: #1657 This commit defines `logdiffexp` and `logmeanexp`, numerical utility functions that are going to be more generally useful for log-space computations. Reviewed By: Balandat Differential Revision: D43061278 fbshipit-source-id: a279a2ab7da5e1eb23a20fba0c3498dd52ea37b6
1 parent 6475503 commit ffcad4a

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

botorch/acquisition/analytic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
ndtr as Phi,
3535
phi,
3636
)
37-
from botorch.utils.safe_math import log1mexp
37+
from botorch.utils.safe_math import log1mexp, logmeanexp
3838
from botorch.utils.transforms import convert_to_target_pre_hook, t_batch_mode_transform
3939
from torch import Tensor
4040

@@ -587,7 +587,7 @@ def forward(self, X: Tensor) -> Tensor:
587587
u = _scaled_improvement(mean, sigma, self.best_f, self.maximize)
588588
log_ei = _log_ei_helper(u) + sigma.log()
589589
# this is mathematically - though not numerically - equivalent to log(mean(ei))
590-
return torch.logsumexp(log_ei, dim=-1) - math.log(log_ei.shape[-1])
590+
return logmeanexp(log_ei, dim=-1)
591591

592592

593593
class NoisyExpectedImprovement(ExpectedImprovement):

botorch/utils/probability/utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, Union
1515

1616
import torch
17-
from botorch.utils.safe_math import log1mexp
17+
from botorch.utils.safe_math import logdiffexp
1818
from numpy.polynomial.legendre import leggauss as numpy_leggauss
1919
from torch import BoolTensor, LongTensor, Tensor
2020

@@ -214,9 +214,7 @@ def log_prob_normal_in(a: Tensor, b: Tensor) -> Tensor:
214214
c = torch.where(rev_cond, -b, a)
215215
b = torch.where(rev_cond, -a, b)
216216
a = c # after we updated b, can assign c to a
217-
log_Phi_b = log_ndtr(b)
218-
# Phi(b) > Phi(a), so 0 > log(Phi(a) / Phi(b)) and we can use log1mexp
219-
return log_Phi_b + log1mexp(log_ndtr(a) - log_Phi_b)
217+
return logdiffexp(log_a=log_ndtr(a), log_b=log_ndtr(b))
220218

221219

222220
def swap_along_dim_(

botorch/utils/safe_math.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,30 @@ def log1mexp(x: Tensor) -> Tensor:
7272
(-x.expm1()).log(),
7373
(-x.exp()).log1p(),
7474
)
75+
76+
77+
def logdiffexp(log_a: Tensor, log_b: Tensor) -> Tensor:
78+
"""Computes log(b - a) accurately given log(a) and log(b).
79+
Assumes, log_b > log_a, i.e. b > a > 0.
80+
81+
Args:
82+
log_a (Tensor): The logarithm of a, assumed to be less than log_b.
83+
log_b (Tensor): The logarithm of b, assumed to be larger than log_a.
84+
85+
Returns:
86+
A Tensor of values corresponding to log(b - a).
87+
"""
88+
return log_b + log1mexp(log_a - log_b)
89+
90+
91+
def logmeanexp(X: Tensor, dim: int = -1) -> Tensor:
92+
"""Computes log(mean(exp(X), dim=dim)).
93+
94+
Args:
95+
X (Tensor): The logarithm of a, assumed to be less than log_b.
96+
dim (int): The dimension over which to compute the mean. Default is -1.
97+
98+
Returns:
99+
A Tensor of values corresponding to log(mean(exp(X), dim=dim)).
100+
"""
101+
return torch.logsumexp(X, dim=dim) - math.log(X.shape[dim])

0 commit comments

Comments
 (0)