-
Notifications
You must be signed in to change notification settings - Fork 161
/
Copy pathbasic_gibbs_sampling_via_mlm.py
54 lines (46 loc) · 1.98 KB
/
basic_gibbs_sampling_via_mlm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#! -*- coding: utf-8 -*-
# 测试代码可用性: 结合MLM的Gibbs采样
from tqdm import tqdm
import numpy as np
from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import Tokenizer
import torch
root_model_path = "E:/data/pretrain_ckpt/bert/google@chinese_L-12_H-768_A-12"
vocab_path = root_model_path + "/vocab.txt"
config_path = root_model_path + "/bert4torch_config.json"
checkpoint_path = root_model_path + '/pytorch_model.bin'
tokenizer = Tokenizer(vocab_path, do_lower_case=True) # 建立分词器
model = build_transformer_model(
config_path=config_path, checkpoint_path=checkpoint_path, with_mlm='softmax'
) # 建立模型,加载权重
sentences = []
init_sent = u'科学技术是第一生产力。' # 给定句子或者None
minlen, maxlen = 8, 32
steps = 10000
converged_steps = 1000
vocab_size = tokenizer._vocab_size
if init_sent is None:
length = np.random.randint(minlen, maxlen + 1)
tokens = ['[CLS]'] + ['[MASK]'] * length + ['[SEP]']
token_ids = tokenizer.tokens_to_ids(tokens)
segment_ids = [0] * len(token_ids)
else:
token_ids, segment_ids = tokenizer.encode(init_sent)
length = len(token_ids) - 2
device='cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
with torch.no_grad():
for _ in tqdm(range(steps), desc='Sampling'):
# Gibbs采样流程:随机mask掉一个token,然后通过MLM模型重新采样这个token。
i = np.random.choice(length) + 1
token_ids[i] = tokenizer._token_mask_id
token_ids_tensor, segment_ids_tensor = torch.tensor([token_ids], device=device), torch.tensor([segment_ids], device=device)
_, probas = model([token_ids_tensor, segment_ids_tensor])
probas = probas[0, i]
token = np.random.choice(vocab_size, p=probas.cpu().numpy())
token_ids[i] = token
sentences.append(tokenizer.decode(token_ids))
print(u'部分随机采样结果: ')
for _ in range(10):
print(np.random.choice(sentences[converged_steps:]))