Skip to content

Commit

Permalink
fix gpu-onnx infer (#2562)
Browse files Browse the repository at this point in the history
* fix gpu-onnx infer

* fix lint error

* update

* 修复实时推理时hist_enc为None时报错

* update

* update

* update

* update

* remove double blank lines

* remove wqtrailing whitespace

---------

Co-authored-by: unknown <Simon@DESKTOP-6VPSU4Q>
Co-authored-by: linrrry <linry23@mail2.sysu.edu.cn>
  • Loading branch information
3 people authored Jul 9, 2024
1 parent dec409b commit fae6a8c
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 16 deletions.
21 changes: 14 additions & 7 deletions runtime/gpu/model_repo/scoring/1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def initialize(self, args):
def init_ctc_rescore(self, parameters):
num_processes = multiprocessing.cpu_count()
cutoff_prob = 0.9999
blank_id = 0
alpha = 2.0
beta = 1.0
bidecoder = 0
Expand Down Expand Up @@ -104,8 +103,12 @@ def init_ctc_rescore(self, parameters):

self.num_processes = num_processes
self.cutoff_prob = cutoff_prob
self.blank_id = blank_id
_, vocab = self.load_vocab(vocab_path)
ret = self.load_vocab(vocab_path)
id2vocab, vocab, space_id, blank_id, sos_eos = ret
self.space_id = space_id if space_id else -1
self.blank_id = blank_id if blank_id else 0
self.eos = self.sos = sos_eos if sos_eos else len(vocab) - 1

if lm_path and os.path.exists(lm_path):
self.lm = Scorer(alpha, beta, lm_path, vocab)
print("Successfully load language model!")
Expand All @@ -125,24 +128,28 @@ def init_ctc_rescore(self, parameters):
)
self.vocabulary = vocab
self.bidecoder = bidecoder
sos = eos = len(vocab) - 1
self.sos = sos
self.eos = eos

def load_vocab(self, vocab_file):
"""
load lang_char.txt
"""
id2vocab = {}
space_id, blank_id, sos_eos = None, None, None
with open(vocab_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
char, id = line.split()
id2vocab[int(id)] = char
if char == " ":
space_id = int(id)
elif char == "<blank>":
blank_id = int(id)
elif char == "<sos/eos>":
sos_eos = int(id)
vocab = [0] * len(id2vocab)
for id, char in id2vocab.items():
vocab[id] = char
return id2vocab, vocab
return (id2vocab, vocab, space_id, blank_id, sos_eos)

def load_hotwords(self, hotwords_file):
"""
Expand Down
3 changes: 2 additions & 1 deletion runtime/gpu/model_repo_stateful/wenet/1/wenet_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,12 @@ def infer(self, batch_log_probs, batch_log_probs_idx, seq_lens,
hist_enc = batch_encoder_hist[idx]
if hist_enc is None:
cur_enc = cur_encoder_out[idx]
cur_mask_len = int(0 + seq_lens[idx])
else:
cur_enc = torch.cat([hist_enc, cur_encoder_out[idx]],
axis=0)
cur_mask_len = int(len(hist_enc) + seq_lens[idx])
rescore_encoder_hist.append(cur_enc)
cur_mask_len = int(len(hist_enc) + seq_lens[idx])
rescore_encoder_lens.append(cur_mask_len)
rescore_hyps.append(score_hyps[idx])
if cur_enc.shape[0] > max_length:
Expand Down
2 changes: 1 addition & 1 deletion wenet/bin/export_onnx_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,7 @@ def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path,
configs['cmvn_conf'] = {}
else:
assert configs['cmvn'] == "global_cmvn"
assert configs['cmvn']['cmvn_conf'] is not None
assert configs['cmvn_conf'] is not None
configs['cmvn_conf']["cmvn_file"] = args.cmvn_file
if (args.reverse_weight != -1.0
and "reverse_weight" in configs["model_conf"]):
Expand Down
18 changes: 11 additions & 7 deletions wenet/bin/recognize_onnx_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
'https://github.com/Slyne/ctc_decoder.git')
sys.exit(1)


def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
Expand Down Expand Up @@ -106,10 +105,8 @@ def get_args():
action='store_true',
help='whether to export fp16 model, default false')
args = parser.parse_args()
print(args)
return args


def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
Expand All @@ -122,6 +119,7 @@ def main():
configs = override_config(configs, args.override_config)

reverse_weight = configs["model_conf"].get("reverse_weight", 0.0)
special_tokens = configs.get('tokenizer_conf', {}).get('special_tokens', None)
test_conf = copy.deepcopy(configs['dataset_conf'])
test_conf['filter_conf']['max_length'] = 102400
test_conf['filter_conf']['min_length'] = 0
Expand All @@ -145,7 +143,6 @@ def main():
tokenizer,
test_conf,
partition=False)

test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)

# Init asr model from configs
Expand All @@ -171,10 +168,18 @@ def main():
assert len(arr) == 2
char_dict[int(arr[1])] = arr[0]
vocabulary.append(arr[0])
eos = sos = len(char_dict) - 1

vocab_size = len(char_dict)
sos = (vocab_size - 1 if special_tokens is None else
special_tokens.get("<sos>", vocab_size - 1))
eos = (vocab_size - 1 if special_tokens is None else
special_tokens.get("<eos>", vocab_size - 1))

with torch.no_grad(), open(args.result_file, 'w') as fout:
for _, batch in enumerate(test_data_loader):
keys, feats, _, feats_lengths, _ = batch
keys = batch['keys']
feats = batch['feats']
feats_lengths = batch['feats_lengths']
feats, feats_lengths = feats.numpy(), feats_lengths.numpy()
if args.fp16:
feats = feats.astype(np.float16)
Expand Down Expand Up @@ -288,6 +293,5 @@ def main():
logging.info('{} {}'.format(key, content))
fout.write('{} {}\n'.format(key, content))


if __name__ == '__main__':
main()

0 comments on commit fae6a8c

Please sign in to comment.