diff --git a/farm/modeling/predictions.py b/farm/modeling/predictions.py index 7fedf7e02..adb58e6ef 100644 --- a/farm/modeling/predictions.py +++ b/farm/modeling/predictions.py @@ -3,57 +3,79 @@ from typing import List, Optional, Any from pydantic import BaseModel -class Pred(BaseModel): +class Pred: """ Base class for predictions of every task. Note that it inherits from pydantic.BaseModel which creates an __init__() with the attributes defined in this class (i.e. id, prediction, context) """ - id: str - prediction: List[Any] - context: str - + def __init__(self, + id: str, + prediction: List[Any], + context: str): + self.id = id + self.prediction = prediction + self.context = context def to_json(self): raise NotImplementedError -class QACandidate(BaseModel): +class QACandidate: """ A single QA candidate answer. Note that it inherits from pydantic.BaseModel which builds the __init__() method. See class definition to find list of compulsory and optional arguments and also comments on how they are used. """ - - # self.answer_type can be "is_impossible", "yes", "no" or "span" - answer_type: str - score: float - probability: Optional[float] = None - - # If self.answer_type is "span", self.answer is a string answer span - # Otherwise, it is None - answer: Optional[str] = None - offset_answer_start: int - offset_answer_end: int - - # If self.answer_type is in ["yes", "no"] then self.answer_support is a text string - # If self.answer is a string answer span or self.answer_type is "is_impossible", answer_support is None - # TODO sample_idx can probably be removed since we have passage_id - answer_support: Optional[str] = None - offset_answer_support_start: Optional[int] = None - offset_answer_support_end: Optional[int] = None - sample_idx: Optional[int] = None - - # self.context is the document or passage where the answer is found - context: Optional[str] = None - offset_context_start: Optional[int] = None - offset_context_end: Optional[int] = None - - # Offset unit is either "token" or "char" - # Aggregation level is either "doc" or "passage" - offset_unit: str - aggregation_level: str - - n_samples_in_doc: Optional[int] = None - document_id: Optional[str] = None - passage_id: Optional[str] = None + def __init__(self, + answer_type: str, + score: str, + offset_answer_start: int, + offset_answer_end: int, + offset_unit: str, + aggregation_level: str, + probability: float=None, + answer: str=None, + answer_support: str=None, + offset_answer_support_start: int=None, + offset_answer_support_end: int=None, + sample_idx: int=None, + context: str=None, + offset_context_start: int=None, + offset_context_end: int=None, + n_samples_in_doc: int=None, + document_id: str=None, + passage_id: str=None + ): + # self.answer_type can be "is_impossible", "yes", "no" or "span" + self.answer_type = answer_type + self.score = score + self.probability = probability + + # If self.answer_type is "span", self.answer is a string answer span + # Otherwise, it is None + self.answer = answer + self.offset_answer_start = offset_answer_start + self.offset_answer_end = offset_answer_end + + # If self.answer_type is in ["yes", "no"] then self.answer_support is a text string + # If self.answer is a string answer span or self.answer_type is "is_impossible", answer_support is None + # TODO sample_idx can probably be removed since we have passage_id + self.answer_support = answer_support + self.offset_answer_support_start = offset_answer_support_start + self.offset_answer_support_end = offset_answer_support_end + self.sample_idx = sample_idx + + # self.context is the document or passage where the answer is found + self.context = context + self.offset_context_start = offset_context_start + self.offset_context_end = offset_context_end + + # Offset unit is either "token" or "char" + # Aggregation level is either "doc" or "passage" + self.offset_unit = offset_unit + self.aggregation_level = aggregation_level + + self.n_samples_in_doc = n_samples_in_doc + self.document_id = document_id + self.passage_id = passage_id def to_doc_level(self, start, end): @@ -81,16 +103,30 @@ class QAPred(Pred): the attributes are found in the Pred class and not here. Pred in turn inherits from pydantic.BaseModel which creates an __init__() method. See class definition for required and optional arguments. """ - - question: str - token_offsets: List[int] - context_window_size: int #TODO only needed for to_json() - can we get rid context_window_size, TODO Do we really need this? - aggregation_level: str - question_id: Optional[str] - answer_types: Optional[List[str]] = [] - ground_truth_answer: Optional[str] = None - no_answer_gap: Optional[float] = None - n_samples: int = None + def __init__(self, + id: str, + prediction: List[Any], + context: str, + question: str, + token_offsets: List[int], + context_window_size: int, + aggregation_level: str, + answer_types: List[str]=None, + ground_truth_answer: str =None, + no_answer_gap: float =None, + n_samples: int=None, + question_id: int=None + ): + super().__init__(id, prediction, context) + self.question = question + self.token_offsets = token_offsets + self.context_window_size = context_window_size #TODO only needed for to_json() - can we get rid context_window_size, TODO Do we really need this? + self.aggregation_level = aggregation_level + self.answer_types = answer_types + self.ground_truth_answer = ground_truth_answer + self.no_answer_gap = no_answer_gap + self.n_samples = n_samples + self.question_id = question_id def to_json(self, squad=False): answers = self.answers_to_json(squad) @@ -100,7 +136,7 @@ def to_json(self, squad=False): { "question": self.question, "question_id": self.question_id, - "ground_truth": None, + "ground_truth": self.ground_truth_answer, "answers": answers, "no_ans_gap": self.no_answer_gap # Add no_ans_gap to current no_ans_boost for switching top prediction } @@ -116,22 +152,21 @@ def answers_to_json(self, squad=False): string = qa_answer.answer start_t = qa_answer.offset_answer_start end_t = qa_answer.offset_answer_end - score = qa_answer.score _, ans_start_ch, ans_end_ch = span_to_string(start_t, end_t, self.token_offsets, self.context) context_string, context_start_ch, context_end_ch = self.create_context(ans_start_ch, ans_end_ch, self.context) if squad: if string == "is_impossible": string = "" - curr = {"score": score, - "probability": -1, + curr = {"score": qa_answer.score, + "probability": None, "answer": string, "offset_answer_start": ans_start_ch, "offset_answer_end": ans_end_ch, "context": context_string, "offset_context_start": context_start_ch, "offset_context_end": context_end_ch, - "document_id": self.id} + "document_id": qa_answer.document_id} ret.append(curr) return ret