Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change context in QA formatted preds to not split words #138

Merged
merged 1 commit into from
Nov 11, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 40 additions & 25 deletions farm/modeling/prediction_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,8 @@ def __init__(self,
task_name="question_answering",
no_answer_shift=0,
top_n_predictions=3,
context_size=100,
context_size=20,
max_ans_len=30,
**kwargs):
"""
:param layer_dims: dimensions of Feed Forward block, e.g. [768,2], for adjusting to BERT embedding. Output should be always 2
Expand All @@ -727,8 +728,8 @@ def __init__(self,
:type no_answer_shift: int
:param top_n_predictions: When we split a document into multiple passages we can return top n passage answers
:type top_n_predictions: int
:param context_size: When we format predictions back to string space we also return surrounding context
of size context_size
:param context_size: When we format predictions back to string space we also return "context_size" surrounding
number of subword tokens (it starts and ends on normal word boundaries though)
:type context_size: int
:param kwargs: placeholder for passing generic parameters
:type kwargs: object
Expand All @@ -743,10 +744,11 @@ def __init__(self,
) # predicts start and end token of answer
self.task_name = task_name
self.no_answer_shift = no_answer_shift # how much we want to upweight no answer logit scores compared to text answer ones
self.top_n_predictions = top_n_predictions #for how many passages we want to get predictions
self.context_size = context_size
self.max_ans_len = 1000 # disabling max ans len. Impact on squad performance seems minor
self.top_n_predictions = top_n_predictions
# each answer is returned with surrounding context. In # characters surrounding the answer
self.context_size = context_size
self.max_ans_len = max_ans_len

self.generate_config()


Expand Down Expand Up @@ -857,7 +859,7 @@ def logits_to_preds(self, logits, **kwargs):
best_answer_sum = np.zeros(num_per_batch)
# check if start or end point to the context. Context starts at segment id == 1 (question comes before at segment ids == 0)
context_start = np.argmax(segment_ids,axis=1)
context_end = segment_ids.shape[1] - np.argmax(segment_ids[::-1],axis=1)
context_end = segment_ids.shape[1] - np.argmax(segment_ids[:,::-1],axis=1)
start_proposals = self._get_best_textanswer_indices(start_logits, 3)
end_proposals = self._get_best_textanswer_indices(end_logits, 3)
best_indices = np.zeros((num_per_batch,2),dtype=int)
Expand Down Expand Up @@ -959,38 +961,51 @@ def formatted_preds(self, logits, preds, samples):
passage_pred["probability"] = -1 # TODO add probabilities that make sense : )
try:
#default to returning no answer
start = 0
end = 0
context_start = 0
start_char = 0
end_char = 0
context_start_char = 0
answer = ""
context = ""
if(s_i + e_i > 0):
current_start = int(s_i + passage_shift_i - question_shift_i)
current_end = int(e_i + passage_shift_i - question_shift_i) + 1
start_token = int(s_i + passage_shift_i - question_shift_i)
end_token = int(e_i + passage_shift_i - question_shift_i) + 1
temptext = " ".join(current_sample.clear_text["doc_tokens"])
start = current_sample.tokenized["offsets"][current_start]
start_char = current_sample.tokenized["offsets"][start_token]
# if the last end token is predicted we cannot take the char offset of the following word
if current_end >= len(current_sample.tokenized["offsets"]):
end = len(temptext)
if end_token >= len(current_sample.tokenized["offsets"]):
end_char = len(temptext)
else:
end = current_sample.tokenized["offsets"][current_end]
end_char = current_sample.tokenized["offsets"][end_token]
# we want the answer in original string space (containing newline, tab or multiple
# whitespace. So we need to join doc tokens and work with character offsets
answer = temptext[start:end]
answer = temptext[start_char:end_char]
answer = answer.strip()
# sometimes we strip trailing whitespaces, so we need to adjust end
end = start + len(answer)
context_start = int(np.clip((start-self.context_size),a_min=0,a_max=None))
context_end = int(np.clip(end +self.context_size,a_max=len(temptext),a_min=None))
context = temptext[context_start:context_end]
end_char = start_char + len(answer)
context_start_token = int(np.clip((start_token - self.context_size),a_min=0,a_max=None))
while not current_sample.tokenized["start_of_word"][context_start_token]:
context_start_token -= 1
context_start_char = current_sample.tokenized["offsets"][context_start_token]
context_end_token = int(np.clip(end_token + self.context_size,
a_min=None,
a_max=len(current_sample.tokenized["offsets"]) -1))
while not current_sample.tokenized["start_of_word"][context_end_token]:
context_end_token += 1
if context_end_token >= len(current_sample.tokenized["offsets"]):
break
if context_end_token >= len(current_sample.tokenized["offsets"]):
context_end_char = len(temptext)
else:
context_end_char = current_sample.tokenized["offsets"][context_end_token]
context = temptext[context_start_char:context_end_char]
except IndexError as e:
logger.info(e)
passage_pred["answer"] = answer
passage_pred["offset_start"] = start
passage_pred["offset_end"] = end
passage_pred["offset_start"] = start_char
passage_pred["offset_end"] = end_char
passage_pred["context"] = context
passage_pred["offset_context_start"] = start - context_start
passage_pred["offset_context_end"] = end - context_start
passage_pred["offset_context_start"] = context_start_char
passage_pred["offset_context_end"] = context_end_char
passage_pred["document_id"] = current_sample.clear_text.get("document_id", None)
passage_predictions.append(passage_pred)

Expand Down