Skip to content

Commit

Permalink
Implement Prediction Objects for Question Answering (#405)
Browse files Browse the repository at this point in the history
* Init

* Proper support for is_impossible answers

* create add_answer method

* Fixed population of QAAnswer

* Pred objs now pydantic

* benchmark test now passes

Co-authored-by: Timo Moeller <timo.moeller@deepset.ai>
  • Loading branch information
brandenchan and Timoeller authored Jun 15, 2020
1 parent 32ebf6e commit 71d2b8e
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 149 deletions.
2 changes: 1 addition & 1 deletion farm/data_handler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def load_from_dir(cls, load_dir):
metric=task["metric"],
label_list=task["label_list"],
label_column_name=task["label_column_name"],
text_column_name=task["text_column_name"],
text_column_name=task.get("text_column_name", None),
task_type=task["task_type"])

if processor is None:
Expand Down
8 changes: 4 additions & 4 deletions farm/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def squad_EM(preds, labels):
n_correct = 0
for doc_idx in range(n_docs):
span = preds[doc_idx][0][0]
pred_start = span.start
pred_end = span.end
pred_start = span.offset_answer_start
pred_end = span.offset_answer_end
curr_labels = labels[doc_idx]
if (pred_start, pred_end) in curr_labels:
n_correct += 1
Expand All @@ -165,8 +165,8 @@ def squad_f1(preds, labels):
def squad_f1_single(pred, label, pred_idx=0):
label_start, label_end = label
span = pred[pred_idx]
pred_start = span.start
pred_end = span.end
pred_start = span.offset_answer_start
pred_end = span.offset_answer_end

if (pred_start + pred_end == 0) or (label_start + label_end == 0):
if pred_start == label_start:
Expand Down
147 changes: 91 additions & 56 deletions farm/modeling/prediction_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from farm.data_handler.utils import is_json
from farm.utils import convert_iob_to_simple_tags, span_to_string
from farm.modeling.predictions import Span, DocumentPred
from farm.modeling.predictions import QACandidate, QAPred

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1104,16 +1104,17 @@ def logits_to_preds(self, logits, padding_mask, start_of_word, seq_2_start_t, ma
# Get the n_best candidate answers for each sample that are valid (via some heuristic checks)
for sample_idx in range(batch_size):
sample_top_n = self.get_top_candidates(sorted_candidates[sample_idx],
start_end_matrix[sample_idx],
n_non_padding[sample_idx].item(),
max_answer_length,
seq_2_start_t[sample_idx].item())
start_end_matrix[sample_idx],
n_non_padding[sample_idx].item(),
max_answer_length,
seq_2_start_t[sample_idx].item(),
sample_idx)
all_top_n.append(sample_top_n)

return all_top_n

