-
Notifications
You must be signed in to change notification settings - Fork 161
/
Copy pathtask_relation_extraction_tplinker.py
285 lines (246 loc) · 11.8 KB
/
task_relation_extraction_tplinker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
#! -*- coding:utf-8 -*-
# 三元组抽取任务,tplinker, cat方式实体部分收敛较快,关系部分收敛较慢
# 官方链接:https://github.com/131250208/TPlinker-joint-extraction
# 数据集:http://ai.baidu.com/broad/download?dataset=sked
import json
from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel
from bert4torch.callbacks import Callback
from bert4torch.snippets import sequence_padding, ListDataset
from bert4torch.layers import TplinkerHandshakingKernel
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
maxlen = 64
batch_size = 64
config_path = 'E:/data/pretrain_ckpt/bert/google@chinese_L-12_H-768_A-12/bert4torch_config.json'
checkpoint_path = 'E:/data/pretrain_ckpt/bert/google@chinese_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'E:/data/pretrain_ckpt/bert/google@chinese_L-12_H-768_A-12/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
loss_weight_recover_steps = 6000 # 控制权重的分配,前期实体识别的权重高一些,建议也可以设置为model.total_steps
# 加载标签字典
predicate2id, id2predicate = {}, {}
with open('F:/data/corpus/relation_extraction/BD_Knowledge_Extraction/all_50_schemas', encoding='utf-8') as f:
for l in f:
l = json.loads(l)
if l['predicate'] not in predicate2id:
id2predicate[len(predicate2id)] = l['predicate']
predicate2id[l['predicate']] = len(predicate2id)
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
# 加载数据集
class MyDataset(ListDataset):
@staticmethod
def load_data(filename):
"""加载数据
单条格式:{'text': text, 'spo_list': [(s, p, o)]}
"""
D = []
with open(filename, encoding='utf-8') as f:
for l in f:
l = json.loads(l)
D.append({'text': l['text'],
'spo_list': [(spo['subject'], spo['predicate'], spo['object']) for spo in l['spo_list']]})
return D
def trans_ij2k(seq_len, i, j):
'''把第i行,第j列转化成上三角flat后的序号
'''
if (i > seq_len - 1) or (j > seq_len - 1) or (i > j):
return 0
return int(0.5*(2*seq_len-i+1)*i+(j-i))
map_ij2k = {(i, j): trans_ij2k(maxlen, i, j) for i in range(maxlen) for j in range(maxlen) if j >= i}
map_k2ij = {v: k for k, v in map_ij2k.items()}
def search(pattern, sequence):
"""从sequence中寻找子串pattern
如果找到,返回第一个下标;否则返回-1。
"""
n = len(pattern)
if isinstance(sequence, torch.Tensor):
sequence = sequence.cpu().tolist()
for i in range(len(sequence)):
if sequence[i:i + n] == pattern:
return i
return -1
def collate_fn(batch):
pair_len = maxlen * (maxlen+1)//2
# batch_entity_labels: [btz, pair_len]
# batch_head_labels: [btz, rel_size, pair_len]
# batch_tail_labels: [btz, rel_size, pair_len]
batch_entity_labels = torch.zeros((len(batch), pair_len), dtype=torch.long, device=device)
batch_head_labels = torch.zeros((len(batch), len(predicate2id), pair_len), dtype=torch.long, device=device)
batch_tail_labels = torch.zeros((len(batch), len(predicate2id), pair_len), dtype=torch.long, device=device)
batch_token_ids = []
for i, d in enumerate(batch):
token_ids = tokenizer.encode(d['text'])[0][1:-1][:maxlen] # 这里要限制取前max_len个
batch_token_ids.append(token_ids)
# 整理三元组 {s: [(o, p)]}
for s, p, o in d['spo_list']:
s = tokenizer.encode(s)[0][1:-1]
p = predicate2id[p]
o = tokenizer.encode(o)[0][1:-1]
sh = search(s, token_ids) # 这里超过长度就会找不到
oh = search(o, token_ids)
if sh != -1 and oh != -1:
st, ot = sh+len(s)-1, oh+len(o)-1
batch_entity_labels[i, map_ij2k[sh, st]] = 1
batch_entity_labels[i, map_ij2k[oh, ot]] = 1
if sh <= oh:
batch_head_labels[i, p, map_ij2k[sh, oh]] = 1
else:
batch_head_labels[i, p, map_ij2k[oh, sh]] = 2
if st <= ot:
batch_tail_labels[i, p, map_ij2k[st, ot]] = 1
else:
batch_tail_labels[i, p, map_ij2k[ot, st]] = 2
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, length=maxlen), dtype=torch.long, device=device)
return [batch_token_ids], [batch_entity_labels, batch_head_labels, batch_tail_labels]
train_dataset = MyDataset('F:/data/corpus/relation_extraction/BD_Knowledge_Extraction/train_data.json')
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataset = MyDataset('F:/data/corpus/relation_extraction/BD_Knowledge_Extraction/dev_data.json')
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate_fn)
# 定义bert上的模型结构
class Model(BaseModel):
def __init__(self):
super().__init__()
self.bert = build_transformer_model(config_path, checkpoint_path, segment_vocab_size=0)
self.combine_fc = nn.Linear(768*2, 768)
self.ent_fc = nn.Linear(768, 2)
self.head_rel_fc = nn.Linear(768, len(predicate2id)*3)
self.tail_rel_fc = nn.Linear(768, len(predicate2id)*3)
self.handshaking_kernel = TplinkerHandshakingKernel(768, shaking_type='cat')
def forward(self, *inputs):
last_hidden_state = self.bert(inputs) # [btz, seq_len, hdsz]
shaking_hiddens = self.handshaking_kernel(last_hidden_state) # [btz, pair_len, hdsz]
ent_shaking_outputs = self.ent_fc(shaking_hiddens) # [btz, pair_len, 2]
btz, pair_len = shaking_hiddens.shape[:2]
head_rel_shaking_outputs = self.head_rel_fc(shaking_hiddens).reshape(btz, -1, pair_len, 3) #[btz, predicate_num, pair_len, 3]
tail_rel_shaking_outputs = self.tail_rel_fc(shaking_hiddens).reshape(btz, -1, pair_len, 3)
return ent_shaking_outputs, head_rel_shaking_outputs, tail_rel_shaking_outputs
model = Model().to(device)
class MyLoss(nn.CrossEntropyLoss):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, y_preds, y_trues):
loss_list = []
for y_pred, y_true in zip(y_preds, y_trues):
loss = super().forward(y_pred.view(-1, y_pred.size()[-1]), y_true.view(-1))
loss_list.append(loss)
z = (2 * len(predicate2id) + 1)
w_ent = max(1 / z + 1 - model.global_step / loss_weight_recover_steps, 1 / z)
w_rel = min((len(predicate2id) / z) * model.global_step / loss_weight_recover_steps, (len(predicate2id) / z))
loss = w_ent*loss_list[0] + w_rel*loss_list[1] + w_rel*loss_list[2]
return {'loss': loss, 'entity_loss': loss_list[0], 'head_loss': loss_list[1], 'tail_loss': loss_list[2]}
model.compile(loss=MyLoss(), optimizer=optim.Adam(model.parameters(), 1e-4))
def extract_spoes(text):
"""抽取输入text所包含的三元组
"""
def get_spots_fr_shaking_tag(shaking_tag):
'''解析关系
'''
spots = []
for shaking_inds in shaking_tag.nonzero():
rel_id = shaking_inds[0].item()
tag_id = shaking_tag[rel_id][shaking_inds[1]].item()
matrix_inds = map_k2ij[shaking_inds[1].item()]
# 保证前面是subject,后面是object
if tag_id == 1:
spot = (rel_id, matrix_inds[0], matrix_inds[1])
elif tag_id == 2:
spot = (rel_id, matrix_inds[1], matrix_inds[0])
spots.append(spot)
return spots
tokens = tokenizer.tokenize(text)[1:-1]
mapping = tokenizer.rematch(text, tokens)
token_ids = tokenizer.encode(text)[0][1:-1]
token_ids_ts = torch.tensor(sequence_padding([token_ids], length=maxlen), dtype=torch.long, device=device)
outputs = model.predict([token_ids_ts])
outputs = [o[0].argmax(dim=-1) for o in outputs]
# 抽取entity
ent_matrix_spots = set()
ent_text = set()
for shaking_ind in outputs[0].nonzero():
shaking_ind_ = shaking_ind[0].item()
# tag_id = outputs[0][shaking_ind_]
matrix_inds = map_k2ij[shaking_ind_]
spot = (matrix_inds[0], matrix_inds[1])
if (spot[0] < len(mapping)) and (spot[1] < len(mapping)): # 实体起始在mapping范围内
ent_matrix_spots.add(spot)
ent_text.add(text[mapping[spot[0]][0]:mapping[spot[1]][-1] + 1])
# 识别对应的predicate
head_rel_matrix_spots = get_spots_fr_shaking_tag(outputs[1])
tail_rel_matrix_spots = get_spots_fr_shaking_tag(outputs[2])
spoes = []
for rel_h, sh, oh in head_rel_matrix_spots:
for rel_t, st, ot in tail_rel_matrix_spots:
# 如果关系相同,且(sh, st)和(oh, ot)都在entity_maxtrix_spots中
if (rel_h == rel_t) and ((sh, st) in ent_matrix_spots) and ((oh, ot) in ent_matrix_spots):
spoes.append((text[mapping[sh][0]:mapping[st][-1] + 1], id2predicate[rel_h], text[mapping[oh][0]:mapping[ot][-1] + 1]))
return spoes, token_ids, ent_text
class SPO(tuple):
"""用来存三元组的类
表现跟tuple基本一致,只是重写了 __hash__ 和 __eq__ 方法,
使得在判断两个三元组是否等价时容错性更好。
"""
def __init__(self, spo):
self.spox = (tuple(tokenizer.tokenize(spo[0])), spo[1], tuple(tokenizer.tokenize(spo[2])))
def __hash__(self):
return self.spox.__hash__()
def __eq__(self, spo):
return self.spox == spo.spox
def evaluate(data):
"""评估函数,计算f1、precision、recall
"""
X, Y, Z = 0, 1e-10, 1e-10
E1, E2 = 0, 1e-10
f = open('dev_pred.json', 'w', encoding='utf-8')
pbar = tqdm()
for d in data:
spoes, token_ids, ent_text_pred = extract_spoes(d['text'])
# spo_list是用来根据maxlen删减的
spo_list = []
for s, p, o in d['spo_list']:
s_ = tokenizer.encode(s)[0][1:-1]
o_ = tokenizer.encode(o)[0][1:-1]
sh = search(s_, token_ids) # 这里超过长度就会找不到
oh = search(o_, token_ids)
if sh != -1 and oh != -1:
spo_list.append((s, p, o))
# 计算三元组的f1值
R = set([SPO(spo) for spo in spoes])
T = set([SPO(spo) for spo in spo_list])
X += len(R & T)
Y += len(R)
Z += len(T)
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
# 计算实体的指标
ent_text_truth = set([spo[0] for spo in spo_list] + [spo[-1] for spo in spo_list])
E1 += len(ent_text_pred & ent_text_truth)
E2 += len(ent_text_truth)
E_acc = E1 / E2
# 计算entity_matrix, head_matrix,tail_matrix的accuracy
pbar.update()
pbar.set_description('f1: %.5f, precision: %.5f, recall: %.5f, ent_acc: %.5f' % (f1, precision, recall, E_acc))
s = json.dumps({'text': d['text'], 'spo_list': list(T), 'spo_list_pred': list(R),
'new': list(R - T), 'lack': list(T - R)}, ensure_ascii=False, indent=4)
f.write(s + '\n')
pbar.close()
f.close()
return f1, precision, recall
class Evaluator(Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_f1 = 0.
def on_epoch_end(self, steps, epoch, logs=None):
f1, precision, recall = evaluate(valid_dataset.data)
if f1 >= self.best_val_f1:
self.best_val_f1 = f1
# model.save_weights('best_model.pt')
print('f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' % (f1, precision, recall, self.best_val_f1))
if __name__ == '__main__':
evaluator = Evaluator()
model.fit(train_dataloader, steps_per_epoch=None, epochs=20, callbacks=[evaluator])
else:
model.load_weights('best_model.pt')