Skip to content

Commit 4f5b68d

Browse files
committed
Added full attention option
1 parent e5352fc commit 4f5b68d

File tree

4 files changed

+40
-11
lines changed

4 files changed

+40
-11
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ model = Linformer(
4747
checkpoint_level="C0", # What checkpoint level to use. For more information, see below.
4848
parameter_sharing="layerwise", # What level of parameter sharing to use. For more information, see below.
4949
k_reduce_by_layer=0, # Going down `depth`, how much to reduce `dim_k` by, for the `E` and `F` matrices. Will have a minimum value of 1.
50+
full_attention=False, # Use full attention instead, for O(n^2) time and space complexity. Included here just for comparison
5051
).cuda()
5152
x = torch.randn(1, 262144, 64).cuda()
5253
y = model(x)
@@ -70,6 +71,7 @@ model = MHAttention(
7071
checkpoint_level="C2", # If C2, checkpoint each of the heads
7172
parameter_sharing="layerwise", # What level of parameter sharing to do
7273
E_proj, F_proj, # The E and F projection matrices
74+
full_attention=False, # Use full attention instead
7375
)
7476
x = torch.randn(1, 512, 64)
7577
y = model(x)
@@ -85,7 +87,8 @@ import torch
8587
model = LinearAttentionHead(
8688
dim=64, # Dim 2 of the input
8789
dropout=0.1, # Dropout of the P matrix
88-
E_proj, F_proj # The E and F layers
90+
E_proj, F_proj, # The E and F layers
91+
full_attention=False, # Use Full Attention instead
8992
)
9093
x = torch.randn(1, 512, 64)
9194
y = model(x, x, x)

examples/example_full_attn.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import sys
2+
import torch
3+
4+
sys.path.insert(0, "../")
5+
from linformer_pytorch import Linformer
6+
7+
model = Linformer(
8+
input_size=512,
9+
channels=16,
10+
dim_k=16,
11+
dim_ff=32,
12+
nhead=4,
13+
depth=3,
14+
activation="relu",
15+
checkpoint_level="C1",
16+
parameter_sharing="none",
17+
k_reduce_by_layer=1,
18+
)
19+
x = torch.randn(1, 512, 16)
20+
y = model(x)
21+
print(y) # (1, 512, 16)

examples/pretrain_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
batch_size=16,
1212
lr=0.1,
1313
no_cuda=True,
14-
num_epochs=10,
14+
num_epochs=30,
1515
output_dir="./output",
1616
seed=2222,
1717

linformer_pytorch/linformer_pytorch.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class LinearAttentionHead(nn.Module):
6060
"""
6161
Linear attention, as proposed by the linformer paper
6262
"""
63-
def __init__(self, dim, dropout, E_proj, F_proj):
63+
def __init__(self, dim, dropout, E_proj, F_proj, full_attention=False):
6464
super(LinearAttentionHead, self).__init__()
6565
self.w_k = nn.Linear(dim, dim)
6666
self.w_q = nn.Linear(dim, dim)
@@ -70,6 +70,7 @@ def __init__(self, dim, dropout, E_proj, F_proj):
7070
self.dim = dim
7171
self.dropout = nn.Dropout(dropout)
7272
self.P_bar = None
73+
self.full_attention = full_attention
7374

7475
def forward(self, Q, K, V, **kwargs):
7576
"""
@@ -78,23 +79,27 @@ def forward(self, Q, K, V, **kwargs):
7879
"""
7980
KW = self.w_k(K)
8081
KW = torch.transpose(KW, 1, 2)
81-
KW = self.E(KW)
82+
if not self.full_attention:
83+
KW = self.E(KW)
8284
QW = self.w_q(Q)
8385
QW = torch.matmul(QW, KW)
8486

8587
P_bar = QW/torch.sqrt(torch.tensor(self.dim).type(Q.type()))
8688
P_bar = P_bar.softmax(dim=-1)
8789

90+
print(P_bar.shape)
8891
# Only save this when visualizing
8992
if "visualize" in kwargs and kwargs["visualize"] == True:
9093
self.P_bar = P_bar
9194

9295
P_bar = self.dropout(P_bar)
9396

9497
VW = self.w_v(V)
95-
VW = torch.transpose(VW, 1, 2)
96-
VW = self.F(VW)
97-
VW = torch.transpose(VW, 1, 2)
98+
99+
if not self.full_attention:
100+
VW = torch.transpose(VW, 1, 2)
101+
VW = self.F(VW)
102+
VW = torch.transpose(VW, 1, 2)
98103
out_tensor = torch.matmul(P_bar, VW)
99104

100105
return out_tensor
@@ -104,7 +109,7 @@ class MHAttention(nn.Module):
104109
Multihead attention, with each head being a Linformer Head
105110
This feeds directly into a feed forward head
106111
"""
107-
def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, F_proj):
112+
def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, E_proj, F_proj, full_attention):
108113
super(MHAttention, self).__init__()
109114
self.heads = nn.ModuleList()
110115
self.input_size = input_size
@@ -118,7 +123,7 @@ def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation,
118123
if parameter_sharing == "none":
119124
E_proj = get_EF(input_size, dim_k)
120125
F_proj = get_EF(input_size, dim_k)
121-
attn = LinearAttentionHead(dim, dropout, E_proj, F_proj)
126+
attn = LinearAttentionHead(dim, dropout, E_proj, F_proj, full_attention)
122127
self.heads.append(attn)
123128
self.w_o = nn.Linear(dim*nhead, channels)
124129
self.to_q = nn.Linear(channels, dim, bias=False)
@@ -147,7 +152,7 @@ class Linformer(nn.Module):
147152
My attempt at reproducing the Linformer Paper
148153
https://arxiv.org/pdf/2006.04768.pdf
149154
"""
150-
def __init__(self, input_size=8192, channels=128, dim_k=64, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", use_pos_emb=True, checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0):
155+
def __init__(self, input_size=8192, channels=128, dim_k=64, dim_ff=256, dim_d=None, dropout_ff=0.15, nhead=4, depth=1, dropout=0.1, activation="gelu", use_pos_emb=True, checkpoint_level="C0", parameter_sharing="layerwise", k_reduce_by_layer=0, full_attention=False):
151156
super(Linformer, self).__init__()
152157
assert activation == "gelu" or activation == "relu", "Only gelu and relu activations supported for now"
153158
assert checkpoint_level == "C0" or checkpoint_level == "C1" or checkpoint_level == "C2", "Checkpoint level has to be either C0, C1, or C2."
@@ -167,7 +172,7 @@ def __init__(self, input_size=8192, channels=128, dim_k=64, dim_ff=256, dim_d=No
167172
self.E = get_EF(input_size, dim_k)
168173
self.F = self.E
169174

170-
get_attn = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, self.E, self.F)
175+
get_attn = lambda curr_dim_k: MHAttention(input_size, head_dim, channels, curr_dim_k, nhead, dropout, activation, checkpoint_level, parameter_sharing, self.E, self.F, full_attention)
171176
get_ff = lambda: FeedForward(channels, dim_ff, dropout_ff)
172177
norm_attn = lambda: nn.LayerNorm(channels)
173178
norm_ff = lambda: nn.LayerNorm(channels)

0 commit comments

Comments
 (0)