This code is based on Fairseq v0.6.2. Note that the summarization task requires a newer version, e.g. Fairseq v0.10.2, we will release this code soon.
- PyTorch version >= 1.2.0
- python version >= 3.6
1、Download WMT14' En-De and WMT14' En-Fr
1、Download CNN dataset and Daily Mail dataset
bash preprocess_cnndaily_bin.sh path/to/cnndm_raw_data
1、Download FCE v2.1 dataset、Lang-8 Corpus of Learner English dataset、NUCLE dataset、W&I+LOCNESS v2.1 dataset
bash prepare_conll14_test_data.sh
bash preprocess_gec.sh
bash preprocess_gec_bin.sh
bash train_wmt_en_de.sh
python3 -u train.py data-bin/$data_dir
--distributed-world-size 8 -s src -t tgt
--arch transformer_ode_t2t_wmt_en_de_big
--share-all-embeddings
--optimizer adam --clip-norm 0.0
--adam-betas '(0.9, 0.997)'
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 16000
--lr 0.002 --min-lr 1e-09
--criterion label_smoothed_cross_entropy --label-smoothing 0.1
--max-tokens 4096
--update-freq 4
--max-epoch 20
--dropout 0.3 --attention-dropout 0.1 -- relu-dropout 0.1
--no-progress-bar
--log-interval 100
--ddp-backend no_c10d
--seed 1
--save-dir $save_dir
--keep-last-epochs 10
bash train_wmt_en_fr.sh
python3 -u train.py data-bin/$data_dir
--distributed-world-size 8 -s src -t tgt
--arch transformer_ode_t2t_wmt_en_de_big
--share-all-embeddings
--optimizer adam --clip-norm 0.0
--adam-betas '(0.9, 0.997)'
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 16000
--lr 0.002 --min-lr 1e-09
--criterion label_smoothed_cross_entropy --label-smoothing 0.1
--max-tokens 4096
--update-freq 8
--max-epoch 20
--dropout 0.1 --attention-dropout 0.1 -- relu-dropout 0.1
--no-progress-bar
--log-interval 100
--ddp-backend no_c10d
--seed 1
--save-dir $save_dir
--keep-last-epochs 10
bash train_cnn_daily.sh
python3 -u train.py data-bin/$data_dir
--distributed-world-size 8 -s src -t tgt
--arch transformer_ode_t2t_wmt_en_de
--share-all-embeddings
--optimizer adam --clip-norm 0.0
--adam-betas '(0.9, 0.997)'
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 8000
--lr 0.002 --min-lr 1e-09
--weight-decay 0.0001
--criterion label_smoothed_cross_entropy --label-smoothing 0.1
--max-tokens 4096
--update-freq 4
--max-epoch 20
--dropout 0.1 --attention-dropout 0.1 -- relu-dropout 0.1
--truncate-source --skip-invalid-size-inputs-valid-test --max-source-positions 500
--no-progress-bar
--log-interval 100
--ddp-backend no_c10d
--seed 1
--save-dir $save_dir
--keep-last-epochs 10
bash train_gec.sh
python3 -u train.py data-bin/$data_dir
--distributed-world-size 8 -s src -t tgt
--arch transformer_ode_t2t_wmt_en_de
--share-all-embeddings
--optimizer adam --clip-norm 0.0
--adam-betas '(0.9, 0.98)'
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000
--lr 0.0015 --min-lr 1e-09
--weight-decay 0.0001
--criterion label_smoothed_cross_entropy --label-smoothing 0.1
--max-tokens 4096
--update-freq 2
--max-epoch 55
--dropout 0.2 --attention-dropout 0.1 -- relu-dropout 0.1
--no-progress-bar
--log-interval 100
--ddp-backend no_c10d
--seed 1
--save-dir $save_dir
--keep-last-epochs 10
--tensorboard-logdir $save_dir"
We measure the performance through multi-bleu and sacrebleu
python3 generate.py \
data-bin/wmt-en2de \
--path $model_dir/$checkpoint \
--gen-subset test \
--batch-size 64 \
--beam 4 \
--lenpen 0.6 \
--output hypo.txt \
--quiet \
--remove-bpe
We measure the performance through multi-bleu and sacrebleu
python3 generate.py \
data-bin/wmt-en2fr \
--path $model_dir/$checkpoint \
--gen-subset test \
--batch-size 64 \
--beam 4 \
--lenpen 0.6 \
--output hypo.txt \
--quiet \
--remove-bpe
We use pyrouge as the scoring script.
python3 generate.py \
data-bin/$data_dir \
--path $model_dir/$checkpoint \
--gen-subset test \
--truncate-source \
--batch-size 32 \
--lenpen 2.0 \
--min-len 55 \
--max-len-b 140 \
--max-source-positions 500 \
--beam 4 \
--no-repeat-ngram-size 3 \
--remove-bpe
python3 get_rouge.py --decodes_filename $model_dir/hypo.sorted.tok --targets_filename cnndm.test.target.tok
We use m2scorer as the scoring script.
python3 generate.py \
data-bin/$data_dir \
--path $model_dir/$checkpoint \
--gen-subset test \
--batch-size 64 \
--beam 4 \
--lenpen 2.0 \
--output hypo.txt \
--quiet \
--remove-bpe
path/to/m2scorer path/to/model_output path/to/conll14st-test.m2
Model | Layer | En-De | En-Fr |
---|---|---|---|
Residual-block (baseline) | 6-6 | 29.21 | 42.89 |
RK2-block (learnable |
6-6 | 30.53 | 43.59 |
Residual-block (baseline) | 12-6 | 29.91 | 43.22 |
RK2-block (learnable |
12-6 | 30.76 | 44.11 |
Model | RG-1 | RG-2 | RG-L |
---|---|---|---|
Residual-block | 40.47 | 17.73 | 37.29 |
RK2-block ((learnable |
41.58 | 18.57 | 38.41 |
RK4-block | 41.83 | 18.84 | 38.68 |
Model | Prec. | Recall | F_0.5 |
---|---|---|---|
Residual-block | 67.97 | 32.17 | 55.61 |
RK2-block ((learnable |
68.21 | 35.30 | 57.49 |
RK4-block | 66.20 | 38.13 | 57.71 |