From 313b7758a346d34a585f34e798849c5e0b130d06 Mon Sep 17 00:00:00 2001 From: Kirill Saidov Date: Thu, 17 Oct 2024 15:41:24 +0500 Subject: [PATCH 1/5] multiprocessing.dummy.Pool() replaced with safer concurrent.futures.ThreadPoolExecutor(). I tried loading whisper model and running transcription in a separate multiprocessing.Process(), but failed. The process just froze unresponsive when loading audio to memory with multiprocessing.dummy.Pool(). Replaced with concurrent.futures.ThreadPoolExecutor() and it works now! --- whisper_s2t/audio.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/whisper_s2t/audio.py b/whisper_s2t/audio.py index 06d3092..64847da 100644 --- a/whisper_s2t/audio.py +++ b/whisper_s2t/audio.py @@ -8,7 +8,8 @@ import torch.nn as nn import torch.nn.functional as F -from multiprocessing.dummy import Pool +import concurrent +# from multiprocessing.dummy import Pool from . import BASE_PATH from .configs import * @@ -66,9 +67,36 @@ def load_audio(input_file, sr=16000, return_duration=False): return audio_signal -THREAD_POOL_AUDIO_LOADER = Pool(2) -def audio_batch_generator(audio_files): - return THREAD_POOL_AUDIO_LOADER.imap(load_audio, audio_files) +# THREAD_POOL_AUDIO_LOADER = Pool(2) +# def audio_batch_generator(audio_files): +# return THREAD_POOL_AUDIO_LOADER.imap(load_audio, audio_files) + + +def audio_batch_generator(audio_files: list, parallel: bool = True, max_workers: int = 2): + """ + Generate batches of loaded audio files, with option for parallel or sequential loading. + + Args: + audio_files (list): list of paths to audio files + parallel (bool, optional, default=True): tries parallel loading if True else uses sequential loading + max_workers (int, optional, default=2): maximum number of parallel workers (only used if parallel=True) + + Returns: + Iterator of loaded audio data + """ + # try parallel loading with ThreadPoolExecutor (safer than multiprocessing.dummy.Pool) + if parallel: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + try: + yield from executor.map(load_audio, audio_files) + return # if parallel loading succeeded, we are done + except Exception as e: + print(f'Parallel audio loading failed: {str(e)}. Fall back to sequential loading...') + parallel = False + + # sequential loading (fallback) + for audio_file in audio_files: + yield load_audio(audio_file) def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): From 2ef9e9144c75dbd9dc021133a3897525ac36a830 Mon Sep 17 00:00:00 2001 From: Kirill Saidov Date: Mon, 21 Oct 2024 11:15:53 +0500 Subject: [PATCH 2/5] torch==2.1.2+cu121 not found by pip --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 1d29b3c..8543987 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ tqdm==4.66.2 rich==13.7.0 -torch==2.1.2+cu121 +torch==2.1.2 numpy==1.26.4 platformdirs==4.2.0 ctranslate2==4.0.0 From 8dfb58244ddfbf357c998c6758ae9641f1fdc8e0 Mon Sep 17 00:00:00 2001 From: Kirill Saidov Date: Mon, 21 Oct 2024 11:16:33 +0500 Subject: [PATCH 3/5] add flag progress_bar=False to turn off tqdm progress bar in WhisperModel.transcribe_with_vad() --- whisper_s2t/backends/__init__.py | 46 +++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/whisper_s2t/backends/__init__.py b/whisper_s2t/backends/__init__.py index f45513b..4f84644 100644 --- a/whisper_s2t/backends/__init__.py +++ b/whisper_s2t/backends/__init__.py @@ -153,32 +153,46 @@ def transcribe(self, audio_files, lang_codes=None, tasks=None, initial_prompts=N pbar.update(pbar.total-pbar_pos) - return responses + return responses @torch.no_grad() - def transcribe_with_vad(self, audio_files, lang_codes=None, tasks=None, initial_prompts=None, batch_size=8): - + def transcribe_with_vad(self, audio_files, lang_codes=None, tasks=None, initial_prompts=None, batch_size=8, progress_bar=True): lang_codes = fix_batch_param(lang_codes, 'en', len(audio_files)) tasks = fix_batch_param(tasks, 'transcribe', len(audio_files)) initial_prompts = fix_batch_param(initial_prompts, None, len(audio_files)) responses = [[] for _ in audio_files] - pbar_pos = 0 - with tqdm(total=len(audio_files)*100, desc=f"Transcribing") as pbar: - for signals, prompts, seq_len, seg_metadata, pbar_update in self.data_loader(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size): + if progress_bar: + pbar_pos = 0 + with tqdm(total=len(audio_files)*100, desc=f"Transcribing") as pbar: + for signals, prompts, seq_len, seg_metadata, pbar_update in self.data_loader(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size): + mels, seq_len = self.preprocessor(signals, seq_len) + res = self.generate_segment_batched(mels.to(self.device), prompts, seq_len, seg_metadata) + + for res_idx, _seg_metadata in enumerate(seg_metadata): + responses[_seg_metadata['file_id']].append({ + **res[res_idx], + 'start_time': round(_seg_metadata['start_time'], 3), + 'end_time': round(_seg_metadata['end_time'], 3) + }) + + if (pbar_pos) <= pbar.total: + pbar_pos += pbar_update + pbar.update(pbar_update) + + pbar.update(pbar.total-pbar_pos) + else: + for signals, prompts, seq_len, seg_metadata, _ in self.data_loader(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size): mels, seq_len = self.preprocessor(signals, seq_len) res = self.generate_segment_batched(mels.to(self.device), prompts, seq_len, seg_metadata) for res_idx, _seg_metadata in enumerate(seg_metadata): - responses[_seg_metadata['file_id']].append({**res[res_idx], - 'start_time': round(_seg_metadata['start_time'], 3), - 'end_time': round(_seg_metadata['end_time'], 3)}) - - if (pbar_pos) <= pbar.total: - pbar_pos += pbar_update - pbar.update(pbar_update) - - pbar.update(pbar.total-pbar_pos) + responses[_seg_metadata['file_id']].append({ + **res[res_idx], + 'start_time': round(_seg_metadata['start_time'], 3), + 'end_time': round(_seg_metadata['end_time'], 3) + }) - return responses \ No newline at end of file + return responses + From a3e5de2849aaa2015b2f873c3976cc038869db24 Mon Sep 17 00:00:00 2001 From: Kirill Saidov Date: Tue, 22 Oct 2024 10:35:28 +0500 Subject: [PATCH 4/5] whisper_s2t.audio.load_audio() can now load audio from memory and accept numpy pre-loaded ndarray --- whisper_s2t/audio.py | 118 ++++++++++++++++++++++++++++++++++++------- 1 file changed, 99 insertions(+), 19 deletions(-) diff --git a/whisper_s2t/audio.py b/whisper_s2t/audio.py index 64847da..d0a8f0a 100644 --- a/whisper_s2t/audio.py +++ b/whisper_s2t/audio.py @@ -1,3 +1,4 @@ +import io import os import wave import tempfile @@ -39,32 +40,111 @@ print(f"Using 'swr' resampler. This may degrade performance.") -def load_audio(input_file, sr=16000, return_duration=False): +# def load_audio(input_file, sr=16000, return_duration=False): - try: - with wave.open(input_file, 'rb') as wf: +# try: +# with wave.open(input_file, 'rb') as wf: +# if (wf.getframerate() != sr) or (wf.getnchannels() != 1): +# raise Exception("Not a 16kHz wav mono channel file!") + +# frames = wf.getnframes() +# x = wf.readframes(int(frames)) +# except: +# with tempfile.TemporaryDirectory() as tmpdir: +# wav_file = f"{tmpdir}/tmp.wav" +# ret_code = os.system(f'ffmpeg -hide_banner -loglevel panic -i "{input_file}" -threads 1 -acodec pcm_s16le -ac 1 -af aresample=resampler={RESAMPLING_ENGINE} -ar {sr} "{wav_file}" -y') +# if ret_code != 0: raise RuntimeError("ffmpeg failed to resample the input audio file, make sure ffmpeg is compiled properly!") + +# with wave.open(wav_file, 'rb') as wf: +# frames = wf.getnframes() +# x = wf.readframes(int(frames)) + +# audio_signal = np.frombuffer(x, np.int16).flatten().astype(np.float32)/32768.0 +# audio_duration = len(audio_signal)/sr + +# if return_duration: +# return audio_signal, audio_duration +# else: +# return audio_signal + + +def load_audio(input_file: str | bytes | np.ndarray, sr: int = 16000, return_duration: bool = False) -> np.ndarray | tuple[np.ndarray, float]: + """Load audio from disk or memory + + Args: + input_file (str | bytes | np.ndarray): path to file, audio object in memory or numpy pre-loaded ndarray + sr (int, optional): sample rate. Defaults to 16000. + return_duration (bool, optional): return audio duration. Defaults to False. + + Returns: + (np.ndarray | tuple[np.ndarray, float]): audio signal as numpy ndarray, audio duration + """ + def _load_audio_as_ndarray(input_file: str | bytes, sr: int = 16000) -> tuple[np.ndarray, float]: + """Load audio from WAV file + + Args: + input_file (str | bytes): path to file or object in memory + sr (int, optional): sample rate. Defaults to 16000. + + Raises: + Exception: Not a 16kHz wav mono channel file! + + Returns: + tuple[np.ndarray, float]: audio signal as numpy ndarray, audio duration + """ + with wave.open(input_file if isinstance(input_file, str) else io.BytesIO(input_file), 'rb') as wf: if (wf.getframerate() != sr) or (wf.getnchannels() != 1): raise Exception("Not a 16kHz wav mono channel file!") - + frames = wf.getnframes() x = wf.readframes(int(frames)) - except: - with tempfile.TemporaryDirectory() as tmpdir: - wav_file = f"{tmpdir}/tmp.wav" - ret_code = os.system(f'ffmpeg -hide_banner -loglevel panic -i "{input_file}" -threads 1 -acodec pcm_s16le -ac 1 -af aresample=resampler={RESAMPLING_ENGINE} -ar {sr} "{wav_file}" -y') - if ret_code != 0: raise RuntimeError("ffmpeg failed to resample the input audio file, make sure ffmpeg is compiled properly!") - - with wave.open(wav_file, 'rb') as wf: - frames = wf.getnframes() - x = wf.readframes(int(frames)) + + # convert to numpy and calculate audio duration + audio_signal = np.frombuffer(x, np.int16).flatten().astype(np.float32)/32768.0 + audio_duration = len(audio_signal)/sr + + return audio_signal, audio_duration - audio_signal = np.frombuffer(x, np.int16).flatten().astype(np.float32)/32768.0 - audio_duration = len(audio_signal)/sr + def _ffmpeg_convert_to_wav(input_file: str, wav_file: str, sr: int = 16000): + """Converts audio file into WAV file format + + Args: + input_file (str): input file + wav_file (str): wav file name + sr (int, optional): sample rate. Defaults to 16000. + + Raises: + RuntimeError: ffmpeg failed to resample the input audio file, make sure ffmpeg is compiled properly! + """ + ret_code = os.system(f'ffmpeg -hide_banner -loglevel panic -i "{input_file}" -threads 1 -acodec pcm_s16le -ac 1 -af aresample=resampler={RESAMPLING_ENGINE} -ar {sr} "{wav_file}" -y') + if ret_code != 0: raise RuntimeError("ffmpeg failed to resample the input audio file, make sure ffmpeg is compiled properly!") - if return_duration: - return audio_signal, audio_duration - else: - return audio_signal + # load audio from disk or memory + audio_signal = None + audio_duration = None + if isinstance(input_file, (str, bytes)): + try: + audio_signal, audio_duration = _load_audio_as_ndarray(input_file=input_file, sr=sr) + except: + with tempfile.TemporaryDirectory() as tmpdir: + # save bytes to file + if isinstance(input_file, bytes): + tmp_file = os.path.join(tmpdir, 'audio') + with open(tmp_file, 'wb') as f: + f.write(input_file) + input_file = tmp_file + + # convert to wav + wav_file = os.path.join(tmpdir, 'tmp.wav') + _ffmpeg_convert_to_wav(input_file=input_file, wav_file=wav_file, sr=sr) + audio_signal, audio_duration = _load_audio_as_ndarray(input_file=wav_file, sr=sr) + + # already preprocessed into numpy ndarray + else: + audio_signal = input_file + audio_duration = len(input_file) / sr + + return (audio_signal, audio_duration) if return_duration else audio_signal # THREAD_POOL_AUDIO_LOADER = Pool(2) From 550d4aaff6bb830bb568f91a4b26420d521e0b6e Mon Sep 17 00:00:00 2001 From: Kirill Saidov Date: Wed, 30 Oct 2024 12:43:33 +0500 Subject: [PATCH 5/5] update requirements.txt: torch==2.1.2 not found --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8543987..5546531 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ tqdm==4.66.2 rich==13.7.0 -torch==2.1.2 +torch>=2.1.2 numpy==1.26.4 platformdirs==4.2.0 ctranslate2==4.0.0