Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New fix #40

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,69 @@
```bash
export HF_ENDPOINT=https://hf-mirror.com
```
## api接口使用

启动接口服务

```
python3 api.py
```

get请求:

```
http://localhost:9880/?text=我问她月饼爱吃咸的还是甜的,那天老色鬼说,你身上的月饼,自然是甜过了蜜糖。&seed=2
```

post请求:

```
import requests
import json
data = json.dumps({"text":"我问她月饼爱吃咸的还是甜的,那天老色鬼说,你身上的月饼,自然是甜过了蜜糖","seed":2})
url = "http://localhost:9880"
headers = {"Content-Type":"application/json"}
r = requests.post(url,data=data,headers=headers)
audio_data = r.content

with open(f"测试音频.wav","wb") as f:
f.write(audio_data)
```

接口支持音色向量角色参数roleid,参见webui,支持流式参数:streaming=1

## 本地ollama大模型生成角色剧本

下载权重:https://huggingface.co/shenzhi-wang/Mistral-7B-v0.3-Chinese-Chat-4bit

新建template文件:Modelflie_mistral7b0.3.txt

```
FROM .\Mistral-7B-v0.3-Chinese-Chat-q4_0.gguf

#设置模型温度(值越小回答越严谨,值越大回答越发散)
PARAMETER temperature 0.8

#设置上下文token尺寸
PARAMETER num_ctx 4096

TEMPLATE """[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]
"""

SYSTEM """你是一名作家,文辞优美,擅长转写话剧和影视剧的剧本"""
```

导入模型:

```
ollama create nsfw -f Modelflie_mistral7b0.3.txt
```

启动服务:

```
ollama serve
```

## 贡献

Expand Down
38 changes: 30 additions & 8 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from starlette.middleware.cors import CORSMiddleware #引入 CORS中间件模块

import json

#设置允许访问的域名
origins = ["*"] #"*",即为所有。

Expand Down Expand Up @@ -59,8 +61,9 @@ def deterministic(seed=0):

class TTS_Request(BaseModel):
text: str = None
seed: int = 2581
seed: int = 7750
speed: int = 3
roleid: str = None
media_type: str = "wav"
streaming: int = 0

Expand Down Expand Up @@ -156,7 +159,7 @@ def pack_audio(io_buffer:BytesIO, data:np.ndarray, rate:int, media_type:str):
return io_buffer


