Skip to content

Commit

Permalink
[python] support multimodal models in vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
sindhuvahinis committed Jul 6, 2024
1 parent 5cb0b2f commit dc186b7
Show file tree
Hide file tree
Showing 12 changed files with 300 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ jobs:
# of aarch64 with ubuntu-20.04 not supported by the actions/setup-python
sudo apt-get install python3 python-is-python3 python3-pip -y
- name: Install pip dependencies
run: pip3 install pytest requests "numpy<2" pillow huggingface_hub
run: pip3 install pytest requests "numpy<2" pillow huggingface_hub openai
- name: Install awscurl
working-directory: tests/integration
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from typing import Optional, Union, List, Dict
from typing import Optional, Union, List, Dict, Any

from pydantic import BaseModel, Field, field_validator, ValidationInfo

Expand All @@ -21,7 +21,7 @@ class ChatProperties(BaseModel):
See https://platform.openai.com/docs/api-reference/chat/create
"""

messages: List[Dict[str, str]]
messages: List[Dict[str, Union[str, List]]]
model: Optional[str] = Field(default=None, exclude=True) # Unused
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = Field(default=None, exclude=True)
Expand All @@ -41,7 +41,8 @@ class ChatProperties(BaseModel):

@field_validator('messages', mode='before')
def validate_messages(
cls, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
cls, messages: List[Dict[str, Union[str, List]]]
) -> List[Dict[str, Union[str, List]]]:
if messages is None:
return None

Expand Down
35 changes: 34 additions & 1 deletion engines/python/setup/djl_python/chat_completions/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,31 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
from typing import Dict
from typing import List, Dict

from djl_python.chat_completions.chat_properties import ChatProperties
from djl_python.multimodal.utils import fetch_image, get_image_text_prompt


def is_chat_completions_request(inputs: Dict) -> bool:
return "messages" in inputs


def parse_multi_modal_chat_content(contents: List):
prompt_texts = []
images = []
for content in contents:
content_type = content.get("type")
if content_type == "text":
prompt_texts.append(content.get("text"))
elif content_type == "image_url":
image = fetch_image(content.get("image_url")["url"])
images.append(image)
else:
raise ValueError("We only support types text and image_url")
return prompt_texts, images


def parse_chat_completions_request(input_map: Dict, is_rolling_batch: bool,
tokenizer):
if not is_rolling_batch:
Expand All @@ -33,11 +49,28 @@ def parse_chat_completions_request(input_map: Dict, is_rolling_batch: bool,
chat_params = ChatProperties(**input_map)
param = chat_params.model_dump(by_alias=True, exclude_none=True)
messages = param.pop("messages")
images = []
for message in messages:
content = message.get("content")
if not isinstance(message.get("content"), str):
prompt_texts, content_images = parse_multi_modal_chat_content(
message.get("content"))
prompt_texts = '\n'.join(prompt_texts)
if content_images:
images.extend(content_images)
content = get_image_text_prompt(prompt_texts)
else:
content = prompt_texts
message["content"] = content

inputs = tokenizer.apply_chat_template(messages, tokenize=False)
param[
"do_sample"] = chat_params.temperature is not None and chat_params.temperature > 0.0
param["details"] = True # Enable details for chat completions
param[
"output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat"

if images:
param["images"] = images

return inputs, param
Empty file.
69 changes: 69 additions & 0 deletions engines/python/setup/djl_python/multimodal/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import base64
from io import BytesIO
from typing import Union

import requests
from PIL import Image

# TODO: Image token differs for each VLM model.
# Including model_config becomes easier once parse_input refactor PR is done.


def get_image_text_prompt(prompt_text: str) -> str:
# TODO: image token str must be decoded from image_token_id in serving.properties. Change it after refactor PR.
image_token_str = '<image>'

# TODO: image_feature_size should be referred from serving.properties. Change it after refactor PR.
image_feature_size = 1176

# TODO: Remove image_token_str*1176 after vllm next release, as the image placeholder is not needed.
return f"{image_token_str*image_feature_size}\n{prompt_text}"


def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
return Image.open(BytesIO(base64.b64decode(image)))


def fetch_image_from_url(image_url: str) -> Image.Image:
# TODO: Add configurable timeout, by using an env or serving.properties, from properties.py
# TODO: add validation for http url
# TODO: Now, we always assume, it is an image format, it could also be pixel numpy file or image features file (pt)
# Fetches the image from the http url
with requests.get(url=image_url) as response:
response.raise_for_status()
image_raw = response.content
# Opens the image using pillow, but it does not load the model into memory yet
# (image.load()), as some frameworks like vllm does it anyway.
image = Image.open(BytesIO(image_raw))
return image


def fetch_image(image_url: str) -> Image.Image:
if image_url.startswith("http"):
return fetch_image_from_url(image_url)
elif image_url.startswith("data:image"):
_, image_base64 = image_url.split(",", 1)
return load_image_from_base64(image_base64)
else:
raise ValueError("Invalid image url")


# Use base64 encoded image in the payload
def encode_image_base64_from_url(image_url: str) -> str:
"""Encode an image retrieved from a remote url to base64 format."""
with requests.get(image_url) as response:
response.raise_for_status()
base64_image = base64.b64encode(response.content).decode('utf-8')
return base64_image
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ class VllmRbProperties(Properties):
device: Optional[str] = None
preloaded_model: Optional[Any] = None

# Vision language configurations
# TODO: remove this after vLLM next release
image_token_id: Optional[int] = None
image_input_type: Optional[str] = None
image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None
image_processor: Optional[str] = None
image_processor_revision: Optional[str] = None

@field_validator('engine')
def validate_engine(cls, engine):
if engine != "Python":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from vllm import EngineArgs
from vllm.outputs import CompletionOutput, RequestOutput as vLLMRequestOutput
from vllm.lora.request import LoRARequest
from vllm.multimodal.image import ImagePixelData
from vllm.inputs import PromptInputs

from djl_python.request_io import Token, Sequence
from djl_python.request import Request
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
Expand Down Expand Up @@ -224,4 +227,27 @@ def get_engine_args_from_config(config: VllmRbProperties) -> EngineArgs:
max_lora_rank=config.max_lora_rank,
lora_extra_vocab_size=config.lora_extra_vocab_size,
max_cpu_loras=config.max_cpu_loras,
revision=config.revision)
revision=config.revision,
image_input_type=config.image_input_type,
image_token_id=config.image_token_id,
image_input_shape=config.image_input_shape,
image_feature_size=config.image_feature_size,
image_processor=config.image_processor,
image_processor_revision=config.image_processor_revision)


def get_multi_modal_data(request: Request) -> dict:
parameters = request.request_input.parameters
images = parameters.pop("images", None)
multi_modal_data = None
if images:
multi_modal_data = ImagePixelData(images[0])
return multi_modal_data


def get_prompt_inputs(request: Request):
prompt_inputs: PromptInputs = {"prompt": request.request_input.input_text}
multi_modal_data = get_multi_modal_data(request)
if multi_modal_data:
prompt_inputs["multi_modal_data"] = multi_modal_data
return prompt_inputs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
from collections import OrderedDict, defaultdict

from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm import LLMEngine, SamplingParams
from vllm.utils import random_uuid

from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params
from djl_python.rolling_batch.rolling_batch_vllm_utils import (
update_request_cache_with_output, get_lora_request_params,
get_engine_args_from_config)
get_engine_args_from_config, get_prompt_inputs)
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
from typing import List

Expand Down Expand Up @@ -126,11 +126,14 @@ def inference(self,
# step 0: register new requests to engine
for request in new_requests:
request_id = random_uuid()
prompt_inputs = get_prompt_inputs(request)
params = self.translate_vllm_params(request.parameters)
sampling_params = SamplingParams(**params)
request_params = get_lora_request_params(request, self.lora_ids)
self.engine.add_request(request_id, request.input_text,
sampling_params, **request_params)
self.engine.add_request(request_id=request_id,
inputs=prompt_inputs,
params=sampling_params,
**request_params)
self.request_cache[request_id] = {
"request_output": request.request_output
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import base64
import unittest

from openai import OpenAI
from transformers import AutoTokenizer

from djl_python.chat_completions.chat_utils import parse_chat_completions_request
from djl_python.multimodal.utils import encode_image_base64_from_url

OPENAI_API_KEY = "EMPTY"
OPENAI_API_BASE = "http://localhost:8000/v1"

client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=OPENAI_API_KEY,
base_url=OPENAI_API_BASE,
)


class TestLmiDist(unittest.TestCase):

def test_open_ai_format_parse(self):
image_url = "https://resources.djl.ai/images/dog_bike_car.jpg"
image_base64 = encode_image_base64_from_url(image_url=image_url)
sample_messages = [{
"role":
"user",
"content": [
{
"type": "text",
"text": "What’s in this image?"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
},
],
}]
sample_input_map = {'messages': sample_messages, 'model': ""}
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-v1.6-34b-hf",
use_fast=False)
inputs, params = parse_chat_completions_request(sample_input_map,
is_rolling_batch=True,
tokenizer=tokenizer)
print(inputs)
images = params.pop("images", None)
for image in images:
print(image)
print(params)
82 changes: 82 additions & 0 deletions tests/integration/llm/openai_vision_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import argparse
import base64
import sys

import requests
from openai import OpenAI

OPENAI_API_KEY = "EMPTY"
OPENAI_API_BASE = "http://localhost:8080/invocations"

client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=OPENAI_API_KEY,
base_url=OPENAI_API_BASE,
)


def call_chat_completion_api(image: str):

sample_messages = [{
"role":
"user",
"content": [
{
"type": "text",
"text": "What’s in this image?"
},
{
"type": "image_url",
"image_url": {
"url": f"{image}"
},
},
],
}]

chat_completion_with_image = client.chat.completions.create(
messages=sample_messages,
model="",
)

return chat_completion_with_image


def get_image_url(image_url_type: str, image: str):
if image_url_type == "base64":
if image.startswith("http"):
with requests.get(image_url) as response:
response.raise_for_status()
image_base64 = base64.b64encode(
response.content).decode('utf-8')
else:
with open(image, "rb") as image_file:
image_base64 = base64.b64encode(image_file.read())
return f"data:image/jpeg;base64,{image_base64}"
else:
return image


def run(raw_args):
parser = argparse.ArgumentParser(description="OpenAI VLM API client")
parser.add_argument("image_url_type",
type=str,
choices=["url", "base64"],
default="url",
help="image url type")
parser.add_argument(
"image",
type=str,
default="https://resources.djl.ai/images/dog_bike_car.jpg",
help="image http url or local path")

global args
args = parser.parse_args(args=raw_args)

image_url = get_image_url(args.image_url_type, args.image)
result = call_chat_completion_api(image_url)
print(f"OpenAI vision client result {result}")


if __name__ == "__main__":
run(sys.argv[1:])
Loading

0 comments on commit dc186b7

Please sign in to comment.