Skip to content

Commit

Permalink
Improve: Compatible with decode responses
Browse files Browse the repository at this point in the history
solved #59
  • Loading branch information
yangbodong22011 committed Sep 27, 2023
1 parent eb5915f commit 021b407
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 104 deletions.
4 changes: 4 additions & 0 deletions format_code.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash

python -m isort --profile black tair/*.py tests/*.py
python -m isort -v --profile black tair/*.py tests/*.py
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
author_email="vincillau@outlook.com",
python_requires=">=3.7",
packages=["tair", "tair.asyncio"],
install_requires=["redis >= 4.4.4"],
install_requires=["redis == 4.4.4"],
)
42 changes: 20 additions & 22 deletions tair/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,27 @@ def bool_ok(resp) -> bool:
"TR.SETBITARRAY": bool_ok,
"TR.OPTIMIZE": bool_ok,
"TR.SCAN": parse_tr_scan,
"TR.RANGEBITARRAY": lambda resp: resp.decode(),
"TR.JACCARD": lambda resp: float(resp.decode()),
"TR.RANGEBITARRAY": lambda resp: str_if_bytes(resp),
"TR.JACCARD": lambda resp: float(resp),
# TairSearch
"TFT.CREATEINDEX": bool_ok,
"TFT.UPDATEINDEX": bool_ok,
"TFT.GETINDEX": lambda resp: None if resp is None else resp.decode(),
"TFT.ADDDOC": lambda resp: resp.decode(),
"TFT.GETINDEX": lambda resp: None if resp is None else str_if_bytes(resp),
"TFT.ADDDOC": lambda resp: str_if_bytes(resp),
"TFT.MADDDOC": bool_ok,
"TFT.DELDOC": lambda resp: int(resp.decode()),
"TFT.DELDOC": lambda resp: int(resp),
"TFT.UPDATEDOCFIELD": bool_ok,
"TFT.INCRFLOATDOCFIELD": lambda resp: float(resp.decode()),
"TFT.GETDOC": lambda resp: None if resp is None else resp.decode(),
"TFT.INCRFLOATDOCFIELD": lambda resp: float(resp),
"TFT.GETDOC": lambda resp: None if resp is None else str_if_bytes(resp),
"TFT.SCANDOCID": lambda resp: ScandocidResult(
resp[0].decode(), [i.decode() for i in resp[1]]
str_if_bytes(resp[0]), [str_if_bytes(i) for i in resp[1]]
),
"TFT.DELALL": bool_ok,
"TFT.SEARCH": lambda resp: resp.decode(),
"TFT.GETSUG": lambda resp: [i.decode() for i in resp],
"TFT.GETALLSUGS": lambda resp: [i.decode() for i in resp],
"TFT.SEARCH": lambda resp: str_if_bytes(resp),
"TFT.GETSUG": lambda resp: [str_if_bytes(i) for i in resp],
"TFT.GETALLSUGS": lambda resp: [str_if_bytes(i) for i in resp],
# TairDoc
"JSON.SET": lambda resp: None if resp is None else resp == b"OK",
"JSON.SET": lambda resp: None if resp is None else bool_ok(resp),
"JSON.TYPE": str_if_bytes,
# TairTs
"EXTS.P.CREATE": bool_ok,
Expand All @@ -130,18 +130,16 @@ def bool_ok(resp) -> bool:
"EXTS.S.RAW_MINCRBY": lambda resp: [bool_ok(i) for i in resp],
# TairCpc
"CPC.UPDATE": bool_ok,
"CPC.ESTIMATE": lambda resp: float(resp.decode()),
"CPC.UPDATE2EST": lambda resp: float(resp.decode()),
"CPC.UPDATE2JUD": lambda resp: CpcUpdate2judResult(
float(resp[0].decode()), float(resp[1].decode())
),
"CPC.ESTIMATE": lambda resp: float(resp),
"CPC.UPDATE2EST": lambda resp: float(resp),
"CPC.UPDATE2JUD": lambda resp: CpcUpdate2judResult(float(resp[0]), float(resp[1])),
"CPC.ARRAY.UPDATE": bool_ok,
"CPC.ARRAY.ESTIMATE": lambda resp: float(resp.decode()),
"CPC.ARRAY.ESTIMATE.RANGE": lambda resp: [float(i.decode()) for i in resp],
"CPC.ARRAY.ESTIMATE.RANGE.MERGE": lambda resp: float(resp.decode()),
"CPC.ARRAY.UPDATE2EST": lambda resp: float(resp.decode()),
"CPC.ARRAY.ESTIMATE": lambda resp: float(resp),
"CPC.ARRAY.ESTIMATE.RANGE": lambda resp: [float(i) for i in resp],
"CPC.ARRAY.ESTIMATE.RANGE.MERGE": lambda resp: float(resp),
"CPC.ARRAY.UPDATE2EST": lambda resp: float(resp),
"CPC.ARRAY.UPDATE2JUD": lambda resp: CpcUpdate2judResult(
float(resp[0].decode()), float(resp[1].decode())
float(resp[0]), float(resp[1])
),
# TairVector
"TVS.CREATEINDEX": bool_ok,
Expand Down
25 changes: 18 additions & 7 deletions tair/tairhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


class ValueVersionItem:
def __init__(self, value: bytes, version: int) -> None:
def __init__(self, value: Union[bytes, str], version: int) -> None:
self.value = value
self.version = version

Expand All @@ -28,11 +28,11 @@ def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

def __repr__(self) -> str:
return f"{{value: {self.value.decode()}, version: {self.version}}}"
return f"{{value: {self.value}, version: {self.version}}}"


class FieldValueItem:
def __init__(self, field: bytes, value: bytes) -> None:
def __init__(self, field: Union[bytes, str], value: Union[bytes, str]) -> None:
self.field = field
self.value = value

Expand All @@ -41,15 +41,26 @@ def __eq__(self, other: object) -> bool:
return False
return self.field == other.field and self.value == other.value

def __lt__(self, other: object) -> bool:
if not isinstance(other, FieldValueItem):
raise TypeError(
"Cannot compare 'FieldValueItem' with non-FieldValueItem objects."
)
return self.field < other.field or (
self.field == other.field and self.value < other.value
)

def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

def __repr__(self) -> str:
return f"{{field: {self.field.decode()}, value: {self.value.decode()}}}"
return f"{{field: {self.field}, value: {self.value}}}"


class ExhscanResult:
def __init__(self, next_field: bytes, items: Iterable[FieldValueItem]) -> None:
def __init__(
self, next_field: Union[bytes, str], items: Iterable[FieldValueItem]
) -> None:
self.next_field = next_field
self.items = list(items)

Expand All @@ -62,7 +73,7 @@ def __ne__(self, other: object) -> bool:
return not self.__eq__(other)

def __repr__(self) -> str:
return f"{{next_field: {self.next_field.decode()}, items: {self.items}}}"
return f"{{next_field: {self.next_field}, items: {self.items}}}"


class TairHashCommands(CommandsProtocol):
Expand Down Expand Up @@ -417,7 +428,7 @@ def exhdel(self, key: KeyT, fields: Iterable[FieldT]) -> ResponseT:
def parse_exhincrbyfloat(resp) -> Union[float, None]:
if resp is None:
return resp
return float(resp.decode())
return float(resp)


def parse_exhgetwithver(resp) -> Union[ValueVersionItem, None]:
Expand Down
13 changes: 8 additions & 5 deletions tair/tairstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import time
from typing import List, Optional, Union

from redis.client import bool_ok
from redis.utils import str_if_bytes

from tair.exceptions import DataError
from tair.typing import (
AbsExpiryT,
Expand All @@ -14,7 +17,7 @@


class ExgetResult:
def __init__(self, value: bytes, version: int) -> None:
def __init__(self, value: Union[bytes, str], version: int) -> None:
self.value = value
self.version = version

Expand All @@ -31,7 +34,7 @@ def __repr__(self) -> str:


class ExcasResult:
def __init__(self, msg: str, value: bytes, version: int) -> None:
def __init__(self, msg: str, value: Union[bytes, str], version: int) -> None:
self.msg = msg
self.value = value
self.version = version
Expand Down Expand Up @@ -272,7 +275,7 @@ def cad(self, key: KeyT, value: EncodableT) -> ResponseT:
def parse_exset(resp) -> Union[bool, None]:
if resp is None:
return None
return resp == b"OK"
return bool_ok(resp)


def parse_exget(resp) -> ExgetResult:
Expand All @@ -282,10 +285,10 @@ def parse_exget(resp) -> ExgetResult:
def parse_excas(resp) -> ExcasResult:
if isinstance(resp, int):
return resp
return ExcasResult(resp[0].decode(), resp[1], resp[2])
return ExcasResult(str_if_bytes(resp[0]), resp[1], resp[2])


def parse_exincrbyfloat(resp) -> Union[float, None]:
if resp is None:
return resp
return float(resp.decode())
return float(resp)
7 changes: 4 additions & 3 deletions tair/tairzset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Iterable, List, Mapping, Optional
from typing import Iterable, List, Mapping, Optional, Union

from redis.typing import CommandsProtocol
from redis.utils import str_if_bytes

from tair.exceptions import DataError
from tair.typing import AnyKeyT, EncodableT, KeyT, ResponseT


class TairZsetItem:
def __init__(self, member: bytes, score: str) -> None:
def __init__(self, member: Union[bytes, str], score: str) -> None:
self.member = member
self.score = score

Expand Down Expand Up @@ -237,7 +238,7 @@ def parse_tair_zset_items(resp, **options):
result: List[TairZsetItem] = []
if options.get("withscores"):
for i in range(0, len(resp), 2):
result.append(TairZsetItem(resp[i], resp[i + 1].decode()))
result.append(TairZsetItem(resp[i], str_if_bytes(resp[i + 1])))
else:
for i in resp:
result.append(TairZsetItem(i, None))
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,12 @@ def get_server_time(client) -> datetime:
seconds, milliseconds = client.time()
timestamp = float(f"{seconds}.{milliseconds}")
return datetime.fromtimestamp(timestamp)


def compare_str(left, right):
if isinstance(left, bytes):
left = left.decode()
if isinstance(right, bytes):
right = right.decode()

return left == right
Loading

0 comments on commit 021b407

Please sign in to comment.