Skip to content

Commit

Permalink
Merge pull request #57 from DuanxinCao/feature-tair-vector-tvsmindexk…
Browse files Browse the repository at this point in the history
…nnsearch-tvsmindexmknnsearch

TairVector: add cmd tvs.mindexknnsearch and tvs.mindexmknnsearch
  • Loading branch information
yangbodong22011 authored Jan 5, 2023
2 parents 575cdf8 + 0dd8db4 commit bb06f1c
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 43 deletions.
11 changes: 7 additions & 4 deletions examples/tair_vector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env python
from conf_examples import get_tair

from tair import ResponseError


# create an index
# @param index_name the name of index
# @param dims the dimension of vector
Expand All @@ -13,12 +15,13 @@ def create_index(index_name: str, dims: str):
"M": 32,
"ef_construct": 200,
}
#index_params the params of index
return tair.tvs_create_index(index_name, dims,**index_params)
# index_params the params of index
return tair.tvs_create_index(index_name, dims, **index_params)
except ResponseError as e:
print(e)
return None


# delete an index
# @param index_name the name of index
# @return success: True, fail: False.
Expand All @@ -32,5 +35,5 @@ def delete_index(index_name: str):


if __name__ == "__main__":
create_index("test",4)
delete_index("test")
create_index("test", 4)
delete_index("test")
2 changes: 1 addition & 1 deletion tair/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from tair.tairsearch import ScandocidResult
from tair.tairstring import ExcasResult, ExgetResult
from tair.tairts import Aggregation, TairTsSkeyItem
from tair.tairvector import TairVectorIndex, TairVectorScanResult
from tair.tairzset import TairZsetItem
from tair.tairvector import TairVectorScanResult, TairVectorIndex

__all__ = [
"Aggregation",
Expand Down
12 changes: 7 additions & 5 deletions tair/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
parse_exset,
)
from tair.tairts import TairTsCommands
from tair.tairzset import TairZsetCommands, parse_tair_zset_items
from tair.tairvector import (
TairVectorCommands,
parse_tvs_get_index_result,
parse_tvs_get_result,
parse_tvs_search_result,
parse_tvs_msearch_result,
parse_tvs_hmget_result,
parse_tvs_msearch_result,
parse_tvs_search_result,
)
from tair.tairzset import TairZsetCommands, parse_tair_zset_items


class TairCommands(
Expand Down Expand Up @@ -143,16 +143,18 @@ def bool_ok(resp) -> bool:
float(resp[0].decode()), float(resp[1].decode())
),
# TairVector
"TVS.CREATEINDEX":bool_ok,
"TVS.CREATEINDEX": bool_ok,
"TVS.GETINDEX": parse_tvs_get_index_result,
"TVS.DELINDEX": int_or_none,
"TVS.HSET": int_or_none,
"TVS.DEL": int_or_none,
"TVS.HDEL": int_or_none,
"TVS.HGETALL": parse_tvs_get_result,
"TVS.HMGET":parse_tvs_hmget_result,
"TVS.HMGET": parse_tvs_hmget_result,
"TVS.KNNSEARCH": parse_tvs_search_result,
"TVS.MKNNSEARCH": parse_tvs_msearch_result,
"TVS.MINDEXKNNSEARCH": parse_tvs_search_result,
"TVS.MINDEXMKNNSEARCH": parse_tvs_msearch_result,
}


Expand Down
90 changes: 81 additions & 9 deletions tair/tairvector.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,35 @@
from typing import Sequence, Tuple, Union,Iterable
from tair.typing import ResponseT
from typing import Dict, List, Tuple, Union
from functools import partial, reduce
from typing import Dict, Iterable, List, Sequence, Tuple, Union

from redis.client import pairs_to_dict
from redis.utils import str_if_bytes
from typing import Sequence, Union
from functools import partial, reduce

from tair.typing import ResponseT

VectorType = Sequence[Union[int, float]]


class DistanceMetric:
Euclidean = "L2" # an alias to L2
L2 = "L2"
InnerProduct = "IP"
Jaccard = "JACCARD"


class IndexType:
HNSW = "HNSW"
FLAT = "FLAT"


class Constants:
VECTOR_KEY = "VECTOR"


class DataType:
Float32 = "FLOAT32"
Binary = "BINARY"


class TextVectorEncoder:
SEP = bytes(",", "ascii")
BITS = ("0", "1")
Expand All @@ -50,6 +55,7 @@ def decode(buf: bytes) -> Tuple[float]:
return tuple(int(x) for x in components)
return tuple(float(x) for x in components)


