-
Notifications
You must be signed in to change notification settings - Fork 4
/
infer_wavenext_onnx.py
170 lines (147 loc) · 6.12 KB
/
infer_wavenext_onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import argparse
import os
import warnings
from pathlib import Path
from time import perf_counter
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch
from text import text_to_sequence, sequence_to_text
def intersperse(lst, item):
# Adds blank symbol
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def process_text(i: int, text: str, device: torch.device):
print(f"[{i}] - Input text: {text}")
x = torch.tensor(
intersperse(text_to_sequence(text, ["catalan_cleaners"]), 0),
dtype=torch.long,
device=device,
)[None]
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
x_phones = sequence_to_text(x.squeeze(0).tolist())
print(f"[{i}] - Phonetised text: {x_phones[1::2]}")
return {"x_orig": text, "x": x, "x_lengths": x_lengths, "x_phones": x_phones}
def validate_args(args):
assert (
args.text or args.file
), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
assert args.temperature >= 0, "Sampling temperature cannot be negative"
assert args.speaking_rate >= 0, "Speaking rate must be greater than 0"
return args
def write_wavs(model, inputs, output_dir, external_vocoder=None):
if external_vocoder is None:
print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly")
t0 = perf_counter()
wavs, wav_lengths = model.run(None, inputs)
infer_secs = perf_counter() - t0
mel_infer_secs = vocoder_infer_secs = None
else:
print("[🍵] Generating mel using Matcha")
mel_t0 = perf_counter()
mels, mel_lengths = model.run(None, inputs)
mel_infer_secs = perf_counter() - mel_t0
print("Generating waveform from mel using external vocoder")
vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels}
vocoder_t0 = perf_counter()
wavs = external_vocoder.run(None, vocoder_inputs)[0]
vocoder_infer_secs = perf_counter() - vocoder_t0
wavs = wavs.squeeze(1)
wav_lengths = mel_lengths * 256
infer_secs = mel_infer_secs + vocoder_infer_secs
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)):
output_filename = output_dir.joinpath(f"output_{i + 1}.wav")
audio = wav[:wav_length]
print(f"Writing audio to {output_filename}")
sf.write(output_filename, audio, 22050, "PCM_24")
wav_secs = wav_lengths.sum() / 22050
print(f"Inference seconds: {infer_secs}")
print(f"Generated wav seconds: {wav_secs}")
rtf = infer_secs / wav_secs
if mel_infer_secs is not None:
mel_rtf = mel_infer_secs / wav_secs
print(f"Matcha RTF: {mel_rtf}")
if vocoder_infer_secs is not None:
vocoder_rtf = vocoder_infer_secs / wav_secs
print(f"Vocoder RTF: {vocoder_rtf}")
print(f"Overall RTF: {rtf}")
def main():
parser = argparse.ArgumentParser(
description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
)
parser.add_argument(
"model",
type=str,
help="ONNX model to use",
)
parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)")
parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")
parser.add_argument("--spk", type=int, default=None, help="Speaker ID")
parser.add_argument(
"--temperature",
type=float,
default=0.667,
help="Variance of the x0 noise (default: 0.667)",
)
parser.add_argument(
"--speaking-rate",
type=float,
default=1.0,
help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
)
parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)")
parser.add_argument(
"--output-dir",
type=str,
default=os.getcwd(),
help="Output folder to save results (default: current dir)",
)
args = parser.parse_args()
args = validate_args(args)
if args.gpu:
providers = ["CUDAExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
model = ort.InferenceSession(args.model, providers=providers)
model_inputs = model.get_inputs()
model_outputs = list(model.get_outputs())
if args.text:
text_lines = args.text.splitlines()
else:
with open(args.file, encoding="utf-8") as file:
text_lines = file.read().splitlines()
processed_lines = [process_text(0, line, "cpu") for line in text_lines]
x = [line["x"].squeeze() for line in processed_lines]
# Pad
x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)
x = x.detach().cpu().numpy()
x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64)
inputs = {
"x": x,
"x_lengths": x_lengths,
"scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32),
}
is_multi_speaker = len(model_inputs) == 4
if is_multi_speaker:
if args.spk is None:
args.spk = 0
warn = "[!] Speaker ID not provided! Using speaker ID 0"
warnings.warn(warn, UserWarning)
inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64)
has_vocoder_embedded = model_outputs[0].name == "wav"
if has_vocoder_embedded:
write_wavs(model, inputs, args.output_dir)
elif args.vocoder:
external_vocoder = ort.InferenceSession(args.vocoder, providers=providers)
write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder)
else:
warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory"
warnings.warn(warn, UserWarning)
write_mels(model, inputs, args.output_dir)
if __name__ == "__main__":
main()