Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Adds Image Generation Capability 2.0 #1907

Merged
merged 55 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
636bf9d
adds image generation capability
WaelKarkoub Mar 5, 2024
aca17d5
add todo
WaelKarkoub Mar 5, 2024
81e438c
readded cache
WaelKarkoub Mar 6, 2024
1bfe7c7
Merge branch 'main' into describe-image-capability
WaelKarkoub Mar 6, 2024
a35a8d2
Merge branch 'main' into describe-image-capability
WaelKarkoub Mar 6, 2024
8f7aeff
Merge branch 'main' into describe-image-capability
WaelKarkoub Mar 6, 2024
5cfada7
wip
WaelKarkoub Mar 7, 2024
5a1e23b
Merge branch 'main' into describe-image-capability
WaelKarkoub Mar 7, 2024
8e37db3
fix content str bugs
WaelKarkoub Mar 7, 2024
d2682e0
removed todo: delete imshow
WaelKarkoub Mar 7, 2024
0f74c33
wip
WaelKarkoub Mar 7, 2024
2403c1b
fix circular imports
WaelKarkoub Mar 7, 2024
4a0f842
add notebook
WaelKarkoub Mar 7, 2024
ed8e2d6
improve prompt
WaelKarkoub Mar 7, 2024
e62f7cf
improved text analyzer + notebook
WaelKarkoub Mar 7, 2024
8bd9d66
notebook update
WaelKarkoub Mar 7, 2024
8908fb2
Merge branch 'main' into describe-image-capability
WaelKarkoub Mar 7, 2024
f38cf6a
Merge branch 'main' into describe-image-capability
WaelKarkoub Mar 8, 2024
f89f686
improve notebook
WaelKarkoub Mar 8, 2024
e7329d9
smaller notebook size
WaelKarkoub Mar 8, 2024
85b6bcf
made changes to the wrong branch :(
WaelKarkoub Mar 8, 2024
8984b47
Merge branch 'main' into describe-image-capability
WaelKarkoub Mar 9, 2024
d241dea
Merge branch 'main' into describe-image-capability
WaelKarkoub Mar 9, 2024
aada6b4
resolve comments + 1
WaelKarkoub Mar 10, 2024
50ea140
adds doc strings
WaelKarkoub Mar 10, 2024
2439241
Merge branch 'main' into describe-image-capability
WaelKarkoub Mar 12, 2024
d5b9d2c
Merge branch 'main' into describe-image-capability
WaelKarkoub Mar 13, 2024
691aabf
adds cache doc string
WaelKarkoub Mar 13, 2024
03406ac
adds doc string to add_to_agent
WaelKarkoub Mar 13, 2024
298cc90
adds doc string to ImageGeneration
WaelKarkoub Mar 13, 2024
1440a0d
instructions are not configurable
WaelKarkoub Mar 13, 2024
3c0a3e1
removed unnecessary imports
WaelKarkoub Mar 13, 2024
a502420
changed doc string location
WaelKarkoub Mar 13, 2024
322e55d
more doc strings
WaelKarkoub Mar 13, 2024
d5e9b5f
Merge branch 'main' into describe-image-capability
WaelKarkoub Mar 14, 2024
2418c60
improves testability
WaelKarkoub Mar 15, 2024
c2e833d
adds tests
WaelKarkoub Mar 15, 2024
e2bd370
adds cache test
WaelKarkoub Mar 15, 2024
70cd6e3
added test to github workflow
WaelKarkoub Mar 15, 2024
c21d0e2
compatible llm config format
WaelKarkoub Mar 15, 2024
fad3ab2
Merge branch 'main' into describe-image-capability
WaelKarkoub Mar 15, 2024
4636a95
configurable reply function position
WaelKarkoub Mar 15, 2024
18f39ab
skip_openai + better comments
WaelKarkoub Mar 15, 2024
b5a8ee4
Merge branch 'main' into describe-image-capability
WaelKarkoub Mar 15, 2024
45fa663
fix test
WaelKarkoub Mar 15, 2024
1a9b7b9
fix test?
WaelKarkoub Mar 15, 2024
793da60
please fix test?
WaelKarkoub Mar 15, 2024
7464a1c
last fix test?
WaelKarkoub Mar 15, 2024
6a93ebd
remove type hint
WaelKarkoub Mar 15, 2024
8603f71
skip cache test
WaelKarkoub Mar 15, 2024
83a64ac
adds mock api key
WaelKarkoub Mar 15, 2024
300c53e
dalle-2 test
sonichi Mar 15, 2024
af912f0
Merge remote-tracking branch 'refs/remotes/origin/describe-image-capa…
sonichi Mar 15, 2024
cf02c1a
fix dalle config
WaelKarkoub Mar 15, 2024
a250c3d
use apu key function
WaelKarkoub Mar 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .github/workflows/contrib-openai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,38 @@ jobs:
with:
file: ./coverage.xml
flags: unittests
ImageGen:
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.12"]
runs-on: ${{ matrix.os }}
environment: openai1
steps:
# checkout to pr branch
- name: Checkout
uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies
run: |
docker --version
python -m pip install --upgrade pip wheel
pip install -e .[lmm]
python -c "import autogen"
pip install coverage pytest
- name: Coverage
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
run: |
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_image_generation_capability.py
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
7 changes: 4 additions & 3 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ on:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions: {}
# actions: read
permissions:
{}
# actions: read
# checks: read
# contents: read
# deployments: read
Expand Down Expand Up @@ -246,7 +247,7 @@ jobs:
- name: Coverage
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py --skip-openai
coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py test/agentchat/contrib/capabilities/test_image_generation_capability.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
Expand Down
291 changes: 291 additions & 0 deletions autogen/agentchat/contrib/capabilities/generate_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
import re
from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union

from openai import OpenAI
from PIL.Image import Image
BeibinLi marked this conversation as resolved.
Show resolved Hide resolved

from autogen import Agent, ConversableAgent, code_utils
from autogen.cache import Cache
from autogen.agentchat.contrib import img_utils
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent

SYSTEM_MESSAGE = "You've been given the special ability to generate images."
DESCRIPTION_MESSAGE = "This agent has the ability to generate images."

PROMPT_INSTRUCTIONS = """In detail, please summarize the provided prompt to generate the image described in the TEXT.
DO NOT include any advice. RESPOND like the following example:
EXAMPLE: Blue background, 3D shapes, ...
"""


class ImageGenerator(Protocol):
"""This class defines an interface for image generators.

Concrete implementations of this protocol must provide a `generate_image` method that takes a string prompt as
input and returns a PIL Image object.

NOTE: Current implementation does not allow you to edit a previously existing image.
"""

def generate_image(self, prompt: str) -> Image:
"""Generates an image based on the provided prompt.

Args:
prompt: A string describing the desired image.

Returns:
A PIL Image object representing the generated image.

Raises:
ValueError: If the image generation fails.
"""
...

def cache_key(self, prompt: str) -> str:
"""Generates a unique cache key for the given prompt.

This key can be used to store and retrieve generated images based on the prompt.

Args:
prompt: A string describing the desired image.

Returns:
A unique string that can be used as a cache key.
"""
...


class DalleImageGenerator:
"""Generates images using OpenAI's DALL-E models.

This class provides a convenient interface for generating images based on textual prompts using OpenAI's DALL-E
models. It allows you to specify the DALL-E model, resolution, quality, and the number of images to generate.

Note: Current implementation does not allow you to edit a previously existing image.
"""

def __init__(
self,
llm_config: Dict,
resolution: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] = "1024x1024",
quality: Literal["standard", "hd"] = "standard",
num_images: int = 1,
):
"""
Args:
llm_config (dict): llm config, must contain a valid dalle model and OpenAI API key in config_list.
resolution (str): The resolution of the image you want to generate. Must be one of "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792".
quality (str): The quality of the image you want to generate. Must be one of "standard", "hd".
num_images (int): The number of images to generate.
"""
config_list = llm_config["config_list"]
_validate_dalle_model(config_list[0]["model"])
_validate_resolution_format(resolution)

