Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Run() method for framework::Executor #7807

Merged
merged 12 commits into from
Jan 26, 2018
4 changes: 3 additions & 1 deletion paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op)
cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)

cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)

cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table profiler)
framework_proto backward glog lod_rank_table profiler feed_fetch_method)

cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
Expand Down
165 changes: 165 additions & 0 deletions paddle/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <set>

#include "gflags/gflags.h"
#include "paddle/framework/feed_fetch_method.h"
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor_array.h"
Expand Down Expand Up @@ -149,5 +150,169 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
}
}

// Return false if the block does not have any feed operators.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add description of Check whether the block already has feed operators. here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// If some feed operators have been prepended to the block, check that
// the info contained in these feed operators matches the feed_targets
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true when the block has feed operators with matching info.
static bool has_feed_operators(
BlockDesc* block, std::map<std::string, const LoDTensor*>& feed_targets,
const std::string& feed_holder_name) {
size_t feed_count = 0;
for (auto* op : block->AllOps()) {
if (op->Type() == "feed") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op->Type() == "feed" -> op->Type() == kFeedOpType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

feed_count++;
PADDLE_ENFORCE(op->Input("X")[0] == feed_holder_name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use PADDLE_ENFORCE_EQ here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"Input to feed op should be '%s'", feed_holder_name);
std::string feed_target_name = op->Output("Out")[0];
PADDLE_ENFORCE(
feed_targets.find(feed_target_name) != feed_targets.end(),
"Feed operator output name '%s' cannot be found in 'feed_targets'",
feed_target_name);
PADDLE_ENFORCE(op->GetAttr("col").type() == typeid(int),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this.

"Attribute type of 'col' should be int");
} else {
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the else branch. It seems not necessary here.

}
}

if (feed_count > 0) {
PADDLE_ENFORCE(feed_count == feed_targets.size(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use PADDLE_ENFORCE_EQ here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"The number of feed operators should match 'feed_targets'");
}
return feed_count > 0;
}

// Return false if the block does not have any fetch operators.
// If some fetch operators have been appended to the block, check that
// the info contained in these fetch operators matches the fetch_targets
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true when the block has fetch operators with matching info.
static bool has_fetch_operators(
BlockDesc* block, std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& fetch_holder_name) {
size_t fetch_count = 0;
for (auto* op : block->AllOps()) {
if (op->Type() == "fetch") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op->Type() == "fetch" -> op->Type() == kFetchOpType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

fetch_count++;
PADDLE_ENFORCE(op->Output("Out")[0] == fetch_holder_name,
"Output of fetch op should be '%s'", fetch_holder_name);
std::string fetch_target_name = op->Input("X")[0];
PADDLE_ENFORCE(
fetch_targets.find(fetch_target_name) != fetch_targets.end(),
"Fetch operator input name '%s' cannot be found in 'fetch_targets'",
fetch_target_name);
PADDLE_ENFORCE(op->GetAttr("col").type() == typeid(int),
"Attribute type of 'col' should be int");
}
}

if (fetch_count > 0) {
PADDLE_ENFORCE(
fetch_count == fetch_targets.size(),
"The number of fetch operators should match 'fetch_targets'");
}
return fetch_count > 0;
}

