-
Notifications
You must be signed in to change notification settings - Fork 929
/
task_seq2seq_autotitle_multigpu.py
162 lines (129 loc) · 6.19 KB
/
task_seq2seq_autotitle_multigpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#! -*- coding: utf-8 -*-
# bert做Seq2Seq任务,采用UNILM方案
# 介绍链接:https://kexue.fm/archives/6933
# 单机多卡版本,读者可以对照着 task_seq2seq_autotitle.py 阅读
from __future__ import print_function
import os
os.environ['TF_KERAS'] = '1' # 必须使用tf.keras
import glob
import numpy as np
from bert4keras.backend import keras, K
from bert4keras.layers import Loss
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer, load_vocab
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, open
from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder
from keras.models import Model
import tensorflow as tf # 导入tf,备用
# 基本参数
maxlen = 256
batch_size = 64
steps_per_epoch = 1000
epochs = 10000
# bert配置
config_path = '/root/kg/bert/chinese_wwm_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '/root/kg/bert/chinese_wwm_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/root/kg/bert/chinese_wwm_L-12_H-768_A-12/vocab.txt'
# 训练样本。THUCNews数据集,每个样本保存为一个txt。
txts = glob.glob('/root/thuctc/THUCNews/*/*.txt')
# 加载并精简词表,建立分词器
token_dict, keep_tokens = load_vocab(
dict_path=dict_path,
simplified=True,
startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],
)
tokenizer = Tokenizer(token_dict, do_lower_case=True)
class data_generator(DataGenerator):
"""数据生成器
(每次只需要返回一条样本)
"""
def __iter__(self, random=False):
for is_end, txt in self.sample(random):
text = open(txt, encoding='utf-8').read()
text = text.split('\n')
if len(text) > 1:
title = text[0]
content = '\n'.join(text[1:])
token_ids, segment_ids = tokenizer.encode(
content, title, maxlen=maxlen
)
# 返回一条样本
yield token_ids, segment_ids
class CrossEntropy(Loss):
"""交叉熵作为loss,并mask掉输入部分
"""
def compute_loss(self, inputs, mask=None):
y_true, y_mask, y_pred = inputs
y_true = y_true[:, 1:] # 目标token_ids
y_mask = y_mask[:, 1:] # segment_ids,刚好指示了要预测的部分
y_pred = y_pred[:, :-1] # 预测序列,错开一位
loss = K.sparse_categorical_crossentropy(y_true, y_pred)
loss = K.sum(loss * y_mask) / K.sum(y_mask)
return loss
strategy = tf.distribute.MirroredStrategy() # 建立单机多卡策略
with strategy.scope(): # 调用该策略
bert = build_transformer_model(
config_path,
checkpoint_path=None, # 此时可以不加载预训练权重
application='unilm',
keep_tokens=keep_tokens, # 只保留keep_tokens中的字,精简原字表
return_keras_model=False, # 返回bert4keras类,而不是keras模型
)
model = bert.model # 这个才是keras模型
output = CrossEntropy(2)(model.inputs + model.outputs)
model = Model(model.inputs, output)
model.compile(optimizer=Adam(1e-5))
model.summary()
bert.load_weights_from_checkpoint(checkpoint_path) # 必须最后才加载预训练权重
class AutoTitle(AutoRegressiveDecoder):
"""seq2seq解码器
"""
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, states):
token_ids, segment_ids = inputs
token_ids = np.concatenate([token_ids, output_ids], 1)
segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1)
return self.last_token(model).predict([token_ids, segment_ids])
def generate(self, text, topk=1):
max_c_len = maxlen - self.maxlen
token_ids, segment_ids = tokenizer.encode(text, maxlen=max_c_len)
output_ids = self.beam_search([token_ids, segment_ids],
topk=topk) # 基于beam search
return tokenizer.decode(output_ids)
autotitle = AutoTitle(start_id=None, end_id=tokenizer._token_end_id, maxlen=32)
def just_show():
s1 = u'夏天来临,皮肤在强烈紫外线的照射下,晒伤不可避免,因此,晒后及时修复显得尤为重要,否则可能会造成长期伤害。专家表示,选择晒后护肤品要慎重,芦荟凝胶是最安全,有效的一种选择,晒伤严重者,还请及 时 就医 。'
s2 = u'8月28日,网络爆料称,华住集团旗下连锁酒店用户数据疑似发生泄露。从卖家发布的内容看,数据包含华住旗下汉庭、禧玥、桔子、宜必思等10余个品牌酒店的住客信息。泄露的信息包括华住官网注册资料、酒店入住登记的身份信息及酒店开房记录,住客姓名、手机号、邮箱、身份证号、登录账号密码等。卖家对这个约5亿条数据打包出售。第三方安全平台威胁猎人对信息出售者提供的三万条数据进行验证,认为数据真实性非常高。当天下午 ,华 住集 团发声明称,已在内部迅速开展核查,并第一时间报警。当晚,上海警方消息称,接到华住集团报案,警方已经介入调查。'
for s in [s1, s2]:
print(u'生成标题:', autotitle.generate(s))
print()
class Evaluator(keras.callbacks.Callback):
"""评估与保存
"""
def __init__(self):
self.lowest = 1e10
def on_epoch_end(self, epoch, logs=None):
# 保存最优
if logs['loss'] <= self.lowest:
self.lowest = logs['loss']
model.save_weights('./best_model.weights')
# 演示效果
just_show()
if __name__ == '__main__':
evaluator = Evaluator()
train_generator = data_generator(txts, batch_size)
dataset = train_generator.to_dataset(
types=('float32', 'float32'),
shapes=([None], [None]), # 配合后面的padded_batch=True,实现自动padding
names=('Input-Token', 'Input-Segment'),
padded_batch=True
) # 数据要转为tf.data.Dataset格式,names跟输入层的名字对应
model.fit(
dataset,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
callbacks=[evaluator]
)
else:
model.load_weights('./best_model.weights')