为什么需要做模型蒸馏?
Bert类模型精读高,但是推理速度慢,模型蒸馏可以在速度和精读之间做一个平衡。
- 从蒸馏方法
从蒸馏方法来看,一般可以分为三种:
-
参数的共享或者剪枝
-
低秩分解
-
知识蒸馏
对于1和2,可以参考一下 Albert。
而对于知识蒸馏来说,本质是通过一种映射关系,将老师学到的东西映射到或者说传递给学生网络。
在最开始的时候,一般都会有一种疑问? 我有训练数据了,训练数据的准确度肯定比你大模型的输出结构准确度高,为什么还需要从老师网络来学习知识?
我觉得对于这个问题,我在李如的文章看到这样一句话:”好模型的目标不是拟合训练数据,而是学习如何泛化到新的数据“
我觉得写的很好。对于这个问题,我们这么去想,我们的大模型的输出对于logits不仅仅是类别属于哪一个,还有一个特点就是会给出不同类别之间的一个关系。
比如说,在预测”今天天气真不错,现在就决定了,出去浪一波,来顿烧烤哦“。
文本真实标签可能就直接给出了”旅游“这个标签,而我们的模型在概率输出的时候可能会发现”旅游“和”美食“两个标签都还行。
这就是模型从数据中学习到的一种”暗知识“(好像是这么叫,忘了在哪里看到了)、
而且还存在一个问题,有些时候是没有那么多训练数据的,需要的是大模型Bert这种给出无监督数据的伪标签作为冷启动也是不错的。
- 从蒸馏结构
从蒸馏结构来说,我们可以分为两种:
-
从transformer到transformer结构
-
从transformer结构到别的模型(CNN或者lstm结构)
我主要是想聊一下 Bert 到 TextCNN模型的蒸馏。
为啥选择textcnn?最大的原因就是速度快精读还不错。
论文参考 Distilling Task-Specific Knowledge from BERT into Simple Neural Networks
对于这个蒸馏,对于我而言,最重要的掌握一个点就是损失函数的设定,别的地方我暂且不考虑。
对于损失函数,分为两个部分,一个是我当前lstm输出结果和真实标签的交叉熵损失,一个是我的当前lstm输出结果和大模型bert的输出logits的平方损失。
至于为啥一个是交叉熵一个是平方损失,是因为其实前面的看做分类问题,后面的看做回归问题。当然只是谁更合适的选择问题。
因为是加权两个部分做损失,我这边选择为都是0.5。
当然在李如的文章中谈到,可能真实标签这边的权重小一点会更好一点,因为蒸馏本质上还是想多关注bert的输出多一点。
关于这个论文有一个很好的解释:
知识蒸馏论文选读(二) - 小禅心的文章 - 知乎 https://zhuanlan.zhihu.com/p/89420539
关于模型蒸馏,我就简单了解到这里,可能之后会花费大量精力看看背的蒸馏方式,放上开源代码:
bert到textcnn/lstm/lkeras/torch
罗列一下关于Bert模型蒸馏的文章和博客:
首先一个讲的比较好的文章就是下面这个文章,比较系统的讲了一遍
BERT知识蒸馏综述 - 王三火的文章 - 知乎 https://zhuanlan.zhihu.com/p/106810758
还有一个文章讲的比较好的是:BERT 模型蒸馏 Distillation BERT
https://www.jianshu.com/p/ed7942b5207a
这个文章就是比较系统的对比了Bert的两个蒸馏操作:DistilBERT 和 Distilled BiLSTM 我觉得写得还不错
从实战的角度来说,我觉得写得很好的就是:BERT 蒸馏在垃圾舆情识别中的探索 https://blog.csdn.net/alitech2017/article/details/107412038
这个文章是对bert的蒸馏,到textcnn,使用了多种方式并且比较了最终的结果。
接下来是李如的这个文章,很概括,确实大佬,写得很好: 【DL】模型蒸馏Distillation - 李如的文章 - 知乎 https://zhuanlan.zhihu.com/p/71986772