From 82021a3a17e1416dc251e2f5fb2dadd431e44f64 Mon Sep 17 00:00:00 2001 From: Zhang Yi Date: Thu, 21 Mar 2024 18:33:57 +0800 Subject: [PATCH] [CPU]Fix GPT-J RoPE fusion (#23519) ### Details: - *Support new RoPE pattern of GPT-J* - *Local test shows 17 % improvement for 2nd token latency for BF16 in `Intel(R) Xeon(R) Platinum 8468`* ### Tickets: - *CVS-134949* --- .../cpu_opset/common/pass/rope_fusion.cpp | 29 ++-- .../subgraph_tests/src/rotary_pos_emb.cpp | 143 ++++++++++++++++++ 2 files changed, 163 insertions(+), 9 deletions(-) diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.cpp index 839ccc1267d226..200659f5f2e322 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/rope_fusion.cpp @@ -387,8 +387,11 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto varsplit = makePattern({gather_sin_cos, -1, {ndims / 2, -1}}); varsplit->set_output_size(2); - auto unsqueeze_sin = makePattern({varsplit->output(0), {1, -1, 1, 32}}); - auto unsqueeze_cos = makePattern({varsplit->output(1), {1, -1, 1, 32}}); + // Reshape or UnSqueeze should both be support + auto unsqueeze_sin = makePattern({varsplit->output(0), {1, -1, 1, 32}}) | + makePattern({varsplit->output(0), 2}); + auto unsqueeze_cos = makePattern({varsplit->output(1), {1, -1, 1, 32}}) | + makePattern({varsplit->output(1), 2}); // repeate cos/sin table auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) { const auto& vec = node.get_vector(); @@ -402,9 +405,6 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto repeat_interleave_sin = makePattern({unsqueeze_sin, const_idx, 3}, {{"batch_dims", 0}}); auto repeat_interleave_cos = makePattern({unsqueeze_cos, const_idx, 3}, {{"batch_dims", 0}}); - auto t_cos = makePattern(ov::Rank(4)); - auto t_sin = makePattern(ov::Rank(4)); - // x interleave (-x[:,:,:, 1::2], x[:,:,:, 0::2]) auto slice_Slice_1174 = GenSlice(slice_Slice_965, 1, int32_max, 2, 3); @@ -418,13 +418,16 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto ShapeOf_169068 = makePattern({stack_1182}); auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0); auto flatten_Concat_1197 = makePattern({flatten_Slice_1194, {-1}}, {{"axis", 0}}); + // If with special zero, no need to use shapeof to get full shape auto flatten_Reshape_1198 = makePattern({stack_1182, flatten_Concat_1197}); + auto flatten_Reshape_Zero = + makePattern({stack_1182, ov::pass::pattern::any_input()}, {{"special_zero", true}}); // x*cos [B,L,H,ndims] auto mul_cos = makePattern({slice_Slice_965, repeat_interleave_cos}, {{"auto_broadcast", "numpy"}}); auto mul_sin = - makePattern({flatten_Reshape_1198, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}}); + makePattern({flatten_Reshape_1198 | flatten_Reshape_Zero, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}}); // *cos + *sin auto rotary_emb = makePattern({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}}); @@ -460,15 +463,12 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto new_node = std::make_shared(new_args, config); new_node->set_friendly_name(old_node->get_friendly_name()); ov::copy_runtime_info({pattern_map.at(varsplit).get_node_shared_ptr(), - pattern_map.at(unsqueeze_sin).get_node_shared_ptr(), - pattern_map.at(unsqueeze_cos).get_node_shared_ptr(), pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(), pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(), pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(), pattern_map.at(Unsqueeze_65524).get_node_shared_ptr(), pattern_map.at(Unsqueeze_65525).get_node_shared_ptr(), pattern_map.at(stack_1182).get_node_shared_ptr(), - pattern_map.at(flatten_Concat_1197).get_node_shared_ptr(), pattern_map.at(mul_cos).get_node_shared_ptr(), pattern_map.at(mul_sin).get_node_shared_ptr(), pattern_map.at(rotary_emb).get_node_shared_ptr(), @@ -476,6 +476,17 @@ ov::intel_cpu::RoPEFusionGPTJ::RoPEFusionGPTJ() { pattern_map.at(permute_Transpose_1213).get_node_shared_ptr()}, new_node); ov::replace_node(old_node, new_node); + // shapeof may be moved up from transpose to add, + // After RoPE fusion, shapeof must be moved to the data input of RoPE otherwise extra subgraph exists + std::shared_ptr rotary_emb_node = pattern_map.at(rotary_emb).get_node_shared_ptr(); + auto rotary_emb_out = rotary_emb_node->output(0); + if (rotary_emb_out.get_target_inputs().size() == 2) { + for (auto& input : rotary_emb_out.get_target_inputs()) { + if (ov::is_type(input.get_node())) { + input.replace_source_output(pattern_map.at(view_Reshape)); + } + } + } return true; }; diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/rotary_pos_emb.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/rotary_pos_emb.cpp index 3fdaadc8d4362e..d0bf420278e412 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/rotary_pos_emb.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/rotary_pos_emb.cpp @@ -457,5 +457,148 @@ TEST_F(RoPECPUTestQwen7b, smoke_CompareWithRefs) { CheckNumberOfNodesWithType(compiledModel, "RoPE", 1); } +class RoPECPUTestGPTJ : public SubgraphBaseTest, public testing::WithParamInterface { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + bool hasShapeOf; + hasShapeOf = obj.param; + std::ostringstream result; + result << "hasShapeOf=" << hasShapeOf << std::endl; + return result.str(); + } + void generate_inputs(const std::vector& targetInputStaticShapes) override { + const auto& funcInputs = function->inputs(); + + auto& input_shape = targetInputStaticShapes[0]; + auto& sincos_shape = targetInputStaticShapes[1]; + ov::Tensor t_input = + utils::create_and_fill_tensor(funcInputs[0].get_element_type(), input_shape, 2, -1.0f, 32768); + ov::Tensor t_cos_sin_cache = + utils::create_and_fill_tensor(funcInputs[1].get_element_type(), sincos_shape, 2, -1.0f, 32768); + + inputs.clear(); + inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); + inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos_sin_cache}); + } + +protected: + std::shared_ptr buildROPE_GPTJ(const int num_head, + const int hidden_dims, + const int rotary_dims, + bool hasShapeOf) { + auto int32_max = std::numeric_limits::max(); + auto input = + std::make_shared(ov::element::f32, PartialShape{-1, -1, num_head, hidden_dims}); + auto sincos = std::make_shared(ov::element::f32, PartialShape{-1, -1, rotary_dims}); + + auto slice_Slice_965 = + makeOP({input, {0, 0, 0, 0}, {0, 0, 0, rotary_dims}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + slice_Slice_965->set_friendly_name("slice_Slice_965"); + + auto varsplit = makeOP({sincos, -1, {rotary_dims / 2, -1}}); + varsplit->set_output_size(2); + varsplit->set_friendly_name("varsplit"); + auto unsqueeze_sin = makeOP({varsplit->output(0), 2}); + auto unsqueeze_cos = makeOP({varsplit->output(1), 2}); + std::vector gather_idx(rotary_dims, 1); + int32_t v = 0; + for (size_t i = 0; i < gather_idx.size(); i += 2, v++) { + gather_idx[i] = v; + gather_idx[i + 1] = v; + } + + auto const_idx = makeConst(ov::element::i32, ov::Shape({static_cast(rotary_dims)}), gather_idx); + auto constant_155588 = makeConst(element::f32, + ov::Shape({ + 1, + 1, + 1, + 1, + }), + {-1.000000f}); + auto repeat_interleave_sin = makeOP({unsqueeze_sin, const_idx, 3}, {{"batch_dims", 0}}); + auto repeat_interleave_cos = makeOP({unsqueeze_cos, const_idx, 3}, {{"batch_dims", 0}}); + repeat_interleave_sin->set_friendly_name("repeat_interleave_sin"); + repeat_interleave_cos->set_friendly_name("repeat_interleave_cos"); + // x interleave (-x[:,:,:, 1::2], x[:,:,:, 0::2]) + auto slice_Slice_1174 = + makeOP({slice_Slice_965, {0, 0, 0, 1}, {0, 0, 0, int32_max}, {1, 1, 1, 2}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto neg_Multiply_1177 = + makeOP({slice_Slice_1174, constant_155588}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_65524 = makeOP({neg_Multiply_1177, -1}); + + auto slice_Slice_1168 = + makeOP({slice_Slice_965, {0, 0, 0, 0}, {0, 0, 0, int32_max}, {1, 1, 1, 2}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto Unsqueeze_65525 = makeOP({slice_Slice_1168, -1}); + auto stack_1182 = makeOP({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}}); + auto flatten_Reshape_1198 = + makeOP({stack_1182, {0, 0, num_head, rotary_dims}}, {{"special_zero", true}}); + // x*cos [B,L,H,ndims] + auto mul_cos = + makeOP({slice_Slice_965, repeat_interleave_cos}, {{"auto_broadcast", "numpy"}}); + mul_cos->set_friendly_name("mul_cos"); + auto mul_sin = + makeOP({flatten_Reshape_1198, repeat_interleave_sin}, {{"auto_broadcast", "numpy"}}); + // *cos + *sin + auto rotary_emb = makeOP({mul_cos, mul_sin}, {{"auto_broadcast", "numpy"}}); + + auto slice_Slice_971 = + makeOP({input, {0, 0, 0, rotary_dims}, {0, 0, 0, int32_max}, {1, 1, 1, 1}}, + {{"begin_mask", {1, 1, 1, 0}}, + {"end_mask", {1, 1, 1, 0}}, + {"new_axis_mask", {}}, + {"shrink_axis_mask", {}}, + {"ellipsis_mask", {}}}); + auto cat_Concat_1211 = makeOP({rotary_emb, slice_Slice_971}, {{"axis", -1}}); + auto permute_Transpose_1213 = makeOP({cat_Concat_1211, {0, 2, 1, 3}}); + ov::NodeVector model_output = {permute_Transpose_1213}; + if (hasShapeOf) { + auto shapeOf = makeOP({rotary_emb}, {{"output_type", "i32"}}); + auto gather = makeOP({shapeOf, {1}, 0}, {{"batch_dims", 0}}); + model_output.push_back(gather); + } + return std::make_shared(model_output, ov::ParameterVector{input, sincos}); + } + void SetUp() override { + targetDevice = ov::test::utils::DEVICE_CPU; + bool hasShapeOf = this->GetParam(); + const int batch = 2; + const int seq_length = 7; + const int num_head = 16; + const int hidden_dims = 256; + const int rotary_dims = 64; + + InputShape input = {{batch, seq_length, num_head, hidden_dims}, {{batch, seq_length, num_head, hidden_dims}}}; + InputShape sincos = {{batch, seq_length, rotary_dims}, {{batch, seq_length, rotary_dims}}}; + init_input_shapes({input, sincos}); + function = buildROPE_GPTJ(num_head, hidden_dims, rotary_dims, hasShapeOf); + } +}; + +TEST_P(RoPECPUTestGPTJ, smoke_CompareWithRefs) { + run(); + CheckNumberOfNodesWithType(compiledModel, "RoPE", 1); +} + +INSTANTIATE_TEST_SUITE_P(smoke_RoPECPUTestGPTJ, + RoPECPUTestGPTJ, + ::testing::Values(true, false), + RoPECPUTestGPTJ::getTestCaseName); + } // namespace test } // namespace ov