void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use const LoDTensor& here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name,
const std::string& fetch_holder_name) {
auto* copy_program = new ProgramDesc(program);
auto* global_block = copy_program->MutableBlock(0);

VarDesc* feed_holder = nullptr;
VarDesc* fetch_holder = nullptr;
for (auto* var : global_block->AllVars()) {
if (var->GetType() == proto::VarDesc::FEED_MINIBATCH) {
PADDLE_ENFORCE(var->Name() == feed_holder_name,
"'feed_holder_name' should match the program desc");
feed_holder = var;
} else if (var->GetType() == proto::VarDesc::FETCH_LIST) {
PADDLE_ENFORCE(var->Name() == fetch_holder_name,
"'fetch_holder_name' should match the program desc");
fetch_holder = var;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 228 - 238 is used to check whether there is already a feed_holder? In my opinion, if has_feed_operators() returns true, there should exist a feed_holder and we also check the feed_holder_name in has_feed_operators(). Maybe we can simplify the codes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point! Have removed these codes and move the check of feed/fetch_holder_name in has_feed/fetch_operator().


// create a feed_holder variable if not found
if (feed_holder == nullptr) {
feed_holder = global_block->Var(feed_holder_name);
feed_holder->SetType(proto::VarDesc::FEED_MINIBATCH);
feed_holder->SetPersistable(true);
}

// create a fetch_holder variable if not found
if (fetch_holder == nullptr) {
fetch_holder = global_block->Var(fetch_holder_name);
fetch_holder->SetType(proto::VarDesc::FETCH_LIST);
fetch_holder->SetPersistable(true);
}

if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) {
int i = 0;
for (auto& feed_target : feed_targets) {
std::string var_name = feed_target.first;
LOG(INFO) << "feed target's name: " << var_name;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOG(INFO) -> VLOG(3)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// prepend feed op
auto* op = global_block->PrependOp();
op->SetType("feed");
op->SetInput("X", {feed_holder_name});
op->SetOutput("Out", {var_name});
op->SetAttr("col", {static_cast<int>(i)});
op->CheckAttrs();

i++;
}
}

// map the data of feed_targets to feed_holder
for (auto* op : global_block->AllOps()) {
if (op->Type() == "feed") {
std::string feed_target_name = op->Output("Out")[0];
int idx = boost::get<int>(op->GetAttr("col"));
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
idx);
} else {
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the else branch.

}
}

if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) {
int i = 0;
for (auto& fetch_target : fetch_targets) {
std::string var_name = fetch_target.first;
LOG(INFO) << "fetch target's name: " << var_name;

// append fetch op
auto* op = global_block->AppendOp();
op->SetType("fetch");
op->SetInput("X", {var_name});
op->SetOutput("Out", {fetch_holder_name});
op->SetAttr("col", {static_cast<int>(i)});
op->CheckAttrs();

i++;
}
}

Run(*copy_program, scope, 0, true, true);

// obtain the data of fetch_targets from fetch_holder
for (auto* op : global_block->AllOps()) {
if (op->Type() == "fetch") {
std::string fetch_target_name = op->Input("X")[0];
int idx = boost::get<int>(op->GetAttr("col"));
*fetch_targets[fetch_target_name] =
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}

delete copy_program;
}

} // namespace framework
} // namespace paddle
6 changes: 6 additions & 0 deletions paddle/framework/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class Executor {
void Run(const ProgramDesc&, Scope*, int, bool create_local_scope = true,
bool create_vars = true);

void Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");

private:
const platform::Place place_;
};
Expand Down
56 changes: 56 additions & 0 deletions paddle/framework/feed_fetch_method.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/feed_fetch_method.h"
#include "glog/logging.h"
#include "paddle/framework/variable.h"

namespace paddle {
namespace framework {

void SetFeedVariable(Scope* scope, const LoDTensor& input,
const std::string& var_name, size_t index) {
// If var_name Variable is not found in GlobalScope, a new variable will
// be created.
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
Variable* g_feed_value = scope->Var(var_name);
auto& feed_inputs =
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
if (index >= feed_inputs.size()) {
feed_inputs.resize(index + 1);
}
// shared data with input tensor
feed_inputs[index].ShareDataWith(input);
// set lod
feed_inputs[index].set_lod(input.lod());
}

LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
size_t index) {
// Since we want to fetch LodTensor from a variable, the variable must
// be created alreadly.
Variable* g_fetch_value = scope.FindVar(var_name);
PADDLE_ENFORCE(g_fetch_value->IsType<FeedFetchList>(),
"Only %s can be invoked by GetFetchVariable",
typeid(FeedFetchList).name());
auto& fetch_outputs = *g_fetch_value->GetMutable<FeedFetchList>();
auto& tensor = fetch_outputs[index];
VLOG(3) << "Fetch " << var_name << " with index " << index
<< " shape= " << tensor.dims();
PADDLE_ENFORCE_LT(index, fetch_outputs.size());
return tensor;
}

} // namespace framework
} // namespace paddle
34 changes: 3 additions & 31 deletions paddle/framework/feed_fetch_method.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "glog/logging.h"

