Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix missing messages in Gemini history #2906

Merged
merged 9 commits into from
Jun 14, 2024
10 changes: 5 additions & 5 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def create(self, params: Dict) -> ChatCompletion:
for attempt in range(max_retries):
ans = None
try:
response = chat.send_message(gemini_messages[-1].parts[0].text, stream=stream)
response = chat.send_message(gemini_messages[-1], stream=stream)
except InternalServerError:
delay = 5 * (2**attempt)
warnings.warn(
Expand Down Expand Up @@ -344,19 +344,19 @@ def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> li
for i, message in enumerate(messages):
parts = self._oai_content_to_gemini_content(message["content"])
role = "user" if message["role"] in ["user", "system"] else "model"

if prev_role is None or role == prev_role:
if (prev_role is None) or (role == prev_role):
curr_parts += parts
elif role != prev_role:
if self.use_vertexai:
rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=prev_role))
rst.append(VertexAIContent(parts=curr_parts, role=prev_role))
else:
rst.append(Content(parts=curr_parts, role=prev_role))
curr_parts = parts
prev_role = role

# handle the last message
if self.use_vertexai:
rst.append(VertexAIContent(parts=self._concat_parts(curr_parts), role=role))
rst.append(VertexAIContent(parts=curr_parts, role=role))
else:
rst.append(Content(parts=curr_parts, role=role))

Expand Down
42 changes: 42 additions & 0 deletions test/oai/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,48 @@ def test_valid_initialization(gemini_client):
assert gemini_client.api_key == "fake_api_key", "API Key should be correctly set"


@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_gemini_message_handling(gemini_client):
messages = [
{"role": "system", "content": "You are my personal assistant."},
{"role": "model", "content": "How can I help you?"},
{"role": "user", "content": "Which planet is the nearest to the sun?"},
{"role": "user", "content": "Which planet is the farthest from the sun?"},
{"role": "model", "content": "Mercury is the closest palnet to the sun."},
{"role": "model", "content": "Neptune is the farthest palnet from the sun."},
{"role": "user", "content": "How can we determine the mass of a black hole?"},
]

# The datastructure below defines what the structure of the messages
# should resemble after converting to Gemini format.
# Messages of similar roles are expected to be merged to a single message,
# where the contents of the original messages will be included in
# consecutive parts of the converted Gemini message
expected_gemini_struct = [
# system role is converted to user role
{"role": "user", "parts": ["You are my personal assistant."]},
{"role": "model", "parts": ["How can I help you?"]},
{
"role": "user",
"parts": ["Which planet is the nearest to the sun?", "Which planet is the farthest from the sun?"],
},
{
"role": "model",
"parts": ["Mercury is the closest palnet to the sun.", "Neptune is the farthest palnet from the sun."],
},
{"role": "user", "parts": ["How can we determine the mass of a black hole?"]},
]

converted_messages = gemini_client._oai_messages_to_gemini_messages(messages)

assert len(converted_messages) == len(expected_gemini_struct), "The number of messages is not as expected"

for i, expected_msg in enumerate(expected_gemini_struct):
assert expected_msg["role"] == converted_messages[i].role, "Incorrect mapped message role"
for j, part in enumerate(expected_msg["parts"]):
assert converted_messages[i].parts[j].text == part, "Incorrect mapped message text"


# Test error handling
@patch("autogen.oai.gemini.genai")
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
Expand Down
Loading