def generate_tts_audio(text_file,seed=2581,speed=3, oral=0, laugh=0, bk=4, min_length=10, batch_size=5, temperature=0.3, top_P=0.7,
def generate_tts_audio(text_file,seed=7750,roleid=None,speed=3, oral=0, laugh=0, bk=4, min_length=10, batch_size=5, temperature=0.3, top_P=0.7,
top_K=20,streaming=0,cur_tqdm=None):

from utils import combine_audio, save_audio, batch_split
Expand All @@ -178,8 +181,24 @@ def generate_tts_audio(text_file,seed=2581,speed=3, oral=0, laugh=0, bk=4, min_l
refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]"


deterministic(seed)
rnd_spk_emb = chat.sample_random_speaker()

if not roleid:
deterministic(seed)
rnd_spk_emb = chat.sample_random_speaker()
else:

# 从 JSON 文件中读取数据
with open('./slct_voice_240605.json', 'r', encoding='utf-8') as json_file:
slct_idx_loaded = json.load(json_file)

# 将包含 Tensor 数据的部分转换回 Tensor 对象
for key in slct_idx_loaded:
tensor_list = slct_idx_loaded[key]["tensor"]
slct_idx_loaded[key]["tensor"] = torch.tensor(tensor_list)

# 将音色 tensor 打包进params_infer_code,固定使用此音色发音,调低temperature
rnd_spk_emb = slct_idx_loaded[roleid]["tensor"]

params_infer_code = {
'spk_emb': rnd_spk_emb,
'prompt': f'[speed_{speed}]',
Expand Down Expand Up @@ -265,8 +284,10 @@ async def tts_handle(req:dict):
print(req["streaming"])

if not req["streaming"]:

print(req["roleid"])

audio_data = next(generate_tts_audio(req["text"],req["seed"]))
audio_data = next(generate_tts_audio(req["text"],req["seed"]),req["roleid"])

# print(audio_data)

Expand All @@ -282,7 +303,7 @@ async def tts_handle(req:dict):

else:

tts_generator = generate_tts_audio(req["text"],req["seed"],streaming=1)
tts_generator = generate_tts_audio(req["text"],req["seed"],roleid=req["roleid"],streaming=1)

sr = 24000

Expand All @@ -297,14 +318,15 @@ def streaming_generator(tts_generator:Generator, media_type:str):


@app.get("/")
async def tts_get(text: str = None,media_type:str = "wav",seed:int = 2581,streaming:int = 0):
async def tts_get(text: str = None,media_type:str = "wav",seed:int = 2581,streaming:int = 0,roleid:str = None):
req = {
"text": text,
"media_type": media_type,
"seed": seed,
"streaming": streaming,
"roleid": roleid,
}
print("第一次")

return await tts_handle(req)


Expand Down
38 changes: 38 additions & 0 deletions llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,44 @@
from tqdm import tqdm
from config import LLM_RETRIES, LLM_REQUEST_INTERVAL, LLM_RETRY_DELAY, LLM_MAX_TEXT_LENGTH, LLM_PROMPT

import requests


def ollama_generate(url,model,text):
from config import LLM_PROMPT


LLM_PROMPT += f"\n注意,只返回json即可,不要返回其他格式\n{text}"

data = {
"model":model,
"prompt":LLM_PROMPT,
"stream":False
}

data = json.dumps(data)
headers = {"Content-Type":"application/json"}#指定提交的是json
r = requests.post(url,data=data,headers=headers)

try:
res = json.loads(r.text)

clist = json.loads(res["response"].replace("```json","").replace("```",""))

texts = ""

for x in clist:

texts += f"{x['character']}::{x['txt']}\n"
except Exception as e:

print(str(e))

texts = "报错了,请重新生成 \n"


return texts


def send_request(client, prompt, text, model):
text = remove_json_escape_characters(text)
Expand Down
27 changes: 24 additions & 3 deletions webui_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from config import DEFAULT_BATCH_SIZE, DEFAULT_SPEED, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P, DEFAULT_ORAL, \
DEFAULT_LAUGH, DEFAULT_BK, DEFAULT_SEG_LENGTH

from llm_utils import ollama_generate

parser = argparse.ArgumentParser(description="Gradio ChatTTS MIX")
parser.add_argument("--source", type=str, default="huggingface", help="Model source: 'huggingface' or 'local'.")
parser.add_argument("--local_path", type=str, help="Path to local model if source is 'local'.")
Expand Down Expand Up @@ -55,6 +57,15 @@
# chat.load_models(source="local", local_path="models")
# torch.cuda.empty_cache()

def file_show(file):
if file is None:
return ""
try:
with open(file.name, "r", encoding="utf-8") as f:
text = f.read()
return text
except Exception as error:
return error

# 加载
def load_seeds():
Expand Down Expand Up @@ -435,7 +446,8 @@ def inser_token(text, btn):
break_button = gr.Button("+停顿", variant="secondary")
laugh_button = gr.Button("+笑声", variant="secondary")
refine_button = gr.Button("Refine Text(预处理 加入停顿词、笑声等)", variant="secondary")

srt_file = gr.File(label="上传文本",file_types=['.txt'],file_count='single')
srt_file.change(file_show,inputs=[srt_file],outputs=[text_file_input])
with gr.Column():
gr.Markdown("### 配置参数")
with gr.Row():
Expand Down Expand Up @@ -587,7 +599,8 @@ def llm_change(model):
"gpt-3.5-turbo-0125": ["https://api.openai.com/v1"],
"gpt-4o": ["https://api.openai.com/v1"],
"deepseek-chat": ["https://api.deepseek.com"],
"yi-large": ["https://api.lingyiwanwu.com/v1"]
"yi-large": ["https://api.lingyiwanwu.com/v1"],
"本地ollama": ["http://localhost:11434/api/generate"]
}
if model in llm_setting:
return llm_setting[model][0]
Expand Down Expand Up @@ -775,20 +788,22 @@ def batch(iterable, batch_size):

with gr.Row(equal_height=True):
# 选择模型 只有 gpt4o deepseek-chat yi-large 三个选项
model_select = gr.Radio(label="选择模型", choices=["gpt-4o", "deepseek-chat", "yi-large"],
model_select = gr.Radio(label="选择模型", choices=["gpt-4o", "deepseek-chat", "yi-large","本地ollama"],
value="gpt-4o", interactive=True, )
with gr.Row(equal_height=True):
openai_api_base_input = gr.Textbox(label="OpenAI API Base URL",
placeholder="请输入API Base URL",
value=r"https://api.openai.com/v1")
openai_api_key_input = gr.Textbox(label="OpenAI API Key", placeholder="请输入API Key",
value="sk-xxxxxxx",type="password")
ollama_model = gr.Textbox(label="ollama本地模型名称",value="nsfw")
# AI提示词
ai_text_input = gr.Textbox(label="剧情简介或者一段故事", placeholder="请输入文本...", lines=2,
value=ai_text_default)

# 生成脚本的按钮
ai_script_generate_button = gr.Button("AI脚本生成")
ai_script_generate_button_ollama = gr.Button("本地Ollama模型AI脚本生成")

with gr.Column(scale=3):
gr.Markdown("### 脚本")
Expand Down Expand Up @@ -844,6 +859,12 @@ def batch(iterable, batch_size):
inputs=[model_select, openai_api_base_input, openai_api_key_input, ai_text_input],
outputs=[script_text_input]
)
# AI脚本生成 ollama
ai_script_generate_button_ollama.click(
ollama_generate,
inputs=[openai_api_base_input,ollama_model,ai_text_input],
outputs=[script_text_input]
)
# 音频生成
script_generate_audio.click(
generate_script_audio,
Expand Down