Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemini multimodality update phi 2173 #1548

Merged
merged 7 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cookbook/providers/google/audio_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from phi.agent import Agent
from phi.model.google import Gemini

agent = Agent(
model=Gemini(id="gemini-2.0-flash-exp"),
markdown=True,
)

# Please download a sample audio file to test this Agent
agent.print_response(
"Tell me about this audio",
audio={"data": "cookbook/providers/google/sample_audio.mp3"},
stream=True,
)
17 changes: 17 additions & 0 deletions cookbook/providers/google/image_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from phi.agent import Agent
from phi.model.google import Gemini
from phi.tools.duckduckgo import DuckDuckGo

agent = Agent(
model=Gemini(id="gemini-2.0-flash-exp"),
tools=[DuckDuckGo()],
markdown=True,
)

agent.print_response(
"Tell me about this image and give me the latest news about it.",
images=[
"https://upload.wikimedia.org/wikipedia/commons/b/bf/Krakow_-_Kosciol_Mariacki.jpg",
],
stream=True,
)
16 changes: 16 additions & 0 deletions cookbook/providers/google/video_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from phi.agent import Agent
from phi.model.google import Gemini

agent = Agent(
model=Gemini(id="gemini-2.0-flash-exp"),
markdown=True,
)

# Please download "GreatRedSpot.mp4" using wget https://storage.googleapis.com/generativeai-downloads/images/GreatRedSpot.mp4
agent.print_response(
"Tell me about this video",
videos=[
"cookbook/providers/google/GreatRedSpot.mp4",
],
stream=True,
)
56 changes: 37 additions & 19 deletions phi/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,9 +1126,11 @@ def convert_context_to_string(self, context: Dict[str, Any]) -> str:

