Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dense Passage Retriever (incl. Training) #513

Merged
merged 78 commits into from
Oct 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
0dd4959
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
1982b3e
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
6066afd
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
b4c79ac
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
407b87f
Merge branch 'dpr_processor' of https://github.com/deepset-ai/FARM in…
kolk Sep 8, 2020
34b7046
DPR tokenization and test added
kolk Sep 11, 2020
76c89f9
DPR question and context models added
kolk Sep 14, 2020
2ebadf8
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
6c3744b
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
a1bcdf4
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
f3ef0f8
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
265e928
DPR tokenization and test added
kolk Sep 11, 2020
dafff7e
DPR question and context models added
kolk Sep 14, 2020
a390ab4
Merge branch 'dpr_processor' of https://github.com/deepset-ai/FARM in…
kolk Sep 14, 2020
dfe4178
DPR training added
kolk Sep 17, 2020
2eb4aba
biadaptive model and dpr example script added
kolk Sep 17, 2020
e4a4e16
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
e738d91
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
eaf21b6
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
e293197
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
6aef273
DPR tokenization and test added
kolk Sep 11, 2020
9a2f767
DPR question and context models added
kolk Sep 14, 2020
52ea8bc
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
43b0f5d
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
4373ba6
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
5bd6fe3
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
5a84c25
DPR tokenization and test added
kolk Sep 11, 2020
0de21d6
DPR training added
kolk Sep 17, 2020
d0e1af5
biadaptive model and dpr example script added
kolk Sep 17, 2020
e3d2d37
Merge branch 'dpr_processor' of https://github.com/deepset-ai/FARM in…
kolk Sep 17, 2020
0c33ffe
dpr eval added
kolk Sep 18, 2020
25e834f
Average rank evaluation added
kolk Sep 28, 2020
313f94a
dpr eval report creation modified
kolk Sep 28, 2020
5a3d8e5
DPR eval refactored
kolk Sep 30, 2020
aec8432
DPR eval bug fix
kolk Oct 5, 2020
9466c6a
Saving models fixed, PredictionHead and Processor names refactored
kolk Oct 8, 2020
14fb2bc
DPR language model names refactored
kolk Oct 12, 2020
9b083c3
bug fix in intitalizing DPR language model
kolk Oct 14, 2020
02ed47f
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
09397d2
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
e217f4e
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
e02ecc3
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
9fe273a
DPR tokenization and test added
kolk Sep 11, 2020
69d7ae5
DPR question and context models added
kolk Sep 14, 2020
1ec0152
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
17fdf2c
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
afa5916
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
24c1b4f
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
af1b6d8
DPR tokenization and test added
kolk Sep 11, 2020
37ff6be
DPR training added
kolk Sep 17, 2020
ffaf8c6
biadaptive model and dpr example script added
kolk Sep 17, 2020
d589fc0
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
80d4549
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
1186186
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
bef17e7
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
92dacfa
DPR question and context models added
kolk Sep 14, 2020
4ba91ae
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
35f8703
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
7f65b75
DPRProcessor, sample_to_features_dpr added
kolk Sep 3, 2020
3bd5d3a
remove samples_to_features_dpr, featurize in _dict_to_samples
kolk Sep 7, 2020
d4eee5a
DPR training added
kolk Sep 17, 2020
c761b2c
dpr eval added
kolk Sep 18, 2020
55fe3e6
Average rank evaluation added
kolk Sep 28, 2020
744d967
dpr eval report creation modified
kolk Sep 28, 2020
9831be5
DPR eval refactored
kolk Sep 30, 2020
71155e4
DPR eval bug fix
kolk Oct 5, 2020
0f36335
Saving models fixed, PredictionHead and Processor names refactored
kolk Oct 8, 2020
7e0d9c5
DPR language model names refactored
kolk Oct 12, 2020
c0e4fa7
bug fix in intitalizing DPR language model
kolk Oct 14, 2020
ceb6484
merge conflict resolved
kolk Oct 14, 2020
4664857
conversion to and from transformers added
kolk Oct 14, 2020
f1d6815
DPR tests added
kolk Oct 14, 2020
c9395a4
DPR language model config parameter loading added
kolk Oct 14, 2020
b5cfecf
doc strings modified, metrics refactored
kolk Oct 14, 2020
1f47f4a
docstring warnings fixed
kolk Oct 15, 2020
f406581
Merge branch 'master' into dpr_processor
tholor Oct 15, 2020
15889c9
Doc string modified, from_scratch() removed from DPR lan models
kolk Oct 15, 2020
e82b61e
Merge branch 'dpr_processor' of https://github.com/deepset-ai/FARM in…
kolk Oct 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions examples/dpr_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# fmt: off
import logging
import os
import pprint
from pathlib import Path

