提供3大功能:
- LLM模型预训练:支持常见模型的预训练,包括:decoder结构(LLaMA、GPT)、encoder结构(GLM)
- LLM模型评测:参考GPT类模型,基于ZeroShot和FewShot实现
- ChatGPT模型训练pipeline:根据Learning to Summarize from human feedback ,实现3大流程: SFT、Reward Model和RLHF
- 支持RLHF阶段 (1) 联合优化reward和policy (2) 单独优化policy,冻结reward
- 支持DPO作为Reward+RLHF的替代方案,可显著降低显存占用,同时实现RL的效果
git clone https://github.com/microsoft/DeepSpeed.git
cd deepspeed
rm -rf build
TORCH_CUDA_ARCH_LIST="7.0" DS_BUILD_OPS=1 pip install -e . --global-option="build_ext" --global-option="-j8" --no-cache -v --disable-pip-version-check 2>&1 | tee build.log
如果想创建binary wheel,方便在其他机器上安装,可使用如下命令,会在dist
目录生成类似可安装文件deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl
git clone https://github.com/microsoft/DeepSpeed.git
cd deepspeed
rm -rf build
TORCH_CUDA_ARCH_LIST="7.0" DS_BUILD_OPS=1 python setup.py build_ext -j8 bdist_wheel 2>&1 | tee build.log
PS:需要根据下图,调整TORCH_CUDA_ARCH_LIST="7.0"
为自己对应的NVIDIA GPU架构
或运行torch.cuda.get_device_capability()
获取自己GPU的架构
在使用Pangu类模型的时候,其special_token格式为<sep>
、<pad>
等,而tokenization_gptpangu.py中tokenize()
函数会使用jieba
进行分词。但直接pip install jieba
,默认会将<
和>
直接切分开,使用jieba.add_word("<sep>")
也没有作用,因为jieba
直接hardcode了会自动切分的token,其中就包括了<
和>
。
因此需要执行:
git clone https://github.com/fxsjy/jieba.git
cd jieba
将代码clone到本地,修改jieba/__init__.py
中re_han_default
的取值,具体改动如下:
- 改动前:
re_han_default = re.compile("([\u4E00-\u9FD5a-zA-Z0-9+#&\._%\-]+)", re.U)
- 改动后:
re_han_default = re.compile("([\u4E00-\u9FD5a-zA-Z0-9+#&\._%\-<>]+)", re.U)
修改完成后使用pip install .
进行本地编译安装,替换原有jieba
。安装完成后,在代码中加入jieba.add_word("<sep>")
(该代码已加入tokenization_gptpangu.py),即可解决将<sep>
一类的special token切分为多个id的情况
git clone https://github.com/NVIDIA/apex
cd apex
pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache -v --disable-pip-version-check . 2>&1 | tee build.log
如果想创建binary wheel,方便在其他机器上安装,可使用如下命令,会在dist
目录生成类似可安装文件apex-0.0.1+7150e20-cp38-cp38-linux_x86_64.whl
git clone https://github.com/NVIDIA/apex
cd apex
python setup.py --cpp_ext --cuda_ext bdist_wheel 2>&1 | tee build.log
模型 | size | huggingface地址 | 百度网盘地址 | 提取码 |
---|---|---|---|---|
Pangu-350M | 659MB | sunzeyeah/pangu-350M | Pangu-350M | c5jj |
Pangu-2.6B | 9.8GB | sunzeyeah/pangu-2_6B | Pangu-2.6B | 2rad |
Pangu-13B | 23.6GB | sunzeyeah/pangu-13B | Pangu-13B | u3dx |
GLM-350M-chinese | 679MB | sunzeyeah/glm-350M-chinese | GLM-350M-chinese | ii8e |
GLM-10B-chinese | 18.4G | sunzeyeah/glm-10B-chinese | GLM-10B-chinese | fynj |
ChatGLM-6B | 25.6G | sunzeyeah/chatglm-6B | ChatGLM-6B | uq1k |
PS: 本repo提供的预训练模型下载中,
- 对于pytorch_model*.bin
- 如果源文件已包括,则不做改动
- 如果源文件不包括,则根据其提供的checkpoint转换为pytorch_model*.bin
- 其余文件可能相对原文件有改动,包括:modeling_*.py、tokenization_*.py、configuration_*.py、config.json和tokenizer.config
数据集 | size | huggingface地址 | 百度网盘地址 | 提取码 |
---|---|---|---|---|
CLUE Benchmark | 500MB | CLUE Benchmark | m6gt | |
SFT & Reward Data | 5GB | sunzeyeah/chinese_chatgpt_corpus | SFT & Reward Data | ecyc |
百科 | 652MB | baike_qa_2019 | 7jad | |
知道问答 | 847MB | zhidao | neds | |
对联 | 221MB | couplets | 54ey | |
古文 | 125MB | Classical & Modern | a4cr | |
古诗词 | 87MB | chinese poetry | 5zzj | |
微博新闻评论 | 522MB | weibo summary comments | w0g1 |
PS: SFT & Reward Data基于百科、知道问答、对联、古文、古诗词、微博新闻评论数据构造,可直接用于SFT和Reward阶段训练。详见data_prepare.py
对开源LLM进行增量预训练,基于deepspeed实现。目前支持2类模型架构:
- decoder结构:LLaMA、Baichuan、Pangu
- encoder结构:GLM、ChatGLM
cd examples
bash pretrain.sh
对开源中文LLM进行ZeroShot、OneShot或FewShot的评测。详见eval_pretrain.py 和 data.py。
目前支持的评测任务:
- C-Eval
- MMLU
- CLUEBenchmark :评测方法和prompt模板参考Pangu-alpha论文
目前支持的开源模型:
- LLaMA及相关衍生模型
- ChatGLM(1和2)
- Baichuan
- Qwen
- Pangu
- GLM
cd examples
bash eval_pretrain.sh
使用开源LLM + SFT&Reward数据进行SFT训练
cd examples
bash train_sft.sh
使用SFT模型 + SFT&Reward数据进行Reward模型训练
cd examples
bash train_reward.sh
利用PPO算法和Reward Model,进一步更新SFT模型。基于开源框架DeepSpeedChat 实现
cd examples
bash train_rlhf.sh
利用DPO算法替代Reward+RLHF的pipeline,免去训练Reward模型,同时达到RL训练的效果,该方法可显著降低显存占用。基于开源框架trl 实现
cd examples
bash train_dpo.sh
C-Eval 5-shot测试集(test)结果
Model | Avg | Avg(Hard) | STEM | Social Science | Humanities | Other |
Baichuan2-13B-Chat | 56.30 | 34.20 | 48.20 | 70.00 | 60.50 | 54.20 |
xverse-13B | 55.30 | 32.50 | 45.90 | 66.70 | 59.50 | 57.60 |
Qwen-7B-Chat | 54.70 | 35.40 | 47.90 | 68.30 | 58.70 | 50.00 |
Baichuan-13B-Base | 53.70 | 35.60 | 46.80 | 65.80 | 58.00 | 50.80 |
Baichuan2-7B-Chat | 52.50 | 33.80 | 45.70 | 64.20 | 56.60 | 50.20 |
ChatGLM2-6B | 51.20 | 33.40 | 46.90 | 63.00 | 51.60 | 47.70 |
Baichuan-13B-Chat | 47.90 | 31.50 | 41.40 | 56.80 | 53.00 | 46.50 |
Baichuan-7B | 44.20 | 31.70 | 39.20 | 53.30 | 47.30 | 41.90 |
Ziya-LLaMA-13B-v1.1 | 40.10 | 30.30 | 35.80 | 47.30 | 42.80 | 38.50 |
ChatGLM1.1-6B | 38.10 | 28.60 | 33.60 | 46.70 | 40.90 | 35.70 |
AtomGPT-13B-56k | 37.60 | 25.30 | 32.00 | 44.70 | 42.80 | 36.10 |
LLaMA2-13B-chat | 37.10 | 29.30 | 34.60 | 43.60 | 35.90 | 37.00 |
ChatGLM-6B | 36.30 | 27.20 | 32.90 | 42.80 | 38.10 | 34.90 |
LLaMA-30B | 35.90 | 29.90 | 34.40 | 42.40 | 33.30 | 35.60 |
LLaMA2-7B-chat | 33.50 | 27.30 | 31.60 | 38.10 | 33.80 | 32.70 |
Ziya-LLaMA-13B-Pretrain-v1 | 31.10 | 22.20 | 27.40 | 36.50 | 33.80 | 30.40 |
LLaMA-13B | 29.8 | 24.20 | 28.40 | 33.70 | 29.60 | 29.00 |
LLaMA-7B | 26.80 | 26.70 | 26.20 | 27.60 | 25.70 | 28.10 |
MMLU 5-shot测试集(test)结果
Model | Avg | STEM | Social Science | Humanities | Other |
Baichuan2-13B-Chat | 56.90 | 47.28 | 66.23 | 52.90 | 63.50 |
LLaMA-30B | 56.33 | 44.68 | 65.64 | 54.60 | 61.57 |
xverse-13B | 55.24 | 45.60 | 64.51 | 50.32 | 63.27 |
Qwen-7B-Chat | 54.13 | 41.76 | 63.43 | 50.81 | 62.50 |
LLaMA2-13B-chat | 53.98 | 44.52 | 63.40 | 49.37 | 61.21 |
Baichuan-13B-Base | 53.46 | 43.86 | 63.14 | 49.73 | 59.28 |
Baichuan2-7B-Chat | 53.11 | 43.51 | 62.26 | 49.58 | 59.12 |
Baichuan-13B-Chat | 51.12 | 41.61 | 59.11 | 47.52 | 58.31 |
Ziya-LLaMA-13B-v1.1 | 51.06 | 41.89 | 57.71 | 49.22 | 56.54 |
LLaMA2-7B-chat | 48.10 | 39.64 | 56.28 | 43.61 | 55.39 |
LLaMA-13B | 46.51 | 37.23 | 52.71 | 44.35 | 53.04 |
ChatGLM2-6B | 45.83 | 38.75 | 52.06 | 43.20 | 50.82 |
AtomGPT-13B-56k | 42.75 | 36.02 | 49.04 | 38.80 | 49.30 |
Baichuan-7B | 41.96 | 36.63 | 47.77 | 37.55 | 48.31 |
Ziya-LLaMA-13B-Pretrain-v1 | 41.61 | 33.61 | 46.01 | 39.85 | 48.05 |
ChatGLM1.1-6B | 40.07 | 32.95 | 44.55 | 39.23 | 44.12 |
ChatGLM-6B | 37.87 | 32.41 | 43.80 | 35.60 | 41.00 |
LLaMA-7B | 28.53 | 26.10 | 28.76 | 28.52 | 24.81 |
CLUEBenchmark 验证集(dev.json)结果
Dataset | Method | Metrics | Task Type | Zero-shot | Few-shot | ||||||||
GLM-350M-chinese | Pangu-350M | Pangu-2.6B | GLM-10B-chinese | Pangu-13B | GLM-350M-chinese | Pangu-350M | Pangu-2.6B | GLM-10B-chinese | Pangu-13B | ||||
OCNLI | PPL | acc | NLI | 0.3074 | 0.3369 | 0.3061 | 0.3288 | 0.3301 | 0.3298 | 0.3352 | 0.3216 | ||
CMNLI | PPL | acc | NLI | 0.3279 | 0.3302 | 0.3310 | 0.3338 | 0.3358 | 0.3356 | 0.3328 | 0.3300 | ||
CHID | PPL | acc | Cloze(multi-choices) | 0.0734 | 0.0916 | 0.0670 | 0.1016 | 0.1018 | 0.0979 | 0.1007 | 0.0996 | ||
CMRC2018 | generation | f1 | MRC | 0.093 | 0.0979 | 0.1007 | 0.1392 | 0.021 | 0.09345 | 0.097 | 0.1007 | ||
CLUEWSC2020 | PPL | acc | WSC | 0.4934 | 0.5328 | 0.5592 | 0.5131 | 0.4671 | 0.5526 | 0.4473 | 0.4671 | ||
C3 | PPL | acc | Common sense reasoning | 0.2360 | 0.2426 | 0.2418 | 0.2573 | 0.2567 | 0.2476 | 0.2559 | 0.2515 | ||
AFQMC | PPL | acc | Text classification | 0.6306 | 0.4582 | 0.4914 | 0.4960 | 0.5000 | 0.4872 | 0.4993 | 0.5018 | ||
CSL | PPL | acc | Text classification | 0.4943 | 0.4913 | 0.4666 | 0.5126 | 0.4996 | 0.5140 | 0.5036 | 0.4973 | ||
IFLYTEK | PPL | acc | Text classification | 0.1292 | 0.3058 | 0.265 | 0.2620 | 0.2408 | 0.2539 | 0.2535 | 0.2524 | ||
TNEWS | PPL | acc | Text classification | 0.1582 | 0.2022 | 0.2449 | 0.2489 | 0.2527 | 0.2555 | 0.2466 | 0.2494 |
模型训练参数:
模型 | 可训练参数量 | 数据量 | batch size | sequence length | 硬件 | 显存占用 | speed | Hours per epoch |
---|---|---|---|---|---|---|---|---|
GLM-350M-chinese | 355M | 5.4M | 4 | 512 | V100 16G | 13G | 3.7 s/iter | 88h |
Pangu-350M | 345M | 5.4M | 48 | 512 | A100 80G | 78G | 1.91 s/iter | 27.5h |
Pangu-2.6B | 2.6B | 5.4M | 8 | 512 | A100 80G | 79.4G | 9.61 s/iter | 116h |
SFT模型下载:
预训练模型 | SFT数据集 | size | huggingface地址 | 百度网盘地址 | 提取码 |
---|---|---|---|---|---|
Pangu-350M | SFT & Reward Data | 1.3GB | sunzeyeah/pangu-350M-sft | Pangu-350M-SFT | 0f8j |
Pangu-2.6B | SFT & Reward Data | 9.8GB | sunzeyeah/pangu-2_6B-sft | Pangu-2.6B-SFT | r13w |
Pangu-2.6B-SFT生成示例
# 生成参数
max_length = 512
max_length_generation = 100
top_k = 10
temperature = 0.8
do_sample=True
[Input] 我不能确定对方是不是喜欢我,我却想分分秒秒跟他在一起,有谁能告诉我如何能想他少一点<sep>回答:
[Generated] 你爱他就多关心他,少在他面前表现出你的脆弱。这样他才更会爱你。
[Input] 一次次买彩,一次次倾听那比分,一次次兴奋,一次次失落,一次次无归而返!最终还是任8~~~~<sep>回答:
[Generated] 我也是这样,买彩的时间也不长,但是每次都是这样。。。。。。。
[Input] 现代文:太祖很高兴地说:凭着这样的部下去作战,还担心不能取胜么?<sep>古文:
[Generated] 太祖悦曰:此自是以去,何忧不胜?
[Input] 现代文:十一年秋,丞相伯颜与阿术会师襄阳,遣阿剌罕率诸翼军攻郢、复诸州,十月夺取郢州之南门堡。<sep>古文:
[Generated] 十一年秋,丞相伯颜与阿术会师襄阳,遣阿剌罕率诸翼军攻郢、复诸州,十月夺郢州之南门堡。
[Input] 以孔⁇的风格,写一篇经书,题为中庸<sep>
[Generated] 我闻昔者夫子之言也,曰:吾道一以贯之。其或行之或不行,吾未知之。是故君子慎其独也。吾未见好德如好色者也。好德如好色者,未之有也。好德如好色者,既已之有也。君子不以言举人,不以人废言。好德如好色者,其如命何!吾
[Input] 以陆游的风格,写一首七言诗,题为晨起<sep>
[Generated] 秋光满纸清⁇明,睡醒一江万马奔。自笑身如枯木鹤,何妨食饱作闲人。
[Input] 上联:霜风几笔雕残画<sep>下联:
[Generated] 云海千山叠翠峰
[Input] 新闻内容:据悉,这名14岁的女生在Facebook上发布了戏谑考试的博文,其中使用了胡志明1946年反对法国殖民者的讲话,随后被越南当局停课一年。越南官员称,该学生歪曲了历史事件,侮辱了老师的情感。多数网友虽对越南当局做法不满,但都表示,至少他们还能上Facebook.<sep>摘要:越南女生Facebook上模仿胡志明讲话被停课<sep>评论:
[Generated] 这女的,还真有勇气,我喜欢你!
[Input] 新闻内容:习近平天津考察时勉励大学生志存高远、脚踏实地,转变择业观念,坚持从实际出发,勇于到基层一线和艰苦地方去,把人生的路一步步走稳走实,善于在平凡岗位上创造不平凡的业绩<sep>摘要:习近平勉励大学生转变择业观念勇于到基层和艰苦地方去<sep>评论:
[Generated] 现在的社会,很多年轻人都是在父母的安排下,去了不同的单位,在不同的地方工作,在不同的城市生活,这就导致了年轻人的思维方式,价值观都不一样,这也导致了年轻人的成长道路,也有很多的不安全感,这也是很正常的事情,年轻人需要正确的引导,正确的引导,也需要正确的引导!
[Input] 问题:劳动仲裁申请法院强制执行,法院强制执行需要多长时间<sep>回答:
[Generated] 一般申请法院强制执行需要3-6个月时间,具体需要看法院的执行措施。
[Input] 问题:银行贷款50万,30年!年利率是4.41怎么算一个月还多少钱<sep>回答:
[Generated] 若申请的是招行贷款,要计算贷款利息或者每月还款金额,需要知道贷款本金,贷款期限,还款方式,贷款年利率。若上述信息都确认到,可以通过我行贷款计算器尝试计算,登录招行官网右下方找到“理财计算器”--“个人贷款计算器”可以计算。
模型训练参数:
模型 | 可训练参数量 | 数据量 | batch size | sequence length | 硬件 | 显存占用 | speed | Hours per epoch |
---|---|---|---|---|---|---|---|---|
Pangu-350M | 131M | 12M | 32 | 512 | A100 80G | 72.6G | 1.91 s/iter | 105h |
Pangu-2.6B | 815M | 12M | 8 | 512 | A100 80G | 80.7G | 423h |
Reward模型下载:
预训练模型 | SFT数据集 | size | huggingface地址 | 百度网盘地址 | 提取码 |
---|---|---|---|---|---|
Pangu-350M | SFT & Reward Data | 1.3GB | sunzeyeah/pangu-350M-reward | Pangu-350M-Reward | 4gju |
To be updated
为验证不同预训练模型使用deepspeed的训练效率是否能达到官方宣称的效果(加速、节省GPU等),进行了benchmarking
- 实验场景:SFT阶段训练
- 实验参数:
max_sequence_length=512
DeepSpeed实验结果
模型 | 数据 | 整体耗时/epoch | 单条样本耗时 | 内存使用量 | 显存使用量 | GPU型号和数量 | fp16 | bf16 | deepspeed stage | offload optimizer | pin memory | offloard param | overlap comm | allgather bucket size | stage3 max live parameters | batch size | gradient accumulation steps | gradient checkpointing | model half |
T5-large | wmt16-en-ro, 共计61万条样本 | 43h | 0.5s/it | 7.1G | 1*14529MB | 1*V100 16G | true | - | - | - | - | - | - | - | - | 2 | 8 | false | false |
152h | 1.78s/it | 38.26G | 1*11663MB | 1*V100 16G | true | - | 2 | true | true | - | false | 2e8 | - | 2 | 8 | false | false | ||
250h | 2.95s/it | 38.74G | 1*7255MB | 1*V100 16G | true | - | 2 | true | true | - | false | 1e5 | - | 2 | 8 | false | false | ||
62h | 5.8s/it | 86.81G | 8*7811MB | 8*V100 16G | true | - | 2 | true | true | - | false | 1e5 | - | 2 | 8 | false | false | ||
- | - | - | OOM | 1*V100 16G | true | - | 2 | true | true | - | false | 2e8 | - | 16 | 8 | false | false | ||
- | - | - | OOM | 1*V100 16G | true | - | 2 | true | true | - | false | 1e5 | - | 16 | 8 | false | false | ||
290h | 3.48s/it | 46.53G | 1*6655MB | 1*V100 16G | true | - | 3 | true | true | true | false | 2e8 | 2e8 | 2 | 8 | false | false | ||
380h | 4.5s/it | 43.48G | 1*5263MB | 1*V100 16G | true | - | 3 | true | true | true | false | 1e5 | 1e5 | 2 | 8 | false | false | ||
215h | 4.9s/it | 47.31G | 2*5019MB | 2*V100 16G | true | - | 3 | true | true | true | false | 1e5 | 1e5 | 2 | 8 | false | false | ||
1370h | 64s/it | 57.55G | 4*4701MB | 4*V100 16G | true | - | 3 | true | true | true | false | 1e5 | 1e5 | 2 | 8 | false | false | ||
948h | 90s/it | 72.54G | 8*4585MB | 8*V100 16G | true | - | 3 | true | true | true | false | 1e5 | 1e5 | 2 | 8 | false | false | ||
Pangu-2.6B | SFT & Reward Data的验证集,共1万条样本 | 2h | 5.76s/it | 67.86G | 1*15631MB | 1*V100 16G | true | - | 2 | true | true | - | false | 2e8 | - | 2 | 8 | false | false |
2.1h | 6.15s/it | 67.88G | 1*15705MB | 1*V100 16G | true | - | 2 | true | true | - | false | 1e5 | - | 2 | 8 | false | false | ||
4.5h | 13.3s/it | 81.02G | 1*15449MB | 1*V100 16G | true | - | 3 | true | true | true | false | 2e8 | 2e8 | 2 | 8 | false | false | ||
11.5h | 8.2s/it | 75.89G | 1*15299MB | 1*V100 16G | true | - | 3 | true | true | true | false | 1e5 | 1e5 | 2 | 8 | false | false | ||
5.5h | 7.8s/it | 81.16G | 2*14851MB | 2*V100 16G | true | - | 3 | true | true | true | false | 1e5 | 1e5 | 2 | 8 | false | false | ||
6.2h | 18.3s/it | 97.31G | 4*14389MB | 4*V100 16G | true | - | 3 | true | true | true | false | 1e5 | 1e5 | 2 | 8 | false | false | ||
6.6h | 38s/it | 118.82G | 8*14335MB | 8*V100 16G | true | - | 3 | true | true | true | false | 1e5 | 1e5 | 2 | 8 | false | false | ||
ChatGLM-6B | SFT & Reward Data的验证集,共1万条样本 | - | - | 120.45G | OOM | 1*V100 16G | true | - | 2 | true | true | - | false | 1e5 | - | 1 | 8 | true | true |
- | - | 120.48G | OOM | 1*V100 16G | true | - | 2 | true | true | - | false | 1e3 | - | 1 | 8 | false | true | ||
- | - | 153.02G | OOM | 1*V100 16G | true | - | 3 | true | true | true | false | 1e2 | 1e2 | 1 | 8 | false | true | ||
- | - | 154G | OOM | 1*V100 16G | true | - | 3 | true | true | true | false | 2e8 | 2e8 | 1 | 8 | true | true | ||
21.2h | 60s/it | 154G | 1*10443MB | 1*V100 16G | true | - | 3 | true | true | true | false | 2e8 | auto | 1 | 8 | true | true | ||
21.5h | 60s/it | 152.81G | 1*10409MB | 1*V100 16G | true | - | 3 | true | true | true | false | 1e5 | 1e5 | 1 | 8 | true | true | ||
23.5h | 65s/it | 153.36G | 1*9229MB | 1*V100 16G | true | - | 3 | true | true | true | false | 1e3 | 1e3 | 1 | 8 | true | true | ||
14h | 80s/it | 158.21G | 2*8631MB | 2*V100 16G | true | - | 3 | true | true | true | false | 1e3 | 1e3 | 1 | 8 | true | true | ||
7.8h | 90s/it | 168.38G | 4*6743MB | 4*V100 16G | true | - | 3 | true | true | true | false | 1e3 | 1e3 | 1 | 8 | true | true | ||
4h | 90s/it | 189.34G | 8*6729MB | 8*V100 16G | true | - | 3 | true | true | true | false | 1e3 | 1e3 | 1 | 8 | true | true | ||
1h | 100s/it | 189.38G | 8*10047MB | 8*V100 16G | true | - | 3 | true | true | true | false | 1e3 | 1e3 | 4 | 8 | true | true | ||
50min | 40s/it | 189.39G | 8*14763MB | 8*V100 16G | true | - | 3 | true | true | true | false | 1e3 | 1e3 | 8 | 2 | true | true | ||
35min | 113s/it | 189.39G | 8*14763MB | 8*V100 16G | true | - | 3 | true | true | true | false | 1e3 | 1e3 | 8 | 8 | true | true | ||
- | - | 189.34G | OOM | 8*V100 16G | true | - | 3 | true | true | true | false | 1e3 | 1e3 | 10 | 8 | true | true | ||
GLM-10B-Chinese | SFT & Reward Data的验证集,共1万条样本 | - | - | - | OOM | 1*V100 16G | true | - | 3 | true | true | true | false | 2e8 | 2e8 | 1 | 8 | true | false |
- | - | - | OOM | 1*V100 16G | true | - | 3 | true | true | true | false | 2e8 | auto | 1 | 8 | true | false | ||
- | - | - | OOM | 1*V100 16G | true | - | 3 | true | true | true | false | 1e5 | 1e5 | 1 | 8 | true | false | ||
- | - | - | OOM | 1*V100 16G | true | - | 3 | true | true | true | false | 1e3 | 1e3 | 1 | 8 | true | false | ||
- | - | - | OOM | 1*V100 16G | true | - | 3 | true | true | true | false | 1e2 | 1e2 | 1 | 8 | true | false | ||
- | - | - | OOM | 2*V100 16G | true | - | 3 | true | true | true | false | 1e2 | 1e2 | 1 | 8 | true | false | ||
- | - | - | OOM | 4*V100 16G | true | - | 3 | true | true | true | false | 1e2 | 1e2 | 1 | 8 | true | false | ||
- | - | OOM | - | 8*V100 16G | true | - | 3 | true | true | true | false | 1e2 | 1e2 | 1 | 8 | true | false | ||
- | - | - | OOM | 4*V100 16G | true | - | 3 | true | true | true | false | 1e2 | 1e2 | 1 | 8 | true | true | ||
- | - | - | OOM | 6*V100 16G | true | - | 3 | true | true | true | false | 1e2 | 1e2 | 1 | 8 | true | true | ||
- | - | OOM | - | 8*V100 16G | true | - | 3 | true | true | true | false | 1e2 | 1e2 | 1 | 8 | true | true |
PS: deepspeed的参数介绍和调优经验,可参见DeepSpeed Configuration
为验证LoRA的训练效率提升,进行了benchmarking
- 实验场景:SFT阶段训练
- 实验数据:SFT & Reward Data的验证集,共1万条样本
- 实验参数:
max_sequence_length=512, lora_alpha=1, lora_train_bias='none'
LoRA实验结果
模型 | LoRA rank | 可训练参数量 | deepspeed | batch size | GPU型号和数量 | 显存使用量 | 单条样本耗时 | 整体耗时/epoch |
Pangu-2.6B | - | 2.6B | - | 8 | 1*A100 80G | 1*79421MB | 9.66s/it | 12.5min |
1000 | 1.5B | - | 8 | 1*A100 80G | 1*76129MB | 11.61s/it | 15min | |
500 | 758MB | - | 12 | 1*A100 80G | 1*77179MB | 16.2s/it | 14min | |
100 | 151MB | - | 16 | 1*A100 80G | 1*81103MB | 18.6s/it | 12min | |
50 | 75MB | - | 16 | 1*A100 80G | 1*80809MB | 17.8s/it | 11.5min | |
10 | 15MB | - | 16 | 1*A100 80G | 1*78735MB | 17.6s/it | 11.5min | |
100 | 151MB | stage=2, w offloading | 24 | 1*A100 80G | 1*76933MB | 25.5s/it | 11min | |
100 | 151MB | stage=3, w offloading | 24 | 1*A100 80G | 1*77259MB | 46.5s/it | 20min | |
ChatGLM-6B | - | 6.2B | - | 3 | 1*A100 80G | 1*79206MB | 6.7s/it | 23.5min |
1000 | 1.9B | - | 6 | 1*A100 80G | 1*78840MB | 12.8s/it | 22.5min | |
500 | 994MB | - | 6 | 1*A100 80G | 1*68832MB | 12.4s/it | 21.5min |