From f04c2301b3c8ef9557adabbbefa9c7af685ed468 Mon Sep 17 00:00:00 2001 From: Branden Chan <33759007+brandenchan@users.noreply.github.com> Date: Fri, 3 Jul 2020 13:21:08 +0200 Subject: [PATCH] Question Answering improvements - NQ3 (#419) * unify squad and nq baskets * Clean id handling * Add QAInference type hints * add input_features test --- examples/natural_questions.py | 6 +- examples/question_answering.py | 4 +- farm/data_handler/input_features.py | 120 +++++++--------- farm/data_handler/processor.py | 97 +++++++++---- farm/evaluation/metrics.py | 6 +- farm/infer.py | 22 ++- farm/modeling/adaptive_model.py | 24 ++-- farm/modeling/prediction_head.py | 98 ++++++------- farm/modeling/predictions.py | 163 ++++++++++++++++------ farm/utils.py | 25 ---- test/samples/qa/no_answer/clear_text.json | 1 + test/samples/qa/no_answer/features.json | 1 + test/samples/qa/no_answer/tokenized.json | 1 + test/samples/qa/span/clear_text.json | 1 + test/samples/qa/span/features.json | 1 + test/samples/qa/span/tokenized.json | 1 + test/samples/qa/train-sample.json | 2 +- test/test_input_features.py | 45 ++++++ test/test_question_answering.py | 1 - 19 files changed, 378 insertions(+), 241 deletions(-) create mode 100644 test/samples/qa/no_answer/clear_text.json create mode 100644 test/samples/qa/no_answer/features.json create mode 100644 test/samples/qa/no_answer/tokenized.json create mode 100644 test/samples/qa/span/clear_text.json create mode 100644 test/samples/qa/span/features.json create mode 100644 test/samples/qa/span/tokenized.json create mode 100644 test/test_input_features.py diff --git a/examples/natural_questions.py b/examples/natural_questions.py index 1a8495351..c6e70b33d 100644 --- a/examples/natural_questions.py +++ b/examples/natural_questions.py @@ -6,7 +6,7 @@ from farm.data_handler.data_silo import DataSilo from farm.data_handler.processor import NaturalQuestionsProcessor from farm.file_utils import fetch_archive_from_http -from farm.infer import Inferencer +from farm.infer import QAInferencer from farm.modeling.adaptive_model import AdaptiveModel from farm.modeling.language_model import LanguageModel from farm.modeling.optimization import initialize_optimizer @@ -68,7 +68,7 @@ def question_answering(): max_seq_len=384, train_filename=train_filename, dev_filename=dev_filename, - keep_is_impossible=keep_is_impossible, + keep_no_answer=keep_is_impossible, downsample_context_size=downsample_context_size, data_dir=Path("../data/natural_questions"), ) @@ -131,7 +131,7 @@ def question_answering(): } ] - model = Inferencer.load(model_name_or_path="../saved_models/farm/roberta-base-squad2-nq", batch_size=batch_size, gpu=True) + model = QAInferencer.load(model_name_or_path="../saved_models/farm/roberta-base-squad2-nq", batch_size=batch_size, gpu=True) result = model.inference_from_dicts(dicts=QA_input, return_json=False) # result is a list of QAPred objects print(f"\nQuestion: Did GameTrailers rated Twilight Princess as one of the best games ever created?" diff --git a/examples/question_answering.py b/examples/question_answering.py index bad2ba73c..0317780c8 100644 --- a/examples/question_answering.py +++ b/examples/question_answering.py @@ -7,7 +7,7 @@ from farm.data_handler.data_silo import DataSilo from farm.data_handler.processor import SquadProcessor from farm.data_handler.utils import write_squad_predictions -from farm.infer import Inferencer +from farm.infer import QAInferencer from farm.modeling.adaptive_model import AdaptiveModel from farm.modeling.language_model import LanguageModel from farm.modeling.optimization import initialize_optimizer @@ -110,7 +110,7 @@ def question_answering(): "context": "Twilight Princess was released to universal critical acclaim and commercial success. It received perfect scores from major publications such as 1UP.com, Computer and Video Games, Electronic Gaming Monthly, Game Informer, GamesRadar, and GameSpy. On the review aggregators GameRankings and Metacritic, Twilight Princess has average scores of 95% and 95 for the Wii version and scores of 95% and 96 for the GameCube version. GameTrailers in their review called it one of the greatest games ever created." }] - model = Inferencer.load(save_dir, batch_size=40, gpu=True) + model = QAInferencer.load(save_dir, batch_size=40, gpu=True) result = model.inference_from_dicts(dicts=QA_input)[0] pprint.pprint(result) diff --git a/farm/data_handler/input_features.py b/farm/data_handler/input_features.py index eaf09a507..669f2cad7 100644 --- a/farm/data_handler/input_features.py +++ b/farm/data_handler/input_features.py @@ -70,7 +70,6 @@ def sample_to_features_text( input_ids = pad(input_ids, max_seq_len, tokenizer.pad_token_id, pad_on_left=pad_on_left) padding_mask = pad(padding_mask, max_seq_len, 0, pad_on_left=pad_on_left) - assert len(input_ids) == max_seq_len assert len(padding_mask) == max_seq_len assert len(segment_ids) == max_seq_len @@ -307,10 +306,28 @@ def samples_to_features_bert_lm(sample, max_seq_len, tokenizer, next_sent_pred=T return [feature_dict] -def sample_to_features_qa(sample, tokenizer, max_seq_len, answer_type_list=None, max_answers=6): +def sample_to_features_qa(sample, tokenizer, max_seq_len, sp_toks_start, sp_toks_mid, + answer_type_list=None, max_answers=6): """ Prepares data for processing by the model. Supports cases where there are multiple answers for the one question/document pair. max_answers is by default set to 6 since - that is the most number of answers in the squad2.0 dev set.""" + that is the most number of answers in the squad2.0 dev set. + + :param sample: A Sample object that contains one question / passage pair + :type sample: Sample + :param tokenizer: A Tokenizer object + :type tokenizer: Tokenizer + :param max_seq_len: The maximum sequence length + :type max_seq_len: int + :param sp_toks_start: The number of special tokens that come before the question tokens + :type sp_toks_start: int + :param sp_toks_mid: The number of special tokens that come between the question and passage tokens + :type sp_toks_mid: int + :param answer_type_list: A list of all the answer types that can be expected e.g. ["no_answer", "span", "yes", "no"] for Natural Questions + :type answer_type_list: List[str] + :param max_answers: The maximum number of answer annotations for a sample (In SQuAD, this is 6 hence the default) + :type max_answers: int + :return: dict (keys: [input_ids, padding_mask, segment_ids, answer_type_ids, passage_start_t, start_of_word, labels, id, seq_2_start_2]) + """ # Initialize some basic variables question_tokens = sample.tokenized["question_tokens"] @@ -329,9 +346,10 @@ def sample_to_features_qa(sample, tokenizer, max_seq_len, answer_type_list=None, labels, answer_types = generate_labels(answers, passage_len_t, question_len_t, - tokenizer, - answer_type_list=answer_type_list, - max_answers=max_answers) + max_answers, + sp_toks_start, + sp_toks_mid, + answer_type_list) # Generate a start of word vector for the full sequence (i.e. question + answer + special tokens). # This will allow us to perform evaluation during training without clear text. @@ -382,10 +400,9 @@ def sample_to_features_qa(sample, tokenizer, max_seq_len, answer_type_list=None, if tokenizer.__class__.__name__ in ["XLMRobertaTokenizer", "RobertaTokenizer"]: segment_ids = np.zeros_like(segment_ids) - # Todo: explain how only the first of labels will be used in train, and the full array will be used in eval - # TODO Offset, start of word and spec_tok_mask are not actually needed by model.forward() but are needed for model.formatted_preds() - # TODO passage_start_t is index of passage's first token relative to document - # I don't think we actually need offsets anymore + # The first of the labels will be used in train, and the full array will be used in eval. + # start of word and spec_tok_mask are not actually needed by model.forward() but are needed for model.formatted_preds() + # passage_start_t is index of passage's first token relative to document feature_dict = {"input_ids": input_ids, "padding_mask": padding_mask, "segment_ids": segment_ids, @@ -398,19 +415,25 @@ def sample_to_features_qa(sample, tokenizer, max_seq_len, answer_type_list=None, return [feature_dict] -def generate_labels(answers, passage_len_t, question_len_t, tokenizer, max_answers, answer_type_list=None): +def generate_labels(answers, passage_len_t, question_len_t, max_answers, + sp_toks_start, sp_toks_mid, answer_type_list=None): """ - Creates QA label for each answer in answers. The labels are the index of the start and end token + Creates QA label vector for each answer in answers. The labels are the index of the start and end token relative to the passage. They are contained in an array of size (max_answers, 2). - -1 used to fill array since there the number of answers is often less than max_answers. + -1 is used to fill array since there the number of answers is often less than max_answers. The index values take in to consideration the question tokens, and also special tokens such as [CLS]. When the answer is not fully contained in the passage, or the question is impossible to answer, the start_idx and end_idx are 0 i.e. start and end are on the very first token - (in most models, this is the [CLS] token). Note that in our implementation NQ has 4 labels + (in most models, this is the [CLS] token). Note that in our implementation NQ has 4 answer types ["no_answer", "yes", "no", "span"] and this is what answer_type_list should look like""" + # Note here that label_idxs get passed to the QuestionAnsweringHead and answer_types get passed to the text + # classification head. label_idxs may contain multiple start, end labels since SQuAD dev and test sets + # can have multiple annotations. By contrast, Natural Questions only has one annotation per sample hence + # why answer_types is only of length 1 label_idxs = np.full((max_answers, 2), fill_value=-1) - answer_types = np.full((max_answers), fill_value=-1) + answer_types = np.asarray([-1]) + answer_str = "" # If there are no answers if len(answers) == 0: @@ -418,65 +441,24 @@ def generate_labels(answers, passage_len_t, question_len_t, tokenizer, max_answe answer_types[:] = 0 return label_idxs, answer_types + # Iterate over the answers for the one sample for i, answer in enumerate(answers): - answer_type = answer["answer_type"] start_idx = answer["start_t"] end_idx = answer["end_t"] - # We are going to operate on one-hot label vectors which will later be converted back to label indices. - # This is to take advantage of tokenizer.encode_plus() which applies model dependent special token conventions. - # The two label vectors (start and end) are composed of sections that correspond to the question and - # passage tokens. These are initialized here. The section corresponding to the question - # will always be composed of 0s. - start_vec_question = [0] * question_len_t - end_vec_question = [0] * question_len_t - start_vec_passage = [0] * passage_len_t - end_vec_passage = [0] * passage_len_t - - # If the answer is in the current passage, populate the label vector with 1s for start and end + # Check that the start and end are contained within this passage if answer_in_passage(start_idx, end_idx, passage_len_t): - start_vec_passage[start_idx] = 1 - end_vec_passage[end_idx] = 1 - - # Combine the sections of the label vectors. The length of each of these will be: - # question_len_t + passage_len_t + n_special_tokens - start_vec = combine_vecs(start_vec_question, - start_vec_passage, - tokenizer, - spec_tok_val=0) - end_vec = combine_vecs(end_vec_question, - end_vec_passage, - tokenizer, - spec_tok_val=0) - - start_label_present = 1 in start_vec - end_label_present = 1 in end_vec - - # This is triggered if the answer is not in the passage or the question warrants a no_answer - # In both cases, the token at idx=0 (in BERT, this is the [CLS] token) is given both the start and end label - if start_label_present is False and end_label_present is False: - start_vec[0] = 1 - end_vec[0] = 1 - answer_type = "no_answer" - elif start_label_present is False or end_label_present is False: - raise Exception("The label vectors are lacking either a start or end label") - - # Ensure label vectors are one-hot - assert sum(start_vec) == 1 - assert sum(end_vec) == 1 - - start_idx = start_vec.index(1) - end_idx = end_vec.index(1) - - label_idxs[i, 0] = start_idx - label_idxs[i, 1] = end_idx - - # Only Natural Questions trains a classification head on answer_type, SQuAD only has the QA head. answer_type_list - # will be None for SQuAD but something like ["no_answer", "span", "yes", "no"] for Natural Questions - if answer_type_list: - answer_types[i] = answer_type_list.index(answer_type) - - assert np.max(label_idxs) > -1 + label_idxs[i][0] = sp_toks_start + question_len_t + sp_toks_mid + start_idx + label_idxs[i][1] = sp_toks_start + question_len_t + sp_toks_mid + end_idx + answer_str = answer["answer_type"] + # If the start or end of the span answer is outside the passage, treat passage as no_answer + else: + label_idxs[i][0] = 0 + label_idxs[i][1] = 0 + answer_str = "no_answer" + + if answer_type_list: + answer_types[0] = answer_type_list.index(answer_str) return label_idxs, answer_types diff --git a/farm/data_handler/processor.py b/farm/data_handler/processor.py index 64ae7a0dc..8b7260c2a 100644 --- a/farm/data_handler/processor.py +++ b/farm/data_handler/processor.py @@ -40,6 +40,7 @@ get_sequence_pair, join_sentences ) + from farm.modeling.tokenization import Tokenizer, tokenize_with_metadata, truncate_sequences from farm.utils import MLFlowLogger as MlLogger from farm.utils import try_get @@ -348,6 +349,7 @@ def _log_samples(self, n_samples): random_sample = random.choice(random_basket.samples) logger.info(random_sample) + def _log_params(self): params = { "processor": self.__class__.__name__, @@ -1035,7 +1037,26 @@ def estimate_n_samples(self, filepath, max_docs=500): # QA Processors #### ######################################### -class SquadProcessor(Processor): +class QAProcessor(Processor): + """ + This is class inherits from Processor and is the parent to SquadProcessor and NaturalQuestionsProcessor. + Its main role is to extend the __init__() so that the number of starting, intermediate and end special tokens + are calculated from the tokenizer and store as attributes. These are used by the child processors in their + sample_to_features() methods + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.initialize_special_tokens_count() + + def initialize_special_tokens_count(self): + vec = self.tokenizer.build_inputs_with_special_tokens(token_ids_0=["a"], + token_ids_1=["b"]) + self.sp_toks_start = vec.index("a") + self.sp_toks_mid = vec.index("b") - self.sp_toks_start - 1 + self.sp_toks_end = len(vec) - vec.index("b") - 1 + + +class SquadProcessor(QAProcessor): """ Used to handle the SQuAD dataset""" def __init__( @@ -1159,13 +1180,15 @@ def _dict_to_samples(self, dictionary: dict, **kwargs) -> [Sample]: return samples def _sample_to_features(self, sample) -> dict: - # TODO, make this function return one set of features per sample + check_valid_answer(sample) features = sample_to_features_qa(sample=sample, tokenizer=self.tokenizer, - max_seq_len=self.max_seq_len) + max_seq_len=self.max_seq_len, + sp_toks_start=self.sp_toks_start, + sp_toks_mid=self.sp_toks_mid) return features -class NaturalQuestionsProcessor(Processor): +class NaturalQuestionsProcessor(QAProcessor): """ Used to handle the Natural Question QA dataset""" def __init__( @@ -1252,7 +1275,6 @@ def __init__( self.add_task("question_answering", "squad", ["start_token", "end_token"]) self.add_task("text_classification", "f1_macro", self.answer_type_list, label_name="answer_type") - def file_to_dicts(self, file: str) -> [dict]: dicts = read_jsonl(file, proxies=self.proxies) return dicts @@ -1270,7 +1292,7 @@ def _dict_to_samples(self, dictionary: dict, all_dicts=None) -> [Sample]: """ # Turns a NQ dictionaries into a SQuAD style dictionaries if not self.inference: - dictionary = self.prepare_dict(dictionary=dictionary) + dictionary = self._prepare_dict(dictionary=dictionary) dictionary_tokenized = apply_tokenization(dictionary, self.tokenizer)[0] n_special_tokens = self.tokenizer.num_special_tokens_to_add(pair=True) @@ -1282,16 +1304,16 @@ def _dict_to_samples(self, dictionary: dict, all_dicts=None) -> [Sample]: # Downsample the number of samples with an no_answer label. This fn will always return at least one sample # so that we don't end up with a basket with 0 samples if not self.inference: - samples = self.downsample(samples, self.keep_no_answer) + samples = self._downsample(samples, self.keep_no_answer) return samples - def downsample(self, samples, keep_prob): + def _downsample(self, samples, keep_prob): # Downsamples samples with a no_answer label (since there is an overrepresentation of these in NQ) # This method will always return at least one sample. This is done so that we don't end up with SampleBaskets # with 0 samples ret = [] for s in samples: - if self.check_no_answer_sample(s): + if self._check_no_answer_sample(s): if random_float() > 1 - keep_prob: ret.append(s) else: @@ -1300,7 +1322,7 @@ def downsample(self, samples, keep_prob): ret = [random.choice(samples)] return ret - def downsample_unprocessed(self, dictionary): + def _downsample_unprocessed(self, dictionary): doc_text = dictionary["document_text"] doc_tokens = doc_text.split(" ") annotations = dictionary.get("annotations",[]) @@ -1308,7 +1330,7 @@ def downsample_unprocessed(self, dictionary): if len(annotations) == 1: annotation = annotations[0] # There seem to be cases where there is no answer but an annotation is given as a (-1, -1) long answer - if self.check_no_answer(annotation): + if self._check_no_answer(annotation): dictionary["document_text"] = " ".join(doc_tokens[:self.max_seq_len+randint(1,self.downsample_context_size)]) else: # finding earliest start and latest end labels @@ -1341,28 +1363,28 @@ def downsample_unprocessed(self, dictionary): return dictionary - def prepare_dict(self, dictionary): + def _prepare_dict(self, dictionary): """ Casts a Natural Questions dictionary that is loaded from a jsonl file into SQuAD format so that the same featurization functions can be called for both tasks. Each annotation can be one of four answer types, ["yes", "no", "span", "no_answer"]""" if self.downsample_context_size is not None: - dictionary = self.downsample_unprocessed(dictionary) + dictionary = self._downsample_unprocessed(dictionary) converted_answers = [] doc_text = dictionary["document_text"] _, tok_to_ch = split_with_metadata(doc_text) for annotation in dictionary["annotations"]: # There seem to be cases where there is no answer but an annotation is given as a (-1, -1) long answer - if self.check_no_answer(annotation): + if self._check_no_answer(annotation): continue - sa_text, sa_start_c = self.unify_short_answers(annotation["short_answers"], doc_text, tok_to_ch) - la_text, la_start_c = self.retrieve_long_answer(annotation["long_answer"]["start_token"], - annotation["long_answer"]["end_token"], - tok_to_ch, - doc_text) + sa_text, sa_start_c = self._unify_short_answers(annotation["short_answers"], doc_text, tok_to_ch) + la_text, la_start_c = self._retrieve_long_answer(annotation["long_answer"]["start_token"], + annotation["long_answer"]["end_token"], + tok_to_ch, + doc_text) # Picks the span to be considered as annotation by choosing between short answer, long answer and no_answer - text, start_c = self.choose_span(sa_text, sa_start_c, la_text, la_start_c) + text, start_c = self._choose_span(sa_text, sa_start_c, la_text, la_start_c) converted_answers.append({"text": text, "answer_start": start_c}) if len(converted_answers) == 0: @@ -1380,7 +1402,7 @@ def prepare_dict(self, dictionary): return converted @staticmethod - def check_no_answer(annotation): + def _check_no_answer(annotation): if annotation["long_answer"]["start_token"] > -1 or annotation["long_answer"]["end_token"] > -1: return False for sa in annotation["short_answers"]: @@ -1390,7 +1412,7 @@ def check_no_answer(annotation): return True @staticmethod - def check_no_answer_sample(sample): + def _check_no_answer_sample(sample): sample_tok = sample.tokenized if len(sample_tok["answers"]) == 0: return True @@ -1404,14 +1426,14 @@ def check_no_answer_sample(sample): else: return False - def retrieve_long_answer(self, start_t, end_t, tok_to_ch, doc_text): + def _retrieve_long_answer(self, start_t, end_t, tok_to_ch, doc_text): """ Retrieves the string long answer and also its starting character index""" - start_c, end_c = self.convert_tok_to_ch(start_t, end_t, tok_to_ch, doc_text) + start_c, end_c = self._convert_tok_to_ch(start_t, end_t, tok_to_ch, doc_text) text = doc_text[start_c: end_c] return text, start_c @staticmethod - def choose_span(sa_text, sa_start_c, la_text, la_start_c): + def _choose_span(sa_text, sa_start_c, la_text, la_start_c): if sa_text: return sa_text, sa_start_c elif la_text: @@ -1419,7 +1441,7 @@ def choose_span(sa_text, sa_start_c, la_text, la_start_c): else: return "", -1 - def unify_short_answers(self, short_answers, doc_text, tok_to_ch): + def _unify_short_answers(self, short_answers, doc_text, tok_to_ch): """ In cases where an NQ sample has multiple disjoint short answers, this fn generates the single shortest span that contains all the answers""" if not short_answers: @@ -1431,13 +1453,13 @@ def unify_short_answers(self, short_answers, doc_text, tok_to_ch): short_answer_idxs.append(short_answer["end_token"]) answer_start_t = min(short_answer_idxs) answer_end_t = max(short_answer_idxs) - answer_start_c, answer_end_c = self.convert_tok_to_ch(answer_start_t, answer_end_t, tok_to_ch, doc_text) + answer_start_c, answer_end_c = self._convert_tok_to_ch(answer_start_t, answer_end_t, tok_to_ch, doc_text) answer_text = doc_text[answer_start_c: answer_end_c] assert answer_text == " ".join(doc_text.split()[answer_start_t: answer_end_t]) return answer_text, answer_start_c @staticmethod - def convert_tok_to_ch(start_t, end_t, tok_to_ch, doc_text): + def _convert_tok_to_ch(start_t, end_t, tok_to_ch, doc_text): n_tokens = len(tok_to_ch) if start_t == -1 and end_t == -1: return -1, -1 @@ -1452,9 +1474,12 @@ def convert_tok_to_ch(start_t, end_t, tok_to_ch, doc_text): return start_c, end_c def _sample_to_features(self, sample: Sample) -> dict: + check_valid_answer(sample) features = sample_to_features_qa(sample=sample, tokenizer=self.tokenizer, max_seq_len=self.max_seq_len, + sp_toks_start=self.sp_toks_start, + sp_toks_mid=self.sp_toks_mid, answer_type_list=self.answer_type_list) return features @@ -1677,3 +1702,19 @@ def is_impossible_to_answer_type(qas): q["answer_type"] = answer_type new_qas.append(q) return new_qas + + +def check_valid_answer(sample): + passage_text = sample.clear_text["passage_text"] + for answer in sample.clear_text["answers"]: + len_passage = len(passage_text) + start = answer["start_c"] + end = answer["end_c"] + # Cases where the answer is not within the current passage will be turned into no answers by the featurization fn + if start < 0 or end >= len_passage: + continue + answer_indices = passage_text[start: end + 1] + answer_text = answer["text"] + if answer_indices != answer_text: + raise ValueError(f"""Answer using start/end indices is '{answer_indices}' while gold label text is '{answer_text}'""") + diff --git a/farm/evaluation/metrics.py b/farm/evaluation/metrics.py index 311ee934c..702ab9469 100644 --- a/farm/evaluation/metrics.py +++ b/farm/evaluation/metrics.py @@ -142,9 +142,9 @@ def squad_EM(preds, labels): n_docs = len(preds) n_correct = 0 for doc_idx in range(n_docs): - span = preds[doc_idx][0][0] - pred_start = span.offset_answer_start - pred_end = span.offset_answer_end + qa_candidate = preds[doc_idx][0][0] + pred_start = qa_candidate.offset_answer_start + pred_end = qa_candidate.offset_answer_end curr_labels = labels[doc_idx] if (pred_start, pred_end) in curr_labels: n_correct += 1 diff --git a/farm/infer.py b/farm/infer.py index a6c67aa57..af765aebd 100644 --- a/farm/infer.py +++ b/farm/infer.py @@ -7,6 +7,7 @@ from torch.utils.data.sampler import SequentialSampler from tqdm import tqdm from transformers.configuration_auto import AutoConfig +from typing import Generator, List, Union from farm.data_handler.dataloader import NamedDataLoader from farm.data_handler.processor import Processor, InferenceProcessor, SquadProcessor, NERProcessor, TextClassificationProcessor @@ -16,7 +17,7 @@ from farm.modeling.optimization import optimize_model from farm.utils import initialize_device_settings from farm.utils import set_all_seeds, calc_chunksize, log_ascii_workers - +from farm.modeling.predictions import QAPred logger = logging.getLogger(__name__) @@ -594,7 +595,7 @@ def _get_predictions_and_aggregate(self, dataset, tensor_names, baskets): # can assume that we have only complete docs i.e. all the samples of one doc are in the current chunk logits = [None] preds_all = self.model.formatted_preds(logits=logits, # For QA we collected preds per batch and do not want to pass logits - preds_p=unaggregated_preds_all, + preds=unaggregated_preds_all, baskets=baskets) return preds_all @@ -624,6 +625,23 @@ def extract_vectors(self, dicts, extraction_strategy="cls_token", extraction_lay return self.inference_from_dicts(dicts) +class QAInferencer(Inferencer): + + def inference_from_dicts(self, + dicts, + return_json=True, + multiprocessing_chunksize=None, + streaming=False) -> Union[List[QAPred], Generator[QAPred, None, None]]: + return Inferencer.inference_from_dicts(self, dicts, return_json=return_json, multiprocessing_chunksize=None, streaming=False) + + def inference_from_file(self, + file, + multiprocessing_chunksize=None, + streaming=False, + return_json=True) -> Union[List[QAPred], Generator[QAPred, None, None]]: + return Inferencer.inference_from_file(self, file, return_json=return_json, multiprocessing_chunksize=None, streaming=False) + + class FasttextInferencer: def __init__(self, model, name=None): self.model = model diff --git a/farm/modeling/adaptive_model.py b/farm/modeling/adaptive_model.py index aada3c358..b1e124d3d 100644 --- a/farm/modeling/adaptive_model.py +++ b/farm/modeling/adaptive_model.py @@ -89,14 +89,16 @@ def formatted_preds(self, logits, **kwargs): elif n_heads == 1: preds_final = [] - # TODO This is very specific to QA, make more general + # This try catch is to deal with the fact that sometimes we collect preds before passing it to + # formatted_preds (see Inferencer._get_predictions_and_aggregate()) and sometimes we don't + # (see Inferencer._get_predictions()) try: - preds_p = kwargs["preds_p"] - temp = [y[0] for y in preds_p] - preds_p_flat = [item for sublist in temp for item in sublist] - kwargs["preds_p"] = preds_p_flat + preds = kwargs["preds"] + temp = [y[0] for y in preds] + preds_flat = [item for sublist in temp for item in sublist] + kwargs["preds"] = preds_flat except KeyError: - kwargs["preds_p"] = None + kwargs["preds"] = None head = self.prediction_heads[0] logits_for_head = logits[0] preds = head.formatted_preds(logits=logits_for_head, **kwargs) @@ -109,17 +111,17 @@ def formatted_preds(self, logits, **kwargs): # This case is triggered by Natural Questions else: preds_final = [list() for _ in range(n_heads)] - preds = kwargs["preds_p"] + preds = kwargs["preds"] preds_for_heads = stack(preds) logits_for_heads = [None] * n_heads samples = [s for b in kwargs["baskets"] for s in b.samples] kwargs["samples"] = samples - del kwargs["preds_p"] + del kwargs["preds"] - for i, (head, preds_p_for_head, logits_for_head) in enumerate(zip(self.prediction_heads, preds_for_heads, logits_for_heads)): - preds = head.formatted_preds(logits=logits_for_head, preds_p=preds_p_for_head, **kwargs) + for i, (head, preds_for_head, logits_for_head) in enumerate(zip(self.prediction_heads, preds_for_heads, logits_for_heads)): + preds = head.formatted_preds(logits=logits_for_head, preds=preds_for_head, **kwargs) preds_final[i].append(preds) # Look for a merge() function amongst the heads and if a single one exists, apply it to preds_final @@ -662,7 +664,7 @@ def convert_to_onnx(self, output_path, opset_version=11, optimize_for=None): { "question": "In what country is Normandy located?", "id": "56ddde6b9a695914005b9628", - "answers": [{"text": "France", "answer_start": 159}], + "answers": [], "is_impossible": False, } ], diff --git a/farm/modeling/prediction_head.py b/farm/modeling/prediction_head.py index b29ae5a5c..da9839a32 100644 --- a/farm/modeling/prediction_head.py +++ b/farm/modeling/prediction_head.py @@ -2,16 +2,18 @@ import logging import os import numpy as np + from pathlib import Path from transformers.modeling_bert import BertForPreTraining, BertLayerNorm, ACT2FN from transformers.modeling_auto import AutoModelForQuestionAnswering, AutoModelForTokenClassification, AutoModelForSequenceClassification +from typing import List import torch from torch import nn from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss from farm.data_handler.utils import is_json -from farm.utils import convert_iob_to_simple_tags, span_to_string, try_get +from farm.utils import convert_iob_to_simple_tags, try_get from farm.modeling.predictions import QACandidate, QAPred logger = logging.getLogger(__name__) @@ -337,9 +339,7 @@ def forward(self, X): def logits_to_loss(self, logits, **kwargs): label_ids = kwargs.get(self.label_tensor_name) - # In Natural Questions, each dev sample can have multiple labels - # For loss calculation we only use the first label - label_ids = label_ids.narrow(1,0,1) + label_ids = label_ids return self.loss_fct(logits, label_ids.view(-1)) def logits_to_probs(self, logits, return_class_probs, **kwargs): @@ -369,22 +369,20 @@ def prepare_labels(self, **kwargs): labels = [self.label_list[int(x[0])] for x in label_ids] return labels - def formatted_preds(self, logits=None, preds_p=None, samples=None, return_class_probs=False, **kwargs): - """ Like QuestionAnsweringHead.formatted_preds(), this fn can operate on either logits or preds_p. This + def formatted_preds(self, logits=None, preds=None, samples=None, return_class_probs=False, **kwargs): + """ Like QuestionAnsweringHead.formatted_preds(), this fn can operate on either logits or preds. This is needed since at inference, the order of operations is very different depending on whether we are performing - aggregation or not (compare Inferencer._get_predictions() vs Inferencer._get_predictions_and_aggregate()) - - TODO: Preds_p should be renamed to preds""" + aggregation or not (compare Inferencer._get_predictions() vs Inferencer._get_predictions_and_aggregate())""" - assert (logits is not None) or (preds_p is not None) + assert (logits is not None) or (preds is not None) - # When this method is used along side a QAHead at inference (e.g. Natural Questions), preds_p is the input and + # When this method is used along side a QAHead at inference (e.g. Natural Questions), preds is the input and # there is currently no good way of generating probs if logits is not None: - preds_p = self.logits_to_preds(logits) + preds = self.logits_to_preds(logits) probs = self.logits_to_probs(logits, return_class_probs) else: - probs = [None] * len(preds_p) + probs = [None] * len(preds) # TODO this block has to do with the difference in Basket and Sample structure between SQuAD and NQ try: @@ -397,10 +395,10 @@ def formatted_preds(self, logits=None, preds_p=None, samples=None, return_class_ if len(contexts_b) != 0: contexts = ["|".join([a, b]) for a,b in zip(contexts, contexts_b)] - assert len(preds_p) == len(probs) == len(contexts) + assert len(preds) == len(probs) == len(contexts) res = {"task": "text_classification", "predictions": []} - for pred, prob, context in zip(preds_p, probs, contexts): + for pred, prob, context in zip(preds, probs, contexts): if not return_class_probs: pred_dict = { "start": None, @@ -1194,11 +1192,11 @@ def valid_answer_idxs(start_idx, end_idx, n_non_padding, max_answer_length, seq_ return False return True - def formatted_preds(self, logits=None, preds_p=None, baskets=None, **kwargs): - """ Takes a list of predictions, each corresponding to one sample, and converts them into document level + def formatted_preds(self, logits=None, preds=None, baskets=None, **kwargs): + """ Takes a list of passage level predictions, each corresponding to one sample, and converts them into document level predictions. Leverages information in the SampleBaskets. Assumes that we are being passed predictions from ALL samples in the one SampleBasket i.e. all passages of a document. Logits should be None, because we have - already converted the logits to predictions before calling formatted_preds + already converted the logits to predictions before calling formatted_preds. (see Inferencer._get_predictions_and_aggregate()). """ @@ -1207,17 +1205,17 @@ def formatted_preds(self, logits=None, preds_p=None, baskets=None, **kwargs): # seq_2_start_t is the token index of the first token in passage relative to the input sequence (i.e. number of # special tokens and question tokens that come before the passage tokens) assert logits is None, "Logits are not None, something is passed wrongly into formatted_preds() in infer.py" - assert preds_p is not None, "No preds_p passed to formatted_preds()" + assert preds is not None, "No preds passed to formatted_preds()" samples = [s for b in baskets for s in b.samples] ids = [s.id for s in samples] passage_start_t = [s.features[0]["passage_start_t"] for s in samples] seq_2_start_t = [s.features[0]["seq_2_start_t"] for s in samples] # Aggregate passage level predictions to create document level predictions. - # This method assumes that all passages of each document are contained in preds_p + # This method assumes that all passages of each document are contained in preds # i.e. that there are no incomplete documents. The output of this step # are prediction spans - preds_d = self.aggregate_preds(preds_p, passage_start_t, ids, seq_2_start_t) + preds_d = self.aggregate_preds(preds, passage_start_t, ids, seq_2_start_t) assert len(preds_d) == len(baskets) @@ -1238,7 +1236,7 @@ def to_qa_preds(self, top_preds, no_ans_gaps, baskets): # Unpack document offsets, clear text and id token_offsets = basket.samples[0].tokenized["document_offsets"] - basket_id = basket.id_external if basket.id_external else basket.id_internal + pred_id = basket.id_external if basket.id_external else basket.id_internal # These options reflect the different input dicts that can be assigned to the basket # before any kind of normalization or preprocessing can happen @@ -1250,27 +1248,23 @@ def to_qa_preds(self, top_preds, no_ans_gaps, baskets): # Iterate over each prediction on the one document full_preds = [] - for qa_answer in pred_d: - # This should be a method of Span - pred_str, _, _ = span_to_string(qa_answer.offset_answer_start, - qa_answer.offset_answer_end, - token_offsets, - document_text) - qa_answer.add_answer(pred_str) - full_preds.append(qa_answer) + for qa_candidate in pred_d: + pred_str, _, _ = qa_candidate.span_to_string(token_offsets, document_text) + qa_candidate.add_answer(pred_str) + full_preds.append(qa_candidate) n_samples = full_preds[0].n_passages_in_doc - curr_doc_pred = QAPred(id=basket_id, + curr_doc_pred = QAPred(id=pred_id, + prediction=full_preds, context=document_text, question=question, token_offsets=token_offsets, context_window_size=self.context_window_size, aggregation_level="document", - answer_types=[], # TODO no_answer_gap=no_ans_gap, - n_passages=n_samples - ) + n_passages=n_samples) + ret.append(curr_doc_pred) return ret @@ -1370,15 +1364,15 @@ def reduce_preds(self, preds): # Get all predictions in flattened list and sort by score pos_answers_flat = [] for sample_idx, passage_preds in enumerate(preds): - for qa_answer in passage_preds: - if not (qa_answer.offset_answer_start == -1 and qa_answer.offset_answer_end == -1): - pos_answers_flat.append(QACandidate(offset_answer_start=qa_answer.offset_answer_start, - offset_answer_end=qa_answer.offset_answer_end, - score=qa_answer.score, - answer_type=qa_answer.answer_type, + for qa_candidate in passage_preds: + if not (qa_candidate.offset_answer_start == -1 and qa_candidate.offset_answer_end == -1): + pos_answers_flat.append(QACandidate(offset_answer_start=qa_candidate.offset_answer_start, + offset_answer_end=qa_candidate.offset_answer_end, + score=qa_candidate.score, + answer_type=qa_candidate.answer_type, offset_unit="token", aggregation_level="passage", - passage_id=sample_idx, + passage_id=str(sample_idx), n_passages_in_doc=n_samples) ) @@ -1389,7 +1383,7 @@ def reduce_preds(self, preds): no_ans_gap = -min([nas - pbs for nas, pbs in zip(no_answer_scores, passage_best_score)]) # "no answer" scores and positive answers scores are difficult to compare, because - # + a positive answer score is related to a specific text qa_answer + # + a positive answer score is related to a specific text qa_candidate # - a "no answer" score is related to all input texts # Thus we compute the "no answer" score relative to the best possible answer and adjust it by # the most significant difference between scores. @@ -1521,20 +1515,20 @@ def chunk(iterable, lengths): samples_per_doc = [doc_pred.n_passages for doc_pred in preds_all[0][0]] cls_preds_grouped = chunk(cls_preds, samples_per_doc) - for qa_doc_pred, cls_preds in zip(qa_preds, cls_preds_grouped): - pred_qa_answers = qa_doc_pred.prediction - pred_qa_answers_new = [] - for pred_qa_answer in pred_qa_answers: - passage_id = pred_qa_answer.passage_id + for qa_pred, cls_preds in zip(qa_preds, cls_preds_grouped): + qa_candidates = qa_pred.prediction + qa_candidates_new = [] + for qa_candidate in qa_candidates: + passage_id = qa_candidate.passage_id if passage_id is not None: - cls_pred = cls_preds[passage_id]["label"] + cls_pred = cls_preds[int(passage_id)]["label"] # i.e. if no_answer else: cls_pred = "no_answer" - pred_qa_answer.add_cls(cls_pred) - pred_qa_answers_new.append(pred_qa_answer) - qa_doc_pred.prediction = pred_qa_answers_new - ret.append(qa_doc_pred) + qa_candidate.add_cls(cls_pred) + qa_candidates_new.append(qa_candidate) + qa_pred.prediction = qa_candidates_new + ret.append(qa_pred) return ret diff --git a/farm/modeling/predictions.py b/farm/modeling/predictions.py index ff0c759a5..7994e2de8 100644 --- a/farm/modeling/predictions.py +++ b/farm/modeling/predictions.py @@ -1,14 +1,15 @@ -from abc import ABC from typing import List, Any +from abc import ABC import logging -from farm.utils import span_to_string logger = logging.getLogger(__name__) + + class Pred(ABC): """ - Base Abstract Class for predictions of every task. + Abstract base class for predictions of every task """ def __init__(self, @@ -26,7 +27,6 @@ def to_json(self): class QACandidate: """ A single QA candidate answer. - See class definition to find list of compulsory and optional arguments and also comments on how they are used. """ def __init__(self, @@ -37,38 +37,42 @@ def __init__(self, offset_unit: str, aggregation_level: str, probability: float=None, - answer: str=None, - answer_support: str=None, - offset_answer_support_start: int=None, - offset_answer_support_end: int=None, - context: str=None, - offset_context_start: int=None, - offset_context_end: int=None, n_passages_in_doc: int=None, passage_id: str=None, ): + """ + :param answer_type: The category that this answer falls into e.g. "no_answer", "yes", "no" or "span" + :param score: The score representing the model's confidence of this answer + :param offset_answer_start: The index of the start of the answer span (whether it is char or tok is stated in self.offset_unit) + :param offset_answer_end: The index of the start of the answer span (whether it is char or tok is stated in self.offset_unit) + :param offset_unit: States whether the offsets refer to character or token indices + :param aggregation_level: States whether this candidate and its indices are on a passage level (pre aggregation) or on a document level (post aggregation) + :param probability: The probability the model assigns to the answer + :param n_passages_in_doc: Number of passages that make up the document + :param passage_id: The id of the passage which contains this candidate answer + """ + # self.answer_type can be "no_answer", "yes", "no" or "span" self.answer_type = answer_type self.score = score self.probability = probability - # If self.answer_type is "span", self.answer is a string answer span + # If self.answer_type is "span", self.answer is a string answer (generated by self.span_to_string()) # Otherwise, it is None - self.answer = answer + self.answer = None self.offset_answer_start = offset_answer_start self.offset_answer_end = offset_answer_end # If self.answer_type is in ["yes", "no"] then self.answer_support is a text string # If self.answer is a string answer span or self.answer_type is "no_answer", answer_support is None - self.answer_support = answer_support - self.offset_answer_support_start = offset_answer_support_start - self.offset_answer_support_end = offset_answer_support_end - self.passage_id = passage_id + self.answer_support = None + self.offset_answer_support_start = None + self.offset_answer_support_end = None # self.context is the document or passage where the answer is found - self.context = context - self.offset_context_start = offset_context_start - self.offset_context_end = offset_context_end + self.context = None + self.offset_context_start = None + self.offset_context_end = None # Offset unit is either "token" or "char" # Aggregation level is either "doc" or "passage" @@ -78,13 +82,50 @@ def __init__(self, self.n_passages_in_doc = n_passages_in_doc self.passage_id = passage_id + def span_to_string(self, token_offsets: List[int], clear_text: str): + """ + Generates a string answer span using self.offset_answer_start and self.offset_answer_end. If the candidate + is a no answer, an empty string is returned + + :param token_offsets: A list of ints which give the start character index of the corresponding token + :param clear_text: The text from which the answer span is to be extracted + :return: The string answer span, followed by the start and end character indices + """ + + assert self.offset_unit == "token" + + start_t = self.offset_answer_start + end_t = self.offset_answer_end + + # If it is a no_answer prediction + if start_t == -1 and end_t == -1: + return "", 0, 0 + + n_tokens = len(token_offsets) + + # We do this to point to the beginning of the first token after the span instead of + # the beginning of the last token in the span + end_t += 1 + + # Predictions sometimes land on the very final special token of the passage. But there are no + # special tokens on the document level. We will just interpret this as a span that stretches + # to the end of the document + end_t = min(end_t, n_tokens) + + start_ch = token_offsets[start_t] + # i.e. pointing at the END of the last token + if end_t == n_tokens: + end_ch = len(clear_text) + else: + end_ch = token_offsets[end_t] + return clear_text[start_ch: end_ch].strip(), start_ch, end_ch + def add_cls(self, predicted_class: str): """ Adjust the final QA prediction depending on the prediction of the classification head (e.g. for binary answers in NQ) Currently designed so that the QA head's prediction will always be preferred over the Classification head - :param predicted_class: the predicted class value - :return: None + :param predicted_class: The predicted class e.g. "yes", "no", "no_answer", "span" """ if predicted_class in ["yes", "no"] and self.answer != "no_answer": @@ -95,13 +136,16 @@ def add_cls(self, predicted_class: str): self.offset_answer_support_end = self.offset_answer_end def to_doc_level(self, start, end): + """ Populate the start and end indices with document level indices. Changes aggregation level to 'document'""" self.offset_answer_start = start self.offset_answer_end = end self.aggregation_level = "document" def add_answer(self, string): + """ Set the answer string. This method will check that the answer given is valid given the start + and end indices that are stored in the object. """ if string == "": - self.answer = "is_impossible" + self.answer = "no_answer" if self.offset_answer_start != -1 or self.offset_answer_end != -1: logger.error(f"Something went wrong in tokenization. We have start and end offsets: " f"{self.offset_answer_start, self.offset_answer_end} with an empty answer. " @@ -118,10 +162,9 @@ def to_list(self): class QAPred(Pred): - """Question Answering predictions for a passage or a document. The self.prediction attribute is populated by a - list of QACandidate objects. Note that this object inherits from the Pred class which is why some of - the attributes are found in the Pred class and not here. - See class definition for required and optional arguments. + """ A set of QA predictions for a passage or a document. The candidates are stored in QAPred.prediction which is a + list of QACandidate objects. Also contains all attributes needed to convert the object into json format and also + to create a context window for a UI """ def __init__(self, @@ -132,16 +175,27 @@ def __init__(self, token_offsets: List[int], context_window_size: int, aggregation_level: str, - answer_types: List[str]=None, - ground_truth_answer: str =None, - no_answer_gap: float =None, - n_passages: int=None - - ): + no_answer_gap: float, + n_passages: int, + ground_truth_answer: str = None, + answer_types: List[str] = []): + """ + :param id: The id of the passage or document + :param prediction: A list of QACandidate objects for the given question and document + :param context: The text passage from which the answer can be extracted + :param question: The question being posed + :param token_offsets: A list of ints indicating the start char index of each token + :param context_window_size: The number of chars in the text window around the answer + :param aggregation_level: States whether this candidate and its indices are on a passage level (pre aggregation) or on a document level (post aggregation) + :param no_answer_gap: How much the QuestionAnsweringHead.no_ans_boost needs to change to turn a no_answer to a positive answer + :param n_passages: Number of passages in the context document + :param ground_truth_answer: Ground truth answers + :param answer_types: List of answer_types supported by this task e.g. ["span", "yes_no", "no_answer"] + """ super().__init__(id, prediction, context) self.question = question self.token_offsets = token_offsets - self.context_window_size = context_window_size # TODO only needed for to_json() - can we get rid context_window_size, TODO Do we really need this? + self.context_window_size = context_window_size self.aggregation_level = aggregation_level self.answer_types = answer_types self.ground_truth_answer = ground_truth_answer @@ -149,6 +203,13 @@ def __init__(self, self.n_passages = n_passages def to_json(self, squad=False): + """ + Converts the information stored in the object into a json format. + + :param squad: If True, no_answers are represented by the empty string instead of "no_answer" + :return: + """ + answers = self.answers_to_json(self.id, squad) ret = { "task": "qa", @@ -165,20 +226,23 @@ def to_json(self, squad=False): return ret def answers_to_json(self, id, squad=False): + """ + Convert all answers into a json format + + :param id: ID of the question document pair + :param squad: If True, no_answers are represented by the empty string instead of "no_answer" + :return: + """ + ret = [] # iterate over the top_n predictions of the one document for qa_candidate in self.prediction: - string = str(qa_candidate.answer) - start_t = qa_candidate.offset_answer_start - end_t = qa_candidate.offset_answer_end - - _, ans_start_ch, ans_end_ch = span_to_string(start_t, end_t, self.token_offsets, self.context) - context_string, context_start_ch, context_end_ch = self.create_context(ans_start_ch, - ans_end_ch, - self.context) - if squad: - if string == "no_answer": + string = qa_candidate.answer + + _, ans_start_ch, ans_end_ch = qa_candidate.span_to_string(self.token_offsets, self.context) + context_string, context_start_ch, context_end_ch = self.create_context(ans_start_ch, ans_end_ch, self.context) + if squad and string == "no_answer": string = "" curr = {"score": qa_candidate.score, "probability": None, @@ -192,7 +256,18 @@ def answers_to_json(self, id, squad=False): ret.append(curr) return ret + def create_context(self, ans_start_ch, ans_end_ch, clear_text): + """ + Extract from the clear_text a window that contains the answer and some amount of text on either + side of the answer. Useful for cases where the answer and its surrounding context needs to be + displayed in a UI. + + :param ans_start_ch: Start character index of the answer + :param ans_end_ch: End character index of the answer + :param clear_text: The text from which the answer is extracted + :return: + """ if ans_start_ch == 0 and ans_end_ch == 0: return "", 0, 0 else: diff --git a/farm/utils.py b/farm/utils.py index 8270245f7..4756f81c7 100644 --- a/farm/utils.py +++ b/farm/utils.py @@ -424,31 +424,6 @@ def stack(list_of_lists): ret[i] += (x) return ret -def span_to_string(start_t, end_t, token_offsets, clear_text): - - # If it is a no_answer prediction - if start_t == -1 and end_t == -1: - return "", 0, 0 - - n_tokens = len(token_offsets) - - # We do this to point to the beginning of the first token after the span instead of - # the beginning of the last token in the span - end_t += 1 - - # Predictions sometimes land on the very final special token of the passage. But there are no - # special tokens on the document level. We will just interpret this as a span that stretches - # to the end of the document - end_t = min(end_t, n_tokens) - - start_ch = token_offsets[start_t] - # i.e. pointing at the END of the last token - if end_t == n_tokens: - end_ch = len(clear_text) - else: - end_ch = token_offsets[end_t] - return clear_text[start_ch: end_ch].strip(), start_ch, end_ch - def try_get(keys, dictionary): for key in keys: diff --git a/test/samples/qa/no_answer/clear_text.json b/test/samples/qa/no_answer/clear_text.json new file mode 100644 index 000000000..58e508c38 --- /dev/null +++ b/test/samples/qa/no_answer/clear_text.json @@ -0,0 +1 @@ +{"passage_text": "Note: The green arrows (), red arrows (), and blue dashes () represent changes in rank when compared to the new 2012 data HDI for 2011 \u2013 published in the 2012 report.", "question_text": "What dashes do not represent changes in rank? ", "passage_id": 0, "answers": []} \ No newline at end of file diff --git a/test/samples/qa/no_answer/features.json b/test/samples/qa/no_answer/features.json new file mode 100644 index 000000000..f5c1a8534 --- /dev/null +++ b/test/samples/qa/no_answer/features.json @@ -0,0 +1 @@ +{"input_ids": [0, 2264, 385, 14829, 109, 45, 3594, 1022, 11, 7938, 116, 2, 2, 27728, 35, 20, 2272, 36486, 49038, 1275, 36486, 49038, 8, 2440, 385, 14829, 36418, 3594, 1022, 11, 7938, 77, 1118, 7, 5, 92, 1125, 414, 7951, 100, 13, 1466, 126, 1027, 11, 5, 1125, 266, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "padding_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "segment_ids": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "answer_type_ids": [0], "passage_start_t": 0, "start_of_word": [0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "labels": [[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]], "id": [1735, 0, 0], "seq_2_start_t": 13} \ No newline at end of file diff --git a/test/samples/qa/no_answer/tokenized.json b/test/samples/qa/no_answer/tokenized.json new file mode 100644 index 000000000..799aab2de --- /dev/null +++ b/test/samples/qa/no_answer/tokenized.json @@ -0,0 +1 @@ +{"passage_start_t": 0, "passage_tokens": ["Note", ":", "\u0120The", "\u0120green", "\u0120arrows", "\u0120(),", "\u0120red", "\u0120arrows", "\u0120(),", "\u0120and", "\u0120blue", "\u0120d", "ashes", "\u0120()", "\u0120represent", "\u0120changes", "\u0120in", "\u0120rank", "\u0120when", "\u0120compared", "\u0120to", "\u0120the", "\u0120new", "\u01202012", "\u0120data", "\u0120HD", "I", "\u0120for", "\u01202011", "\u0120\u00e2\u0122\u0135", "\u0120published", "\u0120in", "\u0120the", "\u01202012", "\u0120report", "."], "passage_offsets": [0, 4, 6, 10, 16, 23, 27, 31, 38, 42, 46, 51, 52, 58, 61, 71, 79, 82, 87, 92, 101, 104, 108, 112, 117, 122, 124, 126, 130, 135, 137, 147, 150, 154, 159, 165], "passage_start_of_word": [1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0], "question_tokens": ["What", "\u0120d", "ashes", "\u0120do", "\u0120not", "\u0120represent", "\u0120changes", "\u0120in", "\u0120rank", "?"], "question_offsets": [0, 5, 6, 12, 15, 19, 29, 37, 40, 44], "question_start_of_word": [1, 1, 0, 1, 1, 1, 1, 1, 1, 0], "answers": [], "document_offsets": [0, 4, 6, 10, 16, 23, 27, 31, 38, 42, 46, 51, 52, 58, 61, 71, 79, 82, 87, 92, 101, 104, 108, 112, 117, 122, 124, 126, 130, 135, 137, 147, 150, 154, 159, 165]} \ No newline at end of file diff --git a/test/samples/qa/span/clear_text.json b/test/samples/qa/span/clear_text.json new file mode 100644 index 000000000..d07a1c903 --- /dev/null +++ b/test/samples/qa/span/clear_text.json @@ -0,0 +1 @@ +{"passage_text": "Beyonc\u00e9 Giselle Knowles-Carter (/bi\u02d0\u02c8j\u0252nse\u026a/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny's Child. Managed by her father, Mathew Knowles, the group became one of the world's best-selling girl groups of all time. Their hiatus saw the release of Beyonc\u00e9's debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles \"Crazy in Love\" and \"Baby Boy\".", "question_text": "When did Beyonce start becoming popular?", "passage_id": 0, "answers": [{"text": "in the late 1990s", "start_c": 269, "end_c": 285}]} \ No newline at end of file diff --git a/test/samples/qa/span/features.json b/test/samples/qa/span/features.json new file mode 100644 index 000000000..eec1f38be --- /dev/null +++ b/test/samples/qa/span/features.json @@ -0,0 +1 @@ +{"input_ids": [0, 1779, 222, 12674, 1755, 386, 1959, 1406, 116, 2, 2, 40401, 261, 12695, 272, 354, 6591, 10690, 1634, 12, 43732, 48229, 5605, 43621, 16948, 49066, 267, 35423, 10659, 282, 1090, 35423, 10278, 73, 19417, 12, 975, 2191, 12, 28357, 43, 36, 5400, 772, 204, 6, 14130, 43, 16, 41, 470, 3250, 6, 2214, 9408, 6, 638, 3436, 8, 3390, 4, 8912, 8, 1179, 11, 2499, 6, 1184, 6, 79, 3744, 11, 1337, 6970, 8, 7950, 9150, 25, 10, 920, 6, 8, 1458, 7, 9444, 11, 5, 628, 4525, 29, 25, 483, 3250, 9, 248, 947, 387, 1816, 12, 13839, 23313, 18, 7442, 4, 1554, 4628, 30, 69, 1150, 6, 4101, 16152, 10690, 1634, 6, 5, 333, 1059, 65, 9, 5, 232, 18, 275, 12, 11393, 1816, 1134, 9, 70, 86, 4, 2667, 25224, 794, 5, 800, 9, 12674, 12695, 18, 2453, 2642, 6, 34880, 9412, 11, 3437, 36, 35153, 238, 61, 2885, 69, 25, 10, 5540, 3025, 3612, 6, 2208, 292, 12727, 4229, 8, 3520, 5, 18919, 6003, 727, 346, 12, 1264, 7695, 22, 347, 36616, 11, 3437, 113, 8, 22, 30047, 5637, 845, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], "padding_mask": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "segment_ids": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "answer_type_ids": [-1], "passage_start_t": 0, "start_of_word": [0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "labels": [[85, 89], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]], "id": [0, 0, 0], "seq_2_start_t": 11} \ No newline at end of file diff --git a/test/samples/qa/span/tokenized.json b/test/samples/qa/span/tokenized.json new file mode 100644 index 000000000..9d6ddc609 --- /dev/null +++ b/test/samples/qa/span/tokenized.json @@ -0,0 +1 @@ +{"passage_start_t": 0, "passage_tokens": ["Bey", "on", "c\u00c3\u00a9", "\u0120G", "is", "elle", "\u0120Know", "les", "-", "Carter", "\u0120(/", "bi", "\u00cb", "\u0132", "\u00cb\u012a", "j", "\u00c9", "\u0134", "n", "se", "\u00c9", "\u00aa", "/", "\u0120bee", "-", "Y", "ON", "-", "say", ")", "\u0120(", "born", "\u0120September", "\u01204", ",", "\u01201981", ")", "\u0120is", "\u0120an", "\u0120American", "\u0120singer", ",", "\u0120song", "writer", ",", "\u0120record", "\u0120producer", "\u0120and", "\u0120actress", ".", "\u0120Born", "\u0120and", "\u0120raised", "\u0120in", "\u0120Houston", ",", "\u0120Texas", ",", "\u0120she", "\u0120performed", "\u0120in", "\u0120various", "\u0120singing", "\u0120and", "\u0120dancing", "\u0120competitions", "\u0120as", "\u0120a", "\u0120child", ",", "\u0120and", "\u0120rose", "\u0120to", "\u0120fame", "\u0120in", "\u0120the", "\u0120late", "\u01201990", "s", "\u0120as", "\u0120lead", "\u0120singer", "\u0120of", "\u0120R", "&", "B", "\u0120girl", "-", "group", "\u0120Destiny", "'s", "\u0120Child", ".", "\u0120Man", "aged", "\u0120by", "\u0120her", "\u0120father", ",", "\u0120Mat", "hew", "\u0120Know", "les", ",", "\u0120the", "\u0120group", "\u0120became", "\u0120one", "\u0120of", "\u0120the", "\u0120world", "'s", "\u0120best", "-", "selling", "\u0120girl", "\u0120groups", "\u0120of", "\u0120all", "\u0120time", ".", "\u0120Their", "\u0120hiatus", "\u0120saw", "\u0120the", "\u0120release", "\u0120of", "\u0120Beyon", "c\u00c3\u00a9", "'s", "\u0120debut", "\u0120album", ",", "\u0120Danger", "ously", "\u0120in", "\u0120Love", "\u0120(", "2003", "),", "\u0120which", "\u0120established", "\u0120her", "\u0120as", "\u0120a", "\u0120solo", "\u0120artist", "\u0120worldwide", ",", "\u0120earned", "\u0120five", "\u0120Grammy", "\u0120Awards", "\u0120and", "\u0120featured", "\u0120the", "\u0120Billboard", "\u0120Hot", "\u0120100", "\u0120number", "-", "one", "\u0120singles", "\u0120\"", "C", "razy", "\u0120in", "\u0120Love", "\"", "\u0120and", "\u0120\"", "Baby", "\u0120Boy", "\"."], "passage_offsets": [0, 3, 5, 8, 9, 11, 16, 20, 23, 24, 31, 33, 35, 36, 37, 39, 40, 41, 42, 43, 45, 46, 47, 45, 48, 49, 50, 52, 53, 56, 58, 59, 64, 74, 75, 77, 81, 83, 86, 89, 98, 104, 106, 110, 116, 118, 125, 134, 138, 145, 147, 152, 156, 163, 166, 173, 175, 180, 182, 186, 196, 199, 207, 215, 219, 227, 240, 243, 245, 250, 252, 256, 261, 264, 269, 272, 276, 281, 285, 287, 290, 295, 302, 305, 306, 307, 309, 313, 314, 320, 327, 330, 335, 337, 340, 345, 348, 352, 358, 360, 363, 367, 371, 374, 376, 380, 386, 393, 397, 400, 404, 409, 412, 416, 417, 425, 430, 437, 440, 444, 448, 450, 456, 463, 467, 471, 479, 482, 487, 490, 492, 498, 503, 505, 511, 517, 520, 525, 526, 530, 533, 539, 551, 555, 558, 560, 565, 572, 581, 583, 590, 595, 602, 609, 613, 622, 626, 636, 640, 644, 650, 651, 655, 663, 664, 665, 670, 673, 677, 679, 683, 684, 689, 692], "passage_start_of_word": [1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0], "question_tokens": ["When", "\u0120did", "\u0120Beyon", "ce", "\u0120start", "\u0120becoming", "\u0120popular", "?"], "question_offsets": [0, 5, 9, 14, 17, 23, 32, 39], "question_start_of_word": [1, 1, 1, 0, 1, 1, 1, 0], "answers": [{"start_t": 74, "end_t": 78, "answer_type": "span"}], "document_offsets": [0, 3, 5, 8, 9, 11, 16, 20, 23, 24, 31, 33, 35, 36, 37, 39, 40, 41, 42, 43, 45, 46, 47, 45, 48, 49, 50, 52, 53, 56, 58, 59, 64, 74, 75, 77, 81, 83, 86, 89, 98, 104, 106, 110, 116, 118, 125, 134, 138, 145, 147, 152, 156, 163, 166, 173, 175, 180, 182, 186, 196, 199, 207, 215, 219, 227, 240, 243, 245, 250, 252, 256, 261, 264, 269, 272, 276, 281, 285, 287, 290, 295, 302, 305, 306, 307, 309, 313, 314, 320, 327, 330, 335, 337, 340, 345, 348, 352, 358, 360, 363, 367, 371, 374, 376, 380, 386, 393, 397, 400, 404, 409, 412, 416, 417, 425, 430, 437, 440, 444, 448, 450, 456, 463, 467, 471, 479, 482, 487, 490, 492, 498, 503, 505, 511, 517, 520, 525, 526, 530, 533, 539, 551, 555, 558, 560, 565, 572, 581, 583, 590, 595, 602, 609, 613, 622, 626, 636, 640, 644, 650, 651, 655, 663, 664, 665, 670, 673, 677, 679, 683, 684, 689, 692]} \ No newline at end of file diff --git a/test/samples/qa/train-sample.json b/test/samples/qa/train-sample.json index 71610b81a..0a1d2b9ff 100644 --- a/test/samples/qa/train-sample.json +++ b/test/samples/qa/train-sample.json @@ -1 +1 @@ -{"data": [{"paragraphs": [{"qas": [{"question": "In what country is Normandy located?", "id": "56ddde6b9a695914005b9628", "answers": [{"text": "France", "answer_start": 53}], "is_impossible": false}], "context": "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia."}]}]} \ No newline at end of file +{"data": [{"paragraphs": [{"qas": [{"question": "In what country is Normandy located?", "id": "56ddde6b9a695914005b9628", "answers": [{"text": "France", "answer_start": 159}], "is_impossible": false}], "context": "The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse (\"Norman\" comes from \"Norseman\") raiders and pirates from Denmark, Iceland and Norway who, under their leader Rollo, agreed to swear fealty to King Charles III of West Francia."}]}]} \ No newline at end of file diff --git a/test/test_input_features.py b/test/test_input_features.py new file mode 100644 index 000000000..bd4273c08 --- /dev/null +++ b/test/test_input_features.py @@ -0,0 +1,45 @@ +import json +import logging + +from farm.data_handler.input_features import sample_to_features_qa +from farm.data_handler.samples import Sample +from farm.modeling.tokenization import Tokenizer + + +MODEL = "roberta-base" +SP_TOKENS_START = 1 +SP_TOKENS_MID = 2 + +def to_list(x): + try: + return x.tolist() + except: + return x + +def test_sample_to_features_qa(caplog): + if caplog: + caplog.set_level(logging.CRITICAL) + + sample_types = ["span", "no_answer"] + + for sample_type in sample_types: + clear_text = json.load(open(f"samples/qa/{sample_type}/clear_text.json")) + tokenized = json.load(open(f"samples/qa/{sample_type}/tokenized.json")) + features_gold = json.load(open(f"samples/qa/{sample_type}/features.json")) + max_seq_len = len(features_gold["input_ids"]) + + tokenizer = Tokenizer.load(pretrained_model_name_or_path=MODEL, do_lower_case=False) + curr_id = "-".join([str(x) for x in features_gold["id"]]) + + s = Sample(id=curr_id, clear_text=clear_text, tokenized=tokenized) + features = sample_to_features_qa(s, tokenizer, max_seq_len, SP_TOKENS_START, SP_TOKENS_MID)[0] + features = to_list(features) + + keys = features_gold.keys() + for k in keys: + value_gold = features_gold[k] + value = to_list(features[k]) + assert value == value_gold, f"Mismatch between the {k} features in the {sample_type} test sample." + +if __name__ == "__main__": + test_sample_to_features_qa(None) diff --git a/test/test_question_answering.py b/test/test_question_answering.py index d04ce0671..397c60e86 100644 --- a/test/test_question_answering.py +++ b/test/test_question_answering.py @@ -86,7 +86,6 @@ def test_qa(caplog=None): "context": "Twilight Princess was released to universal critical acclaim and commercial success. It received perfect scores from major publications such as 1UP.com, Computer and Video Games, Electronic Gaming Monthly, Game Informer, GamesRadar, and GameSpy. On the review aggregators GameRankings and Metacritic, Twilight Princess has average scores of 95% and 95 for the Wii version and scores of 95% and 96 for the GameCube version. GameTrailers in their review called it one of the greatest games ever created.", }] - result1 = inferencer.inference_from_dicts(dicts=qa_format_1) result2 = inferencer.inference_from_dicts(dicts=qa_format_2) assert result1 == result2