Skip to content

Commit

Permalink
Improve embedding retrieval performance
Browse files Browse the repository at this point in the history
Signed-off-by: yhmo <yihua.mo@zilliz.com>
  • Loading branch information
yhmo committed Oct 14, 2024
1 parent c7de801 commit 06a67af
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
31 changes: 28 additions & 3 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import time
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import ujson
Expand Down Expand Up @@ -425,10 +426,15 @@ def __init__(
nq_thres = 0
for topk in all_topks:
start, end = nq_thres, nq_thres + topk
t1 = time.time()
nq_th_fields = self.get_fields_by_range(start, end, fields_data)
t2 = time.time()
print("get_fields_by_range", t2-t1)
data.append(
Hits(topk, all_pks[start:end], all_scores[start:end], nq_th_fields, output_fields)
Hits(topk, all_pks[start:end], all_scores[start:end], nq_th_fields, output_fields, start)
)
t3 = time.time()
print("Hits", t3 - t2)
nq_thres += topk

super().__init__(data)
Expand Down Expand Up @@ -487,7 +493,11 @@ def get_fields_by_range(
dim, vectors = field.vectors.dim, field.vectors
field_meta.vectors.dim = dim
if dtype == DataType.FLOAT_VECTOR:
field2data[name] = vectors.float_vector.data[start * dim : end * dim], field_meta
# The vectors.float_vector.data is a protobuf.RepeatedScalarFieldContainer.
# This class is poor performance to convert to list, especially for a long-length array.
# For high topk and high dimention case, length of the array could exceed 10 millions, might take near 1 second to convert.
# Here we pass the reference instead of converting it to a list, just let the Hits to split it.
field2data[name] = vectors.float_vector.data, field_meta
continue

if dtype == DataType.BINARY_VECTOR:
Expand Down Expand Up @@ -543,6 +553,7 @@ def __init__(
distances: List[float],
fields: Dict[str, Tuple[List[Any], schema_pb2.FieldData]],
output_fields: List[str],
start: int, # for FloatVector, the correct position to get vector from reference of protobuf.RepeatedScalarFieldContainer
):
"""
Args:
Expand All @@ -551,19 +562,32 @@ def __init__(
"""
self.ids = pks
self.distances = distances
self.start = start

all_fields = list(fields.keys())
dynamic_fields = list(set(output_fields) - set(all_fields))

# In high-topk case, the hits.append() could be slow because there are unnecessary gc actions.
# Disable gc here. Performance improved by 10% for topk=16384 case.
import gc
gc.disable()
hits = []
for i in range(topk):
curr_field = {}
for fname, (data, field_meta) in fields.items():
if len(data) <= i:
curr_field[fname] = None
# Get dense vectors
if field_meta.type == DataType.FLOAT_VECTOR:
dim = field_meta.vectors.dim
# For FloatVector field, the data is a reference of protobuf.RepeatedScalarFieldContainer
# Fetching a long-length list from RepeatedScalarFieldContainer is a time-consuming work.
# Here we fetch vector one by one, each is a short-length list, to improve performance.
# Performance improved by 25% for topk=16384 dim=1536 case, 19 seconds decrease to 14 seconds.
curr_field[fname] = data[(self.start + i) * dim: (self.start + i + 1) * dim]
continue

if field_meta.type in (
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
DataType.BFLOAT16_VECTOR,
DataType.FLOAT16_VECTOR,
Expand Down Expand Up @@ -591,6 +615,7 @@ def __init__(

hits.append(Hit(pks[i], distances[i], curr_field))

gc.enable()
super().__init__(hits)

def __iter__(self) -> SequenceIterator:
Expand Down
8 changes: 5 additions & 3 deletions pymilvus/client/entity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,9 +644,11 @@ def check_append(field_data: Any):
dim = field_data.vectors.dim
if len(field_data.vectors.float_vector.data) >= index * dim:
start_pos, end_pos = index * dim, (index + 1) * dim
entity_row_data[field_data.field_name] = [
np.single(x) for x in field_data.vectors.float_vector.data[start_pos:end_pos]
]
# Here we use numpy.array to convert the float64 values to numpy.float32 values,
# and return a list of numpy.float32 to users
# By using numpy.array, performance improved by 60% for topk=16384 dim=1536 case.
arr = np.array(field_data.vectors.float_vector.data[start_pos:end_pos], dtype=np.float32)
entity_row_data[field_data.field_name] = [x for x in arr]
elif field_data.type == DataType.BINARY_VECTOR:
dim = field_data.vectors.dim
if len(field_data.vectors.binary_vector) >= index * (dim // 8):
Expand Down
3 changes: 3 additions & 0 deletions pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,10 @@ def _execute_search(
func = kwargs.get("_callback", None)
return SearchFuture(future, func)

t1 = time.time()
response = self._stub.Search(request, timeout=timeout)
t2 = time.time()
print("stub.Search", t2-t1)
check_status(response.status)
round_decimal = kwargs.get("round_decimal", -1)
return SearchResult(response.results, round_decimal, status=response.status)
Expand Down

0 comments on commit 06a67af

Please sign in to comment.