diff --git a/fakeredis/_basefakesocket.py b/fakeredis/_basefakesocket.py index 10894800..ea070778 100644 --- a/fakeredis/_basefakesocket.py +++ b/fakeredis/_basefakesocket.py @@ -2,14 +2,15 @@ import queue import time import weakref -from typing import List, Any, Tuple, Optional, Callable, Union, Match +from typing import List, Any, Tuple, Optional, Callable, Union, Match, AnyStr, Generator +from xmlrpc.client import ResponseError import redis from redis.connection import DefaultParser from . import _msgs as msgs from ._command_args_parsing import extract_args -from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Item, Signature, CommandItem, Hash +from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Signature, CommandItem, Hash from ._helpers import ( SimpleError, valid_response_type, @@ -46,7 +47,7 @@ def _extract_command(fields: List[bytes]) -> Tuple[Any, List[Any]]: return cmd, cmd_arguments -def bin_reverse(x, bits_count): +def bin_reverse(x: int, bits_count: int) -> int: result = 0 for i in range(bits_count): if (x >> i) & 1: @@ -55,7 +56,7 @@ def bin_reverse(x, bits_count): class BaseFakeSocket: - _clear_watches: Callable + _clear_watches: Callable[[], None] ACCEPTED_COMMANDS_WHILE_PUBSUB = { "ping", "subscribe", @@ -68,14 +69,14 @@ class BaseFakeSocket: } _connection_error_class = redis.ConnectionError - def __init__(self, server: "FakeServer", db: int, *args: Any, **kwargs: Any): # noqa: F821 + def __init__(self, server: "FakeServer", db: int, *args: Any, **kwargs: Any) -> None: # type: ignore # noqa: F821 super(BaseFakeSocket, self).__init__(*args, **kwargs) from fakeredis import FakeServer self._server: FakeServer = server self._db_num = db self._db = server.dbs[self._db_num] - self.responses: Optional[queue.Queue] = queue.Queue() + self.responses: Optional[queue.Queue[bytes]] = queue.Queue() # Prevents parser from processing commands. Not used in this module, # but set by aioredis module to prevent new commands being processed # while handling a blocking command. @@ -108,7 +109,7 @@ def resume(self) -> None: self._paused = False self._parser.send(b"") - def shutdown(self, _) -> None: + def shutdown(self, _: Any) -> None: self._parser.close() @staticmethod @@ -136,12 +137,12 @@ def close(self) -> None: # at any time, and hence we can't safely take the server lock. # We rely on list.append being atomic. self._server.closed_sockets.append(weakref.ref(self)) - self._server = None + self._server = None # type: ignore self._db = None self.responses = None @staticmethod - def _extract_line(buf): + def _extract_line(buf: bytes) -> Tuple[bytes, bytes]: pos = buf.find(b"\n") + 1 assert pos > 0 line = buf[:pos] @@ -149,7 +150,7 @@ def _extract_line(buf): assert line.endswith(b"\r\n") return line, buf - def _parse_commands(self): + def _parse_commands(self) -> Generator[None, Any, None]: """Generator that parses commands. It is fed pieces of redis protocol data (via `send`) and calls @@ -175,8 +176,10 @@ def _parse_commands(self): buf = buf[length + 2 :] # +2 to skip the CRLF self._process_command(fields) - def _run_command(self, func: Callable[..., Any], sig: Signature, args: List[Any], from_script: bool) -> Any: - command_items = {} + def _run_command( + self, func: Optional[Callable[[Any], Any]], sig: Signature, args: List[Any], from_script: bool + ) -> Any: + command_items: List[CommandItem] = [] try: ret = sig.apply(args, self._db, self.version) if from_script and msgs.FLAG_NO_SCRIPT in sig.flags: @@ -187,7 +190,7 @@ def _run_command(self, func: Callable[..., Any], sig: Signature, args: List[Any] result = ret[0] else: args, command_items = ret - result = func(*args) + result = func(*args) # type: ignore assert valid_response_type(result) except SimpleError as exc: result = exc @@ -195,7 +198,7 @@ def _run_command(self, func: Callable[..., Any], sig: Signature, args: List[Any] command_item.writeback(remove_empty_val=msgs.FLAG_LEAVE_EMPTY_VAL not in sig.flags) return result - def _decode_error(self, error): + def _decode_error(self, error: SimpleError) -> ResponseError: return DefaultParser(socket_read_size=65536).parse_error(error.value) # type: ignore def _decode_result(self, result: Any) -> Any: @@ -209,7 +212,7 @@ def _decode_result(self, result: Any) -> Any: else: return result - def _blocking(self, timeout: Optional[Union[float, int]], func: Callable[[bool], Any]): + def _blocking(self, timeout: Optional[Union[float, int]], func: Callable[[bool], Any]) -> Any: """Run a function until it succeeds or timeout is reached. The timeout is in seconds, and 0 means infinite. The function @@ -234,7 +237,7 @@ def _blocking(self, timeout: Optional[Union[float, int]], func: Callable[[bool], if ret is not None: return ret - def _name_to_func(self, cmd_name: str) -> Tuple[Optional[Callable], Signature]: + def _name_to_func(self, cmd_name: str) -> Tuple[Optional[Callable[[Any], Any]], Signature]: """Get the signature and the method from the command name.""" if cmd_name not in SUPPORTED_COMMANDS: # redis remaps \r or \n in an error to ' ' to make it legal protocol @@ -244,14 +247,14 @@ def _name_to_func(self, cmd_name: str) -> Tuple[Optional[Callable], Signature]: func = getattr(self, sig.func_name, None) return func, sig - def sendall(self, data): + def sendall(self, data: AnyStr) -> None: if not self._server.connected: raise self._connection_error_class(msgs.CONNECTION_ERROR_MSG) if isinstance(data, str): - data = data.encode("ascii") + data = data.encode("ascii") # type: ignore self._parser.send(data) - def _process_command(self, fields: List[bytes]): + def _process_command(self, fields: List[bytes]) -> None: if not fields: return result: Any @@ -353,7 +356,7 @@ def match_type(key) -> bool: result_cursor = 0 return [str(bin_reverse(result_cursor, bits_len)).encode(), result_data] - def _ttl(self, key: CommandItem, scale: int) -> int: + def _ttl(self, key: CommandItem, scale: float) -> int: if not key: return -2 elif key.expireat is None: @@ -372,7 +375,7 @@ def _encodeint(self, value: int) -> bytes: return Int.encode(value) @staticmethod - def _key_value_type(key: Item) -> SimpleString: + def _key_value_type(key: CommandItem) -> SimpleString: if key.value is None: return SimpleString(b"none") elif isinstance(key.value, bytes): diff --git a/fakeredis/_fakesocket.py b/fakeredis/_fakesocket.py index 8aca07a0..0525a499 100644 --- a/fakeredis/_fakesocket.py +++ b/fakeredis/_fakesocket.py @@ -1,4 +1,4 @@ -from typing import Optional, Set +from typing import Optional, Set, Any from fakeredis.stack import ( JSONCommandsMixin, @@ -9,8 +9,8 @@ TDigestCommandsMixin, TimeSeriesCommandsMixin, ) -from ._server import FakeServer from ._basefakesocket import BaseFakeSocket +from ._server import FakeServer from .commands_mixins.bitmap_mixin import BitmapCommandsMixin from .commands_mixins.connection_mixin import ConnectionCommandsMixin from .commands_mixins.generic_mixin import GenericCommandsMixin @@ -24,9 +24,9 @@ except ImportError: class ScriptingCommandsMixin: # type: ignore # noqa: E303 - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs.pop("lua_modules", None) - super(ScriptingCommandsMixin, self).__init__(*args, **kwargs) + super(ScriptingCommandsMixin, self).__init__(*args, **kwargs) # type: ignore from .commands_mixins.server_mixin import ServerCommandsMixin @@ -65,6 +65,6 @@ def __init__( self, server: "FakeServer", db: int, - lua_modules: Optional[Set[str]] = None, # type: ignore # noqa: F821 + lua_modules: Optional[Set[str]] = None, # noqa: F821 ) -> None: super(FakeSocket, self).__init__(server, db, lua_modules=lua_modules)