-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathSave_pre_trained_locally.py
20 lines (16 loc) · 1020 Bytes
/
Save_pre_trained_locally.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from transformers import *
# Transformers has a unified API
# for 8 transformer architectures and 30 pretrained weights.
# Model | Tokenizer | Pretrained weights shortcut
MODELS = [(BertModel, BertTokenizer, 'bert-base-multilingual-cased')]
# To use TensorFlow 2.0 versions of the models, simply prefix the class names with 'TF', e.g. `TFRobertaModel` is the TF 2.0 counterpart of the PyTorch model `RobertaModel`
# Let's encode some text in a sequence of hidden-states using each model:
for model_class, tokenizer_class, pretrained_weights in MODELS:
# Load pretrained model/tokenizer
tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
model = model_class.from_pretrained(pretrained_weights,
output_hidden_states=True,
output_attentions=False)
model.save_pretrained('Pre_trained_BERT/') # save
tokenizer.save_pretrained('Pre_trained_BERT/') # save