1
+ from __future__ import annotations
2
+
1
3
from functools import partial
2
4
3
5
import torch
@@ -19,6 +21,11 @@ def exists(v):
19
21
def default (v , d ):
20
22
return v if exists (v ) else d
21
23
24
+ def cast_tuple (t , length = 1 ):
25
+ out = t if isinstance (t , tuple ) else ((t ,) * length )
26
+ assert len (out ) == length
27
+ return out
28
+
22
29
def l2norm (t , dim = - 1 ):
23
30
return F .normalize (t , dim = dim , p = 2 )
24
31
@@ -101,7 +108,9 @@ def __init__(
101
108
dim_head = 64 ,
102
109
heads = 8 ,
103
110
norm_qk = True ,
104
- manual_norm_weights = False
111
+ manual_norm_weights = False ,
112
+ s_qk_init = 1. ,
113
+ s_qk_scale = None
105
114
):
106
115
super ().__init__ ()
107
116
NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights )
@@ -167,7 +176,11 @@ def __init__(
167
176
dim ,
168
177
* ,
169
178
expand_factor = 4 ,
170
- manual_norm_weights = False
179
+ manual_norm_weights = False ,
180
+ s_hidden_init = 1. ,
181
+ s_hidden_scale = 1. ,
182
+ s_gate_init = 1. ,
183
+ s_gate_scale = 1.
171
184
):
172
185
super ().__init__ ()
173
186
NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights )
@@ -178,8 +191,8 @@ def __init__(
178
191
self .to_hidden = NormLinear_ (dim , dim_inner )
179
192
self .to_gate = NormLinear_ (dim , dim_inner )
180
193
181
- self .hidden_scale = Scale (dim_inner )
182
- self .gate_scale = Scale (dim_inner )
194
+ self .hidden_scale = Scale (dim_inner , s_hidden_init , s_hidden_scale )
195
+ self .gate_scale = Scale (dim_inner , s_gate_init , s_gate_scale )
183
196
184
197
self .to_out = NormLinear_ (dim_inner , dim , norm_dim_in = False )
185
198
@@ -206,31 +219,98 @@ def __init__(
206
219
attn_norm_qk = True , # they say the query/key normalization is optional
207
220
ff_expand_factor = 4. ,
208
221
ce_ignore_index = - 1 ,
209
- residual_lerp_scale_init = None ,
210
222
manual_norm_weights = False ,
211
- tied_embedding = False
223
+ tied_embedding = False ,
224
+ # below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
225
+ alpha_init : float | None = None , # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
226
+ s_logit_init : float = 1. ,
227
+ s_logit_scale : float | None = None ,
228
+ alpha_attn_init : float | tuple [float , ...] | None = None ,
229
+ alpha_attn_scale : float | tuple [float , ...] | None = None ,
230
+ alpha_ff_init : float | tuple [float , ...] | None = None ,
231
+ alpha_ff_scale : float | tuple [float , ...] | None = None ,
232
+ s_qk_init : float | tuple [float , ...] = 1. ,
233
+ s_qk_scale : float | tuple [float , ...] | None = None ,
234
+ s_ff_hidden_init : float | tuple [float , ...] = 1. ,
235
+ s_ff_hidden_scale : float | tuple [float , ...] = 1. ,
236
+ s_ff_gate_init : float | tuple [float , ...] = 1. ,
237
+ s_ff_gate_scale : float | tuple [float , ...] = 1.
212
238
):
213
239
super ().__init__ ()
214
240
NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights )
215
241
216
242
self .dim = dim
217
- residual_lerp_scale_init = default (residual_lerp_scale_init , 1. / depth )
243
+ alpha_init = default (alpha_init , 1. / depth )
218
244
219
245
self .token_embed = NormLinear_ (dim , num_tokens )
220
246
221
247
self .layers = ModuleList ([])
222
248
223
- for _ in range (depth ):
224
- self .layers .append (ModuleList ([
225
- Attention (dim , dim_head = dim_head , heads = heads , norm_qk = attn_norm_qk , manual_norm_weights = manual_norm_weights ),
226
- FeedForward (dim , expand_factor = ff_expand_factor , manual_norm_weights = manual_norm_weights ),
227
- Scale (dim , residual_lerp_scale_init , dim ** - 0.5 ),
228
- Scale (dim , residual_lerp_scale_init , dim ** - 0.5 ),
229
- ]))
249
+ scale_hparams = (
250
+ alpha_attn_init ,
251
+ alpha_attn_scale ,
252
+ alpha_ff_init ,
253
+ alpha_ff_scale ,
254
+ s_qk_init ,
255
+ s_qk_scale ,
256
+ s_ff_hidden_init ,
257
+ s_ff_hidden_scale ,
258
+ s_ff_gate_init ,
259
+ s_ff_gate_scale
260
+ )
261
+
262
+ scale_hparams = tuple (cast_tuple (hparam , depth ) for hparam in scale_hparams )
263
+
264
+ for (
265
+ alpha_attn_init_ ,
266
+ alpha_attn_scale_ ,
267
+ alpha_ff_init_ ,
268
+ alpha_ff_scale_ ,
269
+ s_qk_init_ ,
270
+ s_qk_scale_ ,
271
+ s_ff_hidden_init_ ,
272
+ s_ff_hidden_scale_ ,
273
+ s_ff_gate_init_ ,
274
+ s_ff_gate_scale_
275
+ ) in zip (* scale_hparams ):
276
+
277
+ attn = Attention (
278
+ dim ,
279
+ dim_head = dim_head ,
280
+ heads = heads ,
281
+ norm_qk = attn_norm_qk ,
282
+ manual_norm_weights = manual_norm_weights ,
283
+ s_qk_init = s_qk_init_ ,
284
+ s_qk_scale = s_qk_scale_ ,
285
+ )
286
+
287
+ ff = FeedForward (
288
+ dim ,
289
+ expand_factor = ff_expand_factor ,
290
+ manual_norm_weights = manual_norm_weights ,
291
+ s_hidden_init = s_ff_hidden_init_ ,
292
+ s_hidden_scale = s_ff_hidden_scale_ ,
293
+ s_gate_init = s_ff_gate_init_ ,
294
+ s_gate_scale = s_ff_gate_scale_
295
+ )
296
+
297
+ attn_interp_factor = Scale (
298
+ dim ,
299
+ default (alpha_attn_init_ , alpha_init ),
300
+ default (alpha_attn_scale_ , dim ** - 0.5 )
301
+ )
302
+
303
+ ff_interp_factor = Scale (
304
+ dim ,
305
+ default (alpha_ff_init_ , alpha_init ),
306
+ default (alpha_ff_scale_ , dim ** - 0.5 )
307
+ )
308
+
309
+ self .layers .append (ModuleList ([attn , ff , attn_interp_factor , ff_interp_factor ]))
230
310
231
311
self .to_logits = NormLinear_ (dim , num_tokens ) if not tied_embedding else None
232
312
233
- self .logit_scale = Scale (num_tokens , 1. , dim ** - 0.5 )
313
+ self .logit_scale = Scale (num_tokens , s_logit_init , default ( s_logit_scale , dim ** - 0.5 ) )
234
314
235
315
self .ignore_index = ce_ignore_index
236
316
0 commit comments