diff --git a/runtime/gpu/model_repo/scoring/1/model.py b/runtime/gpu/model_repo/scoring/1/model.py index 98b331ad5d..e67bad16a6 100644 --- a/runtime/gpu/model_repo/scoring/1/model.py +++ b/runtime/gpu/model_repo/scoring/1/model.py @@ -75,7 +75,6 @@ def initialize(self, args): def init_ctc_rescore(self, parameters): num_processes = multiprocessing.cpu_count() cutoff_prob = 0.9999 - blank_id = 0 alpha = 2.0 beta = 1.0 bidecoder = 0 @@ -104,8 +103,12 @@ def init_ctc_rescore(self, parameters): self.num_processes = num_processes self.cutoff_prob = cutoff_prob - self.blank_id = blank_id - _, vocab = self.load_vocab(vocab_path) + ret = self.load_vocab(vocab_path) + id2vocab, vocab, space_id, blank_id, sos_eos = ret + self.space_id = space_id if space_id else -1 + self.blank_id = blank_id if blank_id else 0 + self.eos = self.sos = sos_eos if sos_eos else len(vocab) - 1 + if lm_path and os.path.exists(lm_path): self.lm = Scorer(alpha, beta, lm_path, vocab) print("Successfully load language model!") @@ -125,24 +128,28 @@ def init_ctc_rescore(self, parameters): ) self.vocabulary = vocab self.bidecoder = bidecoder - sos = eos = len(vocab) - 1 - self.sos = sos - self.eos = eos def load_vocab(self, vocab_file): """ load lang_char.txt """ id2vocab = {} + space_id, blank_id, sos_eos = None, None, None with open(vocab_file, "r", encoding="utf-8") as f: for line in f: line = line.strip() char, id = line.split() id2vocab[int(id)] = char + if char == " ": + space_id = int(id) + elif char == "": + blank_id = int(id) + elif char == "": + sos_eos = int(id) vocab = [0] * len(id2vocab) for id, char in id2vocab.items(): vocab[id] = char - return id2vocab, vocab + return (id2vocab, vocab, space_id, blank_id, sos_eos) def load_hotwords(self, hotwords_file): """ diff --git a/runtime/gpu/model_repo_stateful/wenet/1/wenet_onnx_model.py b/runtime/gpu/model_repo_stateful/wenet/1/wenet_onnx_model.py index 6ea503591d..4e0dc35076 100644 --- a/runtime/gpu/model_repo_stateful/wenet/1/wenet_onnx_model.py +++ b/runtime/gpu/model_repo_stateful/wenet/1/wenet_onnx_model.py @@ -183,11 +183,12 @@ def infer(self, batch_log_probs, batch_log_probs_idx, seq_lens, hist_enc = batch_encoder_hist[idx] if hist_enc is None: cur_enc = cur_encoder_out[idx] + cur_mask_len = int(0 + seq_lens[idx]) else: cur_enc = torch.cat([hist_enc, cur_encoder_out[idx]], axis=0) + cur_mask_len = int(len(hist_enc) + seq_lens[idx]) rescore_encoder_hist.append(cur_enc) - cur_mask_len = int(len(hist_enc) + seq_lens[idx]) rescore_encoder_lens.append(cur_mask_len) rescore_hyps.append(score_hyps[idx]) if cur_enc.shape[0] > max_length: diff --git a/wenet/bin/export_onnx_gpu.py b/wenet/bin/export_onnx_gpu.py index 9b46208027..ab6c1dbe05 100644 --- a/wenet/bin/export_onnx_gpu.py +++ b/wenet/bin/export_onnx_gpu.py @@ -1200,7 +1200,7 @@ def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path, configs['cmvn_conf'] = {} else: assert configs['cmvn'] == "global_cmvn" - assert configs['cmvn']['cmvn_conf'] is not None + assert configs['cmvn_conf'] is not None configs['cmvn_conf']["cmvn_file"] = args.cmvn_file if (args.reverse_weight != -1.0 and "reverse_weight" in configs["model_conf"]): diff --git a/wenet/bin/recognize_onnx_gpu.py b/wenet/bin/recognize_onnx_gpu.py index 3fb0d8bbba..373c3ddbec 100644 --- a/wenet/bin/recognize_onnx_gpu.py +++ b/wenet/bin/recognize_onnx_gpu.py @@ -62,7 +62,6 @@ 'https://github.com/Slyne/ctc_decoder.git') sys.exit(1) - def get_args(): parser = argparse.ArgumentParser(description='recognize with your model') parser.add_argument('--config', required=True, help='config file') @@ -106,10 +105,8 @@ def get_args(): action='store_true', help='whether to export fp16 model, default false') args = parser.parse_args() - print(args) return args - def main(): args = get_args() logging.basicConfig(level=logging.DEBUG, @@ -122,6 +119,7 @@ def main(): configs = override_config(configs, args.override_config) reverse_weight = configs["model_conf"].get("reverse_weight", 0.0) + special_tokens = configs.get('tokenizer_conf', {}).get('special_tokens', None) test_conf = copy.deepcopy(configs['dataset_conf']) test_conf['filter_conf']['max_length'] = 102400 test_conf['filter_conf']['min_length'] = 0 @@ -145,7 +143,6 @@ def main(): tokenizer, test_conf, partition=False) - test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) # Init asr model from configs @@ -171,10 +168,18 @@ def main(): assert len(arr) == 2 char_dict[int(arr[1])] = arr[0] vocabulary.append(arr[0]) - eos = sos = len(char_dict) - 1 + + vocab_size = len(char_dict) + sos = (vocab_size - 1 if special_tokens is None else + special_tokens.get("", vocab_size - 1)) + eos = (vocab_size - 1 if special_tokens is None else + special_tokens.get("", vocab_size - 1)) + with torch.no_grad(), open(args.result_file, 'w') as fout: for _, batch in enumerate(test_data_loader): - keys, feats, _, feats_lengths, _ = batch + keys = batch['keys'] + feats = batch['feats'] + feats_lengths = batch['feats_lengths'] feats, feats_lengths = feats.numpy(), feats_lengths.numpy() if args.fp16: feats = feats.astype(np.float16) @@ -288,6 +293,5 @@ def main(): logging.info('{} {}'.format(key, content)) fout.write('{} {}\n'.format(key, content)) - if __name__ == '__main__': main()