Skip to content

Commit

Permalink
fully support ormsgpack (#518)
Browse files Browse the repository at this point in the history
* fully support ormsgpack

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* dependency

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AnyaCoder and pre-commit-ci[bot] authored Sep 8, 2024
1 parent 0956e02 commit 237f4fd
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 77 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"funasr==1.1.5",
"opencc-python-reimplemented==0.1.7",
"silero-vad",
"ormsgpack",
]

[project.optional-dependencies]
Expand Down
34 changes: 1 addition & 33 deletions tools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
from fish_speech.utils import autocast_exclude_mps
from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
from tools.commons import ServeReferenceAudio, ServeTTSRequest
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
from tools.llama.generate import (
GenerateRequest,
Expand Down Expand Up @@ -156,38 +156,6 @@ def decode_vq_tokens(
routes = MultimethodRoutes(base_class=HttpView)


class ServeReferenceAudio(BaseModel):
audio: bytes
text: str


class ServeTTSRequest(BaseModel):
text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "wav"
mp3_bitrate: Literal[64, 128, 192] = 128
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
mp3_bitrate: Optional[int] = 64
opus_bitrate: Optional[int] = -1000
# Balance mode will reduce latency to 300ms, but may decrease stability
latency: Literal["normal", "balanced"] = "normal"
# not usually used below
streaming: bool = False
emotion: Optional[str] = None
max_new_tokens: int = 1024
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7


def get_content_type(audio_format):
if audio_format == "wav":
return "audio/wav"
Expand Down
35 changes: 35 additions & 0 deletions tools/commons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Annotated, Literal, Optional

from pydantic import BaseModel, Field, conint


class ServeReferenceAudio(BaseModel):
audio: bytes
text: str


class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "wav"
mp3_bitrate: Literal[64, 128, 192] = 128
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
mp3_bitrate: Optional[int] = 64
opus_bitrate: Optional[int] = -1000
# Balance mode will reduce latency to 300ms, but may decrease stability
latency: Literal["normal", "balanced"] = "normal"
# not usually used below
streaming: bool = False
emotion: Optional[str] = None
max_new_tokens: int = 1024
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
36 changes: 1 addition & 35 deletions tools/msgpack_api.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,7 @@
from typing import Annotated, AsyncGenerator, Literal, Optional

import httpx
import ormsgpack
from pydantic import AfterValidator, BaseModel, Field, conint


class ServeReferenceAudio(BaseModel):
audio: bytes
text: str


class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "wav"
mp3_bitrate: Literal[64, 128, 192] = 128
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
mp3_bitrate: Optional[int] = 64
opus_bitrate: Optional[int] = -1000
# Balance mode will reduce latency to 300ms, but may decrease stability
latency: Literal["normal", "balanced"] = "normal"
# not usually used below
streaming: bool = False
emotion: Optional[str] = None
max_new_tokens: int = 1024
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7

from tools.commons import ServeReferenceAudio, ServeTTSRequest

# priority: ref_id > references
request = ServeTTSRequest(
Expand Down
35 changes: 26 additions & 9 deletions tools/post_api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import argparse
import base64
import wave
from pathlib import Path

import ormsgpack
import pyaudio
import requests
from pydub import AudioSegment
from pydub.playback import play

from tools.commons import ServeReferenceAudio, ServeTTSRequest
from tools.file import audio_to_bytes, read_ref_text


Expand Down Expand Up @@ -113,20 +114,26 @@ def parse_args():
idstr: str | None = args.reference_id
# priority: ref_id > [{text, audio},...]
if idstr is None:
base64_audios = [
audio_to_bytes(ref_audio) for ref_audio in args.reference_audio
]
ref_texts = [read_ref_text(ref_text) for ref_text in args.reference_text]
ref_audios = args.reference_audio
ref_texts = args.reference_text
if ref_audios is None:
byte_audios = []
else:
byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
if ref_texts is None:
ref_texts = []
else:
ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
else:
base64_audios = []
byte_audios = []
ref_texts = []
pass # in api.py

data = {
"text": args.text,
"references": [
dict(text=ref_text, audio=ref_audio)
for ref_text, ref_audio in zip(ref_texts, base64_audios)
ServeReferenceAudio(audio=ref_audio, text=ref_text)
for ref_text, ref_audio in zip(ref_texts, byte_audios)
],
"reference_id": idstr,
"normalize": args.normalize,
Expand All @@ -143,7 +150,17 @@ def parse_args():
"streaming": args.streaming,
}

response = requests.post(args.url, json=data, stream=args.streaming)
pydantic_data = ServeTTSRequest(**data)

response = requests.post(
args.url,
data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
stream=args.streaming,
headers={
"authorization": "Bearer YOUR_API_KEY",
"content-type": "application/msgpack",
},
)

if response.status_code == 200:
if args.streaming:
Expand Down

0 comments on commit 237f4fd

Please sign in to comment.