Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of the paper "Jointly Learning to Align and Translate with Transformer Models" #1095

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions examples/joint_alignment_translation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)

This page includes instructions for training models described in [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](https://arxiv.org/abs/1909.02074).

## Training a joint alignment-translation model on WMT'18 En-De

##### 1. Extract and preprocess the WMT'18 En-De data
```bash
./prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh
```

##### 2. Generate alignments from statistical alignment toolkits e.g. Giza++/FastAlign.
In this example, we use FastAlign.
```bash
git clone git@github.com:clab/fast_align.git
pushd fast_align
mkdir build
cd build
cmake ..
make
popd
ALIGN=fast_align/build/fast_align
paste bpe.32k/train.en bpe.32k/train.de | awk -F '\t' '{print $1 " ||| " $2}' > bpe.32k/train.en-de
$ALIGN -i bpe.32k/train.en-de -d -o -v > bpe.32k/train.align
```

##### 3. Preprocess the dataset with the above generated alignments.
```bash
fairseq-preprocess \
--source-lang en --target-lang de \
--trainpref bpe.32k/train \
--validpref bpe.32k/valid \
--testpref bpe.32k/test \
--align-suffix align \
--destdir binarized/ \
--joined-dictionary \
--workers 32
```

##### 4. Train a model
```bash
fairseq-train \
binarized \
--arch transformer_wmt_en_de_big_align --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --activation-fn relu\
--lr 0.0002 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
--max-tokens 3500 --label-smoothing 0.1 \
--save-dir ./checkpoints --log-interval 1000 --max-update 60000 \
--keep-interval-updates -1 --save-interval-updates 0 \
--load-alignments --criterion label_smoothed_cross_entropy_with_alignment
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

\ is missing at the end of this line :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

--fp16
```

Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer.

If you want to train the above model with big batches (assuming your machine has 8 GPUs):
- add `--update-freq 8` to simulate training on 8x8=64 GPUs
- increase the learning rate; 0.0007 works well for big batches

##### 4. Evaluate and generate the alignments (BPE level)
```bash
fairseq-generate \
binarized --gen-subset test --print-alignment \
--source-lang en --target-lang de \
--path checkpoints/checkpoint_best.pt --beam 5 --nbest 1\
```

##### 5. Other resources.
The code for:
1. preparing alignment test sets
2. converting BPE level alignments to token level alignments
3. symmetrizing bidirectional alignments
4. evaluating alignments using AER metric
can be found [here](https://github.com/lilt/alignment-scripts)

## Citation

```bibtex
@inproceedings{garg2019jointly,
title = {Jointly Learning to Align and Translate with Transformer Models},
author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias},
booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)},
address = {Hong Kong},
month = {November},
url = {https://arxiv.org/abs/1909.02074},
year = {2019},
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#!/bin/bash

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git

SCRIPTS=mosesdecoder/scripts
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl

URLS=(
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
"http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
"http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
"http://data.statmt.org/wmt17/translation-task/dev.tgz"
"http://statmt.org/wmt14/test-full.tgz"
)
CORPORA=(
"training/europarl-v7.de-en"
"commoncrawl.de-en"
"training-parallel-nc-v13/news-commentary-v13.de-en"
"rapid2016.de-en"
)

if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
fi

src=en
tgt=de
lang=en-de
prep=wmt18_en_de
tmp=$prep/tmp
orig=orig
dev=dev/newstest2012
codes=32000
bpe=bpe.32k

mkdir -p $orig $tmp $prep $bpe

cd $orig

for ((i=0;i<${#URLS[@]};++i)); do
url=${URLS[i]}
file=$(basename $url)
if [ -f $file ]; then
echo "$file already exists, skipping download"
else
wget "$url"
if [ -f $file ]; then
echo "$url successfully downloaded."
else
echo "$url not successfully downloaded."
exit -1
fi
if [ ${file: -4} == ".tgz" ]; then
tar zxvf $file
elif [ ${file: -4} == ".tar" ]; then
tar xvf $file
fi
fi
done
cd ..

echo "pre-processing train data..."
for l in $src $tgt; do
rm -rf $tmp/train.tags.$lang.tok.$l
for f in "${CORPORA[@]}"; do
cat $orig/$f.$l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/train.tags.$lang.tok.$l
done
done

echo "pre-processing test data..."
for l in $src $tgt; do
if [ "$l" == "$src" ]; then
t="src"
else
t="ref"
fi
grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
sed -e 's/<seg id="[0-9]*">\s*//g' | \
sed -e 's/\s*<\/seg>\s*//g' | \
sed -e "s/\’/\'/g" | \
perl $TOKENIZER -threads 8 -l $l -no-escape > $tmp/test.$l
echo ""
done

# apply length filtering before BPE
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train 1 100

# use newstest2012 for valid
echo "pre-processing valid data..."
for l in $src $tgt; do
rm -rf $tmp/valid.$l
cat $orig/$dev.$l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/valid.$l
done

mkdir output
mv $tmp/{train,valid,test}.{$src,$tgt} output

#BPE
git clone git@github.com:glample/fastBPE.git
pushd fastBPE
g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
popd
fastBPE/fast learnbpe $codes output/train.$src output/train.$tgt > $bpe/codes
for split in {train,valid,test}; do for lang in {en,de}; do fastBPE/fast applybpe $bpe/$split.$lang output/$split.$lang $bpe/codes; done; done
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
import torch.nn.functional as F
Expand Down
10 changes: 6 additions & 4 deletions fairseq/data/language_pair_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,13 @@ def compute_alignment_weights(alignments):

alignments = [alignment + offset for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths) \
for alignment in [samples[align_idx]['alignment'].view(-1, 2)] if check_alignment(alignment, src_len, tgt_len)]
alignments = torch.cat(alignments, dim=0)
align_weights = compute_alignment_weights(alignments)

batch['alignments'] = alignments
batch['align_weights'] = align_weights
if len(alignments) > 0:
alignments = torch.cat(alignments, dim=0)
align_weights = compute_alignment_weights(alignments)

batch['alignments'] = alignments
batch['align_weights'] = align_weights

return batch

Expand Down
2 changes: 1 addition & 1 deletion fairseq/tasks/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def split_exists(split, src, tgt, lang, data_path):

align_dataset = None
if load_alignments:
align_path = os.path.join(data_path, '{}.align'.format(split))
align_path = os.path.join(data_path, '{}.align.{}-{}'.format(split, src, tgt))
if indexed_dataset.dataset_exists(align_path, impl=dataset_impl):
align_dataset = data_utils.load_indexed_dataset(align_path, None, dataset_impl)

Expand Down
13 changes: 7 additions & 6 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,8 @@ def consumer(tensor):


def binarize_alignments(args, filename, parse_alignment, output_prefix, offset, end):
ds = indexed_dataset.IndexedDatasetBuilder(
dataset_dest_file(args, output_prefix, None, "bin")
)
ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"),
impl=args.dataset_impl, vocab_size=None)

def consumer(tensor):
ds.add_item(tensor)
Expand All @@ -322,9 +321,11 @@ def consumer(tensor):

def dataset_dest_prefix(args, output_prefix, lang):
base = "{}/{}".format(args.destdir, output_prefix)
lang_part = (
".{}-{}.{}".format(args.source_lang, args.target_lang, lang) if lang is not None else ""
)
if lang is not None:
lang_part = ".{}-{}.{}".format(args.source_lang, args.target_lang, lang)
else:
lang_part = ".{}-{}".format(args.source_lang, args.target_lang)

return "{}{}".format(base, lang_part)


Expand Down
36 changes: 35 additions & 1 deletion tests/test_binaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,22 @@ def test_mixture_of_experts(self):
'--gen-expert', '0'
])

def test_alignment(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_alignment') as data_dir:
create_dummy_data(data_dir, alignment=True)
preprocess_translation_data(data_dir, ['--align-suffix', 'align'])
train_translation_model(data_dir, 'transformer_align',
['--encoder-layers', '2',
'--decoder-layers', '2',
'--encoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--load-alignments',
'--alignment-layer', '1',
'--criterion', 'label_smoothed_cross_entropy_with_alignment'],
run_validation=True)
generate_main(data_dir)


class TestStories(unittest.TestCase):

Expand Down Expand Up @@ -421,7 +437,7 @@ def test_optimizers(self):
generate_main(data_dir)


def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
def create_dummy_data(data_dir, num_examples=1000, maxlen=20, alignment=False):

def _create_dummy_data(filename):
data = torch.rand(num_examples * maxlen)
Expand All @@ -434,13 +450,31 @@ def _create_dummy_data(filename):
print(ex_str, file=h)
offset += ex_len

def _create_dummy_alignment_data(filename_src, filename_tgt, filename):
with open(os.path.join(data_dir, filename_src), 'r') as src_f, \
open(os.path.join(data_dir, filename_tgt), 'r') as tgt_f, \
open(os.path.join(data_dir, filename), 'w') as h:
for src, tgt in zip(src_f, tgt_f):
src_len = len(src.split())
tgt_len = len(tgt.split())
avg_len = (src_len + tgt_len) // 2
num_alignments = random.randint(avg_len // 2, 2 * avg_len)
src_indices = torch.floor(torch.rand(num_alignments) * src_len).int()
tgt_indices = torch.floor(torch.rand(num_alignments) * tgt_len).int()
ex_str = ' '.join(["{}-{}".format(src, tgt) for src, tgt in zip(src_indices, tgt_indices)])
print(ex_str, file=h)

_create_dummy_data('train.in')
_create_dummy_data('train.out')
_create_dummy_data('valid.in')
_create_dummy_data('valid.out')
_create_dummy_data('test.in')
_create_dummy_data('test.out')

if alignment:
_create_dummy_alignment_data('train.in', 'train.out', 'train.align')
_create_dummy_alignment_data('valid.in', 'valid.out', 'valid.align')
_create_dummy_alignment_data('test.in', 'test.out', 'test.align')

def preprocess_translation_data(data_dir, extra_flags=None):
preprocess_parser = options.get_preprocessing_parser()
Expand Down