Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TTS to OpenAI_API_Compatible #11071

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
"""
Model class for OpenAI Speech to text model.
Model class for OpenAI text2speech model.
"""

def _invoke(
Expand Down
2 changes: 1 addition & 1 deletion api/core/model_runtime/model_providers/gitee_ai/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
"""
Model class for OpenAI Speech to text model.
Model class for OpenAI text2speech model.
"""

def _invoke(
Expand Down
2 changes: 1 addition & 1 deletion api/core/model_runtime/model_providers/openai/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
"""
Model class for OpenAI Speech to text model.
Model class for OpenAI text2speech model.
"""

def _invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ supported_model_types:
- text-embedding
- speech2text
- rerank
- tts
configurate_methods:
- customizable-model
model_credential_schema:
Expand Down Expand Up @@ -67,7 +68,7 @@ model_credential_schema:
- variable: __model_type
value: llm
type: text-input
default: '4096'
default: "4096"
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
Expand All @@ -80,7 +81,7 @@ model_credential_schema:
- variable: __model_type
value: text-embedding
type: text-input
default: '4096'
default: "4096"
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
Expand All @@ -93,7 +94,7 @@ model_credential_schema:
- variable: __model_type
value: rerank
type: text-input
default: '4096'
default: "4096"
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
Expand All @@ -104,7 +105,7 @@ model_credential_schema:
show_on:
- variable: __model_type
value: llm
default: '4096'
default: "4096"
type: text-input
- variable: function_calling_type
show_on:
Expand Down Expand Up @@ -174,3 +175,19 @@ model_credential_schema:
value: llm
default: '\n\n'
type: text-input
- variable: voices
show_on:
- variable: __model_type
value: tts
label:
en_US: Available Voices (comma-separated)
zh_Hans: 可用声音(用英文逗号分隔)
type: text-input
required: false
default: "alloy"
placeholder:
en_US: "alloy,echo,fable,onyx,nova,shimmer"
zh_Hans: "alloy,echo,fable,onyx,nova,shimmer"
help:
en_US: "List voice names separated by commas. First voice will be used as default."
zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。"
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from collections.abc import Iterable
from typing import Optional
from urllib.parse import urljoin

import requests

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat


class OAICompatText2SpeechModel(_CommonOaiApiCompat, TTSModel):
"""
Model class for OpenAI-compatible text2speech model.
"""

def _invoke(
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
user: Optional[str] = None,
) -> Iterable[bytes]:
"""
Invoke TTS model

:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model voice/speaker
:param user: unique user id
:return: audio data as bytes iterator
"""
# Set up headers with authentication if provided
headers = {}
if api_key := credentials.get("api_key"):
headers["Authorization"] = f"Bearer {api_key}"

# Construct endpoint URL
endpoint_url = credentials.get("endpoint_url")
if not endpoint_url.endswith("/"):
endpoint_url += "/"
endpoint_url = urljoin(endpoint_url, "audio/speech")

# Get audio format from model properties
audio_format = self._get_model_audio_type(model, credentials)

# Split text into chunks if needed based on word limit
word_limit = self._get_model_word_limit(model, credentials)
sentences = self._split_text_into_sentences(content_text, word_limit)

for sentence in sentences:
# Prepare request payload
payload = {"model": model, "input": sentence, "voice": voice, "response_format": audio_format}

# Make POST request
response = requests.post(endpoint_url, headers=headers, json=payload, stream=True)

if response.status_code != 200:
raise InvokeBadRequestError(response.text)

# Stream the audio data
for chunk in response.iter_content(chunk_size=4096):
if chunk:
yield chunk

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials

:param model: model name
:param credentials: model credentials
:return:
"""
try:
# Get default voice for validation
voice = self._get_model_default_voice(model, credentials)

# Test with a simple text
next(
self._invoke(
model=model, tenant_id="validate", credentials=credentials, content_text="Test.", voice=voice
)
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema
"""
# Parse voices from comma-separated string
voice_names = credentials.get("voices", "alloy").strip().split(",")
voices = []

for voice in voice_names:
voice = voice.strip()
if not voice:
continue

# Use en-US for all voices
voices.append(
{
"name": voice,
"mode": voice,
"language": "en-US",
}
)

# If no voices provided or all voices were empty strings, use 'alloy' as default
if not voices:
voices = [{"name": "Alloy", "mode": "alloy", "language": "en-US"}]

return AIModelEntity(
model=model,
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TTS,
model_properties={
ModelPropertyKey.AUDIO_TYPE: credentials.get("audio_type", "mp3"),
ModelPropertyKey.WORD_LIMIT: int(credentials.get("word_limit", 4096)),
ModelPropertyKey.DEFAULT_VOICE: voices[0]["mode"],
ModelPropertyKey.VOICES: voices,
},
)

def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
"""
Override base get_tts_model_voices to handle customizable voices
"""
model_schema = self.get_customizable_model_schema(model, credentials)

if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
raise ValueError("this model does not support voice")

voices = model_schema.model_properties[ModelPropertyKey.VOICES]

# Always return all voices regardless of language
return [{"name": d["name"], "value": d["mode"]} for d in voices]
Loading