Skip to content

Commit

Permalink
Jarvis demo, base multimmodalmodel, whisperx -> whisperx_model
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Nov 25, 2023
1 parent 9390efb commit 51c82cf
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 42 deletions.
20 changes: 20 additions & 0 deletions playground/demos/jarvis_multi_modal_auto_agent/jarvis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from swarms.structs import Flow
from swarms.models.gpt4_vision_api import GPT4VisionAPI
from swarms.prompts.multi_modal_autonomous_instruction_prompt import (
MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
)


llm = GPT4VisionAPI()

task = "What is the color of the object?"
img = "images/swarms.jpeg"

## Initialize the workflow
flow = Flow(
llm=llm,
sop=MULTI_MODAL_AUTO_AGENT_SYSTEM_PROMPT_1,
max_loops="auto",
)

flow.run(task=task, img=img)
4 changes: 2 additions & 2 deletions swarms/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

# MultiModal Models
from swarms.models.idefics import Idefics # noqa: E402

# from swarms.models.kosmos_two import Kosmos # noqa: E402
from swarms.models.vilt import Vilt # noqa: E402
from swarms.models.nougat import Nougat # noqa: E402
from swarms.models.layoutlm_document_qa import LayoutLMDocumentQA # noqa: E402
Expand All @@ -30,6 +28,8 @@
# from swarms.models.gpt4v import GPT4Vision
# from swarms.models.dalle3 import Dalle3
# from swarms.models.distilled_whisperx import DistilWhisperModel # noqa: E402
# from swarms.models.whisperx_model import WhisperX # noqa: E402
# from swarms.models.kosmos_two import Kosmos # noqa: E402

__all__ = [
"Anthropic",
Expand Down
209 changes: 209 additions & 0 deletions swarms/models/base_multimodal_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import asyncio
import base64
import concurrent.futures
import time
from concurrent import ThreadPoolExecutor
from io import BytesIO
from typing import List, Optional, Tuple

import requests
from ABC import abstractmethod
from PIL import Image


class BaseMultiModalModel:
def __init__(
self,
model_name: Optional[str],
temperature: Optional[int] = 0.5,
max_tokens: Optional[int] = 500,
max_workers: Optional[int] = 10,
top_p: Optional[int] = 1,
top_k: Optional[int] = 50,
device: Optional[str] = "cuda",
max_new_tokens: Optional[int] = 500,
retries: Optional[int] = 3,
):
self.model_name = model_name
self.temperature = temperature
self.max_tokens = max_tokens
self.max_workers = max_workers
self.top_p = top_p
self.top_k = top_k
self.device = device
self.max_new_tokens = max_new_tokens
self.retries = retries
self.chat_history = []


@abstractmethod
def __call__(self, text: str, img: str):
"""Run the model"""
pass

def run(self, task: str, img: str):
"""Run the model"""
pass

async def arun(self, task: str, img: str):
"""Run the model asynchronously"""
pass

def get_img_from_web(self, img: str):
"""Get the image from the web"""
try:
response = requests.get(img)
response.raise_for_status()
image_pil = Image.open(BytesIO(response.content))
return image_pil
except requests.RequestException as error:
print(f"Error fetching image from {img} and error: {error}")
return None

def encode_img(self, img: str):
"""Encode the image to base64"""
with open(img, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")

def get_img(self, img: str):
"""Get the image from the path"""
image_pil = Image.open(img)
return image_pil

def clear_chat_history(self):
"""Clear the chat history"""
self.chat_history = []

def run_many(
self,
tasks: List[str],
imgs: List[str],
):
"""
Run the model on multiple tasks and images all at once using concurrent
Args:
tasks (List[str]): List of tasks
imgs (List[str]): List of image paths
Returns:
List[str]: List of responses
"""
# Instantiate the thread pool executor
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
results = executor.map(self.run, tasks, imgs)

# Print the results for debugging
for result in results:
print(result)


def run_batch(self, tasks_images: List[Tuple[str, str]]) -> List[str]:
"""Process a batch of tasks and images"""
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(self.run, task, img)
for task, img in tasks_images
]
results = [future.result() for future in futures]
return results

async def run_batch_async(
self, tasks_images: List[Tuple[str, str]]
) -> List[str]:
"""Process a batch of tasks and images asynchronously"""
loop = asyncio.get_event_loop()
futures = [
loop.run_in_executor(None, self.run, task, img)
for task, img in tasks_images
]
return await asyncio.gather(*futures)

async def run_batch_async_with_retries(
self, tasks_images: List[Tuple[str, str]]
) -> List[str]:
"""Process a batch of tasks and images asynchronously with retries"""
loop = asyncio.get_event_loop()
futures = [
loop.run_in_executor(None, self.run_with_retries, task, img)
for task, img in tasks_images
]
return await asyncio.gather(*futures)

def unique_chat_history(self):
"""Get the unique chat history"""
return list(set(self.chat_history))

def run_with_retries(self, task: str, img: str):
"""Run the model with retries"""
for i in range(self.retries):
try:
return self.run(task, img)
except Exception as error:
print(f"Error with the request {error}")
continue

def run_batch_with_retries(self, tasks_images: List[Tuple[str, str]]):
"""Run the model with retries"""
for i in range(self.retries):
try:
return self.run_batch(tasks_images)
except Exception as error:
print(f"Error with the request {error}")
continue

def _tokens_per_second(self) -> float:
"""Tokens per second"""
elapsed_time = self.end_time - self.start_time
if elapsed_time == 0:
return float("inf")
return self._num_tokens() / elapsed_time

def _time_for_generation(self, task: str) -> float:
"""Time for Generation"""
self.start_time = time.time()
self.run(task)
self.end_time = time.time()
return self.end_time - self.start_time

@abstractmethod
def generate_summary(self, text: str) -> str:
"""Generate Summary"""
pass

def set_temperature(self, value: float):
"""Set Temperature"""
self.temperature = value

def set_max_tokens(self, value: int):
"""Set new max tokens"""
self.max_tokens = value

def get_generation_time(self) -> float:
"""Get generation time"""
if self.start_time and self.end_time:
return self.end_time - self.start_time
return 0

def get_chat_history(self):
"""Get the chat history"""
return self.chat_history

def get_unique_chat_history(self):
"""Get the unique chat history"""
return list(set(self.chat_history))

def get_chat_history_length(self):
"""Get the chat history length"""
return len(self.chat_history)

def get_unique_chat_history_length(self):
"""Get the unique chat history length"""
return len(list(set(self.chat_history)))

def get_chat_history_tokens(self):
"""Get the chat history tokens"""
return self._num_tokens()

10 changes: 5 additions & 5 deletions swarms/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def get_img(self, img: str):

def __call__(self, text: str, img: str):
"""Call the model with text and img paths"""
image_pil = Image.open(img)
img = self.get_img(img)
model_inputs = self.processor(
text=text, images=[image_pil], device=self.device_map
text=text, images=[img], device=self.device_map
)

for k, v in model_inputs.items():
Expand All @@ -79,13 +79,13 @@ def __call__(self, text: str, img: str):
)
return print(str(text))

