Skip to content

Commit fcc4b16

Browse files
committed
address #11 with a hparam
1 parent 6129dee commit fcc4b16

File tree

3 files changed

+20
-4
lines changed

3 files changed

+20
-4
lines changed

nGPT_pytorch/nGPT.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ def __init__(
191191
enable_mem_efficient = True
192192
),
193193
norm_eps = 0.,
194-
num_hyperspheres = 1
194+
num_hyperspheres = 1,
195+
mask_value: float | None = None
195196
):
196197
super().__init__()
197198
self.heads = heads
@@ -214,6 +215,8 @@ def __init__(
214215
sdpa_backends = [SDP_BACKEND_MAP[enable_str] for enable_str, enable in flash_kwargs.items() if enable]
215216
self.sdpa_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
216217

218+
self.attn_mask_value = attn_mask_value
219+
217220
# qk rmsnorm + scale
218221

219222
self.norm_qk = norm_qk
@@ -263,6 +266,9 @@ def forward(
263266
if exists(mask):
264267
mask = rearrange(mask, 'b j -> b 1 1 j')
265268

269+
if exists(self.mask_value):
270+
mask = mask * self.mask_value
271+
266272
# scale is sqrt(dk)
267273

268274
with self.sdpa_context_manager():
@@ -339,6 +345,7 @@ def __init__(
339345
num_hyperspheres = 1,
340346
causal = True,
341347
add_value_residual = True,
348+
attn_mask_value: float | None = None, # address some issue with sdpa
342349
# below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
343350
alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
344351
s_logit_init: float = 1.,
@@ -414,6 +421,7 @@ def __init__(
414421
s_qk_init = s_qk_init_,
415422
s_qk_scale = s_qk_scale_,
416423
flash_kwargs = attn_flash_kwargs,
424+
mask_value = attn_mask_value,
417425
norm_eps = norm_eps,
418426
num_hyperspheres = num_hyperspheres
419427
)

nGPT_pytorch/nGPTExperimental.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ def __init__(
190190
enable_mem_efficient = True
191191
),
192192
norm_eps = 0.,
193-
num_hyperspheres = 1
193+
num_hyperspheres = 1,
194+
mask_value = None
194195
):
195196
super().__init__()
196197
self.heads = heads
@@ -213,6 +214,8 @@ def __init__(
213214
sdpa_backends = [SDP_BACKEND_MAP[enable_str] for enable_str, enable in flash_kwargs.items() if enable]
214215
self.sdpa_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)
215216

217+
self.mask_value = mask_value
218+
216219
# rotary
217220

218221
self.rotary_emb = RotaryEmbedding(dim_head)
@@ -263,6 +266,9 @@ def forward(
263266
if exists(mask):
264267
mask = rearrange(mask, 'b j -> b 1 1 j')
265268

269+
if exists(self.mask_value):
270+
mask = mask * self.mask_value
271+
266272
# scale is sqrt(dk)
267273

268274
with self.sdpa_context_manager():
@@ -335,6 +341,7 @@ def __init__(
335341
tied_embedding = False,
336342
num_hyperspheres = 1,
337343
causal = True,
344+
attn_mask_value: float | None = None,
338345
# below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
339346
alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
340347
s_logit_init: float = 1.,
@@ -407,7 +414,8 @@ def __init__(
407414
s_qk_scale = s_qk_scale_,
408415
flash_kwargs = attn_flash_kwargs,
409416
norm_eps = norm_eps,
410-
num_hyperspheres = num_hyperspheres
417+
num_hyperspheres = num_hyperspheres,
418+
mask_value = attn_mask_value
411419
)
412420

413421
ff = FeedForward(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "nGPT-pytorch"
3-
version = "0.2.2"
3+
version = "0.2.3"
44
description = "nGPT"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

0 commit comments

Comments
 (0)