-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New Run() method for framework::Executor #7807
Changes from all commits
4e7ce62
466d40b
9166de5
a4d99fe
55b2e4a
5f1c614
ab9b859
4d255f0
9ff7578
e8f7777
3b8b462
9a3b510
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ limitations under the License. */ | |
#include <set> | ||
|
||
#include "gflags/gflags.h" | ||
#include "paddle/framework/feed_fetch_method.h" | ||
#include "paddle/framework/feed_fetch_type.h" | ||
#include "paddle/framework/lod_rank_table.h" | ||
#include "paddle/framework/lod_tensor_array.h" | ||
|
@@ -149,5 +150,168 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, | |
} | ||
} | ||
|
||
// Check whether the block already has feed operators and feed_holder. | ||
// Return false if the block does not have any feed operators. | ||
// If some feed operators have been prepended to the block, check that | ||
// the info contained in these feed operators matches the feed_targets | ||
// and feed_holder_name. Raise exception when any mismatch is found. | ||
// Return true if the block has feed operators and holder of matching info. | ||
static bool has_feed_operators( | ||
BlockDesc* block, std::map<std::string, const LoDTensor*>& feed_targets, | ||
const std::string& feed_holder_name) { | ||
size_t feed_count = 0; | ||
for (auto* op : block->AllOps()) { | ||
if (op->Type() == kFeedOpType) { | ||
feed_count++; | ||
PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name, | ||
"Input to feed op should be '%s'", feed_holder_name); | ||
std::string feed_target_name = op->Output("Out")[0]; | ||
PADDLE_ENFORCE( | ||
feed_targets.find(feed_target_name) != feed_targets.end(), | ||
"Feed operator output name '%s' cannot be found in 'feed_targets'", | ||
feed_target_name); | ||
} else { | ||
break; | ||
} | ||
} | ||
|
||
if (feed_count > 0) { | ||
PADDLE_ENFORCE_EQ( | ||
feed_count, feed_targets.size(), | ||
"The number of feed operators should match 'feed_targets'"); | ||
|
||
// When feed operator are present, so should be feed_holder | ||
auto var = block->FindVar(feed_holder_name); | ||
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", | ||
feed_holder_name); | ||
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FEED_MINIBATCH, | ||
"'%s' variable should be 'FEED_MINIBATCH' type", | ||
feed_holder_name); | ||
} | ||
|
||
return feed_count > 0; | ||
} | ||
|
||
// Check whether the block already has fetch operators and fetch_holder. | ||
// Return false if the block does not have any fetch operators. | ||
// If some fetch operators have been appended to the block, check that | ||
// the info contained in these fetch operators matches the fetch_targets | ||
// and fetch_holder_name. Raise exception when any mismatch is found. | ||
// Return true if the block has fetch operators and holder of matching info. | ||
static bool has_fetch_operators( | ||
BlockDesc* block, std::map<std::string, LoDTensor*>& fetch_targets, | ||
const std::string& fetch_holder_name) { | ||
size_t fetch_count = 0; | ||
for (auto* op : block->AllOps()) { | ||
if (op->Type() == kFetchOpType) { | ||
fetch_count++; | ||
PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name, | ||
"Output of fetch op should be '%s'", fetch_holder_name); | ||
std::string fetch_target_name = op->Input("X")[0]; | ||
PADDLE_ENFORCE( | ||
fetch_targets.find(fetch_target_name) != fetch_targets.end(), | ||
"Fetch operator input name '%s' cannot be found in 'fetch_targets'", | ||
fetch_target_name); | ||
} | ||
} | ||
|
||
if (fetch_count > 0) { | ||
PADDLE_ENFORCE_EQ( | ||
fetch_count, fetch_targets.size(), | ||
"The number of fetch operators should match 'fetch_targets'"); | ||
|
||
// When fetch operator are present, so should be fetch_holder | ||
auto var = block->FindVar(fetch_holder_name); | ||
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", | ||
fetch_holder_name); | ||
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FETCH_LIST, | ||
"'%s' variable should be 'FETCH_LIST' type", | ||
fetch_holder_name); | ||
} | ||
|
||
return fetch_count > 0; | ||
} | ||
|
||
void Executor::Run(const ProgramDesc& program, Scope* scope, | ||
std::map<std::string, const LoDTensor*>& feed_targets, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reminded by @kexinzhao , it is impossible to use reference in |
||
std::map<std::string, LoDTensor*>& fetch_targets, | ||
const std::string& feed_holder_name, | ||
const std::string& fetch_holder_name) { | ||
auto* copy_program = new ProgramDesc(program); | ||
auto* global_block = copy_program->MutableBlock(0); | ||
|
||
if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) { | ||
// create feed_holder variable | ||
auto* feed_holder = global_block->Var(feed_holder_name); | ||
feed_holder->SetType(proto::VarDesc::FEED_MINIBATCH); | ||
feed_holder->SetPersistable(true); | ||
|
||
int i = 0; | ||
for (auto& feed_target : feed_targets) { | ||
std::string var_name = feed_target.first; | ||
VLOG(3) << "feed target's name: " << var_name; | ||
|
||
// prepend feed op | ||
auto* op = global_block->PrependOp(); | ||
op->SetType(kFeedOpType); | ||
op->SetInput("X", {feed_holder_name}); | ||
op->SetOutput("Out", {var_name}); | ||
op->SetAttr("col", {static_cast<int>(i)}); | ||
op->CheckAttrs(); | ||
|
||
i++; | ||
} | ||
} | ||
|
||
// map the data of feed_targets to feed_holder | ||
for (auto* op : global_block->AllOps()) { | ||
if (op->Type() == kFeedOpType) { | ||
std::string feed_target_name = op->Output("Out")[0]; | ||
int idx = boost::get<int>(op->GetAttr("col")); | ||
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name, | ||
idx); | ||
} else { | ||
break; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the else branch. |
||
} | ||
} | ||
|
||
if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) { | ||
// create fetch_holder variable | ||
auto* fetch_holder = global_block->Var(fetch_holder_name); | ||
fetch_holder->SetType(proto::VarDesc::FETCH_LIST); | ||
fetch_holder->SetPersistable(true); | ||
|
||
int i = 0; | ||
for (auto& fetch_target : fetch_targets) { | ||
std::string var_name = fetch_target.first; | ||
VLOG(3) << "fetch target's name: " << var_name; | ||
|
||
// append fetch op | ||
auto* op = global_block->AppendOp(); | ||
op->SetType(kFetchOpType); | ||
op->SetInput("X", {var_name}); | ||
op->SetOutput("Out", {fetch_holder_name}); | ||
op->SetAttr("col", {static_cast<int>(i)}); | ||
op->CheckAttrs(); | ||
|
||
i++; | ||
} | ||
} | ||
|
||
Run(*copy_program, scope, 0, true, true); | ||
|
||
// obtain the data of fetch_targets from fetch_holder | ||
for (auto* op : global_block->AllOps()) { | ||
if (op->Type() == kFetchOpType) { | ||
std::string fetch_target_name = op->Input("X")[0]; | ||
int idx = boost::get<int>(op->GetAttr("col")); | ||
*fetch_targets[fetch_target_name] = | ||
GetFetchVariable(*scope, fetch_holder_name, idx); | ||
} | ||
} | ||
|
||
delete copy_program; | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/framework/feed_fetch_method.h" | ||
#include "glog/logging.h" | ||
#include "paddle/framework/variable.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
void SetFeedVariable(Scope* scope, const LoDTensor& input, | ||
const std::string& var_name, size_t index) { | ||
// If var_name Variable is not found in GlobalScope, a new variable will | ||
// be created. | ||
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index; | ||
Variable* g_feed_value = scope->Var(var_name); | ||
auto& feed_inputs = | ||
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>()); | ||
if (index >= feed_inputs.size()) { | ||
feed_inputs.resize(index + 1); | ||
} | ||
// shared data with input tensor | ||
feed_inputs[index].ShareDataWith(input); | ||
// set lod | ||
feed_inputs[index].set_lod(input.lod()); | ||
} | ||
|
||
LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name, | ||
size_t index) { | ||
// Since we want to fetch LodTensor from a variable, the variable must | ||
// be created alreadly. | ||
Variable* g_fetch_value = scope.FindVar(var_name); | ||
PADDLE_ENFORCE(g_fetch_value->IsType<FeedFetchList>(), | ||
"Only %s can be invoked by GetFetchVariable", | ||
typeid(FeedFetchList).name()); | ||
auto& fetch_outputs = *g_fetch_value->GetMutable<FeedFetchList>(); | ||
auto& tensor = fetch_outputs[index]; | ||
VLOG(3) << "Fetch " << var_name << " with index " << index | ||
<< " shape= " << tensor.dims(); | ||
PADDLE_ENFORCE_LT(index, fetch_outputs.size()); | ||
return tensor; | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the
else
branch. It seems not necessary here.