Skip to content
This repository has been archived by the owner on Sep 11, 2019. It is now read-only.

Commit

Permalink
Merge pull request #13 from ska-sa/lua-return-types
Browse files Browse the repository at this point in the history
Improve emulation of redis -> Lua returns
  • Loading branch information
bmerry authored Feb 20, 2018
2 parents cc44e71 + 2788f83 commit 35f8948
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 24 deletions.
51 changes: 48 additions & 3 deletions fakenewsredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,19 @@ def _patch_responses(obj):
setattr(obj, attr_name, func)


def _lua_bool_ok(lua_runtime, value):
# Inverse of bool_ok wrapper from redis-py
return lua_runtime.table(ok='OK')


def _lua_reply(converter):
def decorator(func):
func._lua_reply = converter
return func

return decorator


def _remove_empty(func):
@functools.wraps(func)
def wrapper(self, key, *args, **kwargs):
Expand Down Expand Up @@ -285,15 +298,18 @@ def __init__(self, db=0, charset='utf-8', errors='strict',
if decode_responses:
_patch_responses(self)

@_lua_reply(_lua_bool_ok)
def flushdb(self):
self._db.clear()
return True

@_lua_reply(_lua_bool_ok)
def flushall(self):
for db in self._dbs.values():
db.clear()

del self._pubsubs[:]
return True

def _remove_if_empty(self, key):
try:
Expand Down Expand Up @@ -464,6 +480,7 @@ def mget(self, keys, *args):
found.append(value)
return found

@_lua_reply(_lua_bool_ok)
def mset(self, *args, **kwargs):
if args:
if len(args) != 1 or not isinstance(args[0], dict):
Expand Down Expand Up @@ -497,6 +514,7 @@ def ping(self):
def randomkey(self):
pass

@_lua_reply(_lua_bool_ok)
def rename(self, src, dst):
try:
value = self._db[src]
Expand Down Expand Up @@ -640,9 +658,11 @@ def type(self, name):
assert key is None
return b'none'

@_lua_reply(_lua_bool_ok)
def watch(self, *names):
pass

@_lua_reply(_lua_bool_ok)
def unwatch(self):
pass

Expand Down Expand Up @@ -758,14 +778,33 @@ def eval(self, script, numkeys, *keys_and_args):

return self._convert_lua_result(result, nested=False)

def _convert_redis_result(self, result):
def _convert_redis_result(self, lua_runtime, result):
if isinstance(result, dict):
return [
i
for item in result.items()
for i in item
]
return result
elif isinstance(result, set):
converted = sorted(
self._convert_redis_result(lua_runtime, item)
for item in result
)
return lua_runtime.table_from(converted)
elif isinstance(result, (list, set, tuple)):
converted = [
self._convert_redis_result(lua_runtime, item)
for item in result
]
return lua_runtime.table_from(converted)
elif isinstance(result, bool):
return int(result)
elif isinstance(result, float):
return to_bytes(result)
elif result is None:
return False
else:
return result

def _convert_lua_result(self, result, nested=True):
from lupa import lua_type
Expand Down Expand Up @@ -846,7 +885,9 @@ def _lua_redis_call(self, lua_runtime, expected_globals, op, *args):
'incrby': FakeStrictRedis.incr
}
func = special_cases[op] if op in special_cases else getattr(FakeStrictRedis, op)
return self._convert_redis_result(func(self, *args))
result = func(self, *args)
converter = getattr(func, '_lua_reply', self._convert_redis_result)
return converter(lua_runtime, result)

def _retrieve_data_from_sort(self, data, get):
if get is not None:
Expand Down Expand Up @@ -969,6 +1010,7 @@ def lpop(self, name):
except IndexError:
return None

@_lua_reply(_lua_bool_ok)
def lset(self, name, index, value):
try:
lst = self._get_list_or_none(name)
Expand All @@ -977,10 +1019,12 @@ def lset(self, name, index, value):
lst[index] = to_bytes(value)
except IndexError:
raise redis.ResponseError("index out of range")
return True

def rpushx(self, name, value):
self._get_list(name).append(to_bytes(value))

