From 6dcb500f19d285b51c862532b235ba23a5c3046b Mon Sep 17 00:00:00 2001 From: Dominik Rosch Date: Tue, 1 Oct 2024 15:51:38 +0200 Subject: [PATCH] implement medians_of_medians --- .../sparsification/sparsification_utils.h | 36 +++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/kaminpar-shm/coarsening/sparsification/sparsification_utils.h b/kaminpar-shm/coarsening/sparsification/sparsification_utils.h index 2be38b60..5cb83eb9 100644 --- a/kaminpar-shm/coarsening/sparsification/sparsification_utils.h +++ b/kaminpar-shm/coarsening/sparsification/sparsification_utils.h @@ -45,12 +45,12 @@ inline void parallel_for_downward_edges(const CSRGraph &g, Lambda function) { } template -T select_k_smallest(size_t k, Iterator begin, Iterator end, Comp comp = std::less()) { +T quickselect_k_smallest(size_t k, Iterator begin, Iterator end, Comp comp = std::less()) { size_t size = begin - end; if (size == 1) return *begin; - T pivot = begin[Random::instance().random_index(0, size)]; + T pivot = medians_of_medians(begin, end); tbb::concurrent_vector less = {}, greater = {}; tbb::parallel_for(begin, end, [&](auto x) { if (comp(x, pivot)) @@ -65,6 +65,30 @@ T select_k_smallest(size_t k, Iterator begin, Iterator end, Comp comp = std::les return select_k_smallest(k - less.size(), greater.begin(), greater.end()); } +template T medians_of_medians(Iterator begin, Iterator end) { + size_t size = begin - end; + if (size <= 5) + return median(begin, end); + + size_t number_of_sections = (size + 4) / 5; + StaticArray medians(number_of_sections); + tbb::parallel_for(0, number_of_sections, [&](auto i) { + medians[i] = median(begin + 5 * i, begin + std::min(5 * (i + 1), size)); + }); + + return quickselect_k_smallest(number_of_sections / 2, medians.begin(), medians.end()); +} +template T median(Iterator begin, Iterator end) { + size_t size = begin - end; + StaticArray sorted(size); + std::sort(begin, end, begin); + if (size % 2 == 1) { // odd size + return sorted[size / 2]; + } else { + return (sorted[size / 2] + sorted[size / 2 + 1]) / 2; + } +} + template StaticArray sample_k_without_replacement(WeightIterator weights_begin, WeightIterator weights_end, size_t k) { @@ -73,17 +97,17 @@ sample_k_without_replacement(WeightIterator weights_begin, WeightIterator weight tbb::parallel_for(0ul, size, [&](auto i) { keys[i] = -std::log(Random::instance().random_double()) / weights_begin[i]; }); - double x = select_k_smallest(k, keys.begin(), keys.end(), std::less()); + double x = quickselect_k_smallest(k, keys.begin(), keys.end(), std::less()); - StaticArray selected(k); + tbb::concurrent_vector selected; size_t back = 0; tbb::parallel_for(0ul, keys.size(), [&](auto i) { if (keys[i] <= x) { __atomic_fetch_add(&back, 1, __ATOMIC_RELAXED); - selected[back] = i; + selected.push_back(i); } }); - return selected; + return StaticArray(selected.begin(), selected.end()); } } // namespace kaminpar::shm::sparsification::utils