Skip to content

Commit

Permalink
matmul+activation fuse pass (#43519)
Browse files Browse the repository at this point in the history
* add method for post ops

* format code

* gpd

* format style

* add matmul+act test

* implement matmul+activation

* whitespaces

* code style

* python code format

* Increase UT timeout

* code format

* update style

* generalize activation fuse passes

* change order

* Unify activation GPD

* Revert changes with op_act

* remove softmax mkldnn attrs

* set common name for act attributes

* whitespace

* append postops by helper function

* ut style

* revert changes related to quantization

* Reduce redundancy

* reduce number of parameters

* trigger CI

* validate attribute

* trim unit test
  • Loading branch information
Silv3S authored Jul 12, 2022
1 parent 40b6863 commit 3333a43
Show file tree
Hide file tree
Showing 8 changed files with 456 additions and 3 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ if(WITH_MKLDNN)
pass_library(shuffle_channel_mkldnn_detect_pass inference DIR mkldnn)
pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ using string::PrettyLogDetail;

void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = paddle::platform::GetSupportedActivations();

std::vector<std::string> conv_types = {"conv2d"};

for (const auto& conv_type : conv_types)
Expand Down
281 changes: 281 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h"

#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

using string::PrettyLogDetail;

void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = paddle::platform::GetSupportedActivations();
std::vector<std::string> matmul_types = {"matmul"};

for (const auto& matmul_type : matmul_types)
for (auto& act_type : act_types) {
FuseMatmulAct(graph, matmul_type, act_type);
}
}

void MatmulActivationMkldnnFusePass::FuseMatmulAct(
Graph* graph, const std::string& matmul_type, std::string& act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(matmul_type + "_" + act_type + "_mkldnn_fuse_pass", graph);

GraphPatternDetector gpd;
patterns::OperatorActivation matmul_act_pattern(
gpd.mutable_pattern(), "matmul_activation_mkldnn_fuse");
matmul_act_pattern(matmul_type, act_type);

int found_matmul_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle " + matmul_type + "+" + act_type + " fuse";

if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "matmul_activation_mkldnn_fuse_pass op compat failed.";
return;
}

GET_IR_NODE_FROM_SUBGRAPH(matmul, preceding_op, matmul_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, preceding_op_out, matmul_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation, activation, matmul_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, matmul_act_pattern);

OpDesc* matmul_op = matmul->Op();
OpDesc* act_op = activation->Op();

auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
matmul_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
}

if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) {
act_type = BOOST_GET_CONST(bool, activation->Op()->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
}
matmul_op->SetAttr("fuse_activation", act_type);
matmul_op->SetOutput("Out", {activation_out->Name()});

IR_NODE_LINK_TO(matmul, activation_out);
GraphSafeRemoveNodes(graph, {activation, matmul_out});
found_matmul_activation_count++;
};

gpd(graph, handler);
AddStatis(found_matmul_activation_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
PrettyLogDetail("--- fused %d matmul with %s activation",
found_matmul_activation_count,
act_type);
}
}

MatmulActivationMkldnnFusePass::MatmulActivationMkldnnFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End();

AddOpCompat(OpCompat("abs"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();

AddOpCompat(OpCompat("clip"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("min")
.End()
.AddAttr("max")
.End();

AddOpCompat(OpCompat("gelu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("approximate")
.IsType<bool>()
.End();

AddOpCompat(OpCompat("hard_sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("slope")
.IsOptional()
.IsType<float>()
.End()
.AddAttr("offset")
.IsOptional()
.IsType<float>()
.End();

AddOpCompat(OpCompat("hard_swish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("threshold")
.IsOptional()
.IsType<float>()
.End()
.AddAttr("scale")
.IsOptional()
.IsType<float>()
.End()
.AddAttr("offset")
.IsOptional()
.IsType<float>()
.End();

AddOpCompat(OpCompat("leaky_relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End();

AddOpCompat(OpCompat("mish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();

AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();

AddOpCompat(OpCompat("relu6"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("threshold")
.IsType<float>()
.End();

AddOpCompat(OpCompat("sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();

AddOpCompat(OpCompat("sqrt"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();

AddOpCompat(OpCompat("swish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("beta")
.IsType<float>()
.End();

AddOpCompat(OpCompat("tanh"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(matmul_activation_mkldnn_fuse_pass,
paddle::framework::ir::MatmulActivationMkldnnFusePass);

REGISTER_PASS_CAPABILITY(matmul_activation_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("abs", 0)
.LE("clip", 1)
.EQ("gelu", 0)
.EQ("hard_sigmoid", 0)
.LE("hard_swish", 0)
.LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"

namespace paddle {
namespace framework {
namespace ir {

class MatmulActivationMkldnnFusePass : public FusePassBase {
public:
MatmulActivationMkldnnFusePass();
virtual ~MatmulActivationMkldnnFusePass() {}

protected:
void ApplyImpl(Graph *graph) const override;

void FuseMatmulAct(Graph *graph,
const std::string &matmul_type,
std::string &act_type) const;
};

} // namespace ir
} // namespace framework
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ void MainTest(const std::string& activation_type) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("fuse_activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto activation_type =
BOOST_GET_CONST(std::string, op->GetAttr("fuse_activation_type"));
BOOST_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(activation_type.compare(activation_type), 0);
}
}
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"softplus_activation_mkldnn_fuse_pass", //
"shuffle_channel_mkldnn_detect_pass", //
"elt_act_mkldnn_fuse_pass", //
"matmul_activation_mkldnn_fuse_pass", //
// TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710
// "mkldnn_inplace_pass", // This pass should be activated after
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <tuple>

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"

using dnnl::memory;
using dnnl::primitive;
Expand Down Expand Up @@ -453,6 +454,8 @@ class MatMulMKLDNNHandler
matmul_attrs.set_output_scales(0, {scale_out});
}

paddle::platform::AppendActivation(ctx, post_operations);

matmul_attrs.set_post_ops(post_operations);
return matmul_attrs;
}
Expand Down
Loading

0 comments on commit 3333a43

Please sign in to comment.