Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU]Fix GPT-J RoPE fusion #23519

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,11 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() {

auto varsplit = makePattern<opset1::VariadicSplit>({gather_sin_cos, -1, {ndims / 2, -1}});
varsplit->set_output_size(2);
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {1, -1, 1, 32}});
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {1, -1, 1, 32}});
// Reshape or UnSqueeze should both be support
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {1, -1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(0), 2});
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {1, -1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(1), 2});
// repeate cos/sin table
auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) {
const auto& vec = node.get_vector<int32_t>();
Expand All @@ -402,9 +405,6 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() {
auto repeat_interleave_sin = makePattern<opset8::Gather>({unsqueeze_sin, const_idx, 3}, {{"batch_dims", 0}});
auto repeat_interleave_cos = makePattern<opset8::Gather>({unsqueeze_cos, const_idx, 3}, {{"batch_dims", 0}});

auto t_cos = makePattern(ov::Rank(4));
auto t_sin = makePattern(ov::Rank(4));

// x interleave (-x[:,:,:, 1::2], x[:,:,:, 0::2])
auto slice_Slice_1174 = GenSlice(slice_Slice_965, 1, int32_max, 2, 3);

Expand All @@ -418,13 +418,16 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() {
auto ShapeOf_169068 = makePattern<opset1::ShapeOf>({stack_1182});
auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0);
auto flatten_Concat_1197 = makePattern<opset1::Concat>({flatten_Slice_1194, {-1}}, {{"axis", 0}});
// If with special zero, no need to use shapeof to get full shape
auto flatten_Reshape_1198 = makePattern<opset1::Reshape>({stack_1182, flatten_Concat_1197});
auto flatten_Reshape_Zero =
makePattern<opset1::Reshape>({stack_1182, ov::pass::pattern::any_input()}, {{"special_zero", true}});

// x*cos [B,L,H,ndims]
auto mul_cos =
makePattern<opset1::Multiply>({slice_Slice_965, repeat_interleave_cos}, {{"auto_broadcast", "numpy"}});
auto mul_sin =
makePattern<opset1::Multiply>({flatten_Reshape_1198, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}});
makePattern<opset1::Multiply>({flatten_Reshape_1198 | flatten_Reshape_Zero, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}});

// *cos + *sin
auto rotary_emb = makePattern<opset1::Add>({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}});
Expand Down Expand Up @@ -460,22 +463,30 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() {
auto new_node = std::make_shared<RoPENode>(new_args, config);
new_node->set_friendly_name(old_node->get_friendly_name());
ov::copy_runtime_info({pattern_map.at(varsplit).get_node_shared_ptr(),
pattern_map.at(unsqueeze_sin).get_node_shared_ptr(),
pattern_map.at(unsqueeze_cos).get_node_shared_ptr(),
pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(),
pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(),
pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65524).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65525).get_node_shared_ptr(),
pattern_map.at(stack_1182).get_node_shared_ptr(),
pattern_map.at(flatten_Concat_1197).get_node_shared_ptr(),
pattern_map.at(mul_cos).get_node_shared_ptr(),
pattern_map.at(mul_sin).get_node_shared_ptr(),
pattern_map.at(rotary_emb).get_node_shared_ptr(),
pattern_map.at(cat_Concat_1211).get_node_shared_ptr(),
pattern_map.at(permute_Transpose_1213).get_node_shared_ptr()},
new_node);
ov::replace_node(old_node, new_node);
// shapeof may be moved up from transpose to add,
// After RoPE fusion, shapeof must be moved to the data input of RoPE otherwise extra subgraph exists
std::shared_ptr<ov::Node> rotary_emb_node = pattern_map.at(rotary_emb).get_node_shared_ptr();
auto rotary_emb_out = rotary_emb_node->output(0);
if (rotary_emb_out.get_target_inputs().size() == 2) {
for (auto& input : rotary_emb_out.get_target_inputs()) {
if (ov::is_type<opset1::ShapeOf>(input.get_node())) {
input.replace_source_output(pattern_map.at(view_Reshape));
}
}
}
Comment on lines +479 to +489
Copy link
Contributor

@usstq usstq Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for the record, this is tricky fix for a common issue: any unexpected branch out from the middle of a fusion pattern subgraph will create unexpected leftovers and introduces performance regressions.

This is a generic issue for pattern-matching, and it becomes serious issue when the pattern becomes more & more complex, since the possibility of such branch out will be much higher than in smaller patterns

return true;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,5 +457,148 @@ TEST_F(RoPECPUTestQwen7b, smoke_CompareWithRefs) {
CheckNumberOfNodesWithType(compiledModel, "RoPE", 1);
}

class RoPECPUTestGPTJ : public SubgraphBaseTest, public testing::WithParamInterface<bool> {
public:
static std::string getTestCaseName(const testing::TestParamInfo<bool>& obj) {
bool hasShapeOf;
hasShapeOf = obj.param;
std::ostringstream result;
result << "hasShapeOf=" << hasShapeOf << std::endl;
return result.str();
}
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
const auto& funcInputs = function->inputs();

auto& input_shape = targetInputStaticShapes[0];
auto& sincos_shape = targetInputStaticShapes[1];
ov::Tensor t_input =
utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768);
ov::Tensor t_cos_sin_cache =
utils::create_and_fill_tensor(funcInputs[1].get_element_type(), sincos_shape, 2, -1.0f, 32768);

inputs.clear();
inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input});
inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_sin_cache});
}

