From 7a7d4bc796f4870c66736f24c68bca40770c7687 Mon Sep 17 00:00:00 2001 From: TIANYOU CHEN <42710806+CTY-git@users.noreply.github.com> Date: Thu, 13 Feb 2025 22:56:54 +0800 Subject: [PATCH] fix gemini issue --- patchwork/common/client/llm/google.py | 8 ++++---- patchwork/steps/ReadEmail/ReadEmail.py | 12 ++++++++---- patchwork/steps/SendEmail/SendEmail.py | 4 ++-- patchwork/steps/SimplifiedLLMOnce/typed.py | 1 + pyproject.toml | 2 +- 5 files changed, 16 insertions(+), 11 deletions(-) diff --git a/patchwork/common/client/llm/google.py b/patchwork/common/client/llm/google.py index 465feb39..4168e41c 100644 --- a/patchwork/common/client/llm/google.py +++ b/patchwork/common/client/llm/google.py @@ -108,7 +108,7 @@ def is_prompt_supported( model=model, contents=chat, config=CountTokensConfig( - system_instructions=system, + system_instruction=system, ), ) token_count = token_response.total_tokens @@ -181,7 +181,7 @@ def chat_completion( config=GenerateContentConfig( system_instruction=system_content, safety_settings=self.__SAFETY_SETTINGS, - **generation_dict, + **NotGiven.remove_not_given(generation_dict), ), ) return self.__google_response_to_openai_response(response, model) @@ -189,7 +189,7 @@ def chat_completion( @staticmethod def __google_response_to_openai_response(google_response: GenerateContentResponse, model: str) -> ChatCompletion: choices = [] - for candidate in google_response.candidates: + for index, candidate in enumerate(google_response.candidates): # note that instead of system, from openai, its model, from google. parts = [part.text or part.inline_data for part in candidate.content.parts] @@ -202,7 +202,7 @@ def __google_response_to_openai_response(google_response: GenerateContentRespons choice = Choice( finish_reason=finish_reason_map.get(candidate.finish_reason, "stop"), - index=candidate.index, + index=index, message=ChatCompletionMessage( content="\n".join(parts), role="assistant", diff --git a/patchwork/steps/ReadEmail/ReadEmail.py b/patchwork/steps/ReadEmail/ReadEmail.py index 49ef0bab..d3b04b9b 100644 --- a/patchwork/steps/ReadEmail/ReadEmail.py +++ b/patchwork/steps/ReadEmail/ReadEmail.py @@ -13,9 +13,11 @@ from patchwork.step import Step from patchwork.steps.ReadEmail.typed import ReadEmailInputs, ReadEmailOutputs + class InnerParsedHeader(BaseModel): message_id: list[str] = Field(alias="message-id") + class ParsedHeader(BaseModel): subject: str = "" from_: str = Field(alias="from", default_factory=str) @@ -91,10 +93,12 @@ def run(self) -> dict: for content_transfer_encoding in attachment.content_header.content_transfer_encoding: content = self.__decode(content_transfer_encoding, content) f.write(content) - rv["attachments"].append(dict( - path=str(file_path), - # content=content.decode(), - )) + rv["attachments"].append( + dict( + path=str(file_path), + # content=content.decode(), + ) + ) for body in email_data.body: rv["body"] += body.content diff --git a/patchwork/steps/SendEmail/SendEmail.py b/patchwork/steps/SendEmail/SendEmail.py index a6105764..5042041f 100644 --- a/patchwork/steps/SendEmail/SendEmail.py +++ b/patchwork/steps/SendEmail/SendEmail.py @@ -27,8 +27,8 @@ def run(self) -> dict: msg["From"] = self.sender_email msg["To"] = self.recipient_email if self.reply_message_id is not None: - msg.add_header('Reference', self.reply_message_id) - msg.add_header('In-Reply-To', self.reply_message_id) + msg.add_header("Reference", self.reply_message_id) + msg.add_header("In-Reply-To", self.reply_message_id) # TODO: support smtp without ssl with smtplib.SMTP_SSL(self.smtp_host, self.smtp_port) as smtp_server: diff --git a/patchwork/steps/SimplifiedLLMOnce/typed.py b/patchwork/steps/SimplifiedLLMOnce/typed.py index aced8838..631b2cd6 100644 --- a/patchwork/steps/SimplifiedLLMOnce/typed.py +++ b/patchwork/steps/SimplifiedLLMOnce/typed.py @@ -38,6 +38,7 @@ class SimplifiedLLMOnceInputs(__SimplifiedLLMOncePBInputsRequired, total=False): google_api_key: Annotated[ str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key"]) ] + file: Annotated[str, StepTypeConfig(is_path=True)] class SimplifiedLLMOnceOutputs(TypedDict): diff --git a/pyproject.toml b/pyproject.toml index 0caad3a7..811a9a2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "patchwork-cli" -version = "0.0.99.dev2" +version = "0.0.99.dev3" description = "" authors = ["patched.codes"] license = "AGPL"