Skip to content

Commit

Permalink
[feat] Users can pass Google endpoint (ref Uberi#751)
Browse files Browse the repository at this point in the history
  • Loading branch information
ftnext committed May 5, 2024
1 parent 765e2cf commit c09f15f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
26 changes: 19 additions & 7 deletions speech_recognition/recognizers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@ class GoogleResponse(TypedDict):
ProfanityFilterLevel = Literal[0, 1]
RequestHeaders = Dict[str, str]

ENDPOINT = "http://www.google.com/speech-api/v2/recognize"

class RequestBuilder:
endpoint = "http://www.google.com/speech-api/v2/recognize"

class RequestBuilder:
def __init__(
self, *, key: str, language: str, filter_level: ProfanityFilterLevel
self,
*,
endpoint: str,
key: str,
language: str,
filter_level: ProfanityFilterLevel,
) -> None:
self.endpoint = endpoint
self.key = key
self.language = language
self.filter_level = filter_level
Expand All @@ -53,7 +59,7 @@ def build(self, audio_data: AudioData) -> Request:

def build_url(self) -> str:
"""
>>> builder = RequestBuilder(key="awesome-key", language="en-US", filter_level=0)
>>> builder = RequestBuilder(endpoint="http://www.google.com/speech-api/v2/recognize", key="awesome-key", language="en-US", filter_level=0)
>>> builder.build_url()
'http://www.google.com/speech-api/v2/recognize?client=chromium&lang=en-US&key=awesome-key&pFilter=0'
"""
Expand All @@ -69,7 +75,7 @@ def build_url(self) -> str:

def build_headers(self, audio_data: AudioData) -> RequestHeaders:
"""
>>> builder = RequestBuilder(key="", language="", filter_level=1)
>>> builder = RequestBuilder(endpoint="", key="", language="", filter_level=1)
>>> audio_data = AudioData(b"", 16_000, 1)
>>> builder.build_headers(audio_data)
{'Content-Type': 'audio/x-flac; rate=16000'}
Expand Down Expand Up @@ -99,6 +105,7 @@ def to_convert_rate(sample_rate: int) -> int:

def create_request_builder(
*,
endpoint: str,
key: str | None = None,
language: str = "en-US",
filter_level: ProfanityFilterLevel = 0,
Expand All @@ -111,7 +118,10 @@ def create_request_builder(
if key is None:
key = "AIzaSyBOti4mM-6x9WDnZIjIeyEU21OpBXqWBgw"
return RequestBuilder(
key=key, language=language, filter_level=filter_level
endpoint=endpoint,
key=key,
language=language,
filter_level=filter_level,
)


Expand Down Expand Up @@ -220,6 +230,8 @@ def recognize_legacy(
pfilter: ProfanityFilterLevel = 0,
show_all: bool = False,
with_confidence: bool = False,
*,
endpoint: str = ENDPOINT,
):
"""
Performs speech recognition on ``audio_data`` (an ``AudioData`` instance), using the Google Speech Recognition API.
Expand All @@ -237,7 +249,7 @@ def recognize_legacy(
Raises a ``speech_recognition.UnknownValueError`` exception if the speech is unintelligible. Raises a ``speech_recognition.RequestError`` exception if the speech recognition operation failed, if the key isn't valid, or if there is no internet connection.
"""
request_builder = create_request_builder(
key=key, language=language, filter_level=pfilter
endpoint=endpoint, key=key, language=language, filter_level=pfilter
)
request = request_builder.build(audio_data)

Expand Down
19 changes: 15 additions & 4 deletions tests/recognizers/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ class RequestBuilderTestCase(TestCase):
@patch(f"{CLASS_UNDER_TEST}.build_url")
def test_build(self, build_url, build_headers, build_data, Request):
audio_data = MagicMock(spec=AudioData)
sut = google.RequestBuilder(key="", language="", filter_level=0)
sut = google.RequestBuilder(
endpoint="", key="", language="", filter_level=0
)

actual = sut.build(audio_data)

Expand All @@ -36,7 +38,9 @@ def test_build(self, build_url, build_headers, build_data, Request):
def test_build_data(self, to_convert_rate):
# mock has AudioData's attributes (e.g. sample_rate)
audio_data = MagicMock(spec=AudioData(None, 1, 1))
sut = google.RequestBuilder(key="", language="", filter_level=0)
sut = google.RequestBuilder(
endpoint="", key="", language="", filter_level=0
)

actual = sut.build_data(audio_data)

Expand Down Expand Up @@ -131,7 +135,10 @@ def test_default_values(

self.assertEqual(actual, output_parser.parse.return_value)
create_request_builder.assert_called_once_with(
key=None, language="en-US", filter_level=0
endpoint="http://www.google.com/speech-api/v2/recognize",
key=None,
language="en-US",
filter_level=0,
)
request_builder.build.assert_called_once_with(audio_data)
obtain_transcription.assert_called_once_with(
Expand Down Expand Up @@ -160,11 +167,15 @@ def test_specified_values(
pfilter=1,
show_all=True,
with_confidence=False,
endpoint="https://www.google.com/speech-api/v2/recognize",
)

self.assertEqual(actual, output_parser.parse.return_value)
create_request_builder.assert_called_once_with(
key="awesome-key", language="zh-CN", filter_level=1
endpoint="https://www.google.com/speech-api/v2/recognize",
key="awesome-key",
language="zh-CN",
filter_level=1,
)
request_builder.build.assert_called_once_with(audio_data)
obtain_transcription.assert_called_once_with(
Expand Down

0 comments on commit c09f15f

Please sign in to comment.