Skip to content

Commit

Permalink
More refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
yeonbok committed Sep 16, 2024
1 parent 50f4315 commit 1c59230
Showing 1 changed file with 45 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
#include "intel_gpu/op/placeholder.hpp"
#include "intel_gpu/runtime/debug_configuration.hpp"

namespace ov {
namespace intel_gpu {
Expand Down Expand Up @@ -140,38 +141,63 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
}
}

size_t bias_concat_axis = 0;
if (bias_nodes.empty() && n_bias_users == fc_nodes.size()) {
// set user as bias
// Set Add user as bias input to FC
for (size_t i = 0; i < fc_nodes.size(); ++i) {
auto orig_fc = fc_nodes[i];
auto bias_node = orig_fc->get_users()[0];
auto bias_const = orig_fc->get_users()[0]->input_value(1);
auto orig_users_of_bias_user = bias_node->get_users();
ov::OutputVector fc_inputs = orig_fc->input_values();
fc_inputs[2] = bias_const;
auto new_fc = orig_fc->clone_with_new_inputs(fc_inputs);
new_fc->set_friendly_name(orig_fc->get_friendly_name() + "_with_bias");
ov::copy_runtime_info(orig_fc, new_fc);
for (auto u : orig_users_of_bias_user) {
for (size_t idx = 0; idx < u->inputs().size(); ++idx) {
if (u->get_input_node_shared_ptr(idx) == bias_node) {
u->input(idx).replace_source_output(new_fc->output(0));
auto bias_const_ptr = orig_fc->get_users()[0]->get_input_node_shared_ptr(1);
bias_nodes.push_back(bias_const_ptr);
}
// Check shape and find axis
const auto bias_rank = bias_nodes[0]->get_output_partial_shape(0).size();
std::vector<int32_t> bias_add_shape_diffs(bias_rank, 0);
for (size_t i = 1; i < bias_nodes.size(); ++i) {
for (size_t j = 0; j < bias_rank; ++j) {
bias_add_shape_diffs[j] += (bias_nodes[i]->get_output_shape(0)[j] - bias_nodes[i - 1]->get_output_shape(0)[j]);
}
}
auto non_zero_diffs = std::count_if(bias_add_shape_diffs.begin(), bias_add_shape_diffs.end(), [](int32_t diff) { return diff != 0; });
if (non_zero_diffs <= 1) {
for (size_t i = 0; i < bias_rank; ++i) {
if (bias_add_shape_diffs[i] != 0)
bias_concat_axis = i;
}
for (size_t i = 0; i < fc_nodes.size(); ++i) {
auto orig_fc = fc_nodes[i];
auto bias_node = orig_fc->get_users()[0];
GPU_DEBUG_TRACE_DETAIL << "Set Add op user " << bias_node->get_friendly_name() << " as the FC "
<< orig_fc->get_friendly_name() << "'s bias input" << std::endl;
auto bias_const = orig_fc->get_users()[0]->input_value(1);
auto orig_users_of_bias_user = bias_node->get_users();
ov::OutputVector fc_inputs = orig_fc->input_values();
fc_inputs[2] = bias_const;
auto new_fc = orig_fc->clone_with_new_inputs(fc_inputs);
new_fc->set_friendly_name(orig_fc->get_friendly_name() + "_with_bias");
ov::copy_runtime_info(orig_fc, new_fc);
for (auto u : orig_users_of_bias_user) {
for (size_t idx = 0; idx < u->inputs().size(); ++idx) {
if (u->get_input_node_shared_ptr(idx) == bias_node) {
u->input(idx).replace_source_output(new_fc->output(0));
}
}
}
fc_nodes[i] = std::dynamic_pointer_cast<op::FullyConnectedCompressed>(new_fc);
bias_node->clear_control_dependencies();
orig_fc->clear_control_dependencies();
}
fc_nodes[i] = std::dynamic_pointer_cast<op::FullyConnectedCompressed>(new_fc);
bias_nodes.push_back(fc_nodes[i]->get_input_node_shared_ptr(2));
bias_node->clear_control_dependencies();
orig_fc->clear_control_dependencies();
} else {
// biases cannot be fusable. Not to set users as bias input
bias_nodes.clear();
}
}
std::shared_ptr<ov::Node> fused_bias;
if (bias_nodes.size() > 1) {
if (bias_nodes.size() == fc_nodes.size()) {
ov::OutputVector bias_nodes_as_output_vector;
for (size_t i = 0; i < bias_nodes.size(); ++i) {
bias_nodes_as_output_vector.push_back(bias_nodes[i]);
}
const auto bias_concat_axis = 2;
fused_bias = std::make_shared<ov::op::v0::Concat>(bias_nodes_as_output_vector, bias_concat_axis);
fused_bias->set_friendly_name(bias_nodes[0]->get_friendly_name() + "_fused_bias");
ov::copy_runtime_info(bias_nodes, fused_bias);
Expand Down Expand Up @@ -261,7 +287,7 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
}
org_fc->clear_control_dependencies();
}
std::cout << "new fc: " << new_fc_name << std::endl;
GPU_DEBUG_TRACE_DETAIL << "Created a new fused FC " << new_fc_name << std::endl;
return true;
};

Expand Down

0 comments on commit 1c59230

Please sign in to comment.