From 48f2617f0fad54ca32acb7816409fa20a4b6b7c6 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 26 Sep 2015 17:42:39 -0700 Subject: [PATCH] [CTYPES] Fix python interface to be more consistent with C API --- include/mxnet/c_api.h | 16 +++++++--------- python/mxnet/base.py | 1 + python/mxnet/executor.py | 10 +++++++--- python/mxnet/io.py | 3 ++- python/mxnet/ndarray.py | 25 ++++++++++++------------- python/mxnet/symbol.py | 17 ++++++++++------- src/c_api.cc | 34 +++++++++++++++++----------------- 7 files changed, 56 insertions(+), 50 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 503d4d4fa554..d9cbb25a92ad 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -19,8 +19,6 @@ /*! \brief manually define unsigned int */ typedef unsigned int mx_uint; -/*! \brief manually define unsigned long int */ -typedef unsigned long int mx_ulong; // NOLINT(*) /*! \brief manually define unsigned int */ typedef float mx_float; // all the handles are simply void * @@ -108,7 +106,7 @@ MXNET_DLL int MXNDArrayCreate(const mx_uint *shape, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayLoadFromRawBytes(const void *buf, - mx_ulong size, + size_t size, NDArrayHandle *out); /*! * \brief save the NDArray into raw bytes. @@ -118,7 +116,7 @@ MXNET_DLL int MXNDArrayLoadFromRawBytes(const void *buf, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArraySaveRawBytes(NDArrayHandle handle, - mx_ulong *out_size, + size_t *out_size, const char **out_buf); /*! * \brief Save list of narray into the file. @@ -172,8 +170,8 @@ MXNET_DLL int MXNDArraySyncCopyFromCPU(NDArrayHandle handle, * \param size the memory size we want to copy into. */ MXNET_DLL int MXNDArraySyncCopyToCPU(NDArrayHandle handle, - mx_float *data, - size_t size); + mx_float *data, + size_t size); /*! * \brief Wait until all the pending writes with respect NDArray are finished. * Always call this before read data out synchronizely. @@ -354,7 +352,7 @@ MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, - int num_param, + mx_uint num_param, const char **keys, const char **vals, SymbolHandle *out); @@ -558,7 +556,7 @@ MXNET_DLL int MXExecutorPrint(ExecutorHandle symbol, const char **out_str); * \param is_train bool value to indicate whether the forward pass is for evaluation * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXExecutorForward(ExecutorHandle handle, bool is_train); +MXNET_DLL int MXExecutorForward(ExecutorHandle handle, int is_train); /*! * \brief Excecutor run backward * @@ -632,7 +630,7 @@ MXNET_DLL int MXListDataIters(mx_uint *out_size, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle, - int num_param, + mx_uint num_param, const char **keys, const char **vals, DataIterHandle *out); diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 56638540ba8e..7a5606066ac6 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -73,6 +73,7 @@ def _load_lib(): # type definitions mx_uint = ctypes.c_uint mx_float = ctypes.c_float +mx_float_p = ctypes.POINTER(mx_float) NDArrayHandle = ctypes.c_void_p FunctionHandle = ctypes.c_void_p SymbolCreatorHandle = ctypes.c_void_p diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 828fb93bbc4d..57a1ad1d238c 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -34,9 +34,10 @@ def forward(self, is_train=True): ---------- is_train: bool whether this forward is for evaluation purpose - Note: for test only network, please indicate in Bind (TODO) """ - check_call(_LIB.MXExecutorForward(self.handle, is_train)) + check_call(_LIB.MXExecutorForward( + self.handle, + ctypes.c_int(int(is_train)))) def backward(self, head_grads=None): """Do backward on heads' gradient. @@ -55,7 +56,10 @@ def backward(self, head_grads=None): if not isinstance(obj, NDArray): raise TypeError("inputs must be NDArray") ndarray = c_array(NDArrayHandle, [item.handle for item in head_grads]) - check_call(_LIB.MXExecutorBackward(self.handle, len(head_grads), ndarray)) + check_call(_LIB.MXExecutorBackward( + self.handle, + mx_uint(len(head_grads)), + ndarray)) def debug_str(self): """Get a debug string about internal execution plan. diff --git a/python/mxnet/io.py b/python/mxnet/io.py index eb89d5c44226..610b6544eb85 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -310,7 +310,8 @@ def creator(*args, **kwargs): param_vals = c_array(ctypes.c_char_p, param_vals) iter_handle = DataIterHandle() check_call(_LIB.MXDataIterCreateIter( - handle, len(param_keys), + handle, + mx_uint(len(param_keys)), param_keys, param_vals, ctypes.byref(iter_handle))) diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index ab7b928d1a14..5418047ee27f 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -8,7 +8,7 @@ import numpy as np from .base import _LIB, string_types, numeric_types from .base import c_array, py_str, c_str -from .base import mx_uint, mx_float, NDArrayHandle, FunctionHandle +from .base import mx_uint, mx_float, mx_float_p, NDArrayHandle, FunctionHandle from .base import ctypes2buffer from .base import check_call, ctypes2docstring from .context import Context @@ -38,10 +38,10 @@ def _new_alloc_handle(shape, ctx, delay_alloc): hdl = NDArrayHandle() check_call(_LIB.MXNDArrayCreate( c_array(mx_uint, shape), - len(shape), - ctx.device_typeid, - ctx.device_id, - int(delay_alloc), + mx_uint(len(shape)), + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + ctypes.c_int(int(delay_alloc)), ctypes.byref(hdl))) return hdl @@ -92,7 +92,6 @@ def __sub__(self, other): return NDArray._minus_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) - def __isub__(self, other): if isinstance(other, NDArray): return NDArray._minus(self, other, out=self) @@ -158,7 +157,7 @@ def __getstate__(self): this = self.__dict__.copy() handle = this['handle'] if handle is not None: - length = ctypes.c_ulong() + length = ctypes.c_size_t() cptr = ctypes.POINTER(ctypes.c_char)() check_call(_LIB.MXNDArraySaveRawBytes(self.handle, ctypes.byref(length), @@ -172,7 +171,7 @@ def __setstate__(self, state): buf = handle handle = NDArrayHandle() ptr = (ctypes.c_char * len(buf)).from_buffer(buf) - length = ctypes.c_ulong(len(buf)) + length = ctypes.c_size_t(len(buf)) check_call(_LIB.MXNDArrayLoadFromRawBytes(ptr, length, ctypes.byref(handle))) state['handle'] = handle self.__dict__.update(state) @@ -223,8 +222,8 @@ def _sync_copyfrom(self, source_array): raise ValueError('array shape do not match the shape of NDArray') check_call(_LIB.MXNDArraySyncCopyFromCPU( self.handle, - source_array.ctypes.data_as(ctypes.POINTER(mx_float)), - source_array.size)) + source_array.ctypes.data_as(mx_float_p), + ctypes.c_size_t(source_array.size))) def _slice(self, start, stop): """Return a sliiced NDArray that shares memory with current one. @@ -292,8 +291,8 @@ def asnumpy(self): data = np.empty(self.shape, dtype=np.float32) check_call(_LIB.MXNDArraySyncCopyToCPU( self.handle, - data.ctypes.data, - data.size)) + data.ctypes.data_as(mx_float_p), + ctypes.c_size_t(data.size))) return data def copyto(self, other): @@ -505,7 +504,7 @@ def save(fname, data): handles.append(val.handle) keys = None check_call(_LIB.MXNDArraySave(c_str(fname), - len(handles), + mx_uint(len(handles)), c_array(NDArrayHandle, handles), keys)) diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 1c8841a74460..e8b8af78fe3b 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -293,7 +293,8 @@ def infer_shape(self, *args, **kwargs): aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() complete = ctypes.c_int() check_call(_LIB.MXSymbolInferShape( - self.handle, len(indptr) - 1, + self.handle, + mx_uint(len(indptr) - 1), c_array(ctypes.c_char_p, keys), c_array(mx_uint, indptr), c_array(mx_uint, sdata), @@ -561,13 +562,13 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): handle = ExecutorHandle() check_call(_LIB.MXExecutorBind(self.handle, - ctx.device_typeid, - ctx.device_id, - len(args), + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + mx_uint(len(args)), args_handle, args_grad_handle, reqs_array, - len(aux_states), + mx_uint(len(aux_states)), aux_args_handle, ctypes.byref(handle))) executor = Executor(handle) @@ -642,7 +643,8 @@ def Group(symbols): ihandles.append(sym.handle) handle = SymbolHandle() check_call(_LIB.MXSymbolCreateGroup( - len(ihandles), c_array(SymbolHandle, ihandles), ctypes.byref(handle))) + mx_uint(len(ihandles)), + c_array(SymbolHandle, ihandles), ctypes.byref(handle))) return Symbol(handle) @@ -771,7 +773,8 @@ def creator(*args, **kwargs): param_vals = c_array(ctypes.c_char_p, param_vals) sym_handle = SymbolHandle() check_call(_LIB.MXSymbolCreateAtomicSymbol( - handle, len(param_keys), + handle, + mx_uint(len(param_keys)), param_keys, param_vals, ctypes.byref(sym_handle))) diff --git a/src/c_api.cc b/src/c_api.cc index 0e3cd487d99c..5bfd72fb3e1b 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -214,7 +214,7 @@ int MXNDArrayCreate(const mx_uint *shape, } int MXNDArrayLoadFromRawBytes(const void *buf, - mx_ulong size, + size_t size, NDArrayHandle *out) { NDArray *ptr = nullptr; API_BEGIN(); @@ -228,7 +228,7 @@ int MXNDArrayLoadFromRawBytes(const void *buf, } int MXNDArraySaveRawBytes(NDArrayHandle handle, - mx_ulong *out_size, + size_t *out_size, const char **out_buf) { MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); @@ -472,7 +472,7 @@ int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, } int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, - int num_param, + mx_uint num_param, const char **keys, const char **vals, SymbolHandle *out) { @@ -483,7 +483,7 @@ int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, OperatorPropertyReg *e = static_cast(creator); op = e->body(); std::vector > kwargs; - for (int i = 0; i < num_param; ++i) { + for (mx_uint i = 0; i < num_param; ++i) { kwargs.push_back({std::string(keys[i]), std::string(vals[i])}); } op->Init(kwargs); @@ -762,10 +762,10 @@ int MXExecutorPrint(ExecutorHandle handle, const char **out_str) { API_END(); } -int MXExecutorForward(ExecutorHandle handle, bool is_train) { +int MXExecutorForward(ExecutorHandle handle, int is_train) { API_BEGIN(); Executor *exec = static_cast(handle); - exec->Forward(is_train); + exec->Forward(is_train != 0); API_END(); } @@ -851,28 +851,28 @@ int MXListDataIters(mx_uint *out_size, } int MXDataIterGetIterInfo(DataIterCreator creator, - const char **name, - const char **description, - mx_uint *num_args, - const char ***arg_names, - const char ***arg_type_infos, - const char ***arg_descriptions) { + const char **name, + const char **description, + mx_uint *num_args, + const char ***arg_names, + const char ***arg_type_infos, + const char ***arg_descriptions) { DataIteratorReg *e = static_cast(creator); return MXAPIGetFunctionRegInfo(e, name, description, num_args, arg_names, arg_type_infos, arg_descriptions); } int MXDataIterCreateIter(DataIterCreator creator, - int num_param, - const char **keys, - const char **vals, - DataIterHandle *out) { + mx_uint num_param, + const char **keys, + const char **vals, + DataIterHandle *out) { IIterator *iter = nullptr; API_BEGIN(); DataIteratorReg *e = static_cast(creator); iter = e->body(); std::vector > kwargs; - for (int i = 0; i < num_param; ++i) { + for (mx_uint i = 0; i < num_param; ++i) { kwargs.push_back({std::string(keys[i]), std::string(vals[i])}); } iter->Init(kwargs);