Skip to content

Commit

Permalink
Update test cases and llava
Browse files Browse the repository at this point in the history
  • Loading branch information
BeibinLi committed Jan 3, 2024
1 parent 70e982b commit 791c4e0
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 7 deletions.
12 changes: 8 additions & 4 deletions autogen/agentchat/contrib/img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,20 @@
from PIL import Image


def get_pil_image(image_file: str) -> Image.Image:
def get_pil_image(image_file: Union[str, Image.Image]) -> Image.Image:
"""
Loads an image from a file and returns a PIL Image object.
Parameters:
image_file (str): The filename, URL, URI, or base64 string of the image file.
image_file (str, or Image): The filename, URL, URI, or base64 string of the image file.
Returns:
Image.Image: The PIL Image object.
"""
if isinstance(image_file, Image.Image):
# Already a PIL Image object
return image_file

if image_file.startswith("http://") or image_file.startswith("https://"):
# A URL file
response = requests.get(image_file)
Expand All @@ -39,7 +43,7 @@ def get_pil_image(image_file: str) -> Image.Image:
return image.convert("RGB")


def get_image_data(image_file: str, use_b64=True) -> bytes:
def get_image_data(image_file: Union[str, Image.Image], use_b64=True) -> bytes:
"""
Loads an image and returns its data either as raw bytes or in base64-encoded format.
Expand All @@ -49,7 +53,7 @@ def get_image_data(image_file: str, use_b64=True) -> bytes:
either returned directly or as a base64-encoded string.
Parameters:
image_file (str): The path to the image file, a URL to an image, or a base64-encoded
image_file (str, or Image): The path to the image file, a URL to an image, or a base64-encoded
string of the image.
use_b64 (bool): If True, the function returns a base64-encoded string of the image data.
If False, it returns the raw byte data of the image. Defaults to True.
Expand Down
4 changes: 3 additions & 1 deletion autogen/agentchat/contrib/llava_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def _image_reply(self, messages=None, sender=None, config=None):
content_prompt = content_str(msg["content"])
prompt += f"{SEP}{role}: {content_prompt}\n"
prompt += "\n" + SEP + "Assistant: "
images = [re.sub("data:image/.+;base64,", "", im, count=1) for im in images]

# TODO: PIL to base64
images = [get_image_data(im) for im in images]
print(colored(prompt, "blue"))

out = ""
Expand Down
7 changes: 6 additions & 1 deletion test/agentchat/contrib/test_img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,18 @@

@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestGetPilImage(unittest.TestCase):
def test_image(self) -> bytes:
def test_read_local_file(self):
# Create a small red image for testing
temp_file = "_temp.png"
raw_pil_image.save(temp_file)
img2 = get_pil_image(temp_file)
self.assert_((np.array(raw_pil_image) == np.array(img2)).all())

def test_read_pil(self):
# Create a small red image for testing
img2 = get_pil_image(raw_pil_image)
self.assert_((np.array(raw_pil_image) == np.array(img2)).all())


def are_b64_images_equal(x: str, y: str):
"""
Expand Down
5 changes: 4 additions & 1 deletion test/agentchat/contrib/test_lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import autogen
from autogen.agentchat.agent import Agent
from autogen.agentchat.contrib.img_utils import get_pil_image

try:
from autogen.agentchat.contrib.multimodal_conversable_agent import MultimodalConversableAgent
Expand All @@ -19,6 +20,8 @@
"//8/w38GIAXDIBKE0DHxgljNBAAO9TXL0Y4OHwAAAABJRU5ErkJggg=="
)

pil_image = get_pil_image(base64_encoded_image)


@pytest.mark.skipif(skip, reason="dependency is not installed")
class TestMultimodalConversableAgent(unittest.TestCase):
Expand Down Expand Up @@ -51,7 +54,7 @@ def test_system_message(self):
self.agent.system_message,
[
{"type": "text", "text": "We will discuss "},
{"type": "image_url", "image_url": {"url": base64_encoded_image}},
{"type": "image_url", "image_url": {"url": pil_image}},
{"type": "text", "text": " in this conversation."},
],
)
Expand Down

0 comments on commit 791c4e0

Please sign in to comment.