-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New Run() method for framework::Executor #7807
Changes from 10 commits
4e7ce62
466d40b
9166de5
a4d99fe
55b2e4a
5f1c614
ab9b859
4d255f0
9ff7578
e8f7777
3b8b462
9a3b510
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ limitations under the License. */ | |
#include <set> | ||
|
||
#include "gflags/gflags.h" | ||
#include "paddle/framework/feed_fetch_method.h" | ||
#include "paddle/framework/feed_fetch_type.h" | ||
#include "paddle/framework/lod_rank_table.h" | ||
#include "paddle/framework/lod_tensor_array.h" | ||
|
@@ -149,5 +150,169 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, | |
} | ||
} | ||
|
||
// Return false if the block does not have any feed operators. | ||
// If some feed operators have been prepended to the block, check that | ||
// the info contained in these feed operators matches the feed_targets | ||
// and feed_holder_name. Raise exception when any mismatch is found. | ||
// Return true when the block has feed operators with matching info. | ||
static bool has_feed_operators( | ||
BlockDesc* block, std::map<std::string, const LoDTensor*>& feed_targets, | ||
const std::string& feed_holder_name) { | ||
size_t feed_count = 0; | ||
for (auto* op : block->AllOps()) { | ||
if (op->Type() == "feed") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
feed_count++; | ||
PADDLE_ENFORCE(op->Input("X")[0] == feed_holder_name, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
"Input to feed op should be '%s'", feed_holder_name); | ||
std::string feed_target_name = op->Output("Out")[0]; | ||
PADDLE_ENFORCE( | ||
feed_targets.find(feed_target_name) != feed_targets.end(), | ||
"Feed operator output name '%s' cannot be found in 'feed_targets'", | ||
feed_target_name); | ||
PADDLE_ENFORCE(op->GetAttr("col").type() == typeid(int), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to check this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed this. |
||
"Attribute type of 'col' should be int"); | ||
} else { | ||
break; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the |
||
} | ||
} | ||
|
||
if (feed_count > 0) { | ||
PADDLE_ENFORCE(feed_count == feed_targets.size(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
"The number of feed operators should match 'feed_targets'"); | ||
} | ||
return feed_count > 0; | ||
} | ||
|
||
// Return false if the block does not have any fetch operators. | ||
// If some fetch operators have been appended to the block, check that | ||
// the info contained in these fetch operators matches the fetch_targets | ||
// and fetch_holder_name. Raise exception when any mismatch is found. | ||
// Return true when the block has fetch operators with matching info. | ||
static bool has_fetch_operators( | ||
BlockDesc* block, std::map<std::string, LoDTensor*>& fetch_targets, | ||
const std::string& fetch_holder_name) { | ||
size_t fetch_count = 0; | ||
for (auto* op : block->AllOps()) { | ||
if (op->Type() == "fetch") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
fetch_count++; | ||
PADDLE_ENFORCE(op->Output("Out")[0] == fetch_holder_name, | ||
"Output of fetch op should be '%s'", fetch_holder_name); | ||
std::string fetch_target_name = op->Input("X")[0]; | ||
PADDLE_ENFORCE( | ||
fetch_targets.find(fetch_target_name) != fetch_targets.end(), | ||
"Fetch operator input name '%s' cannot be found in 'fetch_targets'", | ||
fetch_target_name); | ||
PADDLE_ENFORCE(op->GetAttr("col").type() == typeid(int), | ||
"Attribute type of 'col' should be int"); | ||
} | ||
} | ||
|
||
if (fetch_count > 0) { | ||
PADDLE_ENFORCE( | ||
fetch_count == fetch_targets.size(), | ||
"The number of fetch operators should match 'fetch_targets'"); | ||
} | ||
return fetch_count > 0; | ||
} | ||
|
||
void Executor::Run(const ProgramDesc& program, Scope* scope, | ||
std::map<std::string, const LoDTensor*>& feed_targets, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reminded by @kexinzhao , it is impossible to use reference in |
||
std::map<std::string, LoDTensor*>& fetch_targets, | ||
const std::string& feed_holder_name, | ||
const std::string& fetch_holder_name) { | ||
auto* copy_program = new ProgramDesc(program); | ||
auto* global_block = copy_program->MutableBlock(0); | ||
|
||
VarDesc* feed_holder = nullptr; | ||
VarDesc* fetch_holder = nullptr; | ||
for (auto* var : global_block->AllVars()) { | ||
if (var->GetType() == proto::VarDesc::FEED_MINIBATCH) { | ||
PADDLE_ENFORCE(var->Name() == feed_holder_name, | ||
"'feed_holder_name' should match the program desc"); | ||
feed_holder = var; | ||
} else if (var->GetType() == proto::VarDesc::FETCH_LIST) { | ||
PADDLE_ENFORCE(var->Name() == fetch_holder_name, | ||
"'fetch_holder_name' should match the program desc"); | ||
fetch_holder = var; | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great point! Have removed these codes and move the check of feed/fetch_holder_name in has_feed/fetch_operator(). |
||
|
||
// create a feed_holder variable if not found | ||
if (feed_holder == nullptr) { | ||
feed_holder = global_block->Var(feed_holder_name); | ||
feed_holder->SetType(proto::VarDesc::FEED_MINIBATCH); | ||
feed_holder->SetPersistable(true); | ||
} | ||
|
||
// create a fetch_holder variable if not found | ||
if (fetch_holder == nullptr) { | ||
fetch_holder = global_block->Var(fetch_holder_name); | ||
fetch_holder->SetType(proto::VarDesc::FETCH_LIST); | ||
fetch_holder->SetPersistable(true); | ||
} | ||
|
||
if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) { | ||
int i = 0; | ||
for (auto& feed_target : feed_targets) { | ||
std::string var_name = feed_target.first; | ||
LOG(INFO) << "feed target's name: " << var_name; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LOG(INFO) -> VLOG(3) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
// prepend feed op | ||
auto* op = global_block->PrependOp(); | ||
op->SetType("feed"); | ||
op->SetInput("X", {feed_holder_name}); | ||
op->SetOutput("Out", {var_name}); | ||
op->SetAttr("col", {static_cast<int>(i)}); | ||
op->CheckAttrs(); | ||
|
||
i++; | ||
} | ||
} | ||
|
||
// map the data of feed_targets to feed_holder | ||
for (auto* op : global_block->AllOps()) { | ||
if (op->Type() == "feed") { | ||
std::string feed_target_name = op->Output("Out")[0]; | ||
int idx = boost::get<int>(op->GetAttr("col")); | ||
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, | ||
idx); | ||
} else { | ||
break; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the else branch. |
||
} | ||
} | ||
|
||
if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) { | ||
int i = 0; | ||
for (auto& fetch_target : fetch_targets) { | ||
std::string var_name = fetch_target.first; | ||
LOG(INFO) << "fetch target's name: " << var_name; | ||
|
||
// append fetch op | ||
auto* op = global_block->AppendOp(); | ||
op->SetType("fetch"); | ||
op->SetInput("X", {var_name}); | ||
op->SetOutput("Out", {fetch_holder_name}); | ||
op->SetAttr("col", {static_cast<int>(i)}); | ||
op->CheckAttrs(); | ||
|
||
i++; | ||
} | ||
} | ||
|
||
Run(*copy_program, scope, 0, true, true); | ||
|
||
// obtain the data of fetch_targets from fetch_holder | ||
for (auto* op : global_block->AllOps()) { | ||
if (op->Type() == "fetch") { | ||
std::string fetch_target_name = op->Input("X")[0]; | ||
int idx = boost::get<int>(op->GetAttr("col")); | ||
*fetch_targets[fetch_target_name] = | ||
GetFetchVariable(*scope, fetch_holder_name, idx); | ||
} | ||
} | ||
|
||
delete copy_program; | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/framework/feed_fetch_method.h" | ||
#include "glog/logging.h" | ||
#include "paddle/framework/variable.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
void SetFeedVariable(Scope* scope, const LoDTensor& input, | ||
const std::string& var_name, size_t index) { | ||
// If var_name Variable is not found in GlobalScope, a new variable will | ||
// be created. | ||
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index; | ||
Variable* g_feed_value = scope->Var(var_name); | ||
auto& feed_inputs = | ||
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>()); | ||
if (index >= feed_inputs.size()) { | ||
feed_inputs.resize(index + 1); | ||
} | ||
// shared data with input tensor | ||
feed_inputs[index].ShareDataWith(input); | ||
// set lod | ||
feed_inputs[index].set_lod(input.lod()); | ||
} | ||
|
||
LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name, | ||
size_t index) { | ||
// Since we want to fetch LodTensor from a variable, the variable must | ||
// be created alreadly. | ||
Variable* g_fetch_value = scope.FindVar(var_name); | ||
PADDLE_ENFORCE(g_fetch_value->IsType<FeedFetchList>(), | ||
"Only %s can be invoked by GetFetchVariable", | ||
typeid(FeedFetchList).name()); | ||
auto& fetch_outputs = *g_fetch_value->GetMutable<FeedFetchList>(); | ||
auto& tensor = fetch_outputs[index]; | ||
VLOG(3) << "Fetch " << var_name << " with index " << index | ||
<< " shape= " << tensor.dims(); | ||
PADDLE_ENFORCE_LT(index, fetch_outputs.size()); | ||
return tensor; | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add description of
Check whether the block already has feed operators.
here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done