@@ -24,6 +24,23 @@ def cast_tuple(val, depth = 1):
2424
2525# classes
2626
27+ # https://arxiv.org/abs/2103.17239
28+ class LayerScale (nn .Module ):
29+ def __init__ (self , dim , depth , fn ):
30+ super ().__init__ ()
31+ if depth <= 18 :
32+ init_eps = 0.1
33+ elif depth > 18 and depth <= 24 :
34+ init_eps = 1e-5
35+ else :
36+ init_eps = 1e-6
37+
38+ scale = torch .zeros (1 , 1 , dim ).fill_ (init_eps )
39+ self .scale = nn .Parameter (scale )
40+ self .fn = fn
41+ def forward (self , x , ** kwargs ):
42+ return self .fn (x , ** kwargs ) * self .scale
43+
2744class PreNorm (nn .Module ):
2845 def __init__ (self , dim , fn ):
2946 super ().__init__ ()
@@ -77,7 +94,7 @@ def __init__(
7794 attn_types = cast_tuple (attn_types )
7895 attn_type_layer = islice (cycle (attn_types ), depth )
7996
80- for _ , sparse_attn , attn_type in zip (range (depth ), sparse_layer , attn_type_layer ):
97+ for ind , sparse_attn , attn_type in zip (range (depth ), sparse_layer , attn_type_layer ):
8198 if attn_type == 'full' :
8299 attn_class = Attention
83100 elif attn_type == 'sparse' :
@@ -92,8 +109,8 @@ def __init__(
92109 raise ValueError (f'attention type "{ attn_type } " is not valid' )
93110
94111 layers .append (nn .ModuleList ([
95- PreNorm (dim , attn_class (dim , causal = causal , seq_len = seq_len , heads = heads , dim_head = dim_head , dropout = attn_dropout )),
96- PreNorm (dim , FeedForward (dim , mult = ff_mult , dropout = ff_dropout ))
112+ LayerScale ( dim , ind + 1 , PreNorm (dim , attn_class (dim , causal = causal , seq_len = seq_len , heads = heads , dim_head = dim_head , dropout = attn_dropout ) )),
113+ LayerScale ( dim , ind + 1 , PreNorm (dim , FeedForward (dim , mult = ff_mult , dropout = ff_dropout ) ))
97114 ]))
98115
99116 execute_type = ReversibleSequence if reversible else SequentialSequence
0 commit comments