self._model = config_list[0]["model"]
self._resolution = resolution
self._quality = quality
self._num_images = num_images
self._dalle_client = OpenAI(api_key=config_list[0]["api_key"])

def generate_image(self, prompt: str) -> Image:
response = self._dalle_client.images.generate(
model=self._model,
prompt=prompt,
size=self._resolution,
quality=self._quality,
n=self._num_images,
)

image_url = response.data[0].url
if image_url is None:
raise ValueError("Failed to generate image.")

return img_utils.get_pil_image(image_url)

def cache_key(self, prompt: str) -> str:
keys = (prompt, self._model, self._resolution, self._quality, self._num_images)
return ",".join([str(k) for k in keys])


class ImageGeneration(AgentCapability):
"""This capability allows a ConversableAgent to generate images based on the message received from other Agents.

1. Utilizes a TextAnalyzerAgent to analyze incoming messages to identify requests for image generation and
extract relevant details.
2. Leverages the provided ImageGenerator (e.g., DalleImageGenerator) to create the image.
3. Optionally caches generated images for faster retrieval in future conversations.

NOTE: This capability increases the token usage of the agent, as it uses TextAnalyzerAgent to analyze every
message received by the agent.

Example:
```python
import autogen
from autogen.agentchat.contrib.capabilities.image_generation import ImageGeneration

# Assuming you have llm configs configured for the LLMs you want to use and Dalle.
# Create the agent
agent = autogen.ConversableAgent(
name="dalle", llm_config={...}, max_consecutive_auto_reply=3, human_input_mode="NEVER"
)

# Create an ImageGenerator with desired settings
dalle_gen = generate_images.DalleImageGenerator(llm_config={...})

# Add the ImageGeneration capability to the agent
agent.add_capability(ImageGeneration(image_generator=dalle_gen))
```
"""

