Skip to content

Commit

Permalink
feat: audio chat (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
Fedir Zadniprovskyi committed Jan 26, 2025
1 parent abf6e91 commit c308555
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/speaches/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class Config(BaseSettings):

chat_completion_base_url: str = "https://api.openai.com/v1"
chat_completion_api_key: str | None = None
chat_completion_model: str | None = None

speech_base_url: str | None = None
speech_api_key: str | None = None
Expand Down
4 changes: 4 additions & 0 deletions src/speaches/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

from speaches.dependencies import ApiKeyDependency, get_config
from speaches.logger import setup_logger
from speaches.routers.chat import (
router as chat_router,
)
from speaches.routers.misc import (
router as misc_router,
)
Expand Down Expand Up @@ -56,6 +59,7 @@ def create_app() -> FastAPI:

app = FastAPI(dependencies=dependencies, openapi_tags=TAGS_METADATA)

app.include_router(chat_router)
app.include_router(stt_router)
app.include_router(models_router)
app.include_router(misc_router)
Expand Down
277 changes: 277 additions & 0 deletions src/speaches/routers/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
import base64
from collections.abc import AsyncGenerator
from datetime import UTC, datetime, timedelta
from io import BytesIO
import logging
from typing import Annotated, Self
from uuid import uuid4

import aiostream
from cachetools import TTLCache
from fastapi import APIRouter, Body, Response
from fastapi.responses import StreamingResponse
from openai import AsyncStream
from openai.resources.audio import AsyncSpeech
from openai.types.chat import (
ChatCompletion,
ChatCompletionAudio,
ChatCompletionChunk,
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from pydantic import Field, model_validator

from speaches.dependencies import (
CompletionClientDependency,
ConfigDependency,
SpeechClientDependency,
TranscriptionClientDependency,
)
from speaches.routers.stt import format_as_sse
from speaches.text_utils import SentenceChunker
from speaches.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartTextParam,
)
from speaches.types.chat import (
CompletionCreateParamsBase as OpenAICompletionCreateParamsBase,
)

# Resources:
# - https://platform.openai.com/docs/guides/audio

# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-model
DEFAULT_SPEECH_MODEL = "tts-1" # or "tts-1-hd"
# https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-model
DEFAULT_TRANSCRIPTION_MODEL = "whisper-1"
AUDIO_TRANSCRIPTION_CACHE_SIZE = 4096
AUDIO_TRANSCRIPTION_TTL_SECONDS = 60 * 60

logger = logging.getLogger(__name__)
router = APIRouter()
cache: TTLCache[str, str] = TTLCache(maxsize=AUDIO_TRANSCRIPTION_CACHE_SIZE, ttl=AUDIO_TRANSCRIPTION_TTL_SECONDS)


# NOTE: OpenAI doesn't use UUIDs


def generate_audio_id() -> str:
return "audio_" + str(uuid4())


def generate_chat_completion_id() -> str:
return "chatcmpl-" + str(uuid4())


class CompletionCreateParamsBase(OpenAICompletionCreateParamsBase):
stream: bool = False
trancription_model: str = DEFAULT_TRANSCRIPTION_MODEL
transcription_extra_body: dict | None = None
speech_model: str = DEFAULT_SPEECH_MODEL
speech_extra_body: dict | None = Field(default_factory=lambda: {"sample_rate": 24000})

@model_validator(mode="after")
def validate_audio_format_when_stream(self) -> Self:
# NOTE: OpenAI only supports pcm format for streaming. We can support any format but keeping this hardcoded for consistency # noqa: E501
if self.stream and self.audio is not None and self.audio.format != "pcm16":
raise ValueError(
f"Unsupported value: 'audio.format' does not support '{self.audio.format}' when stream=true. Supported values are: 'pcm16'." # noqa: E501
)
return self


def transform_choice_delta(choice_delta: ChoiceDelta) -> ChoiceDelta:
if choice_delta.content is None:
return choice_delta

content = choice_delta.content
choice_delta.content = None
choice_delta.audio = { # pyright: ignore[reportAttributeAccessIssue]
"transcript": content,
}
return choice_delta


# FIXME: do not pass in `body`
async def transform_choice(speech_client: AsyncSpeech, choice: Choice, body: CompletionCreateParamsBase) -> Choice:
assert body.audio is not None
# HACK: because OpenAI alternates between `pcm16`(/v1/chat/completions) and `pcm`(/v1/audio/speech)
audio_format = "pcm" if body.audio.format == "pcm16" else body.audio.format

if choice.message.content is None:
return choice
res = await speech_client.create(
input=choice.message.content,
model=body.speech_model,
voice=body.audio.voice, # pyright: ignore[reportArgumentType]
response_format=audio_format,
)
audio_bytes = res.read()
audio_id = generate_audio_id()
cache[audio_id] = choice.message.content
choice.message.audio = ChatCompletionAudio(
id=audio_id,
data=base64.b64encode(audio_bytes).decode("utf-8"),
transcript=choice.message.content,
expires_at=int((datetime.now(UTC) + timedelta(seconds=AUDIO_TRANSCRIPTION_TTL_SECONDS)).timestamp()),
)
choice.message.content = None
return choice


# TODO: document minor deviations from OAI
# TODO: rework modalities handling
class AudioChatStream:
def __init__(
self,
chat_completion_chunk_stream: AsyncStream[ChatCompletionChunk],
speech_client: AsyncSpeech,
sentence_chunker: SentenceChunker,
body: CompletionCreateParamsBase, # FIXME: do not pass in `body`
) -> None:
self.chat_completion_chunk_stream = chat_completion_chunk_stream
self.speech_client = speech_client
self.sentence_chunker = sentence_chunker # NOTE: this should be for every choice is I want to support n > 1
self.body = body
self.audio_id = generate_audio_id()
self.expires_at = int((datetime.now(UTC) + timedelta(seconds=AUDIO_TRANSCRIPTION_TTL_SECONDS)).timestamp())
self.chat_completion_id: str
self.created: int

async def text_chat_completion_chunk_stream(self) -> AsyncGenerator[ChatCompletionChunk]:
async for chunk in self.chat_completion_chunk_stream:
assert len(chunk.choices) == 1
self.chat_completion_id = chunk.id
self.created = chunk.created
choice = chunk.choices[0]
assert self.body.modalities is not None
if "audio" not in self.body.modalities: # do not transform the choice if audio is not in the modalities
yield chunk
continue
if choice.delta.content is not None:
self.sentence_chunker.add_token(choice.delta.content)
choice.delta = transform_choice_delta(choice.delta)
choice.delta.audio["id"] = self.audio_id # pyright: ignore[reportAttributeAccessIssue]
choice.delta.audio["expires_at"] = self.expires_at # pyright: ignore[reportAttributeAccessIssue]
# TODO: consider not sending the chunk if there's a finish_reason
# if choice.finish_reason is None:
yield chunk
self.sentence_chunker.close()

async def audio_chat_completion_chunk_stream(self) -> AsyncGenerator[ChatCompletionChunk]:
assert self.body.audio is not None

# TODO: parallelize
async for sentence in self.sentence_chunker:
res = await self.speech_client.create(
input=sentence,
model=self.body.speech_model,
voice=self.body.audio.voice, # pyright: ignore[reportArgumentType]
response_format="pcm",
extra_body={"sample_rate": 24000},
)
audio_bytes = res.read()
audio_data = base64.b64encode(audio_bytes).decode("utf-8")
delta = ChoiceDelta()
delta.audio = { # pyright: ignore[reportAttributeAccessIssue]
"id": self.audio_id,
"data": audio_data,
"expires_at": self.expires_at,
}
yield ChatCompletionChunk(
id=self.chat_completion_id,
choices=[ChunkChoice(delta=delta, index=0)],
created=self.created,
model=self.body.speech_model,
object="chat.completion.chunk",
)

async def __aiter__(self) -> AsyncGenerator[ChatCompletionChunk]:
assert self.body.modalities is not None

if "audio" in self.body.modalities:
merged_stream = aiostream.stream.merge(
self.text_chat_completion_chunk_stream(), self.audio_chat_completion_chunk_stream()
)
# without stream I got warning
async with merged_stream.stream() as stream:
async for chunk in stream:
yield chunk
else:
async for chunk in self.text_chat_completion_chunk_stream():
yield chunk


# TODO: maybe propagate 400 errors


# https://platform.openai.com/docs/api-reference/chat/create
@router.post("/v1/chat/completions", response_model=ChatCompletion | ChatCompletionChunk)
async def handle_completions( # noqa: C901
config: ConfigDependency,
chat_completion_client: CompletionClientDependency,
transcript_client: TranscriptionClientDependency,
speech_client: SpeechClientDependency,
body: Annotated[CompletionCreateParamsBase, Body()],
) -> Response | StreamingResponse:
assert body.n is None or body.n == 1, "Multiple choices (`n` > 1) are not supported"

for i, message in enumerate(body.messages):
if message.role == "user":
content = message.content
# per https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in#quickstart, input audio should be within the `message.content` list # noqa: E501
if not isinstance(content, list):
continue

# TODO: parallelize
for j in range(len(content)):
content_part = content[j]
if content_part.type == "input_audio":
audio_bytes = base64.b64decode(content_part.input_audio.data)
# TODO: how does the endpoint know the format lol?
transcript = await transcript_client.create(
file=BytesIO(audio_bytes),
model=body.trancription_model,
response_format="text",
)
content[j] = ChatCompletionContentPartTextParam(text=transcript, type="text")
logger.info(f"Transcript for message {i} content part {j}: {transcript}")

elif message.role == "assistant" and message.audio is not None:
transcript = cache[message.audio.id]
body.messages[i] = ChatCompletionAssistantMessageParam(
role="assistant",
content=transcript,
# NOTE: I believe the fields below will always be `None`
name=message.name,
refusal=message.refusal,
tool_calls=message.tool_calls,
function_call=message.function_call,
)

# NOTE: rather than doing a `model_copy` it might be better to override the fields when doing the `model_dump` and destructuring # noqa: E501
proxied_body = body.model_copy(deep=True)
proxied_body.modalities = ["text"]
proxied_body.audio = None
proxied_body.model = (
config.chat_completion_model if config.chat_completion_model is not None else body.model
) # HACK: this is different from the one chat completion api sends
# NOTE: Adding --use-one-literal-as-default breaks the `exclude_defaults=True` behavior
chat_completion = await chat_completion_client.create(**proxied_body.model_dump(exclude_defaults=True))
if isinstance(chat_completion, AsyncStream):

async def inner() -> AsyncGenerator[str]:
audio_chat_stream = AudioChatStream(chat_completion, speech_client, SentenceChunker(), body)
async for chunk in audio_chat_stream:
yield format_as_sse(chunk.model_dump_json())

return StreamingResponse(inner(), media_type="text/event-stream")
elif isinstance(chat_completion, ChatCompletion):
for i in range(len(chat_completion.choices)):
if body.modalities is None or "audio" not in body.modalities:
continue
chat_completion.choices[i] = await transform_choice(speech_client, chat_completion.choices[i], body)
return Response(content=chat_completion.model_dump_json(), media_type="application/json")

raise ValueError(f"Unexpected chat completion type: {type(chat_completion)}")

0 comments on commit c308555

Please sign in to comment.