Skip to content

Commit

Permalink
refactor VL modules for internvl and qwen2-vl (#2764)
Browse files Browse the repository at this point in the history
* qwen2-vl

* internvl

* qwen2
  • Loading branch information
lvhan028 authored Nov 18, 2024
1 parent 0c80baa commit 464d451
Show file tree
Hide file tree
Showing 6 changed files with 586 additions and 249 deletions.
233 changes: 140 additions & 93 deletions lmdeploy/serve/vl_async_engine.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from typing import Dict, List, Optional, Union

import numpy as np

from lmdeploy.pytorch.check_env import try_import_deeplink
from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.utils import get_logger
from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX, IMAGE_TOKEN
from lmdeploy.vl.engine import ImageEncoder
from lmdeploy.vl.templates import VLPromptType, get_vl_prompt_template
from lmdeploy.vl.utils import load_image

logger = get_logger('lmdeploy')

Expand All @@ -19,9 +18,11 @@ class VLAsyncEngine(AsyncEngine):
def __init__(self, model_path: str, **kwargs) -> None:
vision_config = kwargs.pop('vision_config', None)
backend_config = kwargs.get('backend_config', None)
if kwargs.get('backend', '') == 'pytorch':
self.backend = kwargs['backend']
if self.backend == 'pytorch':
try_import_deeplink(backend_config.device_type)
self.vl_encoder = ImageEncoder(model_path,
self.backend,
vision_config,
backend_config=backend_config)
super().__init__(model_path, **kwargs)
Expand All @@ -35,7 +36,7 @@ def __init__(self, model_path: str, **kwargs) -> None:
def _convert_prompts(self,
prompts: Union[VLPromptType, List[Dict],
List[VLPromptType], List[List[Dict]]]):
"""convert prompts to openai format."""
"""convert prompts to openai GPT4V format."""
if isinstance(prompts, str) or isinstance(prompts, tuple):
_prompts = self.vl_prompt_template.prompt_to_messages(prompts)
elif isinstance(prompts[0], tuple) or isinstance(prompts[0], str):
Expand All @@ -47,102 +48,148 @@ def _convert_prompts(self,
return _prompts

async def _get_prompt_input(self,
prompt: Dict,
messages: Union[str, List[Dict]],
do_preprocess: bool,
sequence_start: bool,
adapter_name: str,
tools: Optional[List[object]] = None,
**kwargs):
"""get input_ids, embeddings and offsets."""
if do_preprocess:
decorated = self.vl_prompt_template.messages2prompt(
prompt, sequence_start)
else:
decorated = prompt
segs = decorated.split(IMAGE_TOKEN)

results = {}
input_ids = []
from lmdeploy.vl.templates import (MllamaTempateWrapper,
MolmoChatTemplateWrapper,
Qwen2VLChatTemplateWrapper)
ranges = None
grid_thws = None
if len(segs) > 1:
# yapf: disable
images_with_kwargs = await self.vl_prompt_template.async_collect_pil_images(prompt) # noqa: E501
# yapf: enable
features = []
if len(images_with_kwargs) > 0:
images, image_kwargs = list(zip(*images_with_kwargs))
features = await self.vl_encoder.async_infer(
images, image_kwargs)

from lmdeploy.vl.templates import MiniCPMVTempateWrapper
if isinstance(self.vl_prompt_template, MiniCPMVTempateWrapper):
decorated, features = self.vl_prompt_template.update_image_token( # noqa: E501
decorated, features)
segs = decorated.split(IMAGE_TOKEN)

if isinstance(self.vl_prompt_template,
Qwen2VLChatTemplateWrapper):
grid_thws = [x['grid_thw'] for x in features]
features = [x['embeddings'] for x in features]

if isinstance(self.vl_prompt_template, MllamaTempateWrapper):
# llama3.2 just encode <|image|> and inference
decorated = decorated.replace(IMAGE_TOKEN, '<|image|>')
input_ids = self.tokenizer.encode(decorated,
add_bos=sequence_start)
results['input_ids'] = input_ids
results['prompt'] = decorated
assert len(features)
results['cross_attention_states'] = features[0]
return results

if isinstance(self.vl_prompt_template,
MolmoChatTemplateWrapper):
return features[0]

features = [x.cpu().numpy() for x in features]
input_ids = []
begins = []
ends = []
if len(segs) != len(features) + 1:
logger.error(
f'the number of {IMAGE_TOKEN} is not equal '
f'to input images, {len(segs) - 1} vs {len(features)}')
features = features[:len(segs) - 1]
for i, seg in enumerate(segs):
if i > 0 and i <= len(features):
image_dim = features[i - 1].shape[0]
begins.append(len(input_ids))
ends.append(begins[-1] + image_dim)
input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim)
seg_ids = self.tokenizer.encode(seg,
add_bos=((i == 0)
and sequence_start))
input_ids.extend(seg_ids)
ranges = np.stack([begins, ends], axis=1).tolist()
results['input_embeddings'] = features or None
results['input_embedding_ranges'] = ranges or None
"""process messages and return the required data for the inference
engines. Refer to pytorch.engine.EngineInstance.async_stream_infer and
turbomind.TurboMindInstance.async_stream_infer for the argument
specification.
Args:
messages
do_preprocess
sequence_start
adapter_name
tools
Returns:
"""
if isinstance(messages, str):
return super(self)._get_prompt_input(messages, do_preprocess,
sequence_start, adapter_name,
tools, **kwargs)
elif isinstance(messages, List):
has_multimodal_input = any(
isinstance(message['content'], list) and any(
item['type'] in ['image_url', 'image_data']
for item in message['content']) for message in messages)
if not has_multimodal_input:
return super(self)._get_prompt_input(messages, do_preprocess,
sequence_start,
adapter_name, tools,
**kwargs)
else:
input_ids = self.tokenizer.encode(decorated,
add_bos=sequence_start)

if isinstance(self.vl_prompt_template, Qwen2VLChatTemplateWrapper):
# TODO: refactor _get_prompt_input function
mrope_position_ids, mrope_position_delta = \
self.vl_prompt_template.get_mrope_info(
len(input_ids), grid_thws=grid_thws,
embedding_ranges=ranges)
results['mrope_position_ids'] = mrope_position_ids
results['mrope_position_delta'] = mrope_position_delta

results['input_ids'] = input_ids
results['prompt'] = decorated
raise RuntimeError(f'unsupported messages {messages}')

messages = await self.async_convert_to_pil_images(messages)
results = await self.vl_encoder.preprocess(messages)
if self.backend == 'turbomind':
results = await self.vl_encoder.async_infer(results)
results = await self.vl_encoder.wrap_for_turbomind(
results, self.chat_template, self.tokenizer, sequence_start)
elif self.backend == 'pytorch':
results = await self.vl_encoder.wrap_for_pytorch(
results, self.chat_template, self.tokenizer, sequence_start)
return results

@classmethod
async def async_convert_to_pil_images(cls,
messages: List[Dict]) -> List[Dict]:
"""Scan the provided messages to find image URLs or base64-encoded
image data. Loads the images into Pillow image objects.
Args:
messages (List[Dict]): a user request of GPT4V message format
"""
if isinstance(messages, Dict):
messages = [messages]
assert isinstance(messages, List)

out_messages = [None] * len(messages)

def _inner_call(i, in_messages, out_messages):
role = in_messages[i]['role']
content = in_messages[i]['content']
if role != 'user' or isinstance(content, str):
# the content is a user's prompt or an assistant's prompt,
# returning it directly
out_messages[i] = in_messages[i]
return
# the role is a user and the content is a list, in which there
# might be image_url or image_data
assert isinstance(content, List)
message = dict(role=role, content=[])
for item in content:
# image url or base64-encoded image data
if item['type'] == 'image_url':
"""
convert the following item:
{
'type': 'image_url',
'image_url': {
'url': 'image url or base64-encoded image data',
'key': 'value' # parameters used in image processing
...
}
}
to:
{
'type': 'image',
'image': Pillow.Image,
'key': 'value' # parameters used in image processing
...
}
""" # noqa
data = item['image_url'].copy()
try:
url = data.pop('url')
image = load_image(url)
data.update(type='image', image=image)
message['content'].append(data)
except KeyError:
logger.error(f'invalid format {message}')
elif item['type'] == 'image_data':
"""
convert the following item:
{
'type': 'image_data',
'image_data': {
'data': Pillow.Image,
'key': 'value' # parameters used in image processing
...
}
}
to:
{
'type': 'image',
'image': Pillow.Image,
'key': 'value' # parameters used in image processing
...
}
""" # noqa
data = item['image_data'].copy()
try:
image = data.pop('data')
data.update(type='image', image=image)
message['content'].append(data)
except KeyError:
logger.error(f'invalid format {message}')
elif item['type'] == 'text':
message['content'].append(item)
else:
logger.error(f'unexpected content type {message}')
out_messages[i] = message

await asyncio.gather(*[
asyncio.get_event_loop().run_in_executor(None, _inner_call, i,
messages, out_messages)
for i in range(len(messages))
])
return out_messages

def batch_infer(self, prompts: Union[VLPromptType, List[Dict],
List[VLPromptType], List[List[Dict]]],
**kwargs):
Expand Down
Loading

0 comments on commit 464d451

Please sign in to comment.