Skip to content

Commit

Permalink
[refactor] Extract to support Whisper variants in future
Browse files Browse the repository at this point in the history
  • Loading branch information
ftnext committed Dec 29, 2024
1 parent 5b67acf commit a88e099
Showing 1 changed file with 53 additions and 31 deletions.
84 changes: 53 additions & 31 deletions speech_recognition/recognizers/whisper_local.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import io
from typing import TYPE_CHECKING, Literal, TypedDict
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict

from speech_recognition.audio import AudioData

Expand All @@ -10,6 +10,44 @@
from typing_extensions import Unpack


class Transcribable(Protocol):
def transcribe(self, audio_array, **kwargs) -> str | dict[str, Any]:
pass


class WhisperCompatibleRecognizer:
def __init__(self, model: Transcribable) -> None:
self.model = model

def recognize(
self, audio_data: AudioData, show_dict: bool = False, **kwargs
):
if not isinstance(audio_data, AudioData):
raise ValueError(
"``audio_data`` must be an ``AudioData`` instance"
)

import numpy as np
import soundfile as sf

# 16 kHz https://github.com/openai/whisper/blob/28769fcfe50755a817ab922a7bc83483159600a9/whisper/audio.py#L98-L99
wav_bytes = audio_data.get_wav_data(convert_rate=16000)
wav_stream = io.BytesIO(wav_bytes)
audio_array, sampling_rate = sf.read(wav_stream)
audio_array = audio_array.astype(np.float32)

if "fp16" not in kwargs:
import torch

kwargs["fp16"] = torch.cuda.is_available()
result = self.model.transcribe(audio_array, **kwargs)

if show_dict:
return result
else:
return result["text"]


class LoadModelOptionalParameters(TypedDict, total=False):
# ref: https://github.com/openai/whisper/blob/v20240930/whisper/__init__.py#L103
device: str | torch.device
Expand All @@ -21,12 +59,20 @@ class TranscribeOptionalParameters(TypedDict, total=False):
"""Transcribe optional parameters & DecodingOptions parameters."""

# ref: https://github.com/openai/whisper/blob/v20240930/whisper/transcribe.py#L38
temperature: float | tuple[float, ...]
# TODO Add others

# ref: https://github.com/openai/whisper/blob/v20240930/whisper/decoding.py#L81
# TODO Add others
task: Literal["transcribe", "translate"]
language: str
fp16: bool
# TODO Add others


class TranscribeOutput(TypedDict):
text: str
segments: list[Any] # TODO Fix Any
language: str


def recognize(
Expand All @@ -36,7 +82,7 @@ def recognize(
show_dict: bool = False,
load_options: LoadModelOptionalParameters | None = None,
**transcribe_options: Unpack[TranscribeOptionalParameters],
):
) -> str | TranscribeOutput:
"""
Performs speech recognition on ``audio_data`` (an ``AudioData`` instance), using Whisper.
Expand All @@ -55,34 +101,10 @@ def recognize(
Other values are passed directly to whisper. See https://github.com/openai/whisper/blob/main/whisper/transcribe.py for all options
"""

import numpy as np
import soundfile as sf
import torch
import whisper

if (
load_options
or not hasattr(self, "whisper_model")
or self.whisper_model.get(model) is None
):
self.whisper_model = getattr(self, "whisper_model", {})
self.whisper_model[model] = whisper.load_model(
model, **load_options or {}
)

# 16 kHz https://github.com/openai/whisper/blob/28769fcfe50755a817ab922a7bc83483159600a9/whisper/audio.py#L98-L99
wav_bytes = audio_data.get_wav_data(convert_rate=16000)
wav_stream = io.BytesIO(wav_bytes)
audio_array, sampling_rate = sf.read(wav_stream)
audio_array = audio_array.astype(np.float32)

result = self.whisper_model[model].transcribe(
audio_array,
fp16=torch.cuda.is_available(),
**transcribe_options,
whisper_model = whisper.load_model(model, **load_options or {})
recognizer = WhisperCompatibleRecognizer(whisper_model)
return recognizer.recognize(
audio_data, show_dict=show_dict, **transcribe_options
)

if show_dict:
return result
else:
return result["text"]

0 comments on commit a88e099

Please sign in to comment.