From 2195728129509227fcaa26142df8538000be6e9b Mon Sep 17 00:00:00 2001 From: p3n Date: Thu, 12 Sep 2024 01:05:54 -0400 Subject: [PATCH] switch --prompts to --raw makes more sense --- oproof/args.py | 2 +- oproof/constants.py | 4 ++-- oproof/main.py | 6 +++--- oproof/response.py | 2 +- oproof/serializer.py | 10 +++++----- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/oproof/args.py b/oproof/args.py index 7de0fa9..0edf46c 100644 --- a/oproof/args.py +++ b/oproof/args.py @@ -7,7 +7,7 @@ def __init__(self): self.parser.add_argument(Const.ARG_PROMPT_TEXT, type=str, help=Const.ARG_PROMPT_TEXT_HELP) self.parser.add_argument(Const.ARG_RESPONSE_TEXT, type=str, help=Const.ARG_RESPONSE_TEXT_HELP) self.parser.add_argument(Const.ARG_DEBUG, action="store_true", help=Const.ARG_DEBUG_HELP) - self.parser.add_argument(Const.ARG_PROMPTS, action="store_true", help=Const.ARG_PROMPTS_HELP) + self.parser.add_argument(Const.ARG_RAW, action="store_true", help=Const.ARG_RAW_HELP) self.args = None def parse(self): diff --git a/oproof/constants.py b/oproof/constants.py index 524070a..cc740d3 100644 --- a/oproof/constants.py +++ b/oproof/constants.py @@ -20,8 +20,8 @@ class Const: ARG_TEXT_HELP = "Input text" ARG_DEBUG = "--debug" ARG_DEBUG_HELP = "Enable debug logging" - ARG_PROMPTS = "--prompts" - ARG_PROMPTS_HELP = "Include prompts in the output JSON" + ARG_RAW = "--raw" + ARG_RAW_HELP = "Include raw response in the output JSON" ARG_DESCRIPTION = "Oproof script" ARG_PROMPT_TEXT = "prompt" ARG_PROMPT_TEXT_HELP = "Input prompt" diff --git a/oproof/main.py b/oproof/main.py index 8c23cf4..7386163 100644 --- a/oproof/main.py +++ b/oproof/main.py @@ -28,15 +28,15 @@ def run(parsed_args): Log.setup(parsed_args.debug) if parsed_args.debug: Log.start_main_function() - main._execute(parsed_args.prompt, parsed_args.response, parsed_args.debug, parsed_args.prompts) + main._execute(parsed_args.prompt, parsed_args.response, parsed_args.debug, parsed_args.raw) except Exception as e: handle_error(e, parsed_args.debug) - def _execute(self, prompt: str, response: str, debug: bool, include_prompts: bool) -> None: + def _execute(self, prompt: str, response: str, debug: bool, include_raw: bool) -> None: try: self.manager.check_version() validation_result = self.manager.validate_response(prompt, response) - final_result = Serializer.serialize_output(prompt, [validation_result], include_prompts) + final_result = Serializer.serialize_output(prompt, [validation_result], include_raw) json_output = json.dumps(final_result, indent=2, separators=(',', ': ')) console.print(JSON(json_output)) except ValidationError as e: diff --git a/oproof/response.py b/oproof/response.py index 6fadd9f..7e14140 100644 --- a/oproof/response.py +++ b/oproof/response.py @@ -19,7 +19,7 @@ def get_response_data(self) -> Dict[str, Any]: "context": self.parsed_response.get("context", "unknown") } - if self.parsed_response.get("reason") is not None: + if not self.parsed_response.get("is_valid", False) and self.parsed_response.get("reason") is not None: response_data["reason"] = self.parsed_response.get("reason") if self.output.get("raw_response") is not None: diff --git a/oproof/serializer.py b/oproof/serializer.py index e4e238c..1e51325 100644 --- a/oproof/serializer.py +++ b/oproof/serializer.py @@ -2,18 +2,18 @@ class Serializer: @staticmethod - def serialize_output(text: str, responses: List[Dict[str, Any]], include_prompts: bool) -> Dict[str, Any]: + def serialize_output(text: str, responses: List[Dict[str, Any]], include_raw: bool) -> Dict[str, Any]: result = { "original_text": text, "responses": [ - Serializer._serialize_response(response, include_prompts) + Serializer._serialize_response(response, include_raw) for response in responses ] } return result @staticmethod - def _serialize_response(response: Dict[str, Any], include_prompts: bool) -> Dict[str, Any]: + def _serialize_response(response: Dict[str, Any], include_raw: bool) -> Dict[str, Any]: serialized_response = { "prompt": response.get("prompt", ""), "response": response.get("response", ""), @@ -22,10 +22,10 @@ def _serialize_response(response: Dict[str, Any], include_prompts: bool) -> Dict "context": response.get("context", "unknown") } - if response.get("reason") is not None: + if not response.get("is_valid", False) and response.get("reason") is not None: serialized_response["reason"] = response.get("reason") - if include_prompts and response.get("raw_response") is not None: + if include_raw and response.get("raw_response") is not None: serialized_response["raw_response"] = response.get("raw_response") return serialized_response