-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
365 lines (321 loc) · 11.9 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
# coding: UTF-8
import numpy as np
import torch
import time
from utils import build_iterator, get_time_dif
from importlib import import_module
from tqdm import tqdm
from generate_data import cut_para_many_times
PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号
min_length = 64
label2class = {
"财经" : "高风险",
"时政" : "高风险",
"房产" : "中风险",
"科技" : "中风险",
"教育" : "低风险",
"时尚" : "低风险",
"游戏" : "低风险",
"家居" : "可公开",
"体育" : "可公开",
"娱乐" : "可公开",
}
label2num = {
"财经" : 0,
"时政" : 1,
"房产" : 2,
"科技" : 3,
"教育" : 4,
"时尚" : 5,
"游戏" : 6,
"家居" : 7,
"体育" : 8,
"娱乐" : 9,
}
num2label = {
0 : "财经",
1 : "时政",
2 : "房产",
3 : "科技",
4 : "教育",
5 : "时尚",
6 : "游戏",
7 : "家居",
8 : "体育",
9 : "娱乐"
}
class Predict_Baseline():
"""
第一种预测方法
不对预测的句子做任何处理
就直接尾部截断预测
优点: 快? 因为直接截断,数据量小了很多
问题: 无法看到篇章的全部信息
可能会继续做的方法(咕咕咕):
1. 把预测的序列变成多个,然后综合每个预测结果做出最终预测
2. 对篇章关键词抽取 / ... 等可能有用的方法, 然后建图,做谱聚类 (好像很难写...)
"""
def __init__(self, dataset, config):
self.dataset = dataset
self.config = config
pass
def load_dataset(self, path, pad_size):
contents = []
config = self.config
with open(path, 'r', encoding='utf-8') as fin:
cnt = 0
for line in tqdm(fin):
lin = line.strip()
if not lin:
continue
cnt += 1
if cnt == 1:
continue
# print(cnt, lin + '\n\n\n')
pos = lin.find(',')
id = lin[:pos]
content = lin[pos + 1:]
# print('?????????? : ', id, content + '\n\n')
token = config.tokenizer.tokenize(content)
token = [CLS] + token
seq_len = len(token)
mask = []
token_ids = config.tokenizer.convert_tokens_to_ids(token)
if pad_size:
if len(token) < pad_size:
mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
token_ids += ([0] * (pad_size - len(token)))
else:
mask = [1] * pad_size
token_ids = token_ids[:pad_size]
seq_len = pad_size
contents.append((token_ids, int(id), seq_len, mask))
# print('\nlen(contents) : ', str(len(contents))+'\n')
return contents
def build_dataset(self, path):
# 加载数据集
# [(tokens, int(id), seq_len, mask)]
config = self.config
print('\nloading predict set ...')
predict_data = self.load_dataset(path, config.pad_size)
print('Done!')
self.predict_iter = build_iterator(predict_data, config)
def evaluate(self, model):
config = self.config
predict_iter = self.predict_iter
model.eval()
predict_all = np.array([], dtype=int)
with torch.no_grad():
for texts, ids in tqdm(predict_iter):
outputs = model(texts)
# print('outputs : ', outputs)
ids = ids.data.cpu().numpy()
predict_label = torch.max(outputs.data, 1)[1].cpu().numpy()
predict_all = np.append(predict_all, predict_label)
return predict_all
def predict(self, model):
config = self.config
predict_iter = self.predict_iter
model.load_state_dict(torch.load(config.save_path))
model.eval()
start_time = time.time()
print('prediction ...')
predict_labels = self.evaluate(model)
time_dif = get_time_dif(start_time)
print('Done !')
print('prediction usage:',time_dif)
return predict_labels
def write_csv(self, labels, path):
with open(path, 'w',encoding="utf-8") as fout:
cnt = 0
fout.write('id,class_label,rank_label'+'\n')
for label in labels:
fout.write(str(cnt) + ',' + num2label[label] + ',' + label2class[num2label[label]] + '\n')
cnt += 1
class Predict_Cut_Paras():
"""
方法二 篇章切割,综合结果预测
type = 1 表示label投票
type = 2 表示得分softmax之和
type = 3 表示得分之和
others TBD -> ERROR
"""
def __init__(self, dataset, config, type=1):
self.dataset = dataset
self.config = config
self.type = type
if type == 1 or type == 2 or type == 3:
pass
else:
raise ValueError
def load_dataset(self, path, pad_size):
contents = []
config = self.config
# 篇章切割
print('cut paras ...')
start_time = time.time()
with open(path, 'r', encoding='utf-8') as fin:
cnt = 0
data = []
for line in tqdm(fin):
lin = line.strip()
if not line:
continue
cnt += 1
if cnt == 1:
continue
pos = lin.find(',')
id = lin[:pos]
content = lin[pos + 1:]
paras = cut_para_many_times(content)
for para in paras:
#if len(para) < min_length:
# continue
data.append((int(id), para))
print('Done!')
print('\nparas:',len(data))
print('Time usage:',get_time_dif(start_time))
print('\n Getting tokens ...')
for id, content in tqdm(data):
token = config.tokenizer.tokenize(content)
token = [CLS] + token
seq_len = len(token)
mask = []
token_ids = config.tokenizer.convert_tokens_to_ids(token)
if pad_size:
if len(token) < pad_size:
mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
token_ids += ([0] * (pad_size - len(token)))
else:
mask = [1] * pad_size
token_ids = token_ids[:pad_size]
seq_len = pad_size
contents.append((token_ids, int(id), seq_len, mask))
# print('\nlen(contents) : ', str(len(contents))+'\n')
return contents
def build_dataset(self, path):
# 加载数据集
# [(tokens, int(id), seq_len, mask)]
config = self.config
print('\nloading predict set ...')
predict_data = self.load_dataset(path, config.pad_size)
print('Done!')
self.predict_iter = build_iterator(predict_data, config)
def evaluate(self, model):
config = self.config
predict_iter = self.predict_iter
model.eval()
predict_all = np.array([], dtype=int)
id_all = np.array([], dtype=int)
score_all = np.array([[]], dtype=int)
with torch.no_grad():
for texts, ids in tqdm(predict_iter):
outputs = model(texts)
# print('outputs : ', outputs)
ids = ids.data.cpu().numpy()
predict_label = torch.max(outputs.data, 1)[1].cpu().numpy()
predict_all = np.append(predict_all, predict_label)
id_all = np.append(id_all, ids)
score_all = np.append(score_all, outputs.data.cpu().numpy())
if self.type == 1:
return predict_all, id_all
elif self.type == 2:
return score_all, id_all
elif self.type == 3:
return score_all, id_all
def predict(self, model):
config = self.config
model.load_state_dict(torch.load(config.save_path))
model.eval()
start_time = time.time()
print('prediction ...')
predict_labels, ids = self.evaluate(model)
time_dif = get_time_dif(start_time)
print('Done !')
print('prediction usage:',time_dif)
return predict_labels, ids
def softmax(self, score):
score = np.array(score)
score = np.exp(score)
S = np.sum(score)
score = score / S
return score
def write_csv(self, ids, labels, path):
print('ids:',len(ids))
print('labels',len(labels))
print(labels)
# assert 10 * len(ids) == len(labels)
cnt = 0
with open(path, 'w',encoding="utf-8") as fout:
fout.write('id,class_label,rank_label'+'\n')
i = 0
while i < len(ids):
score = [0] * 10
if self.type == 1:
score = np.array(score)
else:
score = np.array(score, dtype=float)
if self.type == 1:
score[labels[i]] += 1
while i+1 < len(ids) and ids[i+1] == ids[i]:
i += 1
score[labels[i]] += 1
elif self.type == 2:
tmp = labels[10*i:10*(i+1)]
score += tmp
while i+1 < len(ids) and ids[i+1] == ids[i]:
i += 1
tmp = labels[10*i:10*(i+1)]
tmp = self.softmax(tmp)
score += tmp
elif self.type == 3:
tmp = np.array(labels[10*i:10*(i+1)])
score += tmp
while i+1 < len(ids) and ids[i+1] == ids[i]:
i += 1
tmp = np.array(labels[10*i:10*(i+1)])
score += tmp
score = list(score)
label = score.index(max(score))
fout.write(str(cnt) + ',' + num2label[label] + ',' + label2class[num2label[label]] + '\n')
cnt += 1
i += 1
print('cnt:',cnt)
def baseline_method(x, config, dataset):
# 准备预测数据集
start_time = time.time()
predict_model = Predict_Baseline(dataset=dataset, config=config)
predict_model.build_dataset(path=dataset + '/data/test_data.csv')
time_dif = get_time_dif(start_time)
print('Time usage:', time_dif)
# 预测并写入文件
model = x.Model(config).to(config.device)
predict_labels = predict_model.predict(model=model)
predict_model.write_csv(labels=predict_labels, path=dataset + '/data/result_baseline.csv')
def cut_paras_method(x, config, dataset, type):
# 准备预测数据集
start_time = time.time()
predict_model = Predict_Cut_Paras(dataset=dataset, config=config, type=type)
predict_model.build_dataset(path=dataset + '/data/test_data.csv')
time_dif = get_time_dif(start_time)
print('Time usage:', time_dif)
# 预测并写入文件
model = x.Model(config).to(config.device)
predict_labels, ids = predict_model.predict(model=model)
predict_model.write_csv(ids,predict_labels,path=dataset + '/data/result_cut_paras' + str(type) + '.csv')
def main():
dataset = 'real' # 数据集
model_name = 'bert' # 模型名称
# 加载模块
x = import_module('models.' + model_name)
config = x.Config(dataset)
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True # 保证每次结果一样
# predict_baseline 方法
# baseline_method(x, config, dataset)
# predict_cut_paras 方法
cut_paras_method(x, config, dataset, type=1)
if __name__ == '__main__':
main()