From cae0d91741bc00fdd8278b84182f724569b3932d Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Thu, 13 Jul 2017 20:39:39 +0000 Subject: [PATCH 1/3] move storage type vector from nnvm to mxnet --- include/mxnet/graph_attr_types.h | 30 +++++++++++++++++++++++++++ src/common/utils.h | 3 ++- src/executor/attach_op_execs_pass.cc | 2 +- src/executor/exec_pass.h | 3 ++- src/executor/graph_executor.cc | 27 ++++++++++++------------ src/executor/graph_executor.h | 4 ++-- src/executor/infer_graph_attr_pass.cc | 3 ++- 7 files changed, 52 insertions(+), 20 deletions(-) create mode 100644 include/mxnet/graph_attr_types.h diff --git a/include/mxnet/graph_attr_types.h b/include/mxnet/graph_attr_types.h new file mode 100644 index 000000000000..a2bf3cf87d2f --- /dev/null +++ b/include/mxnet/graph_attr_types.h @@ -0,0 +1,30 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file graph_attr_types.h + * \brief Data structures that can appear in graph attributes. + */ +#ifndef MXNET_GRAPH_ATTR_TYPES_H_ +#define MXNET_GRAPH_ATTR_TYPES_H_ + +#include + +namespace mxnet { + +/*! + * \brief The result holder of storage type of each NodeEntry in the graph. + * \note Stored under graph.attrs["storage_type"], provided by Pass "InferStorageType" + * + * \code + * Graph g = ApplyPass(src_graph, "InferStorageType"); + * const StorageVector& stypes = g.GetAttr("storage_type"); + * // get shape by entry id + * int entry_type = stypes[g.indexed_graph().entry_id(my_entry)]; + * \endcode + * + * \sa FInferStorageType + */ +using StorageTypeVector = std::vector; + +} // namespace mxnet + +#endif // MXNET_GRAPH_ATTR_TYPES_H_ diff --git a/src/common/utils.h b/src/common/utils.h index 19592affacac..254b6ce5bd21 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -97,7 +98,7 @@ inline void CastNonDefaultStorage(const std::vector& dst, } // Check if any storage type is not default storage -inline bool ContainsNonDefaultStorage(const nnvm::StorageTypeVector& vstorage) { +inline bool ContainsNonDefaultStorage(const StorageTypeVector& vstorage) { for (auto& i : vstorage) { if (i != kUndefinedStorage && i != kDefaultStorage) return true; } diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 0d718df41c9e..68fe3dd16e64 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include "./exec_pass.h" #include "../common/utils.h" @@ -232,7 +233,6 @@ class FComputeExExecutor : public OpExecutor { // pass to attach operator executors Graph AttachOpExecs(Graph g) { using nnvm::DTypeVector; - using nnvm::StorageTypeVector; using nnvm::ShapeVector; using nnvm::FMutateInputs; diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 9be2d6c2f672..35f5e8b4a41c 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -145,7 +146,7 @@ Graph InferType(Graph graph, * The index of StorageTypeVector is given by graph.indexed_graph().entry_id. */ Graph InferStorageType(Graph graph, - nnvm::StorageTypeVector storage_type_inputs, + StorageTypeVector storage_type_inputs, const std::string& storage_type_attr_key = ""); } // namespace exec diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index cf428ff5701d..59559f8577aa 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -432,7 +432,7 @@ void HandleInferTypeError(const size_t num_forward_inputs, void HandleInferStorageTypeError(const size_t num_forward_inputs, const nnvm::IndexedGraph& idx, - const nnvm::StorageTypeVector& inferred_stypes) { + const StorageTypeVector& inferred_stypes) { int cnt = 10; std::ostringstream oss; for (size_t i = 0; i < num_forward_inputs; ++i) { @@ -490,7 +490,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, data_entry_.resize(idx.num_node_entries()); nnvm::ShapeVector arg_shapes; nnvm::DTypeVector arg_dtypes; - nnvm::StorageTypeVector arg_stypes; + StorageTypeVector arg_stypes; for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& arg_name = idx[nid].source->attrs.name; @@ -540,7 +540,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, g = InferStorageType(std::move(g), arg_stypes, "__storage_type__"); if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), - g.GetAttr("storage_type")); + g.GetAttr("storage_type")); } // Initialize the rest attributes of the graph. @@ -558,7 +558,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, - const nnvm::StorageTypeVector& inferred_stypes, + const StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -664,7 +664,7 @@ NDArray ReshapeOrCreate(const std::string& name, void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, - const nnvm::StorageTypeVector& inferred_stypes, + const StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -787,13 +787,13 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, const nnvm::NodeEntryMap& feed_dict) { const auto& idx = g.indexed_graph(); // dispatch based on stype per operator - const auto& vstorage_type = g.GetAttr("storage_type"); - nnvm::StorageTypeVector dispatch_stypes(idx.num_nodes(), kUndefinedStorage); + const auto& vstorage_type = g.GetAttr("storage_type"); + StorageTypeVector dispatch_stypes(idx.num_nodes(), kUndefinedStorage); for (size_t nid = 0; nid < idx.num_nodes(); nid++) { const auto& inode = idx[nid]; auto num_outputs = inode.source->num_outputs(); auto num_inputs = inode.inputs.size(); - nnvm::StorageTypeVector vs(num_inputs + num_outputs, kUndefinedStorage); + StorageTypeVector vs(num_inputs + num_outputs, kUndefinedStorage); for (size_t i = 0; i < num_inputs; i++) { auto e = inode.inputs[i]; vs[i] = vstorage_type[idx.entry_id(e)]; @@ -904,7 +904,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const nnvm::IndexedGraph& idx = g.indexed_graph(); nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); - nnvm::StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); + StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& name = idx[nid].source->attrs.name; @@ -936,7 +936,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, g = InferStorageType(std::move(g), arg_stypes, "__storage_type__"); if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), - g.GetAttr("storage_type")); + g.GetAttr("storage_type")); } // Create in_args, arg_grads, and aux_states using @@ -944,13 +944,13 @@ void GraphExecutor::Init(nnvm::Symbol symbol, if (nullptr == shared_buffer) { // regular simple bind InitArguments(idx, g.GetAttr("shape"), g.GetAttr("dtype"), - g.GetAttr("storage_type"), + g.GetAttr("storage_type"), in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec); } else { // simple bind using shared data arrays and shared_exec InitArguments(idx, g.GetAttr("shape"), g.GetAttr("dtype"), - g.GetAttr("storage_type"), + g.GetAttr("storage_type"), in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, grad_req_types, shared_arg_names, shared_exec, shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec); @@ -1003,7 +1003,6 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, // initialize the memory of each entries void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { using nnvm::DTypeVector; - using nnvm::StorageTypeVector; using nnvm::ShapeVector; using nnvm::StorageVector; // get the graph @@ -1154,7 +1153,7 @@ void GraphExecutor::InitCachedOps() { const auto& vctx = graph_.GetAttr("context"); const auto& addto_entry = graph_.GetAttr >("addto_entry"); const auto& skip_plus_node = graph_.GetAttr >("skip_plus_node"); - const auto& vstorage_type = graph_.GetAttr("storage_type"); + const auto& vstorage_type = graph_.GetAttr("storage_type"); op_nodes_.resize(idx.num_nodes()); // setup the array and requirements. diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index 308eddba8b80..45063a0e2e2f 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -129,7 +129,7 @@ class GraphExecutor : public Executor { void InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, - const nnvm::StorageTypeVector& inferred_stypes, + const StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -142,7 +142,7 @@ class GraphExecutor : public Executor { void InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, - const nnvm::StorageTypeVector& inferred_stypes, + const StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index 3789c313bf18..7fab732f44c5 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -5,6 +5,7 @@ */ #include +#include #include "./exec_pass.h" namespace mxnet { @@ -314,7 +315,7 @@ nnvm::Graph InferType(nnvm::Graph graph, } nnvm::Graph InferStorageType(nnvm::Graph graph, - nnvm::StorageTypeVector storage_type_inputs, + StorageTypeVector storage_type_inputs, const std::string& storage_type_attr_key) { using dmlc::any; if (storage_type_inputs.size() != 0) { From fa526d840c4e83b6aa75f644e051a63d1e57faa0 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Thu, 13 Jul 2017 20:40:25 +0000 Subject: [PATCH 2/3] update nnvm --- nnvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnvm b/nnvm index d02104dca1ee..66d2e8000bbd 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit d02104dca1eeb174a063aa06b54b774875a9106f +Subproject commit 66d2e8000bbd7bb8da844e3c94003f1c1d6e5f43 From 7edab5a521c6bd21afa8c473632e2a8fb0aa23b9 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Fri, 14 Jul 2017 23:23:39 +0000 Subject: [PATCH 3/3] update nnvm --- nnvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nnvm b/nnvm index 66d2e8000bbd..0767b966fe8a 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit 66d2e8000bbd7bb8da844e3c94003f1c1d6e5f43 +Subproject commit 0767b966fe8a985a3cb2de49876c621271f480ba