Skip to content

Commit 29ac8e1

Browse files
committed
fix when video time seq len less than max time seq len for video acceptor
1 parent e05cd6d commit 29ac8e1

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.11.3',
9+
version = '1.11.4',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description = long_description,

vit_pytorch/accept_video_wrapper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def forward(
9191

9292
pos_emb = pos_emb.reshape(*pos_emb.shape[:2], *((1,) * dims_to_unsqueeze) , pos_emb.shape[-1])
9393

94-
embed = embed + pos_emb
94+
embed = embed + pos_emb[:, :embed.shape[1]]
9595

9696
outputs[self.output_pos_add_pos_emb] = embed
9797

@@ -114,16 +114,16 @@ def forward(
114114
emb_dropout = 0.1
115115
)
116116

117-
videos = torch.randn(1, 3, 10, 256, 256)
117+
videos = torch.randn(1, 3, 7, 256, 256)
118118

119119
# step up the difficulty and return embeddings for robotics
120120

121121
from vit_pytorch.extractor import Extractor
122122
v = Extractor(v)
123123

124-
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 10, dim_emb = 1024)
124+
video_acceptor = AcceptVideoWrapper(v, add_time_pos_emb = True, output_pos_add_pos_emb = 1, time_seq_len = 12, dim_emb = 1024)
125125

126126
logits, embeddings = video_acceptor(videos, eval_with_no_grad = True) # always (batch, channels, time, height, width) - time is always dimension 2
127127

128-
assert logits.shape == (1, 10, 1000)
129-
assert embeddings.shape == (1, 10, 65, 1024)
128+
assert logits.shape == (1, 7, 1000)
129+
assert embeddings.shape == (1, 7, 65, 1024)

0 commit comments

Comments
 (0)