From a88e09953c3d5f030385ec9735423b7584439efd Mon Sep 17 00:00:00 2001 From: ftnext Date: Mon, 30 Dec 2024 01:39:42 +0900 Subject: [PATCH] [refactor] Extract to support Whisper variants in future --- .../recognizers/whisper_local.py | 84 ++++++++++++------- 1 file changed, 53 insertions(+), 31 deletions(-) diff --git a/speech_recognition/recognizers/whisper_local.py b/speech_recognition/recognizers/whisper_local.py index 9f8b5163..f2be01e6 100644 --- a/speech_recognition/recognizers/whisper_local.py +++ b/speech_recognition/recognizers/whisper_local.py @@ -1,7 +1,7 @@ from __future__ import annotations import io -from typing import TYPE_CHECKING, Literal, TypedDict +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict from speech_recognition.audio import AudioData @@ -10,6 +10,44 @@ from typing_extensions import Unpack +class Transcribable(Protocol): + def transcribe(self, audio_array, **kwargs) -> str | dict[str, Any]: + pass + + +class WhisperCompatibleRecognizer: + def __init__(self, model: Transcribable) -> None: + self.model = model + + def recognize( + self, audio_data: AudioData, show_dict: bool = False, **kwargs + ): + if not isinstance(audio_data, AudioData): + raise ValueError( + "``audio_data`` must be an ``AudioData`` instance" + ) + + import numpy as np + import soundfile as sf + + # 16 kHz https://github.com/openai/whisper/blob/28769fcfe50755a817ab922a7bc83483159600a9/whisper/audio.py#L98-L99 + wav_bytes = audio_data.get_wav_data(convert_rate=16000) + wav_stream = io.BytesIO(wav_bytes) + audio_array, sampling_rate = sf.read(wav_stream) + audio_array = audio_array.astype(np.float32) + + if "fp16" not in kwargs: + import torch + + kwargs["fp16"] = torch.cuda.is_available() + result = self.model.transcribe(audio_array, **kwargs) + + if show_dict: + return result + else: + return result["text"] + + class LoadModelOptionalParameters(TypedDict, total=False): # ref: https://github.com/openai/whisper/blob/v20240930/whisper/__init__.py#L103 device: str | torch.device @@ -21,12 +59,20 @@ class TranscribeOptionalParameters(TypedDict, total=False): """Transcribe optional parameters & DecodingOptions parameters.""" # ref: https://github.com/openai/whisper/blob/v20240930/whisper/transcribe.py#L38 + temperature: float | tuple[float, ...] # TODO Add others # ref: https://github.com/openai/whisper/blob/v20240930/whisper/decoding.py#L81 - # TODO Add others task: Literal["transcribe", "translate"] language: str + fp16: bool + # TODO Add others + + +class TranscribeOutput(TypedDict): + text: str + segments: list[Any] # TODO Fix Any + language: str def recognize( @@ -36,7 +82,7 @@ def recognize( show_dict: bool = False, load_options: LoadModelOptionalParameters | None = None, **transcribe_options: Unpack[TranscribeOptionalParameters], -): +) -> str | TranscribeOutput: """ Performs speech recognition on ``audio_data`` (an ``AudioData`` instance), using Whisper. @@ -55,34 +101,10 @@ def recognize( Other values are passed directly to whisper. See https://github.com/openai/whisper/blob/main/whisper/transcribe.py for all options """ - import numpy as np - import soundfile as sf - import torch import whisper - if ( - load_options - or not hasattr(self, "whisper_model") - or self.whisper_model.get(model) is None - ): - self.whisper_model = getattr(self, "whisper_model", {}) - self.whisper_model[model] = whisper.load_model( - model, **load_options or {} - ) - - # 16 kHz https://github.com/openai/whisper/blob/28769fcfe50755a817ab922a7bc83483159600a9/whisper/audio.py#L98-L99 - wav_bytes = audio_data.get_wav_data(convert_rate=16000) - wav_stream = io.BytesIO(wav_bytes) - audio_array, sampling_rate = sf.read(wav_stream) - audio_array = audio_array.astype(np.float32) - - result = self.whisper_model[model].transcribe( - audio_array, - fp16=torch.cuda.is_available(), - **transcribe_options, + whisper_model = whisper.load_model(model, **load_options or {}) + recognizer = WhisperCompatibleRecognizer(whisper_model) + return recognizer.recognize( + audio_data, show_dict=show_dict, **transcribe_options ) - - if show_dict: - return result - else: - return result["text"]