@@ -22,6 +22,25 @@ def default(v, d):
22
22
def l2norm (t , dim = - 1 ):
23
23
return F .normalize (t , dim = dim , p = 2 )
24
24
25
+ # scale
26
+
27
+ class Scale (Module ):
28
+ """
29
+ latter part of section 2.5 in the paper
30
+ """
31
+ def __init__ (
32
+ self ,
33
+ dim ,
34
+ init = 1. ,
35
+ scale = 1.
36
+ ):
37
+ super ().__init__ ()
38
+ self .scale = nn .Parameter (torch .ones (dim ) * scale )
39
+ self .forward_scale = init / scale
40
+
41
+ def forward (self ):
42
+ return self .scale * self .forward_scale
43
+
25
44
# for use with parametrize
26
45
27
46
class L2Norm (Module ):
@@ -87,13 +106,18 @@ def __init__(
87
106
super ().__init__ ()
88
107
NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights )
89
108
109
+ dim_sqrt = dim ** 0.5
110
+ self .dim_sqrt = dim_sqrt
111
+ self .attn_scale = dim_head ** 0.5
112
+
90
113
dim_inner = dim_head * heads
91
114
self .to_q = NormLinear_ (dim , dim_inner )
92
115
self .to_k = NormLinear_ (dim , dim_inner )
93
116
self .to_v = NormLinear_ (dim , dim_inner )
94
117
95
118
self .rotary_emb = RotaryEmbedding (dim_head )
96
- self .qk_scale = nn .Parameter (torch .ones (dim_head ) * (dim_head ** 0.25 ))
119
+ self .q_scale = Scale (dim , 1 , dim ** - 0.5 )
120
+ self .k_scale = Scale (dim , 1 , dim ** - 0.5 )
97
121
98
122
self .norm_qk = norm_qk
99
123
self .split_heads = Rearrange ('b n (h d) -> b h n d' , h = heads )
@@ -107,28 +131,31 @@ def forward(
107
131
):
108
132
q , k , v = self .to_q (x ), self .to_k (x ), self .to_v (x )
109
133
134
+ # scaling queries and keys - this would line up with the popular use of qk rmsnorm from google deepmind and now black forest labs
135
+
136
+ q = q * self .q_scale ()
137
+ k = k * self .k_scale ()
138
+
139
+ # split heads
140
+
110
141
q , k , v = map (self .split_heads , (q , k , v ))
111
142
112
143
# maybe query key norm
113
144
114
145
if self .norm_qk :
115
146
q , k = map (l2norm , (q , k ))
116
147
117
- # scaling queries and keys - this would line up with the popular use of qk rmsnorm from google deepmind and now black forest labs
118
-
119
- q , k = (q * self .qk_scale ), (k * self .qk_scale )
120
-
121
148
# rotary positions
122
149
123
150
q = self .rotary_emb .rotate_queries_or_keys (q )
124
151
k = self .rotary_emb .rotate_queries_or_keys (k )
125
152
126
- # scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16
153
+ # scale is sqrt (dk)
127
154
128
155
out = F .scaled_dot_product_attention (
129
156
q , k , v ,
130
157
is_causal = True ,
131
- scale = 1.
158
+ scale = self . attn_scale
132
159
)
133
160
134
161
out = self .merge_heads (out )
@@ -151,16 +178,16 @@ def __init__(
151
178
self .to_hidden = NormLinear_ (dim , dim_inner )
152
179
self .to_gate = NormLinear_ (dim , dim_inner )
153
180
154
- self .hidden_scale = nn . Parameter ( torch . ones ( dim_inner ) )
155
- self .gate_scale = nn . Parameter ( torch . ones ( dim_inner ) )
181
+ self .hidden_scale = Scale ( dim_inner )
182
+ self .gate_scale = Scale ( dim_inner )
156
183
157
184
self .to_out = NormLinear_ (dim_inner , dim , norm_dim_in = False )
158
185
159
186
def forward (self , x ):
160
187
hidden , gate = self .to_hidden (x ), self .to_gate (x )
161
188
162
- hidden = hidden * self .hidden_scale
163
- gate = gate * self .gate_scale * (self .dim ** 0.5 )
189
+ hidden = hidden * self .hidden_scale ()
190
+ gate = gate * self .gate_scale () * (self .dim ** 0.5 )
164
191
165
192
hidden = F .silu (gate ) * hidden
166
193
return self .to_out (hidden )
@@ -187,28 +214,23 @@ def __init__(
187
214
NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights )
188
215
189
216
self .dim = dim
190
-
191
217
residual_lerp_scale_init = default (residual_lerp_scale_init , 1. / depth )
192
218
193
219
self .token_embed = NormLinear_ (dim , num_tokens )
194
220
195
221
self .layers = ModuleList ([])
196
- self .residual_lerp_scales = nn .ParameterList ([])
197
222
198
223
for _ in range (depth ):
199
224
self .layers .append (ModuleList ([
200
225
Attention (dim , dim_head = dim_head , heads = heads , norm_qk = attn_norm_qk , manual_norm_weights = manual_norm_weights ),
201
226
FeedForward (dim , expand_factor = ff_expand_factor , manual_norm_weights = manual_norm_weights ),
202
- ]))
203
-
204
- self .residual_lerp_scales .append (nn .ParameterList ([
205
- nn .Parameter (torch .ones (dim ) * residual_lerp_scale_init ),
206
- nn .Parameter (torch .ones (dim ) * residual_lerp_scale_init ),
227
+ Scale (dim , residual_lerp_scale_init , dim ** - 0.5 ),
228
+ Scale (dim , residual_lerp_scale_init , dim ** - 0.5 ),
207
229
]))
208
230
209
231
self .to_logits = NormLinear_ (dim , num_tokens ) if not tied_embedding else None
210
232
211
- self .logit_scale = nn . Parameter ( torch . ones ( num_tokens ) )
233
+ self .logit_scale = Scale ( num_tokens , 1. , dim ** - 0.5 )
212
234
213
235
self .ignore_index = ce_ignore_index
214
236
@@ -232,21 +254,21 @@ def forward(
232
254
token_embed = self .token_embed .weight
233
255
tokens = token_embed [ids ]
234
256
235
- for ( attn , ff ), ( attn_alpha , ff_alpha ) in zip ( self .layers , self . residual_lerp_scales ) :
257
+ for attn , ff , attn_alpha , ff_alpha in self .layers :
236
258
237
259
attn_out = l2norm (attn (tokens ))
238
- tokens = l2norm (tokens .lerp (attn_out , attn_alpha ))
260
+ tokens = l2norm (tokens .lerp (attn_out , attn_alpha () ))
239
261
240
262
ff_out = l2norm (ff (tokens ))
241
- tokens = l2norm (tokens .lerp (ff_out , ff_alpha ))
263
+ tokens = l2norm (tokens .lerp (ff_out , ff_alpha () ))
242
264
243
265
if exists (self .to_logits ):
244
266
logits = self .to_logits (tokens )
245
267
else :
246
268
# tied embeddings
247
269
logits = einsum (tokens , token_embed , 'b n d, c d -> b n c' )
248
270
249
- logits = logits * self .logit_scale * ( self . dim ** 0.5 )
271
+ logits = logits * self .logit_scale ( )
250
272
251
273
if not return_loss :
252
274
return logits
0 commit comments