Skip to content

Commit

Permalink
Add gemini client support
Browse files Browse the repository at this point in the history
  • Loading branch information
richard-to committed Sep 9, 2024
1 parent 55be4d9 commit 4d9f803
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 27 deletions.
113 changes: 113 additions & 0 deletions ai/src/ai/common/llm_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from os import getenv
from typing import Iterable, Protocol

import google.generativeai as genai
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam

genai.configure(api_key=getenv("GOOGLE_API_KEY"))


class LlmClient(Protocol):
def generate_content_stream(
self,
model: str,
messages: Iterable[ChatCompletionMessageParam],
max_tokens: int,
):
raise NotImplementedError()

def generate_content_blocking(
self,
model: str,
messages: Iterable[ChatCompletionMessageParam],
max_tokens: int,
):
raise NotImplementedError()


class OpenAIClient(LlmClient):
def __init__(self) -> None:
self.client = OpenAI(
api_key=getenv("OPENAI_API_KEY"),
)

def generate_content_stream(
self,
model: str,
messages: Iterable[ChatCompletionMessageParam],
max_tokens: int,
):
response = self.client.chat.completions.create(
model=model,
max_tokens=max_tokens,
messages=messages,
stream=True,
)
for chunk in response:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content

def generate_content_blocking(
self,
model: str,
messages: Iterable[ChatCompletionMessageParam],
max_tokens: int,
):
response = self.client.chat.completions.create(
model=model,
max_tokens=max_tokens,
messages=messages,
stream=False,
)
content = response.choices[0].message.content
assert content is not None
return content


class GeminClient(LlmClient):
def generate_content_stream(
self,
model: str,
messages: Iterable[ChatCompletionMessageParam],
max_tokens: int,
):
contents = self._make_messages(messages)
client = self._make_client(model, max_tokens)
for response in client.generate_content(contents, stream=True):
yield response.text

def generate_content_blocking(
self,
model: str,
messages: Iterable[ChatCompletionMessageParam],
max_tokens: int,
):
contents = self._make_messages(messages)
client = self._make_client(model, max_tokens)
return client.generate_content(contents).text

def _make_client(self, model: str, max_tokens: int) -> genai.GenerativeModel:
return genai.GenerativeModel(
model_name=model,
generation_config={
"temperature": 1,
"top_p": 0.95,
"top_k": 64,
"max_output_tokens": max_tokens,
"response_mime_type": "text/plain",
}, # type: ignore
)

def _make_messages(
self, messages: Iterable[ChatCompletionMessageParam]
) -> Iterable[str]:
contents = []
for message in messages:
if message["role"] == "system":
contents.append(message["content"])
elif message["role"] == "user":
contents.append("input: " + str(message["content"]))
elif message["role"] == "assistant" and "content" in message:
contents.append("output: " + str(message["content"]))
return contents
62 changes: 38 additions & 24 deletions ai/src/ai/common/llm_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from typing import NamedTuple

from dotenv import load_dotenv
from openai import OpenAI
from openai.types.chat import (
ChatCompletionMessageParam,
)

from ai.common.llm_client import GeminClient, LlmClient, OpenAIClient

load_dotenv()

EDIT_HERE_MARKER = " # <--- EDIT HERE"
Expand Down Expand Up @@ -52,9 +53,8 @@ def apply_patch(original_code: str, patch: str) -> ApplyPatchResult:


DEFAULT_MODEL = "ft:gpt-4o-mini-2024-07-18:mesop:small-prompt:A1472X3c"
DEFAULT_CLIENT = OpenAI(
api_key=getenv("OPENAI_API_KEY"),
)
DEFAULT_CLIENT = OpenAIClient()
GEMINI_CLIENT = GeminClient()


class MessageFormatter:
Expand Down Expand Up @@ -117,8 +117,11 @@ def MakeMessageFormatterShorterUserMsg():
)


def load_unused_goldens():
goldens_path = "ft/gen/formatted_dataset_for_prompting.jsonl"
def load_unused_goldens(all: bool = False):
if all:
goldens_path = "ft/gen/formatted_dataset.jsonl"
else:
goldens_path = "ft/gen/formatted_dataset_for_prompting.jsonl"
new_goldens = []
num_rows = 0
try:
Expand All @@ -138,10 +141,12 @@ def load_unused_goldens():

if getenv("MESOP_AI_INCLUDE_NEW_GOLDENS"):
message_formatter = MakeMessageFormatterShorterUserMsg()
goldens_path = load_unused_goldens()
goldens = load_unused_goldens()
all_goldens = load_unused_goldens(all=True)
else:
message_formatter = MakeDefaultMessageFormatter()
goldens_path = []
goldens = []
all_goldens = []


def format_messages(
Expand All @@ -155,22 +160,20 @@ def adjust_mesop_app_stream(
code: str,
user_input: str,
line_number: int | None,
client: OpenAI = DEFAULT_CLIENT,
client: LlmClient | None = None,
model: str = DEFAULT_MODEL,
):
"""
Returns a stream of the code diff.
"""
if not client:
client = GEMINI_CLIENT if "gemini" in model else DEFAULT_CLIENT
messages = format_messages(code, user_input, line_number)

if goldens_path:
messages = [messages[0], *goldens_path, messages[1]]

return client.chat.completions.create(
messages = _include_goldens(messages, model)
return client.generate_content_stream(
model=model,
max_tokens=16_384,
messages=messages,
stream=True,
)


Expand All @@ -179,22 +182,33 @@ def adjust_mesop_app_blocking(
code: str,
user_input: str,
line_number: int | None = None,
client: OpenAI = DEFAULT_CLIENT,
client: LlmClient | None = None,
model: str = DEFAULT_MODEL,
) -> str:
"""
Returns the code diff.
"""
if not client:
client = GEMINI_CLIENT if "gemini" in model else DEFAULT_CLIENT
messages = format_messages(code, user_input, line_number)
if goldens_path:
messages = [messages[0], *goldens_path, messages[1]]

response = client.chat.completions.create(
messages = _include_goldens(messages, model)
return client.generate_content_blocking(
model=model,
max_tokens=16_384,
messages=messages,
stream=False,
)
content = response.choices[0].message.content
assert content is not None
return content


def _include_goldens(
messages: list[ChatCompletionMessageParam], model: str
) -> list[ChatCompletionMessageParam]:
examples = []
if all_goldens and "gemini" in model:
examples = all_goldens
elif goldens:
examples = goldens

if examples:
messages = [messages[0], *examples, messages[1]]

return messages
6 changes: 3 additions & 3 deletions ai/src/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def generate():
)
diff = ""
for chunk in stream:
if chunk.choices[0].delta.content:
diff += chunk.choices[0].delta.content
yield f"data: {json.dumps({'type': 'progress', 'data': chunk.choices[0].delta.content})}\n\n"
if chunk:
diff += chunk
yield f"data: {json.dumps({'type': 'progress', 'data': chunk})}\n\n"

result = apply_patch(code, diff)
if result.has_error:
Expand Down

0 comments on commit 4d9f803

Please sign in to comment.