-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementation of the paper "Jointly Learning to Align and Translate …
…with Transformer Models" (#877) Summary: Pull Request resolved: fairinternal/fairseq-py#877 This PR implements guided alignment training described in "Jointly Learning to Align and Translate with Transformer Models (https://arxiv.org/abs/1909.02074)". In summary, it allows for training selected heads of the Transformer Model with external alignments computed by Statistical Alignment Toolkits. During inference, attention probabilities from the trained heads can be used to extract reliable alignments. In our work, we did not see any regressions in the translation performance because of guided alignment training. Pull Request resolved: #1095 Differential Revision: D17170337 Pulled By: myleott fbshipit-source-id: daa418bef70324d7088dbb30aa2adf9f95774859
- Loading branch information
1 parent
acb6fba
commit 1c66792
Showing
20 changed files
with
899 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019) | ||
|
||
This page includes instructions for training models described in [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](https://arxiv.org/abs/1909.02074). | ||
|
||
## Training a joint alignment-translation model on WMT'18 En-De | ||
|
||
##### 1. Extract and preprocess the WMT'18 En-De data | ||
```bash | ||
./prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh | ||
``` | ||
|
||
##### 2. Generate alignments from statistical alignment toolkits e.g. Giza++/FastAlign. | ||
In this example, we use FastAlign. | ||
```bash | ||
git clone git@github.com:clab/fast_align.git | ||
pushd fast_align | ||
mkdir build | ||
cd build | ||
cmake .. | ||
make | ||
popd | ||
ALIGN=fast_align/build/fast_align | ||
paste bpe.32k/train.en bpe.32k/train.de | awk -F '\t' '{print $1 " ||| " $2}' > bpe.32k/train.en-de | ||
$ALIGN -i bpe.32k/train.en-de -d -o -v > bpe.32k/train.align | ||
``` | ||
|
||
##### 3. Preprocess the dataset with the above generated alignments. | ||
```bash | ||
fairseq-preprocess \ | ||
--source-lang en --target-lang de \ | ||
--trainpref bpe.32k/train \ | ||
--validpref bpe.32k/valid \ | ||
--testpref bpe.32k/test \ | ||
--align-suffix align \ | ||
--destdir binarized/ \ | ||
--joined-dictionary \ | ||
--workers 32 | ||
``` | ||
|
||
##### 4. Train a model | ||
```bash | ||
fairseq-train \ | ||
binarized \ | ||
--arch transformer_wmt_en_de_big_align --share-all-embeddings \ | ||
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --activation-fn relu\ | ||
--lr 0.0002 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \ | ||
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ | ||
--max-tokens 3500 --label-smoothing 0.1 \ | ||
--save-dir ./checkpoints --log-interval 1000 --max-update 60000 \ | ||
--keep-interval-updates -1 --save-interval-updates 0 \ | ||
--load-alignments --criterion label_smoothed_cross_entropy_with_alignment \ | ||
--fp16 | ||
``` | ||
|
||
Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer. | ||
|
||
If you want to train the above model with big batches (assuming your machine has 8 GPUs): | ||
- add `--update-freq 8` to simulate training on 8x8=64 GPUs | ||
- increase the learning rate; 0.0007 works well for big batches | ||
|
||
##### 5. Evaluate and generate the alignments (BPE level) | ||
```bash | ||
fairseq-generate \ | ||
binarized --gen-subset test --print-alignment \ | ||
--source-lang en --target-lang de \ | ||
--path checkpoints/checkpoint_best.pt --beam 5 --nbest 1 | ||
``` | ||
|
||
##### 6. Other resources. | ||
The code for: | ||
1. preparing alignment test sets | ||
2. converting BPE level alignments to token level alignments | ||
3. symmetrizing bidirectional alignments | ||
4. evaluating alignments using AER metric | ||
can be found [here](https://github.com/lilt/alignment-scripts) | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@inproceedings{garg2019jointly, | ||
title = {Jointly Learning to Align and Translate with Transformer Models}, | ||
author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias}, | ||
booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)}, | ||
address = {Hong Kong}, | ||
month = {November}, | ||
url = {https://arxiv.org/abs/1909.02074}, | ||
year = {2019}, | ||
} | ||
``` |
118 changes: 118 additions & 0 deletions
118
examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#!/bin/bash | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
echo 'Cloning Moses github repository (for tokenization scripts)...' | ||
git clone https://github.com/moses-smt/mosesdecoder.git | ||
|
||
SCRIPTS=mosesdecoder/scripts | ||
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl | ||
CLEAN=$SCRIPTS/training/clean-corpus-n.perl | ||
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl | ||
|
||
URLS=( | ||
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz" | ||
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz" | ||
"http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz" | ||
"http://data.statmt.org/wmt18/translation-task/rapid2016.tgz" | ||
"http://data.statmt.org/wmt17/translation-task/dev.tgz" | ||
"http://statmt.org/wmt14/test-full.tgz" | ||
) | ||
CORPORA=( | ||
"training/europarl-v7.de-en" | ||
"commoncrawl.de-en" | ||
"training-parallel-nc-v13/news-commentary-v13.de-en" | ||
"rapid2016.de-en" | ||
) | ||
|
||
if [ ! -d "$SCRIPTS" ]; then | ||
echo "Please set SCRIPTS variable correctly to point to Moses scripts." | ||
exit | ||
fi | ||
|
||
src=en | ||
tgt=de | ||
lang=en-de | ||
prep=wmt18_en_de | ||
tmp=$prep/tmp | ||
orig=orig | ||
dev=dev/newstest2012 | ||
codes=32000 | ||
bpe=bpe.32k | ||
|
||
mkdir -p $orig $tmp $prep $bpe | ||
|
||
cd $orig | ||
|
||
for ((i=0;i<${#URLS[@]};++i)); do | ||
url=${URLS[i]} | ||
file=$(basename $url) | ||
if [ -f $file ]; then | ||
echo "$file already exists, skipping download" | ||
else | ||
wget "$url" | ||
if [ -f $file ]; then | ||
echo "$url successfully downloaded." | ||
else | ||
echo "$url not successfully downloaded." | ||
exit 1 | ||
fi | ||
if [ ${file: -4} == ".tgz" ]; then | ||
tar zxvf $file | ||
elif [ ${file: -4} == ".tar" ]; then | ||
tar xvf $file | ||
fi | ||
fi | ||
done | ||
cd .. | ||
|
||
echo "pre-processing train data..." | ||
for l in $src $tgt; do | ||
rm -rf $tmp/train.tags.$lang.tok.$l | ||
for f in "${CORPORA[@]}"; do | ||
cat $orig/$f.$l | \ | ||
perl $REM_NON_PRINT_CHAR | \ | ||
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/train.tags.$lang.tok.$l | ||
done | ||
done | ||
|
||
echo "pre-processing test data..." | ||
for l in $src $tgt; do | ||
if [ "$l" == "$src" ]; then | ||
t="src" | ||
else | ||
t="ref" | ||
fi | ||
grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \ | ||
sed -e 's/<seg id="[0-9]*">\s*//g' | \ | ||
sed -e 's/\s*<\/seg>\s*//g' | \ | ||
sed -e "s/\’/\'/g" | \ | ||
perl $TOKENIZER -threads 8 -l $l -no-escape > $tmp/test.$l | ||
echo "" | ||
done | ||
|
||
# apply length filtering before BPE | ||
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train 1 100 | ||
|
||
# use newstest2012 for valid | ||
echo "pre-processing valid data..." | ||
for l in $src $tgt; do | ||
rm -rf $tmp/valid.$l | ||
cat $orig/$dev.$l | \ | ||
perl $REM_NON_PRINT_CHAR | \ | ||
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/valid.$l | ||
done | ||
|
||
mkdir output | ||
mv $tmp/{train,valid,test}.{$src,$tgt} output | ||
|
||
#BPE | ||
git clone git@github.com:glample/fastBPE.git | ||
pushd fastBPE | ||
g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast | ||
popd | ||
fastBPE/fast learnbpe $codes output/train.$src output/train.$tgt > $bpe/codes | ||
for split in {train,valid,test}; do for lang in {en,de}; do fastBPE/fast applybpe $bpe/$split.$lang output/$split.$lang $bpe/codes; done; done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
90 changes: 90 additions & 0 deletions
90
fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import math | ||
|
||
from fairseq import utils | ||
|
||
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion | ||
from . import register_criterion | ||
|
||
|
||
@register_criterion('label_smoothed_cross_entropy_with_alignment') | ||
class LabelSmoothedCrossEntropyCriterionWithAlignment(LabelSmoothedCrossEntropyCriterion): | ||
|
||
def __init__(self, args, task): | ||
super().__init__(args, task) | ||
self.alignment_lambda = args.alignment_lambda | ||
|
||
@staticmethod | ||
def add_args(parser): | ||
"""Add criterion-specific arguments to the parser.""" | ||
super(LabelSmoothedCrossEntropyCriterionWithAlignment, | ||
LabelSmoothedCrossEntropyCriterionWithAlignment).add_args(parser) | ||
parser.add_argument('--alignment-lambda', default=0.05, type=float, metavar='D', | ||
help='weight for the alignment loss') | ||
|
||
def forward(self, model, sample, reduce=True): | ||
"""Compute the loss for the given sample. | ||
Returns a tuple with three elements: | ||
1) the loss | ||
2) the sample size, which is used as the denominator for the gradient | ||
3) logging outputs to display while training | ||
""" | ||
net_output = model(**sample['net_input']) | ||
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) | ||
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] | ||
logging_output = { | ||
'loss': utils.item(loss.data) if reduce else loss.data, | ||
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data, | ||
'ntokens': sample['ntokens'], | ||
'nsentences': sample['target'].size(0), | ||
'sample_size': sample_size, | ||
} | ||
|
||
alignment_loss = None | ||
|
||
# Compute alignment loss only for training set and non dummy batches. | ||
if 'alignments' in sample and sample['alignments'] is not None: | ||
alignment_loss = self.compute_alignment_loss(sample, net_output) | ||
|
||
if alignment_loss is not None: | ||
logging_output['alignment_loss'] = utils.item(alignment_loss.data) | ||
loss += self.alignment_lambda * alignment_loss | ||
|
||
return loss, sample_size, logging_output | ||
|
||
def compute_alignment_loss(self, sample, net_output): | ||
attn_prob = net_output[1]['attn'] | ||
bsz, tgt_sz, src_sz = attn_prob.shape | ||
attn = attn_prob.view(bsz * tgt_sz, src_sz) | ||
|
||
align = sample['alignments'] | ||
align_weights = sample['align_weights'].float() | ||
|
||
if len(align) > 0: | ||
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to | ||
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing. | ||
loss = -((attn[align[:, 1][:, None], align[:, 0][:, None]]).log() * align_weights[:, None]).sum() | ||
else: | ||
return None | ||
|
||
return loss | ||
|
||
@staticmethod | ||
def aggregate_logging_outputs(logging_outputs): | ||
"""Aggregate logging outputs from data parallel training.""" | ||
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) | ||
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) | ||
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) | ||
return { | ||
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0., | ||
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2) if ntokens > 0 else 0., | ||
'alignment_loss': sum(log.get('alignment_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0., | ||
'ntokens': ntokens, | ||
'nsentences': nsentences, | ||
'sample_size': sample_size, | ||
} |
Oops, something went wrong.
1c66792
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does not compile, error on line 242 with **extra