From f13228dc01728e410d5ca6916176049a04490218 Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Fri, 26 Apr 2024 16:54:09 +0000 Subject: [PATCH 01/19] *Inital prototype for explicit caching *Add basic CURD support for caching *Remove INPUT_ONLY marked fields from CachedContent dataclass *Rename files 'cached_content*' -> 'caching*' *Update 'Create' method for explicit instantination of 'CachedContent' *Add a factory method to instatinate model with `CachedContent` as its context *blacken *Add tests Change-Id: I694545243efda467d6fd599beded0dc6679b727d --- google/generativeai/caching.py | 76 ++++++ google/generativeai/client.py | 4 + google/generativeai/generative_models.py | 71 +++++- google/generativeai/types/caching_types.py | 257 +++++++++++++++++++++ tests/test_caching.py | 224 ++++++++++++++++++ tests/test_generative_models.py | 58 +++++ 6 files changed, 688 insertions(+), 2 deletions(-) create mode 100644 google/generativeai/caching.py create mode 100644 google/generativeai/types/caching_types.py create mode 100644 tests/test_caching.py diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py new file mode 100644 index 000000000..55fd95bac --- /dev/null +++ b/google/generativeai/caching.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Optional, Iterable + +import google.ai.generativelanguage as glm + +from google.generativeai.types import caching_types +from google.generativeai.types import content_types +from google.generativeai.client import get_default_cache_client + + +# alias for `caching_types.CachedContent`. +CachedContent = caching_types.CachedContent + + +def get_cached_content(name: str, client: glm.CacheServiceClient | None = None) -> CachedContent: + """Fetches required `CachedContent` resource. + + Args: + name: name: The resource name referring to the cached content. + + Returns: + `CachedContent` resource with specified name. + """ + return CachedContent.get_cached_content(name=name, client=client) + + +def delete_cached_content(name: str, client: glm.CacheServiceClient | None = None) -> None: + """Deletes `CachedContent` resource. + + Args: + name: The resource name referring to the cached content. + Format: cachedContents/{id}. + """ + if client is None: + client = get_default_cache_client() + + if "cachedContents/" not in name: + name = "cachedContents/" + name + + request = glm.DeleteCachedContentRequest(name=name) + client.delete_cached_content(request) + return + + +def list_cached_contents( + page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None +) -> Iterable[CachedContent]: + """Lists `CachedContent` objects associated with the project. + + Args: + page_size: The maximum number of permissions to return (per page). The service may return fewer `CachedContent` objects. + + Returns: + A paginated list of `CachedContent` objects. + """ + if client is None: + client = get_default_cache_client() + + request = glm.ListCachedContentsRequest(page_size=page_size) + for cached_content in client.list_cached_contents(request): + yield caching_types.decode_cached_content(cached_content) diff --git a/google/generativeai/client.py b/google/generativeai/client.py index bee858122..ec0970241 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -289,6 +289,10 @@ def configure( _client_manager.configure() +def get_default_cache_client() -> glm.CacheServiceClient: + return _client_manager.get_default_client("cache") + + def get_default_discuss_client() -> glm.DiscussServiceClient: return _client_manager.get_default_client("discuss") diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 02cab0b29..30613a52a 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -5,8 +5,7 @@ from collections.abc import Iterable import dataclasses import textwrap -from typing import Any -from typing import Union +from typing import Any, Union, overload import reprlib # pylint: disable=bad-continuation, line-too-long @@ -16,9 +15,11 @@ from google.ai import generativelanguage as glm from google.generativeai import client from google.generativeai import string_utils +from google.generativeai import caching from google.generativeai.types import content_types from google.generativeai.types import generation_types from google.generativeai.types import safety_types +from google.generativeai.types import caching_types class GenerativeModel: @@ -98,6 +99,15 @@ def __init__( self._client = None self._async_client = None + def __new__(cls, *args, **kwargs): + self = super().__new__(cls) + + if cached_instance := kwargs.pop("cached_content", None): + setattr(self, "_cached_content", cached_instance.name) + setattr(cls, "cached_content", property(fget=lambda self: self._cached_content)) + + return self + @property def model_name(self): return self._model_name @@ -134,6 +144,13 @@ def _prepare_request( if not contents: raise TypeError("contents must not be empty") + if hasattr(self, "cached_content") and any([self._system_instruction, tools, tool_config]): + raise ValueError( + "`tools`, `tool_config`, `system_instruction` cannot be set on a model instantinated with `cached_content` as its context." + ) + + cached_content = getattr(self, "cached_content", None) + tools_lib = self._get_tools_lib(tools) if tools_lib is not None: tools_lib = tools_lib.to_proto() @@ -162,6 +179,7 @@ def _prepare_request( tools=tools_lib, tool_config=tool_config, system_instruction=self._system_instruction, + cached_content=cached_content, ) def _get_tools_lib( @@ -172,6 +190,55 @@ def _get_tools_lib( else: return content_types.to_function_library(tools) + @overload + @classmethod + def from_cached_content( + cls, + cached_content: str, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: ... + + @overload + @classmethod + def from_cached_content( + cls, + cached_content: caching_types.CachedContent, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: ... + + @classmethod + def from_cached_content( + cls, + cached_content: str | caching_types.CachedContent, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: + """Creates a model with `cached_content` as model's context. + + Args: + cached_content: context for the model. + + Returns: + `GenerativeModel` object with `cached_content` as its context. + """ + if isinstance(cached_content, str): + cached_content = caching.get_cached_content(name=cached_content) + + # call __new__ with the cached_content to set the model's context. This is done to avoid + # the exposing `cached_content` as a public attribute. + self = cls.__new__(cls, cached_content=cached_content) + + # call __init__ to set the model's `generation_config`, `safety_settings`. + # `model_name` will be the name of the model for which the `cached_content` was created. + self.__init__( + model_name=cached_content.model, + generation_config=generation_config, + safety_settings=safety_settings, + ) + return self + def generate_content( self, contents: content_types.ContentsType, diff --git a/google/generativeai/types/caching_types.py b/google/generativeai/types/caching_types.py new file mode 100644 index 000000000..a99f25414 --- /dev/null +++ b/google/generativeai/types/caching_types.py @@ -0,0 +1,257 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +import datetime +from typing import Optional, Union +from typing_extensions import TypedDict + +from google.generativeai.types import content_types +from google.generativeai.types.model_types import idecode_time +from google.generativeai.utils import flatten_update_paths +from google.generativeai.client import get_default_cache_client + +from google.protobuf import field_mask_pb2 + +import google.ai.generativelanguage as glm + + +__all__ = ["CachedContent"] + + +class TTL(TypedDict): + seconds: int + + +ExpirationTypes = Union[TTL, int, datetime.timedelta] + + +def _to_ttl(expiration: Optional[ExpirationTypes]) -> TTL: + if isinstance(expiration, datetime.timedelta): + return {"seconds": int(expiration.total_seconds())} + elif isinstance(expiration, dict): + return expiration + elif isinstance(expiration, int): + return {"seconds": expiration} + else: + raise TypeError( + f"Could not convert input to `expire_time` \n'" f" type: {type(expiration)}\n", + expiration, + ) + + +def decode_cached_content(cached_content: glm.CachedContent) -> CachedContent: + # not supposed to get INPUT_ONLY repeated fields, but local gapic lib build + # is returning these, hence setting including_default_value_fields to False + cached_content = type(cached_content).to_dict( + cached_content, including_default_value_fields=False + ) + + idecode_time(cached_content, "create_time") + idecode_time(cached_content, "update_time") + # always decode `expire_time` as Timestamp is returned + # regardless of what was sent on input + idecode_time(cached_content, "expire_time") + return CachedContent(**cached_content) + + +@dataclasses.dataclass +class CachedContent: + """Cached content resource.""" + + name: str + model: str + create_time: datetime.datetime + update_time: datetime.datetime + expire_time: datetime.datetime + + # NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+). + # Adding basic support for now. + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.delete() + + def _to_dict(self) -> glm.CachedContent: + proto_paths = { + "name": self.name, + "model": self.model, + } + return glm.CachedContent(**proto_paths) + + def _apply_update(self, path, value): + parts = path.split(".") + for part in parts[:-1]: + self = getattr(self, part) + if parts[-1] == "ttl": + value = self.expire_time + datetime.timedelta(seconds=value["seconds"]) + parts[-1] = "expire_time" + setattr(self, parts[-1], value) + + @staticmethod + def _prepare_create_request( + name: str, + model: str, + system_instruction: Optional[content_types.ContentType] = None, + contents: Optional[content_types.ContentsType] = None, + tools: Optional[content_types.FunctionLibraryType] = None, + tool_config: Optional[content_types.ToolConfigType] = None, + ttl: Optional[ExpirationTypes] = datetime.timedelta(hours=1), + ) -> glm.CreateCachedContentRequest: + """Prepares a CreateCachedContentRequest.""" + if "cachedContents/" not in name: + name = "cachedContents/" + name + + if "/" not in model: + model = "models/" + model + + if system_instruction: + system_instruction = content_types.to_content(system_instruction) + + tools_lib = content_types.to_function_library(tools) + if tools_lib: + tools_lib = tools_lib.to_proto() + + if tool_config: + tool_config = content_types.to_tool_config(tool_config) + + if contents: + contents = content_types.to_contents(contents) + + if ttl: + ttl = _to_ttl(ttl) + + cached_content = glm.CachedContent( + name=name, + model=model, + system_instruction=system_instruction, + contents=contents, + tools=tools_lib, + tool_config=tool_config, + ttl=ttl, + ) + + return glm.CreateCachedContentRequest(cached_content=cached_content) + + @classmethod + def create( + cls, + name: str, + model: str, + system_instruction: Optional[content_types.ContentType] = None, + contents: Optional[content_types.ContentsType] = None, + tools: Optional[content_types.FunctionLibraryType] = None, + tool_config: Optional[content_types.ToolConfigType] = None, + ttl: Optional[ExpirationTypes] = datetime.timedelta(hours=1), + client: glm.CacheServiceClient | None = None, + ) -> CachedContent: + """Creates CachedContent resource. + + Args: + name: The resource name referring to the cached content. + Format: cachedContents/{id}. + model: The name of the `Model` to use for cached content + Format: models/{model}. Cached content resource can be only + used with model it was created for. + system_instruction: Developer set system instruction. + contents: Contents to cache. + tools: A list of `Tools` the model may use to generate response. + tool_config: Config to apply to all tools. + ttl: TTL for cached resource (in seconds). Defaults to 1 hour. + + Returns: + `CachedContent` resource with specified name. + """ + if client is None: + client = get_default_cache_client() + + request = cls._prepare_create_request( + name=name, + model=model, + system_instruction=system_instruction, + contents=contents, + tools=tools, + tool_config=tool_config, + ttl=ttl, + ) + + response = client.create_cached_content(request) + return decode_cached_content(response) + + @classmethod + def get_cached_content( + cls, name: str, client: glm.CacheServiceClient | None = None + ) -> CachedContent: + """Gets a `CachedContent` resource.""" + if client is None: + client = get_default_cache_client() + + if "cachedContents/" not in name: + name = "cachedContents/" + name + + request = glm.GetCachedContentRequest(name=name) + response = client.get_cached_content(request) + return decode_cached_content(response) + + def delete(self, client: glm.CachedServiceClient | None = None) -> None: + """Deletes a `CachedContent` resource.""" + if client is None: + client = get_default_cache_client() + + request = glm.DeleteCachedContentRequest(name=self.name) + client.delete_cached_content(request) + return + + def update( + self, + updates: dict[str, Any], + client: glm.CacheServiceClient | None = None, + ) -> CachedContent: + """Updates requested `CachedContent` resource. + + Args: + updates: The list of fields to update. + Currently only `ttl/expire_time` is supported as an update path. + + Returns: + `CachedContent` object with specified updates. + """ + if client is None: + client = get_default_cache_client() + + updates = flatten_update_paths(updates) + for update_path in updates: + if update_path == "ttl": + updates = updates.copy() + update_path_val = updates.get(update_path) + updates[update_path] = _to_ttl(update_path_val) + else: + raise ValueError( + f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead." + ) + field_mask = field_mask_pb2.FieldMask() + + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + self._apply_update(path, value) + + request = glm.UpdateCachedContentRequest( + cached_content=self._to_dict(), update_mask=field_mask + ) + client.update_cached_content(request) + return self diff --git a/tests/test_caching.py b/tests/test_caching.py new file mode 100644 index 000000000..916dc407b --- /dev/null +++ b/tests/test_caching.py @@ -0,0 +1,224 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import math +import datetime +from typing import Any +import unittest +import unittest.mock as mock + +import google.ai.generativelanguage as glm + +from google.generativeai import caching +from google.generativeai.types import caching_types + +from google.generativeai import client +from absl.testing import absltest +from absl.testing import parameterized + + +class UnitTests(parameterized.TestCase): + def setUp(self): + self.client = unittest.mock.MagicMock() + + client._client_manager.clients["cache"] = self.client + + self.observed_requests = [] + + def add_client_method(f): + name = f.__name__ + setattr(self.client, name, f) + return f + + @add_client_method + def create_cached_content( + request: glm.CreateCachedContentRequest, + **kwargs, + ) -> glm.CachedContent: + self.observed_requests.append(request) + return glm.CachedContent( + name="cachedContent/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ) + + @add_client_method + def get_cached_content( + request: glm.GetCachedContentRequest, + **kwargs, + ) -> glm.CachedContent: + self.observed_requests.append(request) + return glm.CachedContent( + name="cachedContent/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ) + + @add_client_method + def list_cached_contents( + request: glm.ListCachedContentsRequest, + **kwargs, + ) -> glm.ListCachedContentsResponse: + self.observed_requests.append(request) + return [ + glm.CachedContent( + name="cachedContent/test-cached-content-1", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ), + glm.CachedContent( + name="cachedContent/test-cached-content-2", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ), + ] + + @add_client_method + def update_cached_content( + request: glm.UpdateCachedContentRequest, + **kwargs, + ) -> glm.CachedContent: + self.observed_requests.append(request) + return glm.CachedContent( + name="cachedContent/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T03:01:01.123456Z", + ) + + @add_client_method + def delete_cached_content( + request: glm.DeleteCachedContentRequest, + **kwargs, + ) -> None: + self.observed_requests.append(request) + + def test_create_cached_content(self): + + def add(a: int, b: int) -> int: + return a + b + + cc = caching.CachedContent.create( + name="test-cached-content", + model="models/gemini-1.0-pro-001", + contents=["Add 5 and 6"], + tools=[add], + tool_config={"function_calling_config": "ANY"}, + system_instruction="Always add 10 to the result.", + ttl=datetime.timedelta(minutes=30), + ) + self.assertIsInstance(self.observed_requests[-1], glm.CreateCachedContentRequest) + self.assertIsInstance(cc, caching_types.CachedContent) + self.assertEqual(cc.name, "cachedContent/test-cached-content") + self.assertEqual(cc.model, "models/gemini-1.0-pro-001") + + @parameterized.named_parameters( + [ + dict( + testcase_name="ttl-is-int-seconds", + ttl=7200, + ), + dict( + testcase_name="ttl-is-timedelta", + ttl=datetime.timedelta(hours=2), + ), + dict( + testcase_name="ttl-is-dict", + ttl={"seconds": 7200}, + ), + dict( + testcase_name="ttl-is-none-default-to-1-hr", + ttl=None, + ), + ] + ) + def test_expiration_types_for_create_cached_content(self, ttl): + cc = caching.CachedContent.create( + name="test-cached-content", + model="models/gemini-1.0-pro-001", + contents=["cache this please for 2 hours"], + ttl=ttl, + ) + self.assertIsInstance(self.observed_requests[-1], glm.CreateCachedContentRequest) + self.assertIsInstance(cc, caching_types.CachedContent) + + def test_get_cached_content(self): + cc = caching.get_cached_content(name="cachedContent/test-cached-content") + self.assertIsInstance(self.observed_requests[-1], glm.GetCachedContentRequest) + self.assertIsInstance(cc, caching_types.CachedContent) + self.assertEqual(cc.name, "cachedContent/test-cached-content") + self.assertEqual(cc.model, "models/gemini-1.0-pro-001") + + def test_list_cached_contents(self): + ccs = list(caching.list_cached_contents(page_size=2)) + self.assertIsInstance(self.observed_requests[-1], glm.ListCachedContentsRequest) + self.assertLen(ccs, 2) + self.assertIsInstance(ccs[0], caching_types.CachedContent) + self.assertIsInstance(ccs[1], caching_types.CachedContent) + + def test_update_cached_content_invalid_update_paths(self): + update_masks = dict( + name="change", + model="models/gemini-1.5-pro-001", + system_instruction="Always add 10 to the result.", + contents=["add this Content"], + ) + + cc = caching.get_cached_content(name="cachedContent/test-cached-content") + with self.assertRaises(ValueError): + cc.update(updates=update_masks) + + def test_update_cached_content_valid_update_paths(self): + update_masks = dict( + ttl=datetime.timedelta(hours=2), + ) + + cc = caching.get_cached_content(name="cachedContent/test-cached-content") + cc = cc.update(updates=update_masks) + self.assertIsInstance(self.observed_requests[-1], glm.UpdateCachedContentRequest) + self.assertIsInstance(cc, caching_types.CachedContent) + + def test_delete_cached_content(self): + cc = caching.get_cached_content(name="cachedContent/test-cached-content") + cc.delete() + self.assertIsInstance(self.observed_requests[-1], glm.DeleteCachedContentRequest) + + cc = caching.delete_cached_content(name="cachedContent/test-cached-content") + self.assertIsInstance(self.observed_requests[-1], glm.DeleteCachedContentRequest) + + def test_auto_delete_cached_content_with_context_manager(self): + with caching.CachedContent.create( + name="test-cached-content", + model="models/gemini-1.0-pro-001", + contents=["Add 5 and 6"], + system_instruction="Always add 10 to the result.", + ttl=datetime.timedelta(minutes=30), + ) as cc: + ... # some logic + + self.assertIsInstance(self.observed_requests[-1], glm.DeleteCachedContentRequest) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index bc011823c..948c87c1f 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -12,6 +12,7 @@ from google.generativeai import generative_models from google.generativeai.types import content_types from google.generativeai.types import generation_types +from google.generativeai.types import caching_types import PIL.Image @@ -40,6 +41,7 @@ def setUp(self): self.client = unittest.mock.MagicMock() client_lib._client_manager.clients["generative"] = self.client + client_lib._client_manager.clients["cache"] = self.client def add_client_method(f): name = f.__name__ @@ -77,6 +79,20 @@ def count_tokens( response = self.responses["count_tokens"].pop(0) return response + @add_client_method + def get_cached_content( + request: glm.GetCachedContentRequest, + **kwargs, + ) -> glm.CachedContent: + self.observed_requests.append(request) + return glm.CachedContent( + name="cachedContent/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ) + def test_hello(self): # Generate text from text prompt model = generative_models.GenerativeModel(model_name="gemini-pro") @@ -293,6 +309,48 @@ def test_stream_prompt_feedback_not_blocked(self): text = "".join(chunk.text for chunk in response) self.assertEqual(text, "first second") + @parameterized.named_parameters( + [ + dict(testcase_name="test_cached_content_as_id", cached_content="test-cached-content"), + dict( + testcase_name="test_cached_content_as_CachedContent_object", + cached_content=caching_types.CachedContent( + name="cachedContent/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ), + ), + ], + ) + def test_model_with_cached_content_as_context(self, cached_content): + model = generative_models.GenerativeModel.from_cached_content(cached_content=cached_content) + cc_name = model.cached_content + model_name = model.model_name + self.assertEqual(cc_name, "cachedContent/test-cached-content") + self.assertEqual(model_name, "models/gemini-1.0-pro-001") + + def test_content_generation_with_model_having_context(self): + self.responses["generate_content"] = [simple_response("world!")] + model = generative_models.GenerativeModel.from_cached_content( + cached_content="test-cached-content" + ) + response = model.generate_content("Hello") + + self.assertEqual(response.text, "world!") + + def test_fail_content_generation_with_model_having_context(self): + model = generative_models.GenerativeModel.from_cached_content( + cached_content="test-cached-content" + ) + + def add(a: int, b: int) -> int: + return a + b + + with self.assertRaises(ValueError): + model.generate_content("Hello", tools=[add]) + def test_chat(self): # Multi turn chat model = generative_models.GenerativeModel("gemini-pro") From 6fafe6b329647586ffd073fd22588290fecc28db Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Wed, 22 May 2024 10:49:35 +0530 Subject: [PATCH 02/19] rename get_cached_content to get --- google/generativeai/caching.py | 2 +- google/generativeai/types/caching_types.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py index 55fd95bac..c0d9e3ef0 100644 --- a/google/generativeai/caching.py +++ b/google/generativeai/caching.py @@ -36,7 +36,7 @@ def get_cached_content(name: str, client: glm.CacheServiceClient | None = None) Returns: `CachedContent` resource with specified name. """ - return CachedContent.get_cached_content(name=name, client=client) + return CachedContent.get(name=name, client=client) def delete_cached_content(name: str, client: glm.CacheServiceClient | None = None) -> None: diff --git a/google/generativeai/types/caching_types.py b/google/generativeai/types/caching_types.py index a99f25414..aa9fbe559 100644 --- a/google/generativeai/types/caching_types.py +++ b/google/generativeai/types/caching_types.py @@ -193,9 +193,7 @@ def create( return decode_cached_content(response) @classmethod - def get_cached_content( - cls, name: str, client: glm.CacheServiceClient | None = None - ) -> CachedContent: + def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedContent: """Gets a `CachedContent` resource.""" if client is None: client = get_default_cache_client() From cfc936e164d5bc8fcedf6ae2894fa0369f75f762 Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Thu, 23 May 2024 23:10:16 +0530 Subject: [PATCH 03/19] Stroke out functional approach for CachedContent CURD ops --- google/generativeai/caching.py | 268 ++++++++++++++++++--- google/generativeai/generative_models.py | 8 +- google/generativeai/types/caching_types.py | 218 +---------------- tests/test_caching.py | 22 +- tests/test_generative_models.py | 4 +- 5 files changed, 246 insertions(+), 274 deletions(-) diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py index c0d9e3ef0..5ed449ef8 100644 --- a/google/generativeai/caching.py +++ b/google/generativeai/caching.py @@ -14,63 +14,251 @@ # limitations under the License. from __future__ import annotations -from typing import Optional, Iterable - -import google.ai.generativelanguage as glm +import dataclasses +import datetime +from typing import Any, Iterable, Optional +from google.generativeai.types.model_types import idecode_time from google.generativeai.types import caching_types from google.generativeai.types import content_types +from google.generativeai.utils import flatten_update_paths from google.generativeai.client import get_default_cache_client +from google.protobuf import field_mask_pb2 +import google.ai.generativelanguage as glm + + +@dataclasses.dataclass +class CachedContent: + """Cached content resource.""" + + name: str + model: str + create_time: datetime.datetime + update_time: datetime.datetime + expire_time: datetime.datetime + + # NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+). + # Adding basic support for now. + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.delete() + + def _to_dict(self) -> glm.CachedContent: + proto_paths = { + "name": self.name, + "model": self.model, + } + return glm.CachedContent(**proto_paths) + + def _apply_update(self, path, value): + parts = path.split(".") + for part in parts[:-1]: + self = getattr(self, part) + if parts[-1] == "ttl": + value = self.expire_time + datetime.timedelta(seconds=value["seconds"]) + parts[-1] = "expire_time" + setattr(self, parts[-1], value) + + @classmethod + def _decode_cached_content(cls, cached_content: glm.CachedContent) -> CachedContent: + # not supposed to get INPUT_ONLY repeated fields, but local gapic lib build + # is returning these, hence setting including_default_value_fields to False + cached_content = type(cached_content).to_dict( + cached_content, including_default_value_fields=False + ) + + idecode_time(cached_content, "create_time") + idecode_time(cached_content, "update_time") + # always decode `expire_time` as Timestamp is returned + # regardless of what was sent on input + idecode_time(cached_content, "expire_time") + return cls(**cached_content) + + @staticmethod + def _prepare_create_request( + name: str, + model: str, + system_instruction: Optional[content_types.ContentType] = None, + contents: Optional[content_types.ContentsType] = None, + tools: Optional[content_types.FunctionLibraryType] = None, + tool_config: Optional[content_types.ToolConfigType] = None, + ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), + ) -> glm.CreateCachedContentRequest: + """Prepares a CreateCachedContentRequest.""" + if "cachedContents/" not in name: + name = "cachedContents/" + name + + if "/" not in model: + model = "models/" + model + + if system_instruction: + system_instruction = content_types.to_content(system_instruction) + + tools_lib = content_types.to_function_library(tools) + if tools_lib: + tools_lib = tools_lib.to_proto() + + if tool_config: + tool_config = content_types.to_tool_config(tool_config) + + if contents: + contents = content_types.to_contents(contents) + + if ttl: + ttl = caching_types.to_ttl(ttl) + + cached_content = glm.CachedContent( + name=name, + model=model, + system_instruction=system_instruction, + contents=contents, + tools=tools_lib, + tool_config=tool_config, + ttl=ttl, + ) + + return glm.CreateCachedContentRequest(cached_content=cached_content) + + @classmethod + def create( + cls, + name: str, + model: str, + system_instruction: Optional[content_types.ContentType] = None, + contents: Optional[content_types.ContentsType] = None, + tools: Optional[content_types.FunctionLibraryType] = None, + tool_config: Optional[content_types.ToolConfigType] = None, + ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), + client: glm.CacheServiceClient | None = None, + ) -> CachedContent: + """Creates CachedContent resource. + + Args: + name: The resource name referring to the cached content. + Format: cachedContents/{id}. + model: The name of the `Model` to use for cached content + Format: models/{model}. Cached content resource can be only + used with model it was created for. + system_instruction: Developer set system instruction. + contents: Contents to cache. + tools: A list of `Tools` the model may use to generate response. + tool_config: Config to apply to all tools. + ttl: TTL for cached resource (in seconds). Defaults to 1 hour. + + Returns: + `CachedContent` resource with specified name. + """ + if client is None: + client = get_default_cache_client() + + request = cls._prepare_create_request( + name=name, + model=model, + system_instruction=system_instruction, + contents=contents, + tools=tools, + tool_config=tool_config, + ttl=ttl, + ) + + response = client.create_cached_content(request) + return cls._decode_cached_content(response) + + @classmethod + def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedContent: + """Fetches required `CachedContent` resource. + + Args: + name: name: The resource name referring to the cached content. -# alias for `caching_types.CachedContent`. -CachedContent = caching_types.CachedContent + Returns: + `CachedContent` resource with specified name. + """ + if client is None: + client = get_default_cache_client() + if "cachedContents/" not in name: + name = "cachedContents/" + name -def get_cached_content(name: str, client: glm.CacheServiceClient | None = None) -> CachedContent: - """Fetches required `CachedContent` resource. + request = glm.GetCachedContentRequest(name=name) + response = client.get_cached_content(request) + return cls._decode_cached_content(response) + + @classmethod + def list( + cls, + page_size: Optional[int] = 1, + client: glm.CacheServiceClient | None = None + ) -> Iterable[CachedContent]: + """Lists `CachedContent` objects associated with the project. - Args: - name: name: The resource name referring to the cached content. + Args: + page_size: The maximum number of permissions to return (per page). + The service may return fewer `CachedContent` objects. - Returns: - `CachedContent` resource with specified name. - """ - return CachedContent.get(name=name, client=client) + Returns: + A paginated list of `CachedContent` objects. + """ + if client is None: + client = get_default_cache_client() + request = glm.ListCachedContentsRequest(page_size=page_size) + for cached_content in client.list_cached_contents(request): + yield cls._decode_cached_content(cached_content) -def delete_cached_content(name: str, client: glm.CacheServiceClient | None = None) -> None: - """Deletes `CachedContent` resource. + def delete(self, client: glm.CachedServiceClient | None = None) -> None: + """Deletes `CachedContent` resource. - Args: - name: The resource name referring to the cached content. - Format: cachedContents/{id}. - """ - if client is None: - client = get_default_cache_client() + Args: + name: The resource name referring to the cached content. + Format: cachedContents/{id}. + """ + if client is None: + client = get_default_cache_client() - if "cachedContents/" not in name: - name = "cachedContents/" + name + request = glm.DeleteCachedContentRequest(name=self.name) + client.delete_cached_content(request) + return - request = glm.DeleteCachedContentRequest(name=name) - client.delete_cached_content(request) - return + def update( + self, + updates: dict[str, Any], + client: glm.CacheServiceClient | None = None, + ) -> CachedContent: + """Updates requested `CachedContent` resource. + Args: + updates: The list of fields to update. + Currently only `ttl/expire_time` is supported as an update path. -def list_cached_contents( - page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None -) -> Iterable[CachedContent]: - """Lists `CachedContent` objects associated with the project. + Returns: + `CachedContent` object with specified updates. + """ + if client is None: + client = get_default_cache_client() - Args: - page_size: The maximum number of permissions to return (per page). The service may return fewer `CachedContent` objects. + updates = flatten_update_paths(updates) + for update_path in updates: + if update_path == "ttl": + updates = updates.copy() + update_path_val = updates.get(update_path) + updates[update_path] = caching_types.to_ttl(update_path_val) + else: + raise ValueError( + f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead." + ) + field_mask = field_mask_pb2.FieldMask() - Returns: - A paginated list of `CachedContent` objects. - """ - if client is None: - client = get_default_cache_client() + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + self._apply_update(path, value) - request = glm.ListCachedContentsRequest(page_size=page_size) - for cached_content in client.list_cached_contents(request): - yield caching_types.decode_cached_content(cached_content) + request = glm.UpdateCachedContentRequest( + cached_content=self._to_dict(), update_mask=field_mask + ) + client.update_cached_content(request) + return self diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 89af27607..7101d930e 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -19,8 +19,6 @@ from google.generativeai.types import generation_types from google.generativeai.types import helper_types from google.generativeai.types import safety_types -from google.generativeai.types import caching_types - class GenerativeModel: """ @@ -198,7 +196,7 @@ def from_cached_content( @classmethod def from_cached_content( cls, - cached_content: caching_types.CachedContent, + cached_content: caching.CachedContent, generation_config: generation_types.GenerationConfigType | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, ) -> GenerativeModel: ... @@ -206,7 +204,7 @@ def from_cached_content( @classmethod def from_cached_content( cls, - cached_content: str | caching_types.CachedContent, + cached_content: str | caching.CachedContent, generation_config: generation_types.GenerationConfigType | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, ) -> GenerativeModel: @@ -219,7 +217,7 @@ def from_cached_content( `GenerativeModel` object with `cached_content` as its context. """ if isinstance(cached_content, str): - cached_content = caching.get_cached_content(name=cached_content) + cached_content = caching.CachedContent.get(name=cached_content) # call __new__ with the cached_content to set the model's context. This is done to avoid # the exposing `cached_content` as a public attribute. diff --git a/google/generativeai/types/caching_types.py b/google/generativeai/types/caching_types.py index aa9fbe559..2af39f4f2 100644 --- a/google/generativeai/types/caching_types.py +++ b/google/generativeai/types/caching_types.py @@ -14,22 +14,11 @@ # limitations under the License. from __future__ import annotations -import dataclasses import datetime from typing import Optional, Union from typing_extensions import TypedDict -from google.generativeai.types import content_types -from google.generativeai.types.model_types import idecode_time -from google.generativeai.utils import flatten_update_paths -from google.generativeai.client import get_default_cache_client - -from google.protobuf import field_mask_pb2 - -import google.ai.generativelanguage as glm - - -__all__ = ["CachedContent"] +__all__ = ["TTL"] class TTL(TypedDict): @@ -38,8 +27,7 @@ class TTL(TypedDict): ExpirationTypes = Union[TTL, int, datetime.timedelta] - -def _to_ttl(expiration: Optional[ExpirationTypes]) -> TTL: +def to_ttl(expiration: Optional[ExpirationTypes]) -> TTL: if isinstance(expiration, datetime.timedelta): return {"seconds": int(expiration.total_seconds())} elif isinstance(expiration, dict): @@ -51,205 +39,3 @@ def _to_ttl(expiration: Optional[ExpirationTypes]) -> TTL: f"Could not convert input to `expire_time` \n'" f" type: {type(expiration)}\n", expiration, ) - - -def decode_cached_content(cached_content: glm.CachedContent) -> CachedContent: - # not supposed to get INPUT_ONLY repeated fields, but local gapic lib build - # is returning these, hence setting including_default_value_fields to False - cached_content = type(cached_content).to_dict( - cached_content, including_default_value_fields=False - ) - - idecode_time(cached_content, "create_time") - idecode_time(cached_content, "update_time") - # always decode `expire_time` as Timestamp is returned - # regardless of what was sent on input - idecode_time(cached_content, "expire_time") - return CachedContent(**cached_content) - - -@dataclasses.dataclass -class CachedContent: - """Cached content resource.""" - - name: str - model: str - create_time: datetime.datetime - update_time: datetime.datetime - expire_time: datetime.datetime - - # NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+). - # Adding basic support for now. - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, exc_tb): - self.delete() - - def _to_dict(self) -> glm.CachedContent: - proto_paths = { - "name": self.name, - "model": self.model, - } - return glm.CachedContent(**proto_paths) - - def _apply_update(self, path, value): - parts = path.split(".") - for part in parts[:-1]: - self = getattr(self, part) - if parts[-1] == "ttl": - value = self.expire_time + datetime.timedelta(seconds=value["seconds"]) - parts[-1] = "expire_time" - setattr(self, parts[-1], value) - - @staticmethod - def _prepare_create_request( - name: str, - model: str, - system_instruction: Optional[content_types.ContentType] = None, - contents: Optional[content_types.ContentsType] = None, - tools: Optional[content_types.FunctionLibraryType] = None, - tool_config: Optional[content_types.ToolConfigType] = None, - ttl: Optional[ExpirationTypes] = datetime.timedelta(hours=1), - ) -> glm.CreateCachedContentRequest: - """Prepares a CreateCachedContentRequest.""" - if "cachedContents/" not in name: - name = "cachedContents/" + name - - if "/" not in model: - model = "models/" + model - - if system_instruction: - system_instruction = content_types.to_content(system_instruction) - - tools_lib = content_types.to_function_library(tools) - if tools_lib: - tools_lib = tools_lib.to_proto() - - if tool_config: - tool_config = content_types.to_tool_config(tool_config) - - if contents: - contents = content_types.to_contents(contents) - - if ttl: - ttl = _to_ttl(ttl) - - cached_content = glm.CachedContent( - name=name, - model=model, - system_instruction=system_instruction, - contents=contents, - tools=tools_lib, - tool_config=tool_config, - ttl=ttl, - ) - - return glm.CreateCachedContentRequest(cached_content=cached_content) - - @classmethod - def create( - cls, - name: str, - model: str, - system_instruction: Optional[content_types.ContentType] = None, - contents: Optional[content_types.ContentsType] = None, - tools: Optional[content_types.FunctionLibraryType] = None, - tool_config: Optional[content_types.ToolConfigType] = None, - ttl: Optional[ExpirationTypes] = datetime.timedelta(hours=1), - client: glm.CacheServiceClient | None = None, - ) -> CachedContent: - """Creates CachedContent resource. - - Args: - name: The resource name referring to the cached content. - Format: cachedContents/{id}. - model: The name of the `Model` to use for cached content - Format: models/{model}. Cached content resource can be only - used with model it was created for. - system_instruction: Developer set system instruction. - contents: Contents to cache. - tools: A list of `Tools` the model may use to generate response. - tool_config: Config to apply to all tools. - ttl: TTL for cached resource (in seconds). Defaults to 1 hour. - - Returns: - `CachedContent` resource with specified name. - """ - if client is None: - client = get_default_cache_client() - - request = cls._prepare_create_request( - name=name, - model=model, - system_instruction=system_instruction, - contents=contents, - tools=tools, - tool_config=tool_config, - ttl=ttl, - ) - - response = client.create_cached_content(request) - return decode_cached_content(response) - - @classmethod - def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedContent: - """Gets a `CachedContent` resource.""" - if client is None: - client = get_default_cache_client() - - if "cachedContents/" not in name: - name = "cachedContents/" + name - - request = glm.GetCachedContentRequest(name=name) - response = client.get_cached_content(request) - return decode_cached_content(response) - - def delete(self, client: glm.CachedServiceClient | None = None) -> None: - """Deletes a `CachedContent` resource.""" - if client is None: - client = get_default_cache_client() - - request = glm.DeleteCachedContentRequest(name=self.name) - client.delete_cached_content(request) - return - - def update( - self, - updates: dict[str, Any], - client: glm.CacheServiceClient | None = None, - ) -> CachedContent: - """Updates requested `CachedContent` resource. - - Args: - updates: The list of fields to update. - Currently only `ttl/expire_time` is supported as an update path. - - Returns: - `CachedContent` object with specified updates. - """ - if client is None: - client = get_default_cache_client() - - updates = flatten_update_paths(updates) - for update_path in updates: - if update_path == "ttl": - updates = updates.copy() - update_path_val = updates.get(update_path) - updates[update_path] = _to_ttl(update_path_val) - else: - raise ValueError( - f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead." - ) - field_mask = field_mask_pb2.FieldMask() - - for path in updates.keys(): - field_mask.paths.append(path) - for path, value in updates.items(): - self._apply_update(path, value) - - request = glm.UpdateCachedContentRequest( - cached_content=self._to_dict(), update_mask=field_mask - ) - client.update_cached_content(request) - return self diff --git a/tests/test_caching.py b/tests/test_caching.py index 916dc407b..db8f5a90d 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -129,7 +129,7 @@ def add(a: int, b: int) -> int: ttl=datetime.timedelta(minutes=30), ) self.assertIsInstance(self.observed_requests[-1], glm.CreateCachedContentRequest) - self.assertIsInstance(cc, caching_types.CachedContent) + self.assertIsInstance(cc, caching.CachedContent) self.assertEqual(cc.name, "cachedContent/test-cached-content") self.assertEqual(cc.model, "models/gemini-1.0-pro-001") @@ -161,21 +161,21 @@ def test_expiration_types_for_create_cached_content(self, ttl): ttl=ttl, ) self.assertIsInstance(self.observed_requests[-1], glm.CreateCachedContentRequest) - self.assertIsInstance(cc, caching_types.CachedContent) + self.assertIsInstance(cc, caching.CachedContent) def test_get_cached_content(self): - cc = caching.get_cached_content(name="cachedContent/test-cached-content") + cc = caching.CachedContent.get(name="cachedContent/test-cached-content") self.assertIsInstance(self.observed_requests[-1], glm.GetCachedContentRequest) - self.assertIsInstance(cc, caching_types.CachedContent) + self.assertIsInstance(cc, caching.CachedContent) self.assertEqual(cc.name, "cachedContent/test-cached-content") self.assertEqual(cc.model, "models/gemini-1.0-pro-001") def test_list_cached_contents(self): - ccs = list(caching.list_cached_contents(page_size=2)) + ccs = list(caching.CachedContent.list(page_size=2)) self.assertIsInstance(self.observed_requests[-1], glm.ListCachedContentsRequest) self.assertLen(ccs, 2) - self.assertIsInstance(ccs[0], caching_types.CachedContent) - self.assertIsInstance(ccs[1], caching_types.CachedContent) + self.assertIsInstance(ccs[0], caching.CachedContent) + self.assertIsInstance(ccs[1], caching.CachedContent) def test_update_cached_content_invalid_update_paths(self): update_masks = dict( @@ -185,7 +185,7 @@ def test_update_cached_content_invalid_update_paths(self): contents=["add this Content"], ) - cc = caching.get_cached_content(name="cachedContent/test-cached-content") + cc = caching.CachedContent.get(name="cachedContent/test-cached-content") with self.assertRaises(ValueError): cc.update(updates=update_masks) @@ -194,13 +194,13 @@ def test_update_cached_content_valid_update_paths(self): ttl=datetime.timedelta(hours=2), ) - cc = caching.get_cached_content(name="cachedContent/test-cached-content") + cc = caching.CachedContent.get(name="cachedContent/test-cached-content") cc = cc.update(updates=update_masks) self.assertIsInstance(self.observed_requests[-1], glm.UpdateCachedContentRequest) - self.assertIsInstance(cc, caching_types.CachedContent) + self.assertIsInstance(cc, caching.CachedContent) def test_delete_cached_content(self): - cc = caching.get_cached_content(name="cachedContent/test-cached-content") + cc = caching.CachedContent.get(name="cachedContent/test-cached-content") cc.delete() self.assertIsInstance(self.observed_requests[-1], glm.DeleteCachedContentRequest) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 78ccc0bce..31444a03a 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -10,9 +10,9 @@ import google.ai.generativelanguage as glm from google.generativeai import client as client_lib from google.generativeai import generative_models +from google.generativeai import caching from google.generativeai.types import content_types from google.generativeai.types import generation_types -from google.generativeai.types import caching_types from google.generativeai.types import helper_types import PIL.Image @@ -337,7 +337,7 @@ def test_stream_prompt_feedback_not_blocked(self): dict(testcase_name="test_cached_content_as_id", cached_content="test-cached-content"), dict( testcase_name="test_cached_content_as_CachedContent_object", - cached_content=caching_types.CachedContent( + cached_content=caching.CachedContent( name="cachedContent/test-cached-content", model="models/gemini-1.0-pro-001", create_time="2000-01-01T01:01:01.123456Z", From e65d16e5a8c72683780631a037769c1e00dc6b7d Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Thu, 23 May 2024 23:12:05 +0530 Subject: [PATCH 04/19] blacken --- google/generativeai/caching.py | 10 ++++------ google/generativeai/generative_models.py | 1 + google/generativeai/types/caching_types.py | 1 + 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py index 5ed449ef8..4dbde18bc 100644 --- a/google/generativeai/caching.py +++ b/google/generativeai/caching.py @@ -61,7 +61,7 @@ def _apply_update(self, path, value): value = self.expire_time + datetime.timedelta(seconds=value["seconds"]) parts[-1] = "expire_time" setattr(self, parts[-1], value) - + @classmethod def _decode_cached_content(cls, cached_content: glm.CachedContent) -> CachedContent: # not supposed to get INPUT_ONLY repeated fields, but local gapic lib build @@ -186,17 +186,15 @@ def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedC request = glm.GetCachedContentRequest(name=name) response = client.get_cached_content(request) return cls._decode_cached_content(response) - + @classmethod def list( - cls, - page_size: Optional[int] = 1, - client: glm.CacheServiceClient | None = None + cls, page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None ) -> Iterable[CachedContent]: """Lists `CachedContent` objects associated with the project. Args: - page_size: The maximum number of permissions to return (per page). + page_size: The maximum number of permissions to return (per page). The service may return fewer `CachedContent` objects. Returns: diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 7101d930e..ff3d10b30 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -20,6 +20,7 @@ from google.generativeai.types import helper_types from google.generativeai.types import safety_types + class GenerativeModel: """ The `genai.GenerativeModel` class wraps default parameters for calls to diff --git a/google/generativeai/types/caching_types.py b/google/generativeai/types/caching_types.py index 2af39f4f2..753fb2026 100644 --- a/google/generativeai/types/caching_types.py +++ b/google/generativeai/types/caching_types.py @@ -27,6 +27,7 @@ class TTL(TypedDict): ExpirationTypes = Union[TTL, int, datetime.timedelta] + def to_ttl(expiration: Optional[ExpirationTypes]) -> TTL: if isinstance(expiration, datetime.timedelta): return {"seconds": int(expiration.total_seconds())} From d35cc7194a905d2776abcc719eafac3f4c91d512 Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Thu, 23 May 2024 23:12:38 +0530 Subject: [PATCH 05/19] Improve tests --- tests/test_generative_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 31444a03a..eca103ee1 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -353,6 +353,7 @@ def test_model_with_cached_content_as_context(self, cached_content): model_name = model.model_name self.assertEqual(cc_name, "cachedContent/test-cached-content") self.assertEqual(model_name, "models/gemini-1.0-pro-001") + self.assertEqual(model.cached_content, "cachedContent/test-cached-content") def test_content_generation_with_model_having_context(self): self.responses["generate_content"] = [simple_response("world!")] @@ -362,6 +363,7 @@ def test_content_generation_with_model_having_context(self): response = model.generate_content("Hello") self.assertEqual(response.text, "world!") + self.assertEqual(model.cached_content, "cachedContent/test-cached-content") def test_fail_content_generation_with_model_having_context(self): model = generative_models.GenerativeModel.from_cached_content( From d862dae543645d13e5cf31512b8306c03dcb3fc1 Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Thu, 23 May 2024 18:09:14 +0000 Subject: [PATCH 06/19] fix tests Change-Id: I39f61012f850a82e09a7afb80b527a0f99ad0ec7 --- tests/test_generative_models.py | 40 ++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index eca103ee1..db89482ed 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -76,6 +76,20 @@ def count_tokens( self.observed_kwargs.append(kwargs) response = self.responses["count_tokens"].pop(0) return response + + def get_cached_content( + self, + request: glm.GetCachedContentRequest, + **kwargs, + ) -> glm.CachedContent: + self.observed_requests.append(request) + return glm.CachedContent( + name="cachedContent/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ) class CUJTests(parameterized.TestCase): @@ -98,19 +112,19 @@ def setUp(self): client_lib._client_manager.clients["generative"] = self.client client_lib._client_manager.clients["cache"] = self.client - @add_client_method - def get_cached_content( - request: glm.GetCachedContentRequest, - **kwargs, - ) -> glm.CachedContent: - self.observed_requests.append(request) - return glm.CachedContent( - name="cachedContent/test-cached-content", - model="models/gemini-1.0-pro-001", - create_time="2000-01-01T01:01:01.123456Z", - update_time="2000-01-01T01:01:01.123456Z", - expire_time="2000-01-01T01:01:01.123456Z", - ) + # @add_client_method + # def get_cached_content( + # request: glm.GetCachedContentRequest, + # **kwargs, + # ) -> glm.CachedContent: + # self.observed_requests.append(request) + # return glm.CachedContent( + # name="cachedContent/test-cached-content", + # model="models/gemini-1.0-pro-001", + # create_time="2000-01-01T01:01:01.123456Z", + # update_time="2000-01-01T01:01:01.123456Z", + # expire_time="2000-01-01T01:01:01.123456Z", + # ) def test_hello(self): # Generate text from text prompt From 2cde1a21ea15c42eceb6778add040eb6d3a69b95 Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Thu, 23 May 2024 18:09:14 +0000 Subject: [PATCH 07/19] fix tests Change-Id: I39f61012f850a82e09a7afb80b527a0f99ad0ec7 --- tests/test_caching.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_caching.py b/tests/test_caching.py index db8f5a90d..20b675950 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -204,7 +204,8 @@ def test_delete_cached_content(self): cc.delete() self.assertIsInstance(self.observed_requests[-1], glm.DeleteCachedContentRequest) - cc = caching.delete_cached_content(name="cachedContent/test-cached-content") + cc = caching.CachedContent.get(name="cachedContent/test-cached-content") + cc.delete() self.assertIsInstance(self.observed_requests[-1], glm.DeleteCachedContentRequest) def test_auto_delete_cached_content_with_context_manager(self): From e1d8c7ac2785add8b27e4fee8bd7835a98156de7 Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Fri, 24 May 2024 10:21:42 +0000 Subject: [PATCH 08/19] Validate name checks for CachedContent creation Change-Id: Ie41602621d99ddff6404c6708c7278e0da790652 --- google/generativeai/caching.py | 4 +++- google/generativeai/types/caching_types.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py index 4dbde18bc..c8dc4b598 100644 --- a/google/generativeai/caching.py +++ b/google/generativeai/caching.py @@ -88,8 +88,10 @@ def _prepare_create_request( ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), ) -> glm.CreateCachedContentRequest: """Prepares a CreateCachedContentRequest.""" - if "cachedContents/" not in name: + if caching_types.valid_cached_content_name(name): name = "cachedContents/" + name + else: + raise ValueError(caching_types.NAME_ERROR_MESSAGE.format(name=name)) if "/" not in model: model = "models/" + model diff --git a/google/generativeai/types/caching_types.py b/google/generativeai/types/caching_types.py index 753fb2026..8d55b70b2 100644 --- a/google/generativeai/types/caching_types.py +++ b/google/generativeai/types/caching_types.py @@ -17,10 +17,21 @@ import datetime from typing import Optional, Union from typing_extensions import TypedDict +import re __all__ = ["TTL"] +_VALID_CACHED_CONTENT_NAME = r"([a-z0-9-\.]+)$" +NAME_ERROR_MESSAGE = ( + "The `name` must consist of alphanumeric characters (or `-` or `.`). Received: `{name}`" +) + + +def valid_cached_content_name(name: str) -> bool: + return re.match(_VALID_CACHED_CONTENT_NAME, name) is not None + + class TTL(TypedDict): seconds: int From 59663c88d6fc3958544fe877d3c71962c15bd865 Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Fri, 24 May 2024 10:22:08 +0000 Subject: [PATCH 09/19] Add tests Change-Id: I249188fa585bd9b7193efa48b1cfca20b8a79821 --- tests/test_caching.py | 27 +++++++++++++++++++++++++++ tests/test_generative_models.py | 2 +- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/test_caching.py b/tests/test_caching.py index 20b675950..3ad352d0d 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -163,6 +163,33 @@ def test_expiration_types_for_create_cached_content(self, ttl): self.assertIsInstance(self.observed_requests[-1], glm.CreateCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) + @parameterized.named_parameters( + [ + dict( + testcase_name="upper_case", + name="Test-cached-content", + ), + dict( + testcase_name="special_characters_except_dot_and_hyphen", + name="test-cac*@/hed-conte#nt", + ), + dict( + testcase_name="empty_name", + name="", + ), + dict( + testcase_name="blank_spaces", + name="test cached content", + ), + ] + ) + def test_create_cached_content_with_invalid_name_format(self, name): + with self.assertRaises(ValueError): + _ = caching.CachedContent.create( + name=name, + model="models/gemini-1.0-pro-001", + ) + def test_get_cached_content(self): cc = caching.CachedContent.get(name="cachedContent/test-cached-content") self.assertIsInstance(self.observed_requests[-1], glm.GetCachedContentRequest) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index db89482ed..669107735 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -76,7 +76,7 @@ def count_tokens( self.observed_kwargs.append(kwargs) response = self.responses["count_tokens"].pop(0) return response - + def get_cached_content( self, request: glm.GetCachedContentRequest, From f37df8cc5e3dc5f81603ec013746059ce1abc717 Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Sun, 26 May 2024 06:51:54 +0000 Subject: [PATCH 10/19] mark name as OPTIONAL for CachedContent creation If not provided, the name will be randomly generated Change-Id: Ib95fbafd3dfe098b43164d7ee4d6c2a84b0aae2e --- google/generativeai/caching.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py index c8dc4b598..d825db902 100644 --- a/google/generativeai/caching.py +++ b/google/generativeai/caching.py @@ -79,8 +79,8 @@ def _decode_cached_content(cls, cached_content: glm.CachedContent) -> CachedCont @staticmethod def _prepare_create_request( - name: str, model: str, + name: str = None, system_instruction: Optional[content_types.ContentType] = None, contents: Optional[content_types.ContentsType] = None, tools: Optional[content_types.FunctionLibraryType] = None, @@ -88,10 +88,11 @@ def _prepare_create_request( ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), ) -> glm.CreateCachedContentRequest: """Prepares a CreateCachedContentRequest.""" - if caching_types.valid_cached_content_name(name): + if name is not None: + if not caching_types.valid_cached_content_name(name): + raise ValueError(caching_types.NAME_ERROR_MESSAGE.format(name=name)) + name = "cachedContents/" + name - else: - raise ValueError(caching_types.NAME_ERROR_MESSAGE.format(name=name)) if "/" not in model: model = "models/" + model @@ -127,8 +128,8 @@ def _prepare_create_request( @classmethod def create( cls, - name: str, model: str, + name: str = None, system_instruction: Optional[content_types.ContentType] = None, contents: Optional[content_types.ContentsType] = None, tools: Optional[content_types.FunctionLibraryType] = None, @@ -139,11 +140,10 @@ def create( """Creates CachedContent resource. Args: - name: The resource name referring to the cached content. - Format: cachedContents/{id}. model: The name of the `Model` to use for cached content Format: models/{model}. Cached content resource can be only used with model it was created for. + name: The resource name referring to the cached content. system_instruction: Developer set system instruction. contents: Contents to cache. tools: A list of `Tools` the model may use to generate response. @@ -157,8 +157,8 @@ def create( client = get_default_cache_client() request = cls._prepare_create_request( - name=name, model=model, + name=name, system_instruction=system_instruction, contents=contents, tools=tools, From d1fd7496ea09612b6d8df64bd374603589fb62fb Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Mon, 27 May 2024 05:04:43 +0000 Subject: [PATCH 11/19] Add type-annotations to __new__ to fix pytype checks Change-Id: I6c69c036e54d56d18ea60368fa0a1dcda2d315fd --- google/generativeai/generative_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index ff3d10b30..56488a743 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -96,7 +96,7 @@ def __init__( self._client = None self._async_client = None - def __new__(cls, *args, **kwargs): + def __new__(cls, *args, **kwargs) -> GenerativeModel: self = super().__new__(cls) if cached_instance := kwargs.pop("cached_content", None): From 17372e3f118d1126ac32e918aac25975d8f455c4 Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Mon, 27 May 2024 05:54:06 +0000 Subject: [PATCH 12/19] Add 'cached_content' to GenerativeModel's repr Change-Id: I06676fad23895e3e1a6393baa938fc1f2df57d80 --- google/generativeai/generative_models.py | 1 + tests/test_generative_models.py | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 56488a743..a460b17f9 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -123,6 +123,7 @@ def maybe_text(content): safety_settings={self._safety_settings}, tools={self._tools}, system_instruction={maybe_text(self._system_instruction)}, + cached_content={getattr(self, "cached_content", None)} )""" ) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 669107735..68b358546 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -1213,6 +1213,7 @@ def test_repr_for_multi_turn_chat(self): safety_settings={}, tools=None, system_instruction=None, + cached_content=None ), history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), glm.Content({'parts': [{'text': 'first'}], 'role': 'model'}), glm.Content({'parts': [{'text': 'I also like this image.'}, {'inline_data': {'data': 'iVBORw0KGgoA...AAElFTkSuQmCC', 'mime_type': 'image/png'}}], 'role': 'user'}), glm.Content({'parts': [{'text': 'second'}], 'role': 'model'}), glm.Content({'parts': [{'text': 'What things do I like?.'}], 'role': 'user'}), glm.Content({'parts': [{'text': 'third'}], 'role': 'model'})] )""" @@ -1241,6 +1242,7 @@ def test_repr_for_incomplete_streaming_chat(self): safety_settings={}, tools=None, system_instruction=None, + cached_content=None ), history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" @@ -1285,6 +1287,7 @@ def test_repr_for_broken_streaming_chat(self): safety_settings={}, tools=None, system_instruction=None, + cached_content=None ), history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" @@ -1296,6 +1299,14 @@ def test_repr_for_system_instruction(self): result = repr(model) self.assertIn("system_instruction='Be excellent.'", result) + def test_repr_for_model_created_from_cahced_content(self): + model = generative_models.GenerativeModel.from_cached_content( + cached_content="test-cached-content" + ) + result = repr(model) + self.assertIn("cached_content=cachedContent/test-cached-content", result) + self.assertIn("model_name='models/gemini-1.0-pro-001'", result) + def test_count_tokens_called_with_request_options(self): self.responses["count_tokens"].append(glm.CountTokensResponse()) request_options = {"timeout": 120} From 645ceab6d2bd10524edf0edd43f780e4c93c410b Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Mon, 27 May 2024 05:54:26 +0000 Subject: [PATCH 13/19] blacken Change-Id: I4e073d821d29eea30801bdb7e2a8dc01bb7d6b9a --- google/generativeai/caching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py index d825db902..5f249841a 100644 --- a/google/generativeai/caching.py +++ b/google/generativeai/caching.py @@ -91,7 +91,7 @@ def _prepare_create_request( if name is not None: if not caching_types.valid_cached_content_name(name): raise ValueError(caching_types.NAME_ERROR_MESSAGE.format(name=name)) - + name = "cachedContents/" + name if "/" not in model: From f48cedc391982f2442dde08e553303298c61f49c Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Mon, 27 May 2024 11:13:44 +0000 Subject: [PATCH 14/19] Fix types Change-Id: Ia4bf6b936fab4c1992798c65cff91c15e51a92c0 --- google/generativeai/caching.py | 4 ++-- tests/test_caching.py | 24 +++++++++++----------- tests/test_generative_models.py | 35 ++++++++++----------------------- 3 files changed, 24 insertions(+), 39 deletions(-) diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py index 5f249841a..37bd10532 100644 --- a/google/generativeai/caching.py +++ b/google/generativeai/caching.py @@ -80,7 +80,7 @@ def _decode_cached_content(cls, cached_content: glm.CachedContent) -> CachedCont @staticmethod def _prepare_create_request( model: str, - name: str = None, + name: str | None = None, system_instruction: Optional[content_types.ContentType] = None, contents: Optional[content_types.ContentsType] = None, tools: Optional[content_types.FunctionLibraryType] = None, @@ -129,7 +129,7 @@ def _prepare_create_request( def create( cls, model: str, - name: str = None, + name: str | None = None, system_instruction: Optional[content_types.ContentType] = None, contents: Optional[content_types.ContentsType] = None, tools: Optional[content_types.FunctionLibraryType] = None, diff --git a/tests/test_caching.py b/tests/test_caching.py index 3ad352d0d..9489db092 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -49,7 +49,7 @@ def create_cached_content( ) -> glm.CachedContent: self.observed_requests.append(request) return glm.CachedContent( - name="cachedContent/test-cached-content", + name="cachedContents/test-cached-content", model="models/gemini-1.0-pro-001", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", @@ -63,7 +63,7 @@ def get_cached_content( ) -> glm.CachedContent: self.observed_requests.append(request) return glm.CachedContent( - name="cachedContent/test-cached-content", + name="cachedContents/test-cached-content", model="models/gemini-1.0-pro-001", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", @@ -78,14 +78,14 @@ def list_cached_contents( self.observed_requests.append(request) return [ glm.CachedContent( - name="cachedContent/test-cached-content-1", + name="cachedContents/test-cached-content-1", model="models/gemini-1.0-pro-001", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T01:01:01.123456Z", ), glm.CachedContent( - name="cachedContent/test-cached-content-2", + name="cachedContents/test-cached-content-2", model="models/gemini-1.0-pro-001", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", @@ -100,7 +100,7 @@ def update_cached_content( ) -> glm.CachedContent: self.observed_requests.append(request) return glm.CachedContent( - name="cachedContent/test-cached-content", + name="cachedContents/test-cached-content", model="models/gemini-1.0-pro-001", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", @@ -130,7 +130,7 @@ def add(a: int, b: int) -> int: ) self.assertIsInstance(self.observed_requests[-1], glm.CreateCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) - self.assertEqual(cc.name, "cachedContent/test-cached-content") + self.assertEqual(cc.name, "cachedContents/test-cached-content") self.assertEqual(cc.model, "models/gemini-1.0-pro-001") @parameterized.named_parameters( @@ -191,10 +191,10 @@ def test_create_cached_content_with_invalid_name_format(self, name): ) def test_get_cached_content(self): - cc = caching.CachedContent.get(name="cachedContent/test-cached-content") + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") self.assertIsInstance(self.observed_requests[-1], glm.GetCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) - self.assertEqual(cc.name, "cachedContent/test-cached-content") + self.assertEqual(cc.name, "cachedContents/test-cached-content") self.assertEqual(cc.model, "models/gemini-1.0-pro-001") def test_list_cached_contents(self): @@ -212,7 +212,7 @@ def test_update_cached_content_invalid_update_paths(self): contents=["add this Content"], ) - cc = caching.CachedContent.get(name="cachedContent/test-cached-content") + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") with self.assertRaises(ValueError): cc.update(updates=update_masks) @@ -221,17 +221,17 @@ def test_update_cached_content_valid_update_paths(self): ttl=datetime.timedelta(hours=2), ) - cc = caching.CachedContent.get(name="cachedContent/test-cached-content") + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") cc = cc.update(updates=update_masks) self.assertIsInstance(self.observed_requests[-1], glm.UpdateCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) def test_delete_cached_content(self): - cc = caching.CachedContent.get(name="cachedContent/test-cached-content") + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") cc.delete() self.assertIsInstance(self.observed_requests[-1], glm.DeleteCachedContentRequest) - cc = caching.CachedContent.get(name="cachedContent/test-cached-content") + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") cc.delete() self.assertIsInstance(self.observed_requests[-1], glm.DeleteCachedContentRequest) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 68b358546..6ca8cdb2e 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -84,7 +84,7 @@ def get_cached_content( ) -> glm.CachedContent: self.observed_requests.append(request) return glm.CachedContent( - name="cachedContent/test-cached-content", + name="cachedContents/test-cached-content", model="models/gemini-1.0-pro-001", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", @@ -112,19 +112,6 @@ def setUp(self): client_lib._client_manager.clients["generative"] = self.client client_lib._client_manager.clients["cache"] = self.client - # @add_client_method - # def get_cached_content( - # request: glm.GetCachedContentRequest, - # **kwargs, - # ) -> glm.CachedContent: - # self.observed_requests.append(request) - # return glm.CachedContent( - # name="cachedContent/test-cached-content", - # model="models/gemini-1.0-pro-001", - # create_time="2000-01-01T01:01:01.123456Z", - # update_time="2000-01-01T01:01:01.123456Z", - # expire_time="2000-01-01T01:01:01.123456Z", - # ) def test_hello(self): # Generate text from text prompt @@ -351,23 +338,19 @@ def test_stream_prompt_feedback_not_blocked(self): dict(testcase_name="test_cached_content_as_id", cached_content="test-cached-content"), dict( testcase_name="test_cached_content_as_CachedContent_object", - cached_content=caching.CachedContent( - name="cachedContent/test-cached-content", - model="models/gemini-1.0-pro-001", - create_time="2000-01-01T01:01:01.123456Z", - update_time="2000-01-01T01:01:01.123456Z", - expire_time="2000-01-01T01:01:01.123456Z", - ), + cached_content=caching.CachedContent.get(name="cachedContents/test-cached-content"), ), ], ) def test_model_with_cached_content_as_context(self, cached_content): model = generative_models.GenerativeModel.from_cached_content(cached_content=cached_content) - cc_name = model.cached_content + cc_name = model.cached_content # pytype: disable=attribute-error model_name = model.model_name - self.assertEqual(cc_name, "cachedContent/test-cached-content") + self.assertEqual(cc_name, "cachedContents/test-cached-content") self.assertEqual(model_name, "models/gemini-1.0-pro-001") - self.assertEqual(model.cached_content, "cachedContent/test-cached-content") + self.assertEqual( + model.cached_content, "cachedContents/test-cached-content" + ) # pytype: disable=attribute-error def test_content_generation_with_model_having_context(self): self.responses["generate_content"] = [simple_response("world!")] @@ -377,7 +360,9 @@ def test_content_generation_with_model_having_context(self): response = model.generate_content("Hello") self.assertEqual(response.text, "world!") - self.assertEqual(model.cached_content, "cachedContent/test-cached-content") + self.assertEqual( + model.cached_content, "cachedContents/test-cached-content" + ) # pytype: disable=attribute-error def test_fail_content_generation_with_model_having_context(self): model = generative_models.GenerativeModel.from_cached_content( From a1c8c725540ebe1b3ea486ad1b45ee6836b40ca6 Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Mon, 27 May 2024 11:15:15 +0000 Subject: [PATCH 15/19] Fix docstrings Change-Id: I6020df4e862a4f1d58462a4cd70876a8448293cf --- google/generativeai/caching.py | 7 +------ tests/test_generative_models.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py index 37bd10532..2e59717d0 100644 --- a/google/generativeai/caching.py +++ b/google/generativeai/caching.py @@ -210,12 +210,7 @@ def list( yield cls._decode_cached_content(cached_content) def delete(self, client: glm.CachedServiceClient | None = None) -> None: - """Deletes `CachedContent` resource. - - Args: - name: The resource name referring to the cached content. - Format: cachedContents/{id}. - """ + """Deletes `CachedContent` resource.""" if client is None: client = get_default_cache_client() diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 6ca8cdb2e..ef7213740 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -1,6 +1,7 @@ import collections from collections.abc import Iterable import copy +import datetime import pathlib from typing import Any import textwrap @@ -112,7 +113,6 @@ def setUp(self): client_lib._client_manager.clients["generative"] = self.client client_lib._client_manager.clients["cache"] = self.client - def test_hello(self): # Generate text from text prompt model = generative_models.GenerativeModel(model_name="gemini-pro") @@ -338,7 +338,13 @@ def test_stream_prompt_feedback_not_blocked(self): dict(testcase_name="test_cached_content_as_id", cached_content="test-cached-content"), dict( testcase_name="test_cached_content_as_CachedContent_object", - cached_content=caching.CachedContent.get(name="cachedContents/test-cached-content"), + cached_content=caching.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time=datetime.datetime.now(), + update_time=datetime.datetime.now(), + expire_time=datetime.datetime.now(), + ), ), ], ) @@ -1289,7 +1295,7 @@ def test_repr_for_model_created_from_cahced_content(self): cached_content="test-cached-content" ) result = repr(model) - self.assertIn("cached_content=cachedContent/test-cached-content", result) + self.assertIn("cached_content=cachedContents/test-cached-content", result) self.assertIn("model_name='models/gemini-1.0-pro-001'", result) def test_count_tokens_called_with_request_options(self): From 67472d32bcd1dbbb62972e1ad626efdee30cf0c1 Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Mon, 27 May 2024 11:26:03 +0000 Subject: [PATCH 16/19] Fix types Change-Id: Id3e7316562f4029e5b7409ae725bb66e2207f075 --- tests/test_generative_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index ef7213740..0c9d5650f 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -355,8 +355,8 @@ def test_model_with_cached_content_as_context(self, cached_content): self.assertEqual(cc_name, "cachedContents/test-cached-content") self.assertEqual(model_name, "models/gemini-1.0-pro-001") self.assertEqual( - model.cached_content, "cachedContents/test-cached-content" - ) # pytype: disable=attribute-error + model.cached_content, "cachedContents/test-cached-content" # pytype: disable=attribute-error + ) def test_content_generation_with_model_having_context(self): self.responses["generate_content"] = [simple_response("world!")] @@ -367,8 +367,8 @@ def test_content_generation_with_model_having_context(self): self.assertEqual(response.text, "world!") self.assertEqual( - model.cached_content, "cachedContents/test-cached-content" - ) # pytype: disable=attribute-error + model.cached_content, "cachedContents/test-cached-content" # pytype: disable=attribute-error + ) def test_fail_content_generation_with_model_having_context(self): model = generative_models.GenerativeModel.from_cached_content( From bf6551ac133c50be294788357fb52a318d4d5d4d Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Mon, 27 May 2024 11:26:03 +0000 Subject: [PATCH 17/19] Fix types Change-Id: Id3e7316562f4029e5b7409ae725bb66e2207f075 --- tests/test_generative_models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 0c9d5650f..b2d95d6d4 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -355,7 +355,8 @@ def test_model_with_cached_content_as_context(self, cached_content): self.assertEqual(cc_name, "cachedContents/test-cached-content") self.assertEqual(model_name, "models/gemini-1.0-pro-001") self.assertEqual( - model.cached_content, "cachedContents/test-cached-content" # pytype: disable=attribute-error + model.cached_content, # pytype: disable=attribute-error + "cachedContents/test-cached-content", ) def test_content_generation_with_model_having_context(self): @@ -367,7 +368,8 @@ def test_content_generation_with_model_having_context(self): self.assertEqual(response.text, "world!") self.assertEqual( - model.cached_content, "cachedContents/test-cached-content" # pytype: disable=attribute-error + model.cached_content, # pytype: disable=attribute-error + "cachedContents/test-cached-content", ) def test_fail_content_generation_with_model_having_context(self): From 8e86ef19f9b9fce9d384e12ff364c4e8bdb0265f Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Thu, 30 May 2024 16:18:22 +0000 Subject: [PATCH 18/19] Refactor for genai.protos module Change-Id: I2f02d2421d7303f0309ec86f05d33c07332c03c1 --- google/generativeai/caching.py | 37 +++++++++++----------- tests/test_caching.py | 54 +++++++++++++++------------------ tests/test_generative_models.py | 6 ++-- 3 files changed, 46 insertions(+), 51 deletions(-) diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py index 2e59717d0..a28a50256 100644 --- a/google/generativeai/caching.py +++ b/google/generativeai/caching.py @@ -18,6 +18,7 @@ import datetime from typing import Any, Iterable, Optional +from google.generativeai import protos from google.generativeai.types.model_types import idecode_time from google.generativeai.types import caching_types from google.generativeai.types import content_types @@ -46,12 +47,12 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_tb): self.delete() - def _to_dict(self) -> glm.CachedContent: + def _to_dict(self) -> protos.CachedContent: proto_paths = { "name": self.name, "model": self.model, } - return glm.CachedContent(**proto_paths) + return protos.CachedContent(**proto_paths) def _apply_update(self, path, value): parts = path.split(".") @@ -63,7 +64,7 @@ def _apply_update(self, path, value): setattr(self, parts[-1], value) @classmethod - def _decode_cached_content(cls, cached_content: glm.CachedContent) -> CachedContent: + def _decode_cached_content(cls, cached_content: protos.CachedContent) -> CachedContent: # not supposed to get INPUT_ONLY repeated fields, but local gapic lib build # is returning these, hence setting including_default_value_fields to False cached_content = type(cached_content).to_dict( @@ -86,7 +87,7 @@ def _prepare_create_request( tools: Optional[content_types.FunctionLibraryType] = None, tool_config: Optional[content_types.ToolConfigType] = None, ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), - ) -> glm.CreateCachedContentRequest: + ) -> protos.CreateCachedContentRequest: """Prepares a CreateCachedContentRequest.""" if name is not None: if not caching_types.valid_cached_content_name(name): @@ -113,7 +114,7 @@ def _prepare_create_request( if ttl: ttl = caching_types.to_ttl(ttl) - cached_content = glm.CachedContent( + cached_content = protos.CachedContent( name=name, model=model, system_instruction=system_instruction, @@ -123,7 +124,7 @@ def _prepare_create_request( ttl=ttl, ) - return glm.CreateCachedContentRequest(cached_content=cached_content) + return protos.CreateCachedContentRequest(cached_content=cached_content) @classmethod def create( @@ -137,12 +138,12 @@ def create( ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), client: glm.CacheServiceClient | None = None, ) -> CachedContent: - """Creates CachedContent resource. + """Creates `CachedContent` resource. Args: - model: The name of the `Model` to use for cached content - Format: models/{model}. Cached content resource can be only - used with model it was created for. + model: The name of the `model` to use for cached content creation. + Any `CachedContent` resource can be only used with the + `model` it was created for. name: The resource name referring to the cached content. system_instruction: Developer set system instruction. contents: Contents to cache. @@ -174,10 +175,10 @@ def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedC """Fetches required `CachedContent` resource. Args: - name: name: The resource name referring to the cached content. + name: The resource name referring to the cached content. Returns: - `CachedContent` resource with specified name. + `CachedContent` resource with specified `name`. """ if client is None: client = get_default_cache_client() @@ -185,7 +186,7 @@ def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedC if "cachedContents/" not in name: name = "cachedContents/" + name - request = glm.GetCachedContentRequest(name=name) + request = protos.GetCachedContentRequest(name=name) response = client.get_cached_content(request) return cls._decode_cached_content(response) @@ -205,7 +206,7 @@ def list( if client is None: client = get_default_cache_client() - request = glm.ListCachedContentsRequest(page_size=page_size) + request = protos.ListCachedContentsRequest(page_size=page_size) for cached_content in client.list_cached_contents(request): yield cls._decode_cached_content(cached_content) @@ -214,7 +215,7 @@ def delete(self, client: glm.CachedServiceClient | None = None) -> None: if client is None: client = get_default_cache_client() - request = glm.DeleteCachedContentRequest(name=self.name) + request = protos.DeleteCachedContentRequest(name=self.name) client.delete_cached_content(request) return @@ -226,8 +227,8 @@ def update( """Updates requested `CachedContent` resource. Args: - updates: The list of fields to update. - Currently only `ttl/expire_time` is supported as an update path. + updates: The list of fields to update. Currently only + `ttl/expire_time` is supported as an update path. Returns: `CachedContent` object with specified updates. @@ -252,7 +253,7 @@ def update( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateCachedContentRequest( + request = protos.UpdateCachedContentRequest( cached_content=self._to_dict(), update_mask=field_mask ) client.update_cached_content(request) diff --git a/tests/test_caching.py b/tests/test_caching.py index 9489db092..47692325b 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,17 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy -import math import datetime -from typing import Any import unittest -import unittest.mock as mock - -import google.ai.generativelanguage as glm from google.generativeai import caching -from google.generativeai.types import caching_types +from google.generativeai import protos from google.generativeai import client from absl.testing import absltest @@ -44,11 +38,11 @@ def add_client_method(f): @add_client_method def create_cached_content( - request: glm.CreateCachedContentRequest, + request: protos.CreateCachedContentRequest, **kwargs, - ) -> glm.CachedContent: + ) -> protos.CachedContent: self.observed_requests.append(request) - return glm.CachedContent( + return protos.CachedContent( name="cachedContents/test-cached-content", model="models/gemini-1.0-pro-001", create_time="2000-01-01T01:01:01.123456Z", @@ -58,11 +52,11 @@ def create_cached_content( @add_client_method def get_cached_content( - request: glm.GetCachedContentRequest, + request: protos.GetCachedContentRequest, **kwargs, - ) -> glm.CachedContent: + ) -> protos.CachedContent: self.observed_requests.append(request) - return glm.CachedContent( + return protos.CachedContent( name="cachedContents/test-cached-content", model="models/gemini-1.0-pro-001", create_time="2000-01-01T01:01:01.123456Z", @@ -72,19 +66,19 @@ def get_cached_content( @add_client_method def list_cached_contents( - request: glm.ListCachedContentsRequest, + request: protos.ListCachedContentsRequest, **kwargs, - ) -> glm.ListCachedContentsResponse: + ) -> protos.ListCachedContentsResponse: self.observed_requests.append(request) return [ - glm.CachedContent( + protos.CachedContent( name="cachedContents/test-cached-content-1", model="models/gemini-1.0-pro-001", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T01:01:01.123456Z", ), - glm.CachedContent( + protos.CachedContent( name="cachedContents/test-cached-content-2", model="models/gemini-1.0-pro-001", create_time="2000-01-01T01:01:01.123456Z", @@ -95,11 +89,11 @@ def list_cached_contents( @add_client_method def update_cached_content( - request: glm.UpdateCachedContentRequest, + request: protos.UpdateCachedContentRequest, **kwargs, - ) -> glm.CachedContent: + ) -> protos.CachedContent: self.observed_requests.append(request) - return glm.CachedContent( + return protos.CachedContent( name="cachedContents/test-cached-content", model="models/gemini-1.0-pro-001", create_time="2000-01-01T01:01:01.123456Z", @@ -109,7 +103,7 @@ def update_cached_content( @add_client_method def delete_cached_content( - request: glm.DeleteCachedContentRequest, + request: protos.DeleteCachedContentRequest, **kwargs, ) -> None: self.observed_requests.append(request) @@ -128,7 +122,7 @@ def add(a: int, b: int) -> int: system_instruction="Always add 10 to the result.", ttl=datetime.timedelta(minutes=30), ) - self.assertIsInstance(self.observed_requests[-1], glm.CreateCachedContentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.CreateCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) self.assertEqual(cc.name, "cachedContents/test-cached-content") self.assertEqual(cc.model, "models/gemini-1.0-pro-001") @@ -160,7 +154,7 @@ def test_expiration_types_for_create_cached_content(self, ttl): contents=["cache this please for 2 hours"], ttl=ttl, ) - self.assertIsInstance(self.observed_requests[-1], glm.CreateCachedContentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.CreateCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) @parameterized.named_parameters( @@ -192,14 +186,14 @@ def test_create_cached_content_with_invalid_name_format(self, name): def test_get_cached_content(self): cc = caching.CachedContent.get(name="cachedContents/test-cached-content") - self.assertIsInstance(self.observed_requests[-1], glm.GetCachedContentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) self.assertEqual(cc.name, "cachedContents/test-cached-content") self.assertEqual(cc.model, "models/gemini-1.0-pro-001") def test_list_cached_contents(self): ccs = list(caching.CachedContent.list(page_size=2)) - self.assertIsInstance(self.observed_requests[-1], glm.ListCachedContentsRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListCachedContentsRequest) self.assertLen(ccs, 2) self.assertIsInstance(ccs[0], caching.CachedContent) self.assertIsInstance(ccs[1], caching.CachedContent) @@ -223,17 +217,17 @@ def test_update_cached_content_valid_update_paths(self): cc = caching.CachedContent.get(name="cachedContents/test-cached-content") cc = cc.update(updates=update_masks) - self.assertIsInstance(self.observed_requests[-1], glm.UpdateCachedContentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.UpdateCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) def test_delete_cached_content(self): cc = caching.CachedContent.get(name="cachedContents/test-cached-content") cc.delete() - self.assertIsInstance(self.observed_requests[-1], glm.DeleteCachedContentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) cc = caching.CachedContent.get(name="cachedContents/test-cached-content") cc.delete() - self.assertIsInstance(self.observed_requests[-1], glm.DeleteCachedContentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) def test_auto_delete_cached_content_with_context_manager(self): with caching.CachedContent.create( @@ -245,7 +239,7 @@ def test_auto_delete_cached_content_with_context_manager(self): ) as cc: ... # some logic - self.assertIsInstance(self.observed_requests[-1], glm.DeleteCachedContentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) if __name__ == "__main__": diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 9af48b844..73789346d 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -80,11 +80,11 @@ def count_tokens( def get_cached_content( self, - request: glm.GetCachedContentRequest, + request: protos.GetCachedContentRequest, **kwargs, - ) -> glm.CachedContent: + ) -> protos.CachedContent: self.observed_requests.append(request) - return glm.CachedContent( + return protos.CachedContent( name="cachedContents/test-cached-content", model="models/gemini-1.0-pro-001", create_time="2000-01-01T01:01:01.123456Z", From 4627fe1b411dcb1b5e3c7c1d882ce18b8eac73f7 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 4 Jun 2024 09:54:31 -0700 Subject: [PATCH 19/19] use preview build Change-Id: Ic1cd4fc28f591794dc5fbff0647a00a77ea7f601 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6f9545e4f..0575dcd28 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def get_version(): release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-ai-generativelanguage==0.6.4", + "google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py.tar.gz", "google-api-core", "google-api-python-client", "google-auth>=2.15.0", # 2.15 adds API key auth support