From a3b76b1a8adc66d5e7bb45c200d69d7fa459ad75 Mon Sep 17 00:00:00 2001 From: Malte Pietsch Date: Mon, 26 Oct 2020 18:06:22 +0100 Subject: [PATCH] Refactor DPR modeling for multigpu training (#601) * change defaults in dpr example. implement max_samples in processor. simplify prediction head * move similarity score calc out of forward pass * fix variable name * adjust dpr test to new forward pass * update docstrings * update example script * update typehints --- examples/dpr_encoder.py | 6 ++-- farm/data_handler/processor.py | 2 +- farm/data_handler/utils.py | 5 +-- farm/modeling/prediction_head.py | 62 +++++++++++++++++++++++--------- farm/train.py | 2 ++ test/test_dpr.py | 11 +++--- 6 files changed, 62 insertions(+), 26 deletions(-) diff --git a/examples/dpr_encoder.py b/examples/dpr_encoder.py index 764a1e2e8..eb04225a9 100644 --- a/examples/dpr_encoder.py +++ b/examples/dpr_encoder.py @@ -42,6 +42,7 @@ def dense_passage_retrieval(): similarity_function = "dot_product" train_filename = "nq-train.json" dev_filename = "nq-dev.json" + max_samples = None #load a smaller dataset (e.g. for debugging) # 1.Create question and passage tokenizers query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=question_lang_model, @@ -56,7 +57,7 @@ def dense_passage_retrieval(): metric = "text_similarity_metric" processor = TextSimilarityProcessor(tokenizer=query_tokenizer, passage_tokenizer=context_tokenizer, - max_seq_len=512, + max_seq_len=256, label_list=label_list, metric=metric, data_dir="data/retriever", @@ -64,7 +65,8 @@ def dense_passage_retrieval(): dev_filename=dev_filename, test_filename=dev_filename, embed_title=embed_title, - num_hard_negatives=num_hard_negatives) + num_hard_negatives=num_hard_negatives, + max_samples=max_samples) # 3. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them and calculates a few descriptive statistics of our datasets # NOTE: In FARM, the dev set metrics differ from test set metrics in that they are calculated on a token level instead of a word level diff --git a/farm/data_handler/processor.py b/farm/data_handler/processor.py index 5ff14eca0..f6657a5a4 100644 --- a/farm/data_handler/processor.py +++ b/farm/data_handler/processor.py @@ -1956,7 +1956,7 @@ def file_to_dicts(self, file: str) -> [dict]: {"text": document_text, "title": xxx, "label": "hard_negative", "external_id": abb134}, ...]} """ - dicts = read_dpr_json(file) + dicts = read_dpr_json(file, max_samples=self.max_samples) return dicts def _normalize_question(self, question: str) -> str: diff --git a/farm/data_handler/utils.py b/farm/data_handler/utils.py index bec25baeb..25d25ce83 100644 --- a/farm/data_handler/utils.py +++ b/farm/data_handler/utils.py @@ -183,7 +183,7 @@ def read_ner_file(filename, sep="\t", proxies=None): data.append({"text": " ".join(sentence), "ner_label": label}) return data -def read_dpr_json(file, proxies=None): +def read_dpr_json(file, max_samples=None, proxies=None): """ Reads a Dense Passage Retrieval (DPR) data file in json format and returns a list of dictionaries. @@ -215,7 +215,8 @@ def read_dpr_json(file, proxies=None): logger.info(f" Couldn't find {file} locally. Trying to download ...") _download_extract_downstream_data(file, proxies=proxies) dicts = json.load(open(file)) - + if max_samples: + dicts = random.sample(dicts, min(max_samples, len(dicts))) # convert DPR dictionary to standard dictionary query_json_keys = ["question", "questions", "query"] positive_context_json_keys = ["positive_contexts", "positive_ctxs", "positive_context", "positive_ctx"] diff --git a/farm/modeling/prediction_head.py b/farm/modeling/prediction_head.py index 0e9244289..583c6956a 100644 --- a/farm/modeling/prediction_head.py +++ b/farm/modeling/prediction_head.py @@ -6,7 +6,7 @@ from pathlib import Path from transformers.modeling_bert import BertForPreTraining, ACT2FN from transformers.modeling_auto import AutoModelForQuestionAnswering, AutoModelForTokenClassification, AutoModelForSequenceClassification -from typing import List +from typing import List, Tuple import torch from torch import nn @@ -1557,6 +1557,9 @@ def pick_single_fn(heads, fn_name): class TextSimilarityHead(PredictionHead): + """ + Trains a head on predicting the similarity of two texts like in Dense Passage Retrieval. + """ def __init__(self, similarity_function="dot_product", **kwargs): super(TextSimilarityHead, self).__init__() @@ -1612,21 +1615,35 @@ def get_similarity_function(self): elif "cosine" in self.similarity_function: return TextSimilarityHead.cosine_scores - def forward(self, query_vectors, context_vectors): + def forward(self, query_vectors:torch.Tensor, context_vectors:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ - Computes the log softmax similarity scores between two 2-dimensional tensors + Only packs the embeddings from both language models into a tuple. No further modification. + The similarity calculation is handled later to enable distributed training (DDP) + while keeping the support for in-batch negatives. + (Gather all embeddings from nodes => then do similarity scores + loss) - :param query_vectors: tensor of query embeddings from BiAdaptive model + :param query_vectors: Tensor of query embeddings from BiAdaptive model of dimension n1 x D, where n1 is the number of queries/batch size and D is embedding size :type query_vectors: torch.Tensor - :param context_vectors: tensor of context/passage embeddings from BiAdaptive model + :param context_vectors: Tensor of context/passage embeddings from BiAdaptive model of dimension n2 x D, where n2 is the number of queries/batch size and D is embedding size :type context_vectors: torch.Tensor - :return: log softmax similarity score of each query with each context/passage (dimension: n1xn2) + :return: (query_vectors, context_vectors) + """ + return (query_vectors, context_vectors) + + def _embeddings_to_scores(self, query_vectors:torch.Tensor, context_vectors:torch.Tensor): + """ + Calculates similarity scores between all given query_vectors and context_vectors + + :param query_vectors: Tensor of queries encoded by the query encoder model + :param context_vectors: Tensor of passages encoded by the passage encoder model + :return: Tensor of log softmax similarity scores of each query with each passage (dimension: n1xn2) """ + sim_func = self.get_similarity_function() scores = sim_func(query_vectors, context_vectors) @@ -1637,22 +1654,31 @@ def forward(self, query_vectors, context_vectors): softmax_scores = nn.functional.log_softmax(scores, dim=1) return softmax_scores - def logits_to_loss(self, logits, **kwargs): + def logits_to_loss(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs): """ - Computes the loss from similarity scores + Computes the loss (Default: NLLLoss) by applying a similarity function (Default: dot product) to the input + tuple of (query_vectors, context_vectors) and afterwards applying the loss function on similarity scores. - :param logits: tensor of log softmax similarity scores of each query with each context/passage (dimension: n1xn2) - :type logits: torch.Tensor + :param logits: Tuple of Tensors (query_embedding, context_embedding) as returned from forward() :return: negative log likelihood loss from similarity scores """ + # Prepare predicted scores + query_vectors, context_vectors = logits + softmax_scores = self._embeddings_to_scores(query_vectors, context_vectors) + + # Prepare Labels lm_label_ids = kwargs.get(self.label_tensor_name) - positive_idx_per_question = (lm_label_ids.view(-1) == 1).nonzero() - loss = self.loss_fct(logits, - torch.tensor(positive_idx_per_question).squeeze(-1).to(logits.device)) + positive_idx_per_question = torch.nonzero((lm_label_ids.view(-1) == 1), as_tuple=False) + #TODO gather global tensors from all nodes for DDP + global_positive_idx_per_question = positive_idx_per_question + targets = global_positive_idx_per_question.squeeze(-1).to(softmax_scores.device) + + # Calculate loss + loss = self.loss_fct(softmax_scores, targets) return loss - def logits_to_preds(self, logits, **kwargs): + def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs): """ Returns predicted ranks(similarity) of passages/context for each query @@ -1661,8 +1687,10 @@ def logits_to_preds(self, logits, **kwargs): :return: predicted ranks of passages for each query """ - _, logits_sorted_indices = torch.sort(logits, dim=1, descending=True) - return logits_sorted_indices + query_vectors, context_vectors = logits + softmax_scores = self._embeddings_to_scores(query_vectors, context_vectors) + _, sorted_scores = torch.sort(softmax_scores, dim=1, descending=True) + return sorted_scores def prepare_labels(self, **kwargs): """ @@ -1677,5 +1705,5 @@ def prepare_labels(self, **kwargs): labels[i, indx.item()] = 1 return labels - def formatted_preds(self, logits, **kwargs): + def formatted_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs): raise NotImplementedError("formatted_preds is not supported in TextSimilarityHead yet!") \ No newline at end of file diff --git a/farm/train.py b/farm/train.py index d459fba28..455a738fb 100644 --- a/farm/train.py +++ b/farm/train.py @@ -13,6 +13,7 @@ from farm.data_handler.data_silo import DataSilo from farm.visual.ascii.images import GROWING_TREE from farm.modeling.adaptive_model import AdaptiveModel +from farm.modeling.biadaptive_model import BiAdaptiveModel from farm.modeling.optimization import get_scheduler try: @@ -248,6 +249,7 @@ def train(self): # connect the prediction heads with the right output from processor self.model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True) # Check that the tokenizer fits the language model + #TODO: make this compliant for DP / DDP where the model class is wrapped if self.model._get_name() == 'BiAdaptiveModel': self.model.verify_vocab_size(vocab_size1=len(self.data_silo.processor.tokenizer), vocab_size2=len(self.data_silo.processor.passage_tokenizer)) diff --git a/test/test_dpr.py b/test/test_dpr.py index 12530ee57..e68c01e81 100644 --- a/test/test_dpr.py +++ b/test/test_dpr.py @@ -9,7 +9,6 @@ from farm.modeling.tokenization import Tokenizer from farm.utils import set_all_seeds, initialize_device_settings from farm.data_handler.dataset import convert_features_to_dataset -from transformers import DPRConfig def test_dpr_modules(caplog=None): if caplog: @@ -111,9 +110,13 @@ def test_dpr_modules(caplog=None): 0.3350, -0.3412]), torch.ones((1, 10)) * 0.0001)) # test logits and loss - logits = model(**features) - loss = model.logits_to_loss_per_head(logits, **features) - similarity_scores = logits[0].cpu() + embeddings = model(**features) + query_emb, passage_emb = embeddings[0] + assert torch.all(torch.eq(query_emb.cpu(), query_vector.cpu())) + assert torch.all(torch.eq(passage_emb.cpu(), passage_vector.cpu())) + + loss = model.logits_to_loss_per_head(embeddings, **features) + similarity_scores = model.prediction_heads[0]._embeddings_to_scores(query_emb, passage_emb).cpu() assert torch.all(torch.le(similarity_scores - torch.tensor([[-1.8311e-03, -6.3016e+00]]), torch.ones((1, 2)) * 0.0001)) assert (loss[0].item() - 0.0018) <= 0.0001