-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon + GradientCache】 #1799
Changes from 52 commits
784f345
5c9a957
01f1bc0
fa610bc
c430b8a
8f420d1
8f3a97c
a26786d
2227a26
4510e1d
f86eeb9
d129186
483cb62
5e91937
bba0521
ae0125e
ff3789c
ff34e2a
ccbd5b1
202664a
b7a6db3
e675ea9
17be523
43acadb
4563c2d
c892976
c600939
87f029a
25aa42c
d5984a1
c7fdafd
8012929
7650da6
675efb6
d890d8e
88ba024
abff61d
1cd93be
162165c
9cbcd71
e533e10
67bad62
d57380c
748b63f
be889df
2f0901d
f2a4397
de9ba83
db2ccf0
476aaa5
f6716fb
6343cf7
25c0b2a
644438d
7ccabad
865d50c
f5a9606
0aa9739
3891997
ed675ec
152437f
2c57eb6
2fbfde8
fb38a58
ab0f9d1
8f335c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import random | ||
|
||
import paddle | ||
from paddle.io import Dataset | ||
import json | ||
from paddlenlp.transformers.bert.tokenizer import BertTokenizer | ||
import collections | ||
from typing import Dict, List, Tuple | ||
import numpy as np | ||
|
||
BiEncoderPassage = collections.namedtuple("BiEncoderPassage", ["text", "title"]) | ||
|
||
BiENcoderBatch = collections.namedtuple("BiEncoderInput", [ | ||
"questions_ids", | ||
"question_segments", | ||
"context_ids", | ||
"ctx_segments", | ||
"is_positive", | ||
"hard_negatives", | ||
"encoder_type", | ||
]) | ||
|
||
|
||
def normalize_question(question: str) -> str: | ||
question = question.replace("’", "'") | ||
return question | ||
|
||
|
||
def normalize_passage(ctx_text: str): | ||
ctx_text = ctx_text.replace("\n", " ").replace("’", "'") | ||
if ctx_text.startswith('"'): | ||
ctx_text = ctx_text[1:] | ||
if ctx_text.endswith('"'): | ||
ctx_text = ctx_text[:-1] | ||
return ctx_text | ||
|
||
|
||
class BiEncoderSample(object): | ||
query: str | ||
positive_passages: List[BiEncoderPassage] | ||
negative_passages: List[BiEncoderPassage] | ||
hard_negative_passages: List[BiEncoderPassage] | ||
|
||
|
||
class NQdataSetForDPR(Dataset): | ||
""" | ||
class for managing dataset | ||
""" | ||
|
||
def __init__(self, dataPath, query_special_suffix=None): | ||
super(NQdataSetForDPR, self).__init__() | ||
self.data = self._read_json_data(dataPath) | ||
self.tokenizer = BertTokenizer | ||
self.query_special_suffix = query_special_suffix | ||
self.new_data = [] | ||
for i in range(0, self.__len__()): | ||
self.new_data.append(self.__getitem__(i)) | ||
|
||
def _read_json_data(self, dataPath): | ||
results = [] | ||
with open(dataPath, "r", encoding="utf-8") as f: | ||
print("Reading file %s" % dataPath) | ||
data = json.load(f) | ||
results.extend(data) | ||
print("Aggregated data size: {}".format(len(results))) | ||
return results | ||
|
||
def __getitem__(self, index): | ||
json_sample_data = self.data[index] | ||
r = BiEncoderSample() | ||
r.query = self._porcess_query(json_sample_data["question"]) | ||
|
||
positive_ctxs = json_sample_data["positive_ctxs"] | ||
|
||
negative_ctxs = json_sample_data[ | ||
"negative_ctxs"] if "negative_ctxs" in json_sample_data else [] | ||
hard_negative_ctxs = json_sample_data["hard_negative_ctxs"] if "hard_negative_ctxs" in json_sample_data else [] | ||
|
||
for ctx in positive_ctxs + negative_ctxs + hard_negative_ctxs: | ||
if "title" not in ctx: | ||
ctx["title"] = None | ||
|
||
def create_passage(ctx): | ||
return BiEncoderPassage(normalize_passage(ctx["text"]), | ||
ctx["title"]) | ||
|
||
r.positive_passages = [create_passage(ctx) for ctx in positive_ctxs] | ||
r.negative_passages = [create_passage(ctx) for ctx in negative_ctxs] | ||
r.hard_negative_passages = [ | ||
create_passage(ctx) for ctx in hard_negative_ctxs | ||
] | ||
|
||
return r | ||
|
||
def _porcess_query(self, query): | ||
query = normalize_question(query) | ||
|
||
if self.query_special_suffix and not query.endswith( | ||
self.query_special_suffix): | ||
query += self.query_special_suffix | ||
|
||
return query | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
|
||
class DataUtil(): | ||
""" | ||
Class for working with datasets | ||
""" | ||
|
||
def __init__(self): | ||
self.tensorizer = BertTensorizer() | ||
|
||
def create_biencoder_input(self, | ||
samples: List[BiEncoderSample], | ||
inserted_title, | ||
num_hard_negatives=0, | ||
num_other_negatives=0, | ||
shuffle=True, | ||
shuffle_positives=False, | ||
hard_neg_positives=False, | ||
hard_neg_fallback=True, | ||
query_token=None): | ||
|
||
question_tensors = [] | ||
ctx_tensors = [] | ||
positive_ctx_indices = [] | ||
hard_neg_ctx_indices = [] | ||
|
||
for sample in samples: | ||
|
||
if shuffle and shuffle_positives: | ||
positive_ctxs = sample.positive_passages | ||
positive_ctx = positive_ctxs[np.random.choice( | ||
len(positive_ctxs))] | ||
else: | ||
positive_ctx = sample.positive_passages[0] | ||
|
||
neg_ctxs = sample.negative_passages | ||
hard_neg_ctxs = sample.hard_negative_passages | ||
question = sample.query | ||
|
||
if shuffle: | ||
random.shuffle(neg_ctxs) | ||
random.shuffle(hard_neg_ctxs) | ||
|
||
if hard_neg_fallback and len(hard_neg_ctxs) == 0: | ||
hard_neg_ctxs = neg_ctxs[0:num_hard_negatives] | ||
|
||
neg_ctxs = neg_ctxs[0:num_other_negatives] | ||
hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] | ||
|
||
all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs | ||
hard_negative_start_idx = 1 | ||
hard_negative_end_idx = 1 + len(hard_neg_ctxs) | ||
|
||
current_ctxs_len = len(ctx_tensors) | ||
|
||
sample_ctxs_tensors = [ | ||
self.tensorizer.text_to_tensor( | ||
ctx.text, | ||
title=ctx.title if (inserted_title and ctx.title) else None) | ||
for ctx in all_ctxs | ||
] | ||
|
||
ctx_tensors.extend(sample_ctxs_tensors) | ||
positive_ctx_indices.append(current_ctxs_len) | ||
hard_neg_ctx_indices.append(i for i in range( | ||
current_ctxs_len + hard_negative_start_idx, | ||
current_ctxs_len + hard_negative_end_idx, | ||
)) | ||
"""if query_token: | ||
if query_token == "[START_END]": | ||
query_span = _select_span | ||
else: | ||
question_tensors.append(self.tensorizer.text_to_tensor(" ".join([query_token, question]))) | ||
else:""" | ||
|
||
question_tensors.append(self.tensorizer.text_to_tensor(question)) | ||
|
||
ctxs_tensor = paddle.concat( | ||
[paddle.reshape(ctx, [1, -1]) for ctx in ctx_tensors], axis=0) | ||
questions_tensor = paddle.concat( | ||
[paddle.reshape(q, [1, -1]) for q in question_tensors], axis=0) | ||
|
||
ctx_segments = paddle.zeros_like(ctxs_tensor) | ||
question_segments = paddle.zeros_like(questions_tensor) | ||
|
||
return BiENcoderBatch( | ||
questions_tensor, | ||
question_segments, | ||
ctxs_tensor, | ||
ctx_segments, | ||
positive_ctx_indices, | ||
hard_neg_ctx_indices, | ||
"question", | ||
) | ||
|
||
|
||
class BertTensorizer(): | ||
|
||
def __init__(self, pad_to_max=True, max_length=256): | ||
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | ||
self.max_length = max_length | ||
self.pad_to_max = pad_to_max | ||
|
||
def text_to_tensor( | ||
self, | ||
text: str, | ||
title=None, | ||
): | ||
text = text.strip() | ||
|
||
if title: | ||
token_ids = self.tokenizer.encode( | ||
text, | ||
text_pair=title, | ||
max_seq_len=self.max_length, | ||
pad_to_max_seq_len=False, | ||
truncation_strategy="longest_first", | ||
)["input_ids"] | ||
else: | ||
token_ids = self.tokenizer.encode( | ||
text, | ||
max_seq_len=self.max_length, | ||
pad_to_max_seq_len=False, | ||
truncation_strategy="longest_first", | ||
)["input_ids"] | ||
|
||
seq_len = self.max_length | ||
if self.pad_to_max and len(token_ids) < seq_len: | ||
token_ids = token_ids + [self.tokenizer.pad_token_type_id | ||
] * (seq_len - len(token_ids)) | ||
if len(token_ids) >= seq_len: | ||
token_ids = token_ids[0:seq_len] | ||
token_ids[-1] = 102 | ||
|
||
return paddle.to_tensor(token_ids) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Gradient Cache策略 [DPR](https://arxiv.org/abs/2004.04906) | ||
|
||
|
||
### 实验结果 | ||
|
||
`Gradient Cache` 的实验结果如下,使用的评估指标是`Accuracy`: | ||
|
||
| DPR method | TOP-5 | TOP-10 | TOP-50| 说明 | | ||
| :-----: | :----: | :----: | :----: | :---- | | ||
| Gradient_cache | 68.1 | 79.4| 86.2 | DPR结合GC策略训练 | ||
| GC_Batch_size_512 | 67.3 | 79.6| 86.3| DPR结合GC策略训练,且batch_size设置为512| | ||
|
||
实验对应的超参数如下: | ||
|
||
| Hyper Parameter | batch_size| learning_rate| warmup_steps| epoches| chunk_size|max_grad_norm | | ||
| :----: | :----: | :----: | :----: | :---: | :----: | :----: | | ||
| \ | 128/512| 2e-05 | 1237 | 40 | 2| 16/8 | | ||
|
||
## 数据准备 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在介绍具体的用法之前,在这里用表格给出复现的实验结果及其对应的重要的超参数 |
||
我们使用Dense Passage Retrieval的[原始仓库](https://github.com/Elvisambition/DPR) | ||
中提供的数据集进行训练和评估。可以使用[download_data.py](https://github.com/Elvisambition/DPR/blob/main/dpr/data/download_data.py) | ||
脚本下载所需数据集。 数据集详细介绍见[原仓库](https://github.com/Elvisambition/DPR) 。 | ||
|
||
### 数据格式 | ||
``` | ||
[ | ||
{ | ||
"question": "....", | ||
"answers": ["...", "...", "..."], | ||
"positive_ctxs": [{ | ||
"title": "...", | ||
"text": "...." | ||
}], | ||
"negative_ctxs": ["..."], | ||
"hard_negative_ctxs": ["..."] | ||
}, | ||
... | ||
] | ||
``` | ||
|
||
### 数据下载 | ||
在[原始仓库](https://github.com/Elvisambition/DPR) | ||
下使用命令 | ||
``` | ||
python data/download_data.py --resource data.wikipedia_split.psgs_w100 | ||
python data/download_data.py --resource data.retriever.nq | ||
python data/download_data.py --resource data.retriever.qas.nq | ||
``` | ||
### 单独下载链接 | ||
[data.retriever.nq-train](https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz) | ||
[data.retriever.nq-dev](https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-dev.json.gz) | ||
[data.retriever.qas.nq-dev](https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-dev.qa.csv) | ||
[data.retriever.qas.nq-test](https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-test.qa.csv) | ||
[data.retriever.qas.nq-train](https://dl.fbaipublicfiles.com/dpr/data/retriever/nq-train.qa.csv) | ||
[psgs_w100.tsv](https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz) | ||
|
||
## 模型训练 | ||
### 基于 [Dense Passage Retriever](https://arxiv.org/abs/2004.04906) 策略训练 | ||
``` | ||
python train_dense_encoder.py \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
--batch_size 128 \ | ||
--learning_rate 2e-05 \ | ||
--save_dir save_biencoder | ||
--warmup_steps 1237 \ | ||
--epoches 40 \ | ||
--max_grad_norm 2 \ | ||
--train_data_path {data_path} \ | ||
--chunk_size 16 \ | ||
``` | ||
|
||
参数含义说明 | ||
* `batch_size`: 批次大小 | ||
* `learning_rate`: 学习率 | ||
* `save_dir`:模型保存位置 | ||
* `warmupsteps`: 预热学习率参数 | ||
* `epoches`: 训练批次大小 | ||
* `max_grad_norm`: 详见ClipGradByGlobalNorm | ||
* `train_data_path`:训练数据存放地址 | ||
* `chunk_size`:chunk大小 | ||
|
||
## 生成文章稠密向量表示 | ||
|
||
``` | ||
python generate_dense_embeddings.py \ | ||
--model_file {path to biencoder} \ | ||
--ctx_file {path to psgs_w100.tsv file} \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 解释一下这些参数,并给出默认数据目录的路径,让用户一键运行 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 参数的解释就在{}里面,已经很清楚了,把需要的数据下载下来看看就一清二楚了。具体的路径是和上一步强相关的,得用户自己设置。 |
||
--shard_id {shard_num, 0-based} --num_shards {total number of shards} \ | ||
--out_file ${out files location + name PREFX} \ | ||
--que_model_path {que_model_path} \ | ||
--con_model_path {con_model_path} | ||
``` | ||
|
||
## 如果只有一台机器,可以直接使用 | ||
|
||
``` | ||
python generate_dense_embedding \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
--ctx_file {data/psgs_w100.tsv} \ | ||
--out_file {test_generate} \ | ||
--que_model_path {que_model_path} \ | ||
--con_model_path {con_model_path} | ||
``` | ||
|
||
|
||
参数含义说明 | ||
* `ctx_file`: ctx文件读取地址 | ||
* `out_file`: 生成后的文件输出地址 | ||
* `que_model_path`: question model path | ||
* `con_model_path`: context model path | ||
|
||
|
||
## 针对全部文档的检索器验证 | ||
``` | ||
python dense_retriever.py --hnsw_index \ | ||
--out_file {out_file} \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请确保命令能够一键运行,比如 |
||
--encoded_ctx_file {encoded_ctx} \ | ||
--ctx_file {ctx} \ | ||
--qa_file {nq.qa.csv} \ | ||
--que_model_path {que_model_path} \ | ||
--con_model_path {con_model_path} | ||
``` | ||
参数含义说明 | ||
* `hnsw_index`:使用hnsw_index | ||
* `outfile`: 输出文件地址 | ||
* `encoded_ctx_file`: 编码后的ctx文件 | ||
* `ctx_file`: ctx文件 | ||
* `qa_file`: qa_file文件 | ||
* `que_model_path`: question encoder model | ||
* `con_model_path`: context encoder model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不能只有表格,要给出相应的文字说明。
可以参考,https://github.com/PaddlePaddle/PaddleNLP/tree/develop/applications/neural_search