Skip to content

Commit c4c93ba

Browse files
Vincent Moensosalpekar
authored andcommitted
[Feature] Device transform (#1472)
1 parent 2b5d612 commit c4c93ba

File tree

8 files changed

+287
-38
lines changed

8 files changed

+287
-38
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ to be able to create this other composition:
450450
CatTensors
451451
CenterCrop
452452
Compose
453+
DeviceCastTransform
453454
DiscreteActionProjection
454455
DoubleToFloat
455456
DTypeCastTransform

test/mocking_classes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
**kwargs,
9292
):
9393
super().__init__(
94-
device="cpu",
94+
device=kwargs.pop("device", "cpu"),
9595
dtype=torch.get_default_dtype(),
9696
)
9797
self.set_seed(seed)

test/test_transforms.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@
3838
from torchrl.data import (
3939
BoundedTensorSpec,
4040
CompositeSpec,
41+
LazyMemmapStorage,
4142
LazyTensorStorage,
4243
ReplayBuffer,
4344
TensorDictReplayBuffer,
45+
TensorStorage,
4446
UnboundedContinuousTensorSpec,
4547
)
4648
from torchrl.envs import (
@@ -49,6 +51,7 @@
4951
CatTensors,
5052
CenterCrop,
5153
Compose,
54+
DeviceCastTransform,
5255
DiscreteActionProjection,
5356
DoubleToFloat,
5457
EnvBase,
@@ -8133,6 +8136,105 @@ def test_kl_lstm(self):
81338136
klt(env.rollout(3, policy))
81348137

81358138

8139+
class TestDeviceCastTransform(TransformBase):
8140+
def test_single_trans_env_check(self):
8141+
env = ContinuousActionVecMockEnv(device="cpu:0")
8142+
env = TransformedEnv(env, DeviceCastTransform("cpu:1"))
8143+
assert env.device == torch.device("cpu:1")
8144+
check_env_specs(env)
8145+
8146+
def test_serial_trans_env_check(self):
8147+
def make_env():
8148+
return TransformedEnv(
8149+
ContinuousActionVecMockEnv(device="cpu:0"), DeviceCastTransform("cpu:1")
8150+
)
8151+
8152+
env = SerialEnv(2, make_env)
8153+
assert env.device == torch.device("cpu:1")
8154+
check_env_specs(env)
8155+
8156+
def test_parallel_trans_env_check(self):
8157+
def make_env():
8158+
return TransformedEnv(
8159+
ContinuousActionVecMockEnv(device="cpu:0"), DeviceCastTransform("cpu:1")
8160+
)
8161+
8162+
env = ParallelEnv(2, make_env)
8163+
assert env.device == torch.device("cpu:1")
8164+
check_env_specs(env)
8165+
8166+
def test_trans_serial_env_check(self):
8167+
def make_env():
8168+
return ContinuousActionVecMockEnv(device="cpu:0")
8169+
8170+
env = TransformedEnv(SerialEnv(2, make_env), DeviceCastTransform("cpu:1"))
8171+
assert env.device == torch.device("cpu:1")
8172+
check_env_specs(env)
8173+
8174+
def test_trans_parallel_env_check(self):
8175+
def make_env():
8176+
return ContinuousActionVecMockEnv(device="cpu:0")
8177+
8178+
env = TransformedEnv(ParallelEnv(2, make_env), DeviceCastTransform("cpu:1"))
8179+
assert env.device == torch.device("cpu:1")
8180+
check_env_specs(env)
8181+
8182+
def test_transform_no_env(self):
8183+
t = DeviceCastTransform("cpu:1", "cpu:0")
8184+
assert t._call(TensorDict({}, [], device="cpu:0")).device == torch.device(
8185+
"cpu:1"
8186+
)
8187+
8188+
def test_transform_compose(self):
8189+
t = Compose(DeviceCastTransform("cpu:1", "cpu:0"))
8190+
assert t._call(TensorDict({}, [], device="cpu:0")).device == torch.device(
8191+
"cpu:1"
8192+
)
8193+
assert t._inv_call(TensorDict({}, [], device="cpu:1")).device == torch.device(
8194+
"cpu:0"
8195+
)
8196+
8197+
def test_transform_env(self):
8198+
env = ContinuousActionVecMockEnv(device="cpu:0")
8199+
assert env.device == torch.device("cpu:0")
8200+
env = TransformedEnv(env, DeviceCastTransform("cpu:1"))
8201+
assert env.device == torch.device("cpu:1")
8202+
assert env.transform.device == torch.device("cpu:1")
8203+
assert env.transform.orig_device == torch.device("cpu:0")
8204+
8205+
def test_transform_model(self):
8206+
t = Compose(DeviceCastTransform("cpu:1", "cpu:0"))
8207+
m = nn.Sequential(t)
8208+
assert t(TensorDict({}, [], device="cpu:0")).device == torch.device("cpu:1")
8209+
8210+
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
8211+
@pytest.mark.parametrize(
8212+
"storage", [TensorStorage, LazyTensorStorage, LazyMemmapStorage]
8213+
)
8214+
def test_transform_rb(self, rbclass, storage):
8215+
t = Compose(DeviceCastTransform("cpu:1", "cpu:0"))
8216+
storage_kwargs = (
8217+
{
8218+
"storage": TensorDict(
8219+
{"a": torch.zeros(20, 1, device="cpu:0")}, [20], device="cpu:0"
8220+
)
8221+
}
8222+
if storage is TensorStorage
8223+
else {}
8224+
)
8225+
rb = rbclass(storage=storage(max_size=20, device="auto", **storage_kwargs))
8226+
rb.append_transform(t)
8227+
rb.add(TensorDict({"a": [1]}, [], device="cpu:1"))
8228+
assert rb._storage._storage.device == torch.device("cpu:0")
8229+
assert rb.sample(4).device == torch.device("cpu:1")
8230+
8231+
def test_transform_inverse(self):
8232+
t = DeviceCastTransform("cpu:1", "cpu:0")
8233+
assert t._inv_call(TensorDict({}, [], device="cpu:1")).device == torch.device(
8234+
"cpu:0"
8235+
)
8236+
8237+
81368238
if __name__ == "__main__":
81378239
args, unknown = argparse.ArgumentParser().parse_known_args()
81388240
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,13 @@ def add(self, data: Any) -> int:
248248
Returns:
249249
index where the data lives in the replay buffer.
250250
"""
251+
if self._transform is not None and (
252+
is_tensor_collection(data) or len(self._transform)
253+
):
254+
data = self._transform.inv(data)
255+
return self._add(data)
256+
257+
def _add(self, data):
251258
with self._replay_lock:
252259
index = self._writer.add(data)
253260
self._sampler.add(index)
@@ -271,9 +278,9 @@ def extend(self, data: Sequence) -> torch.Tensor:
271278
Returns:
272279
Indices of the data added to the replay buffer.
273280
"""
274-
if self._transform is not None and is_tensor_collection(data):
275-
data = self._transform.inv(data)
276-
elif self._transform is not None and len(self._transform):
281+
if self._transform is not None and (
282+
is_tensor_collection(data) or len(self._transform)
283+
):
277284
data = self._transform.inv(data)
278285
return self._extend(data)
279286

@@ -675,19 +682,24 @@ def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]:
675682
return priority
676683

