Skip to content

Commit

Permalink
Merge pull request #652 from ShiromiyaG/fix-preprocess
Browse files Browse the repository at this point in the history
Fix preprocess taking too long
  • Loading branch information
blaisewf authored Aug 30, 2024
2 parents 5e529a3 + e2febc7 commit b928451
Showing 1 changed file with 46 additions and 63 deletions.
109 changes: 46 additions & 63 deletions rvc/train/preprocess/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import os
import sys
import time
import torchaudio
import torch
from scipy import signal
from scipy.io import wavfile
import numpy as np
import multiprocessing
from pydub import AudioSegment
import json
from distutils.util import strtobool
import librosa

multiprocessing.set_start_method("spawn", force=True)

Expand Down Expand Up @@ -57,58 +55,54 @@ def __init__(self, sr: int, exp_dir: str, per: float):
os.makedirs(self.gt_wavs_dir, exist_ok=True)
os.makedirs(self.wavs16k_dir, exist_ok=True)

def _normalize_audio(self, audio: torch.Tensor):
tmp_max = torch.abs(audio).max()
def _normalize_audio(self, audio: np.ndarray):
tmp_max = np.abs(audio).max()
if tmp_max > 2.5:
return None
return (audio / tmp_max * (MAX_AMPLITUDE * ALPHA)) + (1 - ALPHA) * audio

def _write_audio(self, audio: torch.Tensor, filename: str, sr: int):
audio = audio.cpu().numpy()
wavfile.write(filename, sr, audio.astype(np.float32))

def process_audio_segment(
self,
audio_segment: torch.Tensor,
audio_segment: np.ndarray,
idx0: int,
idx1: int,
process_effects: bool,
):
if process_effects == False:
normalized_audio = audio_segment
else:
normalized_audio = self._normalize_audio(audio_segment)
normalized_audio = (
self._normalize_audio(audio_segment) if process_effects else audio_segment
)
if normalized_audio is None:
print(f"{idx0}-{idx1}-filtered")
return

gt_wav_path = os.path.join(self.gt_wavs_dir, f"{idx0}_{idx1}.wav")
self._write_audio(normalized_audio, gt_wav_path, self.sr)

resampler = torchaudio.transforms.Resample(
orig_freq=self.sr, new_freq=SAMPLE_RATE_16K
).to(self.device)
audio_16k = resampler(normalized_audio.float())
wav_16k_path = os.path.join(self.wavs16k_dir, f"{idx0}_{idx1}.wav")
self._write_audio(audio_16k, wav_16k_path, SAMPLE_RATE_16K)
wavfile.write(
os.path.join(self.gt_wavs_dir, f"{idx0}_{idx1}.wav"),
self.sr,
normalized_audio.astype(np.float32),
)
audio_16k = librosa.resample(
normalized_audio, orig_sr=self.sr, target_sr=SAMPLE_RATE_16K
)
wavfile.write(
os.path.join(self.wavs16k_dir, f"{idx0}_{idx1}.wav"),
SAMPLE_RATE_16K,
audio_16k.astype(np.float32),
)

def process_audio(
self, path: str, idx0: int, cut_preprocess: bool, process_effects: bool
self,
path: str,
idx0: int,
cut_preprocess: bool,
process_effects: bool,
):
audio_length = 0
try:
audio = load_audio(path, self.sr)
if process_effects == False:
audio = torch.tensor(audio, device=self.device).float()
else:
audio = torch.tensor(
signal.lfilter(self.b_high, self.a_high, audio), device=self.device
).float()
audio_length = librosa.get_duration(y=audio, sr=self.sr)
if process_effects:
audio = signal.lfilter(self.b_high, self.a_high, audio)
idx1 = 0
if cut_preprocess:
for audio_segment in self.slicer.slice(audio.cpu().numpy()):
audio_segment = torch.tensor(
audio_segment, device=self.device
).float()
for audio_segment in self.slicer.slice(audio):
i = 0
while True:
start = int(self.sr * (self.per - OVERLAP) * i)
Expand All @@ -130,17 +124,9 @@ def process_audio(
break
else:
self.process_audio_segment(audio, idx0, idx1, process_effects)
except Exception as error:
print(f"An error occurred on {path} path: {error}")

def process_audio_file(self, file_path_idx, cut_preprocess, process_effects):
file_path, idx0 = file_path_idx
ext = os.path.splitext(file_path)[1].lower()
if ext not in [".wav"]:
audio = AudioSegment.from_file(file_path)
file_path = os.path.join("/tmp", f"{idx0}.wav")
audio.export(file_path, format="wav")
self.process_audio(file_path, idx0, cut_preprocess, process_effects)
except Exception as e:
print(f"Error processing audio: {e}")
return audio_length


def format_duration(seconds):
Expand All @@ -150,7 +136,7 @@ def format_duration(seconds):
return f"{hours:02}:{minutes:02}:{seconds:02}"


def save_dataset_duration(file_path, dataset_duration=0):
def save_dataset_duration(file_path, dataset_duration):
try:
with open(file_path, "r") as f:
data = json.load(f)
Expand All @@ -168,14 +154,10 @@ def save_dataset_duration(file_path, dataset_duration=0):
json.dump(data, f, indent=4)


def get_audio_duration(file_path):
audio = AudioSegment.from_file(file_path)
return len(audio) / 1000.0


def process_file(args):
def process_audio_wrapper(args):
pp, file, cut_preprocess, process_effects = args
pp.process_audio_file(file, cut_preprocess, process_effects)
file_path, idx0 = file
return pp.process_audio(file_path, idx0, cut_preprocess, process_effects)


def preprocess_training_set(
Expand All @@ -188,7 +170,6 @@ def preprocess_training_set(
process_effects: bool,
):
start_time = time.time()

pp = PreProcess(sr, exp_dir, per)
print(f"Starting preprocess with {num_processes} processes...")

Expand All @@ -197,18 +178,20 @@ def preprocess_training_set(
for idx, f in enumerate(os.listdir(input_root))
if f.lower().endswith((".wav", ".mp3", ".flac", ".ogg"))
]
file_paths = [file[0] for file in files]
ctx = multiprocessing.get_context("spawn")
with ctx.Pool(processes=num_processes) as pool:
pool.map(
process_file,
audio_length = pool.map(
process_audio_wrapper,
[(pp, file, cut_preprocess, process_effects) for file in files],
)
durations = pool.map(get_audio_duration, file_paths)
total_duration = sum(durations)
save_dataset_duration(os.path.join(exp_dir, "model_info.json"), total_duration)
audio_length = sum(audio_length)
save_dataset_duration(
os.path.join(exp_dir, "model_info.json"), dataset_duration=audio_length
)
elapsed_time = time.time() - start_time
print(f"Preprocess completed in {elapsed_time:.2f} seconds.")
print(
f"Preprocess completed in {elapsed_time:.2f} seconds. Dataset duration: {format_duration(audio_length)}."
)


if __name__ == "__main__":
Expand Down

0 comments on commit b928451

Please sign in to comment.