From a0d0b4d2f38ae55a1396dfad4d6bff7cc9435c2d Mon Sep 17 00:00:00 2001 From: keskarnitish Date: Wed, 30 Oct 2019 11:44:18 -0700 Subject: [PATCH] adding functionality to convert TF model to HuggingFace compatible (PyTorch) one --- README.md | 19 +++++++ convert_tf_to_huggingface_pytorch.py | 83 ++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 convert_tf_to_huggingface_pytorch.py diff --git a/README.md b/README.md index 9ea23d6..90ee7fd 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,25 @@ Authors: [Nitish Shirish Keskar](http://keskarnitish.github.io), [Bryan McCann]( ## Updates +**Oct 31, 2019** + +Adding functionality to convert a model from TF to HuggingFace/Transformers in response to [a request](https://github.com/huggingface/transformers/issues/1654). To convert the checkpoint, simply run `python -u convert_tf_to_huggingface_pytorch.py --tf --pytorch ` + +Then, to use this in HuggingFace: + +``` +# create folder and contents for HuggingFace/Transformers +mkdir custom_ctrl_model +cd custom_ctrl_model +mv . +wget -O config.json https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json +wget -O merges.txt https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt +wget -O vocab.json https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json + +# run +python examples/run_generation.py --model_type ctrl --model_name / --temperature 0 --repetition 1.2 +``` + **Oct 21, 2019** CTRL is now in [hugginface/transformers](https://github.com/huggingface/transformers)! diff --git a/convert_tf_to_huggingface_pytorch.py b/convert_tf_to_huggingface_pytorch.py new file mode 100644 index 0000000..6f2ab92 --- /dev/null +++ b/convert_tf_to_huggingface_pytorch.py @@ -0,0 +1,83 @@ +import tensorflow as tf +import tqdm +import torch +import os +import argparse +import sys + +from transformers import CTRLConfig +from transformers import CTRLLMHeadModel, CTRLTokenizer +from tensorflow.python import pywrap_tensorflow + +parser = argparse.ArgumentParser(description='Code for converting TF checkpoint to PyTorch') +parser.add_argument('--tf_checkpoint', type=str, required=True, + help='location of the .data file of the TensorFlow checkpoint. This is NOT the model folder. This could be /seqlen256_v1.ckpt/model.ckpt-413000.data-00000-of-00001') +parser.add_argument('--pytorch_checkpoint', type=str, default='pytorch_model.bin', + help='location of where to write the PyTorch checkpoint') +parser.add_argument('--num_layers', type=int, default=48, + help='number of layers in the model being converted') + +args = parser.parse_args() + +model = CTRLLMHeadModel(CTRLConfig()) + +if os.path.isfile(args.tf_checkpoint): + print('INFO :: Found TensorFlow checkpoint') +else: + print('INFO :: TensorFlow checkpoint not found. Please verify location of the .data file or raise GitHub issue if problem persists.') + +if os.path.isfile(args.pytorch_checkpoint): + print('PyTorch model already exists. Will not over-write. Please delete old checkpoint or specify different file name') + sys.exit(1) + + +chkpt_for_reader = '.'.join(args.tf_checkpoint.split('.')[:-1]) +reader = pywrap_tensorflow.NewCheckpointReader(chkpt_for_reader) + +tensor_read_get = lambda x, y: torch.tensor(reader.get_tensor(x)) +def tensor_read_get(varname, transpose=True): + loaded_weight = torch.tensor(reader.get_tensor(varname)) + if transpose and len(loaded_weight.shape)>1: + return loaded_weight.t() + else: + return loaded_weight +model.transformer.w.weight.data = tensor_read_get('w', transpose=False) +model.lm_head.bias.data = tensor_read_get('b') +model.transformer.layernorm.weight.data = tensor_read_get('encoder/layer_normalization_96/gamma') +model.transformer.layernorm.bias.data = tensor_read_get('encoder/layer_normalization_96/beta') + +list_of_variables = list(filter(lambda x: 'Adagrad' not in x, reader.get_variable_to_shape_map().keys())) + +if args.num_layers != 48: + raise NotImplementedError('Only supports 48 layers at the moment') + +for i in tqdm.tqdm(range(args.num_layers)): + if i==0: + layer_variables = sorted(filter(lambda x: 'layer/' in x, list_of_variables)) + else: + layer_variables = sorted(filter(lambda x: 'layer_'+str(i)+'/' in x, list_of_variables)) + + current_layer = model.transformer.h[i] + + current_layer.layernorm1.bias.data = tensor_read_get(layer_variables[0]) + current_layer.layernorm1.weight.data = tensor_read_get(layer_variables[1]) + + current_layer.layernorm2.bias.data = tensor_read_get(layer_variables[2]) + current_layer.layernorm2.weight.data = tensor_read_get(layer_variables[3]) + + + current_layer.multi_head_attention.Wq.bias.data = tensor_read_get(layer_variables[4]) + current_layer.multi_head_attention.Wq.weight.data = tensor_read_get(layer_variables[5]) + current_layer.multi_head_attention.Wk.bias.data = tensor_read_get(layer_variables[6]) + current_layer.multi_head_attention.Wk.weight.data = tensor_read_get(layer_variables[7]) + current_layer.multi_head_attention.Wv.bias.data = tensor_read_get(layer_variables[8]) + current_layer.multi_head_attention.Wv.weight.data = tensor_read_get(layer_variables[9]) + current_layer.multi_head_attention.dense.bias.data = tensor_read_get(layer_variables[10]) + current_layer.multi_head_attention.dense.weight.data = tensor_read_get(layer_variables[11]) + current_layer.ffn[0].bias.data = tensor_read_get(layer_variables[12]) + current_layer.ffn[0].weight.data = tensor_read_get(layer_variables[13]) + current_layer.ffn[2].bias.data = tensor_read_get(layer_variables[14]) + current_layer.ffn[2].weight.data = tensor_read_get(layer_variables[15]) + +torch.save(model.state_dict(), args.pytorch_checkpoint) +print('INFO :: Saved PyTorch model to ', args.pytorch_checkpoint)