Skip to content

Commit

Permalink
Refactor DPR for haystack compatibility (#606)
Browse files Browse the repository at this point in the history
* DPRProcessor, BiAdaptiveModel modified for optional query/passage

* TextSimilarityProcessor, BiAdaptiveModel haystack inference compatibility added

* TextSimilarityProcessor query/passage features optional

* DPRProcessor, BiAdaptiveModel modified for optional query/passage

* TextSimilarityProcessor, BiAdaptiveModel haystack inference compatibility added

* TextSimilarityProcessor query/passage features optional

* prediction head modified

* bugfix in BiAdaptiveModel init

* Fix removal of yes no answers (#540)

* fix removal of yes no answers

* Make use of the answer_type linked to each answers.

* pin seqeval version

* remove hardcoded answer types list

Co-authored-by: Fabio Tesser <fabio.tesser@gmail.com>
Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>

* DPR test modified for max_seq_len_query/max_seq_len_context

* Infer model type from config (#600)

* Inference model and tokenizer type from config

* Infer type from model name as fallback

* BiAdaptive model output type modified to tuple from dict

* dpr tests reflect biadaptive model output type(tuple)

* DPRProcessor, BiAdaptiveModel modified for optional query/passage

* TextSimilarityProcessor, BiAdaptiveModel haystack inference compatibility added

* TextSimilarityProcessor query/passage features optional

* DPRProcessor, BiAdaptiveModel modified for optional query/passage

* TextSimilarityProcessor, BiAdaptiveModel haystack inference compatibility added

* prediction head modified

* bugfix in BiAdaptiveModel init

* DPR test modified for max_seq_len_query/max_seq_len_context

* BiAdaptive model output type modified to tuple from dict

* dpr tests reflect biadaptive model output type(tuple)

* renamed variables from 'context' to 'passage'

* DPR language model loading refactored

* DPR Language model comments fix

Co-authored-by: Branden Chan <33759007+brandenchan@users.noreply.github.com>
Co-authored-by: Fabio Tesser <fabio.tesser@gmail.com>
Co-authored-by: Malte Pietsch <malte.pietsch@deepset.ai>
Co-authored-by: bogdankostic <bogdankostic@web.de>
  • Loading branch information
5 people authored Oct 28, 2020
1 parent b7ecb37 commit 6dd113d
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 150 deletions.
16 changes: 9 additions & 7 deletions examples/dpr_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def dense_passage_retrieval():
)

ml_logger = MLFlowLogger(tracking_uri="https://public-mlflow.deepset.ai/")
ml_logger.init_experiment(experiment_name="FARM-dense_passage_retrieval", run_name="Run_dpr_enocder")
ml_logger.init_experiment(experiment_name="FARM-dense_passage_retrieval", run_name="Run_dpr")

##########################
########## Settings
Expand All @@ -42,12 +42,13 @@ def dense_passage_retrieval():
similarity_function = "dot_product"
train_filename = "nq-train.json"
dev_filename = "nq-dev.json"
test_filename = "nq-dev.json"
max_samples = None #load a smaller dataset (e.g. for debugging)

# 1.Create question and passage tokenizers
query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=question_lang_model,
do_lower_case=do_lower_case, use_fast=use_fast)
context_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_lang_model,
passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_lang_model,
do_lower_case=do_lower_case, use_fast=use_fast)

# 2. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
Expand All @@ -56,14 +57,15 @@ def dense_passage_retrieval():
label_list = ["hard_negative", "positive"]
metric = "text_similarity_metric"
processor = TextSimilarityProcessor(tokenizer=query_tokenizer,
passage_tokenizer=context_tokenizer,
max_seq_len=256,
passage_tokenizer=passage_tokenizer,
max_seq_len_query=256,
max_seq_len_passage=256,
label_list=label_list,
metric=metric,
data_dir="data/retriever",
train_filename=train_filename,
dev_filename=dev_filename,
test_filename=dev_filename,
test_filename=test_filename,
embed_title=embed_title,
num_hard_negatives=num_hard_negatives,
max_samples=max_samples)
Expand All @@ -73,8 +75,8 @@ def dense_passage_retrieval():
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False)


# 4. Create an AdaptiveModel+
# a) which consists of a pretrained language model as a basis
# 4. Create an BiAdaptiveModel+
# a) which consists of 2 pretrained language models as a basis
question_language_model = LanguageModel.load(pretrained_model_name_or_path="bert-base-uncased", language_model_class="DPRQuestionEncoder")
passage_language_model = LanguageModel.load(pretrained_model_name_or_path="bert-base-uncased", language_model_class="DPRContextEncoder")

