Skip to content

Commit

Permalink
Fix types
Browse files Browse the repository at this point in the history
Change-Id: Ia4bf6b936fab4c1992798c65cff91c15e51a92c0
  • Loading branch information
mayureshagashe2105 committed May 27, 2024
1 parent 645ceab commit f48cedc
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 39 deletions.
4 changes: 2 additions & 2 deletions google/generativeai/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _decode_cached_content(cls, cached_content: glm.CachedContent) -> CachedCont
@staticmethod
def _prepare_create_request(
model: str,
name: str = None,
name: str | None = None,
system_instruction: Optional[content_types.ContentType] = None,
contents: Optional[content_types.ContentsType] = None,
tools: Optional[content_types.FunctionLibraryType] = None,
Expand Down Expand Up @@ -129,7 +129,7 @@ def _prepare_create_request(
def create(
cls,
model: str,
name: str = None,
name: str | None = None,
system_instruction: Optional[content_types.ContentType] = None,
contents: Optional[content_types.ContentsType] = None,
tools: Optional[content_types.FunctionLibraryType] = None,
Expand Down
24 changes: 12 additions & 12 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def create_cached_content(
) -> glm.CachedContent:
self.observed_requests.append(request)
return glm.CachedContent(
name="cachedContent/test-cached-content",
name="cachedContents/test-cached-content",
model="models/gemini-1.0-pro-001",
create_time="2000-01-01T01:01:01.123456Z",
update_time="2000-01-01T01:01:01.123456Z",
Expand All @@ -63,7 +63,7 @@ def get_cached_content(
) -> glm.CachedContent:
self.observed_requests.append(request)
return glm.CachedContent(
name="cachedContent/test-cached-content",
name="cachedContents/test-cached-content",
model="models/gemini-1.0-pro-001",
create_time="2000-01-01T01:01:01.123456Z",
update_time="2000-01-01T01:01:01.123456Z",
Expand All @@ -78,14 +78,14 @@ def list_cached_contents(
self.observed_requests.append(request)
return [
glm.CachedContent(
name="cachedContent/test-cached-content-1",
name="cachedContents/test-cached-content-1",
model="models/gemini-1.0-pro-001",
create_time="2000-01-01T01:01:01.123456Z",
update_time="2000-01-01T01:01:01.123456Z",
expire_time="2000-01-01T01:01:01.123456Z",
),
glm.CachedContent(
name="cachedContent/test-cached-content-2",
name="cachedContents/test-cached-content-2",
model="models/gemini-1.0-pro-001",
create_time="2000-01-01T01:01:01.123456Z",
update_time="2000-01-01T01:01:01.123456Z",
Expand All @@ -100,7 +100,7 @@ def update_cached_content(
) -> glm.CachedContent:
self.observed_requests.append(request)
return glm.CachedContent(
name="cachedContent/test-cached-content",
name="cachedContents/test-cached-content",
model="models/gemini-1.0-pro-001",
create_time="2000-01-01T01:01:01.123456Z",
update_time="2000-01-01T01:01:01.123456Z",
Expand Down Expand Up @@ -130,7 +130,7 @@ def add(a: int, b: int) -> int:
)
self.assertIsInstance(self.observed_requests[-1], glm.CreateCachedContentRequest)
self.assertIsInstance(cc, caching.CachedContent)
self.assertEqual(cc.name, "cachedContent/test-cached-content")
self.assertEqual(cc.name, "cachedContents/test-cached-content")
self.assertEqual(cc.model, "models/gemini-1.0-pro-001")

@parameterized.named_parameters(
Expand Down Expand Up @@ -191,10 +191,10 @@ def test_create_cached_content_with_invalid_name_format(self, name):
)

def test_get_cached_content(self):
cc = caching.CachedContent.get(name="cachedContent/test-cached-content")
cc = caching.CachedContent.get(name="cachedContents/test-cached-content")
self.assertIsInstance(self.observed_requests[-1], glm.GetCachedContentRequest)
self.assertIsInstance(cc, caching.CachedContent)
self.assertEqual(cc.name, "cachedContent/test-cached-content")
self.assertEqual(cc.name, "cachedContents/test-cached-content")
self.assertEqual(cc.model, "models/gemini-1.0-pro-001")

def test_list_cached_contents(self):
Expand All @@ -212,7 +212,7 @@ def test_update_cached_content_invalid_update_paths(self):
contents=["add this Content"],
)

cc = caching.CachedContent.get(name="cachedContent/test-cached-content")
cc = caching.CachedContent.get(name="cachedContents/test-cached-content")
with self.assertRaises(ValueError):
cc.update(updates=update_masks)

Expand All @@ -221,17 +221,17 @@ def test_update_cached_content_valid_update_paths(self):
ttl=datetime.timedelta(hours=2),
)

cc = caching.CachedContent.get(name="cachedContent/test-cached-content")
cc = caching.CachedContent.get(name="cachedContents/test-cached-content")
cc = cc.update(updates=update_masks)
self.assertIsInstance(self.observed_requests[-1], glm.UpdateCachedContentRequest)
self.assertIsInstance(cc, caching.CachedContent)

def test_delete_cached_content(self):
cc = caching.CachedContent.get(name="cachedContent/test-cached-content")
cc = caching.CachedContent.get(name="cachedContents/test-cached-content")
cc.delete()
self.assertIsInstance(self.observed_requests[-1], glm.DeleteCachedContentRequest)

cc = caching.CachedContent.get(name="cachedContent/test-cached-content")
cc = caching.CachedContent.get(name="cachedContents/test-cached-content")
cc.delete()
self.assertIsInstance(self.observed_requests[-1], glm.DeleteCachedContentRequest)

Expand Down
35 changes: 10 additions & 25 deletions tests/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_cached_content(
) -> glm.CachedContent:
self.observed_requests.append(request)
return glm.CachedContent(
name="cachedContent/test-cached-content",
name="cachedContents/test-cached-content",
model="models/gemini-1.0-pro-001",
create_time="2000-01-01T01:01:01.123456Z",
update_time="2000-01-01T01:01:01.123456Z",
Expand Down Expand Up @@ -112,19 +112,6 @@ def setUp(self):
client_lib._client_manager.clients["generative"] = self.client
client_lib._client_manager.clients["cache"] = self.client

# @add_client_method
# def get_cached_content(
# request: glm.GetCachedContentRequest,
# **kwargs,
# ) -> glm.CachedContent:
# self.observed_requests.append(request)
# return glm.CachedContent(
# name="cachedContent/test-cached-content",
# model="models/gemini-1.0-pro-001",
# create_time="2000-01-01T01:01:01.123456Z",
# update_time="2000-01-01T01:01:01.123456Z",
# expire_time="2000-01-01T01:01:01.123456Z",
# )

def test_hello(self):
# Generate text from text prompt
Expand Down Expand Up @@ -351,23 +338,19 @@ def test_stream_prompt_feedback_not_blocked(self):
dict(testcase_name="test_cached_content_as_id", cached_content="test-cached-content"),
dict(
testcase_name="test_cached_content_as_CachedContent_object",
cached_content=caching.CachedContent(
name="cachedContent/test-cached-content",
model="models/gemini-1.0-pro-001",
create_time="2000-01-01T01:01:01.123456Z",
update_time="2000-01-01T01:01:01.123456Z",
expire_time="2000-01-01T01:01:01.123456Z",
),
cached_content=caching.CachedContent.get(name="cachedContents/test-cached-content"),
),
],
)
def test_model_with_cached_content_as_context(self, cached_content):
model = generative_models.GenerativeModel.from_cached_content(cached_content=cached_content)
cc_name = model.cached_content
cc_name = model.cached_content # pytype: disable=attribute-error
model_name = model.model_name
self.assertEqual(cc_name, "cachedContent/test-cached-content")
self.assertEqual(cc_name, "cachedContents/test-cached-content")
self.assertEqual(model_name, "models/gemini-1.0-pro-001")
self.assertEqual(model.cached_content, "cachedContent/test-cached-content")
self.assertEqual(
model.cached_content, "cachedContents/test-cached-content"
) # pytype: disable=attribute-error

def test_content_generation_with_model_having_context(self):
self.responses["generate_content"] = [simple_response("world!")]
Expand All @@ -377,7 +360,9 @@ def test_content_generation_with_model_having_context(self):
response = model.generate_content("Hello")

self.assertEqual(response.text, "world!")
self.assertEqual(model.cached_content, "cachedContent/test-cached-content")
self.assertEqual(
model.cached_content, "cachedContents/test-cached-content"
) # pytype: disable=attribute-error

def test_fail_content_generation_with_model_having_context(self):
model = generative_models.GenerativeModel.from_cached_content(
Expand Down

0 comments on commit f48cedc

Please sign in to comment.