Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: gemini system prompt with variable raise error #11946

Merged
merged 9 commits into from
Dec 21, 2024
16 changes: 13 additions & 3 deletions api/core/model_runtime/model_providers/google/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
Expand Down Expand Up @@ -143,7 +144,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
"""

try:
ping_message = SystemPromptMessage(content="ping")
ping_message = UserPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})

except Exception as ex:
Expand Down Expand Up @@ -187,17 +188,23 @@ def _generate(
config_kwargs["stop_sequences"] = stop

genai.configure(api_key=credentials["google_api_key"])
google_model = genai.GenerativeModel(model_name=model)

history = []
system_instruction = None

for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
elif content["role"] == "system":
system_instruction = content["parts"][0]
else:
history.append(content)

if not history:
raise InvokeError("The user prompt message is required. You only add a system prompt message.")

google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction)
response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(**config_kwargs),
Expand Down Expand Up @@ -404,7 +411,10 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
)
return glm_content
elif isinstance(message, SystemPromptMessage):
return {"role": "user", "parts": [to_part(message.content)]}
if isinstance(message.content, list):
text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content)
message.content = "".join(c.data for c in text_contents)
return {"role": "system", "parts": [to_part(message.content)]}
elif isinstance(message, ToolPromptMessage):
return {
"role": "function",
Expand Down
Loading