diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 4df97e594a01..fa26dfff9628 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -72,9 +72,28 @@ np.int8: 5, np.int64: 6, np.bool_: 7, + np.int16: 8, + np.uint16 : 9, + np.uint32 : 10, + np.uint64 : 11, np.dtype([('bfloat16', np.uint16)]): 12, } +def _register_platform_dependent_mx_dtype(): + """Register platform dependent types to the fixed size counterparts.""" + kind_map = {'i': 'int', 'u': 'uint', 'f': 'float'} + for np_type in [ + np.byte, np.ubyte, np.short, np.ushort, np.intc, np.uintc, np.int_, + np.uint, np.longlong, np.ulonglong, np.half, np.float16, np.single, + np.double, np.longdouble]: + dtype = np.dtype(np_type) + kind, size = dtype.kind, dtype.itemsize + bits = size * 8 + fixed_dtype = getattr(np, kind_map[kind]+str(bits)) + if fixed_dtype in _DTYPE_NP_TO_MX: + _DTYPE_NP_TO_MX[np_type] = _DTYPE_NP_TO_MX[fixed_dtype] +_register_platform_dependent_mx_dtype() + _DTYPE_MX_TO_NP = { -1: None, 0: np.float32, @@ -85,6 +104,10 @@ 5: np.int8, 6: np.int64, 7: np.bool_, + 8: np.int16, + 9: np.uint16, + 10: np.uint32, + 11: np.uint64, 12: np.dtype([('bfloat16', np.uint16)]), } diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 17c549193324..5274408e4403 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -185,7 +185,8 @@ def _as_mx_np_array(object, ctx=None): if isinstance(object, ndarray): return object elif isinstance(object, _np.ndarray): - return array(object, dtype=object.dtype, ctx=ctx) + np_dtype = _np.dtype(object.dtype).type + return array(object, dtype=np_dtype, ctx=ctx) elif isinstance(object, (integer_types, numeric_types)): return object elif isinstance(object, (list, tuple)):