Skip to content

Commit

Permalink
[XPU] Extend slice num of multi_encoder (#10565)
Browse files Browse the repository at this point in the history
  • Loading branch information
WhatGhost authored Sep 24, 2024
1 parent 9b53b4a commit c809775
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,22 @@ class XPUMultiEncoderSliceLinkFuser : public FuseBase {
PMNode* layer_norm = nullptr;
PMNode* layer_norm_out = nullptr;

auto* slice = OpNode("slice", "slice")
->assert_op_attr_satisfied<std::vector<int>>(
"axes",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 1;
})
->assert_op_attr_satisfied<std::vector<int>>(
"starts",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 0;
})
->assert_op_attr_satisfied<std::vector<int>>(
"ends", [](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 1;
});
auto* slice =
OpNode("slice", "slice")
->assert_op_attr_satisfied<std::vector<int>>(
"axes",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 1;
})
->assert_op_attr_satisfied<std::vector<int>>(
"starts",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 0;
})
->assert_op_attr_satisfied<std::vector<int>>(
"ends", [](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] > 0 && attr[0] <= 20;
});
if (pre_ln_) {
xpu_encoder->assert_op_attr<bool>("norm_before", true);
encoder_out->assert_is_op_input("layer_norm", "X");
Expand Down
5 changes: 3 additions & 2 deletions lite/kernels/xpu/__xpu__multi_encoder_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,10 @@ void XPUMultiEncoderCompute::PrepareForRun() {
}
// prepare with sice
if ((param.slice_starts.size() > 0 && param.slice_starts[0] == 0) &&
(param.slice_ends.size() > 0 && param.slice_ends[0] == 1) &&
(param.slice_ends.size() > 0 && param.slice_ends[0] > 0 &&
param.slice_ends[0] <= 20) &&
(param.slice_axes.size() > 0 && param.slice_axes[0] == 1)) {
slice_idx = 0;
slice_idx = param.slice_ends[0] - 1;
}
// prepare input_cast and output_cast guard_
cast_in_guard_ = TargetWrapperXPU::MallocScratchPad(4 * 1024 * 1024);
Expand Down
8 changes: 5 additions & 3 deletions lite/operators/__xpu__multi_encoder_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ bool XPUMultiEncoderOp::InferShapeImpl() const {
seq_len = param_.PadSeqLen->data<int>()[0];
}
if ((param_.slice_starts.size() > 0 && param_.slice_starts[0] == 0) &&
(param_.slice_ends.size() > 0 && param_.slice_ends[0] == 1) &&
(param_.slice_ends.size() > 0 && param_.slice_ends[0] > 0 &&
param_.slice_ends[0] <= 20) &&
(param_.slice_axes.size() > 0 && param_.slice_axes[0] == 1)) {
DDim out_dims(std::vector<int64_t>({batch_size, 1, head_num}));
DDim out_dims(
std::vector<int64_t>({batch_size, param_.slice_ends[0], head_num}));
if (param_.slice_decrease_axis.size() > 0) {
std::vector<int64_t> new_out_shape;
for (size_t i = 0; i < slice_decrease_axis.size(); ++i) {
Expand All @@ -65,7 +67,7 @@ bool XPUMultiEncoderOp::InferShapeImpl() const {
out_dims = new_dims;
}
if (param_.norm_before) {
param_.output->Resize({batch_size, 1, head_num});
param_.output->Resize({batch_size, param_.slice_ends[0], head_num});
} else {
param_.output->Resize(out_dims);
}
Expand Down

0 comments on commit c809775

Please sign in to comment.