42
42
43
43
from rotary_embedding_torch import RotaryEmbedding , apply_rotary_emb
44
44
45
+ from hyper_connections import HyperConnections
46
+
45
47
from tqdm import tqdm
46
48
from loguru import logger
47
49
@@ -594,97 +596,6 @@ def forward(
594
596
fourier_embed , _ = pack ((times , freqs .sin (), freqs .cos ()), 'b n *' )
595
597
return fourier_embed
596
598
597
- # hyper connections - multiple residual streams
598
-
599
- class Residual (Module ):
600
- def __init__ (self , ** kwargs ):
601
- super ().__init__ ()
602
-
603
- def prepare_with_inverse (self , residuals ):
604
- branch_input , residuals , residual_kwargs = self .prepare (residuals )
605
-
606
- def inverse (branch_out ):
607
- return self (branch_out , residuals , ** residual_kwargs )
608
-
609
- return branch_input , inverse
610
-
611
- def prepare (self , residuals ):
612
- return residuals , residuals , dict ()
613
-
614
- def forward (self , branch_out , residuals , ** kwargs ):
615
- return branch_out + residuals
616
-
617
- class HyperConnections (Module ):
618
- def __init__ (
619
- self ,
620
- dim ,
621
- * ,
622
- num_residual_streams ,
623
- layer_index = None ,
624
- tanh = True ,
625
- ** kwargs
626
- ):
627
- """
628
- https://arxiv.org/abs/2409.19606
629
- Appendix J - Algorithm 2, Dynamic only
630
- """
631
- super ().__init__ ()
632
-
633
- self .act = nn .Tanh () if tanh else nn .Identity ()
634
-
635
- self .norm = nn .RMSNorm (dim )
636
-
637
- self .num_residual_streams = num_residual_streams
638
- layer_index = default (layer_index , randrange (num_residual_streams )) # just choose one random residual stream if layer index not given
639
-
640
- self .static_beta = nn .Parameter (torch .ones (num_residual_streams ))
641
-
642
- init_alpha0 = torch .zeros ((num_residual_streams , 1 ))
643
- init_alpha0 [layer_index % num_residual_streams , 0 ] = 1.
644
-
645
- self .static_alpha = nn .Parameter (torch .cat ([init_alpha0 , torch .eye (num_residual_streams )], dim = 1 ))
646
-
647
- self .dynamic_alpha_fn = nn .Parameter (torch .zeros (dim , num_residual_streams + 1 ))
648
- self .dynamic_alpha_scale = nn .Parameter (torch .ones (()) * 1e-2 )
649
- self .dynamic_beta_fn = nn .Parameter (torch .zeros (dim ))
650
- self .dynamic_beta_scale = nn .Parameter (torch .ones (()) * 1e-2 )
651
-
652
- def prepare_with_inverse (self , residuals ):
653
- branch_input , residuals , residual_kwargs = self .prepare (residuals )
654
-
655
- def inverse (branch_out ):
656
- return self (branch_out , residuals , ** residual_kwargs )
657
-
658
- return branch_input , inverse
659
-
660
- def prepare (self , residuals ):
661
-
662
- residuals = rearrange (residuals , '(b s) n d -> b n s d' , s = self .num_residual_streams )
663
-
664
- normed = self .norm (residuals )
665
-
666
- wc_weight = self .act (normed @ self .dynamic_alpha_fn )
667
- dynamic_alpha = wc_weight * self .dynamic_alpha_scale
668
- alpha = dynamic_alpha + self .static_alpha
669
-
670
- dc_weight = self .act (normed @ self .dynamic_beta_fn )
671
- dynamic_beta = dc_weight * self .dynamic_beta_scale
672
- beta = dynamic_beta + self .static_beta
673
-
674
- # width connection
675
-
676
- mix_h = einsum (alpha , residuals , '... s t, ... s d -> ... t d' )
677
-
678
- branch_input , residuals = mix_h [..., 0 , :], mix_h [..., 1 :, :]
679
-
680
- return branch_input , residuals , dict (beta = beta )
681
-
682
- def forward (self , branch_output , residuals , * , beta ):
683
- # 'depth' connection
684
-
685
- residuals = einsum (branch_output , beta , 'b n d, b n s -> b n s d' ) + residuals
686
- return rearrange (residuals , 'b n s d -> (b s) n d' )
687
-
688
599
# adaptive layernorm and ada-ln zero rolled into one wrapper
689
600
# from DiT paper and sota for time conditioning for now
690
601
@@ -1056,7 +967,8 @@ def __init__(
1056
967
self .num_residual_streams = num_residual_streams
1057
968
1058
969
counter = count ()
1059
- residual_klass = Residual if num_residual_streams == 1 else HyperConnections
970
+
971
+ init_residual_fn , self .expand_stream , self .reduce_stream = HyperConnections .get_init_and_expand_reduce_stream_functions (num_residual_streams , disable = num_residual_streams == 1 )
1060
972
1061
973
# layers
1062
974
@@ -1076,8 +988,8 @@ def __init__(
1076
988
attn = AdaptiveWrapper (attn , dim = dim , dim_cond = dim * 4 )
1077
989
ff = AdaptiveWrapper (ff , dim = dim , dim_cond = dim * 4 )
1078
990
1079
- attn_residual = residual_klass (dim = dim , num_residual_streams = num_residual_streams , layer_id = next (counter ))
1080
- ff_residual = residual_klass (dim = dim , num_residual_streams = num_residual_streams , layer_id = next (counter ))
991
+ attn_residual = init_residual_fn (dim = dim , layer_index = next (counter ))
992
+ ff_residual = init_residual_fn (dim = dim , layer_index = next (counter ))
1081
993
1082
994
layers .append (ModuleList ([skip_proj , attn , attn_residual , ff , ff_residual ]))
1083
995
@@ -1171,8 +1083,7 @@ def forward(
1171
1083
1172
1084
# expand input into multiple residual streams for maybe hyper connection
1173
1085
1174
- if self .num_residual_streams > 1 :
1175
- x = repeat (x , 'b ... -> (b s) ...' , s = self .num_residual_streams )
1086
+ x = self .expand_stream (x )
1176
1087
1177
1088
# transformer layers as usual, using mask from above
1178
1089
@@ -1203,7 +1114,7 @@ def forward(
1203
1114
1204
1115
# attention and feedforward
1205
1116
1206
- x , add_attn_residual = attn_residual . prepare_with_inverse (x )
1117
+ x , add_attn_residual = attn_residual (x )
1207
1118
1208
1119
(attn_out , attn_values ), kv_cache = attn (
1209
1120
x ,
@@ -1222,16 +1133,15 @@ def forward(
1222
1133
1223
1134
x = add_attn_residual (attn_out )
1224
1135
1225
- x , add_ff_residual = ff_residual . prepare_with_inverse (x )
1136
+ x , add_ff_residual = ff_residual (x )
1226
1137
1227
1138
ff_out = ff (x , ** adaptive_kwargs )
1228
1139
1229
1140
x = add_ff_residual (ff_out )
1230
1141
1231
1142
# reduce multiple residual streams for maybe hyper connection
1232
1143
1233
- if self .num_residual_streams > 1 :
1234
- x = reduce (x , '(b s) ... -> b ...' , 'sum' , s = self .num_residual_streams )
1144
+ x = self .reduce_stream (x )
1235
1145
1236
1146
assert len (skips ) == 0
1237
1147
0 commit comments