Skip to content

Commit 34f6c59

Browse files
committed
just tempt some student into trying laser attention
1 parent 0791dfe commit 34f6c59

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,3 +514,12 @@ docker run -v .:/data --gpus all -it af3
514514
url = {https://api.semanticscholar.org/CorpusID:273532030}
515515
}
516516
```
517+
518+
```bibtex
519+
@inproceedings{Duvvuri2024LASERAW,
520+
title = {LASER: Attention with Exponential Transformation},
521+
author = {Sai Surya Duvvuri and Inderjit S. Dhillon},
522+
year = {2024},
523+
url = {https://api.semanticscholar.org/CorpusID:273849947}
524+
}
525+
```

alphafold3_pytorch/attention.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def pack_one(t, pattern):
4040
def unpack_one(t, ps, pattern):
4141
return unpack(t, ps, pattern)[0]
4242

43+
def log(t, eps = 1e-20):
44+
return t.clamp(min = eps).log()
45+
4346
def softclamp(t, value):
4447
return (t / value).tanh() * value
4548

@@ -181,6 +184,7 @@ def __init__(
181184
query_bias = True,
182185
window_size = None,
183186
num_memory_kv: int = 0,
187+
laser = False,
184188
enable_attn_softclamp = False,
185189
attn_softclamp_value = 50.,
186190
softmax_full_precision = False
@@ -202,6 +206,7 @@ def __init__(
202206
dim_inner = dim_head * heads
203207

204208
self.attend = Attend(
209+
laser = laser,
205210
dropout = dropout,
206211
window_size = window_size,
207212
enable_attn_softclamp = enable_attn_softclamp,
@@ -299,6 +304,7 @@ class Attend(Module):
299304
def __init__(
300305
self,
301306
dropout = 0.,
307+
laser = False,
302308
window_size = None,
303309
scale: float | None = None,
304310
enable_attn_softclamp = False,
@@ -327,6 +333,10 @@ def __init__(
327333

328334
self.attn_dropout = nn.Dropout(dropout)
329335

336+
# laser attention
337+
338+
self.laser = laser
339+
330340
# softclamp attention logits
331341
# being adopted by a number of recent llms (gemma, grok)
332342

@@ -447,10 +457,21 @@ def local_attn(
447457

448458
attn = sim.softmax(dim = -1)
449459

460+
# maybe laser
461+
462+
if self.laser:
463+
v_max = v.amax(dim = -2, keepdim = True)
464+
v = (v - v_max).exp()
465+
450466
# aggregate
451467

452468
out = einsum(attn, v, "... i j, ... j d -> ... i d")
453469

470+
# maybe laser
471+
472+
if self.laser:
473+
out = log(out) + v_max
474+
454475
# un-window the output
455476

456477
out = rearrange(out, "b h n w d -> b h (n w) d")
@@ -546,8 +567,19 @@ def forward(
546567

547568
attn = self.attn_dropout(attn)
548569

570+
# maybe laser
571+
572+
if self.laser:
573+
v_max = v.amax(dim = -2, keepdim = True)
574+
v = (v - v_max).exp()
575+
549576
# aggregate values
550577

551578
out = einsum(attn, v, "b h i j, b h j d -> b h i d")
552579

580+
# maybe laser
581+
582+
if self.laser:
583+
out = log(out) + v_max
584+
553585
return out

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.6.8"
3+
version = "0.6.10"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" },

0 commit comments

Comments
 (0)