-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
110 additions
and
224 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,165 +1,90 @@ | ||
import asyncio | ||
import json | ||
import re | ||
from contextlib import suppress | ||
from typing import Generator, List | ||
import enum | ||
from typing import Generator, List, Dict | ||
|
||
import aiohttp | ||
from EdgeGPT.EdgeGPT import Chatbot as EdgeChatbot, ConversationStyle, NotAllowedToAccess | ||
from EdgeGPT.ImageGen import ImageGenAsync | ||
from graia.ariadne.message.element import Image as GraiaImage | ||
import openai | ||
from loguru import logger | ||
|
||
import constants | ||
from framework.accounts import account_manager | ||
from framework.drawing import DrawAI | ||
from framework.exceptions import LlmOperationNotSupportedException, LlmRequestTimeoutException, \ | ||
LLmAuthenticationFailedException, DrawingFailedException | ||
from framework.llm.llm import Llm | ||
from framework.llm.microsoft.models import BingCookieAuth | ||
from framework.utils.tokenutils import get_token_count | ||
|
||
image_pattern = r"!\[.*\]\((.*)\)" | ||
|
||
|
||
class BingAdapter(Llm, DrawAI): | ||
cookieData = None | ||
count: int = 0 | ||
class ConversationStyle(enum.Enum): | ||
Creative = 'creative' | ||
Balanced = 'balanced' | ||
Precise = 'precise' | ||
|
||
|
||
class BingAdapter(Llm, DrawAI): | ||
conversation_style: ConversationStyle = None | ||
|
||
bot: EdgeChatbot | ||
"""底层实现""" | ||
account: BingCookieAuth | ||
messages: List[Dict[str, str]] | ||
|
||
def __init__(self, session_id: str = "unknown", conversation_style: ConversationStyle = ConversationStyle.creative): | ||
def __init__(self, session_id: str = "unknown", conversation_style: ConversationStyle = ConversationStyle.Creative): | ||
super().__init__(session_id) | ||
self.account = account_manager.pick('bing') | ||
self.session_id = session_id | ||
self.conversation_style = conversation_style | ||
account = account_manager.pick('bing') | ||
self.cookieData = json.loads(account.cookie_content) | ||
try: | ||
self.bot = EdgeChatbot(cookies=self.cookieData, proxy=constants.proxy) | ||
except NotAllowedToAccess as e: | ||
raise LLmAuthenticationFailedException("bing") from e | ||
self.__conversation_keep_from = 0 | ||
self.messages = [] | ||
self.max_tokens = 7000 | ||
|
||
async def rollback(self): | ||
raise LlmOperationNotSupportedException() | ||
self.messages = self.messages[:-2 or None] | ||
|
||
async def on_destoryed(self): | ||
... | ||
|
||
async def ask(self, prompt: str) -> Generator[str, None, None]: | ||
self.count = self.count + 1 | ||
parsed_content = '' | ||
image_urls = [] | ||
try: | ||
async for final, response in self.bot.ask_stream(prompt=prompt, | ||
conversation_style=self.conversation_style, | ||
wss_link=constants.config.bing.wss_link, | ||
locale="zh-cn"): | ||
if not response: | ||
continue | ||
|
||
if final: | ||
# 最后一条消息 | ||
max_messages = constants.config.bing.max_messages | ||
with suppress(KeyError): | ||
max_messages = response["item"]["throttling"]["maxNumUserMessagesInConversation"] | ||
|
||
with suppress(KeyError): | ||
raw_text = response["item"]["messages"][1]["adaptiveCards"][0]["body"][0]["text"] | ||
image_urls = re.findall(image_pattern, raw_text) | ||
|
||
remaining_conversations = f'\n剩余回复数:{self.count} / {max_messages} ' \ | ||
if constants.config.bing.show_remaining_count else '' | ||
|
||
if len(response["item"].get('messages', [])) > 1 and constants.config.bing.show_suggestions: | ||
suggestions = response["item"]["messages"][-1].get("suggestedResponses", []) | ||
if len(suggestions) > 0: | ||
parsed_content = parsed_content + '\n猜你想问: \n' | ||
for suggestion in suggestions: | ||
parsed_content = f"{parsed_content}* {suggestion.get('text')} \n" | ||
|
||
parsed_content = parsed_content + remaining_conversations | ||
|
||
if parsed_content == remaining_conversations: # No content | ||
yield "Bing 已结束本次会话。继续发送消息将重新开启一个新会话。" | ||
self.count = 0 | ||
await self.bot.reset() | ||
return | ||
else: | ||
# 生成中的消息 | ||
parsed_content = re.sub(r"Searching the web for:(.*)\n", "", response) | ||
parsed_content = re.sub(r"```json(.*)```", "", parsed_content, flags=re.DOTALL) | ||
parsed_content = re.sub(r"Generating answers for you...", "", parsed_content) | ||
if constants.config.bing.show_references: | ||
parsed_content = re.sub(r"\[(\d+)\]: ", r"\1: ", parsed_content) | ||
else: | ||
parsed_content = re.sub(r"(\[\d+]: .+)+", "", parsed_content) | ||
parts = re.split(image_pattern, parsed_content) | ||
# 图片单独保存 | ||
parsed_content = parts[0] | ||
|
||
if len(parts) > 2: | ||
parsed_content = parsed_content + parts[-1] | ||
|
||
yield parsed_content | ||
logger.debug(f"[Bing AI 响应] {parsed_content}") | ||
image_tasks = [ | ||
asyncio.create_task(self.__download_image(url)) | ||
for url in image_urls | ||
] | ||
for image in await asyncio.gather(*image_tasks): | ||
yield image | ||
except (asyncio.exceptions.TimeoutError, asyncio.exceptions.CancelledError) as e: | ||
raise LlmRequestTimeoutException("bing") from e | ||
except NotAllowedToAccess as e: | ||
raise LLmAuthenticationFailedException("bing") from e | ||
except Exception as e: | ||
if str(e) == 'Redirect failed': | ||
raise DrawingFailedException() from e | ||
raise e | ||
|
||
async def text_to_img(self, prompt: str): | ||
logger.debug(f"[Bing Image] Prompt: {prompt}") | ||
try: | ||
async with ImageGenAsync( | ||
all_cookies=self.bot.chat_hub.cookies, | ||
quiet=True | ||
) as image_generator: | ||
images = await image_generator.get_images(prompt) | ||
|
||
logger.debug(f"[Bing Image] Response: {images}") | ||
tasks = [asyncio.create_task(self.__download_image(image)) for image in images] | ||
return await asyncio.gather(*tasks) | ||
except Exception as e: | ||
if str(e) == 'Redirect failed': | ||
raise DrawingFailedException() from e | ||
raise e | ||
|
||
async def img_to_img(self, init_images: List[GraiaImage], prompt=''): | ||
return await self.text_to_img(prompt) | ||
|
||
async def __download_image(self, url) -> GraiaImage: | ||
logger.debug(f"[Bing AI] 下载图片:{url}") | ||
|
||
async with aiohttp.ClientSession() as session: | ||
async with session.get(url, proxy=self.bot.proxy) as resp: | ||
resp.raise_for_status() | ||
logger.debug(f"[Bing AI] 下载完成:{resp.content_type} {url}") | ||
return GraiaImage(data_bytes=await resp.read()) | ||
async def ask(self, msg: str) -> Generator[str, None, None]: | ||
"""向 AI 发送消息""" | ||
self.messages.append({"role": "user", "content": msg}) | ||
full_chunk = [] | ||
full_text = '' | ||
while self.max_tokens - get_token_count('gpt-4', self.messages) < 0 and \ | ||
len(self.messages) > self.__conversation_keep_from: | ||
self.messages.pop(self.__conversation_keep_from) | ||
logger.debug( | ||
f"清理 token,历史记录遗忘后使用 token 数:{str(get_token_count('gpt-4', self.messages))}" | ||
) | ||
async for chunk in await openai.ChatCompletion.acreate( | ||
model=f'bing-{self.conversation_style.value}', | ||
messages=self.messages, | ||
stream=True, | ||
api_base="https://llm-proxy.lss233.com/bing/v1", | ||
api_key="sk-274a8645fd3clss233achatgptfor0botfe", | ||
headers=self.account.build_headers() | ||
): | ||
logger.info(chunk.choices[0].delta) | ||
full_chunk.append(chunk.choices[0].delta) | ||
full_text = ''.join([m.get('content', '') for m in full_chunk]) | ||
yield full_text | ||
logger.debug(f"[Bing-{self.conversation_style.value}] {self.session_id} - {full_text}") | ||
self.messages.append({"role": "assistant", "content": full_text}) | ||
|
||
# async def __download_image(self, url) -> GraiaImage: | ||
# logger.debug(f"[Bing AI] 下载图片:{url}") | ||
# | ||
# async with aiohttp.ClientSession() as session: | ||
# async with session.get(url, proxy=self.bot.proxy) as resp: | ||
# resp.raise_for_status() | ||
# logger.debug(f"[Bing AI] 下载完成:{resp.content_type} {url}") | ||
# return GraiaImage(data_bytes=await resp.read()) | ||
|
||
@classmethod | ||
def register(cls): | ||
account_manager.register_type("bing", BingCookieAuth) | ||
|
||
async def preset_ask(self, role: str, text: str): | ||
if role.endswith('bot') or role in {'assistant', 'bing'}: | ||
logger.debug(f"[预设] 响应:{text}") | ||
yield text | ||
else: | ||
logger.debug(f"[预设] 发送:{text}") | ||
item = None | ||
async for item in self.ask(text): | ||
pass | ||
if item: | ||
logger.debug(f"[预设] Chatbot 回应:{item}") | ||
async def preset_ask(self, role: str, prompt: str): | ||
if role.endswith('bot') or role in {'assistant', 'chatgpt'}: | ||
logger.debug(f"[预设] 响应:{prompt}") | ||
yield prompt | ||
role = 'assistant' | ||
if role not in ['assistant', 'user', 'system']: | ||
raise ValueError(f"预设文本有误!仅支持设定 assistant、user 或 system 的预设文本,但你写了{role}。") | ||
self.messages.append({"role": role, "content": prompt}) | ||
self.__conversation_keep_from = len(self.messages) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.