Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use PIL Image internally for the Multimodal Agent #1124

Merged
merged 16 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 139 additions & 12 deletions autogen/agentchat/contrib/img_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64
import copy
import mimetypes
import os
import re
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
Expand All @@ -8,17 +10,63 @@
from PIL import Image


def get_image_data(image_file: str, use_b64=True) -> bytes:
def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
"""
Loads an image from a file and returns a PIL Image object.

Parameters:
image_file (str, or Image): The filename, URL, URI, or base64 string of the image file.

Returns:
Image.Image: The PIL Image object.
"""
if isinstance(image_file, Image.Image):
# Already a PIL Image object
return image_file

if image_file.startswith("http://") or image_file.startswith("https://"):
# A URL file
response = requests.get(image_file)
content = response.content
content = BytesIO(response.content)
image = Image.open(content)
elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
# A URI. Remove the prefix and decode the base64 string.
base64_data = re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
image = _to_pil(base64_data)
elif os.path.exists(image_file):
# A local file
image = Image.open(image_file)
else:
image = Image.open(image_file).convert("RGB")
buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()
# base64 encoded string
image = _to_pil(image_file)

return image.convert("RGB")


def get_image_data(image_file: Union[str, Image.Image], use_b64=True) -> bytes:
"""
Loads an image and returns its data either as raw bytes or in base64-encoded format.

This function first loads an image from the specified file, URL, or base64 string using
the `get_pil_image` function. It then saves this image in memory in PNG format and
retrieves its binary content. Depending on the `use_b64` flag, this binary content is
either returned directly or as a base64-encoded string.

Parameters:
image_file (str, or Image): The path to the image file, a URL to an image, or a base64-encoded
string of the image.
use_b64 (bool): If True, the function returns a base64-encoded string of the image data.
If False, it returns the raw byte data of the image. Defaults to True.

Returns:
bytes: The image data in raw bytes if `use_b64` is False, or a base64-encoded string
if `use_b64` is True.
"""
image = get_pil_image(image_file)

buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()

