From ad0ab0a4177091edf75b2820e40c886ef3a644c4 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 15 Jul 2016 17:38:15 -0700 Subject: [PATCH] Make cython compatible with python3 (#12) --- nnvm/Makefile | 4 ++- nnvm/python/nnvm/cython/base.pyi | 22 +++++++++++++++++ nnvm/python/nnvm/cython/symbol.pyx | 39 ++++++++++++++++++++---------- 3 files changed, 51 insertions(+), 14 deletions(-) diff --git a/nnvm/Makefile b/nnvm/Makefile index feba8d278183..d24e4a8e06e2 100644 --- a/nnvm/Makefile +++ b/nnvm/Makefile @@ -3,7 +3,7 @@ export CFLAGS = -std=c++11 -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loop -Iinclude -Idmlc-core/include -fPIC # specify tensor path -.PHONY: clean all test lint doc cython cython3 +.PHONY: clean all test lint doc cython cython3 cyclean all: lib/libnnvm.so lib/libnnvm.a cli_test @@ -37,6 +37,8 @@ cython: cython3: cd python; python3 setup.py build_ext --inplace +cyclean: + rm -rf python/nnvm/*/*.so python/nnvm/*/*.cpp lint: python2 dmlc-core/scripts/lint.py nnvm cpp include src diff --git a/nnvm/python/nnvm/cython/base.pyi b/nnvm/python/nnvm/cython/base.pyi index f9175651dea9..1163f64b07fd 100644 --- a/nnvm/python/nnvm/cython/base.pyi +++ b/nnvm/python/nnvm/cython/base.pyi @@ -9,6 +9,21 @@ cdef py_str(const char* x): return x.decode("utf-8") +cdef c_str(pystr): + """Create ctypes char * from a python string + Parameters + ---------- + string : string type + python string + + Returns + ------- + str : c_char_p + A char pointer that can be passed to C API + """ + return pystr.encode("utf-8") + + cdef CALL(int ret): if ret != 0: raise NNVMError(NNGetLastError()) @@ -20,6 +35,13 @@ cdef const char** CBeginPtr(vector[const char*]& vec): else: return NULL +cdef vector[const char*] SVec2Ptr(vector[string]& vec): + cdef vector[const char*] svec + svec.resize(vec.size()) + for i in range(vec.size()): + svec[i] = vec[i].c_str() + return svec + cdef BuildDoc(nn_uint num_args, const char** arg_names, diff --git a/nnvm/python/nnvm/cython/symbol.pyx b/nnvm/python/nnvm/cython/symbol.pyx index 5554520cdf5b..eeec2c430e89 100644 --- a/nnvm/python/nnvm/cython/symbol.pyx +++ b/nnvm/python/nnvm/cython/symbol.pyx @@ -6,6 +6,7 @@ from .._base import NNVMError from ..name import NameManager from ..attribute import AttrScope from libcpp.vector cimport vector +from libcpp.string cimport string from cpython.version cimport PY_MAJOR_VERSION include "./base.pyi" @@ -110,7 +111,7 @@ cdef class Symbol: CALL(NNSymbolGetOutput(self.handle, c_index, &handle)) return NewSymbol(handle) - def attr(self, const char* key): + def attr(self, key): """Get attribute string from the symbol, this function only works for non-grouped symbol. Parameters @@ -125,6 +126,8 @@ cdef class Symbol: """ cdef const char* ret cdef int success + key = c_str(key) + CALL(NNSymbolGetAttr( self.handle, key, &ret, &success)) if success != 0: @@ -203,16 +206,19 @@ cdef class Symbol: def debug_str(self): cdef const char* out_str CALL(NNSymbolPrint(self.handle, &out_str)) - return str(out_str) + return py_str(out_str) cdef SymbolSetAttr(SymbolHandle handle, dict kwargs): - cdef vector[const char*] param_keys - cdef vector[const char*] param_vals + cdef vector[string] sparam_keys + cdef vector[string] sparam_vals cdef nn_uint num_args for k, v in kwargs.items(): - param_keys.push_back(k) - param_vals.push_back(str(v)) + sparam_keys.push_back(c_str(k)) + sparam_vals.push_back(c_str(str(v))) + # keep strings in vector + cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys) + cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals) num_args = param_keys.size() CALL(NNSymbolSetAttrs( handle, num_args, CBeginPtr(param_keys), CBeginPtr(param_vals))) @@ -225,7 +231,7 @@ cdef NewSymbol(SymbolHandle handle): return sym -def Variable(const char* name, **kwargs): +def Variable(name, **kwargs): """Create a symbolic variable with specified name. Parameters @@ -241,6 +247,7 @@ def Variable(const char* name, **kwargs): The created variable symbol. """ cdef SymbolHandle handle + name = c_str(name) CALL(NNSymbolCreateVariable(name, &handle)) return NewSymbol(handle) @@ -274,10 +281,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): func_hint = func_name.lower() def creator(*args, **kwargs): - cdef vector[const char*] param_keys - cdef vector[const char*] param_vals + cdef vector[string] sparam_keys + cdef vector[string] sparam_vals cdef vector[SymbolHandle] symbol_args - cdef vector[const char*] symbol_keys + cdef vector[string] ssymbol_keys cdef SymbolHandle ret_handle name = kwargs.pop("name", None) @@ -286,11 +293,11 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): if len(kwargs) != 0: for k, v in kwargs.items(): if isinstance(v, Symbol): - symbol_keys.push_back(k) + ssymbol_keys.push_back(c_str(k)) symbol_args.push_back((v).handle) else: - param_keys.push_back(k) - param_vals.push_back(str(v)) + sparam_keys.push_back(c_str(k)) + sparam_vals.push_back(c_str(str(v))) if len(args) != 0: if symbol_args.size() != 0: @@ -301,6 +308,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): raise TypeError('Compose expect `Symbol` as arguments') symbol_args.push_back((v).handle) + cdef vector[const char*] param_keys = SVec2Ptr(sparam_keys) + cdef vector[const char*] param_vals = SVec2Ptr(sparam_vals) + cdef vector[const char*] symbol_keys = SVec2Ptr(ssymbol_keys) + CALL(NNSymbolCreateAtomicSymbol( handle, param_keys.size(), @@ -315,7 +326,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle): name = NameManager.current.get(name, func_hint) cdef const char* c_name = NULL + if name: + name = c_str(name) c_name = name CALL(NNSymbolCompose(