Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
register AtomicSymbol
Browse files Browse the repository at this point in the history
  • Loading branch information
mavenlin committed Jul 17, 2015
1 parent 8edf819 commit 6214cd9
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 178 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ endif
BIN = test/api_registry_test
OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o
# add threaded engine after it is done
OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o operator.o
OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o operator.o atomic_symbol_registry.o
CUOBJ =
SLIB = api/libmxnet.so
ALIB = api/libmxnet.a
Expand All @@ -85,6 +85,7 @@ static_operator.o: src/static_operator/static_operator.cc
static_operator_cpu.o: src/static_operator/static_operator_cpu.cc
static_operator_gpu.o: src/static_operator/static_operator_gpu.cu
symbol.o: src/symbol/symbol.cc
atomic_symbol_registry.o: src/symbol/atomic_symbol_registry.cc
api_registry.o: src/api_registry.cc
mxnet_api.o: api/mxnet_api.cc
operator.o: src/operator/operator.cc
Expand Down
55 changes: 21 additions & 34 deletions api/mxnet_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <dmlc/logging.h>
#include <mxnet/base.h>
#include <mxnet/narray.h>
#include <mxnet/atomic_symbol_registry.h>
#include <mxnet/api_registry.h>
#include <mutex>
#include "./mxnet_api.h"
Expand Down Expand Up @@ -243,57 +244,43 @@ int MXFuncInvoke(FunctionHandle fun,
auto *f = static_cast<const FunctionRegistry::Entry *>(fun);
(*f)((NArray**)(use_vars), // NOLINT(*)
scalar_args,
(NArray**)(mutate_vars)); // NOLINT(*)
(NArray**)(mutate_vars)); // NOLINT(*)
API_END();
}

