Skip to content

Commit 7f4b718

Browse files
author
Vincent Moens
committed
[Feature] example_data for NonTensor spec
ghstack-source-id: 9af4a9e Pull Request resolved: pytorch/rl#2698
1 parent 256a700 commit 7f4b718

File tree

1 file changed

+40
-8
lines changed

1 file changed

+40
-8
lines changed

torchrl/data/tensor_specs.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2457,6 +2457,7 @@ def __init__(
24572457
shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
24582458
device: Optional[DEVICE_TYPING] = None,
24592459
dtype: torch.dtype | None = None,
2460+
example_data: Any = None,
24602461
**kwargs,
24612462
):
24622463
if isinstance(shape, int):
@@ -2467,6 +2468,7 @@ def __init__(
24672468
super().__init__(
24682469
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
24692470
)
2471+
self.example_data = example_data
24702472

24712473
def cardinality(self) -> Any:
24722474
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
@@ -2485,30 +2487,46 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
24852487
dest_device = torch.device(dest)
24862488
if dest_device == self.device and dest_dtype == self.dtype:
24872489
return self
2488-
return self.__class__(shape=self.shape, device=dest_device, dtype=None)
2490+
return self.__class__(
2491+
shape=self.shape,
2492+
device=dest_device,
2493+
dtype=None,
2494+
example_data=self.example_data,
2495+
)
24892496

24902497
def clone(self) -> NonTensor:
2491-
return self.__class__(shape=self.shape, device=self.device, dtype=self.dtype)
2498+
return self.__class__(
2499+
shape=self.shape,
2500+
device=self.device,
2501+
dtype=self.dtype,
2502+
example_data=self.example_data,
2503+
)
24922504

24932505
def rand(self, shape=None):
24942506
if shape is None:
24952507
shape = ()
24962508
return NonTensorData(
2497-
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
2509+
data=self.example_data,
2510+
batch_size=(*shape, *self._safe_shape),
2511+
device=self.device,
24982512
)
24992513

25002514
def zero(self, shape=None):
25012515
if shape is None:
25022516
shape = ()
25032517
return NonTensorData(
2504-
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
2518+
data=self.example_data,
2519+
batch_size=(*shape, *self._safe_shape),
2520+
device=self.device,
25052521
)
25062522

25072523
def one(self, shape=None):
25082524
if shape is None:
25092525
shape = ()
25102526
return NonTensorData(
2511-
data=None, batch_size=(*shape, *self._safe_shape), device=self.device
2527+
data=self.example_data,
2528+
batch_size=(*shape, *self._safe_shape),
2529+
device=self.device,
25122530
)
25132531

25142532
def is_in(self, val: Any) -> bool:
@@ -2533,23 +2551,36 @@ def expand(self, *shape):
25332551
raise ValueError(
25342552
f"The last elements of the expanded shape must match the current one. Got shape={shape} while self.shape={self.shape}."
25352553
)
2536-
return self.__class__(shape=shape, device=self.device, dtype=None)
2554+
return self.__class__(
2555+
shape=shape, device=self.device, dtype=None, example_data=self.example_data
2556+
)
25372557

25382558
def _reshape(self, shape):
2539-
return self.__class__(shape=shape, device=self.device, dtype=self.dtype)
2559+
return self.__class__(
2560+
shape=shape,
2561+
device=self.device,
2562+
dtype=self.dtype,
2563+
example_data=self.example_data,
2564+
)
25402565

25412566
def _unflatten(self, dim, sizes):
25422567
shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape
25432568
return self.__class__(
25442569
shape=shape,
25452570
device=self.device,
25462571
dtype=self.dtype,
2572+
example_data=self.example_data,
25472573
)
25482574

25492575
def __getitem__(self, idx: SHAPE_INDEX_TYPING):
25502576
"""Indexes the current TensorSpec based on the provided index."""
25512577
indexed_shape = _size(_shape_indexing(self.shape, idx))
2552-
return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype)
2578+
return self.__class__(
2579+
shape=indexed_shape,
2580+
device=self.device,
2581+
dtype=self.dtype,
2582+
example_data=self.example_data,
2583+
)
25532584

25542585
def unbind(self, dim: int = 0):
25552586
orig_dim = dim
@@ -2565,6 +2596,7 @@ def unbind(self, dim: int = 0):
25652596
shape=shape,
25662597
device=self.device,
25672598
dtype=self.dtype,
2599+
example_data=self.example_data,
25682600
)
25692601
for i in range(self.shape[dim])
25702602
)

0 commit comments

Comments
 (0)