Skip to content

Commit

Permalink
Merge branch 'refs/heads/main' into eugene/TLK-1771_improve_agent-cre…
Browse files Browse the repository at this point in the history
…ation_flow
  • Loading branch information
EugeneLightsOn committed Oct 16, 2024
2 parents ccd36e6 + 7e2e993 commit 3d23402
Show file tree
Hide file tree
Showing 30 changed files with 603 additions and 190 deletions.
4 changes: 4 additions & 0 deletions .env-template
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,7 @@ GOOGLE_DRIVE_CLIENT_ID=
GOOGLE_DRIVE_CLIENT_SECRET=
NEXT_PUBLIC_GOOGLE_DRIVE_CLIENT_ID=${GOOGLE_DRIVE_CLIENT_ID}
NEXT_PUBLIC_GOOGLE_DRIVE_DEVELOPER_KEY=

# Google Cloud

GOOGLE_CLOUD_API_KEY=<API_KEY_HERE>
44 changes: 37 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ hyperframe = "^6.0.1"
llama-index = "^0.11.10"
llama-index-llms-cohere = "^0.3.0"
llama-index-embeddings-cohere = "^0.2.1"
google-cloud-texttospeech = "^2.18.0"

[tool.poetry.group.dev]
optional = true
Expand Down
4 changes: 3 additions & 1 deletion src/backend/config/secrets.template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@ auth:
oidc:
client_id:
client_secret:
well_known_endpoint:
well_known_endpoint:
google_cloud:
api_key:
8 changes: 8 additions & 0 deletions src/backend/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ class RedisSettings(BaseSettings, BaseModel):
)


class GoogleCloudSettings(BaseSettings, BaseModel):
model_config = SETTINGS_CONFIG
api_key: Optional[str] = Field(
default=None, validation_alias=AliasChoices("GOOGLE_CLOUD_API_KEY", "api_key")
)


