Skip to content

cuda算子:flashattention #34

Open
@miaobyte

Description

@miaobyte

可以考虑增加flashattention

1.flashattention (v1版本)实现要点

matmul和scaled的分块运算就跳过了

flashattention 的关键是softmax分块运算核心是延迟全局依赖计算:

  •  分块计算局部量:对每个键块,计算局部分数、局部最大值 m_t 和局部指数和 Z_t ,仅存储归一化分数(分数-局部最大值)及 m_t、Z_t ,释放原始分数内存。
  • 全局归约统计量:通过归约所有块的 m_t 得到全局最大值 m_i ,利用指数性质 \exp(a+b)=\exp(a)\exp(b) ,将各块 Z_t 转换为以 m_i 为基准的贡献,累加得到全局指数和 Z_i 。
  •  块内收尾计算:用 m_i 和 Z_i 对每个块的归一化分数计算最终权重(\exp(分数-m_i)/Z_i),逐块输出结果,避免存储完整矩阵。通过分层处理,将 O(n^2) 内存依赖降为 O(nb) ,适配有限共享内存。

Metadata

Metadata

Assignees

No one assigned

    Labels

    todo待实现

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions