-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadd_CTC
16 lines (12 loc) · 802 Bytes
/
add_CTC
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# To add connectionist temporal classification layer you need to use ctc_loss_fuction during training.
# implementing ctc_loss_function and create a training model using ctc_loss_function:
labels = Input(name='the_labels', shape=[max_label_len], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
def ctc_lambda_func(args):
y_pred, labels, input_length, label_length = args
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([outputs, labels,
input_length, label_length])
#model to be used at training time
training_model = Model(inputs=[inputs, labels, input_length, label_length], outputs=loss_out)