diff --git a/TTS/bin/remove_silence_using_vad.py b/TTS/bin/remove_silence_using_vad.py index 9070f2dacb..a8a60bf80d 100755 --- a/TTS/bin/remove_silence_using_vad.py +++ b/TTS/bin/remove_silence_using_vad.py @@ -1,51 +1,24 @@ import argparse import glob -import multiprocessing import os import pathlib -from tqdm.contrib.concurrent import process_map +from tqdm import tqdm +from TTS.utils.vad import get_vad_model_and_utils, remove_silence -from TTS.utils.vad import get_vad_speech_segments, read_wave, write_wave - -def remove_silence(filepath): - output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, "")) +def adjust_path_and_remove_silence(audio_path): + output_path = audio_path.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, "")) # ignore if the file exists if os.path.exists(output_path) and not args.force: - return + return output_path # create all directory structure pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True) - # load wave - audio, sample_rate = read_wave(filepath) - - # get speech segments - segments = get_vad_speech_segments(audio, sample_rate, aggressiveness=args.aggressiveness) + # remove the silence and save the audio + output_path = remove_silence(model_and_utils, audio_path, output_path, trim_just_beginning_and_end=args.trim_just_beginning_and_end, use_cuda=args.use_cuda) - segments = list(segments) - num_segments = len(segments) - flag = False - # create the output wave - if num_segments != 0: - for i, segment in reversed(list(enumerate(segments))): - if i >= 1: - if not flag: - concat_segment = segment - flag = True - else: - concat_segment = segment + concat_segment - else: - if flag: - segment = segment + concat_segment - # print("Saving: ", output_path) - write_wave(output_path, segment, sample_rate) - return - else: - print("> Just Copying the file to:", output_path) - # if fail to remove silence just write the file - write_wave(output_path, audio, sample_rate) - return + return output_path def preprocess_audios(): @@ -54,17 +27,24 @@ def preprocess_audios(): if not args.force: print("> Ignoring files that already exist in the output directory.") + if args.trim_just_beginning_and_end: + print("> Trimming just the beginning and the end with nonspeech parts.") + else: + print("> Trimming all nonspeech parts.") + if files: # create threads - num_threads = multiprocessing.cpu_count() - process_map(remove_silence, files, max_workers=num_threads, chunksize=15) + # num_threads = multiprocessing.cpu_count() + # process_map(adjust_path_and_remove_silence, files, max_workers=num_threads, chunksize=15) + for f in tqdm(files): + adjust_path_and_remove_silence(f) else: print("> No files Found !") if __name__ == "__main__": parser = argparse.ArgumentParser( - description="python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2" + description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end True" ) parser.add_argument("-i", "--input_dir", type=str, default="../VCTK-Corpus", help="Dataset root dir") parser.add_argument( @@ -79,11 +59,20 @@ def preprocess_audios(): help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav", ) parser.add_argument( - "-a", - "--aggressiveness", - type=int, - default=2, - help="set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.", + "-t", + "--trim_just_beginning_and_end", + type=bool, + default=True, + help="If True this script will trim just the beginning and end nonspeech parts. If False all nonspeech parts will be trim. Default True", + ) + parser.add_argument( + "-c", + "--use_cuda", + type=bool, + default=False, + help="If True use cuda", ) args = parser.parse_args() + # load the model and utils + model_and_utils = get_vad_model_and_utils(use_cuda=args.use_cuda) preprocess_audios() diff --git a/TTS/utils/vad.py b/TTS/utils/vad.py index 923544d0b4..8879020282 100644 --- a/TTS/utils/vad.py +++ b/TTS/utils/vad.py @@ -1,144 +1,71 @@ -# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py -import collections -import contextlib -import wave - -import webrtcvad - - -def read_wave(path): - """Reads a .wav file. - - Takes the path, and returns (PCM audio data, sample rate). - """ - with contextlib.closing(wave.open(path, "rb")) as wf: - num_channels = wf.getnchannels() - assert num_channels == 1 - sample_width = wf.getsampwidth() - assert sample_width == 2 - sample_rate = wf.getframerate() - assert sample_rate in (8000, 16000, 32000, 48000) - pcm_data = wf.readframes(wf.getnframes()) - return pcm_data, sample_rate - - -def write_wave(path, audio, sample_rate): - """Writes a .wav file. - - Takes path, PCM audio data, and sample rate. - """ - with contextlib.closing(wave.open(path, "wb")) as wf: - wf.setnchannels(1) - wf.setsampwidth(2) - wf.setframerate(sample_rate) - wf.writeframes(audio) - - -class Frame(object): - """Represents a "frame" of audio data.""" - - def __init__(self, _bytes, timestamp, duration): - self.bytes = _bytes - self.timestamp = timestamp - self.duration = duration - - -def frame_generator(frame_duration_ms, audio, sample_rate): - """Generates audio frames from PCM audio data. - - Takes the desired frame duration in milliseconds, the PCM data, and - the sample rate. - - Yields Frames of the requested duration. - """ - n = int(sample_rate * (frame_duration_ms / 1000.0) * 2) - offset = 0 - timestamp = 0.0 - duration = (float(n) / sample_rate) / 2.0 - while offset + n < len(audio): - yield Frame(audio[offset : offset + n], timestamp, duration) - timestamp += duration - offset += n - - -def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames): - """Filters out non-voiced audio frames. - - Given a webrtcvad.Vad and a source of audio frames, yields only - the voiced audio. - - Uses a padded, sliding window algorithm over the audio frames. - When more than 90% of the frames in the window are voiced (as - reported by the VAD), the collector triggers and begins yielding - audio frames. Then the collector waits until 90% of the frames in - the window are unvoiced to detrigger. - - The window is padded at the front and back to provide a small - amount of silence or the beginnings/endings of speech around the - voiced frames. - - Arguments: - - sample_rate - The audio sample rate, in Hz. - frame_duration_ms - The frame duration in milliseconds. - padding_duration_ms - The amount to pad the window, in milliseconds. - vad - An instance of webrtcvad.Vad. - frames - a source of audio frames (sequence or generator). - - Returns: A generator that yields PCM audio data. - """ - num_padding_frames = int(padding_duration_ms / frame_duration_ms) - # We use a deque for our sliding window/ring buffer. - ring_buffer = collections.deque(maxlen=num_padding_frames) - # We have two states: TRIGGERED and NOTTRIGGERED. We start in the - # NOTTRIGGERED state. - triggered = False - - voiced_frames = [] - for frame in frames: - is_speech = vad.is_speech(frame.bytes, sample_rate) - - # sys.stdout.write('1' if is_speech else '0') - if not triggered: - ring_buffer.append((frame, is_speech)) - num_voiced = len([f for f, speech in ring_buffer if speech]) - # If we're NOTTRIGGERED and more than 90% of the frames in - # the ring buffer are voiced frames, then enter the - # TRIGGERED state. - if num_voiced > 0.9 * ring_buffer.maxlen: - triggered = True - # sys.stdout.write('+(%s)' % (ring_buffer[0][0].timestamp,)) - # We want to yield all the audio we see from now until - # we are NOTTRIGGERED, but we have to start with the - # audio that's already in the ring buffer. - for f, _ in ring_buffer: - voiced_frames.append(f) - ring_buffer.clear() - else: - # We're in the TRIGGERED state, so collect the audio data - # and add it to the ring buffer. - voiced_frames.append(frame) - ring_buffer.append((frame, is_speech)) - num_unvoiced = len([f for f, speech in ring_buffer if not speech]) - # If more than 90% of the frames in the ring buffer are - # unvoiced, then enter NOTTRIGGERED and yield whatever - # audio we've collected. - if num_unvoiced > 0.9 * ring_buffer.maxlen: - # sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration)) - triggered = False - yield b"".join([f.bytes for f in voiced_frames]) - ring_buffer.clear() - voiced_frames = [] - # If we have any leftover voiced audio when we run out of input, - # yield it. - if voiced_frames: - yield b"".join([f.bytes for f in voiced_frames]) - - -def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300): - - vad = webrtcvad.Vad(int(aggressiveness)) - frames = list(frame_generator(30, audio, sample_rate)) - segments = vad_collector(sample_rate, 30, padding_duration_ms, vad, frames) - - return segments +import torch +import torchaudio + +def read_audio(path): + wav, sr = torchaudio.load(path) + + if wav.size(0) > 1: + wav = wav.mean(dim=0, keepdim=True) + + return wav.squeeze(0), sr + +def resample_wav(wav, sr, new_sr): + wav = wav.unsqueeze(0) + transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=new_sr) + wav = transform(wav) + return wav.squeeze(0) + +def map_timestamps_to_new_sr(vad_sr, new_sr, timestamps, just_begging_end=False): + factor = new_sr / vad_sr + new_timestamps = [] + if just_begging_end: + # get just the start and end timestamps + new_dict = {'start': int(timestamps[0]['start']*factor), 'end': int(timestamps[-1]['end']*factor)} + new_timestamps.append(new_dict) + else: + for ts in timestamps: + # map to the new SR + new_dict = {'start': int(ts['start']*factor), 'end': int(ts['end']*factor)} + new_timestamps.append(new_dict) + + return new_timestamps + +def get_vad_model_and_utils(use_cuda=False): + model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', + model='silero_vad', + force_reload=True, + onnx=False) + if use_cuda: + model = model.cuda() + + get_speech_timestamps, save_audio, _, _, collect_chunks = utils + return model, get_speech_timestamps, save_audio, collect_chunks + +def remove_silence(model_and_utils, audio_path, out_path, vad_sample_rate=8000, trim_just_beginning_and_end=True, use_cuda=False): + + # get the VAD model and utils functions + model, get_speech_timestamps, save_audio, collect_chunks = model_and_utils + + # read ground truth wav and resample the audio for the VAD + wav, gt_sample_rate = read_audio(audio_path) + + # if needed, resample the audio for the VAD model + if gt_sample_rate != vad_sample_rate: + wav_vad = resample_wav(wav, gt_sample_rate, vad_sample_rate) + else: + wav_vad = wav + + if use_cuda: + wav_vad = wav_vad.cuda() + + # get speech timestamps from full audio file + speech_timestamps = get_speech_timestamps(wav_vad, model, sampling_rate=vad_sample_rate, window_size_samples=768) + + # map the current speech_timestamps to the sample rate of the ground truth audio + new_speech_timestamps = map_timestamps_to_new_sr(vad_sample_rate, gt_sample_rate, speech_timestamps, trim_just_beginning_and_end) + + # save audio + save_audio(out_path, + collect_chunks(new_speech_timestamps, wav), sampling_rate=gt_sample_rate) + + return out_path diff --git a/requirements.txt b/requirements.txt index c35992203a..f735c57ad9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,5 +34,3 @@ mecab-python3==1.0.3 unidic-lite==1.0.8 # gruut+supported langs gruut[cs,de,es,fr,it,nl,pt,ru,sv]==2.2.3 -# others -webrtcvad # for VAD