Skip to content

Commit

Permalink
Merge pull request #604 from ShiromiyaG/rework-batch
Browse files Browse the repository at this point in the history
Rework batch inference
  • Loading branch information
blaisewf authored Aug 17, 2024
2 parents c090fbd + 48fabab commit c0f1968
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 38 deletions.
65 changes: 27 additions & 38 deletions core.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,44 +136,33 @@ def run_batch_infer_script(
f for f in os.listdir(input_folder) if f.endswith((".mp3", ".wav", ".flac"))
]
print(f"Detected {len(audio_files)} audio files for inference.")

for audio_file in audio_files:
if "_output" in audio_file:
pass
else:
input_path = os.path.join(input_folder, audio_file)
output_file_name = os.path.splitext(os.path.basename(audio_file))[0]
output_path = os.path.join(
output_folder,
f"{output_file_name}_output{os.path.splitext(audio_file)[1]}",
)
infer_pipeline = import_voice_converter()
print(f"Inferring {input_path}...")
infer_pipeline.convert_audio(
pitch=pitch,
filter_radius=filter_radius,
index_rate=index_rate,
volume_envelope=volume_envelope,
protect=protect,
hop_length=hop_length,
f0_method=f0_method,
audio_input_path=input_path,
audio_output_path=output_path,
model_path=pth_path,
index_path=index_path,
split_audio=split_audio,
f0_autotune=f0_autotune,
clean_audio=clean_audio,
clean_strength=clean_strength,
export_format=export_format,
upscale_audio=upscale_audio,
f0_file=f0_file,
embedder_model=embedder_model,
embedder_model_custom=embedder_model_custom,
formant_shifting=formant_shifting,
formant_qfrency=formant_qfrency,
formant_timbre=formant_timbre,
)
infer_pipeline = import_voice_converter()
infer_pipeline.convert_audio_batch(
pitch=pitch,
filter_radius=filter_radius,
index_rate=index_rate,
volume_envelope=volume_envelope,
protect=protect,
hop_length=hop_length,
f0_method=f0_method,
audio_input_paths=input_folder,
audio_output_path=output_folder,
model_path=pth_path,
index_path=index_path,
split_audio=split_audio,
f0_autotune=f0_autotune,
clean_audio=clean_audio,
clean_strength=clean_strength,
export_format=export_format,
upscale_audio=upscale_audio,
f0_file=f0_file,
embedder_model=embedder_model,
embedder_model_custom=embedder_model_custom,
formant_shifting=formant_shifting,
formant_qfrency=formant_qfrency,
formant_timbre=formant_timbre,
pid_file_path=os.path.join(now_dir, "assets", "infer_pid.txt"),
)

return f"Files from {input_folder} inferred successfully."

Expand Down
215 changes: 215 additions & 0 deletions rvc/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logging.getLogger("faiss").setLevel(logging.WARNING)
logging.getLogger("faiss.loader").setLevel(logging.WARNING)


