Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable generating code for a given subgraph. #21126

Merged
merged 10 commits into from
Nov 20, 2019
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/fusion_group/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph)
if(NOT APPLE AND NOT WIN32)
if(WITH_GPU)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass)
endif()
endif()

Expand Down
201 changes: 169 additions & 32 deletions paddle/fluid/framework/ir/fusion_group/code_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include <set>
#include <sstream>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/operation.h"

namespace paddle {
namespace framework {
Expand All @@ -30,69 +31,205 @@ CodeGenerator::CodeGenerator() {
code_templates_[0] = elementwise_t;
}

std::string CodeGenerator::Generate(SubGraph* subgraph) {
std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph);
return Generate(subgraph->func_name, expressions);
}

std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
SubGraph* subgraph) {
std::unordered_map<std::string, int> var_ids = EncodeVarNodes(subgraph);
std::vector<OperationExpression> expressions;
for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsOp() && node->Op()) {
auto* op = node->Op();

// Input ids should be set in fixed order, like:
// - x, y in forward operations
// - x, y, out, out@GRAD in backward operations
std::vector<int> input_ids;
std::vector<std::string> input_names =
OperationMap::Instance().Get(op->Type()).input_names;
for (auto& name : input_names) {
// TODO(liuyiqun): support duplicated input.
if (op->Input(name).size() >= 1U) {
// Some input vars are not used in grad ops, such as
// "elementwise_add_grad", where "X", "Y" and "Out" are not used.
PADDLE_ENFORCE_NE(var_ids.find(op->Input(name)[0]), var_ids.end(),
"Input(%s) of operation %s should be set.", name,
op->Type());
input_ids.push_back(var_ids[op->Input(name)[0]]);
} else {
input_ids.push_back(-1);
}
}
// Output ids should be set in fixed order, like:
// - dx, dy in backward operations
std::vector<int> output_ids;
std::vector<std::string> output_names =
OperationMap::Instance().Get(op->Type()).output_names;
for (auto& name : output_names) {
PADDLE_ENFORCE_EQ(op->Output(name).size(), 1U,
"Output(%s) of operation %s should be set.", name,
op->Type());
PADDLE_ENFORCE_NE(var_ids.find(op->Output(name)[0]), var_ids.end(),
"Output(%s) of operation %s should be set.", name,
op->Type());
output_ids.push_back(var_ids[op->Output(name)[0]]);
}
expressions.push_back(
OperationExpression(node->Name(), input_ids, output_ids));
}
}
return expressions;
}

// In order to get the right result of expression, we need to calculate and
// store the expression as suffix Expressions using vector.
std::string CodeGenerator::GenerateCode(
std::string CodeGenerator::Generate(
std::string func_name, std::vector<OperationExpression> expressions) {
// Check whether all expressions are elementwise operations.
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
std::string dtype = "float";
std::set<int> input_ids = DistilInputIds(expressions);
std::set<int> output_ids = DistilOutputIds(expressions);

TemplateVariable template_var;
template_var.Add("func_name", func_name);
template_var.Add("parameters", EmitParameters(expressions, "float"));
template_var.Add("compute_body", EmitComputeBody(expressions));
template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtype));
template_var.Add("compute_body",
EmitComputeBody(expressions, input_ids, output_ids, dtype));
return predefined_cuda_functions + code_templates_[0].Format(template_var);
}

// we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters(
std::vector<OperationExpression> expressions, std::string dtype) {
std::set<int> CodeGenerator::DistilInputIds(
const std::vector<OperationExpression>& expressions) {
std::set<int> input_ids;
std::set<int> output_ids;
// Remove the reptead id and get a ordered list.
// Use std::set to remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetInputIds()) {
input_ids.insert(id);
if (id >= 0) {
input_ids.insert(id);
}
}
}
return input_ids;
}

std::set<int> CodeGenerator::DistilOutputIds(
const std::vector<OperationExpression>& expressions) {
std::set<int> output_ids;
// Use std::set to remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetOutputIds()) {
output_ids.insert(id);
}
}
return output_ids;
}

// we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids,
std::string dtype) {
std::stringstream ret;
ret << "int N, ";

// If a id is in the input and output list at the same time, then remove it
// from the input list.
for (auto iter = input_ids.begin(); iter != input_ids.end();) {
if (output_ids.find(*iter) != output_ids.end()) {
input_ids.erase(iter++);
} else {
iter++;
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end()) {
ret << dtype << "* " << ArgName(id) << ", ";
}
}

std::stringstream ret;
ret << "int N, ";
for (auto iter = input_ids.begin(); iter != input_ids.end(); iter++) {
ret << dtype << "* " << VarName(*iter) << ", ";
}

size_t count_index = 0;
for (auto iter = output_ids.begin(); iter != output_ids.end(); iter++) {
ret << dtype << "* " << VarName(*iter);
if (count_index != output_ids.size() - 1) {
size_t index = 0;
for (auto id : output_ids) {
ret << dtype << "* " << ArgName(id);
if (index != output_ids.size() - 1) {
ret << ", ";
}
count_index++;
index++;
}

return ret.str();
}

std::string CodeGenerator::EmitComputeBody(
std::vector<OperationExpression> expressions) {
// get the right experssion code using suffix expression
std::stringstream ret;
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype) {
std::ostringstream compute;
std::unordered_set<int> used;
for (size_t i = 0; i < expressions.size(); i++) {
ret << expressions[i].GetExpression();
VLOG(3) << DebugString(expressions[i]);
compute << expressions[i].GetExpression(dtype, &used);
}
return ret.str();

// Load input to temporal variables.
std::ostringstream load;
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end() &&
used.find(id) != used.end()) {
load << dtype << " " << TmpName(id) << " = " << ArgName(id) << "[idx];";
}
}

// Store temporal variables to memory.
std::ostringstream store;
for (auto id : output_ids) {
store << ArgName(id) << "[idx] = " << TmpName(id) << ";";
}

return load.str() + compute.str() + store.str();
}

