Skip to content

Commit 8b14cd4

Browse files
committed
add the meta/register tokens used successfully in Hymba
1 parent b766ef6 commit 8b14cd4

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,12 @@ sampled = model.generate_text_only(text[:, :1], 1024)
232232
url = {https://api.semanticscholar.org/CorpusID:273849947}
233233
}
234234
```
235+
236+
```bibtex
237+
@inproceedings{Dong2024HymbaAH,
238+
title = {Hymba: A Hybrid-head Architecture for Small Language Models},
239+
author = {Xin Dong and Y. Fu and Shizhe Diao and Wonmin Byeon and Zijia Chen and Ameya Mahabaleshwarkar and Shih-Yang Liu and Matthijs Van Keirsbilck and Min-Hung Chen and Yoshi Suhara and Yingyan Lin and Jan Kautz and Pavlo Molchanov},
240+
year = {2024},
241+
url = {https://api.semanticscholar.org/CorpusID:274166163}
242+
}
243+
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.6.5"
3+
version = "0.6.6"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,15 @@ def inner(t: Tensor, *args, **kwargs) -> Tensor:
147147
return out
148148
return inner
149149

150+
def pack_with_inverse(t, pattern):
151+
packed, packed_shape = pack(t, pattern)
152+
153+
def inverse(out, inv_pattern = None):
154+
inv_pattern = default(inv_pattern, pattern)
155+
return unpack(out, packed_shape, inv_pattern)
156+
157+
return packed, inverse
158+
150159
def pack_one_with_inverse(t, pattern):
151160
packed, packed_shape = pack([t], pattern)
152161

@@ -1115,6 +1124,7 @@ def __init__(
11151124
self,
11161125
*,
11171126
num_text_tokens,
1127+
num_register_tokens = 16,
11181128
transformer: dict | Transformer,
11191129
dim_latent: int | tuple[int, ...] | None = None,
11201130
channel_first_latent: bool | tuple[bool, ...] = False,
@@ -1298,6 +1308,11 @@ def __init__(
12981308
self.latent_to_model_projs = ModuleList(latent_to_model_projs)
12991309
self.model_to_latent_projs = ModuleList(model_to_latent_projs)
13001310

1311+
# maybe register tokens (used in hymba, renamed from "meta" to register as "meta" was reserved from above already for the modality meta tag)
1312+
1313+
self.register_tokens = nn.Parameter(torch.zeros(num_register_tokens, dim))
1314+
nn.init.normal_(self.register_tokens, std = 0.02)
1315+
13011316
# relative positions
13021317

13031318
self.rotary_emb = RotaryEmbedding(transformer.dim_head)
@@ -2392,6 +2407,7 @@ def inner(pred_flow):
23922407
if modality_positions.numel() == 0:
23932408
modality_positions = F.pad(modality_positions, (0, 0, 0, 1))
23942409

2410+
23952411
# sort the modalities tensor and sanitize, readying for noising of modalities
23962412

23972413
modality_positions, sorted_indices = order_modality_positions_by_seq_offset(modality_positions)
@@ -2415,6 +2431,18 @@ def inner(pred_flow):
24152431

24162432
tokens = einx.where('b n, b n d, b n d', is_any_modality, modality_tokens, text_tokens)
24172433

2434+
# handle maybe meta / register tokens
2435+
2436+
register_tokens = repeat(self.register_tokens, '... -> b ...', b = batch)
2437+
2438+
num_register_tokens = register_tokens.shape[-2]
2439+
seq_len += num_register_tokens
2440+
2441+
tokens, unpack_register_tokens = pack_with_inverse((register_tokens, tokens), 'b * d')
2442+
modality_positions[..., 1] += num_register_tokens
2443+
2444+
is_modalities = F.pad(is_modalities, (num_register_tokens, 0), value = False)
2445+
24182446
# derive rotary positions
24192447

24202448
rotary_positions = derive_rotary_positions_from_modality_positions(seq_len, modality_positions)
@@ -2455,6 +2483,10 @@ def inner(pred_flow):
24552483
return_kv_cache = True
24562484
)
24572485

2486+
# remove register tokens
2487+
2488+
_, embed = unpack_register_tokens(embed)
2489+
24582490
# early return for embedding for decoding modality
24592491

24602492
if return_embed:

0 commit comments

Comments
 (0)