diff --git a/src/gemini_webapi/client.py b/src/gemini_webapi/client.py index beff247..2be9037 100644 --- a/src/gemini_webapi/client.py +++ b/src/gemini_webapi/client.py @@ -10,10 +10,10 @@ from .exceptions import AuthError, APIError, TimeoutError, GeminiError from .constants import Endpoint, Headers from .utils import ( - get_cookie_by_name, upload_file, - get_access_token, rotate_1psidts, + get_access_token, + load_browser_cookies, rotate_tasks, logger, ) @@ -124,10 +124,8 @@ def __init__( self.cookies["__Secure-1PSIDTS"] = secure_1psidts else: try: - import browser_cookie3 - - cookies = browser_cookie3.load(domain_name="google.com") - if not (cookies and get_cookie_by_name(cookies, "__Secure-1PSID")): + cookies = load_browser_cookies(domain_name="google.com") + if not (cookies and cookies.get("__Secure-1PSID")): raise ValueError( "Failed to load cookies from local browser. Please pass cookie values manually." ) @@ -154,7 +152,7 @@ async def init( Request timeout of the client in seconds. Used to limit the max waiting time when sending a request. auto_close: `bool`, optional If `True`, the client will close connections and clear resource usage after a certain period - of inactivity. Useful for keep-alive services. + of inactivity. Useful for always-on services. close_delay: `float`, optional Time to wait before auto-closing the client in seconds. Effective only if `auto_close` is `True`. auto_refresh: `bool`, optional diff --git a/src/gemini_webapi/utils/__init__.py b/src/gemini_webapi/utils/__init__.py index 2cb78d4..433f7c0 100644 --- a/src/gemini_webapi/utils/__init__.py +++ b/src/gemini_webapi/utils/__init__.py @@ -2,7 +2,8 @@ from .upload_file import upload_file # noqa: F401 from .rotate_1psidts import rotate_1psidts # noqa: F401 -from .get_access_token import get_access_token, get_cookie_by_name # noqa: F401 +from .get_access_token import get_access_token # noqa: F401 +from .load_browser_cookies import load_browser_cookies # noqa: F401 from .logger import logger, set_log_level # noqa: F401 diff --git a/src/gemini_webapi/utils/get_access_token.py b/src/gemini_webapi/utils/get_access_token.py index 28f2da8..5543517 100644 --- a/src/gemini_webapi/utils/get_access_token.py +++ b/src/gemini_webapi/utils/get_access_token.py @@ -2,37 +2,15 @@ import asyncio from asyncio import Task from pathlib import Path -from http.cookiejar import CookieJar from httpx import AsyncClient, Response from ..constants import Endpoint, Headers from ..exceptions import AuthError +from .load_browser_cookies import load_browser_cookies from .logger import logger -def get_cookie_by_name(jar: CookieJar, name: str) -> str: - """ - Get the value of a cookie from a http CookieJar. - - Parameters - ---------- - jar : `http.cookiejar.CookieJar` - Cookie jar to be used. - name : `str` - Name of the cookie to be retrieved. - - Returns - ------- - `str` - Value of the cookie. - """ - - for cookie in jar: - if cookie.name == name: - return cookie.value - - async def get_access_token( base_cookies: dict, proxies: dict | None = None, verbose: bool = False ) -> tuple[str, dict]: @@ -100,16 +78,12 @@ async def send_request(cookies: dict) -> tuple[Response | None, dict]: ) try: - import browser_cookie3 - - browser_cookies = browser_cookie3.load(domain_name="google.com") - if browser_cookies and ( - secure_1psid := get_cookie_by_name(browser_cookies, "__Secure-1PSID") - ): + browser_cookies = load_browser_cookies( + domain_name="google.com", verbose=verbose + ) + if browser_cookies and (secure_1psid := browser_cookies.get("__Secure-1PSID")): local_cookies = {"__Secure-1PSID": secure_1psid} - if secure_1psidts := get_cookie_by_name( - browser_cookies, "__Secure-1PSIDTS" - ): + if secure_1psidts := browser_cookies.get("__Secure-1PSIDTS"): local_cookies["__Secure-1PSIDTS"] = secure_1psidts tasks.append(Task(send_request(local_cookies))) elif verbose: diff --git a/src/gemini_webapi/utils/load_browser_cookies.py b/src/gemini_webapi/utils/load_browser_cookies.py new file mode 100644 index 0000000..96ad23d --- /dev/null +++ b/src/gemini_webapi/utils/load_browser_cookies.py @@ -0,0 +1,52 @@ +import browser_cookie3 as bc3 + +from .logger import logger + + +def load_browser_cookies(domain_name: str = "", verbose=True) -> dict: + """ + Try to load cookies from all supported browsers and return combined cookiejar. + Optionally pass in a domain name to only load cookies from the specified domain. + + Parameters + ---------- + domain_name : str, optional + Domain name to filter cookies by, by default will load all cookies without filtering. + verbose : bool, optional + If `True`, will print more infomation in logs. + + Returns + ------- + `dict` + Dictionary with cookie name as key and cookie value as value. + """ + cookies = {} + for cookie_fn in [ + bc3.chrome, + bc3.chromium, + bc3.opera, + bc3.opera_gx, + bc3.brave, + bc3.edge, + bc3.vivaldi, + bc3.firefox, + bc3.librewolf, + bc3.safari, + ]: + try: + for cookie in cookie_fn(domain_name=domain_name): + cookies[cookie.name] = cookie.value + except bc3.BrowserCookieError: + pass + except PermissionError as e: + if verbose: + logger.warning( + f"Permission denied while trying to load cookies from {cookie_fn.__name__}. {e}" + ) + except Exception as e: + if verbose: + logger.error( + f"Error happened while trying to load cookies from {cookie_fn.__name__}. {e}" + ) + + return cookies diff --git a/tests/test_rotate_cookie.py b/tests/test_rotate_cookies.py similarity index 100% rename from tests/test_rotate_cookie.py rename to tests/test_rotate_cookies.py