protected:
std::shared_ptr<ov::Model> buildROPE_GPTJ(const int num_head,
const int hidden_dims,
const int rotary_dims,
bool hasShapeOf) {
auto int32_max = std::numeric_limits<std::int32_t>::max();
auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, PartialShape{-1, -1, num_head, hidden_dims});
auto sincos = std::make_shared<ov::opset1::Parameter>(ov::element::f32, PartialShape{-1, -1, rotary_dims});

auto slice_Slice_965 =
makeOP<ov::op::v1::StridedSlice>({input, {0, 0, 0, 0}, {0, 0, 0, rotary_dims}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
slice_Slice_965->set_friendly_name("slice_Slice_965");

auto varsplit = makeOP<ov::op::v1::VariadicSplit>({sincos, -1, {rotary_dims / 2, -1}});
varsplit->set_output_size(2);
varsplit->set_friendly_name("varsplit");
auto unsqueeze_sin = makeOP<opset1::Unsqueeze>({varsplit->output(0), 2});
auto unsqueeze_cos = makeOP<opset1::Unsqueeze>({varsplit->output(1), 2});
std::vector<int32_t> gather_idx(rotary_dims, 1);
int32_t v = 0;
for (size_t i = 0; i < gather_idx.size(); i += 2, v++) {
gather_idx[i] = v;
gather_idx[i + 1] = v;
}

auto const_idx = makeConst(ov::element::i32, ov::Shape({static_cast<size_t>(rotary_dims)}), gather_idx);
auto constant_155588 = makeConst(element::f32,
ov::Shape({
1,
1,
1,
1,
}),
{-1.000000f});
auto repeat_interleave_sin = makeOP<opset8::Gather>({unsqueeze_sin, const_idx, 3}, {{"batch_dims", 0}});
auto repeat_interleave_cos = makeOP<opset8::Gather>({unsqueeze_cos, const_idx, 3}, {{"batch_dims", 0}});
repeat_interleave_sin->set_friendly_name("repeat_interleave_sin");
repeat_interleave_cos->set_friendly_name("repeat_interleave_cos");
// x interleave (-x[:,:,:, 1::2], x[:,:,:, 0::2])
auto slice_Slice_1174 =
makeOP<ov::op::v1::StridedSlice>({slice_Slice_965, {0, 0, 0, 1}, {0, 0, 0, int32_max}, {1, 1, 1, 2}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto neg_Multiply_1177 =
makeOP<opset1::Multiply>({slice_Slice_1174, constant_155588}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_65524 = makeOP<opset1::Unsqueeze>({neg_Multiply_1177, -1});

auto slice_Slice_1168 =
makeOP<ov::op::v1::StridedSlice>({slice_Slice_965, {0, 0, 0, 0}, {0, 0, 0, int32_max}, {1, 1, 1, 2}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto Unsqueeze_65525 = makeOP<opset1::Unsqueeze>({slice_Slice_1168, -1});
auto stack_1182 = makeOP<opset1::Concat>({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}});
auto flatten_Reshape_1198 =
makeOP<opset1::Reshape>({stack_1182, {0, 0, num_head, rotary_dims}}, {{"special_zero", true}});
// x*cos [B,L,H,ndims]
auto mul_cos =
makeOP<opset1::Multiply>({slice_Slice_965, repeat_interleave_cos}, {{"auto_broadcast", "numpy"}});
mul_cos->set_friendly_name("mul_cos");
auto mul_sin =
makeOP<opset1::Multiply>({flatten_Reshape_1198, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}});
// *cos + *sin
auto rotary_emb = makeOP<opset1::Add>({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}});

auto slice_Slice_971 =
makeOP<ov::op::v1::StridedSlice>({input, {0, 0, 0, rotary_dims}, {0, 0, 0, int32_max}, {1, 1, 1, 1}},
{{"begin_mask", {1, 1, 1, 0}},
{"end_mask", {1, 1, 1, 0}},
{"new_axis_mask", {}},
{"shrink_axis_mask", {}},
{"ellipsis_mask", {}}});
auto cat_Concat_1211 = makeOP<opset1::Concat>({rotary_emb, slice_Slice_971}, {{"axis", -1}});
auto permute_Transpose_1213 = makeOP<opset1::Transpose>({cat_Concat_1211, {0, 2, 1, 3}});
ov::NodeVector model_output = {permute_Transpose_1213};
if (hasShapeOf) {
auto shapeOf = makeOP<opset1::ShapeOf>({rotary_emb}, {{"output_type", "i32"}});
auto gather = makeOP<opset8::Gather>({shapeOf, {1}, 0}, {{"batch_dims", 0}});
model_output.push_back(gather);
}
return std::make_shared<ov::Model>(model_output, ov::ParameterVector{input, sincos});
}
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;
bool hasShapeOf = this->GetParam();
const int batch = 2;
const int seq_length = 7;
const int num_head = 16;
const int hidden_dims = 256;
const int rotary_dims = 64;

InputShape input = {{batch, seq_length, num_head, hidden_dims}, {{batch, seq_length, num_head, hidden_dims}}};
InputShape sincos = {{batch, seq_length, rotary_dims}, {{batch, seq_length, rotary_dims}}};
init_input_shapes({input, sincos});
function = buildROPE_GPTJ(num_head, hidden_dims, rotary_dims, hasShapeOf);
}
};

TEST_P(RoPECPUTestGPTJ, smoke_CompareWithRefs) {
run();
CheckNumberOfNodesWithType(compiledModel, "RoPE", 1);
}

INSTANTIATE_TEST_SUITE_P(smoke_RoPECPUTestGPTJ,
RoPECPUTestGPTJ,
::testing::Values(true, false),
RoPECPUTestGPTJ::getTestCaseName);

} // namespace test
} // namespace ov
Loading