@@ -30,7 +30,6 @@ def attention(
30
30
if exists (attn_bias ):
31
31
sim = sim + attn_bias
32
32
33
- sim = sim - sim .amax (dim = - 1 , keepdim = True ).detach ()
34
33
mask_value = - torch .finfo (sim .dtype ).max
35
34
36
35
if exists (mask ):
@@ -42,6 +41,7 @@ def attention(
42
41
mask = torch .ones (i , j , device = q .device , dtype = torch .bool ).triu (j - i + 1 )
43
42
sim = sim .masked_fill (mask , mask_value )
44
43
44
+ sim = sim - sim .amax (dim = - 1 , keepdim = True ).detach ()
45
45
attn = sim .softmax (dim = - 1 )
46
46
47
47
out = einsum ('b h i j, b h j d -> b h i d' , attn , v )
@@ -55,9 +55,6 @@ def summarize_qkv_chunk(q, k, v, mask, causal_mask, attn_bias_chunk):
55
55
if exists (attn_bias_chunk ):
56
56
weight = weight + attn_bias_chunk
57
57
58
- weight_max = weight .amax (dim = - 1 , keepdim = True ).detach ()
59
- weight = weight - weight_max
60
-
61
58
mask_value = - torch .finfo (weight .dtype ).max
62
59
63
60
if exists (mask ):
@@ -67,6 +64,9 @@ def summarize_qkv_chunk(q, k, v, mask, causal_mask, attn_bias_chunk):
67
64
if exists (causal_mask ):
68
65
weight = weight .masked_fill (causal_mask , mask_value )
69
66
67
+ weight_max = weight .amax (dim = - 1 , keepdim = True ).detach ()
68
+ weight = weight - weight_max
69
+
70
70
exp_weight = weight .exp ()
71
71
weighted_value = einsum ('b h i j, b h j d -> b h i d' , exp_weight , v )
72
72
0 commit comments