Skip to content

Commit

Permalink
fix: update GeminiClient.close and Image.save
Browse files Browse the repository at this point in the history
  • Loading branch information
HanaokaYuzu committed Feb 25, 2024
1 parent 4da5170 commit 3341fd2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 19 deletions.
40 changes: 26 additions & 14 deletions src/gemini/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,23 +80,23 @@ def __init__(
self.access_token: Optional[str] = None
self.running: bool = False
self.auto_close: bool = False
self.close_delay: int = 0
self.close_delay: float = 300
self.close_task: Task | None = None

async def init(
self, timeout: float = 30, auto_close: bool = False, close_delay: int = 300
self, timeout: float = 30, auto_close: bool = False, close_delay: float = 300
) -> None:
"""
Get SNlM0e value as access token. Without this token posting will fail with 400 bad request.
Parameters
----------
timeout: `int`, optional
timeout: `float`, optional
Request timeout of the client in seconds. Used to limit the max waiting time when sending a request
auto_close: `bool`, optional
If `True`, the client will close connections and clear resource usage after a certain period
of inactivity. Useful for keep-alive services
close_delay: `int`, optional
close_delay: `float`, optional
Time to wait before auto-closing the client in seconds. Effective only if `auto_close` is `True`
"""
try:
Expand Down Expand Up @@ -130,19 +130,25 @@ async def init(
if self.auto_close:
await self.reset_close_task()
except Exception:
await self.close(0)
await self.close()
raise

async def close(self, wait: int | None = None) -> None:
async def close(self, delay: float = 0) -> None:
"""
Close the client after a certain period of inactivity, or call manually to close immediately.
Parameters
----------
wait: `int`, optional
delay: `float`, optional
Time to wait before closing the client in seconds
"""
await asyncio.sleep(wait is not None and wait or self.close_delay)
if delay:
await asyncio.sleep(delay)

if self.close_task:
self.close_task.cancel()
self.close_task = None

await self.client.aclose()
self.running = False

Expand All @@ -153,7 +159,7 @@ async def reset_close_task(self) -> None:
if self.close_task:
self.close_task.cancel()
self.close_task = None
self.close_task = asyncio.create_task(self.close())
self.close_task = asyncio.create_task(self.close(self.close_delay))

@running
async def generate_content(
Expand Down Expand Up @@ -196,23 +202,25 @@ async def generate_content(
)

if response.status_code != 200:
await self.close(0)
await self.close()
raise APIError(
f"Failed to generate contents. Request failed with status code {response.status_code}"
)
else:
try:
body = json.loads(json.loads(response.text.split("\n")[2])[0][2]) # Plain request
# Plain request
body = json.loads(json.loads(response.text.split("\n")[2])[0][2])

if not body[4]:
body = json.loads(json.loads(response.text.split("\n")[2])[4][2]) # Request with extensions as middleware
# Request with extensions as middleware
body = json.loads(json.loads(response.text.split("\n")[2])[4][2])

if not body[4]:
raise APIError(
"Failed to parse response body. Data structure is invalid. To report this error, please submit an issue at https://github.com/HanaokaYuzu/Gemini-API/issues"
)
except Exception:
await self.close(0)
await self.close()
raise APIError(
"Failed to generate contents. Invalid response data received. Client will try to re-initiate on next request."
)
Expand All @@ -223,7 +231,11 @@ async def generate_content(
web_images = (
candidate[4]
and [
WebImage(url=image[0][0][0], title=image[2], alt=image[0][4])
WebImage(
url=image[0][0][0],
title=image[2],
alt=image[0][4],
)
for image in candidate[4]
]
or []
Expand Down
29 changes: 24 additions & 5 deletions src/gemini/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from pathlib import Path
from datetime import datetime

from loguru import logger
Expand Down Expand Up @@ -31,22 +32,34 @@ def __repr__(self):
return f"""Image(title='{self.title}', url='{len(self.url) <= 20 and self.url or self.url[:8] + '...' + self.url[-12:]}', alt='{self.alt}')"""

async def save(
self, path: str = "temp/", filename: str = None, cookies: dict = None
self,
path: str = "temp",
filename: str | None = None,
cookies: dict | None = None,
verbose: bool = False
) -> None:
"""
Save the image to disk.
Parameters
----------
path: `str`, optional
Path to save the image
Path to save the image, by default will save to ./temp
filename: `str`, optional
Filename to save the image, by default will use the original filename from the URL
cookies: `dict`, optional
Cookies used for requesting the content of the image
verbose : `bool`, optional
If True, print the path of the saved file, by default False
"""
try:
filename = filename or re.search(r"^(.*\.\w+)", self.url.split("/")[-1]).group()
filename = (
filename
or (
re.search(r"^(.*\.\w+)", self.url.split("/")[-1])
or re.search(r"^(.*)\?", self.url.split("/")[-1])
).group()
)
except AttributeError:
filename = self.url.split("/")[-1]

Expand All @@ -59,8 +72,14 @@ async def save(
f"Content type of {filename} is not image, but {content_type}."
)

with open(f"{path}{filename}", "wb") as file:
file.write(response.content)
path = Path(path)
path.mkdir(parents=True, exist_ok=True)

dest = path / filename
dest.write_bytes(response.content)

if verbose:
logger.info(f"Image saved as {dest.resolve()}")
else:
raise HTTPError(
f"Error downloading image: {response.status_code} {response.reason_phrase}"
Expand Down

0 comments on commit 3341fd2

Please sign in to comment.