-
Notifications
You must be signed in to change notification settings - Fork 0
/
indexers.py
121 lines (95 loc) · 3.94 KB
/
indexers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Credit : facebookresearch/DPR
"""
FAISS-based index components for dense retriever
"""
import faiss
import logging
import numpy as np
import os
import pickle
from typing import List, Tuple
logger = logging.getLogger()
class DenseIndexer(object):
def __init__(self, buffer_size: int = 50000):
self.buffer_size = buffer_size
self.index_id_to_db_id = []
self.index = None
def init_index(self, vector_sz: int):
raise NotImplementedError
def index_data(self, data: List[Tuple[object, np.array]]):
raise NotImplementedError
def get_index_name(self):
raise NotImplementedError
def search_knn(
self, query_vectors: np.array, top_docs: int
) -> List[Tuple[List[object], List[float]]]:
raise NotImplementedError
def serialize(self, file: str):
logger.info("Serializing index to %s", file)
if os.path.isdir(file):
index_file = os.path.join(file, "index.dpr")
meta_file = os.path.join(file, "index_meta.dpr")
else:
index_file = file + ".index.dpr"
meta_file = file + ".index_meta.dpr"
faiss.write_index(self.index, index_file)
with open(meta_file, mode="wb") as f:
pickle.dump(self.index_id_to_db_id, f)
def get_files(self, path: str):
if os.path.isdir(path):
index_file = os.path.join(path, "index.dpr")
meta_file = os.path.join(path, "index_meta.dpr")
else:
index_file = path + ".{}.dpr".format(self.get_index_name())
meta_file = path + ".{}_meta.dpr".format(self.get_index_name())
return index_file, meta_file
def index_exists(self, path: str):
index_file, meta_file = self.get_files(path)
return os.path.isfile(index_file) and os.path.isfile(meta_file)
def deserialize(self, path: str):
logger.info("Loading index from %s", path)
index_file, meta_file = self.get_files(path)
self.index = faiss.read_index(index_file)
logger.info(
"Loaded index of type %s and size %d", type(self.index), self.index.ntotal
)
with open(meta_file, "rb") as reader:
self.index_id_to_db_id = pickle.load(reader)
assert (
len(self.index_id_to_db_id) == self.index.ntotal
), "Deserialized index_id_to_db_id should match faiss index size"
def _update_id_mapping(self, db_ids: List) -> int:
self.index_id_to_db_id.extend(db_ids)
return len(self.index_id_to_db_id)
class DenseFlatIndexer(DenseIndexer):
def __init__(self, buffer_size: int = 50000):
super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size)
def init_index(self, vector_sz: int):
self.index = faiss.IndexFlatIP(vector_sz)
def index_data(self, data: List[Tuple[object, np.array]]):
n = len(data)
# indexing in batches is beneficial for many faiss index types
for i in range(0, n, self.buffer_size):
db_ids = [t[0] for t in data[i : i + self.buffer_size]]
vectors = [
np.reshape(t[1], (1, -1)) for t in data[i : i + self.buffer_size]
]
vectors = np.concatenate(vectors, axis=0)
total_data = self._update_id_mapping(db_ids)
self.index.add(vectors)
logger.info("data indexed %d", total_data)
indexed_cnt = len(self.index_id_to_db_id)
logger.info("Total data indexed %d", indexed_cnt)
def search_knn(
self, query_vectors: np.array, top_docs: int
) -> List[Tuple[List[object], List[float]]]:
scores, indexes = self.index.search(query_vectors, top_docs)
# convert to external ids
db_ids = [
[self.index_id_to_db_id[i] for i in query_top_idxs]
for query_top_idxs in indexes
]
result = [(db_ids[i], scores[i]) for i in range(len(db_ids))]
return result
def get_index_name(self):
return "flat_index"