-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
119 lines (98 loc) · 4.96 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# %%
from transformers import RobertaTokenizer, RobertaForMaskedLM
from transformers import ElectraTokenizer, ElectraForMaskedLM
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM
from transformers import XLNetTokenizer, XLNetLMHeadModel
import torch
import string
from transformers import BertTokenizer, BertForMaskedLM
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased').eval()
xlnet_tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
xlnet_model = XLNetLMHeadModel.from_pretrained('xlnet-base-cased').eval()
xlmroberta_tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')
xlmroberta_model = XLMRobertaForMaskedLM.from_pretrained(
'xlm-roberta-base').eval()
bart_tokenizer = BartTokenizer.from_pretrained('bart-large')
bart_model = BartForConditionalGeneration.from_pretrained('bart-large').eval()
electra_tokenizer = ElectraTokenizer.from_pretrained(
'google/electra-small-generator')
electra_model = ElectraForMaskedLM.from_pretrained(
'google/electra-small-generator').eval()
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
roberta_model = RobertaForMaskedLM.from_pretrained('roberta-base').eval()
top_k = 10
def decode(tokenizer, pred_idx, top_clean):
ignore_tokens = string.punctuation + '[PAD]'
tokens = []
for w in pred_idx:
token = ''.join(tokenizer.decode(w).split())
if token not in ignore_tokens:
tokens.append(token.replace('##', ''))
return '\n'.join(tokens[:top_clean])
def encode(tokenizer, text_sentence, add_special_tokens=True):
text_sentence = text_sentence.replace('<mask>', tokenizer.mask_token)
# if <mask> is the last token, append a "." so that models dont predict punctuation.
if tokenizer.mask_token == text_sentence.split()[-1]:
text_sentence += ' .'
input_ids = torch.tensor(
[tokenizer.encode(text_sentence, add_special_tokens=add_special_tokens)])
mask_idx = torch.where(input_ids == tokenizer.mask_token_id)[1].tolist()[0]
return input_ids, mask_idx
def get_all_predictions(text_sentence, top_clean=5):
# ========================= BERT =================================
print(text_sentence)
input_ids, mask_idx = encode(bert_tokenizer, text_sentence)
with torch.no_grad():
predict = bert_model(input_ids)[0]
bert = decode(bert_tokenizer, predict[0, mask_idx, :].topk(
top_k).indices.tolist(), top_clean)
# ========================= XLNET LARGE =================================
input_ids, mask_idx = encode(xlnet_tokenizer, text_sentence, False)
perm_mask = torch.zeros(
(1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
perm_mask[:, :, mask_idx] = 1.0 # Previous tokens don't see last token
# Shape [1, 1, seq_length] => let's predict one token
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)
# Our first (and only) prediction will be the last token of the sequence (the masked token)
target_mapping[0, 0, mask_idx] = 1.0
with torch.no_grad():
predict = xlnet_model(input_ids, perm_mask=perm_mask,
target_mapping=target_mapping)[0]
xlnet = decode(xlnet_tokenizer, predict[0, 0, :].topk(
top_k).indices.tolist(), top_clean)
# ========================= XLM ROBERTA BASE =================================
input_ids, mask_idx = encode(
xlmroberta_tokenizer, text_sentence, add_special_tokens=True)
with torch.no_grad():
predict = xlmroberta_model(input_ids)[0]
xlm = decode(xlmroberta_tokenizer, predict[0, mask_idx, :].topk(
top_k).indices.tolist(), top_clean)
# ========================= BART =================================
input_ids, mask_idx = encode(
bart_tokenizer, text_sentence, add_special_tokens=True)
with torch.no_grad():
predict = bart_model(input_ids)[0]
bart = decode(bart_tokenizer, predict[0, mask_idx, :].topk(
top_k).indices.tolist(), top_clean)
# ========================= ELECTRA =================================
input_ids, mask_idx = encode(
electra_tokenizer, text_sentence, add_special_tokens=True)
with torch.no_grad():
predict = electra_model(input_ids)[0]
electra = decode(electra_tokenizer, predict[0, mask_idx, :].topk(
top_k).indices.tolist(), top_clean)
# ========================= ROBERTA =================================
input_ids, mask_idx = encode(
roberta_tokenizer, text_sentence, add_special_tokens=True)
with torch.no_grad():
predict = roberta_model(input_ids)[0]
roberta = decode(roberta_tokenizer, predict[0, mask_idx, :].topk(
top_k).indices.tolist(), top_clean)
return {'bert': bert,
'xlnet': xlnet,
'xlm': xlm,
'bart': bart,
'electra': electra,
'roberta': roberta}