Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
map platform-dependent types to fixed-size types
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jul 25, 2020
1 parent 84436cf commit cd5de03
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
25 changes: 24 additions & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,19 +72,42 @@
np.int8: 5,
np.int64: 6,
np.bool_: 7,
np.int16: 8,
np.uint16 : 9,
np.uint32 : 10,
np.uint64 : 11,
np.dtype([('bfloat16', np.uint16)]): 12,
}

def _register_platform_dependent_mx_dtype():
"""Register platform dependent types to the fixed size counterparts."""
kind_map = {'i': 'int', 'u': 'uint', 'f': 'float'}
for np_type in [
np.byte, np.ubyte, np.short, np.ushort, np.intc, np.uintc, np.int_,
np.uint, np.longlong, np.ulonglong, np.half, np.float16, np.single,
np.double, np.longdouble]:
dtype = np.dtype(np_type)
kind, size = dtype.kind, dtype.itemsize
bits = size * 8
fixed_dtype = getattr(np, kind_map[kind]+str(bits))
if fixed_dtype in _DTYPE_NP_TO_MX:
_DTYPE_NP_TO_MX[np_type] = _DTYPE_NP_TO_MX[fixed_dtype]
_register_platform_dependent_mx_dtype()

_DTYPE_MX_TO_NP = {
-1: None,
0: np.float32,
0: np.float32
1: np.float64,
2: np.float16,
3: np.uint8,
4: np.int32,
5: np.int8,
6: np.int64,
7: np.bool_,
8: np.int16,
9: np.uint16,
10: np.uint32,
11: np.uint64,
12: np.dtype([('bfloat16', np.uint16)]),
}

Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def _as_mx_np_array(object, ctx=None):
if isinstance(object, ndarray):
return object
elif isinstance(object, _np.ndarray):
return array(object, dtype=object.dtype, ctx=ctx)
np_dtype = _np.dtype(object.dtype).type
return array(object, dtype=np_dtype, ctx=ctx)
elif isinstance(object, (integer_types, numeric_types)):
return object
elif isinstance(object, (list, tuple)):
Expand Down

0 comments on commit cd5de03

Please sign in to comment.