Skip to content

Commit d80b556

Browse files
sdaultonfacebook-github-bot
authored andcommitted
make indices a buffer in AffineInputTransform (#1656)
Summary: Pull Request resolved: #1656 see title. This is important for being able to set the device with `to()` Reviewed By: Balandat, SebastianAment Differential Revision: D43048429 fbshipit-source-id: 6a290099752b150492162cdc0ff3ed110523ee04
1 parent 52f39cf commit d80b556

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

botorch/models/transforms/input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def __init__(
365365
raise ValueError("Elements of `indices` have to be smaller than `d`!")
366366
if len(indices.unique()) != len(indices):
367367
raise ValueError("Elements of `indices` tensor must be unique!")
368-
self.indices = indices
368+
self.register_buffer("indices", indices)
369369
torch.broadcast_shapes(coefficient.shape, offset.shape)
370370

371371
self._d = d

test/models/transforms/test_input.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,13 @@ def test_normalize(self):
197197
self.assertEqual(nlz.mins.shape, torch.Size([1, 1]))
198198
self.assertEqual(nlz.ranges.shape, torch.Size([1, 1]))
199199
self.assertEqual(len(nlz.indices), 1)
200-
self.assertTrue((nlz.indices == torch.tensor([0], dtype=torch.long)).all())
200+
nlz.to(device=self.device)
201+
self.assertTrue(
202+
(
203+
nlz.indices
204+
== torch.tensor([0], dtype=torch.long, device=self.device)
205+
).all()
206+
)
201207

202208
# test .to
203209
other_dtype = torch.float if dtype == torch.double else torch.double
@@ -382,17 +388,25 @@ def test_standardize(self):
382388
self.assertEqual(stdz.means.shape, torch.Size([1, 1]))
383389
self.assertEqual(stdz.stds.shape, torch.Size([1, 1]))
384390
self.assertEqual(len(stdz.indices), 1)
391+
stdz.to(device=self.device)
385392
self.assertTrue(
386-
torch.equal(stdz.indices, torch.tensor([0], dtype=torch.long))
393+
torch.equal(
394+
stdz.indices,
395+
torch.tensor([0], dtype=torch.long, device=self.device),
396+
)
387397
)
388398
stdz = InputStandardize(d=2, indices=[0], batch_shape=torch.Size([3]))
399+
stdz.to(device=self.device)
389400
self.assertTrue(stdz.training)
390401
self.assertEqual(stdz._d, 2)
391402
self.assertEqual(stdz.means.shape, torch.Size([3, 1, 1]))
392403
self.assertEqual(stdz.stds.shape, torch.Size([3, 1, 1]))
393404
self.assertEqual(len(stdz.indices), 1)
394405
self.assertTrue(
395-
torch.equal(stdz.indices, torch.tensor([0], dtype=torch.long))
406+
torch.equal(
407+
stdz.indices,
408+
torch.tensor([0], device=self.device, dtype=torch.long),
409+
)
396410
)
397411

398412
# test jitter

0 commit comments

Comments
 (0)