Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[FEATURE]Horovod support for training transformer + add mirror data f…
Browse files Browse the repository at this point in the history
…or wmt (PART 1) (#1284)

* set default shuffle=True for boundedbudgetsampler

* fix

* fix log condition

* use horovod to train transformer

* fix

* add mirror wmt dataset

* fix

* rename wmt.txt to wmt.json and remove part of urls

* fix

* tuning params

* use get_repo_url()

* update average checkpoint cli

* paste result of transformer large

* fix

* fix logging in train_transformer

* fix

* fix

* fix

* add transformer base config

Co-authored-by: Hu <huta@a483e74650ff.ant.amazon.com>
  • Loading branch information
hutao965 and Hu committed Aug 7, 2020
1 parent ded0f99 commit c33e62e
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 68 deletions.
24 changes: 21 additions & 3 deletions scripts/datasets/machine_translation/prepare_wmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import functools
import tarfile
import gzip
import json
from xml.etree import ElementTree
from gluonnlp.data.filtering import ProfanityFilter
from gluonnlp.utils.misc import file_line_number, download, load_checksum_stats
from gluonnlp.base import get_data_home_dir
from gluonnlp.base import get_data_home_dir, get_repo_url
from gluonnlp.registry import DATA_PARSER_REGISTRY, DATA_MAIN_REGISTRY

# The datasets are provided by WMT2014-WMT2019 and can be freely used for research purposes.
Expand Down Expand Up @@ -336,6 +337,15 @@
}
}

with open(os.path.join(_CURR_DIR, '..', 'url_checksums', 'mirror', 'wmt.json')) as wmt_mirror_map_f:
_WMT_MIRROR_URL_MAP = json.load(wmt_mirror_map_f)

def _download_with_mirror(url, path, sha1_hash):
return download(
get_repo_url() + _WMT_MIRROR_URL_MAP[url] if url in _WMT_MIRROR_URL_MAP else url,
path=path,
sha1_hash=sha1_hash
)

