Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace multiprocessing.dummy.Pool() with concurrent.futures.ThreadPoolExecutor() so whisper_s2t instance can run separately with multiprocessing.Process() #74

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
tqdm==4.66.2
rich==13.7.0
torch==2.1.2+cu121
torch>=2.1.2
numpy==1.26.4
platformdirs==4.2.0
ctranslate2==4.0.0
Expand Down
154 changes: 131 additions & 23 deletions whisper_s2t/audio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import os
import wave
import tempfile
Expand All @@ -8,7 +9,8 @@
import torch.nn as nn
import torch.nn.functional as F

from multiprocessing.dummy import Pool
import concurrent
# from multiprocessing.dummy import Pool

from . import BASE_PATH
from .configs import *
Expand Down Expand Up @@ -38,37 +40,143 @@
print(f"Using 'swr' resampler. This may degrade performance.")


def load_audio(input_file, sr=16000, return_duration=False):
# def load_audio(input_file, sr=16000, return_duration=False):

try:
with wave.open(input_file, 'rb') as wf:
# try:
# with wave.open(input_file, 'rb') as wf:
# if (wf.getframerate() != sr) or (wf.getnchannels() != 1):
# raise Exception("Not a 16kHz wav mono channel file!")

# frames = wf.getnframes()
# x = wf.readframes(int(frames))
# except:
# with tempfile.TemporaryDirectory() as tmpdir:
# wav_file = f"{tmpdir}/tmp.wav"
# ret_code = os.system(f'ffmpeg -hide_banner -loglevel panic -i "{input_file}" -threads 1 -acodec pcm_s16le -ac 1 -af aresample=resampler={RESAMPLING_ENGINE} -ar {sr} "{wav_file}" -y')
# if ret_code != 0: raise RuntimeError("ffmpeg failed to resample the input audio file, make sure ffmpeg is compiled properly!")

# with wave.open(wav_file, 'rb') as wf:
# frames = wf.getnframes()
# x = wf.readframes(int(frames))

# audio_signal = np.frombuffer(x, np.int16).flatten().astype(np.float32)/32768.0
# audio_duration = len(audio_signal)/sr

# if return_duration:
# return audio_signal, audio_duration
# else:
# return audio_signal


def load_audio(input_file: str | bytes | np.ndarray, sr: int = 16000, return_duration: bool = False) -> np.ndarray | tuple[np.ndarray, float]:
"""Load audio from disk or memory

Args:
input_file (str | bytes | np.ndarray): path to file, audio object in memory or numpy pre-loaded ndarray
sr (int, optional): sample rate. Defaults to 16000.
return_duration (bool, optional): return audio duration. Defaults to False.

Returns:
(np.ndarray | tuple[np.ndarray, float]): audio signal as numpy ndarray, audio duration
"""
def _load_audio_as_ndarray(input_file: str | bytes, sr: int = 16000) -> tuple[np.ndarray, float]:
"""Load audio from WAV file

Args:
input_file (str | bytes): path to file or object in memory
sr (int, optional): sample rate. Defaults to 16000.

Raises:
Exception: Not a 16kHz wav mono channel file!

Returns:
tuple[np.ndarray, float]: audio signal as numpy ndarray, audio duration
"""
with wave.open(input_file if isinstance(input_file, str) else io.BytesIO(input_file), 'rb') as wf:
if (wf.getframerate() != sr) or (wf.getnchannels() != 1):
raise Exception("Not a 16kHz wav mono channel file!")

frames = wf.getnframes()
x = wf.readframes(int(frames))
except:
with tempfile.TemporaryDirectory() as tmpdir:
wav_file = f"{tmpdir}/tmp.wav"
ret_code = os.system(f'ffmpeg -hide_banner -loglevel panic -i "{input_file}" -threads 1 -acodec pcm_s16le -ac 1 -af aresample=resampler={RESAMPLING_ENGINE} -ar {sr} "{wav_file}" -y')
if ret_code != 0: raise RuntimeError("ffmpeg failed to resample the input audio file, make sure ffmpeg is compiled properly!")

with wave.open(wav_file, 'rb') as wf:
frames = wf.getnframes()
x = wf.readframes(int(frames))

# convert to numpy and calculate audio duration
audio_signal = np.frombuffer(x, np.int16).flatten().astype(np.float32)/32768.0
audio_duration = len(audio_signal)/sr

return audio_signal, audio_duration

audio_signal = np.frombuffer(x, np.int16).flatten().astype(np.float32)/32768.0
audio_duration = len(audio_signal)/sr
def _ffmpeg_convert_to_wav(input_file: str, wav_file: str, sr: int = 16000):
"""Converts audio file into WAV file format

Args:
input_file (str): input file
wav_file (str): wav file name
sr (int, optional): sample rate. Defaults to 16000.

Raises:
RuntimeError: ffmpeg failed to resample the input audio file, make sure ffmpeg is compiled properly!
"""
ret_code = os.system(f'ffmpeg -hide_banner -loglevel panic -i "{input_file}" -threads 1 -acodec pcm_s16le -ac 1 -af aresample=resampler={RESAMPLING_ENGINE} -ar {sr} "{wav_file}" -y')
if ret_code != 0: raise RuntimeError("ffmpeg failed to resample the input audio file, make sure ffmpeg is compiled properly!")

