-
Notifications
You must be signed in to change notification settings - Fork 68
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[python] support multimodal models in vllm
- Loading branch information
1 parent
5cb0b2f
commit dc186b7
Showing
12 changed files
with
300 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#!/usr/bin/env python | ||
# | ||
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file | ||
# except in compliance with the License. A copy of the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" | ||
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for | ||
# the specific language governing permissions and limitations under the License. | ||
import base64 | ||
from io import BytesIO | ||
from typing import Union | ||
|
||
import requests | ||
from PIL import Image | ||
|
||
# TODO: Image token differs for each VLM model. | ||
# Including model_config becomes easier once parse_input refactor PR is done. | ||
|
||
|
||
def get_image_text_prompt(prompt_text: str) -> str: | ||
# TODO: image token str must be decoded from image_token_id in serving.properties. Change it after refactor PR. | ||
image_token_str = '<image>' | ||
|
||
# TODO: image_feature_size should be referred from serving.properties. Change it after refactor PR. | ||
image_feature_size = 1176 | ||
|
||
# TODO: Remove image_token_str*1176 after vllm next release, as the image placeholder is not needed. | ||
return f"{image_token_str*image_feature_size}\n{prompt_text}" | ||
|
||
|
||
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: | ||
return Image.open(BytesIO(base64.b64decode(image))) | ||
|
||
|
||
def fetch_image_from_url(image_url: str) -> Image.Image: | ||
# TODO: Add configurable timeout, by using an env or serving.properties, from properties.py | ||
# TODO: add validation for http url | ||
# TODO: Now, we always assume, it is an image format, it could also be pixel numpy file or image features file (pt) | ||
# Fetches the image from the http url | ||
with requests.get(url=image_url) as response: | ||
response.raise_for_status() | ||
image_raw = response.content | ||
# Opens the image using pillow, but it does not load the model into memory yet | ||
# (image.load()), as some frameworks like vllm does it anyway. | ||
image = Image.open(BytesIO(image_raw)) | ||
return image | ||
|
||
|
||
def fetch_image(image_url: str) -> Image.Image: | ||
if image_url.startswith("http"): | ||
return fetch_image_from_url(image_url) | ||
elif image_url.startswith("data:image"): | ||
_, image_base64 = image_url.split(",", 1) | ||
return load_image_from_base64(image_base64) | ||
else: | ||
raise ValueError("Invalid image url") | ||
|
||
|
||
# Use base64 encoded image in the payload | ||
def encode_image_base64_from_url(image_url: str) -> str: | ||
"""Encode an image retrieved from a remote url to base64 format.""" | ||
with requests.get(image_url) as response: | ||
response.raise_for_status() | ||
base64_image = base64.b64encode(response.content).decode('utf-8') | ||
return base64_image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
engines/python/setup/djl_python/tests/multimodal/test_parse_multimodal.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import base64 | ||
import unittest | ||
|
||
from openai import OpenAI | ||
from transformers import AutoTokenizer | ||
|
||
from djl_python.chat_completions.chat_utils import parse_chat_completions_request | ||
from djl_python.multimodal.utils import encode_image_base64_from_url | ||
|
||
OPENAI_API_KEY = "EMPTY" | ||
OPENAI_API_BASE = "http://localhost:8000/v1" | ||
|
||
client = OpenAI( | ||
# defaults to os.environ.get("OPENAI_API_KEY") | ||
api_key=OPENAI_API_KEY, | ||
base_url=OPENAI_API_BASE, | ||
) | ||
|
||
|
||
class TestLmiDist(unittest.TestCase): | ||
|
||
def test_open_ai_format_parse(self): | ||
image_url = "https://resources.djl.ai/images/dog_bike_car.jpg" | ||
image_base64 = encode_image_base64_from_url(image_url=image_url) | ||
sample_messages = [{ | ||
"role": | ||
"user", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": "What’s in this image?" | ||
}, | ||
{ | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/jpeg;base64,{image_base64}" | ||
}, | ||
}, | ||
], | ||
}] | ||
sample_input_map = {'messages': sample_messages, 'model': ""} | ||
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-v1.6-34b-hf", | ||
use_fast=False) | ||
inputs, params = parse_chat_completions_request(sample_input_map, | ||
is_rolling_batch=True, | ||
tokenizer=tokenizer) | ||
print(inputs) | ||
images = params.pop("images", None) | ||
for image in images: | ||
print(image) | ||
print(params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import argparse | ||
import base64 | ||
import sys | ||
|
||
import requests | ||
from openai import OpenAI | ||
|
||
OPENAI_API_KEY = "EMPTY" | ||
OPENAI_API_BASE = "http://localhost:8080/invocations" | ||
|
||
client = OpenAI( | ||
# defaults to os.environ.get("OPENAI_API_KEY") | ||
api_key=OPENAI_API_KEY, | ||
base_url=OPENAI_API_BASE, | ||
) | ||
|
||
|
||
def call_chat_completion_api(image: str): | ||
|
||
sample_messages = [{ | ||
"role": | ||
"user", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": "What’s in this image?" | ||
}, | ||
{ | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"{image}" | ||
}, | ||
}, | ||
], | ||
}] | ||
|
||
chat_completion_with_image = client.chat.completions.create( | ||
messages=sample_messages, | ||
model="", | ||
) | ||
|
||
return chat_completion_with_image | ||
|
||
|
||
def get_image_url(image_url_type: str, image: str): | ||
if image_url_type == "base64": | ||
if image.startswith("http"): | ||
with requests.get(image_url) as response: | ||
response.raise_for_status() | ||
image_base64 = base64.b64encode( | ||
response.content).decode('utf-8') | ||
else: | ||
with open(image, "rb") as image_file: | ||
image_base64 = base64.b64encode(image_file.read()) | ||
return f"data:image/jpeg;base64,{image_base64}" | ||
else: | ||
return image | ||
|
||
|
||
def run(raw_args): | ||
parser = argparse.ArgumentParser(description="OpenAI VLM API client") | ||
parser.add_argument("image_url_type", | ||
type=str, | ||
choices=["url", "base64"], | ||
default="url", | ||
help="image url type") | ||
parser.add_argument( | ||
"image", | ||
type=str, | ||
default="https://resources.djl.ai/images/dog_bike_car.jpg", | ||
help="image http url or local path") | ||
|
||
global args | ||
args = parser.parse_args(args=raw_args) | ||
|
||
image_url = get_image_url(args.image_url_type, args.image) | ||
result = call_chat_completion_api(image_url) | ||
print(f"OpenAI vision client result {result}") | ||
|
||
|
||
if __name__ == "__main__": | ||
run(sys.argv[1:]) |
Oops, something went wrong.