Skip to content

Commit

Permalink
Fuse bias to horizontally fused fc
Browse files Browse the repository at this point in the history
  • Loading branch information
yeonbok committed Sep 14, 2024
1 parent fca9341 commit 3144fcf
Showing 1 changed file with 58 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
return std::dynamic_pointer_cast<op::Placeholder>(node);
};
// Three FCs connected to the same input
const int min_num_fcs_to_fuse = 3;
const int max_num_fcs_to_fuse = 3;
const auto& fc = std::dynamic_pointer_cast<op::FullyConnectedCompressed>(output.get_node_shared_ptr());
const auto& input = fc->get_input_node_shared_ptr(0);
Expand All @@ -65,9 +66,7 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
}
user_fc_count++;
}
if (std::getenv("MLP_FUSION") == nullptr && user_fc_count != 3)
return false;
return (user_fc_count > 1) && (user_fc_count <= max_num_fcs_to_fuse) &&
return (user_fc_count >= min_num_fcs_to_fuse) && (user_fc_count <= max_num_fcs_to_fuse) &&
(nodes_with_bias == static_cast<int32_t>(user_fc_count) || nodes_with_bias == 0) &&
(nodes_with_zp == static_cast<int32_t>(user_fc_count) || nodes_with_zp == 0);
};
Expand Down Expand Up @@ -98,23 +97,29 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
zp_nodes.push_back(fc_user->get_input_node_shared_ptr(4));
}
}
auto weight_dtype = fc_nodes[0]->get_input_element_type(1);
auto k_size = fc_nodes[0]->get_input_shape(1)[fc_nodes[0]->get_input_shape(1).size() - 1];
// fc weight is already transposed to [N, K]
const size_t weight_idx = 1;
if (fc_nodes[0]->get_input_shape(weight_idx).size() != 2)
return false;
const size_t n_axis = 0;
const size_t k_axis = 1;
auto weight_dtype = fc_nodes[0]->get_input_element_type(weight_idx);
auto k_size = fc_nodes[0]->get_input_shape(weight_idx)[k_axis];
std::vector<int64_t> orig_n_sizes;
// merge weights, scale, zp
for (auto fc : fc_nodes) {
if (k_size != fc->get_input_shape(1)[fc->get_input_shape(1).size() - 1])
if (k_size != fc->get_input_shape(weight_idx)[k_axis])
return false;
if (weight_dtype != fc->get_input_element_type(1))
if (weight_dtype != fc->get_input_element_type(weight_idx))
return false;
orig_n_sizes.push_back(fc->get_input_shape(1)[fc->get_input_shape(1).size() - 2]);
orig_n_sizes.push_back(fc->get_input_shape(weight_idx)[n_axis]);
}
ov::OutputVector weight_nodes_as_output_vector;
for (size_t i = 0; i < weight_nodes.size(); ++i) {
weight_nodes_as_output_vector.push_back(weight_nodes[i]);
}
auto fused_weight = std::make_shared<ov::op::v0::Concat>(weight_nodes_as_output_vector, 0);
fused_weight->set_friendly_name(weight_nodes[0]->get_friendly_name() + "_fused");
fused_weight->set_friendly_name(weight_nodes[0]->get_friendly_name() + "_fused_weight");
ov::copy_runtime_info(weight_nodes, fused_weight);

ov::OutputVector scales_as_output_vector;
Expand All @@ -123,17 +128,52 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
}

auto fused_scale = std::make_shared<ov::op::v0::Concat>(scales_as_output_vector, 0);
fused_scale->set_friendly_name(scale_nodes[0]->get_friendly_name() + "_fused");
fused_scale->set_friendly_name(scale_nodes[0]->get_friendly_name() + "_fused_scale");
ov::copy_runtime_info(scale_nodes, fused_scale);
// check if all of the fc has a bias user, set it as bias input
size_t n_bias_users = 0;
for (auto fc : fc_nodes) {
if (fc->get_users().size() == 1
&& fc->get_users()[0]->get_type_info() == ov::opset1::Add::get_type_info_static()
&& ov::is_type<ov::op::v0::Constant>(fc->get_users()[0]->inputs()[1].get_source_output().get_node())) {
n_bias_users++;
}
}

