Skip to content

Commit

Permalink
created renderer prompt class
Browse files Browse the repository at this point in the history
  • Loading branch information
p3nGu1nZz committed Sep 12, 2024
1 parent 377a2d5 commit e2b30db
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 45 deletions.
3 changes: 2 additions & 1 deletion oproof/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def validate_response(self, prompt: str, response: str) -> Dict[str, Any]:
'is_valid': validation_result['is_valid'],
'domain': domain,
'context': context,
'reason': validation_result.get('reason', None)
'reason': validation_result.get('reason', None),
'raw_response': validation_result.get('raw_response', "")
}

def generate_prompt(self, prompt: str, response: str) -> str:
Expand Down
32 changes: 32 additions & 0 deletions oproof/renderer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from jinja2 import Template as T
from .template import Template
from .constants import Const

class Renderer:
@staticmethod
def render_prompt(prompt: str, response: str, system_prompt: str, instructions: str, cfg) -> str:
task_name = list(Template.TASKS.keys())[0]
rendered_template = Template.TEMPLATES["validation"].render(
system=system_prompt,
task=task_name,
text=prompt,
example=Template.TASKS[task_name],
instructions=instructions,
lang=cfg.lang,
prompt=prompt,
response=response
)
return Renderer._post_process(rendered_template)

@staticmethod
def _post_process(prompt: str) -> str:
task_name = list(Template.TASKS.keys())[0]
replacements = {
"{{ task }}": task_name,
"{{ lang }}": Const.LANG_DEFAULT,
"{{ system_type }}": Template.SYSTEM_TYPE,
"{{ domains }}": ', '.join(Template.DOMAINS)
}
for key, value in replacements.items():
prompt = prompt.replace(key, value)
return prompt
3 changes: 2 additions & 1 deletion oproof/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ def serialize_output(text: str, responses: List[Dict[str, Any]], response_prompt
"is_valid": response.get("is_valid", False),
"domain": response.get("domain", "unknown"),
"context": response.get("context", "unknown"),
"reason": response.get("reason", "No reason provided")
"reason": response.get("reason", "No reason provided"),
"raw_response": response.get("raw_response", "") # Include the raw response
}
for response in responses
]
Expand Down
36 changes: 7 additions & 29 deletions oproof/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .constants import Const
from .template import Template
from .log import Log
from .renderer import Renderer
import ollama as oll
from httpx import ConnectError

Expand All @@ -20,11 +21,10 @@ def run(self, cmd: List[str], error_msg: str = Const.RUN_COMMAND_ERROR) -> None:
self._log_error_and_raise(error_msg, error_msg)

def execute(self, prompt: str, response: str, template, system_prompt, instructions) -> Dict[str, Any]:
rendered_prompt = self._render_prompt(prompt, response, template, system_prompt, instructions)
processed_prompt = self._post_process(rendered_prompt)
Log.debug(f"Prompt: {processed_prompt}")
rendered_prompt = Renderer.render_prompt(prompt, response, system_prompt, instructions, self.cfg)
Log.debug(f"Prompt: {rendered_prompt}")

output = self._generate_output(processed_prompt)
output = self._generate_output(rendered_prompt)
Log.debug(f"Response: {output}")
Log.debug(Const.PROMPT_SEPARATOR)

Expand All @@ -34,33 +34,12 @@ def execute(self, prompt: str, response: str, template, system_prompt, instructi
parsed_response = self._parse_response(output['response'])
Log.debug(f"Parsed Response: {parsed_response}")

return {"prompt": processed_prompt, "data": output['response'], "response": parsed_response}
return {"prompt": rendered_prompt, "data": output['response'], "response": parsed_response}

def _log_error_and_raise(self, error_message: str, exception_message: str) -> None:
Log.error(error_message)
raise Exception(exception_message)

def _render_prompt(self, prompt: str, response: str, template, system_prompt, instructions) -> str:
task_name = list(Template.TASKS.keys())[0]
return template.render(
system=system_prompt,
task=task_name,
text=prompt,
example=Template.TASKS[task_name],
instructions=instructions,
lang=self.cfg.lang
)

def _post_process(self, prompt: str) -> str:
task_name = list(Template.TASKS.keys())[0]
replacements = {
"{{ task }}": task_name,
"{{ lang }}": Const.LANG_DEFAULT
}
for key, value in replacements.items():
prompt = prompt.replace(key, value)
return prompt

def _generate_output(self, prompt: str) -> Dict[str, Any]:
try:
return oll.generate(prompt=prompt, model=self.cfg.model)
Expand All @@ -74,9 +53,8 @@ def _parse_response(self, response: str) -> Any:
try:
corrected_response = self._correct_response(response)
return json.loads(corrected_response)
except json.JSONDecodeError as e:
Log.error(f"JSON decode error: {e}")
return {"error": "Invalid JSON response"}
except Exception as ex:
return {"error": str(ex)}

def _correct_response(self, response: str) -> str:
if response.startswith("[") and response.endswith("]"):
Expand Down
37 changes: 24 additions & 13 deletions oproof/template.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,43 @@
from jinja2 import Template as T

class Template:
INSTRUCTIONS = (
"Provide the task as plain JSON, no explanations or markdown.\n"
"Return a JSON object with the validation result.\n"
"The object should include 'is_valid', 'domain', 'context', and 'reason' fields.\n"
"No markdown or code.\n"
"Do not answer the input; only validate the response.\n"
"No explanations; only the JSON object."
SYSTEM_TYPE = (
"proof validation"
)

DOMAINS = [
"basic math",
"grammar",
"spelling"
]

SYSTEM_PROMPTS = {
"validation": "You are an expert validation system that validates responses for {{ task }}s. Validate the following response in {{ lang }}."
"validation": "You are an expert {{ system_type }} system that identifies the domain for {{ task }}. Identify the domain of the following response in {{ lang }}."
}

INSTRUCTIONS = (
"You are an expert {{ system_type }} system. Your task is to identify the domain and context of the given pair of prompt and response strings.\n"
"Return the domain and context as plain text.\n"
"Do not provide any explanations, markdown, code, or other content beside a JSON Object.\n"
"Only return the domain and context of input prompt and response pair.\n"
"Return JSON Object of type { \"domain\": domain, \"context\": context }\n"
"The domains to choose from are: {{ domains }}.\n"
)

PROMPT_TEMPLATE = (
PROMPT = (
"System: {{ system }}\n"
"Instructions: {{ instructions }}\n"
"Example: {{ example }}\n"
"Example: 'What is 2 + 2?' '4' returns 'basic math' with context 'arithmetic'\n"
"User: {{ prompt }}\n"
"Response: {{ response }}\n"
"System: Return only a JSON object with the validation result. No explanations, only JSON object; e.g., {\"is_valid\": true, \"domain\": \"basic math\", \"context\": \"arithmetic\", \"reason\": null}"
"System: Return only the domain and context. No explanations, only the domain and context; e.g., {\"domain\": \"basic math\", \"context\": \"arithmetic\"}\n"
"The domains to choose from are: {{ domains }}.\n"
)

TEMPLATES = {
"validation": T(PROMPT_TEMPLATE)
"validation": T(PROMPT)
}

TASKS = {
"proofs": "Proof the given prompt and response pair of input text strings. e.g., 'What is 2 + 2?' '4' returns {\"is_valid\": true, \"domain\": \"basic math\", \"context\": \"arithmetic\", \"reason\": null}"
"proofs": "Proof the given prompt and response pair of input text strings. e.g., 'What is 2 + 2?' '4' returns 'basic math' with context 'arithmetic'\n"
}
3 changes: 2 additions & 1 deletion oproof/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@ def _collect_validation(self, prompt: str, response: str) -> Dict[str, Any]:
'is_valid': is_valid,
'domain': domain,
'context': context,
'reason': reason
'reason': reason,
'raw_response': validation_result # Include the raw response
}

0 comments on commit e2b30db

Please sign in to comment.