Skip to content

Commit

Permalink
adding functionality to convert TF model to HuggingFace compatible (P…
Browse files Browse the repository at this point in the history
…yTorch) one
  • Loading branch information
keskarnitish committed Oct 30, 2019
1 parent e900f8a commit a0d0b4d
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,25 @@ Authors: [Nitish Shirish Keskar](http://keskarnitish.github.io), [Bryan McCann](

## Updates

**Oct 31, 2019**

Adding functionality to convert a model from TF to HuggingFace/Transformers in response to [a request](https://github.com/huggingface/transformers/issues/1654). To convert the checkpoint, simply run `python -u convert_tf_to_huggingface_pytorch.py --tf <path_to_tensorflow_data_checkpoint> --pytorch <path_to_where_you_want_to_store_pytorch_checkpoint>`

Then, to use this in HuggingFace:

```
# create folder and contents for HuggingFace/Transformers
mkdir custom_ctrl_model
cd custom_ctrl_model
mv <path_to_pytorch_checkpoint_from_above> .
wget -O config.json https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json
wget -O merges.txt https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-merges.txt
wget -O vocab.json https://raw.githubusercontent.com/salesforce/ctrl/master/ctrl-vocab.json
# run
python examples/run_generation.py --model_type ctrl --model_name <path_to_custom_ctrl_model>/ --temperature 0 --repetition 1.2
```

**Oct 21, 2019**

CTRL is now in [hugginface/transformers](https://github.com/huggingface/transformers)!
Expand Down
83 changes: 83 additions & 0 deletions convert_tf_to_huggingface_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import tensorflow as tf
import tqdm
import torch
import os
import argparse
import sys

from transformers import CTRLConfig
from transformers import CTRLLMHeadModel, CTRLTokenizer
from tensorflow.python import pywrap_tensorflow

parser = argparse.ArgumentParser(description='Code for converting TF checkpoint to PyTorch')
parser.add_argument('--tf_checkpoint', type=str, required=True,
help='location of the .data file of the TensorFlow checkpoint. This is NOT the model folder. This could be <path>/seqlen256_v1.ckpt/model.ckpt-413000.data-00000-of-00001')
parser.add_argument('--pytorch_checkpoint', type=str, default='pytorch_model.bin',
help='location of where to write the PyTorch checkpoint')
parser.add_argument('--num_layers', type=int, default=48,
help='number of layers in the model being converted')

args = parser.parse_args()

model = CTRLLMHeadModel(CTRLConfig())

if os.path.isfile(args.tf_checkpoint):
print('INFO :: Found TensorFlow checkpoint')
else:
print('INFO :: TensorFlow checkpoint not found. Please verify location of the .data file or raise GitHub issue if problem persists.')

if os.path.isfile(args.pytorch_checkpoint):
print('PyTorch model already exists. Will not over-write. Please delete old checkpoint or specify different file name')
sys.exit(1)


chkpt_for_reader = '.'.join(args.tf_checkpoint.split('.')[:-1])
reader = pywrap_tensorflow.NewCheckpointReader(chkpt_for_reader)

tensor_read_get = lambda x, y: torch.tensor(reader.get_tensor(x))
def tensor_read_get(varname, transpose=True):
loaded_weight = torch.tensor(reader.get_tensor(varname))
if transpose and len(loaded_weight.shape)>1:
return loaded_weight.t()
else:
return loaded_weight
model.transformer.w.weight.data = tensor_read_get('w', transpose=False)
model.lm_head.bias.data = tensor_read_get('b')
model.transformer.layernorm.weight.data = tensor_read_get('encoder/layer_normalization_96/gamma')
model.transformer.layernorm.bias.data = tensor_read_get('encoder/layer_normalization_96/beta')

list_of_variables = list(filter(lambda x: 'Adagrad' not in x, reader.get_variable_to_shape_map().keys()))

if args.num_layers != 48:
raise NotImplementedError('Only supports 48 layers at the moment')

for i in tqdm.tqdm(range(args.num_layers)):
if i==0:
layer_variables = sorted(filter(lambda x: 'layer/' in x, list_of_variables))
else:
layer_variables = sorted(filter(lambda x: 'layer_'+str(i)+'/' in x, list_of_variables))

current_layer = model.transformer.h[i]

current_layer.layernorm1.bias.data = tensor_read_get(layer_variables[0])
current_layer.layernorm1.weight.data = tensor_read_get(layer_variables[1])

current_layer.layernorm2.bias.data = tensor_read_get(layer_variables[2])
current_layer.layernorm2.weight.data = tensor_read_get(layer_variables[3])


current_layer.multi_head_attention.Wq.bias.data = tensor_read_get(layer_variables[4])
current_layer.multi_head_attention.Wq.weight.data = tensor_read_get(layer_variables[5])
current_layer.multi_head_attention.Wk.bias.data = tensor_read_get(layer_variables[6])
current_layer.multi_head_attention.Wk.weight.data = tensor_read_get(layer_variables[7])
current_layer.multi_head_attention.Wv.bias.data = tensor_read_get(layer_variables[8])
current_layer.multi_head_attention.Wv.weight.data = tensor_read_get(layer_variables[9])
current_layer.multi_head_attention.dense.bias.data = tensor_read_get(layer_variables[10])
current_layer.multi_head_attention.dense.weight.data = tensor_read_get(layer_variables[11])
current_layer.ffn[0].bias.data = tensor_read_get(layer_variables[12])
current_layer.ffn[0].weight.data = tensor_read_get(layer_variables[13])
current_layer.ffn[2].bias.data = tensor_read_get(layer_variables[14])
current_layer.ffn[2].weight.data = tensor_read_get(layer_variables[15])

torch.save(model.state_dict(), args.pytorch_checkpoint)
print('INFO :: Saved PyTorch model to ', args.pytorch_checkpoint)

0 comments on commit a0d0b4d

Please sign in to comment.