Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DPR modeling for multigpu training #601

Merged
merged 8 commits into from
Oct 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/dpr_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def dense_passage_retrieval():
similarity_function = "dot_product"
train_filename = "nq-train.json"
dev_filename = "nq-dev.json"
max_samples = None #load a smaller dataset (e.g. for debugging)

# 1.Create question and passage tokenizers
query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=question_lang_model,
Expand All @@ -56,15 +57,16 @@ def dense_passage_retrieval():
metric = "text_similarity_metric"
processor = TextSimilarityProcessor(tokenizer=query_tokenizer,
passage_tokenizer=context_tokenizer,
max_seq_len=512,
max_seq_len=256,
label_list=label_list,
metric=metric,
data_dir="data/retriever",
train_filename=train_filename,
dev_filename=dev_filename,
test_filename=dev_filename,
embed_title=embed_title,
num_hard_negatives=num_hard_negatives)
num_hard_negatives=num_hard_negatives,
max_samples=max_samples)

# 3. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them and calculates a few descriptive statistics of our datasets
# NOTE: In FARM, the dev set metrics differ from test set metrics in that they are calculated on a token level instead of a word level
Expand Down
2 changes: 1 addition & 1 deletion farm/data_handler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,7 +1956,7 @@ def file_to_dicts(self, file: str) -> [dict]:
{"text": document_text, "title": xxx, "label": "hard_negative", "external_id": abb134},
...]}
"""
dicts = read_dpr_json(file)
dicts = read_dpr_json(file, max_samples=self.max_samples)
return dicts

def _normalize_question(self, question: str) -> str:
Expand Down
5 changes: 3 additions & 2 deletions farm/data_handler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def read_ner_file(filename, sep="\t", proxies=None):
data.append({"text": " ".join(sentence), "ner_label": label})
return data

def read_dpr_json(file, proxies=None):
def read_dpr_json(file, max_samples=None, proxies=None):
"""
Reads a Dense Passage Retrieval (DPR) data file in json format and returns a list of dictionaries.

Expand Down Expand Up @@ -215,7 +215,8 @@ def read_dpr_json(file, proxies=None):
logger.info(f" Couldn't find {file} locally. Trying to download ...")
_download_extract_downstream_data(file, proxies=proxies)
dicts = json.load(open(file))

if max_samples:
dicts = random.sample(dicts, min(max_samples, len(dicts)))
# convert DPR dictionary to standard dictionary
query_json_keys = ["question", "questions", "query"]
positive_context_json_keys = ["positive_contexts", "positive_ctxs", "positive_context", "positive_ctx"]
Expand Down
62 changes: 45 additions & 17 deletions farm/modeling/prediction_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path
from transformers.modeling_bert import BertForPreTraining, ACT2FN
from transformers.modeling_auto import AutoModelForQuestionAnswering, AutoModelForTokenClassification, AutoModelForSequenceClassification
from typing import List
from typing import List, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -1557,6 +1557,9 @@ def pick_single_fn(heads, fn_name):


class TextSimilarityHead(PredictionHead):
"""
Trains a head on predicting the similarity of two texts like in Dense Passage Retrieval.
"""
def __init__(self, similarity_function="dot_product", **kwargs):
super(TextSimilarityHead, self).__init__()

Expand Down Expand Up @@ -1612,21 +1615,35 @@ def get_similarity_function(self):
elif "cosine" in self.similarity_function:
return TextSimilarityHead.cosine_scores

