Chinese-CLIP训练现已支持通过FlashAttention加速训练进程。
- Turing、Ampere、Ada、Hopper架构的Nvidia GPU显卡(如H100、A100、RTX 3090、T4、RTX 2080),Nvidia各架构对应显卡型号可参见此文档表格。
- CUDA 11.4及以上版本。
- Pytorch 1.12及以上版本。
- FlashAttention:通过执行
pip install flash-attn
安装FlashAttention。
更多信息可参见FlashAttention项目仓库。
在Chinese-CLIP finetune中应用FlashAttention非常简单,只需要在finetune的sh脚本中加入--use-flash-attention
配置项即可。我们提供了样例脚本run_scripts/muge_finetune_vit-b-16_rbt-base_flashattn.sh
。
启用FlashAttention可在不影响效果的条件下为Chinese-CLIP的finetune过程显著提速以及降低显存占用。我们的实验在一台8卡A100 GPU(80GB显存)机器进行,FlashAttention 0.2.8,Pytorch 1.10.1。
我们分别列出finetune过程中,相同batch size下启用FlashAttention前后每个规模模型的FP16精度finetune的batch time和显存占用对比,可以看到启用FlashAttention后,训练速度有所提升,也更加节约显存。对于更大规模模型的训练速度提升和显存占用降低更为显著。
Batch Time | ||||
---|---|---|---|---|
单位: 秒/it | Batch size | w/o FlashAttention | w/ FlashAttention | Speedup |
CN-CLIPRN50 | 1200*8 | 1.710 | 1.680 | 1.02× |
CN-CLIPViT-B/16 | 450*8 | 1.477 | 0.960 | 1.54× |
CN-CLIPViT-L/14 | 128*8 | 1.293 | 0.785 | 1.65× |
CN-CLIPViT-L/14@336px | 40*8 | 1.397 | 0.587 | 2.38× |
CN-CLIPViT-H/14 | 64*8 | 1.265 | 0.845 | 1.50× |
显存 | ||||
---|---|---|---|---|
单位: GB | Batch size | w/o FlashAttention | w/ FlashAttention | |
CN-CLIPRN50 | 1200*8 | 79 | 75 | |
CN-CLIPViT-B/16 | 450*8 | 80 | 56 | |
CN-CLIPViT-L/14 | 128*8 | 77 | 50 | |
CN-CLIPViT-L/14@336px | 40*8 | 78 | 37 | |
CN-CLIPViT-H/14 | 64*8 | 76 | 57 |