From c33e62e1a5930ad3460308531a9cfec705fca8ab Mon Sep 17 00:00:00 2001 From: ht Date: Fri, 7 Aug 2020 11:39:05 +0800 Subject: [PATCH] [FEATURE]Horovod support for training transformer + add mirror data for wmt (PART 1) (#1284) * set default shuffle=True for boundedbudgetsampler * fix * fix log condition * use horovod to train transformer * fix * add mirror wmt dataset * fix * rename wmt.txt to wmt.json and remove part of urls * fix * tuning params * use get_repo_url() * update average checkpoint cli * paste result of transformer large * fix * fix logging in train_transformer * fix * fix * fix * add transformer base config Co-authored-by: Hu --- .../machine_translation/prepare_wmt.py | 24 +++- .../machine_translation/wmt2014_ende.sh | 1 - .../datasets/url_checksums/mirror/wmt.json | 48 ++++++++ scripts/machine_translation/README.md | 109 +++++++++++++++--- .../machine_translation/train_transformer.py | 65 ++++++++--- .../wmt2014_back_translation.sh | 15 ++- src/gluonnlp/cli/average_checkpoint.py | 53 +++++---- src/gluonnlp/data/sampler.py | 16 ++- 8 files changed, 263 insertions(+), 68 deletions(-) create mode 100644 scripts/datasets/url_checksums/mirror/wmt.json diff --git a/scripts/datasets/machine_translation/prepare_wmt.py b/scripts/datasets/machine_translation/prepare_wmt.py index 1e910a70ca..2ac5f77772 100644 --- a/scripts/datasets/machine_translation/prepare_wmt.py +++ b/scripts/datasets/machine_translation/prepare_wmt.py @@ -7,10 +7,11 @@ import functools import tarfile import gzip +import json from xml.etree import ElementTree from gluonnlp.data.filtering import ProfanityFilter from gluonnlp.utils.misc import file_line_number, download, load_checksum_stats -from gluonnlp.base import get_data_home_dir +from gluonnlp.base import get_data_home_dir, get_repo_url from gluonnlp.registry import DATA_PARSER_REGISTRY, DATA_MAIN_REGISTRY # The datasets are provided by WMT2014-WMT2019 and can be freely used for research purposes. @@ -336,6 +337,15 @@ } } +with open(os.path.join(_CURR_DIR, '..', 'url_checksums', 'mirror', 'wmt.json')) as wmt_mirror_map_f: + _WMT_MIRROR_URL_MAP = json.load(wmt_mirror_map_f) + +def _download_with_mirror(url, path, sha1_hash): + return download( + get_repo_url() + _WMT_MIRROR_URL_MAP[url] if url in _WMT_MIRROR_URL_MAP else url, + path=path, + sha1_hash=sha1_hash + ) def _clean_space(s: str): """Removes trailing and leading spaces and collapses multiple consecutive internal spaces to a single one. @@ -626,7 +636,11 @@ def fetch_mono_dataset(selection: Union[str, List[str], List[List[str]]], save_path_l = [path] + selection + [matched_lang, original_filename] else: save_path_l = [path] + selection + [original_filename] - download_fname = download(url, path=os.path.join(*save_path_l), sha1_hash=sha1_hash) + download_fname = _download_with_mirror( + url, + path=os.path.join(*save_path_l), + sha1_hash=sha1_hash + ) download_fname_l.append(download_fname) if len(download_fname_l) > 1: data_path = concatenate_files(download_fname_l) @@ -792,7 +806,11 @@ def fetch_wmt_parallel_dataset(selection: Union[str, List[str], List[List[str]]] save_path_l = [path] + selection + [matched_pair, original_filename] else: save_path_l = [path] + selection + [original_filename] - download_fname = download(url, path=os.path.join(*save_path_l), sha1_hash=sha1_hash) + download_fname = _download_with_mirror( + url, + path=os.path.join(*save_path_l), + sha1_hash=sha1_hash + ) download_fname_l.append(download_fname) if len(download_fname_l) > 1: data_path = concatenate_files(download_fname_l) diff --git a/scripts/datasets/machine_translation/wmt2014_ende.sh b/scripts/datasets/machine_translation/wmt2014_ende.sh index 028b796ed2..f319db2163 100644 --- a/scripts/datasets/machine_translation/wmt2014_ende.sh +++ b/scripts/datasets/machine_translation/wmt2014_ende.sh @@ -34,7 +34,6 @@ nlp_preprocess clean_tok_para_corpus --src-lang ${SRC} \ --tgt-corpus dev.raw.${TGT} \ --min-num-words 1 \ --max-num-words 100 \ - --max-ratio 1.5 \ --src-save-path dev.tok.${SRC} \ --tgt-save-path dev.tok.${TGT} diff --git a/scripts/datasets/url_checksums/mirror/wmt.json b/scripts/datasets/url_checksums/mirror/wmt.json new file mode 100644 index 0000000000..fa695f6bd9 --- /dev/null +++ b/scripts/datasets/url_checksums/mirror/wmt.json @@ -0,0 +1,48 @@ +{ + "http://www.statmt.org/europarl/v7/cs-en.tgz" : "datasets/third_party_mirror/cs-en-28bad3e096923694fb776b6cd6ba1079546a9e58.tgz", + "http://www.statmt.org/europarl/v7/de-en.tgz" : "datasets/third_party_mirror/de-en-53bb5408d22977c89284bd755717e6bbb5b12bc5.tgz", + "http://data.statmt.org/wmt18/translation-task/training-parallel-ep-v8.tgz" : "datasets/third_party_mirror/training-parallel-ep-v8-2f5c2c2c98b72921474a3f1837dc5b61dd44ba88.tgz", + "http://www.statmt.org/europarl/v9/training/europarl-v9.cs-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.cs-en.tsv-e36a1bfe634379ec813b399b57a38093df2349ef.gz", + "http://www.statmt.org/europarl/v9/training/europarl-v9.de-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.de-en.tsv-d553d0c8189642c1c7ae6ed3c265c847e432057c.gz", + "http://www.statmt.org/europarl/v9/training/europarl-v9.fi-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.fi-en.tsv-c5d2f6aad04e88dda6ad11a110f4ca24150edca3.gz", + "http://www.statmt.org/europarl/v9/training/europarl-v9.lt-en.tsv.gz" : "datasets/third_party_mirror/europarl-v9.lt-en.tsv-a6343d8fc158f44714ea7d01c0eb65b34640841d.gz", + "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz" : "datasets/third_party_mirror/training-parallel-commoncrawl-1c0ad85f0ebaf1d543acb009607205f5dae6627d.tgz", + "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz" : "datasets/third_party_mirror/training-parallel-nc-v9-c7ae7f50cd45c2f3014d78ddba25a4a8a851e27a.tgz", + "http://www.statmt.org/wmt15/training-parallel-nc-v10.tgz" : "datasets/third_party_mirror/training-parallel-nc-v10-6c3c45b0f34d5e84a4d0b75a5edcca226ba7d6c2.tgz", + "http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz" : "datasets/third_party_mirror/training-parallel-nc-v11-f51a1f03908e790d23d10001e92e09ce9555a790.tgz", + "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz" : "datasets/third_party_mirror/training-parallel-nc-v12-d98afc59e1d753485530b377ff65f1f891d3bced.tgz", + "http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz" : "datasets/third_party_mirror/training-parallel-nc-v13-cbaa7834e58d36f228336e3caee6a9056029ff5d.tgz", + "http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.de-en.tsv.gz" : "datasets/third_party_mirror/news-commentary-v14.de-en.tsv-c1fd94c7c9ff222968cbd45100bdd8dbeb5ab2aa.gz", + "http://data.statmt.org/news-commentary/v14/training/news-commentary-v14.en-zh.tsv.gz" : "datasets/third_party_mirror/news-commentary-v14.en-zh.tsv-4ca5c01deeba5425646d42f9598d081cd662908b.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.cs-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.cs-en.tsv-6e094d218dfd8f987fa1a18ea7b4cb127cfb1763.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.cs-pl.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.cs-pl.tsv-dc93d346d151bf73e4165d6db425b903fc21a5b0.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.de-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.de-en.tsv-e141c55c43a474e06c259c3fa401288b39cd4315.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.es-pt.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.es-pt.tsv-c3bd398d57471ee4ab33323393977b8d475a368c.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.fi-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.fi-en.tsv-5668b004567ca286d1aad9c2b45862a441d79667.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.gu-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.gu-en.tsv-95b9f15b6a86bfed6dc9bc91597368fd334f436e.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.hi-ne.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.hi-ne.tsv-6d63908950c72bc8cc69ca470deccff11354afc2.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.kk-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.kk-en.tsv-56ee1e450ef98fe92ea2116c3ce7acc7c7c42b39.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.lt-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.lt-en.tsv-b8829928686727165eec6c591d2875d12d7c0cfe.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.ru-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.ru-en.tsv-16d8d231fdf6347b4cc7834654adec80153ff7a4.gz", + "http://data.statmt.org/wikititles/v1/wikititles-v1.zh-en.tsv.gz" : "datasets/third_party_mirror/wikititles-v1.zh-en.tsv-5829097ff7dd61752f29fb306b04d790a1a1cfd7.gz", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.00" : "datasets/third_party_mirror/UNv1.0.en-ru-98c4e01e16070567d27da0ab4fe401f309dd3678.tar.gz.00", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.01" : "datasets/third_party_mirror/UNv1.0.en-ru-86c6013dc88f353d2d6e591928e7549060fcb949.tar.gz.01", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-ru.tar.gz.02" : "datasets/third_party_mirror/UNv1.0.en-ru-bf6b18a33c8cafa6889fd463fa8a2850d8877d35.tar.gz.02", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.00" : "datasets/third_party_mirror/UNv1.0.en-zh-1bec5f10297512183e483fdd4984d207700657d1.tar.gz.00", + "https://stuncorpusprod.blob.core.windows.net/corpusfiles/UNv1.0.en-zh.tar.gz.01" : "datasets/third_party_mirror/UNv1.0.en-zh-15df2968bc69ef7662cf3029282bbb62cbf107b1.tar.gz.01", + "http://data.statmt.org/wmt17/translation-task/rapid2016.tgz" : "datasets/third_party_mirror/rapid2016-8b173ce0bc77f2a1a57c8134143e3b5ae228a6e2.tgz", + "http://data.statmt.org/wmt19/translation-task/dev.tgz" : "datasets/third_party_mirror/dev-451ce2cae815c8392212ccb3f54f5dcddb9b2b9e.tgz", + "http://data.statmt.org/wmt19/translation-task/test.tgz" : "datasets/third_party_mirror/test-ce02a36fb2cd41abfa19d36eb8c8d50241ed3346.tgz", + "http://data.statmt.org/news-crawl/de/news.2007.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2007.de.shuffled.deduped-9d746b9df345f764e6e615119113c70e3fb0858c.gz", + "http://data.statmt.org/news-crawl/de/news.2008.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2008.de.shuffled.deduped-185a24e8833844486aee16cb5decf9a64da1c101.gz", + "http://data.statmt.org/news-crawl/de/news.2009.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2009.de.shuffled.deduped-9f7645fc6467de88f4205d94f483194838bad8ce.gz", + "http://data.statmt.org/news-crawl/de/news.2010.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2010.de.shuffled.deduped-f29b761194e9606f086102cfac12813931575818.gz", + "http://data.statmt.org/news-crawl/de/news.2011.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2011.de.shuffled.deduped-613b16e7a1cb8559dd428525a4c3b42c8a4dc278.gz", + "http://data.statmt.org/news-crawl/de/news.2012.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2012.de.shuffled.deduped-1bc419364ea3fe2f9ba4236947c012d4198d9282.gz", + "http://data.statmt.org/news-crawl/de/news.2013.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2013.de.shuffled.deduped-3edd84a7f105907608371c81babc7a9078f40aac.gz", + "http://data.statmt.org/news-crawl/de/news.2014.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2014.de.shuffled.deduped-1466c67b330c08ab5ab7d48e666c1d3a0bb4e479.gz", + "http://data.statmt.org/news-crawl/de/news.2015.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2015.de.shuffled.deduped-2c6d5ec9f8fe51e9eb762be8ff7107c6116c00c4.gz", + "http://data.statmt.org/news-crawl/de/news.2016.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2016.de.shuffled.deduped-e7d235c5d28e36dcf6382f1aa12c6ff37d4529bb.gz", + "http://data.statmt.org/news-crawl/de/news.2017.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2017.de.shuffled.deduped-f70b4a67bc04c0fdc2ec955b737fa22681e8c038.gz", + "http://data.statmt.org/news-crawl/de/news.2018.de.shuffled.deduped.gz" : "datasets/third_party_mirror/news.2018.de.shuffled.deduped-43f8237de1e219276c0682255def13aa2cb80e35.gz" +} \ No newline at end of file diff --git a/scripts/machine_translation/README.md b/scripts/machine_translation/README.md index 8b5d0695f1..061ba0658d 100644 --- a/scripts/machine_translation/README.md +++ b/scripts/machine_translation/README.md @@ -10,38 +10,113 @@ You may first run the following command in [datasets/machine_translation](../dat bash wmt2014_ende.sh yttm ``` -Then, you can run the experiment, we use the -"transformer_base" configuration. +Then, you can run the experiment. +For "transformer_base" configuration +# TODO ```bash SUBWORD_MODEL=yttm +SRC=en +TGT=de +datapath=../datasets/machine_translation python train_transformer.py \ - --train_src_corpus ../datasets/machine_translation/wmt2014_ende/train.tok.${SUBWORD_MODEL}.en \ - --train_tgt_corpus ../datasets/machine_translation/wmt2014_ende/train.tok.${SUBWORD_MODEL}.de \ - --dev_src_corpus ../datasets/machine_translation/wmt2014_ende/dev.tok.${SUBWORD_MODEL}.en \ - --dev_tgt_corpus ../datasets/machine_translation/wmt2014_ende/dev.tok.${SUBWORD_MODEL}.de \ + --train_src_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${SRC} \ + --train_tgt_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${TGT} \ + --dev_src_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${SRC} \ + --dev_tgt_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${TGT} \ + --src_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ + --src_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ + --tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ + --tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ + --save_dir transformer_base_wmt2014_en_de_${SUBWORD_ALGO} \ + --cfg transformer_base \ + --lr 0.002 \ + --batch_size 2700 \ + --num_averages 5 \ + --warmup_steps 4000 \ + --warmup_init_lr 0.0 \ + --seed 123 \ + --gpus 0,1,2,3 +``` + +Use the average_checkpoint cli to average the last 10 checkpoints + +```bash +gluon_average_checkpoint --checkpoints transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/epoch*.params \ + --begin 21 \ + --end 30 \ + --save-path transformer_base_wmt2014_en_de_${SUBWORD_ALGO}/avg_21_30.params +``` + + +Use the following command to inference/evaluate the Transformer model: + +```bash +SUBWORD_MODEL=yttm +python evaluate_transformer.py \ + --param_path transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \ + --src_lang en \ + --tgt_lang de \ + --cfg transformer_base_wmt2014_en_de_${SUBWORD_MODEL}/config.yml \ + --src_tokenizer ${SUBWORD_MODEL} \ + --tgt_tokenizer ${SUBWORD_MODEL} \ --src_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \ - --src_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \ --tgt_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \ + --src_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \ --tgt_vocab_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.vocab \ - --save_dir transformer_wmt2014_ende_${SUBWORD_MODEL} \ - --cfg transformer_base \ - --lr 0.002 \ + --src_corpus ../datasets/machine_translation/wmt2014_ende/test.raw.en \ + --tgt_corpus ../datasets/machine_translation/wmt2014_ende/test.raw.de +``` + + + +For "transformer_wmt_en_de_big" configuration + +```bash +SUBWORD_MODEL=yttm +SRC=en +TGT=de +datapath=../datasets/machine_translation +python train_transformer.py \ + --train_src_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${SRC} \ + --train_tgt_corpus ${datapath}/wmt2014_ende/train.tok.${SUBWORD_ALGO}.${TGT} \ + --dev_src_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${SRC} \ + --dev_tgt_corpus ${datapath}/wmt2014_ende/dev.tok.${SUBWORD_ALGO}.${TGT} \ + --src_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ + --src_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ + --tgt_subword_model_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.model \ + --tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ + --save_dir transformer_big_wmt2014_en_de_${SUBWORD_ALGO} \ + --cfg transformer_wmt_en_de_big \ + --lr 0.001 \ + --sampler BoundedBudgetSampler \ + --max_num_tokens 3584 \ + --max_update 15000 \ --warmup_steps 4000 \ --warmup_init_lr 0.0 \ --seed 123 \ --gpus 0,1,2,3 ``` +Use the average_checkpoint cli to average the last 10 checkpoints + +```bash +gluon_average_checkpoint --checkpoints transformer_big_wmt2014_en_de_${SUBWORD_ALGO}/update*.params \ + --begin 21 \ + --end 30 \ + --save-path transformer_big_wmt2014_en_de_${SUBWORD_ALGO}/avg_21_30.params +``` + + Use the following command to inference/evaluate the Transformer model: ```bash SUBWORD_MODEL=yttm python evaluate_transformer.py \ - --param_path transformer_wmt2014_ende_${SUBWORD_MODEL}/average.params \ + --param_path transformer_big_wmt2014_en_de_${SUBWORD_MODEL}/average_21_30.params \ --src_lang en \ --tgt_lang de \ - --cfg transformer_wmt2014_ende_${SUBWORD_MODEL}/config.yml \ + --cfg transformer_big_wmt2014_en_de_${SUBWORD_MODEL}/config.yml \ --src_tokenizer ${SUBWORD_MODEL} \ --tgt_tokenizer ${SUBWORD_MODEL} \ --src_subword_model_path ../datasets/machine_translation/wmt2014_ende/${SUBWORD_MODEL}.model \ @@ -59,6 +134,14 @@ Test BLEU score with 3 seeds (evaluated via sacre BLEU): | Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | MeanĀ±std | |---------------|------------|-------------|-------------|--------------|-------------| -| yttm | | 26.63 | 26.73 | | - | +| yttm | | - | - | - | - | +| hf_bpe | | - | - | - | - | +| spm | | - | - | - | - | + +- transformer_wmt_en_de_big + +| Subword Model | #Params | Seed = 123 | Seed = 1234 | Seed = 12345 | MeanĀ±std | +|---------------|------------|-------------|-------------|--------------|-------------| +| yttm | | 27.99 | - | - | - | | hf_bpe | | - | - | - | - | | spm | | - | - | - | - | diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index 3b0a6565e5..2a3ef665ab 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -44,7 +44,7 @@ from mxnet import gluon from gluonnlp.models.transformer import TransformerModel from gluonnlp.utils.misc import logging_config, AverageSGDTracker, count_parameters,\ - md5sum, grouper + md5sum, grouper, init_comm from gluonnlp.data.sampler import ( ConstWidthBucket, LinearWidthBucket, @@ -58,6 +58,11 @@ from gluonnlp.data.tokenizers import BaseTokenizerWithVocab from gluonnlp.lr_scheduler import InverseSquareRootScheduler from gluonnlp.loss import LabelSmoothCrossEntropyLoss +try: + import horovod.mxnet as hvd +except ImportError: + hvd = None + mx.npx.set_np() @@ -131,9 +136,9 @@ def parse_args(): '"exp": the width of bucket increases exponentially') parser.add_argument('--bucket_ratio', type=float, default=0.0, help='Ratio for increasing the throughput of the bucketing') - parser.add_argument('--max_tokens', type=int, default=-1, + parser.add_argument('--max_num_tokens', type=int, default=-1, help='max tokens num of each batch, applicable while using BoundedBudgetSampler') - parser.add_argument('--max_sentences', type=int, default=-1, + parser.add_argument('--max_num_sentences', type=int, default=-1, help='max sentences num of each batch, applicable while using BoundedBudgetSampler') parser.add_argument('--lr', type=float, default=0.002, help='The learning rate at the end of the warmup stage. ' @@ -151,10 +156,10 @@ def parse_args(): 'This is useful to mimic large batch training with limited gpu memory') parser.add_argument('--magnitude', type=float, default=3.0, help='Magnitude of Xavier initialization') - parser.add_argument('--num_averages', type=int, default=5, + parser.add_argument('--num_averages', type=int, default=-1, help='Perform final testing based on the ' 'average of last num_averages checkpoints. ' - 'This is only used if average_checkpoint is True') + 'Use num_average will cause extra gpu memory usage.') parser.add_argument('--log_interval', type=int, default=10, metavar='N', help='report interval') parser.add_argument('--save_dir', type=str, default='transformer_out', @@ -162,6 +167,9 @@ def parse_args(): parser.add_argument('--overwrite_cache', action='store_true') parser.add_argument('--fp16', action='store_true', help='Whether to use dtype float16') + parser.add_argument('--comm_backend', type=str, default='device', + choices=['horovod', 'dist_sync_device', 'device'], + help='Communication backend.') parser.add_argument('--gpus', type=str, help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.') args = parser.parse_args() @@ -280,6 +288,8 @@ def create_tokenizer(tokenizer_type, model_path, vocab_path): def train(args): + store, num_parts, rank, local_rank, is_master_node, ctx_l = init_comm( + args.comm_backend, args.gpus) src_tokenizer = create_tokenizer(args.src_tokenizer, args.src_subword_model_path, args.src_vocab_path) @@ -304,8 +314,6 @@ def train(args): data_val = gluon.data.SimpleDataset( [(src_tokens, tgt_tokens, len(src_tokens), len(tgt_tokens), i) for i, (src_tokens, tgt_tokens) in enumerate(zip(dev_src_data, dev_tgt_data))]) - ctx_l = [mx.cpu()] if args.gpus is None or args.gpus == ''\ - else [mx.gpu(int(x)) for x in args.gpus.split(',')] # Construct the model + loss function if args.cfg.endswith('.yml'): cfg = TransformerModel.get_cfg().clone_merge(args.cfg) @@ -322,7 +330,8 @@ def train(args): model.initialize(mx.init.Xavier(magnitude=args.magnitude), ctx=ctx_l) model.hybridize() - logging.info(model) + if local_rank == 0: + logging.info(model) with open(os.path.join(args.save_dir, 'config.yml'), 'w') as cfg_f: cfg_f.write(cfg.dump()) label_smooth_loss = LabelSmoothCrossEntropyLoss(num_labels=len(tgt_vocab), @@ -330,6 +339,10 @@ def train(args): from_logits=False) label_smooth_loss.hybridize() rescale_loss = 100.0 + + if args.comm_backend == 'horovod': + hvd.broadcast_parameters(model.collect_params(), root_rank=0) + # Construct the trainer # TODO(sxjscience) Support AMP if args.lr is None: @@ -338,16 +351,25 @@ def train(args): base_lr = args.lr lr_scheduler = InverseSquareRootScheduler(warmup_steps=args.warmup_steps, base_lr=base_lr, warmup_init_lr=args.warmup_init_lr) - trainer = gluon.Trainer(model.collect_params(), 'adam', + trainer_settings = (model.collect_params(), 'adam', {'learning_rate': args.lr, 'beta1': 0.9, 'beta2': 0.98, 'epsilon': 1e-9, 'lr_scheduler': lr_scheduler}) + if args.comm_backend == 'horovod': + trainer = hvd.DistributedTrainer(*trainer_settings) + else: + trainer = gluon.Trainer(*trainer_settings) # Load Data if args.sampler == 'BoundedBudgetSampler': train_batch_sampler = BoundedBudgetSampler(lengths=[(ele[2], ele[3]) for ele in data_train], - max_tokens=args.max_tokens, - max_sentences=args.max_sentences, - seed=args.seed) + max_num_tokens=args.max_num_tokens, + max_num_sentences=args.max_num_sentences, + seed=args.seed, + num_parts=num_parts, + part_index=rank) elif args.sampler == 'FixedBucketSampler': + if args.comm_backend == 'horovod': + raise NotImplementedError('FixedBucketSampler does not support horovod at present') + if args.bucket_scheme == 'constant': bucket_scheme = ConstWidthBucket() elif args.bucket_scheme == 'linear': @@ -368,12 +390,15 @@ def train(args): else: raise NotImplementedError + if local_rank == 0: + logging.info(train_batch_sampler) + batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack()) train_data_loader = gluon.data.DataLoader(data_train, batch_sampler=train_batch_sampler, batchify_fn=batchify_fn, num_workers=0) - logging.info(train_batch_sampler) + val_data_loader = gluon.data.DataLoader(data_val, batch_size=args.val_batch_size, batchify_fn=batchify_fn, @@ -432,7 +457,7 @@ def train(args): sample_data_l = next(train_multi_data_loader) except StopIteration: is_last_batch = True - if num_params is None: + if local_rank == 0 and num_params is None: num_params, num_fixed_params = count_parameters(model.collect_params()) logging.info('Total Number of Parameters (not-fixed/fixed): {}/{}' .format(num_params, num_fixed_params)) @@ -450,27 +475,29 @@ def train(args): if (args.epochs > 0 and epoch_id >= args.epochs - args.num_averages) or \ (args.max_update > 0 and n_train_iters >= args.max_update - args.num_averages * args.save_interval_update): model_averager.step() - if n_epoch_train_iters % args.log_interval == 0: + if local_rank == 0 and \ + (n_epoch_train_iters % args.log_interval == 0 or is_last_batch): log_end_time = time.time() log_wc = log_wc.asnumpy() wps = log_wc / (log_end_time - log_start_time) log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy() logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' 'throughput={:.2f}K wps, wc={:.2f}K, LR={}' - .format(epoch_id, processed_batch_num, len(train_data_loader), + .format(epoch_id, processed_batch_num * num_parts, len(train_data_loader), log_avg_loss, np.exp(log_avg_loss), wps / 1000, log_wc / 1000, trainer.learning_rate)) log_start_time = time.time() log_avg_loss = 0 log_loss_denom = 0 log_wc = 0 - if args.max_update > 0 and n_train_iters % args.save_interval_update == 0: + if local_rank == 0 and \ + (args.max_update > 0 and n_train_iters % args.save_interval_update == 0): model.save_parameters(os.path.join(args.save_dir, - '{:d}.params'.format(n_train_iters // args.save_interval_update)), + 'update{:d}.params'.format(n_train_iters // args.save_interval_update)), deduplicate=True) if args.max_update > 0 and n_train_iters >= args.max_update: break - if args.epochs > 0: + if local_rank == 0 and args.epochs > 0: model.save_parameters(os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)), deduplicate=True) diff --git a/scripts/machine_translation/wmt2014_back_translation.sh b/scripts/machine_translation/wmt2014_back_translation.sh index 9e12f8be3c..db7b702e52 100644 --- a/scripts/machine_translation/wmt2014_back_translation.sh +++ b/scripts/machine_translation/wmt2014_back_translation.sh @@ -126,13 +126,16 @@ python train_transformer.py \ --tgt_vocab_path ${datapath}/wmt2014_ende/${SUBWORD_ALGO}.vocab \ --save_dir backtranslation_transformer_wmt2014_ende_${SUBWORD_ALGO} \ --cfg transformer_base \ - --lr 0.002 \ - --batch_size 2700 \ - --max_update 60000 \ + --lr 0.003 \ + --max_num_tokens 4096 \ + --sampler BoundedBudgetSampler \ + --comm_backend horovod \ + --max_update 30000 \ --save_interval_update 1000 \ - --warmup_steps 4000 \ + --warmup_steps 6000 \ --warmup_init_lr 0.0 \ - --seed 100 \ + --num_averages -1 \ + --seed 123 \ --gpus 0,1,2,3 # TODO nlp_average_checkpoint @@ -142,7 +145,7 @@ nlp_nmt average_checkpoint --prefix range() \ # Finally, we can evaluate the model python evaluate_transformer.py \ - --param_path backtranslation_transformer_wmt2014_ende_${SUBWORD_ALGO}/average.params \ + --param_path backtranslation_transformer_wmt2014_ende_${SUBWORD_ALGO}/avg_20_29.params \ --src_lang ${SRC} \ --tgt_lang ${TGT} \ --cfg transformer_base \ diff --git a/src/gluonnlp/cli/average_checkpoint.py b/src/gluonnlp/cli/average_checkpoint.py index 8a4ce86b63..a660244bfe 100644 --- a/src/gluonnlp/cli/average_checkpoint.py +++ b/src/gluonnlp/cli/average_checkpoint.py @@ -1,39 +1,44 @@ import argparse import mxnet as mx +import re mx.npx.set_np() def get_parser(): parser = argparse.ArgumentParser(description='Script to average the checkpoints') - parser.add_argument('--checkpoints', type=str, required=True, - help='path of checkpoints, use * to represent the numbers, ' - 'e.g. --checkpoints folder/epoch*.prams') - parser.add_argument('--range', type=str, nargs='+', required=True, - help='number of checkpoints, supports range and list format at present, ' - 'e.g. --range range(3) [4,7, 5] range(8,100,2)') + parser.add_argument('--checkpoints', type=str, required=True, nargs='+', + help='checkpoint file paths, supports two format, ' + '--checkpoints folder/epoch*.params or --checkpoints folder/update*.param') + parser.add_argument('--begin', type=int, required=True, help='begin number of checkpoints') + parser.add_argument('--end', type=int, required=True, help='end number of checkpoints') parser.add_argument('--save-path', type=str, required=True, help='Path of the output file') return parser def main(args): - temp_range = [] - try: - for r in args.range: - if len(r) > 5 and r[:5] == 'range': - r = r[5:].strip()[1:-1].split(',') - r = tuple([int(n.strip()) for n in r]) - assert len(r) >= 1 and len(r) <= 3 - temp_range.extend(range(*r)) - elif r[0] == '[' and r[-1] == ']': - r = r[1:-1].split(',') - r = [int(n.strip()) for n in r] - temp_range.extend(r) - else: - raise NotImplementedError - except: - raise Exception('wrong range format') - args.range = temp_range - ckpt_paths = [args.checkpoints.replace('*', str(i)) for i in args.range] + assert args.begin >= 0 + assert args.end >= args.begin + args.range = list(range(args.begin, args.end + 1)) + + ckpt_epochs_regexp = re.compile(r'(.*\/)?epoch(\d+)\.params') + ckpt_updates_regexp = re.compile(r'(.*\/)?update(\d+)\.params') + ckpt_path = args.checkpoints[0] + if ckpt_epochs_regexp.fullmatch(ckpt_path) is not None: + ckpt_regexp = ckpt_epochs_regexp + elif ckpt_updates_regexp.fullmatch(ckpt_path) is not None: + ckpt_regexp = ckpt_updates_regexp + else: + raise Exception('Wrong checkpoints path format') + + ckpt_paths = [] + for path in args.checkpoints: + m = ckpt_regexp.fullmatch(path) + assert m is not None, 'Wrong checkpoints path format' + num = int(m.group(2)) + if num >= args.begin and num <= args.end: + ckpt_paths.append(path) + + assert len(ckpt_paths) > 0 res = mx.npx.load(ckpt_paths[0]) keys = res.keys() for ckpt_path in ckpt_paths[1:]: diff --git a/src/gluonnlp/data/sampler.py b/src/gluonnlp/data/sampler.py index 5879ea453a..6d817dcdef 100644 --- a/src/gluonnlp/data/sampler.py +++ b/src/gluonnlp/data/sampler.py @@ -285,14 +285,20 @@ class BoundedBudgetSampler(BaseSampler): Whether to shuffle the batches. seed The seed of the sampler + num_parts + Number of partitions which the data is split into (default: 1) + part_index + The index of the part to read from """ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], max_num_tokens: int = -1, max_num_sentences: int = -1, required_batch_size_multiple: int = 1, - shuffle: bool = False, seed: Optional[int] = None): + shuffle: bool = True, seed: Optional[int] = None, + num_parts: int = 1, part_index: int = 0): assert len(lengths) > 0, 'BoundedBudgetSampler does not support empty lengths.' assert max_num_tokens > 0 or max_num_sentences > 0, \ 'One of max_num_tokens and max_num_sentences must be larger than 0' + assert part_index < num_parts, 'part_index should be less than num_parts' self._lengths = np.array(lengths) if self._lengths.ndim == 2: self._lengths = self._lengths.max(axis=1) @@ -302,6 +308,8 @@ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], self._batches = [] self._shuffle = shuffle self._rng = np.random.RandomState(seed) + self._num_parts = num_parts + self._part_index = part_index # sort self._indices = self._indices[np.argsort(self._lengths, kind='mergesort')] batch = [] @@ -332,7 +340,11 @@ def __init__(self, lengths: Union[Sequence[int], Sequence[Sequence[int]]], def __iter__(self): if self._shuffle: self._rng.shuffle(self._batches) - for batch in self._batches: + part_batches = [] + for i in range(len(self._batches)): + if i % self._num_parts == self._part_index: + part_batches.append(self._batches[i]) + for batch in part_batches: yield batch def __len__(self):