diff --git a/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_slice_link_fuse_pass.cc b/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_slice_link_fuse_pass.cc index 7037805c28f..03948c51cf2 100644 --- a/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_slice_link_fuse_pass.cc +++ b/lite/core/optimizer/mir/fusion/__xpu__multi_encoder_slice_link_fuse_pass.cc @@ -36,21 +36,22 @@ class XPUMultiEncoderSliceLinkFuser : public FuseBase { PMNode* layer_norm = nullptr; PMNode* layer_norm_out = nullptr; - auto* slice = OpNode("slice", "slice") - ->assert_op_attr_satisfied>( - "axes", - [](const std::vector& attr) { - return attr.size() == 1 && attr[0] == 1; - }) - ->assert_op_attr_satisfied>( - "starts", - [](const std::vector& attr) { - return attr.size() == 1 && attr[0] == 0; - }) - ->assert_op_attr_satisfied>( - "ends", [](const std::vector& attr) { - return attr.size() == 1 && attr[0] == 1; - }); + auto* slice = + OpNode("slice", "slice") + ->assert_op_attr_satisfied>( + "axes", + [](const std::vector& attr) { + return attr.size() == 1 && attr[0] == 1; + }) + ->assert_op_attr_satisfied>( + "starts", + [](const std::vector& attr) { + return attr.size() == 1 && attr[0] == 0; + }) + ->assert_op_attr_satisfied>( + "ends", [](const std::vector& attr) { + return attr.size() == 1 && attr[0] > 0 && attr[0] <= 20; + }); if (pre_ln_) { xpu_encoder->assert_op_attr("norm_before", true); encoder_out->assert_is_op_input("layer_norm", "X"); diff --git a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc index c87a3e270d8..051372a2938 100644 --- a/lite/kernels/xpu/__xpu__multi_encoder_compute.cc +++ b/lite/kernels/xpu/__xpu__multi_encoder_compute.cc @@ -239,9 +239,10 @@ void XPUMultiEncoderCompute::PrepareForRun() { } // prepare with sice if ((param.slice_starts.size() > 0 && param.slice_starts[0] == 0) && - (param.slice_ends.size() > 0 && param.slice_ends[0] == 1) && + (param.slice_ends.size() > 0 && param.slice_ends[0] > 0 && + param.slice_ends[0] <= 20) && (param.slice_axes.size() > 0 && param.slice_axes[0] == 1)) { - slice_idx = 0; + slice_idx = param.slice_ends[0] - 1; } // prepare input_cast and output_cast guard_ cast_in_guard_ = TargetWrapperXPU::MallocScratchPad(4 * 1024 * 1024); diff --git a/lite/operators/__xpu__multi_encoder_op.cc b/lite/operators/__xpu__multi_encoder_op.cc index 8ef7da2e7b6..11b027ed94a 100644 --- a/lite/operators/__xpu__multi_encoder_op.cc +++ b/lite/operators/__xpu__multi_encoder_op.cc @@ -42,9 +42,11 @@ bool XPUMultiEncoderOp::InferShapeImpl() const { seq_len = param_.PadSeqLen->data()[0]; } if ((param_.slice_starts.size() > 0 && param_.slice_starts[0] == 0) && - (param_.slice_ends.size() > 0 && param_.slice_ends[0] == 1) && + (param_.slice_ends.size() > 0 && param_.slice_ends[0] > 0 && + param_.slice_ends[0] <= 20) && (param_.slice_axes.size() > 0 && param_.slice_axes[0] == 1)) { - DDim out_dims(std::vector({batch_size, 1, head_num})); + DDim out_dims( + std::vector({batch_size, param_.slice_ends[0], head_num})); if (param_.slice_decrease_axis.size() > 0) { std::vector new_out_shape; for (size_t i = 0; i < slice_decrease_axis.size(); ++i) { @@ -65,7 +67,7 @@ bool XPUMultiEncoderOp::InferShapeImpl() const { out_dims = new_dims; } if (param_.norm_before) { - param_.output->Resize({batch_size, 1, head_num}); + param_.output->Resize({batch_size, param_.slice_ends[0], head_num}); } else { param_.output->Resize(out_dims); }