diff --git a/app/topic_prompt/topic_prompt_generator.py b/app/topic_prompt/topic_prompt_generator.py index 7c70f8e..cc7cd4f 100644 --- a/app/topic_prompt/topic_prompt_generator.py +++ b/app/topic_prompt/topic_prompt_generator.py @@ -8,6 +8,7 @@ from topic_prompt.toprompt_llm_prompt import TopromptLLMPrompt, get_output_parser from topic_prompt.toprompt import Toprompt, TopromptOptions from topic_prompt.differentiate_writing import repeated_phrase +from util.general import escape_json_inner_quotes from langchain.prompts import PromptTemplate from basic_langchain.chat_models import ChatOpenAI @@ -32,7 +33,7 @@ def _get_toprompt_options(lang: str, topic: Topic, source: TopicPromptSource, ot responses += [HumanMessage(content=secondary_prompt.format())] output_parser = get_output_parser() - parsed_output = output_parser.parse(curr_response.content) + parsed_output = output_parser.parse(escape_json_inner_quotes(curr_response.content)) parsed_output.title = _remove_colon_from_title_with_validation(responses, parsed_output.title) topic_prompts += [Toprompt(topic, source, parsed_output.why, parsed_output.what, parsed_output.title)] @@ -46,7 +47,7 @@ def _get_toprompt_options(lang: str, topic: Topic, source: TopicPromptSource, ot partial_variables={"phrase": phrase_to_avoid, "format_instructions": get_output_parser().get_format_instructions()}) curr_response = llm([human_message] + responses + [HumanMessage(content=avoid_prompt.format())]) output_parser = get_output_parser() - parsed_output = output_parser.parse(curr_response.content) + parsed_output = output_parser.parse(escape_json_inner_quotes(curr_response.content)) parsed_output.title = _remove_colon_from_title_with_validation(responses + [curr_response], parsed_output.title) topic_prompts[-1] = Toprompt(topic, source, parsed_output.why, parsed_output.what, parsed_output.title) diff --git a/app/util/general.py b/app/util/general.py index 7ef1d21..43ecf99 100644 --- a/app/util/general.py +++ b/app/util/general.py @@ -8,6 +8,23 @@ from basic_langchain.schema import SystemMessage, HumanMessage +def escape_json_inner_quotes(json_string): + """ + Given a JSON string, escape all double quotes that are in values to avoid invalid JSON + Assumes JSON is pretty for + + :param json_string: + :return: + """ + pattern = r'(:\s*")(.*?)(?="[,}\n])' + + def escape_quotes(match): + # Escape quotes within the matched group + return match.group(1) + match.group(2).replace('"', '\\"') + + return re.sub(pattern, escape_quotes, json_string) + + def get_source_text_with_fallback(source: TopicPromptSource, lang: str, auto_translate=False) -> str: text = source.text.get(lang, "") other_lang = "en" if lang == "he" else "he"