Skip to content

Commit 74c7ecd

Browse files
committed
add hyper connections proposed by bytedance ai labs
1 parent 16f73e1 commit 74c7ecd

File tree

4 files changed

+145
-7
lines changed

4 files changed

+145
-7
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,5 +239,15 @@ sampled = model.generate_text_only(text[:, :1], 1024)
239239
author = {Xin Dong and Y. Fu and Shizhe Diao and Wonmin Byeon and Zijia Chen and Ameya Mahabaleshwarkar and Shih-Yang Liu and Matthijs Van Keirsbilck and Min-Hung Chen and Yoshi Suhara and Yingyan Lin and Jan Kautz and Pavlo Molchanov},
240240
year = {2024},
241241
url = {https://api.semanticscholar.org/CorpusID:274166163}
242+
```
243+
244+
```bibtex
245+
@article{Zhu2024HyperConnections,
246+
title = {Hyper-Connections},
247+
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
248+
journal = {ArXiv},
249+
year = {2024},
250+
volume = {abs/2409.19606},
251+
url = {https://api.semanticscholar.org/CorpusID:272987528}
242252
}
243253
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.6.7"
3+
version = "0.7.0"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_transfusion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020

2121
@pytest.mark.parametrize('cache_kv', (False, True))
2222
@pytest.mark.parametrize('use_flex_attn', (False, True))
23+
@pytest.mark.parametrize('num_residual_streams', (1, 4))
2324
@pytest.mark.parametrize('reconstruction_loss_weight', (0., 0.1))
2425
def test_transfusion(
2526
cache_kv: bool,
2627
use_flex_attn: bool,
28+
num_residual_streams: int,
2729
reconstruction_loss_weight: float
2830
):
2931

@@ -41,7 +43,8 @@ def test_transfusion(
4143
transformer = dict(
4244
dim = 64,
4345
depth = 2,
44-
use_flex_attn = use_flex_attn
46+
use_flex_attn = use_flex_attn,
47+
num_residual_streams = num_residual_streams
4548
)
4649
)
4750

transfusion_pytorch/transfusion.py

Lines changed: 130 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import math
1818
from collections import defaultdict
1919

20+
from random import randrange
21+
from itertools import count
2022
from functools import partial, wraps, cache
2123
from typing import NamedTuple, Callable, Literal
2224

@@ -591,6 +593,97 @@ def forward(
591593
fourier_embed, _ = pack((times, freqs.sin(), freqs.cos()), 'b n *')
592594
return fourier_embed
593595

596+
# hyper connections - multiple residual streams
597+
598+
class Residual(Module):
599+
def __init__(self, **kwargs):
600+
super().__init__()
601+
602+
def prepare_with_inverse(self, residuals):
603+
branch_input, residuals, residual_kwargs = self.prepare(residuals)
604+
605+
def inverse(branch_out):
606+
return self(branch_out, residuals, **residual_kwargs)
607+
608+
return branch_input, inverse
609+
610+
def prepare(self, residuals):
611+
return residuals, residuals, dict()
612+
613+
def forward(self, branch_out, residuals, **kwargs):
614+
return branch_out + residuals
615+
616+
class HyperConnections(Module):
617+
def __init__(
618+
self,
619+
dim,
620+
*,
621+
num_residual_streams,
622+
layer_index = None,
623+
tanh = True,
624+
**kwargs
625+
):
626+
"""
627+
https://arxiv.org/abs/2409.19606
628+
Appendix J - Algorithm 2, Dynamic only
629+
"""
630+
super().__init__()
631+
632+
self.act = nn.Tanh() if tanh else nn.Identity()
633+
634+
self.norm = nn.RMSNorm(dim)
635+
636+
self.num_residual_streams = num_residual_streams
637+
layer_index = default(layer_index, randrange(num_residual_streams)) # just choose one random residual stream if layer index not given
638+
639+
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
640+
641+
init_alpha0 = torch.zeros((num_residual_streams, 1))
642+
init_alpha0[layer_index % num_residual_streams, 0] = 1.
643+
644+
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
645+
646+
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
647+
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
648+
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
649+
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
650+
651+
def prepare_with_inverse(self, residuals):
652+
branch_input, residuals, residual_kwargs = self.prepare(residuals)
653+
654+
def inverse(branch_out):
655+
return self(branch_out, residuals, **residual_kwargs)
656+
657+
return branch_input, inverse
658+
659+
def prepare(self, residuals):
660+
661+
residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams)
662+
663+
normed = self.norm(residuals)
664+
665+
wc_weight = self.act(normed @ self.dynamic_alpha_fn)
666+
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
667+
alpha = dynamic_alpha + self.static_alpha
668+
669+
dc_weight = self.act(normed @ self.dynamic_beta_fn)
670+
dynamic_beta = dc_weight * self.dynamic_beta_scale
671+
beta = dynamic_beta + self.static_beta
672+
673+
# width connection
674+
675+
mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
676+
677+
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
678+
679+
return branch_input, residuals, dict(beta = beta)
680+
681+
def forward(self, branch_output, residuals, *, beta):
682+
# 'depth' connection
683+
684+
residuals = einsum(branch_output, beta, 'b n d, b n s -> b n s d') + residuals
685+
return rearrange(residuals, 'b n s d -> (b s) n d')
686+
594687
# adaptive layernorm and ada-ln zero rolled into one wrapper
595688
# from DiT paper and sota for time conditioning for now
596689

@@ -940,7 +1033,8 @@ def __init__(
9401033
ff_kwargs: dict = dict(),
9411034
attn_laser = False,
9421035
unet_skips = True,
943-
use_flex_attn = False
1036+
use_flex_attn = False,
1037+
num_residual_streams = 1
9441038
):
9451039
super().__init__()
9461040
self.use_flex_attn = use_flex_attn
@@ -954,6 +1048,17 @@ def __init__(
9541048
nn.SiLU()
9551049
)
9561050

1051+
# hyper connections
1052+
1053+
assert num_residual_streams > 0
1054+
is_hyper_connection = num_residual_streams > 1
1055+
self.num_residual_streams = num_residual_streams
1056+
1057+
counter = count()
1058+
residual_klass = Residual if num_residual_streams == 1 else HyperConnections
1059+
1060+
# layers
1061+
9571062
layers = ModuleList([])
9581063

9591064
for ind in range(depth):
@@ -970,7 +1075,10 @@ def __init__(
9701075
attn = AdaptiveWrapper(attn, dim = dim, dim_cond = dim * 4)
9711076
ff = AdaptiveWrapper(ff, dim = dim, dim_cond = dim * 4)
9721077

973-
layers.append(ModuleList([skip_proj, attn, ff]))
1078+
attn_residual = residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_id = next(counter))
1079+
ff_residual = residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_id = next(counter))
1080+
1081+
layers.append(ModuleList([skip_proj, attn, attn_residual, ff, ff_residual]))
9741082

9751083
self.layers = layers
9761084
self.norm = RMSNorm(dim)
@@ -1060,6 +1168,11 @@ def forward(
10601168
cache = default(cache, (None,))
10611169
iter_cache = iter(cache)
10621170

1171+
# expand input into multiple residual streams for maybe hyper connection
1172+
1173+
if self.num_residual_streams > 1:
1174+
x = repeat(x, 'b ... -> (b s) ...', s = self.num_residual_streams)
1175+
10631176
# transformer layers as usual, using mask from above
10641177

10651178
skips = []
@@ -1069,7 +1182,7 @@ def forward(
10691182

10701183
depth = len(self.layers)
10711184

1072-
for ind, (skip_proj, attn, ff) in enumerate(self.layers):
1185+
for ind, (skip_proj, attn, attn_residual, ff, ff_residual) in enumerate(self.layers):
10731186
layer = ind + 1
10741187

10751188
# skip connection
@@ -1089,6 +1202,8 @@ def forward(
10891202

10901203
# attention and feedforward
10911204

1205+
x, add_attn_residual = attn_residual.prepare_with_inverse(x)
1206+
10921207
(attn_out, attn_values), kv_cache = attn(
10931208
x,
10941209
rotary_emb = rotary_emb,
@@ -1104,8 +1219,18 @@ def forward(
11041219

11051220
new_cache.append(kv_cache)
11061221

1107-
x = attn_out + x
1108-
x = ff(x, **adaptive_kwargs) + x
1222+
x = add_attn_residual(attn_out)
1223+
1224+
x, add_ff_residual = ff_residual.prepare_with_inverse(x)
1225+
1226+
ff_out = ff(x, **adaptive_kwargs)
1227+
1228+
x = add_ff_residual(ff_out)
1229+
1230+
# reduce multiple residual streams for maybe hyper connection
1231+
1232+
if self.num_residual_streams > 1:
1233+
x = reduce(x, '(b s) ... -> b ...', 'sum', s = self.num_residual_streams)
11091234

11101235
assert len(skips) == 0
11111236

0 commit comments

Comments
 (0)