if (bias_nodes.empty() && n_bias_users == fc_nodes.size()) {
// set user as bias
for (size_t i = 0; i < fc_nodes.size(); ++i) {
auto orig_fc = fc_nodes[i];
auto bias_node = orig_fc->get_users()[0];
auto bias_const = orig_fc->get_users()[0]->input_value(1);
auto orig_users_of_bias_user = bias_node->get_users();
ov::OutputVector fc_inputs = orig_fc->input_values();
fc_inputs[2] = bias_const;
auto new_fc = orig_fc->clone_with_new_inputs(fc_inputs);
new_fc->set_friendly_name(orig_fc->get_friendly_name() + "_with_bias");
ov::copy_runtime_info(orig_fc, new_fc);
for (auto u : orig_users_of_bias_user) {
for (size_t idx = 0; idx < u->inputs().size(); ++idx) {
if (u->get_input_node_shared_ptr(idx) == bias_node) {
u->input(idx).replace_source_output(new_fc->output(0));
}
}
}
fc_nodes[i] = std::dynamic_pointer_cast<op::FullyConnectedCompressed>(new_fc);
bias_nodes.push_back(fc_nodes[i]->get_input_node_shared_ptr(2));
bias_node->clear_control_dependencies();
orig_fc->clear_control_dependencies();
}
}
std::shared_ptr<ov::Node> fused_bias;
if (bias_nodes.size() > 1) {
ov::OutputVector bias_nodes_as_output_vector;
for (size_t i = 0; i < bias_nodes.size(); ++i) {
bias_nodes_as_output_vector.push_back(bias_nodes[i]);
}
fused_bias = std::make_shared<ov::op::v0::Concat>(bias_nodes_as_output_vector, 0);
fused_bias->set_friendly_name(bias_nodes[0]->get_friendly_name() + "_fused");
const auto bias_concat_axis = 2;
fused_bias = std::make_shared<ov::op::v0::Concat>(bias_nodes_as_output_vector, bias_concat_axis);
fused_bias->set_friendly_name(bias_nodes[0]->get_friendly_name() + "_fused_bias");
ov::copy_runtime_info(bias_nodes, fused_bias);
} else {
fused_bias = std::make_shared<op::Placeholder>();
Expand Down Expand Up @@ -174,11 +214,12 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
return false;
}
} else {
auto zp_nodes_as_output_vector = ov::OutputVector{zp_nodes[0], zp_nodes[1]};
if (fc_nodes.size() == 3)
zp_nodes_as_output_vector.push_back(zp_nodes[2]);
ov::OutputVector zp_nodes_as_output_vector;
for (size_t i = 0; i < zp_nodes.size(); ++i) {
zp_nodes_as_output_vector.push_back(zp_nodes[i]);
}
fused_zps = std::make_shared<ov::op::v0::Concat>(zp_nodes_as_output_vector, 0);
fused_zps->set_friendly_name(zp_nodes[0]->get_friendly_name() + "_fused");
fused_zps->set_friendly_name(zp_nodes[0]->get_friendly_name() + "_fused_zps");
}
}
// Create new fc with merged weights, bias, scale, zp
Expand All @@ -197,7 +238,7 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
fused_scale,
fc_nodes[0]->get_output_type());

auto new_fc_name = fc_nodes[0]->get_friendly_name() + "_fused";
auto new_fc_name = fc_nodes[0]->get_friendly_name() + "_fused_" + std::to_string(fc_nodes.size()) + "FCs";
new_fc->set_friendly_name(new_fc_name);
copy_runtime_info(fc_nodes_vec, new_fc);

Expand Down

0 comments on commit 3144fcf

Please sign in to comment.