diff --git a/dill/source.py b/dill/source.py index 4b22c945..18bd875c 100644 --- a/dill/source.py +++ b/dill/source.py @@ -530,7 +530,7 @@ def outdent(code, spaces=None, all=True): return '\n'.join(_outdent(code.split('\n'), spaces=spaces, all=all)) -#XXX: not sure what the point of _wrap is... +# _wrap provides an wrapper to correctly exec and load into locals __globals__ = globals() __locals__ = locals() def _wrap(f): diff --git a/dill/tests/test_source.py b/dill/tests/test_source.py index 1adb9d86..51dc8527 100644 --- a/dill/tests/test_source.py +++ b/dill/tests/test_source.py @@ -6,8 +6,8 @@ # License: 3-clause BSD. The full license text is available at: # - https://github.com/uqfoundation/dill/blob/master/LICENSE -from dill.source import getsource, getname, _wrap, likely_import -from dill.source import getimportable +from dill.source import getsource, getname, _wrap, getimport +from dill.source import importable from dill._dill import IS_PYPY import sys @@ -55,31 +55,31 @@ def test_getsource(): # test itself def test_itself(): - assert likely_import(likely_import)=='from dill.source import likely_import\n' + assert getimport(getimport)=='from dill.source import getimport\n' # builtin functions and objects def test_builtin(): - assert likely_import(pow) == 'pow\n' - assert likely_import(100) == '100\n' - assert likely_import(True) == 'True\n' - assert likely_import(pow, explicit=True) == 'from builtins import pow\n' - assert likely_import(100, explicit=True) == '100\n' - assert likely_import(True, explicit=True) == 'True\n' + assert getimport(pow) == 'pow\n' + assert getimport(100) == '100\n' + assert getimport(True) == 'True\n' + assert getimport(pow, builtin=True) == 'from builtins import pow\n' + assert getimport(100, builtin=True) == '100\n' + assert getimport(True, builtin=True) == 'True\n' # this is kinda BS... you can't import a None - assert likely_import(None) == 'None\n' - assert likely_import(None, explicit=True) == 'None\n' + assert getimport(None) == 'None\n' + assert getimport(None, builtin=True) == 'None\n' # other imported functions def test_imported(): from math import sin - assert likely_import(sin) == 'from math import sin\n' + assert getimport(sin) == 'from math import sin\n' # interactively defined functions def test_dynamic(): - assert likely_import(add) == 'from %s import add\n' % __name__ + assert getimport(add) == 'from %s import add\n' % __name__ # interactive lambdas - assert likely_import(squared) == 'from %s import squared\n' % __name__ + assert getimport(squared) == 'from %s import squared\n' % __name__ # classes and class instances def test_classes(): @@ -88,44 +88,44 @@ def test_classes(): x = y if (IS_PYPY or sys.hexversion >= PY310b) else "from io import BytesIO\n" s = StringIO() - assert likely_import(StringIO) == x - assert likely_import(s) == y + assert getimport(StringIO) == x + assert getimport(s) == y # interactively defined classes and class instances - assert likely_import(Foo) == 'from %s import Foo\n' % __name__ - assert likely_import(_foo) == 'from %s import Foo\n' % __name__ + assert getimport(Foo) == 'from %s import Foo\n' % __name__ + assert getimport(_foo) == 'from %s import Foo\n' % __name__ -# test getimportable +# test importable def test_importable(): - assert getimportable(add) == 'from %s import add\n' % __name__ - assert getimportable(squared) == 'from %s import squared\n' % __name__ - assert getimportable(Foo) == 'from %s import Foo\n' % __name__ - assert getimportable(Foo.bar) == 'from %s import bar\n' % __name__ - assert getimportable(_foo.bar) == 'from %s import bar\n' % __name__ - assert getimportable(None) == 'None\n' - assert getimportable(100) == '100\n' - - assert getimportable(add, byname=False) == 'def add(x,y):\n return x+y\n' - assert getimportable(squared, byname=False) == 'squared = lambda x:x**2\n' - assert getimportable(None, byname=False) == 'None\n' - assert getimportable(Bar, byname=False) == 'class Bar:\n pass\n' - assert getimportable(Foo, byname=False) == 'class Foo(object):\n def bar(self, x):\n return x*x+x\n' - assert getimportable(Foo.bar, byname=False) == 'def bar(self, x):\n return x*x+x\n' - assert getimportable(Foo.bar, byname=True) == 'from %s import bar\n' % __name__ - assert getimportable(Foo.bar, alias='memo', byname=True) == 'from %s import bar as memo\n' % __name__ - assert getimportable(Foo, alias='memo', byname=True) == 'from %s import Foo as memo\n' % __name__ - assert getimportable(squared, alias='memo', byname=True) == 'from %s import squared as memo\n' % __name__ - assert getimportable(squared, alias='memo', byname=False) == 'memo = squared = lambda x:x**2\n' - assert getimportable(add, alias='memo', byname=False) == 'def add(x,y):\n return x+y\n\nmemo = add\n' - assert getimportable(None, alias='memo', byname=False) == 'memo = None\n' - assert getimportable(100, alias='memo', byname=False) == 'memo = 100\n' - assert getimportable(add, explicit=True) == 'from %s import add\n' % __name__ - assert getimportable(squared, explicit=True) == 'from %s import squared\n' % __name__ - assert getimportable(Foo, explicit=True) == 'from %s import Foo\n' % __name__ - assert getimportable(Foo.bar, explicit=True) == 'from %s import bar\n' % __name__ - assert getimportable(_foo.bar, explicit=True) == 'from %s import bar\n' % __name__ - assert getimportable(None, explicit=True) == 'None\n' - assert getimportable(100, explicit=True) == '100\n' + assert importable(add, source=False) == 'from %s import add\n' % __name__ + assert importable(squared, source=False) == 'from %s import squared\n' % __name__ + assert importable(Foo, source=False) == 'from %s import Foo\n' % __name__ + assert importable(Foo.bar, source=False) == 'from %s import bar\n' % __name__ + assert importable(_foo.bar, source=False) == 'from %s import bar\n' % __name__ + assert importable(None, source=False) == 'None\n' + assert importable(100, source=False) == '100\n' + + assert importable(add, source=True) == 'def add(x,y):\n return x+y\n' + assert importable(squared, source=True) == 'squared = lambda x:x**2\n' + assert importable(None, source=True) == 'None\n' + assert importable(Bar, source=True) == 'class Bar:\n pass\n' + assert importable(Foo, source=True) == 'class Foo(object):\n def bar(self, x):\n return x*x+x\n' + assert importable(Foo.bar, source=True) == 'def bar(self, x):\n return x*x+x\n' + assert importable(Foo.bar, source=False) == 'from %s import bar\n' % __name__ + assert importable(Foo.bar, alias='memo', source=False) == 'from %s import bar as memo\n' % __name__ + assert importable(Foo, alias='memo', source=False) == 'from %s import Foo as memo\n' % __name__ + assert importable(squared, alias='memo', source=False) == 'from %s import squared as memo\n' % __name__ + assert importable(squared, alias='memo', source=True) == 'memo = squared = lambda x:x**2\n' + assert importable(add, alias='memo', source=True) == 'def add(x,y):\n return x+y\n\nmemo = add\n' + assert importable(None, alias='memo', source=True) == 'memo = None\n' + assert importable(100, alias='memo', source=True) == 'memo = 100\n' + assert importable(add, builtin=True, source=False) == 'from %s import add\n' % __name__ + assert importable(squared, builtin=True, source=False) == 'from %s import squared\n' % __name__ + assert importable(Foo, builtin=True, source=False) == 'from %s import Foo\n' % __name__ + assert importable(Foo.bar, builtin=True, source=False) == 'from %s import bar\n' % __name__ + assert importable(_foo.bar, builtin=True, source=False) == 'from %s import bar\n' % __name__ + assert importable(None, builtin=True, source=False) == 'None\n' + assert importable(100, builtin=True, source=False) == '100\n' def test_numpy(): @@ -133,16 +133,16 @@ def test_numpy(): import numpy as np y = np.array x = y([1,2,3]) - assert getimportable(x) == 'from numpy import array\narray([1, 2, 3])\n' - assert getimportable(y) == 'from %s import array\n' % y.__module__ - assert getimportable(x, byname=False) == 'from numpy import array\narray([1, 2, 3])\n' - assert getimportable(y, byname=False) == 'from %s import array\n' % y.__module__ + assert importable(x, source=False) == 'from numpy import array\narray([1, 2, 3])\n' + assert importable(y, source=False) == 'from %s import array\n' % y.__module__ + assert importable(x, source=True) == 'from numpy import array\narray([1, 2, 3])\n' + assert importable(y, source=True) == 'from %s import array\n' % y.__module__ y = np.int64 x = y(0) - assert getimportable(x) == 'from numpy import int64\nint64(0)\n' - assert getimportable(y) == 'from %s import int64\n' % y.__module__ - assert getimportable(x, byname=False) == 'from numpy import int64\nint64(0)\n' - assert getimportable(y, byname=False) == 'from %s import int64\n' % y.__module__ + assert importable(x, source=False) == 'from numpy import int64\nint64(0)\n' + assert importable(y, source=False) == 'from %s import int64\n' % y.__module__ + assert importable(x, source=True) == 'from numpy import int64\nint64(0)\n' + assert importable(y, source=True) == 'from %s import int64\n' % y.__module__ y = np.bool_ x = y(0) import warnings @@ -151,15 +151,15 @@ def test_numpy(): warnings.filterwarnings('ignore', category=DeprecationWarning) if hasattr(np, 'bool'): b = 'bool_' if np.bool is bool else 'bool' else: b = 'bool_' - assert getimportable(x) == 'from numpy import %s\n%s(False)\n' % (b,b) - assert getimportable(y) == 'from %s import %s\n' % (y.__module__,b) - assert getimportable(x, byname=False) == 'from numpy import %s\n%s(False)\n' % (b,b) - assert getimportable(y, byname=False) == 'from %s import %s\n' % (y.__module__,b) + assert importable(x, source=False) == 'from numpy import %s\n%s(False)\n' % (b,b) + assert importable(y, source=False) == 'from %s import %s\n' % (y.__module__,b) + assert importable(x, source=True) == 'from numpy import %s\n%s(False)\n' % (b,b) + assert importable(y, source=True) == 'from %s import %s\n' % (y.__module__,b) except ImportError: pass -#NOTE: if before likely_import(pow), will cause pow to throw AssertionError +#NOTE: if before getimport(pow), will cause pow to throw AssertionError def test_foo(): - assert getimportable(_foo, byname=False).startswith("import dill\nclass Foo(object):\n def bar(self, x):\n return x*x+x\ndill.loads(") + assert importable(_foo, source=True).startswith("import dill\nclass Foo(object):\n def bar(self, x):\n return x*x+x\ndill.loads(") if __name__ == '__main__': test_getsource() diff --git a/dill/tests/test_sources.py b/dill/tests/test_sources.py new file mode 100644 index 00000000..9deb2422 --- /dev/null +++ b/dill/tests/test_sources.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python +# +# Author: Mike McKerns (mmckerns @uqfoundation) +# Copyright (c) 2024 The Uncertainty Quantification Foundation. +# License: 3-clause BSD. The full license text is available at: +# - https://github.com/uqfoundation/dill/blob/master/LICENSE +""" +check that dill.source performs as expected with changes to locals in 3.13.0b1 +see: https://github.com/python/cpython/issues/118888 +""" +# repeat functions from test_source.py +f = lambda x: x**2 +def g(x): return f(x) - x + +def h(x): + def g(x): return x + return g(x) - x + +class Foo(object): + def bar(self, x): + return x*x+x +_foo = Foo() + +def add(x,y): + return x+y + +squared = lambda x:x**2 + +class Bar: + pass +_bar = Bar() + +# repeat, but from test_source.py +import test_source as ts + +# test objects created in other test modules +import test_mixins as tm + +import dill.source as ds + + +def test_isfrommain(): + assert ds.isfrommain(add) == True + assert ds.isfrommain(squared) == True + assert ds.isfrommain(Bar) == True + assert ds.isfrommain(_bar) == True + assert ds.isfrommain(ts.add) == False + assert ds.isfrommain(ts.squared) == False + assert ds.isfrommain(ts.Bar) == False + assert ds.isfrommain(ts._bar) == False + assert ds.isfrommain(tm.quad) == False + assert ds.isfrommain(tm.double_add) == False + assert ds.isfrommain(tm.quadratic) == False + assert ds.isdynamic(add) == False + assert ds.isdynamic(squared) == False + assert ds.isdynamic(ts.add) == False + assert ds.isdynamic(ts.squared) == False + assert ds.isdynamic(tm.double_add) == False + assert ds.isdynamic(tm.quadratic) == False + + +def test_matchlambda(): + assert ds._matchlambda(f, 'f = lambda x: x**2\n') + assert ds._matchlambda(squared, 'squared = lambda x:x**2\n') + assert ds._matchlambda(ts.f, 'f = lambda x: x**2\n') + assert ds._matchlambda(ts.squared, 'squared = lambda x:x**2\n') + + +def test_findsource(): + lines, lineno = ds.findsource(add) + assert lines[lineno] == 'def add(x,y):\n' + lines, lineno = ds.findsource(ts.add) + assert lines[lineno] == 'def add(x,y):\n' + lines, lineno = ds.findsource(squared) + assert lines[lineno] == 'squared = lambda x:x**2\n' + lines, lineno = ds.findsource(ts.squared) + assert lines[lineno] == 'squared = lambda x:x**2\n' + lines, lineno = ds.findsource(Bar) + assert lines[lineno] == 'class Bar:\n' + lines, lineno = ds.findsource(ts.Bar) + assert lines[lineno] == 'class Bar:\n' + lines, lineno = ds.findsource(_bar) + assert lines[lineno] == 'class Bar:\n' + lines, lineno = ds.findsource(ts._bar) + assert lines[lineno] == 'class Bar:\n' + lines, lineno = ds.findsource(tm.quad) + assert lines[lineno] == 'def quad(a=1, b=1, c=0):\n' + lines, lineno = ds.findsource(tm.double_add) + assert lines[lineno] == ' def func(*args, **kwds):\n' + lines, lineno = ds.findsource(tm.quadratic) + assert lines[lineno] == ' def dec(f):\n' + + +def test_getsourcelines(): + assert ''.join(ds.getsourcelines(add)[0]) == 'def add(x,y):\n return x+y\n' + assert ''.join(ds.getsourcelines(ts.add)[0]) == 'def add(x,y):\n return x+y\n' + assert ''.join(ds.getsourcelines(squared)[0]) == 'squared = lambda x:x**2\n' + assert ''.join(ds.getsourcelines(ts.squared)[0]) == 'squared = lambda x:x**2\n' + assert ''.join(ds.getsourcelines(Bar)[0]) == 'class Bar:\n pass\n' + assert ''.join(ds.getsourcelines(ts.Bar)[0]) == 'class Bar:\n pass\n' + assert ''.join(ds.getsourcelines(_bar)[0]) == 'class Bar:\n pass\n' #XXX: ? + assert ''.join(ds.getsourcelines(ts._bar)[0]) == 'class Bar:\n pass\n' #XXX: ? + assert ''.join(ds.getsourcelines(tm.quad)[0]) == 'def quad(a=1, b=1, c=0):\n inverted = [False]\n def invert():\n inverted[0] = not inverted[0]\n def dec(f):\n def func(*args, **kwds):\n x = f(*args, **kwds)\n if inverted[0]: x = -x\n return a*x**2 + b*x + c\n func.__wrapped__ = f\n func.invert = invert\n func.inverted = inverted\n return func\n return dec\n' + assert ''.join(ds.getsourcelines(tm.quadratic)[0]) == ' def dec(f):\n def func(*args,**kwds):\n fx = f(*args,**kwds)\n return a*fx**2 + b*fx + c\n return func\n' + assert ''.join(ds.getsourcelines(tm.quadratic, lstrip=True)[0]) == 'def dec(f):\n def func(*args,**kwds):\n fx = f(*args,**kwds)\n return a*fx**2 + b*fx + c\n return func\n' + assert ''.join(ds.getsourcelines(tm.quadratic, enclosing=True)[0]) == 'def quad_factory(a=1,b=1,c=0):\n def dec(f):\n def func(*args,**kwds):\n fx = f(*args,**kwds)\n return a*fx**2 + b*fx + c\n return func\n return dec\n' + assert ''.join(ds.getsourcelines(tm.double_add)[0]) == ' def func(*args, **kwds):\n x = f(*args, **kwds)\n if inverted[0]: x = -x\n return a*x**2 + b*x + c\n' + assert ''.join(ds.getsourcelines(tm.double_add, enclosing=True)[0]) == 'def quad(a=1, b=1, c=0):\n inverted = [False]\n def invert():\n inverted[0] = not inverted[0]\n def dec(f):\n def func(*args, **kwds):\n x = f(*args, **kwds)\n if inverted[0]: x = -x\n return a*x**2 + b*x + c\n func.__wrapped__ = f\n func.invert = invert\n func.inverted = inverted\n return func\n return dec\n' + + +def test_indent(): + assert ds.outdent(''.join(ds.getsourcelines(tm.quadratic)[0])) == ''.join(ds.getsourcelines(tm.quadratic, lstrip=True)[0]) + assert ds.indent(''.join(ds.getsourcelines(tm.quadratic, lstrip=True)[0]), 2) == ''.join(ds.getsourcelines(tm.quadratic)[0]) + + +def test_dumpsource(): + local = {} + exec(ds.dumpsource(add, alias='raw'), {}, local) + exec(ds.dumpsource(ts.add, alias='mod'), {}, local) + assert local['raw'](1,2) == local['mod'](1,2) + exec(ds.dumpsource(squared, alias='raw'), {}, local) + exec(ds.dumpsource(ts.squared, alias='mod'), {}, local) + assert local['raw'](3) == local['mod'](3) + assert ds._wrap(add)(1,2) == ds._wrap(ts.add)(1,2) + assert ds._wrap(squared)(3) == ds._wrap(ts.squared)(3) + + +def test_name(): + assert ds._namespace(add) == ds.getname(add, fqn=True).split('.') + assert ds._namespace(ts.add) == ds.getname(ts.add, fqn=True).split('.') + assert ds._namespace(squared) == ds.getname(squared, fqn=True).split('.') + assert ds._namespace(ts.squared) == ds.getname(ts.squared, fqn=True).split('.') + assert ds._namespace(Bar) == ds.getname(Bar, fqn=True).split('.') + assert ds._namespace(ts.Bar) == ds.getname(ts.Bar, fqn=True).split('.') + assert ds._namespace(tm.quad) == ds.getname(tm.quad, fqn=True).split('.') + #XXX: the following also works, however behavior may be wrong for nested functions + #assert ds._namespace(tm.double_add) == ds.getname(tm.double_add, fqn=True).split('.') + #assert ds._namespace(tm.quadratic) == ds.getname(tm.quadratic, fqn=True).split('.') + assert ds.getname(add) == 'add' + assert ds.getname(ts.add) == 'add' + assert ds.getname(squared) == 'squared' + assert ds.getname(ts.squared) == 'squared' + assert ds.getname(Bar) == 'Bar' + assert ds.getname(ts.Bar) == 'Bar' + assert ds.getname(tm.quad) == 'quad' + assert ds.getname(tm.double_add) == 'func' #XXX: ? + assert ds.getname(tm.quadratic) == 'dec' #XXX: ? + + +def test_getimport(): + local = {} + exec(ds.getimport(add, alias='raw'), {}, local) + exec(ds.getimport(ts.add, alias='mod'), {}, local) + assert local['raw'](1,2) == local['mod'](1,2) + exec(ds.getimport(squared, alias='raw'), {}, local) + exec(ds.getimport(ts.squared, alias='mod'), {}, local) + assert local['raw'](3) == local['mod'](3) + exec(ds.getimport(Bar, alias='raw'), {}, local) + exec(ds.getimport(ts.Bar, alias='mod'), {}, local) + assert ds.getname(local['raw']) == ds.getname(local['mod']) + exec(ds.getimport(tm.quad, alias='mod'), {}, local) + assert local['mod']()(sum)([1,2,3]) == tm.quad()(sum)([1,2,3]) + #FIXME: wrong results for nested functions (e.g. tm.double_add, tm.quadratic) + + +def test_importable(): + assert ds.importable(add, source=False) == ds.getimport(add) + assert ds.importable(add) == ds.getsource(add) + assert ds.importable(squared, source=False) == ds.getimport(squared) + assert ds.importable(squared) == ds.getsource(squared) + assert ds.importable(Bar, source=False) == ds.getimport(Bar) + assert ds.importable(Bar) == ds.getsource(Bar) + assert ds.importable(ts.add) == ds.getimport(ts.add) + assert ds.importable(ts.add, source=True) == ds.getsource(ts.add) + assert ds.importable(ts.squared) == ds.getimport(ts.squared) + assert ds.importable(ts.squared, source=True) == ds.getsource(ts.squared) + assert ds.importable(ts.Bar) == ds.getimport(ts.Bar) + assert ds.importable(ts.Bar, source=True) == ds.getsource(ts.Bar) + + +if __name__ == '__main__': + test_isfrommain() + test_matchlambda() + test_findsource() + test_getsourcelines() + test_indent() + test_dumpsource() + test_name() + test_getimport() + test_importable()