diff --git a/tair/commands.py b/tair/commands.py index cd6523b..7d7a6c2 100644 --- a/tair/commands.py +++ b/tair/commands.py @@ -30,10 +30,10 @@ TairVectorCommands, parse_tvs_get_index_result, parse_tvs_get_result, + parse_tvs_hincrbyfloat_result, parse_tvs_hmget_result, parse_tvs_msearch_result, parse_tvs_search_result, - parse_tvs_hincrbyfloat_result, ) from tair.tairzset import TairZsetCommands, parse_tair_zset_items diff --git a/tair/tairsearch.py b/tair/tairsearch.py index debacb8..90fc854 100644 --- a/tair/tairsearch.py +++ b/tair/tairsearch.py @@ -153,11 +153,18 @@ def tft_search(self, index: KeyT, query: str, use_cache: bool = False) -> Respon pieces.append("use_cache") return self.execute_command("TFT.SEARCH", *pieces) - def tft_msearch(self, index_count: int, index: Iterable[KeyT], query: str) -> ResponseT: + def tft_msearch( + self, index_count: int, index: Iterable[KeyT], query: str + ) -> ResponseT: return self.execute_command("TFT.MSEARCH", index_count, *index, query) - def tft_analyzer(self, analyzer_name: str, text: str, index: Optional[KeyT] = None, - show_time: Optional[bool] = False) -> ResponseT: + def tft_analyzer( + self, + analyzer_name: str, + text: str, + index: Optional[KeyT] = None, + show_time: Optional[bool] = False, + ) -> ResponseT: pieces: List[EncodableT] = [analyzer_name, text] if index is not None: pieces.append("INDEX") @@ -167,9 +174,11 @@ def tft_analyzer(self, analyzer_name: str, text: str, index: Optional[KeyT] = No target_nodes = None if isinstance(self, tair.TairCluster): if index is None: - target_nodes = 'random' + target_nodes = "random" else: - target_nodes = self.nodes_manager.get_node_from_slot(self.keyslot(index)) + target_nodes = self.nodes_manager.get_node_from_slot( + self.keyslot(index) + ) return self.execute_command("TFT.ANALYZER", *pieces, target_nodes=target_nodes) def tft_explaincost(self, index: KeyT, query: str) -> ResponseT: diff --git a/tair/tairvector.py b/tair/tairvector.py index c3e4fc0..a886a5a 100644 --- a/tair/tairvector.py +++ b/tair/tairvector.py @@ -1,11 +1,12 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial, reduce -from typing import Dict, List, Sequence, Tuple, Union, Optional, Iterable -from tair.typing import AbsExpiryT, CommandsProtocol, ExpiryT, ResponseT +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union from redis.client import pairs_to_dict from redis.utils import str_if_bytes +from tair.typing import AbsExpiryT, CommandsProtocol, ExpiryT, ResponseT + VectorType = Sequence[Union[int, float]] @@ -123,11 +124,11 @@ def __init__(self, client, name, **index_params): # bind methods for method in ( - "tvs_del", - "tvs_hdel", - "tvs_hgetall", - "tvs_hmget", - "tvs_scan", + "tvs_del", + "tvs_hdel", + "tvs_hgetall", + "tvs_hmget", + "tvs_scan", ): attr = getattr(TairVectorCommands, method) if callable(attr): @@ -150,11 +151,11 @@ def tvs_hset(self, key: str, vector: Union[VectorType, str, None] = None, **kwar return self.client.tvs_hset(self.name, key, vector, self.is_binary, **kwargs) def tvs_knnsearch( - self, - k: int, - vector: Union[VectorType, str], - filter_str: Optional[str] = None, - **kwargs + self, + k: int, + vector: Union[VectorType, str], + filter_str: Optional[str] = None, + **kwargs ): """search for the top @k approximate nearest neighbors of @vector""" return self.client.tvs_knnsearch( @@ -162,11 +163,11 @@ def tvs_knnsearch( ) def tvs_mknnsearch( - self, - k: int, - vectors: Sequence[VectorType], - filter_str: Optional[str] = None, - **kwargs + self, + k: int, + vectors: Sequence[VectorType], + filter_str: Optional[str] = None, + **kwargs ): """batch approximate nearest neighbors search for a list of vectors""" return self.client.tvs_mknnsearch( @@ -190,13 +191,13 @@ class TairVectorCommands(CommandsProtocol): SCAN_INDEX_CMD = "TVS.SCANINDEX" def tvs_create_index( - self, - name: str, - dim: int, - distance_type: str = DistanceMetric.L2, - index_type: str = IndexType.HNSW, - data_type: str = DataType.Float32, - **kwargs + self, + name: str, + dim: int, + distance_type: str = DistanceMetric.L2, + index_type: str = IndexType.HNSW, + data_type: str = DataType.Float32, + **kwargs ): """ create a vector @@ -231,7 +232,7 @@ def tvs_del_index(self, name: str): return self.execute_command(self.DEL_INDEX_CMD, name) def tvs_scan_index( - self, pattern: Optional[str] = None, batch: int = 10 + self, pattern: Optional[str] = None, batch: int = 10 ) -> TairVectorScanResult: """ scan all the indices @@ -257,12 +258,12 @@ def tvs_index(self, name: str, **index_params) -> TairVectorIndex: SCAN_CMD = "TVS.SCAN" def tvs_hset( - self, - index: str, - key: str, - vector: Union[VectorType, str, None] = None, - is_binary=False, - **kwargs + self, + index: str, + key: str, + vector: Union[VectorType, str, None] = None, + is_binary=False, + **kwargs ): """ add/update a data entry to index @@ -309,13 +310,13 @@ def tvs_hmget(self, index: str, key: str, *args): return self.execute_command(self.HMGET_CMD, index, key, *args) def tvs_scan( - self, - index: str, - pattern: Optional[str] = None, - batch: int = 10, - filter_str: Optional[str] = None, - vector: Optional[VectorType] = None, - max_dist: Optional[float] = None, + self, + index: str, + pattern: Optional[str] = None, + batch: int = 10, + filter_str: Optional[str] = None, + vector: Optional[VectorType] = None, + max_dist: Optional[float] = None, ): """ scan all data entries in an index @@ -340,14 +341,14 @@ def get_batch(c): return TairVectorScanResult(self, get_batch) def _tvs_scan( - self, - index: str, - cursor: int = 0, - count: Optional[int] = None, - pattern: Optional[str] = None, - filter_str: Optional[str] = None, - vector: Union[VectorType, bytes, None] = None, - max_dist: Optional[float] = None, + self, + index: str, + cursor: int = 0, + count: Optional[int] = None, + pattern: Optional[str] = None, + filter_str: Optional[str] = None, + vector: Union[VectorType, bytes, None] = None, + max_dist: Optional[float] = None, ): args = [] if pattern is None else ["MATCH", pattern] if count is not None: @@ -374,13 +375,13 @@ def _tvs_scan( MINDEXMKNNSEARCH_CMD = "TVS.MINDEXMKNNSEARCH" def tvs_knnsearch( - self, - index: str, - k: int, - vector: Union[VectorType, str, bytes], - is_binary: bool = False, - filter_str: Optional[str] = None, - **kwargs + self, + index: str, + k: int, + vector: Union[VectorType, str, bytes], + is_binary: bool = False, + filter_str: Optional[str] = None, + **kwargs ): """ search for the top @k approximate nearest neighbors of @vector in an index @@ -395,13 +396,13 @@ def tvs_knnsearch( ) def tvs_mknnsearch( - self, - index: str, - k: int, - vectors: Sequence[VectorType], - is_binary: bool = False, - filter_str: Optional[str] = None, - **kwargs + self, + index: str, + k: int, + vectors: Sequence[VectorType], + is_binary: bool = False, + filter_str: Optional[str] = None, + **kwargs ): """ batch approximate nearest neighbors search for a list of vectors @@ -430,13 +431,13 @@ def tvs_mknnsearch( ) def tvs_mindexknnsearch( - self, - index: Sequence[str], - k: int, - vector: Union[VectorType, str, bytes], - is_binary: bool = False, - filter_str: Optional[str] = None, - **kwargs + self, + index: Sequence[str], + k: int, + vector: Union[VectorType, str, bytes], + is_binary: bool = False, + filter_str: Optional[str] = None, + **kwargs ): """ search for the top @k approximate nearest neighbors of @vector in indexs @@ -453,13 +454,13 @@ def tvs_mindexknnsearch( ) def tvs_mindexmknnsearch( - self, - index: Sequence[str], - k: int, - vectors: Sequence[VectorType], - is_binary: bool = False, - filter_str: Optional[str] = None, - **kwargs + self, + index: Sequence[str], + k: int, + vectors: Sequence[VectorType], + is_binary: bool = False, + filter_str: Optional[str] = None, + **kwargs ): """ batch approximate nearest neighbors search for a list of vectors @@ -492,13 +493,13 @@ def tvs_mindexmknnsearch( GETDISTANCE_CMD = "TVS.GETDISTANCE" def _tvs_getdistance( - self, - index_name: str, - vector: VectorType, - keys: Iterable[str], - top_n: Optional[int] = None, - max_dist: Optional[float] = None, - filter_str: Optional[str] = None, + self, + index_name: str, + vector: VectorType, + keys: Iterable[str], + top_n: Optional[int] = None, + max_dist: Optional[float] = None, + filter_str: Optional[str] = None, ): """ low level interface for TVS.GETDISTANCE @@ -520,15 +521,15 @@ def _tvs_getdistance( ) def tvs_getdistance( - self, - index_name: str, - vector: Union[VectorType, str, bytes], - keys: Iterable[str], - batch_size: int = 100000, - parallelism: int = 1, - top_n: Optional[int] = None, - max_dist: Optional[float] = None, - filter_str: Optional[str] = None, + self, + index_name: str, + vector: Union[VectorType, str, bytes], + keys: Iterable[str], + batch_size: int = 100000, + parallelism: int = 1, + top_n: Optional[int] = None, + max_dist: Optional[float] = None, + filter_str: Optional[str] = None, ): """ wrapped interface for TVS.GETDISTANCE @@ -562,7 +563,7 @@ def process_batch(batch): with ThreadPoolExecutor(max_workers=parallelism) as executor: batches = [ - keys[i: i + batch_size] for i in range(0, len(keys), batch_size) + keys[i : i + batch_size] for i in range(0, len(keys), batch_size) ] futures = [executor.submit(process_batch, batch) for batch in batches] diff --git a/tests/test_tairsearch.py b/tests/test_tairsearch.py index 40e5a3f..4a6444d 100644 --- a/tests/test_tairsearch.py +++ b/tests/test_tairsearch.py @@ -50,7 +50,6 @@ def test_tft_updateindex(self, t: Tair): assert t.tft_updateindex(index, mappings2) t.delete(index) - def test_tft_getindex(self, t: Tair): index = "idx_" + str(uuid.uuid4()) mappings = """ @@ -455,7 +454,7 @@ def test_tft_search(self, t: Tair): result = t.tft_search(index, '{"sort":[{"price":{"order":"desc"}}]}', True) assert json.loads(want) == json.loads(result) result = t.tft_explaincost(index, '{"sort":[{"price":{"order":"desc"}}]}') - assert json.loads(result)['QUERY_COST'] + assert json.loads(result)["QUERY_COST"] t.delete(index) def test_tft_msearch(self, t: Tair): @@ -480,12 +479,8 @@ def test_tft_msearch(self, t: Tair): assert t.tft_createindex(index1, mappings) assert t.tft_createindex(index2, mappings) - assert t.tft_madddoc( - index1, {document1: "00001", document2: "00002"} - ) - assert t.tft_madddoc( - index2, {document3: "00003", document4: "00004"} - ) + assert t.tft_madddoc(index1, {document1: "00001", document2: "00002"}) + assert t.tft_madddoc(index2, {document3: "00003", document4: "00004"}) want = f"""{{ "aux_info": {{"index_crc64": 5843875291690071373}}, @@ -520,7 +515,9 @@ def test_tft_msearch(self, t: Tair): "total": {{ "relation": "eq", "value": 4 }} }} }}""" - result = t.tft_msearch(2, {index1, index2}, '{"sort":[{"_doc":{"order":"asc"}}]}') + result = t.tft_msearch( + 2, {index1, index2}, '{"sort":[{"_doc":{"order":"asc"}}]}' + ) assert json.loads(want) == json.loads(result) t.delete(index1) t.delete(index2) @@ -547,11 +544,13 @@ def test_tft_analyzer(self, t: Tair): } } }""" - text = 'This is tair-py.' + text = "This is tair-py." assert t.tft_createindex(index, mappings) - assert t.tft_analyzer("standard", text) == t.tft_analyzer("my_analyzer", text, index) - assert 'consuming time' in str(t.tft_analyzer("standard", text, None, True)) + assert t.tft_analyzer("standard", text) == t.tft_analyzer( + "my_analyzer", text, index + ) + assert "consuming time" in str(t.tft_analyzer("standard", text, None, True)) t.delete(index) diff --git a/tests/test_tairvector.py b/tests/test_tairvector.py index b62f201..412e083 100644 --- a/tests/test_tairvector.py +++ b/tests/test_tairvector.py @@ -2,22 +2,17 @@ import os import string import sys +import time import unittest import uuid -import pytest -import time from random import choice, randint, random +import pytest import redis sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from tair.tairvector import ( - Constants, - DataType, - DistanceMetric, - TairVectorIndex -) +from tair.tairvector import Constants, DataType, DistanceMetric, TairVectorIndex from .conftest import get_tair_client @@ -34,7 +29,7 @@ for _ in range(num_vectors * num_attrs) ] test_attributes = [ - dict(zip(attr_keys, attr_values[i: i + 3])) + dict(zip(attr_keys, attr_values[i : i + 3])) for i in range(0, num_vectors * num_attrs, num_attrs) ] @@ -229,7 +224,9 @@ def test_8_hincry(self): def test_9_hincrbyfloat(self): key_tmp = "key_tmp1" client.tvs_hset("test", key_tmp, field2=1.1) - assert client.tvs_hincrbyfloat("test", key_tmp, "field2", 2.2) == pytest.approx(3.3) + assert client.tvs_hincrbyfloat("test", key_tmp, "field2", 2.2) == pytest.approx( + 3.3 + ) def test_10_delete(self): client.tvs_del_index("test") @@ -729,7 +726,15 @@ def test_tvs_hexpire(self): key = "key_" + str(uuid.uuid4()) vector = [random() for _ in range(dim)] self.assertEqual( - client.tvs_hset(self.index_name, key, vector, field1=str(uuid.uuid4()), field2=randint(0, 100)), 3) + client.tvs_hset( + self.index_name, + key, + vector, + field1=str(uuid.uuid4()), + field2=randint(0, 100), + ), + 3, + ) self.assertEqual(client.tvs_hexpire(self.index_name, key, 100), 1) assert 0 < client.tvs_httl(self.index_name, key) <= 100 @@ -738,7 +743,15 @@ def test_tvs_hpexpire(self): key = "key_" + str(uuid.uuid4()) vector = [random() for _ in range(dim)] self.assertEqual( - client.tvs_hset(self.index_name, key, vector, field1=str(uuid.uuid4()), field2=randint(0, 100)), 3) + client.tvs_hset( + self.index_name, + key, + vector, + field1=str(uuid.uuid4()), + field2=randint(0, 100), + ), + 3, + ) self.assertEqual(client.tvs_hpexpire(self.index_name, key, 100), 1) assert 0 < client.tvs_hpttl(self.index_name, key) <= 100 @@ -748,7 +761,15 @@ def test_tvs_hexpireat(self): vector = [random() for _ in range(dim)] abs_expire = int(time.time()) + 100 self.assertEqual( - client.tvs_hset(self.index_name, key, vector, field1=str(uuid.uuid4()), field2=randint(0, 100)), 3) + client.tvs_hset( + self.index_name, + key, + vector, + field1=str(uuid.uuid4()), + field2=randint(0, 100), + ), + 3, + ) self.assertEqual(client.tvs_hexpireat(self.index_name, key, abs_expire), 1) assert 0 < client.tvs_httl(self.index_name, key) <= 100 self.assertEqual(client.tvs_hexpiretime(self.index_name, key), abs_expire) @@ -759,7 +780,15 @@ def test_tvs_hpexpireat(self): vector = [random() for _ in range(dim)] abs_expire = int(time.time() * 1000) + 100 self.assertEqual( - client.tvs_hset(self.index_name, key, vector, field1=str(uuid.uuid4()), field2=randint(0, 100)), 3) + client.tvs_hset( + self.index_name, + key, + vector, + field1=str(uuid.uuid4()), + field2=randint(0, 100), + ), + 3, + ) self.assertEqual(client.tvs_hpexpireat(self.index_name, key, abs_expire), 1) assert 0 < client.tvs_hpttl(self.index_name, key) <= 100 self.assertEqual(client.tvs_hpexpiretime(self.index_name, key), abs_expire)