677684
def add(self, data: TensorDictBase) -> int:
685+
if self._transform is not None:
686+
data = self._transform.inv(data)
687+
678688
if is_tensor_collection(data):
679689
data_add = TensorDict(
680690
{
681691
"_data": data,
682692
},
683693
batch_size=[],
694+
device=data.device,
684695
)
685696
if data.batch_size:
686697
data_add["_rb_batch_size"] = torch.tensor(data.batch_size)
687698

688699
else:
689700
data_add = data
690-
index = super().add(data_add)
701+
702+
index = super()._add(data_add)
691703
if is_tensor_collection(data_add):
692704
data_add.set("index", index)
693705

@@ -699,7 +711,8 @@ def add(self, data: TensorDictBase) -> int:
699711
def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor:
700712
if is_tensor_collection(tensordicts):
701713
tensordicts = TensorDict(
702-
{"_data": tensordicts}, batch_size=tensordicts.batch_size[:1]
714+
{"_data": tensordicts},
715+
batch_size=tensordicts.batch_size[:1],
703716
)
704717
if tensordicts.batch_dims > 1:
705718
# we want the tensordict to have one dimension only. The batch size
@@ -730,14 +743,12 @@ def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor:
730743
stacked_td = tensordicts
731744

732745
if self._transform is not None:
733-
stacked_td.set("_data", self._transform.inv(stacked_td.get("_data")))
746+
tensordicts = self._transform.inv(stacked_td.get("_data"))
747+
stacked_td.set("_data", tensordicts)
748+
if tensordicts.device is not None:
749+
stacked_td = stacked_td.to(tensordicts.device)
734750

735751
index = super()._extend(stacked_td)
736-
# stacked_td.set(
737-
# "index",
738-
# torch.tensor(index, dtype=torch.int, device=stacked_td.device),
739-
# inplace=True,
740-
# )
741752
self.update_tensordict_priority(stacked_td)
742753
return index
743754

