Skip to content

Commit

Permalink
feat: extend QA generation types (#126)
Browse files Browse the repository at this point in the history
Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
  • Loading branch information
vagenas authored Aug 7, 2023
1 parent b93ed58 commit 8900d17
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 6 deletions.
4 changes: 2 additions & 2 deletions deepsearch/model/examples/dummy_qa_generator/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import Dict, List, Tuple

from deepsearch.model.base.types import Kind
from deepsearch.model.kinds.qagen.model import BaseQAGenerator
Expand All @@ -18,7 +18,7 @@ def get_qagen_config(self) -> QAGenConfig:
return self._config

def generate_answers(
self, texts: List[Tuple[List[str], str]]
self, texts: List[Tuple[List[Dict], str]]
) -> GenerateAnswersOutput:
"""Just answers with the question itself.
Args:
Expand Down
5 changes: 4 additions & 1 deletion deepsearch/model/kinds/qagen/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ def dispatch_predict(self, spec: CtrlPredInput) -> CtrlPredOutput:
if isinstance(spec, QAGenReqSpec):
gen_answers = spec.generateAnswers
answers = self._model.generate_answers(
[(c, q) for c, q in zip(gen_answers.contexts, gen_answers.questions)]
texts=[
([ctx_entry.dict() for ctx_entry in ctx_list], q)
for ctx_list, q in zip(gen_answers.contexts, gen_answers.questions)
],
)
return QAGenCtrlPredOutput(
answers=answers,
Expand Down
4 changes: 2 additions & 2 deletions deepsearch/model/kinds/qagen/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Tuple
from typing import Dict, List, Tuple

from deepsearch.model.base.model import BaseDSModel
from deepsearch.model.base.types import BaseModelConfig
Expand All @@ -9,7 +9,7 @@
class BaseQAGenerator(BaseDSModel):
@abstractmethod
def generate_answers(
self, texts: List[Tuple[List[str], str]]
self, texts: List[Tuple[List[Dict], str]]
) -> GenerateAnswersOutput:
raise NotImplementedError()

Expand Down
8 changes: 7 additions & 1 deletion deepsearch/model/kinds/qagen/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@
)


class ContextEntry(StrictModel):
text: str
type: str
representation_type: str


class GenerateAnswers(StrictModel):
contexts: List[List[str]]
contexts: List[List[ContextEntry]]
questions: List[str]

@root_validator
Expand Down

0 comments on commit 8900d17

Please sign in to comment.