@@ -147,6 +147,15 @@ def inner(t: Tensor, *args, **kwargs) -> Tensor:
147
147
return out
148
148
return inner
149
149
150
+ def pack_with_inverse (t , pattern ):
151
+ packed , packed_shape = pack (t , pattern )
152
+
153
+ def inverse (out , inv_pattern = None ):
154
+ inv_pattern = default (inv_pattern , pattern )
155
+ return unpack (out , packed_shape , inv_pattern )
156
+
157
+ return packed , inverse
158
+
150
159
def pack_one_with_inverse (t , pattern ):
151
160
packed , packed_shape = pack ([t ], pattern )
152
161
@@ -1115,6 +1124,7 @@ def __init__(
1115
1124
self ,
1116
1125
* ,
1117
1126
num_text_tokens ,
1127
+ num_register_tokens = 16 ,
1118
1128
transformer : dict | Transformer ,
1119
1129
dim_latent : int | tuple [int , ...] | None = None ,
1120
1130
channel_first_latent : bool | tuple [bool , ...] = False ,
@@ -1298,6 +1308,11 @@ def __init__(
1298
1308
self .latent_to_model_projs = ModuleList (latent_to_model_projs )
1299
1309
self .model_to_latent_projs = ModuleList (model_to_latent_projs )
1300
1310
1311
+ # maybe register tokens (used in hymba, renamed from "meta" to register as "meta" was reserved from above already for the modality meta tag)
1312
+
1313
+ self .register_tokens = nn .Parameter (torch .zeros (num_register_tokens , dim ))
1314
+ nn .init .normal_ (self .register_tokens , std = 0.02 )
1315
+
1301
1316
# relative positions
1302
1317
1303
1318
self .rotary_emb = RotaryEmbedding (transformer .dim_head )
@@ -2392,6 +2407,7 @@ def inner(pred_flow):
2392
2407
if modality_positions .numel () == 0 :
2393
2408
modality_positions = F .pad (modality_positions , (0 , 0 , 0 , 1 ))
2394
2409
2410
+
2395
2411
# sort the modalities tensor and sanitize, readying for noising of modalities
2396
2412
2397
2413
modality_positions , sorted_indices = order_modality_positions_by_seq_offset (modality_positions )
@@ -2415,6 +2431,18 @@ def inner(pred_flow):
2415
2431
2416
2432
tokens = einx .where ('b n, b n d, b n d' , is_any_modality , modality_tokens , text_tokens )
2417
2433
2434
+ # handle maybe meta / register tokens
2435
+
2436
+ register_tokens = repeat (self .register_tokens , '... -> b ...' , b = batch )
2437
+
2438
+ num_register_tokens = register_tokens .shape [- 2 ]
2439
+ seq_len += num_register_tokens
2440
+
2441
+ tokens , unpack_register_tokens = pack_with_inverse ((register_tokens , tokens ), 'b * d' )
2442
+ modality_positions [..., 1 ] += num_register_tokens
2443
+
2444
+ is_modalities = F .pad (is_modalities , (num_register_tokens , 0 ), value = False )
2445
+
2418
2446
# derive rotary positions
2419
2447
2420
2448
rotary_positions = derive_rotary_positions_from_modality_positions (seq_len , modality_positions )
@@ -2455,6 +2483,10 @@ def inner(pred_flow):
2455
2483
return_kv_cache = True
2456
2484
)
2457
2485
2486
+ # remove register tokens
2487
+
2488
+ _ , embed = unpack_register_tokens (embed )
2489
+
2458
2490
# early return for embedding for decoding modality
2459
2491
2460
2492
if return_embed :
0 commit comments