def get_top_candidates(self, sorted_candidates, start_end_matrix,
n_non_padding, max_answer_length, seq_2_start_t):
n_non_padding, max_answer_length, seq_2_start_t, sample_idx):
""" Returns top candidate answers as a list of Span objects. Operates on a matrix of summed start and end logits.
This matrix corresponds to a single sample (includes special tokens, question tokens, passage tokens).
This method always returns a list of len n_best + 1 (it is comprised of the n_best positive answers along with the one no_answer)"""
Expand All @@ -1136,10 +1137,22 @@ def get_top_candidates(self, sorted_candidates, start_end_matrix,
# Check that the candidate's indices are valid and save them if they are
if self.valid_answer_idxs(start_idx, end_idx, n_non_padding, max_answer_length, seq_2_start_t):
score = start_end_matrix[start_idx, end_idx].item()
top_candidates.append(Span(start_idx, end_idx, score, unit="token", level="passage"))
top_candidates.append(QACandidate(offset_answer_start=start_idx,
offset_answer_end=end_idx,
score=score,
answer_type="span",
offset_unit="token",
aggregation_level="passage",
sample_idx=sample_idx))

no_answer_score = start_end_matrix[0, 0].item()
top_candidates.append(Span(0, 0, no_answer_score, unit="token", pred_str="", level="passage"))
top_candidates.append(QACandidate(offset_answer_start=0,
offset_answer_end=0,
score=no_answer_score,
answer_type="is_impossible",
offset_unit="token",
aggregation_level="passage",
sample_idx=None))

return top_candidates

Expand Down Expand Up @@ -1210,12 +1223,12 @@ def formatted_preds(self, logits=None, preds_p=None, baskets=None, **kwargs):
top_preds, no_ans_gaps = zip(*preds_d)

# Takes document level prediction spans and returns string predictions
doc_preds = self.to_doc_preds(top_preds, no_ans_gaps, baskets)
doc_preds = self.to_qa_preds(top_preds, no_ans_gaps, baskets)

return doc_preds

def to_doc_preds(self, top_preds, no_ans_gaps, baskets):
""" Groups Span objects together in a DocumentPred object """
def to_qa_preds(self, top_preds, no_ans_gaps, baskets):
""" Groups Span objects together in a QAPred object """
ret = []

# Iterate over each set of document level prediction
Expand Down Expand Up @@ -1254,19 +1267,27 @@ def to_doc_preds(self, top_preds, no_ans_gaps, baskets):

# Iterate over each prediction on the one document
full_preds = []
for span, basket in zip(pred_d, baskets):
for qa_answer, basket in zip(pred_d, baskets):
# This should be a method of Span
pred_str, _, _ = span_to_string(span.start, span.end, token_offsets, document_text)
span.pred_str = pred_str
full_preds.append(span)
curr_doc_pred = DocumentPred(id=basket_id,
preds=full_preds,
document_text=document_text,
question=question,
no_ans_gap=no_ans_gap,
token_offsets=token_offsets,
context_window_size=self.context_window_size,
question_id=question_id)
pred_str, _, _ = span_to_string(qa_answer.offset_answer_start,
qa_answer.offset_answer_end,
token_offsets,
document_text)
qa_answer.add_answer(pred_str)
full_preds.append(qa_answer)
n_samples = full_preds[0].n_samples_in_doc
curr_doc_pred = QAPred(id=basket_id,
prediction=full_preds,
context=document_text,
question=question,
question_id=question_id,
token_offsets=token_offsets,
context_window_size=self.context_window_size,
aggregation_level="document",
answer_types=[], # TODO
no_answer_gap=no_ans_gap,
n_samples=n_samples
)
ret.append(curr_doc_pred)
return ret

Expand Down Expand Up @@ -1363,12 +1384,19 @@ def reduce_preds(self, preds):
passage_best_score.append(best_pred_score)

# Get all predictions in flattened list and sort by score
pos_answers_flat = [
Span(span.start, span.end, span.score, sample_idx, n_samples, unit="token", level="passage")
for sample_idx, passage_preds in enumerate(preds)
for span in passage_preds
if not (span.start == -1 and span.end == -1)
]
pos_answers_flat = []
for sample_idx, passage_preds in enumerate(preds):
for qa_answer in passage_preds:
if not (qa_answer.offset_answer_start == -1 and qa_answer.offset_answer_end == -1):
pos_answers_flat.append(QACandidate(offset_answer_start=qa_answer.offset_answer_start,
offset_answer_end=qa_answer.offset_answer_end,
score=qa_answer.score,
answer_type=qa_answer.answer_type,
offset_unit="token",
aggregation_level="passage",
sample_idx=sample_idx,
n_samples_in_doc=n_samples)
)

# TODO add switch for more variation in answers, e.g. if varied_ans then never return overlapping answers
pos_answer_dedup = self.deduplicate(pos_answers_flat)
Expand All @@ -1377,13 +1405,20 @@ def reduce_preds(self, preds):
no_ans_gap = -min([nas - pbs for nas, pbs in zip(no_answer_scores, passage_best_score)])

# "no answer" scores and positive answers scores are difficult to compare, because
# + a positive answer score is related to a specific text span
# + a positive answer score is related to a specific text qa_answer
# - a "no answer" score is related to all input texts
# Thus we compute the "no answer" score relative to the best possible answer and adjust it by
# the most significant difference between scores.
# Most significant difference: change top prediction from "no answer" to answer (or vice versa)
best_overall_positive_score = max(x.score for x in pos_answer_dedup)
no_answer_pred = Span(-1, -1, best_overall_positive_score - no_ans_gap, None, n_samples, unit="token")
no_answer_pred = QACandidate(offset_answer_start=-1,
offset_answer_end=-1,
score=best_overall_positive_score - no_ans_gap,
answer_type="is_impossible",
offset_unit="token",
aggregation_level="document",
sample_idx=None,
n_samples_in_doc=n_samples)

# Add no answer to positive answers, sort the order and return the n_best
n_preds = [no_answer_pred] + pos_answer_dedup
Expand All @@ -1395,14 +1430,14 @@ def reduce_preds(self, preds):
def deduplicate(flat_pos_answers):
# Remove duplicate spans that might be twice predicted in two different passages
seen = {}
for span in flat_pos_answers:
if (span.start, span.end) not in seen:
seen[(span.start, span.end)] = span
for qa_answer in flat_pos_answers:
if (qa_answer.offset_answer_start, qa_answer.offset_answer_end) not in seen:
seen[(qa_answer.offset_answer_start, qa_answer.offset_answer_end)] = qa_answer
else:
seen_score = seen[(span.start, span.end)].score
if span.score > seen_score:
seen[(span.start, span.end)] = span
return [span for span in seen.values()]
seen_score = seen[(qa_answer.offset_answer_start, qa_answer.offset_answer_end)].score
if qa_answer.score > seen_score:
seen[(qa_answer.offset_answer_start, qa_answer.offset_answer_end)] = qa_answer
return list(seen.values())



Expand All @@ -1429,10 +1464,10 @@ def deduplicate(flat_pos_answers):

@staticmethod
def get_no_answer_score(preds):
for span in preds:
start = span.start
end = span.end
score = span.score
for qa_answer in preds:
start = qa_answer.offset_answer_start
end = qa_answer.offset_answer_end
score = qa_answer.score
if start == -1 and end == -1:
return score
raise Exception
Expand All @@ -1441,12 +1476,11 @@ def get_no_answer_score(preds):
def pred_to_doc_idxs(pred, passage_start_t):
""" Converts the passage level predictions to document level predictions. Note that on the doc level we
don't have special tokens or question tokens. This means that a no answer
cannot be prepresented by a (0,0) span but will instead be represented by (-1, -1)"""
cannot be prepresented by a (0,0) qa_answer but will instead be represented by (-1, -1)"""
new_pred = []
for span in pred:
start = span.start
end = span.end
score = span.score
for qa_answer in pred:
start = qa_answer.offset_answer_start
end = qa_answer.offset_answer_end
if start == 0:
start = -1
else:
Expand All @@ -1457,7 +1491,8 @@ def pred_to_doc_idxs(pred, passage_start_t):
else:
end += passage_start_t
assert start >= 0
new_pred.append(Span(start, end, score, level="doc"))
qa_answer.to_doc_level(start, end)
new_pred.append(qa_answer)
return new_pred

@staticmethod
Expand Down Expand Up @@ -1503,18 +1538,18 @@ def chunk(iterable, lengths):
cls_preds_grouped = chunk(cls_preds, samples_per_doc)

for qa_doc_pred, cls_preds in zip(qa_preds, cls_preds_grouped):
pred_spans = qa_doc_pred.preds
pred_spans_new = []
for pred_span in pred_spans:
sample_idx = pred_span.sample_idx
pred_qa_answers = qa_doc_pred.prediction
pred_qa_answers_new = []
for pred_qa_answer in pred_qa_answers:
sample_idx = pred_qa_answer.sample_idx
if sample_idx is not None:
cls_pred = cls_preds[sample_idx]["label"]
# i.e. if is_impossible
else:
cls_pred = None
pred_span.classification = cls_pred
pred_spans_new.append(pred_span)
qa_doc_pred.preds = pred_spans_new
cls_pred = "is_impossible"
pred_qa_answer.answer = cls_pred
pred_qa_answers_new.append(pred_qa_answer)
qa_doc_pred.prediction = pred_qa_answers_new
ret.append(qa_doc_pred)
return ret

Expand Down
Loading

0 comments on commit 71d2b8e

Please sign in to comment.