-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #797 from ftnext/feature/groq-support
Support Groq whisper
- Loading branch information
Showing
7 changed files
with
139 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
from typing import Literal, TypedDict | ||
from typing_extensions import Unpack | ||
|
||
from speech_recognition.audio import AudioData | ||
from speech_recognition.exceptions import SetupError | ||
from speech_recognition.recognizers.whisper_api import ( | ||
OpenAICompatibleRecognizer, | ||
) | ||
|
||
# https://console.groq.com/docs/speech-text#supported-models | ||
GroqModel = Literal[ | ||
"whisper-large-v3-turbo", "whisper-large-v3", "distil-whisper-large-v3-en" | ||
] | ||
|
||
|
||
class GroqOptionalParameters(TypedDict): | ||
"""Groq speech transcription's optional parameters. | ||
https://console.groq.com/docs/speech-text#transcription-endpoint-usage | ||
""" | ||
|
||
prompt: str | ||
response_format: str | ||
temperature: float | ||
language: str | ||
|
||
|
||
def recognize_groq( | ||
recognizer, | ||
audio_data: "AudioData", | ||
*, | ||
model: GroqModel = "whisper-large-v3-turbo", | ||
**kwargs: Unpack[GroqOptionalParameters], | ||
) -> str: | ||
""" | ||
Performs speech recognition on ``audio_data`` (an ``AudioData`` instance), using the Groq Whisper API. | ||
This function requires login to Groq; visit https://console.groq.com/login, then generate API Key in `API Keys <https://console.groq.com/keys>`__ menu. | ||
Detail: https://console.groq.com/docs/speech-text | ||
Raises a ``speech_recognition.exceptions.SetupError`` exception if there are any issues with the groq installation, or the environment variable is missing. | ||
""" | ||
if os.environ.get("GROQ_API_KEY") is None: | ||
raise SetupError("Set environment variable ``GROQ_API_KEY``") | ||
|
||
try: | ||
import groq | ||
except ImportError: | ||
raise SetupError( | ||
"missing groq module: ensure that groq is set up correctly." | ||
) | ||
|
||
recognizer = OpenAICompatibleRecognizer(groq.Groq()) | ||
return recognizer.recognize(audio_data, model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from io import BytesIO | ||
|
||
from speech_recognition.audio import AudioData | ||
|
||
|
||
class OpenAICompatibleRecognizer: | ||
def __init__(self, client) -> None: | ||
self.client = client | ||
|
||
def recognize(self, audio_data: "AudioData", model: str, **kwargs) -> str: | ||
if not isinstance(audio_data, AudioData): | ||
raise ValueError( | ||
"``audio_data`` must be an ``AudioData`` instance" | ||
) | ||
|
||
wav_data = BytesIO(audio_data.get_wav_data()) | ||
wav_data.name = "SpeechRecognition_audio.wav" | ||
|
||
transcript = self.client.audio.transcriptions.create( | ||
file=wav_data, model=model, **kwargs | ||
) | ||
return transcript.text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from unittest.mock import MagicMock | ||
|
||
import httpx | ||
import respx | ||
|
||
from speech_recognition import AudioData, Recognizer | ||
from speech_recognition.recognizers import groq | ||
|
||
|
||
@respx.mock(assert_all_called=True, assert_all_mocked=True) | ||
def test_transcribe_with_groq_whisper(respx_mock, monkeypatch): | ||
monkeypatch.setenv("GROQ_API_KEY", "gsk_grok_api_key") | ||
|
||
respx_mock.post( | ||
"https://api.groq.com/openai/v1/audio/transcriptions" | ||
).mock( | ||
return_value=httpx.Response( | ||
200, | ||
json={ | ||
"text": "Transcription by Groq Whisper", | ||
"x_groq": {"id": "req_unique_id"}, | ||
}, | ||
) | ||
) | ||
|
||
audio_data = MagicMock(spec=AudioData) | ||
audio_data.get_wav_data.return_value = b"audio_data" | ||
|
||
actual = groq.recognize_groq( | ||
MagicMock(spec=Recognizer), audio_data, model="whisper-large-v3" | ||
) | ||
|
||
assert actual == "Transcription by Groq Whisper" |