if use_b64:
return base64.b64encode(content).decode("utf-8")
Expand Down Expand Up @@ -72,6 +120,22 @@ def llava_formatter(prompt: str, order_image_tokens: bool = False) -> Tuple[str,
return new_prompt, images


def pil_to_data_uri(image: Image.Image) -> str:
"""
Converts a PIL Image object to a data URI.

Parameters:
image (Image.Image): The PIL Image object.

Returns:
str: The data URI string.
"""
buffered = BytesIO()
image.save(buffered, format="PNG")
content = buffered.getvalue()
return convert_base64_to_data_uri(base64.b64encode(content).decode("utf-8"))


def convert_base64_to_data_uri(base64_image):
def _get_mime_type_from_data_uri(base64_image):
# Decode the base64 string
Expand All @@ -92,16 +156,19 @@ def _get_mime_type_from_data_uri(base64_image):
return data_uri


def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
def gpt4v_formatter(prompt: str, img_format: str = "uri") -> List[Union[str, dict]]:
"""
Formats the input prompt by replacing image tags and returns a list of text and images.

Parameters:
Args:
- prompt (str): The input string that may contain image tags like <img ...>.
- img_format (str): what image format should be used. One of "uri", "url", "pil".

Returns:
- List[Union[str, dict]]: A list of alternating text and image dictionary items.
"""
assert img_format in ["uri", "url", "pil"]

output = []
last_index = 0
image_count = 0
Expand All @@ -114,7 +181,15 @@ def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
image_location = match.group(1)

try:
img_data = get_image_data(image_location)
if img_format == "pil":
img_data = get_pil_image(image_location)
elif img_format == "uri":
img_data = get_image_data(image_location)
img_data = convert_base64_to_data_uri(img_data)
elif img_format == "url":
img_data = image_location
else:
raise ValueError(f"Unknown image format {img_format}")
except Exception as e:
# Warning and skip this token
print(f"Warning! Unable to load image from {image_location}, because {e}")
Expand All @@ -124,7 +199,7 @@ def gpt4v_formatter(prompt: str) -> List[Union[str, dict]]:
output.append({"type": "text", "text": prompt[last_index : match.start()]})

# Add image data to output list
output.append({"type": "image_url", "image_url": {"url": convert_base64_to_data_uri(img_data)}})
output.append({"type": "image_url", "image_url": {"url": img_data}})

last_index = match.end()
image_count += 1
Expand Down Expand Up @@ -162,9 +237,61 @@ def _to_pil(data: str) -> Image.Image:
and finally creates and returns a PIL Image object from the BytesIO object.

Parameters:
data (str): The base64 encoded image data string.
data (str): The encoded image data string.

Returns:
Image.Image: The PIL Image object created from the input data.
"""
return Image.open(BytesIO(base64.b64decode(data)))


def message_formatter_pil_to_b64(messages: List[Dict]) -> List[Dict]:
"""
Converts the PIL image URLs in the messages to base64 encoded data URIs.

This function iterates over a list of message dictionaries. For each message,
if it contains a 'content' key with a list of items, it looks for items
with an 'image_url' key. The function then converts the PIL image URL
(pointed to by 'image_url') to a base64 encoded data URI.

Parameters:
messages (List[Dict]): A list of message dictionaries. Each dictionary
may contain a 'content' key with a list of items,
some of which might be image URLs.

Returns:
List[Dict]: A new list of message dictionaries with PIL image URLs in the
'image_url' key converted to base64 encoded data URIs.

Example Input:
[
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
{'content': [
{'type': 'text', 'text': "What's the breed of this dog here? \n"},
{'type': 'image_url', 'image_url': {'url': a PIL.Image.Image}},
{'type': 'text', 'text': '.'}],
'role': 'user'}
]

Example Output:
[
{'content': [{'type': 'text', 'text': 'You are a helpful AI assistant.'}], 'role': 'system'},
{'content': [
{'type': 'text', 'text': "What's the breed of this dog here? \n"},
{'type': 'image_url', 'image_url': {'url': a B64 Image}},
{'type': 'text', 'text': '.'}],
'role': 'user'}
]
"""
new_messages = []
for message in messages:
# Handle the new GPT messages format.
if isinstance(message, dict) and "content" in message and isinstance(message["content"], list):
message = copy.deepcopy(message)
for item in message["content"]:
if isinstance(item, dict) and "image_url" in item:
item["image_url"]["url"] = pil_to_data_uri(item["image_url"]["url"])

new_messages.append(message)

return new_messages
4 changes: 3 additions & 1 deletion autogen/agentchat/contrib/llava_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def _image_reply(self, messages=None, sender=None, config=None):
content_prompt = content_str(msg["content"])
prompt += f"{SEP}{role}: {content_prompt}\n"
prompt += "\n" + SEP + "Assistant: "
images = [re.sub("data:image/.+;base64,", "", im, count=1) for im in images]

# TODO: PIL to base64
images = [get_image_data(im) for im in images]
print(colored(prompt, "blue"))

out = ""
Expand Down
52 changes: 49 additions & 3 deletions autogen/agentchat/contrib/multimodal_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

from autogen import OpenAIWrapper
from autogen.agentchat import Agent, ConversableAgent
from autogen.agentchat.contrib.img_utils import gpt4v_formatter
from autogen.agentchat.contrib.img_utils import (
convert_base64_to_data_uri,
gpt4v_formatter,
message_formatter_pil_to_b64,
pil_to_data_uri,
)

from ..._pydantic import model_dump

try:
from termcolor import colored
Expand Down Expand Up @@ -55,6 +62,21 @@ def __init__(
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)

# Override the `generate_oai_reply`
def _replace_reply_func(arr, x, y):
for item in arr:
if item["reply_func"] is x:
item["reply_func"] = y

_replace_reply_func(
self._reply_func_list, ConversableAgent.generate_oai_reply, MultimodalConversableAgent.generate_oai_reply
)
_replace_reply_func(
self._reply_func_list,
ConversableAgent.a_generate_oai_reply,
MultimodalConversableAgent.a_generate_oai_reply,
)

def update_system_message(self, system_message: Union[Dict, List, str]):
"""Update the system message.

Expand All @@ -76,18 +98,42 @@ def _message_to_dict(message: Union[Dict, List, str]) -> Dict:
will be processed using the gpt4v_formatter.
"""
if isinstance(message, str):
return {"content": gpt4v_formatter(message)}
return {"content": gpt4v_formatter(message, img_format="pil")}
if isinstance(message, list):
return {"content": message}
if isinstance(message, dict):
assert "content" in message, "The message dict must have a `content` field"
if isinstance(message["content"], str):
message = copy.deepcopy(message)
message["content"] = gpt4v_formatter(message["content"])
message["content"] = gpt4v_formatter(message["content"], img_format="pil")
try:
content_str(message["content"])
except (TypeError, ValueError) as e:
print("The `content` field should be compatible with the content_str function!")
raise e
return message
raise ValueError(f"Unsupported message type: {type(message)}")

def generate_oai_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[OpenAIWrapper] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""Generate a reply using autogen.oai."""
client = self.client if config is None else config
if client is None:
return False, None
if messages is None:
messages = self._oai_messages[sender]

messages_with_b64_img = message_formatter_pil_to_b64(self._oai_system_message + messages)

# TODO: #1143 handle token limit exceeded error
response = client.create(context=messages[-1].pop("context", None), messages=messages_with_b64_img)

# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
extracted_response = client.extract_text_or_completion_object(response)[0]
if not isinstance(extracted_response, str):
extracted_response = model_dump(extracted_response)
return True, extracted_response
Loading
Loading