Skip to content

Commit

Permalink
acl save/load
Browse files Browse the repository at this point in the history
  • Loading branch information
cunla committed Nov 3, 2024
1 parent f65c4d3 commit cda0c00
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 7 deletions.
15 changes: 14 additions & 1 deletion fakeredis/commands_mixins/acl_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,19 @@ def acl_load(self) -> SimpleString:
with open(acl_filename, "rb") as f:
rules_list = f.readlines()
for rule in rules_list:
if not rule.startswith(b"user "):
continue
splitted = rule.split(b" ")
self._set_user_acl(splitted[0], *splitted[1:])
components = list()
i = 1
while i < len(splitted):
current_component = splitted[i]
if current_component.startswith(b"("):
while not current_component.endswith(b")"):
i += 1
current_component += b" " + splitted[i]
components.append(current_component)
i += 1

self._set_user_acl(components[0], *components[1:])
return OK
15 changes: 10 additions & 5 deletions fakeredis/model/_acl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,28 @@ def as_array(self) -> List[bytes]:

@classmethod
def from_bytes(cls, data: bytes) -> "Selector":
allowed = data[0] == ord("+")
data = data[1:]
command, data = data.split(b" ", 1)
keys = b""
channels = b""
command = b""
allowed = False
data = data.split(b" ")
for item in data:
if item.startswith(b"&"):
if item.startswith(b"&"): # channels
channels = item
continue
if item.startswith(b"%RW"):
if item.startswith(b"%RW"): # keys
item = item[3:]
key = item
if key.startswith(b"%"):
key = key[2:]
if key.startswith(b"~"):
keys = item
continue
# command
if item[0] == ord("+") or item[0] == ord("-"):
command = item[1:]
allowed = item[0] == ord("+")

return cls(command, allowed, keys, channels)


Expand Down
37 changes: 36 additions & 1 deletion test/test_internals/test_acl_save_load.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from fakeredis import FakeServer, FakeStrictRedis


Expand All @@ -19,6 +21,39 @@ def test_acl_save_load():
)
r.acl_save()

with open(acl_filename, "rb") as f:
# assert acl file contains all data
with open(acl_filename, "r") as f:
lines = f.readlines()
assert len(lines) == 1
user_rule = lines[0]
assert user_rule.startswith("user fakeredis-user")
assert "nopass" not in user_rule
assert "#e6c3da5b206634d7f3f3586d747ffdb36b5c675757b380c6a5fe5c570c714349" in user_rule
assert "#1ba3d16e9881959f8c9a9762854f72c6e6321cdd44358a10a4e939033117eab9" in user_rule
assert "on" in user_rule
assert "~cache:*" in user_rule
assert "~objects:*" in user_rule
assert "resetchannels &message:*" in user_rule
assert "(%W~app* resetchannels -@all +set)" in user_rule
assert "(~app* resetchannels &x -@all +get)" in user_rule
assert "(%W~app* resetchannels -@all -hset)" in user_rule

# assert acl file is loaded correctly
server2 = FakeServer(config={b"aclfile": acl_filename})
r2 = FakeStrictRedis(server=server2)
r2.acl_load()
rules = r2.acl_list()
user_rule = next(filter(lambda x: x.startswith(f"user {username}"), rules), None)
assert user_rule.startswith("user fakeredis-user")
assert "nopass" not in user_rule
assert "#e6c3da5b206634d7f3f3586d747ffdb36b5c675757b380c6a5fe5c570c714349" in user_rule
assert "#1ba3d16e9881959f8c9a9762854f72c6e6321cdd44358a10a4e939033117eab9" in user_rule
assert "on" in user_rule
assert "~cache:*" in user_rule
assert "~objects:*" in user_rule
assert "resetchannels &message:*" in user_rule
assert "(%W~app* resetchannels -@all +set)" in user_rule
assert "(~app* resetchannels &x -@all +get)" in user_rule
assert "(%W~app* resetchannels -@all -hset)" in user_rule

os.remove(acl_filename)

0 comments on commit cda0c00

Please sign in to comment.