Skip to content

Commit

Permalink
LMM Code added
Browse files Browse the repository at this point in the history
  • Loading branch information
BeibinLi committed Nov 5, 2023
1 parent f052977 commit 157a158
Show file tree
Hide file tree
Showing 8 changed files with 1,764 additions and 1,361 deletions.
175 changes: 175 additions & 0 deletions autogen/agentchat/contrib/llava_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple, Union

import replicate
import requests
from regex import R

from autogen.agentchat.agent import Agent
from autogen.agentchat.multimodal_conversable_agent import MultimodalConversableAgent
from autogen.img_utils import get_image_data, lmm_formater

try:
from termcolor import colored
except ImportError:

def colored(x, *args, **kwargs):
return x


logger = logging.getLogger(__name__)

# we will override the following variables later.
SEP = "###"

DEFAULT_LLAVA_SYS_MSG = "You are an AI agent and you can view images."


class LLaVAAgent(MultimodalConversableAgent):
def __init__(
self,
name: str,
system_message: Optional[Tuple[str, List]] = DEFAULT_LLAVA_SYS_MSG,
*args,
**kwargs,
):
"""
Args:
name (str): agent name.
system_message (str): system message for the ChatCompletion inference.
Please override this attribute if you want to reprogram the agent.
**kwargs (dict): Please refer to other kwargs in
[ConversableAgent](conversable_agent#__init__).
"""
super().__init__(
name,
system_message=system_message,
*args,
**kwargs,
)

assert self.llm_config is not None, "llm_config must be provided."
self.register_reply([Agent, None], reply_func=LLaVAAgent._image_reply, position=0)

def _image_reply(self, messages=None, sender=None, config=None):
# Note: we did not use "llm_config" yet.
# TODO 1: make the LLaVA API design compatible with llm_config

if all((messages is None, sender is None)):
error_msg = f"Either {messages=} or {sender=} must be provided."
logger.error(error_msg)
raise AssertionError(error_msg)

if messages is None:
messages = self._oai_messages[sender]

# The formats for LLaVA and GPT are different. So, we manually handle them here.
# TODO: format the images from the history accordingly.
images = []
prompt = self._content_str(self.system_message) + "\n"
for msg in messages:
role = "Human" if msg["role"] == "user" else "Assistant"
images += [d["image"] for d in msg["content"] if isinstance(d, dict)]
content_prompt = self._content_str(msg["content"])
prompt += f"{SEP}{role}: {content_prompt}\n"
prompt += "\n" + SEP + "Assistant: "
print(colored(prompt, "blue"))

out = ""
retry = 10
while len(out) == 0 and retry > 0:
# image names will be inferred automatically from llava_call
out = llava_call_binary(
prompt=prompt,
images=images,
config_list=self.llm_config["config_list"],
temperature=self.llm_config.get("temperature", 0.5),
max_new_tokens=self.llm_config.get("max_new_tokens", 2000),
)
retry -= 1

assert out != "", "Empty response from LLaVA."

return True, out


def _llava_call_binary_with_config(
prompt: str, images: list, config: dict, max_new_tokens: int = 1000, temperature: float = 0.5, seed: int = 1
):
if config["api_base"].find("0.0.0.0") >= 0 or config["api_base"].find("localhost") >= 0:
llava_mode = "local"
else:
llava_mode = "remote"

if llava_mode == "local":
headers = {"User-Agent": "LLaVA Client"}
pload = {
"model": config["model"],
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"stop": SEP,
"images": images,
}

response = requests.post(
config["api_base"].rstrip("/") + "/worker_generate_stream", headers=headers, json=pload, stream=False
)

for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"].split(SEP)[-1]
elif llava_mode == "remote":
# The Replicate version of the model only support 1 image for now.
img = "data:image/jpeg;base64," + images[0]
response = replicate.run(
config["api_base"], input={"image": img, "prompt": prompt.replace("<image>", " "), "seed": seed}
)
# The yorickvp/llava-13b model can stream output as it's running.
# The predict method returns an iterator, and you can iterate over that output.
output = ""
for item in response:
# https://replicate.com/yorickvp/llava-13b/versions/2facb4a474a0462c15041b78b1ad70952ea46b5ec6ad29583c0b29dbd4249591/api#output-schema
output += item

# Remove the prompt and the space.
output = output.replace(prompt, "").strip().rstrip()
return output


