Skip to content

Commit

Permalink
sav
Browse files Browse the repository at this point in the history
  • Loading branch information
CTY-git committed Feb 5, 2025
1 parent f5eef37 commit 238b79c
Show file tree
Hide file tree
Showing 10 changed files with 825 additions and 556 deletions.
30 changes: 19 additions & 11 deletions patchwork/common/client/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,7 @@ def example_json_to_schema(json_example: str | dict | None) -> ResponseFormat |
if json_example is None:
return None

base_model = None
if isinstance(json_example, str):
base_model = __example_string_to_base_model(json_example)
elif isinstance(json_example, dict):
base_model = __example_dict_to_base_model(json_example)

base_model = example_json_to_base_model(json_example)
if base_model is None:
return None

Expand All @@ -76,25 +71,38 @@ def base_model_to_schema(base_model: Type[BaseModel]) -> ResponseFormat:
return type_to_response_format_param(base_model)


def __example_string_to_base_model(json_example: str) -> Type[BaseModel] | None:
def example_json_to_base_model(json_example: str | dict | None) -> Type[BaseModel] | None:
if json_example is None:
return None

base_model = None
if isinstance(json_example, str):
base_model = example_string_to_base_model(json_example)
elif isinstance(json_example, dict):
base_model = example_dict_to_base_model(json_example)

return base_model


def example_string_to_base_model(json_example: str) -> Type[BaseModel] | None:
try:
example_data = json.loads(json_example)
except Exception as e:
logger.error(f"Failed to parse example json", e)
return None

return __example_dict_to_base_model(example_data)
return example_dict_to_base_model(example_data)


def __example_dict_to_base_model(example_data: dict) -> Type[BaseModel]:
def example_dict_to_base_model(example_data: dict) -> Type[BaseModel]:
base_model_field_defs: dict[str, tuple[type | BaseModel, Field]] = dict()
for example_data_key, example_data_value in example_data.items():
if isinstance(example_data_value, dict):
value_typing = __example_dict_to_base_model(example_data_value)
value_typing = example_dict_to_base_model(example_data_value)
elif isinstance(example_data_value, list):
nested_value = example_data_value[0]
if isinstance(nested_value, dict):
nested_typing = __example_dict_to_base_model(nested_value)
nested_typing = example_dict_to_base_model(nested_value)
else:
nested_typing = type(nested_value)
value_typing = List[nested_typing]
Expand Down
47 changes: 37 additions & 10 deletions patchwork/common/multiturn_strategy/agentic_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
import sys
from json import JSONDecodeError
from pathlib import Path
from typing import Union, Any

import chevron
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from pydantic_ai import Agent
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic import BaseModel

from patchwork.common.client.llm.protocol import LlmClient
from patchwork.common.client.llm.utils import example_string_to_base_model, example_json_to_base_model
from patchwork.common.tools import CodeEditTool, Tool
from patchwork.common.tools.agentic_tools import EndTool

Expand Down Expand Up @@ -106,22 +111,46 @@ def __init__(self, llm_client: LlmClient, tool_set: dict[str, Tool], system_prom
self.history.append(dict(role="system", content=system_prompt))




class AgentConfig(BaseModel):
name: str
tool_set: dict[str, Tool]
system_prompt: str = ''


class AgenticStrategy:
def __init__(
self,
llm_client: LlmClient,
tool_set: dict[str, Tool],
api_key: str,
template_data: dict[str, str],
system_prompt_template: str,
user_prompt_template: str,
agent_configs: list[AgentConfig],
example_json: Union[str, dict[str, Any]] = '{"output":"output text"}',
*args,
**kwargs,
):
self.tool_set = dict(end=EndTool(), **tool_set)
self.__template_data = template_data
self.__user_prompt_template = user_prompt_template
self.__assistant_role = Assistant(llm_client, self.tool_set, self.__render_prompt(system_prompt_template))
self.__user_role = UserProxy(llm_client, dict())
model = AnthropicModel("claude-3-5-sonnet-latest", api_key=api_key)
self.__user_role = Agent(
model,
system_prompt=self.__render_prompt(system_prompt_template),
result_type=example_json_to_base_model(example_json),
)
self.__assistants = []
for assistant_config in agent_configs:
tools = []
for tool in assistant_config.tool_set.values():
tools.append(tool.to_pydantic_ai_function_tool())
assistant = Agent(
"claude-3-5-sonnet-latest",
system_prompt=self.__render_prompt(assistant_config.system_prompt),
tools=tools
)

self.__assistants.append(assistant)

def __render_prompt(self, prompt_template: str) -> str:
chevron.render.__globals__["_html_escape"] = lambda x: x
Expand All @@ -133,9 +162,6 @@ def __render_prompt(self, prompt_template: str) -> str:
partials_dict=dict(),
)

def __get_initial_prompt(self) -> list[ChatCompletionMessageParam]:
return [dict(role="user", content=self.__render_prompt(self.__user_prompt_template))]

def __is_session_completed(self) -> bool:
for message in reversed(self.__assistant_role.history):
if message.get("tool") is not None:
Expand All @@ -149,9 +175,10 @@ def execute(self, limit: int | None = None) -> None:
message = self.__render_prompt(self.__user_prompt_template)
try:
for i in range(limit or self.__limit or sys.maxsize):
self.__user_role.run_sync(self.__user_prompt_template)
self.run_count = i + 1
for role in [self.__assistant_role, self.__user_role]:
message = role.generate_reply(message)
for role in [*self.__assistants, self.__user_role]:
message = role.run_sync(message)
if self.__is_session_completed():
break
except Exception as e:
Expand Down
8 changes: 6 additions & 2 deletions patchwork/common/tools/code_edit_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

import os
from pathlib import Path
from typing import Literal
from typing import Literal, Optional

from patchwork.common.tools.tool import Tool
from patchwork.common.utils.utils import detect_newline
from pydantic import BaseModel
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.test import TestModel
from pydantic_ai.tools import Tool, ToolDefinition


class CodeEditTool(Tool, tool_name="code_edit_tool"):
Expand Down
10 changes: 10 additions & 0 deletions patchwork/common/tools/tool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Type
from pydantic_ai.tools import ToolDefinition, Tool as PydanticTool, RunContext


class Tool(ABC):
Expand Down Expand Up @@ -42,3 +43,12 @@ def get_description(tooling: "ToolProtocol") -> str:
@staticmethod
def get_parameters(tooling: "ToolProtocol") -> str:
return ", ".join(tooling.json_schema.get("required", []))

def to_pydantic_ai_function_tool(self) -> PydanticTool[None]:
async def _prep(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition:
tool_def.name = self.name
tool_def.description = self.json_schema.get("description", "")
tool_def.parameters_json_schema = self.json_schema.get("input_schema", {})
return tool_def

return PydanticTool(self.execute, prepare=_prep)
17 changes: 1 addition & 16 deletions patchwork/patchflows/LogAnalysis/LogAnalysis.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,11 @@
from enum import IntEnum
from pathlib import Path

import yaml

from patchwork.common.utils.progress_bar import PatchflowProgressBar
from patchwork.common.utils.step_typing import validate_steps_with_inputs
from patchwork.logger import logger
from patchwork.step import Step
from patchwork.steps import (
LLM,
PR,
CallLLM,
CommitChanges,
CreatePR,
ExtractCode,
ExtractModelResponse,
ModifyCode,
PreparePR,
PreparePrompt,
ScanSemgrep,
ScanSonar, CallSQL, AgenticLLM,
)
from patchwork.steps import AgenticLLM, CallSQL

_DEFAULT_INPUT_FILE = Path(__file__).parent / "defaults.yml"

Expand Down
2 changes: 0 additions & 2 deletions patchwork/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ def __init__(self, inputs: DataPoint):
self.run = self.__managed_run

def __init_subclass__(cls, input_class: Optional[Type] = None, output_class: Optional[Type] = None, **kwargs):
if cls.__name__ == "PreparePR":
print(1)
input_class = input_class or getattr(cls, "input_class", None)
if input_class is not None and not is_typeddict(input_class):
input_class = None
Expand Down
11 changes: 7 additions & 4 deletions patchwork/steps/AgenticLLM/AgenticLLM.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path

from patchwork.common.client.llm.aio import AioLlmClient
from patchwork.common.multiturn_strategy.agentic_strategy import AgenticStrategy
from patchwork.common.multiturn_strategy.agentic_strategy import AgenticStrategy, AgentConfig
from patchwork.common.tools import Tool
from patchwork.step import Step
from patchwork.steps.AgenticLLM.typed import AgenticLLMInputs, AgenticLLMOutputs
Expand All @@ -15,13 +15,16 @@ def __init__(self, inputs):
base_path = str(Path.cwd())
self.conversation_limit = int(int(inputs.get("max_llm_calls", 2)) / 2)
self.agentic_strategy = AgenticStrategy(
llm_client=AioLlmClient.create_aio_client(inputs),
tool_set=Tool.get_tools(path=base_path),
api_key=inputs.get("anthropic_api_key"),
template_data=inputs.get("prompt_value"),
system_prompt_template=inputs.get("system_prompt"),
system_prompt_template="",
user_prompt_template=inputs.get("user_prompt"),
agent_configs=[
AgentConfig(name="", tool_set=Tool.get_tools(path=base_path), system_prompt=inputs.get("system_prompt"))
]
)


def run(self) -> dict:
self.agentic_strategy.execute(limit=self.conversation_limit)
return dict(
Expand Down
1 change: 0 additions & 1 deletion patchwork/steps/CallShell/CallShell.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from patchwork.common.utils.utils import mustache_render
from patchwork.logger import logger
from patchwork.step import Step, StepStatus
from patchwork.steps import CallSQL
from patchwork.steps.CallShell.typed import CallShellInputs, CallShellOutputs


Expand Down
Loading

0 comments on commit 238b79c

Please sign in to comment.