Skip to content

Commit 343a696

Browse files
committed
allow for exploration into non-autoregressive
1 parent e00c3a0 commit 343a696

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

nGPT_pytorch/nGPT.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134
dim_head = 64,
135135
heads = 8,
136136
norm_qk = True,
137+
causal = True,
137138
manual_norm_weights = False,
138139
s_qk_init = 1.,
139140
s_qk_scale = None,
@@ -145,6 +146,8 @@ def __init__(
145146
norm_eps = 0.
146147
):
147148
super().__init__()
149+
self.causal = causal
150+
148151
NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps)
149152
self.l2norm = partial(l2norm, norm_eps = norm_eps)
150153

@@ -179,7 +182,8 @@ def __init__(
179182

180183
def forward(
181184
self,
182-
x
185+
x,
186+
mask = None
183187
):
184188
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
185189

@@ -202,12 +206,18 @@ def forward(
202206
q = self.rotary_emb.rotate_queries_or_keys(q)
203207
k = self.rotary_emb.rotate_queries_or_keys(k)
204208

209+
# for non-autoregressive masking
210+
211+
if exists(mask):
212+
mask = rearrange(mask, 'b j -> b 1 1 j')
213+
205214
# scale is sqrt(dk)
206215

207216
with self.sdpa_context_manager():
208217
out = F.scaled_dot_product_attention(
209218
q, k, v,
210-
is_causal = True,
219+
attn_mask = mask,
220+
is_causal = self.causal,
211221
scale = self.attn_scale
212222
)
213223

@@ -268,6 +278,7 @@ def __init__(
268278
ce_ignore_index = -1,
269279
manual_norm_weights = False,
270280
tied_embedding = False,
281+
causal = True,
271282
# below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
272283
alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
273284
s_logit_init: float = 1.,
@@ -294,6 +305,7 @@ def __init__(
294305
self.l2norm = partial(l2norm, norm_eps = norm_eps)
295306

296307
self.dim = dim
308+
self.causal = causal
297309
alpha_init = default(alpha_init, 1. / depth)
298310

299311
self.token_embed = NormLinear_(dim, num_tokens)
@@ -332,6 +344,7 @@ def __init__(
332344
dim,
333345
dim_head = dim_head,
334346
heads = heads,
347+
causal = causal,
335348
norm_qk = attn_norm_qk,
336349
manual_norm_weights = manual_norm_weights,
337350
s_qk_init = s_qk_init_,
@@ -382,18 +395,20 @@ def norm_weights_(self):
382395
def forward(
383396
self,
384397
ids,
398+
mask = None,
385399
return_loss = False
386400
):
387401
token_embed, l2norm = self.token_embed.weight, self.l2norm
388402

389403
if return_loss:
404+
assert self.causal
390405
ids, labels = ids[:, :-1], ids[:, 1:]
391406

392407
tokens = token_embed[ids]
393408

394409
for attn, ff, attn_alpha, ff_alpha in self.layers:
395410

396-
attn_out = l2norm(attn(tokens))
411+
attn_out = l2norm(attn(tokens, mask = mask))
397412
tokens = l2norm(tokens.lerp(attn_out, attn_alpha()))
398413

399414
ff_out = l2norm(ff(tokens))

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)