diff --git a/src/par/src/Partitioner.cpp b/src/par/src/Partitioner.cpp index 9a01437d8c..ea7475ed9f 100644 --- a/src/par/src/Partitioner.cpp +++ b/src/par/src/Partitioner.cpp @@ -176,15 +176,30 @@ void Partitioner::RandomPart(const HGraphPtr& hgraph, vertices.insert(vertices.begin(), path_vertices.begin(), path_vertices.end()); if (vile_mode == false) { + std::vector total_vertices_weight(hgraph->GetVertexDimensions(), + 0.0f); + for (const auto& v : vertices) { + total_vertices_weight + = total_vertices_weight + hgraph->GetVertexWeights(v); + } + // try to generate balanced random partitioning int block_id = 0; for (const auto& v : vertices) { solution[v] = block_id; block_balance[block_id] = block_balance[block_id] + hgraph->GetVertexWeights(v); + const int previous_block_id = block_id; if (block_balance[block_id] >= lower_block_balance[block_id]) { block_id++; block_id = block_id % num_parts_; // adjust the block_id + if (block_balance[previous_block_id] == total_vertices_weight) { + solution[v] = block_id; + block_balance[block_id] + = block_balance[block_id] + hgraph->GetVertexWeights(v); + block_balance[previous_block_id] + = block_balance[previous_block_id] - hgraph->GetVertexWeights(v); + } } } } else { @@ -195,10 +210,11 @@ void Partitioner::RandomPart(const HGraphPtr& hgraph, bool stop_flag = false; for (const auto& v : vertices) { solution[v] = block_id; + const std::vector previous_block_balance = block_balance[block_id]; block_balance[block_id] = block_balance[block_id] + hgraph->GetVertexWeights(v); if (block_balance[block_id] >= upper_block_balance[block_id] - && stop_flag == false) { + && stop_flag == false && !equal(previous_block_balance, 0.0f)) { block_id++; solution[v] = block_id; // move the vertex to next block if (block_id == num_parts_ - 1) { diff --git a/src/par/src/Utilities.cpp b/src/par/src/Utilities.cpp index 34c246e210..705f4e581c 100644 --- a/src/par/src/Utilities.cpp +++ b/src/par/src/Utilities.cpp @@ -139,6 +139,17 @@ void Accumulate(std::vector& a, const std::vector& b) std::transform(a.begin(), a.end(), b.begin(), a.begin(), std::plus()); } +bool equal(const std::vector& vertex_weights, const float value) +{ + for (const float dimension_weight : vertex_weights) { + if (dimension_weight != value) { + return false; + } + } + + return true; +} + // weighted sum std::vector WeightedSum(const std::vector& a, const float a_factor, diff --git a/src/par/src/Utilities.h b/src/par/src/Utilities.h index 1dd4851ff8..07da6ae54f 100644 --- a/src/par/src/Utilities.h +++ b/src/par/src/Utilities.h @@ -94,6 +94,8 @@ std::vector SplitLine(const std::string& line); // Add right vector to left vector void Accumulate(std::vector& a, const std::vector& b); +bool equal(const std::vector& vertex_weights, const float value); + // weighted sum std::vector WeightedSum(const std::vector& a, float a_factor,