-
Notifications
You must be signed in to change notification settings - Fork 61
/
ctc_model.py
118 lines (90 loc) · 4.29 KB
/
ctc_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
def leaky_relu(features, alpha=0.2, name=None):
with ops.name_scope(name, "LeakyRelu", [features, alpha]):
features = ops.convert_to_tensor(features, name="features")
alpha = ops.convert_to_tensor(alpha, name="alpha")
return math_ops.maximum(alpha * features, features)
#
# params["height"] = height of the input image
# params["width"] = width of the input image
def default_model_params(img_height, vocabulary_size):
params = dict()
params['img_height'] = img_height
params['img_width'] = None
params['batch_size'] = 16
params['img_channels'] = 1
params['conv_blocks'] = 4
params['conv_filter_n'] = [32, 64, 128, 256]
params['conv_filter_size'] = [ [3,3], [3,3], [3,3], [3,3] ]
params['conv_pooling_size'] = [ [2,2], [2,2], [2,2], [2,2] ]
params['rnn_units'] = 512
params['rnn_layers'] = 2
params['vocabulary_size'] = vocabulary_size
return params
def ctc_crnn(params):
# TODO Assert parameters
input = tf.placeholder(shape=(None,
params['img_height'],
params['img_width'],
params['img_channels']), # [batch, height, width, channels]
dtype=tf.float32,
name='model_input')
input_shape = tf.shape(input)
width_reduction = 1
height_reduction = 1
# Convolutional blocks
x = input
for i in range(params['conv_blocks']):
x = tf.layers.conv2d(
inputs=x,
filters=params['conv_filter_n'][i],
kernel_size=params['conv_filter_size'][i],
padding="same",
activation=None)
x = tf.layers.batch_normalization(x)
x = leaky_relu(x)
x = tf.layers.max_pooling2d(inputs=x,
pool_size=params['conv_pooling_size'][i],
strides=params['conv_pooling_size'][i])
width_reduction = width_reduction * params['conv_pooling_size'][i][1]
height_reduction = height_reduction * params['conv_pooling_size'][i][0]
# Prepare output of conv block for recurrent blocks
features = tf.transpose(x, perm=[2, 0, 3, 1]) # -> [width, batch, height, channels] (time_major=True)
feature_dim = params['conv_filter_n'][-1] * (params['img_height'] / height_reduction)
feature_width = input_shape[2] / width_reduction
features = tf.reshape(features, tf.stack([tf.cast(feature_width,'int32'), input_shape[0], tf.cast(feature_dim,'int32')])) # -> [width, batch, features]
tf.constant(params['img_height'],name='input_height')
tf.constant(width_reduction,name='width_reduction')
# Recurrent block
rnn_keep_prob = tf.placeholder(dtype=tf.float32, name="keep_prob")
rnn_hidden_units = params['rnn_units']
rnn_hidden_layers = params['rnn_layers']
rnn_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
tf.contrib.rnn.MultiRNNCell(
[tf.nn.rnn_cell.DropoutWrapper(tf.contrib.rnn.BasicLSTMCell(rnn_hidden_units), input_keep_prob=rnn_keep_prob)
for _ in range(rnn_hidden_layers)]),
tf.contrib.rnn.MultiRNNCell(
[tf.nn.rnn_cell.DropoutWrapper(tf.contrib.rnn.BasicLSTMCell(rnn_hidden_units), input_keep_prob=rnn_keep_prob)
for _ in range(rnn_hidden_layers)]),
features,
dtype=tf.float32,
time_major=True,
)
rnn_outputs = tf.concat(rnn_outputs, 2)
logits = tf.contrib.layers.fully_connected(
rnn_outputs,
params['vocabulary_size'] + 1, # BLANK
activation_fn=None,
)
tf.add_to_collection("logits",logits) # for restoring purposes
# CTC Loss computation
seq_len = tf.placeholder(tf.int32, [None], name='seq_lengths')
targets = tf.sparse_placeholder(dtype=tf.int32, name='target')
ctc_loss = tf.nn.ctc_loss(labels=targets, inputs=logits, sequence_length=seq_len, time_major=True)
loss = tf.reduce_mean(ctc_loss)
# CTC decoding
decoded, log_prob = tf.nn.ctc_greedy_decoder(logits, seq_len)
# decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits,seq_len,beam_width=50,top_paths=1,merge_repeated=True)
return input, seq_len, targets, decoded, loss, rnn_keep_prob