Skip to content

Commit

Permalink
feat: extend QA gen types (#133)
Browse files Browse the repository at this point in the history
Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
  • Loading branch information
vagenas authored Sep 1, 2023
1 parent e663d80 commit 43eeb57
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 8 deletions.
21 changes: 17 additions & 4 deletions deepsearch/model/examples/dummy_qa_generator/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Dict, List, Tuple
from typing import Any, Dict, List, Tuple

from deepsearch.model.base.types import Kind
from deepsearch.model.kinds.qagen.model import BaseQAGenerator
from deepsearch.model.kinds.qagen.types import GenerateAnswersOutput, QAGenConfig
from deepsearch.model.kinds.qagen.types import (
GenerateAnswersOutEntry,
GenerateAnswersOutput,
QAGenConfig,
)


class DummyQAGenerator(BaseQAGenerator):
Expand All @@ -18,10 +22,19 @@ def get_qagen_config(self) -> QAGenConfig:
return self._config

def generate_answers(
self, texts: List[Tuple[List[Dict], str]]
self,
texts: List[Tuple[List[Dict], str]],
extras: Dict[str, Any],
) -> GenerateAnswersOutput:
"""Just answers with the question itself.
Args:
texts: a list of context, question pairs.
extras: any extras to pass.
"""
return [question for _, question in texts]
return [
GenerateAnswersOutEntry(
answer=question,
metadata={"foo": "bar"},
)
for _, question in texts
]
1 change: 1 addition & 0 deletions deepsearch/model/kinds/qagen/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def dispatch_predict(self, spec: CtrlPredInput) -> CtrlPredOutput:
([ctx_entry.dict() for ctx_entry in ctx_list], q)
for ctx_list, q in zip(gen_answers.contexts, gen_answers.questions)
],
extras=gen_answers.extras or {},
)
return QAGenCtrlPredOutput(
answers=answers,
Expand Down
4 changes: 2 additions & 2 deletions deepsearch/model/kinds/qagen/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Dict, List, Tuple
from typing import Any, Dict, List, Tuple

from deepsearch.model.base.model import BaseDSModel
from deepsearch.model.base.types import BaseModelConfig
Expand All @@ -9,7 +9,7 @@
class BaseQAGenerator(BaseDSModel):
@abstractmethod
def generate_answers(
self, texts: List[Tuple[List[Dict], str]]
self, texts: List[Tuple[List[Dict], str]], extras: Dict[str, Any]
) -> GenerateAnswersOutput:
raise NotImplementedError()

Expand Down
10 changes: 8 additions & 2 deletions deepsearch/model/kinds/qagen/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Literal
from typing import Any, Dict, List, Literal, Optional

from pydantic import root_validator

Expand All @@ -20,6 +20,7 @@ class ContextEntry(StrictModel):
class GenerateAnswers(StrictModel):
contexts: List[List[ContextEntry]]
questions: List[str]
extras: Optional[Dict[str, Any]] = None

@root_validator
def check_lengths_match(cls, values):
Expand All @@ -38,7 +39,12 @@ class QAGenAppPredInput(BaseAppPredInput):
spec: QAGenReqSpec


GenerateAnswersOutput = List[str]
class GenerateAnswersOutEntry(StrictModel):
answer: str
metadata: Dict[str, Any]


GenerateAnswersOutput = List[GenerateAnswersOutEntry]


class QAGenCtrlPredOutput(StrictModel):
Expand Down

0 comments on commit 43eeb57

Please sign in to comment.