Skip to content

Commit

Permalink
feat: OpenAI图片模型
Browse files Browse the repository at this point in the history
  • Loading branch information
liuruibin committed Nov 5, 2024
1 parent ddad340 commit 3a1728c
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# coding=utf-8
from typing import Dict

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode


class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

def get_model_params_setting_form(self, model_name):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import base64
import os
from typing import Dict

from openai import OpenAI

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_image import BaseImage


def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)


class OpenAIImage(MaxKBBaseModel, BaseImage):
api_base: str
api_key: str
model: str

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.api_key = kwargs.get('api_key')
self.api_base = kwargs.get('api_base')

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return OpenAIImage(
model=model_name,
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
**optional_params,
)

def check_auth(self):
client = OpenAI(
base_url=self.api_base,
api_key=self.api_key
)
response_list = client.models.with_raw_response.list()
# print(response_list)
# cwd = os.path.dirname(os.path.abspath(__file__))
# with open(f'{cwd}/img_1.png', 'rb') as f:
# self.image_understand(f, "一句话概述这个图片")

def image_understand(self, image_file, text):
client = OpenAI(
base_url=self.api_base,
api_key=self.api_key
)
base64_image = base64.b64encode(image_file.read()).decode('utf-8')

response = client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": text,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
},
},
],
}
],
)
return response.choices[0].message.content
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
ModelTypeConst, ModelInfoManage
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
from setting.models_provider.impl.openai_model_provider.credential.image import OpenAIImageModelCredential
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential
from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
from setting.models_provider.impl.openai_model_provider.model.image import OpenAIImage
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech
Expand All @@ -24,6 +26,7 @@
openai_llm_model_credential = OpenAILLMModelCredential()
openai_stt_model_credential = OpenAISTTModelCredential()
openai_tts_model_credential = OpenAITTSModelCredential()
openai_image_model_credential = OpenAIImageModelCredential()
model_info_list = [
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
Expand Down Expand Up @@ -88,11 +91,20 @@
OpenAIEmbeddingModel)
]

model_info_image_list = [
ModelInfo('gpt-4o', '最新的GPT-4o,比gpt-4-turbo更便宜、更快,随OpenAI调整而更新',
ModelTypeConst.IMAGE, openai_image_model_credential,
OpenAIImage),
ModelInfo('gpt-4o-mini', '最新的gpt-4o-mini,比gpt-4o更便宜、更快,随OpenAI调整而更新',
ModelTypeConst.IMAGE, openai_image_model_credential,
OpenAIImage),
]

model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
)).append_model_info_list(model_info_embedding_list).append_default_model_info(
model_info_embedding_list[0]).build()
model_info_embedding_list[0]).append_model_info_list(model_info_image_list).build()


class OpenAIModelProvider(IModelProvider):
Expand Down

0 comments on commit 3a1728c

Please sign in to comment.