Skip to content

Commit 7daad68

Browse files
Balandatfacebook-github-bot
authored andcommitted
Add unit test test for separate_mtmvn (#2577)
Summary: Get test coverage back up. Pull Request resolved: #2577 Reviewed By: esantorella Differential Revision: D64448451 Pulled By: Balandat fbshipit-source-id: 5bf860680ba2aac0c520edce2d3ae24f8b1b8a36
1 parent 0479873 commit 7daad68

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

test/utils/test_multitask.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
from botorch.utils.multitask import separate_mtmvn
10+
11+
from botorch.utils.testing import BotorchTestCase
12+
from gpytorch.distributions import MultitaskMultivariateNormal
13+
from gpytorch.distributions.multivariate_normal import MultivariateNormal
14+
15+
16+
class TestSeparateMTMVN(BotorchTestCase):
17+
18+
def _test_separate_mtmvn(self, interleaved=False):
19+
for dtype in (torch.float, torch.double):
20+
tkwargs = {"device": self.device, "dtype": dtype}
21+
mean = torch.rand(2, 2, **tkwargs)
22+
a = torch.rand(4, 4, **tkwargs)
23+
covar = a @ a.transpose(-1, -2) + torch.eye(4, **tkwargs)
24+
mvn = MultitaskMultivariateNormal(
25+
mean=mean, covariance_matrix=covar, interleaved=interleaved
26+
)
27+
mtmvn_list = separate_mtmvn(mvn)
28+
29+
mean_1 = mean[..., 0]
30+
mean_2 = mean[..., 1]
31+
if interleaved:
32+
covar_1 = covar[::2, ::2]
33+
covar_2 = covar[1::2, 1::2]
34+
else:
35+
covar_1 = covar[:2, :2]
36+
covar_2 = covar[2:, 2:]
37+
38+
self.assertEqual(len(mtmvn_list), 2)
39+
for mvn_i, mean_i, covar_i in zip(
40+
mtmvn_list, (mean_1, mean_2), (covar_1, covar_2)
41+
):
42+
self.assertIsInstance(mvn_i, MultivariateNormal)
43+
self.assertTrue(torch.equal(mvn_i.mean, mean_i))
44+
self.assertAllClose(mvn_i.covariance_matrix, covar_i)
45+
46+
def test_separate_mtmvn_interleaved(self) -> None:
47+
self._test_separate_mtmvn(interleaved=True)
48+
49+
def test_separate_mtmvn_not_interleaved(self) -> None:
50+
self._test_separate_mtmvn(interleaved=False)

0 commit comments

Comments
 (0)