Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzomi12 committed Jun 19, 2024
2 parents 23943ad + 01ac355 commit 7a3ce1f
Show file tree
Hide file tree
Showing 27 changed files with 646 additions and 156 deletions.
2 changes: 1 addition & 1 deletion 00Others/Editors.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

AI 系统概述:[苏统华教授](https://homepage.hit.edu.cn/tonghuasu), [ZOMI 酱](https://github.com/chenzomi12)

AI 硬件体系架构:[苏统华教授](https://homepage.hit.edu.cn/tonghuasu), [ZOMI 酱](https://github.com/chenzomi12), [@刘军](https://github.com/AI-LJ), [@张晓天](), [@李明慧](https://github.com/xxx), [@张圣航](), [@刘纬恒](), [@李宗泽](), [@赵文千]()
AI 硬件体系架构:[苏统华教授](https://homepage.hit.edu.cn/tonghuasu), [ZOMI 酱](https://github.com/chenzomi12), [@刘军](https://github.com/AI-LJ), [@张晓天](), [@李明慧](https://github.com/xxx), [@张圣航](), [@刘纬恒](), [@李宗泽](https://freelulul.github.io/), [@赵文千]()

AI 编译原理:[苏统华教授](https://homepage.hit.edu.cn/tonghuasu), [ZOMI 酱](https://github.com/chenzomi12), [@史家兴](), [@宋一帆](https://github.com/sfs999), [@韩昊知](https://github.com/haozhihan), [@李行](), [@武震卿](), [@张浩](), [@陈庚天](),

Expand Down
103 changes: 67 additions & 36 deletions 03Compiler/03Frontend/07ConstantFold.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

# 常量折叠原理

常量折叠(Constant Folding)是编译器的一种优化技术,它通过在编译期间对常量表达式进行计算,将其结果替换为常量值,从而减少程序运行时的计算和开销。下面我们将分别介绍传统编译器和 AI 编译器的常量折叠优化。

============= 建议后面的案例统一用 python 或者 C++,不要引入 JAVA 或者 GO
常量折叠(Constant Folding)是编译器的一种优化技术,它通过在编译期间对常量表达式进行计算,将其结果替换为常量值,从而减少程序运行时的计算和开销。

## 传统编译器的常量折叠

Expand Down Expand Up @@ -37,11 +35,7 @@ dis.dis("day_sec=24*60*60")

上述的 CPython 的字节码表明,python 在对 day_sec 赋值是直接加载计算结果 86400,相比于 3 次载入数据和两次乘法,更加地高效。

常量折叠发生的条件:

- 必须是编译期常量之间进行运算才会进行常量折叠。

- 编译期常量就是“编译的时候就可以确定其值的常量”。即字面量(数字字面量,字符串字面量等等),常量传播可达的变量。
表达式e可进行常量折叠,当且仅当表达式e的所有子表达式都是常量。而子表达式被判断为常量通常需要常量传播的帮助。

举个例子:

Expand All @@ -59,25 +53,49 @@ int y = 7 - 14 / 2;
return y * (28 / 14 + 2);
```

对于 y = 7 - 14 / 2 这个表达式,所有参与计算的都是常量(字面量),所以会被常量折叠成 y = 0
对于 $7 - 14 /2$ 和 $28/14+2$ 这两个表达式,由于其所有的子表达式都是常量(字面值),所以这两个表达式可以进常量折叠优化,优化后得到

```c++
int x = 14;
int y = 0;
return y * (28 / 14 + 2);
return y * 4;
```

编译器再次对 y 进行常量传播以及对于 y * (28 / 14 + 2)进行常量折叠后得到:
编译器再次对 y 进行常量传播,将所有对y的可达引用都替换 0 得到:

```c++
int x = 14;
int y = 0;
return 0 * 4;
```

对于表达式 $0 * 4$,由于其所有的子表达式都是常量(字面值),所以这个表达式可以进行常量折叠优化,优化后得到:

```c++
int x = 14;
int y = 0;
return 0;
```

通过上述的例子可以得知一个表达式如果可以进行常量折叠,那么其参与计算的所有值最后都应该是字面值,因为常量变量可以通过常量传播转化为字面值。
由例子可见,常量传播对于常量折叠的重要性。在传统编译器中,常量传播主要是通过对控制流图(CFG)进行可达性分析,为每个基本块维护一个可达集合,记为$Reaches(n)$。其含义为若定义$d\in Reaches(n)$,则意味着存在一条从入口基本块到基本块 $b_n$的路径,d没有被重新定义。计算公式如下:

$$
Reaches(n) = \bigcup_{m\in preds(n)}(DEDef(m)\ \cup (Reaches(m)\ \cap\ \overline{DefKill(m)}))
$$

方程的初始条件为:$Reaches(n) = \emptyset$, $\forall n$

其中:

- $preds(n)$ 表示n的前趋结点集。
- $DEDef(m)$ 表示基本块 $b_m$ 中向下展示的定义,其含义为若定义$d\in DEDef(m)$,则意味着从d定义处到 $b_m$ 的出口处都没有被重新定义。
- $DefKill(m)$ 表示在基本块 $b_m$ 中被杀死的定义。其含义若定义$d\in DefKill(m)$,则意味着从d定义处到 $b_m$ 的出口处被重新定义。因此$\overline{DefKill(m)}$包含了m中可见的所有定义位置。

从公式上看,如果定义d在基本块的出口处是可达的,当且仅当定义d是基本块中向下展示的定义,或者定义d在基本块的入口处是可定义的,并且在基本块内没有被杀死。根据入口可达集合的定义,存在一条路径即可,所以定义d在基本块的入口处是可达的,只需要在其任意前趋结点的出口处是可达的即可。

不同的编译器对于编译期常量的定义不同,有些编译器需要特殊的关键字标识才会被认定为编译期常量。这将会影响一个表达式能否进行常量折叠的判断。比如下面 java 进行字符串计算的例子:
当已知基本块入口处的可达定义集合,对于基本块中某个定义引用,若从引用处到基本块的入口都没有重定义,且该定义引用在可达定义集合中,则可以用可达定义集合中的值替换该定义引用。如果有重定义,则用重定义的值替换该定义引用,从而达到传播的目的。

如果传播的定义为常量定义,则称常量传播。但不同的编译器对于常量的定义不同,有些编译器需要特殊的关键字标识才会被认定为常量。这将会影响一个表达式能否进行常量折叠的判断。比如下面 java 进行字符串计算的例子 :

```java
String a = "a";
Expand All @@ -95,32 +113,39 @@ String s2 = new StringBuilder(a).append(b).toString();

- 当常量表达式计算的结果溢出时,编译器不会进行常量折叠。比如 2 的 64 次幂会被折叠,而 4 的 64 次幂不会被折叠:

======= 格式参考我上面改的一样,输出不要用这个格式哦。
```python
>>> dis.dis("day_sec=2**64")
```

```python
>>> dis.dis("day_sec=2**64")
```text
1 0 LOAD_CONST 0 (18446744073709551616)
2 STORE_NAME 0 (day_sec)
4 LOAD_CONST 1 (None)
6 RETURN_VALUE
```
2 STORE_NAME 0 (day_sec)
4 LOAD_CONST 1 (None)
6 RETURN_VALUE
```

```python
>>> dis.dis("day_sec=4**64")
```

```python
>>> dis.dis("day_sec=4**64")
```text
1 0 LOAD_CONST 0 (4)
2 LOAD_CONST 1 (64)
4 BINARY_POWER
6 STORE_NAME 0 (day_sec)
8 LOAD_CONST 2 (None)
10 RETURN_VALUE
```
2 LOAD_CONST 1 (64)
4 BINARY_POWER
6 STORE_NAME 0 (day_sec)
8 LOAD_CONST 2 (None)
10 RETURN_VALUE
```

上面两图分别是 2 的 64 次幂和 4 的 64 次幂的 CPython 字节码,可以看出 2 的 64 次幂会被折叠成一个具体的数,而 4 的 64 次幂不会。

- 当进行字符串运算的时候,比如两个字符串相加,当且仅当字符串相加的结果大小小于等于 4096 时,该常量表达式才会被折叠,否者不会进行折叠:

```python
>>> dis.dis("day_sec='-'*4097")
```

```text
1 0 LOAD_CONST 0 ('-')
2 LOAD_CONST 1 (4097)
4 BINARY_MULTIPLY
Expand All @@ -131,11 +156,7 @@ String s2 = new StringBuilder(a).append(b).toString();

上面给出了结果大小为 4097 的字节码,可以看出并不会被折叠。而大小为 4096 会被折叠,由于会被折叠,得到的字符串是非常长的,这里就不给出例子了,读者可以自行尝试。

不同编译器的常量折叠的实现细节不尽相同,下面以 python 为例来描述传统编译器的常量折叠的一种实现。

在 python 中,CPython 会调用 astfold_expr 来对表达式进行常量折叠。astfold_expr 以递归的方式遍历 AST(抽象语法树),并尝试折叠所有的表达式。比如二值运算操作,astfold_expr 会先递归处理该二值操作的左操作数和右操作数,然后将此折叠操作代理给特定的折叠函数 fold_binop。在 fold_binop 中,首先会判断左操作数和右操作数是否都是常量,如果为常量,则判断该二值操作的具体操作类型,然后调用对应基本运算操作,比如 ADD 运算,会调用 PyNumber_Add。最后将计算出来的结果更新到 AST 对应的节点中。

## AI 编译器的常量折叠
## AI编译器的常量折叠

常量折叠作为传统编译器的一种优化技术,其迁移到 AI 编译器依然适用。传统编译器通常是对抽象语法树进行常量折叠优化,而 AI 编译器是对计算图进行常量折叠优化。AI 编译器会对计算图中的每个操作节点进行分析,判断其是否可进行常量折叠。如果可以,则通过计算得到结果替换该节点。

Expand All @@ -161,7 +182,17 @@ String s2 = new StringBuilder(a).append(b).toString();

与传统编译器相同,AI 编译器在进行常量折叠的时候也会被诸多因素所影响,比如 如果常量的大小(以字节为单位)太大,则不替换它。这可以防止图的大小变得过大。

不同的 AI 编译器对于常量折叠的细节不尽相同,这里以 tensorflow 为例,描述其常量折叠的实现细节:
## 实现案例

### 传统编译器实现案例

不同编译器的常量折叠的实现细节不尽相同,下面以python为例来描述传统编译器的常量折叠的一种实现。

在python中,CPython会调用astfold_expr来对表达式进行常量折叠。astfold_expr以递归的方式遍历AST(抽象语法树),并尝试折叠所有的表达式。比如二值运算操作,astfold_expr会先递归处理该二值操作的左操作数和右操作数,然后将此折叠操作代理给特定的折叠函数fold_binop。在fold_binop中,首先会判断左操作数和右操作数是否都是常量,如果为常量,则判断该二值操作的具体操作类型,然后调用对应基本运算操作,比如ADD运算,会调用PyNumber_Add。最后将计算出来的结果更新到AST对应的节点中。

### AI编译器实现案例

不同的AI编译器对于常量折叠的细节不尽相同,这里以tensorflow为例,描述其常量折叠的实现细节:

1. 获得逆后续节点集。使用反向深度优先搜索对计算图进行遍历,获取逆后续节点集。这样可以保证在处理当前节点时,其所有输入的节点都已经处理完毕了。

Expand All @@ -173,9 +204,9 @@ String s2 = new StringBuilder(a).append(b).toString();

5. 常量折叠。判断需要可常量折叠节点集是否为空,如果为空,则说明常量图中的所有常量节点无论替换与否都不会影响原图的计算,这些点通常可能是死代码或者冗余节点,会被后续的优化处理给删除掉,这种情况则不需要进行常量折叠;如果不为空,那么会将生成一个值与该节点的计算结果相同的常量数据节点来替换该节点。

![third test](images/07constant_fold03.png)
![third test](images/07constant_fold03.png)

如上图中,AddN2 的输出为 Conv(不在常量图中),tensorflow 会计算 AddN2 的输出结果,由于可以肯定其输入要么是编译期常量要么是可折叠常量节点,递归处理其所有输入节点即可。
如上图中,AddN2 的输出为 Conv(不在常量图中),tensorflow 会计算 AddN2 的输出结果,由于可以肯定其输入要么是编译期常量要么是可折叠常量节点,递归处理其所有输入节点即可。

## 小结与思考

Expand Down
Loading

0 comments on commit 7a3ce1f

Please sign in to comment.