Skip to content
/ GPKD Public

The codebase of paper:Learning Light-Weight Translation Models from Deep Transformer, which is accepted by AAAI2021 conference.

License

Notifications You must be signed in to change notification settings

libeineu/GPKD

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Learning Light-Weight Translation Models from Deep Transformer

Bei Li, Ziyang Wang, Hui Liu, Quan Du, Tong Xiao, Chunliang Zhang, Jingbo Zhu. In Proceedings of AAAI, 2021. [paper]

GPKD Method on Fairseq

The GPKD method is based on the Transformer system Fairseq v0.6.2 implemented by Facebook

Runtime Environment

This system has been tested in the following environment.

  • Python version >=3.6
  • Pytorch version >=1.0.0

Group-Permutation Training:

First, we train the teacher network with group-permutation training strategy which rectifies the information flow.

For the --arch and arguments, group_ should be used as the prefix for teacher network, such as: --arch transformer_t2t_wmt_en_de -> --arch group_transformer_t2t_wmt_en_de

We can set the detailed arguments of different architectures in group_transformer.py

Example of the script for training phase:

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8
max_token=2048 
data_dir=google 
save_dir_1=
python3 -u train.py data-bin/$data_dir \
--distributed-world-size 8 -s en -t de \
--ddp-backend no_c10d \
--arch group_transformer_t2t_wmt_en_de \
--optimizer adam --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 16000 \
--lr $lr_1 --min-lr 1e-09 \
--weight-decay 0.0 \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--max-tokens $max_tokens \
--update-freq 4 \
--no-progress-bar \
--fp16 \
--adam-betas '(0.9, 0.997)' \
--log-interval 100 \
--share-all-embeddings \
--max-epoch 21 \
--save-dir $save_dir_1 \
--keep-last-epochs 5 \
--tensorboard-logdir $save_dir_1 > $save_dir_1/train.log

We can use the group-permutation method to train a teacher network by train.sh

sh train.sh

Generating SKD Data

Given the training dataset { X, Y }, the teacher network translates the source inputs into the target Z. Then the SKD data is the collection of { X, Z }.

We can set the save dir of the teacher model and the other arguments to generate the target of SKD Data by translation.sh.

sh translation.sh

Student Training

First, we need extract the student network form the teacher network by extract.py, we can set the index of the encoder layers or decoder layers that will be extracted.

python3 extract.py teacher_dir/checkpoint_last.pt student_dir/checkpoint_last.pt

Then, we can finetune the student network with the SKD data by train.sh

sh train.sh

But we need reset the --arch, --save-dir and --max_epoch

The architecture of student model is Transformer, so we reset --arch: --arch group_transformer_t2t_wmt_en_de -> --arch transformer_t2t_wmt_en_de

Skipping Sub-Layers Method

To further enhance the teacher model, we present a Skipping Sub-Layer method to randomly omit sub-layers to introduce perturbation into training.

For the --arch and arguments, skipping_sublayer_ should be used as the prefix for teacher network, such as: --arch transformer_t2t_wmt_en_de -> --arch skipping_sublayer_transformer_t2t_wmt_en_de

We can set the detailed arguments of different architecture in skipping_sublayer_transformer.py.

We can use the Skipping Sub-Layer method method to train a network by train.sh

sh train.sh

Citation

please cite as:

@article{li2020learning,
  title={Learning Light-Weight Translation Models from Deep Transformer},
  author={Li, Bei and Wang, Ziyang and Liu, Hui and Du, Quan and Xiao, Tong and Zhang, Chunliang and Zhu, Jingbo},
  journal={arXiv preprint arXiv:2012.13866},
  year={2020}
}



MIT License Latest Release Build Status Documentation Status


Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization, language modeling and other text generation tasks.

What's New:

Features:

Fairseq provides reference implementations of various sequence-to-sequence models, including:

Additionally:

  • multi-GPU (distributed) training on one machine or across multiple machines
  • fast generation on both CPU and GPU with multiple search algorithms implemented:
  • large mini-batch training even on a single GPU via delayed updates
  • mixed precision training (trains faster with less GPU memory on NVIDIA tensor cores)
  • extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers

We also provide pre-trained models for translation and language modeling with a convenient torch.hub interface:

en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
en2de.translate('Hello world', beam=5)
# 'Hallo Welt'

See the PyTorch Hub tutorials for translation and RoBERTa for more examples.

Model

Requirements and Installation

  • PyTorch version >= 1.4.0
  • Python version >= 3.6
  • For training new models, you'll also need an NVIDIA GPU and NCCL
  • For faster training install NVIDIA's apex library with the --cuda_ext and --deprecated_fused_adam options

To install fairseq:

pip install fairseq

On MacOS:

CFLAGS="-stdlib=libc++" pip install fairseq

If you use Docker make sure to increase the shared memory size either with --ipc=host or --shm-size as command line options to nvidia-docker run.

Installing from source

To install fairseq from source and develop locally:

git clone https://github.com/pytorch/fairseq
cd fairseq
pip install --editable .

Getting Started

The full documentation contains instructions for getting started, training new models and extending fairseq with new model types and tasks.

Pre-trained models and examples

We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below, as well as example training and evaluation commands.

  • Translation: convolutional and transformer models are available
  • Language Modeling: convolutional and transformer models are available
  • wav2vec: wav2vec large model is available

We also have more detailed READMEs to reproduce results from specific papers:

Join the fairseq community

License

fairseq(-py) is BSD-licensed. The license applies to the pre-trained models as well. We also provide an additional patent grant.

About

The codebase of paper:Learning Light-Weight Translation Models from Deep Transformer, which is accepted by AAAI2021 conference.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages