Skip to content

Commit

Permalink
Fixed bug for calculating bias axis
Browse files Browse the repository at this point in the history
  • Loading branch information
yeonbok committed Sep 16, 2024
1 parent 1c59230 commit a828dec
Showing 1 changed file with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,18 +152,18 @@ FullyConnectedHorizontalFusion::FullyConnectedHorizontalFusion() {
}
// Check shape and find axis
const auto bias_rank = bias_nodes[0]->get_output_partial_shape(0).size();
std::vector<int32_t> bias_add_shape_diffs(bias_rank, 0);
for (size_t i = 1; i < bias_nodes.size(); ++i) {
for (size_t j = 0; j < bias_rank; ++j) {
bias_add_shape_diffs[j] += (bias_nodes[i]->get_output_shape(0)[j] - bias_nodes[i - 1]->get_output_shape(0)[j]);
size_t non_zero_diffs = 0;
for (size_t i = 0; i < bias_rank; ++i) {
std::unordered_set<size_t> dims;
for (size_t j = 0; j < bias_nodes.size(); ++j) {
dims.insert(bias_nodes[j]->get_output_partial_shape(0)[i].get_length());
}
if (dims.size() > 1) {
bias_concat_axis = i;
non_zero_diffs++;
}
}
auto non_zero_diffs = std::count_if(bias_add_shape_diffs.begin(), bias_add_shape_diffs.end(), [](int32_t diff) { return diff != 0; });
if (non_zero_diffs <= 1) {
for (size_t i = 0; i < bias_rank; ++i) {
if (bias_add_shape_diffs[i] != 0)
bias_concat_axis = i;
}
for (size_t i = 0; i < fc_nodes.size(); ++i) {
auto orig_fc = fc_nodes[i];
auto bias_node = orig_fc->get_users()[0];
Expand Down

0 comments on commit a828dec

Please sign in to comment.