Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #10 from antinucleon/master
Browse files Browse the repository at this point in the history
new symbol interface
  • Loading branch information
antinucleon committed Aug 10, 2015
2 parents fead856 + e50817d commit 84023c7
Show file tree
Hide file tree
Showing 14 changed files with 1,017 additions and 385 deletions.
11 changes: 9 additions & 2 deletions include/mxnet/atomic_symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ class AtomicSymbol {
*/
virtual ~AtomicSymbol() {}
/*! \brief get the descriptions of inputs for this symbol */
virtual std::vector<std::string> DescribeArguments() const {
virtual std::vector<std::string> ListArguments() const {
// default implementation returns "data"
return std::vector<std::string>(1, std::string("data"));
}
/*! \brief get the descriptions of outputs for this symbol */
virtual std::vector<std::string> DescribeReturns() const {
virtual std::vector<std::string> ListReturns() const {
// default implementation returns "output"
return std::vector<std::string>(1, std::string("output"));
}
Expand Down Expand Up @@ -77,6 +77,13 @@ class AtomicSymbol {
*/
virtual std::string TypeString() const = 0;
friend class Symbol;

/*!
* \brief create atomic symbol by type name
* \param type_name the type string of the AtomicSymbol
* \return a new constructed AtomicSymbol
*/
static AtomicSymbol *Create(const char* type_name);
};

} // namespace mxnet
Expand Down
89 changes: 63 additions & 26 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,21 @@ MXNET_DLL int MXFuncInvoke(FunctionHandle fun,
// Part 3: symbolic configuration generation
//--------------------------------------------
/*!
* \brief create symbol from config
* \param cfg configuration string
* \param out created symbol handle
* \brief list all the available AtomicSymbolEntry
* \param out_size the size of returned array
* \param out_array the output AtomicSymbolCreator array
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg,
SymbolHandle *out);
MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
AtomicSymbolCreator **out_array);
/*!
* \brief Get the name of AtomicSymbol.
* \param creator the AtomicSymbolCreator
* \param out the returned name of the creator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
const char **out);
/*!
* \brief create Symbol by wrapping AtomicSymbol
* \param creator the AtomicSymbolCreator
Expand All @@ -231,50 +239,79 @@ MXNET_DLL int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator,
const char **vals,
SymbolHandle *out);
/*!
* \brief free the symbol handle
* \brief Create a Variable Symbol.
* \param name name of the variable
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateVariable(const char *name, SymbolHandle *out);
/*!
* \brief Create symbol from config.
* \param cfg configuration string
* \param out created symbol handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg,
SymbolHandle *out);
/*!
* \brief Free the symbol handle.
* \param symbol the symbol
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolFree(SymbolHandle symbol);
/*!
* \brief list all the available AtomicSymbolEntry
* \param out_size the size of returned array
* \param out_array the output AtomicSymbolCreator array
* \brief Copy the symbol to another handle
* \param symbol the source symbol
* \param out used to hold the result of copy
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
AtomicSymbolCreator **out_array);
MXNET_DLL int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
/*!
* \brief get the singleton Symbol of the AtomicSymbol if any
* \param creator the AtomicSymbolCreator
* \param out the returned singleton Symbol of the AtomicSymbol the creator stands for
* \brief Print the content of symbol, used for debug.
* \param symbol the symbol
* \param out_str pointer to hold the output string of the printing.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetSingleton(AtomicSymbolCreator creator,
SymbolHandle *out);
MXNET_DLL int MXSymbolPrint(SymbolHandle symbol, const char **out_str);
/*!
* \brief get the singleton Symbol of the AtomicSymbol if any
* \param creator the AtomicSymbolCreator
* \param out the returned name of the creator
* \brief List arguments in the symbol.
* \param symbol the symbol
* \param out_size output size
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
const char **out);
MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);
/*!
* \brief compose the symbol on other symbol
* \brief List returns in the symbol.
* \param symbol the symbol
* \param out_size output size
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolListReturns(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);
/*!
* \brief Compose the symbol on other symbols.
*
* This function will change the sym hanlde.
* To achieve function apply behavior, copy the symbol first
* before apply.
*
* \param sym the symbol to apply
* \param name the name of symbol
* \param num_args number of arguments
* \param keys the key of keyword args (optional)
* \param args arguments to sym
* \param out the resulting symbol
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCompose(SymbolHandle sym,
const char *name,
mx_uint num_args,
const char** keys,
SymbolHandle* args,
SymbolHandle* out);

SymbolHandle* args);
//--------------------------------------------
// Part 4: operator interface on NArray
//--------------------------------------------
Expand Down
17 changes: 1 addition & 16 deletions include/mxnet/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,31 +229,16 @@ struct AtomicSymbolEntry {
std::string name;
/*! \brief function body to create AtomicSymbol */
Creator body;
/*! \brief singleton is created when no param is needed for the AtomicSymbol */
Symbol *singleton_symbol;
/*! \brief constructor */
explicit AtomicSymbolEntry(const std::string& name)
: use_param(true), name(name), body(NULL), singleton_symbol(NULL) {}
/*!
* \brief set if param is needed by this AtomicSymbol
*/
inline AtomicSymbolEntry &set_use_param(bool use_param) {
this->use_param = use_param;
return *this;
}
: use_param(true), name(name), body(NULL) {}
/*!
* \brief set the function body
*/
inline AtomicSymbolEntry &set_body(Creator body) {
this->body = body;
return *this;
}
/*!
* \brief return the singleton symbol
*/
Symbol *GetSingletonSymbol();
/*! \brief destructor */
~AtomicSymbolEntry();
/*!
* \brief invoke the function
* \return the created AtomicSymbol
Expand Down
43 changes: 16 additions & 27 deletions include/mxnet/static_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,28 @@ struct StaticGraph {
/*! \brief Node in static graph */
struct StaticNode {
/*! \brief wrapped atomic symbol */
AtomicSymbol* sym_;
std::unique_ptr<AtomicSymbol> sym;
/*! \brief name of the node */
std::string name_;
std::string name;
/*! \brief index of output from the source. */
int index;
/*! \brief output shape for node */
std::vector<TShape> in_shape;
/*! \brief output shape for node */
std::vector<TShape> out_shape;
/*! \brief input id for each node */
std::vector<int> inputs_index;
/*! \brief output id for each node */
std::vector<int> outputs_index;
};
/*! \brief head node (need input from outside) */
std::vector<int> in_args_node_id;
/*! \brief tail node (generate data to outside) */
std::vector<int> return_node_id;
/*! \brief node name to id dictionary */
std::unordered_map<std::string, int> name_id_map;
/*! \brief all nodes in the graph */
std::vector<StaticNode> nodes;
/*! \brief output id for each node */
std::vector<std::vector<int> > output_index;
/*! \brief connected graph for each node */
std::vector<std::vector<int> > connected_graph;
/*! \brief find node by using name
* \param name node name
* \param sym symbol need to be copied into node
* \return node id
*/
int FindNodeByName(const std::string& name, const AtomicSymbol* sym) {
int id = 0;
if (name_id_map.find(name) == name_id_map.end()) {
name_id_map[name] = name_id_map.size();
StaticNode static_node;
static_node.sym_ = sym->Copy();
static_node.name_ = name;
nodes.push_back(static_node);
output_index.push_back(std::vector<int>());
connected_graph.push_back(std::vector<int>());
id = name_id_map.size();
} else {
id = name_id_map[name];
}
return id;
}
};
} // namespace mxnet
#endif // MXNET_STATIC_GRAPH_H_
Loading

0 comments on commit 84023c7

Please sign in to comment.