Skip to content

Commit

Permalink
horizontal mlp fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
yeonbok committed Sep 12, 2024
1 parent 05bd0c7 commit e0e22fa
Showing 1 changed file with 34 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,11 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
return std::dynamic_pointer_cast<op::Placeholder>(node);
};
// Three FCs connected to the same input
const int num_fcs_to_fuse = 3;
const int max_num_fcs_to_fuse = 3;
const auto& fc = std::dynamic_pointer_cast<op::FullyConnectedCompressed>(output.get_node_shared_ptr());
const auto& input = fc->get_input_node_shared_ptr(0);
if (!fc->get_input_partial_shape(0).is_dynamic())
return false;
if (input->get_users().size() < num_fcs_to_fuse)
return false;
size_t user_fc_count = 0;
int32_t nodes_with_bias = 0;
int32_t nodes_with_zp = 0;
Expand All @@ -67,8 +65,11 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
}
user_fc_count++;
}
return (user_fc_count == num_fcs_to_fuse) && (nodes_with_bias == num_fcs_to_fuse || nodes_with_bias == 0) &&
(nodes_with_zp == num_fcs_to_fuse || nodes_with_zp == 0);
if (std::getenv("MLP_FUSION") == nullptr && user_fc_count != 3)
return false;
return (user_fc_count > 1) && (user_fc_count <= max_num_fcs_to_fuse) &&
(nodes_with_bias == static_cast<int32_t>(user_fc_count) || nodes_with_bias == 0) &&
(nodes_with_zp == static_cast<int32_t>(user_fc_count) || nodes_with_zp == 0);
};

auto target_fc = wrap_type<op::FullyConnectedCompressed>(is_target_pattern);
Expand All @@ -78,6 +79,7 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
auto m_fc = pattern_map.at(target_fc).get_node_shared_ptr();
auto input_node = m_fc->get_input_node_shared_ptr(0);
std::vector<std::shared_ptr<op::FullyConnectedCompressed>> fc_nodes;
ov::NodeVector fc_nodes_vec;
ov::NodeVector weight_nodes;
ov::NodeVector scale_nodes;
ov::NodeVector bias_nodes;
Expand All @@ -87,6 +89,7 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
if (fc_user) {
OPENVINO_ASSERT(fc_user->inputs().size() >= 4, "Compressed FC should have at least 4 inputs");
fc_nodes.push_back(fc_user);
fc_nodes_vec.push_back(fc_user);
weight_nodes.push_back(fc_user->get_input_node_shared_ptr(1));
if (!std::dynamic_pointer_cast<op::Placeholder>(fc_user->get_input_node_shared_ptr(2)))
bias_nodes.push_back(fc_user->get_input_node_shared_ptr(2));
Expand All @@ -106,22 +109,32 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
return false;
orig_n_sizes.push_back(fc->get_input_shape(1)[fc->get_input_shape(1).size() - 2]);
}
auto weight_nodes_as_output_vector = ov::OutputVector{weight_nodes[0], weight_nodes[1], weight_nodes[2]};
ov::OutputVector weight_nodes_as_output_vector;
for (size_t i = 0; i < weight_nodes.size(); ++i) {
weight_nodes_as_output_vector.push_back(weight_nodes[i]);
}
auto fused_weight = std::make_shared<ov::op::v0::Concat>(weight_nodes_as_output_vector, 0);
fused_weight->set_friendly_name(weight_nodes[0]->get_friendly_name() + "_fused");
ov::copy_runtime_info({weight_nodes[0], weight_nodes[1], weight_nodes[2]}, fused_weight);
ov::copy_runtime_info(weight_nodes, fused_weight);

auto scale_nodes_as_output_vector = ov::OutputVector{scale_nodes[0], scale_nodes[1], scale_nodes[2]};
auto fused_scale = std::make_shared<ov::op::v0::Concat>(scale_nodes_as_output_vector, 0);
ov::OutputVector scales_as_output_vector;
for (size_t i = 0; i < scale_nodes.size(); ++i) {
scales_as_output_vector.push_back(scale_nodes[i]);
}

auto fused_scale = std::make_shared<ov::op::v0::Concat>(scales_as_output_vector, 0);
fused_scale->set_friendly_name(scale_nodes[0]->get_friendly_name() + "_fused");
ov::copy_runtime_info({scale_nodes[0], scale_nodes[1], scale_nodes[2]}, fused_scale);
ov::copy_runtime_info(scale_nodes, fused_scale);

std::shared_ptr<ov::Node> fused_bias;
if (bias_nodes.size() == 3) {
auto bias_nodes_as_output_vector = ov::OutputVector{bias_nodes[0], bias_nodes[1], bias_nodes[2]};
if (bias_nodes.size() > 1) {
ov::OutputVector bias_nodes_as_output_vector;
for (size_t i = 0; i < bias_nodes.size(); ++i) {
bias_nodes_as_output_vector.push_back(bias_nodes[i]);
}
fused_bias = std::make_shared<ov::op::v0::Concat>(bias_nodes_as_output_vector, 0);
fused_bias->set_friendly_name(bias_nodes[0]->get_friendly_name() + "_fused");
ov::copy_runtime_info({bias_nodes[0], bias_nodes[1], bias_nodes[2]}, fused_bias);
ov::copy_runtime_info(bias_nodes, fused_bias);
} else {
fused_bias = std::make_shared<op::Placeholder>();
}
Expand Down Expand Up @@ -161,7 +174,9 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
return false;
}
} else {
auto zp_nodes_as_output_vector = ov::OutputVector{zp_nodes[0], zp_nodes[1], zp_nodes[2]};
auto zp_nodes_as_output_vector = ov::OutputVector{zp_nodes[0], zp_nodes[1]};
if (fc_nodes.size() == 3)
zp_nodes_as_output_vector.push_back(zp_nodes[2]);
fused_zps = std::make_shared<ov::op::v0::Concat>(zp_nodes_as_output_vector, 0);
fused_zps->set_friendly_name(zp_nodes[0]->get_friendly_name() + "_fused");
}
Expand All @@ -184,14 +199,15 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {

auto new_fc_name = fc_nodes[0]->get_friendly_name() + "_fused";
new_fc->set_friendly_name(new_fc_name);
copy_runtime_info({fc_nodes[0], fc_nodes[1], fc_nodes[2]}, new_fc);
copy_runtime_info(fc_nodes_vec, new_fc);

// Split output and connect to the orig users
auto split_name = fc_nodes[0]->get_friendly_name() + "_split";
auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {new_fc->get_output_partial_shape(0).size() - 1});
auto split_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, orig_n_sizes);
auto split_size = fc_nodes.size();
auto split_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{split_size}, orig_n_sizes);
auto output_split = std::make_shared<ov::op::v1::VariadicSplit>(new_fc, axis_const, split_const);
copy_runtime_info({fc_nodes[0], fc_nodes[1], fc_nodes[2]}, output_split);
copy_runtime_info(fc_nodes_vec, output_split);
output_split->set_friendly_name(split_name);
for (size_t i = 0; i < fc_nodes.size(); ++i) {
auto org_fc = fc_nodes[i];
Expand All @@ -204,6 +220,7 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
}
org_fc->clear_control_dependencies();
}
std::cout << "new fc: " << new_fc_name << std::endl;
return true;
};

Expand Down

0 comments on commit e0e22fa

Please sign in to comment.