Skip to content
This repository has been archived by the owner on Feb 1, 2020. It is now read-only.

force infer storage on backward pass (#1) #113

Merged
merged 2 commits into from
May 15, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,6 @@ Symbol Symbol::GetInternals() const {
}

Symbol Symbol::GetChildren() const {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol ret;
std::unordered_set<Node*> visited;
for (const auto& p : this->outputs) {
Expand Down
26 changes: 15 additions & 11 deletions src/pass/infer_shape_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
namespace nnvm {
namespace pass {
namespace {
// TODO(haibin) change file name to infer_attrs.cc
template<typename AttrType, typename IsNone, typename FDefault>
Graph InferAttr(Graph &&ret,
const AttrType empty_val,
Expand All @@ -20,7 +19,8 @@ Graph InferAttr(Graph &&ret,
const char* attr_name,
const char* unknown_name,
IsNone fis_none,
FDefault fdefault) {
FDefault fdefault,
bool backward_identity_assign) {
using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape =
Expand Down Expand Up @@ -88,7 +88,8 @@ Graph InferAttr(Graph &&ret,
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
}
}
} else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) {
} else if (is_backward.get(inode.source->op(), false) &&
inode.control_deps.size() && backward_identity_assign) {
CHECK_GE(inode.control_deps.size(), 1U)
<< "BackwardOp need to have control_deps to its forward op";
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
Expand Down Expand Up @@ -208,7 +209,7 @@ NNVM_REGISTER_PASS(InferShape)
"FInferShape", "shape_inputs", "shape_attr_key",
"shape", "shape_num_unknown_nodes",
[](const TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
nullptr);
nullptr, true);
})
.set_change_graph(false)
.provide_graph_attr("shape");
Expand Down Expand Up @@ -241,29 +242,32 @@ inline bool SameType(const NodeAttrs& attrs,
}

// assigning default type N to both input and output attrs with value -1
template <int N>
template <int default_val, int none>
inline bool DefaultType(const NodeAttrs& attrs,
std::vector<int> *iattr,
std::vector<int> *oattr) {
// LOG(INFO) << "DefaultType " << N;
for (int& v : *oattr) {
if (v == -1) v = N;
if (v == none) v = default_val;
}
for (int& v : *iattr) {
if (v == -1) v = N;
if (v == none) v = default_val;
}
return true;
}

NNVM_REGISTER_PASS(InferStorageType)
.describe("Infer the storage type of each node entries.")
.set_body([](Graph ret) {
// for storage type, the backward attr is not necessarily the same as it's correspondence
const int none = -1;
const int kDefaultStorage = 0;
return InferAttr<int>(
std::move(ret), -1,
std::move(ret), none,
"FInferStorageType", "storage_type_inputs", "storage_type_attr_key",
"storage_type", "storage_type_num_unknown_nodes",
[](const int t) { return t == -1; },
DefaultType<1>);
[](const int t) { return t == none; },
DefaultType<kDefaultStorage, none>, false);
})
.set_change_graph(false)
.provide_graph_attr("storage_type");
Expand All @@ -276,7 +280,7 @@ NNVM_REGISTER_PASS(InferType)
"FInferType", "dtype_inputs", "dtype_attr_key",
"dtype", "dtype_num_unknown_nodes",
[](const int t) { return t == -1; },
SameType);
SameType, true);
})
.set_change_graph(false)
.provide_graph_attr("dtype");
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_infer_storage_type():
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
assert g.json_attr('storage_type')[jnode_row_ptr[nindex["add1"]]] == 1
assert g.json_attr('storage_type')[jnode_row_ptr[nindex["add1"]]] == 0

def test_place_device():
x = sym.Variable('x', device_group="stage1")
Expand Down