diff --git a/examples/dpr_encoder.py b/examples/dpr_encoder.py index eb04225a9..8b7622b31 100644 --- a/examples/dpr_encoder.py +++ b/examples/dpr_encoder.py @@ -23,7 +23,7 @@ def dense_passage_retrieval(): ) ml_logger = MLFlowLogger(tracking_uri="https://public-mlflow.deepset.ai/") - ml_logger.init_experiment(experiment_name="FARM-dense_passage_retrieval", run_name="Run_dpr_enocder") + ml_logger.init_experiment(experiment_name="FARM-dense_passage_retrieval", run_name="Run_dpr") ########################## ########## Settings @@ -42,12 +42,13 @@ def dense_passage_retrieval(): similarity_function = "dot_product" train_filename = "nq-train.json" dev_filename = "nq-dev.json" + test_filename = "nq-dev.json" max_samples = None #load a smaller dataset (e.g. for debugging) # 1.Create question and passage tokenizers query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=question_lang_model, do_lower_case=do_lower_case, use_fast=use_fast) - context_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_lang_model, + passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_lang_model, do_lower_case=do_lower_case, use_fast=use_fast) # 2. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset @@ -56,14 +57,15 @@ def dense_passage_retrieval(): label_list = ["hard_negative", "positive"] metric = "text_similarity_metric" processor = TextSimilarityProcessor(tokenizer=query_tokenizer, - passage_tokenizer=context_tokenizer, - max_seq_len=256, + passage_tokenizer=passage_tokenizer, + max_seq_len_query=256, + max_seq_len_passage=256, label_list=label_list, metric=metric, data_dir="data/retriever", train_filename=train_filename, dev_filename=dev_filename, - test_filename=dev_filename, + test_filename=test_filename, embed_title=embed_title, num_hard_negatives=num_hard_negatives, max_samples=max_samples) @@ -73,8 +75,8 @@ def dense_passage_retrieval(): data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False) - # 4. Create an AdaptiveModel+ - # a) which consists of a pretrained language model as a basis + # 4. Create an BiAdaptiveModel+ + # a) which consists of 2 pretrained language models as a basis question_language_model = LanguageModel.load(pretrained_model_name_or_path="bert-base-uncased", language_model_class="DPRQuestionEncoder") passage_language_model = LanguageModel.load(pretrained_model_name_or_path="bert-base-uncased", language_model_class="DPRContextEncoder") diff --git a/farm/data_handler/processor.py b/farm/data_handler/processor.py index 35d30fa6f..08e7f0f0c 100644 --- a/farm/data_handler/processor.py +++ b/farm/data_handler/processor.py @@ -1799,8 +1799,9 @@ def __init__( self, tokenizer, passage_tokenizer, - max_seq_len, - data_dir, + max_seq_len_query, + max_seq_len_passage, + data_dir="", metric=None, train_filename="train.json", dev_filename=None, @@ -1819,8 +1820,10 @@ def __init__( """ :param tokenizer: Used to split a question (str) into tokens :param passage_tokenizer: Used to split a passage (str) into tokens. - :param max_seq_len: Samples are truncated after this many tokens. - :type max_seq_len: int + :param max_seq_len_query: Query samples are truncated after this many tokens. + :type max_seq_len_query: int + :param max_seq_len_passage: Context/Passage Samples are truncated after this many tokens. + :type max_seq_len_passage: int :param data_dir: The directory in which the train and dev files can be found. If not available the dataset will be loaded automaticaly if the last directory has the same name as a predefined dataset. @@ -1868,10 +1871,12 @@ def __init__( self.num_positives = num_positives self.shuffle_negatives = shuffle_negatives self.shuffle_positives = shuffle_positives + self.max_seq_len_query = max_seq_len_query + self.max_seq_len_passage = max_seq_len_passage super(TextSimilarityProcessor, self).__init__( tokenizer=tokenizer, - max_seq_len=max_seq_len, + max_seq_len=max_seq_len_query, train_filename=train_filename, dev_filename=dev_filename, test_filename=test_filename, @@ -1993,88 +1998,103 @@ def _dict_to_samples(self, dictionary: dict, **kwargs) -> [Sample]: Returns: sample: instance of Sample """ + + + clear_text = {} + tokenized = {} + features = {} # extract query, positive context passages and titles, hard-negative passages and titles - query = self._normalize_question(dictionary["query"]) - positive_context = list(filter(lambda x: x["label"] == "positive", dictionary["passages"])) - if self.shuffle_positives: - random.shuffle(positive_context) - positive_context = positive_context[:self.num_positives] - hard_negative_context = list(filter(lambda x: x["label"] == "hard_negative", dictionary["passages"])) - if self.shuffle_negatives: - random.shuffle(hard_negative_context) - hard_negative_context = hard_negative_context[:self.num_hard_negatives] - - positive_ctx_titles = [passage.get("title", None) for passage in positive_context] - positive_ctx_texts = [passage["text"] for passage in positive_context] - hard_negative_ctx_titles = [passage.get("title", None) for passage in hard_negative_context] - hard_negative_ctx_texts = [passage["text"] for passage in hard_negative_context] - - # all context passages and labels: 1 for positive context and 0 for hard-negative context - ctx_label = [1]*self.num_positives + [0]*self.num_hard_negatives #(self.num_positives if self.num_positives < len(positive_context) else len(positive_context)) + \ - # +(self.num_hard_negatives if self.num_hard_negatives < len(hard_negative_context) else len(hard_negative_context)) - - # featurize the query - query_inputs = self.query_tokenizer.encode_plus( - text=query, - max_length=self.max_seq_len, - add_special_tokens=True, - truncation_strategy='do_not_truncate', - padding="max_length", - return_token_type_ids=True, - ) + if "query" in dictionary.keys(): + query = self._normalize_question(dictionary["query"]) + + # featurize the query + query_inputs = self.query_tokenizer.encode_plus( + text=query, + max_length=self.max_seq_len_query, + add_special_tokens=True, + truncation_strategy='do_not_truncate', + padding="max_length", + return_token_type_ids=True, + ) - # featurize context passages - if self.embed_title: - # embed title with positive context passages + negative context passages - all_ctx = [tuple((title, ctx)) for title, ctx in - zip(positive_ctx_titles, positive_ctx_texts)] + \ - [tuple((title, ctx)) for title, ctx in - zip(hard_negative_ctx_titles, hard_negative_ctx_texts)] - else: - all_ctx = positive_ctx_texts + hard_negative_ctx_texts - - # assign empty string tuples if hard_negative passages less than num_hard_negatives - all_ctx += [('', '')] * ((self.num_positives + self.num_hard_negatives)-len(all_ctx)) - - ctx_inputs = self.passage_tokenizer.batch_encode_plus( - all_ctx, - add_special_tokens=True, - truncation=True, - padding="max_length", - max_length=self.max_seq_len, - return_token_type_ids=True - ) + query_input_ids, query_segment_ids, query_padding_mask = query_inputs["input_ids"], query_inputs[ + "token_type_ids"], query_inputs["attention_mask"] + + # tokeize query + tokenized_query = self.query_tokenizer.convert_ids_to_tokens(query_input_ids) + + if len(tokenized_query) == 0: + logger.warning( + f"The query could not be tokenized, likely because it contains a character that the query tokenizer does not recognize") + return None + + clear_text["query_text"] = query + tokenized["query_tokens"] = tokenized_query + features["query_input_ids"] = query_input_ids + features["query_segment_ids"] = query_segment_ids + features["query_attention_mask"] = query_padding_mask + + if "passages" in dictionary.keys(): + positive_context = list(filter(lambda x: x["label"] == "positive", dictionary["passages"])) + if self.shuffle_positives: + random.shuffle(positive_context) + positive_context = positive_context[:self.num_positives] + hard_negative_context = list(filter(lambda x: x["label"] == "hard_negative", dictionary["passages"])) + if self.shuffle_negatives: + random.shuffle(hard_negative_context) + hard_negative_context = hard_negative_context[:self.num_hard_negatives] + + positive_ctx_titles = [passage.get("title", None) for passage in positive_context] + positive_ctx_texts = [passage["text"] for passage in positive_context] + hard_negative_ctx_titles = [passage.get("title", None) for passage in hard_negative_context] + hard_negative_ctx_texts = [passage["text"] for passage in hard_negative_context] + + # all context passages and labels: 1 for positive context and 0 for hard-negative context + ctx_label = [1]*self.num_positives + [0]*self.num_hard_negatives #(self.num_positives if self.num_positives < len(positive_context) else len(positive_context)) + \ + # +(self.num_hard_negatives if self.num_hard_negatives < len(hard_negative_context) else len(hard_negative_context)) + + # featurize context passages + if self.embed_title: + # embed title with positive context passages + negative context passages + all_ctx = [tuple((title, ctx)) for title, ctx in + zip(positive_ctx_titles, positive_ctx_texts)] + \ + [tuple((title, ctx)) for title, ctx in + zip(hard_negative_ctx_titles, hard_negative_ctx_texts)] + else: + all_ctx = positive_ctx_texts + hard_negative_ctx_texts + + # assign empty string tuples if hard_negative passages less than num_hard_negatives + all_ctx += [('', '')] * ((self.num_positives + self.num_hard_negatives)-len(all_ctx)) + + + ctx_inputs = self.passage_tokenizer.batch_encode_plus( + all_ctx, + add_special_tokens=True, + truncation=True, + padding="max_length", + max_length=self.max_seq_len_passage, + return_token_type_ids=True + ) + + + ctx_input_ids, ctx_segment_ids_, ctx_padding_mask = ctx_inputs["input_ids"], ctx_inputs["token_type_ids"], \ + ctx_inputs["attention_mask"] + ctx_segment_ids = list(torch.zeros((len(ctx_segment_ids_), len(ctx_segment_ids_[0]))).numpy()) + + # tokenize query and contexts + tokenized_passage = [self.passage_tokenizer.convert_ids_to_tokens(ctx) for ctx in ctx_input_ids] + + if len(tokenized_passage) == 0: + logger.warning(f"The context could not be tokenized, likely because it contains a character that the context tokenizer does not recognize") + return None + + clear_text["passages"] = positive_context + hard_negative_context + tokenized["passages_tokens"] = tokenized_passage + features["passage_input_ids"] = ctx_input_ids + features["passage_segment_ids"] = ctx_segment_ids + features["passage_attention_mask"] = ctx_padding_mask + features["label_ids"] = ctx_label - query_input_ids, query_segment_ids, query_padding_mask = query_inputs["input_ids"], query_inputs[ - "token_type_ids"], query_inputs["attention_mask"] - ctx_input_ids, ctx_segment_ids_, ctx_padding_mask = ctx_inputs["input_ids"], ctx_inputs["token_type_ids"], \ - ctx_inputs["attention_mask"] - ctx_segment_ids = list(torch.zeros((len(ctx_segment_ids_), len(ctx_segment_ids_[0]))).numpy()) - - # tokenize query and contexts - tokenized_query = self.query_tokenizer.convert_ids_to_tokens(query_input_ids) - tokenized_passage = [self.passage_tokenizer.convert_ids_to_tokens(ctx) for ctx in ctx_input_ids] - - if len(tokenized_query) == 0 or len(tokenized_passage) == 0: - logger.warning(f"The text could not be tokenized, likely because it contains a character that the tokenizer does not recognize") - return None - - clear_text = {"query_text": query, - "passages": positive_context + hard_negative_context - } - - tokenized = {"query_tokens": tokenized_query, - "passages_tokens": tokenized_passage - } - - features = {"query_input_ids": query_input_ids, - "query_segment_ids": query_segment_ids, - "query_attention_mask": query_padding_mask, - "passage_input_ids": ctx_input_ids, - "passage_segment_ids": ctx_segment_ids, - "passage_attention_mask": ctx_padding_mask, - "label_ids": ctx_label - } sample = Sample(id=None, clear_text=clear_text, diff --git a/farm/modeling/biadaptive_model.py b/farm/modeling/biadaptive_model.py index 37aaf47f9..b388fc106 100644 --- a/farm/modeling/biadaptive_model.py +++ b/farm/modeling/biadaptive_model.py @@ -151,8 +151,8 @@ def __init__( language_model1, language_model2, prediction_heads, - embeds_dropout_prob, - device, + embeds_dropout_prob=0.1, + device="cuda", lm1_output_types=["per_sequence"], lm2_output_types=["per_sequence"], loss_aggregation_fn=None, @@ -199,9 +199,9 @@ def __init__( self.lm1_output_dims = language_model1.get_output_dims() self.language_model2 = language_model2.to(device) self.lm2_output_dims = language_model2.get_output_dims() - self.prediction_heads = nn.ModuleList([ph.to(device) for ph in prediction_heads]) self.dropout1 = nn.Dropout(embeds_dropout_prob) self.dropout2 = nn.Dropout(embeds_dropout_prob) + self.prediction_heads = nn.ModuleList([ph.to(device) for ph in prediction_heads]) self.lm1_output_types = ( [lm1_output_types] if isinstance(lm1_output_types, str) else lm1_output_types ) @@ -354,32 +354,38 @@ def forward(self, **kwargs): """ # Run forward pass of language model - pooled_output1, pooled_output2 = self.forward_lm(**kwargs) + pooled_output = self.forward_lm(**kwargs) # Run forward pass of (multiple) prediction heads using the output from above all_logits = [] if len(self.prediction_heads) > 0: for head, lm1_out, lm2_out in zip(self.prediction_heads, self.lm1_output_types, self.lm2_output_types): # Choose relevant vectors from LM as output and perform dropout - if lm1_out == "per_sequence" or lm1_out == "per_sequence_continuous": - output1 = self.dropout1(pooled_output1) + if pooled_output[0] is not None: + if lm1_out == "per_sequence" or lm1_out == "per_sequence_continuous": + output1 = self.dropout1(pooled_output[0]) + else: + raise ValueError( + "Unknown extraction strategy from BiAdaptive language_model1: {}".format(lm1_out) + ) else: - raise ValueError( - "Unknown extraction strategy from DPR model: {}".format(lm1_out) - ) - - if lm2_out == "per_sequence" or lm2_out == "per_sequence_continuous": - output2 = self.dropout2(pooled_output2) + output1 = None + + if pooled_output[1] is not None: + if lm2_out == "per_sequence" or lm2_out == "per_sequence_continuous": + output2 = self.dropout2(pooled_output[1]) + else: + raise ValueError( + "Unknown extraction strategy from BiAdaptive language_model2: {}".format(lm2_out) + ) else: - raise ValueError( - "Unknown extraction strategy from DPR model: {}".format(lm2_out) - ) + output2 = None - # Do the actual forward pass of a single head - all_logits.append(head(output1, output2)) + embedding1, embedding2 = head(output1, output2) + all_logits.append(tuple([embedding1, embedding2])) else: # just return LM output (e.g. useful for extracting embeddings at inference time) - all_logits.append((pooled_output1, pooled_output2)) + all_logits.append((pooled_output)) return all_logits @@ -390,10 +396,15 @@ def forward_lm(self, **kwargs): :param kwargs: :return: 2 tensors of pooled_output from the 2 language models """ - pooled_output1, hidden_states1 = self.language_model1(**kwargs) - pooled_output2, hidden_states2 = self.language_model2(**kwargs) + pooled_output = [None, None] + if "query_input_ids" in kwargs.keys(): + pooled_output1, hidden_states1 = self.language_model1(**kwargs) + pooled_output[0] = pooled_output1 + if "passage_input_ids" in kwargs.keys(): + pooled_output2, hidden_states2 = self.language_model2(**kwargs) + pooled_output[1] = pooled_output2 - return pooled_output1, pooled_output2 + return tuple(pooled_output) def log_params(self): """ @@ -407,8 +418,7 @@ def log_params(self): "lm2_name": self.language_model2.name, "lm2_output_types": ",".join(self.lm2_output_types), "prediction_heads": ",".join( - [head.__class__.__name__ for head in self.prediction_heads] - ), + [head.__class__.__name__ for head in self.prediction_heads]) } try: MlLogger.log_params(params) diff --git a/farm/modeling/language_model.py b/farm/modeling/language_model.py index f4f40eee3..6d4b1223d 100644 --- a/farm/modeling/language_model.py +++ b/farm/modeling/language_model.py @@ -1444,10 +1444,16 @@ def load(cls, pretrained_model_name_or_path, language=None, **kwargs): dpr_question_encoder.model = transformers.DPRQuestionEncoder.from_pretrained(farm_lm_model, config=dpr_config, **kwargs) dpr_question_encoder.language = dpr_question_encoder.model.config.language else: - # Pytorch-transformer Style - dpr_question_encoder.model = transformers.DPRQuestionEncoder(config=transformers.DPRConfig(**kwargs)) - # load weights from pretrained_model_name_or_path Language model into DPRQuestionEncoder - dpr_question_encoder.model.base_model.bert_model = AutoModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs) + model_type = AutoConfig.from_pretrained(pretrained_model_name_or_path).model_type + if model_type == "dpr": + # "pretrained dpr model": load existing pretrained DPRQuestionEncoder model + dpr_question_encoder.model = transformers.DPRQuestionEncoder.from_pretrained( + str(pretrained_model_name_or_path), **kwargs) + else: + # "from scratch": load weights from different architecture (e.g. bert) into DPRQuestionEncoder + dpr_question_encoder.model = transformers.DPRQuestionEncoder(config=transformers.DPRConfig(**kwargs)) + dpr_question_encoder.model.base_model.bert_model = AutoModel.from_pretrained( + str(pretrained_model_name_or_path), **kwargs) dpr_question_encoder.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) return dpr_question_encoder @@ -1532,9 +1538,16 @@ def load(cls, pretrained_model_name_or_path, language=None, **kwargs): dpr_context_encoder.language = dpr_context_encoder.model.config.language else: # Pytorch-transformer Style - dpr_context_encoder.model = transformers.DPRContextEncoder(config=transformers.DPRConfig(**kwargs)) - # load weights from pretrained_model_name_or_path Language model into DPRContextEncoder - dpr_context_encoder.model.base_model.bert_model = AutoModel.from_pretrained(str(pretrained_model_name_or_path), **kwargs) + model_type = AutoConfig.from_pretrained(pretrained_model_name_or_path).model_type + if model_type == "dpr": + # "pretrained dpr model": load existing pretrained DPRContextEncoder model + dpr_context_encoder.model = transformers.DPRContextEncoder.from_pretrained( + str(pretrained_model_name_or_path), **kwargs) + else: + # "from scratch": load weights from different architecture (e.g. bert) into DPRContextEncoder + dpr_context_encoder.model = transformers.DPRContextEncoder(config=transformers.DPRConfig(**kwargs)) + dpr_context_encoder.model.base_model.bert_model = AutoModel.from_pretrained( + str(pretrained_model_name_or_path), **kwargs) dpr_context_encoder.language = cls._get_or_infer_language_from_name(language, pretrained_model_name_or_path) return dpr_context_encoder diff --git a/farm/modeling/prediction_head.py b/farm/modeling/prediction_head.py index 583c6956a..d7053dc6b 100644 --- a/farm/modeling/prediction_head.py +++ b/farm/modeling/prediction_head.py @@ -1572,23 +1572,23 @@ def __init__(self, similarity_function="dot_product", **kwargs): self.generate_config() @classmethod - def dot_product_scores(cls, query_vectors, context_vectors): + def dot_product_scores(cls, query_vectors, passage_vectors): """ Calculates dot product similarity scores for two 2-dimensional tensors :param query_vectors: tensor of query embeddings from BiAdaptive model of dimension n1 x D, where n1 is the number of queries/batch size and D is embedding size :type query_vectors: torch.Tensor - :param context_vectors: tensor of context/passage embeddings from BiAdaptive model of dimension n2 x D, where n2 is the number of queries/batch size and D is embedding size - :type context_vectors: torch.Tensor + :param passage_vectors: tensor of context/passage embeddings from BiAdaptive model of dimension n2 x D, where n2 is the number of queries/batch size and D is embedding size + :type passage_vectors: torch.Tensor :return dot_product: similarity score of each query with each context/passage (dimension: n1xn2) """ # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 - dot_product = torch.matmul(query_vectors, torch.transpose(context_vectors, 0, 1)) + dot_product = torch.matmul(query_vectors, torch.transpose(passage_vectors, 0, 1)) return dot_product @classmethod - def cosine_scores(cls, query_vectors, context_vectors): + def cosine_scores(cls, query_vectors, passage_vectors): """ Calculates cosine similarity scores for two 2-dimensional tensors @@ -1596,15 +1596,15 @@ def cosine_scores(cls, query_vectors, context_vectors): of dimension n1 x D, where n1 is the number of queries/batch size and D is embedding size :type query_vectors: torch.Tensor - :param context_vectors: tensor of context/passage embeddings from BiAdaptive model + :param passage_vectors: tensor of context/passage embeddings from BiAdaptive model of dimension n2 x D, where n2 is the number of queries/batch size and D is embedding size - :type context_vectors: torch.Tensor + :type passage_vectors: torch.Tensor :return: cosine similarity score of each query with each context/passage (dimension: n1xn2) """ # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 - return nn.functional.cosine_similarity(query_vectors, context_vectors, dim=1) + return nn.functional.cosine_similarity(query_vectors, passage_vectors, dim=1) def get_similarity_function(self): """ @@ -1615,7 +1615,7 @@ def get_similarity_function(self): elif "cosine" in self.similarity_function: return TextSimilarityHead.cosine_scores - def forward(self, query_vectors:torch.Tensor, context_vectors:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, query_vectors:torch.Tensor, passage_vectors:torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Only packs the embeddings from both language models into a tuple. No further modification. The similarity calculation is handled later to enable distributed training (DDP) @@ -1626,26 +1626,26 @@ def forward(self, query_vectors:torch.Tensor, context_vectors:torch.Tensor) -> T of dimension n1 x D, where n1 is the number of queries/batch size and D is embedding size :type query_vectors: torch.Tensor - :param context_vectors: Tensor of context/passage embeddings from BiAdaptive model + :param passage_vectors: Tensor of context/passage embeddings from BiAdaptive model of dimension n2 x D, where n2 is the number of queries/batch size and D is embedding size - :type context_vectors: torch.Tensor + :type passage_vectors: torch.Tensor - :return: (query_vectors, context_vectors) + :return: (query_vectors, passage_vectors) """ - return (query_vectors, context_vectors) + return (query_vectors, passage_vectors) - def _embeddings_to_scores(self, query_vectors:torch.Tensor, context_vectors:torch.Tensor): + def _embeddings_to_scores(self, query_vectors:torch.Tensor, passage_vectors:torch.Tensor): """ - Calculates similarity scores between all given query_vectors and context_vectors + Calculates similarity scores between all given query_vectors and passage_vectors :param query_vectors: Tensor of queries encoded by the query encoder model - :param context_vectors: Tensor of passages encoded by the passage encoder model + :param passage_vectors: Tensor of passages encoded by the passage encoder model :return: Tensor of log softmax similarity scores of each query with each passage (dimension: n1xn2) """ sim_func = self.get_similarity_function() - scores = sim_func(query_vectors, context_vectors) + scores = sim_func(query_vectors, passage_vectors) if len(query_vectors.size()) > 1: q_num = query_vectors.size(0) @@ -1657,15 +1657,15 @@ def _embeddings_to_scores(self, query_vectors:torch.Tensor, context_vectors:torc def logits_to_loss(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs): """ Computes the loss (Default: NLLLoss) by applying a similarity function (Default: dot product) to the input - tuple of (query_vectors, context_vectors) and afterwards applying the loss function on similarity scores. + tuple of (query_vectors, passage_vectors) and afterwards applying the loss function on similarity scores. - :param logits: Tuple of Tensors (query_embedding, context_embedding) as returned from forward() + :param logits: Tuple of Tensors (query_embedding, passage_embedding) as returned from forward() :return: negative log likelihood loss from similarity scores """ # Prepare predicted scores - query_vectors, context_vectors = logits - softmax_scores = self._embeddings_to_scores(query_vectors, context_vectors) + query_vectors, passage_vectors = logits + softmax_scores = self._embeddings_to_scores(query_vectors, passage_vectors) # Prepare Labels lm_label_ids = kwargs.get(self.label_tensor_name) @@ -1687,8 +1687,8 @@ def logits_to_preds(self, logits: Tuple[torch.Tensor, torch.Tensor], **kwargs): :return: predicted ranks of passages for each query """ - query_vectors, context_vectors = logits - softmax_scores = self._embeddings_to_scores(query_vectors, context_vectors) + query_vectors, passage_vectors = logits + softmax_scores = self._embeddings_to_scores(query_vectors, passage_vectors) _, sorted_scores = torch.sort(softmax_scores, dim=1, descending=True) return sorted_scores diff --git a/test/test_dpr.py b/test/test_dpr.py index e68c01e81..92d627e5e 100644 --- a/test/test_dpr.py +++ b/test/test_dpr.py @@ -20,13 +20,14 @@ def test_dpr_modules(caplog=None): # 1.Create question and passage tokenizers query_tokenizer = Tokenizer.load(pretrained_model_name_or_path="facebook/dpr-question_encoder-single-nq-base", do_lower_case=True, use_fast=True) - context_tokenizer = Tokenizer.load(pretrained_model_name_or_path="facebook/dpr-ctx_encoder-single-nq-base", + passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path="facebook/dpr-ctx_encoder-single-nq-base", do_lower_case=True, use_fast=True) processor = TextSimilarityProcessor( tokenizer=query_tokenizer, - passage_tokenizer=context_tokenizer, - max_seq_len=256, + passage_tokenizer=passage_tokenizer, + max_seq_len_query=256, + max_seq_len_passage=256, label_list=["hard_negative", "positive"], metric="text_similarity_metric", data_dir="data/retriever", @@ -205,10 +206,11 @@ def test_dpr_processor(embed_title, passage_ids, passage_attns, use_fast, num_ha query_tok = "facebook/dpr-question_encoder-single-nq-base" query_tokenizer = Tokenizer.load(query_tok, use_fast=use_fast) passage_tok = "facebook/dpr-ctx_encoder-single-nq-base" - context_tokenizer = Tokenizer.load(passage_tok, use_fast=use_fast) + passage_tokenizer = Tokenizer.load(passage_tok, use_fast=use_fast) processor = TextSimilarityProcessor(tokenizer=query_tokenizer, - passage_tokenizer=context_tokenizer, - max_seq_len=256, + passage_tokenizer=passage_tokenizer, + max_seq_len_query=256, + max_seq_len_passage=256, data_dir="data/retriever", train_filename="nq-train.json", test_filename="nq-dev.json",