Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicit Caching #355

Merged
merged 23 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f13228d
*Inital prototype for explicit caching
mayureshagashe2105 Apr 26, 2024
a4ac7a5
Merge branch 'main' into caching
mayureshagashe2105 May 21, 2024
6fafe6b
rename get_cached_content to get
mayureshagashe2105 May 22, 2024
afd066d
Merge branch 'main' into caching
mayureshagashe2105 May 22, 2024
cfc936e
Stroke out functional approach for CachedContent CURD ops
mayureshagashe2105 May 23, 2024
e65d16e
blacken
mayureshagashe2105 May 23, 2024
d35cc71
Improve tests
mayureshagashe2105 May 23, 2024
d862dae
fix tests
mayureshagashe2105 May 23, 2024
2cde1a2
fix tests
mayureshagashe2105 May 23, 2024
e1d8c7a
Validate name checks for CachedContent creation
mayureshagashe2105 May 24, 2024
59663c8
Add tests
mayureshagashe2105 May 24, 2024
f37df8c
mark name as OPTIONAL for CachedContent creation
mayureshagashe2105 May 26, 2024
d1fd749
Add type-annotations to __new__ to fix pytype checks
mayureshagashe2105 May 27, 2024
17372e3
Add 'cached_content' to GenerativeModel's repr
mayureshagashe2105 May 27, 2024
645ceab
blacken
mayureshagashe2105 May 27, 2024
f48cedc
Fix types
mayureshagashe2105 May 27, 2024
a1c8c72
Fix docstrings
mayureshagashe2105 May 27, 2024
67472d3
Fix types
mayureshagashe2105 May 27, 2024
bf6551a
Fix types
mayureshagashe2105 May 27, 2024
82d3c5a
Merge branch 'main' of https://github.com/mayureshagashe2105/generati…
mayureshagashe2105 May 30, 2024
8e86ef1
Refactor for genai.protos module
mayureshagashe2105 May 30, 2024
4627fe1
use preview build
MarkDaoust Jun 4, 2024
fb9995c
Merge branch 'main' into caching
MarkDaoust Jun 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 260 additions & 0 deletions google/generativeai/caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# -*- coding: utf-8 -*-
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import dataclasses
import datetime
from typing import Any, Iterable, Optional

from google.generativeai import protos
from google.generativeai.types.model_types import idecode_time
from google.generativeai.types import caching_types
from google.generativeai.types import content_types
from google.generativeai.utils import flatten_update_paths
from google.generativeai.client import get_default_cache_client

from google.protobuf import field_mask_pb2
import google.ai.generativelanguage as glm


@dataclasses.dataclass
class CachedContent:
mayureshagashe2105 marked this conversation as resolved.
Show resolved Hide resolved
"""Cached content resource."""

name: str
model: str
create_time: datetime.datetime
update_time: datetime.datetime
expire_time: datetime.datetime

# NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+).
# Adding basic support for now.
def __enter__(self):
mayureshagashe2105 marked this conversation as resolved.
Show resolved Hide resolved
return self

def __exit__(self, exc_type, exc_value, exc_tb):
self.delete()

def _to_dict(self) -> protos.CachedContent:
proto_paths = {
"name": self.name,
"model": self.model,
}
return protos.CachedContent(**proto_paths)

def _apply_update(self, path, value):
parts = path.split(".")
for part in parts[:-1]:
self = getattr(self, part)
if parts[-1] == "ttl":
value = self.expire_time + datetime.timedelta(seconds=value["seconds"])
parts[-1] = "expire_time"
setattr(self, parts[-1], value)

@classmethod
def _decode_cached_content(cls, cached_content: protos.CachedContent) -> CachedContent:
# not supposed to get INPUT_ONLY repeated fields, but local gapic lib build
# is returning these, hence setting including_default_value_fields to False
cached_content = type(cached_content).to_dict(
cached_content, including_default_value_fields=False
)

idecode_time(cached_content, "create_time")
idecode_time(cached_content, "update_time")
# always decode `expire_time` as Timestamp is returned
# regardless of what was sent on input
idecode_time(cached_content, "expire_time")
return cls(**cached_content)

