Skip to content

Commit 8b2d0be

Browse files
committed
remove register tokens for now and turn on hyper connections
1 parent 88287b7 commit 8b2d0be

File tree

2 files changed

+2
-25
lines changed

2 files changed

+2
-25
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.7.2"
3+
version = "0.8.0"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ def __init__(
946946
attn_laser = False,
947947
unet_skips = True,
948948
use_flex_attn = False,
949-
num_residual_streams = 1
949+
num_residual_streams = 4
950950
):
951951
super().__init__()
952952
self.use_flex_attn = use_flex_attn
@@ -1160,7 +1160,6 @@ def __init__(
11601160
self,
11611161
*,
11621162
num_text_tokens,
1163-
num_register_tokens = 16,
11641163
transformer: dict | Transformer,
11651164
dim_latent: int | tuple[int, ...] | None = None,
11661165
channel_first_latent: bool | tuple[bool, ...] = False,
@@ -1344,11 +1343,6 @@ def __init__(
13441343
self.latent_to_model_projs = ModuleList(latent_to_model_projs)
13451344
self.model_to_latent_projs = ModuleList(model_to_latent_projs)
13461345

1347-
# maybe register tokens (used in hymba, renamed from "meta" to register as "meta" was reserved from above already for the modality meta tag)
1348-
1349-
self.register_tokens = nn.Parameter(torch.zeros(num_register_tokens, dim))
1350-
nn.init.normal_(self.register_tokens, std = 0.02)
1351-
13521346
# relative positions
13531347

13541348
self.rotary_emb = RotaryEmbedding(transformer.dim_head)
@@ -2467,18 +2461,6 @@ def inner(pred_flow):
24672461

24682462
tokens = einx.where('b n, b n d, b n d', is_any_modality, modality_tokens, text_tokens)
24692463

2470-
# handle maybe meta / register tokens
2471-
2472-
register_tokens = repeat(self.register_tokens, '... -> b ...', b = batch)
2473-
2474-
num_register_tokens = register_tokens.shape[-2]
2475-
seq_len += num_register_tokens
2476-
2477-
tokens, unpack_register_tokens = pack_with_inverse((register_tokens, tokens), 'b * d')
2478-
modality_positions[..., 1] += num_register_tokens
2479-
2480-
is_modalities = F.pad(is_modalities, (num_register_tokens, 0), value = False)
2481-
24822464
# derive rotary positions
24832465

24842466
rotary_positions = derive_rotary_positions_from_modality_positions(seq_len, modality_positions)
@@ -2519,11 +2501,6 @@ def inner(pred_flow):
25192501
return_kv_cache = True
25202502
)
25212503

2522-
if not exists(decode_length):
2523-
# remove register tokens
2524-
2525-
_, embed = unpack_register_tokens(embed)
2526-
25272504
# early return for embedding for decoding modality
25282505

25292506
if return_embed:

0 commit comments

Comments
 (0)