Transformer 模型的时候处理速度较慢,且会占用大量的显存。自注意力的时间和内存复杂度是与序列长度的平方成正比。在 GPU 上,计算速度已超过内存速度,并且 Transformers 中的大多数操作都受到内存访问的瓶颈。因此,内存访问模式的优化是加速 Transformer 模型的关键。Flash Attention 是一种高效的自注意力实现,它通过将内存访问模式与计算结合起来,减少了内存带宽的使用,从而提高了性能。
在这个系列中,我们将介绍 Flash Attention 系列的原理和实现。
老规矩,我们还是先回顾一下 GPU 的层次结构。GPU 内存层次结构由不同大小和速度的多种形式的内存组成,较小的内存速度较快。例如,A100 GPU 具有 40-80GB 的高带宽内存(HBM),带宽为 1.5-2.0TB/s,并且每个 108 个流处理器有 192KB 的片上 SRAM,其带宽估计约为 19TB/s [44, 45]。片上 SRAM 的速度比 HBM 快一个数量级,但其大小小了多个数量级。
给定输入序列
在标准的注意力机制实现中,矩阵
以一个具体实例来看,在 GPT-2 模型中,序列长度
FlashAttention 的核心思想可以用两个关键词来概括:分块计算 和 动态重计算。这两种技术的结合,使得注意力机制在保持高效的同时,显著减少了内存占用。
传统的 Softmax 需要一次性加载整个输入数据,才能计算全局的最大值和归一化系数。而 FlashAttention 采用了 增量式计算 的方式,将输入数据分成小块,依次加载到 GPU 的片上缓存(SRAM)中。
我们首先定义一些变量方便后续的讨论:
| 变量 | 尺寸(shape) | 说明 |
|---|---|---|
| 输入矩阵 | ||
|
|
||
|
|
||
| 局部注意力分数矩阵 | ||
| 局部行最大值向量 | ||
| 局部未归一化的注意力权重 | ||
| 局部行和向量 | ||
| 更新后的全局行最大值 | ||
| 更新后的全局行和 | ||
| 输出的第 |
首先,FlashAttention 将输入矩阵
对于每一块
为了在分块计算中保持数值稳定性,FlashAttention 维护两个全局统计量:行最大值
在更新输出矩阵
这一公式确保了输出结果与全局计算等价。在每一块计算完成后,更新后的
:::note
本文不探讨公式的具体推导过程,感兴趣的读者可以参考 [2]
:::
上图是 FlashAttention 的分块计算的示意图,外层循环中会对
这里我们以一个最简单的例子来说明更新的过程。
我们以 序列长度
-
$\mathbf{Q} \in \mathbb{R}^{4 \times 2}$ ,分为 2 块:$\mathbf{Q}_1 \in \mathbb{R}^{2 \times 2}$,$\mathbf{Q}_2 \in \mathbb{R}^{2 \times 2}$ (每块行数$B_r = 2$ )。 -
$\mathbf{K}, \mathbf{V} \in \mathbb{R}^{4 \times 2}$ ,分为 2 块:$\mathbf{K}_1, \mathbf{V}_1 \in \mathbb{R}^{2 \times 2}$,$\mathbf{K}_2, \mathbf{V}_2 \in \mathbb{R}^{2 \times 2}$ (每块行数$B_c = 2$ )。
初始状态下:
- 输出矩阵 $\mathbf{O} = \begin{bmatrix} 0 & 0 \ 0 & 0 \ 0 & 0 \ 0 & 0 \end{bmatrix}$。
- 全局统计量:$\ell = [0, 0, 0, 0]^T$,
$m = [-\infty, -\infty, -\infty, -\infty]^T$ 。
步骤 1:外层循环
-
加载
$\mathbf{K}_1$ ,$\mathbf{V}_1$ 到 SRAM:$$ \mathbf{K}1 = \begin{bmatrix} k{11} & k_{12} \ k_{21} & k_{22} \end{bmatrix}, \quad \mathbf{V}1 = \begin{bmatrix} v{11} & v_{12} \ v_{21} & v_{22} \end{bmatrix} $$
-
内层循环
$i=1$ ,处理块$\mathbf{Q}_1$ :- 加载数据: $$ \mathbf{Q}1 = \begin{bmatrix} q{11} & q_{12} \ q_{21} & q_{22} \end{bmatrix}, \quad \mathbf{O}_1 = \begin{bmatrix} 0 & 0 \ 0 & 0 \end{bmatrix}, \quad \ell_1 = [0, 0]^T, \quad m_1 = [-\infty, -\infty]^T $$
- 计算局部注意力分数: $$ \mathbf{S}{11} = \mathbf{Q}1 \mathbf{K}1^T = \begin{bmatrix} q{11}k{11} + q{12}k_{12} & q_{11}k_{21} + q_{12}k_{22} \ q_{21}k_{11} + q_{22}k_{12} & q_{21}k_{21} + q_{22}k_{22} \end{bmatrix} \in \mathbb{R}^{2 \times 2} $$
-
局部统计量:
- 逐行最大值 $\tilde{m}{11} = [\max(\mathbf{S}{11}[1,:]), \max(\mathbf{S}_{11}[2,:])]^T$。
- 未归一化注意力权重 $\tilde{\mathbf{P}}{11} = \exp(\mathbf{S}{11} - \tilde{m}_{11})$。
- 逐行和 $\tilde{\ell}{11} = [\text{sum}(\tilde{\mathbf{P}}{11}[1,:]), \text{sum}(\tilde{\mathbf{P}}_{11}[2,:])]^T$。
-
更新全局统计量:
- 全局最大值
$m_1^{\text{new}} = \max(m_1, \tilde{m}_{11})$ 。 - 全局行和 $\ell_1^{\text{new}} = e^{m_1 - m_1^{\text{new}}} \ell_1 + e^{\tilde{m}{11} - m_1^{\text{new}}} \tilde{\ell}{11}$。
- 全局最大值
- 更新输出: $$ \mathbf{O}1 \leftarrow \text{diag}(\ell_1^{\text{new}})^{-1} \left( \text{diag}(\ell_1) e^{m_1 - m_1^{\text{new}}} \mathbf{O}1 + e^{\tilde{m}{11} - m_1^{\text{new}}} \tilde{\mathbf{P}}{11} \mathbf{V}_1 \right) $$
-
写回 HBM:更新后的
$\mathbf{O}_1$ 对应前两行,$\ell_1$ 和$m_1$ 同步更新。
-
内层循环
$i=2$ ,处理块$\mathbf{Q}_2$ :- 类似地,加载 $\mathbf{Q}2 = \begin{bmatrix} q{31} & q_{32} \ q_{41} & q_{42} \end{bmatrix}$,计算
$\mathbf{S}_{21} = \mathbf{Q}_2 \mathbf{K}_1^T$ ,更新后两行$\mathbf{O}_2$ 。
- 类似地,加载 $\mathbf{Q}2 = \begin{bmatrix} q{31} & q_{32} \ q_{41} & q_{42} \end{bmatrix}$,计算
步骤 2:外层循环
-
加载
$\mathbf{K}_2$ ,$\mathbf{V}_2$ 到 SRAM: $$ \mathbf{K}2 = \begin{bmatrix} k{31} & k_{32} \ k_{41} & k_{42} \end{bmatrix}, \quad \mathbf{V}2 = \begin{bmatrix} v{31} & v_{32} \ v_{41} & v_{42} \end{bmatrix} $$ -
内层循环
$i=1$ ,处理块$\mathbf{Q}_1$ :-
加载数据:当前
$\mathbf{O}_1$ 已包含来自$\mathbf{V}_1$ 的贡献。 - 计算局部注意力分数: $$ \mathbf{S}{12} = \mathbf{Q}1 \mathbf{K}2^T = \begin{bmatrix} q{11}k{31} + q{12}k_{32} & q_{11}k_{41} + q_{12}k_{42} \ q_{21}k_{31} + q_{22}k_{32} & q_{21}k_{41} + q_{22}k_{42} \end{bmatrix} \in \mathbb{R}^{2 \times 2} $$
-
更新统计量:根据
$\mathbf{S}_{12}$ 的局部最大值和行和,更新$m_1^{\text{new}}$ 和$\ell_1^{\text{new}}$ 。 - 更新输出: $$ \mathbf{O}1 \leftarrow \text{diag}(\ell_1^{\text{new}})^{-1} \left( \text{diag}(\ell_1) e^{m_1 - m_1^{\text{new}}} \mathbf{O}1 + e^{\tilde{m}{11} - m_1^{\text{new}}} \tilde{\mathbf{P}}{12} \mathbf{V}_2 \right) $$
-
结果等价于全局 Softmax:最终
$\mathbf{O}_1$ 为前两行注意力结果的加权和。
-
加载数据:当前
-
内层循环
$i=2$ ,处理块$\mathbf{Q}_2$ :- 类似地,计算
$\mathbf{S}_{22} = \mathbf{Q}_2 \mathbf{K}_2^T$ ,更新后两行$\mathbf{O}_2$ 。
- 类似地,计算
通过这种分阶段、分块处理的方式,FlashAttention 在不牺牲计算精度的前提下,显著提升了注意力机制的效率,成为处理长序列任务的利器。
在反向传播阶段,传统的注意力机制需要存储前向传播生成的完整注意力矩阵,这进一步加剧了内存压力。FlashAttention 采用了 动态重计算 的策略:在前向传播中,只存储必要的中间结果(如最大值和归一化系数),而在反向传播时,按需重新计算注意力矩阵。
我们的文章里面展示只实现前向传播的计算,反向传播的详细过程可以参考 [2]。
[1] Andrei Ivanov, Nikoli Dryden, Tal Ben-Nun, Shigang Li, and Torsten Hoefler. Data movement is all you need: A case study on optimizing transformers. Proceedings of Machine Learning and Systems, 3:711–732, 2021 [2] https://zhuanlan.zhihu.com/p/669926191 [3] http://www.zh0ngtian.tech/posts/49b73eba.html


