Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Serialization Logic #82

Merged
merged 3 commits into from
Sep 26, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 97 additions & 28 deletions codejail/safe_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# Flags to let developers temporarily change some behavior in this file.

# Set this to True to log all the code and globals being executed.
LOG_ALL_CODE = False
LOG_ALL_CODE = True

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I presume this change should be reverted before merging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

# Set this to True to use the unsafe code, so that you can debug it.
ALWAYS_BE_UNSAFE = False

Expand Down Expand Up @@ -97,6 +97,9 @@ def safe_exec(code, globals_dict, files=None, python_path=None, slug=None,
class DevNull(object):
def write(self, *args, **kwargs):
pass

def flush(self, *args, **kwargs):
pass
sys.stdout = DevNull()
"""
# Read the code and the globals from the stdin.
Expand Down Expand Up @@ -126,20 +129,6 @@ def write(self, *args, **kwargs):
# so recursively convert them to strings prior to creating the final globals
# dict
"""
def decode_object(obj):
if isinstance(obj, bytes):
return obj.decode('utf8')
elif isinstance(obj, list):
return [decode_object(i) for i in obj]
elif isinstance(obj, dict):
return {k: decode_object(v) for k, v in six.iteritems(obj)}
elif isinstance(obj, tuple):
return tuple(decode_object(i) for i in obj)
else:
return obj

decoded_dict = decode_object(g_dict)

def jsonable(v):
if not isinstance(v, ok_types):
return False
Expand All @@ -148,11 +137,48 @@ def jsonable(v):
except Exception:
return False
return True
g_dict = {
k:v
for k,v in six.iteritems(decoded_dict)
if jsonable(v) and k not in bad_keys
}

def filter_unserializable(obj):
if isinstance(obj, bytes):
return obj.decode('utf-8')
elif isinstance(obj, list):
new_list = []
for i in obj:
try:
new_obj = filter_unserializable(i)
if jsonable(new_obj):
new_list.append(new_obj)
except Exception as e:
pass # Don't add the item if we can't decode it
return new_list
elif isinstance(obj, dict):
new_dict = {}
for k,v in six.iteritems(obj):
try:
new_value = filter_unserializable(v)
if jsonable(new_value):
new_dict[k] = new_value

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the key comes in as bytes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add that case.

except Exception as e:
pass # Don't add the item if we can't decode it
return new_dict
elif isinstance(obj, tuple):
list_for_new_tuple = []
for i in obj:
try:
new_obj = filter_unserializable(i)
if jsonable(new_obj):
list_for_new_tuple.append(new_obj)
except Exception as e:
pass # Don't add the item if we can't decode it
return tuple(list_for_new_tuple)
else:
return obj

for key in bad_keys:
if key in g_dict:
del g_dict[key]

g_dict = filter_unserializable(g_dict)
"""
# Write the globals back to the calling process.
"""
Expand All @@ -172,6 +198,12 @@ def jsonable(v):
"python", code=jailed_code, stdin=stdin, files=files, slug=slug,
extra_files=extra_files,
)

if LOG_ALL_CODE:
log.debug("Status: %s", res.status)
log.debug("Stdout: %s", res.stdout)
log.debug("Stderr: %s", res.stderr)

if res.status != 0:
raise SafeExecException((
"Couldn't execute jailed code: stdout: {res.stdout!r}, "
Expand All @@ -187,23 +219,60 @@ def json_safe(d):
Used to emulate reading data through a serialization straw.

"""
def decode_object(obj):

ok_types = (type(None), int, float, str, six.text_type, list, tuple, dict)

def jsonable(v):
if not isinstance(v, ok_types):
return False
try:
json.dumps(v)
except Exception:
return False
return True

def filter_unserializable(obj):
if isinstance(obj, bytes):
return obj.decode('utf8')
return obj.decode('utf-8')
elif isinstance(obj, list):
return [decode_object(i) for i in obj]
new_list = []
for i in obj:
try:
new_obj = filter_unserializable(i)
if jsonable(new_obj):
new_list.append(new_obj)
except Exception:
pass # Don't add the item if we can't decode it
return new_list
elif isinstance(obj, dict):
return {k: decode_object(v) for k, v in six.iteritems(obj)}
new_dict = {}
for k,v in six.iteritems(obj):
try:
new_value = filter_unserializable(v)
if jsonable(new_value):
new_dict[k] = new_value
except Exception:
pass # Don't add the item if we can't decode it
return new_dict
elif isinstance(obj, tuple):
return tuple(decode_object(i) for i in obj)
list_for_new_tuple = []
for i in obj:
try:
new_obj = filter_unserializable(i)
if jsonable(new_obj):
list_for_new_tuple.append(new_obj)
except Exception:
pass # Don't add the item if we can't decode it
return tuple(list_for_new_tuple)
else:
return obj
decoded_dict = decode_object(d)

ok_types = (type(None), int, float, str, six.text_type, list, tuple, dict)
serializable_dict = filter_unserializable(d)
#serializable_dict = d

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to keep this comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove.


bad_keys = ("__builtins__",)
jd = {}
for k, v in six.iteritems(decoded_dict):
for k, v in six.iteritems(serializable_dict):
if not isinstance(v, ok_types):
continue
if k in bad_keys:
Expand Down
13 changes: 11 additions & 2 deletions codejail/tests/test_safe_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,16 @@ def test_set_values(self):
def test_complex_globals(self):
globs = {}
self.safe_exec(
"from builtins import bytes; test_dict = {1: bytes('a', 'utf8'), 2: 'b', 3: {1: bytes('b', 'utf8'), 2: (1, bytes('a', 'utf8'))}}",
textwrap.dedent("""\
from builtins import bytes
test_dict = {1: bytes('a', 'utf8'), 2: 'b', 3: {1: bytes('b', 'utf8'), 2: (1, bytes('a', 'utf8'))}}
foo = "bar"
test_dict_type = type(test_dict)
"""),
globs
)
from pprint import pprint
pprint(globs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this just for debugging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

self.assertDictEqual(globs['test_dict'], {'1': 'a', '2': 'b', '3': {'1': 'b', '2': [1, 'a']}})

def test_files_are_copied(self):
Expand Down Expand Up @@ -113,14 +120,16 @@ def test_extra_files(self):
]
self.safe_exec(textwrap.dedent("""\
import six
with open("extra.txt", 'rb') as f:
import io
with io.open("extra.txt", 'r') as f:
extra = f.read()
with open("also.dat", 'rb') as f:
if six.PY2:
also = f.read().encode("hex")
else:
also = f.read().hex()
"""), globs, extra_files=extras)

self.assertEqual(globs['extra'], "I'm extra!\n")
self.assertEqual(globs['also'], "01ff02fe")

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="codejail",
version="2.0",
version="2.1",
packages=['codejail'],
classifiers=[
"License :: OSI Approved :: Apache Software License",
Expand Down