-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
115 lines (107 loc) · 4.49 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import paddle
import numpy as np
import os
from datetime import datetime
from utils.model import Model
from utils.decoder import ctc_greedy_decoder, label_to_string, cer
from paddle.io import DataLoader
from utils.data import collate_fn
from utils.data import CustomDataset
from visualdl import LogWriter
# 训练数据列表路径
train_data_list_path = 'dataset/train_list.txt'
# 测试数据列表路径
test_data_list_path = 'dataset/test_list.txt'
# 词汇表路径
voc_path = 'dataset/vocabulary.txt'
# 模型保存的路径
save_model = 'models/'
# 每一批数据大小
batch_size = 32
# 预训练模型路径
pretrained_model = None
# 训练轮数
num_epoch = 100
# 初始学习率大小
learning_rate = 1e-3
# 日志记录噐
writer = LogWriter(logdir='log')
def train():
# 获取训练数据
train_dataset = CustomDataset(train_data_list_path, voc_path, img_height=32)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True)
# 获取测试数据
test_dataset = CustomDataset(test_data_list_path, voc_path, img_height=32, is_data_enhance=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, collate_fn=collate_fn)
# 获取模型
model = Model(train_dataset.vocabulary, image_height=train_dataset.img_height, channel=1)
paddle.summary(model, input_size=(batch_size, 1, train_dataset.img_height, 500))
# 设置优化方法
boundaries = [30, 100, 200]
lr = [0.1 ** l * learning_rate for l in range(len(boundaries) + 1)]
scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=boundaries, values=lr, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(),
learning_rate=scheduler,
weight_decay=paddle.regularizer.L2Decay(1e-4))
# 获取损失函数
ctc_loss = paddle.nn.CTCLoss()
# 加载预训练模型
if pretrained_model is not None:
model.set_state_dict(paddle.load(os.path.join(pretrained_model, 'model.pdparams')))
optimizer.set_state_dict(paddle.load(os.path.join(pretrained_model, 'optimizer.pdopt')))
train_step = 0
test_step = 0
# 开始训练
for epoch in range(num_epoch):
for batch_id, (inputs, labels, input_lengths, label_lengths) in enumerate(train_loader()):
out = model(inputs)
# 计算损失
input_lengths = paddle.full(shape=[batch_size], fill_value=out.shape[0], dtype='int64')
loss = ctc_loss(out, labels, input_lengths, label_lengths)
loss.backward()
optimizer.step()
optimizer.clear_grad()
# 多卡训练只使用一个进程打印
if batch_id % 100 == 0:
print('[%s] Train epoch %d, batch %d, loss: %f' % (datetime.now(), epoch, batch_id, loss))
writer.add_scalar('Train loss', loss, train_step)
train_step += 1
# 执行评估
if epoch % 10 == 0:
model.eval()
cer = evaluate(model, test_loader, train_dataset.vocabulary)
print('[%s] Test epoch %d, cer: %f' % (datetime.now(), epoch, cer))
writer.add_scalar('Test cer', cer, test_step)
test_step += 1
model.train()
# 记录学习率
writer.add_scalar('Learning rate', scheduler.last_lr, epoch)
scheduler.step()
# 保存模型
paddle.save(model.state_dict(), os.path.join(save_model, 'model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(save_model, 'optimizer.pdopt'))
# 评估模型
def evaluate(model, test_loader, vocabulary):
cer_result = []
for batch_id, (inputs, labels, _, _) in enumerate(test_loader()):
# 执行识别
outs = model(inputs)
outs = paddle.transpose(outs, perm=[1, 0, 2])
outs = paddle.nn.functional.softmax(outs)
# 解码获取识别结果
labelss = []
out_strings = []
for out in outs:
out_string = ctc_greedy_decoder(out, vocabulary)
out_strings.append(out_string)
for i, label in enumerate(labels):
label_str = label_to_string(label, vocabulary)
labelss.append(label_str)
for out_string, label in zip(*(out_strings, labelss)):
# 计算字错率
c = cer(out_string, label) / float(len(label))
cer_result.append(c)
cer_result = float(np.mean(cer_result))
return cer_result
if __name__ == '__main__':
train()