def get_user_message(
self,
*,
message: Optional[Union[str, List]],
images: Optional[Sequence[Union[str, Dict]]] = None,
audio: Optional[Dict] = None,
images: Optional[Sequence[Union[str, Dict]]] = None,
videos: Optional[Sequence[Union[str, Dict]]] = None,
**kwargs: Any,
) -> Optional[Message]:
"""Return the user message for the Agent.
Expand Down Expand Up @@ -1170,8 +1172,9 @@ def get_user_message(
return Message(
role=self.user_message_role,
content=user_prompt_content,
images=images,
audio=audio,
images=images,
videos=videos,
**kwargs,
)

Expand All @@ -1182,8 +1185,9 @@ def get_user_message(
return Message(
role=self.user_message_role,
content=user_prompt_from_template,
images=images,
audio=audio,
images=images,
videos=videos,
**kwargs,
)

Expand Down Expand Up @@ -1220,17 +1224,19 @@ def get_user_message(
return Message(
role=self.user_message_role,
content=user_prompt,
images=images,
audio=audio,
images=images,
videos=videos,
**kwargs,
)

def get_messages_for_run(
self,
*,
message: Optional[Union[str, List, Dict, Message]] = None,
images: Optional[Sequence[Union[str, Dict]]] = None,
audio: Optional[Dict] = None,
images: Optional[Sequence[Union[str, Dict]]] = None,
videos: Optional[Sequence[Union[str, Dict]]] = None,
messages: Optional[Sequence[Union[Dict, Message]]] = None,
**kwargs: Any,
) -> Tuple[Optional[Message], List[Message], List[Message]]:
Expand Down Expand Up @@ -1312,7 +1318,7 @@ def get_messages_for_run(
elif isinstance(message, str) or isinstance(message, list):
# Get the user message
user_message: Optional[Message] = self.get_user_message(
message=message, images=images, audio=audio, **kwargs
message=message, audio=audio, images=images, videos=videos, **kwargs
)
# Add user message to the messages list
if user_message is not None:
Expand Down Expand Up @@ -1715,6 +1721,7 @@ def _run(
stream: bool = False,
audio: Optional[Dict] = None,
images: Optional[Sequence[Union[str, Dict]]] = None,
videos: Optional[Sequence[Union[str, Dict]]] = None,
messages: Optional[Sequence[Union[Dict, Message]]] = None,
stream_intermediate_steps: bool = False,
**kwargs: Any,
Expand Down Expand Up @@ -1753,7 +1760,7 @@ def _run(

# 3. Prepare messages for this run
system_message, user_messages, messages_for_model = self.get_messages_for_run(
message=message, images=images, audio=audio, messages=messages, **kwargs
message=message, audio=audio, images=images, videos=videos, messages=messages, **kwargs
)

# 4. Reason about the task if reasoning is enabled
Expand Down Expand Up @@ -1933,8 +1940,9 @@ def run(
message: Optional[Union[str, List, Dict, Message]] = None,
*,
stream: Literal[False] = False,
images: Optional[Sequence[Union[str, Dict]]] = None,
audio: Optional[Dict] = None,
images: Optional[Sequence[Union[str, Dict]]] = None,
videos: Optional[Sequence[Union[str, Dict]]] = None,
messages: Optional[Sequence[Union[Dict, Message]]] = None,
**kwargs: Any,
) -> RunResponse: ...
Expand All @@ -1945,8 +1953,9 @@ def run(
message: Optional[Union[str, List, Dict, Message]] = None,
*,
stream: Literal[True] = True,
images: Optional[Sequence[Union[str, Dict]]] = None,
audio: Optional[Dict] = None,
images: Optional[Sequence[Union[str, Dict]]] = None,
videos: Optional[Sequence[Union[str, Dict]]] = None,
messages: Optional[Sequence[Union[Dict, Message]]] = None,
stream_intermediate_steps: bool = False,
**kwargs: Any,
Expand All @@ -1957,8 +1966,9 @@ def run(
message: Optional[Union[str, List, Dict, Message]] = None,
*,
stream: bool = False,
images: Optional[Sequence[Union[str, Dict]]] = None,
audio: Optional[Dict] = None,
images: Optional[Sequence[Union[str, Dict]]] = None,
videos: Optional[Sequence[Union[str, Dict]]] = None,
messages: Optional[Sequence[Union[Dict, Message]]] = None,
stream_intermediate_steps: bool = False,
**kwargs: Any,
Expand All @@ -1973,8 +1983,9 @@ def run(
self._run(
message=message,
stream=False,
images=images,
audio=audio,
images=images,
videos=videos,
messages=messages,
stream_intermediate_steps=stream_intermediate_steps,
**kwargs,
Expand Down Expand Up @@ -2022,8 +2033,9 @@ def run(
resp = self._run(
message=message,
stream=True,
images=images,
audio=audio,
images=images,
videos=videos,
messages=messages,
stream_intermediate_steps=stream_intermediate_steps,
**kwargs,
Expand All @@ -2033,8 +2045,9 @@ def run(
resp = self._run(
message=message,
stream=False,
images=images,
audio=audio,
images=images,
videos=videos,
messages=messages,
stream_intermediate_steps=stream_intermediate_steps,
**kwargs,
Expand All @@ -2046,8 +2059,9 @@ async def _arun(
message: Optional[Union[str, List, Dict, Message]] = None,
*,
stream: bool = False,
images: Optional[Sequence[Union[str, Dict]]] = None,
audio: Optional[Dict] = None,
images: Optional[Sequence[Union[str, Dict]]] = None,
videos: Optional[Sequence[Union[str, Dict]]] = None,
messages: Optional[Sequence[Union[Dict, Message]]] = None,
stream_intermediate_steps: bool = False,
**kwargs: Any,
Expand Down Expand Up @@ -2083,7 +2097,7 @@ async def _arun(

# 3. Prepare messages for this run
system_message, user_messages, messages_for_model = self.get_messages_for_run(
message=message, images=images, audio=audio, messages=messages, **kwargs
message=message, audio=audio, images=images, videos=videos, messages=messages, **kwargs
)

# 4. Reason about the task if reasoning is enabled
Expand Down Expand Up @@ -2263,8 +2277,9 @@ async def arun(
message: Optional[Union[str, List, Dict, Message]] = None,
*,
stream: bool = False,
images: Optional[Sequence[Union[str, Dict]]] = None,
audio: Optional[Dict] = None,
images: Optional[Sequence[Union[str, Dict]]] = None,
videos: Optional[Sequence[Union[str, Dict]]] = None,
messages: Optional[Sequence[Union[Dict, Message]]] = None,
stream_intermediate_steps: bool = False,
**kwargs: Any,
Expand All @@ -2278,8 +2293,9 @@ async def arun(
run_response = await self._arun(
message=message,
stream=False,
images=images,
audio=audio,
images=images,
videos=videos,
messages=messages,
stream_intermediate_steps=stream_intermediate_steps,
**kwargs,
Expand Down Expand Up @@ -2324,8 +2340,9 @@ async def arun(
resp = self._arun(
message=message,
stream=True,
images=images,
audio=audio,
images=images,
videos=videos,
messages=messages,
stream_intermediate_steps=stream_intermediate_steps,
**kwargs,
Expand All @@ -2335,8 +2352,9 @@ async def arun(
resp = self._arun(
message=message,
stream=False,
images=images,
audio=audio,
images=images,
videos=videos,
messages=messages,
stream_intermediate_steps=stream_intermediate_steps,
**kwargs,
Expand Down
98 changes: 94 additions & 4 deletions phi/model/google/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]:

# Add content to the message for the model
content = message.content
if not content or message.role == "tool":
if not content or message.role in ["tool", "model"]:
parts = message.parts # type: ignore
else:
if isinstance(content, str):
Expand All @@ -157,6 +157,96 @@ def format_messages(self, messages: List[Message]) -> List[Dict[str, Any]]:
else:
parts = [" "]
message_for_model["parts"] = parts

# Add images to the message for the model
if message.images is not None and message.role == "user":
for image in message.images:
if isinstance(image, str):
# Case 1: Image is a URL
if image.startswith("http://") or image.startswith("https://"):
try:
import httpx
import base64

image_content = httpx.get(image).content
image_data = {
"mime_type": "image/jpeg",
"data": base64.b64encode(image_content).decode("utf-8"),
}
message_for_model["parts"].append(image_data) # type: ignore
except Exception as e:
logger.warning(f"Failed to download image from {image}: {e}")
continue
# Case 2: Image is a path
else:
try:
from os.path import exists, isfile
import PIL.Image
except ImportError:
logger.error("`PIL.Image not installed. Please install it using 'pip install pillow'`")
raise

try:
if exists(image) and isfile(image):
image_data = PIL.Image.open(image) # type: ignore
else:
logger.error(f"Image file {image} does not exist.")
raise
message_for_model["parts"].append(image_data) # type: ignore
except Exception as e:
logger.warning(f"Failed to load image from {image}: {e}")
continue

elif isinstance(image, bytes):
image_data = {"mime_type": "image/jpeg", "data": base64.b64encode(image).decode("utf-8")}
message_for_model["parts"].append(image_data)

if message.videos is not None and message.role == "user":
try:
for video in message.videos:
import time
from os.path import exists, isfile

video_file = None
if exists(video) and isfile(video): # type: ignore
video_file = genai.upload_file(path=video)
else:
logger.error(f"Video file {video} does not exist.")
raise

# Check whether the file is ready to be used.
while video_file.state.name == "PROCESSING":
time.sleep(10)
video_file = genai.get_file(video_file.name)

if video_file.state.name == "FAILED":
raise ValueError(video_file.state.name)

message_for_model["parts"].insert(0, video_file) # type: ignore

except Exception as e:
logger.warning(f"Failed to load video from {message.videos}: {e}")
continue

if message.audio is not None and message.role == "user":
try:
from pathlib import Path
from os.path import exists, isfile

audio = message.audio.get("data")
if audio:
audio_file = None
if exists(audio) and isfile(audio):
audio_file = {"mime_type": "audio/mp3", "data": Path(audio).read_bytes()}
else:
logger.error(f"Audio file {audio} does not exist.")
raise
message_for_model["parts"].insert(0, audio_file) # type: ignore

except Exception as e:
logger.warning(f"Failed to load video from {message.videos}: {e}")
continue

formatted_messages.append(message_for_model)
return formatted_messages

Expand Down Expand Up @@ -418,8 +508,8 @@ def format_function_call_results(
messages (List[Message]): The list of conversation messages.
"""
if function_call_results:
combined_content = [] # Use a list to collect all result contents
combined_parts = [] # Use a list to collect all function responses
combined_content: List = []
combined_parts: List = []

for result in function_call_results:
s = Struct()
Expand All @@ -429,7 +519,7 @@ def format_function_call_results(
)
combined_content.append(result.content)
combined_parts.append(function_response)
messages.append(Message(role="tool", content="\n".join(combined_content), parts=combined_parts)) # type: ignore
messages.append(Message(role="tool", content=combined_content, parts=combined_parts))

def handle_tool_calls(self, assistant_message: Message, messages: List[Message], model_response: ModelResponse):
"""
Expand Down
1 change: 1 addition & 0 deletions phi/model/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def log(self, level: Optional[str] = None):
if self.videos:
_logger(f"Number of Videos: {len(self.videos)}")
if self.audio:
_logger(f"Number of Audio Files: {len(self.audio)}")
if "id" in self.audio:
_logger(f"Audio ID: {self.audio['id']}")
elif "data" in self.audio:
Expand Down
Loading