Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

move storage type vector from nnvm to mxnet #7054

Merged
merged 3 commits into from
Jul 15, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions include/mxnet/graph_attr_types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*!
* Copyright (c) 2016 by Contributors
* \file graph_attr_types.h
* \brief Data structures that can appear in graph attributes.
*/
#ifndef MXNET_GRAPH_ATTR_TYPES_H_
#define MXNET_GRAPH_ATTR_TYPES_H_

#include <vector>

namespace mxnet {

/*!
* \brief The result holder of storage type of each NodeEntry in the graph.
* \note Stored under graph.attrs["storage_type"], provided by Pass "InferStorageType"
*
* \code
* Graph g = ApplyPass(src_graph, "InferStorageType");
* const StorageVector& stypes = g.GetAttr<StorageTypeVector>("storage_type");
* // get shape by entry id
* int entry_type = stypes[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
* \sa FInferStorageType
*/
using StorageTypeVector = std::vector<int>;

} // namespace mxnet

#endif // MXNET_GRAPH_ATTR_TYPES_H_
3 changes: 2 additions & 1 deletion src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <mxnet/engine.h>
#include <mxnet/ndarray.h>
#include <mxnet/op_attr_types.h>
#include <mxnet/graph_attr_types.h>
#include <nnvm/graph_attr_types.h>

#include <memory>
Expand Down Expand Up @@ -97,7 +98,7 @@ inline void CastNonDefaultStorage(const std::vector<NDArray>& dst,
}

