Skip to content

Commit

Permalink
[feat] Refine trascribe parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
ftnext committed Dec 29, 2024
1 parent fadc9a5 commit 5b67acf
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
17 changes: 9 additions & 8 deletions speech_recognition/recognizers/whisper_local.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import io
from typing import TYPE_CHECKING, TypedDict
from typing import TYPE_CHECKING, Literal, TypedDict

from speech_recognition.audio import AudioData

Expand All @@ -25,29 +25,32 @@ class TranscribeOptionalParameters(TypedDict, total=False):

# ref: https://github.com/openai/whisper/blob/v20240930/whisper/decoding.py#L81
# TODO Add others
task: Literal["transcribe", "translate"]
language: str


def recognize(
self,
audio_data: AudioData,
model: str = "base",
show_dict: bool = False,
language: str | None = None,
translate: bool = False,
load_options: LoadModelOptionalParameters | None = None,
**transcribe_options: Unpack[TranscribeOptionalParameters],
):
"""
Performs speech recognition on ``audio_data`` (an ``AudioData`` instance), using Whisper.
The recognition language is determined by ``language``, an uncapitalized full language name like "english" or "chinese". See the full language list at https://github.com/openai/whisper/blob/main/whisper/tokenizer.py
Pick ``model`` from output of :command:`python -c 'import whisper; print(whisper.available_models())'`.
See also https://github.com/openai/whisper?tab=readme-ov-file#available-models-and-languages.
If ``show_dict`` is true, returns the full dict response from Whisper, including the detected language. Otherwise returns only the transcription.
You can translate the result to english with Whisper by passing ``translate=True``.
You can specify:
* ``language``: recognition language, an uncapitalized full language name like "english" or "chinese". See the full language list at https://github.com/openai/whisper/blob/main/whisper/tokenizer.py
* ``task``
* If you want transcribe + **translate**, set ``task="translate"``.
Other values are passed directly to whisper. See https://github.com/openai/whisper/blob/main/whisper/transcribe.py for all options
"""
Expand Down Expand Up @@ -75,8 +78,6 @@ def recognize(

result = self.whisper_model[model].transcribe(
audio_array,
language=language,
task="translate" if translate else None,
fp16=torch.cuda.is_available(),
**transcribe_options,
)
Expand Down
6 changes: 2 additions & 4 deletions tests/test_whisper_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def test_default_parameters(
audio_array.astype.assert_called_once_with(np.float32)
whisper_model.transcribe.assert_called_once_with(
audio_array.astype.return_value,
language=None,
task=None,
fp16=is_available.return_value,
)
transcript.__getitem__.assert_called_once_with("text")
Expand Down Expand Up @@ -64,16 +62,16 @@ def test_pass_parameters(self, load_model, is_available, sf_read, BytesIO):
audio_data,
model="small",
language="english",
translate=True,
task="translate",
temperature=0,
)

self.assertEqual(actual, transcript.__getitem__.return_value)
load_model.assert_called_once_with("small")
whisper_model.transcribe.assert_called_once_with(
audio_array.astype.return_value,
fp16=is_available.return_value,
language="english",
task="translate",
fp16=is_available.return_value,
temperature=0,
)

0 comments on commit 5b67acf

Please sign in to comment.