diff --git a/include/nnvm/op_attr_types.h b/include/nnvm/op_attr_types.h index 4d10c304a5d6..b29ade88a034 100644 --- a/include/nnvm/op_attr_types.h +++ b/include/nnvm/op_attr_types.h @@ -30,6 +30,18 @@ namespace nnvm { */ using FListInputNames = std::function (const NodeAttrs& attrs)>; +/*! + * \brief Return number of visible outputs by the user. + * + * \param attrs The attributes of the node. + * + * \note Register under "FNumVisibleOutputs", default not registered. + * This can be used to hide certain output from the user, + * but the additional outputs can be used to pass information from + * forward to gradient pass. + */ +using FNumVisibleOutputs = std::function; + /*! * \brief Return list of output arguments names of each operator. * diff --git a/src/core/symbolic.cc b/src/core/symbolic.cc index 8da33faa6f54..535dfa4f2127 100644 --- a/src/core/symbolic.cc +++ b/src/core/symbolic.cc @@ -87,7 +87,7 @@ inline std::vector GetKeys( // whether the symbol is atomic functor inline bool IsAtomic(const std::vector& outputs) { - return outputs.size() == 1 && outputs[0].node->inputs.size() == 0; + return outputs[0].node->inputs.size() == 0; } // public functions @@ -222,6 +222,7 @@ std::vector Symbol::ListInputNames(ListInputOption option) const { std::vector Symbol::ListOutputNames() const { static auto& flist_ouputs = Op::GetAttr("FListOutputNames"); + std::vector ret; for (auto &head : outputs) { if (head.node->is_variable()) { @@ -256,8 +257,6 @@ void Symbol::Compose(const array_view& args, const std::string& name) { static auto& flist_inputs = Op::GetAttr("FListInputNames"); - CHECK_EQ(outputs.size(), 1) - << "Only composition of value function is supported currently"; CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed"; // parameter check. for (size_t i = 0; i < args.size(); ++i) { @@ -400,6 +399,7 @@ void Symbol::AddControlDeps(const Symbol& src) { } Symbol Symbol::GetInternals() const { + static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol ret; DFSVisit(this->outputs, [&ret](const NodePtr& node) { Node* n = node.get(); @@ -409,6 +409,9 @@ Symbol Symbol::GetInternals() const { ret.outputs.emplace_back(NodeEntry{node, 0, param.version}); } else { uint32_t nout = n->num_outputs(); + if (fnum_vis_output.count(n->op())) { + nout = fnum_vis_output[n->op()](n->attrs); + } for (uint32_t i = 0; i < nout; ++i) { ret.outputs.emplace_back(NodeEntry{node, i, 0}); } @@ -467,6 +470,7 @@ std::unordered_map Symbol::ListAttrs(ListAttrOption op Symbol Symbol::CreateFunctor(const Op* op, std::unordered_map attrs) { + static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs"); Symbol s; NodePtr n = Node::Create(); n->attrs.op = op; @@ -474,7 +478,14 @@ Symbol Symbol::CreateFunctor(const Op* op, if (n->op()->attr_parser != nullptr) { n->op()->attr_parser(&(n->attrs)); } - s.outputs.emplace_back(NodeEntry{std::move(n), 0, 0}); + + uint32_t nout = n->num_outputs(); + if (fnum_vis_output.count(n->op())) { + nout = fnum_vis_output[n->op()](n->attrs); + } + for (uint32_t i = 0; i < nout; ++i) { + s.outputs.emplace_back(NodeEntry{n, i, 0}); + } return s; } diff --git a/src/pass/place_device.cc b/src/pass/place_device.cc index 607c51a7f319..00d216512558 100644 --- a/src/pass/place_device.cc +++ b/src/pass/place_device.cc @@ -12,7 +12,6 @@ namespace nnvm { namespace pass { namespace { - // simply logic to place device according to device_group hint // insert copy node when there is Graph PlaceDevice(Graph src) { diff --git a/src/pass/plan_memory.cc b/src/pass/plan_memory.cc index 34b05d5d6c94..0f0b8e79b9c2 100644 --- a/src/pass/plan_memory.cc +++ b/src/pass/plan_memory.cc @@ -142,11 +142,11 @@ Graph PlanMemory(Graph ret) { // step 1: initialize reference count for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (const auto& e : idx[nid].inputs) { - ++ref_count[e.node_id]; + ++ref_count[idx.entry_id(e)]; } } for (const auto& e : idx.outputs()) { - ++ref_count[e.node_id]; + ++ref_count[idx.entry_id(e)]; } // step 2: allocate memory. StorageVector storage(idx.num_node_entries(), -1); @@ -202,10 +202,13 @@ Graph PlanMemory(Graph ret) { } } // check if there are outputs that can be freeded immediately + // these output are not referenced by any operator. for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { uint32_t eid = idx.entry_id(nid, index); if (ref_count[eid] == 0 && storage[eid] != GraphAllocator::kBadStorageID) { allocator.Release(storage[eid], nid); + // use -2 to indicate that the node was never touched. + storage_inplace_index[eid] = -2; } if (storage[eid] == GraphAllocator::kBadStorageID) { ++num_not_allocated;