Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[CTYPES] Fix python interface to be more consistent with C API #164

Merged
merged 1 commit into from
Sep 27, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

/*! \brief manually define unsigned int */
typedef unsigned int mx_uint;
/*! \brief manually define unsigned long int */
typedef unsigned long int mx_ulong; // NOLINT(*)
/*! \brief manually define unsigned int */
typedef float mx_float;
// all the handles are simply void *
Expand Down Expand Up @@ -108,7 +106,7 @@ MXNET_DLL int MXNDArrayCreate(const mx_uint *shape,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayLoadFromRawBytes(const void *buf,
mx_ulong size,
size_t size,
NDArrayHandle *out);
/*!
* \brief save the NDArray into raw bytes.
Expand All @@ -118,7 +116,7 @@ MXNET_DLL int MXNDArrayLoadFromRawBytes(const void *buf,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArraySaveRawBytes(NDArrayHandle handle,
mx_ulong *out_size,
size_t *out_size,
const char **out_buf);
/*!
* \brief Save list of narray into the file.
Expand Down Expand Up @@ -172,8 +170,8 @@ MXNET_DLL int MXNDArraySyncCopyFromCPU(NDArrayHandle handle,
* \param size the memory size we want to copy into.
*/
MXNET_DLL int MXNDArraySyncCopyToCPU(NDArrayHandle handle,
mx_float *data,
size_t size);
mx_float *data,
size_t size);
/*!
* \brief Wait until all the pending writes with respect NDArray are finished.
* Always call this before read data out synchronizely.
Expand Down Expand Up @@ -354,7 +352,7 @@ MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
int num_param,
mx_uint num_param,
const char **keys,
const char **vals,
SymbolHandle *out);
Expand Down Expand Up @@ -558,7 +556,7 @@ MXNET_DLL int MXExecutorPrint(ExecutorHandle symbol, const char **out_str);
* \param is_train bool value to indicate whether the forward pass is for evaluation
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXExecutorForward(ExecutorHandle handle, bool is_train);
MXNET_DLL int MXExecutorForward(ExecutorHandle handle, int is_train);
/*!
* \brief Excecutor run backward
*
Expand Down Expand Up @@ -632,7 +630,7 @@ MXNET_DLL int MXListDataIters(mx_uint *out_size,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterCreateIter(DataIterCreator handle,
int num_param,
mx_uint num_param,
const char **keys,
const char **vals,
DataIterHandle *out);
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _load_lib():
# type definitions
mx_uint = ctypes.c_uint
mx_float = ctypes.c_float
mx_float_p = ctypes.POINTER(mx_float)
NDArrayHandle = ctypes.c_void_p
FunctionHandle = ctypes.c_void_p
SymbolCreatorHandle = ctypes.c_void_p
Expand Down
10 changes: 7 additions & 3 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ def forward(self, is_train=True):
----------
is_train: bool
whether this forward is for evaluation purpose
Note: for test only network, please indicate in Bind (TODO)
"""
check_call(_LIB.MXExecutorForward(self.handle, is_train))
check_call(_LIB.MXExecutorForward(
self.handle,
ctypes.c_int(int(is_train))))

def backward(self, head_grads=None):
"""Do backward on heads' gradient.
Expand All @@ -55,7 +56,10 @@ def backward(self, head_grads=None):
if not isinstance(obj, NDArray):
raise TypeError("inputs must be NDArray")
ndarray = c_array(NDArrayHandle, [item.handle for item in head_grads])
check_call(_LIB.MXExecutorBackward(self.handle, len(head_grads), ndarray))
check_call(_LIB.MXExecutorBackward(
self.handle,
mx_uint(len(head_grads)),
ndarray))

def debug_str(self):
"""Get a debug string about internal execution plan.
Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ def creator(*args, **kwargs):
param_vals = c_array(ctypes.c_char_p, param_vals)
iter_handle = DataIterHandle()
check_call(_LIB.MXDataIterCreateIter(
handle, len(param_keys),
handle,
mx_uint(len(param_keys)),
param_keys, param_vals,
ctypes.byref(iter_handle)))

Expand Down
25 changes: 12 additions & 13 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from .base import _LIB, string_types, numeric_types
from .base import c_array, py_str, c_str
from .base import mx_uint, mx_float, NDArrayHandle, FunctionHandle
from .base import mx_uint, mx_float, mx_float_p, NDArrayHandle, FunctionHandle
from .base import ctypes2buffer
from .base import check_call, ctypes2docstring
from .context import Context
Expand Down Expand Up @@ -38,10 +38,10 @@ def _new_alloc_handle(shape, ctx, delay_alloc):
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayCreate(
c_array(mx_uint, shape),
len(shape),
ctx.device_typeid,
ctx.device_id,
int(delay_alloc),
mx_uint(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(delay_alloc)),
ctypes.byref(hdl)))
return hdl

Expand Down Expand Up @@ -92,7 +92,6 @@ def __sub__(self, other):
return NDArray._minus_scalar(self, float(other))
else:
raise TypeError('type %s not supported' % str(type(other)))

def __isub__(self, other):
if isinstance(other, NDArray):
return NDArray._minus(self, other, out=self)
Expand Down Expand Up @@ -158,7 +157,7 @@ def __getstate__(self):
this = self.__dict__.copy()
handle = this['handle']
if handle is not None:
length = ctypes.c_ulong()
length = ctypes.c_size_t()
cptr = ctypes.POINTER(ctypes.c_char)()
check_call(_LIB.MXNDArraySaveRawBytes(self.handle,
ctypes.byref(length),
Expand All @@ -172,7 +171,7 @@ def __setstate__(self, state):
buf = handle
handle = NDArrayHandle()
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
length = ctypes.c_ulong(len(buf))
length = ctypes.c_size_t(len(buf))
check_call(_LIB.MXNDArrayLoadFromRawBytes(ptr, length, ctypes.byref(handle)))
state['handle'] = handle
self.__dict__.update(state)
Expand Down Expand Up @@ -223,8 +222,8 @@ def _sync_copyfrom(self, source_array):
raise ValueError('array shape do not match the shape of NDArray')
check_call(_LIB.MXNDArraySyncCopyFromCPU(
self.handle,
source_array.ctypes.data_as(ctypes.POINTER(mx_float)),
source_array.size))
source_array.ctypes.data_as(mx_float_p),
ctypes.c_size_t(source_array.size)))

def _slice(self, start, stop):
"""Return a sliiced NDArray that shares memory with current one.
Expand Down Expand Up @@ -292,8 +291,8 @@ def asnumpy(self):
data = np.empty(self.shape, dtype=np.float32)
check_call(_LIB.MXNDArraySyncCopyToCPU(
self.handle,
data.ctypes.data,
data.size))
data.ctypes.data_as(mx_float_p),
ctypes.c_size_t(data.size)))
return data

def copyto(self, other):
Expand Down Expand Up @@ -505,7 +504,7 @@ def save(fname, data):
handles.append(val.handle)
keys = None
check_call(_LIB.MXNDArraySave(c_str(fname),
len(handles),
mx_uint(len(handles)),
c_array(NDArrayHandle, handles),
keys))

Expand Down
17 changes: 10 additions & 7 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def infer_shape(self, *args, **kwargs):
aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))()
complete = ctypes.c_int()
check_call(_LIB.MXSymbolInferShape(
self.handle, len(indptr) - 1,
self.handle,
mx_uint(len(indptr) - 1),
c_array(ctypes.c_char_p, keys),
c_array(mx_uint, indptr),
c_array(mx_uint, sdata),
Expand Down Expand Up @@ -561,13 +562,13 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None):

handle = ExecutorHandle()
check_call(_LIB.MXExecutorBind(self.handle,
ctx.device_typeid,
ctx.device_id,
len(args),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
mx_uint(len(args)),
args_handle,
args_grad_handle,
reqs_array,
len(aux_states),
mx_uint(len(aux_states)),
aux_args_handle,
ctypes.byref(handle)))
executor = Executor(handle)
Expand Down Expand Up @@ -642,7 +643,8 @@ def Group(symbols):
ihandles.append(sym.handle)
handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateGroup(
len(ihandles), c_array(SymbolHandle, ihandles), ctypes.byref(handle)))
mx_uint(len(ihandles)),
c_array(SymbolHandle, ihandles), ctypes.byref(handle)))
return Symbol(handle)


Expand Down Expand Up @@ -771,7 +773,8 @@ def creator(*args, **kwargs):
param_vals = c_array(ctypes.c_char_p, param_vals)
sym_handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateAtomicSymbol(
handle, len(param_keys),
handle,
mx_uint(len(param_keys)),
param_keys, param_vals,
ctypes.byref(sym_handle)))

Expand Down
34 changes: 17 additions & 17 deletions src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ int MXNDArrayCreate(const mx_uint *shape,
}

int MXNDArrayLoadFromRawBytes(const void *buf,
mx_ulong size,
size_t size,
NDArrayHandle *out) {
NDArray *ptr = nullptr;
API_BEGIN();
Expand All @@ -228,7 +228,7 @@ int MXNDArrayLoadFromRawBytes(const void *buf,
}

int MXNDArraySaveRawBytes(NDArrayHandle handle,
mx_ulong *out_size,
size_t *out_size,
const char **out_buf) {
MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
Expand Down Expand Up @@ -472,7 +472,7 @@ int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
}

int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
int num_param,
mx_uint num_param,
const char **keys,
const char **vals,
SymbolHandle *out) {
Expand All @@ -483,7 +483,7 @@ int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
OperatorPropertyReg *e = static_cast<OperatorPropertyReg *>(creator);
op = e->body();
std::vector<std::pair<std::string, std::string> > kwargs;
for (int i = 0; i < num_param; ++i) {
for (mx_uint i = 0; i < num_param; ++i) {
kwargs.push_back({std::string(keys[i]), std::string(vals[i])});
}
op->Init(kwargs);
Expand Down Expand Up @@ -762,10 +762,10 @@ int MXExecutorPrint(ExecutorHandle handle, const char **out_str) {
API_END();
}

int MXExecutorForward(ExecutorHandle handle, bool is_train) {
int MXExecutorForward(ExecutorHandle handle, int is_train) {
API_BEGIN();
Executor *exec = static_cast<Executor*>(handle);
exec->Forward(is_train);
exec->Forward(is_train != 0);
API_END();
}

Expand Down Expand Up @@ -851,28 +851,28 @@ int MXListDataIters(mx_uint *out_size,
}

int MXDataIterGetIterInfo(DataIterCreator creator,
const char **name,
const char **description,
mx_uint *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions) {
const char **name,
const char **description,
mx_uint *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions) {
DataIteratorReg *e = static_cast<DataIteratorReg *>(creator);
return MXAPIGetFunctionRegInfo(e, name, description, num_args,
arg_names, arg_type_infos, arg_descriptions);
}

int MXDataIterCreateIter(DataIterCreator creator,
int num_param,
const char **keys,
const char **vals,
DataIterHandle *out) {
mx_uint num_param,
const char **keys,
const char **vals,
DataIterHandle *out) {
IIterator<DataBatch> *iter = nullptr;
API_BEGIN();
DataIteratorReg *e = static_cast<DataIteratorReg *>(creator);
iter = e->body();
std::vector<std::pair<std::string, std::string> > kwargs;
for (int i = 0; i < num_param; ++i) {
for (mx_uint i = 0; i < num_param; ++i) {
kwargs.push_back({std::string(keys[i]), std::string(vals[i])});
}
iter->Init(kwargs);
Expand Down