Skip to content

Commit

Permalink
Merge pull request #1622 from h2oai/openai_stt
Browse files Browse the repository at this point in the history
OpenAI proxy STT
  • Loading branch information
pseudotensor authored May 16, 2024
2 parents 87c352e + 88d19bd commit 588d8ee
Show file tree
Hide file tree
Showing 7 changed files with 243 additions and 23 deletions.
2 changes: 1 addition & 1 deletion docs/README_InferenceServers.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ where `<key>` should be replaced by your OpenAI key that probably starts with `s
### Text to Speech
h2oGPT can do text-to-speech and speech-to-text if `--enable_tts=True` and `--enable_stt=True`, respecitively. h2oGPT's OpenAI Proxy server follows OpenAI API for [Text to Speech](https://platform.openai.com/docs/guides/text-to-speech), e.g.:
h2oGPT can do text-to-speech and speech-to-text if `--enable_tts=True` and `--enable_stt=True` as well as `--pre_load_image_audio_models=True`, respectively. h2oGPT's OpenAI Proxy server follows OpenAI API for [Text to Speech](https://platform.openai.com/docs/guides/text-to-speech), e.g.:
```python
from openai import OpenAI
from pathlib import Path
Expand Down
89 changes: 88 additions & 1 deletion openai_server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from collections import deque

import filelock
from pydub import AudioSegment

from log import logger
from openai_server.backend_utils import convert_messages_to_structure
Expand Down Expand Up @@ -456,6 +455,93 @@ def get_model_list():
return dict(model_names=base_models)


def split_audio_on_silence(audio_bytes):
from pydub import AudioSegment
from pydub.silence import split_on_silence

audio = AudioSegment.from_file(io.BytesIO(audio_bytes), format="wav")
chunks = split_on_silence(audio, min_silence_len=500, silence_thresh=-40, keep_silence=200)

chunk_bytes = []
for chunk in chunks:
chunk_buffer = io.BytesIO()
chunk.export(chunk_buffer, format="wav")
chunk_bytes.append(chunk_buffer.getvalue())

return chunk_bytes


def split_audio_fixed_intervals(audio_bytes, interval_ms=10000):
from pydub import AudioSegment

audio = AudioSegment.from_file(io.BytesIO(audio_bytes), format="wav")
chunks = [audio[i:i + interval_ms] for i in range(0, len(audio), interval_ms)]

chunk_bytes = []
for chunk in chunks:
chunk_buffer = io.BytesIO()
chunk.export(chunk_buffer, format="wav")
chunk_bytes.append(chunk_buffer.getvalue())

return chunk_bytes


def audio_to_text(model, audio_file, stream, response_format, chunk, **kwargs):
if chunk != 'none':
# break-up audio file
if chunk == 'silence':
audio_files = split_audio_on_silence(audio_file)
else:
audio_files = split_audio_fixed_intervals(audio_file, interval_ms=chunk)

for audio_file1 in audio_files:
for text in _audio_to_text(model, audio_file1, stream, response_format, chunk, **kwargs):
yield text
else:
for text in _audio_to_text(model, audio_file, stream, response_format, chunk, **kwargs):
yield text


def _audio_to_text(model, audio_file, stream, response_format, chunk, **kwargs):
# assumes enable_stt=True set for h2oGPT
if os.getenv('GRADIO_H2OGPT_H2OGPT_KEY') and not kwargs.get('h2ogpt_key'):
kwargs.update(dict(h2ogpt_key=os.getenv('GRADIO_H2OGPT_H2OGPT_KEY')))

client = get_gradio_client(kwargs.get('user'))
h2ogpt_key = kwargs.get('h2ogpt_key', '')

# string of dict for input
if not isinstance(audio_file, str):
audio_file = base64.b64encode(audio_file).decode('utf-8')

inputs = dict(audio_file=audio_file, stream_output=stream, h2ogpt_key=h2ogpt_key)
if stream:
job = client.submit(*tuple(list(inputs.values())), api_name='/transcribe_audio_api')

# ensure no immediate failure (only required for testing)
import concurrent.futures
try:
e = job.exception(timeout=0.2)
if e is not None:
raise RuntimeError(e)
except concurrent.futures.TimeoutError:
pass

n = 0
for text in job:
yield dict(text=text.strip())
n += 1

# get rest after job done
outputs = job.outputs().copy()
for text in outputs[n:]:
yield dict(text=text.strip())
n += 1
else:
text = client.predict(*tuple(list(inputs.values())), api_name='/transcribe_audio_api')
yield dict(text=text.strip())


def text_to_audio(model, voice, input, stream, format, **kwargs):
# tts_model = 'microsoft/speecht5_tts'
# tts_model = 'tts_models/multilingual/multi-dataset/xtts_v2'
Expand Down Expand Up @@ -521,6 +607,7 @@ def audio_str_to_bytes(audio_str1, format='wav'):
sample_width = 2 # Assuming 16-bit samples (2 bytes), adjust if necessary

