|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# |
| 4 | +# This source code is licensed under the MIT license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | + |
| 8 | +import torch |
| 9 | +from botorch.utils.multitask import separate_mtmvn |
| 10 | + |
| 11 | +from botorch.utils.testing import BotorchTestCase |
| 12 | +from gpytorch.distributions import MultitaskMultivariateNormal |
| 13 | +from gpytorch.distributions.multivariate_normal import MultivariateNormal |
| 14 | + |
| 15 | + |
| 16 | +class TestSeparateMTMVN(BotorchTestCase): |
| 17 | + |
| 18 | + def _test_separate_mtmvn(self, interleaved=False): |
| 19 | + for dtype in (torch.float, torch.double): |
| 20 | + tkwargs = {"device": self.device, "dtype": dtype} |
| 21 | + mean = torch.rand(2, 2, **tkwargs) |
| 22 | + a = torch.rand(4, 4, **tkwargs) |
| 23 | + covar = a @ a.transpose(-1, -2) + torch.eye(4, **tkwargs) |
| 24 | + mvn = MultitaskMultivariateNormal( |
| 25 | + mean=mean, covariance_matrix=covar, interleaved=interleaved |
| 26 | + ) |
| 27 | + mtmvn_list = separate_mtmvn(mvn) |
| 28 | + |
| 29 | + mean_1 = mean[..., 0] |
| 30 | + mean_2 = mean[..., 1] |
| 31 | + if interleaved: |
| 32 | + covar_1 = covar[::2, ::2] |
| 33 | + covar_2 = covar[1::2, 1::2] |
| 34 | + else: |
| 35 | + covar_1 = covar[:2, :2] |
| 36 | + covar_2 = covar[2:, 2:] |
| 37 | + |
| 38 | + self.assertEqual(len(mtmvn_list), 2) |
| 39 | + for mvn_i, mean_i, covar_i in zip( |
| 40 | + mtmvn_list, (mean_1, mean_2), (covar_1, covar_2) |
| 41 | + ): |
| 42 | + self.assertIsInstance(mvn_i, MultivariateNormal) |
| 43 | + self.assertTrue(torch.equal(mvn_i.mean, mean_i)) |
| 44 | + self.assertAllClose(mvn_i.covariance_matrix, covar_i) |
| 45 | + |
| 46 | + def test_separate_mtmvn_interleaved(self) -> None: |
| 47 | + self._test_separate_mtmvn(interleaved=True) |
| 48 | + |
| 49 | + def test_separate_mtmvn_not_interleaved(self) -> None: |
| 50 | + self._test_separate_mtmvn(interleaved=False) |
0 commit comments