Skip to content

Commit

Permalink
parallel sort.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 3, 2023
1 parent 9d3afba commit a7ade13
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/common/stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,14 @@ float Quantile(double alpha, Iter const& begin, Iter const& end) {

std::vector<size_t> sorted_idx(n);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
[&](size_t l, size_t r) { return *(begin + l) < *(begin + r); });
if (omp_in_parallel()) {
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
} else {
XGBOOST_PARALLEL_STABLE_SORT(
sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
}

auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };
static_assert(std::is_same<decltype(val(0)), float>::value, "");
Expand Down Expand Up @@ -76,8 +82,14 @@ float WeightedQuantile(double alpha, Iter begin, Iter end, WeightIter weights) {
}
std::vector<size_t> sorted_idx(n);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
[&](size_t l, size_t r) { return *(begin + l) < *(begin + r); });
if (omp_in_parallel()) {
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
} else {
XGBOOST_PARALLEL_STABLE_SORT(
sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
}

auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };

Expand Down

0 comments on commit a7ade13

Please sign in to comment.