Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add new Gemini 2.0 models #172

Merged
merged 25 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0ecf3db
Add new adapter with Gemini 2.0 models
roman-romanov-o Jan 14, 2025
4dcb2f4
Add new models into endpoint unit test
roman-romanov-o Jan 14, 2025
b9bfbb5
Add new models to integration tests
roman-romanov-o Jan 14, 2025
3680c58
forgotten in last commit
roman-romanov-o Jan 14, 2025
b6db9c0
correctly convert genai API error
roman-romanov-o Jan 14, 2025
874b974
fix code duplication
roman-romanov-o Jan 15, 2025
75868f1
remove code duplication once again
roman-romanov-o Jan 15, 2025
d0fdfca
remove bound for typevar
roman-romanov-o Jan 15, 2025
5c55581
Refactor: gather all parts/content/conversation creation into distinc…
roman-romanov-o Jan 15, 2025
417d3f2
Minor refactoring
roman-romanov-o Jan 15, 2025
90d2aaf
Merge branch 'development' into feat/gemini-2-0
roman-romanov-o Jan 16, 2025
5eb8fd0
refactor exception type
roman-romanov-o Jan 16, 2025
43deba3
review fixes
roman-romanov-o Jan 16, 2025
0b81f5f
fix 429 issues on integration tests with retries
roman-romanov-o Jan 16, 2025
1634196
refactor: move conversation factory to distinct module
roman-romanov-o Jan 16, 2025
dba0d09
fix linter
roman-romanov-o Jan 16, 2025
f9a46f5
Fix mistake in integration test check
roman-romanov-o Jan 16, 2025
d8b1ee7
Move everything to pydantic v1 to avoid warnings
roman-romanov-o Jan 16, 2025
c397ac8
final fix of warnings
roman-romanov-o Jan 16, 2025
6719727
Merge branch 'development' into feat/gemini-2-0
roman-romanov-o Jan 16, 2025
29576d2
Review fixes
roman-romanov-o Jan 16, 2025
cbf8dce
More refactoring
roman-romanov-o Jan 16, 2025
a78258e
Fix pydantic errors
roman-romanov-o Jan 16, 2025
cade913
Update README with new models
roman-romanov-o Jan 16, 2025
92e2f88
fix README
roman-romanov-o Jan 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion aidial_adapter_vertexai/adapters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import assert_never

from google.genai.client import Client as GenAIClient

