17
17
import math
18
18
from collections import defaultdict
19
19
20
+ from random import randrange
21
+ from itertools import count
20
22
from functools import partial , wraps , cache
21
23
from typing import NamedTuple , Callable , Literal
22
24
@@ -591,6 +593,97 @@ def forward(
591
593
fourier_embed , _ = pack ((times , freqs .sin (), freqs .cos ()), 'b n *' )
592
594
return fourier_embed
593
595
596
+ # hyper connections - multiple residual streams
597
+
598
+ class Residual (Module ):
599
+ def __init__ (self , ** kwargs ):
600
+ super ().__init__ ()
601
+
602
+ def prepare_with_inverse (self , residuals ):
603
+ branch_input , residuals , residual_kwargs = self .prepare (residuals )
604
+
605
+ def inverse (branch_out ):
606
+ return self (branch_out , residuals , ** residual_kwargs )
607
+
608
+ return branch_input , inverse
609
+
610
+ def prepare (self , residuals ):
611
+ return residuals , residuals , dict ()
612
+
613
+ def forward (self , branch_out , residuals , ** kwargs ):
614
+ return branch_out + residuals
615
+
616
+ class HyperConnections (Module ):
617
+ def __init__ (
618
+ self ,
619
+ dim ,
620
+ * ,
621
+ num_residual_streams ,
622
+ layer_index = None ,
623
+ tanh = True ,
624
+ ** kwargs
625
+ ):
626
+ """
627
+ https://arxiv.org/abs/2409.19606
628
+ Appendix J - Algorithm 2, Dynamic only
629
+ """
630
+ super ().__init__ ()
631
+
632
+ self .act = nn .Tanh () if tanh else nn .Identity ()
633
+
634
+ self .norm = nn .RMSNorm (dim )
635
+
636
+ self .num_residual_streams = num_residual_streams
637
+ layer_index = default (layer_index , randrange (num_residual_streams )) # just choose one random residual stream if layer index not given
638
+
639
+ self .static_beta = nn .Parameter (torch .ones (num_residual_streams ))
640
+
641
+ init_alpha0 = torch .zeros ((num_residual_streams , 1 ))
642
+ init_alpha0 [layer_index % num_residual_streams , 0 ] = 1.
643
+
644
+ self .static_alpha = nn .Parameter (torch .cat ([init_alpha0 , torch .eye (num_residual_streams )], dim = 1 ))
645
+
646
+ self .dynamic_alpha_fn = nn .Parameter (torch .zeros (dim , num_residual_streams + 1 ))
647
+ self .dynamic_alpha_scale = nn .Parameter (torch .ones (()) * 1e-2 )
648
+ self .dynamic_beta_fn = nn .Parameter (torch .zeros (dim ))
649
+ self .dynamic_beta_scale = nn .Parameter (torch .ones (()) * 1e-2 )
650
+
651
+ def prepare_with_inverse (self , residuals ):
652
+ branch_input , residuals , residual_kwargs = self .prepare (residuals )
653
+
654
+ def inverse (branch_out ):
655
+ return self (branch_out , residuals , ** residual_kwargs )
656
+
657
+ return branch_input , inverse
658
+
659
+ def prepare (self , residuals ):
660
+
661
+ residuals = rearrange (residuals , '(b s) n d -> b n s d' , s = self .num_residual_streams )
662
+
663
+ normed = self .norm (residuals )
664
+
665
+ wc_weight = self .act (normed @ self .dynamic_alpha_fn )
666
+ dynamic_alpha = wc_weight * self .dynamic_alpha_scale
667
+ alpha = dynamic_alpha + self .static_alpha
668
+
669
+ dc_weight = self .act (normed @ self .dynamic_beta_fn )
670
+ dynamic_beta = dc_weight * self .dynamic_beta_scale
671
+ beta = dynamic_beta + self .static_beta
672
+
673
+ # width connection
674
+
675
+ mix_h = einsum (alpha , residuals , '... s t, ... s d -> ... t d' )
676
+
677
+ branch_input , residuals = mix_h [..., 0 , :], mix_h [..., 1 :, :]
678
+
679
+ return branch_input , residuals , dict (beta = beta )
680
+
681
+ def forward (self , branch_output , residuals , * , beta ):
682
+ # 'depth' connection
683
+
684
+ residuals = einsum (branch_output , beta , 'b n d, b n s -> b n s d' ) + residuals
685
+ return rearrange (residuals , 'b n s d -> (b s) n d' )
686
+
594
687
# adaptive layernorm and ada-ln zero rolled into one wrapper
595
688
# from DiT paper and sota for time conditioning for now
596
689
@@ -940,7 +1033,8 @@ def __init__(
940
1033
ff_kwargs : dict = dict (),
941
1034
attn_laser = False ,
942
1035
unet_skips = True ,
943
- use_flex_attn = False
1036
+ use_flex_attn = False ,
1037
+ num_residual_streams = 1
944
1038
):
945
1039
super ().__init__ ()
946
1040
self .use_flex_attn = use_flex_attn
@@ -954,6 +1048,17 @@ def __init__(
954
1048
nn .SiLU ()
955
1049
)
956
1050
1051
+ # hyper connections
1052
+
1053
+ assert num_residual_streams > 0
1054
+ is_hyper_connection = num_residual_streams > 1
1055
+ self .num_residual_streams = num_residual_streams
1056
+
1057
+ counter = count ()
1058
+ residual_klass = Residual if num_residual_streams == 1 else HyperConnections
1059
+
1060
+ # layers
1061
+
957
1062
layers = ModuleList ([])
958
1063
959
1064
for ind in range (depth ):
@@ -970,7 +1075,10 @@ def __init__(
970
1075
attn = AdaptiveWrapper (attn , dim = dim , dim_cond = dim * 4 )
971
1076
ff = AdaptiveWrapper (ff , dim = dim , dim_cond = dim * 4 )
972
1077
973
- layers .append (ModuleList ([skip_proj , attn , ff ]))
1078
+ attn_residual = residual_klass (dim = dim , num_residual_streams = num_residual_streams , layer_id = next (counter ))
1079
+ ff_residual = residual_klass (dim = dim , num_residual_streams = num_residual_streams , layer_id = next (counter ))
1080
+
1081
+ layers .append (ModuleList ([skip_proj , attn , attn_residual , ff , ff_residual ]))
974
1082
975
1083
self .layers = layers
976
1084
self .norm = RMSNorm (dim )
@@ -1060,6 +1168,11 @@ def forward(
1060
1168
cache = default (cache , (None ,))
1061
1169
iter_cache = iter (cache )
1062
1170
1171
+ # expand input into multiple residual streams for maybe hyper connection
1172
+
1173
+ if self .num_residual_streams > 1 :
1174
+ x = repeat (x , 'b ... -> (b s) ...' , s = self .num_residual_streams )
1175
+
1063
1176
# transformer layers as usual, using mask from above
1064
1177
1065
1178
skips = []
@@ -1069,7 +1182,7 @@ def forward(
1069
1182
1070
1183
depth = len (self .layers )
1071
1184
1072
- for ind , (skip_proj , attn , ff ) in enumerate (self .layers ):
1185
+ for ind , (skip_proj , attn , attn_residual , ff , ff_residual ) in enumerate (self .layers ):
1073
1186
layer = ind + 1
1074
1187
1075
1188
# skip connection
@@ -1089,6 +1202,8 @@ def forward(
1089
1202
1090
1203
# attention and feedforward
1091
1204
1205
+ x , add_attn_residual = attn_residual .prepare_with_inverse (x )
1206
+
1092
1207
(attn_out , attn_values ), kv_cache = attn (
1093
1208
x ,
1094
1209
rotary_emb = rotary_emb ,
@@ -1104,8 +1219,18 @@ def forward(
1104
1219
1105
1220
new_cache .append (kv_cache )
1106
1221
1107
- x = attn_out + x
1108
- x = ff (x , ** adaptive_kwargs ) + x
1222
+ x = add_attn_residual (attn_out )
1223
+
1224
+ x , add_ff_residual = ff_residual .prepare_with_inverse (x )
1225
+
1226
+ ff_out = ff (x , ** adaptive_kwargs )
1227
+
1228
+ x = add_ff_residual (ff_out )
1229
+
1230
+ # reduce multiple residual streams for maybe hyper connection
1231
+
1232
+ if self .num_residual_streams > 1 :
1233
+ x = reduce (x , '(b s) ... -> b ...' , 'sum' , s = self .num_residual_streams )
1109
1234
1110
1235
assert len (skips ) == 0
1111
1236
0 commit comments