diff --git a/README.md b/README.md index 253245561b28..87afa88b5223 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ This implementation is provided with [Google's pre-trained models](https://githu ## Installation -This repo was tested on Python 3.6+ and PyTorch 0.4.1 +This repo was tested on Python 3.5+ and PyTorch 0.4.1/1.0.0 ### With pip @@ -46,13 +46,13 @@ python -m pytest -sv tests/ This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme: -- Seven PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file): +- Eight PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file): - [`BertModel`](./pytorch_pretrained_bert/modeling.py#L537) - raw BERT Transformer model (**fully pre-trained**), - [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L691) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**), - [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L752) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**), - [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L620) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**), - [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L814) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**), - - [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**), + - [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**), - [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L949) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**), - [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1015) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**). @@ -156,7 +156,7 @@ Here is a detailed documentation of the classes in the package and how to use th | Sub-section | Description | |-|-| | [Loading Google AI's pre-trained weigths](#Loading-Google-AIs-pre-trained-weigths-and-PyTorch-dump) | How to load Google AI's pre-trained weight or a PyTorch saved instance | -| [PyTorch models](#PyTorch-models) | API of the seven PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering` | +| [PyTorch models](#PyTorch-models) | API of the eight PyTorch model classes: `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering` | | [Tokenizer: `BertTokenizer`](#Tokenizer-BertTokenizer) | API of the `BertTokenizer` class| | [Optimizer: `BertAdam`](#Optimizer-BertAdam) | API of the `BertAdam` class | @@ -170,7 +170,7 @@ model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None) where -- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the seven PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification` or `BertForQuestionAnswering`, and +- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the eight PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification`, `BertForTokenClassification`, `BertForMultipleChoice` or `BertForQuestionAnswering`, and - `PRE_TRAINED_MODEL_NAME_OR_PATH` is either: - the shortcut name of a Google AI's pre-trained model selected in the list: @@ -353,14 +353,13 @@ The optimizer accepts the following arguments: BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32). -To help with fine-tuning these models, we have included five techniques that you can activate in the fine-tuning scripts [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py): gradient-accumulation, multi-gpu training, distributed training, optimize on CPU and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month. +To help with fine-tuning these models, we have included several techniques that you can activate in the fine-tuning scripts [`run_classifier.py`](./examples/run_classifier.py) and [`run_squad.py`](./examples/run_squad.py): gradient-accumulation, multi-gpu training, distributed training and 16-bits training . For more details on how to use these techniques you can read [the tips on training large batches in PyTorch](https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255) that I published earlier this month. Here is how to use these techniques in our scripts: - **Gradient Accumulation**: Gradient accumulation can be used by supplying a integer greater than 1 to the `--gradient_accumulation_steps` argument. The batch at each step will be divided by this integer and gradient will be accumulated over `gradient_accumulation_steps` steps. - **Multi-GPU**: Multi-GPU is automatically activated when several GPUs are detected and the batches are splitted over the GPUs. - **Distributed training**: Distributed training can be activated by supplying an integer greater or equal to 0 to the `--local_rank` argument (see below). -- **Optimize on CPU**: The Adam optimizer stores 2 moving average of the weights of the model. If you keep them on GPU 1 (typical behavior), your first GPU will have to store 3-times the size of the model. This is not optimal for large models like `BERT-large` and means your batch size is a lot lower than it could be. This option will perform the optimization and store the averages on the CPU/RAM to free more room on the GPU(s). As the most computational intensive operation is usually the backward pass, this doesn't have a significant impact on the training time. Activate this option with `--optimize_on_cpu` on the [`run_squad.py`](./examples/run_squad.py) script. - **16-bits training**: 16-bits training, also called mixed-precision training, can reduce the memory requirement of your model on the GPU by using half-precision training, basically allowing to double the batch size. If you have a recent GPU (starting from NVIDIA Volta architecture) you should see no decrease in speed. A good introduction to Mixed precision training can be found [here](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) and a full documentation is [here](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html). In our scripts, this option can be activated by setting the `--fp16` flag and you can play with loss scaling using the `--loss_scaling` flag (see the previously linked documentation for details on loss scaling). If the loss scaling is too high (`Nan` in the gradients) it will be automatically scaled down until the value is acceptable. The default loss scaling is 128 which behaved nicely in our tests. Note: To use *Distributed Training*, you will need to run one training script on each of your machines. This can be done for example by running the following command on each server (see [the above mentioned blog post]((https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255)) for more details): @@ -371,16 +370,21 @@ Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your mach ### Fine-tuning with BERT: running the examples -We showcase the same examples as [the original implementation](https://github.com/google-research/bert/): fine-tuning a sequence-level classifier on the MRPC classification corpus and a token-level classifier on the question answering dataset SQuAD. +We showcase several fine-tuning examples based on (and extended from) [the original implementation](https://github.com/google-research/bert/): -Before running these examples you should download the +- a *sequence-level classifier* on the MRPC classification corpus, +- a *token-level classifier* on the question answering dataset SQuAD, and +- a *sequence-level multiple-choice classifier* on the SWAG classification corpus. + +#### MRPC + +This example code fine-tunes BERT on the Microsoft Research Paraphrase +Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80 and in 27 seconds (!) on single tesla V100 16GB with apex installed. + +Before running this example you should download the [GLUE data](https://gluebenchmark.com/tasks) by running [this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e) -and unpack it to some directory `$GLUE_DIR`. Please also download the `BERT-Base` -checkpoint, unzip it to some directory `$BERT_BASE_DIR`, and convert it to its PyTorch version as explained in the previous section. - -This example code fine-tunes `BERT-Base` on the Microsoft Research Paraphrase -Corpus (MRPC) corpus and runs in less than 10 minutes on a single K-80. +and unpack it to some directory `$GLUE_DIR`. ```shell export GLUE_DIR=/path/to/glue @@ -401,7 +405,29 @@ python run_classifier.py \ Our test ran on a few seeds with [the original implementation hyper-parameters](https://github.com/google-research/bert#sentence-and-sentence-pair-classification-tasks) gave evaluation results between 84% and 88%. -The second example fine-tunes `BERT-Base` on the SQuAD question answering task. +**Fast run with apex and 16 bit precision: fine-tuning on MRPC in 27 seconds!** +First install apex as indicated [here](https://github.com/NVIDIA/apex). +Then run +```shell +export GLUE_DIR=/path/to/glue + +python run_classifier.py \ + --task_name MRPC \ + --do_train \ + --do_eval \ + --do_lower_case \ + --data_dir $GLUE_DIR/MRPC/ \ + --bert_model bert-base-uncased \ + --max_seq_length 128 \ + --train_batch_size 32 \ + --learning_rate 2e-5 \ + --num_train_epochs 3.0 \ + --output_dir /tmp/mrpc_output/ +``` + +#### SQuAD + +This example code fine-tunes BERT on the SQuAD dataset. It runs in 24 min (with BERT-base) or 68 min (with BERT-large) on a single tesla V100 16GB. The data for SQuAD can be downloaded with the following links and should be saved in a `$SQUAD_DIR` directory. @@ -432,7 +458,9 @@ Training with the previous hyper-parameters gave us the following results: {"f1": 88.52381567990474, "exact_match": 81.22043519394512} ``` -The data for Swag can be downloaded by cloning the following [repository](https://github.com/rowanz/swagaf) +#### SWAG + +The data for SWAG can be downloaded by cloning the following [repository](https://github.com/rowanz/swagaf) ```shell export SWAG_DIR=/path/to/SWAG @@ -440,17 +468,18 @@ export SWAG_DIR=/path/to/SWAG python run_swag.py \ --bert_model bert-base-uncased \ --do_train \ + --do_lower_case \ --do_eval \ - --data_dir $SWAG_DIR/data + --data_dir $SWAG_DIR/data \ --train_batch_size 16 \ --learning_rate 2e-5 \ --num_train_epochs 3.0 \ --max_seq_length 80 \ - --output_dir /tmp/swag_output/ + --output_dir /tmp/swag_output/ \ --gradient_accumulation_steps 4 ``` -Training with the previous hyper-parameters gave us the following results: +Training with the previous hyper-parameters on a single GPU gave us the following results: ``` eval_accuracy = 0.8062081375587323 eval_loss = 0.5966546792367169 diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 000000000000..e47eb548f9a1 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,7 @@ +FROM pytorch/pytorch:latest + +RUN git clone https://github.com/NVIDIA/apex.git && cd apex && python setup.py install --cuda_ext --cpp_ext + +RUN pip install pytorch-pretrained-bert + +WORKDIR /workspace \ No newline at end of file diff --git a/examples/extract_features.py b/examples/extract_features.py index dbab934c0813..4f8812121ea1 100644 --- a/examples/extract_features.py +++ b/examples/extract_features.py @@ -168,7 +168,7 @@ def read_examples(input_file): """Read a list of `InputExample`s from an input file.""" examples = [] unique_id = 0 - with open(input_file, "r") as reader: + with open(input_file, "r", encoding='utf-8') as reader: while True: line = reader.readline() if not line: diff --git a/examples/run_classifier.py b/examples/run_classifier.py index a531ea572554..adf81f4e28b6 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -36,13 +36,6 @@ from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE -try: - from apex.optimizers import FP16_Optimizer - from apex.optimizers import FusedAdam - from apex.parallel import DistributedDataParallel as DDP -except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.") - logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', level = logging.INFO) @@ -98,7 +91,7 @@ def get_labels(self): @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" - with open(input_file, "r") as f: + with open(input_file, "r", encoding='utf-8') as f: reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: @@ -329,7 +322,7 @@ def main(): default=None, type=str, required=True, - help="The output directory where the model checkpoints will be written.") + help="The output directory where the model predictions and checkpoints will be written.") ## Other parameters parser.add_argument("--max_seq_length", @@ -420,7 +413,8 @@ def main(): n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') - logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) + logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( + device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( @@ -467,6 +461,11 @@ def main(): model.half() model.to(device) if args.local_rank != -1: + try: + from apex.parallel import DistributedDataParallel as DDP + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + model = DDP(model) elif n_gpu > 1: model = torch.nn.DataParallel(model) @@ -482,6 +481,12 @@ def main(): if args.local_rank != -1: t_total = t_total // torch.distributed.get_world_size() if args.fp16: + try: + from apex.optimizers import FP16_Optimizer + from apex.optimizers import FusedAdam + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, @@ -546,6 +551,16 @@ def main(): optimizer.zero_grad() global_step += 1 + # Save a trained model + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self + output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") + torch.save(model_to_save.state_dict(), output_model_file) + + # Load a trained model that you have fine-tuned + model_state_dict = torch.load(output_model_file) + model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict) + model.to(device) + if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): eval_examples = processor.get_dev_examples(args.data_dir) eval_features = convert_examples_to_features( diff --git a/examples/run_squad.py b/examples/run_squad.py index b96fcece37ce..6a97dd300b0e 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -39,13 +39,6 @@ from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE -try: - from apex.optimizers import FP16_Optimizer - from apex.optimizers import FusedAdam - from apex.parallel import DistributedDataParallel as DDP -except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.") - logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S', level = logging.INFO) @@ -115,7 +108,7 @@ def __init__(self, def read_squad_examples(input_file, is_training): """Read a SQuAD json file into a list of SquadExample.""" - with open(input_file, "r") as reader: + with open(input_file, "r", encoding='utf-8') as reader: input_data = json.load(reader)["data"] def is_whitespace(c): @@ -690,7 +683,7 @@ def main(): help="Bert pre-trained model selected in the list: bert-base-uncased, " "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") parser.add_argument("--output_dir", default=None, type=str, required=True, - help="The output directory where the model checkpoints will be written.") + help="The output directory where the model checkpoints and predictions will be written.") ## Other parameters parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json") @@ -764,7 +757,7 @@ def main(): n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') - logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits trainiing: {}".format( + logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: @@ -813,6 +806,11 @@ def main(): model.half() model.to(device) if args.local_rank != -1: + try: + from apex.parallel import DistributedDataParallel as DDP + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + model = DDP(model) elif n_gpu > 1: model = torch.nn.DataParallel(model) @@ -834,6 +832,12 @@ def main(): if args.local_rank != -1: t_total = t_total // torch.distributed.get_world_size() if args.fp16: + try: + from apex.optimizers import FP16_Optimizer + from apex.optimizers import FusedAdam + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, @@ -911,6 +915,16 @@ def main(): optimizer.zero_grad() global_step += 1 + # Save a trained model + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self + output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") + torch.save(model_to_save.state_dict(), output_model_file) + + # Load a trained model that you have fine-tuned + model_state_dict = torch.load(output_model_file) + model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict) + model.to(device) + if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): eval_examples = read_squad_examples( input_file=args.predict_file, is_training=False) diff --git a/examples/run_swag.py b/examples/run_swag.py index 88297bf80107..caddbee8ab7a 100644 --- a/examples/run_swag.py +++ b/examples/run_swag.py @@ -1,5 +1,6 @@ # coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -99,7 +100,7 @@ def __init__(self, def read_swag_examples(input_file, is_training): - with open(input_file, 'r') as f: + with open(input_file, 'r', encoding='utf-8') as f: reader = csv.reader(f) lines = list(reader) @@ -232,34 +233,10 @@ def select_field(features, field): for feature in features ] -def copy_optimizer_params_to_model(named_params_model, named_params_optimizer): - """ Utility function for optimize_on_cpu and 16-bits training. - Copy the parameters optimized on CPU/RAM back to the model on GPU - """ - for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model): - if name_opti != name_model: - logger.error("name_opti != name_model: {} {}".format(name_opti, name_model)) - raise ValueError - param_model.data.copy_(param_opti.data) - -def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_nan=False): - """ Utility function for optimize_on_cpu and 16-bits training. - Copy the gradient of the GPU parameters to the CPU/RAMM copy of the model - """ - is_nan = False - for (name_opti, param_opti), (name_model, param_model) in zip(named_params_optimizer, named_params_model): - if name_opti != name_model: - logger.error("name_opti != name_model: {} {}".format(name_opti, name_model)) - raise ValueError - if param_model.grad is not None: - if test_nan and torch.isnan(param_model.grad).sum() > 0: - is_nan = True - if param_opti.grad is None: - param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size())) - param_opti.grad.data.copy_(param_model.grad.data) - else: - param_opti.grad = None - return is_nan +def warmup_linear(x, warmup=0.002): + if x < warmup: + return x/warmup + return 1.0 - x def main(): parser = argparse.ArgumentParser() @@ -335,17 +312,15 @@ def main(): type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.") - parser.add_argument('--optimize_on_cpu', - default=False, - action='store_true', - help="Whether to perform optimization and keep the optimizer averages on CPU") parser.add_argument('--fp16', default=False, action='store_true', help="Whether to use 16-bit float precision instead of 32-bit") parser.add_argument('--loss_scale', - type=float, default=128, - help='Loss scaling, positive power of 2 values can improve fp16 convergence.') + type=float, default=0, + help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" + "0 (default value): dynamic loss scaling.\n" + "Positive power of 2: static loss scaling value.\n") args = parser.parse_args() @@ -353,14 +328,13 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: + torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') - if args.fp16: - logger.info("16-bits training currently not supported in distributed training") - args.fp16 = False # (see https://github.com/pytorch/pytorch/pull/13496) - logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) + logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format( + device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( @@ -393,38 +367,55 @@ def main(): # Prepare model model = BertForMultipleChoice.from_pretrained(args.bert_model, cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank), - num_choices = 4 - ) + num_choices=4) if args.fp16: model.half() model.to(device) if args.local_rank != -1: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], - output_device=args.local_rank) + try: + from apex.parallel import DistributedDataParallel as DDP + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + + model = DDP(model) elif n_gpu > 1: model = torch.nn.DataParallel(model) # Prepare optimizer - if args.fp16: - param_optimizer = [(n, param.clone().detach().to('cpu').float().requires_grad_()) \ - for n, param in model.named_parameters()] - elif args.optimize_on_cpu: - param_optimizer = [(n, param.clone().detach().to('cpu').requires_grad_()) \ - for n, param in model.named_parameters()] - else: - param_optimizer = list(model.named_parameters()) - no_decay = ['bias', 'gamma', 'beta'] + param_optimizer = list(model.named_parameters()) + + # hack to remove pooler, which is not used + # thus it produce None grad that break apex + param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] + + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ - {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, - {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0} + {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, + {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] t_total = num_train_steps if args.local_rank != -1: t_total = t_total // torch.distributed.get_world_size() - optimizer = BertAdam(optimizer_grouped_parameters, - lr=args.learning_rate, - warmup=args.warmup_proportion, - t_total=t_total) + if args.fp16: + try: + from apex.optimizers import FP16_Optimizer + from apex.optimizers import FusedAdam + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") + + optimizer = FusedAdam(optimizer_grouped_parameters, + lr=args.learning_rate, + bias_correction=False, + max_grad_norm=1.0) + if args.loss_scale == 0: + optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) + else: + optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) + else: + optimizer = BertAdam(optimizer_grouped_parameters, + lr=args.learning_rate, + warmup=args.warmup_proportion, + t_total=t_total) global_step = 0 if args.do_train: @@ -461,30 +452,35 @@ def main(): loss = loss * args.loss_scale if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps - loss.backward() tr_loss += loss.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 + + if args.fp16: + optimizer.backward(loss) + else: + loss.backward() if (step + 1) % args.gradient_accumulation_steps == 0: - if args.fp16 or args.optimize_on_cpu: - if args.fp16 and args.loss_scale != 1.0: - # scale down gradients for fp16 training - for param in model.parameters(): - if param.grad is not None: - param.grad.data = param.grad.data / args.loss_scale - is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True) - if is_nan: - logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling") - args.loss_scale = args.loss_scale / 2 - model.zero_grad() - continue - optimizer.step() - copy_optimizer_params_to_model(model.named_parameters(), param_optimizer) - else: - optimizer.step() - model.zero_grad() + # modify learning rate with special warm up BERT uses + lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion) + for param_group in optimizer.param_groups: + param_group['lr'] = lr_this_step + optimizer.step() + optimizer.zero_grad() global_step += 1 + # Save a trained model + model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self + output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") + torch.save(model_to_save.state_dict(), output_model_file) + + # Load a trained model that you have fine-tuned + model_state_dict = torch.load(output_model_file) + model = BertForMultipleChoice.from_pretrained(args.bert_model, + state_dict=model_state_dict, + num_choices=4) + model.to(device) + if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True) eval_features = convert_examples_to_features( diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py index e1ecabf31dbd..0ef826374815 100644 --- a/pytorch_pretrained_bert/__init__.py +++ b/pytorch_pretrained_bert/__init__.py @@ -1,3 +1,4 @@ +__version__ = "0.4.0" from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .modeling import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, diff --git a/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py b/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py index 20fdd8c0d6e8..120624bc1b49 100755 --- a/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py +++ b/pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py @@ -50,7 +50,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor name = name.split('/') # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model - if name[-1] in ["adam_v", "adam_m"]: + if any(n in ["adam_v", "adam_m"] for n in name): print("Skipping {}".format("/".join(name))) continue pointer = model @@ -59,9 +59,9 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor l = re.split(r'_(\d+)', m_name) else: l = [m_name] - if l[0] == 'kernel': + if l[0] == 'kernel' or l[0] == 'gamma': pointer = getattr(pointer, 'weight') - elif l[0] == 'output_bias': + elif l[0] == 'output_bias' or l[0] == 'beta': pointer = getattr(pointer, 'bias') elif l[0] == 'output_weights': pointer = getattr(pointer, 'weight') diff --git a/pytorch_pretrained_bert/file_utils.py b/pytorch_pretrained_bert/file_utils.py index 139418f1a544..43fa8ca87e20 100644 --- a/pytorch_pretrained_bert/file_utils.py +++ b/pytorch_pretrained_bert/file_utils.py @@ -227,7 +227,7 @@ def read_set_from_file(filename: str) -> Set[str]: Expected file format is one item per line. ''' collection = set() - with open(filename, 'r') as file_: + with open(filename, 'r', encoding='utf-8') as file_: for line in file_: collection.add(line.rstrip()) return collection diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 1aeff4dd04c7..acdc741f6da4 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -34,9 +34,6 @@ from .file_utils import cached_path -logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt = '%m/%d/%Y %H:%M:%S', - level = logging.INFO) logger = logging.getLogger(__name__) PRETRAINED_MODEL_ARCHIVE_MAP = { @@ -106,7 +103,7 @@ def __init__(self, initializing all weight matrices. """ if isinstance(vocab_size_or_config_json_file, str): - with open(vocab_size_or_config_json_file, "r") as reader: + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: json_config = json.loads(reader.read()) for key, value in json_config.items(): self.__dict__[key] = value @@ -137,7 +134,7 @@ def from_dict(cls, json_object): @classmethod def from_json_file(cls, json_file): """Constructs a `BertConfig` from a json file of parameters.""" - with open(json_file, "r") as reader: + with open(json_file, "r", encoding='utf-8') as reader: text = reader.read() return cls.from_dict(json.loads(text)) @@ -448,9 +445,9 @@ def init_bert_weights(self, module): module.bias.data.zero_() @classmethod - def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): """ - Instantiate a PreTrainedBertModel from a pre-trained model file. + Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: @@ -464,6 +461,8 @@ def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwarg - a path or url to a pretrained model archive containing: . `bert_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ @@ -505,22 +504,23 @@ def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwarg logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) - weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load(weights_path) + if state_dict is None: + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load(weights_path) old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: - new_key = key.replace('gamma','weight') + new_key = key.replace('gamma', 'weight') if 'beta' in key: - new_key = key.replace('beta','bias') + new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key]=state_dict.pop(old_key) + state_dict[new_key] = state_dict.pop(old_key) missing_keys = [] unexpected_keys = [] diff --git a/pytorch_pretrained_bert/tokenization.py b/pytorch_pretrained_bert/tokenization.py index c7ef20ddefcb..5954b78f6833 100644 --- a/pytorch_pretrained_bert/tokenization.py +++ b/pytorch_pretrained_bert/tokenization.py @@ -25,9 +25,6 @@ from .file_utils import cached_path -logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt = '%m/%d/%Y %H:%M:%S', - level = logging.INFO) logger = logging.getLogger(__name__) PRETRAINED_VOCAB_ARCHIVE_MAP = { diff --git a/requirements.txt b/requirements.txt index e9a3640a9b3a..f37f11cc540b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ -# This installs Pytorch for CUDA 8 only. If you are using a newer version, -# please visit http://pytorch.org/ and install the relevant version. -torch>=0.4.1,<0.5.0 +# PyTorch +torch>=0.4.1 # progress bars in model download and training scripts tqdm # Accessing files from S3 directly. diff --git a/setup.py b/setup.py index fc793b53e695..dbfeb2c6948e 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,47 @@ +""" +Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py + +To create the package for pypi. + +1. Change the version in __init__.py and setup.py. + +2. Commit these changes with the message: "Release: VERSION" + +3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' " + Push the tag to git: git push --tags origin master + +4. Build both the sources and the wheel. Do not change anything in setup.py between + creating the wheel and the source distribution (obviously). + + For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory. + (this will build a wheel for the python version you use to build it - make sure you use python 3.x). + + For the sources, run: "python setup.py sdist" + You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp. + +5. Check that everything looks correct by uploading the package to the pypi test server: + + twine upload dist/* -r pypitest + (pypi suggest using twine as other methods upload files via plaintext.) + + Check that you can install it in a virtualenv by running: + pip install -i https://testpypi.python.org/pypi allennlp + +6. Upload the final version to actual pypi: + twine upload dist/* -r pypi + +7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. + +""" from setuptools import find_packages, setup setup( name="pytorch_pretrained_bert", - version="0.3.0", + version="0.4.0", author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors", author_email="thomas@huggingface.co", description="PyTorch version of Google AI BERT model with script to load Google pre-trained models", - long_description=open("README.md", "r").read(), + long_description=open("README.md", "r", encoding='utf-8').read(), long_description_content_type="text/markdown", keywords='BERT NLP deep learning google', license='Apache', diff --git a/tests/optimization_test.py b/tests/optimization_test.py index ad13c28d0c5e..848b9d1cf5c2 100644 --- a/tests/optimization_test.py +++ b/tests/optimization_test.py @@ -32,7 +32,7 @@ def assertListAlmostEqual(self, list1, list2, tol): def test_adam(self): w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True) target = torch.tensor([0.4, 0.2, -0.5]) - criterion = torch.nn.MSELoss(reduction='elementwise_mean') + criterion = torch.nn.MSELoss() # No warmup, constant schedule, no gradient clipping optimizer = BertAdam(params=[w], lr=2e-1, weight_decay=0.0,