13
13
14
14
from rotary_embedding_torch import RotaryEmbedding
15
15
16
+ # constants
17
+
18
+ from torch .nn .attention import SDPBackend
19
+
20
+ SDP_BACKEND_MAP = dict (
21
+ enable_flash = SDPBackend .FLASH_ATTENTION ,
22
+ enable_mem_efficient = SDPBackend .EFFICIENT_ATTENTION ,
23
+ enable_math = SDPBackend .MATH ,
24
+ enable_cudnn = SDPBackend .CUDNN_ATTENTION
25
+ )
26
+
16
27
# functions
17
28
18
29
def exists (v ):
@@ -26,8 +37,19 @@ def cast_tuple(t, length = 1):
26
37
assert len (out ) == length
27
38
return out
28
39
29
- def l2norm (t , dim = - 1 ):
30
- return F .normalize (t , dim = dim , p = 2 )
40
+ def l2norm (
41
+ t ,
42
+ dim = - 1 ,
43
+ norm_eps = 0.05 , # allow vectors to inhabit a small distance below and above the hypersphere if greater than 0.
44
+ eps = 1e-10
45
+ ):
46
+ if norm_eps == 0. :
47
+ return F .normalize (t , dim = dim , p = 2 , eps = eps )
48
+
49
+ norm = t .norm (dim = dim , keepdim = True )
50
+ target_norm = norm .detach ().clamp (min = 1. - norm_eps , max = 1. + norm_eps )
51
+ divisor = norm / target_norm
52
+ return t / divisor .clamp (min = eps )
31
53
32
54
# scale
33
55
@@ -51,26 +73,28 @@ def forward(self):
51
73
# for use with parametrize
52
74
53
75
class L2Norm (Module ):
54
- def __init__ (self , dim = - 1 ):
76
+ def __init__ (self , dim = - 1 , norm_eps = 0. ):
55
77
super ().__init__ ()
56
78
self .dim = dim
79
+ self .norm_eps = norm_eps
57
80
58
81
def forward (self , t ):
59
- return l2norm (t , dim = self .dim )
82
+ return l2norm (t , dim = self .dim , norm_eps = self . norm_eps )
60
83
61
84
class NormLinear (Module ):
62
85
def __init__ (
63
86
self ,
64
87
dim ,
65
88
dim_out ,
66
89
norm_dim_in = True ,
67
- parametrize = True
90
+ parametrize = True ,
91
+ norm_eps = 0.
68
92
):
69
93
super ().__init__ ()
70
94
self .linear = nn .Linear (dim , dim_out , bias = False )
71
95
72
96
self .parametrize = parametrize
73
- self .l2norm = L2Norm (dim = - 1 if norm_dim_in else 0 )
97
+ self .l2norm = L2Norm (dim = - 1 if norm_dim_in else 0 , norm_eps = norm_eps )
74
98
75
99
if parametrize :
76
100
register_parametrization (
@@ -98,7 +122,7 @@ def weight(self):
98
122
def forward (self , x ):
99
123
return self .linear (x )
100
124
101
- # attention and feedforward
125
+ # attention
102
126
103
127
class Attention (Module ):
104
128
def __init__ (
@@ -110,10 +134,17 @@ def __init__(
110
134
norm_qk = True ,
111
135
manual_norm_weights = False ,
112
136
s_qk_init = 1. ,
113
- s_qk_scale = None
137
+ s_qk_scale = None ,
138
+ flash_kwargs : dict = dict (
139
+ enable_flash = True ,
140
+ enable_math = True ,
141
+ enable_mem_efficient = True
142
+ ),
143
+ norm_eps = 0.
114
144
):
115
145
super ().__init__ ()
116
- NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights )
146
+ NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights , norm_eps = norm_eps )
147
+ self .l2norm = partial (l2norm , norm_eps = norm_eps )
117
148
118
149
dim_sqrt = dim ** 0.5
119
150
self .dim_sqrt = dim_sqrt
@@ -124,11 +155,21 @@ def __init__(
124
155
self .to_k = NormLinear_ (dim , dim_inner )
125
156
self .to_v = NormLinear_ (dim , dim_inner )
126
157
158
+ # flash attention related context manager
159
+
160
+ sdpa_backends = [SDP_BACKEND_MAP [enable_str ] for enable_str , enable in flash_kwargs .items () if enable ]
161
+ self .sdpa_context_manager = partial (torch .nn .attention .sdpa_kernel , sdpa_backends )
162
+
163
+ # rotary
164
+
127
165
self .rotary_emb = RotaryEmbedding (dim_head )
128
- self . q_scale = Scale ( dim , 1 , dim ** - 0.5 )
129
- self . k_scale = Scale ( dim , 1 , dim ** - 0.5 )
166
+
167
+ # qk rmsnorm + scale
130
168
131
169
self .norm_qk = norm_qk
170
+ self .q_scale = Scale (dim , s_qk_init , default (s_qk_scale , dim ** - 0.5 ))
171
+ self .k_scale = Scale (dim , s_qk_init , default (s_qk_scale , dim ** - 0.5 ))
172
+
132
173
self .split_heads = Rearrange ('b n (h d) -> b h n d' , h = heads )
133
174
self .merge_heads = Rearrange ('b h n d -> b n (h d)' )
134
175
@@ -152,7 +193,7 @@ def forward(
152
193
# maybe query key norm
153
194
154
195
if self .norm_qk :
155
- q , k = map (l2norm , (q , k ))
196
+ q , k = map (self . l2norm , (q , k ))
156
197
157
198
# rotary positions
158
199
@@ -161,15 +202,18 @@ def forward(
161
202
162
203
# scale is sqrt(dk)
163
204
164
- out = F .scaled_dot_product_attention (
165
- q , k , v ,
166
- is_causal = True ,
167
- scale = self .attn_scale
168
- )
205
+ with self .sdpa_context_manager ():
206
+ out = F .scaled_dot_product_attention (
207
+ q , k , v ,
208
+ is_causal = True ,
209
+ scale = self .attn_scale
210
+ )
169
211
170
212
out = self .merge_heads (out )
171
213
return self .to_out (out )
172
214
215
+ # feedforward
216
+
173
217
class FeedForward (Module ):
174
218
def __init__ (
175
219
self ,
@@ -180,10 +224,11 @@ def __init__(
180
224
s_hidden_init = 1. ,
181
225
s_hidden_scale = 1. ,
182
226
s_gate_init = 1. ,
183
- s_gate_scale = 1.
227
+ s_gate_scale = 1. ,
228
+ norm_eps = 0.
184
229
):
185
230
super ().__init__ ()
186
- NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights )
231
+ NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights , norm_eps = norm_eps )
187
232
188
233
self .dim = dim
189
234
dim_inner = int (dim * expand_factor * 2 / 3 )
@@ -234,10 +279,17 @@ def __init__(
234
279
s_ff_hidden_init : float | tuple [float , ...] = 1. ,
235
280
s_ff_hidden_scale : float | tuple [float , ...] = 1. ,
236
281
s_ff_gate_init : float | tuple [float , ...] = 1. ,
237
- s_ff_gate_scale : float | tuple [float , ...] = 1.
282
+ s_ff_gate_scale : float | tuple [float , ...] = 1. ,
283
+ attn_flash_kwargs : dict = dict (
284
+ enable_flash = True ,
285
+ enable_math = True ,
286
+ enable_mem_efficient = True
287
+ ),
288
+ norm_eps = 0. # greater than 0 allows the norm to be around (1. - norm_eps) to (1. + norm_eps)
238
289
):
239
290
super ().__init__ ()
240
- NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights )
291
+ NormLinear_ = partial (NormLinear , parametrize = not manual_norm_weights , norm_eps = norm_eps )
292
+ self .l2norm = partial (l2norm , norm_eps = norm_eps )
241
293
242
294
self .dim = dim
243
295
alpha_init = default (alpha_init , 1. / depth )
@@ -282,6 +334,8 @@ def __init__(
282
334
manual_norm_weights = manual_norm_weights ,
283
335
s_qk_init = s_qk_init_ ,
284
336
s_qk_scale = s_qk_scale_ ,
337
+ flash_kwargs = attn_flash_kwargs ,
338
+ norm_eps = norm_eps
285
339
)
286
340
287
341
ff = FeedForward (
@@ -291,7 +345,8 @@ def __init__(
291
345
s_hidden_init = s_ff_hidden_init_ ,
292
346
s_hidden_scale = s_ff_hidden_scale_ ,
293
347
s_gate_init = s_ff_gate_init_ ,
294
- s_gate_scale = s_ff_gate_scale_
348
+ s_gate_scale = s_ff_gate_scale_ ,
349
+ norm_eps = norm_eps
295
350
)
296
351
297
352
attn_interp_factor = Scale (
@@ -327,11 +382,11 @@ def forward(
327
382
ids ,
328
383
return_loss = False
329
384
):
385
+ token_embed , l2norm = self .token_embed .weight , self .l2norm
330
386
331
387
if return_loss :
332
388
ids , labels = ids [:, :- 1 ], ids [:, 1 :]
333
389
334
- token_embed = self .token_embed .weight
335
390
tokens = token_embed [ids ]
336
391
337
392
for attn , ff , attn_alpha , ff_alpha in self .layers :
0 commit comments