Skip to content

Commit

Permalink
stable diffusion (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
anisha1607 authored Jun 16, 2023
1 parent de81722 commit 4061998
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 0 deletions.
6 changes: 6 additions & 0 deletions config_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ MODEL_NAME: "gpt-3.5-turbo-0301"
MAX_TOOL_TOKEN_LIMIT: 800
MAX_MODEL_TOKEN_LIMIT: 4032 # set to 2048 for llama

# For running stable diffusion
STABILITY_API_KEY: YOUR_STABILITY_API_KEY
#Engine IDs that can be used: 'stable-diffusion-v1', 'stable-diffusion-v1-5','stable-diffusion-512-v2-0', 'stable-diffusion-768-v2-0','stable-diffusion-512-v2-1','stable-diffusion-768-v2-1','stable-diffusion-xl-beta-v2-2-2'
ENGINE_ID: "stable-diffusion-xl-beta-v2-2-2"


#DATABASE INFO
# redis details
DB_NAME: super_agi_main
Expand Down
115 changes: 115 additions & 0 deletions superagi/tools/image_generation/stable_diffusion_image_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import Type, Optional
from pydantic import BaseModel, Field
from superagi.tools.base_tool import BaseTool
from superagi.config.config import get_config
import os
from PIL import Image
from io import BytesIO
import requests
import base64
from superagi.models.db import connect_db
from superagi.helper.resource_helper import ResourceHelper
from superagi.helper.s3_helper import S3Helper
from sqlalchemy.orm import sessionmaker


class StableDiffusionImageGenInput(BaseModel):
prompt: str = Field(..., description="Prompt for Image Generation to be used by Stable Diffusion.")
height: int = Field(..., description="Height of the image to be Generated. default height is 512")
width: int = Field(..., description="Width of the image to be Generated. default width is 512")
num: int = Field(..., description="Number of Images to be generated. default num is 2")
steps: int = Field(..., description="Number of diffusion steps to run. default steps are 50")
image_name: list = Field(..., description="Image Names for the generated images, example 'image_1.png'. Only include the image name. Don't include path.")


class StableDiffusionImageGenTool(BaseTool):
name: str = "Stable Diffusion Image Generation"
args_schema: Type[BaseModel] = StableDiffusionImageGenInput
description: str = "Generate Images using Stable Diffusion"
agent_id: int = None

def _execute(self, prompt: str, image_name: list, width: int = 512, height: int = 512, num: int = 2, steps: int = 50):
engine = connect_db()
Session = sessionmaker(bind=engine)
session = Session()

api_key = get_config("STABILITY_API_KEY")
engine_id = get_config("ENGINE_ID")

if api_key is None:
return "Error: Missing Stability API key."

if "768" in engine_id:
if height < 768:
height = 768
if width < 768:
width = 768

response = requests.post(
f"https://api.stability.ai/v1/generation/{engine_id}/text-to-image",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {api_key}"
},
json={
"text_prompts": [
{
"text": prompt
}
],
"height": height,
"width": width,
"samples": num,
"steps": steps,
},
)

if response.status_code != 200:
return f"Non-200 response: {str(response.text)}"

data = response.json()

artifacts = data['artifacts']
base64_strings = []
for artifact in artifacts:
base64_strings.append(artifact['base64'])

for i in range(num):
image_base64 = base64_strings[i]
img_data = base64.b64decode(image_base64)
final_img = Image.open(BytesIO(img_data))
image_format = final_img.format

image = image_name[i]
final_path = image

root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')

if root_dir is not None:
root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
final_path = root_dir + image
else:
final_path = os.getcwd() + "/" + image

try:
with open(final_path, mode="wb") as img:
final_img.save(img, format=image_format)
with open(final_path, 'rb') as img:
resource = ResourceHelper.make_written_file_resource(file_name=image_name[i],
agent_id=self.agent_id, file=img,channel="OUTPUT")
print(resource)
if resource is not None:
session.add(resource)
session.commit()
session.flush()
if resource.storage_type == "S3":
s3_helper = S3Helper()
s3_helper.upload_file(img, path=resource.path)
session.close()
print(f"Image {image} saved successfully")
except Exception as err:
print(f"Error in _execute: {err}")
return f"Error: {err}"
return "Images downloaded and saved successfully"
64 changes: 64 additions & 0 deletions tests/tools/stable_diffusion_image_gen_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os
import unittest
from unittest.mock import patch, MagicMock
from PIL import Image
from io import BytesIO
import base64
from superagi.config.config import get_config

from superagi.tools.image_generation.stable_diffusion_image_gen import StableDiffusionImageGenTool


class TestStableDiffusionImageGenTool(unittest.TestCase):

@patch('requests.post')
@patch('superagi.tools.image_generation.stable_diffusion_image_gen.get_config')
def test_stable_diffusion_image_gen_tool_execute(self, mock_get_config, mock_requests_post):
# Setup
tool = StableDiffusionImageGenTool()
prompt = 'Artificial Intelligence'
image_names = ['image1.png', 'image2.png']
height = 512
width = 512
num = 2
steps = 50

# Create a temporary directory for image storage
temp_dir = get_config("RESOURCES_OUTPUT_ROOT_DIR")

# Mock responses
mock_configs = {"STABILITY_API_KEY": "api_key", "ENGINE_ID": "engine_id", "RESOURCES_OUTPUT_ROOT_DIR": temp_dir}
mock_get_config.side_effect = lambda k: mock_configs[k]

# Prepare sample image bytes
img = Image.new("RGB", (width, height), "white")
buffer = BytesIO()
img.save(buffer, "PNG")
buffer.seek(0)
img_data = buffer.getvalue()
encoded_image_data = base64.b64encode(img_data).decode()

# Use the proper base64-encoded string
mock_requests_post.return_value = MagicMock(status_code=200, json=lambda: {
"artifacts": [
{"base64": encoded_image_data},
{"base64": encoded_image_data}
]
})

# Run the method under test
response = tool._execute(prompt, image_names, width, height, num, steps)
self.assertEqual(response, f"Images downloaded successfully")

for image_name in image_names:
path = os.path.join(temp_dir, image_name)
self.assertTrue(os.path.exists(path))
with open(path, "rb") as file:
self.assertEqual(file.read(), img_data)

# Clean up
for image_name in image_names:
os.remove(os.path.join(temp_dir, image_name))

if __name__ == '__main__':
unittest.main()

0 comments on commit 4061998

Please sign in to comment.