Skip to content

Commit d4842fe

Browse files
authored
[Performance] Faster CatFrames.unfolding with padding="same" (#2407)
1 parent ca6eae4 commit d4842fe

File tree

1 file changed

+26
-18
lines changed

1 file changed

+26
-18
lines changed

torchrl/envs/transforms/transforms.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3082,6 +3082,31 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
30823082
else:
30833083
return self.unfolding(tensordict)
30843084

3085+
def _apply_same_padding(self, dim, data, done_mask):
3086+
d = data.ndim + dim - 1
3087+
res = data.clone()
3088+
num_repeats_per_sample = done_mask.sum(dim=-1)
3089+
3090+
if num_repeats_per_sample.dim() > 2:
3091+
extra_dims = num_repeats_per_sample.dim() - 2
3092+
num_repeats_per_sample = num_repeats_per_sample.flatten(0, extra_dims)
3093+
res_flat_series = res.flatten(0, extra_dims)
3094+
else:
3095+
extra_dims = 0
3096+
res_flat_series = res
3097+
3098+
if d - 1 > extra_dims:
3099+
res_flat_series_flat_batch = res_flat_series.flatten(1, d - 1)
3100+
else:
3101+
res_flat_series_flat_batch = res_flat_series[:, None]
3102+
3103+
for sample_idx, num_repeats in enumerate(num_repeats_per_sample):
3104+
if num_repeats > 0:
3105+
res_slice = res_flat_series_flat_batch[sample_idx]
3106+
res_slice[:, :num_repeats] = res_slice[:, num_repeats : num_repeats + 1]
3107+
3108+
return res
3109+
30853110
@set_lazy_legacy(False)
30863111
def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase:
30873112
# it is assumed that the last dimension of the tensordict is the time dimension
@@ -3192,24 +3217,7 @@ def unfold_done(done, N):
31923217
if self.padding != "same":
31933218
data = torch.where(done_mask_expand, self.padding_value, data)
31943219
else:
3195-
# TODO: This is a pretty bad implementation, could be
3196-
# made more efficient but it works!
3197-
reset_any = reset.any(-1, False)
3198-
reset_vals = list(data_orig[reset_any].unbind(0))
3199-
j_ = float("inf")
3200-
reps = []
3201-
d = data.ndim + self.dim - 1
3202-
n_feat = data.shape[data.ndim + self.dim :].numel()
3203-
for j in done_mask_expand.flatten(d, -1).sum(-1).view(-1) // n_feat:
3204-
if j > j_:
3205-
reset_vals = reset_vals[1:]
3206-
reps.extend([reset_vals[0]] * int(j))
3207-
j_ = j
3208-
if reps:
3209-
reps = torch.stack(reps)
3210-
data = torch.masked_scatter(
3211-
data, done_mask_expand, reps.reshape(-1)
3212-
)
3220+
data = self._apply_same_padding(self.dim, data, done_mask)
32133221

32143222
if first_val is not None:
32153223
# Aggregate reset along last dim

0 commit comments

Comments
 (0)