Skip to content

Commit bb54873

Browse files
committed
throw in laser attention
1 parent bc5f10f commit bb54873

File tree

5 files changed

+30
-4
lines changed

5 files changed

+30
-4
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,12 @@ sampled = model.generate_text_only(text[:, :1], 1024)
234234
url = {https://api.semanticscholar.org/CorpusID:273532030}
235235
}
236236
```
237+
238+
```bibtex
239+
@inproceedings{Duvvuri2024LASERAW,
240+
title = {LASER: Attention with Exponential Transformation},
241+
author = {Sai Surya Duvvuri and Inderjit S. Dhillon},
242+
year = {2024},
243+
url = {https://api.semanticscholar.org/CorpusID:273849947}
244+
}
245+
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.6.0"
3+
version = "0.6.3"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

train_image_only.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def forward(self, x):
5151
dim = 64,
5252
depth = 4,
5353
dim_head = 32,
54-
heads = 8
54+
heads = 8,
55+
attn_laser = True
5556
)
5657
).cuda()
5758

train_text_only.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def decode_tokens(tokens):
5050
dim = 384,
5151
depth = 8,
5252
dim_head = 64,
53-
heads = 8
53+
heads = 8,
54+
attn_laser = True
5455
)
5556
).cuda()
5657

transfusion_pytorch/transfusion.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,7 @@ def __init__(
756756
softcap_value = 50.,
757757
use_flex_attn = False,
758758
gate_values = True,
759+
laser = False,
759760
learned_value_residual_mix = False
760761
):
761762
super().__init__()
@@ -783,6 +784,8 @@ def __init__(
783784

784785
self.softcap_value = softcap_value
785786

787+
self.laser = laser
788+
786789
self.dropout = nn.Dropout(dropout)
787790

788791
self.to_out = nn.Sequential(
@@ -844,6 +847,12 @@ def forward(
844847
if exists(rotary_emb):
845848
q, k = tuple(apply_rotary_emb(rotary_emb, t, freqs_seq_dim = -2) for t in (q, k))
846849

850+
# laser attention
851+
852+
if self.laser:
853+
v_max = v.amax(dim = -2, keepdim = True).detach()
854+
v = (v - v_max).exp()
855+
847856
# whether to use flex attention or not
848857

849858
if should_use_flex_attn:
@@ -878,6 +887,11 @@ def forward(
878887

879888
out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
880889

890+
# laser attention
891+
892+
if self.laser:
893+
out = log(out) + v_max
894+
881895
# maybe gate values
882896

883897
if exists(self.to_gates):
@@ -908,6 +922,7 @@ def __init__(
908922
ff_expansion_factor = 4,
909923
attn_kwargs: dict = dict(),
910924
ff_kwargs: dict = dict(),
925+
attn_laser = False,
911926
unet_skips = True,
912927
use_flex_attn = False
913928
):
@@ -932,7 +947,7 @@ def __init__(
932947

933948
skip_proj = Linear(dim * 2, dim, bias = False) if is_latter_half and unet_skips else None
934949

935-
attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout, use_flex_attn = use_flex_attn, learned_value_residual_mix = not is_first, **attn_kwargs)
950+
attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout, use_flex_attn = use_flex_attn, learned_value_residual_mix = not is_first, laser = attn_laser, **attn_kwargs)
936951

937952
ff = FeedForward(dim = dim, expansion_factor = ff_expansion_factor, **ff_kwargs)
938953

0 commit comments

Comments
 (0)