Skip to content

Commit

Permalink
2024.01.31 Evidence Pattern Retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlettllc committed Jan 31, 2024
1 parent 57e9c1d commit a137eb3
Show file tree
Hide file tree
Showing 13 changed files with 1,290 additions and 11 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ bert-base-uncased/
NSM_H/datasets
NSM_H/checkpoint
log/
test.ipynb
test.ipynb
runs/
12 changes: 4 additions & 8 deletions atomic_pattern_retrieval/biencoder/biencoder_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ def build_index(output_path, rr_ap_vectors_path, index_buffer=50000):
index.serialize(output_path)


def reference_pipeline(
split,
def inference_pipeline(
questions_path,
all_rr_aps,
model_path,
Expand Down Expand Up @@ -227,7 +226,7 @@ def reference_pipeline(


if __name__ == "__main__":
topk = 100
topk = 500
epoch = 5
inference_dir = os.path.join(
f"data/{Config.ds_tag}/ap_retrieval/model", f"{Config.ds_tag}_ep_{epoch}"
Expand All @@ -253,11 +252,7 @@ def reference_pipeline(
)
for split in ["dev", "test", "train"]:
start = time.time()
split_folder = os.path.join(inference_dir, f"{epoch_folder}_{split}")
if not os.path.exists(split_folder):
os.makedirs(split_folder)
reference_pipeline(
split,
inference_pipeline(
questions_path=Config.ds_split_f(split),
all_rr_aps=read_json(Config.cache_rr_aps),
model_path=model_file,
Expand All @@ -267,6 +262,7 @@ def reference_pipeline(
Config.retrieved_ap_f(split),
),
index_file=os.path.join(inference_dir, "flat.index"),
top_k=topk
)
end = time.time()
print(end - start)
3 changes: 2 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class Config:
ep_retrieval_dir = config["ep_retrieval"]["work_dir"]
max_combine_rels = config["ep_retrieval"]["max_combine_rels"]
ep_rank_td_f = lambda split: f"{Config.ep_retrieval_dir}{Config.ds_tag}_{split}_top{Config.ap_topk}_ap_ep_rank_td.jsonl"
candi_ep_f = lambda split: f"{Config.ep_retrieval_dir}{Config.ds_tag}_{split}_top{Config.ap_topk}_ap_candi_ep.json"
ep_rank_sample_size = config["ep_retrieval"]["sample_size"]
candi_ep_f = lambda split: f"{Config.ep_retrieval_dir}{Config.ds_tag}_{split}_top{Config.ap_topk}_ap_candi_eps.json"
ranked_ep_f = lambda split: f"{Config.ep_retrieval_dir}{Config.ds_tag}_{split}_top{Config.ap_topk}_ap_ranked_ep.json"

# subgraph extraction
Expand Down
1 change: 1 addition & 0 deletions config_CWQ.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ ep_retrieval:
work_dir: "data/CWQ/ep_retrieval/"
# max combine count
max_combine_rels: 5
sample_size: 64

subgraph_extraction:
work_dir: "data/CWQ/subgraph_extraction/"
Expand Down
1 change: 1 addition & 0 deletions config_WebQSP.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ ep_retrieval:
work_dir: "data/WebQSP/ep_retrieval/"
# max combine count
max_combine_rels: 3
sample_size: 100

subgraph_extraction:
work_dir: "data/WebQSP/subgraph_extraction/"
Expand Down
158 changes: 158 additions & 0 deletions evidence_pattern_retrieval/BERT_Ranker/BertRanker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""
Copyright (c) 2021, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import Dataset
from transformers import BertPreTrainedModel, BertModel


def get_inf_mask(bool_mask):
return (~bool_mask) * -100000.0


class BertForCandidateRanking(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)

self.init_weights()

# for training return loss, [batch_size * num_sample]
# for testing, batch size have to be 1
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
sample_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
assert return_dict is None
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# for training, input is batch_size * sample_size * L
# for testing, it is batch_size * L
if labels is not None:
batch_size = input_ids.size(0)
sample_size = input_ids.size(1)
seq_length = input_ids.size(2)
# flatten first two dim
input_ids = input_ids.view((batch_size * sample_size, -1))
token_type_ids = token_type_ids.view((batch_size * sample_size, -1))
attention_mask = attention_mask.view((batch_size * sample_size, -1))

outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# embedding_by_tokens = outputs[0]
pooled_output = outputs[1]

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

# embedding_by_tokens = embedding_by_tokens.view((batch_size, sample_size, seq_length, 768))

loss = None
if labels[0].item() != -1:
# reshape logits
logits = logits.view((batch_size, sample_size))
logits = logits + get_inf_mask(sample_mask)
# apply infmask
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits, labels.view(-1))
else:
logits = logits.view((batch_size, sample_size))
logits = logits + get_inf_mask(sample_mask)

return (loss, logits) if loss is not None else logits


class ListDataset(Dataset):
def __init__(self, examples):
self.examples = examples

def __len__(self):
return len(self.examples)

def __getitem__(self, i):
return self.examples[i]

def __iter__(self):
return iter(self.examples)


# for single problem
class RankingFeature:
def __init__(self, pid, input_ids, token_type_ids, target_idx):
self.pid = pid
self.candidate_input_ids = input_ids
self.candidate_token_type_ids = token_type_ids
self.target_idx = target_idx


def _collect_contrastive_inputs(feat, num_sample, dummy_inputs):
input_ids = []
token_type_ids = []

input_ids.extend(feat.candidate_input_ids)
token_type_ids.extend(feat.candidate_token_type_ids)
filled_num = len(input_ids)
# force padding
for _ in range(filled_num, num_sample):
input_ids.append(dummy_inputs['input_ids'])
token_type_ids.append(dummy_inputs['token_type_ids'])
sample_mask = [1] * filled_num + [0] * (num_sample - filled_num)
return input_ids, token_type_ids, sample_mask


def disamb_collate_fn(data, tokenizer):
dummy_inputs = tokenizer('', '', return_token_type_ids=True)
# batch size
# input_id: B * N_Sample * L
# token_type: B * N_Sample * L
# attention_mask: B * N_Sample * N
# sample_mask: B * N_Sample
# labels: B, all zero
batch_size = len(data)
num_sample = max([len(x.candidate_input_ids) for x in data])

all_input_ids = []
all_token_type_ids = []
all_sample_masks = []
for feat in data:
input_ids, token_type_ids, sample_mask = _collect_contrastive_inputs(feat, num_sample, dummy_inputs)
all_input_ids.extend(input_ids)
all_token_type_ids.extend(token_type_ids)
all_sample_masks.append(sample_mask)

encoded = tokenizer.pad({'input_ids': all_input_ids, 'token_type_ids': all_token_type_ids}, return_tensors='pt')
all_sample_masks = torch.BoolTensor(all_sample_masks)
labels = torch.LongTensor([x.target_idx for x in data])

all_input_ids = encoded['input_ids'].view((batch_size, num_sample, -1))
all_token_type_ids = encoded['token_type_ids'].view((batch_size, num_sample, -1))
all_attention_masks = encoded['attention_mask'].view((batch_size, num_sample, -1))
return all_input_ids, all_token_type_ids, all_attention_masks, all_sample_masks, labels
Loading

0 comments on commit a137eb3

Please sign in to comment.