Skip to content

Commit

Permalink
fix bug while pickle globals of function
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 24, 2014
1 parent 729952a commit dfbccf5
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
42 changes: 36 additions & 6 deletions python/pyspark/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import itertools
from copy_reg import _extension_registry, _inverted_registry, _extension_cache
import new
import dis
import traceback
import platform

Expand All @@ -61,6 +62,14 @@
import logging
cloudLog = logging.getLogger("Cloud.Transport")

#relevant opcodes
STORE_GLOBAL = chr(dis.opname.index('STORE_GLOBAL'))
DELETE_GLOBAL = chr(dis.opname.index('DELETE_GLOBAL'))
LOAD_GLOBAL = chr(dis.opname.index('LOAD_GLOBAL'))
GLOBAL_OPS = [STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL]

HAVE_ARGUMENT = chr(dis.HAVE_ARGUMENT)
EXTENDED_ARG = chr(dis.EXTENDED_ARG)

if PyImp == "PyPy":
# register builtin type in `new`
Expand Down Expand Up @@ -304,16 +313,37 @@ def save_function_tuple(self, func, forced_imports):
write(pickle.REDUCE) # applies _fill_function on the tuple

@staticmethod
def extract_code_globals(code):
def extract_code_globals(co):
"""
Find all globals names read or written to by codeblock co
"""
names = set(code.co_names)
if code.co_consts: # see if nested function have any global refs
for const in code.co_consts:
code = co.co_code
names = co.co_names
out_names = set()

n = len(code)
i = 0
extended_arg = 0
while i < n:
op = code[i]

i = i+1
if op >= HAVE_ARGUMENT:
oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
extended_arg = 0
i = i+2
if op == EXTENDED_ARG:
extended_arg = oparg*65536L
if op in GLOBAL_OPS:
out_names.add(names[oparg])
#print 'extracted', out_names, ' from ', names

if co.co_consts: # see if nested function have any global refs
for const in co.co_consts:
if type(const) is types.CodeType:
names |= CloudPickler.extract_code_globals(const)
return names
out_names |= CloudPickler.extract_code_globals(const)

return out_names

def extract_func_data(self, func):
"""
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,24 @@ def test_pickling_file_handles(self):
out2 = ser.loads(ser.dumps(out1))
self.assertEquals(out1, out2)

def test_func_globals(self):

class Unpicklable(object):
def __reduce__(self):
raise Exception("not picklable")

global exit
exit = Unpicklable()

ser = CloudPickleSerializer()
self.assertRaises(Exception, lambda: ser.dumps(exit))

def foo():
sys.exit(0)

self.assertTrue("exit" in foo.func_code.co_names)
ser.dumps(foo)


class PySparkTestCase(unittest.TestCase):

Expand Down

0 comments on commit dfbccf5

Please sign in to comment.