from aidial_adapter_vertexai.chat.bison.adapter import (
BisonChatAdapter,
BisonCodeChatAdapter,
Expand All @@ -9,6 +11,7 @@
)
from aidial_adapter_vertexai.chat.gemini.adapter import (
GeminiChatCompletionAdapter,
GeminiGenAIChatCompletionAdapter,
)
from aidial_adapter_vertexai.chat.imagen.adapter import (
ImagenChatCompletionAdapter,
Expand All @@ -28,7 +31,7 @@


async def get_chat_completion_model(
api_key: str, deployment: ChatCompletionDeployment
api_key: str, deployment: ChatCompletionDeployment, client: GenAIClient
) -> ChatCompletionAdapter:
model_id = deployment.get_model_id()

Expand Down Expand Up @@ -58,6 +61,15 @@ async def get_chat_completion_model(
return await GeminiChatCompletionAdapter.create(
storage, model_id, deployment
)
case (
ChatCompletionDeployment.GEMINI_2_0_FLASH_EXP
| ChatCompletionDeployment.GEMINI_2_0_FLASH_THINKING_EXP_1219
| ChatCompletionDeployment.GEMINI_2_0_EXPERIMENTAL_1206
):
storage = create_file_storage(api_key)
return GeminiGenAIChatCompletionAdapter(
client, storage, model_id, deployment
)
case ChatCompletionDeployment.IMAGEN_005:
storage = create_file_storage(api_key)
return await ImagenChatCompletionAdapter.create(storage, model_id)
Expand Down
12 changes: 10 additions & 2 deletions aidial_adapter_vertexai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import vertexai
from aidial_sdk import DIALApp
from aidial_sdk.telemetry.types import TelemetryConfig
from google.genai.client import Client as GenAIClient

from aidial_adapter_vertexai.chat_completion import VertexAIChatCompletion
from aidial_adapter_vertexai.deployments import (
Expand Down Expand Up @@ -51,8 +52,15 @@ async def models():
return ModelsResponse(data=models)


for deployment in ChatCompletionDeployment:
app.add_chat_completion(deployment.get_model_id(), VertexAIChatCompletion())
genai_client = GenAIClient(
vertexai=True, project=GCP_PROJECT_ID, location=DEFAULT_REGION
)


for deployment in ChatCompletionDeployment:
app.add_chat_completion(
deployment.get_model_id(),
VertexAIChatCompletion(client=genai_client),
)
for deployment in EmbeddingsDeployment:
app.add_embeddings(deployment.get_model_id(), VertexAIEmbeddings())
8 changes: 8 additions & 0 deletions aidial_adapter_vertexai/chat/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Choice,
FinishReason,
Response,
Stage,
)

from aidial_adapter_vertexai.dial_api.token_usage import TokenUsage
Expand Down Expand Up @@ -41,6 +42,10 @@ async def set_finish_reason(self, finish_reason: FinishReason):
def is_empty(self) -> bool:
pass

@abstractmethod
async def create_stage(self, name: str) -> Stage:
pass


class ChoiceConsumer(Consumer):
response: Response
Expand Down Expand Up @@ -132,3 +137,6 @@ async def set_finish_reason(self, finish_reason: FinishReason):
)

self.finish_reason = finish_reason

async def create_stage(self, name) -> Stage:
return self.choice.create_stage(name)
5 changes: 2 additions & 3 deletions aidial_adapter_vertexai/chat/gemini/adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .genai_lib import GeminiGenAIChatCompletionAdapter
from .vertex_lib import GeminiChatCompletionAdapter

__all__ = [
"GeminiChatCompletionAdapter",
]
__all__ = ["GeminiChatCompletionAdapter", "GeminiGenAIChatCompletionAdapter"]
238 changes: 238 additions & 0 deletions aidial_adapter_vertexai/chat/gemini/adapter/genai_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
from logging import DEBUG
from typing import AsyncIterator, Callable, List, Optional, assert_never

from aidial_sdk.chat_completion import FinishReason, Message, Stage
from aidial_sdk.exceptions import RuntimeServerError
from google.genai.client import Client as GenAIClient
from google.genai.types import (
GenerateContentResponse as GenAIGenerateContentResponse,
)
from typing_extensions import override

from aidial_adapter_vertexai.chat.chat_completion_adapter import (
ChatCompletionAdapter,
)
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.errors import UserError
from aidial_adapter_vertexai.chat.gemini.error import generate_with_retries
from aidial_adapter_vertexai.chat.gemini.finish_reason import (
genai_to_openai_finish_reason,
)
from aidial_adapter_vertexai.chat.gemini.generation_config import (
create_genai_generation_config,
)
from aidial_adapter_vertexai.chat.gemini.grounding import create_grounding
from aidial_adapter_vertexai.chat.gemini.output import (
create_attachments_from_citations,
create_function_calls_from_genai,
set_usage,
)
from aidial_adapter_vertexai.chat.gemini.prompt.base import GeminiGenAIPrompt
from aidial_adapter_vertexai.chat.gemini.prompt.gemini_2 import Gemini_2_Prompt
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.deployments import (
ChatCompletionDeployment,
Gemini2Deployment,
)
from aidial_adapter_vertexai.dial_api.request import ModelParameters
from aidial_adapter_vertexai.dial_api.storage import FileStorage
from aidial_adapter_vertexai.utils.json import json_dumps, json_dumps_short
from aidial_adapter_vertexai.utils.log_config import vertex_ai_logger as log
from aidial_adapter_vertexai.utils.timer import Timer

_COUNT_TOKENS_ERROR = RuntimeServerError("Failed to count tokens for prompt")
adubovik marked this conversation as resolved.
Show resolved Hide resolved


class GeminiGenAIChatCompletionAdapter(
ChatCompletionAdapter[GeminiGenAIPrompt]
):
deployment: Gemini2Deployment

def __init__(
self,
client: GenAIClient,
file_storage: Optional[FileStorage],
model_id: str,
deployment: Gemini2Deployment,
):
self.file_storage = file_storage
self.model_id = model_id
self.deployment = deployment
self.client = client

@override
async def parse_prompt(
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> GeminiGenAIPrompt | UserError:
match self.deployment:
case (
ChatCompletionDeployment.GEMINI_2_0_EXPERIMENTAL_1206
| ChatCompletionDeployment.GEMINI_2_0_FLASH_EXP
| ChatCompletionDeployment.GEMINI_2_0_FLASH_THINKING_EXP_1219
):
return await Gemini_2_Prompt.parse(
self.file_storage, tools, static_tools, messages
)
case _:
assert_never(self.deployment)

async def send_message_async(
self, params: ModelParameters, prompt: GeminiGenAIPrompt
) -> AsyncIterator[GenAIGenerateContentResponse]:

generation_config = create_genai_generation_config(
params,
prompt.tools,
prompt.static_tools,
prompt.system_instruction,
)
if params.stream:
async for chunk in self.client.aio.models.generate_content_stream(
model=self.model_id,
contents=list(prompt.contents),
config=generation_config,
):
yield chunk
else:
yield await self.client.aio.models.generate_content(
model=self.model_id,
contents=list(prompt.contents),
config=generation_config,
)

async def process_chunks(
self,
consumer: Consumer,
tools: ToolsConfig,
generator: Callable[[], AsyncIterator[GenAIGenerateContentResponse]],
):
thinking_stage: Stage | None = None

usage_metadata = None
is_grounding_added = False
try:
async for chunk in generator():
if log.isEnabledFor(DEBUG):
chunk_str = json_dumps(chunk)
log.debug(f"response chunk: {chunk_str}")

if chunk.prompt_feedback:
await consumer.set_finish_reason(
FinishReason.CONTENT_FILTER
)

if chunk.usage_metadata:
usage_metadata = chunk.usage_metadata

if not chunk.candidates:
continue

candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
await create_function_calls_from_genai(
part, consumer, tools
)
if part.thought and part.text:
if thinking_stage is None:
thinking_stage = await consumer.create_stage(
"Thought Process"
)
thinking_stage.open()
thinking_stage.append_content(part.text)
yield part.text
elif part.text:
await consumer.append_content(part.text)
yield part.text

is_grounding_added |= await create_grounding(
candidate, consumer
)

await create_attachments_from_citations(candidate, consumer)
if openai_reason := genai_to_openai_finish_reason(
candidate.finish_reason,
consumer.is_empty(),
):
await consumer.set_finish_reason(openai_reason)
finally:
if thinking_stage:
thinking_stage.close()
# It's possible that max tokens will be reached during the thinking stage
# and there will be no content in response.
# And set_usage will fail with 'Trying to set "usage" before generating all choices' error.
# Append empty content, so at least one choice is generated.
await consumer.append_content("")

if usage_metadata:
await set_usage(
usage_metadata,
consumer,
self.deployment,
is_grounding_added,
)

@override
async def truncate_prompt(
self, prompt: GeminiGenAIPrompt, max_prompt_tokens: int
) -> TruncatedPrompt[GeminiGenAIPrompt]:
return await prompt.truncate(
tokenizer=self.count_prompt_tokens, user_limit=max_prompt_tokens
)

@override
async def count_prompt_tokens(self, prompt: GeminiGenAIPrompt) -> int:
with Timer("count_tokens[prompt] timing: {time}", log.debug):
resp = await self.client.aio.models.count_tokens(
model=self.model_id,
contents=list(prompt.contents),
)
log.debug(f"count_tokens[prompt] response: {json_dumps(resp)}")
if resp.total_tokens is None:
raise _COUNT_TOKENS_ERROR
return resp.total_tokens

@override
async def count_completion_tokens(self, string: str) -> int:
with Timer("count_tokens[completion] timing: {time}", log.debug):
resp = await self.client.aio.models.count_tokens(
model=self.model_id,
contents=string,
)
log.debug(f"count_tokens[completion] response: {json_dumps(resp)}")
if resp.total_tokens is None:
raise _COUNT_TOKENS_ERROR
return resp.total_tokens

@override
async def chat(
self,
params: ModelParameters,
consumer: Consumer,
prompt: GeminiGenAIPrompt,
) -> None:

with Timer("predict timing: {time}", log.debug):
if log.isEnabledFor(DEBUG):
log.debug(
"predict request: "
+ json_dumps_short({"parameters": params, "prompt": prompt})
)

completion = ""
async for content in generate_with_retries(
lambda: self.process_chunks(
consumer,
prompt.tools,
lambda: self.send_message_async(params, prompt),
),
2,
):
completion += content

log.debug(f"predict response: {completion!r}")
Loading
Loading