Skip to content

Commit 4dac9de

Browse files
committed
offer cosine sim attention, used in conjunction with numerically unstable memory efficient attention
1 parent ee594d3 commit 4dac9de

File tree

2 files changed

+212
-0
lines changed

2 files changed

+212
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from memory_efficient_attention_pytorch.memory_efficient_attention import Attention
2+
from memory_efficient_attention_pytorch.memory_efficient_cosine_sim_attention import CosineSimAttention
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
import math
2+
import torch
3+
import torch.nn.functional as F
4+
from functools import partial
5+
from torch import nn, einsum
6+
from torch.utils.checkpoint import checkpoint
7+
8+
from einops import rearrange
9+
10+
# helper functions
11+
12+
def exists(val):
13+
return val is not None
14+
15+
def default(val, d):
16+
return val if exists(val) else d
17+
18+
def l2norm(t):
19+
return F.normalize(t, dim = -1)
20+
21+
# regular attention
22+
23+
def attention(
24+
q, k, v,
25+
mask = None,
26+
causal = False,
27+
attn_bias = None,
28+
**kwargs
29+
):
30+
scale = q.shape[-1] ** -0.5
31+
q = q * scale
32+
33+
sim = einsum('b h i d, b h j d -> b h i j', q, k)
34+
35+
if exists(attn_bias):
36+
sim = sim + attn_bias
37+
38+
mask_value = -torch.finfo(sim.dtype).max
39+
40+
if exists(mask):
41+
mask = rearrange(mask, 'b j -> b 1 1 j')
42+
sim = sim.masked_fill(~mask, mask_value)
43+
44+
if causal:
45+
i, j = sim.shape[-2:]
46+
mask = torch.ones(i, j, device = q.device).triu(j - i + 1).bool()
47+
sim = sim.masked_fill(mask, mask_value)
48+
49+
attn = sim.softmax(dim = -1)
50+
51+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
52+
return out
53+
54+
# memory efficient attention
55+
56+
def summarize_qkv_chunk(q, k, v, mask, causal_mask, attn_bias_chunk):
57+
weight = einsum('b h i d, b h j d -> b h i j', q, k)
58+
59+
if exists(attn_bias_chunk):
60+
weight = weight + attn_bias_chunk
61+
62+
mask_value = -torch.finfo(weight.dtype).max
63+
64+
if exists(mask):
65+
mask = rearrange(mask, 'b j -> b 1 1 j')
66+
weight = weight.masked_fill(~mask, mask_value)
67+
68+
if exists(causal_mask):
69+
weight = weight.masked_fill(causal_mask, mask_value)
70+
71+
exp_weight = weight.exp()
72+
weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
73+
74+
return exp_weight.sum(dim = -1), weighted_value
75+
76+
checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)
77+
78+
def numerically_unstable_memory_efficient_attention(
79+
q, k, v,
80+
mask = None,
81+
causal = False,
82+
attn_bias = None,
83+
q_bucket_size = 512,
84+
k_bucket_size = 1024,
85+
eps = 1e-8
86+
):
87+
scale = q.shape[-1] ** -0.5
88+
q = q * scale
89+
90+
# chunk all the inputs
91+
92+
q_chunks = q.split(q_bucket_size, dim = -2)
93+
k_chunks = k.split(k_bucket_size, dim = -2)
94+
v_chunks = v.split(k_bucket_size, dim = -2)
95+
mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))
96+
97+
if causal:
98+
i, j = q.shape[-2], k.shape[-2]
99+
causal_mask = torch.ones(i, j, device = q.device).triu(j - i + 1).bool()
100+
causal_mask_chunks = causal_mask.split(q_bucket_size, dim = 0)
101+
causal_mask_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), causal_mask_chunks))
102+
103+
if exists(attn_bias):
104+
i, j = attn_bias.shape[-2:]
105+
attn_bias_chunks = attn_bias.split(q_bucket_size, dim = -2)
106+
attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim = -1), attn_bias_chunks))
107+
108+
# loop through all chunks and accumulate
109+
110+
out = []
111+
for q_index, q_chunk in enumerate(q_chunks):
112+
exp_weights = []
113+
weighted_values = []
114+
115+
for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
116+
117+
causal_mask_chunk = causal_mask_chunks[q_index][k_index] if causal else None
118+
119+
if exists(causal_mask_chunk) and torch.all(causal_mask_chunk):
120+
# if chunk is to be all masked out causally, skip
121+
continue
122+
123+
attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None
124+
125+
exp_weight_chunk, weighted_value_chunk = checkpointed_summarize_qkv_chunk(
126+
q_chunk,
127+
k_chunk,
128+
v_chunk,
129+
mask_chunk,
130+
causal_mask_chunk,
131+
attn_bias_chunk
132+
)
133+
134+
exp_weights.append(exp_weight_chunk)
135+
weighted_values.append(weighted_value_chunk)
136+
137+
all_values = sum(weighted_values)
138+
all_weights = sum(exp_weights)
139+
140+
normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
141+
out.append(normalized_values)
142+
143+
return torch.cat(out, dim = -2)
144+
145+
# main class
146+
147+
class CosineSimAttention(nn.Module):
148+
def __init__(
149+
self,
150+
*,
151+
dim,
152+
seq_len,
153+
heads = 8,
154+
dim_head = 64,
155+
dropout = 0.,
156+
causal = False,
157+
memory_efficient = False,
158+
q_bucket_size = 512,
159+
k_bucket_size = 1024
160+
):
161+
super().__init__()
162+
self.heads = heads
163+
self.causal = causal
164+
165+
inner_dim = heads * dim_head
166+
167+
scale_init_value = -math.log(math.log2(seq_len ** 2 - seq_len))
168+
self.scale = nn.Parameter(torch.full((1, heads, 1, 1), scale_init_value))
169+
170+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
171+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
172+
self.to_out = nn.Linear(inner_dim, dim)
173+
174+
# memory efficient attention related parameters
175+
# can be overriden on forward
176+
self.memory_efficient = memory_efficient
177+
self.q_bucket_size = q_bucket_size
178+
self.k_bucket_size = k_bucket_size
179+
180+
def forward(
181+
self,
182+
x,
183+
context = None,
184+
mask = None,
185+
attn_bias = None,
186+
memory_efficient = None,
187+
q_bucket_size = None,
188+
k_bucket_size = None,
189+
):
190+
memory_efficient = default(memory_efficient, self.memory_efficient)
191+
q_bucket_size = default(q_bucket_size, self.q_bucket_size)
192+
k_bucket_size = default(k_bucket_size, self.k_bucket_size)
193+
194+
h = self.heads
195+
context = default(context, x)
196+
197+
q = self.to_q(x)
198+
k, v = self.to_kv(context).chunk(2, dim = -1)
199+
200+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
201+
202+
q, k = map(l2norm, (q, k))
203+
204+
q = q * self.scale.exp()
205+
206+
attn_fn = attention if not memory_efficient else numerically_unstable_memory_efficient_attention
207+
208+
out = attn_fn(q, k, v, mask = mask, attn_bias = attn_bias, causal = self.causal, q_bucket_size = q_bucket_size, k_bucket_size = k_bucket_size)
209+
210+
out = rearrange(out, 'b h n d -> b n (h d)')
211+
return self.to_out(out)

0 commit comments

Comments
 (0)