FP32: https://en.wikipedia.org/wiki/Single-precision_floating-point_format
FP16: https://en.wikipedia.org/wiki/Half-precision_floating-point_format
BF16: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
FP8: https://en.wikipedia.org/wiki/Minifloat
FP8 in deeplearning: https://arxiv.org/abs/2206.02915
arxiv.org/abs/2206.02915 use different bias in weight and act,
in order to make it simple, we use same bias.
1.5.2 have less m_cnt, 2x2bit mul is easy to calculate in FPGA.
1.4.3 is more detail than 1.5.2, but less dynamic range.
Note we don't have INF/NaN in fp8, just cutoff them.
Here is the fp8 table:
print(" ", end="")
for m in range(1<<m_cnt):
print("%02d\t"%m, end="")
print("")
for e in range(1<<e_cnt):
print("%02d: "%e, end="")
for m in range(1<<m_cnt):
if e==0:
data = m/(1<<m_cnt)*pow(2,-e_oft+1)
else:
data = (1.0+m/(1<<m_cnt))*pow(2,e-e_oft)
print("%3.4f\t"%(data), end="")
print("")
00 01 02 03 04 05 06 07
00: 0.0000 0.0005 0.0010 0.0015 0.0020 0.0024 0.0029 0.0034
01: 0.0039 0.0044 0.0049 0.0054 0.0059 0.0063 0.0068 0.0073
02: 0.0078 0.0088 0.0098 0.0107 0.0117 0.0127 0.0137 0.0146
03: 0.0156 0.0176 0.0195 0.0215 0.0234 0.0254 0.0273 0.0293
04: 0.0312 0.0352 0.0391 0.0430 0.0469 0.0508 0.0547 0.0586
05: 0.0625 0.0703 0.0781 0.0859 0.0938 0.1016 0.1094 0.1172
06: 0.1250 0.1406 0.1562 0.1719 0.1875 0.2031 0.2188 0.2344
07: 0.2500 0.2812 0.3125 0.3438 0.3750 0.4062 0.4375 0.4688
08: 0.5000 0.5625 0.6250 0.6875 0.7500 0.8125 0.8750 0.9375
09: 1.0000 1.1250 1.2500 1.3750 1.5000 1.6250 1.7500 1.8750
10: 2.0000 2.2500 2.5000 2.7500 3.0000 3.2500 3.5000 3.7500
11: 4.0000 4.5000 5.0000 5.5000 6.0000 6.5000 7.0000 7.5000
12: 8.0000 9.0000 10.0000 11.0000 12.0000 13.0000 14.0000 15.0000
13: 16.0000 18.0000 20.0000 22.0000 24.0000 26.0000 28.0000 30.0000
14: 32.0000 36.0000 40.0000 44.0000 48.0000 52.0000 56.0000 60.0000
15: 64.0000 72.0000 80.0000 88.0000 96.0000 104.0000 112.0000 120.0000
00 01 02 03
00: 0.0000 0.0000 0.0000 0.0000
01: 0.0001 0.0001 0.0001 0.0001
02: 0.0001 0.0002 0.0002 0.0002
03: 0.0002 0.0003 0.0004 0.0004
04: 0.0005 0.0006 0.0007 0.0009
05: 0.0010 0.0012 0.0015 0.0017
06: 0.0020 0.0024 0.0029 0.0034
07: 0.0039 0.0049 0.0059 0.0068
08: 0.0078 0.0098 0.0117 0.0137
09: 0.0156 0.0195 0.0234 0.0273
10: 0.0312 0.0391 0.0469 0.0547
11: 0.0625 0.0781 0.0938 0.1094
12: 0.1250 0.1562 0.1875 0.2188
13: 0.2500 0.3125 0.3750 0.4375
14: 0.5000 0.6250 0.7500 0.8750
15: 1.0000 1.2500 1.5000 1.7500
16: 2.0000 2.5000 3.0000 3.5000
17: 4.0000 5.0000 6.0000 7.0000
18: 8.0000 10.0000 12.0000 14.0000
19: 16.0000 20.0000 24.0000 28.0000
20: 32.0000 40.0000 48.0000 56.0000
21: 64.0000 80.0000 96.0000 112.0000
22: 128.0000 160.0000 192.0000 224.0000
23: 256.0000 320.0000 384.0000 448.0000
24: 512.0000 640.0000 768.0000 896.0000
25: 1024.0000 1280.0000 1536.0000 1792.0000
26: 2048.0000 2560.0000 3072.0000 3584.0000
27: 4096.0000 5120.0000 6144.0000 7168.0000
28: 8192.0000 10240.0000 12288.0000 14336.0000
29: 16384.0000 20480.0000 24576.0000 28672.0000
30: 32768.0000 40960.0000 49152.0000 57344.0000
31: 65536.0000 81920.0000 98304.0000 114688.0000
Test success in simple MNIST models, still fail in mobilenet...
You can enter example/mnist dir, and set tm_port.h :
#define TM_ARCH TM_ARCH_CPU
#define TM_OPT_LEVEL TM_OPT0
#define TM_MDL_TYPE TM_MDL_FP8_152
Compile and run, here is the mnist fp8 model result:
mnist demo
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 36, 56,137,201,199, 95, 37, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45,152,234,254,254,254,254,254,250,211,151, 6, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 46,153,240,254,254,227,166,133,251,200,254,229,225,104, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0,153,234,254,254,187,142, 8, 0, 0,191, 40,198,246,223,253, 21, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 8,126,253,254,233,128, 11, 0, 0, 0, 0,210, 43, 70,254,254,254, 21, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 72,243,254,228, 54, 0, 0, 0, 0, 3, 32,116,225,242,254,255,162, 5, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 75,240,254,223,109,138,178,178,169,210,251,231,254,254,254,232, 38, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 9,175,244,253,255,254,254,251,254,254,254,254,254,252,171, 25, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 16,136,195,176,146,153,200,254,254,254,254,150, 16, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,162,254,254,241, 99, 3, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,118,250,254,254, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,100,242,254,254,211, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 54,241,254,254,242, 59, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,131,254,254,244, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13,249,254,254,152, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 12,228,254,254,208, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 78,255,254,254, 66, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,209,254,254,137, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,227,255,233, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0,113,255,108, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
================================ model stat ================================
mdl_type=5 (fp8 1.5.2))
out_deq=1
input_cnt=1, output_cnt=1, layer_cnt=7
input 3dims: (28, 28, 1)
output 1dims: (1, 1, 10)
main buf size 3528; sub buf size 0
//Note: PARAM is layer param size, include align padding
Idx Layer outshape inoft outoft PARAM MEMOUT OPS
--- Input 28, 28, 1 - 0 0 784 0
000 Conv2D 14, 14, 12 0 1176 128 2352 21168
001 Conv2D 7, 7, 24 1176 0 2616 1176 127008
002 Conv2D 4, 4, 48 0 2760 10416 768 165888
003 Conv2D 1, 1, 96 2760 0 73824 96 73728
004 Reshape 1, 1, 96 0 0 0 96 0
005 FC 1, 1, 10 0 3512 976 10 960
006 Softmax 1, 1, 10 3512 0 0 10 60
Total param ~85.9 KB, OPS ~0.39 MOPS, buffer 3.4 KB
===tm_run use 12.668 ms
0: 0.000
1: 0.000
2: 0.000
3: 0.000
4: 0.000
5: 0.000
6: 0.000
7: 0.000
8: 0.000
9: 1.000
### Predict output is: Number 9, prob 1.000