Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add the audio tool #10695

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/core/tools/provider/builtin/audio/_assets/icon.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions api/core/tools/provider/builtin/audio/audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController


class AudioToolProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
pass
11 changes: 11 additions & 0 deletions api/core/tools/provider/builtin/audio/audio.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
identity:
author: hjlarry
name: audio
label:
en_US: Audio
description:
en_US: A tool for tts and asr.
zh_Hans: 一个用于文本转语音和语音转文本的工具。
icon: icon.svg
tags:
- utilities
70 changes: 70 additions & 0 deletions api/core/tools/provider/builtin/audio/tools/asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import io
from typing import Any

from core.file.enums import FileType
from core.file.file_manager import download
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
from core.tools.tool.builtin_tool import BuiltinTool
from services.model_provider_service import ModelProviderService


class ASRTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
file = tool_parameters.get("audio_file")
if file.type != FileType.AUDIO:
return [self.create_text_message("not a valid audio file")]
audio_binary = io.BytesIO(download(file))
audio_binary.name = "temp.mp3"
provider, model = tool_parameters.get("model").split("#")
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
provider=provider,
model_type=ModelType.SPEECH2TEXT,
model=model,
)
text = model_instance.invoke_speech2text(
file=audio_binary,
user=user_id,
)
return [self.create_text_message(text)]

def get_available_models(self) -> list[tuple[str, str]]:
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(
tenant_id=self.runtime.tenant_id, model_type="speech2text"
)
items = []
for provider_model in models:
provider = provider_model.provider
for model in provider_model.models:
items.append((provider, model.model))
return items

def get_runtime_parameters(self) -> list[ToolParameter]:
parameters = []

options = []
for provider, model in self.get_available_models():
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
options.append(option)

parameters.append(
ToolParameter(
name="model",
label=I18nObject(en_US="Model", zh_Hans="Model"),
human_description=I18nObject(
en_US="All available ASR models",
zh_Hans="所有可用的 ASR 模型",
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
default=options[0].value,
options=options,
)
)
return parameters
22 changes: 22 additions & 0 deletions api/core/tools/provider/builtin/audio/tools/asr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
identity:
name: asr
author: hjlarry
label:
en_US: Speech To Text
description:
human:
en_US: Convert audio file to text.
zh_Hans: 将音频文件转换为文本。
llm: Convert audio file to text.
parameters:
- name: audio_file
type: file
required: true
label:
en_US: Audio File
zh_Hans: 音频文件
human_description:
en_US: The audio file to be converted.
zh_Hans: 要转换的音频文件。
llm_description: The audio file to be converted.
form: llm
90 changes: 90 additions & 0 deletions api/core/tools/provider/builtin/audio/tools/tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import io
from typing import Any

from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
from core.tools.tool.builtin_tool import BuiltinTool
from services.model_provider_service import ModelProviderService


class TTSTool(BuiltinTool):
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
provider, model = tool_parameters.get("model").split("#")
voice = tool_parameters.get(f"voice#{provider}#{model}")
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.runtime.tenant_id,
provider=provider,
model_type=ModelType.TTS,
model=model,
)
tts = model_instance.invoke_tts(
content_text=tool_parameters.get("text"),
user=user_id,
tenant_id=self.runtime.tenant_id,
voice=voice,
)
buffer = io.BytesIO()
for chunk in tts:
buffer.write(chunk)

wav_bytes = buffer.getvalue()
return [
self.create_text_message("Audio generated successfully"),
self.create_blob_message(
blob=wav_bytes,
meta={"mime_type": "audio/x-wav"},
save_as=self.VariableKey.AUDIO,
),
]

def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
model_provider_service = ModelProviderService()
models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
items = []
for provider_model in models:
provider = provider_model.provider
for model in provider_model.models:
voices = model.model_properties.get(ModelPropertyKey.VOICES, [])
items.append((provider, model.model, voices))
return items

def get_runtime_parameters(self) -> list[ToolParameter]:
parameters = []

options = []
for provider, model, voices in self.get_available_models():
option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
options.append(option)
parameters.append(
ToolParameter(
name=f"voice#{provider}#{model}",
label=I18nObject(en_US=f"Voice of {model}({provider})"),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
options=[
ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
for voice in voices
],
)
)

parameters.insert(
0,
ToolParameter(
name="model",
label=I18nObject(en_US="Model", zh_Hans="Model"),
human_description=I18nObject(
en_US="All available TTS models",
zh_Hans="所有可用的 TTS 模型",
),
type=ToolParameter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
required=True,
default=options[0].value,
options=options,
),
)
return parameters
22 changes: 22 additions & 0 deletions api/core/tools/provider/builtin/audio/tools/tts.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
identity:
name: tts
author: hjlarry
label:
en_US: Text To Speech
description:
human:
en_US: Convert text to audio file.
zh_Hans: 将文本转换为音频文件。
llm: Convert text to audio file.
parameters:
- name: text
type: string
required: true
label:
en_US: Text
zh_Hans: 文本
human_description:
en_US: The text to be converted.
zh_Hans: 要转换的文本。
llm_description: The text to be converted.
form: llm