Expand Down
188 changes: 104 additions & 84 deletions farm/data_handler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,8 +1799,9 @@ def __init__(
self,
tokenizer,
passage_tokenizer,
max_seq_len,
data_dir,
max_seq_len_query,
max_seq_len_passage,
data_dir="",
metric=None,
train_filename="train.json",
dev_filename=None,
Expand All @@ -1819,8 +1820,10 @@ def __init__(
"""
:param tokenizer: Used to split a question (str) into tokens
:param passage_tokenizer: Used to split a passage (str) into tokens.
:param max_seq_len: Samples are truncated after this many tokens.
:type max_seq_len: int
:param max_seq_len_query: Query samples are truncated after this many tokens.
:type max_seq_len_query: int
:param max_seq_len_passage: Context/Passage Samples are truncated after this many tokens.
:type max_seq_len_passage: int
:param data_dir: The directory in which the train and dev files can be found.
If not available the dataset will be loaded automaticaly
if the last directory has the same name as a predefined dataset.
Expand Down Expand Up @@ -1868,10 +1871,12 @@ def __init__(
self.num_positives = num_positives
self.shuffle_negatives = shuffle_negatives
self.shuffle_positives = shuffle_positives
self.max_seq_len_query = max_seq_len_query
self.max_seq_len_passage = max_seq_len_passage

super(TextSimilarityProcessor, self).__init__(
tokenizer=tokenizer,
max_seq_len=max_seq_len,
max_seq_len=max_seq_len_query,
train_filename=train_filename,
dev_filename=dev_filename,
test_filename=test_filename,
Expand Down Expand Up @@ -1993,88 +1998,103 @@ def _dict_to_samples(self, dictionary: dict, **kwargs) -> [Sample]:
Returns:
sample: instance of Sample
"""


clear_text = {}
tokenized = {}
features = {}
# extract query, positive context passages and titles, hard-negative passages and titles
query = self._normalize_question(dictionary["query"])
positive_context = list(filter(lambda x: x["label"] == "positive", dictionary["passages"]))
if self.shuffle_positives:
random.shuffle(positive_context)
positive_context = positive_context[:self.num_positives]
hard_negative_context = list(filter(lambda x: x["label"] == "hard_negative", dictionary["passages"]))
if self.shuffle_negatives:
random.shuffle(hard_negative_context)
hard_negative_context = hard_negative_context[:self.num_hard_negatives]

positive_ctx_titles = [passage.get("title", None) for passage in positive_context]
positive_ctx_texts = [passage["text"] for passage in positive_context]
hard_negative_ctx_titles = [passage.get("title", None) for passage in hard_negative_context]
hard_negative_ctx_texts = [passage["text"] for passage in hard_negative_context]

# all context passages and labels: 1 for positive context and 0 for hard-negative context
ctx_label = [1]*self.num_positives + [0]*self.num_hard_negatives #(self.num_positives if self.num_positives < len(positive_context) else len(positive_context)) + \
# +(self.num_hard_negatives if self.num_hard_negatives < len(hard_negative_context) else len(hard_negative_context))

# featurize the query
query_inputs = self.query_tokenizer.encode_plus(
text=query,
max_length=self.max_seq_len,
add_special_tokens=True,
truncation_strategy='do_not_truncate',
padding="max_length",
return_token_type_ids=True,
)
if "query" in dictionary.keys():
query = self._normalize_question(dictionary["query"])

# featurize the query
query_inputs = self.query_tokenizer.encode_plus(
text=query,
max_length=self.max_seq_len_query,
add_special_tokens=True,
truncation_strategy='do_not_truncate',
padding="max_length",
return_token_type_ids=True,
)

# featurize context passages
if self.embed_title:
# embed title with positive context passages + negative context passages
all_ctx = [tuple((title, ctx)) for title, ctx in
zip(positive_ctx_titles, positive_ctx_texts)] + \
[tuple((title, ctx)) for title, ctx in
zip(hard_negative_ctx_titles, hard_negative_ctx_texts)]
else:
all_ctx = positive_ctx_texts + hard_negative_ctx_texts

# assign empty string tuples if hard_negative passages less than num_hard_negatives
all_ctx += [('', '')] * ((self.num_positives + self.num_hard_negatives)-len(all_ctx))

ctx_inputs = self.passage_tokenizer.batch_encode_plus(
all_ctx,
add_special_tokens=True,
truncation=True,
padding="max_length",
max_length=self.max_seq_len,
return_token_type_ids=True
)
query_input_ids, query_segment_ids, query_padding_mask = query_inputs["input_ids"], query_inputs[
"token_type_ids"], query_inputs["attention_mask"]

# tokeize query
tokenized_query = self.query_tokenizer.convert_ids_to_tokens(query_input_ids)

if len(tokenized_query) == 0:
logger.warning(
f"The query could not be tokenized, likely because it contains a character that the query tokenizer does not recognize")
return None

clear_text["query_text"] = query
tokenized["query_tokens"] = tokenized_query
features["query_input_ids"] = query_input_ids
features["query_segment_ids"] = query_segment_ids
features["query_attention_mask"] = query_padding_mask

if "passages" in dictionary.keys():
positive_context = list(filter(lambda x: x["label"] == "positive", dictionary["passages"]))
if self.shuffle_positives:
random.shuffle(positive_context)
positive_context = positive_context[:self.num_positives]
hard_negative_context = list(filter(lambda x: x["label"] == "hard_negative", dictionary["passages"]))
if self.shuffle_negatives:
random.shuffle(hard_negative_context)
hard_negative_context = hard_negative_context[:self.num_hard_negatives]

positive_ctx_titles = [passage.get("title", None) for passage in positive_context]
positive_ctx_texts = [passage["text"] for passage in positive_context]
hard_negative_ctx_titles = [passage.get("title", None) for passage in hard_negative_context]
hard_negative_ctx_texts = [passage["text"] for passage in hard_negative_context]

# all context passages and labels: 1 for positive context and 0 for hard-negative context
ctx_label = [1]*self.num_positives + [0]*self.num_hard_negatives #(self.num_positives if self.num_positives < len(positive_context) else len(positive_context)) + \
# +(self.num_hard_negatives if self.num_hard_negatives < len(hard_negative_context) else len(hard_negative_context))

# featurize context passages
if self.embed_title:
# embed title with positive context passages + negative context passages
all_ctx = [tuple((title, ctx)) for title, ctx in
zip(positive_ctx_titles, positive_ctx_texts)] + \
[tuple((title, ctx)) for title, ctx in
zip(hard_negative_ctx_titles, hard_negative_ctx_texts)]
else:
all_ctx = positive_ctx_texts + hard_negative_ctx_texts

# assign empty string tuples if hard_negative passages less than num_hard_negatives
all_ctx += [('', '')] * ((self.num_positives + self.num_hard_negatives)-len(all_ctx))


ctx_inputs = self.passage_tokenizer.batch_encode_plus(
all_ctx,
add_special_tokens=True,
truncation=True,
padding="max_length",
max_length=self.max_seq_len_passage,
return_token_type_ids=True
)


ctx_input_ids, ctx_segment_ids_, ctx_padding_mask = ctx_inputs["input_ids"], ctx_inputs["token_type_ids"], \
ctx_inputs["attention_mask"]
ctx_segment_ids = list(torch.zeros((len(ctx_segment_ids_), len(ctx_segment_ids_[0]))).numpy())

# tokenize query and contexts
tokenized_passage = [self.passage_tokenizer.convert_ids_to_tokens(ctx) for ctx in ctx_input_ids]

if len(tokenized_passage) == 0:
logger.warning(f"The context could not be tokenized, likely because it contains a character that the context tokenizer does not recognize")
return None

clear_text["passages"] = positive_context + hard_negative_context
tokenized["passages_tokens"] = tokenized_passage
features["passage_input_ids"] = ctx_input_ids
features["passage_segment_ids"] = ctx_segment_ids
features["passage_attention_mask"] = ctx_padding_mask
features["label_ids"] = ctx_label

query_input_ids, query_segment_ids, query_padding_mask = query_inputs["input_ids"], query_inputs[
"token_type_ids"], query_inputs["attention_mask"]
ctx_input_ids, ctx_segment_ids_, ctx_padding_mask = ctx_inputs["input_ids"], ctx_inputs["token_type_ids"], \
ctx_inputs["attention_mask"]
ctx_segment_ids = list(torch.zeros((len(ctx_segment_ids_), len(ctx_segment_ids_[0]))).numpy())

# tokenize query and contexts
tokenized_query = self.query_tokenizer.convert_ids_to_tokens(query_input_ids)
tokenized_passage = [self.passage_tokenizer.convert_ids_to_tokens(ctx) for ctx in ctx_input_ids]

if len(tokenized_query) == 0 or len(tokenized_passage) == 0:
logger.warning(f"The text could not be tokenized, likely because it contains a character that the tokenizer does not recognize")
return None

clear_text = {"query_text": query,
"passages": positive_context + hard_negative_context
}

tokenized = {"query_tokens": tokenized_query,
"passages_tokens": tokenized_passage
}

features = {"query_input_ids": query_input_ids,
"query_segment_ids": query_segment_ids,
"query_attention_mask": query_padding_mask,
"passage_input_ids": ctx_input_ids,
"passage_segment_ids": ctx_segment_ids,
"passage_attention_mask": ctx_padding_mask,
"label_ids": ctx_label
}

sample = Sample(id=None,
clear_text=clear_text,
Expand Down
56 changes: 33 additions & 23 deletions farm/modeling/biadaptive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def __init__(
language_model1,
language_model2,
prediction_heads,
embeds_dropout_prob,
device,
embeds_dropout_prob=0.1,
device="cuda",
lm1_output_types=["per_sequence"],
lm2_output_types=["per_sequence"],
loss_aggregation_fn=None,
Expand Down Expand Up @@ -199,9 +199,9 @@ def __init__(
self.lm1_output_dims = language_model1.get_output_dims()
self.language_model2 = language_model2.to(device)
self.lm2_output_dims = language_model2.get_output_dims()
self.prediction_heads = nn.ModuleList([ph.to(device) for ph in prediction_heads])
self.dropout1 = nn.Dropout(embeds_dropout_prob)
self.dropout2 = nn.Dropout(embeds_dropout_prob)
self.prediction_heads = nn.ModuleList([ph.to(device) for ph in prediction_heads])
self.lm1_output_types = (
[lm1_output_types] if isinstance(lm1_output_types, str) else lm1_output_types
)
Expand Down Expand Up @@ -354,32 +354,38 @@ def forward(self, **kwargs):
"""

# Run forward pass of language model
pooled_output1, pooled_output2 = self.forward_lm(**kwargs)
pooled_output = self.forward_lm(**kwargs)

# Run forward pass of (multiple) prediction heads using the output from above
all_logits = []
if len(self.prediction_heads) > 0:
for head, lm1_out, lm2_out in zip(self.prediction_heads, self.lm1_output_types, self.lm2_output_types):
# Choose relevant vectors from LM as output and perform dropout
if lm1_out == "per_sequence" or lm1_out == "per_sequence_continuous":
output1 = self.dropout1(pooled_output1)
if pooled_output[0] is not None:
if lm1_out == "per_sequence" or lm1_out == "per_sequence_continuous":
output1 = self.dropout1(pooled_output[0])
else:
raise ValueError(
"Unknown extraction strategy from BiAdaptive language_model1: {}".format(lm1_out)
)
else:
raise ValueError(
"Unknown extraction strategy from DPR model: {}".format(lm1_out)
)

if lm2_out == "per_sequence" or lm2_out == "per_sequence_continuous":
output2 = self.dropout2(pooled_output2)
output1 = None

if pooled_output[1] is not None:
if lm2_out == "per_sequence" or lm2_out == "per_sequence_continuous":
output2 = self.dropout2(pooled_output[1])
else:
raise ValueError(
"Unknown extraction strategy from BiAdaptive language_model2: {}".format(lm2_out)
)
else:
raise ValueError(
"Unknown extraction strategy from DPR model: {}".format(lm2_out)
)
output2 = None

# Do the actual forward pass of a single head
all_logits.append(head(output1, output2))
embedding1, embedding2 = head(output1, output2)
all_logits.append(tuple([embedding1, embedding2]))
else:
# just return LM output (e.g. useful for extracting embeddings at inference time)
all_logits.append((pooled_output1, pooled_output2))
all_logits.append((pooled_output))

return all_logits

Expand All @@ -390,10 +396,15 @@ def forward_lm(self, **kwargs):
:param kwargs:
:return: 2 tensors of pooled_output from the 2 language models
"""
pooled_output1, hidden_states1 = self.language_model1(**kwargs)
pooled_output2, hidden_states2 = self.language_model2(**kwargs)
pooled_output = [None, None]
if "query_input_ids" in kwargs.keys():
pooled_output1, hidden_states1 = self.language_model1(**kwargs)
pooled_output[0] = pooled_output1
if "passage_input_ids" in kwargs.keys():
pooled_output2, hidden_states2 = self.language_model2(**kwargs)
pooled_output[1] = pooled_output2

return pooled_output1, pooled_output2
return tuple(pooled_output)

def log_params(self):
"""
Expand All @@ -407,8 +418,7 @@ def log_params(self):
"lm2_name": self.language_model2.name,
"lm2_output_types": ",".join(self.lm2_output_types),
"prediction_heads": ",".join(
[head.__class__.__name__ for head in self.prediction_heads]
),
[head.__class__.__name__ for head in self.prediction_heads])
}
try:
MlLogger.log_params(params)
Expand Down
Loading

0 comments on commit 6dd113d

Please sign in to comment.