@@ -60,7 +60,7 @@ class LinearAttentionHead(nn.Module):
60
60
"""
61
61
Linear attention, as proposed by the linformer paper
62
62
"""
63
- def __init__ (self , dim , dropout , E_proj , F_proj ):
63
+ def __init__ (self , dim , dropout , E_proj , F_proj , full_attention = False ):
64
64
super (LinearAttentionHead , self ).__init__ ()
65
65
self .w_k = nn .Linear (dim , dim )
66
66
self .w_q = nn .Linear (dim , dim )
@@ -70,6 +70,7 @@ def __init__(self, dim, dropout, E_proj, F_proj):
70
70
self .dim = dim
71
71
self .dropout = nn .Dropout (dropout )
72
72
self .P_bar = None
73
+ self .full_attention = full_attention
73
74
74
75
def forward (self , Q , K , V , ** kwargs ):
75
76
"""
@@ -78,23 +79,27 @@ def forward(self, Q, K, V, **kwargs):
78
79
"""
79
80
KW = self .w_k (K )
80
81
KW = torch .transpose (KW , 1 , 2 )
81
- KW = self .E (KW )
82
+ if not self .full_attention :
83
+ KW = self .E (KW )
82
84
QW = self .w_q (Q )
83
85
QW = torch .matmul (QW , KW )
84
86
85
87
P_bar = QW / torch .sqrt (torch .tensor (self .dim ).type (Q .type ()))
86
88
P_bar = P_bar .softmax (dim = - 1 )
87
89
90
+ print (P_bar .shape )
88
91
# Only save this when visualizing
89
92
if "visualize" in kwargs and kwargs ["visualize" ] == True :
90
93
self .P_bar = P_bar
91
94
92
95
P_bar = self .dropout (P_bar )
93
96
94
97
VW = self .w_v (V )
95
- VW = torch .transpose (VW , 1 , 2 )
96
- VW = self .F (VW )
97
- VW = torch .transpose (VW , 1 , 2 )
98
+
99
+ if not self .full_attention :
100
+ VW = torch .transpose (VW , 1 , 2 )
101
+ VW = self .F (VW )
102
+ VW = torch .transpose (VW , 1 , 2 )
98
103
out_tensor = torch .matmul (P_bar , VW )
99
104
100
105
return out_tensor
@@ -104,7 +109,7 @@ class MHAttention(nn.Module):
104
109
Multihead attention, with each head being a Linformer Head
105
110
This feeds directly into a feed forward head
106
111
"""
107
- def __init__ (self , input_size , dim , channels , dim_k , nhead , dropout , activation , checkpoint_level , parameter_sharing , E_proj , F_proj ):
112
+ def __init__ (self , input_size , dim , channels , dim_k , nhead , dropout , activation , checkpoint_level , parameter_sharing , E_proj , F_proj , full_attention ):
108
113
super (MHAttention , self ).__init__ ()
109
114
self .heads = nn .ModuleList ()
110
115
self .input_size = input_size
@@ -118,7 +123,7 @@ def __init__(self, input_size, dim, channels, dim_k, nhead, dropout, activation,
118
123
if parameter_sharing == "none" :
119
124
E_proj = get_EF (input_size , dim_k )
120
125
F_proj = get_EF (input_size , dim_k )
121
- attn = LinearAttentionHead (dim , dropout , E_proj , F_proj )
126
+ attn = LinearAttentionHead (dim , dropout , E_proj , F_proj , full_attention )
122
127
self .heads .append (attn )
123
128
self .w_o = nn .Linear (dim * nhead , channels )
124
129
self .to_q = nn .Linear (channels , dim , bias = False )
@@ -147,7 +152,7 @@ class Linformer(nn.Module):
147
152
My attempt at reproducing the Linformer Paper
148
153
https://arxiv.org/pdf/2006.04768.pdf
149
154
"""
150
- def __init__ (self , input_size = 8192 , channels = 128 , dim_k = 64 , dim_ff = 256 , dim_d = None , dropout_ff = 0.15 , nhead = 4 , depth = 1 , dropout = 0.1 , activation = "gelu" , use_pos_emb = True , checkpoint_level = "C0" , parameter_sharing = "layerwise" , k_reduce_by_layer = 0 ):
155
+ def __init__ (self , input_size = 8192 , channels = 128 , dim_k = 64 , dim_ff = 256 , dim_d = None , dropout_ff = 0.15 , nhead = 4 , depth = 1 , dropout = 0.1 , activation = "gelu" , use_pos_emb = True , checkpoint_level = "C0" , parameter_sharing = "layerwise" , k_reduce_by_layer = 0 , full_attention = False ):
151
156
super (Linformer , self ).__init__ ()
152
157
assert activation == "gelu" or activation == "relu" , "Only gelu and relu activations supported for now"
153
158
assert checkpoint_level == "C0" or checkpoint_level == "C1" or checkpoint_level == "C2" , "Checkpoint level has to be either C0, C1, or C2."
@@ -167,7 +172,7 @@ def __init__(self, input_size=8192, channels=128, dim_k=64, dim_ff=256, dim_d=No
167
172
self .E = get_EF (input_size , dim_k )
168
173
self .F = self .E
169
174
170
- get_attn = lambda curr_dim_k : MHAttention (input_size , head_dim , channels , curr_dim_k , nhead , dropout , activation , checkpoint_level , parameter_sharing , self .E , self .F )
175
+ get_attn = lambda curr_dim_k : MHAttention (input_size , head_dim , channels , curr_dim_k , nhead , dropout , activation , checkpoint_level , parameter_sharing , self .E , self .F , full_attention )
171
176
get_ff = lambda : FeedForward (channels , dim_ff , dropout_ff )
172
177
norm_attn = lambda : nn .LayerNorm (channels )
173
178
norm_ff = lambda : nn .LayerNorm (channels )
0 commit comments