Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#50 from tc20042008/xk-cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
Xk cinn trivalop fuse
  • Loading branch information
tc20042008 authored Mar 11, 2024
2 parents 302ba60 + 0a97ad9 commit a6cfd99
Showing 1 changed file with 147 additions and 21 deletions.
168 changes: 147 additions & 21 deletions paddle/cinn/frontend/group_pattern_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,14 +226,17 @@ ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) {
LOG(FATAL) << "Dead code";
}

template<typename InputIt>
std::unordered_map<pir::Value, ShardableAxes> ReversedInferShardableAxes(
common::TopoWalker<const pir::Operation*>& reversed_walker,
const pir::Operation* sink,
const ShardableAxes& init_sa) {
std::unordered_map<pir::Value, ShardableAxes> value2shardable_axes{
{sink->result(0), init_sa}};
const auto& UpdateValue2ShardableAxes = [&](pir::Value value,
const ShardableAxes& sa) {
const common::TopoWalker<const pir::Operation*>& reversed_walker,
InputIt sink_and_init_begin, InputIt sink_and_init_end) {
std::unordered_map<pir::Value, ShardableAxes> value2shardable_axes;
std::list<const pir::Operation*> sinks;
for (auto iter = sink_and_init_begin; iter != sink_and_init_end; ++iter) {
sinks.push_back(iter->first.defining_op());
value2shardable_axes[iter->first] = iter->second;
}
const auto& UpdateValue2ShardableAxes = [&](pir::Value value, const ShardableAxes& sa) {
auto iter = value2shardable_axes.find(value);
if (iter != value2shardable_axes.end()) {
iter->second =
Expand All @@ -242,7 +245,7 @@ std::unordered_map<pir::Value, ShardableAxes> ReversedInferShardableAxes(
iter->second = sa;
}
};
reversed_walker(sink, [&](const auto* op) {
reversed_walker(sinks.begin(), sinks.end(), [&](const auto* op) {
auto shardable_axes_sig = MakeShardableAxesSignature4Op(op);
const auto& old2new = ShardableAxesUtil::GetOldName2NewName(
shardable_axes_sig.output_shardable_axes,
Expand All @@ -259,8 +262,17 @@ std::unordered_map<pir::Value, ShardableAxes> ReversedInferShardableAxes(
return value2shardable_axes;
}

common::TopoWalker<const pir::Operation*> GetOpsTopoWalker(
const std::unordered_set<const pir::Operation*>& ops) {
std::unordered_map<pir::Value, ShardableAxes> ReversedInferShardableAxes(
const common::TopoWalker<const pir::Operation*>& reversed_walker,
const pir::Operation* sink,
const ShardableAxes& init_sa) {
using OpAndInitValue = std::pair<pir::Value, ShardableAxes>;
CHECK_EQ(sink->num_results(), 1);
std::array<OpAndInitValue, 1> sinks{OpAndInitValue{sink->result(0), init_sa}};
return ReversedInferShardableAxes(reversed_walker, sinks.begin(), sinks.end());
}

common::TopoWalker<const pir::Operation*> GetOpsTopoWalker(const std::unordered_set<const pir::Operation*>& ops) {
const auto* ops_set = &ops;
const auto VisitUpStreamInOps = [ops_set](const pir::Operation* op,
const OpVisitor& DoEach) {
Expand Down Expand Up @@ -305,6 +317,128 @@ std::list<const pir::Operation*> GetSinks(
return sinks;
}

std::unordered_map<const pir::Operation*, ShardableAxesSignature>
GetOp2ShardableAxesSignature(const std::unordered_set<const pir::Operation*>& ops) {
std::unordered_map<const pir::Operation*, ShardableAxesSignature> ret;
for (const auto* op : ops) {
ret[op] = MakeShardableAxesSignature4Op(op);
}
return ret;
}

std::map<std::string, std::vector<std::string>>
GetAxisName2BoundAxisName(
const std::unordered_set<const pir::Operation*>& ops,
const std::unordered_map<const pir::Operation*, ShardableAxesSignature>& op2shardable_axes_signature) {
const auto GetInputShardableAxes = [&](const OpAndOperandIndex& op_and_idx) -> std::optional<const ShardableAxes*> {
const auto& [op, idx] = op_and_idx;
const auto* input_op = op->operand_source(idx).defining_op();
if (ops.count(input_op) == 0) return std::nullopt;
const auto& iter = op2shardable_axes_signature.find(input_op);
if (iter == op2shardable_axes_signature.end()) return std::nullopt;
const auto& output_sa = iter->second.output_shardable_axes;
return &output_sa;
};
std::map<std::string, std::vector<std::string>> axis_name2bound_axis_name;
const auto UpdateAxisName2BoundAxisName = [&](const ShardableAxes& input_sa, const ShardableAxes& sa) {
for (const auto& [input_axis, input_axis_name] : input_sa) {
for (const auto& [axis, axis_name] : sa) {
if (input_axis != axis) continue;
axis_name2bound_axis_name[axis_name].push_back(input_axis_name);
axis_name2bound_axis_name[input_axis_name].push_back(axis_name);
}
}
};
for (const auto& [op, signature] : op2shardable_axes_signature) {
for (const auto& [op_and_idx, sa] : signature.input_shardable_axes) {
const auto& input_sa = GetInputShardableAxes(op_and_idx);
if (!input_sa.has_value()) continue;
UpdateAxisName2BoundAxisName(*input_sa.value(), sa);
}
}
return axis_name2bound_axis_name;
}

std::unordered_map<std::string, std::string>
GetAxisName2UnionFindSetRoot(
const std::unordered_set<const pir::Operation*>& ops,
const std::unordered_map<const pir::Operation*, ShardableAxesSignature>& op2shardable_axes_signature) {
const auto axis_name2bound_axis_name = GetAxisName2BoundAxisName(ops, op2shardable_axes_signature);
using NodeVisitor = std::function<void(const std::string&)>;
const auto VisitNext = [&](const std::string& axis_name, const NodeVisitor& DoEach) {
const auto& iter = axis_name2bound_axis_name.find(axis_name);
if (iter == axis_name2bound_axis_name.end()) return;
for (const auto& input_axis_name : iter->second) {
DoEach(input_axis_name);
}
};
common::BfsWalker<std::string> walk(VisitNext);
std::unordered_map<std::string, std::string> axis_name2root;
for (const auto& [union_find_root, _] : axis_name2bound_axis_name) {
if (axis_name2root.count(union_find_root) > 0) continue;
walk(union_find_root, [&](const std::string& axis_name){
CHECK(axis_name2root.emplace(axis_name, union_find_root).second);
});
}
return axis_name2root;
}

std::unordered_map<pir::Value, ShardableAxes>
GetSinkAndInitShardableAxes(
const std::list<const pir::Operation*>& sinks,
const std::unordered_map<const pir::Operation*, ShardableAxesSignature>& op2shardable_axes_signature,
const std::unordered_map<std::string, std::string>& axis_name2union_find_set_root) {
const auto& ConvertByBoundAxisName = [&](const ShardableAxes& sa) {
ShardableAxes ret_sa;
for (const auto& [axis, axis_name] : sa) {
const auto& iter = axis_name2union_find_set_root.find(axis_name);
CHECK(iter != axis_name2union_find_set_root.end());
ret_sa.emplace_back(ShardableAxis{
.axis=axis,
.axis_name=iter->second,
});
}
return ret_sa;
};
std::unordered_map<pir::Value, ShardableAxes> sink2sa;
for (const auto* sink : sinks) {
const auto& sig_iter = op2shardable_axes_signature.find(sink);
CHECK(sig_iter != op2shardable_axes_signature.end());
const auto& output_shardable_axes = sig_iter->second.output_shardable_axes;
CHECK_EQ(sink->num_results(), 1);
sink2sa[sink->result(0)] = ConvertByBoundAxisName(output_shardable_axes);
}
return sink2sa;
}

void RenameDuplicatedAxisName(std::unordered_map<pir::Value, ShardableAxes>* sink2sa) {
const auto& RenameDuplicated = [&](ShardableAxes* sa) {
std::set<std::string> existed_axis_name;
for (auto& [_, axis_name] : *sa) {
if (!existed_axis_name.emplace(axis_name).second) {
axis_name = axis_name + "_" + std::to_string(ShardableAxis::UnqiueSeqNo());
} else {
// do nothing.
}
}
};
for (auto& [_, sa] : *sink2sa) {
RenameDuplicated(&sa);
}
}

std::unordered_map<pir::Value, ShardableAxes> GetSinkAndInitValues(
const common::TopoWalker<const pir::Operation*>& reverse_walker,
const std::unordered_set<const pir::Operation*>& ops,
const std::list<const pir::Operation*>& sinks) {
const auto& op2shardable_axes_signature = GetOp2ShardableAxesSignature(ops);
const auto& axis_name2union_find_set_root = GetAxisName2UnionFindSetRoot(ops, op2shardable_axes_signature);
std::unordered_map<pir::Value, ShardableAxes> sink_and_inits =
GetSinkAndInitShardableAxes(sinks, op2shardable_axes_signature, axis_name2union_find_set_root);
RenameDuplicatedAxisName(&sink_and_inits);
return sink_and_inits;
}

class StmtFusionHelper {
public:
explicit StmtFusionHelper(cinn::dialect::FusionOp& fusion_op)
Expand Down Expand Up @@ -738,17 +872,9 @@ std::unordered_map<pir::Value, ShardableAxes> InferShardableAxesFromSink(
std::unordered_map<pir::Value, ShardableAxes> InferShardableAxes(
const std::unordered_set<const pir::Operation*>& ops) {
auto reversed_walker = GetOpsTopoWalker(ops);
const pir::Operation* sink = [&] {
const auto& sinks = GetSinks(ops);
CHECK_EQ(sinks.size(), 1) << "ops must have only one sink node.";
return *sinks.begin();
}();
const auto& value2shardable_axes = [&] {
size_t rank = GetRank(sink->result(0));
const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank);
return ReversedInferShardableAxes(reversed_walker, sink, init_sa);
}();
return value2shardable_axes;
const auto& sinks = GetSinks(ops);
const auto& sink_and_init_value = GetSinkAndInitValues(reversed_walker, ops, sinks);
return ReversedInferShardableAxes(reversed_walker, sink_and_init_value.begin(), sink_and_init_value.end());
}

} // namespace cinn::frontend

0 comments on commit a6cfd99

Please sign in to comment.