Skip to content

Commit

Permalink
[LP] fix gain cache updates & refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
N-Maas committed Sep 5, 2023
1 parent c8209cc commit 0781e22
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ namespace mt_kahypar {
Metrics& best_metrics,
vec<vec<Move>>& rebalance_moves_by_part) {
Metrics current_metrics = best_metrics;
const bool should_update_gain_cache = !PartitionedHypergraph::is_graph && _gain_cache.isInitialized();
_visited_he.reset();
_next_active.reset();
_gain.reset();
Expand All @@ -131,7 +132,7 @@ namespace mt_kahypar {
for ( size_t j = 0; j < _active_nodes.size(); ++j ) {
const HypernodeID hn = _active_nodes[j];
if ( moveVertex<unconstrained>(hypergraph, hn, next_active_nodes, objective_delta) ) {
_active_node_was_moved[j] = uint8_t(true);
if (should_update_gain_cache) { _active_node_was_moved[j] = uint8_t(true); }
} else {
converged = false;
}
Expand All @@ -143,7 +144,7 @@ namespace mt_kahypar {
tbb::parallel_for(UL(0), _active_nodes.size(), [&](const size_t& j) {
const HypernodeID hn = _active_nodes[j];
if ( moveVertex<unconstrained>(hypergraph, hn, next_active_nodes, objective_delta) ) {
_active_node_was_moved[j] = uint8_t(true);
if (should_update_gain_cache) { _active_node_was_moved[j] = uint8_t(true); }
} else {
converged = false;
}
Expand All @@ -153,7 +154,7 @@ namespace mt_kahypar {
current_metrics.imbalance = metrics::imbalance(hypergraph, _context);
current_metrics.quality += _gain.delta();

if ( _gain_cache.isInitialized() ) {
if ( should_update_gain_cache ) {
auto recompute = [&](size_t j) {
if ( _active_node_was_moved[j] ) {
_gain_cache.recomputeInvalidTerms(hypergraph, _active_nodes[j]);
Expand All @@ -170,12 +171,13 @@ namespace mt_kahypar {
}
}

if (unconstrained && !metrics::isBalanced(hypergraph, _context)) {
bool should_stop = applyRebalancing(hypergraph, next_active_nodes, best_metrics,
current_metrics, rebalance_moves_by_part);
converged |= should_stop;
} else if (unconstrained) {
updateNodeData(hypergraph, next_active_nodes);
if constexpr ( unconstrained ) {
if (!metrics::isBalanced(hypergraph, _context)) {
converged = applyRebalancing(hypergraph, next_active_nodes, best_metrics,
current_metrics, rebalance_moves_by_part);
} else {
updateNodeData(hypergraph, next_active_nodes, false);
}
}

ASSERT(current_metrics.quality <= best_metrics.quality);
Expand All @@ -200,76 +202,90 @@ namespace mt_kahypar {
timer.stop_timer("rebalance_lp");
DBG << "[LP] Imbalance after rebalancing: " << current_metrics.imbalance << ", quality: " << current_metrics.quality;

const bool should_update_gain_cache = !PartitionedHypergraph::is_graph && _gain_cache.isInitialized();
if (current_metrics.quality > best_metrics.quality) { // rollback and stop LP
auto noop_obj_fn = [](const SynchronizedEdgeUpdate&) { };
current_metrics = best_metrics;

if ( _context.refinement.label_propagation.execute_sequential ) {
for (const HypernodeID hn : hypergraph.nodes()) {
if (hypergraph.partID(hn) != _old_part[hn]) {
changeNodePart<true>(hypergraph, hn, hypergraph.partID(hn), _old_part[hn], noop_obj_fn);
// rollback all changes and update gain cache
forEachMovedNode(hypergraph, [&](const HypernodeID hn) {
const PartitionID old_part = _old_part[hn];
if (hypergraph.partID(hn) != old_part && old_part != kInvalidPartition) {
changeNodePart<true>(hypergraph, hn, hypergraph.partID(hn), old_part, noop_obj_fn);
if (should_update_gain_cache) {
// slightly hacky: use kInvalidPartition to mark nodes that need to be updated
_old_part[hn] = kInvalidPartition;
}
}
} else {
hypergraph.doParallelForAllNodes([&](const HypernodeID hn) {
if (hypergraph.partID(hn) != _old_part[hn]) {
changeNodePart<true>(hypergraph, hn, hypergraph.partID(hn), _old_part[hn], noop_obj_fn);
}
});
}, &rebalance_moves_by_part);
if (should_update_gain_cache) {
forEachMovedNode(hypergraph, [&](const HypernodeID hn) {
_gain_cache.recomputeInvalidTerms(hypergraph, hn);
}, &rebalance_moves_by_part);
}
return true;
}

// collect activated nodes and recompute penalties
updateNodeData(hypergraph, next_active_nodes, &rebalance_moves_by_part);
updateNodeData(hypergraph, next_active_nodes, should_update_gain_cache, &rebalance_moves_by_part);
return false;
}

template <typename TypeTraits, typename GainTypes>
void LabelPropagationRefiner<TypeTraits, GainTypes>::updateNodeData(PartitionedHypergraph& hypergraph,
NextActiveNodes& next_active_nodes,
bool should_update_gain_cache,
vec<vec<Move>>* rebalance_moves_by_part) {
auto update_node = [&](const HypernodeID hn) {
const PartitionID current_part = hypergraph.partID(hn);
if (current_part != _old_part[hn]) {
activateNodeAndNeighbors(hypergraph, next_active_nodes, hn, false);
if ( _gain_cache.isInitialized() ) {
_gain_cache.recomputeInvalidTerms(hypergraph, hn);
}
// collect activated nodes and update gain cache
forEachMovedNode(hypergraph, [&](const HypernodeID hn) {
activateNodeAndNeighbors(hypergraph, next_active_nodes, hn, false);
if (should_update_gain_cache) {
_gain_cache.recomputeInvalidTerms(hypergraph, hn);
}
};
auto reset_node = [&](const HypernodeID hn) {
}, rebalance_moves_by_part);
// store current part of each node (required for rollback)
forEachMovedNode(hypergraph, [&](const HypernodeID hn) {
_old_part[hn] = hypergraph.partID(hn);
};
}, rebalance_moves_by_part);
}

auto for_each_moved_node = [&](auto apply) {
template <typename TypeTraits, typename GainTypes>
template<typename F>
void LabelPropagationRefiner<TypeTraits, GainTypes>::forEachMovedNode(const PartitionedHypergraph& hypergraph,
F node_fn,
const vec<vec<Move>>* rebalance_moves_by_part) {
if (rebalance_moves_by_part != nullptr) {
if ( _context.refinement.label_propagation.execute_sequential ) {
for (const HypernodeID hn: _active_nodes) {
apply(hn);
for (const auto& moves: *rebalance_moves_by_part) {
for (const Move& m: moves) {
node_fn(m.node);
}
}
} else {
tbb::parallel_for(UL(0), _active_nodes.size(), [&](const size_t j) {
apply(_active_nodes[j]);
});
tbb::parallel_for(UL(0), rebalance_moves_by_part->size(), [&](const size_t index) {
const auto& moves = (*rebalance_moves_by_part)[index];
tbb::parallel_for(UL(0), moves.size(), [&](const size_t j) {
node_fn(moves[j].node);
});
}, tbb::static_partitioner());
}
}

if (rebalance_moves_by_part != nullptr) {
for (const auto& moves: *rebalance_moves_by_part) {
if ( _context.refinement.label_propagation.execute_sequential ) {
for (const Move& m: moves) {
apply(m.node);
}
} else {
tbb::parallel_for(UL(0), moves.size(), [&](const size_t j) {
apply(moves[j].node);
});
}
// NOTE: we need to handle rebalancing nodes and active nodes in separate steps,
// otherwise a node might be concurrently updated by two threads
if ( _context.refinement.label_propagation.execute_sequential ) {
for (const HypernodeID hn: _active_nodes) {
if (hypergraph.partID(hn) != _old_part[hn]) {
node_fn(hn);
}
}
};

for_each_moved_node(update_node);
for_each_moved_node(reset_node);
} else {
tbb::parallel_for(UL(0), _active_nodes.size(), [&](const size_t j) {
const HyperedgeID hn = _active_nodes[j];
if (hypergraph.partID(hn) != _old_part[hn]) {
node_fn(hn);
}
});
}
}

template <typename TypeTraits, typename GainTypes>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,14 @@ class LabelPropagationRefiner final : public IRefiner {

void updateNodeData(PartitionedHypergraph& hypergraph,
NextActiveNodes& next_active_nodes,
bool should_update_gain_cache,
vec<vec<Move>>* rebalance_moves_by_part = nullptr);

template<typename F>
void forEachMovedNode(const PartitionedHypergraph& hypergraph,
F node_fn,
const vec<vec<Move>>* rebalance_moves_by_part = nullptr);

template<bool unconstrained, typename F>
bool moveVertex(PartitionedHypergraph& hypergraph,
const HypernodeID hn,
Expand Down Expand Up @@ -173,6 +179,7 @@ class LabelPropagationRefiner final : public IRefiner {
void initializeImpl(mt_kahypar_partitioned_hypergraph_t&) final;

template<bool unconstrained, typename F>
MT_KAHYPAR_ATTRIBUTE_ALWAYS_INLINE
bool changeNodePart(PartitionedHypergraph& phg,
const HypernodeID hn,
const PartitionID from,
Expand Down

0 comments on commit 0781e22

Please sign in to comment.