Skip to content

_scaled_dot_product_efficient_attention a bug for lse #122

@neonhuang

Description

@neonhuang

https://github.com/feifeibear/long-context-attention/blob/0.6.0/yunchang/kernels/attention.py#L47
torch版本时torch 2.3
您好,我实验发现对于 下面的case
batch_size = 1
num_heads = 2
head_dim = 128
seq_len = 16

torch.ops.aten._scaled_dot_product_efficient_attention 的lse返回值有bug
这里建议使用 _scaled_dot_product_flash_attention这个函数,返回lse

import torch
batch_size = 1
num_heads = 2
head_dim = 128
seq_len = 16

dtype = torch.float16
device = 'cuda'

torch.manual_seed(42)
query = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
key = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
value = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)

out1 = torch.nn.functional.scaled_dot_product_attention(query,key,value)

out2, lse2 = torch.ops.aten._scaled_dot_product_flash_attention(query, key, value)[:2]
print(f'lse2: {lse2}')

out3, lse3 = torch.ops.aten._scaled_dot_product_efficient_attention(query, key,value,
attn_bias=None, compute_log_sumexp=True)[:2]
print(f'lse3: {lse3}')

print(f'Result: {torch.allclose(out1, out2, rtol=1e-3, atol=1e-3)}')
print(f'Result: {torch.allclose(out1, out3, rtol=1e-3, atol=1e-3)}')
print(f'Result: {torch.allclose(lse2, lse3, rtol=1e-3, atol=1e-3)}')

结果:
lse2: tensor([[[5.4660, 5.7685, 5.4396, 5.6371, 5.3691, 5.4923, 5.4666, 5.4459,
5.5439, 5.7281, 5.8045, 5.6622, 5.7923, 5.6987, 5.5620, 5.5473],
[5.4989, 5.8029, 5.5886, 5.5052, 5.6427, 5.5984, 5.7117, 5.4015,
5.6134, 5.5992, 5.4512, 5.8386, 5.8852, 5.3351, 5.6285, 5.6732]]],
device='cuda:0')

lse3: tensor([[[5.4660, 5.7685, 5.4396, 5.6371, 5.3691, 5.4923, 5.4666, 5.4459,
5.5439, 5.7281, 5.8045, 5.6622, 5.7923, 5.6987, 5.5620, 5.5473,
inf, inf, inf, inf, inf, inf, inf, inf,
inf, inf, inf, inf, inf, inf, inf, inf],
[5.4989, 5.8029, 5.5886, 5.5052, 5.6427, 5.5984, 5.7117, 5.4015,
5.6134, 5.5992, 5.4512, 5.8386, 5.8852, 5.3351, 5.6285, 5.6732,
inf, inf, inf, inf, inf, inf, inf, inf,
inf, inf, inf, inf, inf, inf, inf, inf]]],
device='cuda:0')

Result: True
Result: True

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions