From 9569bb2674f980eab84edbcd55c1060230d73ce1 Mon Sep 17 00:00:00 2001 From: pingpingy1 Date: Mon, 11 Dec 2023 21:28:12 +0900 Subject: [PATCH] [feat] add testing function --- hopre/core.py | 44 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/hopre/core.py b/hopre/core.py index 9c8724f..e4436ad 100644 --- a/hopre/core.py +++ b/hopre/core.py @@ -1,6 +1,6 @@ """User interface for HoPRe""" -import os +import os, sys, getopt, json from swiplserver import PrologMQI, PrologThread from hopre.utils import assert_all_json, tokenize @@ -155,5 +155,45 @@ def hopre() -> None: print() +def test(filename: str | None = None, encoding: str = "UTF-8") -> None: + """Perform all tests given in the JSON file. + + :param filename: Path to JSON file, if provided + :param encoding: Encoding of JSON file + """ + if not filename: + filename = os.path.join(BASE_DIR, "punTestSuite.json") + + print(f"Running tests in {filename}...\n") + + with open(filename, "r", encoding=encoding) as f: + testcases = json.load(f) + + assert isinstance(testcases, list) + + with PrologMQI(output_file_name=os.path.join(BASE_DIR, "output.log")) as mqi: + with mqi.create_thread() as pthread: + init_kb(pthread) + + for tc in testcases: + assert "input" in tc.keys() + assert "joke" in tc.keys() + + sentences = [tokenize(s) for s in tc["input"]] + result, hopre_ans = check_pun(sentences, pthread) + + print(f'Provided input:\n{tc["input"]}') + print(f'Answer: {"Pun" if tc["joke"] else "Not a pun"}') + print(f'HoPRe: {"Pun" if hopre_ans else "Not a pun"}') + if hopre_ans: + print(f"Analysis: {result}") + print(f'Status: {"SUCCESS" if hopre_ans == tc["joke"] else "FAIL"}') + print() + + if __name__ == "__main__": - hopre() + opts, _ = getopt.getopt(sys.argv[1:], "t", ["test"]) + if opts: + test() + else: + hopre()