Skip to content

Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint

License

Notifications You must be signed in to change notification settings

jsksxs360/bin2ckpt

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 

Repository files navigation

bin2ckpt

Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint

详细说明可以参见《将 PyTorch 版 bin 模型转换成 Tensorflow 版 ckpt》

Environment

  • torch
  • tensorflow
  • transformers

Usage

bin_path = './pretrained_model/pytorch_model/'
bin_model = 'pytorch_model.bin'
ckpt_path = './pretrained_model/tensorflow_model/'
ckpt_model = 'bert_model.ckpt'

convert(bin_path, bin_model, ckpt_path, ckpt_model)
  • bin_path: pytorch model path
  • bin_model: pytorch model name
  • ckpt_path: path to save tf ckpt
  • ckpt_model: tf ckpt name

Notice: this script only supports to convert the BERT model. If you need to convert other models, please modify the function to_tf_var_name() and variable tensors_to_transpose.

Converted Models (ckpt)

  • SpanBERT

    • SpanBERT (base & cased): 12-layer, 768-hidden, 12-heads , 110M parameters
    • SpanBERT (large & cased): 24-layer, 1024-hidden, 16-heads, 340M parameters

    Download: GoogleDrive | BaiduDrive (Code: wtyr) | CTDrive

  • SimCSE

    • supervised SimCSE (base & uncased)
    • supervised SimCSE (large & uncased)

    Download: GoogleDrive | BaiduDrive (Code: gnq3) | CTDrive

About

Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages