Skip to content

Commit

Permalink
Format code & implement Gradio WebUI
Browse files Browse the repository at this point in the history
  • Loading branch information
Foxify52 committed Feb 16, 2024
1 parent 167f067 commit 0367fbf
Show file tree
Hide file tree
Showing 10 changed files with 2,670 additions and 1,216 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ You can then place them into the model folder with the corresponding names:
- `(rvc .pth model name)` -> `rvc_model.pth`
- `(rvc .index model name)` -> `rvc_index.index` (optional)

Once you have all of these, you can run the `RVG.py` file with your desired arguments over CLI or you can include this code in your own project and import the `rvg_tts` function from `RVG.py`.
Once you have all of these, you can run the `RVG.py` file with your desired arguments over CLI, run the file without any arguments to launch the Gradio WebUI or you can include this code in your own project and import the `rvg_tts` function from `RVG.py`.

## Current feature set
- RVC v1 and v2 model support
Expand All @@ -29,6 +29,7 @@ Once you have all of these, you can run the `RVG.py` file with your desired argu
- [X] Create a proper importable package
- [X] Support calling from CLI
- [X] Further code condensing
- [X] Gradio WebUI
- [ ] Multi-lang support

## Other languages
Expand Down
278 changes: 205 additions & 73 deletions RVG.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,75 @@
import torch, os, sys, argparse, winsound, numpy as np
import torch, os, sys, argparse, winsound, numpy as np, gradio as gr
from typing import Callable
from fairseq import checkpoint_utils
from scipy.io import wavfile
from multiprocessing import cpu_count

from lib.fwt.dsp import DSP
from lib.fwt.forward_tacotron import ForwardTacotron
from lib.fwt.text_utils import Cleaner, Tokenizer
from multiprocessing import cpu_count
from lib.rvc.vc_infer_pipeline import VC
from lib.rvc.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono, SynthesizerTrnMs768NSFsid, SynthesizerTrnMs768NSFsid_nono
from fairseq import checkpoint_utils
from scipy.io import wavfile
from lib.rvc.models import (
SynthesizerTrnMs256NSFsid,
SynthesizerTrnMs256NSFsid_nono,
SynthesizerTrnMs768NSFsid,
SynthesizerTrnMs768NSFsid_nono,
)


class Synthesizer:

def __init__(self,
tts_path: str,
device='cuda'):
def __init__(self, tts_path: str, device="cuda"):
self.device = torch.device(device)
tts_checkpoint = torch.load(tts_path, map_location=self.device)
tts_config = tts_checkpoint['config']
tts_config = tts_checkpoint["config"]
tts_model = ForwardTacotron.from_config(tts_config)
tts_model.load_state_dict(tts_checkpoint['model'])
tts_model.load_state_dict(tts_checkpoint["model"])
self.tts_model = tts_model
self.vocoder = torch.hub.load('seungwonpark/melgan', 'melgan', verbose=False)
self.vocoder = torch.hub.load("seungwonpark/melgan", "melgan", verbose=False)
self.vocoder.to(device).eval()
self.cleaner = Cleaner.from_config(tts_config)
self.tokenizer = Tokenizer()
self.dsp = DSP.from_config(tts_config)

def __call__(self,
text: str,
voc_model: str,
alpha=1.0,
pitch_function: Callable[[torch.tensor], torch.tensor] = lambda x: x,
energy_function: Callable[[torch.tensor], torch.tensor] = lambda x: x,
) -> np.array:
def __call__(
self,
text: str,
alpha=1.0,
pitch_function: Callable[[torch.tensor], torch.tensor] = lambda x: x,
energy_function: Callable[[torch.tensor], torch.tensor] = lambda x: x,
) -> np.array:
x = self.cleaner(text)
x = self.tokenizer(x)
x = torch.tensor(x).unsqueeze(0)
gen = self.tts_model.generate(x,
alpha=alpha,
pitch_function=pitch_function,
energy_function=energy_function)
m = gen['mel_post'].cpu()
if voc_model == 'melgan':
m = m.cuda()
with torch.no_grad():
wav = self.vocoder.inference(m).cpu().numpy()
else:
print("Specified vocoder isn't supported")
exit()
gen = self.tts_model.generate(
x,
alpha=alpha,
pitch_function=pitch_function,
energy_function=energy_function,
)
m = gen["mel_post"].cpu()
m = m.cuda()
with torch.no_grad():
wav = self.vocoder.inference(m).cpu().numpy()
return wav

def pcm2float(sig, dtype='float32'):

def pcm2float(sig, dtype="float32"):
sig = np.asarray(sig)
if sig.dtype.kind not in 'iu':
if sig.dtype.kind not in "iu":
raise TypeError("'sig' must be an array of integers")
dtype = np.dtype(dtype)
if dtype.kind != 'f':
if dtype.kind != "f":
raise TypeError("'dtype' must be a floating point type")

i = np.iinfo(sig.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
return (sig.astype(dtype) - offset) / abs_max


class Config:
def __init__(self,device,is_half):
def __init__(self, device, is_half):
self.device = device
self.is_half = is_half
self.n_cpu = 0
Expand Down Expand Up @@ -113,9 +117,13 @@ def device_config(self) -> tuple:

return x_pad, x_query, x_center, x_max


def load_hubert():
global hubert_model
models, _, _ = checkpoint_utils.load_model_ensemble_and_task([f'{now_dir}\\models\\hubert.pt'],suffix="",)
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
[f"{now_dir}\\models\\hubert.pt"],
suffix="",
)
hubert_model = models[0]
hubert_model = hubert_model.to(config.device)
if config.is_half:
Expand All @@ -124,20 +132,21 @@ def load_hubert():
hubert_model = hubert_model.float()
hubert_model.eval()


def vc_single(sid, audio, f0_up_key, f0_file, file_index, index_rate):
global tgt_sr,net_g,vc,hubert_model, version
global tgt_sr, net_g, vc, hubert_model, version
f0_up_key = int(f0_up_key)
times = [0, 0, 0]
if(hubert_model==None):
if hubert_model == None:
load_hubert()
if_f0 = cpt.get("f0", 1)
audio_opt=vc.pipeline(
model=hubert_model,
net_g=net_g,
sid=sid,
audio=audio,
times=times,
f0_up_key=f0_up_key,
audio_opt = vc.pipeline(
model=hubert_model,
net_g=net_g,
sid=sid,
audio=audio,
times=times,
f0_up_key=f0_up_key,
file_index=file_index,
index_rate=index_rate,
if_f0=if_f0,
Expand All @@ -146,13 +155,14 @@ def vc_single(sid, audio, f0_up_key, f0_file, file_index, index_rate):
rms_mix_rate=0.25,
version=version,
protect=0.5,
f0_file=f0_file
f0_file=f0_file,
)
print(times)
return audio_opt


def get_vc(model_path):
global n_spk,tgt_sr,net_g,vc,cpt,device,is_half, version
global n_spk, tgt_sr, net_g, vc, cpt, version
cpt = torch.load(model_path, map_location="cpu")
tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
Expand All @@ -176,31 +186,32 @@ def get_vc(model_path):
else:
net_g = net_g.float()
vc = VC(tgt_sr, config)
n_spk = cpt["config"][-3]
n_spk = cpt["config"][-3]


def rvg_tts(
input_text="hello world!",
voice_transform=0,
tts_model=f"{os.getcwd()}\\models\\forward.pt",
rvc_model=f"{os.getcwd()}\\models\\rvc_model.pth",
rvc_index=f"{os.getcwd()}\\models\\rvc_index.index",
device="cuda:0",
is_half=True,
silent_mode=False,
persist=True
):
input_text="hello world!",
voice_transform=0,
tts_model=f"{os.getcwd()}\\models\\forward.pt",
rvc_model=f"{os.getcwd()}\\models\\rvc_model.pth",
rvc_index=f"{os.getcwd()}\\models\\rvc_index.index",
device="cuda:0",
is_half=True,
silent_mode=True,
persist=True,
):
global now_dir, config, hubert_model
now_dir = os.getcwd()
sys.path.append(now_dir)
config=Config(device,is_half)
hubert_model = None if persist else vars().get('hubert_model', None)
config = Config(device, is_half)
hubert_model = None if persist else vars().get("hubert_model", None)

synth_forward = Synthesizer(tts_model)
synth_output = pcm2float(synth_forward(input_text, voc_model='melgan', alpha=1.3), dtype=np.float32)
synth_output = pcm2float(synth_forward(input_text, alpha=1.3), dtype=np.float32)

get_vc(rvc_model)
wav_opt=vc_single(
sid=0,
wav_opt = vc_single(
sid=0,
audio=synth_output,
f0_up_key=voice_transform,
f0_file=None,
Expand All @@ -210,15 +221,136 @@ def rvg_tts(
wavfile.write("output.wav", tgt_sr, wav_opt)
if silent_mode == False:
winsound.PlaySound("output.wav", winsound.SND_FILENAME)
else:
return "output.wav"


if __name__ == "__main__":
cli_args = [
"--input_text",
"--voice_transform",
"--tts_model",
"--rvc_model",
"--rvc_index",
"--device",
"--is_half",
"--silent_mode",
]

if any(arg in sys.argv for arg in cli_args):
parser = argparse.ArgumentParser(
description="A retrieval based voice generation text to speech system"
)
parser.add_argument(
"--input_text",
default="hello world!",
type=str,
help="The input text to be converted to speech",
)
parser.add_argument(
"--voice_transform",
default=0,
type=int,
help="The voice transposition to be applied (Ranges from -12 to 12)",
)
parser.add_argument(
"--tts_model",
default=f"{os.getcwd()}\\models\\forward.pt",
type=str,
help="The path to the text-to-speech model",
)
parser.add_argument(
"--rvc_model",
default=f"{os.getcwd()}\\models\\rvc_model.pth",
type=str,
help="The path to the RVC model",
)
parser.add_argument(
"--rvc_index",
default=f"{os.getcwd()}\\models\\rvc_index.index",
type=str,
help="The path to the RVC index",
)
parser.add_argument(
"--device",
default="cuda:0",
type=str,
help="The device to run the models on",
)
parser.add_argument(
"--is_half",
action="store_false",
help="Whether to use half precision for the models",
)
parser.add_argument(
"--silent_mode",
action="store_false",
help="Whether to suppress the output sound",
)
args = parser.parse_args()

rvg_tts(**vars(args))

else:
modelDir = f"{os.getcwd()}\\models\\"

ptList, pthList, indexList = [], [], []

for x in os.listdir(modelDir):
if x.endswith(".pt"):
ptList.append(f".\\models\\{x}")
if x.endswith(".pth"):
pthList.append(f".\\models\\{x}")
if x.endswith(".index"):
indexList.append(f".\\models\\{x}")

device_choices = ["cpu"] + [
f"cuda:{i}" for i in range(torch.cuda.device_count())
]

interface = gr.Interface(
fn=rvg_tts,
inputs=[
gr.Textbox(
value="hello world!",
label="Input text",
info="Text to be converted to speech",
lines=3,
),
gr.Slider(
minimum=-12,
maximum=12,
value=0,
step=1,
label="Voice transform",
info="The voice transposition to be applied",
),
gr.Dropdown(
choices=ptList,
label="Text-to-speech model",
info="Forward tacotron model",
),
gr.Dropdown(
choices=pthList, label="RVC voice model", info="RVC voice model"
),
gr.Dropdown(
choices=indexList,
label="RVC index model",
info="RVC index model (optional)",
),
gr.Dropdown(
choices=device_choices,
label="Device",
info="Device to run the model on",
),
gr.Checkbox(
value=True,
label="Use half precision",
info="Whether to use half precision for the models",
),
],
outputs=gr.Audio(label="Output audio"),
allow_flagging=False,
)

parser = argparse.ArgumentParser(description = "A retrieval based voice generation text to speech system")
parser.add_argument("--input_text", default="hello world!", type=str, help="The input text to be converted to speech")
parser.add_argument("--voice_transform", default=0, type=int, help="The voice transposition to be applied (Ranges from -12 to 12)")
parser.add_argument("--tts_model", default=f"{os.getcwd()}\\models\\forward.pt", type=str, help="The path to the text-to-speech model")
parser.add_argument("--rvc_model", default=f"{os.getcwd()}\\models\\rvc_model.pth", type=str, help="The path to the RVC model")
parser.add_argument("--rvc_index", default=f"{os.getcwd()}\\models\\rvc_index.index", type=str, help="The path to the RVC index")
parser.add_argument("--device", default="cuda:0", type=str, help="The device to run the models on")
parser.add_argument("--is_half", action="store_false", help="Whether to use half precision for the models")
parser.add_argument("--silent_mode", action="store_false", help="Whether to suppress the output sound")
args = parser.parse_args()
rvg_tts(**vars(args))
interface.launch()
Loading

0 comments on commit 0367fbf

Please sign in to comment.