Skip to content

Commit

Permalink
Fix missing messages in Gemini history (#2906)
Browse files Browse the repository at this point in the history
* fix missing message in history

* fix message handling

* add list of Parts to Content object

* add test for gemini message conversion function

* add test for gemini message conversion

* add message to asserts

* add safety setting support for vertexai

* remove vertexai safety settings
  • Loading branch information
luxzoli authored Jun 14, 2024
1 parent 6d4cf40 commit 10b8fa5
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
10 changes: 5 additions & 5 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def create(self, params: Dict) -> ChatCompletion:
for attempt in range(max_retries):
ans = None
try:
response = chat.send_message(gemini_messages[-1].parts[0].text, stream=stream)
response = chat.send_message(gemini_messages[-1], stream=stream)
except InternalServerError:
delay = 5 * (2**attempt)
warnings.warn(
Expand Down Expand Up @@ -344,19 +344,19 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li
for i, message in enumerate(messages):
parts = self._oai_content_to_gemini_content(message["content"])
role = "user" if message["role"] in ["user", "system"] else "model"

if prev_role is None or role == prev_role:
if (prev_role is None) or (role == prev_role):
curr_parts += parts
elif role != prev_role:
if self.use_vertexai:
rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=prev_role))
rst.append(VertexAIContent(parts=curr_parts, role=prev_role))
else:
rst.append(Content(parts=curr_parts, role=prev_role))
curr_parts = parts
prev_role = role

# handle the last message
if self.use_vertexai:
rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=role))
rst.append(VertexAIContent(parts=curr_parts, role=role))
else:
rst.append(Content(parts=curr_parts, role=role))

Expand Down
42 changes: 42 additions & 0 deletions test/oai/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,48 @@ def test_valid_initialization(gemini_client):
assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set"


@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_gemini_message_handling(gemini_client):
messages = [
{"role": "system", "content": "You are my personal assistant."},
{"role": "model", "content": "How can I help you?"},
{"role": "user", "content": "Which planet is the nearest to the sun?"},
{"role": "user", "content": "Which planet is the farthest from the sun?"},
{"role": "model", "content": "Mercury is the closest palnet to the sun."},
{"role": "model", "content": "Neptune is the farthest palnet from the sun."},
{"role": "user", "content": "How can we determine the mass of a black hole?"},
]

# The datastructure below defines what the structure of the messages
# should resemble after converting to Gemini format.
# Messages of similar roles are expected to be merged to a single message,
# where the contents of the original messages will be included in
# consecutive parts of the converted Gemini message
expected_gemini_struct = [
# system role is converted to user role
{"role": "user", "parts": ["You are my personal assistant."]},
{"role": "model", "parts": ["How can I help you?"]},
{
"role": "user",
"parts": ["Which planet is the nearest to the sun?", "Which planet is the farthest from the sun?"],
},
{
"role": "model",
"parts": ["Mercury is the closest palnet to the sun.", "Neptune is the farthest palnet from the sun."],
},
{"role": "user", "parts": ["How can we determine the mass of a black hole?"]},
]

converted_messages = gemini_client._oai_messages_to_gemini_messages(messages)

assert len(converted_messages) == len(expected_gemini_struct), "The number of messages is not as expected"

for i, expected_msg in enumerate(expected_gemini_struct):
assert expected_msg["role"] == converted_messages[i].role, "Incorrect mapped message role"
for j, part in enumerate(expected_msg["parts"]):
assert converted_messages[i].parts[j].text == part, "Incorrect mapped message text"


# Test error handling
@patch("autogen.oai.gemini.genai")
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
Expand Down

0 comments on commit 10b8fa5

Please sign in to comment.