@@ -3082,6 +3082,31 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
3082
3082
else :
3083
3083
return self .unfolding (tensordict )
3084
3084
3085
+ def _apply_same_padding (self , dim , data , done_mask ):
3086
+ d = data .ndim + dim - 1
3087
+ res = data .clone ()
3088
+ num_repeats_per_sample = done_mask .sum (dim = - 1 )
3089
+
3090
+ if num_repeats_per_sample .dim () > 2 :
3091
+ extra_dims = num_repeats_per_sample .dim () - 2
3092
+ num_repeats_per_sample = num_repeats_per_sample .flatten (0 , extra_dims )
3093
+ res_flat_series = res .flatten (0 , extra_dims )
3094
+ else :
3095
+ extra_dims = 0
3096
+ res_flat_series = res
3097
+
3098
+ if d - 1 > extra_dims :
3099
+ res_flat_series_flat_batch = res_flat_series .flatten (1 , d - 1 )
3100
+ else :
3101
+ res_flat_series_flat_batch = res_flat_series [:, None ]
3102
+
3103
+ for sample_idx , num_repeats in enumerate (num_repeats_per_sample ):
3104
+ if num_repeats > 0 :
3105
+ res_slice = res_flat_series_flat_batch [sample_idx ]
3106
+ res_slice [:, :num_repeats ] = res_slice [:, num_repeats : num_repeats + 1 ]
3107
+
3108
+ return res
3109
+
3085
3110
@set_lazy_legacy (False )
3086
3111
def unfolding (self , tensordict : TensorDictBase ) -> TensorDictBase :
3087
3112
# it is assumed that the last dimension of the tensordict is the time dimension
@@ -3192,24 +3217,7 @@ def unfold_done(done, N):
3192
3217
if self .padding != "same" :
3193
3218
data = torch .where (done_mask_expand , self .padding_value , data )
3194
3219
else :
3195
- # TODO: This is a pretty bad implementation, could be
3196
- # made more efficient but it works!
3197
- reset_any = reset .any (- 1 , False )
3198
- reset_vals = list (data_orig [reset_any ].unbind (0 ))
3199
- j_ = float ("inf" )
3200
- reps = []
3201
- d = data .ndim + self .dim - 1
3202
- n_feat = data .shape [data .ndim + self .dim :].numel ()
3203
- for j in done_mask_expand .flatten (d , - 1 ).sum (- 1 ).view (- 1 ) // n_feat :
3204
- if j > j_ :
3205
- reset_vals = reset_vals [1 :]
3206
- reps .extend ([reset_vals [0 ]] * int (j ))
3207
- j_ = j
3208
- if reps :
3209
- reps = torch .stack (reps )
3210
- data = torch .masked_scatter (
3211
- data , done_mask_expand , reps .reshape (- 1 )
3212
- )
3220
+ data = self ._apply_same_padding (self .dim , data , done_mask )
3213
3221
3214
3222
if first_val is not None :
3215
3223
# Aggregate reset along last dim
0 commit comments