@staticmethod
def _prepare_create_request(
model: str,
name: str | None = None,
system_instruction: Optional[content_types.ContentType] = None,
contents: Optional[content_types.ContentsType] = None,
tools: Optional[content_types.FunctionLibraryType] = None,
tool_config: Optional[content_types.ToolConfigType] = None,
ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1),
) -> protos.CreateCachedContentRequest:
"""Prepares a CreateCachedContentRequest."""
if name is not None:
if not caching_types.valid_cached_content_name(name):
raise ValueError(caching_types.NAME_ERROR_MESSAGE.format(name=name))

name = "cachedContents/" + name

if "/" not in model:
model = "models/" + model

if system_instruction:
system_instruction = content_types.to_content(system_instruction)

tools_lib = content_types.to_function_library(tools)
if tools_lib:
tools_lib = tools_lib.to_proto()

if tool_config:
tool_config = content_types.to_tool_config(tool_config)

if contents:
contents = content_types.to_contents(contents)

if ttl:
ttl = caching_types.to_ttl(ttl)

cached_content = protos.CachedContent(
name=name,
model=model,
system_instruction=system_instruction,
contents=contents,
tools=tools_lib,
tool_config=tool_config,
ttl=ttl,
)

return protos.CreateCachedContentRequest(cached_content=cached_content)

@classmethod
def create(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Async?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't implement it, due to the following reason:

  1. Does it make sense to have async for cacheservice?
  2. Anyway, I'll submit it in a separate PR, just copy paste and async/await :)

cls,
model: str,
name: str | None = None,
system_instruction: Optional[content_types.ContentType] = None,
contents: Optional[content_types.ContentsType] = None,
tools: Optional[content_types.FunctionLibraryType] = None,
tool_config: Optional[content_types.ToolConfigType] = None,
ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1),
client: glm.CacheServiceClient | None = None,
) -> CachedContent:
"""Creates `CachedContent` resource.

Args:
model: The name of the `model` to use for cached content creation.
Any `CachedContent` resource can be only used with the
`model` it was created for.
name: The resource name referring to the cached content.
system_instruction: Developer set system instruction.
contents: Contents to cache.
tools: A list of `Tools` the model may use to generate response.
tool_config: Config to apply to all tools.
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.

Returns:
`CachedContent` resource with specified name.
"""
if client is None:
client = get_default_cache_client()

request = cls._prepare_create_request(
model=model,
name=name,
system_instruction=system_instruction,
contents=contents,
tools=tools,
tool_config=tool_config,
ttl=ttl,
)

response = client.create_cached_content(request)
return cls._decode_cached_content(response)

@classmethod
def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedContent:
"""Fetches required `CachedContent` resource.

Args:
name: The resource name referring to the cached content.

Returns:
`CachedContent` resource with specified `name`.
"""
if client is None:
client = get_default_cache_client()

if "cachedContents/" not in name:
name = "cachedContents/" + name

request = protos.GetCachedContentRequest(name=name)
response = client.get_cached_content(request)
return cls._decode_cached_content(response)

@classmethod
def list(
cls, page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None
) -> Iterable[CachedContent]:
"""Lists `CachedContent` objects associated with the project.

Args:
page_size: The maximum number of permissions to return (per page).
The service may return fewer `CachedContent` objects.

Returns:
A paginated list of `CachedContent` objects.
"""
if client is None:
client = get_default_cache_client()

request = protos.ListCachedContentsRequest(page_size=page_size)
for cached_content in client.list_cached_contents(request):
yield cls._decode_cached_content(cached_content)

def delete(self, client: glm.CachedServiceClient | None = None) -> None:
"""Deletes `CachedContent` resource."""
if client is None:
client = get_default_cache_client()

request = protos.DeleteCachedContentRequest(name=self.name)
client.delete_cached_content(request)
return

def update(
self,
updates: dict[str, Any],
client: glm.CacheServiceClient | None = None,
) -> CachedContent:
"""Updates requested `CachedContent` resource.

Args:
updates: The list of fields to update. Currently only
`ttl/expire_time` is supported as an update path.

Returns:
`CachedContent` object with specified updates.
"""
if client is None:
client = get_default_cache_client()

updates = flatten_update_paths(updates)
for update_path in updates:
if update_path == "ttl":
updates = updates.copy()
update_path_val = updates.get(update_path)
updates[update_path] = caching_types.to_ttl(update_path_val)
else:
raise ValueError(
f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead."
)
field_mask = field_mask_pb2.FieldMask()

for path in updates.keys():
field_mask.paths.append(path)
for path, value in updates.items():
self._apply_update(path, value)

request = protos.UpdateCachedContentRequest(
cached_content=self._to_dict(), update_mask=field_mask
)
client.update_cached_content(request)
return self
4 changes: 4 additions & 0 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ def configure(
_client_manager.configure()


def get_default_cache_client() -> glm.CacheServiceClient:
mayureshagashe2105 marked this conversation as resolved.
Show resolved Hide resolved
return _client_manager.get_default_client("cache")


def get_default_discuss_client() -> glm.DiscussServiceClient:
return _client_manager.get_default_client("discuss")

Expand Down
71 changes: 70 additions & 1 deletion google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from collections.abc import Iterable
import textwrap
from typing import Any
from typing import Any, Union, overload
import reprlib

# pylint: disable=bad-continuation, line-too-long
Expand All @@ -13,6 +13,8 @@
import google.api_core.exceptions
from google.generativeai import protos
from google.generativeai import client

from google.generativeai import caching
from google.generativeai.types import content_types
from google.generativeai.types import generation_types
from google.generativeai.types import helper_types
Expand Down Expand Up @@ -94,6 +96,15 @@ def __init__(
self._client = None
self._async_client = None

def __new__(cls, *args, **kwargs) -> GenerativeModel:
self = super().__new__(cls)

if cached_instance := kwargs.pop("cached_content", None):
setattr(self, "_cached_content", cached_instance.name)
setattr(cls, "cached_content", property(fget=lambda self: self._cached_content))

return self

@property
def model_name(self):
return self._model_name
Expand All @@ -112,6 +123,7 @@ def maybe_text(content):
safety_settings={self._safety_settings},
tools={self._tools},
system_instruction={maybe_text(self._system_instruction)},
cached_content={getattr(self, "cached_content", None)}
)"""
)

Expand All @@ -127,6 +139,13 @@ def _prepare_request(
tool_config: content_types.ToolConfigType | None,
) -> protos.GenerateContentRequest:
"""Creates a `protos.GenerateContentRequest` from raw inputs."""
if hasattr(self, "cached_content") and any([self._system_instruction, tools, tool_config]):
raise ValueError(
"`tools`, `tool_config`, `system_instruction` cannot be set on a model instantinated with `cached_content` as its context."
)

cached_content = getattr(self, "cached_content", None)

tools_lib = self._get_tools_lib(tools)
if tools_lib is not None:
tools_lib = tools_lib.to_proto()
Expand Down Expand Up @@ -155,6 +174,7 @@ def _prepare_request(
tools=tools_lib,
tool_config=tool_config,
system_instruction=self._system_instruction,
cached_content=cached_content,
)

def _get_tools_lib(
Expand All @@ -165,6 +185,55 @@ def _get_tools_lib(
else:
return content_types.to_function_library(tools)

@overload
@classmethod
def from_cached_content(
cls,
cached_content: str,
generation_config: generation_types.GenerationConfigType | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
) -> GenerativeModel: ...

@overload
@classmethod
def from_cached_content(
cls,
cached_content: caching.CachedContent,
generation_config: generation_types.GenerationConfigType | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
) -> GenerativeModel: ...

@classmethod
def from_cached_content(
cls,
cached_content: str | caching.CachedContent,
generation_config: generation_types.GenerationConfigType | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
) -> GenerativeModel:
"""Creates a model with `cached_content` as model's context.

Args:
cached_content: context for the model.

Returns:
`GenerativeModel` object with `cached_content` as its context.
"""
if isinstance(cached_content, str):
cached_content = caching.CachedContent.get(name=cached_content)

# call __new__ with the cached_content to set the model's context. This is done to avoid
mayureshagashe2105 marked this conversation as resolved.
Show resolved Hide resolved
# the exposing `cached_content` as a public attribute.
self = cls.__new__(cls, cached_content=cached_content)

# call __init__ to set the model's `generation_config`, `safety_settings`.
# `model_name` will be the name of the model for which the `cached_content` was created.
self.__init__(
model_name=cached_content.model,
generation_config=generation_config,
safety_settings=safety_settings,
)
return self

def generate_content(
self,
contents: content_types.ContentsType,
Expand Down
Loading
Loading