This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from antinucleon/master
static graph
- Loading branch information
Showing
7 changed files
with
224 additions
and
155 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,84 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file static_graph.h | ||
* \brief the static graph of symbols | ||
* \brief The static graph of symbols | ||
*/ | ||
#ifndef MXNET_STATIC_GRAPH_H_ | ||
#define MXNET_STATIC_GRAPH_H_ | ||
|
||
#include <vector> | ||
#include <unordered_map> | ||
#include <string> | ||
#include <memory> | ||
#include "./base.h" | ||
#include "./atomic_symbol.h" | ||
|
||
namespace mxnet { | ||
/*! \brief static graph interface | ||
* static graph is an internal representation of symbol graph. | ||
* | ||
* The main purpose for static graph for binding a composite operator | ||
/*! | ||
* \brief StaticGraph is the configuration of computation graphs. | ||
* This is the "configuration file" of mxnet. | ||
* It can be converted to/from Symbol, and can be used to bind to operators. | ||
*/ | ||
struct StaticGraph { | ||
/*! \brief Node in static graph */ | ||
struct StaticNode { | ||
class StaticGraph { | ||
public: | ||
/*! \brief represents a data in the graph */ | ||
struct DataEntry { | ||
/*! \brief the source node id in the computation graph */ | ||
uint32_t source_id; | ||
/*! | ||
* \brief index of output from the source. | ||
* If index == -1, it represents all the outputs. | ||
*/ | ||
int32_t index; | ||
}; | ||
/*! \brief Operation Node in static graph */ | ||
struct Node { | ||
/*! \brief wrapped atomic symbol */ | ||
std::unique_ptr<AtomicSymbol> sym; | ||
/*! \brief name of the node */ | ||
std::string name; | ||
/*! \brief index of output from the source. */ | ||
int index; | ||
/*! \brief output shape for node */ | ||
std::vector<TShape> in_shape; | ||
/*! \brief output shape for node */ | ||
std::vector<TShape> out_shape; | ||
/*! \brief input id for each node */ | ||
std::vector<int> inputs_index; | ||
/*! \brief output id for each node */ | ||
std::vector<int> outputs_index; | ||
/*! \brief inputs (node_id, index) for of the nodes*/ | ||
std::vector<DataEntry> inputs; | ||
}; | ||
/*! \brief head node (need input from outside) */ | ||
std::vector<int> in_args_node_id; | ||
/*! \brief tail node (generate data to outside) */ | ||
std::vector<int> return_node_id; | ||
/*! \brief node name to id dictionary */ | ||
std::unordered_map<std::string, int> name_id_map; | ||
/*! \brief all nodes in the graph */ | ||
std::vector<StaticNode> nodes; | ||
std::vector<Node> nodes; | ||
/*! \brief index is nodes that correspods to arguments */ | ||
std::vector<uint32_t> arg_nodes; | ||
/*! \brief outputs(heads) of the graph */ | ||
std::vector<DataEntry> outputs; | ||
// funtions to help inference in static graph | ||
/*! | ||
* \brief Perform a topological sort on the graph | ||
* \return a topological order of node indices. | ||
*/ | ||
std::vector<uint32_t> TopoSort() const; | ||
/*! | ||
* \brief infer the node shapes in the computation graph. | ||
* | ||
* When calling this function, user can setup the shape information known into right position. | ||
* Unknown shape are indicated by shape.ndim() == 0. | ||
* | ||
* \param topo_order The topological order of node index, as created by TopoSort. | ||
* \param node_out_shapes The shapes of the each outputs of nodes in the graph. | ||
* \return if the shape inference is successful, return true, else return false. | ||
*/ | ||
bool InferNodeShapes(const std::vector<uint32_t> &topo_order, | ||
std::vector<std::vector<TShape> > *node_out_shapes) const; | ||
/*! | ||
* \brief infer the shapes of outputs and unknown input arguments | ||
* \param in_shape the shape of input arguments of the operator | ||
* this should be of same length as the vector returned by ListArguments | ||
* in_shape allows unknown elements, which are checked by shape.ndim() == 0. | ||
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape | ||
* For known shapes, InferShape will check shape consistency | ||
* | ||
* common practice: set the shape of data input, and usually weight's shape can be infered | ||
* | ||
* \param out_shape the shape of outputs of the operator | ||
* InferShape will modify the vector to fill output TShape | ||
* \return if the shape inference is successful, return true, else return false. | ||
*/ | ||
bool InferShape(std::vector<TShape> *in_shape, | ||
std::vector<TShape> *out_shape) const; | ||
}; | ||
} // namespace mxnet | ||
#endif // MXNET_STATIC_GRAPH_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file static_graph.cc | ||
* \brief static graph of mxnet | ||
*/ | ||
#include <dmlc/logging.h> | ||
#include <mxnet/static_graph.h> | ||
#include <vector> | ||
#include <queue> | ||
|
||
std::vector<uint32_t> StaticGraph::TopoSort() const { | ||
std::vector<int> out_degree(nodes.size(), 0); | ||
for (const Node &n : nodes) { | ||
for (const DataEntry &e : n.inputs) { | ||
++out_degree[e.source_id]; | ||
} | ||
} | ||
std::vector<uint32_t> ret(nodes.size()); | ||
auto result = ret.rbegin(); | ||
std::queue<uint32_t> queue; | ||
for (size_t i = 0; i < nodes.size(); ++i) { | ||
if (out_degree[i] == 0) { | ||
queue.push(static_cast<uint32_t>(i)); | ||
} | ||
} | ||
while (!queue.empty()) { | ||
uint32_t node_id = queue.front(); | ||
queue.pop(); | ||
*result = node_id; | ||
++result; | ||
for (const DataEntry &e : nodes[node_id].inputs) { | ||
out_degree[e.source_id] -= 1; | ||
if (out_degree[e.source_id] == 0) { | ||
queue.push(e.source_id); | ||
} | ||
} | ||
} | ||
return std::move(ret); | ||
} | ||
|
||
bool StaticGraph::InferShape(const std::vector<uint32_t> &topo_order, | ||
std::vector<std::vector<TShape> > *node_out_shapes) const { | ||
bool success = true; | ||
for (uint32_t nid : topo_order) { | ||
const Node &node = nodes[nid]; | ||
if (node.sym != nullptr) { | ||
std::vector<TShape> in_shape; | ||
for (const DataEntry &e : node.inputs) { | ||
in_shape.push_back(node_out_shapes[e.source_id][e.index]); | ||
} | ||
if (!node.sym->InferShape(&in_shape, &(*node_out_shapes)[nid])) return false; | ||
for (size_t i = 0; i < node.inputs.size(); ++i) { | ||
const DataEntry &e = node.inputs[i]; | ||
node_out_shapes[e.source_id][e.index] = in_shape[i]; | ||
} | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
bool StaticGraph::InferShape(std::vector<TShape> *in_shape, | ||
std::vector<TShape> *out_shape) const { | ||
std::vector<std::vector<TShape> > node_out_shapes(nodes.size()); | ||
for (size_t i = 0; i < nodes.size(); ++i) { | ||
int nout = 1; | ||
if (nodes[i].sym != nullptr) { | ||
nout = nodes[i].sym->NumReturns(); | ||
} | ||
node_out_shapes[i].resize(nout); | ||
} | ||
CHECK(in_shape->size() == arg_nodes.size()) | ||
<< "Wrong number of inputs to infer shape"; | ||
for (size_t i = 0; i < arg_nodes.size(); ++i) { | ||
node_out_shapes[nid][0] = (*in_shape)[i]; | ||
} | ||
if (!InferNodeShapes(this->TopoSort(), | ||
&node_out_shapes)) return false; | ||
for (size_t i = 0; i < arg_nodes.size(); ++i) { | ||
(*in_shape)[i] = node_out_shapes[nid][0]; | ||
} | ||
for (size_t i = 0; i < outputs.size(); ++i) { | ||
DataEntry e = outputs[i]; | ||
(*out_shape)[i] = node_out_shapes[e.source_id][e.index]; | ||
} | ||
} |
Oops, something went wrong.