class VoiceConverter:
Expand Down Expand Up @@ -315,6 +316,220 @@ def convert_audio(
print(f"An error occurred during audio conversion: {error}")
print(traceback.format_exc())

def convert_audio_batch(
self,
audio_input_paths: str,
audio_output_path: str,
model_path: str,
index_path: str,
embedder_model: str,
pitch: int,
f0_file: str,
f0_method: str,
index_rate: float,
volume_envelope: int,
protect: float,
hop_length: int,
split_audio: bool,
f0_autotune: bool,
filter_radius: int,
embedder_model_custom: str,
clean_audio: bool,
clean_strength: float,
export_format: str,
upscale_audio: bool,
formant_shifting: bool,
formant_qfrency: float,
formant_timbre: float,
resample_sr: int = 0,
sid: int = 0,
pid_file_path: str = None,
):
"""
Performs voice conversion on a batch of input audio files.
Args:
audio_input_paths (list): List of paths to the input audio files.
audio_output_path (str): Path to the output audio file.
model_path (str): Path to the voice conversion model.
index_path (str): Path to the index file.
sid (int, optional): Speaker ID. Default is 0.
pitch (str, optional): Key for F0 up-sampling. Default is None.
f0_file (str, optional): Path to the F0 file. Default is None.
f0_method (str, optional): Method for F0 extraction. Default is None.
index_rate (float, optional): Rate for index matching. Default is None.
resample_sr (int, optional): Resample sampling rate. Default is 0.
volume_envelope (float, optional): RMS mix rate. Default is None.
protect (float, optional): Protection rate for certain audio segments. Default is None.
hop_length (int, optional): Hop length for audio processing. Default is None.
split_audio (bool, optional): Whether to split the audio for processing. Default is False.
f0_autotune (bool, optional): Whether to use F0 autotune. Default is False.
filter_radius (int, optional): Radius for filtering. Default is None.
embedder_model (str, optional): Path to the embedder model. Default is None.
embedder_model_custom (str, optional): Path to the custom embedder model. Default is None.
clean_audio (bool, optional): Whether to clean the audio. Default is False.
clean_strength (float, optional): Strength of the audio cleaning. Default is 0.7.
export_format (str, optional): Format for exporting the audio. Default is "WAV".
upscale_audio (bool, optional): Whether to upscale the audio. Default is False.
formant_shift (bool, optional): Whether to shift the formants. Default is False.
formant_qfrency (float, optional): Formant frequency. Default is 1.0.
formant_timbre (float, optional): Formant timbre. Default is 1.0.
pid_file_path (str, optional): Path to the PID file. Default is None.
"""
pid = os.getpid()
with open(pid_file_path, "w") as pid_file:
pid_file.write(str(pid))
try:
if not self.hubert_model:
self.load_hubert(embedder_model, embedder_model_custom)
self.get_vc(model_path, sid)
file_index = (
index_path.strip()
.strip('"')
.strip("\n")
.strip('"')
.strip()
.replace("trained", "added")
)
start_time = time.time()
print(f"Converting audio batch '{audio_input_paths}'...")
audio_files = [
f
for f in os.listdir(audio_input_paths)
if f.endswith((".mp3", ".wav", ".flac", ".m4a", ".ogg", ".opus"))
]
print(f"Detected {len(audio_files)} audio files for inference.")
for i, audio_input_path in enumerate(audio_files):
audio_output_paths = os.path.join(
audio_output_path,
f"{os.path.splitext(os.path.basename(audio_input_path))[0]}_output.{export_format.lower()}",
)
if os.path.exists(audio_output_paths):
continue
print(f"Converting audio '{audio_input_path}'...")
audio_input_path = os.path.join(audio_input_paths, audio_input_path)

if upscale_audio == True:
upscale(audio_input_path, audio_input_path)
audio = load_audio_infer(
audio_input_path,
16000,
formant_shifting,
formant_qfrency,
formant_timbre,
)
audio_max = np.abs(audio).max() / 0.95

if audio_max > 1:
audio /= audio_max

if self.tgt_sr != resample_sr >= 16000:
self.tgt_sr = resample_sr

if split_audio:
result, new_dir_path = process_audio(audio_input_path)
if result == "Error":
return "Error with Split Audio", None

dir_path = (
new_dir_path.strip().strip('"').strip("\n").strip('"').strip()
)
if dir_path:
paths = [
os.path.join(root, name)
for root, _, files in os.walk(dir_path, topdown=False)
for name in files
if name.endswith(".wav") and root == dir_path
]
try:
for path in paths:
self.convert_audio(
audio_input_path=path,
audio_output_path=path,
model_path=model_path,
index_path=index_path,
sid=sid,
pitch=pitch,
f0_file=None,
f0_method=f0_method,
index_rate=index_rate,
resample_sr=resample_sr,
volume_envelope=volume_envelope,
protect=protect,
hop_length=hop_length,
split_audio=False,
f0_autotune=f0_autotune,
filter_radius=filter_radius,
export_format=export_format,
upscale_audio=upscale_audio,
embedder_model=embedder_model,
embedder_model_custom=embedder_model_custom,
clean_audio=clean_audio,
clean_strength=clean_strength,
)
except Exception as error:
print(
f"An error occurred processing the segmented audio: {error}"
)
print(traceback.format_exc())
return f"Error {error}"
print("Finished processing segmented audio, now merging audio...")
merge_timestamps_file = os.path.join(
os.path.dirname(new_dir_path),
f"{os.path.basename(audio_input_path).split('.')[0]}_timestamps.txt",
)
self.tgt_sr, audio_opt = merge_audio(merge_timestamps_file)
os.remove(merge_timestamps_file)
else:
audio_opt = self.vc.pipeline(
model=self.hubert_model,
net_g=self.net_g,
sid=sid,
audio=audio,
input_audio_path=audio_input_path,
pitch=pitch,
f0_method=f0_method,
file_index=file_index,
index_rate=index_rate,
pitch_guidance=self.use_f0,
filter_radius=filter_radius,
tgt_sr=self.tgt_sr,
resample_sr=resample_sr,
volume_envelope=volume_envelope,
version=self.version,
protect=protect,
hop_length=hop_length,
f0_autotune=f0_autotune,
f0_file=f0_file,
)

if audio_output_paths:
sf.write(audio_output_paths, audio_opt, self.tgt_sr, format="WAV")

if clean_audio:
cleaned_audio = self.remove_audio_noise(
audio_output_paths, clean_strength
)
if cleaned_audio is not None:
sf.write(
audio_output_paths, cleaned_audio, self.tgt_sr, format="WAV"
)

output_path_format = audio_output_path.replace(
".wav", f".{export_format.lower()}"
)
audio_output_paths = self.convert_audio_format(
audio_output_paths, output_path_format, export_format
)
print(f"Conversion completed at '{audio_output_paths}'.")
elapsed_time = time.time() - start_time
print(f"Batch conversion completed in {elapsed_time:.2f} seconds.")
os.remove(pid_file_path)
except Exception as error:
print(f"An error occurred during audio conversion: {error}")
print(traceback.format_exc())

def get_vc(self, weight_root, sid):
"""
Loads the voice conversion model and sets up the pipeline.
Expand Down
3 changes: 3 additions & 0 deletions rvc/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging

logging.getLogger("fairseq").setLevel(logging.WARNING)
logging.getLogger("faiss.loader").setLevel(logging.WARNING)

now_dir = os.getcwd()
sys.path.append(now_dir)
Expand Down Expand Up @@ -40,6 +41,8 @@ def load_audio_infer(
):
try:
file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
if not os.path.isfile(file):
raise FileNotFoundError(f"File not found: {file}")
audio, sr = sf.read(file)
if len(audio.shape) > 1:
audio = librosa.to_mono(audio.T)
Expand Down
25 changes: 25 additions & 0 deletions tabs/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from assets.i18n.i18n import I18nAuto

from rvc.lib.utils import format_title
from tabs.settings.restart import stop_infer

i18n = I18nAuto()

Expand Down Expand Up @@ -937,6 +938,8 @@ def inference_tab():
)

convert_button2 = gr.Button(i18n("Convert"))
stop_button = gr.Button(i18n("Stop convert"), visible=False)
stop_button.click(fn=stop_infer, inputs=[], outputs=[])

with gr.Row():
vc_output3 = gr.Textbox(
Expand All @@ -957,6 +960,18 @@ def toggle_visible_embedder_custom(embedder_model):
return {"visible": True, "__type__": "update"}
return {"visible": False, "__type__": "update"}

def enable_stop_convert_button():
return {"visible": False, "__type__": "update"}, {
"visible": True,
"__type__": "update",
}

def disable_stop_convert_button():
return {"visible": True, "__type__": "update"}, {
"visible": False,
"__type__": "update",
}

def toggle_visible_formant_shifting(checkbox):
if checkbox:
return (
Expand Down Expand Up @@ -1170,3 +1185,13 @@ def toggle_visible_formant_shifting(checkbox):
],
outputs=[vc_output3],
)
convert_button2.click(
fn=enable_stop_convert_button,
inputs=[],
outputs=[convert_button2, stop_button],
)
stop_button.click(
fn=disable_stop_convert_button,
inputs=[],
outputs=[convert_button2, stop_button],
)
12 changes: 12 additions & 0 deletions tabs/settings/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ def stop_train(model_name: str):
pass


def stop_infer():
pid_file_path = os.path.join(now_dir, "assets", "infer_pid.txt")
try:
with open(pid_file_path, "r") as pid_file:
pids = [int(pid) for pid in pid_file.readlines()]
for pid in pids:
os.kill(pid, 9)
os.remove(pid_file_path)
except:
pass


def restart_applio():
if os.name != "nt":
os.system("clear")
Expand Down

0 comments on commit c0f1968

Please sign in to comment.