From 305614a9e4254d76005485f5f693023366bb34ce Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 2 May 2017 12:01:44 -0700 Subject: [PATCH] [PYTHON] Enable cython ndarray API (#113) --- python/tvm/_ffi/_ctypes/function.py | 3 +- python/tvm/_ffi/_ctypes/ndarray.py | 35 ++++++ python/tvm/_ffi/_ctypes/types.py | 2 +- python/tvm/_ffi/_cython/base.pxi | 30 ++++- python/tvm/_ffi/_cython/core.pyx | 1 + python/tvm/_ffi/_cython/function.pxi | 9 +- python/tvm/_ffi/_cython/ndarray.pxi | 51 ++++++++ python/tvm/_ffi/ndarray.py | 180 +++------------------------ python/tvm/_ffi/runtime_ctypes.py | 138 ++++++++++++++++++++ src/api/api_base.cc | 4 +- 10 files changed, 282 insertions(+), 171 deletions(-) create mode 100644 python/tvm/_ffi/_ctypes/ndarray.py create mode 100644 python/tvm/_ffi/_cython/ndarray.pxi create mode 100644 python/tvm/_ffi/runtime_ctypes.py diff --git a/python/tvm/_ffi/_ctypes/function.py b/python/tvm/_ffi/_ctypes/function.py index 8d8525f6e..1c18123f2 100644 --- a/python/tvm/_ffi/_ctypes/function.py +++ b/python/tvm/_ffi/_ctypes/function.py @@ -10,7 +10,8 @@ from ..base import _LIB, check_call from ..base import c_str, string_types from ..node_generic import convert_to_node, NodeGeneric -from ..ndarray import TVMType, TVMByteArray, NDArrayBase, _make_array +from ..runtime_ctypes import TVMType, TVMByteArray +from .ndarray import NDArrayBase, _make_array from .types import TVMValue, TypeCode from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func diff --git a/python/tvm/_ffi/_ctypes/ndarray.py b/python/tvm/_ffi/_ctypes/ndarray.py new file mode 100644 index 000000000..2f24f779d --- /dev/null +++ b/python/tvm/_ffi/_ctypes/ndarray.py @@ -0,0 +1,35 @@ +"""Runtime NDArray api""" +from __future__ import absolute_import + +import ctypes +from ..base import _LIB, check_call +from ..runtime_ctypes import TVMArrayHandle + +class NDArrayBase(object): + """A simple Device/CPU Array object in runtime.""" + __slots__ = ["handle", "is_view"] + # pylint: disable=no-member + def __init__(self, handle, is_view=False): + """Initialize the function with handle + + Parameters + ---------- + handle : TVMArrayHandle + the handle to the underlying C++ TVMArray + """ + self.handle = handle + self.is_view = is_view + + def __del__(self): + if not self.is_view: + check_call(_LIB.TVMArrayFree(self.handle)) + +def _make_array(handle, is_view): + handle = ctypes.cast(handle, TVMArrayHandle) + return _CLASS_NDARRAY(handle, is_view) + +_CLASS_NDARRAY = None + +def _set_class_ndarray(cls): + global _CLASS_NDARRAY + _CLASS_NDARRAY = cls diff --git a/python/tvm/_ffi/_ctypes/types.py b/python/tvm/_ffi/_ctypes/types.py index e722a3580..332e72862 100644 --- a/python/tvm/_ffi/_ctypes/types.py +++ b/python/tvm/_ffi/_ctypes/types.py @@ -4,7 +4,7 @@ import ctypes from ..base import py_str, check_call, _LIB -from ..ndarray import TVMByteArray +from ..runtime_ctypes import TVMByteArray class TypeCode(object): """Type code used in API calls""" diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 60e38f61d..124e815b5 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -19,18 +19,34 @@ cdef enum TVMTypeCode: kBytes = 11 cdef extern from "tvm/runtime/c_runtime_api.h": - struct DLType: + ctypedef struct DLDataType: uint8_t code uint8_t bits uint16_t lanes + ctypedef struct DLContext: + int device_id + int device_type + + ctypedef struct DLTensor: + void* data + DLContext ctx + int ndim + DLDataType dtype + int64_t* shape + int64_t* strides + size_t byte_offset; + ctypedef struct TVMValue: int64_t v_int64 double v_float64 void* v_handle const char* v_str - DLType v_type + DLDataType v_type +ctypedef int64_t tvm_index_t +ctypedef void* DLTensorHandle +ctypedef void* TVMStreamHandle ctypedef void* TVMRetValueHandle ctypedef void* TVMFunctionHandle ctypedef void* NodeHandle @@ -61,6 +77,15 @@ cdef extern from "tvm/runtime/c_runtime_api.h": void* resource_handle, TVMPackedCFuncFinalizer fin, TVMFunctionHandle *out) + int TVMArrayAlloc(tvm_index_t* shape, + tvm_index_t ndim, + DLDataType dtype, + DLContext ctx, + DLTensorHandle* out) + int TVMArrayFree(DLTensorHandle handle) + int TVMArrayCopyFromTo(DLTensorHandle src, + DLTensorHandle to, + TVMStreamHandle stream) cdef extern from "tvm/c_api.h": int TVMCbArgToReturn(TVMValue* value, int code) @@ -106,6 +131,7 @@ cdef inline object ctypes_handle(void* chandle): """Cast C handle to ctypes handle.""" return ctypes.cast(chandle, ctypes.c_void_p) + cdef inline void* c_handle(object handle): """Cast C types handle to c handle.""" cdef unsigned long long v_ptr diff --git a/python/tvm/_ffi/_cython/core.pyx b/python/tvm/_ffi/_cython/core.pyx index 9bd755b20..e392b709c 100644 --- a/python/tvm/_ffi/_cython/core.pyx +++ b/python/tvm/_ffi/_cython/core.pyx @@ -1,3 +1,4 @@ include "./base.pxi" include "./node.pxi" include "./function.pxi" +include "./ndarray.pxi" diff --git a/python/tvm/_ffi/_cython/function.pxi b/python/tvm/_ffi/_cython/function.pxi index 51943173c..8516359a6 100644 --- a/python/tvm/_ffi/_cython/function.pxi +++ b/python/tvm/_ffi/_cython/function.pxi @@ -4,7 +4,7 @@ from cpython cimport Py_INCREF, Py_DECREF from numbers import Number, Integral from ..base import string_types from ..node_generic import convert_to_node, NodeGeneric -from ..ndarray import NDArrayBase, TVMType, TVMByteArray, _make_array +from ..runtime_ctypes import TVMType, TVMByteArray print("TVM: Initializing cython mode...") @@ -32,7 +32,7 @@ cdef int tvm_callback(TVMValue* args, if tcode != kArrayHandle: pyargs.append(make_ret(value, tcode)) else: - pyargs.append(_make_array(ctypes_handle(value.v_handle), True)) + pyargs.append(c_make_array(value.v_handle, True)) try: rv = local_pyfunc(*pyargs) except Exception: @@ -81,8 +81,7 @@ cdef inline void make_arg(object arg, value[0].v_handle = (arg).chandle tcode[0] = kNodeHandle elif isinstance(arg, NDArrayBase): - value[0].v_handle = c_handle( - ctypes.cast(arg.handle, ctypes.c_void_p)) + value[0].v_handle = (arg).chandle tcode[0] = kArrayHandle elif isinstance(arg, Integral): value[0].v_int64 = arg @@ -205,7 +204,7 @@ cdef class FunctionBase: cdef TVMFunctionHandle chandle cdef int is_global - cdef _set_handle(self, handle): + cdef inline _set_handle(self, handle): if handle is None: self.chandle = NULL else: diff --git a/python/tvm/_ffi/_cython/ndarray.pxi b/python/tvm/_ffi/_cython/ndarray.pxi new file mode 100644 index 000000000..9a0570244 --- /dev/null +++ b/python/tvm/_ffi/_cython/ndarray.pxi @@ -0,0 +1,51 @@ +from ..runtime_ctypes import TVMArrayHandle + +cdef class NDArrayBase: + cdef DLTensor* chandle + cdef int c_is_view + + cdef inline _set_handle(self, handle): + cdef unsigned long long ptr + if handle is None: + self.chandle = NULL + else: + ptr = ctypes.addressof(handle.contents) + self.chandle = (ptr) + + property handle: + def __get__(self): + if self.chandle == NULL: + return None + else: + return ctypes.cast( + self.chandle, TVMArrayHandle) + + def __set__(self, value): + self._set_handle(value) + + + def __init__(self, handle, is_view): + self._set_handle(handle) + self.c_is_view = is_view + + + def __dealloc__(self): + if self.c_is_view == 0: + CALL(TVMArrayFree(self.chandle)) + + +cdef c_make_array(void* chandle, is_view): + ret = _CLASS_NDARRAY(None, is_view) + (ret).chandle = chandle + return ret + + +def _make_array(handle, is_view): + handle = ctypes.cast(handle, TVMArrayHandle) + return _CLASS_NDARRAY(handle, is_view) + +_CLASS_NDARRAY = None + +def _set_class_ndarray(cls): + global _CLASS_NDARRAY + _CLASS_NDARRAY = cls diff --git a/python/tvm/_ffi/ndarray.py b/python/tvm/_ffi/ndarray.py index 11023a411..ed97c6b2c 100644 --- a/python/tvm/_ffi/ndarray.py +++ b/python/tvm/_ffi/ndarray.py @@ -1,142 +1,28 @@ -# pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement -# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring +# pylint: disable=invalid-name, unused-import """Runtime NDArray api""" from __future__ import absolute_import + +import sys import ctypes import numpy as np -from .base import _LIB, check_call, c_array, string_types -from .. import _api_internal - -tvm_shape_index_t = ctypes.c_int64 - -class TVMByteArray(ctypes.Structure): - """Temp data structure for byte array.""" - _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), - ("size", ctypes.c_size_t)] - -class TVMType(ctypes.Structure): - """TVM datatype structure""" - _fields_ = [("type_code", ctypes.c_uint8), - ("bits", ctypes.c_uint8), - ("lanes", ctypes.c_uint16)] - CODE2STR = { - 0 : 'int', - 1 : 'uint', - 2 : 'float', - 4 : 'handle' - } - def __init__(self, type_str, lanes=1): - super(TVMType, self).__init__() - if isinstance(type_str, np.dtype): - type_str = str(type_str) - if type_str.startswith("int"): - self.type_code = 0 - bits = int(type_str[3:]) - elif type_str.startswith("uint"): - self.type_code = 1 - bits = int(type_str[4:]) - elif type_str.startswith("float"): - self.type_code = 2 - bits = int(type_str[5:]) - elif type_str.startswith("handle"): - self.type_code = 4 - bits = 64 - else: - raise ValueError("Donot know how to handle type %s" % type_str) - - bits = 32 if bits == 0 else bits - if (bits & (bits - 1)) != 0 or bits < 8: - raise ValueError("Donot know how to handle type %s" % type_str) - self.bits = bits - self.lanes = lanes - - def __repr__(self): - x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits) - if self.lanes != 1: - x += "x%d" % self.lanes - return x - - def __eq__(self, other): - return (self.bits == other.bits and - self.type_code == other.type_code and - self.lanes == other.lanes) +from .base import _LIB, check_call, c_array, string_types, _FFI_MODE +from .runtime_ctypes import TVMType, TVMContext, TVMArray, TVMArrayHandle, tvm_shape_index_t - def __ne__(self, other): - return not self.__eq__(other) - - -class TVMContext(ctypes.Structure): - """TVM context strucure.""" - _fields_ = [("device_id", ctypes.c_int), - ("device_type", ctypes.c_int)] - MASK2STR = { - 1 : 'cpu', - 2 : 'gpu', - 4 : 'opencl', - 8 : 'metal', - 9 : 'vpi' - } - STR2MASK = { - 'cpu': 1, - 'gpu': 2, - 'cuda': 2, - 'cl': 4, - 'opencl': 4, - 'metal': 8, - 'vpi': 9 - } - def __init__(self, device_type, device_id): - super(TVMContext, self).__init__() - self.device_id = device_id - self.device_type = device_type - - @property - def exist(self): - """Whether this device exist.""" - return _api_internal._GetDeviceAttr( - self.device_type, self.device_id, 0) != 0 - - @property - def max_threads_per_block(self): - """Maximum number of threads on each block.""" - return _api_internal._GetDeviceAttr( - self.device_type, self.device_id, 1) - - @property - def warp_size(self): - """Number of threads that executes in concurrent.""" - return _api_internal._GetDeviceAttr( - self.device_type, self.device_id, 2) - def sync(self): - """Synchronize until jobs finished at the context.""" - check_call(_LIB.TVMSynchronize(self, None)) +IMPORT_EXCEPT = RuntimeError if _FFI_MODE == "cython" else ImportError - def __eq__(self, other): - return (isinstance(other, TVMContext) and - self.device_id == other.device_id and - self.device_type == other.device_type) +try: + # pylint: disable=wrong-import-position + if _FFI_MODE == "ctypes": + raise ImportError() + if sys.version_info >= (3, 0): + from ._cy3.core import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase + else: + from ._cy2.core import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase +except IMPORT_EXCEPT: + # pylint: disable=wrong-import-position + from ._ctypes.ndarray import _set_class_ndarray, _make_array, NDArrayBase as _NDArrayBase - def __ne__(self, other): - return not self.__eq__(other) - - def __repr__(self): - return "%s(%d)" % ( - TVMContext.MASK2STR[self.device_type], self.device_id) - - -class TVMArray(ctypes.Structure): - """TVMValue in C API""" - _fields_ = [("data", ctypes.c_void_p), - ("ctx", TVMContext), - ("ndim", ctypes.c_int), - ("dtype", TVMType), - ("shape", ctypes.POINTER(tvm_shape_index_t)), - ("strides", ctypes.POINTER(tvm_shape_index_t)), - ("byte_offset", ctypes.c_size_t)] - - -TVMArrayHandle = ctypes.POINTER(TVMArray) def context(dev_type, dev_id=0): """Construct a TVM context with given device type and id. @@ -214,28 +100,10 @@ def empty(shape, dtype="float32", ctx=context(1, 0)): dtype = TVMType(dtype) check_call(_LIB.TVMArrayAlloc( shape, ndim, dtype, ctx, ctypes.byref(handle))) - return _CLASS_NDARRAY(handle) - + return _make_array(handle, False) -class NDArrayBase(object): +class NDArrayBase(_NDArrayBase): """A simple Device/CPU Array object in runtime.""" - __slots__ = ["handle", "is_view"] - # pylint: disable=no-member - def __init__(self, handle, is_view=False): - """Initialize the function with handle - - Parameters - ---------- - handle : TVMArrayHandle - the handle to the underlying C++ TVMArray - """ - self.handle = handle - self.is_view = is_view - - def __del__(self): - if not self.is_view: - check_call(_LIB.TVMArrayFree(self.handle)) - @property def shape(self): """Shape of this array""" @@ -324,13 +192,3 @@ def copyto(self, target): else: raise ValueError("Unsupported target type %s" % str(type(target))) return target - -def _make_array(handle, is_view): - handle = ctypes.cast(handle, TVMArrayHandle) - return _CLASS_NDARRAY(handle, is_view) - -_CLASS_NDARRAY = None - -def _set_class_ndarray(cls): - global _CLASS_NDARRAY - _CLASS_NDARRAY = cls diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py new file mode 100644 index 000000000..ffd010aac --- /dev/null +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -0,0 +1,138 @@ +"""Common runtime ctypes.""" +# pylint: disable=invalid-name +from __future__ import absolute_import + +import ctypes +import numpy as np +from .base import _LIB, check_call +from .. import _api_internal + +tvm_shape_index_t = ctypes.c_int64 + +class TVMByteArray(ctypes.Structure): + """Temp data structure for byte array.""" + _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), + ("size", ctypes.c_size_t)] + +class TVMType(ctypes.Structure): + """TVM datatype structure""" + _fields_ = [("type_code", ctypes.c_uint8), + ("bits", ctypes.c_uint8), + ("lanes", ctypes.c_uint16)] + CODE2STR = { + 0 : 'int', + 1 : 'uint', + 2 : 'float', + 4 : 'handle' + } + def __init__(self, type_str, lanes=1): + super(TVMType, self).__init__() + if isinstance(type_str, np.dtype): + type_str = str(type_str) + if type_str.startswith("int"): + self.type_code = 0 + bits = int(type_str[3:]) + elif type_str.startswith("uint"): + self.type_code = 1 + bits = int(type_str[4:]) + elif type_str.startswith("float"): + self.type_code = 2 + bits = int(type_str[5:]) + elif type_str.startswith("handle"): + self.type_code = 4 + bits = 64 + else: + raise ValueError("Donot know how to handle type %s" % type_str) + + bits = 32 if bits == 0 else bits + if (bits & (bits - 1)) != 0 or bits < 8: + raise ValueError("Donot know how to handle type %s" % type_str) + self.bits = bits + self.lanes = lanes + + def __repr__(self): + x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits) + if self.lanes != 1: + x += "x%d" % self.lanes + return x + + def __eq__(self, other): + return (self.bits == other.bits and + self.type_code == other.type_code and + self.lanes == other.lanes) + + def __ne__(self, other): + return not self.__eq__(other) + + +class TVMContext(ctypes.Structure): + """TVM context strucure.""" + _fields_ = [("device_id", ctypes.c_int), + ("device_type", ctypes.c_int)] + MASK2STR = { + 1 : 'cpu', + 2 : 'gpu', + 4 : 'opencl', + 8 : 'metal', + 9 : 'vpi' + } + STR2MASK = { + 'cpu': 1, + 'gpu': 2, + 'cuda': 2, + 'cl': 4, + 'opencl': 4, + 'metal': 8, + 'vpi': 9 + } + def __init__(self, device_type, device_id): + super(TVMContext, self).__init__() + self.device_id = device_id + self.device_type = device_type + + @property + def exist(self): + """Whether this device exist.""" + return _api_internal._GetDeviceAttr( + self.device_type, self.device_id, 0) != 0 + + @property + def max_threads_per_block(self): + """Maximum number of threads on each block.""" + return _api_internal._GetDeviceAttr( + self.device_type, self.device_id, 1) + + @property + def warp_size(self): + """Number of threads that executes in concurrent.""" + return _api_internal._GetDeviceAttr( + self.device_type, self.device_id, 2) + + def sync(self): + """Synchronize until jobs finished at the context.""" + check_call(_LIB.TVMSynchronize(self, None)) + + def __eq__(self, other): + return (isinstance(other, TVMContext) and + self.device_id == other.device_id and + self.device_type == other.device_type) + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return "%s(%d)" % ( + TVMContext.MASK2STR[self.device_type], self.device_id) + + +class TVMArray(ctypes.Structure): + """TVMValue in C API""" + _fields_ = [("data", ctypes.c_void_p), + ("ctx", TVMContext), + ("ndim", ctypes.c_int), + ("dtype", TVMType), + ("shape", ctypes.POINTER(tvm_shape_index_t)), + ("strides", ctypes.POINTER(tvm_shape_index_t)), + ("byte_offset", ctypes.c_size_t)] + +TVMArrayHandle = ctypes.POINTER(TVMArray) diff --git a/src/api/api_base.cc b/src/api/api_base.cc index f8403f9b4..3130e109f 100644 --- a/src/api/api_base.cc +++ b/src/api/api_base.cc @@ -8,7 +8,6 @@ #include namespace tvm { - TVM_REGISTER_API("_format_str") .set_body([](TVMArgs args, TVMRetValue *ret) { CHECK(args[0].type_code() == kNodeHandle); @@ -34,4 +33,7 @@ TVM_REGISTER_API("_load_json") *ret = NodeRef(LoadJSON_(args[0])); }); +TVM_REGISTER_API("_nop") +.set_body([](TVMArgs args, TVMRetValue *ret) { + }); } // namespace tvm