Skip to content

Commit

Permalink
Use PIL Image internally for the Multimodal Agent (#1124)
Browse files Browse the repository at this point in the history
* Change defualt model for `lmm`

* Try to use PIL image for LMM's _oai_messages

* Update test cases and llava

* Remove redundant files

* Update the imports for lmm tests

* Test case fix

* Docstring update

* LMM notebook lint

* Typo correction for img_utils and its test

* Update test_llava.py

debug, reformat

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
Co-authored-by: Shaokun Zhang <shaokunzhang529@gmail.com>
Co-authored-by: Shaokun Zhang <shaokun.zhang@psu.edu>
  • Loading branch information
4 people authored Feb 18, 2024
1 parent 2d29d36 commit 9de374a
Show file tree
Hide file tree
Showing 7 changed files with 493 additions and 192 deletions.
151 changes: 139 additions & 12 deletions autogen/agentchat/contrib/img_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64
import copy
import mimetypes
import os
import re
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
Expand All @@ -8,17 +10,63 @@
from PIL import Image


def get_image_data(image_file: str, use_b64=True) -> bytes:
def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
"""
Loads an image from a file and returns a PIL Image object.
Parameters:
image_file (str, or Image): The filename, URL, URI, or base64 string of the image file.
Returns:
Image.Image: The PIL Image object.
"""
if isinstance(image_file, Image.Image):
# Already a PIL Image object
return image_file

if image_file.startswith("http://") or image_file.startswith("https://"):
# A URL file
response = requests.get(image_file)
content = response.content
content = BytesIO(response.content)
image = Image.open(content)
elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
# A URI. Remove the prefix and decode the base64 string.
base64_data = re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
image = _to_pil(base64_data)
elif os.path.exists(image_file):
# A local file
image = Image.open(image_file)
else:
image = Image.open(image_file).convert("RGB")
buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()
# base64 encoded string
image = _to_pil(image_file)

return image.convert("RGB")


def get_image_data(image_file: Union[str, Image.Image], use_b64=True) -> bytes:
"""
Loads an image and returns its data either as raw bytes or in base64-encoded format.
This function first loads an image from the specified file, URL, or base64 string using
the `get_pil_image` function. It then saves this image in memory in PNG format and
retrieves its binary content. Depending on the `use_b64` flag, this binary content is
either returned directly or as a base64-encoded string.
Parameters:
image_file (str, or Image): The path to the image file, a URL to an image, or a base64-encoded
string of the image.
use_b64 (bool): If True, the function returns a base64-encoded string of the image data.
If False, it returns the raw byte data of the image. Defaults to True.
Returns:
bytes: The image data in raw bytes if `use_b64` is False, or a base64-encoded string
if `use_b64` is True.
"""
image = get_pil_image(image_file)

buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()

if use_b64:
return base64.b64encode(content).decode("utf-8")
Expand Down Expand Up @@ -72,6 +120,22 @@ def llava_formatter(prompt: str, order_image_tokens: bool = False) -> Tuple[str,
return new_prompt, images


def pil_to_data_uri(image: Image.Image) -> str:
"""
Converts a PIL Image object to a data URI.
Parameters:
image (Image.Image): The PIL Image object.
Returns:
str: The data URI string.
"""
buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()
return convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))


def convert_base64_to_data_uri(base64_image):
def _get_mime_type_from_data_uri(base64_image):
# Decode the base64 string
Expand All @@ -92,16 +156,19 @@ def _get_mime_type_from_data_uri(base64_image):
return data_uri


def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
def gpt4v_formatter(prompt: str, img_format: str = "uri") -> List[Union[str, dict]]:
"""
Formats the input prompt by replacing image tags and returns a list of text and images.
Parameters:
Args:
- prompt (str): The input string that may contain image tags like <img ...>.
- img_format (str): what image format should be used. One of "uri", "url", "pil".
Returns:
- List[Union[str, dict]]: A list of alternating text and image dictionary items.
"""
assert img_format in ["uri", "url", "pil"]

output = []
last_index = 0
image_count = 0
Expand All @@ -114,7 +181,15 @@ def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
image_location = match.group(1)

try:
img_data = get_image_data(image_location)
if img_format == "pil":
img_data = get_pil_image(image_location)
elif img_format == "uri":
img_data = get_image_data(image_location)
img_data = convert_base64_to_data_uri(img_data)
elif img_format == "url":
img_data = image_location
else:
raise ValueError(f"Unknown image format {img_format}")
except Exception as e:
# Warning and skip this token
print(f"Warning! Unable to load image from {image_location}, because {e}")
Expand All @@ -124,7 +199,7 @@ def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
output.append({"type": "text", "text": prompt[last_index : match.start()]})

# Add image data to output list
output.append({"type": "image_url", "image_url": {"url": convert_base64_to_data_uri(img_data)}})
output.append({"type": "image_url", "image_url": {"url": img_data}})

