Skip to content

Commit

Permalink
[CINN] replace struct Group with OpLoweringGroup in lower_cinn_fusion…
Browse files Browse the repository at this point in the history
…_op_pass (#62339)

Signed-off-by: ZelinMa557 <3388706467@qq.com>
  • Loading branch information
ZelinMa557 authored Mar 25, 2024
1 parent 7d9b987 commit a7d5ea9
Show file tree
Hide file tree
Showing 22 changed files with 586 additions and 288 deletions.
4 changes: 2 additions & 2 deletions paddle/cinn/adt/adapter_dynamic_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/dim_expr.h"
#include "paddle/cinn/adt/symbolic_dim.h"
#include "paddle/cinn/hlir/framework/pir/group.h"
#include "paddle/cinn/hlir/framework/pir/op_lowering_group.h"

namespace cinn::adt::adapter {

struct DynamicTensor final {
::pir::Value node_data;
const hlir::framework::pir::Group* group;
const hlir::framework::pir::OpLoweringGroup* group;

bool operator==(const DynamicTensor& other) const {
return this->node_data == other.node_data;
Expand Down
34 changes: 19 additions & 15 deletions paddle/cinn/adt/generate_map_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ bool HasDynamicShape(const ::pir::Value& tensor) {
return false;
}

List<Arg> MakeOpStmtInputList(const ::pir::Operation* op,
const hlir::framework::pir::Group* group) {
List<Arg> MakeOpStmtInputList(
const ::pir::Operation* op,
const hlir::framework::pir::OpLoweringGroup* group) {
List<Arg> ret{};

VisitEachInputTensor(op, [&](const ::pir::Value& tensor) {
Expand All @@ -131,8 +132,9 @@ void VisitEachOutputTensor(const ::pir::Operation* op, const DoEachT& DoEach) {
}
}

List<Arg> MakeOpStmtOutputList(const ::pir::Operation* op,
const hlir::framework::pir::Group* group) {
List<Arg> MakeOpStmtOutputList(
const ::pir::Operation* op,
const hlir::framework::pir::OpLoweringGroup* group) {
List<Arg> ret{};

VisitEachOutputTensor(op, [&](const ::pir::Value& tensor) {
Expand All @@ -147,9 +149,10 @@ List<Arg> MakeOpStmtOutputList(const ::pir::Operation* op,
}

template <typename DoEachT>
void VisitEachOpStmt(const std::shared_ptr<hlir::framework::pir::Group>& group,
const DoEachT& DoEach) {
for (const auto* op : group->CollectOps()) {
void VisitEachOpStmt(
const std::shared_ptr<hlir::framework::pir::OpLoweringGroup>& group,
const DoEachT& DoEach) {
for (const auto* op : group->ops()) {
DoEach(OpStmt{MakeOp(op),
MakeOpStmtInputList(op, group.get()),
MakeOpStmtOutputList(op, group.get())});
Expand Down Expand Up @@ -187,7 +190,7 @@ void CollectRewrittenOpStmts(const OpStmt& op_stmt, List<OpStmt>* ret) {
}

List<OpStmt> MakeOpStmts(
const std::shared_ptr<hlir::framework::pir::Group>& group) {
const std::shared_ptr<hlir::framework::pir::OpLoweringGroup>& group) {
List<OpStmt> ret{};

VisitEachOpStmt(group, [&](const auto& op_stmt) {
Expand Down Expand Up @@ -223,7 +226,7 @@ std::shared_ptr<IGroup> MakeIGroup(const AnchorGroup& igroup_spec) {
}

std::vector<std::shared_ptr<IGroup>> GenerateIGroups(
const std::shared_ptr<hlir::framework::pir::Group>& group) {
const std::shared_ptr<hlir::framework::pir::OpLoweringGroup>& group) {
std::vector<std::shared_ptr<IGroup>> ret{};

List<OpStmt> op_stmts = MakeOpStmts(group);
Expand All @@ -237,7 +240,7 @@ std::vector<std::shared_ptr<IGroup>> GenerateIGroups(
}

std::shared_ptr<KGroup> GenerateKGroups(
const std::shared_ptr<hlir::framework::pir::Group>& group,
const std::shared_ptr<hlir::framework::pir::OpLoweringGroup>& group,
const std::vector<std::shared_ptr<IGroup>>& igroups) {
CHECK_EQ(igroups.size(), 1);
return std::make_shared<KGroup>(group, igroups);
Expand Down Expand Up @@ -352,15 +355,15 @@ Tensor GetAnchorTensor(const std::shared_ptr<IGroup>& igroup) {
}

template <typename DoEachT>
void VisitInputTensor(const hlir::framework::pir::Group& group,
void VisitInputTensor(const hlir::framework::pir::OpLoweringGroup& group,
const DoEachT& DoEach) {
for (const ::pir::Value& node_data : group.GetInputOpValues()) {
DoEach(node_data);
}
}

template <typename DoEachT>
void VisitOutputTensor(const hlir::framework::pir::Group& group,
void VisitOutputTensor(const hlir::framework::pir::OpLoweringGroup& group,
const DoEachT& DoEach) {
for (const ::pir::Value& node_data : group.GetOutputOpValues()) {
DoEach(node_data);
Expand Down Expand Up @@ -444,7 +447,7 @@ MapExpr GenerateMapExpr(const std::shared_ptr<KGroup>& kgroup) {
} // namespace

MapExpr GenerateMapExpr(
const std::shared_ptr<hlir::framework::pir::Group>& group) {
const std::shared_ptr<hlir::framework::pir::OpLoweringGroup>& group) {
const auto& igroups = GenerateIGroups(group);

const auto& kgroup = GenerateKGroups(group, igroups);
Expand All @@ -453,13 +456,14 @@ MapExpr GenerateMapExpr(
}

void TryGenerateMapExprFromGroup(
const std::shared_ptr<hlir::framework::pir::Group>& fusion_group) {
const std::shared_ptr<hlir::framework::pir::OpLoweringGroup>&
fusion_group) {
if (!FLAGS_cinn_enable_map_expr) {
return;
}
const auto& map_expr = GenerateMapExpr(fusion_group);
VLOG(4) << "Generate MapExpr: \n"
<< ToTxtString(map_expr, fusion_group->group_id);
<< ToTxtString(map_expr, fusion_group->group_id());
fusion_group->set_map_expr_ctx(std::make_shared<MapExprCtx>(map_expr));
}

Expand Down
7 changes: 3 additions & 4 deletions paddle/cinn/adt/generate_map_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@

namespace cinn::hlir::framework::pir {

struct Group;
using GroupList = std::vector<std::shared_ptr<Group>>;
struct OpLoweringGroup;

} // namespace cinn::hlir::framework::pir

namespace cinn::adt {

MapExpr GenerateMapExpr(
const std::shared_ptr<cinn::hlir::framework::pir::Group>& group);
const std::shared_ptr<cinn::hlir::framework::pir::OpLoweringGroup>& group);

void TryGenerateMapExprFromGroup(
const std::shared_ptr<hlir::framework::pir::Group>& fusion_group);
const std::shared_ptr<hlir::framework::pir::OpLoweringGroup>& fusion_group);

} // namespace cinn::adt
8 changes: 4 additions & 4 deletions paddle/cinn/adt/kgroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace cinn::hlir::framework::pir {

struct Group;
struct OpLoweringGroup;

} // namespace cinn::hlir::framework::pir

Expand All @@ -39,11 +39,11 @@ using cinn::adt::LoopDescriptors;
class KGroup final {
public:
explicit KGroup(
const std::shared_ptr<hlir::framework::pir::Group>& cinn_group,
const std::shared_ptr<hlir::framework::pir::OpLoweringGroup>& cinn_group,
const std::vector<std::shared_ptr<IGroup>>& igroups)
: cinn_group_(cinn_group), igroups_(igroups) {}

std::shared_ptr<hlir::framework::pir::Group> cinn_group() const {
std::shared_ptr<hlir::framework::pir::OpLoweringGroup> cinn_group() const {
return CHECK_NOTNULL(cinn_group_.lock());
}

Expand All @@ -58,7 +58,7 @@ class KGroup final {
const std::shared_ptr<IGroup>& igroup) const;

private:
std::weak_ptr<hlir::framework::pir::Group> cinn_group_;
std::weak_ptr<hlir::framework::pir::OpLoweringGroup> cinn_group_;
// NOTE: Use single igroup temporarily. Actually KGroup contains
// multiple IGroups
std::vector<std::shared_ptr<IGroup>> igroups_;
Expand Down
Loading

0 comments on commit a7d5ea9

Please sign in to comment.