Skip to content

Commit

Permalink
feat: all either dict or interface class to be passed
Browse files Browse the repository at this point in the history
  • Loading branch information
nsantacruz committed Feb 22, 2024
1 parent 56c798d commit 3f8c76d
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Union
from dataclasses import dataclass
from sefaria_llm_interface import Topic
from sefaria_llm_interface.topic_prompt import TopicPromptSource
Expand All @@ -10,7 +10,7 @@ class TopicPromptInput:
topic: Topic
sources: List[TopicPromptSource]

def __init__(self, lang: str, topic: dict, sources: List[dict]):
def __init__(self, lang: str, topic: dict, sources: List[Union[dict, TopicPromptSource]]):
self.lang = lang
self.topic = Topic(**topic)
self.sources = [TopicPromptSource(**raw_source) for raw_source in sources]
self.sources = [s if isinstance(s, TopicPromptSource) else TopicPromptSource(**s) for s in sources]
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Classes for instantiating objects received from the topic prompt generator
"""
from typing import List
from typing import List, Union
from dataclasses import dataclass


Expand All @@ -18,6 +18,6 @@ class TopicPromptGenerationOutput:
lang: str
prompts: List[TopicPrompt]

def __init__(self, lang: str, prompts: List[dict]):
def __init__(self, lang: str, prompts: List[Union[dict, TopicPrompt]]):
self.lang = lang
self.prompts = [TopicPrompt(**raw_prompt) for raw_prompt in prompts]
self.prompts = [p if isinstance(p, TopicPrompt) else TopicPrompt(**p) for p in prompts]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict, Optional
from typing import List, Dict, Optional, Union
from dataclasses import dataclass


Expand All @@ -23,7 +23,7 @@ class TopicPromptSource:

def __init__(self, ref: str, categories: List[str], book_description: Dict[str, str], book_title: Dict[str, str],
comp_date: str, author_name: str, context_hint: str, text: Dict[str, str],
commentary: List[dict] = None, surrounding_text: Dict[str, str]=None):
commentary: List[Union[dict, TopicPromptCommentary]] = None, surrounding_text: Dict[str, str]=None):
self.ref = ref
self.categories = categories
self.book_description = book_description
Expand All @@ -33,7 +33,7 @@ def __init__(self, ref: str, categories: List[str], book_description: Dict[str,
self.context_hint = context_hint
self.text = text
self.commentary = [
TopicPromptCommentary(**comment)
for comment in (commentary or [])
c if isinstance(c, TopicPromptCommentary) else TopicPromptCommentary(**c)
for c in (commentary or [])
]
self.surrounding_text = surrounding_text
2 changes: 1 addition & 1 deletion app/topic_prompt/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ def generate_topic_prompts(raw_topic_prompt_input: dict) -> dict:
tp_input = TopicPromptInput(**raw_topic_prompt_input)
toprompt_options_list = get_toprompts(tp_input)
# only return the first option for now
toprompts = [options.toprompts[0].serialize() for options in toprompt_options_list]
toprompts = [TopicPrompt(**options.toprompts[0].serialize()) for options in toprompt_options_list]
output = TopicPromptGenerationOutput(lang=tp_input.lang, prompts=toprompts)
return asdict(output)

0 comments on commit 3f8c76d

Please sign in to comment.