Skip to content

Commit

Permalink
better handle import strings of numpy scalars (#678)
Browse files Browse the repository at this point in the history
* better handling of import strings of numpy scalars

* handle deprecated np.bool is bool
  • Loading branch information
mmckerns authored Sep 10, 2024
1 parent a14e75d commit 15d7c6d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
9 changes: 7 additions & 2 deletions dill/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,10 +602,13 @@ def dumpsource(object, alias='', new=False, enclose=True):

def getname(obj, force=False, fqn=False): #XXX: throw(?) to raise error on fail?
"""get the name of the object. for lambdas, get the name of the pointer """
if fqn: return '.'.join(_namespace(obj))
if fqn: return '.'.join(_namespace(obj)) #NOTE: returns 'type'
module = getmodule(obj)
if not module: # things like "None" and "1"
if not force: return None
if not force: return None #NOTE: returns 'instance' NOT 'type' #FIXME?
# handle some special cases
if hasattr(obj, 'dtype') and not obj.shape:
return getname(obj.__class__) + "(" + repr(obj.tolist()) + ")"
return repr(obj)
try:
#XXX: 'wrong' for decorators and curried functions ?
Expand Down Expand Up @@ -740,6 +743,8 @@ def getimport(obj, alias='', verify=True, builtin=False, enclosing=False):
except Exception: # it's probably something 'importable'
if head in ['builtins','__builtin__']:
name = repr(obj) #XXX: catch [1,2], (1,2), set([1,2])... others?
elif _isinstance(obj):
name = getname(obj, force=True).split('(')[0]
else:
name = repr(obj).split('(')[0]
#if not repr(obj).startswith('<'): name = repr(obj).split('(')[0]
Expand Down
27 changes: 23 additions & 4 deletions dill/tests/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,31 @@ def test_importable():

def test_numpy():
try:
from numpy import array
x = array([1,2,3])
import numpy as np
y = np.array
x = y([1,2,3])
assert getimportable(x) == 'from numpy import array\narray([1, 2, 3])\n'
assert getimportable(array) == 'from %s import array\n' % array.__module__
assert getimportable(y) == 'from %s import array\n' % y.__module__
assert getimportable(x, byname=False) == 'from numpy import array\narray([1, 2, 3])\n'
assert getimportable(array, byname=False) == 'from %s import array\n' % array.__module__
assert getimportable(y, byname=False) == 'from %s import array\n' % y.__module__
y = np.int64
x = y(0)
assert getimportable(x) == 'from numpy import int64\nint64(0)\n'
assert getimportable(y) == 'from %s import int64\n' % y.__module__
assert getimportable(x, byname=False) == 'from numpy import int64\nint64(0)\n'
assert getimportable(y, byname=False) == 'from %s import int64\n' % y.__module__
y = np.bool_
x = y(0)
import warnings
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=DeprecationWarning)
if hasattr(np, 'bool'): b = 'bool_' if np.bool is bool else 'bool'
else: b = 'bool_'
assert getimportable(x) == 'from numpy import %s\n%s(False)\n' % (b,b)
assert getimportable(y) == 'from %s import %s\n' % (y.__module__,b)
assert getimportable(x, byname=False) == 'from numpy import %s\n%s(False)\n' % (b,b)
assert getimportable(y, byname=False) == 'from %s import %s\n' % (y.__module__,b)
except ImportError: pass

#NOTE: if before likely_import(pow), will cause pow to throw AssertionError
Expand Down

0 comments on commit 15d7c6d

Please sign in to comment.