def _clean_space(s: str):
"""Removes trailing and leading spaces and collapses multiple consecutive internal spaces to a single one.
Expand Down Expand Up @@ -626,7 +636,11 @@ def fetch_mono_dataset(selection: Union[str, List[str], List[List[str]]],
save_path_l = [path] + selection + [matched_lang, original_filename]
else:
save_path_l = [path] + selection + [original_filename]
download_fname = download(url, path=os.path.join(*save_path_l), sha1_hash=sha1_hash)
download_fname = _download_with_mirror(
url,
path=os.path.join(*save_path_l),
sha1_hash=sha1_hash
)
download_fname_l.append(download_fname)
if len(download_fname_l) > 1:
data_path = concatenate_files(download_fname_l)
Expand Down Expand Up @@ -792,7 +806,11 @@ def fetch_wmt_parallel_dataset(selection: Union[str, List[str], List[List[str]]]
save_path_l = [path] + selection + [matched_pair, original_filename]
else:
save_path_l = [path] + selection + [original_filename]
download_fname = download(url, path=os.path.join(*save_path_l), sha1_hash=sha1_hash)
download_fname = _download_with_mirror(
url,
path=os.path.join(*save_path_l),
sha1_hash=sha1_hash
)
download_fname_l.append(download_fname)
if len(download_fname_l) > 1:
data_path = concatenate_files(download_fname_l)
Expand Down
1 change: 0 additions & 1 deletion scripts/datasets/machine_translation/wmt2014_ende.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \
--tgt-corpus dev.raw.${TGT} \
--min-num-words 1 \
--max-num-words 100 \
--max-ratio 1.5 \
--src-save-path dev.tok.${SRC} \
--tgt-save-path dev.tok.${TGT}

Expand Down
48 changes: 48 additions & 0 deletions scripts/datasets/url_checksums/mirror/wmt.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"http://www.statmt.org/europarl/v7/cs-en.tgz" : "datasets/third_party_mirror/cs-en-28bad3e096923694fb776b6cd6ba1079546a9e58.tgz",
"http://www.statmt.org/europarl/v7/de-en.tgz" : "datasets/third_party_mirror/de-en-53bb5408d22977c89284bd755717e6bbb5b12bc5.tgz",
"http://data.statmt.org/wmt18/translation-task/training-parallel-ep-v8.tgz" : "datasets/third_party_mirror/training-parallel-ep-v8-2f5c2c2c98b72921474a3f1837dc5b61dd44ba88.tgz",
"http://www.statmt.org/europarl/v9/training/europarl-v9.cs-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.cs-en.tsv-e36a1bfe634379ec813b399b57a38093df2349ef.gz",
"http://www.statmt.org/europarl/v9/training/europarl-v9.de-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.de-en.tsv-d553d0c8189642c1c7ae6ed3c265c847e432057c.gz",
"http://www.statmt.org/europarl/v9/training/europarl-v9.fi-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.fi-en.tsv-c5d2f6aad04e88dda6ad11a110f4ca24150edca3.gz",
"http://www.statmt.org/europarl/v9/training/europarl-v9.lt-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.lt-en.tsv-a6343d8fc158f44714ea7d01c0eb65b34640841d.gz",
"http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz" : "datasets/third_party_mirror/training-parallel-commoncrawl-1c0ad85f0ebaf1d543acb009607205f5dae6627d.tgz",
"http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz" : "datasets/third_party_mirror/training-parallel-nc-v9-c7ae7f50cd45c2f3014d78ddba25a4a8a851e27a.tgz",
"http://www.statmt.org/wmt15/training-parallel-nc-v10.tgz" : "datasets/third_party_mirror/training-parallel-nc-v10-6c3c45b0f34d5e84a4d0b75a5edcca226ba7d6c2.tgz",
"http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz" : "datasets/third_party_mirror/training-parallel-nc-v11-f51a1f03908e790d23d10001e92e09ce9555a790.tgz",
"http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz" : "datasets/third_party_mirror/training-parallel-nc-v12-d98afc59e1d753485530b377ff65f1f891d3bced.tgz",
"http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz" : "datasets/third_party_mirror/training-parallel-nc-v13-cbaa7834e58d36f228336e3caee6a9056029ff5d.tgz",
"http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.de-en.tsv.gz" : "datasets/third_party_mirror/news-commentary-v14.de-en.tsv-c1fd94c7c9ff222968cbd45100bdd8dbeb5ab2aa.gz",
"http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.en-zh.tsv.gz" : "datasets/third_party_mirror/news-commentary-v14.en-zh.tsv-4ca5c01deeba5425646d42f9598d081cd662908b.gz",
"http://data.statmt.org/wikititles/v1/wikititles-v1.cs-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.cs-en.tsv-6e094d218dfd8f987fa1a18ea7b4cb127cfb1763.gz",
"http://data.statmt.org/wikititles/v1/wikititles-v1.cs-pl.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.cs-pl.tsv-dc93d346d151bf73e4165d6db425b903fc21a5b0.gz",
"http://data.statmt.org/wikititles/v1/wikititles-v1.de-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.de-en.tsv-e141c55c43a474e06c259c3fa401288b39cd4315.gz",
"http://data.statmt.org/wikititles/v1/wikititles-v1.es-pt.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.es-pt.tsv-c3bd398d57471ee4ab33323393977b8d475a368c.gz",
"http://data.statmt.org/wikititles/v1/wikititles-v1.fi-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.fi-en.tsv-5668b004567ca286d1aad9c2b45862a441d79667.gz",
"http://data.statmt.org/wikititles/v1/wikititles-v1.gu-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.gu-en.tsv-95b9f15b6a86bfed6dc9bc91597368fd334f436e.gz",
"http://data.statmt.org/wikititles/v1/wikititles-v1.hi-ne.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.hi-ne.tsv-6d63908950c72bc8cc69ca470deccff11354afc2.gz",
"http://data.statmt.org/wikititles/v1/wikititles-v1.kk-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.kk-en.tsv-56ee1e450ef98fe92ea2116c3ce7acc7c7c42b39.gz",
"http://data.statmt.org/wikititles/v1/wikititles-v1.lt-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.lt-en.tsv-b8829928686727165eec6c591d2875d12d7c0cfe.gz",
"http://data.statmt.org/wikititles/v1/wikititles-v1.ru-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.ru-en.tsv-16d8d231fdf6347b4cc7834654adec80153ff7a4.gz",
"http://data.statmt.org/wikititles/v1/wikititles-v1.zh-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.zh-en.tsv-5829097ff7dd61752f29fb306b04d790a1a1cfd7.gz",
"https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.00" : "datasets/third_party_mirror/UNv1.0.en-ru-98c4e01e16070567d27da0ab4fe401f309dd3678.tar.gz.00",
"https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.01" : "datasets/third_party_mirror/UNv1.0.en-ru-86c6013dc88f353d2d6e591928e7549060fcb949.tar.gz.01",
"https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.02" : "datasets/third_party_mirror/UNv1.0.en-ru-bf6b18a33c8cafa6889fd463fa8a2850d8877d35.tar.gz.02",
"https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.00" : "datasets/third_party_mirror/UNv1.0.en-zh-1bec5f10297512183e483fdd4984d207700657d1.tar.gz.00",
"https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.01" : "datasets/third_party_mirror/UNv1.0.en-zh-15df2968bc69ef7662cf3029282bbb62cbf107b1.tar.gz.01",
"http://data.statmt.org/wmt17/translation-task/rapid2016.tgz" : "datasets/third_party_mirror/rapid2016-8b173ce0bc77f2a1a57c8134143e3b5ae228a6e2.tgz",
"http://data.statmt.org/wmt19/translation-task/dev.tgz" : "datasets/third_party_mirror/dev-451ce2cae815c8392212ccb3f54f5dcddb9b2b9e.tgz",
"http://data.statmt.org/wmt19/translation-task/test.tgz" : "datasets/third_party_mirror/test-ce02a36fb2cd41abfa19d36eb8c8d50241ed3346.tgz",
"http://data.statmt.org/news-crawl/de/news.2007.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2007.de.shuffled.deduped-9d746b9df345f764e6e615119113c70e3fb0858c.gz",
"http://data.statmt.org/news-crawl/de/news.2008.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2008.de.shuffled.deduped-185a24e8833844486aee16cb5decf9a64da1c101.gz",
"http://data.statmt.org/news-crawl/de/news.2009.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2009.de.shuffled.deduped-9f7645fc6467de88f4205d94f483194838bad8ce.gz",
"http://data.statmt.org/news-crawl/de/news.2010.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2010.de.shuffled.deduped-f29b761194e9606f086102cfac12813931575818.gz",
"http://data.statmt.org/news-crawl/de/news.2011.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2011.de.shuffled.deduped-613b16e7a1cb8559dd428525a4c3b42c8a4dc278.gz",
"http://data.statmt.org/news-crawl/de/news.2012.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2012.de.shuffled.deduped-1bc419364ea3fe2f9ba4236947c012d4198d9282.gz",
"http://data.statmt.org/news-crawl/de/news.2013.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2013.de.shuffled.deduped-3edd84a7f105907608371c81babc7a9078f40aac.gz",
"http://data.statmt.org/news-crawl/de/news.2014.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2014.de.shuffled.deduped-1466c67b330c08ab5ab7d48e666c1d3a0bb4e479.gz",
"http://data.statmt.org/news-crawl/de/news.2015.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2015.de.shuffled.deduped-2c6d5ec9f8fe51e9eb762be8ff7107c6116c00c4.gz",
"http://data.statmt.org/news-crawl/de/news.2016.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2016.de.shuffled.deduped-e7d235c5d28e36dcf6382f1aa12c6ff37d4529bb.gz",
"http://data.statmt.org/news-crawl/de/news.2017.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2017.de.shuffled.deduped-f70b4a67bc04c0fdc2ec955b737fa22681e8c038.gz",
"http://data.statmt.org/news-crawl/de/news.2018.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2018.de.shuffled.deduped-43f8237de1e219276c0682255def13aa2cb80e35.gz"
}
109 changes: 96 additions & 13 deletions scripts/machine_translation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,113 @@ You may first run the following command in [datasets/machine_translation](../dat
bash wmt2014_ende.sh yttm
```

Then, you can run the experiment, we use the
"transformer_base" configuration.
Then, you can run the experiment.
For "transformer_base" configuration

# TODO
```bash
SUBWORD_MODEL=yttm
SRC=en
TGT=de
datapath=../datasets/machine_translation
python train_transformer.py \
--train_src_corpus ../datasets/machine_translation/wmt2014_ende/train.tok.${SUBWORD_MODEL}.en \
--train_tgt_corpus ../datasets/machine_translation/wmt2014_ende/train.tok.${SUBWORD_MODEL}.de \
--dev_src_corpus ../datasets/machine_translation/wmt2014_ende/dev.tok.${SUBWORD_MODEL}.en \
--dev_tgt_corpus ../datasets/machine_translation/wmt2014_ende/dev.tok.${SUBWORD_MODEL}.de \
--train_src_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${SRC} \
--train_tgt_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${TGT} \
--dev_src_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${SRC} \
--dev_tgt_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${TGT} \
--src_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \
--src_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \
--tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \
--tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \
--save_dir transformer_base_wmt2014_en_de_${SUBWORD_ALGO} \
--cfg transformer_base \
--lr 0.002 \
--batch_size 2700 \
--num_averages 5 \
--warmup_steps 4000 \
--warmup_init_lr 0.0 \
--seed 123 \
--gpus 0,1,2,3
```

Use the average_checkpoint cli to average the last 10 checkpoints

```bash
gluon_average_checkpoint --checkpoints transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/epoch*.params \
--begin 21 \
--end 30 \
--save-path transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/avg_21_30.params
```


Use the following command to inference/evaluate the Transformer model:

```bash
SUBWORD_MODEL=yttm
python evaluate_transformer.py \
--param_path transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \
--src_lang en \
--tgt_lang de \
--cfg transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/config.yml \
--src_tokenizer ${SUBWORD_MODEL} \
--tgt_tokenizer ${SUBWORD_MODEL} \
--src_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \
--src_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \
--tgt_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \
--src_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \
--tgt_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \
--save_dir transformer_wmt2014_ende_${SUBWORD_MODEL} \
--cfg transformer_base \
--lr 0.002 \
--src_corpus ../datasets/machine_translation/wmt2014_ende/test.raw.en \
--tgt_corpus ../datasets/machine_translation/wmt2014_ende/test.raw.de
```



For "transformer_wmt_en_de_big" configuration

```bash
SUBWORD_MODEL=yttm
SRC=en
TGT=de
datapath=../datasets/machine_translation
python train_transformer.py \
--train_src_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${SRC} \
--train_tgt_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${TGT} \
--dev_src_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${SRC} \
--dev_tgt_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${TGT} \
--src_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \
--src_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \
--tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \
--tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \
--save_dir transformer_big_wmt2014_en_de_${SUBWORD_ALGO} \
--cfg transformer_wmt_en_de_big \
--lr 0.001 \
--sampler BoundedBudgetSampler \
--max_num_tokens 3584 \
--max_update 15000 \
--warmup_steps 4000 \
--warmup_init_lr 0.0 \
--seed 123 \
--gpus 0,1,2,3
```

Use the average_checkpoint cli to average the last 10 checkpoints

```bash
gluon_average_checkpoint --checkpoints transformer_big_wmt2014_en_de_${SUBWORD_ALGO}/update*.params \
--begin 21 \
--end 30 \
--save-path transformer_big_wmt2014_en_de_${SUBWORD_ALGO}/avg_21_30.params
```


Use the following command to inference/evaluate the Transformer model:

```bash
SUBWORD_MODEL=yttm
python evaluate_transformer.py \
--param_path transformer_wmt2014_ende_${SUBWORD_MODEL}/average.params \
--param_path transformer_big_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \
--src_lang en \
--tgt_lang de \
--cfg transformer_wmt2014_ende_${SUBWORD_MODEL}/config.yml \
--cfg transformer_big_wmt2014_en_de_${SUBWORD_MODEL}/config.yml \
--src_tokenizer ${SUBWORD_MODEL} \
--tgt_tokenizer ${SUBWORD_MODEL} \
--src_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \
Expand All @@ -59,6 +134,14 @@ Test BLEU score with 3 seeds (evaluated via sacre BLEU):

| Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | Mean±std |
|---------------|------------|-------------|-------------|--------------|-------------|
| yttm | | 26.63 | 26.73 | | - |
| yttm | | - | - | - | - |
| hf_bpe | | - | - | - | - |
| spm | | - | - | - | - |

- transformer_wmt_en_de_big

| Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | Mean±std |
|---------------|------------|-------------|-------------|--------------|-------------|
| yttm | | 27.99 | - | - | - |
| hf_bpe | | - | - | - | - |
| spm | | - | - | - | - |
Loading

0 comments on commit c33e62e

Please sign in to comment.