def llava_call_binary(
prompt: str, images: list, config_list: list, max_new_tokens: int = 1000, temperature: float = 0.5, seed: int = 1
):
# TODO 1: add caching around the LLaVA call to save compute and cost
# TODO 2: add `seed` to ensure reproducibility. The seed is not working now.

for config in config_list:
try:
return _llava_call_binary_with_config(prompt, images, config, max_new_tokens, temperature, seed)
except Exception as e:
print(f"Error: {e}")
continue


def llava_call(prompt: str, llm_config: dict) -> str:
"""
Makes a call to the LLaVA service to generate text based on a given prompt
"""

prompt, images = lmm_formater(prompt, order_image_tokens=False)

for im in images:
if len(im) == 0:
raise RuntimeError("An image is empty!")

return llava_call_binary(
prompt,
images,
config_list=llm_config["config_list"],
max_new_tokens=llm_config.get("max_new_tokens", 2000),
temperature=llm_config.get("temperature", 0.5),
seed=llm_config.get("seed", None),
)
118 changes: 118 additions & 0 deletions autogen/agentchat/multimodal_conversable_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import Callable, Dict, List, Optional, Tuple, Union

from autogen import oai
from autogen.img_utils import gpt4v_formatter

from .agent import Agent
from .conversable_agent import ConversableAgent

try:
from termcolor import colored
except ImportError:

def colored(x, *args, **kwargs):
return x


DEFAULT_LMM_SYS_MSG = """You are a helpful AI assistant.
You can also view images, where the "<image i>" represent the i-th image you received."""


class MultimodalConversableAgent(ConversableAgent):
def __init__(
self,
name: str,
system_message: Optional[Tuple[str, List]] = DEFAULT_LMM_SYS_MSG,
is_termination_msg=None,
*args,
**kwargs,
):
"""
Args:
name (str): agent name.
system_message (str): system message for the ChatCompletion inference.
Please override this attribute if you want to reprogram the agent.
**kwargs (dict): Please refer to other kwargs in
[ConversableAgent](conversable_agent#__init__).
"""
super().__init__(
name,
system_message,
is_termination_msg=is_termination_msg,
*args,
**kwargs,
)

self.update_system_message(system_message)
self._is_termination_msg = (
is_termination_msg if is_termination_msg is not None else (lambda x: x.get("content")[-1] == "TERMINATE")
)

@property
def system_message(self) -> List:
"""Return the system message."""
return self._oai_system_message[0]["content"]

def update_system_message(self, system_message: Union[Dict, List, str]):
"""Update the system message.
Args:
system_message (str): system message for the ChatCompletion inference.
"""
self._oai_system_message[0]["content"] = self._message_to_dict(system_message)["content"]
self._oai_system_message[0]["role"] = "system"

@staticmethod
def _message_to_dict(message: Union[Dict, List, str]):
"""Convert a message to a dictionary.
The message can be a string or a dictionary. The string will be put in the "content" field of the new dictionary.
"""
if isinstance(message, str):
return {"content": gpt4v_formatter(message)}
if isinstance(message, list):
return {"content": message}
else:
return message

def _content_str(self, content: List) -> str:
rst = ""
for item in content:
if isinstance(item, str):
rst += item
else:
assert isinstance(item, dict) and "image" in item, "Wrong content format."
rst += "<image>"
return rst

def _print_received_message(self, message: Union[Dict, str], sender: Agent):
# print the message received
print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True)
if message.get("role") == "function":
func_print = f"***** Response from calling function \"{message['name']}\" *****"
print(colored(func_print, "green"), flush=True)
print(self._content_str(message["content"]), flush=True)
print(colored("*" * len(func_print), "green"), flush=True)
else:
content = message.get("content")
if content is not None:
if "context" in message:
content = oai.ChatCompletion.instantiate(
content,
message["context"],
self.llm_config and self.llm_config.get("allow_format_str_template", False),
)
print(self._content_str(content), flush=True)
if "function_call" in message:
func_print = f"***** Suggested function Call: {message['function_call'].get('name', '(No function name found)')} *****"
print(colored(func_print, "green"), flush=True)
print(
"Arguments: \n",
message["function_call"].get("arguments", "(No arguments found)"),
flush=True,
sep="",
)
print(colored("*" * len(func_print), "green"), flush=True)
print("\n", "-" * 80, flush=True, sep="")

# TODO: we may want to udpate `generate_code_execution_reply` or `extract_code` for the "content" type change.
Loading

0 comments on commit 157a158

Please sign in to comment.