class SageMakerSettings(BaseSettings, BaseModel):
model_config = SETTINGS_CONFIG
endpoint_name: Optional[str] = Field(
Expand Down Expand Up @@ -331,6 +338,7 @@ class Settings(BaseSettings):
tools: Optional[ToolSettings] = Field(default=ToolSettings())
database: Optional[DatabaseSettings] = Field(default=DatabaseSettings())
redis: Optional[RedisSettings] = Field(default=RedisSettings())
google_cloud: Optional[GoogleCloudSettings] = Field(default=GoogleCloudSettings())
deployments: Optional[DeploymentSettings] = Field(default=DeploymentSettings())
logger: Optional[LoggerSettings] = Field(default=LoggerSettings())

Expand Down
21 changes: 21 additions & 0 deletions src/backend/crud/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,27 @@ def get_messages(
)


@validate_transaction
def get_conversation_message(db: Session, conversation_id: str, message_id: str, user_id: str) -> Message | None:
"""
Get a message based on the conversation ID, message ID, and user ID.
Args:
db (Session): Database session.
conversation_id (str): Conversation ID.
message_id (str): Message ID.
user_id (str): User ID.
Returns:
Message | None: Message with the given conversation ID, message ID, and user ID or None if not found.
"""
return (
db.query(Message)
.filter(Message.conversation_id == conversation_id, Message.id == message_id, Message.user_id == user_id)
.first()
)


@validate_transaction
def get_messages_by_conversation_id(
db: Session, conversation_id: str, user_id: str
Expand Down
45 changes: 45 additions & 0 deletions src/backend/routers/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from fastapi import APIRouter, Depends, Form, HTTPException, Request
from fastapi import File as RequestFile
from fastapi import UploadFile as FastAPIUploadFile
from starlette.responses import Response

from backend.chat.custom.utils import get_deployment
from backend.config.routers import RouterName
from backend.crud import agent as agent_crud
from backend.crud import conversation as conversation_crud
from backend.crud import message as message_crud
from backend.database_models import Conversation as ConversationModel
from backend.database_models.database import DBSessionDep
from backend.schemas.agent import Agent
Expand Down Expand Up @@ -39,6 +41,7 @@
get_file_service,
validate_file,
)
from backend.services.synthesizer import synthesize

router = APIRouter(
prefix="/v1/conversations",
Expand Down Expand Up @@ -543,3 +546,45 @@ async def generate_title(
title=title,
error=error,
)


# SYNTHESIZE
@router.get("/{conversation_id}/synthesize/{message_id}")
async def synthesize_message(
conversation_id: str,
message_id: str,
session: DBSessionDep,
ctx: Context = Depends(get_context),
) -> Response:
"""
Generate a synthesized audio for a specific message in a conversation.
Args:
conversation_id (str): Conversation ID.
message_id (str): Message ID.
session (DBSessionDep): Database session.
ctx (Context): Context object.
Returns:
Response: Synthesized audio file.
Raises:
HTTPException: If the message with the given ID is not found or synthesis fails.
"""
user_id = ctx.get_user_id()
message = message_crud.get_conversation_message(session, conversation_id, message_id, user_id)

if not message:
raise HTTPException(
status_code=404,
detail=f"Message with ID: {message_id} not found.",
)

try:
synthesized_audio = synthesize(message.text)
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error while message synthesis: {e}"
)

return Response(synthesized_audio, media_type="audio/mp3")
4 changes: 2 additions & 2 deletions src/backend/routers/experimental_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@router.get("/")
def list_experimental_features(ctx: Context = Depends(get_context)):
def list_experimental_features(ctx: Context = Depends(get_context)) -> dict[str, bool]:
"""
List all experimental features and if they are enabled
Expand All @@ -22,8 +22,8 @@ def list_experimental_features(ctx: Context = Depends(get_context)):
Returns:
Dict[str, bool]: Experimental feature and their isEnabled state
"""

experimental_features = {
"USE_AGENTS_VIEW": Settings().feature_flags.use_agents_view,
"USE_TEXT_TO_SPEECH_SYNTHESIS": bool(Settings().google_cloud.api_key),
}
return experimental_features
1 change: 1 addition & 0 deletions src/backend/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class StreamToolCallsGeneration(ChatResponse):


class StreamEnd(ChatResponse):
message_id: str | None = Field(default=None)
response_id: str | None = Field(default=None)
event_type: ClassVar[StreamEvent] = StreamEvent.STREAM_END
generation_id: str | None = Field(default=None)
Expand Down
1 change: 1 addition & 0 deletions src/backend/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,7 @@ async def generate_chat_stream(
user_id = ctx.get_user_id()

stream_end_data = {
"message_id": response_message.id,
"conversation_id": conversation_id,
"response_id": ctx.get_trace_id(),
"text": "",
Expand Down
79 changes: 79 additions & 0 deletions src/backend/services/synthesizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from google.cloud.texttospeech import (
AudioConfig,
AudioEncoding,
SynthesisInput,
TextToSpeechClient,
VoiceSelectionParams,
)
from googleapiclient.discovery import build

from backend.config import Settings


def synthesize(text: str) -> bytes:
"""
Synthesizes speech from the input text.
Args:
text (str): The input text to be synthesized into speech.
Returns:
bytes: The audio content generated from the input text in MP3 format.
Raises:
ValueError: If the Google Cloud API key from the settings is not valid.
"""
client = TextToSpeechClient(client_options={
"api_key": _validate_google_cloud_api_key()
})

language = detect_language(text)

response = client.synthesize_speech(
input=SynthesisInput(text=text),
voice=VoiceSelectionParams(language_code=language),
audio_config=AudioConfig(audio_encoding=AudioEncoding.MP3)
)

return response.audio_content


def detect_language(text: str) -> str:
"""
Detect the language of the given text.
Args:
text (str): The text for which the language needs to be detected.
Returns:
str: The language code of the detected language (e.g., 'en', 'es').
Raises:
ValueError: If the Google Cloud API key from the settings is not valid.
"""
client = build("translate", "v2", developerKey=_validate_google_cloud_api_key())

response = client.detections().list(q=text).execute()

return response["detections"][0][0]["language"]


def _validate_google_cloud_api_key() -> str:
"""
Validates the Google Cloud API key from the settings.
Returns:
str: The validated API key.
Raises:
ValueError: If the API key is not found in the settings or is empty.
"""
google_cloud = Settings().google_cloud

if not google_cloud:
raise ValueError("google_cloud in secrets.yaml is missing.")

if not google_cloud.api_key:
raise ValueError("google_cloud.api_key in secrets.yaml is missing.")

return google_cloud.api_key
Loading

0 comments on commit 3d23402

Please sign in to comment.