Skip to content

Commit

Permalink
implement medians_of_medians
Browse files Browse the repository at this point in the history
  • Loading branch information
Dominik Rosch committed Oct 1, 2024
1 parent b8302d8 commit 6dcb500
Showing 1 changed file with 30 additions and 6 deletions.
36 changes: 30 additions & 6 deletions kaminpar-shm/coarsening/sparsification/sparsification_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ inline void parallel_for_downward_edges(const CSRGraph &g, Lambda function) {
}

template <typename T, typename Iterator, typename Comp>
T select_k_smallest(size_t k, Iterator begin, Iterator end, Comp comp = std::less<T>()) {
T quickselect_k_smallest(size_t k, Iterator begin, Iterator end, Comp comp = std::less<T>()) {

size_t size = begin - end;
if (size == 1)
return *begin;
T pivot = begin[Random::instance().random_index(0, size)];
T pivot = medians_of_medians(begin, end);
tbb::concurrent_vector<T> less = {}, greater = {};
tbb::parallel_for(begin, end, [&](auto x) {
if (comp(x, pivot))
Expand All @@ -65,6 +65,30 @@ T select_k_smallest(size_t k, Iterator begin, Iterator end, Comp comp = std::les
return select_k_smallest(k - less.size(), greater.begin(), greater.end());
}

template <typename T, typename Iterator> T medians_of_medians(Iterator begin, Iterator end) {
size_t size = begin - end;
if (size <= 5)
return median(begin, end);

size_t number_of_sections = (size + 4) / 5;
StaticArray<T> medians(number_of_sections);
tbb::parallel_for(0, number_of_sections, [&](auto i) {
medians[i] = median(begin + 5 * i, begin + std::min(5 * (i + 1), size));
});

return quickselect_k_smallest<T, Iterator>(number_of_sections / 2, medians.begin(), medians.end());
}
template <typename T, typename Iterator> T median(Iterator begin, Iterator end) {
size_t size = begin - end;
StaticArray<T> sorted(size);
std::sort(begin, end, begin);
if (size % 2 == 1) { // odd size
return sorted[size / 2];
} else {
return (sorted[size / 2] + sorted[size / 2 + 1]) / 2;
}
}

template <typename WeightIterator>
StaticArray<size_t>
sample_k_without_replacement(WeightIterator weights_begin, WeightIterator weights_end, size_t k) {
Expand All @@ -73,17 +97,17 @@ sample_k_without_replacement(WeightIterator weights_begin, WeightIterator weight
tbb::parallel_for(0ul, size, [&](auto i) {
keys[i] = -std::log(Random::instance().random_double()) / weights_begin[i];
});
double x = select_k_smallest<double>(k, keys.begin(), keys.end(), std::less<double>());
double x = quickselect_k_smallest<double>(k, keys.begin(), keys.end(), std::less<double>());

StaticArray<size_t> selected(k);
tbb::concurrent_vector<double> selected;
size_t back = 0;
tbb::parallel_for(0ul, keys.size(), [&](auto i) {
if (keys[i] <= x) {
__atomic_fetch_add(&back, 1, __ATOMIC_RELAXED);
selected[back] = i;
selected.push_back(i);
}
});
return selected;
return StaticArray<size_t>(selected.begin(), selected.end());
}

} // namespace kaminpar::shm::sparsification::utils

0 comments on commit 6dcb500

Please sign in to comment.