Open
Description
可以考虑增加flashattention
1.flashattention (v1版本)实现要点
matmul和scaled的分块运算就跳过了
flashattention 的关键是softmax分块运算核心是延迟全局依赖计算:
- 分块计算局部量:对每个键块,计算局部分数、局部最大值 m_t 和局部指数和 Z_t ,仅存储归一化分数(分数-局部最大值)及 m_t、Z_t ,释放原始分数内存。
- 全局归约统计量:通过归约所有块的 m_t 得到全局最大值 m_i ,利用指数性质 \exp(a+b)=\exp(a)\exp(b) ,将各块 Z_t 转换为以 m_i 为基准的贡献,累加得到全局指数和 Z_i 。
- 块内收尾计算:用 m_i 和 Z_i 对每个块的归一化分数计算最终权重(\exp(分数-m_i)/Z_i),逐块输出结果,避免存储完整矩阵。通过分层处理,将 O(n^2) 内存依赖降为 O(nb) ,适配有限共享内存。