Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Serialization Logic #82

Merged
merged 3 commits into from
Sep 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 97 additions & 27 deletions codejail/safe_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def safe_exec(code, globals_dict, files=None, python_path=None, slug=None,
class DevNull(object):
def write(self, *args, **kwargs):
pass

def flush(self, *args, **kwargs):
pass
sys.stdout = DevNull()
"""
# Read the code and the globals from the stdin.
Expand Down Expand Up @@ -126,20 +129,6 @@ def write(self, *args, **kwargs):
# so recursively convert them to strings prior to creating the final globals
# dict
"""
def decode_object(obj):
if isinstance(obj, bytes):
return obj.decode('utf8')
elif isinstance(obj, list):
return [decode_object(i) for i in obj]
elif isinstance(obj, dict):
return {k: decode_object(v) for k, v in six.iteritems(obj)}
elif isinstance(obj, tuple):
return tuple(decode_object(i) for i in obj)
else:
return obj

decoded_dict = decode_object(g_dict)

def jsonable(v):
if not isinstance(v, ok_types):
return False
Expand All @@ -148,11 +137,49 @@ def jsonable(v):
except Exception:
return False
return True
g_dict = {
k:v
for k,v in six.iteritems(decoded_dict)
if jsonable(v) and k not in bad_keys
}

def filter_unserializable(obj):
if isinstance(obj, bytes):
return obj.decode('utf-8')
elif isinstance(obj, list):
new_list = []
for i in obj:
try:
new_obj = filter_unserializable(i)
if jsonable(new_obj):
new_list.append(new_obj)
except Exception as e:
pass # Don't add the item if we can't decode it
return new_list
elif isinstance(obj, dict):
new_dict = {}
for k,v in six.iteritems(obj):
try:
new_key = filter_unserializable(k)
new_value = filter_unserializable(v)
if jsonable(new_value) and jsonable(new_key):
new_dict[new_key] = new_value
except Exception as e:
pass # Don't add the item if we can't decode it
return new_dict
elif isinstance(obj, tuple):
list_for_new_tuple = []
for i in obj:
try:
new_obj = filter_unserializable(i)
if jsonable(new_obj):
list_for_new_tuple.append(new_obj)
except Exception as e:
pass # Don't add the item if we can't decode it
return tuple(list_for_new_tuple)
else:
return obj

for key in bad_keys:
if key in g_dict:
del g_dict[key]

g_dict = filter_unserializable(g_dict)
"""
# Write the globals back to the calling process.
"""
Expand All @@ -172,6 +199,12 @@ def jsonable(v):
"python", code=jailed_code, stdin=stdin, files=files, slug=slug,
extra_files=extra_files,
)

if LOG_ALL_CODE:
log.debug("Status: %s", res.status)
log.debug("Stdout: %s", res.stdout)
log.debug("Stderr: %s", res.stderr)

if res.status != 0:
raise SafeExecException((
"Couldn't execute jailed code: stdout: {res.stdout!r}, "
Expand All @@ -187,23 +220,60 @@ def json_safe(d):
Used to emulate reading data through a serialization straw.

"""
def decode_object(obj):

ok_types = (type(None), int, float, str, six.text_type, list, tuple, dict)

def jsonable(v):
if not isinstance(v, ok_types):
return False
try:
json.dumps(v)
except Exception:
return False
return True

def filter_unserializable(obj):
if isinstance(obj, bytes):
return obj.decode('utf8')
return obj.decode('utf-8')
elif isinstance(obj, list):
return [decode_object(i) for i in obj]
new_list = []
for i in obj:
try:
new_obj = filter_unserializable(i)
if jsonable(new_obj):
new_list.append(new_obj)
except Exception:
pass # Don't add the item if we can't decode it
return new_list
elif isinstance(obj, dict):
return {k: decode_object(v) for k, v in six.iteritems(obj)}
new_dict = {}
for k,v in six.iteritems(obj):
try:
new_key = filter_unserializable(k)
new_value = filter_unserializable(v)
if jsonable(new_value) and jsonable(new_key):
new_dict[new_key] = new_value
except Exception:
pass # Don't add the item if we can't decode it
return new_dict
elif isinstance(obj, tuple):
return tuple(decode_object(i) for i in obj)
list_for_new_tuple = []
for i in obj:
try:
new_obj = filter_unserializable(i)
if jsonable(new_obj):
list_for_new_tuple.append(new_obj)
except Exception:
pass # Don't add the item if we can't decode it
return tuple(list_for_new_tuple)
else:
return obj
decoded_dict = decode_object(d)

ok_types = (type(None), int, float, str, six.text_type, list, tuple, dict)
serializable_dict = filter_unserializable(d)

bad_keys = ("__builtins__",)
jd = {}
for k, v in six.iteritems(decoded_dict):
for k, v in six.iteritems(serializable_dict):
if not isinstance(v, ok_types):
continue
if k in bad_keys:
Expand Down
11 changes: 9 additions & 2 deletions codejail/tests/test_safe_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ def test_set_values(self):
def test_complex_globals(self):
globs = {}
self.safe_exec(
"from builtins import bytes; test_dict = {1: bytes('a', 'utf8'), 2: 'b', 3: {1: bytes('b', 'utf8'), 2: (1, bytes('a', 'utf8'))}}",
textwrap.dedent("""\
from builtins import bytes
test_dict = {1: bytes('a', 'utf8'), 2: 'b', 3: {1: bytes('b', 'utf8'), 2: (1, bytes('a', 'utf8'))}}
foo = "bar"
test_dict_type = type(test_dict)
"""),
globs
)
self.assertDictEqual(globs['test_dict'], {'1': 'a', '2': 'b', '3': {'1': 'b', '2': [1, 'a']}})
Expand Down Expand Up @@ -113,14 +118,16 @@ def test_extra_files(self):
]
self.safe_exec(textwrap.dedent("""\
import six
with open("extra.txt", 'rb') as f:
import io
with io.open("extra.txt", 'r') as f:
extra = f.read()
with open("also.dat", 'rb') as f:
if six.PY2:
also = f.read().encode("hex")
else:
also = f.read().hex()
"""), globs, extra_files=extras)

self.assertEqual(globs['extra'], "I'm extra!\n")
self.assertEqual(globs['also'], "01ff02fe")

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="codejail",
version="2.0",
version="2.1",
packages=['codejail'],
classifiers=[
"License :: OSI Approved :: Apache Software License",
Expand Down