// Check if any storage type is not default storage
inline bool ContainsNonDefaultStorage(const nnvm::StorageTypeVector& vstorage) {
inline bool ContainsNonDefaultStorage(const StorageTypeVector& vstorage) {
for (auto& i : vstorage) {
if (i != kUndefinedStorage && i != kDefaultStorage) return true;
}
Expand Down
2 changes: 1 addition & 1 deletion src/executor/attach_op_execs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <mxnet/base.h>
#include <mxnet/operator.h>
#include <mxnet/op_attr_types.h>
#include <mxnet/graph_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include "./exec_pass.h"
#include "../common/utils.h"
Expand Down Expand Up @@ -232,7 +233,6 @@ class FComputeExExecutor : public OpExecutor {
// pass to attach operator executors
Graph AttachOpExecs(Graph g) {
using nnvm::DTypeVector;
using nnvm::StorageTypeVector;
using nnvm::ShapeVector;
using nnvm::FMutateInputs;

Expand Down
3 changes: 2 additions & 1 deletion src/executor/exec_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <mxnet/base.h>
#include <mxnet/ndarray.h>
#include <mxnet/operator.h>
#include <mxnet/graph_attr_types.h>
#include <nnvm/graph.h>
#include <nnvm/graph_attr_types.h>
#include <vector>
Expand Down Expand Up @@ -145,7 +146,7 @@ Graph InferType(Graph graph,
* The index of StorageTypeVector is given by graph.indexed_graph().entry_id.
*/
Graph InferStorageType(Graph graph,
nnvm::StorageTypeVector storage_type_inputs,
StorageTypeVector storage_type_inputs,
const std::string& storage_type_attr_key = "");

} // namespace exec
Expand Down
27 changes: 13 additions & 14 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ void HandleInferTypeError(const size_t num_forward_inputs,

void HandleInferStorageTypeError(const size_t num_forward_inputs,
const nnvm::IndexedGraph& idx,
const nnvm::StorageTypeVector& inferred_stypes) {
const StorageTypeVector& inferred_stypes) {
int cnt = 10;
std::ostringstream oss;
for (size_t i = 0; i < num_forward_inputs; ++i) {
Expand Down Expand Up @@ -490,7 +490,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
data_entry_.resize(idx.num_node_entries());
nnvm::ShapeVector arg_shapes;
nnvm::DTypeVector arg_dtypes;
nnvm::StorageTypeVector arg_stypes;
StorageTypeVector arg_stypes;
for (size_t i = 0; i < num_forward_inputs_; ++i) {
const uint32_t nid = idx.input_nodes().at(i);
const std::string& arg_name = idx[nid].source->attrs.name;
Expand Down Expand Up @@ -540,7 +540,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
g = InferStorageType(std::move(g), arg_stypes, "__storage_type__");
if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
g.GetAttr<nnvm::StorageTypeVector>("storage_type"));
g.GetAttr<StorageTypeVector>("storage_type"));
}

// Initialize the rest attributes of the graph.
Expand All @@ -558,7 +558,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
const nnvm::ShapeVector& inferred_shapes,
const nnvm::DTypeVector& inferred_dtypes,
const nnvm::StorageTypeVector& inferred_stypes,
const StorageTypeVector& inferred_stypes,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
Expand Down Expand Up @@ -664,7 +664,7 @@ NDArray ReshapeOrCreate(const std::string& name,
void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
const nnvm::ShapeVector& inferred_shapes,
const nnvm::DTypeVector& inferred_dtypes,
const nnvm::StorageTypeVector& inferred_stypes,
const StorageTypeVector& inferred_stypes,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
Expand Down Expand Up @@ -787,13 +787,13 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
const auto& idx = g.indexed_graph();
// dispatch based on stype per operator
const auto& vstorage_type = g.GetAttr<nnvm::StorageTypeVector>("storage_type");
nnvm::StorageTypeVector dispatch_stypes(idx.num_nodes(), kUndefinedStorage);
const auto& vstorage_type = g.GetAttr<StorageTypeVector>("storage_type");
StorageTypeVector dispatch_stypes(idx.num_nodes(), kUndefinedStorage);
for (size_t nid = 0; nid < idx.num_nodes(); nid++) {
const auto& inode = idx[nid];
auto num_outputs = inode.source->num_outputs();
auto num_inputs = inode.inputs.size();
nnvm::StorageTypeVector vs(num_inputs + num_outputs, kUndefinedStorage);
StorageTypeVector vs(num_inputs + num_outputs, kUndefinedStorage);
for (size_t i = 0; i < num_inputs; i++) {
auto e = inode.inputs[i];
vs[i] = vstorage_type[idx.entry_id(e)];
Expand Down Expand Up @@ -904,7 +904,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
const nnvm::IndexedGraph& idx = g.indexed_graph();
nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape());
nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1);
nnvm::StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage);
StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage);
for (size_t i = 0; i < num_forward_inputs_; ++i) {
const uint32_t nid = idx.input_nodes().at(i);
const std::string& name = idx[nid].source->attrs.name;
Expand Down Expand Up @@ -936,21 +936,21 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
g = InferStorageType(std::move(g), arg_stypes, "__storage_type__");
if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
g.GetAttr<nnvm::StorageTypeVector>("storage_type"));
g.GetAttr<StorageTypeVector>("storage_type"));
}

// Create in_args, arg_grads, and aux_states using
// the inferred shapes and dtypes.
if (nullptr == shared_buffer) { // regular simple bind
InitArguments(idx, g.GetAttr<nnvm::ShapeVector>("shape"),
g.GetAttr<nnvm::DTypeVector>("dtype"),
g.GetAttr<nnvm::StorageTypeVector>("storage_type"),
g.GetAttr<StorageTypeVector>("storage_type"),
in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec);
} else { // simple bind using shared data arrays and shared_exec
InitArguments(idx, g.GetAttr<nnvm::ShapeVector>("shape"),
g.GetAttr<nnvm::DTypeVector>("dtype"),
g.GetAttr<nnvm::StorageTypeVector>("storage_type"),
g.GetAttr<StorageTypeVector>("storage_type"),
in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
grad_req_types, shared_arg_names, shared_exec,
shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec);
Expand Down Expand Up @@ -1003,7 +1003,6 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
// initialize the memory of each entries
void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
using nnvm::DTypeVector;
using nnvm::StorageTypeVector;
using nnvm::ShapeVector;
using nnvm::StorageVector;
// get the graph
Expand Down Expand Up @@ -1154,7 +1153,7 @@ void GraphExecutor::InitCachedOps() {
const auto& vctx = graph_.GetAttr<ContextVector>("context");
const auto& addto_entry = graph_.GetAttr<std::vector<int> >("addto_entry");
const auto& skip_plus_node = graph_.GetAttr<std::vector<int> >("skip_plus_node");
const auto& vstorage_type = graph_.GetAttr<nnvm::StorageTypeVector>("storage_type");
const auto& vstorage_type = graph_.GetAttr<StorageTypeVector>("storage_type");

op_nodes_.resize(idx.num_nodes());
// setup the array and requirements.
Expand Down
4 changes: 2 additions & 2 deletions src/executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class GraphExecutor : public Executor {
void InitArguments(const nnvm::IndexedGraph& idx,
const nnvm::ShapeVector& inferred_shapes,
const nnvm::DTypeVector& inferred_dtypes,
const nnvm::StorageTypeVector& inferred_stypes,
const StorageTypeVector& inferred_stypes,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
Expand All @@ -142,7 +142,7 @@ class GraphExecutor : public Executor {
void InitArguments(const nnvm::IndexedGraph& idx,
const nnvm::ShapeVector& inferred_shapes,
const nnvm::DTypeVector& inferred_dtypes,
const nnvm::StorageTypeVector& inferred_stypes,
const StorageTypeVector& inferred_stypes,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
Expand Down
3 changes: 2 additions & 1 deletion src/executor/infer_graph_attr_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/

#include <mxnet/op_attr_types.h>
#include <mxnet/graph_attr_types.h>
#include "./exec_pass.h"

namespace mxnet {
Expand Down Expand Up @@ -314,7 +315,7 @@ nnvm::Graph InferType(nnvm::Graph graph,
}

nnvm::Graph InferStorageType(nnvm::Graph graph,
nnvm::StorageTypeVector storage_type_inputs,
StorageTypeVector storage_type_inputs,
const std::string& storage_type_attr_key) {
using dmlc::any;
if (storage_type_inputs.size() != 0) {
Expand Down