Skip to content

Commit

Permalink
Merge pull request chenzomi12#260 from engin-work/master
Browse files Browse the repository at this point in the history
Update 03Extend.md
  • Loading branch information
chenzomi12 authored Jun 25, 2024
2 parents 543d2d1 + 9ea14f3 commit de23bcc
Show file tree
Hide file tree
Showing 14 changed files with 85 additions and 16 deletions.
101 changes: 85 additions & 16 deletions 04Inference/05Optimize/03Extend.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,41 +84,80 @@ Recomputation(重算)则是一种以计算能力为代价来节省存储空

主要的算法实现:

步骤一:计算分子块的大小
![flashAttention 算法方式](images/03Extend04.png)

### 步骤一:计算分子块的大小

首先,我们需要获取 GPU 硬件 SRAM 的大小,我们假设为 M。为了让 Q、K、V 在计算中可以存放在 SRAM 中,我们需要设定分块的大小尺寸。

其次,在 SRAM 上需要存在的数据包括,Q 子块,K 子块,V 子块,其次还应包括计算过程中的中间输出 O,O 的大小应该与 Q、K、V 子块大小一致。

所以,在这里我们计算出子块的列大小 Bc =[M/4d], d 为矩阵维度。
![flashAttention 算法方式](images/03Extend05.png)
1. Set block sizes B_c = [M / 4d], B_r = min([M / 4d],d)

所以,在这里我们计算出子块的列大小 Bc =[M/4d], d 为矩阵维度。当然,需要注意的是,上面的设置子块的大小并非唯一的,只有保证子块大小不超过SRAM的大小即可。

### 步骤二:初始化输出矩阵 O

步骤二:初始化输出矩阵 O
![flashAttention 算法方式](images/03Extend06.png)

SRAM 上的输出 O 矩阵赋值为全 0,它将作为一个累加器保存 softmax 的累积分母。
SRAM 上的输出 O 矩阵赋值为全 0,它将作为一个累加器保存 softmax 的累积分母,其中 l 也类似。m 用于记录每一行行最大分数,其初始化为-inf

步骤三:切分子块
### 步骤三:切分子块

![flashAttention 算法方式](images/03Extend07.png)

将 Q 划分成 Tr 个 Bolck,K、V 划分成 Tc 个 Block,初始化 attention output O,并划分成 Tr 个 Block。

步骤四:外循环加载 K、V 内循环加载 Q 子块
### 步骤四:外循环加载 K、V 内循环加载 Q 子块

![flashattention cycle](images/03Extend04.png)
![flashattention cycle](images/03Extend08.png)

如图所示:
上图完美解释了这个循环过程,

1. 外循环:对于每一个 Block Key 和 Value,从 HBM 加载进 SRAM
2. 内循环:对于每个 Block Query,从 HBM 加载进 SRAM
3. 在 SRAM 上完成 Block S 的计算

步骤五:实现分块 SoftMax 算法
![flashattention cycle](images/03Extend09.png)

这里要注意的是,Oi, li, mi其中存储的可能是上一个循环计算的中间结果。

### 步骤五:实现分块 SoftMax 算法

下面我们看看原版的证明公式:

假如有切片向量 x = [x^(1), x^(2)],切片后 softmax 的计算方式:
#### 标准 softmax 计算方式

![flashattention cycle](images/03Extend10.png)

在实际硬件中,因为浮点数表示的范围是有限的,对于 float32 和 bfloat16 来说,当 z ≥ 89 时,exp(z) 就会变成inf,发生数据上溢的问题。

为了确保数值计算的稳定性,避免溢出问题,通常采用一种称为“safe softmax”的计算策略。在此方法中,通过减去最大值来缩放输入数据,以保证数值的相对稳定性。

所以说,现有所有的深度学习框架中都采用了“safe softmax”这种计算方式,其计算公式如下:

![flashattention cycle](images/03Extend11.png)

计算举例:a = [0.1, 0.2, 0.3, 0.4]; m(a) = 0.4

则可以得到:f(a) = [e^(0.1-0.4), e^(0.2-0.4), e^(0.3-0.4), e^(0.4-0.4)]

然后计算得到:l(a) = Σf(a), softmax(a) = f(a) / l(a)

从上面可以看出,首先在分子上“safe softmax”需要获取当前区间的最大值来缩放输入数据,而在分母上需要累加所有的分子 f(a)。

由于 flashAttention 已经采取了分块计算的策略,也就意味着在计算 softmax 时,并不能拿到所有数据列的最大值和全部 f(a) 的和。

#### flashAttention 改进方式

![softmax 计算方式 1](images/03Extend05.png)
虽然softmax与K的列是耦合的,但如果分开计算每个子块的softmax再将最后的结果进行收集转换是否可以等价呢?下面我们看看原版的证明公式:

update m(x),根据更新后的 m(x),根据上一步计算结果重新计算 f(x), l(x)。假设存在 x^(3), 那么便可以将 x^(1)和 x^(2)合并成一个序列,重复步骤 1 即可。
1. 假如有切片向量x = [x^(1), x^(2)],切片后softmax 的计算方式:

![softmax 计算方式 1](images/03Extend12.png)

