diff --git a/rvc/train/preprocess/preprocess.py b/rvc/train/preprocess/preprocess.py index bcef9dfc..898a42d2 100644 --- a/rvc/train/preprocess/preprocess.py +++ b/rvc/train/preprocess/preprocess.py @@ -1,15 +1,13 @@ import os import sys import time -import torchaudio -import torch from scipy import signal from scipy.io import wavfile import numpy as np import multiprocessing -from pydub import AudioSegment import json from distutils.util import strtobool +import librosa multiprocessing.set_start_method("spawn", force=True) @@ -57,58 +55,54 @@ def __init__(self, sr: int, exp_dir: str, per: float): os.makedirs(self.gt_wavs_dir, exist_ok=True) os.makedirs(self.wavs16k_dir, exist_ok=True) - def _normalize_audio(self, audio: torch.Tensor): - tmp_max = torch.abs(audio).max() + def _normalize_audio(self, audio: np.ndarray): + tmp_max = np.abs(audio).max() if tmp_max > 2.5: return None return (audio / tmp_max * (MAX_AMPLITUDE * ALPHA)) + (1 - ALPHA) * audio - def _write_audio(self, audio: torch.Tensor, filename: str, sr: int): - audio = audio.cpu().numpy() - wavfile.write(filename, sr, audio.astype(np.float32)) - def process_audio_segment( self, - audio_segment: torch.Tensor, + audio_segment: np.ndarray, idx0: int, idx1: int, process_effects: bool, ): - if process_effects == False: - normalized_audio = audio_segment - else: - normalized_audio = self._normalize_audio(audio_segment) + normalized_audio = ( + self._normalize_audio(audio_segment) if process_effects else audio_segment + ) if normalized_audio is None: - print(f"{idx0}-{idx1}-filtered") return - - gt_wav_path = os.path.join(self.gt_wavs_dir, f"{idx0}_{idx1}.wav") - self._write_audio(normalized_audio, gt_wav_path, self.sr) - - resampler = torchaudio.transforms.Resample( - orig_freq=self.sr, new_freq=SAMPLE_RATE_16K - ).to(self.device) - audio_16k = resampler(normalized_audio.float()) - wav_16k_path = os.path.join(self.wavs16k_dir, f"{idx0}_{idx1}.wav") - self._write_audio(audio_16k, wav_16k_path, SAMPLE_RATE_16K) + wavfile.write( + os.path.join(self.gt_wavs_dir, f"{idx0}_{idx1}.wav"), + self.sr, + normalized_audio.astype(np.float32), + ) + audio_16k = librosa.resample( + normalized_audio, orig_sr=self.sr, target_sr=SAMPLE_RATE_16K + ) + wavfile.write( + os.path.join(self.wavs16k_dir, f"{idx0}_{idx1}.wav"), + SAMPLE_RATE_16K, + audio_16k.astype(np.float32), + ) def process_audio( - self, path: str, idx0: int, cut_preprocess: bool, process_effects: bool + self, + path: str, + idx0: int, + cut_preprocess: bool, + process_effects: bool, ): + audio_length = 0 try: audio = load_audio(path, self.sr) - if process_effects == False: - audio = torch.tensor(audio, device=self.device).float() - else: - audio = torch.tensor( - signal.lfilter(self.b_high, self.a_high, audio), device=self.device - ).float() + audio_length = librosa.get_duration(y=audio, sr=self.sr) + if process_effects: + audio = signal.lfilter(self.b_high, self.a_high, audio) idx1 = 0 if cut_preprocess: - for audio_segment in self.slicer.slice(audio.cpu().numpy()): - audio_segment = torch.tensor( - audio_segment, device=self.device - ).float() + for audio_segment in self.slicer.slice(audio): i = 0 while True: start = int(self.sr * (self.per - OVERLAP) * i) @@ -130,17 +124,9 @@ def process_audio( break else: self.process_audio_segment(audio, idx0, idx1, process_effects) - except Exception as error: - print(f"An error occurred on {path} path: {error}") - - def process_audio_file(self, file_path_idx, cut_preprocess, process_effects): - file_path, idx0 = file_path_idx - ext = os.path.splitext(file_path)[1].lower() - if ext not in [".wav"]: - audio = AudioSegment.from_file(file_path) - file_path = os.path.join("/tmp", f"{idx0}.wav") - audio.export(file_path, format="wav") - self.process_audio(file_path, idx0, cut_preprocess, process_effects) + except Exception as e: + print(f"Error processing audio: {e}") + return audio_length def format_duration(seconds): @@ -150,7 +136,7 @@ def format_duration(seconds): return f"{hours:02}:{minutes:02}:{seconds:02}" -def save_dataset_duration(file_path, dataset_duration=0): +def save_dataset_duration(file_path, dataset_duration): try: with open(file_path, "r") as f: data = json.load(f) @@ -168,14 +154,10 @@ def save_dataset_duration(file_path, dataset_duration=0): json.dump(data, f, indent=4) -def get_audio_duration(file_path): - audio = AudioSegment.from_file(file_path) - return len(audio) / 1000.0 - - -def process_file(args): +def process_audio_wrapper(args): pp, file, cut_preprocess, process_effects = args - pp.process_audio_file(file, cut_preprocess, process_effects) + file_path, idx0 = file + return pp.process_audio(file_path, idx0, cut_preprocess, process_effects) def preprocess_training_set( @@ -188,7 +170,6 @@ def preprocess_training_set( process_effects: bool, ): start_time = time.time() - pp = PreProcess(sr, exp_dir, per) print(f"Starting preprocess with {num_processes} processes...") @@ -197,18 +178,20 @@ def preprocess_training_set( for idx, f in enumerate(os.listdir(input_root)) if f.lower().endswith((".wav", ".mp3", ".flac", ".ogg")) ] - file_paths = [file[0] for file in files] ctx = multiprocessing.get_context("spawn") with ctx.Pool(processes=num_processes) as pool: - pool.map( - process_file, + audio_length = pool.map( + process_audio_wrapper, [(pp, file, cut_preprocess, process_effects) for file in files], ) - durations = pool.map(get_audio_duration, file_paths) - total_duration = sum(durations) - save_dataset_duration(os.path.join(exp_dir, "model_info.json"), total_duration) + audio_length = sum(audio_length) + save_dataset_duration( + os.path.join(exp_dir, "model_info.json"), dataset_duration=audio_length + ) elapsed_time = time.time() - start_time - print(f"Preprocess completed in {elapsed_time:.2f} seconds.") + print( + f"Preprocess completed in {elapsed_time:.2f} seconds. Dataset duration: {format_duration(audio_length)}." + ) if __name__ == "__main__":