From 101e406af0dc59e0e25d28db62f2edffc94e92ff Mon Sep 17 00:00:00 2001 From: "sourcery-ai[bot]" <58596630+sourcery-ai[bot]@users.noreply.github.com> Date: Wed, 12 Jul 2023 23:43:33 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8D=E5=86=8D=E5=BC=BA=E5=88=B6=E6=A3=80?= =?UTF-8?q?=E6=9F=A5API=EF=BC=8C=E6=96=B0=E5=A2=9E=E9=9D=9E=E6=B5=81API?= =?UTF-8?q?=E8=AF=B7=E6=B1=82=20(Sourcery=20refactored)=20(#1048)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修复程序开启http模式关不掉的问题 * 修复部分api错误 修复了加载预设导致engine重置为None的问题 * 不再强制检查API,新增非流API请求 * 格式化文件 * 'Refactored by Sourcery' --------- Co-authored-by: Matt Gideon <117586514+Haibersut@users.noreply.github.com> Co-authored-by: Sourcery AI <> --- adapter/chatgpt/api.py | 112 ++++++++++++++++++++++++----------------- config.py | 1 + manager/bot.py | 11 ++-- 3 files changed, 72 insertions(+), 52 deletions(-) diff --git a/adapter/chatgpt/api.py b/adapter/chatgpt/api.py index 5343568a..b4c2b462 100644 --- a/adapter/chatgpt/api.py +++ b/adapter/chatgpt/api.py @@ -1,16 +1,14 @@ import json import time -from typing import AsyncGenerator - import aiohttp import async_timeout - +import tiktoken from loguru import logger +from typing import AsyncGenerator from adapter.botservice import BotAdapter from config import OpenAIAPIKey from constants import botManager, config -import tiktoken DEFAULT_ENGINE: str = "gpt-3.5-turbo" @@ -53,12 +51,7 @@ async def rollback(self, session_id: str = "default", n: int = 1) -> None: logger.error(f"未知错误: {e}") raise - def add_to_conversation( - self, - message: str, - role: str, - session_id: str = "default", - ) -> None: + def add_to_conversation(self, message: str, role: str, session_id: str = "default") -> None: self.conversation[session_id].append({"role": role, "content": message}) # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb @@ -71,26 +64,8 @@ def count_tokens(self, session_id: str = "default", model: str = DEFAULT_ENGINE) except KeyError: encoding = tiktoken.get_encoding("cl100k_base") - if model in { - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-4-0314", - "gpt-4-32k-0314", - "gpt-4-0613", - "gpt-4-32k-0613", - "gpt-3.5-turbo", - "gpt-4", - "gpt-4-32k" - }: - tokens_per_message = 3 - tokens_per_name = 1 - elif model == "gpt-3.5-turbo-0301": - tokens_per_message = 4 # every message follows {role/name}\n{content}\n - tokens_per_name = -1 # if there's a name, the role is omitted - else: - logger.warning("未找到相应模型计算方法,使用默认方法进行计算") - tokens_per_message = 3 - tokens_per_name = 1 + tokens_per_message = 4 + tokens_per_name = 1 num_tokens = 0 for message in self.conversation[session_id]: @@ -168,15 +143,7 @@ async def on_reset(self): self.bot.engine = self.api_info.model self.__conversation_keep_from = 0 - async def request_with_stream(self, session_id: str = None, messages: list = None) -> AsyncGenerator[str, None]: - self.api_info = botManager.pick('openai-api') - api_key = self.api_info.api_key - proxy = self.api_info.proxy - api_endpoint = config.openai.api_endpoint or "https://api.openai.com/v1" - - if not messages: - messages = self.bot.conversation[session_id] - + def construct_data(self, messages: list = None, api_key: str = None, stream: bool = True): headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {api_key}' @@ -184,7 +151,7 @@ async def request_with_stream(self, session_id: str = None, messages: list = Non data = { 'model': self.bot.engine, 'messages': messages, - 'stream': True, + 'stream': stream, 'temperature': self.bot.temperature, 'top_p': self.bot.top_p, 'presence_penalty': self.bot.presence_penalty, @@ -192,16 +159,64 @@ async def request_with_stream(self, session_id: str = None, messages: list = Non "user": 'user', 'max_tokens': self.bot.get_max_tokens(self.session_id, self.bot.engine), } + return headers, data + + def _prepare_request(self, session_id: str = None, messages: list = None, stream: bool = False): + self.api_info = botManager.pick('openai-api') + api_key = self.api_info.api_key + proxy = self.api_info.proxy + api_endpoint = config.openai.api_endpoint or "https://api.openai.com/v1" + + if not messages: + messages = self.bot.conversation[session_id] + + headers, data = self.construct_data(messages, api_key, stream) + + return api_key, proxy, api_endpoint, headers, data + + async def _process_response(self, resp, session_id: str = None): + + result = await resp.json() + + total_tokens = result.get('usage', {}).get('total_tokens', None) + logger.debug(f"[ChatGPT-API:{self.bot.engine}] 使用 token 数:{total_tokens}") + if total_tokens is None: + raise Exception("Response does not contain 'total_tokens'") + + content = result.get('choices', [{}])[0].get('message', {}).get('content', None) + logger.debug(f"[ChatGPT-API:{self.bot.engine}] 响应:{content}") + if content is None: + raise Exception("Response does not contain 'content'") + + response_role = result.get('choices', [{}])[0].get('message', {}).get('role', None) + if response_role is None: + raise Exception("Response does not contain 'role'") + + self.bot.add_to_conversation(content, response_role, session_id) + + return content + + async def request(self, session_id: str = None, messages: list = None) -> str: + api_key, proxy, api_endpoint, headers, data = self._prepare_request(session_id, messages, stream=False) + async with aiohttp.ClientSession() as session: with async_timeout.timeout(self.bot.timeout): - async with session.post(f'{api_endpoint}/chat/completions', headers=headers, data=json.dumps(data), - proxy=proxy) as resp: + async with session.post(f'{api_endpoint}/chat/completions', headers=headers, + data=json.dumps(data)) as resp: if resp.status != 200: response_text = await resp.text() raise Exception( f"{resp.status} {resp.reason} {response_text}", ) + return await self._process_response(resp, session_id) + + async def request_with_stream(self, session_id: str = None, messages: list = None) -> AsyncGenerator[str, None]: + api_key, proxy, api_endpoint, headers, data = self._prepare_request(session_id, messages, stream=True) + async with aiohttp.ClientSession() as session: + with async_timeout.timeout(self.bot.timeout): + async with session.post(f'{api_endpoint}/chat/completions', headers=headers, data=json.dumps(data), + proxy=proxy) as resp: response_role: str = '' completion_text: str = '' @@ -273,12 +288,15 @@ async def ask(self, prompt: str) -> AsyncGenerator[str, None]: self.bot.add_to_conversation(prompt, "user", session_id=self.session_id) start_time = time.time() - async for completion_text in self.request_with_stream(session_id=self.session_id): - yield completion_text + if config.openai.gpt_params.stream: + async for completion_text in self.request_with_stream(session_id=self.session_id): + yield completion_text - token_count = self.bot.count_tokens(self.session_id, self.bot.engine) - logger.debug(f"[ChatGPT-API:{self.bot.engine}] 响应:{completion_text}") - logger.debug(f"[ChatGPT-API:{self.bot.engine}] 使用 token 数:{token_count}") + token_count = self.bot.count_tokens(self.session_id, self.bot.engine) + logger.debug(f"[ChatGPT-API:{self.bot.engine}] 响应:{completion_text}") + logger.debug(f"[ChatGPT-API:{self.bot.engine}] 使用 token 数:{token_count}") + else: + yield await self.request(session_id=self.session_id) event_time = time.time() - start_time if event_time is not None: logger.debug(f"[ChatGPT-API:{self.bot.engine}] 接收到全部消息花费了{event_time:.2f}秒") diff --git a/config.py b/config.py index c428707c..ef77bbee 100644 --- a/config.py +++ b/config.py @@ -85,6 +85,7 @@ class OpenAIParams(BaseModel): min_tokens: int = 1000 compressed_session: bool = False compressed_tokens: int = 1000 + stream: bool = True class OpenAIAuths(BaseModel): diff --git a/manager/bot.py b/manager/bot.py index 795a4db4..f04154dd 100644 --- a/manager/bot.py +++ b/manager/bot.py @@ -120,14 +120,15 @@ async def handle_openai(self): openai.api_base = self.config.openai.api_endpoint or openai.api_base if openai.api_base.endswith("/"): openai.api_base.removesuffix("/") + logger.info(f"当前的 api_endpoint 为:{openai.api_base}") pattern = r'^https://[^/]+/v1$' - if match := re.match(pattern, openai.api_base): - logger.info(f"当前的 api_endpoint 为:{openai.api_base}") - await self.login_openai() - else: + + if not re.match(pattern, openai.api_base): logger.error("API反代地址填写错误,正确格式应为 'https://<网址>/v1'") - raise ValueError("API反代地址填写错误,正确格式应为 'https://<网址>/v1'") + + await self.login_openai() + async def login(self): self.bots = {