-
Notifications
You must be signed in to change notification settings - Fork 2
/
tokenizer_check.py
45 lines (37 loc) · 1.62 KB
/
tokenizer_check.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import argparse
import paddle
paddle.set_device("cpu")
text = [
"it is a nice day today!", # 官方例题
"My friends are cool but they eat too many carbs.", # 官方例题
"My 'but' they:@ eat too many carbs:)",
]
def run_check(model_name):
for t in text:
print("input text:", t)
PTtokenizer = PTTokenizer.from_pretrained('facebook/' + model_name)
pt_temp = PTtokenizer(t)
pt_inputs = pt_temp["input_ids"]
PDtokenizer = PDTokenizer.from_pretrained(model_name)
pd_temp = PDtokenizer(t)
pd_inputs = pd_temp["input_ids"]
if pt_inputs == pd_inputs:
print("Passed")
# print(f"torch token matched paddle? {pt_inputs == pd_inputs}")
# print("torch tokenizer: ", pt_inputs)
# print("paddle tokenizer: ", pd_inputs)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default='blenderbot_small-90M',
help="blenderbot_small-90M or blenderbot-400M-distill")
args = parser.parse_args()
model_name = args.model_name
if model_name == 'blenderbot_small-90M':
from paddlenlp.transformers import BlenderbotSmallTokenizer as PDTokenizer
from transformers import BlenderbotSmallTokenizer as PTTokenizer
elif model_name == 'blenderbot-400M-distill':
from paddlenlp.transformers import BlenderbotTokenizer as PDTokenizer
from transformers import BlenderbotTokenizer as PTTokenizer
else:
raise f"model name not in {['blenderbot_small-90M', 'blenderbot-400M-distill']} "
run_check(model_name=model_name)