Skip to content

Commit

Permalink
mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla committed Sep 24, 2024
1 parent cec1307 commit 1705351
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 26 deletions.
45 changes: 24 additions & 21 deletions fakeredis/_basefakesocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import queue
import time
import weakref
from typing import List, Any, Tuple, Optional, Callable, Union, Match
from typing import List, Any, Tuple, Optional, Callable, Union, Match, AnyStr, Generator
from xmlrpc.client import ResponseError

import redis
from redis.connection import DefaultParser

from . import _msgs as msgs
from ._command_args_parsing import extract_args
from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Item, Signature, CommandItem, Hash
from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Signature, CommandItem, Hash
from ._helpers import (
SimpleError,
valid_response_type,
Expand Down Expand Up @@ -46,7 +47,7 @@ def _extract_command(fields: List[bytes]) -> Tuple[Any, List[Any]]:
return cmd, cmd_arguments


def bin_reverse(x, bits_count):
def bin_reverse(x: int, bits_count: int) -> int:
result = 0
for i in range(bits_count):
if (x >> i) & 1:
Expand All @@ -55,7 +56,7 @@ def bin_reverse(x, bits_count):


class BaseFakeSocket:
_clear_watches: Callable
_clear_watches: Callable[[], None]
ACCEPTED_COMMANDS_WHILE_PUBSUB = {
"ping",
"subscribe",
Expand All @@ -68,14 +69,14 @@ class BaseFakeSocket:
}
_connection_error_class = redis.ConnectionError

def __init__(self, server: "FakeServer", db: int, *args: Any, **kwargs: Any): # noqa: F821
def __init__(self, server: "FakeServer", db: int, *args: Any, **kwargs: Any) -> None: # type: ignore # noqa: F821
super(BaseFakeSocket, self).__init__(*args, **kwargs)
from fakeredis import FakeServer

self._server: FakeServer = server
self._db_num = db
self._db = server.dbs[self._db_num]
self.responses: Optional[queue.Queue] = queue.Queue()
self.responses: Optional[queue.Queue[bytes]] = queue.Queue()
# Prevents parser from processing commands. Not used in this module,
# but set by aioredis module to prevent new commands being processed
# while handling a blocking command.
Expand Down Expand Up @@ -108,7 +109,7 @@ def resume(self) -> None:
self._paused = False
self._parser.send(b"")

def shutdown(self, _) -> None:
def shutdown(self, _: Any) -> None:
self._parser.close()

@staticmethod
Expand Down Expand Up @@ -136,20 +137,20 @@ def close(self) -> None:
# at any time, and hence we can't safely take the server lock.
# We rely on list.append being atomic.
self._server.closed_sockets.append(weakref.ref(self))
self._server = None
self._server = None # type: ignore
self._db = None
self.responses = None

@staticmethod
def _extract_line(buf):
def _extract_line(buf: bytes) -> Tuple[bytes, bytes]:
pos = buf.find(b"\n") + 1
assert pos > 0
line = buf[:pos]
buf = buf[pos:]
assert line.endswith(b"\r\n")
return line, buf

def _parse_commands(self):
def _parse_commands(self) -> Generator[None, Any, None]:
"""Generator that parses commands.
It is fed pieces of redis protocol data (via `send`) and calls
Expand All @@ -175,8 +176,10 @@ def _parse_commands(self):
buf = buf[length + 2 :] # +2 to skip the CRLF
self._process_command(fields)

def _run_command(self, func: Callable[..., Any], sig: Signature, args: List[Any], from_script: bool) -> Any:
command_items = {}
def _run_command(
self, func: Optional[Callable[[Any], Any]], sig: Signature, args: List[Any], from_script: bool
) -> Any:
command_items: List[CommandItem] = []
try:
ret = sig.apply(args, self._db, self.version)
if from_script and msgs.FLAG_NO_SCRIPT in sig.flags:
Expand All @@ -187,15 +190,15 @@ def _run_command(self, func: Callable[..., Any], sig: Signature, args: List[Any]
result = ret[0]
else:
args, command_items = ret
result = func(*args)
result = func(*args) # type: ignore
assert valid_response_type(result)
except SimpleError as exc:
result = exc
for command_item in command_items:
command_item.writeback(remove_empty_val=msgs.FLAG_LEAVE_EMPTY_VAL not in sig.flags)
return result

def _decode_error(self, error):
def _decode_error(self, error: SimpleError) -> ResponseError:
return DefaultParser(socket_read_size=65536).parse_error(error.value) # type: ignore

def _decode_result(self, result: Any) -> Any:
Expand All @@ -209,7 +212,7 @@ def _decode_result(self, result: Any) -> Any:
else:
return result

def _blocking(self, timeout: Optional[Union[float, int]], func: Callable[[bool], Any]):
def _blocking(self, timeout: Optional[Union[float, int]], func: Callable[[bool], Any]) -> Any:
"""Run a function until it succeeds or timeout is reached.
The timeout is in seconds, and 0 means infinite. The function
Expand All @@ -234,7 +237,7 @@ def _blocking(self, timeout: Optional[Union[float, int]], func: Callable[[bool],
if ret is not None:
return ret

def _name_to_func(self, cmd_name: str) -> Tuple[Optional[Callable], Signature]:
def _name_to_func(self, cmd_name: str) -> Tuple[Optional[Callable[[Any], Any]], Signature]:
"""Get the signature and the method from the command name."""
if cmd_name not in SUPPORTED_COMMANDS:
# redis remaps \r or \n in an error to ' ' to make it legal protocol
Expand All @@ -244,14 +247,14 @@ def _name_to_func(self, cmd_name: str) -> Tuple[Optional[Callable], Signature]:
func = getattr(self, sig.func_name, None)
return func, sig

def sendall(self, data):
def sendall(self, data: AnyStr) -> None:
if not self._server.connected:
raise self._connection_error_class(msgs.CONNECTION_ERROR_MSG)
if isinstance(data, str):
data = data.encode("ascii")
data = data.encode("ascii") # type: ignore
self._parser.send(data)

def _process_command(self, fields: List[bytes]):
def _process_command(self, fields: List[bytes]) -> None:
if not fields:
return
result: Any
Expand Down Expand Up @@ -353,7 +356,7 @@ def match_type(key) -> bool:
result_cursor = 0
return [str(bin_reverse(result_cursor, bits_len)).encode(), result_data]

def _ttl(self, key: CommandItem, scale: int) -> int:
def _ttl(self, key: CommandItem, scale: float) -> int:
if not key:
return -2
elif key.expireat is None:
Expand All @@ -372,7 +375,7 @@ def _encodeint(self, value: int) -> bytes:
return Int.encode(value)

@staticmethod
def _key_value_type(key: Item) -> SimpleString:
def _key_value_type(key: CommandItem) -> SimpleString:
if key.value is None:
return SimpleString(b"none")
elif isinstance(key.value, bytes):
Expand Down
10 changes: 5 additions & 5 deletions fakeredis/_fakesocket.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Set
from typing import Optional, Set, Any

from fakeredis.stack import (
JSONCommandsMixin,
Expand All @@ -9,8 +9,8 @@
TDigestCommandsMixin,
TimeSeriesCommandsMixin,
)
from ._server import FakeServer
from ._basefakesocket import BaseFakeSocket
from ._server import FakeServer
from .commands_mixins.bitmap_mixin import BitmapCommandsMixin
from .commands_mixins.connection_mixin import ConnectionCommandsMixin
from .commands_mixins.generic_mixin import GenericCommandsMixin
Expand All @@ -24,9 +24,9 @@
except ImportError:

class ScriptingCommandsMixin: # type: ignore # noqa: E303
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs.pop("lua_modules", None)
super(ScriptingCommandsMixin, self).__init__(*args, **kwargs)
super(ScriptingCommandsMixin, self).__init__(*args, **kwargs) # type: ignore


from .commands_mixins.server_mixin import ServerCommandsMixin
Expand Down Expand Up @@ -65,6 +65,6 @@ def __init__(
self,
server: "FakeServer",
db: int,
lua_modules: Optional[Set[str]] = None, # type: ignore # noqa: F821
lua_modules: Optional[Set[str]] = None, # noqa: F821
) -> None:
super(FakeSocket, self).__init__(server, db, lua_modules=lua_modules)

0 comments on commit 1705351

Please sign in to comment.