-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
Copy pathquantization.py
377 lines (332 loc) · 18.7 KB
/
quantization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
import time
from torch import Tensor
from typing import List, Literal, Tuple, TYPE_CHECKING
import numpy as np
import logging
from typing import Dict, Optional, Union
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
import faiss
import usearch
def semantic_search_faiss(
query_embeddings: np.ndarray,
corpus_embeddings: Optional[np.ndarray] = None,
corpus_index: Optional["faiss.Index"] = None,
corpus_precision: Literal["float32", "uint8", "ubinary"] = "float32",
top_k: int = 10,
ranges: Optional[np.ndarray] = None,
calibration_embeddings: Optional[np.ndarray] = None,
rescore: bool = True,
rescore_multiplier: int = 2,
exact: bool = True,
output_index: bool = False,
) -> Tuple[List[List[Dict[str, Union[int, float]]]], float, "faiss.Index"]:
"""
Performs semantic search using the FAISS library.
Rescoring will be performed if:
1. `rescore` is True
2. The query embeddings are not quantized
3. The corpus is quantized, i.e. the corpus precision is not float32
Only if these conditions are true, will we search for `top_k * rescore_multiplier` samples and then rescore to only
keep `top_k`.
:param query_embeddings: Embeddings of the query sentences. Ideally not quantized to allow for rescoring.
:param corpus_embeddings: Embeddings of the corpus sentences. Either `corpus_embeddings` or `corpus_index` should
be used, not both. The embeddings can be quantized to "int8" or "binary" for more efficient search.
:param corpus_index: FAISS index for the corpus sentences. Either `corpus_embeddings` or `corpus_index` should
be used, not both.
:param corpus_precision: Precision of the corpus embeddings. The options are "float32", "int8", or "binary".
Default is "float32".
:param top_k: Number of top results to retrieve. Default is 10.
:param ranges: Ranges for quantization of embeddings. This is only used for int8 quantization, where the ranges
refers to the minimum and maximum values for each dimension. So, it's a 2D array with shape (2, embedding_dim).
Default is None, which means that the ranges will be calculated from the calibration embeddings.
:param calibration_embeddings: Embeddings used for calibration during quantization. This is only used for int8
quantization, where the calibration embeddings can be used to compute ranges, i.e. the minimum and maximum
values for each dimension. Default is None, which means that the ranges will be calculated from the query
embeddings. This is not recommended.
:param rescore: Whether to perform rescoring. Note that rescoring still will only be used if the query embeddings
are not quantized and the corpus is quantized, i.e. the corpus precision is not "float32". Default is True.
:param rescore_multiplier: Oversampling factor for rescoring. The code will now search `top_k * rescore_multiplier` samples
and then rescore to only keep `top_k`. Default is 2.
:param exact: Whether to use exact search or approximate search. Default is True.
:param output_index: Whether to output the FAISS index used for the search. Default is False.
:return: A tuple containing a list of search results and the time taken for the search. If `output_index` is True,
the tuple will also contain the FAISS index used for the search.
:raises ValueError: If both `corpus_embeddings` and `corpus_index` are provided or if neither is provided.
The list of search results is in the format: [[{"corpus_id": int, "score": float}, ...], ...]
The time taken for the search is a float value.
"""
import faiss
if corpus_embeddings is not None and corpus_index is not None:
raise ValueError("Only corpus_embeddings or corpus_index should be used, not both.")
if corpus_embeddings is None and corpus_index is None:
raise ValueError("Either corpus_embeddings or corpus_index should be used.")
# If corpus_index is not provided, create a new index
if corpus_index is None:
if corpus_precision in ("float32", "uint8"):
if exact:
corpus_index = faiss.IndexFlatIP(corpus_embeddings.shape[1])
else:
corpus_index = faiss.IndexHNSWFlat(corpus_embeddings.shape[1], 16)
elif corpus_precision == "ubinary":
if exact:
corpus_index = faiss.IndexBinaryFlat(corpus_embeddings.shape[1] * 8)
else:
corpus_index = faiss.IndexBinaryHNSW(corpus_embeddings.shape[1] * 8, 16)
corpus_index.add(corpus_embeddings)
# If rescoring is enabled and the query embeddings are in float32, we need to quantize them
# to the same precision as the corpus embeddings. Also update the top_k value to account for the
# rescore_multiplier
rescore_embeddings = None
k = top_k
if query_embeddings.dtype not in (np.uint8, np.int8):
if rescore:
if corpus_precision != "float32":
rescore_embeddings = query_embeddings
k *= rescore_multiplier
else:
logger.warning(
"Rescoring is enabled but the corpus is not quantized. Either pass `rescore=False` or "
'quantize the corpus embeddings with `quantize_embeddings(embeddings, precision="...") `'
'and pass `corpus_precision="..."` to `semantic_search_faiss`.'
)
query_embeddings = quantize_embeddings(
query_embeddings,
precision=corpus_precision,
ranges=ranges,
calibration_embeddings=calibration_embeddings,
)
elif rescore:
logger.warning(
"Rescoring is enabled but the query embeddings are quantized. Either pass `rescore=False` or don't quantize the query embeddings."
)
# Perform the search using the usearch index
start_t = time.time()
scores, indices = corpus_index.search(query_embeddings, k)
# If rescoring is enabled, we need to rescore the results using the rescore_embeddings
if rescore_embeddings is not None:
top_k_embeddings = np.array(
[[corpus_index.reconstruct(idx.item()) for idx in query_indices] for query_indices in indices]
)
# If the corpus precision is binary, we need to unpack the bits
if corpus_precision == "ubinary":
top_k_embeddings = np.unpackbits(top_k_embeddings, axis=-1).astype(int)
else:
top_k_embeddings = top_k_embeddings.astype(int)
# rescore_embeddings: [num_queries, embedding_dim]
# top_k_embeddings: [num_queries, top_k, embedding_dim]
# updated_scores: [num_queries, top_k]
# We use einsum to calculate the dot product between the query and the top_k embeddings, equivalent to looping
# over the queries and calculating 'rescore_embeddings[i] @ top_k_embeddings[i].T'
rescored_scores = np.einsum("ij,ikj->ik", rescore_embeddings, top_k_embeddings)
rescored_indices = np.argsort(-rescored_scores)[:, :top_k]
indices = indices[np.arange(len(query_embeddings))[:, None], rescored_indices]
scores = rescored_scores[np.arange(len(query_embeddings))[:, None], rescored_indices]
delta_t = time.time() - start_t
outputs = (
[
[
{"corpus_id": int(neighbor), "score": float(score)}
for score, neighbor in zip(scores[query_id], indices[query_id])
]
for query_id in range(len(query_embeddings))
],
delta_t,
)
if output_index:
outputs = (*outputs, corpus_index)
return outputs
def semantic_search_usearch(
query_embeddings: np.ndarray,
corpus_embeddings: Optional[np.ndarray] = None,
corpus_index: Optional["usearch.index.Index"] = None,
corpus_precision: Literal["float32", "int8", "binary"] = "float32",
top_k: int = 10,
ranges: Optional[np.ndarray] = None,
calibration_embeddings: Optional[np.ndarray] = None,
rescore: bool = True,
rescore_multiplier: int = 2,
exact: bool = True,
output_index: bool = False,
) -> Tuple[List[List[Dict[str, Union[int, float]]]], float, "usearch.index.Index"]:
"""
Performs semantic search using the usearch library.
Rescoring will be performed if:
1. `rescore` is True
2. The query embeddings are not quantized
3. The corpus is quantized, i.e. the corpus precision is not float32
Only if these conditions are true, will we search for `top_k * rescore_multiplier` samples and then rescore to only
keep `top_k`.
:param query_embeddings: Embeddings of the query sentences. Ideally not quantized to allow for rescoring.
:param corpus_embeddings: Embeddings of the corpus sentences. Either `corpus_embeddings` or `corpus_index` should
be used, not both. The embeddings can be quantized to "int8" or "binary" for more efficient search.
:param corpus_index: usearch index for the corpus sentences. Either `corpus_embeddings` or `corpus_index` should
be used, not both.
:param corpus_precision: Precision of the corpus embeddings. The options are "float32", "int8", or "binary".
Default is "float32".
:param top_k: Number of top results to retrieve. Default is 10.
:param ranges: Ranges for quantization of embeddings. This is only used for int8 quantization, where the ranges
refers to the minimum and maximum values for each dimension. So, it's a 2D array with shape (2, embedding_dim).
Default is None, which means that the ranges will be calculated from the calibration embeddings.
:param calibration_embeddings: Embeddings used for calibration during quantization. This is only used for int8
quantization, where the calibration embeddings can be used to compute ranges, i.e. the minimum and maximum
values for each dimension. Default is None, which means that the ranges will be calculated from the query
embeddings. This is not recommended.
:param rescore: Whether to perform rescoring. Note that rescoring still will only be used if the query embeddings
are not quantized and the corpus is quantized, i.e. the corpus precision is not "float32". Default is True.
:param rescore_multiplier: Oversampling factor for rescoring. The code will now search `top_k * rescore_multiplier` samples
and then rescore to only keep `top_k`. Default is 2.
:param exact: Whether to use exact search or approximate search. Default is True.
:param output_index: Whether to output the usearch index used for the search. Default is False.
:return: A tuple containing a list of search results and the time taken for the search. If `output_index` is True,
the tuple will also contain the usearch index used for the search.
:raises ValueError: If both `corpus_embeddings` and `corpus_index` are provided or if neither is provided.
The list of search results is in the format: [[{"corpus_id": int, "score": float}, ...], ...]
The time taken for the search is a float value.
"""
from usearch.index import Index
from usearch.compiled import ScalarKind
if corpus_embeddings is not None and corpus_index is not None:
raise ValueError("Only corpus_embeddings or corpus_index should be used, not both.")
if corpus_embeddings is None and corpus_index is None:
raise ValueError("Either corpus_embeddings or corpus_index should be used.")
if corpus_precision not in ["float32", "int8", "binary"]:
raise ValueError('corpus_precision must be "float32", "int8", or "binary" for usearch')
# If corpus_index is not provided, create a new index
if corpus_index is None:
if corpus_precision == "float32":
corpus_index = Index(
ndim=corpus_embeddings.shape[1],
metric="cos",
dtype="f32",
)
elif corpus_precision == "int8":
corpus_index = Index(
ndim=corpus_embeddings.shape[1],
metric="ip",
dtype="i8",
)
elif corpus_precision == "binary":
corpus_index = Index(
ndim=corpus_embeddings.shape[1],
metric="hamming",
dtype="i8",
)
corpus_index.add(np.arange(len(corpus_embeddings)), corpus_embeddings)
# If rescoring is enabled and the query embeddings are in float32, we need to quantize them
# to the same precision as the corpus embeddings. Also update the top_k value to account for the
# rescore_multiplier
rescore_embeddings = None
k = top_k
if query_embeddings.dtype not in (np.uint8, np.int8):
if rescore:
if corpus_index.dtype != ScalarKind.F32:
rescore_embeddings = query_embeddings
k *= rescore_multiplier
else:
logger.warning(
"Rescoring is enabled but the corpus is not quantized. Either pass `rescore=False` or "
'quantize the corpus embeddings with `quantize_embeddings(embeddings, precision="...") `'
'and pass `corpus_precision="..."` to `semantic_search_usearch`.'
)
query_embeddings = quantize_embeddings(
query_embeddings,
precision=corpus_precision,
ranges=ranges,
calibration_embeddings=calibration_embeddings,
)
elif rescore:
logger.warning(
"Rescoring is enabled but the query embeddings are quantized. Either pass `rescore=False` or don't quantize the query embeddings."
)
# Perform the search using the usearch index
start_t = time.time()
matches = corpus_index.search(query_embeddings, count=k, exact=exact)
scores = matches.distances
indices = matches.keys
# If rescoring is enabled, we need to rescore the results using the rescore_embeddings
if rescore_embeddings is not None:
top_k_embeddings = np.array([corpus_index.get(query_indices) for query_indices in indices])
# If the corpus precision is binary, we need to unpack the bits
if corpus_precision == "binary":
top_k_embeddings = np.unpackbits(top_k_embeddings.astype(np.uint8), axis=-1)
top_k_embeddings = top_k_embeddings.astype(int)
# rescore_embeddings: [num_queries, embedding_dim]
# top_k_embeddings: [num_queries, top_k, embedding_dim]
# updated_scores: [num_queries, top_k]
# We use einsum to calculate the dot product between the query and the top_k embeddings, equivalent to looping
# over the queries and calculating 'rescore_embeddings[i] @ top_k_embeddings[i].T'
rescored_scores = np.einsum("ij,ikj->ik", rescore_embeddings, top_k_embeddings)
rescored_indices = np.argsort(-rescored_scores)[:, :top_k]
indices = indices[np.arange(len(query_embeddings))[:, None], rescored_indices]
scores = rescored_scores[np.arange(len(query_embeddings))[:, None], rescored_indices]
delta_t = time.time() - start_t
outputs = (
[
[
{"corpus_id": int(neighbor), "score": float(score)}
for score, neighbor in zip(scores[query_id], indices[query_id])
]
for query_id in range(len(query_embeddings))
],
delta_t,
)
if output_index:
outputs = (*outputs, corpus_index)
return outputs
def quantize_embeddings(
embeddings: Union[Tensor, np.ndarray],
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"],
ranges: Optional[np.ndarray] = None,
calibration_embeddings: Optional[np.ndarray] = None,
) -> np.ndarray:
"""
Quantizes embeddings to a lower precision. This can be used to reduce the memory footprint and increase the
speed of similarity search. The supported precisions are "float32", "int8", "uint8", "binary", and "ubinary".
:param embeddings: Unquantized (e.g. float) embeddings with to quantize to a given precision
:param precision: The precision to convert to. Options are "float32", "int8", "uint8", "binary", "ubinary".
:param ranges: Ranges for quantization of embeddings. This is only used for int8 quantization, where the ranges
refers to the minimum and maximum values for each dimension. So, it's a 2D array with shape (2, embedding_dim).
Default is None, which means that the ranges will be calculated from the calibration embeddings.
:type ranges: Optional[np.ndarray]
:param calibration_embeddings: Embeddings used for calibration during quantization. This is only used for int8
quantization, where the calibration embeddings can be used to compute ranges, i.e. the minimum and maximum
values for each dimension. Default is None, which means that the ranges will be calculated from the query
embeddings. This is not recommended.
:type calibration_embeddings: Optional[np.ndarray]
:return: Quantized embeddings with the specified precision
"""
if isinstance(embeddings, Tensor):
embeddings = embeddings.cpu().numpy()
elif isinstance(embeddings, list):
if isinstance(embeddings[0], Tensor):
embeddings = [embedding.cpu().numpy() for embedding in embeddings]
embeddings = np.array(embeddings)
if embeddings.dtype in (np.uint8, np.int8):
raise Exception("Embeddings to quantize must be float rather than int8 or uint8.")
if precision == "float32":
return embeddings.astype(np.float32)
if precision.endswith("int8"):
# Either use the 1. provided ranges, 2. the calibration dataset or 3. the provided embeddings
if ranges is None:
if calibration_embeddings is not None:
ranges = np.vstack((np.min(calibration_embeddings, axis=0), np.max(calibration_embeddings, axis=0)))
else:
if embeddings.shape[0] < 100:
logger.warning(
f"Computing {precision} quantization buckets based on {len(embeddings)} embedding{'s' if len(embeddings) != 1 else ''}."
f" {precision} quantization is more stable with `ranges` calculated from more embeddings "
"or a `calibration_embeddings` that can be used to calculate the buckets."
)
ranges = np.vstack((np.min(embeddings, axis=0), np.max(embeddings, axis=0)))
starts = ranges[0, :]
steps = (ranges[1, :] - ranges[0, :]) / 255
if precision == "uint8":
return ((embeddings - starts) / steps).astype(np.uint8)
elif precision == "int8":
return ((embeddings - starts) / steps - 128).astype(np.int8)
if precision == "binary":
return (np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1) - 128).astype(np.int8)
if precision == "ubinary":
return np.packbits(embeddings > 0).reshape(embeddings.shape[0], -1)
raise ValueError(f"Precision {precision} is not supported")