diff --git a/rfdetr/models/backbone/dinov2_with_windowed_attn.py b/rfdetr/models/backbone/dinov2_with_windowed_attn.py index b315c46..8bbc398 100644 --- a/rfdetr/models/backbone/dinov2_with_windowed_attn.py +++ b/rfdetr/models/backbone/dinov2_with_windowed_attn.py @@ -312,7 +312,7 @@ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: Optional[torch.Te num_w_patches_per_window = num_w_patches // self.config.num_windows num_h_patches_per_window = num_h_patches // self.config.num_windows num_windows = self.config.num_windows - windowed_pixel_tokens = pixel_tokens_with_pos_embed.view(batch_size, num_windows, num_h_patches_per_window, num_windows, num_h_patches_per_window, -1) + windowed_pixel_tokens = pixel_tokens_with_pos_embed.view(batch_size, num_windows, num_h_patches_per_window, num_windows, num_w_patches_per_window, -1) windowed_pixel_tokens = windowed_pixel_tokens.permute(0, 1, 3, 2, 4, 5) windowed_pixel_tokens = windowed_pixel_tokens.reshape(batch_size * num_windows ** 2, num_h_patches_per_window * num_w_patches_per_window, -1) windowed_cls_token_with_pos_embed = cls_token_with_pos_embed.repeat(num_windows ** 2, 1, 1)