-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- `IndexUpdater` class for adding/removing new documents from an existing index. - `class_factory` wrapper for `HF_ColBERT` to initializing new types of models: - AutoModel - BERT - Deberta - Electra - Roberta - XLMRoberta - README updates.
- Loading branch information
1 parent
bf4df83
commit 81b7a1a
Showing
24 changed files
with
980 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from collections import defaultdict | ||
|
||
import tqdm | ||
import ujson | ||
from colbert.data import Ranking | ||
from colbert.distillation.scorer import Scorer | ||
from colbert.infra import Run | ||
from colbert.infra.provenance import Provenance | ||
from colbert.utility.utils.save_metadata import get_metadata_only | ||
from colbert.utils.utils import print_message, zipstar | ||
|
||
|
||
class RankingScorer: | ||
def __init__(self, scorer: Scorer, ranking: Ranking): | ||
self.scorer = scorer | ||
self.ranking = ranking.tolist() | ||
self.__provenance = Provenance() | ||
|
||
print_message(f"#> Loaded ranking with {len(self.ranking)} qid--pid pairs!") | ||
|
||
def provenance(self): | ||
return self.__provenance | ||
|
||
def run(self): | ||
print_message(f"#> Starting..") | ||
|
||
qids, pids, *_ = zipstar(self.ranking) | ||
distillation_scores = self.scorer.launch(qids, pids) | ||
|
||
scores_by_qid = defaultdict(list) | ||
|
||
for qid, pid, score in tqdm.tqdm(zip(qids, pids, distillation_scores)): | ||
scores_by_qid[qid].append((score, pid)) | ||
|
||
with Run().open("distillation_scores.json", "w") as f: | ||
for qid in tqdm.tqdm(scores_by_qid): | ||
obj = (qid, scores_by_qid[qid]) | ||
f.write(ujson.dumps(obj) + "\n") | ||
|
||
output_path = f.name | ||
print_message(f"#> Saved the distillation_scores to {output_path}") | ||
|
||
with Run().open(f"{output_path}.meta", "w") as f: | ||
d = {} | ||
d["metadata"] = get_metadata_only() | ||
d["provenance"] = self.provenance() | ||
line = ujson.dumps(d, indent=4) | ||
f.write(line) | ||
|
||
return output_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import torch | ||
import tqdm | ||
from colbert.infra import Run, RunConfig | ||
from colbert.infra.launcher import Launcher | ||
from colbert.modeling.reranker.electra import ElectraReranker | ||
from colbert.utils.utils import flatten | ||
from transformers import AutoModelForSequenceClassification, AutoTokenizer | ||
|
||
DEFAULT_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" | ||
|
||
|
||
class Scorer: | ||
def __init__(self, queries, collection, model=DEFAULT_MODEL, maxlen=180, bsize=256): | ||
self.queries = queries | ||
self.collection = collection | ||
self.model = model | ||
|
||
self.maxlen = maxlen | ||
self.bsize = bsize | ||
|
||
def launch(self, qids, pids): | ||
launcher = Launcher(self._score_pairs_process, return_all=True) | ||
outputs = launcher.launch(Run().config, qids, pids) | ||
|
||
return flatten(outputs) | ||
|
||
def _score_pairs_process(self, config, qids, pids): | ||
assert len(qids) == len(pids), (len(qids), len(pids)) | ||
share = 1 + len(qids) // config.nranks | ||
offset = config.rank * share | ||
endpos = (1 + config.rank) * share | ||
|
||
return self._score_pairs( | ||
qids[offset:endpos], pids[offset:endpos], show_progress=(config.rank < 1) | ||
) | ||
|
||
def _score_pairs(self, qids, pids, show_progress=False): | ||
tokenizer = AutoTokenizer.from_pretrained(self.model) | ||
model = AutoModelForSequenceClassification.from_pretrained(self.model).cuda() | ||
|
||
assert len(qids) == len(pids), (len(qids), len(pids)) | ||
|
||
scores = [] | ||
|
||
model.eval() | ||
with torch.inference_mode(): | ||
with torch.cuda.amp.autocast(): | ||
for offset in tqdm.tqdm( | ||
range(0, len(qids), self.bsize), disable=(not show_progress) | ||
): | ||
endpos = offset + self.bsize | ||
|
||
queries_ = [self.queries[qid] for qid in qids[offset:endpos]] | ||
passages_ = [self.collection[pid] for pid in pids[offset:endpos]] | ||
|
||
features = tokenizer( | ||
queries_, | ||
passages_, | ||
padding="longest", | ||
truncation=True, | ||
return_tensors="pt", | ||
max_length=self.maxlen, | ||
).to(model.device) | ||
|
||
scores.append(model(**features).logits.flatten()) | ||
|
||
scores = torch.cat(scores) | ||
scores = scores.tolist() | ||
|
||
Run().print(f"Returning with {len(scores)} scores") | ||
|
||
return scores | ||
|
||
|
||
# LONG-TERM TODO: This can be sped up by sorting by length in advance. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.