2. update m(x),根据更新后的 m(x),根据上一步计算结果重新计算 f(x), l(x)。假设存在 x^(3), 那么便可以将 x^(1)和 x^(2)合并成一个序列,重复步骤 1 即可。

计算举例:a = [0.1, 0.2, 0.3, 0.4] = [a1, a2]

Expand All @@ -128,17 +167,47 @@ update m(x),根据更新后的 m(x),根据上一步计算结果重新计算

最终得到 softmax(a) = f(a) / l(a)

通过上述的转换可知,softmax 与分块 softmax 是在数学上是等价的关系。不过由于真实计算中次数变多,精度上也可能存在一定丢失。
需要注意的是,可以利用GPU多线程同时并行计算多个block的softmax。

可见通过上述的转换可知,softmax 与分块 softmax是在数学上是等价的关系。不过由于真实计算中次数变多,精度上也可能存在一定丢失。

在介绍完 flashAttention 中 softmax 的改进后,我们继续围绕论文中的代码进行分析:

![softmax 计算方式 1](images/03Extend13.png)

首先,根据上一步计算的子块 Sij,来计算当前块的行最大值 mij,当前块 Pij (即 softmax 的分子),lij 为 Pij 的累积值。

其次,计算子块与子块间的最大值 m^new 和多个子块的 Pij 的累积值 l^new。

最后,根据 softmax 公式计算最终的 softmax,经结果写到 SRAM 的 Oi 中,并写出到 HBM,同时将最后的最后的l^new 赋值给 li 写出到 HBM,m^new 赋值到 mi 写出到 HBM,开始下一轮循环。

到此前向计算就算完成,我们可以通过下图来总结下flashAttention的前向计算过程,这里就不做过多解释了。

![softmax 计算方式 1](images/03Extend14.png)

步骤六:反向计算

从上面的前向过程中,我们知道前向过程中只将 Oi, li, mi 写出到了 HBM,而S和P的保存则主要在反向的重算中实现。

1. 前向过程会保留 Q,K,V,O, l, m 在 HBM 中,dO 由反向计算获取后,按照前向相同的分块模式重新分块。

2. 初始化 dQ,dK,dV 为全 0 矩阵,并按照对等 Q,K,V 的分割方式分割 dQ,dK,dV。

3. 分别从 HBM 中 Load K V block on SRAM,再 Load Q block on SRAM。根据前向过程重新计算对应 block 的 S 和 P;按分块矩阵的方式分别计算对应梯度,完成参数更新。

代码实现:https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py#L17
最终可以看到,在将三个kernel进行合并后,flashAttention v1实现了中间计算完全基于SRAM的目的。

### flashAttention性能分析

FlashAttention 节省的访存次数计算如下:

首先,K,V (Nxd)的每个block都需要Load 进SRAM,因此该过程的HBM访问次数为 O(Nxd)。

其次,Q 也需要分 block Load 进 SRAM,该过程一共持续外循环 Tc 次,因此该过程的 HBM 访问次数为 O(TcNd)

最后,而 Tc = N/(Bc) = 4Nd/M(向上取整)。因此 flash attention 的 HBM 访问次数为 O(N^2d^2/M)

下面来看 flash Attention 代码的具体实现:https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py#L17

```python
def _fwd_kernel(
Expand Down Expand Up @@ -232,15 +301,15 @@ Flash Attention 的优点在于充分考虑了在计算任务中 IO 的重要性

具体示例如下:

![数据节点转换](images/03Extend06.png)
![数据节点转换](images/03Extend15.png)

内存优化是一种计算机系统优化技术,主要目的是提高系统的运行性能,通过更有效地使用和管理内存资源来达到这个目的。

Inplace operation:是一种内存优化手段,它在当前的内存块上直接进行操作,而不需要额外开辟新的内存。如果一块内存不再需要,且下一个操作是 element-wise(元素级操作,比如加法、乘法等),我们就可以使用原地操作,直接在原内存上进行计算,覆盖原有的数据。这样做的好处是可以节省内存,减少内存的分配和回收开销,从而提高程序的运行效率。

Memory sharing:是另一种内存优化策略。它在内存使用上进行优化,当两个数据的内存大小相同,且有一个数据参与计算后不再需要时,我们可以让后一个数据直接覆盖前一个数据的内存。这样做的好处是可以减少内存的开销,节省内存空间,提高内存的使用效率。

![内存优化](images/03Extend07.png)
![内存优化](images/03Extend16.png)

## 小结与思考

Expand Down
Binary file modified 04Inference/05Optimize/images/03Extend04.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified 04Inference/05Optimize/images/03Extend05.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified 04Inference/05Optimize/images/03Extend06.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified 04Inference/05Optimize/images/03Extend07.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added 04Inference/05Optimize/images/03Extend08.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added 04Inference/05Optimize/images/03Extend09.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added 04Inference/05Optimize/images/03Extend10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added 04Inference/05Optimize/images/03Extend11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added 04Inference/05Optimize/images/03Extend12.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added 04Inference/05Optimize/images/03Extend13.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added 04Inference/05Optimize/images/03Extend14.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added 04Inference/05Optimize/images/03Extend15.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added 04Inference/05Optimize/images/03Extend16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit de23bcc

Please sign in to comment.