std::unordered_map<std::string, int> CodeGenerator::EncodeVarNodes(
SubGraph* subgraph) {
const auto& input_var_nodes = subgraph->GetInputVarNodes();
const auto& output_var_nodes = subgraph->GetOutputVarNodes();

int id = 0;
std::unordered_map<std::string, int> var_ids;
// Numbering input vars.
for (auto* in : input_var_nodes) {
VLOG(3) << "Encoding input names:" << in->Name() << ", id:" << id;
if (var_ids.find(in->Name()) == var_ids.end()) {
var_ids[in->Name()] = id++;
}
}
// Numbering internal vars.
for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsVar() && node->Var()) {
bool is_found = false;
for (auto* in : input_var_nodes) {
if (node == in) {
is_found = true;
break;
}
}
if (is_found) {
continue;
}
for (auto* out : output_var_nodes) {
if (node == out) {
is_found = true;
break;
}
}
PADDLE_ENFORCE_EQ(
is_found, true,
"Subgraph with internal var nodes (%s) is not supported yet.",
node->Name());
}
}
// Encoding output vars.
for (auto* out : output_var_nodes) {
VLOG(3) << "Ecoding output names:" << out->Name() << ", id:" << id;
if (var_ids.find(out->Name()) == var_ids.end()) {
var_ids[out->Name()] = id++;
}
}
return var_ids;
}

} // namespace fusion_group
Expand Down
28 changes: 22 additions & 6 deletions paddle/fluid/framework/ir/fusion_group/code_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ limitations under the License. */

#pragma once

#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"

namespace paddle {
namespace framework {
Expand All @@ -27,18 +30,31 @@ class CodeGenerator {
public:
CodeGenerator();

std::string GenerateCode(std::string func_name,
std::vector<OperationExpression> expressions);
std::string Generate(std::string func_name,
std::vector<OperationExpression> expressions);

// TODO(wangchao): add a more general interface
// std::string Generate(const std::string name, const SubGraph& subgraph);
std::string Generate(SubGraph* subgraph);

std::vector<OperationExpression> ConvertToExpressions(SubGraph* subgraph);

private:
std::set<int> DistilInputIds(
const std::vector<OperationExpression>& expressions);
std::set<int> DistilOutputIds(
const std::vector<OperationExpression>& expressions);

// we get the parameter list code for the expression information
std::string EmitParameters(std::vector<OperationExpression> expressions,
std::string EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids,
std::string dtype);

std::string EmitComputeBody(std::vector<OperationExpression> expressions);
std::string EmitComputeBody(
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype);

// Encode all var nodes in the subgraph with an unique number.
std::unordered_map<std::string, int> EncodeVarNodes(SubGraph* subgraph);

private:
std::vector<CodeTemplate> code_templates_;
Expand Down
23 changes: 14 additions & 9 deletions paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ static T StringTo(const std::string& str) {
return value;
}

std::string OperationExpression::GetRHS(size_t i) {
auto rhs = OperationMap::Instance().Get(op_).exprs[i];
std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
size_t i) const {
auto rhs = OperationMap::Instance().Get(op_type_).exprs[i];
for (size_t i = 0; i < rhs.size(); i++) {
size_t pos = i;
if (rhs[pos] == '$' && rhs[pos + 1] == '{') {
Expand All @@ -47,29 +48,33 @@ std::string OperationExpression::GetRHS(size_t i) {
PADDLE_ENFORCE_LT(index, input_ids_.size(),
"Only %d inputs are provided, but need %d.",
input_ids_.size(), index + 1);
rhs.replace(pos, length + 3, VarName(input_ids_[index]) + R"([idx])");
PADDLE_ENFORCE_GE(input_ids_[index], 0,
"Input id should be no less than 0.");
rhs.replace(pos, length + 3, TmpName(input_ids_[index]));
used->insert(input_ids_[index]);
}
}
return rhs;
}

std::string OperationExpression::GetLHS(size_t i) {
std::string OperationExpression::GetLHS(size_t i) const {
std::stringstream ret;
ret << VarName(output_ids_[i]) << R"([idx])";
ret << TmpName(output_ids_[i]);
return ret.str();
}

bool OperationExpression::IsSupport() {
return OperationMap::Instance().Has(op_);
bool OperationExpression::IsSupport() const {
return OperationMap::Instance().Has(op_type_);
}

// we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group
std::string OperationExpression::GetExpression() {
std::string OperationExpression::GetExpression(
std::string dtype, std::unordered_set<int>* used) const {
std::stringstream ret;
if (IsSupport()) {
for (size_t i = 0; i < output_ids_.size(); ++i) {
ret << GetLHS(i) << " = " << GetRHS(i) << ";";
ret << dtype << " " << GetLHS(i) << " = " << GetRHS(used, i) << ";";
}
}

Expand Down
Loading