Skip to content

Commit

Permalink
add cmd tvs.mindexknnsearch and tvs.mindexmknnsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
DuanxinCao committed Jan 5, 2023
1 parent 575cdf8 commit eca10e7
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
- name: set up python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.9"

- name: install redis-py
run: pip install redis
Expand Down
11 changes: 7 additions & 4 deletions examples/tair_vector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env python
from conf_examples import get_tair

from tair import ResponseError


# create an index
# @param index_name the name of index
# @param dims the dimension of vector
Expand All @@ -13,12 +15,13 @@ def create_index(index_name: str, dims: str):
"M": 32,
"ef_construct": 200,
}
#index_params the params of index
return tair.tvs_create_index(index_name, dims,**index_params)
# index_params the params of index
return tair.tvs_create_index(index_name, dims, **index_params)
except ResponseError as e:
print(e)
return None


# delete an index
# @param index_name the name of index
# @return success: True, fail: False.
Expand All @@ -32,5 +35,5 @@ def delete_index(index_name: str):


if __name__ == "__main__":
create_index("test",4)
delete_index("test")
create_index("test", 4)
delete_index("test")
2 changes: 1 addition & 1 deletion tair/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from tair.tairsearch import ScandocidResult
from tair.tairstring import ExcasResult, ExgetResult
from tair.tairts import Aggregation, TairTsSkeyItem
from tair.tairvector import TairVectorIndex, TairVectorScanResult
from tair.tairzset import TairZsetItem
from tair.tairvector import TairVectorScanResult, TairVectorIndex

__all__ = [
"Aggregation",
Expand Down
12 changes: 7 additions & 5 deletions tair/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
parse_exset,
)
from tair.tairts import TairTsCommands
from tair.tairzset import TairZsetCommands, parse_tair_zset_items
from tair.tairvector import (
TairVectorCommands,
parse_tvs_get_index_result,
parse_tvs_get_result,
parse_tvs_search_result,
parse_tvs_msearch_result,
parse_tvs_hmget_result,
parse_tvs_msearch_result,
parse_tvs_search_result,
)
from tair.tairzset import TairZsetCommands, parse_tair_zset_items


class TairCommands(
Expand Down Expand Up @@ -143,16 +143,18 @@ def bool_ok(resp) -> bool:
float(resp[0].decode()), float(resp[1].decode())
),
# TairVector
"TVS.CREATEINDEX":bool_ok,
"TVS.CREATEINDEX": bool_ok,
"TVS.GETINDEX": parse_tvs_get_index_result,
"TVS.DELINDEX": int_or_none,
"TVS.HSET": int_or_none,
"TVS.DEL": int_or_none,
"TVS.HDEL": int_or_none,
"TVS.HGETALL": parse_tvs_get_result,
"TVS.HMGET":parse_tvs_hmget_result,
"TVS.HMGET": parse_tvs_hmget_result,
"TVS.KNNSEARCH": parse_tvs_search_result,
"TVS.MKNNSEARCH": parse_tvs_msearch_result,
"TVS.MINDEXKNNSEARCH": parse_tvs_search_result,
"TVS.MINDEXMKNNSEARCH": parse_tvs_msearch_result,
}


Expand Down
90 changes: 81 additions & 9 deletions tair/tairvector.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,35 @@
from typing import Sequence, Tuple, Union,Iterable
from tair.typing import ResponseT
from typing import Dict, List, Tuple, Union
from functools import partial, reduce
from typing import Dict, Iterable, List, Sequence, Tuple, Union

from redis.client import pairs_to_dict
from redis.utils import str_if_bytes
from typing import Sequence, Union
from functools import partial, reduce

from tair.typing import ResponseT

VectorType = Sequence[Union[int, float]]


class DistanceMetric:
Euclidean = "L2" # an alias to L2
L2 = "L2"
InnerProduct = "IP"
Jaccard = "JACCARD"


class IndexType:
HNSW = "HNSW"
FLAT = "FLAT"


class Constants:
VECTOR_KEY = "VECTOR"


