Skip to content

Commit 79c0ac5

Browse files
committed
flash attention and move the quasi hypersphere idea into the same file, given the simplicity is already loss with all the scaling hparams
1 parent c58093f commit 79c0ac5

File tree

4 files changed

+81
-277
lines changed

4 files changed

+81
-277
lines changed

nGPT_pytorch/nGPT.py

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@
1313

1414
from rotary_embedding_torch import RotaryEmbedding
1515

16+
# constants
17+
18+
from torch.nn.attention import SDPBackend
19+
20+
SDP_BACKEND_MAP = dict(
21+
enable_flash = SDPBackend.FLASH_ATTENTION,
22+
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION,
23+
enable_math = SDPBackend.MATH,
24+
enable_cudnn = SDPBackend.CUDNN_ATTENTION
25+
)
26+
1627
# functions
1728

1829
def exists(v):
@@ -26,8 +37,19 @@ def cast_tuple(t, length = 1):
2637
assert len(out) == length
2738
return out
2839

29-
def l2norm(t, dim = -1):
30-
return F.normalize(t, dim = dim, p = 2)
40+
def l2norm(
41+
t,
42+
dim = -1,
43+
norm_eps = 0.05, # allow vectors to inhabit a small distance below and above the hypersphere if greater than 0.
44+
eps = 1e-10
45+
):
46+
if norm_eps == 0.:
47+
return F.normalize(t, dim = dim, p = 2, eps = eps)
48+
49+
norm = t.norm(dim = dim, keepdim = True)
50+
target_norm = norm.detach().clamp(min = 1. - norm_eps, max = 1. + norm_eps)
51+
divisor = norm / target_norm
52+
return t / divisor.clamp(min = eps)
3153

3254
# scale
3355

@@ -51,26 +73,28 @@ def forward(self):
5173
# for use with parametrize
5274

5375
class L2Norm(Module):
54-
def __init__(self, dim = -1):
76+
def __init__(self, dim = -1, norm_eps = 0.):
5577
super().__init__()
5678
self.dim = dim
79+
self.norm_eps = norm_eps
5780

5881
def forward(self, t):
59-
return l2norm(t, dim = self.dim)
82+
return l2norm(t, dim = self.dim, norm_eps = self.norm_eps)
6083

6184
class NormLinear(Module):
6285
def __init__(
6386
self,
6487
dim,
6588
dim_out,
6689
norm_dim_in = True,
67-
parametrize = True
90+
parametrize = True,
91+
norm_eps = 0.
6892
):
6993
super().__init__()
7094
self.linear = nn.Linear(dim, dim_out, bias = False)
7195

7296
self.parametrize = parametrize
73-
self.l2norm = L2Norm(dim = -1 if norm_dim_in else 0)
97+
self.l2norm = L2Norm(dim = -1 if norm_dim_in else 0, norm_eps = norm_eps)
7498

