diff --git a/pyproject.toml b/pyproject.toml index 5dbd64eb..274e9e81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "funasr==1.1.5", "opencc-python-reimplemented==0.1.7", "silero-vad", + "ormsgpack", ] [project.optional-dependencies] diff --git a/tools/api.py b/tools/api.py index bf27f25b..945b9739 100644 --- a/tools/api.py +++ b/tools/api.py @@ -39,7 +39,7 @@ from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture from fish_speech.text.chn_text_norm.text import Text as ChnNormedText from fish_speech.utils import autocast_exclude_mps -from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model +from tools.commons import ServeReferenceAudio, ServeTTSRequest from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text from tools.llama.generate import ( GenerateRequest, @@ -156,38 +156,6 @@ def decode_vq_tokens( routes = MultimethodRoutes(base_class=HttpView) -class ServeReferenceAudio(BaseModel): - audio: bytes - text: str - - -class ServeTTSRequest(BaseModel): - text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游." - chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 - # Audio format - format: Literal["wav", "pcm", "mp3"] = "wav" - mp3_bitrate: Literal[64, 128, 192] = 128 - # References audios for in-context learning - references: list[ServeReferenceAudio] = [] - # Reference id - # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/ - # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 - reference_id: str | None = None - # Normalize text for en & zh, this increase stability for numbers - normalize: bool = True - mp3_bitrate: Optional[int] = 64 - opus_bitrate: Optional[int] = -1000 - # Balance mode will reduce latency to 300ms, but may decrease stability - latency: Literal["normal", "balanced"] = "normal" - # not usually used below - streaming: bool = False - emotion: Optional[str] = None - max_new_tokens: int = 1024 - top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 - repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 - temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 - - def get_content_type(audio_format): if audio_format == "wav": return "audio/wav" diff --git a/tools/commons.py b/tools/commons.py new file mode 100644 index 00000000..f81cadec --- /dev/null +++ b/tools/commons.py @@ -0,0 +1,35 @@ +from typing import Annotated, Literal, Optional + +from pydantic import BaseModel, Field, conint + + +class ServeReferenceAudio(BaseModel): + audio: bytes + text: str + + +class ServeTTSRequest(BaseModel): + text: str + chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 + # Audio format + format: Literal["wav", "pcm", "mp3"] = "wav" + mp3_bitrate: Literal[64, 128, 192] = 128 + # References audios for in-context learning + references: list[ServeReferenceAudio] = [] + # Reference id + # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/ + # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 + reference_id: str | None = None + # Normalize text for en & zh, this increase stability for numbers + normalize: bool = True + mp3_bitrate: Optional[int] = 64 + opus_bitrate: Optional[int] = -1000 + # Balance mode will reduce latency to 300ms, but may decrease stability + latency: Literal["normal", "balanced"] = "normal" + # not usually used below + streaming: bool = False + emotion: Optional[str] = None + max_new_tokens: int = 1024 + top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 + repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 + temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 diff --git a/tools/msgpack_api.py b/tools/msgpack_api.py index 52f2220d..67f907bf 100644 --- a/tools/msgpack_api.py +++ b/tools/msgpack_api.py @@ -1,41 +1,7 @@ -from typing import Annotated, AsyncGenerator, Literal, Optional - import httpx import ormsgpack -from pydantic import AfterValidator, BaseModel, Field, conint - - -class ServeReferenceAudio(BaseModel): - audio: bytes - text: str - - -class ServeTTSRequest(BaseModel): - text: str - chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 - # Audio format - format: Literal["wav", "pcm", "mp3"] = "wav" - mp3_bitrate: Literal[64, 128, 192] = 128 - # References audios for in-context learning - references: list[ServeReferenceAudio] = [] - # Reference id - # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/ - # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 - reference_id: str | None = None - # Normalize text for en & zh, this increase stability for numbers - normalize: bool = True - mp3_bitrate: Optional[int] = 64 - opus_bitrate: Optional[int] = -1000 - # Balance mode will reduce latency to 300ms, but may decrease stability - latency: Literal["normal", "balanced"] = "normal" - # not usually used below - streaming: bool = False - emotion: Optional[str] = None - max_new_tokens: int = 1024 - top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 - repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 - temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 +from tools.commons import ServeReferenceAudio, ServeTTSRequest # priority: ref_id > references request = ServeTTSRequest( diff --git a/tools/post_api.py b/tools/post_api.py index 79c03cb6..c20dc455 100644 --- a/tools/post_api.py +++ b/tools/post_api.py @@ -1,13 +1,14 @@ import argparse import base64 import wave -from pathlib import Path +import ormsgpack import pyaudio import requests from pydub import AudioSegment from pydub.playback import play +from tools.commons import ServeReferenceAudio, ServeTTSRequest from tools.file import audio_to_bytes, read_ref_text @@ -113,20 +114,26 @@ def parse_args(): idstr: str | None = args.reference_id # priority: ref_id > [{text, audio},...] if idstr is None: - base64_audios = [ - audio_to_bytes(ref_audio) for ref_audio in args.reference_audio - ] - ref_texts = [read_ref_text(ref_text) for ref_text in args.reference_text] + ref_audios = args.reference_audio + ref_texts = args.reference_text + if ref_audios is None: + byte_audios = [] + else: + byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios] + if ref_texts is None: + ref_texts = [] + else: + ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts] else: - base64_audios = [] + byte_audios = [] ref_texts = [] pass # in api.py data = { "text": args.text, "references": [ - dict(text=ref_text, audio=ref_audio) - for ref_text, ref_audio in zip(ref_texts, base64_audios) + ServeReferenceAudio(audio=ref_audio, text=ref_text) + for ref_text, ref_audio in zip(ref_texts, byte_audios) ], "reference_id": idstr, "normalize": args.normalize, @@ -143,7 +150,17 @@ def parse_args(): "streaming": args.streaming, } - response = requests.post(args.url, json=data, stream=args.streaming) + pydantic_data = ServeTTSRequest(**data) + + response = requests.post( + args.url, + data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), + stream=args.streaming, + headers={ + "authorization": "Bearer YOUR_API_KEY", + "content-type": "application/msgpack", + }, + ) if response.status_code == 200: if args.streaming: