Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR]Support layer norm lowering cinn pir #58606

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
eb08b34
[CINN+PIR]Support DoGroupSchedule for PIRComppiler
Aurelius84 Oct 26, 2023
78ab1c7
fix conflict
Aurelius84 Oct 30, 2023
4c2ad01
support cinn broadcast code gen
phlrain Oct 31, 2023
2418c20
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 31, 2023
764de36
fix op fusion pass bug
phlrain Oct 31, 2023
31256fc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Oct 31, 2023
cde335a
using output_ops to parse function arguments
Aurelius84 Oct 31, 2023
3dab76c
update
phlrain Oct 31, 2023
330c7b8
fix unittest
Aurelius84 Oct 31, 2023
af8ba0c
remove VLOG(1)
Aurelius84 Oct 31, 2023
65e0126
ignore some UT and add FIXME
Aurelius84 Oct 31, 2023
a0a1d6b
update
phlrain Nov 1, 2023
d64f696
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 1, 2023
aab3735
Merge commit 'refs/pull/58399/head' of https://github.com/PaddlePaddl…
phlrain Nov 1, 2023
b88ecf8
remove args limit
phlrain Nov 1, 2023
3501eab
Merge commit 'refs/pull/58399/head' of https://github.com/PaddlePaddl…
phlrain Nov 1, 2023
0992bc4
fix bug and remove useless code
phlrain Nov 1, 2023
1be0fcf
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 1, 2023
94c7d7f
update
phlrain Nov 1, 2023
448971d
Merge commit 'refs/pull/58516/head' of https://github.com/PaddlePaddl…
phlrain Nov 1, 2023
e1c3cfc
update
phlrain Nov 1, 2023
8867d25
fix bug
phlrain Nov 1, 2023
e44190d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 1, 2023
62903dc
update
phlrain Nov 1, 2023
a74b749
fix bug
phlrain Nov 1, 2023
375541f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 1, 2023
0f0e473
update
phlrain Nov 2, 2023
fbbf5bc
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 2, 2023
81318fb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 3, 2023
8efca8f
update
phlrain Nov 7, 2023
9810688
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 7, 2023
ddefacd
remove useless code
phlrain Nov 7, 2023
03fca47
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 7, 2023
4d553c0
remove useless code
phlrain Nov 8, 2023
1896ceb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 8, 2023
a7c9f68
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
phlrain Nov 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ bool IsSameDim(const phi::DDim& first, const std::vector<int64_t>& second) {
return false;
}

std::vector<int64_t> GetBroadcastAxis(const phi::DDim& in_shape,
const std::vector<int64_t>& out_shape) {
std::vector<int64_t> broadcast_axes(in_shape.size(), 0);
auto in_shape_size = in_shape.size();
if (in_shape_size >= 1) {
for (int i = 1; i <= in_shape_size; ++i) {
broadcast_axes[in_shape_size - i] = out_shape.size() - i;
}
}

return broadcast_axes;
}

bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
auto x_dims = op->operand_source(0)
.type()
Expand All @@ -93,21 +106,21 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {

if (x_dims != y_dims) {
auto output_shape = GetOutputShape(x_dims, y_dims);
std::vector<int64_t> vec_dims;
for (int64_t i = 0; i < output_shape.size(); ++i) {
vec_dims.push_back(i);
}
if (!IsSameDim(x_dims, output_shape)) {
// add broadcast to input 0
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
op->operand_source(0), vec_dims, output_shape);
op->operand_source(0),
GetBroadcastAxis(x_dims, output_shape),
output_shape);

op->operand(0).set_source(new_transpose_op->result(0));
}

if (!IsSameDim(y_dims, output_shape)) {
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
op->operand_source(1), vec_dims, output_shape);
op->operand_source(1),
GetBroadcastAxis(y_dims, output_shape),
output_shape);

op->operand(1).set_source(new_transpose_op->result(0));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ OpPatternKind GetOpKind(const std::string& op_name) {
}

phi::DDim GetFirstInputShape(const ::pir::Operation* op) {
if (op->num_operands() == 0) {
return phi::DDim({});
}
auto in = op->operand_source(0);

return in.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/hlir/framework/pir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const std::unordered_map<std::string, std::string> CompatibleInfo::OP_NAMES = {
{"pd_op.add", "elementwise_add"},
{"pd_op.subtract", "subtract"},
{"pd_op.divide", "divide"},
{"pd_op.multiply", "elementwise_mul"},
{"cinn_op.broadcast", "broadcast_to"}};

// Tagging PaddleDialect Op with REGITER_OP_MAPPER(OP)
Expand Down
32 changes: 28 additions & 4 deletions paddle/fluid/pir/transforms/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,9 @@ class CinnSubgraphDetector {
CinnSubgraphDetector(pir::Block* block, const OpClassifier& classifier)
: block_(block), op_classifier_(classifier) {
sort_ops_ = InverselyTopologicalSort(block_);
for (size_t i = 0; i < sort_ops_.size(); ++i) {
op2id_[sort_ops_[i]] = i;
size_t index = 0;
for (auto* op : *block) {
op2id_[op] = index++;
}
}

Expand All @@ -313,7 +314,18 @@ class CinnSubgraphDetector {
if (!subgraph->substitute) {
continue;
}
groups.push_back(subgraph->ops);

// sort group ops
std::vector<pir::Operation*> tmp_ops(subgraph->ops.begin(),
subgraph->ops.end());
auto& op2id = op2id_;
std::sort(tmp_ops.begin(),
tmp_ops.end(),
[&op2id](pir::Operation* a, pir::Operation* b) {
return op2id.at(a) > op2id.at(b);
});

groups.push_back(tmp_ops);
}

return groups;
Expand Down Expand Up @@ -581,6 +593,12 @@ std::vector<pir::Value> AnalysisOutputs(GroupOpsVec& group_ops) { // NOLINT
}
}

if (vec_res.size() == 0) {
for (size_t i = 0; i < group_ops.back()->num_results(); ++i) {
vec_res.push_back(group_ops.back()->result(i));
}
}

return vec_res;
}

Expand All @@ -601,14 +619,20 @@ void ReplaceWithGroupOp(pir::Block* block,
// step 2: Replace the old op with GroupOp.
auto new_group_op = builder.Build<cinn::dialect::GroupOp>(output_types);
pir::Block* group_block = new_group_op.block();

for (auto* op : group_ops) {
op->MoveTo(group_block, group_block->begin());
}

// step 3: Replace outputs of inner ops
std::vector<pir::OpResult> group_outs = new_group_op->results();
std::unordered_set<pir::Operation*> inner_ops(group_ops.begin(),
group_ops.end());
for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].ReplaceAllUsesWith(group_outs[i]);
outputs[i].ReplaceUsesWithIf(group_outs[i],
[&inner_ops](pir::OpOperand op) {
return !inner_ops.count(op.owner());
});
}

// step 4: Insert YieldOp for outputs
Expand Down
2 changes: 2 additions & 0 deletions paddle/pir/core/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ void Value::ReplaceUsesWithIf(
for (auto it = use_begin(); it != use_end();) {
if (should_replace(*it)) {
(it++)->set_source(new_value);
} else {
it++;
}
}
}
Expand Down
106 changes: 106 additions & 0 deletions test/cpp/pir/cinn/pir_all_path_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,109 @@ TEST(GroupOp, TestBuild) {
bool res0 = simple_cmp(out_tensor.data<float>()[0], 1.0 / 768);
EXPECT_EQ(res0, true);
}

std::shared_ptr<::pir::Program> BuildLayerNormProgram() {
::pir::IrContext* ctx = ::pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
auto program = std::make_shared<::pir::Program>(ctx);
::pir::Builder builder = ::pir::Builder(ctx, program->block());

std::vector<int64_t> axes{-1};
auto x =
builder
.Build<paddle::dialect::FullOp>(std::vector<int64_t>({128, 128, 768}),
1.0,
phi::DataType::FLOAT32,
phi::GPUPlace())
.result(0);

auto bias = builder
.Build<paddle::dialect::FullOp>(std::vector<int64_t>({768}),
1.0,
phi::DataType::FLOAT32,
phi::GPUPlace())
.result(0);

auto scale = builder
.Build<paddle::dialect::FullOp>(std::vector<int64_t>({768}),
1.0,
phi::DataType::FLOAT32,
phi::GPUPlace())
.result(0);

auto num = builder
.Build<paddle::dialect::FullOp>(std::vector<int64_t>{1},
768.0,
phi::DataType::FLOAT32,
phi::CPUPlace())
.result(0);
auto eps = builder
.Build<paddle::dialect::FullOp>(std::vector<int64_t>{1},
1e-5,
phi::DataType::FLOAT32,
phi::CPUPlace())
.result(0);

auto sum =
builder
.Build<paddle::dialect::SumOp>(x, axes, phi::DataType::FLOAT32, true)
.result(0);

auto mean = builder.Build<paddle::dialect::DivideOp>(sum, num).result(0);
auto power = builder.Build<paddle::dialect::MultiplyOp>(x, x).result(0);
auto power_sum = builder
.Build<paddle::dialect::SumOp>(
power, axes, phi::DataType::FLOAT32, true)
.result(0);
auto mean2 =
builder.Build<paddle::dialect::DivideOp>(power_sum, num).result(0);
auto power_mean =
builder.Build<paddle::dialect::MultiplyOp>(mean, mean).result(0);

auto var =
builder.Build<paddle::dialect::SubtractOp>(mean2, power_mean).result(0);

auto sub = builder.Build<paddle::dialect::SubtractOp>(x, mean).result(0);
auto t1 = builder.Build<paddle::dialect::AddOp>(var, eps).result(0);
auto t2 = builder.Build<paddle::dialect::SqrtOp>(t1).result(0);
auto t3 = builder.Build<paddle::dialect::DivideOp>(sub, t2).result(0);
auto t5 = builder.Build<paddle::dialect::MultiplyOp>(t3, scale).result(0);
auto out = builder.Build<paddle::dialect::MultiplyOp>(t5, bias).result(0);

builder.Build<paddle::dialect::FetchOp>(out, "out", 0);
return program;
}

TEST(GroupOp, TestBuildLayerNorm) {
// Step 1: Construct pir::Program
::pir::IrContext* ctx = ::pir::IrContext::Instance();
std::shared_ptr<::pir::Program> program = BuildLayerNormProgram();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();

cinn::dialect::ir::PdOp2CinnOpConverter(program.get());

pir::PassManager pm(ctx);
pm.AddPass(
std::make_unique<cinn::dialect::ir::AddBroadcastToElementwisePass>());
pm.AddPass(pir::CreateBuildCinnPass());
CHECK_EQ(pm.Run(program.get()), true);

auto res = cinn::dialect::ir::CINNGroupLoweringPass(program.get());

paddle::platform::Place place = paddle::platform::CUDAPlace(0);

auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(res.get(), place);

paddle::framework::Scope exe_scope;

paddle::framework::InterpreterCore executor(
place, {"out@fetch"}, kernel_program->block(), &exe_scope);

// TODO(phlrain): fix exec error
// executor.Run({}, true);

// auto out_tensor =
// executor.local_scope()->FindVar("out@fetch")->Get<phi::DenseTensor>();
}