Skip to content

Commit

Permalink
feat: 加入gpt-3.5-turbo-instruct模型支持
Browse files Browse the repository at this point in the history
  • Loading branch information
GaiZhenbiao committed Oct 12, 2023
1 parent fc2938f commit e99bd71
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 15 deletions.
12 changes: 6 additions & 6 deletions modules/models/OpenAI.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,13 @@ def _get_response(self, stream=False):
timeout = TIMEOUT_ALL

# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
if shared.state.completion_url != COMPLETION_URL:
logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")

with retrieve_proxy():
try:
response = requests.post(
shared.state.completion_url,
shared.state.chat_completion_url,
headers=headers,
json=payload,
stream=stream,
Expand Down Expand Up @@ -237,12 +237,12 @@ def _single_query_at_once(self, history, temperature=1.0):
"messages": history,
}
# 如果有自定义的api-host,使用自定义host发送请求,否则使用默认设置发送请求
if shared.state.completion_url != COMPLETION_URL:
logging.debug(f"使用自定义API URL: {shared.state.completion_url}")
if shared.state.chat_completion_url != CHAT_COMPLETION_URL:
logging.debug(f"使用自定义API URL: {shared.state.chat_completion_url}")

with retrieve_proxy():
response = requests.post(
shared.state.completion_url,
shared.state.chat_completion_url,
headers=headers,
json=payload,
stream=False,
Expand Down
27 changes: 27 additions & 0 deletions modules/models/OpenAIInstruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import openai
from .base_model import BaseLLMModel
from .. import shared
from ..config import retrieve_proxy


class OpenAI_Instruct_Client(BaseLLMModel):
def __init__(self, model_name, api_key, user_name="") -> None:
super().__init__(model_name=model_name, user=user_name)
self.api_key = api_key

def _get_instruct_style_input(self):
return "\n\n".join([item["content"] for item in self.history])

@shared.state.switching_api_key
def get_answer_at_once(self):
prompt = self._get_instruct_style_input()
with retrieve_proxy():
response = openai.Completion.create(
api_key=self.api_key,
api_base=shared.state.openai_api_base,
model=self.model_name,
prompt=prompt,
temperature=self.temperature,
top_p=self.top_p,
)
return response.choices[0].text.strip(), response.usage["total_tokens"]
8 changes: 6 additions & 2 deletions modules/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,17 @@ class ModelType(Enum):
LangchainChat = 10
Midjourney = 11
Spark = 12
OpenAIInstruct = 13

@classmethod
def get_type(cls, model_name: str):
model_type = None
model_name_lower = model_name.lower()
if "gpt" in model_name_lower:
model_type = ModelType.OpenAI
if "instruct" in model_name_lower:
model_type = ModelType.OpenAIInstruct
else:
model_type = ModelType.OpenAI
elif "chatglm" in model_name_lower:
model_type = ModelType.ChatGLM
elif "llama" in model_name_lower or "alpaca" in model_name_lower:
Expand Down Expand Up @@ -247,7 +251,7 @@ def get_answer_at_once(self):

def billing_info(self):
"""get billing infomation, inplement if needed"""
logging.warning("billing info not implemented, using default")
# logging.warning("billing info not implemented, using default")
return BILLING_NOT_APPLICABLE_MSG

def count_token(self, user_input):
Expand Down
6 changes: 6 additions & 0 deletions modules/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def get_model(
top_p=top_p,
user_name=user_name,
)
elif model_type == ModelType.OpenAIInstruct:
logging.info(f"正在加载OpenAI Instruct模型: {model_name}")
from .OpenAIInstruct import OpenAI_Instruct_Client
access_key = os.environ.get("OPENAI_API_KEY", access_key)
model = OpenAI_Instruct_Client(
model_name, api_key=access_key, user_name=user_name)
elif model_type == ModelType.ChatGLM:
logging.info(f"正在加载ChatGLM模型: {model_name}")
from .ChatGLM import ChatGLM_Client
Expand Down
7 changes: 5 additions & 2 deletions modules/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# ChatGPT 设置
INITIAL_SYSTEM_PROMPT = "You are a helpful assistant."
API_HOST = "api.openai.com"
COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
OPENAI_API_BASE = "https://api.openai.com/v1"
CHAT_COMPLETION_URL = "https://api.openai.com/v1/chat/completions"
COMPLETION_URL = "https://api.openai.com/v1/completions"
BALANCE_API_URL="https://api.openai.com/dashboard/billing/credit_grants"
USAGE_API_URL="https://api.openai.com/dashboard/billing/usage"
HISTORY_DIR = Path("history")
Expand Down Expand Up @@ -50,10 +52,11 @@

ONLINE_MODELS = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
"gpt-3.5-turbo-16k",
"gpt-4",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
Expand Down
12 changes: 7 additions & 5 deletions modules/shared.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from modules.presets import COMPLETION_URL, BALANCE_API_URL, USAGE_API_URL, API_HOST
from modules.presets import CHAT_COMPLETION_URL, BALANCE_API_URL, USAGE_API_URL, API_HOST, OPENAI_API_BASE
import os
import queue
import openai

class State:
interrupted = False
multi_api_key = False
completion_url = COMPLETION_URL
chat_completion_url = CHAT_COMPLETION_URL
balance_api_url = BALANCE_API_URL
usage_api_url = USAGE_API_URL
openai_api_base = OPENAI_API_BASE

def interrupt(self):
self.interrupted = True
Expand All @@ -22,21 +23,22 @@ def set_api_host(self, api_host: str):
api_host = f"https://{api_host}"
if api_host.endswith("/v1"):
api_host = api_host[:-3]
self.completion_url = f"{api_host}/v1/chat/completions"
self.chat_completion_url = f"{api_host}/v1/chat/completions"
self.openai_api_base = f"{api_host}/v1"
self.balance_api_url = f"{api_host}/dashboard/billing/credit_grants"
self.usage_api_url = f"{api_host}/dashboard/billing/usage"
os.environ["OPENAI_API_BASE"] = api_host

def reset_api_host(self):
self.completion_url = COMPLETION_URL
self.chat_completion_url = CHAT_COMPLETION_URL
self.balance_api_url = BALANCE_API_URL
self.usage_api_url = USAGE_API_URL
os.environ["OPENAI_API_BASE"] = f"https://{API_HOST}"
return API_HOST

def reset_all(self):
self.interrupted = False
self.completion_url = COMPLETION_URL
self.chat_completion_url = CHAT_COMPLETION_URL

def set_api_key_queue(self, api_key_list):
self.multi_api_key = True
Expand Down

0 comments on commit e99bd71

Please sign in to comment.