#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/variable.h"

namespace paddle {
namespace framework {

void SetFeedVariable(Scope* scope, const LoDTensor& input,
const std::string& var_name, size_t index) {
// If var_name Variable is not found in GlobalScope, a new variable will
// be created.
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
Variable* g_feed_value = scope->Var(var_name);
auto& feed_inputs =
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
if (index >= feed_inputs.size()) {
feed_inputs.resize(index + 1);
}
// shared data with input tensor
feed_inputs[index].ShareDataWith(input);
// set lod
feed_inputs[index].set_lod(input.lod());
}
const std::string& var_name, size_t index);

LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
size_t index) {
// Since we want to fetch LodTensor from a variable, the variable must
// be created alreadly.
Variable* g_fetch_value = scope.FindVar(var_name);
PADDLE_ENFORCE(g_fetch_value->IsType<FeedFetchList>(),
"Only %s can be invoked by GetFetchVariable",
typeid(FeedFetchList).name());
auto& fetch_outputs = *g_fetch_value->GetMutable<FeedFetchList>();
auto& tensor = fetch_outputs[index];
VLOG(3) << "Fetch " << var_name << " with index " << index
<< " shape= " << tensor.dims();
PADDLE_ENFORCE_LT(index, fetch_outputs.size());
return tensor;
}
size_t index);

} // namespace framework
} // namespace paddle
14 changes: 8 additions & 6 deletions paddle/inference/inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ limitations under the License. */
#include "inference.h"
#include <fstream>
#include "paddle/framework/executor.h"
#include "paddle/framework/feed_fetch_method.h"
#include "paddle/framework/init.h"
#include "paddle/framework/scope.h"

Expand Down Expand Up @@ -154,7 +153,7 @@ void InferenceEngine::Execute(const std::vector<framework::LoDTensor>& feeds,
LOG(FATAL) << "Please initialize the program_ and load_program_ first.";
}

if (feeds.size() < feed_var_names_.size()) {
if (feeds.size() != feed_var_names_.size()) {
LOG(FATAL) << "Please feed " << feed_var_names_.size() << " input Tensors.";
}

Expand All @@ -165,19 +164,22 @@ void InferenceEngine::Execute(const std::vector<framework::LoDTensor>& feeds,

executor->Run(*load_program_, scope, 0, true, true);

std::map<std::string, const framework::LoDTensor*> feed_targets;
std::map<std::string, framework::LoDTensor*> fetch_targets;

// set_feed_variable
for (size_t i = 0; i < feed_var_names_.size(); ++i) {
framework::SetFeedVariable(scope, feeds[i], "feed", i);
feed_targets[feed_var_names_[i]] = &feeds[i];
}

executor->Run(*program_, scope, 0, true, true);

// get_fetch_variable
fetchs.resize(fetch_var_names_.size());
for (size_t i = 0; i < fetch_var_names_.size(); ++i) {
fetchs[i] = framework::GetFetchVariable(*scope, "fetch", i);
fetch_targets[fetch_var_names_[i]] = &fetchs[i];
}

executor->Run(*program_, scope, feed_targets, fetch_targets);

delete place;
delete scope;
delete executor;
Expand Down
2 changes: 1 addition & 1 deletion paddle/pybind/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
if(WITH_PYTHON)
cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc const_value.cc
DEPS pybind python backward proto_desc paddle_memory executor prune init profiler
DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method
${GLOB_OP_LIB})
if(NOT APPLE AND NOT ANDROID)
target_link_libraries(paddle_pybind rt)
Expand Down
4 changes: 3 additions & 1 deletion paddle/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,9 @@ All parameter, weight, gradient are variables in Paddle.

py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>())
.def("run", &Executor::Run);
.def("run",
(void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) &
Executor::Run);

m.def("unique_integer", UniqueIntegerGenerator);
m.def("init_gflags", framework::InitGflags);
Expand Down