@_lua_reply(_lua_bool_ok)
def ltrim(self, name, start, end):
val = self._get_list_or_none(name)
if val is not None:
Expand Down Expand Up @@ -1886,6 +1930,7 @@ def pfcount(self, *sources):
"""
return len(self.sunion(*sources))

@_lua_reply(_lua_bool_ok)
def pfmerge(self, dest, *sources):
"Merge N different HyperLogLogs into a single one."
self.sunionstore(dest, sources)
Expand Down
123 changes: 102 additions & 21 deletions test_fakenewsredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,6 @@ def test_rpush_then_lrange_with_nested_list1(self):
self.assertEqual(self.redis.lrange(
'foo', 0, -1), ['[12345L, 6789L]', '[54321L, 9876L]'] if PY2 else
[b'[12345, 6789]', b'[54321, 9876]'])
self.redis.flushall()

def test_rpush_then_lrange_with_nested_list2(self):
self.assertEqual(self.redis.rpush('foo', [long(12345), 'banana']), 1)
Expand All @@ -597,7 +596,6 @@ def test_rpush_then_lrange_with_nested_list2(self):
'foo', 0, -1),
['[12345L, \'banana\']', '[54321L, \'elephant\']'] if PY2 else
[b'[12345, \'banana\']', b'[54321, \'elephant\']'])
self.redis.flushall()

def test_rpush_then_lrange_with_nested_list3(self):
self.assertEqual(self.redis.rpush('foo', [long(12345), []]), 1)
Expand All @@ -606,7 +604,6 @@ def test_rpush_then_lrange_with_nested_list3(self):
self.assertEqual(self.redis.lrange(
'foo', 0, -1), ['[12345L, []]', '[54321L, []]'] if PY2 else
[b'[12345, []]', b'[54321, []]'])
self.redis.flushall()

def test_lpush_then_lrange_all(self):
self.assertEqual(self.redis.lpush('foo', 'bar'), 1)
Expand Down Expand Up @@ -2309,7 +2306,7 @@ def test_multidb(self):
self.assertEqual(r1['r1'], b'r1')
self.assertEqual(r2['r2'], b'r2')

r1.flushall()
self.assertEqual(r1.flushall(), True)

self.assertTrue('r1' not in r1)
self.assertTrue('r2' not in r2)
Expand Down Expand Up @@ -3046,13 +3043,6 @@ def test_set_existing_key_persists(self):
self.redis.set('foo', 'foo')
self.assertEqual(self.redis.ttl('foo'), -1)

def test_eval_delete(self):
self.redis.set('foo', 'bar')
val = self.redis.get('foo')
self.assertEqual(val, b'bar')
val = self.redis.eval('redis.call("DEL", KEYS[1])', 1, 'foo')
self.assertIsNone(val)

def test_eval_set_value_to_arg(self):
self.redis.eval('redis.call("SET", KEYS[1], ARGV[1])', 1, 'foo', 'bar')
val = self.redis.get('foo')
Expand All @@ -3074,11 +3064,6 @@ def test_eval_conditional(self):
val = self.redis.get('foo')
self.assertEqual(val, b'baz')

def test_eval_lrange(self):
self.redis.lpush("foo", "bar")
val = self.redis.eval('return redis.call("LRANGE", KEYS[1], 0, 1)', 1, 'foo')
self.assertEqual(val, [b'bar'])

def test_eval_table(self):
lua = """
local a = {}
Expand Down Expand Up @@ -3168,23 +3153,23 @@ def test_eval_runtime_error(self):
with self.assertRaises(ResponseError):
self.redis.eval('error("CRASH")', 0)

def test_more_keys_than_args(self):
def test_eval_more_keys_than_args(self):
with self.assertRaises(ResponseError):
self.redis.eval('return 1', 42)

def test_numkeys_float_string(self):
def test_eval_numkeys_float_string(self):
with self.assertRaises(ResponseError):
self.redis.eval('return KEYS[1]', '0.7', 'foo')

def test_numkeys_integer_string(self):
def test_eval_numkeys_integer_string(self):
val = self.redis.eval('return KEYS[1]', "1", "foo")
self.assertEqual(val, b'foo')

def test_numkeys_negative(self):
def test_eval_numkeys_negative(self):
with self.assertRaises(ResponseError):
self.redis.eval('return KEYS[1]', -1, "foo")