class TairVectorScanResult:
"""
wrapper for the results of scan commands
Expand Down Expand Up @@ -152,7 +158,7 @@ def tvs_mknnsearch(
self, k: int, vectors: Sequence[VectorType], filter_str: str = None, **kwargs
):
"""batch approximate nearest neighbors search for a list of vectors"""
return self.client.tvs_knnsearch(
return self.client.tvs_mknnsearch(
self.name, k, vectors, self.is_binary, filter_str, **kwargs
)

Expand All @@ -164,7 +170,6 @@ def __repr__(self):


class TairVectorCommands:

encode_vector = TextVectorEncoder.encode
decode_vector = TextVectorEncoder.decode

Expand Down Expand Up @@ -304,6 +309,8 @@ def tvs_scan(self, index: str, pattern: str = None, batch: int = 10):

SEARCH_CMD = "TVS.KNNSEARCH"
MSEARCH_CMD = "TVS.MKNNSEARCH"
MINDEXKNNSEARCH_CMD = "TVS.MINDEXKNNSEARCH"
MINDEXMKNNSEARCH_CMD = "TVS.MINDEXMKNNSEARCH"

def tvs_knnsearch(
self,
Expand Down Expand Up @@ -361,11 +368,73 @@ def tvs_mknnsearch(
*params
)

def tvs_mindexknnsearch(
self,
index: Sequence[str],
k: int,
vector: Union[VectorType, str],
is_binary: bool = False,
filter_str: str = None,
**kwargs
):
"""
search for the top @k approximate nearest neighbors of @vector in indexs
"""
params = reduce(lambda x, y: x + y, kwargs.items(), ())
if not isinstance(vector, str):
vector = TairVectorCommands.encode_vector(vector, is_binary)
if filter_str is None:
return self.execute_command(
self.MINDEXKNNSEARCH_CMD, len(index), *index, k, vector, *params
)
return self.execute_command(
self.MINDEXKNNSEARCH_CMD, len(index), *index, k, vector, filter_str, *params
)

def tvs_mindexmknnsearch(
self,
index: Sequence[str],
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: str = None,
**kwargs
):
"""
batch approximate nearest neighbors search for a list of vectors
"""
params = reduce(lambda x, y: x + y, kwargs.items(), ())
encoded_vectors = [
TairVectorCommands.encode_vector(x, is_binary) for x in vectors
]
if filter_str is None:
return self.execute_command(
self.MINDEXMKNNSEARCH_CMD,
len(index),
*index,
k,
len(encoded_vectors),
*encoded_vectors,
*params
)
return self.execute_command(
self.MINDEXMKNNSEARCH_CMD,
len(index),
*index,
k,
len(encoded_vectors),
*encoded_vectors,
filter_str,
*params
)


def parse_tvs_get_index_result(resp) -> Union[Dict, None]:
if len(resp) == 0:
return None
return pairs_to_dict(resp, decode_keys=True, decode_string_values=True)


def parse_tvs_get_result(resp) -> Dict:
result = pairs_to_dict(resp, decode_keys=True, decode_string_values=False)

Expand All @@ -376,13 +445,16 @@ def parse_tvs_get_result(resp) -> Dict:
values = map(str_if_bytes, result.values())
return dict(zip(result.keys(), values))


def parse_tvs_hmget_result(resp) -> tuple:
if len(resp) == 0:
return None
return ([resp[i].decode("ascii") if resp[i] else None for i in range(0, len(resp))])
return [resp[i].decode("ascii") if resp[i] else None for i in range(0, len(resp))]


def parse_tvs_search_result(resp) -> List[Tuple]:
return [(resp[i], float(resp[i + 1])) for i in range(0, len(resp), 2)]


def parse_tvs_msearch_result(resp) -> List[List[Tuple]]:
return [parse_tvs_search_result(r) for r in resp]
return [parse_tvs_search_result(r) for r in resp]
30 changes: 15 additions & 15 deletions tests/test_from_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,18 @@ async def test_from_url_async():
await t.close()


# @pytest.mark.asyncio
# async def test_from_url_async_cluster():
# url = f"{TAIR_CLUSTER_SCHEME}://{TAIR_CLUSTER_HOST}:{TAIR_CLUSTER_PORT}"
# tc = AsyncTairCluster.from_url(
# url, username=TAIR_CLUSTER_USERNAME, password=TAIR_CLUSTER_PASSWORD
# )
# key = "key_" + str(uuid.uuid4())
# value = "value_" + str(uuid.uuid4())
#
# assert await tc.exset(key, value)
# result: ExgetResult = await tc.exget(key)
# assert result.value == value.encode()
# assert result.version == 1
#
# await tc.close()
@pytest.mark.asyncio
async def test_from_url_async_cluster():
url = f"{TAIR_CLUSTER_SCHEME}://{TAIR_CLUSTER_HOST}:{TAIR_CLUSTER_PORT}"
tc = AsyncTairCluster.from_url(
url, username=TAIR_CLUSTER_USERNAME, password=TAIR_CLUSTER_PASSWORD
)
key = "key_" + str(uuid.uuid4())
value = "value_" + str(uuid.uuid4())

assert await tc.exset(key, value)
result: ExgetResult = await tc.exget(key)
assert result.value == value.encode()
assert result.version == 1

await tc.close()
Loading

0 comments on commit bb06f1c

Please sign in to comment.