diff --git a/farm/modeling/prediction_head.py b/farm/modeling/prediction_head.py index 00b79df65..f21037939 100644 --- a/farm/modeling/prediction_head.py +++ b/farm/modeling/prediction_head.py @@ -739,7 +739,7 @@ def __init__(self, layer_dims, task_name="question_answering", no_ans_threshold= :type no_ans_threshold: float :param context_window_size: The size, in characters, of the window around the answer span that is used when displaying the context around the answer. :type context_window_size: int - :param n_best: The number of candidate positive answer spans to consider from each passage + :param n_best: The number of candidate positive answer spans to consider from each passage. Same value used as the number of candidates to be considered on document level. :type n_best: int """ super(QuestionAnsweringHead, self).__init__() @@ -1198,11 +1198,13 @@ def reduce_preds(self, preds): for passage_preds in preds for start, end, score in passage_preds if not (start == -1 and end == -1)] - pos_answers_sorted = sorted(pos_answers_flat, key=lambda x: x[2], reverse=True) + + pos_answer_dedup = self.deduplicate(pos_answers_flat) + pos_answers_sorted = sorted(pos_answer_dedup, key=lambda x: x[2], reverse=True) pos_answers_reduced = pos_answers_sorted[:self.n_best] no_answer_pred = [-1, -1, max(no_answer_scores)] - # TODO this is how big the no_answer threshold needs to be to change a no_answer to a pos answer + # This is how big the no_answer threshold needs to be to change a no_answer to a pos answer # (or vice versa). This can in future be used to train the threshold value no_ans_gap = max([nas - pbs for nas, pbs in zip(no_answer_scores, passage_best_score)]) @@ -1212,6 +1214,21 @@ def reduce_preds(self, preds): n_preds = pos_answers_reduced return n_preds, no_ans_gap + @staticmethod + def deduplicate(flat_pos_answers): + # Remove duplicate spans that might be twice predicted in two different passages + seen = {} + for (start, end, score) in flat_pos_answers: + if (start, end) not in seen: + seen[(start, end)] = score + else: + seen_score = seen[(start, end)] + if score > seen_score: + seen[(start, end)] = score + return [(start, end, score) for (start, end), score in seen.items()] + + + ## THIS IS A SIMPLER IMPLEMENTATION OF PICKING BEST ANSWERS FOR A DOCUMENT. MATCHES THE HUGGINGFACE METHOD # @staticmethod # def reduce_preds(preds, n_best=5):