From bc7eda53bc26d5500407b93ec0e51828f480d999 Mon Sep 17 00:00:00 2001 From: zg0d233 <11971297+zhao85@users.noreply.github.com> Date: Mon, 7 Oct 2024 11:06:38 +0800 Subject: [PATCH] fix bug when adding openai or openai-compatible stt model instance (#9006) --- .../agent/output_parser/cot_output_parser.py | 2 +- .../openai/speech2text/speech2text.py | 17 +++++++++++++++++ .../speech2text/speech2text.py | 17 +++++++++++++++++ api/core/tools/tool/workflow_tool.py | 11 +++++++---- 4 files changed, 42 insertions(+), 5 deletions(-) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index d04e38777a54aa..99876b2f5eb3e3 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -14,7 +14,7 @@ def handle_react_stream_output( ) -> Generator[Union[str, AgentScratchpadUnit.Action], None, None]: def parse_action(json_str): try: - action = json.loads(json_str) + action = json.loads(json_str, strict=False) action_name = None action_input = None diff --git a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py index 18f97e45f33bd8..0d54d2ea9a9dda 100644 --- a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py @@ -2,6 +2,8 @@ from openai import OpenAI +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI @@ -58,3 +60,18 @@ def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> response = client.audio.transcriptions.create(model=model, file=file) return response.text + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.SPEECH2TEXT, + model_properties={}, + parameter_rules=[], + ) + + return entity diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py index 405096578cdd5d..cef77cc9419d1f 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py @@ -3,6 +3,8 @@ import requests +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel @@ -59,3 +61,18 @@ def validate_credentials(self, model: str, credentials: dict) -> None: self._invoke(model, credentials, audio_file) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.SPEECH2TEXT, + model_properties={}, + parameter_rules=[], + ) + + return entity diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index ad0c7fc631fd08..a885b8784f27f6 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -68,10 +68,13 @@ def _invoke( result = [] - outputs = data.get("outputs", {}) - outputs, files = self._extract_files(outputs) - for file in files: - result.append(self.create_file_var_message(file)) + outputs = data.get("outputs") + if outputs == None: + outputs = {} + else: + outputs, files = self._extract_files(outputs) + for file in files: + result.append(self.create_file_var_message(file)) result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False))) result.append(self.create_json_message(outputs))