Skip to content

Commit

Permalink
[XPU] Mul quant (#6850)
Browse files Browse the repository at this point in the history
  • Loading branch information
newway authored Sep 27, 2021
1 parent bf16069 commit a51a85a
Show file tree
Hide file tree
Showing 10 changed files with 359 additions and 154 deletions.
9 changes: 9 additions & 0 deletions lite/core/optimizer/mir/fusion/__xpu__fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ class XPUFcFuser : public FuseBase {
} else if (GetStringFromEnv("XPU_ENCODER_PRECISION", "int16") == "int8" ||
lite::TargetWrapperXPU::multi_encoder_precision == "int8") {
precision = "int8";
if (op_desc.HasAttr("enable_int8") &&
op_desc.GetAttr<bool>("enable_int8")) {
CHECK(op_desc.HasAttr("X0_scale")) << " quant model fc no X0_scale";
CHECK(op_desc.HasAttr("Y0_scale")) << " quant model fc no Y0_scale";
VLOG(3) << "Use int8 quant model in XPUFcOp, InputMax:"
<< 127 * op_desc.GetAttr<std::vector<float>>("X0_scale")[0]
<< ", WeightMax: "
<< 127 * op_desc.GetAttr<std::vector<float>>("Y0_scale")[0];
}
VLOG(3) << "Use int8 in XPUFcOp";
}
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,17 @@ class XPULinkConvMaxFuser : public FuseBase {
public:
explicit XPULinkConvMaxFuser(bool with_branch) { with_branch_ = with_branch; }
void BuildPattern() override {
auto non_quant_teller = [](const Node* node) -> bool {
auto op_desc = *const_cast<Node*>(node)->stmt()->op_info();
return (!op_desc.HasAttr("enable_int8") ||
!op_desc.GetAttr<bool>("enable_int8"));
};

auto* input =
VarNode("input")->assert_is_op_input("__xpu__conv2d", "Input");
auto* xpu_fusion_op =
OpNode("xpu_fusion_op", "__xpu__conv2d")
->assert_node_satisfied(non_quant_teller)
->assert_op_attr<bool>("has_branch", with_branch_);

PMNode* branch = nullptr;
Expand Down Expand Up @@ -100,8 +107,14 @@ class XPULinkConvMaxFuser : public FuseBase {
class XPULinkFcMaxFuser : public FuseBase {
public:
void BuildPattern() override {
auto non_quant_teller = [](const Node* node) -> bool {
auto op_desc = *const_cast<Node*>(node)->stmt()->op_info();
return (!op_desc.HasAttr("enable_int8") ||
!op_desc.GetAttr<bool>("enable_int8"));
};
auto* input = VarNode("input")->assert_is_op_input("__xpu__fc", "Input");
auto* xpu_fusion_op = OpNode("xpu_fusion_op", "__xpu__fc");
auto* xpu_fusion_op = OpNode("xpu_fusion_op", "__xpu__fc")
->assert_node_satisfied(non_quant_teller);

*input >> *xpu_fusion_op;
}
Expand Down
Loading

0 comments on commit a51a85a

Please sign in to comment.