Skip to content

Commit

Permalink
Refactor DPR modeling for multigpu training (#601)
Browse files Browse the repository at this point in the history
* change defaults in dpr example. implement max_samples in processor. simplify prediction head

* move similarity score calc out of forward pass

* fix variable name

* adjust dpr test to new forward pass

* update docstrings

* update example script

* update typehints
  • Loading branch information
tholor authored Oct 26, 2020
1 parent dac388a commit a3b76b1
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 26 deletions.
6 changes: 4 additions & 2 deletions examples/dpr_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def dense_passage_retrieval():
similarity_function = "dot_product"
train_filename = "nq-train.json"
dev_filename = "nq-dev.json"
max_samples = None #load a smaller dataset (e.g. for debugging)

# 1.Create question and passage tokenizers
query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=question_lang_model,
Expand All @@ -56,15 +57,16 @@ def dense_passage_retrieval():
metric = "text_similarity_metric"
processor = TextSimilarityProcessor(tokenizer=query_tokenizer,
passage_tokenizer=context_tokenizer,
max_seq_len=512,
max_seq_len=256,
label_list=label_list,
metric=metric,
data_dir="data/retriever",
train_filename=train_filename,
dev_filename=dev_filename,
test_filename=dev_filename,
embed_title=embed_title,
num_hard_negatives=num_hard_negatives)
num_hard_negatives=num_hard_negatives,
max_samples=max_samples)

# 3. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them and calculates a few descriptive statistics of our datasets
# NOTE: In FARM, the dev set metrics differ from test set metrics in that they are calculated on a token level instead of a word level
Expand Down
2 changes: 1 addition & 1 deletion farm/data_handler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,7 +1956,7 @@ def file_to_dicts(self, file: str) -> [dict]:
{"text": document_text, "title": xxx, "label": "hard_negative", "external_id": abb134},
...]}
"""
dicts = read_dpr_json(file)
dicts = read_dpr_json(file, max_samples=self.max_samples)
return dicts

def _normalize_question(self, question: str) -> str:
Expand Down
5 changes: 3 additions & 2 deletions farm/data_handler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def read_ner_file(filename, sep="\t", proxies=None):
data.append({"text": " ".join(sentence), "ner_label": label})
return data

def read_dpr_json(file, proxies=None):
def read_dpr_json(file, max_samples=None, proxies=None):
"""
Reads a Dense Passage Retrieval (DPR) data file in json format and returns a list of dictionaries.
Expand Down Expand Up @@ -215,7 +215,8 @@ def read_dpr_json(file, proxies=None):
logger.info(f" Couldn't find {file} locally. Trying to download ...")
_download_extract_downstream_data(file, proxies=proxies)
dicts = json.load(open(file))

if max_samples:
dicts = random.sample(dicts, min(max_samples, len(dicts)))
# convert DPR dictionary to standard dictionary
query_json_keys = ["question", "questions", "query"]
positive_context_json_keys = ["positive_contexts", "positive_ctxs", "positive_context", "positive_ctx"]
Expand Down
62 changes: 45 additions & 17 deletions farm/modeling/prediction_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from transformers.modeling_bert import BertForPreTraining, ACT2FN
from transformers.modeling_auto import AutoModelForQuestionAnswering, AutoModelForTokenClassification, AutoModelForSequenceClassification
from typing import List
from typing import List, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -1557,6 +1557,9 @@ def pick_single_fn(heads, fn_name):


class TextSimilarityHead(PredictionHead):
"""
Trains a head on predicting the similarity of two texts like in Dense Passage Retrieval.
"""
def __init__(self, similarity_function="dot_product", **kwargs):
super(TextSimilarityHead, self).__init__()

Expand Down Expand Up @@ -1612,21 +1615,35 @@ def get_similarity_function(self):
elif "cosine" in self.similarity_function:
return TextSimilarityHead.cosine_scores

