Skip to content

Commit 9539d45

Browse files
committed
move hyper connections to external reusable lib
1 parent 088f148 commit 9539d45

File tree

2 files changed

+11
-101
lines changed

2 files changed

+11
-101
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.7.0"
3+
version = "0.7.1"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

transfusion_pytorch/transfusion.py

Lines changed: 10 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242

4343
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
4444

45+
from hyper_connections import HyperConnections
46+
4547
from tqdm import tqdm
4648
from loguru import logger
4749

@@ -594,97 +596,6 @@ def forward(
594596
fourier_embed, _ = pack((times, freqs.sin(), freqs.cos()), 'b n *')
595597
return fourier_embed
596598

597-
# hyper connections - multiple residual streams
598-
599-
class Residual(Module):
600-
def __init__(self, **kwargs):
601-
super().__init__()
602-
603-
def prepare_with_inverse(self, residuals):
604-
branch_input, residuals, residual_kwargs = self.prepare(residuals)
605-
606-
def inverse(branch_out):
607-
return self(branch_out, residuals, **residual_kwargs)
608-
609-
return branch_input, inverse
610-
611-
def prepare(self, residuals):
612-
return residuals, residuals, dict()
613-
614-
def forward(self, branch_out, residuals, **kwargs):
615-
return branch_out + residuals
616-
617-
class HyperConnections(Module):
618-
def __init__(
619-
self,
620-
dim,
621-
*,
622-
num_residual_streams,
623-
layer_index = None,
624-
tanh = True,
625-
**kwargs
626-
):
627-
"""
628-
https://arxiv.org/abs/2409.19606
629-
Appendix J - Algorithm 2, Dynamic only
630-
"""
631-
super().__init__()
632-
633-
self.act = nn.Tanh() if tanh else nn.Identity()
634-
635-
self.norm = nn.RMSNorm(dim)
636-
637-
self.num_residual_streams = num_residual_streams
638-
layer_index = default(layer_index, randrange(num_residual_streams)) # just choose one random residual stream if layer index not given
639-
640-
self.static_beta = nn.Parameter(torch.ones(num_residual_streams))
641-
642-
init_alpha0 = torch.zeros((num_residual_streams, 1))
643-
init_alpha0[layer_index % num_residual_streams, 0] = 1.
644-
645-
self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
646-
647-
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
648-
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
649-
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
650-
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)
651-
652-
def prepare_with_inverse(self, residuals):
653-
branch_input, residuals, residual_kwargs = self.prepare(residuals)
654-
655-
def inverse(branch_out):
656-
return self(branch_out, residuals, **residual_kwargs)
657-
658-
return branch_input, inverse
659-
660-
def prepare(self, residuals):
661-
662-
residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams)
663-
664-
normed = self.norm(residuals)
665-
666-
wc_weight = self.act(normed @ self.dynamic_alpha_fn)
667-
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
668-
alpha = dynamic_alpha + self.static_alpha
669-
670-
dc_weight = self.act(normed @ self.dynamic_beta_fn)
671-
dynamic_beta = dc_weight * self.dynamic_beta_scale
672-
beta = dynamic_beta + self.static_beta
673-
674-
# width connection
675-
676-
mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')
677-
678-
branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]
679-
680-
return branch_input, residuals, dict(beta = beta)
681-
682-
def forward(self, branch_output, residuals, *, beta):
683-
# 'depth' connection
684-
685-
residuals = einsum(branch_output, beta, 'b n d, b n s -> b n s d') + residuals
686-
return rearrange(residuals, 'b n s d -> (b s) n d')
687-
688599
# adaptive layernorm and ada-ln zero rolled into one wrapper
689600
# from DiT paper and sota for time conditioning for now
690601

@@ -1056,7 +967,8 @@ def __init__(
1056967
self.num_residual_streams = num_residual_streams
1057968

1058969
counter = count()
1059-
residual_klass = Residual if num_residual_streams == 1 else HyperConnections
970+
971+
init_residual_fn, self.expand_stream, self.reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
1060972

1061973
# layers
1062974

@@ -1076,8 +988,8 @@ def __init__(
1076988
attn = AdaptiveWrapper(attn, dim = dim, dim_cond = dim * 4)
1077989
ff = AdaptiveWrapper(ff, dim = dim, dim_cond = dim * 4)
1078990

1079-
attn_residual = residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_id = next(counter))
1080-
ff_residual = residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_id = next(counter))
991+
attn_residual = init_residual_fn(dim = dim, layer_index = next(counter))
992+
ff_residual = init_residual_fn(dim = dim, layer_index = next(counter))
1081993

1082994
layers.append(ModuleList([skip_proj, attn, attn_residual, ff, ff_residual]))
1083995

@@ -1171,8 +1083,7 @@ def forward(
11711083

11721084
# expand input into multiple residual streams for maybe hyper connection
11731085

1174-
if self.num_residual_streams > 1:
1175-
x = repeat(x, 'b ... -> (b s) ...', s = self.num_residual_streams)
1086+
x = self.expand_stream(x)
11761087

11771088
# transformer layers as usual, using mask from above
11781089

@@ -1203,7 +1114,7 @@ def forward(
12031114

12041115
# attention and feedforward
12051116

1206-
x, add_attn_residual = attn_residual.prepare_with_inverse(x)
1117+
x, add_attn_residual = attn_residual(x)
12071118

12081119
(attn_out, attn_values), kv_cache = attn(
12091120
x,
@@ -1222,16 +1133,15 @@ def forward(
12221133

12231134
x = add_attn_residual(attn_out)
12241135

1225-
x, add_ff_residual = ff_residual.prepare_with_inverse(x)
1136+
x, add_ff_residual = ff_residual(x)
12261137

12271138
ff_out = ff(x, **adaptive_kwargs)
12281139

12291140
x = add_ff_residual(ff_out)
12301141

12311142
# reduce multiple residual streams for maybe hyper connection
12321143

1233-
if self.num_residual_streams > 1:
1234-
x = reduce(x, '(b s) ... -> b ...', 'sum', s = self.num_residual_streams)
1144+
x = self.reduce_stream(x)
12351145

12361146
assert len(skips) == 0
12371147

0 commit comments

Comments
 (0)