def __init__(
self,
image_generator: ImageGenerator,
cache: Optional[Cache] = None,
text_analyzer_llm_config: Optional[Dict] = None,
text_analyzer_instructions: str = PROMPT_INSTRUCTIONS,
verbosity: int = 0,
register_reply_position: int = 2,
):
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
"""
Args:
image_generator (ImageGenerator): The image generator you would like to use to generate images.
cache (None or Cache): The cache client to use to store and retrieve generated images. If None,
no caching will be used.
text_analyzer_llm_config (Dict or None): The LLM config for the text analyzer. If None, the LLM config will
be retrieved from the agent you're adding the ability to.
text_analyzer_instructions (str): Instructions provided to the TextAnalyzerAgent used to analyze
incoming messages and extract the prompt for image generation. The default instructions focus on
summarizing the prompt. You can customize the instructions to achieve more granular control over prompt
extraction.
Example: 'Extract specific details from the message, like desired objects, styles, or backgrounds.'
verbosity (int): The verbosity level. Defaults to 0 and must be greater than or equal to 0. The text
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
analyzer llm calls will be silent if verbosity is less than 2.
register_reply_position (int): The position of the reply function in the agent's list of reply functions.
This capability registers a new reply function to handle messages with image generation requests.
Defaults to 2 to place it after the check termination and human reply for a ConversableAgent.
"""
self._image_generator = image_generator
self._cache = cache
self._text_analyzer_llm_config = text_analyzer_llm_config
self._text_analyzer_instructions = text_analyzer_instructions
self._verbosity = verbosity
self._register_reply_position = register_reply_position

self._agent: Optional[ConversableAgent] = None
self._text_analyzer: Optional[TextAnalyzerAgent] = None

