Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Minimax model #774

Merged
merged 2 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config_example.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
"usage_limit": 120, // API Key的当月限额,单位:美元
// 你的xmchat API Key,与OpenAI API Key不同
"xmchat_api_key": "",
// MiniMax的APIKey(见账户管理页面 https://api.minimax.chat/basic-information)和Group ID,用于MiniMax对话模型
"minimax_api_key": "",
"minimax_group_id": "",
"language": "auto",
// 如果使用代理,请取消注释下面的两行,并替换代理URL
// "https_proxy": "http://127.0.0.1:1079",
Expand Down
5 changes: 5 additions & 0 deletions modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@
xmchat_api_key = config.get("xmchat_api_key", "")
os.environ["XMCHAT_API_KEY"] = xmchat_api_key

minimax_api_key = config.get("minimax_api_key", "")
os.environ["MINIMAX_API_KEY"] = minimax_api_key
minimax_group_id = config.get("minimax_group_id", "")
os.environ["MINIMAX_GROUP_ID"] = minimax_group_id

render_latex = config.get("render_latex", True)

if render_latex:
Expand Down
3 changes: 3 additions & 0 deletions modules/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ModelType(Enum):
StableLM = 4
MOSS = 5
YuanAI = 6
Minimax = 7

@classmethod
def get_type(cls, model_name: str):
Expand All @@ -53,6 +54,8 @@ def get_type(cls, model_name: str):
model_type = ModelType.MOSS
elif "yuanai" in model_name_lower:
model_type = ModelType.YuanAI
elif "minimax" in model_name_lower:
model_type = ModelType.Minimax
else:
model_type = ModelType.Unknown
return model_type
Expand Down
161 changes: 161 additions & 0 deletions modules/models/minimax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import json
import os

import colorama
import requests
import logging

from modules.models.base_model import BaseLLMModel
from modules.presets import STANDARD_ERROR_MSG, GENERAL_ERROR_MSG, TIMEOUT_STREAMING, TIMEOUT_ALL, i18n

group_id = os.environ.get("MINIMAX_GROUP_ID", "")


class MiniMax_Client(BaseLLMModel):
"""
MiniMax Client
接口文档见 https://api.minimax.chat/document/guides/chat
"""

def __init__(self, model_name, api_key, user_name="", system_prompt=None):
super().__init__(model_name=model_name, user=user_name)
self.url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}'
self.history = []
self.api_key = api_key
self.system_prompt = system_prompt
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}

def get_answer_at_once(self):
# minimax temperature is (0,1] and base model temperature is [0,2], and yuan 0.9 == base 1 so need to convert
temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10

request_body = {
"model": self.model_name.replace('minimax-', ''),
"temperature": temperature,
"skip_info_mask": True,
'messages': [{"sender_type": "USER", "text": self.history[-1]['content']}]
}
if self.n_choices:
request_body['beam_width'] = self.n_choices
if self.system_prompt:
request_body['prompt'] = self.system_prompt
if self.max_generation_token:
request_body['tokens_to_generate'] = self.max_generation_token
if self.top_p:
request_body['top_p'] = self.top_p

response = requests.post(self.url, headers=self.headers, json=request_body)

res = response.json()
answer = res['reply']
total_token_count = res["usage"]["total_tokens"]
return answer, total_token_count

def get_answer_stream_iter(self):
response = self._get_response(stream=True)
if response is not None:
iter = self._decode_chat_response(response)
partial_text = ""
for i in iter:
partial_text += i
yield partial_text
else:
yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG

def _get_response(self, stream=False):
minimax_api_key = self.api_key
history = self.history
logging.debug(colorama.Fore.YELLOW +
f"{history}" + colorama.Fore.RESET)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {minimax_api_key}",
}

temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10

messages = []
for msg in self.history:
if msg['role'] == 'user':
messages.append({"sender_type": "USER", "text": msg['content']})
else:
messages.append({"sender_type": "BOT", "text": msg['content']})

request_body = {
"model": self.model_name.replace('minimax-', ''),
"temperature": temperature,
"skip_info_mask": True,
'messages': messages
}
if self.n_choices:
request_body['beam_width'] = self.n_choices
if self.system_prompt:
lines = self.system_prompt.splitlines()
if lines[0].find(":") != -1 and len(lines[0]) < 20:
request_body["role_meta"] = {
"user_name": lines[0].split(":")[0],
"bot_name": lines[0].split(":")[1]
}
lines.pop()
request_body["prompt"] = "\n".join(lines)
if self.max_generation_token:
request_body['tokens_to_generate'] = self.max_generation_token
else:
request_body['tokens_to_generate'] = 512
if self.top_p:
request_body['top_p'] = self.top_p

if stream:
timeout = TIMEOUT_STREAMING
request_body['stream'] = True
request_body['use_standard_sse'] = True
else:
timeout = TIMEOUT_ALL
try:
response = requests.post(
self.url,
headers=headers,
json=request_body,
stream=stream,
timeout=timeout,
)
except:
return None

return response

def _decode_chat_response(self, response):
error_msg = ""
for chunk in response.iter_lines():
if chunk:
chunk = chunk.decode()
chunk_length = len(chunk)
print(chunk)
try:
chunk = json.loads(chunk[6:])
except json.JSONDecodeError:
print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
error_msg += chunk
continue
if chunk_length > 6 and "delta" in chunk["choices"][0]:
if "finish_reason" in chunk["choices"][0] and chunk["choices"][0]["finish_reason"] == "stop":
self.all_token_counts.append(chunk["usage"]["total_tokens"] - sum(self.all_token_counts))
break
try:
yield chunk["choices"][0]["delta"]
except Exception as e:
logging.error(f"Error: {e}")
continue
if error_msg:
try:
error_msg = json.loads(error_msg)
if 'base_resp' in error_msg:
status_code = error_msg['base_resp']['status_code']
status_msg = error_msg['base_resp']['status_msg']
raise Exception(f"{status_code} - {status_msg}")
except json.JSONDecodeError:
pass
raise Exception(error_msg)
5 changes: 5 additions & 0 deletions modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,11 @@ def get_model(
elif model_type == ModelType.YuanAI:
from .inspurai import Yuan_Client
model = Yuan_Client(model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
elif model_type == ModelType.Minimax:
from .minimax import MiniMax_Client
if os.environ.get("MINIMAX_API_KEY") != "":
access_key = os.environ.get("MINIMAX_API_KEY")
model = MiniMax_Client(model_name, api_key=access_key, user_name=user_name, system_prompt=system_prompt)
elif model_type == ModelType.Unknown:
raise ValueError(f"未知模型: {model_name}")
logging.info(msg)
Expand Down
2 changes: 2 additions & 0 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@
"yuanai-1.0-translate",
"yuanai-1.0-dialog",
"yuanai-1.0-rhythm_poems",
"minimax-abab4-chat",
"minimax-abab5-chat",
]

LOCAL_MODELS = [
Expand Down