from farm.data_handler.data_silo import DataSilo
from farm.data_handler.processor import TextSimilarityProcessor
from farm.modeling.biadaptive_model import BiAdaptiveModel
from farm.modeling.language_model import LanguageModel
from farm.modeling.optimization import initialize_optimizer
from farm.modeling.prediction_head import TextSimilarityHead
from farm.modeling.tokenization import Tokenizer
from farm.train import Trainer
from farm.utils import set_all_seeds, MLFlowLogger, initialize_device_settings
from farm.eval import Evaluator

def dense_passage_retrieval():
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)

ml_logger = MLFlowLogger(tracking_uri="https://public-mlflow.deepset.ai/")
ml_logger.init_experiment(experiment_name="FARM-dense_passage_retrieval", run_name="Run_dpr_enocder")

##########################
########## Settings
##########################
set_all_seeds(seed=42)
device, n_gpu = initialize_device_settings(use_cuda=True)
batch_size = 2
n_epochs = 3
evaluate_every = 1000
question_lang_model = "facebook/dpr-question_encoder-single-nq-base"
passage_lang_model = "facebook/dpr-ctx_encoder-single-nq-base"
do_lower_case = True
use_fast = True
embed_title = True
num_hard_negatives = 1
similarity_function = "dot_product"
train_filename = "nq-train.json"
dev_filename = "nq-dev.json"

# 1.Create question and passage tokenizers
query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=question_lang_model,
do_lower_case=do_lower_case, use_fast=use_fast)
context_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_lang_model,
do_lower_case=do_lower_case, use_fast=use_fast)

# 2. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
# data_dir "data/retriever" should contain DPR training and dev files downloaded from https://github.com/facebookresearch/DPR
# i.e., nq-train.json, nq-dev.json or trivia-train.json, trivia-dev.json
label_list = ["hard_negative", "positive"]
metric = "text_similarity_metric"
processor = TextSimilarityProcessor(tokenizer=query_tokenizer,
kolk marked this conversation as resolved.
Show resolved Hide resolved
passage_tokenizer=context_tokenizer,
max_seq_len=512,
label_list=label_list,
metric=metric,
data_dir="data/retriever",
train_filename=train_filename,
dev_filename=dev_filename,
test_filename=dev_filename,
embed_title=embed_title,
num_hard_negatives=num_hard_negatives)

# 3. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them and calculates a few descriptive statistics of our datasets
# NOTE: In FARM, the dev set metrics differ from test set metrics in that they are calculated on a token level instead of a word level
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False)


# 4. Create an AdaptiveModel+
# a) which consists of a pretrained language model as a basis
question_language_model = LanguageModel.load(pretrained_model_name_or_path="bert-base-uncased", language_model_class="DPRQuestionEncoder")
passage_language_model = LanguageModel.load(pretrained_model_name_or_path="bert-base-uncased", language_model_class="DPRContextEncoder")


# b) and a prediction head on top that is suited for our task => Question Answering
prediction_head = TextSimilarityHead(similarity_function=similarity_function)

model = BiAdaptiveModel(
language_model1=question_language_model,
language_model2=passage_language_model,
prediction_heads=[prediction_head],
embeds_dropout_prob=0.1,
lm1_output_types=["per_sequence"],
lm2_output_types=["per_sequence"],
device=device,
)

# 5. Create an optimizer
model, optimizer, lr_schedule = initialize_optimizer(
model=model,
learning_rate=1e-5,
optimizer_opts={"name": "TransformersAdamW", "correct_bias": True, "weight_decay": 0.0, \
"eps": 1e-08},
schedule_opts={"name": "LinearWarmup", "num_warmup_steps": 100},
n_batches=len(data_silo.loaders["train"]),
n_epochs=n_epochs,
grad_acc_steps=1,
device=device
)

# 6. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
trainer = Trainer(
model=model,
optimizer=optimizer,
data_silo=data_silo,
epochs=n_epochs,
n_gpu=n_gpu,
lr_schedule=lr_schedule,
evaluate_every=evaluate_every,
device=device,
)

# 7. Let it grow! Watch the tracked metrics live on the public mlflow server: https://public-mlflow.deepset.ai
trainer.train()

# 8. Hooray! You have a model. Store it:
save_dir = Path("../saved_models/dpr-tutorial")
model.save(save_dir)
processor.save(save_dir)

# 9. Evaluate
test_data_loader = data_silo.get_data_loader("test")
if test_data_loader is not None:
evaluator_test = Evaluator(
data_loader=test_data_loader, tasks=data_silo.processor.tasks, device=device)
model.connect_heads_with_processor(processor.tasks)
test_result = evaluator_test.eval(model)

dense_passage_retrieval()
Loading