@@ -2457,6 +2457,7 @@ def __init__(
2457
2457
shape : Union [torch .Size , int ] = _DEFAULT_SHAPE ,
2458
2458
device : Optional [DEVICE_TYPING ] = None ,
2459
2459
dtype : torch .dtype | None = None ,
2460
+ example_data : Any = None ,
2460
2461
** kwargs ,
2461
2462
):
2462
2463
if isinstance (shape , int ):
@@ -2467,6 +2468,7 @@ def __init__(
2467
2468
super ().__init__ (
2468
2469
shape = shape , space = None , device = device , dtype = dtype , domain = domain , ** kwargs
2469
2470
)
2471
+ self .example_data = example_data
2470
2472
2471
2473
def cardinality (self ) -> Any :
2472
2474
raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
@@ -2485,30 +2487,46 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
2485
2487
dest_device = torch .device (dest )
2486
2488
if dest_device == self .device and dest_dtype == self .dtype :
2487
2489
return self
2488
- return self .__class__ (shape = self .shape , device = dest_device , dtype = None )
2490
+ return self .__class__ (
2491
+ shape = self .shape ,
2492
+ device = dest_device ,
2493
+ dtype = None ,
2494
+ example_data = self .example_data ,
2495
+ )
2489
2496
2490
2497
def clone (self ) -> NonTensor :
2491
- return self .__class__ (shape = self .shape , device = self .device , dtype = self .dtype )
2498
+ return self .__class__ (
2499
+ shape = self .shape ,
2500
+ device = self .device ,
2501
+ dtype = self .dtype ,
2502
+ example_data = self .example_data ,
2503
+ )
2492
2504
2493
2505
def rand (self , shape = None ):
2494
2506
if shape is None :
2495
2507
shape = ()
2496
2508
return NonTensorData (
2497
- data = None , batch_size = (* shape , * self ._safe_shape ), device = self .device
2509
+ data = self .example_data ,
2510
+ batch_size = (* shape , * self ._safe_shape ),
2511
+ device = self .device ,
2498
2512
)
2499
2513
2500
2514
def zero (self , shape = None ):
2501
2515
if shape is None :
2502
2516
shape = ()
2503
2517
return NonTensorData (
2504
- data = None , batch_size = (* shape , * self ._safe_shape ), device = self .device
2518
+ data = self .example_data ,
2519
+ batch_size = (* shape , * self ._safe_shape ),
2520
+ device = self .device ,
2505
2521
)
2506
2522
2507
2523
def one (self , shape = None ):
2508
2524
if shape is None :
2509
2525
shape = ()
2510
2526
return NonTensorData (
2511
- data = None , batch_size = (* shape , * self ._safe_shape ), device = self .device
2527
+ data = self .example_data ,
2528
+ batch_size = (* shape , * self ._safe_shape ),
2529
+ device = self .device ,
2512
2530
)
2513
2531
2514
2532
def is_in (self , val : Any ) -> bool :
@@ -2533,23 +2551,36 @@ def expand(self, *shape):
2533
2551
raise ValueError (
2534
2552
f"The last elements of the expanded shape must match the current one. Got shape={ shape } while self.shape={ self .shape } ."
2535
2553
)
2536
- return self .__class__ (shape = shape , device = self .device , dtype = None )
2554
+ return self .__class__ (
2555
+ shape = shape , device = self .device , dtype = None , example_data = self .example_data
2556
+ )
2537
2557
2538
2558
def _reshape (self , shape ):
2539
- return self .__class__ (shape = shape , device = self .device , dtype = self .dtype )
2559
+ return self .__class__ (
2560
+ shape = shape ,
2561
+ device = self .device ,
2562
+ dtype = self .dtype ,
2563
+ example_data = self .example_data ,
2564
+ )
2540
2565
2541
2566
def _unflatten (self , dim , sizes ):
2542
2567
shape = torch .zeros (self .shape , device = "meta" ).unflatten (dim , sizes ).shape
2543
2568
return self .__class__ (
2544
2569
shape = shape ,
2545
2570
device = self .device ,
2546
2571
dtype = self .dtype ,
2572
+ example_data = self .example_data ,
2547
2573
)
2548
2574
2549
2575
def __getitem__ (self , idx : SHAPE_INDEX_TYPING ):
2550
2576
"""Indexes the current TensorSpec based on the provided index."""
2551
2577
indexed_shape = _size (_shape_indexing (self .shape , idx ))
2552
- return self .__class__ (shape = indexed_shape , device = self .device , dtype = self .dtype )
2578
+ return self .__class__ (
2579
+ shape = indexed_shape ,
2580
+ device = self .device ,
2581
+ dtype = self .dtype ,
2582
+ example_data = self .example_data ,
2583
+ )
2553
2584
2554
2585
def unbind (self , dim : int = 0 ):
2555
2586
orig_dim = dim
@@ -2565,6 +2596,7 @@ def unbind(self, dim: int = 0):
2565
2596
shape = shape ,
2566
2597
device = self .device ,
2567
2598
dtype = self .dtype ,
2599
+ example_data = self .example_data ,
2568
2600
)
2569
2601
for i in range (self .shape [dim ])
2570
2602
)
0 commit comments