From 15d7c6d6ccf4781c624ffbf54c90d23c6e94dc52 Mon Sep 17 00:00:00 2001 From: Mike McKerns Date: Tue, 10 Sep 2024 18:43:34 -0400 Subject: [PATCH] better handle import strings of numpy scalars (#678) * better handling of import strings of numpy scalars * handle deprecated np.bool is bool --- dill/source.py | 9 +++++++-- dill/tests/test_source.py | 27 +++++++++++++++++++++++---- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/dill/source.py b/dill/source.py index 160355c3..4b22c945 100644 --- a/dill/source.py +++ b/dill/source.py @@ -602,10 +602,13 @@ def dumpsource(object, alias='', new=False, enclose=True): def getname(obj, force=False, fqn=False): #XXX: throw(?) to raise error on fail? """get the name of the object. for lambdas, get the name of the pointer """ - if fqn: return '.'.join(_namespace(obj)) + if fqn: return '.'.join(_namespace(obj)) #NOTE: returns 'type' module = getmodule(obj) if not module: # things like "None" and "1" - if not force: return None + if not force: return None #NOTE: returns 'instance' NOT 'type' #FIXME? + # handle some special cases + if hasattr(obj, 'dtype') and not obj.shape: + return getname(obj.__class__) + "(" + repr(obj.tolist()) + ")" return repr(obj) try: #XXX: 'wrong' for decorators and curried functions ? @@ -740,6 +743,8 @@ def getimport(obj, alias='', verify=True, builtin=False, enclosing=False): except Exception: # it's probably something 'importable' if head in ['builtins','__builtin__']: name = repr(obj) #XXX: catch [1,2], (1,2), set([1,2])... others? + elif _isinstance(obj): + name = getname(obj, force=True).split('(')[0] else: name = repr(obj).split('(')[0] #if not repr(obj).startswith('<'): name = repr(obj).split('(')[0] diff --git a/dill/tests/test_source.py b/dill/tests/test_source.py index 8771a542..1adb9d86 100644 --- a/dill/tests/test_source.py +++ b/dill/tests/test_source.py @@ -130,12 +130,31 @@ def test_importable(): def test_numpy(): try: - from numpy import array - x = array([1,2,3]) + import numpy as np + y = np.array + x = y([1,2,3]) assert getimportable(x) == 'from numpy import array\narray([1, 2, 3])\n' - assert getimportable(array) == 'from %s import array\n' % array.__module__ + assert getimportable(y) == 'from %s import array\n' % y.__module__ assert getimportable(x, byname=False) == 'from numpy import array\narray([1, 2, 3])\n' - assert getimportable(array, byname=False) == 'from %s import array\n' % array.__module__ + assert getimportable(y, byname=False) == 'from %s import array\n' % y.__module__ + y = np.int64 + x = y(0) + assert getimportable(x) == 'from numpy import int64\nint64(0)\n' + assert getimportable(y) == 'from %s import int64\n' % y.__module__ + assert getimportable(x, byname=False) == 'from numpy import int64\nint64(0)\n' + assert getimportable(y, byname=False) == 'from %s import int64\n' % y.__module__ + y = np.bool_ + x = y(0) + import warnings + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=FutureWarning) + warnings.filterwarnings('ignore', category=DeprecationWarning) + if hasattr(np, 'bool'): b = 'bool_' if np.bool is bool else 'bool' + else: b = 'bool_' + assert getimportable(x) == 'from numpy import %s\n%s(False)\n' % (b,b) + assert getimportable(y) == 'from %s import %s\n' % (y.__module__,b) + assert getimportable(x, byname=False) == 'from numpy import %s\n%s(False)\n' % (b,b) + assert getimportable(y, byname=False) == 'from %s import %s\n' % (y.__module__,b) except ImportError: pass #NOTE: if before likely_import(pow), will cause pow to throw AssertionError