@@ -91,7 +91,7 @@ def forward(
9191
9292 pos_emb = pos_emb .reshape (* pos_emb .shape [:2 ], * ((1 ,) * dims_to_unsqueeze ) , pos_emb .shape [- 1 ])
9393
94- embed = embed + pos_emb
94+ embed = embed + pos_emb [:, : embed . shape [ 1 ]]
9595
9696 outputs [self .output_pos_add_pos_emb ] = embed
9797
@@ -114,16 +114,16 @@ def forward(
114114 emb_dropout = 0.1
115115 )
116116
117- videos = torch .randn (1 , 3 , 10 , 256 , 256 )
117+ videos = torch .randn (1 , 3 , 7 , 256 , 256 )
118118
119119 # step up the difficulty and return embeddings for robotics
120120
121121 from vit_pytorch .extractor import Extractor
122122 v = Extractor (v )
123123
124- video_acceptor = AcceptVideoWrapper (v , add_time_pos_emb = True , output_pos_add_pos_emb = 1 , time_seq_len = 10 , dim_emb = 1024 )
124+ video_acceptor = AcceptVideoWrapper (v , add_time_pos_emb = True , output_pos_add_pos_emb = 1 , time_seq_len = 12 , dim_emb = 1024 )
125125
126126 logits , embeddings = video_acceptor (videos , eval_with_no_grad = True ) # always (batch, channels, time, height, width) - time is always dimension 2
127127
128- assert logits .shape == (1 , 10 , 1000 )
129- assert embeddings .shape == (1 , 10 , 65 , 1024 )
128+ assert logits .shape == (1 , 7 , 1000 )
129+ assert embeddings .shape == (1 , 7 , 65 , 1024 )
0 commit comments