-
This repository is an unofficial re-implementation of Poly-encoders: Transformer Architectures and Pre-training Strategies for Fast and Accurate Multi-sentence Scoring.
-
Special thanks to sfzhou5678! Some of the data preprocessing (dataset.py) and training loop code is adapted from his github repo. However, the model architecture and data representation in that repository do not follow the paper exactly, thus leading to worse performance. I re-implement the model for Bi-Encoder and Poly-Encoder in encoder.py. In addition, the model and data processing pipeline of cross encoder are also implemented.
-
Most of the training code in run.py is adpated from examples in the huggingface repository.
-
The most important architectural difference between this implementation and the original paper is that only one bert encoder is used (instead of two separate ones). Please refer to this issue for details. However, this should not affect the performance much.
-
This repository does not implement all details as in the original paper, for example, learning rate decay by 0.4 when plateau. Also due to limited computing resources, I cannot use the exact parameter settings such as batch size or context length as in the original paper. In addition, a much smaller bert model is used. Feel free to tune them or use larger models if you have more computing resources.
- Please see requirements.txt.
-
Download BERT model from Google.
-
Pick the model you like (I am using uncased_L-4_H-512_A-8.zip) and move it into bert_model/ then unzip it.
-
cd bert_model/ then bash run.sh
-
Download and unzip the ubuntu data.
-
Rename valid.txt to dev.txt for consistency.
-
Download the data from the official competition site, specifically, download train (ubuntu_train_subtask_1.json), valid (ubuntu_dev_subtask_1.json), test (ubuntu_responses_subtask_1.tsv, ubuntu_test_subtask_1.json) split of subtask 1 and put them in the dstc7/ folder.
-
cd dstc7/ then bash parse.sh
-
This dataset setting does not work for cross encoder. For details, please refer to this issue.
-
Download the data from ParlAI website and keep only ubuntu_train_subtask_1_augmented.json.
-
Move ubuntu_train_subtask_1_augmented.json into dstc7_aug/ then python3 parse.py.
-
Copy the dev.txt and test.txt file from dstc7/ into dstc7_aug/ since only training file is augmented.
-
You can refer to the original post discussing the construction of this augmented data.
-
Train a Bi-Encoder:
python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture bi
-
Train a Poly-Encoder with 16 codes:
python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture poly --poly_m 16
-
Train a Cross-Encoder:
python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture cross
-
Simply change the name of directories to ubuntu and run experiments on the ubuntu dataset.
-
Test on Bi_Encoder:
python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture bi --eval
-
Test on Poly_Encoder with 16 codes:
python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture poly --poly_m 16 --eval
-
Test on Cross_Encoder:
python3 run.py --bert_model bert_model/ --output_dir output_dstc7/ --train_dir dstc7/ --use_pretrain --architecture cross --eval
-
All the experiments are done on a single GTX 1080 GPU with 8G memory and i7-6700K CPU @ 4.00GHz.
-
Default parameters in run.py are used, please refer to run.py for details.
-
The results are calculated on sampled portion (1000 instances) of dev set.
-
da = data augmentation, we only report one result with poly vectors=64 and bert-base (uncased_L-12_H-768_A-12) with data augmentation (dstc7_aug). This result is really close to numbers reported in the original paper.
Ubuntu:
Model | R@1 | R@2 | R@5 | R@10 | MRR |
---|---|---|---|---|---|
Bi-Encoder | 0.760 | 0.855 | 0.971 | 1.00 | 0.844 |
Poly-Encoder 16 | 0.766 | 0.868 | 0.974 | 1.00 | 0.851 |
Poly-Encoder 64 | 0.767 | 0.880 | 0.979 | 1.00 | 0.854 |
Poly-Encoder 360 | 0.754 | 0.858 | 0.970 | 1.00 | 0.842 |
DSTC 7:
Model | R@1 | R@2 | R@5 | R@10 | MRR |
---|---|---|---|---|---|
Bi-Encoder | 0.437 | 0.524 | 0.644 | 0.753 | 0.538 |
Poly-Encoder 16 | 0.447 | 0.534 | 0.668 | 0.760 | 0.550 |
Poly-Encoder 64 | 0.438 | 0.540 | 0.668 | 0.755 | 0.546 |
Poly-Encoder 360 | 0.453 | 0.553 | 0.665 | 0.751 | 0.545 |
Cross-Encoder | 0.502 | 0.595 | 0.712 | 0.790 | 0.599 |
da + bert base | 0.561 | 0.659 | 0.765 | 0.858 | 0.659 |