@@ -163,6 +163,7 @@ def __init__(
163163 dim_head = 64 ,
164164 dropout = 0. ,
165165 emb_dropout = 0. ,
166+ num_registers = 4 ,
166167 token_dropout_prob : float | None = None
167168 ):
168169 super ().__init__ ()
@@ -193,9 +194,18 @@ def __init__(
193194 nn .LayerNorm (dim ),
194195 )
195196
196- self .pos_embed_frame = nn .Parameter (torch .randn (patch_frame_dim , dim ))
197- self .pos_embed_height = nn .Parameter (torch .randn (patch_height_dim , dim ))
198- self .pos_embed_width = nn .Parameter (torch .randn (patch_width_dim , dim ))
197+ self .pos_embed_frame = nn .Parameter (torch .zeros (patch_frame_dim , dim ))
198+ self .pos_embed_height = nn .Parameter (torch .zeros (patch_height_dim , dim ))
199+ self .pos_embed_width = nn .Parameter (torch .zeros (patch_width_dim , dim ))
200+
201+ # register tokens
202+
203+ self .register_tokens = nn .Parameter (torch .zeros (num_registers , dim ))
204+
205+ nn .init .normal_ (self .pos_embed_frame , std = 0.02 )
206+ nn .init .normal_ (self .pos_embed_height , std = 0.02 )
207+ nn .init .normal_ (self .pos_embed_width , std = 0.02 )
208+ nn .init .normal_ (self .register_tokens , std = 0.02 )
199209
200210 self .dropout = nn .Dropout (emb_dropout )
201211
@@ -275,8 +285,6 @@ def forward(
275285
276286 pos_embed = frame_embed + height_embed + width_embed
277287
278- # use nested tensor for transformers and save on padding computation
279-
280288 tokens = torch .cat (tokens )
281289
282290 # linear projection to patch embeddings
@@ -287,7 +295,15 @@ def forward(
287295
288296 tokens = tokens + pos_embed
289297
290- tokens = nested_tensor (tokens .split (seq_lens .tolist ()), layout = torch .jagged , device = device )
298+ # add register tokens
299+
300+ tokens = tokens .split (seq_lens .tolist ())
301+
302+ tokens = [torch .cat ((self .register_tokens , one_tokens )) for one_tokens in tokens ]
303+
304+ # use nested tensor for transformers and save on padding computation
305+
306+ tokens = nested_tensor (tokens , layout = torch .jagged , device = device )
291307
292308 # embedding dropout
293309
0 commit comments