@@ -946,7 +946,7 @@ def __init__(
946
946
attn_laser = False ,
947
947
unet_skips = True ,
948
948
use_flex_attn = False ,
949
- num_residual_streams = 1
949
+ num_residual_streams = 4
950
950
):
951
951
super ().__init__ ()
952
952
self .use_flex_attn = use_flex_attn
@@ -1160,7 +1160,6 @@ def __init__(
1160
1160
self ,
1161
1161
* ,
1162
1162
num_text_tokens ,
1163
- num_register_tokens = 16 ,
1164
1163
transformer : dict | Transformer ,
1165
1164
dim_latent : int | tuple [int , ...] | None = None ,
1166
1165
channel_first_latent : bool | tuple [bool , ...] = False ,
@@ -1344,11 +1343,6 @@ def __init__(
1344
1343
self .latent_to_model_projs = ModuleList (latent_to_model_projs )
1345
1344
self .model_to_latent_projs = ModuleList (model_to_latent_projs )
1346
1345
1347
- # maybe register tokens (used in hymba, renamed from "meta" to register as "meta" was reserved from above already for the modality meta tag)
1348
-
1349
- self .register_tokens = nn .Parameter (torch .zeros (num_register_tokens , dim ))
1350
- nn .init .normal_ (self .register_tokens , std = 0.02 )
1351
-
1352
1346
# relative positions
1353
1347
1354
1348
self .rotary_emb = RotaryEmbedding (transformer .dim_head )
@@ -2467,18 +2461,6 @@ def inner(pred_flow):
2467
2461
2468
2462
tokens = einx .where ('b n, b n d, b n d' , is_any_modality , modality_tokens , text_tokens )
2469
2463
2470
- # handle maybe meta / register tokens
2471
-
2472
- register_tokens = repeat (self .register_tokens , '... -> b ...' , b = batch )
2473
-
2474
- num_register_tokens = register_tokens .shape [- 2 ]
2475
- seq_len += num_register_tokens
2476
-
2477
- tokens , unpack_register_tokens = pack_with_inverse ((register_tokens , tokens ), 'b * d' )
2478
- modality_positions [..., 1 ] += num_register_tokens
2479
-
2480
- is_modalities = F .pad (is_modalities , (num_register_tokens , 0 ), value = False )
2481
-
2482
2464
# derive rotary positions
2483
2465
2484
2466
rotary_positions = derive_rotary_positions_from_modality_positions (seq_len , modality_positions )
@@ -2519,11 +2501,6 @@ def inner(pred_flow):
2519
2501
return_kv_cache = True
2520
2502
)
2521
2503
2522
- if not exists (decode_length ):
2523
- # remove register tokens
2524
-
2525
- _ , embed = unpack_register_tokens (embed )
2526
-
2527
2504
# early return for embedding for decoding modality
2528
2505
2529
2506
if return_embed :
0 commit comments