diff --git a/include/mxnet/graph_attr_types.h b/include/mxnet/graph_attr_types.h new file mode 100644 index 000000000000..a2bf3cf87d2f --- /dev/null +++ b/include/mxnet/graph_attr_types.h @@ -0,0 +1,30 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file graph_attr_types.h + * \brief Data structures that can appear in graph attributes. + */ +#ifndef MXNET_GRAPH_ATTR_TYPES_H_ +#define MXNET_GRAPH_ATTR_TYPES_H_ + +#include + +namespace mxnet { + +/*! + * \brief The result holder of storage type of each NodeEntry in the graph. + * \note Stored under graph.attrs["storage_type"], provided by Pass "InferStorageType" + * + * \code + * Graph g = ApplyPass(src_graph, "InferStorageType"); + * const StorageVector& stypes = g.GetAttr("storage_type"); + * // get shape by entry id + * int entry_type = stypes[g.indexed_graph().entry_id(my_entry)]; + * \endcode + * + * \sa FInferStorageType + */ +using StorageTypeVector = std::vector; + +} // namespace mxnet + +#endif // MXNET_GRAPH_ATTR_TYPES_H_ diff --git a/nnvm b/nnvm index d02104dca1ee..0767b966fe8a 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit d02104dca1eeb174a063aa06b54b774875a9106f +Subproject commit 0767b966fe8a985a3cb2de49876c621271f480ba diff --git a/src/common/utils.h b/src/common/utils.h index 95ddc240cbf6..4e31c9861b13 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -98,7 +99,7 @@ inline void CastNonDefaultStorage(const std::vector& dst, } // Check if any storage type is not default storage -inline bool ContainsNonDefaultStorage(const nnvm::StorageTypeVector& vstorage) { +inline bool ContainsNonDefaultStorage(const StorageTypeVector& vstorage) { for (auto& i : vstorage) { if (i != kUndefinedStorage && i != kDefaultStorage) return true; } diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index e9aca1ecf17b..8287ac15c5b1 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include "../common/utils.h" #include "./exec_pass.h" diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index c51123214a98..693e4e425ba1 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -149,7 +150,7 @@ Graph InferType(Graph graph, * The index of StorageTypeVector is given by graph.indexed_graph().entry_id. */ Graph InferStorageType(Graph graph, - nnvm::StorageTypeVector storage_type_inputs, + StorageTypeVector storage_type_inputs, const std::string& storage_type_attr_key = ""); } // namespace exec diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index be1c0c5f2eb4..eef68e0f4ca1 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -447,7 +447,7 @@ void HandleInferTypeError(const size_t num_forward_inputs, void HandleInferStorageTypeError(const size_t num_forward_inputs, const nnvm::IndexedGraph& idx, - const nnvm::StorageTypeVector& inferred_stypes) { + const StorageTypeVector& inferred_stypes) { int cnt = 10; std::ostringstream oss; for (size_t i = 0; i < num_forward_inputs; ++i) { @@ -505,7 +505,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, data_entry_.resize(idx.num_node_entries()); nnvm::ShapeVector arg_shapes; nnvm::DTypeVector arg_dtypes; - nnvm::StorageTypeVector arg_stypes; + StorageTypeVector arg_stypes; for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& arg_name = idx[nid].source->attrs.name; @@ -555,7 +555,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, g = InferStorageType(std::move(g), arg_stypes, "__storage_type__"); if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), - g.GetAttr("storage_type")); + g.GetAttr("storage_type")); } // Initialize the rest attributes of the graph. @@ -573,7 +573,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, - const nnvm::StorageTypeVector& inferred_stypes, + const StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -679,7 +679,7 @@ NDArray ReshapeOrCreate(const std::string& name, void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, - const nnvm::StorageTypeVector& inferred_stypes, + const StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -802,13 +802,13 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, const nnvm::NodeEntryMap& feed_dict) { const auto& idx = g.indexed_graph(); // dispatch based on stype per operator - const auto& vstorage_type = g.GetAttr("storage_type"); - nnvm::StorageTypeVector dispatch_stypes(idx.num_nodes(), kUndefinedStorage); + const auto& vstorage_type = g.GetAttr("storage_type"); + StorageTypeVector dispatch_stypes(idx.num_nodes(), kUndefinedStorage); for (size_t nid = 0; nid < idx.num_nodes(); nid++) { const auto& inode = idx[nid]; auto num_outputs = inode.source->num_outputs(); auto num_inputs = inode.inputs.size(); - nnvm::StorageTypeVector vs(num_inputs + num_outputs, kUndefinedStorage); + StorageTypeVector vs(num_inputs + num_outputs, kUndefinedStorage); for (size_t i = 0; i < num_inputs; i++) { auto e = inode.inputs[i]; vs[i] = vstorage_type[idx.entry_id(e)]; @@ -919,7 +919,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const nnvm::IndexedGraph& idx = g.indexed_graph(); nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); - nnvm::StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); + StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& name = idx[nid].source->attrs.name; @@ -951,7 +951,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, g = InferStorageType(std::move(g), arg_stypes, "__storage_type__"); if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), - g.GetAttr("storage_type")); + g.GetAttr("storage_type")); } // Create in_args, arg_grads, and aux_states using @@ -959,13 +959,13 @@ void GraphExecutor::Init(nnvm::Symbol symbol, if (nullptr == shared_buffer) { // regular simple bind InitArguments(idx, g.GetAttr("shape"), g.GetAttr("dtype"), - g.GetAttr("storage_type"), + g.GetAttr("storage_type"), in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec); } else { // simple bind using shared data arrays and shared_exec InitArguments(idx, g.GetAttr("shape"), g.GetAttr("dtype"), - g.GetAttr("storage_type"), + g.GetAttr("storage_type"), in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, grad_req_types, shared_arg_names, shared_exec, shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec); @@ -1018,7 +1018,6 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, // initialize the memory of each entries void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { using nnvm::DTypeVector; - using nnvm::StorageTypeVector; using nnvm::ShapeVector; using nnvm::StorageVector; // get the graph @@ -1169,7 +1168,7 @@ void GraphExecutor::InitCachedOps() { const auto& vctx = graph_.GetAttr("context"); const auto& addto_entry = graph_.GetAttr >("addto_entry"); const auto& skip_plus_node = graph_.GetAttr >("skip_plus_node"); - const auto& vstorage_type = graph_.GetAttr("storage_type"); + const auto& vstorage_type = graph_.GetAttr("storage_type"); op_nodes_.resize(idx.num_nodes()); // setup the array and requirements. diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index 6c9d8350774b..697c0be822e3 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -126,7 +126,7 @@ class GraphExecutor : public Executor { void InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, - const nnvm::StorageTypeVector& inferred_stypes, + const StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -139,7 +139,7 @@ class GraphExecutor : public Executor { void InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, - const nnvm::StorageTypeVector& inferred_stypes, + const StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index 3789c313bf18..7fab732f44c5 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -5,6 +5,7 @@ */ #include +#include #include "./exec_pass.h" namespace mxnet { @@ -314,7 +315,7 @@ nnvm::Graph InferType(nnvm::Graph graph, } nnvm::Graph InferStorageType(nnvm::Graph graph, - nnvm::StorageTypeVector storage_type_inputs, + StorageTypeVector storage_type_inputs, const std::string& storage_type_attr_key) { using dmlc::any; if (storage_type_inputs.size() != 0) {