diff --git a/src/plugins/intel_gpu/src/plugin/transformations/fc_horizontal_fusion.cpp b/src/plugins/intel_gpu/src/plugin/transformations/fc_horizontal_fusion.cpp index b366ce29bed672..fe496ae44c9087 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/fc_horizontal_fusion.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/fc_horizontal_fusion.cpp @@ -40,6 +40,7 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() { return std::dynamic_pointer_cast(node); }; // Three FCs connected to the same input + const int min_num_fcs_to_fuse = 3; const int max_num_fcs_to_fuse = 3; const auto& fc = std::dynamic_pointer_cast(output.get_node_shared_ptr()); const auto& input = fc->get_input_node_shared_ptr(0); @@ -65,9 +66,7 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() { } user_fc_count++; } - if (std::getenv("MLP_FUSION") == nullptr && user_fc_count != 3) - return false; - return (user_fc_count > 1) && (user_fc_count <= max_num_fcs_to_fuse) && + return (user_fc_count >= min_num_fcs_to_fuse) && (user_fc_count <= max_num_fcs_to_fuse) && (nodes_with_bias == static_cast(user_fc_count) || nodes_with_bias == 0) && (nodes_with_zp == static_cast(user_fc_count) || nodes_with_zp == 0); }; @@ -98,23 +97,29 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() { zp_nodes.push_back(fc_user->get_input_node_shared_ptr(4)); } } - auto weight_dtype = fc_nodes[0]->get_input_element_type(1); - auto k_size = fc_nodes[0]->get_input_shape(1)[fc_nodes[0]->get_input_shape(1).size() - 1]; + // fc weight is already transposed to [N, K] + const size_t weight_idx = 1; + if (fc_nodes[0]->get_input_shape(weight_idx).size() != 2) + return false; + const size_t n_axis = 0; + const size_t k_axis = 1; + auto weight_dtype = fc_nodes[0]->get_input_element_type(weight_idx); + auto k_size = fc_nodes[0]->get_input_shape(weight_idx)[k_axis]; std::vector orig_n_sizes; // merge weights, scale, zp for (auto fc : fc_nodes) { - if (k_size != fc->get_input_shape(1)[fc->get_input_shape(1).size() - 1]) + if (k_size != fc->get_input_shape(weight_idx)[k_axis]) return false; - if (weight_dtype != fc->get_input_element_type(1)) + if (weight_dtype != fc->get_input_element_type(weight_idx)) return false; - orig_n_sizes.push_back(fc->get_input_shape(1)[fc->get_input_shape(1).size() - 2]); + orig_n_sizes.push_back(fc->get_input_shape(weight_idx)[n_axis]); } ov::OutputVector weight_nodes_as_output_vector; for (size_t i = 0; i < weight_nodes.size(); ++i) { weight_nodes_as_output_vector.push_back(weight_nodes[i]); } auto fused_weight = std::make_shared(weight_nodes_as_output_vector, 0); - fused_weight->set_friendly_name(weight_nodes[0]->get_friendly_name() + "_fused"); + fused_weight->set_friendly_name(weight_nodes[0]->get_friendly_name() + "_fused_weight"); ov::copy_runtime_info(weight_nodes, fused_weight); ov::OutputVector scales_as_output_vector; @@ -123,17 +128,52 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() { } auto fused_scale = std::make_shared(scales_as_output_vector, 0); - fused_scale->set_friendly_name(scale_nodes[0]->get_friendly_name() + "_fused"); + fused_scale->set_friendly_name(scale_nodes[0]->get_friendly_name() + "_fused_scale"); ov::copy_runtime_info(scale_nodes, fused_scale); + // check if all of the fc has a bias user, set it as bias input + size_t n_bias_users = 0; + for (auto fc : fc_nodes) { + if (fc->get_users().size() == 1 + && fc->get_users()[0]->get_type_info() == ov::opset1::Add::get_type_info_static() + && ov::is_type(fc->get_users()[0]->inputs()[1].get_source_output().get_node())) { + n_bias_users++; + } + } + if (bias_nodes.empty() && n_bias_users == fc_nodes.size()) { + // set user as bias + for (size_t i = 0; i < fc_nodes.size(); ++i) { + auto orig_fc = fc_nodes[i]; + auto bias_node = orig_fc->get_users()[0]; + auto bias_const = orig_fc->get_users()[0]->input_value(1); + auto orig_users_of_bias_user = bias_node->get_users(); + ov::OutputVector fc_inputs = orig_fc->input_values(); + fc_inputs[2] = bias_const; + auto new_fc = orig_fc->clone_with_new_inputs(fc_inputs); + new_fc->set_friendly_name(orig_fc->get_friendly_name() + "_with_bias"); + ov::copy_runtime_info(orig_fc, new_fc); + for (auto u : orig_users_of_bias_user) { + for (size_t idx = 0; idx < u->inputs().size(); ++idx) { + if (u->get_input_node_shared_ptr(idx) == bias_node) { + u->input(idx).replace_source_output(new_fc->output(0)); + } + } + } + fc_nodes[i] = std::dynamic_pointer_cast(new_fc); + bias_nodes.push_back(fc_nodes[i]->get_input_node_shared_ptr(2)); + bias_node->clear_control_dependencies(); + orig_fc->clear_control_dependencies(); + } + } std::shared_ptr fused_bias; if (bias_nodes.size() > 1) { ov::OutputVector bias_nodes_as_output_vector; for (size_t i = 0; i < bias_nodes.size(); ++i) { bias_nodes_as_output_vector.push_back(bias_nodes[i]); } - fused_bias = std::make_shared(bias_nodes_as_output_vector, 0); - fused_bias->set_friendly_name(bias_nodes[0]->get_friendly_name() + "_fused"); + const auto bias_concat_axis = 2; + fused_bias = std::make_shared(bias_nodes_as_output_vector, bias_concat_axis); + fused_bias->set_friendly_name(bias_nodes[0]->get_friendly_name() + "_fused_bias"); ov::copy_runtime_info(bias_nodes, fused_bias); } else { fused_bias = std::make_shared(); @@ -174,11 +214,12 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() { return false; } } else { - auto zp_nodes_as_output_vector = ov::OutputVector{zp_nodes[0], zp_nodes[1]}; - if (fc_nodes.size() == 3) - zp_nodes_as_output_vector.push_back(zp_nodes[2]); + ov::OutputVector zp_nodes_as_output_vector; + for (size_t i = 0; i < zp_nodes.size(); ++i) { + zp_nodes_as_output_vector.push_back(zp_nodes[i]); + } fused_zps = std::make_shared(zp_nodes_as_output_vector, 0); - fused_zps->set_friendly_name(zp_nodes[0]->get_friendly_name() + "_fused"); + fused_zps->set_friendly_name(zp_nodes[0]->get_friendly_name() + "_fused_zps"); } } // Create new fc with merged weights, bias, scale, zp @@ -197,7 +238,7 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() { fused_scale, fc_nodes[0]->get_output_type()); - auto new_fc_name = fc_nodes[0]->get_friendly_name() + "_fused"; + auto new_fc_name = fc_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(fc_nodes.size()) + "FCs"; new_fc->set_friendly_name(new_fc_name); copy_runtime_info(fc_nodes_vec, new_fc);