-
Notifications
You must be signed in to change notification settings - Fork 63
Description
https://github.com/feifeibear/long-context-attention/blob/0.6.0/yunchang/kernels/attention.py#L47
torch版本时torch 2.3
您好,我实验发现对于 下面的case
batch_size = 1
num_heads = 2
head_dim = 128
seq_len = 16
torch.ops.aten._scaled_dot_product_efficient_attention 的lse返回值有bug
这里建议使用 _scaled_dot_product_flash_attention这个函数,返回lse
import torch
batch_size = 1
num_heads = 2
head_dim = 128
seq_len = 16
dtype = torch.float16
device = 'cuda'
torch.manual_seed(42)
query = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
key = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
value = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
out1 = torch.nn.functional.scaled_dot_product_attention(query,key,value)
out2, lse2 = torch.ops.aten._scaled_dot_product_flash_attention(query, key, value)[:2]
print(f'lse2: {lse2}')
out3, lse3 = torch.ops.aten._scaled_dot_product_efficient_attention(query, key,value,
attn_bias=None, compute_log_sumexp=True)[:2]
print(f'lse3: {lse3}')
print(f'Result: {torch.allclose(out1, out2, rtol=1e-3, atol=1e-3)}')
print(f'Result: {torch.allclose(out1, out3, rtol=1e-3, atol=1e-3)}')
print(f'Result: {torch.allclose(lse2, lse3, rtol=1e-3, atol=1e-3)}')
结果:
lse2: tensor([[[5.4660, 5.7685, 5.4396, 5.6371, 5.3691, 5.4923, 5.4666, 5.4459,
5.5439, 5.7281, 5.8045, 5.6622, 5.7923, 5.6987, 5.5620, 5.5473],
[5.4989, 5.8029, 5.5886, 5.5052, 5.6427, 5.5984, 5.7117, 5.4015,
5.6134, 5.5992, 5.4512, 5.8386, 5.8852, 5.3351, 5.6285, 5.6732]]],
device='cuda:0')
lse3: tensor([[[5.4660, 5.7685, 5.4396, 5.6371, 5.3691, 5.4923, 5.4666, 5.4459,
5.5439, 5.7281, 5.8045, 5.6622, 5.7923, 5.6987, 5.5620, 5.5473,
inf, inf, inf, inf, inf, inf, inf, inf,
inf, inf, inf, inf, inf, inf, inf, inf],
[5.4989, 5.8029, 5.5886, 5.5052, 5.6427, 5.5984, 5.7117, 5.4015,
5.6134, 5.5992, 5.4512, 5.8386, 5.8852, 5.3351, 5.6285, 5.6732,
inf, inf, inf, inf, inf, inf, inf, inf,
inf, inf, inf, inf, inf, inf, inf, inf]]],
device='cuda:0')
Result: True
Result: True