Skip to content

Latest commit

 

History

History
164 lines (155 loc) · 7.9 KB

FP8_intro.md

File metadata and controls

164 lines (155 loc) · 7.9 KB

FP8 introduce

Basics

FP32: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
FP16: https://en.wikipedia.org/wiki/Half-precision_floating-point_format
BF16: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
FP8: https://en.wikipedia.org/wiki/Minifloat
FP8 in deeplearning: https://arxiv.org/abs/2206.02915

FP8 config

arxiv.org/abs/2206.02915 use different bias in weight and act,
in order to make it simple, we use same bias.
1.5.2 have less m_cnt, 2x2bit mul is easy to calculate in FPGA. 1.4.3 is more detail than 1.5.2, but less dynamic range.
Note we don't have INF/NaN in fp8, just cutoff them.

Here is the fp8 table:

print("      ", end="")
for m in range(1<<m_cnt):
    print("%02d\t"%m, end="")
print("")
for e in range(1<<e_cnt):
    print("%02d: "%e, end="")
    for m in range(1<<m_cnt):
        if e==0:
            data = m/(1<<m_cnt)*pow(2,-e_oft+1)
        else:
            data = (1.0+m/(1<<m_cnt))*pow(2,e-e_oft) 
        print("%3.4f\t"%(data), end="")
    print("")

1.4.3_9

    00      01      02      03      04      05      06      07	
00: 0.0000	0.0005	0.0010	0.0015	0.0020	0.0024	0.0029	0.0034	
01: 0.0039	0.0044	0.0049	0.0054	0.0059	0.0063	0.0068	0.0073	
02: 0.0078	0.0088	0.0098	0.0107	0.0117	0.0127	0.0137	0.0146	
03: 0.0156	0.0176	0.0195	0.0215	0.0234	0.0254	0.0273	0.0293	
04: 0.0312	0.0352	0.0391	0.0430	0.0469	0.0508	0.0547	0.0586	
05: 0.0625	0.0703	0.0781	0.0859	0.0938	0.1016	0.1094	0.1172	
06: 0.1250	0.1406	0.1562	0.1719	0.1875	0.2031	0.2188	0.2344	
07: 0.2500	0.2812	0.3125	0.3438	0.3750	0.4062	0.4375	0.4688	
08: 0.5000	0.5625	0.6250	0.6875	0.7500	0.8125	0.8750	0.9375	
09: 1.0000	1.1250	1.2500	1.3750	1.5000	1.6250	1.7500	1.8750	
10: 2.0000	2.2500	2.5000	2.7500	3.0000	3.2500	3.5000	3.7500	
11: 4.0000	4.5000	5.0000	5.5000	6.0000	6.5000	7.0000	7.5000	
12: 8.0000	9.0000	10.0000	11.0000	12.0000	13.0000	14.0000	15.0000	
13: 16.0000	18.0000	20.0000	22.0000	24.0000	26.0000	28.0000	30.0000	
14: 32.0000	36.0000	40.0000	44.0000	48.0000	52.0000	56.0000	60.0000	
15: 64.0000	72.0000	80.0000	88.0000	96.0000	104.0000	112.0000	120.0000	

1.5.2_15

    00      01      02      03	