def get_img_from_web(self, img_url: str):
def get_img_from_web(self, img: str):
"""Get the image from the web"""
try:
response = requests.get(img_url)
response = requests.get(img)
response.raise_for_status()
image_pil = Image.open(BytesIO(response.content))
return image_pil
except requests.RequestException as error:
print(f"Error fetching image from {img_url} and error: {error}")
print(f"Error fetching image from {img} and error: {error}")
return None
1 change: 0 additions & 1 deletion swarms/models/gpt4_vision_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def run(self, task: str, img: str):
except Exception as error:
print(f"Error with the request: {error}")
raise error
# Function to handle vision tasks

def __call__(self, task: str, img: str):
"""Run the model."""
Expand Down
37 changes: 15 additions & 22 deletions swarms/models/kosmos_two.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,31 @@ def is_overlapping(rect1, rect2):

class Kosmos:
"""
Kosmos model by Yen-Chun Shieh
Parameters
----------
model_name : str
Path to the pretrained model
Examples
--------
>>> kosmos = Kosmos()
>>> kosmos("Hello, my name is", "path/to/image.png")
Args:
# Initialize Kosmos
kosmos = Kosmos()
# Perform multimodal grounding
kosmos.multimodal_grounding("Find the red apple in the image.", "https://example.com/apple.jpg")
# Perform referring expression comprehension
kosmos.referring_expression_comprehension("Show me the green bottle.", "https://example.com/bottle.jpg")
# Generate referring expressions
kosmos.referring_expression_generation("It is on the table.", "https://example.com/table.jpg")
# Perform grounded visual question answering
kosmos.grounded_vqa("What is the color of the car?", "https://example.com/car.jpg")
# Generate grounded image caption
kosmos.grounded_image_captioning("https://example.com/beach.jpg")
"""

def __init__(
self,
model_name="ydshieh/kosmos-2-patch14-224",
*args,
**kwargs,
):
self.model = AutoModelForVision2Seq.from_pretrained(
model_name, trust_remote_code=True
model_name, trust_remote_code=True, *args, **kwargs
)
self.processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=True
model_name, trust_remote_code=True, *args, **kwargs
)

def get_image(self, url):
Expand Down
Loading

0 comments on commit 51c82cf

Please sign in to comment.