From 8f4d469c993368692cca530b61c72fe52a1ea9b0 Mon Sep 17 00:00:00 2001 From: UZQueen <157540577+HanaokaYuzu@users.noreply.github.com> Date: Mon, 18 Mar 2024 19:28:21 -0500 Subject: [PATCH] feat: v1.0.0 release - feat: add support to auto refresh cookies in background - feat: add support to import cookies from local browser - feat: add support to control log level - feat: now client will automatically retry when generate_content raises APIError - fix: now the timeout value will be correctly applied after re-initializing the client - docs: update readme and function docstrings - refactor: split utils.py into multiple files - build: update supported python version close #6 --- .gitignore | 2 +- .vscode/launch.json | 15 ++ README.md | 35 ++- pyproject.toml | 2 +- src/gemini_webapi/__init__.py | 1 + src/gemini_webapi/client.py | 245 +++++++++++++----- src/gemini_webapi/types/image.py | 9 +- src/gemini_webapi/utils/__init__.py | 9 + src/gemini_webapi/utils/get_access_token.py | 150 +++++++++++ src/gemini_webapi/utils/logger.py | 38 +++ src/gemini_webapi/utils/rotate_1psidts.py | 55 ++++ .../{utils.py => utils/upload_file.py} | 2 +- tests/test_client_features.py | 6 +- tests/test_rotate_cookie.py | 26 ++ tests/test_save_image.py | 6 +- 15 files changed, 527 insertions(+), 74 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 src/gemini_webapi/utils/__init__.py create mode 100644 src/gemini_webapi/utils/get_access_token.py create mode 100644 src/gemini_webapi/utils/logger.py create mode 100644 src/gemini_webapi/utils/rotate_1psidts.py rename src/gemini_webapi/{utils.py => utils/upload_file.py} (96%) create mode 100644 tests/test_rotate_cookie.py diff --git a/.gitignore b/.gitignore index 9ef091f..86c6a83 100644 --- a/.gitignore +++ b/.gitignore @@ -201,4 +201,4 @@ Temporary Items .apdisk # Temporary files -/temp +temp/ diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..6b76b4f --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,15 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md index 0b0a40a..febb2bc 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ A reverse-engineered asynchronous python wrapper for [Google Gemini](https://gem ## Features -- **(WIP) Auto Cookie Management** +- **Smart Cookies** - Automatically import cookies and refresh them in background. Free up your hands! - **ImageFx Support** - Supports retrieving images generated by ImageFx, Google's latest AI image generator. - **Extension Support** - Supports generating contents with [Gemini extensions](https://gemini.google.com/extensions) on, like YouTube and Gmail. - **Classified Outputs** - Auto categorizes texts, web images and AI generated images from the response. @@ -51,6 +51,7 @@ A reverse-engineered asynchronous python wrapper for [Google Gemini](https://gem - [Save images to local files](#save-images-to-local-files) - [Generate contents with Gemini extensions](#generate-contents-with-gemini-extensions) - [Check and switch to other reply candidates](#check-and-switch-to-other-reply-candidates) + - [Control log level](#control-log-level) - [References](#references) - [Stargazers](#stargazers) @@ -60,35 +61,45 @@ A reverse-engineered asynchronous python wrapper for [Google Gemini](https://gem pip install gemini_webapi ``` +Optionally, package offers a way to automatically import cookies from your local browser. To enable this feature, install `browser-cookie3` as well. Supported platforms and browsers can be found [here](https://github.com/borisbabic/browser_cookie3?tab=readme-ov-file#contribute). + +```bash +pip install browser-cookie3 +``` + ## Authentication +> [!NOTE] +> +> If `browser-cookie3` is installed, you can skip this step and go directly to [usage](#usage) section. Just make sure you have logged in to in your browser. + - Go to and login with your Google account - Press F12 for web inspector, go to `Network` tab and refresh the page - Click any request and copy cookie values of `__Secure-1PSID` and `__Secure-1PSIDTS` > [!TIP] > -> `__Secure-1PSIDTS` could get expired frequently if is kept opened in the browser after copying cookies. It's recommended to get cookies from a separate session (e.g. a new login in browser's private mode) if you are building a keep-alive service with this package. -> -> For more details, please refer to discussions in [issue #6](https://github.com/HanaokaYuzu/Gemini-API/issues/6) +> API's auto cookie refresh feature may cause that you need to re-login to your Google account in the browser. This is an expected behavior and won't affect the API's functionality. To avoid such result, it's recommended to get cookies from a separate browser session for best utilization (e.g. a fresh login in browser's private mode), or set `auto_refresh` to `False` in `GeminiClient.init` to disable this feature. ## Usage ### Initialization -Import required packages and initialize a client with your cookies obtained from the previous step. +Import required packages and initialize a client with your cookies obtained from the previous step. After a successful initialization, the API will automatically refresh `__Secure-1PSIDTS` in background as long as the process is alive. ```python import asyncio from gemini_webapi import GeminiClient -# Replace "COOKIE VALUE HERE" with your actual cookie values +# Replace "COOKIE VALUE HERE" with your actual cookie values. +# Leave Secure_1PSIDTS empty if it's not available for your account. Secure_1PSID = "COOKIE VALUE HERE" Secure_1PSIDTS = "COOKIE VALUE HERE" async def main(): + # If browser-cookie3 is installed, simply use `client = GeminiClient()` client = GeminiClient(Secure_1PSID, Secure_1PSIDTS, proxies=None) - await client.init(timeout=30, auto_close=False, close_delay=300) + await client.init(timeout=30, auto_close=False, close_delay=300, auto_refresh=True) asyncio.run(main()) ``` @@ -250,6 +261,16 @@ async def main(): asyncio.run(main()) ``` +### Control log level + +You can set the log level of the package to one of the following values: `DEBUG`, `INFO`, `WARNING`, `ERROR` and `CRITICAL`. The default value is `INFO`. + +```python +from gemini_webapi import set_log_level + +set_log_level("DEBUG") +``` + ## References [Google AI Studio](https://ai.google.dev/tutorials/ai-studio_quickstart) diff --git a/pyproject.toml b/pyproject.toml index 8791613..079c1ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", ] -requires-python = ">=3.7" +requires-python = ">=3.8" dependencies = [ "httpx>=0.25.2", "pydantic>=2.5.3", diff --git a/src/gemini_webapi/__init__.py b/src/gemini_webapi/__init__.py index c1ac8cc..2704b87 100644 --- a/src/gemini_webapi/__init__.py +++ b/src/gemini_webapi/__init__.py @@ -1,3 +1,4 @@ from .client import GeminiClient, ChatSession # noqa: F401 from .exceptions import * # noqa: F401, F403 from .types import * # noqa: F401, F403 +from .utils import set_log_level # noqa: F401 diff --git a/src/gemini_webapi/client.py b/src/gemini_webapi/client.py index 4cd4f7a..c40ad94 100644 --- a/src/gemini_webapi/client.py +++ b/src/gemini_webapi/client.py @@ -1,83 +1,149 @@ -import re import json +import functools import asyncio from asyncio import Task from typing import Any, Optional from httpx import AsyncClient, ReadTimeout -from loguru import logger from .types import WebImage, GeneratedImage, Candidate, ModelOutput -from .exceptions import APIError, AuthError, TimeoutError, GeminiError +from .exceptions import AuthError, APIError, TimeoutError, GeminiError from .constants import Endpoint, Headers -from .utils import upload_file +from .utils import ( + get_cookie_by_name, + upload_file, + get_access_token, + rotate_1psidts, + rotate_tasks, + logger, +) -def running(func) -> callable: +def running(retry: int = 0) -> callable: """ Decorator to check if client is running before making a request. + + Parameters + ---------- + retry: `int`, optional + Max number of retries when `gemini_webapi.APIError` is raised. """ - async def wrapper(self: "GeminiClient", *args, **kwargs): - if not self.running: - await self.init(auto_close=self.auto_close, close_delay=self.close_delay) - if self.running: - return await func(self, *args, **kwargs) + def decorator(func): + @functools.wraps(func) + async def wrapper(client: "GeminiClient", *args, retry=retry, **kwargs): + try: + if not client.running: + await client.init( + timeout=client.timeout, + auto_close=client.auto_close, + close_delay=client.close_delay, + auto_refresh=client.auto_refresh, + refresh_interval=client.refresh_interval, + verbose=False, + ) + if client.running: + return await func(client, *args, **kwargs) - raise Exception( - f"Invalid function call: GeminiClient.{func.__name__}. Client initialization failed." - ) - else: - return await func(self, *args, **kwargs) + # Should not reach here + raise APIError( + f"Invalid function call: GeminiClient.{func.__name__}. Client initialization failed." + ) + else: + return await func(client, *args, **kwargs) + except APIError: + if retry > 0: + await asyncio.sleep(1) + return await wrapper(client, *args, retry=retry - 1, **kwargs) + raise + + return wrapper - return wrapper + return decorator class GeminiClient: """ Async httpx client interface for gemini.google.com. + `secure_1psid` must be provided unless the optional dependency `browser-cookie3` is installed and + you have logged in to google.com in your local browser. + Parameters ---------- - secure_1psid: `str` + secure_1psid: `str`, optional __Secure-1PSID cookie value. secure_1psidts: `str`, optional __Secure-1PSIDTS cookie value, some google accounts don't require this value, provide only if it's in the cookie list. proxies: `dict`, optional Dict of proxies. + + Raises + ------ + `ValueError` + If `secure_1psid` is not provided and optional dependency `browser-cookie3` is not installed, or + `browser-cookie3` is installed but cookies for google.com are not found in your local browser storage. """ __slots__ = [ "cookies", "proxies", + "running", "client", "access_token", - "running", + "timeout", "auto_close", "close_delay", "close_task", + "auto_refresh", + "refresh_interval", ] def __init__( self, - secure_1psid: str, - secure_1psidts: Optional[str] = None, - proxies: Optional[dict] = None, + secure_1psid: str | None = None, + secure_1psidts: str | None = None, + proxies: dict | None = None, ): - self.cookies = {"__Secure-1PSID": secure_1psid} + self.cookies = {} self.proxies = proxies + self.running: bool = False self.client: AsyncClient = None self.access_token: str = None - self.running: bool = False + self.timeout: float = 30 self.auto_close: bool = False self.close_delay: float = 300 self.close_task: Task = None + self.auto_refresh: bool = True + self.refresh_interval: float = 540 + + # Validate cookies + if secure_1psid: + self.cookies["__Secure-1PSID"] = secure_1psid + if secure_1psidts: + self.cookies["__Secure-1PSIDTS"] = secure_1psidts + else: + try: + import browser_cookie3 - if secure_1psidts: - self.cookies["__Secure-1PSIDTS"] = secure_1psidts + cookies = browser_cookie3.load(domain_name="google.com") + if not (cookies and get_cookie_by_name(cookies, "__Secure-1PSID")): + raise ValueError( + "Failed to load cookies from local browser. Please pass cookie values manually." + ) + except ImportError: + raise ValueError( + "'secure_1psid' must be provided if optional dependency 'browser-cookie3' is not installed." + ) async def init( - self, timeout: float = 30, auto_close: bool = False, close_delay: float = 300 + self, + timeout: float = 30, + auto_close: bool = False, + close_delay: float = 300, + auto_refresh: bool = True, + refresh_interval: float = 540, + verbose: bool = True, ) -> None: """ Get SNlM0e value as access token. Without this token posting will fail with 400 bad request. @@ -91,37 +157,46 @@ async def init( of inactivity. Useful for keep-alive services. close_delay: `float`, optional Time to wait before auto-closing the client in seconds. Effective only if `auto_close` is `True`. + auto_refresh: `bool`, optional + If `True`, will schedule a task to automatically refresh cookies in the background. + refresh_interval: `float`, optional + Time interval for background cookie refresh in seconds. Effective only if `auto_refresh` is `True`. + verbose: `bool`, optional + If `True`, will print more infomation in logs. """ try: + access_token, valid_cookies = await get_access_token( + base_cookies=self.cookies, proxies=self.proxies, verbose=verbose + ) + self.client = AsyncClient( timeout=timeout, proxies=self.proxies, follow_redirects=True, headers=Headers.GEMINI.value, - cookies=self.cookies, + cookies=valid_cookies, ) + self.access_token = access_token + self.cookies = valid_cookies + self.running = True - response = await self.client.get(Endpoint.INIT.value) - - if response.status_code != 200: - raise APIError( - f"Failed to initiate client. Request failed with status code {response.status_code}" - ) - else: - match = re.search(r'"SNlM0e":"(.*?)"', response.text) - if match: - self.access_token = match.group(1) - self.running = True - logger.success("Gemini client initiated successfully.") - else: - raise AuthError( - "Failed to initiate client. SECURE_1PSIDTS could get expired frequently, please make sure cookie values are up to date." - ) - + self.timeout = timeout self.auto_close = auto_close self.close_delay = close_delay if self.auto_close: await self.reset_close_task() + + self.auto_refresh = auto_refresh + self.refresh_interval = refresh_interval + if self.auto_refresh: + if task := rotate_tasks.get(self.cookies["__Secure-1PSID"]): + task.cancel() + rotate_tasks[self.cookies["__Secure-1PSID"]] = asyncio.create_task( + self.start_auto_refresh() + ) + + if verbose: + logger.success("Gemini client initialized successfully.") except Exception: await self.close() raise @@ -142,7 +217,9 @@ async def close(self, delay: float = 0) -> None: self.close_task.cancel() self.close_task = None - await self.client.aclose() + if self.client: + await self.client.aclose() + self.running = False async def reset_close_task(self) -> None: @@ -154,11 +231,30 @@ async def reset_close_task(self) -> None: self.close_task = None self.close_task = asyncio.create_task(self.close(self.close_delay)) - @running + async def start_auto_refresh(self) -> None: + """ + Start the background task to automatically refresh cookies. + """ + while True: + try: + new_1psidts = await rotate_1psidts(self.cookies, self.proxies) + except AuthError: + if task := rotate_tasks.get(self.cookies["__Secure-1PSID"]): + task.cancel() + logger.warning( + "Failed to refresh cookies. Background auto refresh task canceled." + ) + + logger.debug(f"Cookies refreshed. New __Secure-1PSIDTS: {new_1psidts}") + if new_1psidts: + self.cookies["__Secure-1PSIDTS"] = new_1psidts + await asyncio.sleep(self.refresh_interval) + + @running(retry=1) async def generate_content( self, prompt: str, - image: Optional[bytes | str] = None, + image: bytes | str | None = None, chat: Optional["ChatSession"] = None, ) -> ModelOutput: """ @@ -178,6 +274,18 @@ async def generate_content( :class:`ModelOutput` Output data from gemini.google.com, use `ModelOutput.text` to get the default text reply, `ModelOutput.images` to get a list of images in the default reply, `ModelOutput.candidates` to get a list of all answer candidates in the output. + + Raises + ------ + `AssertionError` + If prompt is empty. + `gemini_webapi.TimeoutError` + If request timed out. + `gemini_webapi.GenimiError` + If no reply candidate found in response. + `gemini_webapi.APIError` + - If request failed with status code other than 200. + - If response structure is invalid and failed to parse. """ assert prompt, "Prompt cannot be empty." @@ -212,7 +320,7 @@ async def generate_content( ) except ReadTimeout: raise TimeoutError( - "Request timed out, please try again. If the problem persists, consider setting a higher `timeout` value when initiating GeminiClient." + "Request timed out, please try again. If the problem persists, consider setting a higher `timeout` value when initializing GeminiClient." ) if response.status_code != 200: @@ -236,7 +344,7 @@ async def generate_content( except Exception: await self.close() raise APIError( - "Failed to generate contents. Invalid response data received. Client will try to re-initiate on next request." + "Failed to generate contents. Invalid response data received. Client will try to re-initialize on next request." ) try: @@ -299,7 +407,12 @@ async def generate_content( def start_chat(self, **kwargs) -> "ChatSession": """ - Returns a `ChatSession` object attached to this model. + Returns a `ChatSession` object attached to this client. + + Parameters + ---------- + kwargs: `dict`, optional + Other arguments to pass to `ChatSession.__init__`. Returns ------- @@ -327,20 +440,19 @@ class ChatSession: Reply candidate id, if provided together with metadata, will override the third value in it. """ - # @properties needn't have their slots pre-defined __slots__ = ["__metadata", "geminiclient", "last_output"] def __init__( self, geminiclient: GeminiClient, - metadata: Optional[list[str]] = None, - cid: Optional[str] = None, # chat id - rid: Optional[str] = None, # reply id - rcid: Optional[str] = None, # reply candidate id + metadata: list[str | None] | None = None, + cid: str | None = None, # chat id + rid: str | None = None, # reply id + rcid: str | None = None, # reply candidate id ): - self.__metadata: list[Optional[str]] = [None, None, None] + self.__metadata: list[str | None] = [None, None, None] self.geminiclient: GeminiClient = geminiclient - self.last_output: Optional[ModelOutput] = None + self.last_output: ModelOutput | None = None if metadata: self.metadata = metadata @@ -364,7 +476,7 @@ def __setattr__(self, name: str, value: Any) -> None: self.rcid = value.rcid async def send_message( - self, prompt: str, image: Optional[bytes | str] = None + self, prompt: str, image: bytes | str | None = None ) -> ModelOutput: """ Generates contents with prompt. @@ -382,6 +494,18 @@ async def send_message( :class:`ModelOutput` Output data from gemini.google.com, use `ModelOutput.text` to get the default text reply, `ModelOutput.images` to get a list of images in the default reply, `ModelOutput.candidates` to get a list of all answer candidates in the output. + + Raises + ------ + `AssertionError` + If prompt is empty. + `gemini_webapi.TimeoutError` + If request timed out. + `gemini_webapi.GenimiError` + If no reply candidate found in response. + `gemini_webapi.APIError` + - If request failed with status code other than 200. + - If response structure is invalid and failed to parse. """ return await self.geminiclient.generate_content( prompt=prompt, image=image, chat=self @@ -400,6 +524,11 @@ def choose_candidate(self, index: int) -> ModelOutput: ------- :class:`ModelOutput` Output data of the chosen candidate. + + Raises + ------ + `ValueError` + If no previous output data found in this chat session, or if index exceeds the number of candidates in last model output. """ if not self.last_output: raise ValueError("No previous output data found in this chat session.") diff --git a/src/gemini_webapi/types/image.py b/src/gemini_webapi/types/image.py index cb08ee1..7c4cbc0 100644 --- a/src/gemini_webapi/types/image.py +++ b/src/gemini_webapi/types/image.py @@ -4,7 +4,8 @@ from httpx import AsyncClient, HTTPError from pydantic import BaseModel, field_validator -from loguru import logger + +from ..utils import logger class Image(BaseModel): @@ -54,7 +55,7 @@ async def save( cookies: `dict`, optional Cookies used for requesting the content of the image. verbose : `bool`, optional - If True, print the path of the saved file or warning for invalid file name, by default False. + If True, will print the path of the saved file or warning for invalid file name, by default False. skip_invalid_filename: `bool`, optional If True, will only save the image if the file name and extension are valid, by default False. @@ -130,7 +131,7 @@ class GeneratedImage(Image): def validate_cookies(cls, v: dict) -> dict: if len(v) == 0: raise ValueError( - "GeneratedImage is designed to be initiated with same cookies as GeminiClient." + "GeneratedImage is designed to be initialized with same cookies as GeminiClient." ) return v @@ -144,7 +145,7 @@ async def save(self, **kwargs) -> None: filename: `str`, optional Filename to save the image, generated images are always in .png format, but file extension will not be included in the URL. And since the URL ends with a long hash, by default will use timestamp + end of the hash as the filename. - **kwargs: `dict`, optional + kwargs: `dict`, optional Other arguments to pass to `Image.save`. """ await super().save( diff --git a/src/gemini_webapi/utils/__init__.py b/src/gemini_webapi/utils/__init__.py new file mode 100644 index 0000000..2cb78d4 --- /dev/null +++ b/src/gemini_webapi/utils/__init__.py @@ -0,0 +1,9 @@ +from asyncio import Task + +from .upload_file import upload_file # noqa: F401 +from .rotate_1psidts import rotate_1psidts # noqa: F401 +from .get_access_token import get_access_token, get_cookie_by_name # noqa: F401 +from .logger import logger, set_log_level # noqa: F401 + + +rotate_tasks: dict[str, Task] = {} diff --git a/src/gemini_webapi/utils/get_access_token.py b/src/gemini_webapi/utils/get_access_token.py new file mode 100644 index 0000000..28f2da8 --- /dev/null +++ b/src/gemini_webapi/utils/get_access_token.py @@ -0,0 +1,150 @@ +import re +import asyncio +from asyncio import Task +from pathlib import Path +from http.cookiejar import CookieJar + +from httpx import AsyncClient, Response + +from ..constants import Endpoint, Headers +from ..exceptions import AuthError +from .logger import logger + + +def get_cookie_by_name(jar: CookieJar, name: str) -> str: + """ + Get the value of a cookie from a http CookieJar. + + Parameters + ---------- + jar : `http.cookiejar.CookieJar` + Cookie jar to be used. + name : `str` + Name of the cookie to be retrieved. + + Returns + ------- + `str` + Value of the cookie. + """ + + for cookie in jar: + if cookie.name == name: + return cookie.value + + +async def get_access_token( + base_cookies: dict, proxies: dict | None = None, verbose: bool = False +) -> tuple[str, dict]: + """ + Send a get request to gemini.google.com for each group of available cookies and return + the value of "SNlM0e" as access token on the first successful request. + + Possible cookie sources: + - Base cookies passed to the function. + - __Secure-1PSID from base cookies with __Secure-1PSIDTS from cache. + - Local browser cookies (if optional dependency `browser-cookie3` is installed). + + Parameters + ---------- + base_cookies : `dict` + Base cookies to be used in the request. + proxies: `dict`, optional + Dict of proxies. + verbose: `bool`, optional + If `True`, will print more infomation in logs. + + Returns + ------- + `str` + Access token. + `dict` + Cookies of the successful request. + + Raises + ------ + `gemini_webapi.AuthError` + If all requests failed. + """ + + async def send_request(cookies: dict) -> tuple[Response | None, dict]: + async with AsyncClient( + proxies=proxies, + headers=Headers.GEMINI.value, + cookies=cookies, + follow_redirects=True, + ) as client: + response = await client.get(Endpoint.INIT.value) + response.raise_for_status() + return response, cookies + + tasks = [] + + if "__Secure-1PSID" in base_cookies: + tasks.append(Task(send_request(base_cookies))) + + filename = f".cached_1psidts_{base_cookies['__Secure-1PSID']}.txt" + path = Path(__file__).parent / "temp" / filename + if path.is_file(): + cached_1psidts = path.read_text() + if cached_1psidts: + cached_cookies = {**base_cookies, "__Secure-1PSIDTS": cached_1psidts} + tasks.append(Task(send_request(cached_cookies))) + elif verbose: + logger.debug("Skipping loading cached cookies. Cache file is empty.") + elif verbose: + logger.debug("Skipping loading cached cookies. Cache file not found.") + elif verbose: + logger.debug( + "Skipping loading base cookies and cached cookies. __Secure-1PSID is not provided." + ) + + try: + import browser_cookie3 + + browser_cookies = browser_cookie3.load(domain_name="google.com") + if browser_cookies and ( + secure_1psid := get_cookie_by_name(browser_cookies, "__Secure-1PSID") + ): + local_cookies = {"__Secure-1PSID": secure_1psid} + if secure_1psidts := get_cookie_by_name( + browser_cookies, "__Secure-1PSIDTS" + ): + local_cookies["__Secure-1PSIDTS"] = secure_1psidts + tasks.append(Task(send_request(local_cookies))) + elif verbose: + logger.debug( + "Skipping loading local browser cookies. Login to gemini.google.com in your browser first." + ) + except ImportError: + if verbose: + logger.debug( + "Skipping loading local browser cookies. Optional dependency 'browser-cookie3' is not installed." + ) + except Exception as e: + if verbose: + logger.warning(f"Skipping loading local browser cookies. {e}") + + for i, future in enumerate(asyncio.as_completed(tasks)): + try: + response, request_cookies = await future + match = re.search(r'"SNlM0e":"(.*?)"', response.text) + if match: + if verbose: + logger.debug( + f"Init attempt ({i + 1}/{len(tasks)}) succeeded. Initializing client..." + ) + return match.group(1), request_cookies + elif verbose: + logger.debug( + f"Init attempt ({i + 1}/{len(tasks)}) failed. Cookies invalid." + ) + except Exception as e: + if verbose: + logger.debug( + f"Init attempt ({i + 1}/{len(tasks)}) failed with error: {e}" + ) + + raise AuthError( + "Failed to initialize client. SECURE_1PSIDTS could get expired frequently, please make sure cookie values are up to date." + ) diff --git a/src/gemini_webapi/utils/logger.py b/src/gemini_webapi/utils/logger.py new file mode 100644 index 0000000..714c859 --- /dev/null +++ b/src/gemini_webapi/utils/logger.py @@ -0,0 +1,38 @@ +import atexit +from sys import stderr + +from loguru._logger import Core as _Core +from loguru._logger import Logger as _Logger + +logger = _Logger( + core=_Core(), + exception=None, + depth=0, + record=False, + lazy=False, + colors=False, + raw=False, + capture=True, + patchers=[], + extra={}, +) + +if stderr: + logger.add(stderr, level="INFO") + +atexit.register(logger.remove) + + +def set_log_level(level: str): + """ + Set the log level for the whole module. Default is "INFO". Set to "DEBUG" to see more detailed logs. + + Parameters + ---------- + level : str + The log level to set. Must be one of "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL". + """ + assert level in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + + logger.remove() + logger.add(stderr, level=level) diff --git a/src/gemini_webapi/utils/rotate_1psidts.py b/src/gemini_webapi/utils/rotate_1psidts.py new file mode 100644 index 0000000..32f9f81 --- /dev/null +++ b/src/gemini_webapi/utils/rotate_1psidts.py @@ -0,0 +1,55 @@ +import os +import time +from pathlib import Path + +from httpx import AsyncClient + +from ..constants import Endpoint, Headers +from ..exceptions import AuthError + + +async def rotate_1psidts(cookies: dict, proxies: dict | None = None) -> str: + """ + Refresh the __Secure-1PSIDTS cookie and store the refreshed cookie value in cache file. + + Parameters + ---------- + cookies : `dict` + Cookies to be used in the request. + proxies: `dict`, optional + Dict of proxies. + + Returns + ------- + `str` + New value of the __Secure-1PSIDTS cookie. + + Raises + ------ + `gemini_webapi.AuthError` + If request failed with 401 Unauthorized. + `httpx.HTTPStatusError` + If request failed with other status codes. + """ + + path = Path(__file__).parent / "temp" + path.mkdir(parents=True, exist_ok=True) + filename = f".cached_1psidts_{cookies['__Secure-1PSID']}.txt" + path = path / filename + + # Check if the cache file was modified in the last minute to avoid 429 Too Many Requests + if not (path.is_file() and time.time() - os.path.getmtime(path) <= 60): + async with AsyncClient(proxies=proxies) as client: + response = await client.post( + url=Endpoint.ROTATE_COOKIES.value, + headers=Headers.ROTATE_COOKIES.value, + cookies=cookies, + data='[000,"-0000000000000000000"]', + ) + if response.status_code == 401: + raise AuthError + response.raise_for_status() + + if new_1psidts := response.cookies.get("__Secure-1PSIDTS"): + path.write_text(new_1psidts) + return new_1psidts diff --git a/src/gemini_webapi/utils.py b/src/gemini_webapi/utils/upload_file.py similarity index 96% rename from src/gemini_webapi/utils.py rename to src/gemini_webapi/utils/upload_file.py index 1456ee7..64b0e34 100644 --- a/src/gemini_webapi/utils.py +++ b/src/gemini_webapi/utils/upload_file.py @@ -1,7 +1,7 @@ from httpx import AsyncClient from pydantic import validate_call -from .constants import Endpoint, Headers +from ..constants import Endpoint, Headers @validate_call diff --git a/tests/test_client_features.py b/tests/test_client_features.py index 68835f6..dd1c411 100644 --- a/tests/test_client_features.py +++ b/tests/test_client_features.py @@ -1,9 +1,13 @@ import os import unittest +import logging from loguru import logger -from gemini_webapi import GeminiClient, AuthError +from gemini_webapi import GeminiClient, AuthError, set_log_level + +logging.getLogger("asyncio").setLevel(logging.ERROR) +set_log_level("DEBUG") class TestGeminiClient(unittest.IsolatedAsyncioTestCase): diff --git a/tests/test_rotate_cookie.py b/tests/test_rotate_cookie.py new file mode 100644 index 0000000..2955951 --- /dev/null +++ b/tests/test_rotate_cookie.py @@ -0,0 +1,26 @@ +import os +import asyncio + +from loguru import logger + +from gemini_webapi import GeminiClient, set_log_level + +set_log_level("DEBUG") + + +@logger.catch() +async def main(): + client = GeminiClient(os.getenv("SECURE_1PSID"), os.getenv("SECURE_1PSIDTS")) + await client.init(close_delay=30, refresh_interval=60) + + while True: + try: + response = await client.generate_content("Hello world") + logger.info(response) + except Exception as e: + logger.error(e) + await asyncio.sleep(60) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_save_image.py b/tests/test_save_image.py index 49d1f1e..d8f40e5 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -1,10 +1,14 @@ import os import unittest +import logging from httpx import HTTPError from loguru import logger -from gemini_webapi import GeminiClient, AuthError +from gemini_webapi import GeminiClient, AuthError, set_log_level + +logging.getLogger("asyncio").setLevel(logging.ERROR) +set_log_level("DEBUG") class TestGeminiClient(unittest.IsolatedAsyncioTestCase):