torchrl/data/replay_buffers/storages.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,14 @@ class TensorStorage(Storage):
171171
"""A storage for tensors and tensordicts.
172172
173173
Args:
174-
data (tensor or TensorDict): the data buffer to be used.
174+
storage (tensor or TensorDict): the data buffer to be used.
175175
max_size (int): size of the storage, i.e. maximum number of elements stored
176176
in the buffer.
177177
device (torch.device, optional): device where the sampled tensors will be
178178
stored and sent. Default is :obj:`torch.device("cpu")`.
179+
If "auto" is passed, the device is automatically gathered from the
180+
first batch of data passed. This is not enabled by default to avoid
181+
data placed on GPU by mistake, causing OOM issues.
179182
180183
Examples:
181184
>>> data = TensorDict({
@@ -230,7 +233,7 @@ def __new__(cls, *args, **kwargs):
230233
cls._storage = None
231234
return super().__new__(cls)
232235

233-
def __init__(self, storage, max_size=None, device=None):
236+
def __init__(self, storage, max_size=None, device="cpu"):
234237
if not ((storage is None) ^ (max_size is None)):
235238
if storage is None:
236239
raise ValueError("Expected storage to be non-null.")
@@ -247,7 +250,13 @@ def __init__(self, storage, max_size=None, device=None):
247250
self._len = max_size
248251
else:
249252
self._len = 0
250-
self.device = device if device else torch.device("cpu")
253+
self.device = (
254+
torch.device(device)
255+
if device != "auto"
256+
else storage.device
257+
if storage is not None
258+
else "auto"
259+
)
251260
self._storage = storage
252261

253262
def state_dict(self) -> Dict[str, Any]:
@@ -345,6 +354,9 @@ class LazyTensorStorage(TensorStorage):
345354
in the buffer.
346355
device (torch.device, optional): device where the sampled tensors will be
347356
stored and sent. Default is :obj:`torch.device("cpu")`.
357+
If "auto" is passed, the device is automatically gathered from the
358+
first batch of data passed. This is not enabled by default to avoid
359+
data placed on GPU by mistake, causing OOM issues.
348360
349361
Examples:
350362
>>> data = TensorDict({
@@ -396,12 +408,14 @@ class LazyTensorStorage(TensorStorage):
396408
397409
"""
398410

399-
def __init__(self, max_size, device=None):
400-
super().__init__(None, max_size, device=device)
411+
def __init__(self, max_size, device="cpu"):
412+
super().__init__(storage=None, max_size=max_size, device=device)
401413

402414
def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
403415
if VERBOSE:
404416
print("Creating a TensorStorage...")
417+
if self.device == "auto":
418+
self.device = data.device
405419
if isinstance(data, torch.Tensor):
406420
# if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
407421
out = torch.empty(
@@ -436,6 +450,9 @@ class LazyMemmapStorage(LazyTensorStorage):
436450
scratch_dir (str or path): directory where memmap-tensors will be written.
437451
device (torch.device, optional): device where the sampled tensors will be
438452
stored and sent. Default is :obj:`torch.device("cpu")`.
453+
If ``None`` is provided, the device is automatically gathered from the
454+
first batch of data passed. This is not enabled by default to avoid
455+
data placed on GPU by mistake, causing OOM issues.
439456
440457
Examples:
441458
>>> data = TensorDict({
@@ -486,15 +503,15 @@ class LazyMemmapStorage(LazyTensorStorage):
486503
487504
"""
488505

489-
def __init__(self, max_size, scratch_dir=None, device=None):
506+
def __init__(self, max_size, scratch_dir=None, device="cpu"):
490507
super().__init__(max_size)
491508
self.initialized = False
492509
self.scratch_dir = None
493510
if scratch_dir is not None:
494511
self.scratch_dir = str(scratch_dir)
495512
if self.scratch_dir[-1] != "/":
496513
self.scratch_dir += "/"
497-
self.device = device if device else torch.device("cpu")
514+
self.device = torch.device(device) if device != "auto" else device
498515
self._len = 0
499516

500517
def state_dict(self) -> Dict[str, Any]:
@@ -552,6 +569,8 @@ def load_state_dict(self, state_dict):
552569
def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
553570
if VERBOSE:
554571
print("Creating a MemmapStorage...")
572+
if self.device == "auto":
573+
self.device = data.device
555574
if isinstance(data, torch.Tensor):
556575
# if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
557576
out = MemmapTensor(
@@ -682,7 +701,7 @@ def _get_default_collate(storage, _is_tensordict=False):
682701
return torch.utils.data._utils.collate.default_collate
683702
elif isinstance(storage, LazyMemmapStorage):
684703
return _collate_as_tensor
685-
elif isinstance(storage, (LazyTensorStorage,)):
704+
elif isinstance(storage, (TensorStorage,)):
686705
return _collate_contiguous
687706
else:
688707
raise NotImplementedError(

torchrl/envs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
CatTensors,
1414
CenterCrop,
1515
Compose,
16+
DeviceCastTransform,
1617
DiscreteActionProjection,
1718
DoubleToFloat,
1819
DTypeCastTransform,

torchrl/envs/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
CatTensors,
1212
CenterCrop,
1313
Compose,
14+
DeviceCastTransform,
1415
DiscreteActionProjection,
1516
DoubleToFloat,
1617
DTypeCastTransform,

0 commit comments

Comments
 (0)