Skip to content

Commit

Permalink
feat: allow uploading multiple images in a single generation
Browse files Browse the repository at this point in the history
  • Loading branch information
HanaokaYuzu committed May 23, 2024
1 parent 17523c2 commit e978a5f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 17 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,14 @@ asyncio.run(main())
### Generate contents from image

Gemini supports image recognition and generate contents from image (currently only supports one image at a time). Optionally, you can pass image data in `bytes` or its path in `str` to `GeminiClient.generate_content` together with text prompt.
Gemini supports image recognition and generating contents from images. Optionally, you can pass images in a list of file data in `bytes` or their paths in `str` to `GeminiClient.generate_content` together with text prompt.

```python
async def main():
response = await client.generate_content("Describe the image", image="assets/banner.png")
response = await client.generate_content(
"Describe each of these images",
images=["assets/banner.png", "assets/favicon.png"],
)
print(response.text)

asyncio.run(main())
Expand Down
Binary file added assets/favicon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 22 additions & 12 deletions src/gemini_webapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ def __init__(
self.cookies = {}
self.proxies = proxies
self.running: bool = False
self.client: AsyncClient = None
self.access_token: str = None
self.client: AsyncClient | None = None
self.access_token: str | None = None
self.timeout: float = 30
self.auto_close: bool = False
self.close_delay: float = 300
self.close_task: Task = None
self.close_task: Task | None = None
self.auto_refresh: bool = True
self.refresh_interval: float = 540

Expand Down Expand Up @@ -252,7 +252,7 @@ async def start_auto_refresh(self) -> None:
async def generate_content(
self,
prompt: str,
image: bytes | str | None = None,
images: list[bytes | str] | None = None,
chat: Optional["ChatSession"] = None,
) -> ModelOutput:
"""
Expand All @@ -262,8 +262,8 @@ async def generate_content(
----------
prompt: `str`
Prompt provided by user.
image: `bytes` | `str`, optional
File data in bytes, or path to the image file to be sent together with the prompt.
images: `list[bytes | str]`, optional
List of image file data in bytes or file paths in string.
chat: `ChatSession`, optional
Chat data to retrieve conversation history. If None, will automatically generate a new chat id when sending post request.
Expand Down Expand Up @@ -300,12 +300,22 @@ async def generate_content(
None,
json.dumps(
[
image
images
and [
prompt,
0,
None,
[[[await upload_file(image, self.proxies), 1]]],
[
[
[
await upload_file(
image, self.proxies
),
1,
]
]
for image in images
],
]
or [prompt],
None,
Expand Down Expand Up @@ -475,7 +485,7 @@ def __setattr__(self, name: str, value: Any) -> None:
self.rcid = value.rcid

async def send_message(
self, prompt: str, image: bytes | str | None = None
self, prompt: str, images: list[bytes | str] | None = None,
) -> ModelOutput:
"""
Generates contents with prompt.
Expand All @@ -485,8 +495,8 @@ async def send_message(
----------
prompt: `str`
Prompt provided by user.
image: `bytes` | `str`, optional
File data in bytes, or path to the image file to be sent together with the prompt.
images: `list[bytes | str]`, optional
List of image file data in bytes or file paths in string.
Returns
-------
Expand All @@ -507,7 +517,7 @@ async def send_message(
- If response structure is invalid and failed to parse.
"""
return await self.geminiclient.generate_content(
prompt=prompt, image=image, chat=self
prompt=prompt, images=images, chat=self
)

def choose_candidate(self, index: int) -> ModelOutput:
Expand Down
7 changes: 4 additions & 3 deletions tests/test_client_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def test_successful_request(self):
@logger.catch(reraise=True)
async def test_upload_image(self):
response = await self.geminiclient.generate_content(
"Describe the image", image="assets/banner.png"
"Describe the image", images=["assets/banner.png"]
)
self.assertTrue(response.text)
logger.debug(response.text)
Expand Down Expand Up @@ -59,11 +59,12 @@ async def test_retrieve_previous_conversation(self):
async def test_chatsession_with_image(self):
chat = self.geminiclient.start_chat()
response1 = await chat.send_message(
"Describe the image", image="assets/banner.png"
"What's the difference between these two images?",
images=["assets/pic1.png", "assets/pic2.png"],
)
self.assertTrue(response1.text)
logger.debug(response1.text)
response2 = await chat.send_message("Tell me more about it.")
response2 = await chat.send_message("Tell me more.")
self.assertTrue(response2.text)
logger.debug(response2.text)

Expand Down

0 comments on commit e978a5f

Please sign in to comment.