Exploring various text classification models based on PyTorch.
基于PyTorch探索各种文本分类模型
Text classification is one of the basic and classic task of NLP. The repository implements various models about text classification based on PyTorch. Anyone can easily learn how to build text classification model and apply it on various dataset. Besides this repo, we also provide another repo TCPapers about worth-reading papers and related resources on text classification. Contribution of any kind welcome!
There are features of this repo:
- Support various models such as FastText, TextCNN, TextRNN, TextRCNN, Attn-BiLSTM, Transformer, etc.
- Support various attention mechanism such as (scaled) dot product, MLP, bi-linear, MHSA, etc.
- Support pretrained embedding like word2vec, Glove, Tencent AILab Chinese Embedding, etc.
- Support various preprocessed English and Chinese dataset as benchmark such as SST-1, SST-2, TREC, etc.
- Support multiple optimization methods sush Adam, SGD, Adadelta, etc.
- Support multiple loss function for text classification such as Softmax, Label Smoothing, Focal Loss, etc.
- Support multiple text classification task such as binary classification, multi-classification.
- Support various tricks like highway network, position embedding, customized feature, etc.
- Support pretrained language model like BERT, ELMo and so on.
文本分类是自然语言处理的一项基本而经典的任务。本仓库实现了基于PyTorch的多种文本分类模型。任何人都能很容易学习如何构建文本分类模型,并且将其应用在各种数据集上。除了本仓库,我们还有一个收集关于文本分类领域值得一读的论文与相关资源合集的仓库 TCPapers 。欢迎各种形式的仓库贡献!
本仓库有如下特性:
- 支持多种模型,如FastText、TextCNN、TextRNN、TextRCNN、Attn-BiLSTM等
- 支持多种注意力机制,如点积、缩放点击、MLP、多头自注意力等
- 支持各种预训练词向量如word2vec、Glove、腾讯中文词向量等
- 提供多种预处理好的中英文数据集如SST-1、SST-2、TREC等
- 支持多种优化方式如Adam、SGD、Adadelta等
- 支持多种适用于文本分类的损失函数如Softmax、标签平滑、Focal Loss等
- 支持多种文本分类任务如二分类、多分类等
- 支持如高速连接融合字词向量、位置向量、自定义特征等技巧
- 支持预训练语言模型如BERT、ELMo等
Run run.py
with specified arguments to train model.
In this repository, we have preprocessed various famous datasets like SST-1, SST-2, TREC and so on.
dataset | avg length | #classes | #train | #val | task | download |
---|---|---|---|---|---|---|
SST-2 | 19 | 2 | 67349 | 872 | sentiment | link |
If you want to apply the code on your own dataset. Please follow these:
- Reformat your data file to format like
{label}\t{tokens}
. - Split your dataset to train set and test set (labeled) .
- Add your code about the name and directories of your dataset in
main()
ofrun.py
(from line 205). - Change the value of argument
--dataset
with your dataset name and then run python file. - In order to get best result, you should wait a few more steps when the accuracy is floating unsteadily instead of early-stop.
- Convolutional Neural Networks for Sentence Classification. Yoon Kim. (EMNLP 2014) [paper] - TextCNN
- Recurrent Neural Network for Text Classification with Multi-Task Learning. Pengfei Liu, Xipeng Qiu, Xuanjing Huang. (IJCAI 2016) [paper] - TextRNN
- Recurrent Convolutional Neural Networks for Text Classification. Siwei Lai, Liheng Xu, Kang Liu, Jun Zhao. (AAAI 2015) [paper] - TextRCNN
- Bag of Tricks for Efficient Text Classification. Armand Joulin, Edouard Grave, Piotr Bojanowski, Tomas Mikolov. (EACL 2016) [paper] - FastText
- Attention-Based Bidirectional Long Short-Term Memory Networks for Relation Classification. Peng Zhou, Wei Shi, Jun Tian, Zhenyu Qi, Bingchen Li, Hongwei Hao, Bo Xu. (ACL 2016) [paper] - Attn-BiLSTM
You can specify your embedding file with argument --embed_file
. With None value, the model will initialize an embedding matrix randomly.
- For non-BERT-based models, training procedure is not very stable. Model performance is affected by many factors like initializer, pretrained embeddings, learning scheduler, random seed and so on.
- Models are sensitive about where to apply dropout layer. The best practice of us is apply it before final dense layer or after embedding layer. Try different positions to get best results.
- We recommend set the learning rate between 1e-2 to 1e-4 for non-BERT-based models and 1e-4 to 1e-5 for BERT-based models.
- Because non-BERT-based model is unstable, you should try it with many epochs and different random seed.
model | acc / F1 score (SST-2) |
---|---|
FastText | 0.7959 / 0.7959 |
TextCNN | 0.8612 / 0.8608 |
TextRNN | 0.8544 / 0.8541 |
TextRCNN | 0.8635 / 0.8635 |
Attn-BiLSTM | 0.8716 / 0.8715 |
We only test our code in environment below.
- Python 3
- PyTorch 1.0+
- sickit_learn 0.23+
- numpy 1.18+
MIT