Skip to content

Commit

Permalink
refactor API done
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick committed Apr 2, 2024
1 parent 0b6b0e5 commit bd20c15
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 152 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ torch>=2.0.1
tqdm
transformers
typing_extensions==4.7.1
validators
visual_genome
xlsxwriter
xtuner
53 changes: 44 additions & 9 deletions vlmeval/api/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import time
import random as rd
from abc import abstractmethod
from ..smp import get_logger
import os.path as osp
from ..smp import get_logger, parse_file


class BaseAPI:

allowed_types = ['text', 'image']

def __init__(self,
retry=10,
wait=3,
Expand Down Expand Up @@ -41,15 +44,47 @@ def working(self):
retry -= 1
return False

def check_content(self, msgs):
if isinstance(msgs, str):
return 'str'
if isinstance(msgs, dict):
return 'dict'
if isinstance(msgs, list):
types = [self.check_msgs(m) for m in msgs]
if all(t == 'str' for t in types):
return 'liststr'
if all(t == 'dict' for t in types):
return 'listdict'
return 'unknown'

def preproc_content(self, inputs):
if self.check_content(inputs) == 'str':
return [dict(type='text', value=inputs)]
elif self.check_content(inputs) == 'dict':
assert 'type' in inputs and 'value' in inputs
return [inputs]
elif self.check_content(inputs) == 'liststr':
res = []
for s in inputs:
mime, pth = parse_file(s)
if mime is None or mime == 'unknown':
res.append(dict(type='text', value=s))
else:
res.append(dict(type=mime.split('/')[0], value=pth))
return res
elif self.check_content(inputs) == 'listdict':
for item in inputs:
assert 'type' in item and 'value' in item
return inputs
else:
return None

def generate(self, inputs, **kwargs):
input_type = None
if isinstance(inputs, str):
input_type = 'str'
elif isinstance(inputs, list) and isinstance(inputs[0], str):
input_type = 'strlist'
elif isinstance(inputs, list) and isinstance(inputs[0], dict):
input_type = 'dictlist'
assert input_type is not None, input_type
assert self.check_content(inputs) in ['str', 'dict', 'liststr', 'listdict'], f'Invalid input type: {inputs}'
inputs = self.preproc_content(inputs)
assert inputs is not None and self.check_content(inputs) == 'listdict'
for item in inputs:
assert item['type'] in self.allowed_types, f'Invalid input type: {item["type"]}'

answer = None
# a very small random delay [0s - 0.5s]
Expand Down
36 changes: 16 additions & 20 deletions vlmeval/api/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from vlmeval.api.base import BaseAPI
from time import sleep
import base64
import mimetypes

url = 'https://openxlab.org.cn/gw/alles-apin-hub/v1/claude/v1/text/chat'
headers = {
Expand Down Expand Up @@ -37,29 +38,27 @@ def __init__(self,

super().__init__(retry=retry, wait=wait, verbose=verbose, system_prompt=system_prompt, **kwargs)

@staticmethod
def build_msgs(msgs_raw):
def build_msgs(self, msgs_raw):

messages = []
message = {'role': 'user', 'content': []}
for msg in msgs_raw:
if isimg(msg):
media_type_map = {
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'png': 'image/png',
'gif': 'image/gif',
'webp': 'iamge/webp'
}
media_type = media_type_map[msg.split('.')[-1].lower()]
with open(msg, 'rb') as file:
image_data = base64.b64encode(file.read()).decode('utf-8')
if msg['type'] == 'image':
pth = msg['value']
suffix = osp.splitext(pth)[-1].lower()
media_type = mimetypes.types_map.get(suffix, None)
assert media_type is not None

item = {
'type': 'image',
'source': {'type': 'base64', 'media_type': media_type, 'data': image_data}
'source': {'type': 'base64', 'media_type': media_type, 'data': encode_image_file_to_base64(pth)}
}

elif msg['type'] == 'text':
item = {'type': 'text', 'text': msg['value']}
else:
item = {'type': 'text', 'text': msg}
raise NotImplementedError(f'Unsupported message type: {msg["type"]}')

message['content'].append(item)
messages.append(message)
return messages
Expand Down Expand Up @@ -96,8 +95,5 @@ def generate_inner(self, inputs, **kwargs) -> str:

class Claude3V(Claude_Wrapper):

def generate(self, image_path, prompt, dataset=None):
return super(Claude_Wrapper, self).generate([image_path, prompt])

def interleave_generate(self, ti_list, dataset=None):
return super(Claude_Wrapper, self).generate(ti_list)
def generate(self, msgs, dataset=None):
return super(Claude_Wrapper, self).generate(msgs)
44 changes: 11 additions & 33 deletions vlmeval/api/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,14 @@ def __init__(self,
proxy_set(proxy)
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)

@staticmethod
def build_msgs(msgs_raw, system_prompt=None):
msgs = cp.deepcopy(msgs_raw)
assert len(msgs) % 2 == 1

if system_prompt is not None:
msgs[0] = [system_prompt, msgs[0]]
ret = []
for i, msg in enumerate(msgs):
role = 'user' if i % 2 == 0 else 'model'
parts = msg if isinstance(msg, list) else [msg]
ret.append(dict(role=role, parts=parts))
return ret
def build_msgs(self, inputs):
messages = [] if self.system_prompt is None else [self.system_prompt]
for inp in inputs:
if inp['type'] == 'text':
messages.append(inp['value'])
elif inp['type'] == 'image':
messages.append(Image.open(inp['value']))
return messages

def generate_inner(self, inputs, **kwargs) -> str:
import google.generativeai as genai
Expand All @@ -54,21 +49,7 @@ def generate_inner(self, inputs, **kwargs) -> str:
pure_text = False
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel('gemini-pro') if pure_text else genai.GenerativeModel('gemini-pro-vision')
if isinstance(inputs, str):
messages = [inputs] if self.system_prompt is None else [self.system_prompt, inputs]
elif pure_text:
messages = self.build_msgs(inputs, self.system_prompt)
else:
messages = [] if self.system_prompt is None else [self.system_prompt]
for s in inputs:
if osp.exists(s):
messages.append(Image.open(s))
elif s.startswith('http'):
pth = download_file(s)
messages.append(Image.open(pth))
shutil.remove(pth)
else:
messages.append(s)
messages = self.build_msgs(inputs)
gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
gen_config.update(self.kwargs)
try:
Expand All @@ -84,8 +65,5 @@ def generate_inner(self, inputs, **kwargs) -> str:

class GeminiProVision(GeminiWrapper):

def generate(self, image_path, prompt, dataset=None):
return super(GeminiProVision, self).generate([image_path, prompt])

def interleave_generate(self, ti_list, dataset=None):
return super(GeminiProVision, self).generate(ti_list)
def interleave_generate(self, msgs, dataset=None):
return super(GeminiProVision, self).generate(msgs)
61 changes: 21 additions & 40 deletions vlmeval/api/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def GPT_context_window(model):
if model in length_map:
return length_map[model]
else:
return 4096
return 128000


class OpenAIWrapper(BaseAPI):
Expand Down Expand Up @@ -91,39 +91,24 @@ def prepare_inputs(self, inputs):
input_msgs = []
if self.system_prompt is not None:
input_msgs.append(dict(role='system', content=self.system_prompt))
if isinstance(inputs, str):
input_msgs.append(dict(role='user', content=inputs))
return input_msgs
assert isinstance(inputs, list)
dict_flag = [isinstance(x, dict) for x in inputs]
if np.all(dict_flag):
input_msgs.extend(inputs)
return input_msgs
str_flag = [isinstance(x, str) for x in inputs]
if np.all(str_flag):
img_flag = [x.startswith('http') or osp.exists(x) for x in inputs]
if np.any(img_flag):
content_list = []
for fl, msg in zip(img_flag, inputs):
if not fl:
content_list.append(dict(type='text', text=msg))
elif msg.startswith('http'):
content_list.append(dict(type='image_url', image_url={'url': msg, 'detail': self.img_detail}))
elif osp.exists(msg):
from PIL import Image
img = Image.open(msg)
b64 = encode_image_to_base64(img, target_size=self.img_size)
img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail)
content_list.append(dict(type='image_url', image_url=img_struct))
input_msgs.append(dict(role='user', content=content_list))
return input_msgs
else:
roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user']
roles = roles * len(inputs)
for role, msg in zip(roles, inputs):
input_msgs.append(dict(role=role, content=msg))
return input_msgs
raise NotImplementedError('list of list prompt not implemented now. ')
has_images = np.sum([x['type'] == 'image' for x in inputs])
if has_images:
content_list = []
for msg in inputs:
if msg['type'] == 'text':
content_list.append(dict(type='text', text=msg['value']))
elif msg['type'] == 'image':
from PIL import Image
img = Image.open(msg['value'])
b64 = encode_image_to_base64(img, target_size=self.img_size)
img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail)
content_list.append(dict(type='image_url', image_url=img_struct))
input_msgs.append(dict(role='user', content=content_list))
else:
assert all([x['type'] == 'text' for x in inputs])
text = '\n'.join([x['value'] for x in inputs])
input_msgs.append(dict(role='user', content=text))
return input_msgs

def generate_inner(self, inputs, **kwargs) -> str:
input_msgs = self.prepare_inputs(inputs)
Expand Down Expand Up @@ -182,10 +167,6 @@ def get_token_len(self, inputs) -> int:

class GPT4V(OpenAIWrapper):

def generate(self, image_path, prompt, dataset=None):
assert self.model == 'gpt-4-vision-preview'
return super(GPT4V, self).generate([image_path, prompt])

def interleave_generate(self, ti_list, dataset=None):
def generate(self, msgs, dataset=None):
assert self.model == 'gpt-4-vision-preview'
return super(GPT4V, self).generate(ti_list)
return super(GPT4V, self).generate(msgs)
8 changes: 2 additions & 6 deletions vlmeval/api/gpt_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,6 @@ def generate_inner(self, inputs, **kwargs) -> str:

class GPT4V_Internal(OpenAIWrapperInternal):

def generate(self, image_path, prompt, dataset=None):
def generate(self, msgs, dataset=None):
assert self.model == 'gpt-4-vision-preview'
return super(GPT4V_Internal, self).generate([image_path, prompt])

def interleave_generate(self, ti_list, dataset=None):
assert self.model == 'gpt-4-vision-preview'
return super(GPT4V_Internal, self).generate(ti_list)
return super(GPT4V_Internal, self).generate(msgs)
1 change: 1 addition & 0 deletions vlmeval/api/qwen_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from vlmeval.smp import *


# Note: This is a pure language model API.
class QwenAPI(BaseAPI):

is_api: bool = True
Expand Down
25 changes: 8 additions & 17 deletions vlmeval/api/qwen_vl_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,18 @@ def build_msgs(msgs_raw, system_prompt=None):
content = list(dict(text=system_prompt))
ret.append(dict(role='system', content=content))
content = []
for i, msg in enumerate(msgs):
if osp.exists(msg):
content.append(dict(image='file://' + msg))
elif msg.startswith('http'):
content.append(dict(image=msg))
else:
content.append(dict(text=msg))
for msg in msgs:
if msg['type'] == 'text':
content.append(dict(text=msg['value']))
elif msg['type'] == 'image':
content.append(dict(image='file://' + msg['value']))
ret.append(dict(role='user', content=content))
return ret

def generate_inner(self, inputs, **kwargs) -> str:
from dashscope import MultiModalConversation
assert isinstance(inputs, str) or isinstance(inputs, list)
pure_text = True
if isinstance(inputs, list):
for pth in inputs:
if osp.exists(pth) or pth.startswith('http'):
pure_text = False
pure_text = np.all([x['type'] == 'text' for x in inputs])
assert not pure_text
messages = self.build_msgs(msgs_raw=inputs, system_prompt=self.system_prompt)
gen_config = dict(max_output_tokens=self.max_tokens, temperature=self.temperature)
Expand All @@ -81,8 +75,5 @@ def generate_inner(self, inputs, **kwargs) -> str:

class QwenVLAPI(QwenVLWrapper):

def generate(self, image_path, prompt, dataset=None):
return super(QwenVLAPI, self).generate([image_path, prompt])

def interleave_generate(self, ti_list, dataset=None):
return super(QwenVLAPI, self).generate(ti_list)
def generate(self, msgs, dataset=None):
return super(QwenVLAPI, self).generate(msgs)
Loading

0 comments on commit bd20c15

Please sign in to comment.