# Use from_raw to correctly interpret the raw audio data
from pydub import AudioSegment
audio_segment = AudioSegment.from_raw(
s,
sample_width=sample_width,
Expand Down
25 changes: 20 additions & 5 deletions openai_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,16 +342,31 @@ async def handle_list_models():
return JSONResponse(content=[dict(id=x) for x in get_model_list()])


# Define your request data model
class AudiotoTextRequest(BaseModel):
model: str = ''
file: str
response_format: str = 'text' # FIXME unused
stream: bool = True # NOTE: No effect on OpenAI API client, would have to use direct API
timestamp_granularities: list = ["word"] # FIXME unused
chunk: Union[str, int] = 'silence' # or 'interval' No effect on OpenAI API client, would have to use direct API


@app.post('/v1/audio/transcriptions', dependencies=check_key)
async def handle_audio_transcription(request: Request, request_data: TextRequest):
async def handle_audio_transcription(request: Request):
form = await request.form()
audio_file = await form["file"].read()
model = form["model"]
stream = form.get("stream", False)
response_format = form.get("response_format", 'text')
chunk = form.get("chunk", 'interval')
request_data = dict(model=model, stream=stream, audio_file=audio_file, response_format=response_format, chunk=chunk)

if request_data.stream:
if stream:
from openai_server.backend import audio_to_text

async def generator():
response = audio_to_text(audio_file, **dict(request_data))
response = audio_to_text(**request_data)
for resp in response:
disconnected = await request.is_disconnected()
if disconnected:
Expand All @@ -361,9 +376,9 @@ async def generator():

return EventSourceResponse(generator())
else:
from openai_server.backend import audio_to_text
from openai_server.backend import _audio_to_text
response = ''
for response1 in audio_to_text(audio_file, **dict(request_data)):
for response1 in _audio_to_text(**request_data):
response = response1
return JSONResponse(response)

Expand Down
25 changes: 25 additions & 0 deletions src/gradio_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6704,6 +6704,31 @@ def stop_audio_func():
load_event6 = load_event5.then(**get_viewable_sources_args_login)
load_event7 = load_event6.then(**viewable_kwargs)

def wrap_transcribe_func_api(audio_obj1, stream_output1, h2ogpt_key1, requests_state1):
# check key
valid_key = is_valid_key(kwargs['enforce_h2ogpt_api_key'],
kwargs['enforce_h2ogpt_ui_key'],
kwargs['h2ogpt_api_keys'],
h2ogpt_key1,
requests_state1=requests_state1)
kwargs['from_ui'] = is_from_ui(requests_state1)
if not valid_key:
raise ValueError(invalid_key_msg)

audio_api_state0 = ['', '', None, 'on']
state_text = kwargs['transcriber_func'](audio_api_state0, audio_obj1)
text = state_text[1]
yield text

audio_api_output = gr.Textbox(value='', visible=False)
audio_api_input = gr.Textbox(value='', visible=False)
audio_api_btn = gr.Button(visible=False)
audio_api_btn.click(fn=wrap_transcribe_func_api,
inputs=[audio_api_input, stream_output, h2ogpt_key, requests_state],
outputs=[audio_api_output],
api_name='transcribe_audio_api',
show_progress='hidden')

demo.queue(**queue_kwargs, api_open=kwargs['api_open'])
favicon_file = "h2o-logo.svg"
favicon_path = kwargs['favicon_path'] or favicon_file
Expand Down
45 changes: 44 additions & 1 deletion src/stt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import base64
import io
import traceback

import numpy as np
from pydub import AudioSegment

from src.utils import get_device


Expand All @@ -18,8 +22,31 @@ def get_transcriber(model="openai/whisper-base.en", use_gpu=True, gpu_id='auto')
return transcriber


def audio_bytes_to_numpy(audio_bytes):
# Load the audio bytes into a BytesIO object
audio_stream = io.BytesIO(audio_bytes)

# Use pydub to read the audio data from the BytesIO object
audio = AudioSegment.from_file(audio_stream)

# Convert pydub AudioSegment to a numpy array
samples = np.array(audio.get_array_of_samples())

# Get the sampling rate
sr = audio.frame_rate

# If the audio is stereo, we need to reshape the numpy array to [n_samples, n_channels]
if audio.channels > 1:
samples = samples.reshape((-1, audio.channels))

return sr, samples


def transcribe(audio_state1, new_chunk, transcriber=None, max_chunks=None, sst_floor=100.0, reject_no_new_text=True,
debug=False):
if debug:
print("start transcribe", flush=True)

if audio_state1[0] is None:
audio_state1[0] = ''
if audio_state1[2] is None:
Expand All @@ -33,7 +60,15 @@ def transcribe(audio_state1, new_chunk, transcriber=None, max_chunks=None, sst_f
return audio_state1, audio_state1[1]
# assume sampling rate always same
# keep chunks so don't normalize on noise periods, which would then saturate noise with non-noise
sr, y = new_chunk
if isinstance(new_chunk, str):
audio_bytes = base64.b64decode(new_chunk.encode('utf-8'))
sr, y = audio_bytes_to_numpy(audio_bytes)
else:
sr, y = new_chunk

if debug:
print("post encode", flush=True)

if y.shape[0] == 0:
avg = 0.0
else:
Expand All @@ -56,7 +91,11 @@ def transcribe(audio_state1, new_chunk, transcriber=None, max_chunks=None, sst_f
stream = stream.astype(np.float32)
max_stream = np.max(np.abs(stream) + 1E-7)
stream /= max_stream
if debug:
print("pre transcriber", flush=True)
text = transcriber({"sampling_rate": sr, "raw": stream})["text"]
if debug:
print("post transcriber", flush=True)

if audio_state1[2]:
try:
Expand All @@ -67,7 +106,11 @@ def transcribe(audio_state1, new_chunk, transcriber=None, max_chunks=None, sst_f
stream0 = stream0.astype(np.float32)
max_stream0 = np.max(np.abs(stream0) + 1E-7)
stream0 /= max_stream0
if debug:
print("pre stranscriber", flush=True)
text_y = transcriber({"sampling_rate": sr, "raw": stream0})["text"]
if debug:
print("post stranscriber", flush=True)
else:
text_y = None

Expand Down
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "560f93d7983c0f3f195f9ffe146bfa17c1d97724"
__version__ = "560f93d7983c0f3f195f9ffe146bfa17c1d97724"
78 changes: 64 additions & 14 deletions tests/test_client_calls.py

Large diffs are not rendered by default.

0 comments on commit 588d8ee

Please sign in to comment.