if return_duration:
return audio_signal, audio_duration
else:
return audio_signal
# load audio from disk or memory
audio_signal = None
audio_duration = None
if isinstance(input_file, (str, bytes)):
try:
audio_signal, audio_duration = _load_audio_as_ndarray(input_file=input_file, sr=sr)
except:
with tempfile.TemporaryDirectory() as tmpdir:
# save bytes to file
if isinstance(input_file, bytes):
tmp_file = os.path.join(tmpdir, 'audio')
with open(tmp_file, 'wb') as f:
f.write(input_file)
input_file = tmp_file

# convert to wav
wav_file = os.path.join(tmpdir, 'tmp.wav')
_ffmpeg_convert_to_wav(input_file=input_file, wav_file=wav_file, sr=sr)
audio_signal, audio_duration = _load_audio_as_ndarray(input_file=wav_file, sr=sr)

# already preprocessed into numpy ndarray
else:
audio_signal = input_file
audio_duration = len(input_file) / sr

return (audio_signal, audio_duration) if return_duration else audio_signal


# THREAD_POOL_AUDIO_LOADER = Pool(2)
# def audio_batch_generator(audio_files):
# return THREAD_POOL_AUDIO_LOADER.imap(load_audio, audio_files)

THREAD_POOL_AUDIO_LOADER = Pool(2)
def audio_batch_generator(audio_files):
return THREAD_POOL_AUDIO_LOADER.imap(load_audio, audio_files)

def audio_batch_generator(audio_files: list, parallel: bool = True, max_workers: int = 2):
"""
Generate batches of loaded audio files, with option for parallel or sequential loading.

Args:
audio_files (list): list of paths to audio files
parallel (bool, optional, default=True): tries parallel loading if True else uses sequential loading
max_workers (int, optional, default=2): maximum number of parallel workers (only used if parallel=True)

Returns:
Iterator of loaded audio data
"""
# try parallel loading with ThreadPoolExecutor (safer than multiprocessing.dummy.Pool)
if parallel:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
try:
yield from executor.map(load_audio, audio_files)
return # if parallel loading succeeded, we are done
except Exception as e:
print(f'Parallel audio loading failed: {str(e)}. Fall back to sequential loading...')
parallel = False

# sequential loading (fallback)
for audio_file in audio_files:
yield load_audio(audio_file)


def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
Expand Down
46 changes: 30 additions & 16 deletions whisper_s2t/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,32 +153,46 @@ def transcribe(self, audio_files, lang_codes=None, tasks=None, initial_prompts=N

pbar.update(pbar.total-pbar_pos)

return responses
return responses

@torch.no_grad()
def transcribe_with_vad(self, audio_files, lang_codes=None, tasks=None, initial_prompts=None, batch_size=8):

def transcribe_with_vad(self, audio_files, lang_codes=None, tasks=None, initial_prompts=None, batch_size=8, progress_bar=True):
lang_codes = fix_batch_param(lang_codes, 'en', len(audio_files))
tasks = fix_batch_param(tasks, 'transcribe', len(audio_files))
initial_prompts = fix_batch_param(initial_prompts, None, len(audio_files))

responses = [[] for _ in audio_files]

pbar_pos = 0
with tqdm(total=len(audio_files)*100, desc=f"Transcribing") as pbar:
for signals, prompts, seq_len, seg_metadata, pbar_update in self.data_loader(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size):
if progress_bar:
pbar_pos = 0
with tqdm(total=len(audio_files)*100, desc=f"Transcribing") as pbar:
for signals, prompts, seq_len, seg_metadata, pbar_update in self.data_loader(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size):
mels, seq_len = self.preprocessor(signals, seq_len)
res = self.generate_segment_batched(mels.to(self.device), prompts, seq_len, seg_metadata)

for res_idx, _seg_metadata in enumerate(seg_metadata):
responses[_seg_metadata['file_id']].append({
**res[res_idx],
'start_time': round(_seg_metadata['start_time'], 3),
'end_time': round(_seg_metadata['end_time'], 3)
})

if (pbar_pos) <= pbar.total:
pbar_pos += pbar_update
pbar.update(pbar_update)

pbar.update(pbar.total-pbar_pos)
else:
for signals, prompts, seq_len, seg_metadata, _ in self.data_loader(audio_files, lang_codes, tasks, initial_prompts, batch_size=batch_size):
mels, seq_len = self.preprocessor(signals, seq_len)
res = self.generate_segment_batched(mels.to(self.device), prompts, seq_len, seg_metadata)

for res_idx, _seg_metadata in enumerate(seg_metadata):
responses[_seg_metadata['file_id']].append({**res[res_idx],
'start_time': round(_seg_metadata['start_time'], 3),
'end_time': round(_seg_metadata['end_time'], 3)})

if (pbar_pos) <= pbar.total:
pbar_pos += pbar_update
pbar.update(pbar_update)

pbar.update(pbar.total-pbar_pos)
responses[_seg_metadata['file_id']].append({
**res[res_idx],
'start_time': round(_seg_metadata['start_time'], 3),
'end_time': round(_seg_metadata['end_time'], 3)
})

return responses
return responses