00: 0.0000	0.0000	0.0000	0.0000	
01: 0.0001	0.0001	0.0001	0.0001	
02: 0.0001	0.0002	0.0002	0.0002	
03: 0.0002	0.0003	0.0004	0.0004	
04: 0.0005	0.0006	0.0007	0.0009	
05: 0.0010	0.0012	0.0015	0.0017	
06: 0.0020	0.0024	0.0029	0.0034	
07: 0.0039	0.0049	0.0059	0.0068	
08: 0.0078	0.0098	0.0117	0.0137	
09: 0.0156	0.0195	0.0234	0.0273	
10: 0.0312	0.0391	0.0469	0.0547	
11: 0.0625	0.0781	0.0938	0.1094	
12: 0.1250	0.1562	0.1875	0.2188	
13: 0.2500	0.3125	0.3750	0.4375	
14: 0.5000	0.6250	0.7500	0.8750	
15: 1.0000	1.2500	1.5000	1.7500	
16: 2.0000	2.5000	3.0000	3.5000	
17: 4.0000	5.0000	6.0000	7.0000	
18: 8.0000	10.0000	12.0000	14.0000	
19: 16.0000	20.0000	24.0000	28.0000	
20: 32.0000	40.0000	48.0000	56.0000	
21: 64.0000	80.0000	96.0000	112.0000	
22: 128.0000	160.0000	192.0000	224.0000	
23: 256.0000	320.0000	384.0000	448.0000	
24: 512.0000	640.0000	768.0000	896.0000	
25: 1024.0000	1280.0000	1536.0000	1792.0000	
26: 2048.0000	2560.0000	3072.0000	3584.0000	
27: 4096.0000	5120.0000	6144.0000	7168.0000	
28: 8192.0000	10240.0000	12288.0000	14336.0000	
29: 16384.0000	20480.0000	24576.0000	28672.0000	
30: 32768.0000	40960.0000	49152.0000	57344.0000	
31: 65536.0000	81920.0000	98304.0000	114688.0000	

FP8 model test

Test success in simple MNIST models, still fail in mobilenet...
You can enter example/mnist dir, and set tm_port.h :

#define TM_ARCH         TM_ARCH_CPU
#define TM_OPT_LEVEL    TM_OPT0 
#define TM_MDL_TYPE     TM_MDL_FP8_152

Compile and run, here is the mnist fp8 model result:

mnist demo
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 36, 56,137,201,199, 95, 37,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 45,152,234,254,254,254,254,254,250,211,151,  6,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0, 46,153,240,254,254,227,166,133,251,200,254,229,225,104,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,153,234,254,254,187,142,  8,  0,  0,191, 40,198,246,223,253, 21,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  8,126,253,254,233,128, 11,  0,  0,  0,  0,210, 43, 70,254,254,254, 21,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0, 72,243,254,228, 54,  0,  0,  0,  0,  3, 32,116,225,242,254,255,162,  5,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0, 75,240,254,223,109,138,178,178,169,210,251,231,254,254,254,232, 38,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  9,175,244,253,255,254,254,251,254,254,254,254,254,252,171, 25,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0, 16,136,195,176,146,153,200,254,254,254,254,150, 16,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,162,254,254,241, 99,  3,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,118,250,254,254, 90,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,100,242,254,254,211,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 54,241,254,254,242, 59,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,131,254,254,244, 64,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0, 13,249,254,254,152,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0, 12,228,254,254,208,  8,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0, 78,255,254,254, 66,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,209,254,254,137,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,227,255,233, 25,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,113,255,108,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
================================ model stat ================================
mdl_type=5 (fp8 1.5.2))
out_deq=1 
input_cnt=1, output_cnt=1, layer_cnt=7
input 3dims: (28, 28, 1)
output 1dims: (1, 1, 10)
main buf size 3528; sub buf size 0
//Note: PARAM is layer param size, include align padding

Idx     Layer            outshape       inoft   outoft  PARAM   MEMOUT OPS
---     Input            28, 28,  1     -       0       0       784     0
000     Conv2D           14, 14, 12     0       1176    128     2352    21168
001     Conv2D            7,  7, 24     1176    0       2616    1176    127008
002     Conv2D            4,  4, 48     0       2760    10416   768     165888
003     Conv2D            1,  1, 96     2760    0       73824   96      73728
004     Reshape           1,  1, 96     0       0       0       96      0
005     FC                1,  1, 10     0       3512    976     10      960
006     Softmax           1,  1, 10     3512    0       0       10      60

Total param ~85.9 KB, OPS ~0.39 MOPS, buffer 3.4 KB

===tm_run use 12.668 ms
0: 0.000
1: 0.000
2: 0.000
3: 0.000
4: 0.000
5: 0.000
6: 0.000
7: 0.000
8: 0.000
9: 1.000
### Predict output is: Number 9, prob 1.000