Skip to content

Latest commit

 

History

History
98 lines (48 loc) · 4.14 KB

bert2textcnn模型蒸馏.md

File metadata and controls

98 lines (48 loc) · 4.14 KB

为什么需要做模型蒸馏?

Bert类模型精读高,但是推理速度慢,模型蒸馏可以在速度和精读之间做一个平衡。

  1. 从蒸馏方法

从蒸馏方法来看,一般可以分为三种:

  1. 参数的共享或者剪枝

  2. 低秩分解

  3. 知识蒸馏

对于1和2,可以参考一下 Albert。

而对于知识蒸馏来说,本质是通过一种映射关系,将老师学到的东西映射到或者说传递给学生网络。

在最开始的时候,一般都会有一种疑问? 我有训练数据了,训练数据的准确度肯定比你大模型的输出结构准确度高,为什么还需要从老师网络来学习知识?

我觉得对于这个问题,我在李如的文章看到这样一句话:”好模型的目标不是拟合训练数据,而是学习如何泛化到新的数据“

我觉得写的很好。对于这个问题,我们这么去想,我们的大模型的输出对于logits不仅仅是类别属于哪一个,还有一个特点就是会给出不同类别之间的一个关系。

比如说,在预测”今天天气真不错,现在就决定了,出去浪一波,来顿烧烤哦“。

文本真实标签可能就直接给出了”旅游“这个标签,而我们的模型在概率输出的时候可能会发现”旅游“和”美食“两个标签都还行。

这就是模型从数据中学习到的一种”暗知识“(好像是这么叫,忘了在哪里看到了)、

而且还存在一个问题,有些时候是没有那么多训练数据的,需要的是大模型Bert这种给出无监督数据的伪标签作为冷启动也是不错的。

  1. 从蒸馏结构

从蒸馏结构来说,我们可以分为两种:

  1. 从transformer到transformer结构

  2. 从transformer结构到别的模型(CNN或者lstm结构)

我主要是想聊一下 Bert 到 TextCNN模型的蒸馏。

为啥选择textcnn?最大的原因就是速度快精读还不错。

论文参考 Distilling Task-Specific Knowledge from BERT into Simple Neural Networks

对于这个蒸馏,对于我而言,最重要的掌握一个点就是损失函数的设定,别的地方我暂且不考虑。

对于损失函数,分为两个部分,一个是我当前lstm输出结果和真实标签的交叉熵损失,一个是我的当前lstm输出结果和大模型bert的输出logits的平方损失。

至于为啥一个是交叉熵一个是平方损失,是因为其实前面的看做分类问题,后面的看做回归问题。当然只是谁更合适的选择问题。

因为是加权两个部分做损失,我这边选择为都是0.5。

当然在李如的文章中谈到,可能真实标签这边的权重小一点会更好一点,因为蒸馏本质上还是想多关注bert的输出多一点。

关于这个论文有一个很好的解释:

知识蒸馏论文选读(二) - 小禅心的文章 - 知乎 https://zhuanlan.zhihu.com/p/89420539

关于模型蒸馏,我就简单了解到这里,可能之后会花费大量精力看看背的蒸馏方式,放上开源代码:

bert到lstm的蒸馏

bert到textcnn/lstm/lkeras/torch

一个pytorch实现的模型蒸馏库

罗列一下关于Bert模型蒸馏的文章和博客:

首先一个讲的比较好的文章就是下面这个文章,比较系统的讲了一遍

BERT知识蒸馏综述 - 王三火的文章 - 知乎 https://zhuanlan.zhihu.com/p/106810758

还有一个文章讲的比较好的是:BERT 模型蒸馏 Distillation BERT

https://www.jianshu.com/p/ed7942b5207a

这个文章就是比较系统的对比了Bert的两个蒸馏操作:DistilBERT 和 Distilled BiLSTM 我觉得写得还不错

从实战的角度来说,我觉得写得很好的就是:BERT 蒸馏在垃圾舆情识别中的探索 https://blog.csdn.net/alitech2017/article/details/107412038

这个文章是对bert的蒸馏,到textcnn,使用了多种方式并且比较了最终的结果。

接下来是李如的这个文章,很概括,确实大佬,写得很好: 【DL】模型蒸馏Distillation - 李如的文章 - 知乎 https://zhuanlan.zhihu.com/p/71986772