Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

不再强制检查API,新增非流API请求 (Sourcery refactored) #1048

Merged
merged 8 commits into from
Jul 12, 2023
112 changes: 65 additions & 47 deletions adapter/chatgpt/api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import json
import time
from typing import AsyncGenerator

import aiohttp
import async_timeout

import tiktoken
from loguru import logger
from typing import AsyncGenerator

from adapter.botservice import BotAdapter
from config import OpenAIAPIKey
from constants import botManager, config
import tiktoken

DEFAULT_ENGINE: str = "gpt-3.5-turbo"

Expand Down Expand Up @@ -53,12 +51,7 @@ async def rollback(self, session_id: str = "default", n: int = 1) -> None:
logger.error(f"未知错误: {e}")
raise

def add_to_conversation(
self,
message: str,
role: str,
session_id: str = "default",
) -> None:
def add_to_conversation(self, message: str, role: str, session_id: str = "default") -> None:
self.conversation[session_id].append({"role": role, "content": message})

# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
Expand All @@ -71,26 +64,8 @@ def count_tokens(self, session_id: str = "default", model: str = DEFAULT_ENGINE)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")

if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-3.5-turbo",
"gpt-4",
"gpt-4-32k"
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows {role/name}\n{content}\n
tokens_per_name = -1 # if there's a name, the role is omitted
else:
logger.warning("未找到相应模型计算方法,使用默认方法进行计算")
tokens_per_message = 3
tokens_per_name = 1
tokens_per_message = 4
tokens_per_name = 1

num_tokens = 0
for message in self.conversation[session_id]:
Expand Down Expand Up @@ -168,40 +143,80 @@ async def on_reset(self):
self.bot.engine = self.api_info.model
self.__conversation_keep_from = 0

async def request_with_stream(self, session_id: str = None, messages: list = None) -> AsyncGenerator[str, None]:
self.api_info = botManager.pick('openai-api')
api_key = self.api_info.api_key
proxy = self.api_info.proxy
api_endpoint = config.openai.api_endpoint or "https://api.openai.com/v1"

if not messages:
messages = self.bot.conversation[session_id]

def construct_data(self, messages: list = None, api_key: str = None, stream: bool = True):
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
data = {
'model': self.bot.engine,
'messages': messages,
'stream': True,
'stream': stream,
'temperature': self.bot.temperature,
'top_p': self.bot.top_p,
'presence_penalty': self.bot.presence_penalty,
'frequency_penalty': self.bot.frequency_penalty,
"user": 'user',
'max_tokens': self.bot.get_max_tokens(self.session_id, self.bot.engine),
}
return headers, data

def _prepare_request(self, session_id: str = None, messages: list = None, stream: bool = False):
self.api_info = botManager.pick('openai-api')
api_key = self.api_info.api_key
proxy = self.api_info.proxy
api_endpoint = config.openai.api_endpoint or "https://api.openai.com/v1"

if not messages:
messages = self.bot.conversation[session_id]

headers, data = self.construct_data(messages, api_key, stream)

return api_key, proxy, api_endpoint, headers, data

async def _process_response(self, resp, session_id: str = None):

result = await resp.json()

total_tokens = result.get('usage', {}).get('total_tokens', None)
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 使用 token 数:{total_tokens}")
if total_tokens is None:
raise Exception("Response does not contain 'total_tokens'")

content = result.get('choices', [{}])[0].get('message', {}).get('content', None)
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 响应:{content}")
if content is None:
raise Exception("Response does not contain 'content'")

response_role = result.get('choices', [{}])[0].get('message', {}).get('role', None)
if response_role is None:
raise Exception("Response does not contain 'role'")

self.bot.add_to_conversation(content, response_role, session_id)

return content

async def request(self, session_id: str = None, messages: list = None) -> str:
api_key, proxy, api_endpoint, headers, data = self._prepare_request(session_id, messages, stream=False)

async with aiohttp.ClientSession() as session:
with async_timeout.timeout(self.bot.timeout):
async with session.post(f'{api_endpoint}/chat/completions', headers=headers, data=json.dumps(data),
proxy=proxy) as resp:
async with session.post(f'{api_endpoint}/chat/completions', headers=headers,
data=json.dumps(data)) as resp:
if resp.status != 200:
response_text = await resp.text()
raise Exception(
f"{resp.status} {resp.reason} {response_text}",
)
return await self._process_response(resp, session_id)

async def request_with_stream(self, session_id: str = None, messages: list = None) -> AsyncGenerator[str, None]:
api_key, proxy, api_endpoint, headers, data = self._prepare_request(session_id, messages, stream=True)

async with aiohttp.ClientSession() as session:
with async_timeout.timeout(self.bot.timeout):
async with session.post(f'{api_endpoint}/chat/completions', headers=headers, data=json.dumps(data),
proxy=proxy) as resp:
response_role: str = ''
completion_text: str = ''

Expand Down Expand Up @@ -273,12 +288,15 @@ async def ask(self, prompt: str) -> AsyncGenerator[str, None]:
self.bot.add_to_conversation(prompt, "user", session_id=self.session_id)
start_time = time.time()

async for completion_text in self.request_with_stream(session_id=self.session_id):
yield completion_text
if config.openai.gpt_params.stream:
async for completion_text in self.request_with_stream(session_id=self.session_id):
yield completion_text

token_count = self.bot.count_tokens(self.session_id, self.bot.engine)
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 响应:{completion_text}")
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 使用 token 数:{token_count}")
token_count = self.bot.count_tokens(self.session_id, self.bot.engine)
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 响应:{completion_text}")
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 使用 token 数:{token_count}")
else:
yield await self.request(session_id=self.session_id)
event_time = time.time() - start_time
if event_time is not None:
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 接收到全部消息花费了{event_time:.2f}秒")
Expand Down
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class OpenAIParams(BaseModel):
min_tokens: int = 1000
compressed_session: bool = False
compressed_tokens: int = 1000
stream: bool = True


class OpenAIAuths(BaseModel):
Expand Down
11 changes: 6 additions & 5 deletions manager/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,15 @@ async def handle_openai(self):
openai.api_base = self.config.openai.api_endpoint or openai.api_base
if openai.api_base.endswith("/"):
openai.api_base.removesuffix("/")
logger.info(f"当前的 api_endpoint 为:{openai.api_base}")

pattern = r'^https://[^/]+/v1$'
if match := re.match(pattern, openai.api_base):
logger.info(f"当前的 api_endpoint 为:{openai.api_base}")
await self.login_openai()
else:

if not re.match(pattern, openai.api_base):
logger.error("API反代地址填写错误,正确格式应为 'https://<网址>/v1'")
raise ValueError("API反代地址填写错误,正确格式应为 'https://<网址>/v1'")

await self.login_openai()


async def login(self):
self.bots = {
Expand Down