From 3f8c76da3d96670074b440f97cb3357e06e3f57c Mon Sep 17 00:00:00 2001 From: nsantacruz Date: Thu, 22 Feb 2024 14:42:58 +0200 Subject: [PATCH] feat: all either dict or interface class to be passed --- .../topic_prompt/topic_prompt_input.py | 6 +++--- .../topic_prompt/topic_prompt_output.py | 6 +++--- .../topic_prompt/topic_prompt_source.py | 8 ++++---- app/topic_prompt/tasks.py | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/app/llm_interface/sefaria_llm_interface/topic_prompt/topic_prompt_input.py b/app/llm_interface/sefaria_llm_interface/topic_prompt/topic_prompt_input.py index e675934..912b499 100644 --- a/app/llm_interface/sefaria_llm_interface/topic_prompt/topic_prompt_input.py +++ b/app/llm_interface/sefaria_llm_interface/topic_prompt/topic_prompt_input.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Union from dataclasses import dataclass from sefaria_llm_interface import Topic from sefaria_llm_interface.topic_prompt import TopicPromptSource @@ -10,7 +10,7 @@ class TopicPromptInput: topic: Topic sources: List[TopicPromptSource] - def __init__(self, lang: str, topic: dict, sources: List[dict]): + def __init__(self, lang: str, topic: dict, sources: List[Union[dict, TopicPromptSource]]): self.lang = lang self.topic = Topic(**topic) - self.sources = [TopicPromptSource(**raw_source) for raw_source in sources] + self.sources = [s if isinstance(s, TopicPromptSource) else TopicPromptSource(**s) for s in sources] diff --git a/app/llm_interface/sefaria_llm_interface/topic_prompt/topic_prompt_output.py b/app/llm_interface/sefaria_llm_interface/topic_prompt/topic_prompt_output.py index cf95f1a..03e0058 100644 --- a/app/llm_interface/sefaria_llm_interface/topic_prompt/topic_prompt_output.py +++ b/app/llm_interface/sefaria_llm_interface/topic_prompt/topic_prompt_output.py @@ -1,7 +1,7 @@ """ Classes for instantiating objects received from the topic prompt generator """ -from typing import List +from typing import List, Union from dataclasses import dataclass @@ -18,6 +18,6 @@ class TopicPromptGenerationOutput: lang: str prompts: List[TopicPrompt] - def __init__(self, lang: str, prompts: List[dict]): + def __init__(self, lang: str, prompts: List[Union[dict, TopicPrompt]]): self.lang = lang - self.prompts = [TopicPrompt(**raw_prompt) for raw_prompt in prompts] + self.prompts = [p if isinstance(p, TopicPrompt) else TopicPrompt(**p) for p in prompts] diff --git a/app/llm_interface/sefaria_llm_interface/topic_prompt/topic_prompt_source.py b/app/llm_interface/sefaria_llm_interface/topic_prompt/topic_prompt_source.py index 4111497..9118b19 100644 --- a/app/llm_interface/sefaria_llm_interface/topic_prompt/topic_prompt_source.py +++ b/app/llm_interface/sefaria_llm_interface/topic_prompt/topic_prompt_source.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union from dataclasses import dataclass @@ -23,7 +23,7 @@ class TopicPromptSource: def __init__(self, ref: str, categories: List[str], book_description: Dict[str, str], book_title: Dict[str, str], comp_date: str, author_name: str, context_hint: str, text: Dict[str, str], - commentary: List[dict] = None, surrounding_text: Dict[str, str]=None): + commentary: List[Union[dict, TopicPromptCommentary]] = None, surrounding_text: Dict[str, str]=None): self.ref = ref self.categories = categories self.book_description = book_description @@ -33,7 +33,7 @@ def __init__(self, ref: str, categories: List[str], book_description: Dict[str, self.context_hint = context_hint self.text = text self.commentary = [ - TopicPromptCommentary(**comment) - for comment in (commentary or []) + c if isinstance(c, TopicPromptCommentary) else TopicPromptCommentary(**c) + for c in (commentary or []) ] self.surrounding_text = surrounding_text diff --git a/app/topic_prompt/tasks.py b/app/topic_prompt/tasks.py index 60580ee..cf19819 100644 --- a/app/topic_prompt/tasks.py +++ b/app/topic_prompt/tasks.py @@ -9,6 +9,6 @@ def generate_topic_prompts(raw_topic_prompt_input: dict) -> dict: tp_input = TopicPromptInput(**raw_topic_prompt_input) toprompt_options_list = get_toprompts(tp_input) # only return the first option for now - toprompts = [options.toprompts[0].serialize() for options in toprompt_options_list] + toprompts = [TopicPrompt(**options.toprompts[0].serialize()) for options in toprompt_options_list] output = TopicPromptGenerationOutput(lang=tp_input.lang, prompts=toprompts) return asdict(output)