From 0b255b3c52e6942146915916d7985f2087e546a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Sat, 21 Dec 2024 21:42:53 +0800 Subject: [PATCH 1/9] fix gemini system prompt with variable raise error --- api/core/model_runtime/model_providers/google/llm/llm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index b54668a12d3207..b26904ccf075c9 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -20,6 +20,7 @@ PromptMessageContent, PromptMessageContentType, PromptMessageTool, + TextPromptMessageContent, SystemPromptMessage, ToolPromptMessage, UserPromptMessage, @@ -404,6 +405,9 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: ) return glm_content elif isinstance(message, SystemPromptMessage): + if isinstance(message.content, list): + text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content) + message.content = "".join(c.data for c in text_contents) return {"role": "user", "parts": [to_part(message.content)]} elif isinstance(message, ToolPromptMessage): return { From 702ce3a1a5370eeadbfa2f65a5e2623c1fa19ec9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Sat, 21 Dec 2024 21:43:01 +0800 Subject: [PATCH 2/9] fix gemini system prompt with variable raise error --- api/core/model_runtime/model_providers/google/llm/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index b26904ccf075c9..4235534ee0b828 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -20,8 +20,8 @@ PromptMessageContent, PromptMessageContentType, PromptMessageTool, - TextPromptMessageContent, SystemPromptMessage, + TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) From e50991e4b00148d5d563f3a4d1bc8919f7ca2d39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Sat, 21 Dec 2024 22:15:21 +0800 Subject: [PATCH 3/9] support system_instruction --- .../model_runtime/model_providers/google/llm/llm.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 4235534ee0b828..7efab826aa3e63 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,6 +1,7 @@ import base64 import json import os +import sys import tempfile import time from collections.abc import Generator @@ -187,18 +188,21 @@ def _generate( if stop: config_kwargs["stop_sequences"] = stop - genai.configure(api_key=credentials["google_api_key"]) - google_model = genai.GenerativeModel(model_name=model) + genai.configure(api_key=credentials["google_api_key"]) history = [] + system_instruction = "" for msg in prompt_messages: # makes message roles strictly alternating content = self._format_message_to_glm_content(msg) if history and history[-1]["role"] == content["role"]: history[-1]["parts"].extend(content["parts"]) + elif content["role"] == "system": + system_instruction = content["parts"][0] else: history.append(content) + google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction) response = google_model.generate_content( contents=history, generation_config=genai.types.GenerationConfig(**config_kwargs), @@ -408,7 +412,7 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: if isinstance(message.content, list): text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content) message.content = "".join(c.data for c in text_contents) - return {"role": "user", "parts": [to_part(message.content)]} + return {"role": "system", "parts": [to_part(message.content)]} elif isinstance(message, ToolPromptMessage): return { "role": "function", From e36906937c62b50295106ef3e8ed5dd247beb533 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Sat, 21 Dec 2024 22:15:29 +0800 Subject: [PATCH 4/9] support system_instruction --- api/core/model_runtime/model_providers/google/llm/llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 7efab826aa3e63..b0af6716b50bff 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,7 +1,6 @@ import base64 import json import os -import sys import tempfile import time from collections.abc import Generator From d69b584a8edef53721ffa6be4c7b190dc52c553a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Sat, 21 Dec 2024 22:18:48 +0800 Subject: [PATCH 5/9] fix CI --- api/core/model_runtime/model_providers/google/llm/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index b0af6716b50bff..eb0631d4af04c2 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -187,7 +187,7 @@ def _generate( if stop: config_kwargs["stop_sequences"] = stop - genai.configure(api_key=credentials["google_api_key"]) + genai.configure(api_key=credentials["google_api_key"]) history = [] system_instruction = "" From a890e95309f36a1bb682c536830bc428c3faaefe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Sat, 21 Dec 2024 22:21:42 +0800 Subject: [PATCH 6/9] system_instruction default value should be None --- api/core/model_runtime/model_providers/google/llm/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index eb0631d4af04c2..a672b1815a2aa4 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -190,7 +190,7 @@ def _generate( genai.configure(api_key=credentials["google_api_key"]) history = [] - system_instruction = "" + system_instruction = None for msg in prompt_messages: # makes message roles strictly alternating content = self._format_message_to_glm_content(msg) From f7982d24731a1fc070ae56e237a7dcec35014797 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Sat, 21 Dec 2024 22:35:01 +0800 Subject: [PATCH 7/9] only system prompts will raise error --- api/core/model_runtime/model_providers/google/llm/llm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index a672b1815a2aa4..1a8faf95ea3eb7 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -200,6 +200,9 @@ def _generate( system_instruction = content["parts"][0] else: history.append(content) + + if not history: + raise InvokeError("The user prompt message is required. You only add a system prompt message.") google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction) response = google_model.generate_content( From 3daffb6b7bcc812364f7e22d8c233ecb4898d13f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Sat, 21 Dec 2024 22:35:12 +0800 Subject: [PATCH 8/9] only system prompts will raise error --- api/core/model_runtime/model_providers/google/llm/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 1a8faf95ea3eb7..6653f62167649f 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -200,7 +200,7 @@ def _generate( system_instruction = content["parts"][0] else: history.append(content) - + if not history: raise InvokeError("The user prompt message is required. You only add a system prompt message.") From 6c63dbe701f984c432336ed2e25a5ed02235d997 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Sat, 21 Dec 2024 23:03:26 +0800 Subject: [PATCH 9/9] fix unit tests --- api/core/model_runtime/model_providers/google/llm/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 6653f62167649f..7d19ccbb74a011 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -144,7 +144,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: """ try: - ping_message = SystemPromptMessage(content="ping") + ping_message = UserPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_output_tokens": 5}) except Exception as ex: