Skip to content

Commit

Permalink
feat: add support to specify language model version
Browse files Browse the repository at this point in the history
ref #51
  • Loading branch information
HanaokaYuzu committed Dec 22, 2024
1 parent 3e229b9 commit a3f7f25
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 12 deletions.
34 changes: 31 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ A reverse-engineered asynchronous python wrapper for [Google Gemini](https://gem
- [Retrieve images in response](#retrieve-images-in-response)
- [Generate images with ImageFx](#generate-images-with-imagefx)
- [Save images to local files](#save-images-to-local-files)
- [Specify language model version](#specify-language-model-version)
- [Generate contents with Gemini extensions](#generate-contents-with-gemini-extensions)
- [Check and switch to other reply candidates](#check-and-switch-to-other-reply-candidates)
- [Control log level](#control-log-level)
Expand Down Expand Up @@ -92,9 +93,9 @@ pip install -U browser-cookie3
```yaml
services:
main:
volumes:
- ./gemini_cookies:/usr/local/lib/python3.12/site-packages/gemini_webapi/utils/temp
main:
volumes:
- ./gemini_cookies:/usr/local/lib/python3.12/site-packages/gemini_webapi/utils/temp
```
> [!NOTE]
Expand Down Expand Up @@ -255,6 +256,33 @@ async def main():
asyncio.run(main())
```

### Specify language model version

You can choose a specified language model version by passing `model` argument to `GeminiClient.generate_content` or `GeminiClient.start_chat`. The default value is `unspecified`.

Currently available models (as of Dec 21, 2024):

- `unspecified` - Default model (Gemini 1.5 Flash)
- `gemini-1.5-flash` - Gemini 1.5 Flash
- `gemini-2.0-flash-exp` - Gemini 2.0 Flash Experimental

```python
from gemini_webapi.constants import Model

async def main():
response1 = await client.generate_content(
"What's you language model version? Reply version number only.",
model="gemini-1.5-flash",
)
print(f"Model version (gemini-1.5-flash): {response1.text}")

chat = client.start_chat(model=Model.G_2_0_FLASH_EXP)
response2 = await chat.send_message("What's you language model version? Reply version number only.")
print(f"Model version ({Model.G_2_0_FLASH_EXP.model_name}): {response2.text}")

asyncio.run(main())
```

### Generate contents with Gemini extensions

> [!IMPORTANT]
Expand Down
45 changes: 40 additions & 5 deletions src/gemini_webapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from httpx import AsyncClient, ReadTimeout

from .constants import Endpoint, Headers
from .constants import Endpoint, Headers, Model
from .exceptions import AuthError, APIError, TimeoutError, GeminiError
from .types import WebImage, GeneratedImage, Candidate, ModelOutput
from .utils import (
Expand Down Expand Up @@ -79,6 +79,9 @@ class GeminiClient:
__Secure-1PSIDTS cookie value, some google accounts don't require this value, provide only if it's in the cookie list.
proxy: `str`, optional
Proxy URL.
kwargs: `dict`, optional
Additional arguments which will be passed to the http client.
Refer to `httpx.AsyncClient` for more information.
Raises
------
Expand All @@ -98,13 +101,15 @@ class GeminiClient:
"close_task",
"auto_refresh",
"refresh_interval",
"kwargs",
]

def __init__(
self,
secure_1psid: str | None = None,
secure_1psidts: str | None = None,
proxy: str | None = None,
**kwargs,
):
self.cookies = {}
self.proxy = proxy
Expand All @@ -117,6 +122,7 @@ def __init__(
self.close_task: Task | None = None
self.auto_refresh: bool = True
self.refresh_interval: float = 540
self.kwargs = kwargs

# Validate cookies
if secure_1psid:
Expand Down Expand Up @@ -173,6 +179,7 @@ async def init(
follow_redirects=True,
headers=Headers.GEMINI.value,
cookies=valid_cookies,
**self.kwargs,
)
self.access_token = access_token
self.cookies = valid_cookies
Expand Down Expand Up @@ -256,7 +263,9 @@ async def generate_content(
self,
prompt: str,
images: list[bytes | str | Path] | None = None,
model: Model | str = Model.UNSPECIFIED,
chat: Optional["ChatSession"] = None,
**kwargs,
) -> ModelOutput:
"""
Generates contents with prompt.
Expand All @@ -267,8 +276,14 @@ async def generate_content(
Prompt provided by user.
images: `list[bytes | str | Path]`, optional
List of image file paths or file data in bytes.
model: `Model` | `str`, optional
Specify the model to use for generation.
Pass either a `gemini_webapi.constants.Model` enum or a model name string.
chat: `ChatSession`, optional
Chat data to retrieve conversation history. If None, will automatically generate a new chat id when sending post request.
kwargs: `dict`, optional
Additional arguments which will be passed to the post request.
Refer to `httpx.AsyncClient.request` for more information.
Returns
-------
Expand All @@ -291,12 +306,16 @@ async def generate_content(

assert prompt, "Prompt cannot be empty."

if not isinstance(model, Model):
model = Model.from_name(model)

if self.auto_close:
await self.reset_close_task()

try:
response = await self.client.post(
Endpoint.GENERATE.value,
headers=model.model_header,
data={
"at": self.access_token,
"f.req": json.dumps(
Expand Down Expand Up @@ -325,6 +344,7 @@ async def generate_content(
]
),
},
**kwargs,
)
except ReadTimeout:
raise TimeoutError(
Expand Down Expand Up @@ -431,12 +451,13 @@ def start_chat(self, **kwargs) -> "ChatSession":
Parameters
----------
kwargs: `dict`, optional
Other arguments to pass to `ChatSession.__init__`.
Additional arguments which will be passed to the chat session.
Refer to `gemini_webapi.ChatSession` for more information.
Returns
-------
:class:`ChatSession`
Empty chat object for retrieving conversation history.
Empty chat session object for retrieving conversation history.
"""

return ChatSession(geminiclient=self, **kwargs)
Expand All @@ -458,9 +479,17 @@ class ChatSession:
Reply id, if provided together with metadata, will override the second value in it.
rcid: `str`, optional
Reply candidate id, if provided together with metadata, will override the third value in it.
model: `Model` | `str`, optional
Specify the model to use for generation.
Pass either a `gemini_webapi.constants.Model` enum or a model name string.
"""

__slots__ = ["__metadata", "geminiclient", "last_output"]
__slots__ = [
"__metadata",
"geminiclient",
"last_output",
"model",
]

def __init__(
self,
Expand All @@ -469,10 +498,12 @@ def __init__(
cid: str | None = None, # chat id
rid: str | None = None, # reply id
rcid: str | None = None, # reply candidate id
model: Model | str = Model.UNSPECIFIED,
):
self.__metadata: list[str | None] = [None, None, None]
self.geminiclient: GeminiClient = geminiclient
self.last_output: ModelOutput | None = None
self.model = model

if metadata:
self.metadata = metadata
Expand All @@ -499,6 +530,7 @@ async def send_message(
self,
prompt: str,
images: list[bytes | str | Path] | None = None,
**kwargs,
) -> ModelOutput:
"""
Generates contents with prompt.
Expand All @@ -510,6 +542,9 @@ async def send_message(
Prompt provided by user.
images: `list[bytes | str | Path]`, optional
List of image file paths or file data in bytes.
kwargs: `dict`, optional
Additional arguments which will be passed to the post request.
Refer to `httpx.AsyncClient.request` for more information.
Returns
-------
Expand All @@ -531,7 +566,7 @@ async def send_message(
"""

return await self.geminiclient.generate_content(
prompt=prompt, images=images, chat=self
prompt=prompt, images=images, model=self.model, chat=self, **kwargs
)

def choose_candidate(self, index: int) -> ModelOutput:
Expand Down
25 changes: 25 additions & 0 deletions src/gemini_webapi/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,28 @@ class Headers(Enum):
"Content-Type": "application/json",
}
UPLOAD = {"Push-ID": "feeds/mcudyrk2a4khkz"}


class Model(Enum):
UNSPECIFIED = ("unspecified", {})
G_1_5_FLASH = (
"gemini-1.5-flash",
{"x-goog-ext-525001261-jspb": '[null,null,null,null,"7daceb7ef88130f5"]'},
)
G_2_0_FLASH_EXP = (
"gemini-2.0-flash-exp",
{"x-goog-ext-525001261-jspb": '[null,null,null,null,"948b866104ccf484"]'},
)

def __init__(self, name, header):
self.model_name = name
self.model_header = header

@classmethod
def from_name(cls, name: str):
for model in cls:
if model.model_name == name:
return model
raise ValueError(
f"Unknown model name: {name}. Available models: {', '.join([model.model_name for model in cls])}"
)
17 changes: 13 additions & 4 deletions tests/test_client_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from loguru import logger

from gemini_webapi import GeminiClient, AuthError, set_log_level
from gemini_webapi.constants import Model

logging.getLogger("asyncio").setLevel(logging.ERROR)
set_log_level("DEBUG")
Expand All @@ -27,10 +28,20 @@ async def test_successful_request(self):
response = await self.geminiclient.generate_content("Hello World!")
self.assertTrue(response.text)

@logger.catch(reraise=True)
async def test_switch_model(self):
for model in Model:
response = await self.geminiclient.generate_content(
"What's you language model version? Reply version number only.",
model=model,
)
logger.debug(f"Model version ({model.model_name}): {response.text}")

@logger.catch(reraise=True)
async def test_upload_image(self):
response = await self.geminiclient.generate_content(
"Describe these images", images=[Path("assets/banner.png"), "assets/favicon.png"]
"Describe these images",
images=[Path("assets/banner.png"), "assets/favicon.png"],
)
logger.debug(response.text)

Expand Down Expand Up @@ -86,9 +97,7 @@ async def test_ai_image_generation(self):

@logger.catch(reraise=True)
async def test_card_content(self):
response = await self.geminiclient.generate_content(
"How is today's weather?"
)
response = await self.geminiclient.generate_content("How is today's weather?")
logger.debug(response.text)

@logger.catch(reraise=True)
Expand Down

0 comments on commit a3f7f25

Please sign in to comment.