forked from tensorpack/tensorpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain-timit.py
executable file
·131 lines (106 loc) · 4.52 KB
/
train-timit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: train-timit.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
import os
import sys
import argparse
from collections import Counter
import operator
import six
from six.moves import map, range
from tensorpack import *
from tensorpack.tfutils.gradproc import SummaryGradient, GlobalNormClip
from tensorpack.utils.globvars import globalns as param
import tensorpack.tfutils.symbolic_functions as symbf
import tensorflow as tf
from timitdata import TIMITBatch
BATCH = 64
NLAYER = 2
HIDDEN = 128
NR_CLASS = 61 + 1 # 61 phoneme + epsilon
FEATUREDIM = 39 # MFCC feature dimension
class Model(ModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, [None, None, FEATUREDIM], 'feat'), # bxmaxseqx39
InputDesc(tf.int64, None, 'labelidx'), # label is b x maxlen, sparse
InputDesc(tf.int32, None, 'labelvalue'),
InputDesc(tf.int64, None, 'labelshape'),
InputDesc(tf.int32, [None], 'seqlen'), # b
]
def _build_graph(self, inputs):
feat, labelidx, labelvalue, labelshape, seqlen = inputs
label = tf.SparseTensor(labelidx, labelvalue, labelshape)
cell = tf.contrib.rnn.BasicLSTMCell(num_units=HIDDEN)
cell = tf.contrib.rnn.MultiRNNCell([cell] * NLAYER)
initial = cell.zero_state(tf.shape(feat)[0], tf.float32)
outputs, last_state = tf.nn.dynamic_rnn(cell, feat,
seqlen, initial,
dtype=tf.float32, scope='rnn')
# o: b x t x HIDDEN
output = tf.reshape(outputs, [-1, HIDDEN]) # (Bxt) x rnnsize
logits = FullyConnected('fc', output, NR_CLASS, nl=tf.identity,
W_init=tf.truncated_normal_initializer(stddev=0.01))
logits = tf.reshape(logits, (BATCH, -1, NR_CLASS))
loss = tf.nn.ctc_loss(label, logits, seqlen, time_major=False)
self.cost = tf.reduce_mean(loss, name='cost')
logits = tf.transpose(logits, [1, 0, 2])
isTrain = get_current_tower_context().is_training
if isTrain:
# beam search is too slow to run in training
predictions = tf.to_int32(
tf.nn.ctc_greedy_decoder(logits, seqlen)[0][0])
else:
predictions = tf.to_int32(
tf.nn.ctc_beam_search_decoder(logits, seqlen)[0][0])
err = tf.edit_distance(predictions, label, normalize=True)
err.set_shape([None])
err = tf.reduce_mean(err, name='error')
summary.add_moving_summary(err, self.cost)
def _get_optimizer(self):
lr = symbolic_functions.get_scalar_var('learning_rate', 5e-3, summary=True)
opt = tf.train.AdamOptimizer(lr, epsilon=1e-3)
return optimizer.apply_grad_processors(
opt, [GlobalNormClip(5), SummaryGradient()])
def get_data(path, isTrain, stat_file):
ds = LMDBDataPoint(path, shuffle=isTrain)
mean, std = serialize.loads(open(stat_file).read())
ds = MapDataComponent(ds, lambda x: (x - mean) / std)
ds = TIMITBatch(ds, BATCH)
if isTrain:
ds = PrefetchDataZMQ(ds, 1)
return ds
def get_config(ds_train, ds_test):
return TrainConfig(
dataflow=ds_train,
callbacks=[
ModelSaver(),
StatMonitorParamSetter('learning_rate', 'error',
lambda x: x * 0.2, 0, 5),
HumanHyperParamSetter('learning_rate'),
PeriodicTrigger(
InferenceRunner(ds_test, [ScalarStats('error')]),
every_k_epochs=2),
],
model=Model(),
max_epoch=70,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--train', help='path to training lmdb', required=True)
parser.add_argument('--test', help='path to testing lmdb', required=True)
parser.add_argument('--stat', help='path to the mean/std statistics file',
default='stats.data')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
logger.auto_set_dir()
ds_train = get_data(args.train, True, args.stat)
ds_test = get_data(args.test, False, args.stat)
config = get_config(ds_train, ds_test)
if args.load:
config.session_init = SaverRestore(args.load)
QueueInputTrainer(config).train()