def forward(self, query_vectors, context_vectors):
def forward(self, query_vectors:torch.Tensor, context_vectors:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the log softmax similarity scores between two 2-dimensional tensors
Only packs the embeddings from both language models into a tuple. No further modification.
The similarity calculation is handled later to enable distributed training (DDP)
while keeping the support for in-batch negatives.
(Gather all embeddings from nodes => then do similarity scores + loss)

:param query_vectors: tensor of query embeddings from BiAdaptive model
:param query_vectors: Tensor of query embeddings from BiAdaptive model
of dimension n1 x D,
where n1 is the number of queries/batch size and D is embedding size
:type query_vectors: torch.Tensor
:param context_vectors: tensor of context/passage embeddings from BiAdaptive model
:param context_vectors: Tensor of context/passage embeddings from BiAdaptive model
of dimension n2 x D,
where n2 is the number of queries/batch size and D is embedding size
:type context_vectors: torch.Tensor

:return: log softmax similarity score of each query with each context/passage (dimension: n1xn2)
:return: (query_vectors, context_vectors)
"""
return (query_vectors, context_vectors)

def _embeddings_to_scores(self, query_vectors:torch.Tensor, context_vectors:torch.Tensor):
"""
Calculates similarity scores between all given query_vectors and context_vectors

:param query_vectors: Tensor of queries encoded by the query encoder model
:param context_vectors: Tensor of passages encoded by the passage encoder model
:return: Tensor of log softmax similarity scores of each query with each passage (dimension: n1xn2)
"""

sim_func = self.get_similarity_function()
scores = sim_func(query_vectors, context_vectors)

Expand All @@ -1637,22 +1654,31 @@ def forward(self, query_vectors, context_vectors):
softmax_scores = nn.functional.log_softmax(scores, dim=1)
return softmax_scores

def logits_to_loss(self, logits, **kwargs):
def logits_to_loss(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs):
"""
Computes the loss from similarity scores
Computes the loss (Default: NLLLoss) by applying a similarity function (Default: dot product) to the input
tuple of (query_vectors, context_vectors) and afterwards applying the loss function on similarity scores.

:param logits: tensor of log softmax similarity scores of each query with each context/passage (dimension: n1xn2)
:type logits: torch.Tensor
:param logits: Tuple of Tensors (query_embedding, context_embedding) as returned from forward()

:return: negative log likelihood loss from similarity scores
"""
# Prepare predicted scores
query_vectors, context_vectors = logits
softmax_scores = self._embeddings_to_scores(query_vectors, context_vectors)

# Prepare Labels
lm_label_ids = kwargs.get(self.label_tensor_name)
positive_idx_per_question = (lm_label_ids.view(-1) == 1).nonzero()
loss = self.loss_fct(logits,
torch.tensor(positive_idx_per_question).squeeze(-1).to(logits.device))
positive_idx_per_question = torch.nonzero((lm_label_ids.view(-1) == 1), as_tuple=False)
#TODO gather global tensors from all nodes for DDP
global_positive_idx_per_question = positive_idx_per_question
targets = global_positive_idx_per_question.squeeze(-1).to(softmax_scores.device)

# Calculate loss
loss = self.loss_fct(softmax_scores, targets)
return loss

def logits_to_preds(self, logits, **kwargs):
def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs):
"""
Returns predicted ranks(similarity) of passages/context for each query

Expand All @@ -1661,8 +1687,10 @@ def logits_to_preds(self, logits, **kwargs):

:return: predicted ranks of passages for each query
"""
_, logits_sorted_indices = torch.sort(logits, dim=1, descending=True)
return logits_sorted_indices
query_vectors, context_vectors = logits
softmax_scores = self._embeddings_to_scores(query_vectors, context_vectors)
_, sorted_scores = torch.sort(softmax_scores, dim=1, descending=True)
return sorted_scores

def prepare_labels(self, **kwargs):
"""
Expand All @@ -1677,5 +1705,5 @@ def prepare_labels(self, **kwargs):
labels[i, indx.item()] = 1
return labels

def formatted_preds(self, logits, **kwargs):
def formatted_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs):
raise NotImplementedError("formatted_preds is not supported in TextSimilarityHead yet!")
2 changes: 2 additions & 0 deletions farm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from farm.data_handler.data_silo import DataSilo
from farm.visual.ascii.images import GROWING_TREE
from farm.modeling.adaptive_model import AdaptiveModel
from farm.modeling.biadaptive_model import BiAdaptiveModel
from farm.modeling.optimization import get_scheduler

try:
Expand Down Expand Up @@ -248,6 +249,7 @@ def train(self):
# connect the prediction heads with the right output from processor
self.model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True)
# Check that the tokenizer fits the language model
#TODO: make this compliant for DP / DDP where the model class is wrapped
if self.model._get_name() == 'BiAdaptiveModel':
self.model.verify_vocab_size(vocab_size1=len(self.data_silo.processor.tokenizer),
vocab_size2=len(self.data_silo.processor.passage_tokenizer))
Expand Down
11 changes: 7 additions & 4 deletions test/test_dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from farm.modeling.tokenization import Tokenizer
from farm.utils import set_all_seeds, initialize_device_settings
from farm.data_handler.dataset import convert_features_to_dataset
from transformers import DPRConfig

def test_dpr_modules(caplog=None):
if caplog:
Expand Down Expand Up @@ -111,9 +110,13 @@ def test_dpr_modules(caplog=None):
0.3350, -0.3412]), torch.ones((1, 10)) * 0.0001))

# test logits and loss
logits = model(**features)
loss = model.logits_to_loss_per_head(logits, **features)
similarity_scores = logits[0].cpu()
embeddings = model(**features)
query_emb, passage_emb = embeddings[0]
assert torch.all(torch.eq(query_emb.cpu(), query_vector.cpu()))
assert torch.all(torch.eq(passage_emb.cpu(), passage_vector.cpu()))

loss = model.logits_to_loss_per_head(embeddings, **features)
similarity_scores = model.prediction_heads[0]._embeddings_to_scores(query_emb, passage_emb).cpu()
assert torch.all(torch.le(similarity_scores - torch.tensor([[-1.8311e-03, -6.3016e+00]]), torch.ones((1, 2)) * 0.0001))
assert (loss[0].item() - 0.0018) <= 0.0001

Expand Down