int MXSymFree(SymbolHandle sym) {
int MXSymCreate(const char *type_str,
int num_param,
const char** keys,
const char** vals,
SymbolHandle* out) {
API_BEGIN();
delete static_cast<Symbol*>(sym);
CCreateSymbol(type_str, num_param, keys, vals, (Symbol**)out); // NOLINT(*)
API_END();
}

int MXSymCreatorDescribe(SymbolCreatorHandle sym_creator,
mx_uint *use_param) {
int MXSymFree(SymbolHandle sym) {
API_BEGIN();
auto *sc = static_cast<const SymbolCreatorRegistry::Entry *>(sym_creator);
*use_param = sc->use_param ? 1 : 0;
delete static_cast<Symbol*>(sym);
API_END();
}

int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator,
int count,
const char** keys,
const char** vals,
SymbolHandle* out) {
int MXSymDescribe(const char *type_str,
mx_uint *use_param) {
API_BEGIN();
const SymbolCreatorRegistry::Entry *sc =
static_cast<const SymbolCreatorRegistry::Entry *>(sym_creator);
sc->body(count, keys, vals, (Symbol**)(out)); // NOLINT(*)
*use_param = AtomicSymbolRegistry::Find(type_str)->use_param ? 1 : 0;
API_END();
}

int MXListSymCreators(mx_uint *out_size,
SymbolCreatorHandle **out_array) {
int MXListSyms(mx_uint *out_size,
const char ***out_array) {
API_BEGIN();
auto &vec = SymbolCreatorRegistry::List();
auto &vec = AtomicSymbolRegistry::List();
*out_size = static_cast<mx_uint>(vec.size());
*out_array = (SymbolCreatorHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*)
API_END();
}

int MXGetSymCreator(const char *name,
SymbolCreatorHandle *out) {
API_BEGIN();
*out = SymbolCreatorRegistry::Find(name);
API_END();
}

int MXSymCreatorGetName(SymbolCreatorHandle sym_creator,
const char **out_name) {
API_BEGIN();
auto *f = static_cast<const SymbolCreatorRegistry::Entry *>(sym_creator);
*out_name = f->name.c_str();
std::vector<const char*> type_strs;
for (auto entry : vec) {
type_strs.push_back(entry->type_str.c_str());
}
*out_array = (const char**)(dmlc::BeginPtr(type_strs)); // NOLINT(*)
API_END();
}

Expand Down
54 changes: 19 additions & 35 deletions api/mxnet_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,20 @@ MXNET_DLL int MXFuncInvoke(FunctionHandle fun,
*/
MXNET_DLL int MXSymCreateFromConfig(const char *cfg,
SymbolHandle *out);
/*!
* \brief invoke registered symbol creator through its handle.
* \param type_str the type of the AtomicSymbol
* \param num_param the number of the key value pairs in the param.
* \param keys an array of c str.
* \param vals the corresponding values of the keys.
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymCreate(const char *type_str,
int num_param,
const char** keys,
const char** vals,
SymbolHandle* out);
/*!
* \brief free the symbol handle
* \param sym the symbol
Expand All @@ -222,51 +236,21 @@ MXNET_DLL int MXSymCreateFromConfig(const char *cfg,
MXNET_DLL int MXSymFree(SymbolHandle sym);
/*!
* \brief query if the symbol creator needs param.
* \param sym_creator the symbol creator handle
* \param type_str the type of the AtomicSymbol
* \param use_param describe if the symbol creator requires param
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymCreatorDescribe(SymbolCreatorHandle sym_creator,
mx_uint *use_param);
/*!
* \brief invoke registered symbol creator through its handle.
* \param sym_creator pointer to the symbolcreator function.
* \param count the number of the key value pairs in the param.
* \param keys an array of c str.
* \param vals the corresponding values of the keys.
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator,
int count,
const char** keys,
const char** vals,
SymbolHandle* out);
MXNET_DLL int MXSymDescribe(const char *type_str,
mx_uint *use_param);
/*!
* \brief list all the available sym_creator
* most user can use it to list all the needed sym_creators
* \param out_size the size of returned array
* \param out_array the output sym_creators
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXListSymCreators(mx_uint *out_size,
SymbolCreatorHandle **out_array);
/*!
* \brief get the sym_creator by name
* \param name the name of the sym_creator
* \param out the corresponding sym_creator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXGetSymCreator(const char *name,
SymbolCreatorHandle *out);
/*!
* \brief get the name of sym_creator handle
* \param sym_creator the sym_creator handle
* \param out_name the name of the sym_creator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymCreatorGetName(SymbolCreatorHandle sym_creator,
const char **out_name);
MXNET_DLL int MXListSyms(mx_uint *out_size,
const char ***out_array);
/*!
* \brief compose the symbol on other symbol
* \param sym the symbol to apply
Expand Down
23 changes: 10 additions & 13 deletions api/python/mxnet/symbol_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class _SymbolCreator(object):
"""SymbolCreator is a function that takes Param and return symbol"""

def __init__(self, handle, name):
def __init__(self, name):
"""Initialize the function with handle
Parameters
Expand All @@ -23,11 +23,10 @@ def __init__(self, handle, name):
name : string
the name of the function
"""
self.handle = handle
self.name = name
use_param = mx_uint()
check_call(_LIB.MXSymCreatorDescribe(
self.handle,
check_call(_LIB.MXSymDescribe(
c_str(self.name),
ctypes.byref(use_param)))
self.use_param = use_param.value

Expand All @@ -45,8 +44,8 @@ def __call__(self, **kwargs):
keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()])
vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()])
sym_handle = SymbolHandle()
check_call(_LIB.MXSymCreatorInvoke(
self.handle,
check_call(_LIB.MXSymCreate(
c_str(self.name),
mx_uint(len(kwargs)),
keys,
vals,
Expand All @@ -56,14 +55,12 @@ def __call__(self, **kwargs):
class _SymbolCreatorRegistry(object):
"""Function Registry"""
def __init__(self):
plist = ctypes.POINTER(ctypes.c_void_p)()
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.MXListSymCreators(ctypes.byref(size),
ctypes.byref(plist)))
check_call(_LIB.MXListSyms(ctypes.byref(size),
ctypes.byref(plist)))
hmap = {}
for i in range(size.value):
hdl = plist[i]
name = ctypes.c_char_p()
check_call(_LIB.MXSymCreatorGetName(hdl, ctypes.byref(name)))
hmap[name.value] = _SymbolCreator(hdl, name.value)
name = plist[i]
hmap[name.value] = _SymbolCreator(name.value)
self.__dict__.update(hmap)
65 changes: 0 additions & 65 deletions include/mxnet/api_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,70 +212,5 @@ class FunctionRegistry {
static auto __ ## name ## _narray_fun__ = \
::mxnet::FunctionRegistry::Get()->Register("" # name)

/*! \brief registry of symbol creator */
class SymbolCreatorRegistry {
public:
/*! \brief SymbolCreator is a function pointer */
typedef void(*SymbolCreator)(int count, const char**, const char**, Symbol**);
/*! \return get a singleton */
static SymbolCreatorRegistry *Get();
/*! \brief keep the SymbolCreator function and its meta information */
struct Entry {
/*! \brief the name of the symbol creator */
std::string name;
/*! \brief the body of the function */
SymbolCreator body;
/*! \brief if the creator requires params to construct */
bool use_param;
/*! \brief constructor */
explicit Entry(const std::string& name) : name(name), body(nullptr), use_param(true) {}
/*! \brief setter of body */
inline Entry& set_body(SymbolCreator sc) { body = sc; return *this; }
/*! \brief setter of use_param */
inline Entry& set_use_param(bool up) { use_param = up; return *this; }
};
/*!
* \brief register a name symbol under name
* \param name name of the function
* \return ref to the registered entry, used to set properties
*/
Entry &Register(const std::string& name);
/*! \return list of functions in the registry */
inline static const std::vector<const Entry*> &List() {
return Get()->fun_list_;
}
/*!
* \brief find an symbolcreator entry with corresponding name
* \param name name of the symbolcreator
* \return the corresponding symbolcreator, can be NULL
*/
inline static const Entry *Find(const std::string &name) {
auto &fmap = Get()->fmap_;
auto p = fmap.find(name);
if (p != fmap.end()) {
return p->second;
} else {
return nullptr;
}
}

private:
/*! \brief list of functions */
std::vector<const Entry*> fun_list_;
/*! \brief map of name->function */
std::map<std::string, Entry*> fmap_;
/*! \brief constructor */
SymbolCreatorRegistry() {}
/*! \brief destructor */
~SymbolCreatorRegistry();
};

/*!
* \brief macro to register symbol creator
*/
#define REGISTER_SYMBOL_CREATOR(name) \
static auto __ ## name ## _symbol_creator__ = \
::mxnet::SymbolCreatorRegistry::Get()->Register("" # name)

} // namespace mxnet
#endif // MXNET_API_REGISTRY_H_
Loading

0 comments on commit 6214cd9

Please sign in to comment.