def forward(self, query_vectors, context_vectors):
def forward(self, query_vectors:torch.Tensor, context_vectors:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the log softmax similarity scores between two 2-dimensional tensors
Only packs the embeddings from both language models into a tuple. No further modification.
The similarity calculation is handled later to enable distributed training (DDP)
while keeping the support for in-batch negatives.
(Gather all embeddings from nodes => then do similarity scores + loss)
:param query_vectors: tensor of query embeddings from BiAdaptive model
:param query_vectors: Tensor of query embeddings from BiAdaptive model
of dimension n1 x D,
where n1 is the number of queries/batch size and D is embedding size
:type query_vectors: torch.Tensor
:param context_vectors: tensor of context/passage embeddings from BiAdaptive model
:param context_vectors: Tensor of context/passage embeddings from BiAdaptive model
of dimension n2 x D,
where n2 is the number of queries/batch size and D is embedding size
:type context_vectors: torch.Tensor
:return: log softmax similarity score of each query with each context/passage (dimension: n1xn2)
:return: (query_vectors, context_vectors)
"""
return (query_vectors, context_vectors)

def _embeddings_to_scores(self, query_vectors:torch.Tensor, context_vectors:torch.Tensor):
"""
Calculates similarity scores between all given query_vectors and context_vectors
:param query_vectors: Tensor of queries encoded by the query encoder model
:param context_vectors: Tensor of passages encoded by the passage encoder model
:return: Tensor of log softmax similarity scores of each query with each passage (dimension: n1xn2)
"""

sim_func = self.get_similarity_function()
scores = sim_func(query_vectors, context_vectors)

Expand All @@ -1637,22 +1654,31 @@ def forward(self, query_vectors, context_vectors):
softmax_scores = nn.functional.log_softmax(scores, dim=1)
return softmax_scores

def logits_to_loss(self, logits, **kwargs):
def logits_to_loss(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs):
"""
Computes the loss from similarity scores
Computes the loss (Default: NLLLoss) by applying a similarity function (Default: dot product) to the input
tuple of (query_vectors, context_vectors) and afterwards applying the loss function on similarity scores.
:param logits: tensor of log softmax similarity scores of each query with each context/passage (dimension: n1xn2)
:type logits: torch.Tensor
:param logits: Tuple of Tensors (query_embedding, context_embedding) as returned from forward()
:return: negative log likelihood loss from similarity scores
"""
# Prepare predicted scores
query_vectors, context_vectors = logits
softmax_scores = self._embeddings_to_scores(query_vectors, context_vectors)

# Prepare Labels
lm_label_ids = kwargs.get(self.label_tensor_name)
positive_idx_per_question = (lm_label_ids.view(-1) == 1).nonzero()
loss = self.loss_fct(logits,
torch.tensor(positive_idx_per_question).squeeze(-1).to(logits.device))
positive_idx_per_question = torch.nonzero((lm_label_ids.view(-1) == 1), as_tuple=False)
#TODO gather global tensors from all nodes for DDP
global_positive_idx_per_question = positive_idx_per_question
targets = global_positive_idx_per_question.squeeze(-1).to(softmax_scores.device)

# Calculate loss
loss = self.loss_fct(softmax_scores, targets)
return loss

def logits_to_preds(self, logits, **kwargs):
def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs):
"""
Returns predicted ranks(similarity) of passages/context for each query
Expand All @@ -1661,8 +1687,10 @@ def logits_to_preds(self, logits, **kwargs):
:return: predicted ranks of passages for each query
"""
_, logits_sorted_indices = torch.sort(logits, dim=1, descending=True)
return logits_sorted_indices
query_vectors, context_vectors = logits
softmax_scores = self._embeddings_to_scores(query_vectors, context_vectors)
_, sorted_scores = torch.sort(softmax_scores, dim=1, descending=True)
return sorted_scores

def prepare_labels(self, **kwargs):
"""
Expand All @@ -1677,5 +1705,5 @@ def prepare_labels(self, **kwargs):
labels[i, indx.item()] = 1
return labels

def formatted_preds(self, logits, **kwargs):
def formatted_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs):
raise NotImplementedError("formatted_preds is not supported in TextSimilarityHead yet!")
2 changes: 2 additions & 0 deletions farm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from farm.data_handler.data_silo import DataSilo
from farm.visual.ascii.images import GROWING_TREE
from farm.modeling.adaptive_model import AdaptiveModel
from farm.modeling.biadaptive_model import BiAdaptiveModel
from farm.modeling.optimization import get_scheduler

try:
Expand Down Expand Up @@ -248,6 +249,7 @@ def train(self):
# connect the prediction heads with the right output from processor
self.model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True)
# Check that the tokenizer fits the language model
#TODO: make this compliant for DP / DDP where the model class is wrapped
if self.model._get_name() == 'BiAdaptiveModel':
self.model.verify_vocab_size(vocab_size1=len(self.data_silo.processor.tokenizer),
vocab_size2=len(self.data_silo.processor.passage_tokenizer))
Expand Down
11 changes: 7 additions & 4 deletions test/test_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from farm.modeling.tokenization import Tokenizer
from farm.utils import set_all_seeds, initialize_device_settings
from farm.data_handler.dataset import convert_features_to_dataset
from transformers import DPRConfig

def test_dpr_modules(caplog=None):
if caplog:
Expand Down Expand Up @@ -111,9 +110,13 @@ def test_dpr_modules(caplog=None):
0.3350, -0.3412]), torch.ones((1, 10)) * 0.0001))

# test logits and loss
logits = model(**features)
loss = model.logits_to_loss_per_head(logits, **features)
similarity_scores = logits[0].cpu()
embeddings = model(**features)
query_emb, passage_emb = embeddings[0]
assert torch.all(torch.eq(query_emb.cpu(), query_vector.cpu()))
assert torch.all(torch.eq(passage_emb.cpu(), passage_vector.cpu()))

loss = model.logits_to_loss_per_head(embeddings, **features)
similarity_scores = model.prediction_heads[0]._embeddings_to_scores(query_emb, passage_emb).cpu()
assert torch.all(torch.le(similarity_scores - torch.tensor([[-1.8311e-03, -6.3016e+00]]), torch.ones((1, 2)) * 0.0001))
assert (loss[0].item() - 0.0018) <= 0.0001

Expand Down

0 comments on commit a3b76b1

Please sign in to comment.