From e2b30dbcf1cf10b734aa5cc3ea814ea0ddd425b3 Mon Sep 17 00:00:00 2001 From: p3n Date: Wed, 11 Sep 2024 23:50:10 -0400 Subject: [PATCH] created renderer prompt class --- oproof/manager.py | 3 ++- oproof/renderer.py | 32 ++++++++++++++++++++++++++++++++ oproof/serializer.py | 3 ++- oproof/task.py | 36 +++++++----------------------------- oproof/template.py | 37 ++++++++++++++++++++++++------------- oproof/validator.py | 3 ++- 6 files changed, 69 insertions(+), 45 deletions(-) create mode 100644 oproof/renderer.py diff --git a/oproof/manager.py b/oproof/manager.py index febb72b..a862a57 100644 --- a/oproof/manager.py +++ b/oproof/manager.py @@ -31,7 +31,8 @@ def validate_response(self, prompt: str, response: str) -> Dict[str, Any]: 'is_valid': validation_result['is_valid'], 'domain': domain, 'context': context, - 'reason': validation_result.get('reason', None) + 'reason': validation_result.get('reason', None), + 'raw_response': validation_result.get('raw_response', "") } def generate_prompt(self, prompt: str, response: str) -> str: diff --git a/oproof/renderer.py b/oproof/renderer.py new file mode 100644 index 0000000..f2e3a83 --- /dev/null +++ b/oproof/renderer.py @@ -0,0 +1,32 @@ +from jinja2 import Template as T +from .template import Template +from .constants import Const + +class Renderer: + @staticmethod + def render_prompt(prompt: str, response: str, system_prompt: str, instructions: str, cfg) -> str: + task_name = list(Template.TASKS.keys())[0] + rendered_template = Template.TEMPLATES["validation"].render( + system=system_prompt, + task=task_name, + text=prompt, + example=Template.TASKS[task_name], + instructions=instructions, + lang=cfg.lang, + prompt=prompt, + response=response + ) + return Renderer._post_process(rendered_template) + + @staticmethod + def _post_process(prompt: str) -> str: + task_name = list(Template.TASKS.keys())[0] + replacements = { + "{{ task }}": task_name, + "{{ lang }}": Const.LANG_DEFAULT, + "{{ system_type }}": Template.SYSTEM_TYPE, + "{{ domains }}": ', '.join(Template.DOMAINS) + } + for key, value in replacements.items(): + prompt = prompt.replace(key, value) + return prompt diff --git a/oproof/serializer.py b/oproof/serializer.py index d27aff7..f643e03 100644 --- a/oproof/serializer.py +++ b/oproof/serializer.py @@ -12,7 +12,8 @@ def serialize_output(text: str, responses: List[Dict[str, Any]], response_prompt "is_valid": response.get("is_valid", False), "domain": response.get("domain", "unknown"), "context": response.get("context", "unknown"), - "reason": response.get("reason", "No reason provided") + "reason": response.get("reason", "No reason provided"), + "raw_response": response.get("raw_response", "") # Include the raw response } for response in responses ] diff --git a/oproof/task.py b/oproof/task.py index 6105f9e..faf7e83 100644 --- a/oproof/task.py +++ b/oproof/task.py @@ -4,6 +4,7 @@ from .constants import Const from .template import Template from .log import Log +from .renderer import Renderer import ollama as oll from httpx import ConnectError @@ -20,11 +21,10 @@ def run(self, cmd: List[str], error_msg: str = Const.RUN_COMMAND_ERROR) -> None: self._log_error_and_raise(error_msg, error_msg) def execute(self, prompt: str, response: str, template, system_prompt, instructions) -> Dict[str, Any]: - rendered_prompt = self._render_prompt(prompt, response, template, system_prompt, instructions) - processed_prompt = self._post_process(rendered_prompt) - Log.debug(f"Prompt: {processed_prompt}") + rendered_prompt = Renderer.render_prompt(prompt, response, system_prompt, instructions, self.cfg) + Log.debug(f"Prompt: {rendered_prompt}") - output = self._generate_output(processed_prompt) + output = self._generate_output(rendered_prompt) Log.debug(f"Response: {output}") Log.debug(Const.PROMPT_SEPARATOR) @@ -34,33 +34,12 @@ def execute(self, prompt: str, response: str, template, system_prompt, instructi parsed_response = self._parse_response(output['response']) Log.debug(f"Parsed Response: {parsed_response}") - return {"prompt": processed_prompt, "data": output['response'], "response": parsed_response} + return {"prompt": rendered_prompt, "data": output['response'], "response": parsed_response} def _log_error_and_raise(self, error_message: str, exception_message: str) -> None: Log.error(error_message) raise Exception(exception_message) - def _render_prompt(self, prompt: str, response: str, template, system_prompt, instructions) -> str: - task_name = list(Template.TASKS.keys())[0] - return template.render( - system=system_prompt, - task=task_name, - text=prompt, - example=Template.TASKS[task_name], - instructions=instructions, - lang=self.cfg.lang - ) - - def _post_process(self, prompt: str) -> str: - task_name = list(Template.TASKS.keys())[0] - replacements = { - "{{ task }}": task_name, - "{{ lang }}": Const.LANG_DEFAULT - } - for key, value in replacements.items(): - prompt = prompt.replace(key, value) - return prompt - def _generate_output(self, prompt: str) -> Dict[str, Any]: try: return oll.generate(prompt=prompt, model=self.cfg.model) @@ -74,9 +53,8 @@ def _parse_response(self, response: str) -> Any: try: corrected_response = self._correct_response(response) return json.loads(corrected_response) - except json.JSONDecodeError as e: - Log.error(f"JSON decode error: {e}") - return {"error": "Invalid JSON response"} + except Exception as ex: + return {"error": str(ex)} def _correct_response(self, response: str) -> str: if response.startswith("[") and response.endswith("]"): diff --git a/oproof/template.py b/oproof/template.py index 3faf102..e801203 100644 --- a/oproof/template.py +++ b/oproof/template.py @@ -1,32 +1,43 @@ from jinja2 import Template as T class Template: - INSTRUCTIONS = ( - "Provide the task as plain JSON, no explanations or markdown.\n" - "Return a JSON object with the validation result.\n" - "The object should include 'is_valid', 'domain', 'context', and 'reason' fields.\n" - "No markdown or code.\n" - "Do not answer the input; only validate the response.\n" - "No explanations; only the JSON object." + SYSTEM_TYPE = ( + "proof validation" ) + DOMAINS = [ + "basic math", + "grammar", + "spelling" + ] + SYSTEM_PROMPTS = { - "validation": "You are an expert validation system that validates responses for {{ task }}s. Validate the following response in {{ lang }}." + "validation": "You are an expert {{ system_type }} system that identifies the domain for {{ task }}. Identify the domain of the following response in {{ lang }}." } + + INSTRUCTIONS = ( + "You are an expert {{ system_type }} system. Your task is to identify the domain and context of the given pair of prompt and response strings.\n" + "Return the domain and context as plain text.\n" + "Do not provide any explanations, markdown, code, or other content beside a JSON Object.\n" + "Only return the domain and context of input prompt and response pair.\n" + "Return JSON Object of type { \"domain\": domain, \"context\": context }\n" + "The domains to choose from are: {{ domains }}.\n" + ) - PROMPT_TEMPLATE = ( + PROMPT = ( "System: {{ system }}\n" "Instructions: {{ instructions }}\n" - "Example: {{ example }}\n" + "Example: 'What is 2 + 2?' '4' returns 'basic math' with context 'arithmetic'\n" "User: {{ prompt }}\n" "Response: {{ response }}\n" - "System: Return only a JSON object with the validation result. No explanations, only JSON object; e.g., {\"is_valid\": true, \"domain\": \"basic math\", \"context\": \"arithmetic\", \"reason\": null}" + "System: Return only the domain and context. No explanations, only the domain and context; e.g., {\"domain\": \"basic math\", \"context\": \"arithmetic\"}\n" + "The domains to choose from are: {{ domains }}.\n" ) TEMPLATES = { - "validation": T(PROMPT_TEMPLATE) + "validation": T(PROMPT) } TASKS = { - "proofs": "Proof the given prompt and response pair of input text strings. e.g., 'What is 2 + 2?' '4' returns {\"is_valid\": true, \"domain\": \"basic math\", \"context\": \"arithmetic\", \"reason\": null}" + "proofs": "Proof the given prompt and response pair of input text strings. e.g., 'What is 2 + 2?' '4' returns 'basic math' with context 'arithmetic'\n" } diff --git a/oproof/validator.py b/oproof/validator.py index d97ec54..f382e8d 100644 --- a/oproof/validator.py +++ b/oproof/validator.py @@ -35,5 +35,6 @@ def _collect_validation(self, prompt: str, response: str) -> Dict[str, Any]: 'is_valid': is_valid, 'domain': domain, 'context': context, - 'reason': reason + 'reason': reason, + 'raw_response': validation_result # Include the raw response }