Skip to content

Commit

Permalink
fix compile issue
Browse files Browse the repository at this point in the history
Signed-off-by: ZelinMa557 <3388706467@qq.com>
  • Loading branch information
ZelinMa557 committed Mar 18, 2024
1 parent e459ab8 commit e7a7d74
Show file tree
Hide file tree
Showing 11 changed files with 217 additions and 144 deletions.
4 changes: 2 additions & 2 deletions paddle/cinn/adt/generate_map_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ template <typename DoEachT>
void VisitEachOpStmt(
const std::shared_ptr<hlir::framework::pir::OpLoweringGroup>& group,
const DoEachT& DoEach) {
for (const auto* op : group->CollectOps()) {
for (const auto* op : group->ops()) {
DoEach(OpStmt{MakeOp(op),
MakeOpStmtInputList(op, group.get()),
MakeOpStmtOutputList(op, group.get())});
Expand Down Expand Up @@ -463,7 +463,7 @@ void TryGenerateMapExprFromGroup(
}
const auto& map_expr = GenerateMapExpr(fusion_group);
VLOG(4) << "Generate MapExpr: \n"
<< ToTxtString(map_expr, fusion_group->group_id);
<< ToTxtString(map_expr, fusion_group->group_id());
fusion_group->set_map_expr_ctx(std::make_shared<MapExprCtx>(map_expr));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,13 @@ void SetLeafBlockByGroupView(
pir::Block* block,
std::unordered_map<pir::Block*, OpLoweringGroupPtr>* group_map) {
pir::IrMapping ir_mapping;
auto origin_group_inputs = GetBlockOutsideInput(origin_group->ops);
auto origin_group_inputs = GetBlockOutsideInput(origin_group->ops());
for (auto input : origin_group_inputs) {
ir_mapping.Add(input, input);
}

auto new_group = CloneGroup(origin_group, block, &ir_mapping);
CHECK_EQ(origin_group->ops.size(), new_group->ops.size());
CHECK_EQ(origin_group->ops().size(), new_group->ops().size());
UpdateGroupShapeExprs(new_group,
origin_group,
ir_mapping,
Expand Down Expand Up @@ -425,15 +425,12 @@ void SimplyConditionBlock(
DoEach(block, group);
std::vector<pir::Operation*> group_new_ops;
group_new_ops.reserve(block->size());
std::unordered_set<pir::Operation*> group_ops_set;
for (auto& op : *block) {
if (!op.isa<pir::YieldOp>()) {
group_new_ops.push_back(&op);
group_ops_set.insert(&op);
}
}
group->ops = group_new_ops;
group->ops_set = group_ops_set;
group->SetOps(group_new_ops);
}
};
ForEachMutBlockGroup([&](auto* block, const auto& group) {
Expand All @@ -460,7 +457,7 @@ void CompileGroupToJitKernelOp(
VLOG(4) << "The size of group_map is : " << group_map->size();
for (auto& [block, group] : *group_map) {
std::vector<pir::Type> output_types;
const auto& group_output_values = group->output_values;
const auto& group_output_values = group->output_values();
for (size_t i = 0; i < group_output_values.size(); ++i) {
output_types.push_back(group_output_values[i].type());
}
Expand Down Expand Up @@ -512,7 +509,7 @@ pir::Operation* CompileBroadcastTreeToConditionBlock(
rewriter.block(),
&group_map);
// 2. simply every condition block
auto* program = group->ops.front()->GetParentProgram();
auto* program = group->ops().front()->GetParentProgram();
VLOG(6) << "Before simply condition block: " << *program;

SimplyConditionBlock(rewriter, &group_map);
Expand Down Expand Up @@ -561,7 +558,7 @@ pir::Operation* ProcessDyShapeGroup(
cinn::common::BroadcastLeaf(all_value_dim_exprs));
VLOG(4) << "broadcast-tree: \n" << ToTxtString(broadcast_tree);

auto group_inputs = GetBlockOutsideInput(group->ops);
auto group_inputs = GetBlockOutsideInput(group->ops());

// has multiple branch
if (broadcast_tree
Expand All @@ -583,7 +580,7 @@ pir::Operation* ProcessDyShapeGroup(
// compile group to jit_kernel_op
auto op_attr_map = CompileGroupAsOpAttribute(pir_compiler, {group});
std::vector<pir::Type> output_types;
const auto& group_output_values = group->output_values;
const auto& group_output_values = group->output_values();
for (size_t i = 0; i < group_output_values.size(); ++i) {
output_types.push_back(group_output_values[i].type());
}
Expand All @@ -610,8 +607,9 @@ bool IsComplicatedDimExpr(const symbol::DimExpr& dim_expr) {
}

template <typename DoEachT>
void VisitEachInputValue(const GroupPtr& group, const DoEachT& DoEach) {
for (pir::Value value : GetBlockOutsideInput(group->ops)) {
void VisitEachInputValue(const OpLoweringGroupPtr& group,
const DoEachT& DoEach) {
for (pir::Value value : GetBlockOutsideInput(group->ops())) {
DoEach(value);
}
}
Expand Down Expand Up @@ -650,7 +648,7 @@ void VisitEachDimExpr(const symbol::ShapeOrDataDimExprs& shape_or_data,

std::unordered_map<symbol::DimExpr, symbol::DimExpr>
CollectSubstituteDimExprMap(
const GroupPtr& group,
const OpLoweringGroupPtr& group,
pir::ShapeConstraintIRAnalysis& shape_analysis) { // NOLINT
std::unordered_map<symbol::DimExpr, symbol::DimExpr> dim_expr_map;

Expand Down Expand Up @@ -746,7 +744,7 @@ CreateGroupShapeOrDataExprs(
std::unordered_map<symbol::DimExpr, symbol::DimExpr> dim_expr_map =
CollectSubstituteDimExprMap(group, shape_analysis);
std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs> value2shape;
for (auto* op : group->ops) {
for (auto* op : group->ops()) {
for (size_t i = 0; i < op->num_operands(); ++i) {
auto operand = op->operand_source(i);
if (operand && value2shape.find(operand) == value2shape.end() &&
Expand Down Expand Up @@ -829,11 +827,11 @@ class FusionOpPattern : public pir::OpRewritePattern<cinn::dialect::FusionOp> {
pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT
const std::shared_ptr<cinn::hlir::framework::PirCompiler>& pir_compiler,
pir::PatternRewriter& rewriter) const { // NOLINT
auto group_inputs = GetBlockOutsideInput(group->ops);
auto group_inputs = GetBlockOutsideInput(group->ops());
// compile group to jit_kernel_op
auto op_attr_map = CompileGroupAsOpAttribute(pir_compiler, {group});
std::vector<pir::Type> output_types;
const auto& group_output_values = group->output_values;
const auto& group_output_values = group->output_values();
for (size_t i = 0; i < group_output_values.size(); ++i) {
output_types.push_back(group_output_values[i].type());
}
Expand All @@ -846,44 +844,40 @@ class FusionOpPattern : public pir::OpRewritePattern<cinn::dialect::FusionOp> {
std::shared_ptr<OpLoweringGroup> RebuildGroup(
cinn::dialect::FusionOp fusion_op) const {
auto group = std::make_shared<OpLoweringGroup>();
group->op_pattern_kind = cinn::hlir::framework::OpPatternKind::kElementWise;
group->set_op_pattern_kind(
cinn::hlir::framework::OpPatternKind::kElementWise);
if (fusion_op.attributes().count("group_info")) {
auto attr = fusion_op.attribute("group_info")
.dyn_cast<cinn::dialect::GroupInfoAttribute>()
.data();

group->op_pattern_kind = attr.op_pattern_kind;
group->loop_ranges = attr.loop_ranges;
group->loop_ranges_expr = attr.loop_ranges_expr;

group->reduce_axis = attr.reduce_axis;
group->alignment_schedule_info = attr.alignment_schedule_info;
group->set_op_pattern_kind(attr.op_pattern_kind);
group->set_loop_ranges(attr.loop_ranges);
group->set_loop_ranges_expr(attr.loop_ranges_expr);
group->set_reduce_axis(attr.reduce_axis);
group->set_alignment_schedule_info(attr.alignment_schedule_info);
}

// Rebuild ops of the group
for (auto op : fusion_op.GetOperators()) {
if (!op->isa<::pir::YieldOp>()) {
group->ops.push_back(op);

group->ops_set.insert(op);
group->op_pattern_kind =
group->mut_ops().push_back(op);
group->set_op_pattern_kind(
static_cast<int>(CompatibleInfo::OpKind(*op)) >
static_cast<int>(group->op_pattern_kind)
static_cast<int>(group->op_pattern_kind())
? CompatibleInfo::OpKind(*op)
: group->op_pattern_kind;
: group->op_pattern_kind());
}
}

// Rebuild output_ops and input_ops of the group
auto yield_op = fusion_op.GetOperators().back();
for (size_t i = 0; i < yield_op->num_operands(); ++i) {
auto in = yield_op->operand_source(i);
group->output_ops.insert(in.defining_op());
group->output_values.push_back(in);
group->mut_output_ops().insert(in.defining_op());
group->mut_output_values().push_back(in);
}

// Rebuild other informations
// TODO(zhangyuqin1998): Do we need group.master_ops?
return group;
}
};
Expand Down
6 changes: 3 additions & 3 deletions paddle/cinn/hlir/framework/pir/compilation_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ std::unique_ptr<Instruction> CompilationTask::BuildInstruction() {
std::unique_ptr<Instruction> instr =
std::make_unique<Instruction>(context_->target_,
context_->scope_.get(),
context_->group_->input_names,
context_->group_->output_names,
context_->group_->input_names(),
context_->group_->output_names(),
fn_name);
VLOG(4) << "Lookup kernel name: " << fn_name;
auto* fn_ptr = context_->backend_compiler_->Lookup(fn_name);
Expand All @@ -113,7 +113,7 @@ pir::CINNKernelInfo CompilationTask::BuildPirCINNKernelInfo() {
cinn_kernel_info.fn_name = fn_name;
cinn_kernel_info.fn_ptr = fn_ptr;
cinn_kernel_info.infer_shape_fn_ptr = infer_shape_fn_ptr;
cinn_kernel_info.int_args_map = context_->group_->int_args_map;
cinn_kernel_info.int_args_map = context_->group_->int_args_map();
return cinn_kernel_info;
}

Expand Down
32 changes: 14 additions & 18 deletions paddle/cinn/hlir/framework/pir/op_lowering_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ std::shared_ptr<OpLoweringGroup> OpLoweringGroup::Clone(
// Mapper from original to new ops.
std::unordered_map<::pir::Operation*, ::pir::Operation*> ops_mapper;
auto clone_options = ::pir::CloneOptions(false, true, false);
for (auto* op : ops) {
for (auto* op : ops_) {
VLOG(4) << "clone op :" << op->name();
auto* new_op = op->Clone(*ir_mapping, clone_options);
// NOTE(dev): Must call block.insert to deal with ownership, otherwise it
Expand All @@ -37,31 +37,27 @@ std::shared_ptr<OpLoweringGroup> OpLoweringGroup::Clone(

// Construct Base information for new Group
auto new_group = std::make_shared<OpLoweringGroup>(new_ops);
for (auto& iter : this->input_ops) {
new_group->input_ops[ops_mapper.at(iter.first)] = iter.second;
for (auto* op : this->output_ops_) {
new_group->output_ops_.insert(ops_mapper.at(op));
}
for (auto* op : this->output_ops) {
new_group->output_ops.insert(ops_mapper.at(op));
}
for (const auto& output_value : this->output_values) {
new_group->output_values.push_back(ir_mapping->Lookup(output_value));
for (const auto& output_value : this->output_values_) {
new_group->output_values_.push_back(ir_mapping->Lookup(output_value));
}

new_group->input_names = this->input_names;
new_group->output_names = this->output_names;
new_group->output_values = this->output_values;
new_group->fn_name = this->fn_name;
new_group->int_args_map = this->int_args_map;
new_group->alignment_schedule_info = this->alignment_schedule_info;
new_group->reduce_axis = this->reduce_axis;
new_group->loop_ranges = this->loop_ranges;
new_group->input_names_ = this->input_names_;
new_group->output_names_ = this->output_names_;
new_group->fn_name_ = this->fn_name_;
new_group->int_args_map_ = this->int_args_map_;
new_group->alignment_schedule_info_ = this->alignment_schedule_info_;
new_group->reduce_axis_ = this->reduce_axis_;
new_group->loop_ranges_ = this->loop_ranges_;
return new_group;
}

std::ostream& operator<<(std::ostream& os, const OpLoweringGroup& group) {
::pir::IrPrinter printer(os);
os << "Group " << group.group_id << " :\n";
for (auto* op : group.ops) {
os << "Group " << group.group_id() << " :\n";
for (auto* op : group.ops()) {
printer.PrintOperation(op);
os << "\n";
}
Expand Down
Loading

0 comments on commit e7a7d74

Please sign in to comment.