diff --git a/.gitignore b/.gitignore index 1aa2ce47da11..d62c63f403e9 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,6 @@ Debug # Emacs .clang_complete .dir-locals.el +__pycache__ +*.pkl +* \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index ec2aa8dab0a3..f01653aec86f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,6 +8,8 @@ env: - TASK=lint LINT_LANG=python - TASK=doc - TASK=build CXX=g++ + - TASK=python CXX=g++ + - TASK=python3 CXX=g++ - TASK=unittest_gtest CXX=g++ # dependent apt packages @@ -27,12 +29,18 @@ addons: - g++-4.8 - clang - python-numpy + - python-nose + - python3-numpy + - python3-dev + - python3-nose + before_install: - export NVCC_PREFIX=${HOME} - if [ "$CXX" = "g++" ]; then export CXX="g++-4.8" CC="gcc-4.8"; fi - scripts/build_dmlc.sh - export TRAVIS=dmlc-core/scripts/travis + - export PYTHONPATH=${PYTHONPATH}:${PWD}/python - source ${TRAVIS}/travis_setup_env.sh @@ -40,7 +48,6 @@ install: - pip install cpplint pylint --user `whoami` - if [ "$CXX" = "g++" ]; then export CXX="g++-4.8" CC="gcc-4.8"; fi - script: - scripts/travis_script.sh diff --git a/doc/Doxyfile b/doc/Doxyfile index aeef012f2384..407ea96e95ab 100644 --- a/doc/Doxyfile +++ b/doc/Doxyfile @@ -1925,7 +1925,7 @@ PERLMOD_MAKEVAR_PREFIX = # C-preprocessor directives found in the sources and include files. # The default value is: YES. -ENABLE_PREPROCESSING = YES +ENABLE_PREPROCESSING = NO # If the MACRO_EXPANSION tag is set to YES doxygen will expand all macro names # in the source code. If set to NO only conditional compilation will be diff --git a/doc/index.md b/doc/index.md index 54c20e37c729..9d3b7c0e8805 100644 --- a/doc/index.md +++ b/doc/index.md @@ -5,6 +5,7 @@ Contents -------- * [Python User Guide](python/python_guide.md) * [C++ Developer Guide](cpp/cpp_guide.md) +* [Doxygen Version of C++ API](https://mxnet.readthedocs.org/en/latest/doxygen) Indices and tables ------------------ diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 77748dd1950c..c7720dcbd935 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# pylint: disable=invalid-name, protected-access # coding: utf-8 """MXNet: a concise, fast and flexible framework for deep learning @@ -10,6 +9,7 @@ from __future__ import absolute_import from .context import Context, current_context +from .base import MXNetError from . import narray from . import symbol diff --git a/python/mxnet/base.py b/python/mxnet/base.py index c514d6939988..6cf8c616f805 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -9,13 +9,18 @@ import platform import numpy as np +__all__ = ['MXNetError'] #---------------------------- # library loading #---------------------------- if sys.version_info[0] == 3: string_types = str, + # this function is needed for python3 + # to convert ctypes.char_p .value back to python str + py_str = lambda x: x.decode('utf-8') else: string_types = basestring, + py_str = lambda x: x class MXNetError(Exception): @@ -86,8 +91,7 @@ def check_call(ret): return value from API calls """ if ret != 0: - raise MXNetError(_LIB.MXGetLastError()) - + raise MXNetError(py_str(_LIB.MXGetLastError())) def c_str(string): """Create ctypes char * from a python string @@ -98,7 +102,8 @@ def c_str(string): Returns ------- - a char pointer that can be passed to C API + str : c_char_p + A char pointer that can be passed to C API """ return ctypes.c_char_p(string.encode('utf-8')) @@ -116,7 +121,8 @@ def c_array(ctype, values): Returns ------- - created ctypes array + out : ctypes array + Created ctypes array """ return (ctype * len(values))(*values) @@ -136,7 +142,8 @@ def ctypes2numpy_shared(cptr, shape): Returns ------- - a numpy array : numpy array + out : numpy_array + A numpy array : numpy array """ if not isinstance(cptr, ctypes.POINTER(mx_float)): raise RuntimeError('expected float pointer') @@ -145,3 +152,5 @@ def ctypes2numpy_shared(cptr, shape): size *= s dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents)) return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape) + + diff --git a/python/mxnet/narray.py b/python/mxnet/narray.py index 0693c2fa2cb4..78ac5e0471fa 100644 --- a/python/mxnet/narray.py +++ b/python/mxnet/narray.py @@ -1,12 +1,11 @@ # coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-locals, fixme, no-member """NArray interface of mxnet""" from __future__ import absolute_import import ctypes import sys from .base import _LIB -from .base import c_array +from .base import c_array, py_str from .base import mx_uint, mx_float, NArrayHandle, FunctionHandle from .base import ctypes2numpy_shared from .base import check_call @@ -49,7 +48,7 @@ class NArray(object): NArray is basic ndarray/Tensor like data structure in mxnet. """ - + # pylint: disable= no-member def __init__(self, handle): """initialize a new NArray @@ -165,7 +164,7 @@ def copyto(self, other): return NArray._copyto(self, out=hret) else: raise TypeError('copyto do not support type ' + type(other)) - + # pylint: enable= no-member def create(shape, ctx=Context.default_ctx): """Create a new NArray, with specified shape. @@ -181,10 +180,9 @@ def create(shape, ctx=Context.default_ctx): """ return NArray(handle=_new_alloc_handle(shape, ctx, False)) - +# pylint: disable=too-many-locals, invalid-name def _make_narray_function(handle): """Create a NArray function from the FunctionHandle.""" - # Constants for type masks. NARRAY_ARG_BEFORE_SCALAR = 1 ACCEPT_EMPTY_MUTATE_TARGET = 1 << 2 # Get the property of NArray @@ -193,12 +191,12 @@ def _make_narray_function(handle): n_scalars = mx_uint() n_mutate_vars = mx_uint() type_mask = ctypes.c_int() - check_call(_LIB.MXFuncDescribe( \ - handle, \ - ctypes.byref(n_used_vars), \ - ctypes.byref(n_scalars), \ - ctypes.byref(n_mutate_vars), \ - ctypes.byref(type_mask))) + check_call(_LIB.MXFuncDescribe( + handle, + ctypes.byref(n_used_vars), + ctypes.byref(n_scalars), + ctypes.byref(n_mutate_vars), + ctypes.byref(type_mask))) n_mutate_vars = n_mutate_vars.value n_used_vars = n_used_vars.value n_scalars = n_scalars.value @@ -220,19 +218,19 @@ def _make_narray_function(handle): arg_types = ctypes.POINTER(ctypes.c_char_p)() arg_descs = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.MXFuncGetInfo( \ - handle, ctypes.byref(name), ctypes.byref(desc), \ - ctypes.byref(num_args), \ - ctypes.byref(arg_names), \ - ctypes.byref(arg_types), \ - ctypes.byref(arg_descs))) - func_name = name.value + check_call(_LIB.MXFuncGetInfo( + handle, ctypes.byref(name), ctypes.byref(desc), + ctypes.byref(num_args), + ctypes.byref(arg_names), + ctypes.byref(arg_types), + ctypes.byref(arg_descs))) + func_name = py_str(name.value) param_str = [] for i in range(num_args.value): - ret = '%s : %s' % (arg_names[i], arg_types[i]) + ret = '%s : %s' % (py_str(arg_names[i]), py_str(arg_types[i])) if len(arg_descs[i]) != 0: - ret += '\n ' + arg_descs[i] + ret += '\n ' + py_str(arg_descs[i]) param_str.append(ret) doc_str = ('%s\n\n' + @@ -245,7 +243,7 @@ def _make_narray_function(handle): '-------\n' + 'out : NArray\n'+ ' The output of binary function.') - doc_str = doc_str % (desc.value, '\n'.join(param_str)) + doc_str = doc_str % (py_str(desc.value), '\n'.join(param_str)) # Definition of internal functions. def binary_narray_function(lhs, rhs, out=None): @@ -258,11 +256,10 @@ def binary_narray_function(lhs, rhs, out=None): if not accept_empty_mutate: raise TypeError('argument out is required to call %s' % func_name) out = NArray(_new_empty_handle()) - check_call(_LIB.MXFuncInvoke( \ - handle, \ - c_array(NArrayHandle, (lhs.handle, rhs.handle)), \ - c_array(mx_float, ()), \ - c_array(NArrayHandle, (out.handle,)))) + check_call(_LIB.MXFuncInvoke(handle, + c_array(NArrayHandle, (lhs.handle, rhs.handle)), + c_array(mx_float, ()), + c_array(NArrayHandle, (out.handle,)))) return out def unary_narray_function(src, out=None): @@ -327,7 +324,7 @@ def generic_narray_function(*args, **kwargs): ret_function.__name__ = func_name ret_function.__doc__ = doc_str return ret_function - +# pylint: enable=too-many-locals, invalid-name def _init_narray_module(): """List and add all the narray functions to current module.""" diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 899f78ec22b8..c35f84fda25d 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -1,12 +1,12 @@ # coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-locals, fixme +# pylint: disable=invalid-name, protected-access, fixme """Symbol support of mxnet""" from __future__ import absolute_import import ctypes import sys from .base import _LIB -from .base import c_array, c_str, mx_uint, string_types +from .base import c_array, c_str, mx_uint, py_str, string_types from .base import NArrayHandle, ExecutorHandle, SymbolHandle from .base import check_call from .context import Context @@ -92,8 +92,8 @@ def _compose(self, *args, **kwargs): else: keys = None args = c_array(SymbolHandle, [s.handle for s in args]) - check_call(_LIB.MXSymbolCompose( \ - self.handle, name, num_args, keys, args)) + check_call(_LIB.MXSymbolCompose( + self.handle, name, num_args, keys, args)) def list_arguments(self): """List all the arguments in the symbol. @@ -105,9 +105,9 @@ def list_arguments(self): """ size = ctypes.c_uint() sarr = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.MXSymbolListArguments( \ - self.handle, ctypes.byref(size), ctypes.byref(sarr))) - return [sarr[i] for i in range(size.value)] + check_call(_LIB.MXSymbolListArguments( + self.handle, ctypes.byref(size), ctypes.byref(sarr))) + return [py_str(sarr[i]) for i in range(size.value)] def list_returns(self): """List all returns in the symbol. @@ -119,9 +119,9 @@ def list_returns(self): """ size = ctypes.c_uint() sarr = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.MXSymbolListReturns( \ - self.handle, ctypes.byref(size), ctypes.byref(sarr))) - return [sarr[i] for i in range(size.value)] + check_call(_LIB.MXSymbolListReturns( + self.handle, ctypes.byref(size), ctypes.byref(sarr))) + return [py_str(sarr[i]) for i in range(size.value)] def infer_shape(self, *args, **kwargs): """Infer the shape of outputs and arguments of given known shapes of arguments. @@ -148,6 +148,7 @@ def infer_shape(self, *args, **kwargs): List of shapes of outputs. The order is in the same order as list_returns() """ + # pylint: disable=too-many-locals if len(args) != 0 and len(kwargs) != 0: raise ValueError('Can only specify known argument \ shapes either by positional or kwargs way.') @@ -176,26 +177,27 @@ def infer_shape(self, *args, **kwargs): out_shape_ndim = ctypes.POINTER(mx_uint)() out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() complete = ctypes.c_int() - check_call(_LIB.MXSymbolInferShape( \ - self.handle, len(indptr) - 1, \ - c_array(ctypes.c_char_p, keys), \ - c_array(mx_uint, indptr), \ - c_array(mx_uint, sdata), \ - ctypes.byref(arg_shape_size), \ - ctypes.byref(arg_shape_ndim), \ - ctypes.byref(arg_shape_data), \ - ctypes.byref(out_shape_size), \ - ctypes.byref(out_shape_ndim), \ - ctypes.byref(out_shape_data), \ - ctypes.byref(complete))) + check_call(_LIB.MXSymbolInferShape( + self.handle, len(indptr) - 1, + c_array(ctypes.c_char_p, keys), + c_array(mx_uint, indptr), + c_array(mx_uint, sdata), + ctypes.byref(arg_shape_size), + ctypes.byref(arg_shape_ndim), + ctypes.byref(arg_shape_data), + ctypes.byref(out_shape_size), + ctypes.byref(out_shape_ndim), + ctypes.byref(out_shape_data), + ctypes.byref(complete))) if complete.value != 0: - arg_shapes = [tuple(arg_shape_data[i][:arg_shape_ndim[i]]) \ - for i in range(arg_shape_size.value)] - out_shapes = [tuple(out_shape_data[i][:out_shape_ndim[i]]) \ - for i in range(out_shape_size.value)] + arg_shapes = [ + tuple(arg_shape_data[i][:arg_shape_ndim[i]]) for i in range(arg_shape_size.value)] + out_shapes = [ + tuple(out_shape_data[i][:out_shape_ndim[i]]) for i in range(out_shape_size.value)] return (arg_shapes, out_shapes) else: return (None, None) + # pylint: enable=too-many-locals def debug_str(self): """Get a debug string. @@ -206,9 +208,9 @@ def debug_str(self): Debug string of the symbol. """ debug_str = ctypes.c_char_p() - check_call(_LIB.MXSymbolPrint( \ - self.handle, ctypes.byref(debug_str))) - return debug_str.value + check_call(_LIB.MXSymbolPrint( + self.handle, ctypes.byref(debug_str))) + return py_str(debug_str.value) def bind(self, ctx, args, args_grad, reqs): """bind current symbol to get an executor. @@ -275,7 +277,7 @@ def Variable(name): if not isinstance(name, string_types): raise TypeError('Expect a string for variable `name`') handle = SymbolHandle() - check_call(_LIB.MXSymbolCreateVariable(name, ctypes.byref(handle))) + check_call(_LIB.MXSymbolCreateVariable(c_str(name), ctypes.byref(handle))) return Symbol(handle) @@ -312,18 +314,18 @@ def _make_atomic_symbol_function(handle): arg_types = ctypes.POINTER(ctypes.c_char_p)() arg_descs = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.MXSymbolGetAtomicSymbolInfo( \ - handle, ctypes.byref(name), ctypes.byref(desc), \ - ctypes.byref(num_args), \ - ctypes.byref(arg_names), \ - ctypes.byref(arg_types), \ - ctypes.byref(arg_descs))) - func_name = name.value + check_call(_LIB.MXSymbolGetAtomicSymbolInfo( + handle, ctypes.byref(name), ctypes.byref(desc), + ctypes.byref(num_args), + ctypes.byref(arg_names), + ctypes.byref(arg_types), + ctypes.byref(arg_descs))) + func_name = py_str(name.value) param_str = [] for i in range(num_args.value): - ret = '%s : %s' % (arg_names[i], arg_types[i]) + ret = '%s : %s' % (py_str(arg_names[i]), py_str(arg_types[i])) if len(arg_descs[i]) != 0: - ret += '\n ' + arg_descs[i] + ret += '\n ' + py_str(arg_descs[i]) param_str.append(ret) doc_str = ('%s\n\n' + @@ -336,7 +338,7 @@ def _make_atomic_symbol_function(handle): '-------\n' + 'symbol: Symbol\n'+ ' The result symbol.') - doc_str = doc_str % (desc.value, '\n'.join(param_str)) + doc_str = doc_str % (py_str(desc.value), '\n'.join(param_str)) def creator(*args, **kwargs): """Activation Operator of Neural Net. @@ -373,8 +375,9 @@ def creator(*args, **kwargs): ctypes.byref(sym_handle))) if len(args) != 0 and len(symbol_kwargs) != 0: - raise TypeError('%s can only accept input \ - Symbols either as positional or keyword arguments, not both' % func_name) + raise TypeError( + '%s can only accept input' + 'Symbols either as positional or keyword arguments, not both' % func_name) s = Symbol(sym_handle) s._compose(*args, name=name, **symbol_kwargs) @@ -394,7 +397,7 @@ def _init_symbol_module(): ctypes.byref(plist))) module_obj = sys.modules[__name__] for i in range(size.value): - hdl = ctypes.c_void_p(plist[i]) + hdl = SymbolHandle(plist[i]) function = _make_atomic_symbol_function(hdl) setattr(module_obj, function.__name__, function) diff --git a/python/test_infer_shape.py b/python/test_infer_shape.py deleted file mode 100644 index 236ad1e7ae71..000000000000 --- a/python/test_infer_shape.py +++ /dev/null @@ -1,19 +0,0 @@ -# pylint: skip-file -import mxnet as mx - -data = mx.symbol.Variable('data') - -fc1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=1000) -fc2 = mx.symbol.FullyConnected(data=fc1, name='fc2', num_hidden=10) -fc3 = mx.symbol.FullyConnected( name='fc2', num_hidden=10) - -print fc2.list_arguments() - -data_shape = (100, 100) -arg_shapes, out_shapes = fc2.infer_shape(data=data_shape) -print dict(zip(fc2.list_arguments(), arg_shapes)) -print dict(zip(fc2.list_returns(), out_shapes)) - -weight_shape= (1, 100) -data_shape = (100, 100) -arg_shapes, out_shapes = fc2.infer_shape(data=data_shape, fc1_weight=weight_shape) diff --git a/python/test_python.py b/python/test_python.py deleted file mode 100644 index 905d16c283f8..000000000000 --- a/python/test_python.py +++ /dev/null @@ -1,30 +0,0 @@ -# pylint: skip-file -import mxnet as mx - -a = mx.narray.create((3000, 4000)) -b = mx.narray.create((3000, 4000)) -a.numpy[:] = 10 -b.numpy[:] = 11 -print(a.numpy) - -c = b * a - -cc = mx.narray.NArray._mul(b, a) - -print(c.context) -print(cc.numpy) -d = c.copyto(mx.Context('cpu', 0)) - -print(d.numpy) - -with mx.Context('gpu', 0) as ctx: - # gpu operations - print mx.current_context() - print ctx - a_gpu = a.copyto(ctx) - b_gpu = b.copyto(ctx) - c_gpu = b * a - -d_cpu = c_gpu.copyto(mx.current_context()) -print d_cpu.numpy - diff --git a/python/test_symbol.py b/python/test_symbol.py deleted file mode 100644 index 451ee39775c9..000000000000 --- a/python/test_symbol.py +++ /dev/null @@ -1,27 +0,0 @@ -# pylint: skip-file -import mxnet as mx - -data = mx.symbol.Variable('data') -print data.debug_str() - -fc1 = mx.symbol.FullyConnected(data=data, name='fc1', no_bias=0) -fc2 = mx.symbol.FullyConnected(data=fc1, name='fc2', no_bias=0) - -print fc2.debug_str() - -print fc2.list_arguments() - -fc3 = mx.symbol.FullyConnected(name='fc3') -fc4 = mx.symbol.FullyConnected(data=fc3, name='fc4') - -print fc4.debug_str() - -print "-" * 10 -composed_fc4 = fc4(fc3_data=fc2, name='composed') -print composed_fc4.debug_str() - -multi_out = mx.symbol.Group([composed_fc4, fc2]) - -print multi_out.debug_str() -print multi_out.list_arguments() -print multi_out.list_returns() diff --git a/scripts/travis_script.sh b/scripts/travis_script.sh index ce9a22110fa2..1046051464be 100755 --- a/scripts/travis_script.sh +++ b/scripts/travis_script.sh @@ -3,26 +3,40 @@ # main script of travis if [ ${TASK} == "lint" ]; then make lint || exit -1 + exit 0 fi if [ ${TASK} == "doc" ]; then make doc 2>log.txt (cat log.txt|grep warning) && exit -1 + exit 0 fi +# prereqs for things that need make +cp make/config.mk config.mk +echo "USE_BLAS=blas" >> config.mk +echo "USE_CUDNN=0" >> config.mk +echo "CXX=g++-4.8" >> config.mk +export CXX="g++-4.8" + + if [ ${TASK} == "build" ]; then - echo "USE_BLAS=blas" >> config.mk echo "USE_CUDA=1" >> config.mk - echo "USE_CUDNN=0" >> config.mk - echo "CXX=g++-4.8" >> config.mk - export CXX="g++-4.8" ./dmlc-core/scripts/setup_nvcc.sh $NVCC_PREFIX make all || exit -1 fi -if [ ${TASK} == "test" ]; then - cd test +if [ ${TASK} == "python" ]; then + echo "USE_CUDA=0" >> config.mk + make all || exit -1 + nosetests tests/python || exit -1 +fi + +if [ ${TASK} == "python3" ]; then + echo "USE_CUDA=0" >> config.mk make all || exit -1 - ../scripts/travis_runtest.sh || exit -1 + nosetests3 tests/python || exit -1 fi + +# TODO(yutian): add unittest back diff --git a/test/.gitignore b/tests/.gitignore similarity index 100% rename from test/.gitignore rename to tests/.gitignore diff --git a/tests/python/models.py b/tests/python/models.py new file mode 100644 index 000000000000..d7fb74e4fd1e --- /dev/null +++ b/tests/python/models.py @@ -0,0 +1,10 @@ +"""This file defines various models used in the test""" +import mxnet as mx + +def mlp2(): + data = mx.symbol.Variable('data') + out = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=1000) + out = mx.symbol.Activation(data=out, act_type='relu') + out = mx.symbol.FullyConnected(data=out, name='fc2', num_hidden=10) + return out + diff --git a/tests/python/test_inter_shape.py b/tests/python/test_inter_shape.py new file mode 100644 index 000000000000..fa18ff175fbf --- /dev/null +++ b/tests/python/test_inter_shape.py @@ -0,0 +1,30 @@ +# pylint: skip-file +import mxnet as mx +import models +from nose.tools import * + +def test_mlp2_infer_shape(): + # Build MLP + out = models.mlp2() + # infer shape + data_shape = (100, 100) + arg_shapes, out_shapes = out.infer_shape(data=data_shape) + arg_shape_dict = dict(zip(out.list_arguments(), arg_shapes)) + + assert len(out_shapes) == 1 + assert out_shapes[0] == (100, 10) + true_shapes = {'fc2_bias': (10,), + 'fc2_weight' : (10, 1000), + 'fc1_bias' : (1000,), + 'fc1_weight' : (1000,100) } + for k, v in true_shapes.items(): + assert arg_shape_dict[k] == v + +@raises(mx.MXNetError) +def test_mlp2_infer_error(): + # Test shape inconsistent case + out = models.mlp2() + weight_shape= (1, 100) + data_shape = (100, 100) + arg_shapes, out_shapes = out.infer_shape(data=data_shape, fc1_weight=weight_shape) + diff --git a/tests/python/test_narray.py b/tests/python/test_narray.py new file mode 100644 index 000000000000..986176de81fe --- /dev/null +++ b/tests/python/test_narray.py @@ -0,0 +1,46 @@ +import mxnet as mx +import numpy as np + +def reldiff(a, b): + diff = np.sum(np.abs(a - b)) + norm = np.sum(np.abs(a)) + reldiff = diff / norm + return reldiff + +def check_with_uniform(uf, arg_shapes, dim=None): + """check function consistency with uniform random numbers""" + if isinstance(arg_shapes, int): + assert dim + shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim)) + arg_shapes = [shape] * arg_shapes + narray_arg = [] + numpy_arg = [] + for s in arg_shapes: + narr = mx.narray.create(s) + npy = np.random.uniform(-10, 10, s) + narr.numpy[:] = npy + narray_arg.append(narr) + numpy_arg.append(npy) + out1 = uf(*narray_arg) + out2 = uf(*numpy_arg) + assert out1.shape == out2.shape + assert reldiff(out1.numpy, out2) < 1e-6 + + +def test_narray_elementwise(): + np.random.seed(0) + nrepeat = 10 + maxdim = 4 + for repeat in range(nrepeat): + for dim in range(1, maxdim): + check_with_uniform(lambda x, y: x + y, 2, dim) + check_with_uniform(lambda x, y: x - y, 2, dim) + check_with_uniform(lambda x, y: x * y, 2, dim) + # check_with_uniform(lambda x, y: x / y, 2, dim) + + +def test_narray_copy(): + c = mx.narray.create((10,10)) + c.numpy[:] = np.random.uniform(-10, 10, c.shape) + d = c.copyto(mx.Context('cpu', 0)) + assert np.sum(np.abs(c.numpy != d.numpy)) == 0.0 diff --git a/tests/python/test_symbol.py b/tests/python/test_symbol.py new file mode 100644 index 000000000000..b08f6a310570 --- /dev/null +++ b/tests/python/test_symbol.py @@ -0,0 +1,28 @@ +import mxnet as mx +import models + +def test_symbol_basic(): + mlist = [] + mlist.append(models.mlp2()) + for m in mlist: + m.list_arguments() + m.list_returns() + + +def test_compose(): + data = mx.symbol.Variable('data') + net1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10) + net1 = mx.symbol.FullyConnected(data=net1, name='fc2', num_hidden=100) + net1.list_arguments() == ['data', + 'fc1_weight', 'fc1_bias', + 'fc2_weight', 'fc2_bias'] + + net2 = mx.symbol.FullyConnected(name='fc3', num_hidden=10) + net2 = mx.symbol.Activation(data=net2) + net2 = mx.symbol.FullyConnected(data=net2, name='fc4', num_hidden=20) + print(net2.debug_str()) + + composed = net2(fc3_data=net1, name='composed') + print(composed.debug_str()) + multi_out = mx.symbol.Group([composed, net1]) + assert len(multi_out.list_returns()) == 2 diff --git a/test/test_storage.cc b/tests/test_storage.cc similarity index 100% rename from test/test_storage.cc rename to tests/test_storage.cc diff --git a/test/test_threaded_engine.cc b/tests/test_threaded_engine.cc similarity index 100% rename from test/test_threaded_engine.cc rename to tests/test_threaded_engine.cc