Skip to content

Commit

Permalink
不再强制检查API,新增非流API请求 (Sourcery refactored) (#1048)
Browse files Browse the repository at this point in the history
* 修复程序开启http模式关不掉的问题

* 修复部分api错误

修复了加载预设导致engine重置为None的问题

* 不再强制检查API,新增非流API请求

* 格式化文件

* 'Refactored by Sourcery'

---------

Co-authored-by: Matt Gideon <117586514+Haibersut@users.noreply.github.com>
Co-authored-by: Sourcery AI <>
  • Loading branch information
sourcery-ai[bot] and Haibersut committed Jul 12, 2023
1 parent 63f7111 commit 101e406
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 52 deletions.
112 changes: 65 additions & 47 deletions adapter/chatgpt/api.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import json
import time
from typing import AsyncGenerator

import aiohttp
import async_timeout

import tiktoken
from loguru import logger
from typing import AsyncGenerator

from adapter.botservice import BotAdapter
from config import OpenAIAPIKey
from constants import botManager, config
import tiktoken

DEFAULT_ENGINE: str = "gpt-3.5-turbo"

Expand Down Expand Up @@ -53,12 +51,7 @@ async def rollback(self, session_id: str = "default", n: int = 1) -> None:
logger.error(f"未知错误: {e}")
raise

def add_to_conversation(
self,
message: str,
role: str,
session_id: str = "default",
) -> None:
def add_to_conversation(self, message: str, role: str, session_id: str = "default") -> None:
self.conversation[session_id].append({"role": role, "content": message})

# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
Expand All @@ -71,26 +64,8 @@ def count_tokens(self, session_id: str = "default", model: str = DEFAULT_ENGINE)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")

if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-3.5-turbo",
"gpt-4",
"gpt-4-32k"
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows {role/name}\n{content}\n
tokens_per_name = -1 # if there's a name, the role is omitted
else:
logger.warning("未找到相应模型计算方法,使用默认方法进行计算")
tokens_per_message = 3
tokens_per_name = 1
tokens_per_message = 4
tokens_per_name = 1

num_tokens = 0
for message in self.conversation[session_id]:
Expand Down Expand Up @@ -168,40 +143,80 @@ async def on_reset(self):
self.bot.engine = self.api_info.model
self.__conversation_keep_from = 0

async def request_with_stream(self, session_id: str = None, messages: list = None) -> AsyncGenerator[str, None]:
self.api_info = botManager.pick('openai-api')
api_key = self.api_info.api_key
proxy = self.api_info.proxy
api_endpoint = config.openai.api_endpoint or "https://api.openai.com/v1"

if not messages:
messages = self.bot.conversation[session_id]

def construct_data(self, messages: list = None, api_key: str = None, stream: bool = True):
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
data = {
'model': self.bot.engine,
'messages': messages,
'stream': True,
'stream': stream,
'temperature': self.bot.temperature,
'top_p': self.bot.top_p,
'presence_penalty': self.bot.presence_penalty,
'frequency_penalty': self.bot.frequency_penalty,
"user": 'user',
'max_tokens': self.bot.get_max_tokens(self.session_id, self.bot.engine),
}
return headers, data

def _prepare_request(self, session_id: str = None, messages: list = None, stream: bool = False):
self.api_info = botManager.pick('openai-api')
api_key = self.api_info.api_key
proxy = self.api_info.proxy
api_endpoint = config.openai.api_endpoint or "https://api.openai.com/v1"

if not messages:
messages = self.bot.conversation[session_id]

headers, data = self.construct_data(messages, api_key, stream)

return api_key, proxy, api_endpoint, headers, data

async def _process_response(self, resp, session_id: str = None):

result = await resp.json()

total_tokens = result.get('usage', {}).get('total_tokens', None)
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 使用 token 数:{total_tokens}")
if total_tokens is None:
raise Exception("Response does not contain 'total_tokens'")

content = result.get('choices', [{}])[0].get('message', {}).get('content', None)
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 响应:{content}")
if content is None:
raise Exception("Response does not contain 'content'")

response_role = result.get('choices', [{}])[0].get('message', {}).get('role', None)
if response_role is None:
raise Exception("Response does not contain 'role'")

self.bot.add_to_conversation(content, response_role, session_id)

return content

async def request(self, session_id: str = None, messages: list = None) -> str:
api_key, proxy, api_endpoint, headers, data = self._prepare_request(session_id, messages, stream=False)

async with aiohttp.ClientSession() as session:
with async_timeout.timeout(self.bot.timeout):
async with session.post(f'{api_endpoint}/chat/completions', headers=headers, data=json.dumps(data),
proxy=proxy) as resp:
async with session.post(f'{api_endpoint}/chat/completions', headers=headers,
data=json.dumps(data)) as resp:
if resp.status != 200:
response_text = await resp.text()
raise Exception(
f"{resp.status} {resp.reason} {response_text}",
)
return await self._process_response(resp, session_id)

async def request_with_stream(self, session_id: str = None, messages: list = None) -> AsyncGenerator[str, None]:
api_key, proxy, api_endpoint, headers, data = self._prepare_request(session_id, messages, stream=True)

async with aiohttp.ClientSession() as session:
with async_timeout.timeout(self.bot.timeout):
async with session.post(f'{api_endpoint}/chat/completions', headers=headers, data=json.dumps(data),
proxy=proxy) as resp:
response_role: str = ''
completion_text: str = ''

Expand Down Expand Up @@ -273,12 +288,15 @@ async def ask(self, prompt: str) -> AsyncGenerator[str, None]:
self.bot.add_to_conversation(prompt, "user", session_id=self.session_id)
start_time = time.time()

async for completion_text in self.request_with_stream(session_id=self.session_id):
yield completion_text
if config.openai.gpt_params.stream:
async for completion_text in self.request_with_stream(session_id=self.session_id):
yield completion_text

token_count = self.bot.count_tokens(self.session_id, self.bot.engine)
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 响应:{completion_text}")
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 使用 token 数:{token_count}")
token_count = self.bot.count_tokens(self.session_id, self.bot.engine)
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 响应:{completion_text}")
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 使用 token 数:{token_count}")
else:
yield await self.request(session_id=self.session_id)
event_time = time.time() - start_time
if event_time is not None:
logger.debug(f"[ChatGPT-API:{self.bot.engine}] 接收到全部消息花费了{event_time:.2f}秒")
Expand Down
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class OpenAIParams(BaseModel):
min_tokens: int = 1000
compressed_session: bool = False
compressed_tokens: int = 1000
stream: bool = True


class OpenAIAuths(BaseModel):
Expand Down
11 changes: 6 additions & 5 deletions manager/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,15 @@ async def handle_openai(self):
openai.api_base = self.config.openai.api_endpoint or openai.api_base
if openai.api_base.endswith("/"):
openai.api_base.removesuffix("/")
logger.info(f"当前的 api_endpoint 为:{openai.api_base}")

pattern = r'^https://[^/]+/v1$'
if match := re.match(pattern, openai.api_base):
logger.info(f"当前的 api_endpoint 为:{openai.api_base}")
await self.login_openai()
else:

if not re.match(pattern, openai.api_base):
logger.error("API反代地址填写错误,正确格式应为 'https://<网址>/v1'")
raise ValueError("API反代地址填写错误,正确格式应为 'https://<网址>/v1'")

await self.login_openai()


async def login(self):
self.bots = {
Expand Down

0 comments on commit 101e406

Please sign in to comment.