class DataType:
Float32 = "FLOAT32"
Binary = "BINARY"


class TextVectorEncoder:
SEP = bytes(",", "ascii")
BITS = ("0", "1")
Expand All @@ -50,6 +55,7 @@ def decode(buf: bytes) -> Tuple[float]:
return tuple(int(x) for x in components)
return tuple(float(x) for x in components)


class TairVectorScanResult:
"""
wrapper for the results of scan commands
Expand Down Expand Up @@ -152,7 +158,7 @@ def tvs_mknnsearch(
self, k: int, vectors: Sequence[VectorType], filter_str: str = None, **kwargs
):
"""batch approximate nearest neighbors search for a list of vectors"""
return self.client.tvs_knnsearch(
return self.client.tvs_mknnsearch(
self.name, k, vectors, self.is_binary, filter_str, **kwargs
)

Expand All @@ -164,7 +170,6 @@ def __repr__(self):


class TairVectorCommands:

encode_vector = TextVectorEncoder.encode
decode_vector = TextVectorEncoder.decode

Expand Down Expand Up @@ -304,6 +309,8 @@ def tvs_scan(self, index: str, pattern: str = None, batch: int = 10):

SEARCH_CMD = "TVS.KNNSEARCH"
MSEARCH_CMD = "TVS.MKNNSEARCH"
MINDEXKNNSEARCH_CMD = "TVS.MINDEXKNNSEARCH"
MINDEXMKNNSEARCH_CMD = "TVS.MINDEXMKNNSEARCH"

def tvs_knnsearch(
self,
Expand Down Expand Up @@ -361,11 +368,73 @@ def tvs_mknnsearch(
*params
)

def tvs_mindexknnsearch(
self,
index: Sequence[str],
k: int,
vector: Union[VectorType, str],
is_binary: bool = False,
filter_str: str = None,
**kwargs
):
"""
search for the top @k approximate nearest neighbors of @vector in indexs
"""
params = reduce(lambda x, y: x + y, kwargs.items(), ())
if not isinstance(vector, str):
vector = TairVectorCommands.encode_vector(vector, is_binary)
if filter_str is None:
return self.execute_command(
self.MINDEXKNNSEARCH_CMD, len(index), *index, k, vector, *params
)
return self.execute_command(
self.MINDEXKNNSEARCH_CMD, len(index), *index, k, vector, filter_str, *params
)

def tvs_mindexmknnsearch(
self,
index: Sequence[str],
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: str = None,
**kwargs
):
"""
batch approximate nearest neighbors search for a list of vectors
"""
params = reduce(lambda x, y: x + y, kwargs.items(), ())
encoded_vectors = [
TairVectorCommands.encode_vector(x, is_binary) for x in vectors
]
if filter_str is None:
return self.execute_command(
self.MINDEXMKNNSEARCH_CMD,
len(index),
*index,
k,
len(encoded_vectors),
*encoded_vectors,
*params
)
return self.execute_command(
self.MINDEXMKNNSEARCH_CMD,
len(index),
*index,
k,
len(encoded_vectors),
*encoded_vectors,
filter_str,
*params
)


def parse_tvs_get_index_result(resp) -> Union[Dict, None]:
if len(resp) == 0:
return None
return pairs_to_dict(resp, decode_keys=True, decode_string_values=True)


def parse_tvs_get_result(resp) -> Dict:
result = pairs_to_dict(resp, decode_keys=True, decode_string_values=False)

Expand All @@ -376,13 +445,16 @@ def parse_tvs_get_result(resp) -> Dict:
values = map(str_if_bytes, result.values())
return dict(zip(result.keys(), values))


def parse_tvs_hmget_result(resp) -> tuple:
if len(resp) == 0:
return None
return ([resp[i].decode("ascii") if resp[i] else None for i in range(0, len(resp))])
return [resp[i].decode("ascii") if resp[i] else None for i in range(0, len(resp))]


def parse_tvs_search_result(resp) -> List[Tuple]:
return [(resp[i], float(resp[i + 1])) for i in range(0, len(resp), 2)]


def parse_tvs_msearch_result(resp) -> List[List[Tuple]]:
return [parse_tvs_search_result(r) for r in resp]
return [parse_tvs_search_result(r) for r in resp]
56 changes: 47 additions & 9 deletions tests/test_tairvector.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,39 @@
# /user/bin/env python3
import sys
import os
import string
import sys
import unittest
import uuid
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from random import choice, randint, random

from random import random, randint, choice
import unittest
import redis
import string

from tair.tairvector import DataType, DistanceMetric, Constants, TairVectorIndex
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from tair.tairvector import (
Constants,
DataType,
DistanceMetric,
TairVectorCommands,
TairVectorIndex,
)

from .conftest import get_tair_client

client = get_tair_client()

dim = 16
num_vectors = 100
test_vectors = [[random() for _ in range(dim)] for _ in range(num_vectors)]
test2_vectors = [[random() for _ in range(dim)] for _ in range(num_vectors)]
num_attrs = 3
attr_keys = ["key-%d" % i for i in range(num_attrs)]
attr_values = [
"".join(choice(string.ascii_uppercase + string.digits) for _ in range(4))
for _ in range(num_vectors * num_attrs)
]
test_attributes = [
dict(zip(attr_keys, attr_values[i: i + 3]))
dict(zip(attr_keys, attr_values[i : i + 3]))
for i in range(0, num_vectors * num_attrs, num_attrs)
]

Expand Down Expand Up @@ -177,7 +186,9 @@ def test_hmget(self):
value2 = "value_" + str(uuid.uuid4())
ret = client.tvs_hset("test", key, vector=vector, field1=value1, field2=value2)
self.assertTrue(ret)
obj = client.tvs_hmget("test", key, Constants.VECTOR_KEY, "field1", "field2", "field3")
obj = client.tvs_hmget(
"test", key, Constants.VECTOR_KEY, "field1", "field2", "field3"
)
self.assertEqual(len(obj[0].split(",")), len(vector))
self.assertEqual(obj[1], str(value1))
self.assertEqual(obj[2], str(value2))
Expand Down Expand Up @@ -230,17 +241,22 @@ def test_0_init(self):
# delete test index
try:
client.tvs_del_index("test")
client.tvs_del_index("test2")
except:
pass

ret = client.tvs_create_index("test", dim, **self.index_params)
ret = client.tvs_create_index("test2", dim, **self.index_params)
if not ret:
raise RuntimeError("create test index failed")
raise RuntimeError("create test/test2 index failed")

def test_1_insert_vectors(self):
for i, v in enumerate(test_vectors):
ret = client.tvs_hset("test", str(i).zfill(6), vector=v)
self.assertTrue(ret)
for i, v in enumerate(test2_vectors):
ret = client.tvs_hset("test2", str(i).zfill(6), vector=v)
self.assertTrue(ret)

def test_2_knn_search(self):
for q in queries:
Expand Down Expand Up @@ -280,8 +296,30 @@ def test_5_search_with_filters(self):
def test_6_msearch_with_filters(self):
pass

def test_7_mindexknnsearch(self):
indexs = ["test", "test2"]
for q in queries:
result = client.tvs_mindexknnsearch(indexs, self.top_k, vector=q)
self.assertEqual(self.top_k, len(result))
d = 0.0
for k, v in result:
self.assertGreaterEqual(v, d)
d = v

def test_8_mindexmknnsearch(self):
indexs = ["test", "test2"]
batch = queries[:2]
result = client.tvs_mindexmknnsearch(indexs, self.top_k, batch)
self.assertEqual(len(result), len(batch))
for r in result:
d = 0.0
for _, v in r:
self.assertGreaterEqual(v, d)
d = v

def test_9_delete_index(self):
client.tvs_del_index("test")
client.tvs_del_index("test2")


class IndexApiTest(unittest.TestCase):
Expand Down

0 comments on commit eca10e7

Please sign in to comment.