Skip to content

Commit

Permalink
Merge pull request #778 from liuyuyan2717/faster-chinese-mlm
Browse files Browse the repository at this point in the history
Do not use pipeline to achieve faster generation of Chinese mask repl…
  • Loading branch information
jxmorris12 authored Mar 5, 2024
2 parents 2fc025b + 894499c commit 62fad01
Showing 1 changed file with 15 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,24 @@ class ChineseWordSwapMaskedLM(WordSwap):
model."""

def __init__(self, task="fill-mask", model="xlm-roberta-base", **kwargs):
self.unmasker = pipeline(task, model)
from transformers import BertTokenizer, BertForMaskedLM
import torch
self.tt = BertTokenizer.from_pretrained(model)
self.mm = BertForMaskedLM.from_pretrained(model)
self.mm.to("cuda")
super().__init__(**kwargs)

def get_replacement_words(self, current_text, indice_to_modify):
masked_text = current_text.replace_word_at_index(indice_to_modify, "<mask>")
outputs = self.unmasker(masked_text.text)
words = []
for dict in outputs:
take = True
for char in dict["token_str"]:
# accept only Chinese characters for potential substitutions
if not is_cjk(char):
take = False
if take:
words.append(dict["token_str"])

return words
masked_text = current_text.replace_word_at_index(indice_to_modify, "[MASK]") # 修改前<mask>,xlmrberta的模型
tokens = self.tt.tokenize(masked_text.text)
input_ids = self.tt.convert_tokens_to_ids(tokens)
input_tensor = torch.tensor([input_ids]).to("cuda")
with torch.no_grad():
outputs = self.mm(input_tensor)
predictions = outputs.logits
predicted_token_ids = torch.argsort(predictions[0, indice_to_modify], descending=True)[:50]
predicted_tokens = self.tt.convert_ids_to_tokens(predicted_token_ids.tolist()[1:])
return predicted_tokens

def _get_transformations(self, current_text, indices_to_modify):
words = current_text.words
Expand Down

0 comments on commit 62fad01

Please sign in to comment.