Skip to content

Commit

Permalink
format code with black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
yangbodong22011 committed Aug 1, 2023
1 parent 7f15861 commit eb5915f
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 123 deletions.
2 changes: 1 addition & 1 deletion tair/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
TairVectorCommands,
parse_tvs_get_index_result,
parse_tvs_get_result,
parse_tvs_hincrbyfloat_result,
parse_tvs_hmget_result,
parse_tvs_msearch_result,
parse_tvs_search_result,
parse_tvs_hincrbyfloat_result,
)
from tair.tairzset import TairZsetCommands, parse_tair_zset_items

Expand Down
19 changes: 14 additions & 5 deletions tair/tairsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,18 @@ def tft_search(self, index: KeyT, query: str, use_cache: bool = False) -> Respon
pieces.append("use_cache")
return self.execute_command("TFT.SEARCH", *pieces)

def tft_msearch(self, index_count: int, index: Iterable[KeyT], query: str) -> ResponseT:
def tft_msearch(
self, index_count: int, index: Iterable[KeyT], query: str
) -> ResponseT:
return self.execute_command("TFT.MSEARCH", index_count, *index, query)

def tft_analyzer(self, analyzer_name: str, text: str, index: Optional[KeyT] = None,
show_time: Optional[bool] = False) -> ResponseT:
def tft_analyzer(
self,
analyzer_name: str,
text: str,
index: Optional[KeyT] = None,
show_time: Optional[bool] = False,
) -> ResponseT:
pieces: List[EncodableT] = [analyzer_name, text]
if index is not None:
pieces.append("INDEX")
Expand All @@ -167,9 +174,11 @@ def tft_analyzer(self, analyzer_name: str, text: str, index: Optional[KeyT] = No
target_nodes = None
if isinstance(self, tair.TairCluster):
if index is None:
target_nodes = 'random'
target_nodes = "random"
else:
target_nodes = self.nodes_manager.get_node_from_slot(self.keyslot(index))
target_nodes = self.nodes_manager.get_node_from_slot(
self.keyslot(index)
)
return self.execute_command("TFT.ANALYZER", *pieces, target_nodes=target_nodes)

def tft_explaincost(self, index: KeyT, query: str) -> ResponseT:
Expand Down
183 changes: 92 additions & 91 deletions tair/tairvector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from concurrent.futures import ThreadPoolExecutor
from functools import partial, reduce
from typing import Dict, List, Sequence, Tuple, Union, Optional, Iterable
from tair.typing import AbsExpiryT, CommandsProtocol, ExpiryT, ResponseT
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union

from redis.client import pairs_to_dict
from redis.utils import str_if_bytes

from tair.typing import AbsExpiryT, CommandsProtocol, ExpiryT, ResponseT

VectorType = Sequence[Union[int, float]]


Expand Down Expand Up @@ -123,11 +124,11 @@ def __init__(self, client, name, **index_params):

# bind methods
for method in (
"tvs_del",
"tvs_hdel",
"tvs_hgetall",
"tvs_hmget",
"tvs_scan",
"tvs_del",
"tvs_hdel",
"tvs_hgetall",
"tvs_hmget",
"tvs_scan",
):
attr = getattr(TairVectorCommands, method)
if callable(attr):
Expand All @@ -150,23 +151,23 @@ def tvs_hset(self, key: str, vector: Union[VectorType, str, None] = None, **kwar
return self.client.tvs_hset(self.name, key, vector, self.is_binary, **kwargs)

def tvs_knnsearch(
self,
k: int,
vector: Union[VectorType, str],
filter_str: Optional[str] = None,
**kwargs
self,
k: int,
vector: Union[VectorType, str],
filter_str: Optional[str] = None,
**kwargs
):
"""search for the top @k approximate nearest neighbors of @vector"""
return self.client.tvs_knnsearch(
self.name, k, vector, self.is_binary, filter_str, **kwargs
)

def tvs_mknnsearch(
self,
k: int,
vectors: Sequence[VectorType],
filter_str: Optional[str] = None,
**kwargs
self,
k: int,
vectors: Sequence[VectorType],
filter_str: Optional[str] = None,
**kwargs
):
"""batch approximate nearest neighbors search for a list of vectors"""
return self.client.tvs_mknnsearch(
Expand All @@ -190,13 +191,13 @@ class TairVectorCommands(CommandsProtocol):
SCAN_INDEX_CMD = "TVS.SCANINDEX"

def tvs_create_index(
self,
name: str,
dim: int,
distance_type: str = DistanceMetric.L2,
index_type: str = IndexType.HNSW,
data_type: str = DataType.Float32,
**kwargs
self,
name: str,
dim: int,
distance_type: str = DistanceMetric.L2,
index_type: str = IndexType.HNSW,
data_type: str = DataType.Float32,
**kwargs
):
"""
create a vector
Expand Down Expand Up @@ -231,7 +232,7 @@ def tvs_del_index(self, name: str):
return self.execute_command(self.DEL_INDEX_CMD, name)

def tvs_scan_index(
self, pattern: Optional[str] = None, batch: int = 10
self, pattern: Optional[str] = None, batch: int = 10
) -> TairVectorScanResult:
"""
scan all the indices
Expand All @@ -257,12 +258,12 @@ def tvs_index(self, name: str, **index_params) -> TairVectorIndex:
SCAN_CMD = "TVS.SCAN"

def tvs_hset(
self,
index: str,
key: str,
vector: Union[VectorType, str, None] = None,
is_binary=False,
**kwargs
self,
index: str,
key: str,
vector: Union[VectorType, str, None] = None,
is_binary=False,
**kwargs
):
"""
add/update a data entry to index
Expand Down Expand Up @@ -309,13 +310,13 @@ def tvs_hmget(self, index: str, key: str, *args):
return self.execute_command(self.HMGET_CMD, index, key, *args)

def tvs_scan(
self,
index: str,
pattern: Optional[str] = None,
batch: int = 10,
filter_str: Optional[str] = None,
vector: Optional[VectorType] = None,
max_dist: Optional[float] = None,
self,
index: str,
pattern: Optional[str] = None,
batch: int = 10,
filter_str: Optional[str] = None,
vector: Optional[VectorType] = None,
max_dist: Optional[float] = None,
):
"""
scan all data entries in an index
Expand All @@ -340,14 +341,14 @@ def get_batch(c):
return TairVectorScanResult(self, get_batch)

def _tvs_scan(
self,
index: str,
cursor: int = 0,
count: Optional[int] = None,
pattern: Optional[str] = None,
filter_str: Optional[str] = None,
vector: Union[VectorType, bytes, None] = None,
max_dist: Optional[float] = None,
self,
index: str,
cursor: int = 0,
count: Optional[int] = None,
pattern: Optional[str] = None,
filter_str: Optional[str] = None,
vector: Union[VectorType, bytes, None] = None,
max_dist: Optional[float] = None,
):
args = [] if pattern is None else ["MATCH", pattern]
if count is not None:
Expand All @@ -374,13 +375,13 @@ def _tvs_scan(
MINDEXMKNNSEARCH_CMD = "TVS.MINDEXMKNNSEARCH"

def tvs_knnsearch(
self,
index: str,
k: int,
vector: Union[VectorType, str, bytes],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
self,
index: str,
k: int,
vector: Union[VectorType, str, bytes],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
):
"""
search for the top @k approximate nearest neighbors of @vector in an index
Expand All @@ -395,13 +396,13 @@ def tvs_knnsearch(
)

def tvs_mknnsearch(
self,
index: str,
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
self,
index: str,
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
):
"""
batch approximate nearest neighbors search for a list of vectors
Expand Down Expand Up @@ -430,13 +431,13 @@ def tvs_mknnsearch(
)

def tvs_mindexknnsearch(
self,
index: Sequence[str],
k: int,
vector: Union[VectorType, str, bytes],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
self,
index: Sequence[str],
k: int,
vector: Union[VectorType, str, bytes],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
):
"""
search for the top @k approximate nearest neighbors of @vector in indexs
Expand All @@ -453,13 +454,13 @@ def tvs_mindexknnsearch(
)

def tvs_mindexmknnsearch(
self,
index: Sequence[str],
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
self,
index: Sequence[str],
k: int,
vectors: Sequence[VectorType],
is_binary: bool = False,
filter_str: Optional[str] = None,
**kwargs
):
"""
batch approximate nearest neighbors search for a list of vectors
Expand Down Expand Up @@ -492,13 +493,13 @@ def tvs_mindexmknnsearch(
GETDISTANCE_CMD = "TVS.GETDISTANCE"

def _tvs_getdistance(
self,
index_name: str,
vector: VectorType,
keys: Iterable[str],
top_n: Optional[int] = None,
max_dist: Optional[float] = None,
filter_str: Optional[str] = None,
self,
index_name: str,
vector: VectorType,
keys: Iterable[str],
top_n: Optional[int] = None,
max_dist: Optional[float] = None,
filter_str: Optional[str] = None,
):
"""
low level interface for TVS.GETDISTANCE
Expand All @@ -520,15 +521,15 @@ def _tvs_getdistance(
)

def tvs_getdistance(
self,
index_name: str,
vector: Union[VectorType, str, bytes],
keys: Iterable[str],
batch_size: int = 100000,
parallelism: int = 1,
top_n: Optional[int] = None,
max_dist: Optional[float] = None,
filter_str: Optional[str] = None,
self,
index_name: str,
vector: Union[VectorType, str, bytes],
keys: Iterable[str],
batch_size: int = 100000,
parallelism: int = 1,
top_n: Optional[int] = None,
max_dist: Optional[float] = None,
filter_str: Optional[str] = None,
):
"""
wrapped interface for TVS.GETDISTANCE
Expand Down Expand Up @@ -562,7 +563,7 @@ def process_batch(batch):

with ThreadPoolExecutor(max_workers=parallelism) as executor:
batches = [
keys[i: i + batch_size] for i in range(0, len(keys), batch_size)
keys[i : i + batch_size] for i in range(0, len(keys), batch_size)
]

futures = [executor.submit(process_batch, batch) for batch in batches]
Expand Down
Loading

0 comments on commit eb5915f

Please sign in to comment.