def add_to_agent(self, agent: ConversableAgent):
"""Adds the Image Generation capability to the specified ConversableAgent.

This function performs the following modifications to the agent:

1. Registers a reply function: A new reply function is registered with the agent to handle messages that
potentially request image generation. This function analyzes the message and triggers image generation if
necessary.
2. Creates an Agent (TextAnalyzerAgent): This is used to analyze messages for image generation requirements.
3. Updates System Message: The agent's system message is updated to include a message indicating the
capability to generate images has been added.
4. Updates Description: The agent's description is updated to reflect the addition of the Image Generation
capability. This might be helpful in certain use cases, like group chats.

Args:
agent (ConversableAgent): The ConversableAgent to add the capability to.
"""
self._agent = agent
ekzhu marked this conversation as resolved.
Show resolved Hide resolved

agent.register_reply([Agent, None], self._image_gen_reply, position=self._register_reply_position)

self._text_analyzer_llm_config = self._text_analyzer_llm_config or agent.llm_config
self._text_analyzer = TextAnalyzerAgent(llm_config=self._text_analyzer_llm_config)

agent.update_system_message(agent.system_message + "\n" + SYSTEM_MESSAGE)
agent.description += "\n" + DESCRIPTION_MESSAGE

def _image_gen_reply(
self,
recipient: ConversableAgent,
messages: Optional[List[Dict]],
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
if messages is None:
return False, None

last_message = code_utils.content_str(messages[-1]["content"])

if not last_message:
return False, None

if self._should_generate_image(last_message):
prompt = self._extract_prompt(last_message)

image = self._cache_get(prompt)
if image is None:
image = self._image_generator.generate_image(prompt)
self._cache_set(prompt, image)

return True, self._generate_content_message(prompt, image)

else:
return False, None

def _should_generate_image(self, message: str) -> bool:
assert self._text_analyzer is not None

instructions = """
Does any part of the TEXT ask the agent to generate an image?
The TEXT must explicitly mention that the image must be generated.
Answer with just one word, yes or no.
"""
analysis = self._text_analyzer.analyze_text(message, instructions)

return "yes" in self._extract_analysis(analysis).lower()

def _extract_prompt(self, last_message) -> str:
assert self._text_analyzer is not None

analysis = self._text_analyzer.analyze_text(last_message, self._text_analyzer_instructions)
return self._extract_analysis(analysis)

def _cache_get(self, prompt: str) -> Optional[Image]:
if self._cache:
key = self._image_generator.cache_key(prompt)
cached_value = self._cache.get(key)

if cached_value:
return img_utils.get_pil_image(cached_value)

def _cache_set(self, prompt: str, image: Image):
if self._cache:
key = self._image_generator.cache_key(prompt)
self._cache.set(key, img_utils.pil_to_data_uri(image))

def _extract_analysis(self, analysis: Union[str, Dict, None]) -> str:
if isinstance(analysis, Dict):
return code_utils.content_str(analysis["content"])
else:
return code_utils.content_str(analysis)

def _generate_content_message(self, prompt: str, image: Image) -> Dict[str, Any]:
return {
"content": [
{"type": "text", "text": f"I generated an image with the prompt: {prompt}"},
{"type": "image_url", "image_url": {"url": img_utils.pil_to_data_uri(image)}},
]
}


### Helpers
def _validate_resolution_format(resolution: str):
"""Checks if a string is in a valid resolution format (e.g., "1024x768")."""
pattern = r"^\d+x\d+$" # Matches a pattern of digits, "x", and digits
matched_resolution = re.match(pattern, resolution)
if matched_resolution is None:
raise ValueError(f"Invalid resolution format: {resolution}")


def _validate_dalle_model(model: str):
if model not in ["dall-e-3", "dall-e-2"]:
raise ValueError(f"Invalid DALL-E model: {model}. Must be 'dall-e-3' or 'dall-e-2'")
Loading
Loading