def test_numkeys_float(self):
def test_eval_numkeys_float(self):
with self.assertRaises(ResponseError):
self.redis.eval('return KEYS[1]', 0.7, "foo")

Expand Down Expand Up @@ -3271,6 +3256,102 @@ def test_eval_pcall_return_value(self):
with self.assertRaises(ResponseError):
self.redis.eval('return redis.pcall("foo")', 0)

def test_eval_delete(self):
self.redis.set('foo', 'bar')
val = self.redis.get('foo')
self.assertEqual(val, b'bar')
val = self.redis.eval('redis.call("DEL", KEYS[1])', 1, 'foo')
self.assertIsNone(val)

def test_eval_exists(self):
val = self.redis.eval('return redis.call("exists", KEYS[1]) == 0', 1, 'foo')
self.assertEqual(val, 1)

This comment has been minimized.

Copy link
@blueyed

blueyed Feb 21, 2018

Would pass for val = True also: self.assertEqual(True, 1) does not fail.

This comment has been minimized.

Copy link
@bmerry

bmerry Feb 22, 2018

Author

That's true, but not really what the test is testing: it's testing the Lua type (and value) of redis.call return value, not the Python type obtained returned from the eval. The latter is tested by test_eval_convert_bool.


def test_eval_flushdb(self):
self.redis.set('foo', 'bar')
val = self.redis.eval(
'''
local value = redis.call("FLUSHDB");
return type(value) == "table" and value.ok == "OK";
''', 0
)
self.assertEqual(val, 1)

def test_eval_flushall(self):
r1 = self.create_redis(db=0)
r2 = self.create_redis(db=1)

r1['r1'] = 'r1'
r2['r2'] = 'r2'

val = self.redis.eval(
'''
local value = redis.call("FLUSHALL");
return type(value) == "table" and value.ok == "OK";
''', 0
)

self.assertEqual(val, 1)
self.assertNotIn('r1', r1)
self.assertNotIn('r2', r2)

def test_eval_incrbyfloat(self):
self.redis.set('foo', 0.5)
val = self.redis.eval(
'''
local value = redis.call("INCRBYFLOAT", KEYS[1], 2.0);
return type(value) == "string" and tonumber(value) == 2.5;
''', 1, 'foo'
)
self.assertEqual(val, 1)

def test_eval_lrange(self):
self.redis.rpush('foo', 'a', 'b')
val = self.redis.eval(
'''
local value = redis.call("LRANGE", KEYS[1], 0, -1);
return type(value) == "table" and value[1] == "a" and value[2] == "b";
''', 1, 'foo'
)
self.assertEqual(val, 1)

def test_eval_ltrim(self):
self.redis.rpush('foo', 'a', 'b', 'c', 'd')
val = self.redis.eval(
'''
local value = redis.call("LTRIM", KEYS[1], 1, 2);
return type(value) == "table" and value.ok == "OK";
''', 1, 'foo'
)
self.assertEqual(val, 1)
self.assertEqual(self.redis.lrange('foo', 0, -1), [b'b', b'c'])

def test_eval_lset(self):
self.redis.rpush('foo', 'a', 'b')
val = self.redis.eval(
'''
local value = redis.call("LSET", KEYS[1], 0, "z");
return type(value) == "table" and value.ok == "OK";
''', 1, 'foo'
)
self.assertEqual(val, 1)
self.assertEqual(self.redis.lrange('foo', 0, -1), [b'z', b'b'])

def test_eval_sdiff(self):
self.redis.sadd('foo', 'a', 'b', 'c', 'f', 'e', 'd')
self.redis.sadd('bar', 'b')
val = self.redis.eval(
'''
local value = redis.call("SDIFF", KEYS[1], KEYS[2]);
if type(value) ~= "table" then
return redis.error_reply(type(value) .. ", should be table");
else
return value;
end
''', 2, 'foo', 'bar')
# Lua must receive the set *sorted*
self.assertEqual(val, [b'a', b'c', b'd', b'e', b'f'])


class TestFakeRedis(unittest.TestCase):
decode_responses = False
Expand Down

0 comments on commit 35f8948

Please sign in to comment.