Skip to content

Commit 71f03ea

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Eliminate expensive indexing in separate_mtmvn (#2920)
Summary: Pull Request resolved: #2920 Fixes #2919 The batched indexing of the covariance tensor was consuming extraordinary amounts of memory, enough to crash servers with moderately sized tensors. Profiling `separate_mtmvn` with `n_test=64` (256 was crashing my server) using the repro from the above issue. Before: Using >7GB ram and taking >800ms {F1980143884} After: Using ~5e-5 GB ram and taking ~20ms {F1980143899} Reviewed By: Balandat Differential Revision: D78038356 fbshipit-source-id: c347c45ef666f211521416299a393dc3fa923a0a
1 parent 44299a1 commit 71f03ea

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

botorch/utils/multitask.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,28 @@ def separate_mtmvn(mvn: MultitaskMultivariateNormal) -> list[MultivariateNormal]
2424
# T150340766 Upstream as a class method on gpytorch MultitaskMultivariateNormal.
2525
full_covar = mvn.lazy_covariance_matrix
2626
num_data, num_tasks = mvn.mean.shape[-2:]
27-
if mvn._interleaved:
28-
data_indices = torch.arange(
29-
0, num_data * num_tasks, num_tasks, device=full_covar.device
30-
).view(-1, 1, 1)
31-
task_indices = torch.arange(num_tasks, device=full_covar.device)
32-
else:
33-
data_indices = torch.arange(num_data, device=full_covar.device).view(-1, 1, 1)
34-
task_indices = torch.arange(
35-
0, num_data * num_tasks, num_data, device=full_covar.device
36-
)
37-
slice_ = (data_indices + task_indices).transpose(-1, -3)
38-
data_covars = full_covar[..., slice_, slice_.transpose(-1, -2)]
27+
3928
mvns = []
4029
for c in range(num_tasks):
41-
mvns.append(
42-
MultivariateNormal(
43-
mvn.mean[..., c], to_linear_operator(data_covars[..., c, :, :])
30+
# Compute indices for task c's data points
31+
if mvn._interleaved:
32+
# For interleaved: task c data points are at positions
33+
# c, c+num_tasks, c+2*num_tasks, ...
34+
task_indices = torch.arange(
35+
c, num_data * num_tasks, num_tasks, device=full_covar.device
36+
)
37+
else:
38+
# For non-interleaved: task c data points are at positions
39+
# c*num_data to (c+1)*num_data
40+
task_indices = torch.arange(
41+
c * num_data, (c + 1) * num_data, device=full_covar.device
4442
)
43+
44+
# Extract covariance submatrix for task c
45+
task_covar = full_covar[..., task_indices, :]
46+
task_covar = task_covar[..., :, task_indices]
47+
48+
mvns.append(
49+
MultivariateNormal(mvn.mean[..., c], to_linear_operator(task_covar))
4550
)
4651
return mvns

test/utils/test_multitask.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import torch
99
from botorch.utils.multitask import separate_mtmvn
10-
1110
from botorch.utils.testing import BotorchTestCase
1211
from gpytorch.distributions import MultitaskMultivariateNormal
1312
from gpytorch.distributions.multivariate_normal import MultivariateNormal

0 commit comments

Comments
 (0)