last_index = match.end()
image_count += 1
Expand Down Expand Up @@ -162,9 +237,61 @@ def _to_pil(data: str) -> Image.Image:
and finally creates and returns a PIL Image object from the BytesIO object.
Parameters:
data (str): The base64 encoded image data string.
data (str): The encoded image data string.
Returns:
Image.Image: The PIL Image object created from the input data.
"""
return Image.open(BytesIO(base64.b64decode(data)))


def message_formatter_pil_to_b64(messages: List[Dict]) -> List[Dict]:
"""
Converts the PIL image URLs in the messages to base64 encoded data URIs.
This function iterates over a list of message dictionaries. For each message,
if it contains a 'content' key with a list of items, it looks for items
with an 'image_url' key. The function then converts the PIL image URL
(pointed to by 'image_url') to a base64 encoded data URI.
Parameters:
messages (List[Dict]): A list of message dictionaries. Each dictionary
may contain a 'content' key with a list of items,
some of which might be image URLs.
Returns:
List[Dict]: A new list of message dictionaries with PIL image URLs in the
'image_url' key converted to base64 encoded data URIs.
Example Input:
[
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
{'content': [
{'type': 'text', 'text': "What's the breed of this dog here? \n"},
{'type': 'image_url', 'image_url': {'url': a PIL.Image.Image}},
{'type': 'text', 'text': '.'}],
'role': 'user'}
]
Example Output:
[
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
{'content': [
{'type': 'text', 'text': "What's the breed of this dog here? \n"},
{'type': 'image_url', 'image_url': {'url': a B64 Image}},
{'type': 'text', 'text': '.'}],
'role': 'user'}
]
"""
new_messages = []
for message in messages:
# Handle the new GPT messages format.
if isinstance(message, dict) and "content" in message and isinstance(message["content"], list):
message = copy.deepcopy(message)
for item in message["content"]:
if isinstance(item, dict) and "image_url" in item:
item["image_url"]["url"] = pil_to_data_uri(item["image_url"]["url"])

new_messages.append(message)

return new_messages
4 changes: 3 additions & 1 deletion autogen/agentchat/contrib/llava_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def _image_reply(self, messages=None, sender=None, config=None):
content_prompt = content_str(msg["content"])
prompt += f"{SEP}{role}: {content_prompt}\n"
prompt += "\n" + SEP + "Assistant: "
images = [re.sub("data:image/.+;base64,", "", im, count=1) for im in images]

# TODO: PIL to base64
images = [get_image_data(im) for im in images]
print(colored(prompt, "blue"))

out = ""
Expand Down
52 changes: 49 additions & 3 deletions autogen/agentchat/contrib/multimodal_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

from autogen import OpenAIWrapper
from autogen.agentchat import Agent, ConversableAgent
from autogen.agentchat.contrib.img_utils import gpt4v_formatter
from autogen.agentchat.contrib.img_utils import (
convert_base64_to_data_uri,
gpt4v_formatter,
message_formatter_pil_to_b64,
pil_to_data_uri,
)

from ..._pydantic import model_dump

try:
from termcolor import colored
Expand Down Expand Up @@ -55,6 +62,21 @@ def __init__(
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)

# Override the `generate_oai_reply`
def _replace_reply_func(arr, x, y):
for item in arr:
if item["reply_func"] is x:
item["reply_func"] = y

_replace_reply_func(
self._reply_func_list, ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply
)
_replace_reply_func(
self._reply_func_list,
ConversableAgent.a_generate_oai_reply,
MultimodalConversableAgent.a_generate_oai_reply,
)

def update_system_message(self, system_message: Union[Dict, List, str]):
"""Update the system message.
Expand All @@ -76,18 +98,42 @@ def _message_to_dict(message: Union[Dict, List, str]) -> Dict:
will be processed using the gpt4v_formatter.
"""
if isinstance(message, str):
return {"content": gpt4v_formatter(message)}
return {"content": gpt4v_formatter(message, img_format="pil")}
if isinstance(message, list):
return {"content": message}
if isinstance(message, dict):
assert "content" in message, "The message dict must have a `content` field"
if isinstance(message["content"], str):
message = copy.deepcopy(message)
message["content"] = gpt4v_formatter(message["content"])
message["content"] = gpt4v_formatter(message["content"], img_format="pil")
try:
content_str(message["content"])
except (TypeError, ValueError) as e:
print("The `content` field should be compatible with the content_str function!")
raise e
return message
raise ValueError(f"Unsupported message type: {type(message)}")

def generate_oai_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[OpenAIWrapper] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""Generate a reply using autogen.oai."""
client = self.client if config is None else config
if client is None:
return False, None
if messages is None:
messages = self._oai_messages[sender]

messages_with_b64_img = message_formatter_pil_to_b64(self._oai_system_message + messages)

# TODO: #1143 handle token limit exceeded error
response = client.create(context=messages[-1].pop("context", None), messages=messages_with_b64_img)

# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
extracted_response = client.extract_text_or_completion_object(response)[0]
if not isinstance(extracted_response, str):
extracted_response = model_dump(extracted_response)
return True, extracted_response
Loading

0 comments on commit 9de374a

Please sign in to comment.