7599
if parametrize:
76100
register_parametrization(
@@ -98,7 +122,7 @@ def weight(self):
98122
def forward(self, x):
99123
return self.linear(x)
100124

101-
# attention and feedforward
125+
# attention
102126

103127
class Attention(Module):
104128
def __init__(
@@ -110,10 +134,17 @@ def __init__(
110134
norm_qk = True,
111135
manual_norm_weights = False,
112136
s_qk_init = 1.,
113-
s_qk_scale = None
137+
s_qk_scale = None,
138+
flash_kwargs: dict = dict(
139+
enable_flash = True,
140+
enable_math = True,
141+
enable_mem_efficient = True
142+
),
143+
norm_eps = 0.
114144
):
115145
super().__init__()
116-
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
146+
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps)
147+
self.l2norm = partial(l2norm, norm_eps = norm_eps)
117148

118149
dim_sqrt = dim ** 0.5
119150
self.dim_sqrt = dim_sqrt
@@ -124,11 +155,21 @@ def __init__(
124155
self.to_k = NormLinear_(dim, dim_inner)
125156
self.to_v = NormLinear_(dim, dim_inner)
126157

158+
# flash attention related context manager
159+
160+
sdpa_backends = [SDP_BACKEND_MAP[enable_str] for enable_str, enable in flash_kwargs.items() if enable]
161+
self.sdpa_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
162+
163+
# rotary
164+
127165
self.rotary_emb = RotaryEmbedding(dim_head)
128-
self.q_scale = Scale(dim, 1, dim ** -0.5)
129-
self.k_scale = Scale(dim, 1, dim ** -0.5)
166+
167+
# qk rmsnorm + scale
130168

131169
self.norm_qk = norm_qk
170+
self.q_scale = Scale(dim, s_qk_init, default(s_qk_scale, dim ** -0.5))
171+
self.k_scale = Scale(dim, s_qk_init, default(s_qk_scale, dim ** -0.5))
172+
132173
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
133174
self.merge_heads = Rearrange('b h n d -> b n (h d)')
134175

@@ -152,7 +193,7 @@ def forward(
152193
# maybe query key norm
153194

154195
if self.norm_qk:
155-
q, k = map(l2norm, (q, k))
196+
q, k = map(self.l2norm, (q, k))
156197

157198
# rotary positions
158199

@@ -161,15 +202,18 @@ def forward(
161202

162203
# scale is sqrt(dk)
163204

164-
out = F.scaled_dot_product_attention(
165-
q, k, v,
166-
is_causal = True,
167-
scale = self.attn_scale
168-
)
205+
with self.sdpa_context_manager():
206+
out = F.scaled_dot_product_attention(
207+
q, k, v,
208+
is_causal = True,
209+
scale = self.attn_scale
210+
)
169211

170212
out = self.merge_heads(out)
171213
return self.to_out(out)
172214

215+
# feedforward
216+
173217
class FeedForward(Module):
174218
def __init__(
175219
self,
@@ -180,10 +224,11 @@ def __init__(
180224
s_hidden_init = 1.,
181225
s_hidden_scale = 1.,
182226
s_gate_init = 1.,
183-
s_gate_scale = 1.
227+
s_gate_scale = 1.,
228+
norm_eps = 0.
184229
):
185230
super().__init__()
186-
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
231+
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps)
187232

188233
self.dim = dim
189234
dim_inner = int(dim * expand_factor * 2 / 3)
@@ -234,10 +279,17 @@ def __init__(
234279
s_ff_hidden_init: float | tuple[float, ...] = 1.,
235280
s_ff_hidden_scale: float | tuple[float, ...] = 1.,
236281
s_ff_gate_init: float | tuple[float, ...] = 1.,
237-
s_ff_gate_scale: float | tuple[float, ...] = 1.
282+
s_ff_gate_scale: float | tuple[float, ...] = 1.,
283+
attn_flash_kwargs: dict = dict(
284+
enable_flash = True,
285+
enable_math = True,
286+
enable_mem_efficient = True
287+
),
288+
norm_eps = 0. # greater than 0 allows the norm to be around (1. - norm_eps) to (1. + norm_eps)
238289
):
239290
super().__init__()
240-
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights)
291+
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps)
292+
self.l2norm = partial(l2norm, norm_eps = norm_eps)
241293

242294
self.dim = dim
243295
alpha_init = default(alpha_init, 1. / depth)
@@ -282,6 +334,8 @@ def __init__(
282334
manual_norm_weights = manual_norm_weights,
283335
s_qk_init = s_qk_init_,
284336
s_qk_scale = s_qk_scale_,
337+
flash_kwargs = attn_flash_kwargs,
338+
norm_eps = norm_eps
285339
)
286340

287341
ff = FeedForward(
@@ -291,7 +345,8 @@ def __init__(
291345
s_hidden_init = s_ff_hidden_init_,
292346
s_hidden_scale = s_ff_hidden_scale_,
293347
s_gate_init = s_ff_gate_init_,
294-
s_gate_scale = s_ff_gate_scale_
348+
s_gate_scale = s_ff_gate_scale_,
349+
norm_eps = norm_eps
295350
)
296351

297352
attn_interp_factor = Scale(
@@ -327,11 +382,11 @@ def forward(
327382
ids,
328383
return_loss = False
329384
):
385+
token_embed, l2norm = self.token_embed.weight, self.l2norm
330386

331387
if return_loss:
332388
ids, labels = ids[:, :-1], ids[:, 1:]
333389

334-
token_embed = self.token_embed.weight
335390
tokens = token_embed[ids]
336391

337392
for attn, ff, attn_alpha, ff_alpha in self.layers:

0 commit comments

Comments
 (0)