forked from GeorgeFedoseev/DeepSpeech
-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer.py
113 lines (66 loc) · 2.98 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# -*- coding:utf-8 -*-
import DeepSpeech
import tensorflow as tf
import sys
import re
import time
import os
from util import text as text_utils
import const
current_dir_path = os.path.dirname(os.path.realpath(__file__))
data_path = os.path.join(current_dir_path, "data")
initialized = False
def init(n_hidden=const.DEEP_SPEECH_N_HIDDEN, checkpoint_dir=const.DEEP_SPEECH_CHECKPOINT_DIR, alphabet_config_path=const.DEEP_SPEECH_ALPHABET_PATH, use_lm=False, language_tool_language=''):
global initialized
if initialized:
return
sys.argv.append("--alphabet_config_path")
sys.argv.append(alphabet_config_path)
sys.argv.append("--n_hidden")
sys.argv.append(str(n_hidden))
sys.argv.append("--checkpoint_dir")
sys.argv.append(checkpoint_dir)
sys.argv.append("--infer_use_lm="+("1" if use_lm else "0"))
sys.argv.append("--lt_lang="+language_tool_language)
DeepSpeech.initialize_globals()
initialized = True
def init_session():
print('Use Language Model: %s' % str(DeepSpeech.FLAGS.infer_use_lm))
session = tf.Session(config=DeepSpeech.session_config)
inputs, outputs = DeepSpeech.create_inference_graph(batch_size=1, use_new_decoder=DeepSpeech.FLAGS.infer_use_lm)
# Create a saver using variables from the above newly created graph
saver = tf.train.Saver(tf.global_variables())
# Restore variables from training checkpoint
# TODO: This restores the most recent checkpoint, but if we use validation to counterract
# over-fitting, we may want to restore an earlier checkpoint.
checkpoint = tf.train.get_checkpoint_state(DeepSpeech.FLAGS.checkpoint_dir)
if not checkpoint:
print('Checkpoint directory ({}) does not contain a valid checkpoint state.'.format(DeepSpeech.FLAGS.checkpoint_dir))
sys.exit(1)
checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path)
return session, inputs, outputs
def infer(wav_path, session_tuple):
session, inputs, outputs = session_tuple
start_time = time.time()
mfcc = DeepSpeech.audiofile_to_input_vector(wav_path, DeepSpeech.n_input, DeepSpeech.n_context)
start_time = time.time()
output = session.run(outputs['outputs'], feed_dict={
inputs['input']: [mfcc],
inputs['input_lengths']: [len(mfcc)],
})
#print "INFER took %.2f" % (time.time() - start_time)
text = DeepSpeech.ndarray_to_text(output[0][0], DeepSpeech.alphabet)
return text
if __name__ == "__main__":
start_time = time.time()
init(use_lm=True)
print("DeepSpeech init took %.2f sec" % (time.time() - start_time))
start_time = time.time()
session = init_session()
print("session init took %.2f sec" % (time.time() - start_time))
test_file_path = os.path.join(const.DATA_DIR, "infer_test_3.wav")
for i in range(0, 10):
start_time = time.time()
print infer(test_file_path, session)
print("infer took %.2f sec" % (time.time() - start_time))