From 0781e220527fe995d6df23b051b09815c6a319e5 Mon Sep 17 00:00:00 2001 From: Nikolai Maas Date: Tue, 5 Sep 2023 14:18:40 +0200 Subject: [PATCH] [LP] fix gain cache updates & refactoring --- .../label_propagation_refiner.cpp | 120 ++++++++++-------- .../label_propagation_refiner.h | 7 + 2 files changed, 75 insertions(+), 52 deletions(-) diff --git a/mt-kahypar/partition/refinement/label_propagation/label_propagation_refiner.cpp b/mt-kahypar/partition/refinement/label_propagation/label_propagation_refiner.cpp index 41f4a8ea0..f15a65c8d 100644 --- a/mt-kahypar/partition/refinement/label_propagation/label_propagation_refiner.cpp +++ b/mt-kahypar/partition/refinement/label_propagation/label_propagation_refiner.cpp @@ -113,6 +113,7 @@ namespace mt_kahypar { Metrics& best_metrics, vec>& rebalance_moves_by_part) { Metrics current_metrics = best_metrics; + const bool should_update_gain_cache = !PartitionedHypergraph::is_graph && _gain_cache.isInitialized(); _visited_he.reset(); _next_active.reset(); _gain.reset(); @@ -131,7 +132,7 @@ namespace mt_kahypar { for ( size_t j = 0; j < _active_nodes.size(); ++j ) { const HypernodeID hn = _active_nodes[j]; if ( moveVertex(hypergraph, hn, next_active_nodes, objective_delta) ) { - _active_node_was_moved[j] = uint8_t(true); + if (should_update_gain_cache) { _active_node_was_moved[j] = uint8_t(true); } } else { converged = false; } @@ -143,7 +144,7 @@ namespace mt_kahypar { tbb::parallel_for(UL(0), _active_nodes.size(), [&](const size_t& j) { const HypernodeID hn = _active_nodes[j]; if ( moveVertex(hypergraph, hn, next_active_nodes, objective_delta) ) { - _active_node_was_moved[j] = uint8_t(true); + if (should_update_gain_cache) { _active_node_was_moved[j] = uint8_t(true); } } else { converged = false; } @@ -153,7 +154,7 @@ namespace mt_kahypar { current_metrics.imbalance = metrics::imbalance(hypergraph, _context); current_metrics.quality += _gain.delta(); - if ( _gain_cache.isInitialized() ) { + if ( should_update_gain_cache ) { auto recompute = [&](size_t j) { if ( _active_node_was_moved[j] ) { _gain_cache.recomputeInvalidTerms(hypergraph, _active_nodes[j]); @@ -170,12 +171,13 @@ namespace mt_kahypar { } } - if (unconstrained && !metrics::isBalanced(hypergraph, _context)) { - bool should_stop = applyRebalancing(hypergraph, next_active_nodes, best_metrics, - current_metrics, rebalance_moves_by_part); - converged |= should_stop; - } else if (unconstrained) { - updateNodeData(hypergraph, next_active_nodes); + if constexpr ( unconstrained ) { + if (!metrics::isBalanced(hypergraph, _context)) { + converged = applyRebalancing(hypergraph, next_active_nodes, best_metrics, + current_metrics, rebalance_moves_by_part); + } else { + updateNodeData(hypergraph, next_active_nodes, false); + } } ASSERT(current_metrics.quality <= best_metrics.quality); @@ -200,76 +202,90 @@ namespace mt_kahypar { timer.stop_timer("rebalance_lp"); DBG << "[LP] Imbalance after rebalancing: " << current_metrics.imbalance << ", quality: " << current_metrics.quality; + const bool should_update_gain_cache = !PartitionedHypergraph::is_graph && _gain_cache.isInitialized(); if (current_metrics.quality > best_metrics.quality) { // rollback and stop LP auto noop_obj_fn = [](const SynchronizedEdgeUpdate&) { }; current_metrics = best_metrics; - if ( _context.refinement.label_propagation.execute_sequential ) { - for (const HypernodeID hn : hypergraph.nodes()) { - if (hypergraph.partID(hn) != _old_part[hn]) { - changeNodePart(hypergraph, hn, hypergraph.partID(hn), _old_part[hn], noop_obj_fn); + // rollback all changes and update gain cache + forEachMovedNode(hypergraph, [&](const HypernodeID hn) { + const PartitionID old_part = _old_part[hn]; + if (hypergraph.partID(hn) != old_part && old_part != kInvalidPartition) { + changeNodePart(hypergraph, hn, hypergraph.partID(hn), old_part, noop_obj_fn); + if (should_update_gain_cache) { + // slightly hacky: use kInvalidPartition to mark nodes that need to be updated + _old_part[hn] = kInvalidPartition; } } - } else { - hypergraph.doParallelForAllNodes([&](const HypernodeID hn) { - if (hypergraph.partID(hn) != _old_part[hn]) { - changeNodePart(hypergraph, hn, hypergraph.partID(hn), _old_part[hn], noop_obj_fn); - } - }); + }, &rebalance_moves_by_part); + if (should_update_gain_cache) { + forEachMovedNode(hypergraph, [&](const HypernodeID hn) { + _gain_cache.recomputeInvalidTerms(hypergraph, hn); + }, &rebalance_moves_by_part); } return true; } - // collect activated nodes and recompute penalties - updateNodeData(hypergraph, next_active_nodes, &rebalance_moves_by_part); + updateNodeData(hypergraph, next_active_nodes, should_update_gain_cache, &rebalance_moves_by_part); return false; } template void LabelPropagationRefiner::updateNodeData(PartitionedHypergraph& hypergraph, NextActiveNodes& next_active_nodes, + bool should_update_gain_cache, vec>* rebalance_moves_by_part) { - auto update_node = [&](const HypernodeID hn) { - const PartitionID current_part = hypergraph.partID(hn); - if (current_part != _old_part[hn]) { - activateNodeAndNeighbors(hypergraph, next_active_nodes, hn, false); - if ( _gain_cache.isInitialized() ) { - _gain_cache.recomputeInvalidTerms(hypergraph, hn); - } + // collect activated nodes and update gain cache + forEachMovedNode(hypergraph, [&](const HypernodeID hn) { + activateNodeAndNeighbors(hypergraph, next_active_nodes, hn, false); + if (should_update_gain_cache) { + _gain_cache.recomputeInvalidTerms(hypergraph, hn); } - }; - auto reset_node = [&](const HypernodeID hn) { + }, rebalance_moves_by_part); + // store current part of each node (required for rollback) + forEachMovedNode(hypergraph, [&](const HypernodeID hn) { _old_part[hn] = hypergraph.partID(hn); - }; + }, rebalance_moves_by_part); + } - auto for_each_moved_node = [&](auto apply) { + template + template + void LabelPropagationRefiner::forEachMovedNode(const PartitionedHypergraph& hypergraph, + F node_fn, + const vec>* rebalance_moves_by_part) { + if (rebalance_moves_by_part != nullptr) { if ( _context.refinement.label_propagation.execute_sequential ) { - for (const HypernodeID hn: _active_nodes) { - apply(hn); + for (const auto& moves: *rebalance_moves_by_part) { + for (const Move& m: moves) { + node_fn(m.node); + } } } else { - tbb::parallel_for(UL(0), _active_nodes.size(), [&](const size_t j) { - apply(_active_nodes[j]); - }); + tbb::parallel_for(UL(0), rebalance_moves_by_part->size(), [&](const size_t index) { + const auto& moves = (*rebalance_moves_by_part)[index]; + tbb::parallel_for(UL(0), moves.size(), [&](const size_t j) { + node_fn(moves[j].node); + }); + }, tbb::static_partitioner()); } + } - if (rebalance_moves_by_part != nullptr) { - for (const auto& moves: *rebalance_moves_by_part) { - if ( _context.refinement.label_propagation.execute_sequential ) { - for (const Move& m: moves) { - apply(m.node); - } - } else { - tbb::parallel_for(UL(0), moves.size(), [&](const size_t j) { - apply(moves[j].node); - }); - } + // NOTE: we need to handle rebalancing nodes and active nodes in separate steps, + // otherwise a node might be concurrently updated by two threads + if ( _context.refinement.label_propagation.execute_sequential ) { + for (const HypernodeID hn: _active_nodes) { + if (hypergraph.partID(hn) != _old_part[hn]) { + node_fn(hn); } } - }; - - for_each_moved_node(update_node); - for_each_moved_node(reset_node); + } else { + tbb::parallel_for(UL(0), _active_nodes.size(), [&](const size_t j) { + const HyperedgeID hn = _active_nodes[j]; + if (hypergraph.partID(hn) != _old_part[hn]) { + node_fn(hn); + } + }); + } } template diff --git a/mt-kahypar/partition/refinement/label_propagation/label_propagation_refiner.h b/mt-kahypar/partition/refinement/label_propagation/label_propagation_refiner.h index 396eda886..0b90ec5fa 100644 --- a/mt-kahypar/partition/refinement/label_propagation/label_propagation_refiner.h +++ b/mt-kahypar/partition/refinement/label_propagation/label_propagation_refiner.h @@ -107,8 +107,14 @@ class LabelPropagationRefiner final : public IRefiner { void updateNodeData(PartitionedHypergraph& hypergraph, NextActiveNodes& next_active_nodes, + bool should_update_gain_cache, vec>* rebalance_moves_by_part = nullptr); + template + void forEachMovedNode(const PartitionedHypergraph& hypergraph, + F node_fn, + const vec>* rebalance_moves_by_part = nullptr); + template bool moveVertex(PartitionedHypergraph& hypergraph, const HypernodeID hn, @@ -173,6 +179,7 @@ class LabelPropagationRefiner final : public IRefiner { void initializeImpl(mt_kahypar_partitioned_hypergraph_t&) final; template + MT_KAHYPAR_ATTRIBUTE_ALWAYS_INLINE bool changeNodePart(PartitionedHypergraph& phg, const HypernodeID hn, const PartitionID from,