Skip to content

Commit

Permalink
[PYTHON] Enable cython ndarray API (dmlc#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored May 2, 2017
1 parent 706f9b6 commit 305614a
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 171 deletions.
3 changes: 2 additions & 1 deletion python/tvm/_ffi/_ctypes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from ..base import _LIB, check_call
from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import TVMType, TVMByteArray, NDArrayBase, _make_array
from ..runtime_ctypes import TVMType, TVMByteArray
from .ndarray import NDArrayBase, _make_array
from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
Expand Down
35 changes: 35 additions & 0 deletions python/tvm/_ffi/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Runtime NDArray api"""
from __future__ import absolute_import

import ctypes
from ..base import _LIB, check_call
from ..runtime_ctypes import TVMArrayHandle

class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime."""
__slots__ = ["handle", "is_view"]
# pylint: disable=no-member
def __init__(self, handle, is_view=False):
"""Initialize the function with handle
Parameters
----------
handle : TVMArrayHandle
the handle to the underlying C++ TVMArray
"""
self.handle = handle
self.is_view = is_view

def __del__(self):
if not self.is_view:
check_call(_LIB.TVMArrayFree(self.handle))

def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)

_CLASS_NDARRAY = None

def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
2 changes: 1 addition & 1 deletion python/tvm/_ffi/_ctypes/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import ctypes
from ..base import py_str, check_call, _LIB
from ..ndarray import TVMByteArray
from ..runtime_ctypes import TVMByteArray

class TypeCode(object):
"""Type code used in API calls"""
Expand Down
30 changes: 28 additions & 2 deletions python/tvm/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,34 @@ cdef enum TVMTypeCode:
kBytes = 11

cdef extern from "tvm/runtime/c_runtime_api.h":
struct DLType:
ctypedef struct DLDataType:
uint8_t code
uint8_t bits
uint16_t lanes

ctypedef struct DLContext:
int device_id
int device_type

ctypedef struct DLTensor:
void* data
DLContext ctx
int ndim
DLDataType dtype
int64_t* shape
int64_t* strides
size_t byte_offset;

ctypedef struct TVMValue:
int64_t v_int64
double v_float64
void* v_handle
const char* v_str
DLType v_type
DLDataType v_type

ctypedef int64_t tvm_index_t
ctypedef void* DLTensorHandle
ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* NodeHandle
Expand Down Expand Up @@ -61,6 +77,15 @@ cdef extern from "tvm/runtime/c_runtime_api.h":
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out)
int TVMArrayAlloc(tvm_index_t* shape,
tvm_index_t ndim,
DLDataType dtype,
DLContext ctx,
DLTensorHandle* out)
int TVMArrayFree(DLTensorHandle handle)
int TVMArrayCopyFromTo(DLTensorHandle src,
DLTensorHandle to,
TVMStreamHandle stream)

cdef extern from "tvm/c_api.h":
int TVMCbArgToReturn(TVMValue* value, int code)
Expand Down Expand Up @@ -106,6 +131,7 @@ cdef inline object ctypes_handle(void* chandle):
"""Cast C handle to ctypes handle."""
return ctypes.cast(<unsigned long long>chandle, ctypes.c_void_p)


cdef inline void* c_handle(object handle):
"""Cast C types handle to c handle."""
cdef unsigned long long v_ptr
Expand Down
1 change: 1 addition & 0 deletions python/tvm/_ffi/_cython/core.pyx
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include "./base.pxi"
include "./node.pxi"
include "./function.pxi"
include "./ndarray.pxi"
9 changes: 4 additions & 5 deletions python/tvm/_ffi/_cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import NDArrayBase, TVMType, TVMByteArray, _make_array
from ..runtime_ctypes import TVMType, TVMByteArray

print("TVM: Initializing cython mode...")

Expand Down Expand Up @@ -32,7 +32,7 @@ cdef int tvm_callback(TVMValue* args,
if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode))
else:
pyargs.append(_make_array(ctypes_handle(value.v_handle), True))
pyargs.append(c_make_array(value.v_handle, True))
try:
rv = local_pyfunc(*pyargs)
except Exception:
Expand Down Expand Up @@ -81,8 +81,7 @@ cdef inline void make_arg(object arg,
value[0].v_handle = (<NodeBase>arg).chandle
tcode[0] = kNodeHandle
elif isinstance(arg, NDArrayBase):
value[0].v_handle = c_handle(
ctypes.cast(arg.handle, ctypes.c_void_p))
value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = kArrayHandle
elif isinstance(arg, Integral):
value[0].v_int64 = arg
Expand Down Expand Up @@ -205,7 +204,7 @@ cdef class FunctionBase:
cdef TVMFunctionHandle chandle
cdef int is_global

cdef _set_handle(self, handle):
cdef inline _set_handle(self, handle):
if handle is None:
self.chandle = NULL
else:
Expand Down
51 changes: 51 additions & 0 deletions python/tvm/_ffi/_cython/ndarray.pxi
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from ..runtime_ctypes import TVMArrayHandle

cdef class NDArrayBase:
cdef DLTensor* chandle
cdef int c_is_view

cdef inline _set_handle(self, handle):
cdef unsigned long long ptr
if handle is None:
self.chandle = NULL
else:
ptr = ctypes.addressof(handle.contents)
self.chandle = <DLTensor*>(ptr)

property handle:
def __get__(self):
if self.chandle == NULL:
return None
else:
return ctypes.cast(
<unsigned long long>self.chandle, TVMArrayHandle)

def __set__(self, value):
self._set_handle(value)


def __init__(self, handle, is_view):
self._set_handle(handle)
self.c_is_view = is_view


def __dealloc__(self):
if self.c_is_view == 0:
CALL(TVMArrayFree(self.chandle))


cdef c_make_array(void* chandle, is_view):
ret = _CLASS_NDARRAY(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret


def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)

_CLASS_NDARRAY = None

def _set_class_ndarray(cls):
global _CLASS_NDARRAY
_CLASS_NDARRAY = cls
Loading

0 comments on commit 305614a

Please sign in to comment.