Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix #12672, importing numpy scalars (zero-dimensional arrays) (#12678)
Browse files Browse the repository at this point in the history
* Fix #12672

Problem is in using np.ascontiguousarray,
which is buggy for zero-dimensional arrays
 (see numpy/numpy#5300 for details).

Here I use the solution proposed by numpy team:
switch to asarray with order='C'.

Add some tests for this situation (for array() and for setitem too).

* typo in tests
  • Loading branch information
arogozhnikov authored and nswamy committed Oct 4, 2018
1 parent 01dd703 commit be0c6ef
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 3 deletions.
2 changes: 1 addition & 1 deletion amalgamation/python/mxnet_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def forward(self, **kwargs):
for k, v in kwargs.items():
if not isinstance(v, np.ndarray):
raise ValueError("Expect numpy ndarray as input")
v = np.ascontiguousarray(v, dtype=np.float32)
v = np.asarray(v, dtype=np.float32, order='C')
_check_call(_LIB.MXPredSetInput(
self.handle, c_str(k),
v.ctypes.data_as(mx_float_p),
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ def _sync_copyfrom(self, source_array):
except:
raise TypeError('array must consist of array-like data,' +
'type %s is not supported' % str(type(array)))
source_array = np.ascontiguousarray(source_array, dtype=self.dtype)
source_array = np.asarray(source_array, dtype=self.dtype, order='C')
if source_array.shape != self.shape:
raise ValueError('Shape inconsistent: expected %s vs got %s'%(
str(self.shape), str(source_array.shape)))
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def as_stype(var, stype, dtype):
if stype == 'default':
executor.arg_dict[k][:] = as_stype(v, stype, dtype=dtype)
for k in location:
location[k] = np.ascontiguousarray(location[k])
location[k] = np.asarray(location[k], order='C')
for k, v in location.items():
if v.dtype.kind != 'f':
continue
Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def test_ndarray_setitem():
x_np[:, -3:-1, -2:-1] = 1
assert same(x.asnumpy(), x_np)

# numpy assignment for empty axis
for trivial_shape in [(), (1,), (1, 1), (1, 1, 1)]:
x = mx.nd.zeros(trivial_shape)
x[:] = np.ones(trivial_shape)
x_np = np.ones(trivial_shape, dtype=x.dtype)
assert x.shape == trivial_shape
assert same(x.asnumpy(), x_np)


@with_seed()
def test_ndarray_elementwise():
Expand Down Expand Up @@ -217,6 +225,13 @@ def test_ndarray_onehot():
assert same(npy, arr.asnumpy())


def test_init_from_scalar():
npy = np.ones([])
arr = mx.nd.array(npy)
assert arr.shape == ()
assert same(npy, arr.asnumpy())


@with_seed()
def test_ndarray_copy():
c = mx.nd.array(np.random.uniform(-10, 10, (10, 10)))
Expand Down

0 comments on commit be0c6ef

Please sign in to comment.