diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py new file mode 100644 index 000000000..a28a50256 --- /dev/null +++ b/google/generativeai/caching.py @@ -0,0 +1,260 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +import datetime +from typing import Any, Iterable, Optional + +from google.generativeai import protos +from google.generativeai.types.model_types import idecode_time +from google.generativeai.types import caching_types +from google.generativeai.types import content_types +from google.generativeai.utils import flatten_update_paths +from google.generativeai.client import get_default_cache_client + +from google.protobuf import field_mask_pb2 +import google.ai.generativelanguage as glm + + +@dataclasses.dataclass +class CachedContent: + """Cached content resource.""" + + name: str + model: str + create_time: datetime.datetime + update_time: datetime.datetime + expire_time: datetime.datetime + + # NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+). + # Adding basic support for now. + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.delete() + + def _to_dict(self) -> protos.CachedContent: + proto_paths = { + "name": self.name, + "model": self.model, + } + return protos.CachedContent(**proto_paths) + + def _apply_update(self, path, value): + parts = path.split(".") + for part in parts[:-1]: + self = getattr(self, part) + if parts[-1] == "ttl": + value = self.expire_time + datetime.timedelta(seconds=value["seconds"]) + parts[-1] = "expire_time" + setattr(self, parts[-1], value) + + @classmethod + def _decode_cached_content(cls, cached_content: protos.CachedContent) -> CachedContent: + # not supposed to get INPUT_ONLY repeated fields, but local gapic lib build + # is returning these, hence setting including_default_value_fields to False + cached_content = type(cached_content).to_dict( + cached_content, including_default_value_fields=False + ) + + idecode_time(cached_content, "create_time") + idecode_time(cached_content, "update_time") + # always decode `expire_time` as Timestamp is returned + # regardless of what was sent on input + idecode_time(cached_content, "expire_time") + return cls(**cached_content) + + @staticmethod + def _prepare_create_request( + model: str, + name: str | None = None, + system_instruction: Optional[content_types.ContentType] = None, + contents: Optional[content_types.ContentsType] = None, + tools: Optional[content_types.FunctionLibraryType] = None, + tool_config: Optional[content_types.ToolConfigType] = None, + ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), + ) -> protos.CreateCachedContentRequest: + """Prepares a CreateCachedContentRequest.""" + if name is not None: + if not caching_types.valid_cached_content_name(name): + raise ValueError(caching_types.NAME_ERROR_MESSAGE.format(name=name)) + + name = "cachedContents/" + name + + if "/" not in model: + model = "models/" + model + + if system_instruction: + system_instruction = content_types.to_content(system_instruction) + + tools_lib = content_types.to_function_library(tools) + if tools_lib: + tools_lib = tools_lib.to_proto() + + if tool_config: + tool_config = content_types.to_tool_config(tool_config) + + if contents: + contents = content_types.to_contents(contents) + + if ttl: + ttl = caching_types.to_ttl(ttl) + + cached_content = protos.CachedContent( + name=name, + model=model, + system_instruction=system_instruction, + contents=contents, + tools=tools_lib, + tool_config=tool_config, + ttl=ttl, + ) + + return protos.CreateCachedContentRequest(cached_content=cached_content) + + @classmethod + def create( + cls, + model: str, + name: str | None = None, + system_instruction: Optional[content_types.ContentType] = None, + contents: Optional[content_types.ContentsType] = None, + tools: Optional[content_types.FunctionLibraryType] = None, + tool_config: Optional[content_types.ToolConfigType] = None, + ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), + client: glm.CacheServiceClient | None = None, + ) -> CachedContent: + """Creates `CachedContent` resource. + + Args: + model: The name of the `model` to use for cached content creation. + Any `CachedContent` resource can be only used with the + `model` it was created for. + name: The resource name referring to the cached content. + system_instruction: Developer set system instruction. + contents: Contents to cache. + tools: A list of `Tools` the model may use to generate response. + tool_config: Config to apply to all tools. + ttl: TTL for cached resource (in seconds). Defaults to 1 hour. + + Returns: + `CachedContent` resource with specified name. + """ + if client is None: + client = get_default_cache_client() + + request = cls._prepare_create_request( + model=model, + name=name, + system_instruction=system_instruction, + contents=contents, + tools=tools, + tool_config=tool_config, + ttl=ttl, + ) + + response = client.create_cached_content(request) + return cls._decode_cached_content(response) + + @classmethod + def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedContent: + """Fetches required `CachedContent` resource. + + Args: + name: The resource name referring to the cached content. + + Returns: + `CachedContent` resource with specified `name`. + """ + if client is None: + client = get_default_cache_client() + + if "cachedContents/" not in name: + name = "cachedContents/" + name + + request = protos.GetCachedContentRequest(name=name) + response = client.get_cached_content(request) + return cls._decode_cached_content(response) + + @classmethod + def list( + cls, page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None + ) -> Iterable[CachedContent]: + """Lists `CachedContent` objects associated with the project. + + Args: + page_size: The maximum number of permissions to return (per page). + The service may return fewer `CachedContent` objects. + + Returns: + A paginated list of `CachedContent` objects. + """ + if client is None: + client = get_default_cache_client() + + request = protos.ListCachedContentsRequest(page_size=page_size) + for cached_content in client.list_cached_contents(request): + yield cls._decode_cached_content(cached_content) + + def delete(self, client: glm.CachedServiceClient | None = None) -> None: + """Deletes `CachedContent` resource.""" + if client is None: + client = get_default_cache_client() + + request = protos.DeleteCachedContentRequest(name=self.name) + client.delete_cached_content(request) + return + + def update( + self, + updates: dict[str, Any], + client: glm.CacheServiceClient | None = None, + ) -> CachedContent: + """Updates requested `CachedContent` resource. + + Args: + updates: The list of fields to update. Currently only + `ttl/expire_time` is supported as an update path. + + Returns: + `CachedContent` object with specified updates. + """ + if client is None: + client = get_default_cache_client() + + updates = flatten_update_paths(updates) + for update_path in updates: + if update_path == "ttl": + updates = updates.copy() + update_path_val = updates.get(update_path) + updates[update_path] = caching_types.to_ttl(update_path_val) + else: + raise ValueError( + f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead." + ) + field_mask = field_mask_pb2.FieldMask() + + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + self._apply_update(path, value) + + request = protos.UpdateCachedContentRequest( + cached_content=self._to_dict(), update_mask=field_mask + ) + client.update_cached_content(request) + return self diff --git a/google/generativeai/client.py b/google/generativeai/client.py index 40c2bdcaf..7012ecc7c 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -315,6 +315,10 @@ def configure( _client_manager.configure() +def get_default_cache_client() -> glm.CacheServiceClient: + return _client_manager.get_default_client("cache") + + def get_default_discuss_client() -> glm.DiscussServiceClient: return _client_manager.get_default_client("discuss") diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 7d69ae8f9..10744a948 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -4,7 +4,7 @@ from collections.abc import Iterable import textwrap -from typing import Any +from typing import Any, Union, overload import reprlib # pylint: disable=bad-continuation, line-too-long @@ -13,6 +13,8 @@ import google.api_core.exceptions from google.generativeai import protos from google.generativeai import client + +from google.generativeai import caching from google.generativeai.types import content_types from google.generativeai.types import generation_types from google.generativeai.types import helper_types @@ -94,6 +96,15 @@ def __init__( self._client = None self._async_client = None + def __new__(cls, *args, **kwargs) -> GenerativeModel: + self = super().__new__(cls) + + if cached_instance := kwargs.pop("cached_content", None): + setattr(self, "_cached_content", cached_instance.name) + setattr(cls, "cached_content", property(fget=lambda self: self._cached_content)) + + return self + @property def model_name(self): return self._model_name @@ -112,6 +123,7 @@ def maybe_text(content): safety_settings={self._safety_settings}, tools={self._tools}, system_instruction={maybe_text(self._system_instruction)}, + cached_content={getattr(self, "cached_content", None)} )""" ) @@ -127,6 +139,13 @@ def _prepare_request( tool_config: content_types.ToolConfigType | None, ) -> protos.GenerateContentRequest: """Creates a `protos.GenerateContentRequest` from raw inputs.""" + if hasattr(self, "cached_content") and any([self._system_instruction, tools, tool_config]): + raise ValueError( + "`tools`, `tool_config`, `system_instruction` cannot be set on a model instantinated with `cached_content` as its context." + ) + + cached_content = getattr(self, "cached_content", None) + tools_lib = self._get_tools_lib(tools) if tools_lib is not None: tools_lib = tools_lib.to_proto() @@ -155,6 +174,7 @@ def _prepare_request( tools=tools_lib, tool_config=tool_config, system_instruction=self._system_instruction, + cached_content=cached_content, ) def _get_tools_lib( @@ -165,6 +185,55 @@ def _get_tools_lib( else: return content_types.to_function_library(tools) + @overload + @classmethod + def from_cached_content( + cls, + cached_content: str, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: ... + + @overload + @classmethod + def from_cached_content( + cls, + cached_content: caching.CachedContent, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: ... + + @classmethod + def from_cached_content( + cls, + cached_content: str | caching.CachedContent, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: + """Creates a model with `cached_content` as model's context. + + Args: + cached_content: context for the model. + + Returns: + `GenerativeModel` object with `cached_content` as its context. + """ + if isinstance(cached_content, str): + cached_content = caching.CachedContent.get(name=cached_content) + + # call __new__ with the cached_content to set the model's context. This is done to avoid + # the exposing `cached_content` as a public attribute. + self = cls.__new__(cls, cached_content=cached_content) + + # call __init__ to set the model's `generation_config`, `safety_settings`. + # `model_name` will be the name of the model for which the `cached_content` was created. + self.__init__( + model_name=cached_content.model, + generation_config=generation_config, + safety_settings=safety_settings, + ) + return self + def generate_content( self, contents: content_types.ContentsType, diff --git a/google/generativeai/types/caching_types.py b/google/generativeai/types/caching_types.py new file mode 100644 index 000000000..8d55b70b2 --- /dev/null +++ b/google/generativeai/types/caching_types.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import datetime +from typing import Optional, Union +from typing_extensions import TypedDict +import re + +__all__ = ["TTL"] + + +_VALID_CACHED_CONTENT_NAME = r"([a-z0-9-\.]+)$" +NAME_ERROR_MESSAGE = ( + "The `name` must consist of alphanumeric characters (or `-` or `.`). Received: `{name}`" +) + + +def valid_cached_content_name(name: str) -> bool: + return re.match(_VALID_CACHED_CONTENT_NAME, name) is not None + + +class TTL(TypedDict): + seconds: int + + +ExpirationTypes = Union[TTL, int, datetime.timedelta] + + +def to_ttl(expiration: Optional[ExpirationTypes]) -> TTL: + if isinstance(expiration, datetime.timedelta): + return {"seconds": int(expiration.total_seconds())} + elif isinstance(expiration, dict): + return expiration + elif isinstance(expiration, int): + return {"seconds": expiration} + else: + raise TypeError( + f"Could not convert input to `expire_time` \n'" f" type: {type(expiration)}\n", + expiration, + ) diff --git a/setup.py b/setup.py index 6f9545e4f..0575dcd28 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def get_version(): release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-ai-generativelanguage==0.6.4", + "google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py.tar.gz", "google-api-core", "google-api-python-client", "google-auth>=2.15.0", # 2.15 adds API key auth support diff --git a/tests/test_caching.py b/tests/test_caching.py new file mode 100644 index 000000000..47692325b --- /dev/null +++ b/tests/test_caching.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import unittest + +from google.generativeai import caching +from google.generativeai import protos + +from google.generativeai import client +from absl.testing import absltest +from absl.testing import parameterized + + +class UnitTests(parameterized.TestCase): + def setUp(self): + self.client = unittest.mock.MagicMock() + + client._client_manager.clients["cache"] = self.client + + self.observed_requests = [] + + def add_client_method(f): + name = f.__name__ + setattr(self.client, name, f) + return f + + @add_client_method + def create_cached_content( + request: protos.CreateCachedContentRequest, + **kwargs, + ) -> protos.CachedContent: + self.observed_requests.append(request) + return protos.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ) + + @add_client_method + def get_cached_content( + request: protos.GetCachedContentRequest, + **kwargs, + ) -> protos.CachedContent: + self.observed_requests.append(request) + return protos.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ) + + @add_client_method + def list_cached_contents( + request: protos.ListCachedContentsRequest, + **kwargs, + ) -> protos.ListCachedContentsResponse: + self.observed_requests.append(request) + return [ + protos.CachedContent( + name="cachedContents/test-cached-content-1", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ), + protos.CachedContent( + name="cachedContents/test-cached-content-2", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ), + ] + + @add_client_method + def update_cached_content( + request: protos.UpdateCachedContentRequest, + **kwargs, + ) -> protos.CachedContent: + self.observed_requests.append(request) + return protos.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T03:01:01.123456Z", + ) + + @add_client_method + def delete_cached_content( + request: protos.DeleteCachedContentRequest, + **kwargs, + ) -> None: + self.observed_requests.append(request) + + def test_create_cached_content(self): + + def add(a: int, b: int) -> int: + return a + b + + cc = caching.CachedContent.create( + name="test-cached-content", + model="models/gemini-1.0-pro-001", + contents=["Add 5 and 6"], + tools=[add], + tool_config={"function_calling_config": "ANY"}, + system_instruction="Always add 10 to the result.", + ttl=datetime.timedelta(minutes=30), + ) + self.assertIsInstance(self.observed_requests[-1], protos.CreateCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + self.assertEqual(cc.name, "cachedContents/test-cached-content") + self.assertEqual(cc.model, "models/gemini-1.0-pro-001") + + @parameterized.named_parameters( + [ + dict( + testcase_name="ttl-is-int-seconds", + ttl=7200, + ), + dict( + testcase_name="ttl-is-timedelta", + ttl=datetime.timedelta(hours=2), + ), + dict( + testcase_name="ttl-is-dict", + ttl={"seconds": 7200}, + ), + dict( + testcase_name="ttl-is-none-default-to-1-hr", + ttl=None, + ), + ] + ) + def test_expiration_types_for_create_cached_content(self, ttl): + cc = caching.CachedContent.create( + name="test-cached-content", + model="models/gemini-1.0-pro-001", + contents=["cache this please for 2 hours"], + ttl=ttl, + ) + self.assertIsInstance(self.observed_requests[-1], protos.CreateCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + + @parameterized.named_parameters( + [ + dict( + testcase_name="upper_case", + name="Test-cached-content", + ), + dict( + testcase_name="special_characters_except_dot_and_hyphen", + name="test-cac*@/hed-conte#nt", + ), + dict( + testcase_name="empty_name", + name="", + ), + dict( + testcase_name="blank_spaces", + name="test cached content", + ), + ] + ) + def test_create_cached_content_with_invalid_name_format(self, name): + with self.assertRaises(ValueError): + _ = caching.CachedContent.create( + name=name, + model="models/gemini-1.0-pro-001", + ) + + def test_get_cached_content(self): + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + self.assertIsInstance(self.observed_requests[-1], protos.GetCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + self.assertEqual(cc.name, "cachedContents/test-cached-content") + self.assertEqual(cc.model, "models/gemini-1.0-pro-001") + + def test_list_cached_contents(self): + ccs = list(caching.CachedContent.list(page_size=2)) + self.assertIsInstance(self.observed_requests[-1], protos.ListCachedContentsRequest) + self.assertLen(ccs, 2) + self.assertIsInstance(ccs[0], caching.CachedContent) + self.assertIsInstance(ccs[1], caching.CachedContent) + + def test_update_cached_content_invalid_update_paths(self): + update_masks = dict( + name="change", + model="models/gemini-1.5-pro-001", + system_instruction="Always add 10 to the result.", + contents=["add this Content"], + ) + + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + with self.assertRaises(ValueError): + cc.update(updates=update_masks) + + def test_update_cached_content_valid_update_paths(self): + update_masks = dict( + ttl=datetime.timedelta(hours=2), + ) + + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + cc = cc.update(updates=update_masks) + self.assertIsInstance(self.observed_requests[-1], protos.UpdateCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + + def test_delete_cached_content(self): + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + cc.delete() + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) + + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + cc.delete() + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) + + def test_auto_delete_cached_content_with_context_manager(self): + with caching.CachedContent.create( + name="test-cached-content", + model="models/gemini-1.0-pro-001", + contents=["Add 5 and 6"], + system_instruction="Always add 10 to the result.", + ttl=datetime.timedelta(minutes=30), + ) as cc: + ... # some logic + + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 0ece77e94..73789346d 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -1,6 +1,7 @@ import collections from collections.abc import Iterable import copy +import datetime import pathlib from typing import Any import textwrap @@ -10,11 +11,11 @@ from google.generativeai import protos from google.generativeai import client as client_lib from google.generativeai import generative_models +from google.generativeai import caching from google.generativeai.types import content_types from google.generativeai.types import generation_types from google.generativeai.types import helper_types - import PIL.Image HERE = pathlib.Path(__file__).parent @@ -77,6 +78,20 @@ def count_tokens( response = self.responses["count_tokens"].pop(0) return response + def get_cached_content( + self, + request: protos.GetCachedContentRequest, + **kwargs, + ) -> protos.CachedContent: + self.observed_requests.append(request) + return protos.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ) + class CUJTests(parameterized.TestCase): """Tests are in order with the design doc.""" @@ -96,6 +111,7 @@ def responses(self): def setUp(self): self.client = MockGenerativeServiceClient(self) client_lib._client_manager.clients["generative"] = self.client + client_lib._client_manager.clients["cache"] = self.client def test_hello(self): # Generate text from text prompt @@ -317,6 +333,56 @@ def test_stream_prompt_feedback_not_blocked(self): text = "".join(chunk.text for chunk in response) self.assertEqual(text, "first second") + @parameterized.named_parameters( + [ + dict(testcase_name="test_cached_content_as_id", cached_content="test-cached-content"), + dict( + testcase_name="test_cached_content_as_CachedContent_object", + cached_content=caching.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time=datetime.datetime.now(), + update_time=datetime.datetime.now(), + expire_time=datetime.datetime.now(), + ), + ), + ], + ) + def test_model_with_cached_content_as_context(self, cached_content): + model = generative_models.GenerativeModel.from_cached_content(cached_content=cached_content) + cc_name = model.cached_content # pytype: disable=attribute-error + model_name = model.model_name + self.assertEqual(cc_name, "cachedContents/test-cached-content") + self.assertEqual(model_name, "models/gemini-1.0-pro-001") + self.assertEqual( + model.cached_content, # pytype: disable=attribute-error + "cachedContents/test-cached-content", + ) + + def test_content_generation_with_model_having_context(self): + self.responses["generate_content"] = [simple_response("world!")] + model = generative_models.GenerativeModel.from_cached_content( + cached_content="test-cached-content" + ) + response = model.generate_content("Hello") + + self.assertEqual(response.text, "world!") + self.assertEqual( + model.cached_content, # pytype: disable=attribute-error + "cachedContents/test-cached-content", + ) + + def test_fail_content_generation_with_model_having_context(self): + model = generative_models.GenerativeModel.from_cached_content( + cached_content="test-cached-content" + ) + + def add(a: int, b: int) -> int: + return a + b + + with self.assertRaises(ValueError): + model.generate_content("Hello", tools=[add]) + def test_chat(self): # Multi turn chat model = generative_models.GenerativeModel("gemini-pro") @@ -1140,6 +1206,7 @@ def test_repr_for_multi_turn_chat(self): safety_settings={}, tools=None, system_instruction=None, + cached_content=None ), history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'first'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'I also like this image.'}, {'inline_data': {'data': 'iVBORw0KGgoA...AAElFTkSuQmCC', 'mime_type': 'image/png'}}], 'role': 'user'}), protos.Content({'parts': [{'text': 'second'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'What things do I like?.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'third'}], 'role': 'model'})] )""" @@ -1168,6 +1235,7 @@ def test_repr_for_incomplete_streaming_chat(self): safety_settings={}, tools=None, system_instruction=None, + cached_content=None ), history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" @@ -1212,6 +1280,7 @@ def test_repr_for_broken_streaming_chat(self): safety_settings={}, tools=None, system_instruction=None, + cached_content=None ), history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" @@ -1223,6 +1292,14 @@ def test_repr_for_system_instruction(self): result = repr(model) self.assertIn("system_instruction='Be excellent.'", result) + def test_repr_for_model_created_from_cahced_content(self): + model = generative_models.GenerativeModel.from_cached_content( + cached_content="test-cached-content" + ) + result = repr(model) + self.assertIn("cached_content=cachedContents/test-cached-content", result) + self.assertIn("model_name='models/gemini-1.0-pro-001'", result) + def test_count_tokens_called_with_request_options(self): self.responses["count_tokens"].append(protos.CountTokensResponse(total_tokens=7)) request_options = {"timeout": 120}