Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

swarm - openaiimggenmodel implemented #673

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 0 additions & 43 deletions pkgs/swarmauri/swarmauri/llms/concrete/OpenAIImageGenerator.py

This file was deleted.

133 changes: 133 additions & 0 deletions pkgs/swarmauri/swarmauri/llms/concrete/OpenAIImgGenModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import json
from pydantic import Field
import asyncio
from typing import List, Dict, Literal, Optional
from openai import OpenAI, AsyncOpenAI
from swarmauri.llms.base.LLMBase import LLMBase


class OpenAIImgGenModel(LLMBase):
"""
Provider resources: https://platform.openai.com/docs/api-reference/images
"""

api_key: str
allowed_models: List[str] = ["dall-e-2", "dall-e-3"]
name: str = "dall-e-3"
type: Literal["OpenAIImgGenModel"] = "OpenAIImgGenModel"
client: OpenAI = Field(default=None, exclude=True)
async_client: AsyncOpenAI = Field(default=None, exclude=True)

def __init__(self, **data):
super().__init__(**data)
self.client = OpenAI(api_key=self.api_key)
self.async_client = AsyncOpenAI(api_key=self.api_key)

def generate_image(
self,
prompt: str,
size: str = "1024x1024",
quality: str = "standard",
n: int = 1,
style: Optional[str] = None,
) -> List[str]:
"""
Generate images using the OpenAI DALL-E model.

Parameters:
- prompt (str): The prompt to generate images from.
- size (str): Size of the generated images. Options: "256x256", "512x512", "1024x1024", "1024x1792", "1792x1024".
- quality (str): Quality of the generated images. Options: "standard", "hd" (only for DALL-E 3).
- n (int): Number of images to generate (max 10 for DALL-E 2, 1 for DALL-E 3).
- style (str): Optional. The style of the generated images. Options: "vivid", "natural" (only for DALL-E 3).

Returns:
- List of URLs of the generated images.
"""
if self.name == "dall-e-3" and n > 1:
raise ValueError("DALL-E 3 only supports generating 1 image at a time.")

kwargs = {
"model": self.name,
"prompt": prompt,
"size": size,
"quality": quality,
"n": n,
}

if style and self.name == "dall-e-3":
kwargs["style"] = style

response = self.client.images.generate(**kwargs)
return [image.url for image in response.data]

async def agenerate_image(
self,
prompt: str,
size: str = "1024x1024",
quality: str = "standard",
n: int = 1,
style: Optional[str] = None,
) -> List[str]:
"""Asynchronous version of generate_image"""
if self.name == "dall-e-3" and n > 1:
raise ValueError("DALL-E 3 only supports generating 1 image at a time.")

kwargs = {
"model": self.name,
"prompt": prompt,
"size": size,
"quality": quality,
"n": n,
}

if style and self.name == "dall-e-3":
kwargs["style"] = style

response = await self.async_client.images.generate(**kwargs)
return [image.url for image in response.data]

def batch(
self,
prompts: List[str],
size: str = "1024x1024",
quality: str = "standard",
n: int = 1,
style: Optional[str] = None,
) -> List[List[str]]:
"""Synchronously process multiple prompts"""
return [
self.generate_image(
prompt,
size=size,
quality=quality,
n=n,
style=style,
)
for prompt in prompts
]

async def abatch(
self,
prompts: List[str],
size: str = "1024x1024",
quality: str = "standard",
n: int = 1,
style: Optional[str] = None,
max_concurrent: int = 5,
) -> List[List[str]]:
"""Process multiple prompts in parallel with controlled concurrency"""
semaphore = asyncio.Semaphore(max_concurrent)

async def process_prompt(prompt):
async with semaphore:
return await self.agenerate_image(
prompt,
size=size,
quality=quality,
n=n,
style=style,
)

tasks = [process_prompt(prompt) for prompt in prompts]
return await asyncio.gather(*tasks)
119 changes: 119 additions & 0 deletions pkgs/swarmauri/tests/unit/llms/OpenAIImgGenModel_unit_tesst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import pytest
import os
from dotenv import load_dotenv
from swarmauri.llms.concrete.OpenAIImgGenModel import OpenAIImgGenModel

load_dotenv()

API_KEY = os.getenv("OPENAI_API_KEY")


@pytest.fixture(scope="module")
def openai_image_model():
if not API_KEY:
pytest.skip("Skipping due to missing OPENAI_API_KEY environment variable")
model = OpenAIImgGenModel(api_key=API_KEY)
return model


def get_allowed_models():
if not API_KEY:
return []
model = OpenAIImgGenModel(api_key=API_KEY)
return model.allowed_models


@pytest.mark.unit
def test_ubc_resource(openai_image_model):
assert openai_image_model.resource == "LLM"


@pytest.mark.unit
def test_ubc_type(openai_image_model):
assert openai_image_model.type == "OpenAIImgGenModel"


@pytest.mark.unit
def test_serialization(openai_image_model):
assert (
openai_image_model.id
== OpenAIImgGenModel.model_validate_json(
openai_image_model.model_dump_json()
).id
)


@pytest.mark.unit
def test_default_model_name(openai_image_model):
assert openai_image_model.name == "dall-e-3"


@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.integration
def test_generate_image(openai_image_model, model_name):
openai_image_model.name = model_name
prompt = "A cute robot dog playing in a park"
image_urls = openai_image_model.generate_image(prompt=prompt)

assert isinstance(image_urls, list)
assert len(image_urls) > 0
assert all(isinstance(url, str) and url.startswith("http") for url in image_urls)


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", get_allowed_models())
@pytest.mark.integration
async def test_agenerate_image(openai_image_model, model_name):
openai_image_model.name = model_name
prompt = "A futuristic cityscape with flying cars"
image_urls = await openai_image_model.agenerate_image(prompt=prompt)

assert isinstance(image_urls, list)
assert len(image_urls) > 0
assert all(isinstance(url, str) and url.startswith("http") for url in image_urls)


@pytest.mark.integration
def test_batch(openai_image_model):
prompts = [
"A serene mountain landscape",
"A bustling city street at night",
"An underwater scene with colorful fish",
]

batch_results = openai_image_model.batch(prompts=prompts)

assert len(batch_results) == len(prompts)
for result in batch_results:
assert isinstance(result, list)
assert len(result) > 0
assert all(isinstance(url, str) and url.startswith("http") for url in result)


@pytest.mark.asyncio
@pytest.mark.integration
async def test_abatch(openai_image_model):
prompts = [
"A magical forest with glowing mushrooms",
"A steampunk-inspired flying machine",
"A cozy cabin in the snow",
]

batch_results = await openai_image_model.abatch(prompts=prompts)

assert len(batch_results) == len(prompts)
for result in batch_results:
assert isinstance(result, list)
assert len(result) > 0
assert all(isinstance(url, str) and url.startswith("http") for url in result)


@pytest.mark.unit
def test_dall_e_3_single_image(openai_image_model):
openai_image_model.name = "dall-e-3"
prompt = "A surreal landscape with floating islands"
image_urls = openai_image_model.generate_image(prompt=prompt)

assert isinstance(image_urls, list)
assert len(image_urls) == 1
assert isinstance(image_urls[0], str) and image_urls[0].startswith("http")