Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Improve sparse ndarray error message (#7181)
Browse files Browse the repository at this point in the history
* add test for broadcast_to

* add comments
  • Loading branch information
eric-haibin-lin authored and piiswrong committed Jul 25, 2017
1 parent 784e689 commit 6644d22
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 14 deletions.
15 changes: 15 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ class MXNetError(Exception):
"""Error that will be throwed by all mxnet functions."""
pass

class NotSupportedForSparseNDArray(MXNetError):
def __init__(self, function, alias, *args):
super(NotSupportedForSparseNDArray, self).__init__()
self.function = function.__name__
self.alias = alias
self.args = [str(type(a)) for a in args]
def __str__(self):
msg = 'Function {}'.format(self.function)
if self.alias:
msg += ' (namely operator "{}")'.format(self.alias)
if self.args:
msg += ' with arguments ({})'.format(', '.join(self.args))
msg += ' is not supported for SparseNDArray and only available in NDArray.'
return msg

def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
Expand Down
20 changes: 8 additions & 12 deletions python/mxnet/ndarray/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# import operator
import numpy as np
from ..base import NotSupportedForSparseNDArray
from ..base import _LIB, numeric_types
from ..base import c_array, mx_real_t
from ..base import mx_uint, NDArrayHandle, check_call
Expand Down Expand Up @@ -179,9 +180,9 @@ def __getitem__(self, key):
"""
stype = self.stype
if stype != 'csr':
raise Exception("__getitem__ for " + str(stype) + " not implemented yet")
raise Exception("__getitem__ for " + str(stype) + " is not implemented yet")
if isinstance(key, int):
raise Exception("Not implemented yet")
raise Exception("__getitem__ with int key is not implemented yet")
if isinstance(key, py_slice):
if key.step is not None:
raise ValueError('NDArray only supports continuous slicing on axis 0')
Expand All @@ -198,13 +199,13 @@ def _sync_copyfrom(self, source_array):
raise Exception('Not implemented for SparseND yet!')

def _at(self, idx):
raise Exception('at operator for SparseND is not supported.')
raise NotSupportedForSparseNDArray(self._at, '[idx]', idx)

def reshape(self, shape):
raise Exception('Not implemented for SparseND yet!')
def _slice(self, start, stop):
raise NotSupportedForSparseNDArray(self._slice, None, start, stop)

def broadcast_to(self, shape):
raise Exception('Not implemented for SparseND yet!')
def reshape(self, shape):
raise NotSupportedForSparseNDArray(self.reshape, None, shape)

def _aux_type(self, i):
"""Data-type of the array’s ith aux data.
Expand Down Expand Up @@ -235,11 +236,6 @@ def _num_aux(self):
"""
return len(_STORAGE_AUX_TYPES[self.stype])

@property
# pylint: disable= invalid-name, undefined-variable
def T(self):
raise Exception('Transpose is not supported for SparseNDArray.')

@property
def _aux_types(self):
"""The data types of the aux data for the SparseNDArray.
Expand Down
49 changes: 47 additions & 2 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,39 @@ def test_sparse_nd_negate():
# we compute (-arr)
assert_almost_equal(npy, arr.asnumpy())

def test_sparse_nd_broadcast():
sample_num = 1000
# TODO(haibin) test with more than 2 dimensions
def test_broadcast_to(stype):
for i in range(sample_num):
ndim = 2
target_shape = np.random.randint(1, 11, size=ndim)
shape = target_shape.copy()
axis_flags = np.random.randint(0, 2, size=ndim)
axes = []
for (axis, flag) in enumerate(axis_flags):
if flag:
shape[axis] = 1
dat = np.random.rand(*shape) - 0.5
numpy_ret = dat
ndarray = mx.nd.cast_storage(mx.nd.array(dat), stype=stype)
ndarray_ret = ndarray.broadcast_to(shape=target_shape)
if type(ndarray_ret) is mx.ndarray.NDArray:
ndarray_ret = ndarray_ret.asnumpy()
assert (ndarray_ret.shape == target_shape).all()
err = np.square(ndarray_ret - numpy_ret).mean()
assert err < 1E-8
stypes = ['csr', 'row_sparse']
for stype in stypes:
test_broadcast_to(stype)


def test_sparse_nd_transpose():
npy = np.random.uniform(-10, 10, rand_shape_2d())
stypes = ['csr', 'row_sparse']
for stype in stypes:
nd = mx.nd.cast_storage(mx.nd.array(npy), stype=stype)
assert_almost_equal(npy.T, (nd.T).asnumpy())

def test_sparse_nd_output_fallback():
shape = (10, 10)
Expand All @@ -327,7 +360,7 @@ def test_sparse_nd_astype():
assert(y.dtype == np.int32), y.dtype


def test_sparse_ndarray_pickle():
def test_sparse_nd_pickle():
np.random.seed(0)
repeat = 10
dim0 = 40
Expand All @@ -347,7 +380,7 @@ def test_sparse_ndarray_pickle():
assert same(a.asnumpy(), b.asnumpy())


def test_sparse_ndarray_save_load():
def test_sparse_nd_save_load():
np.random.seed(0)
repeat = 1
stypes = ['default', 'row_sparse', 'csr']
Expand Down Expand Up @@ -379,6 +412,18 @@ def test_sparse_ndarray_save_load():
assert same(x.asnumpy(), y.asnumpy())
os.remove(fname)

def test_sparse_nd_unsupported():
nd = mx.nd.zeros((2,2), stype='row_sparse')
fn_slice = lambda x: x._slice(None, None)
fn_at = lambda x: x._at(None)
fn_reshape = lambda x: x.reshape(None)
fns = [fn_slice, fn_at, fn_reshape]
for fn in fns:
try:
fn(nd)
assert(False)
except:
pass

def test_create_csr():
dim0 = 50
Expand Down

0 comments on commit 6644d22

Please sign in to comment.