Skip to content

Commit

Permalink
Support multiple alphas for segmented quantile. (#8758)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Feb 7, 2023
1 parent c4802bf commit 48cefa0
Show file tree
Hide file tree
Showing 2 changed files with 299 additions and 108 deletions.
271 changes: 193 additions & 78 deletions src/common/stats.cuh
Original file line number Diff line number Diff line change
@@ -1,99 +1,220 @@
/*!
* Copyright 2022 by XGBoost Contributors
/**
* Copyright 2022-2023 by XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_STATS_CUH_
#define XGBOOST_COMMON_STATS_CUH_

#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/permutation_iterator.h>
#include <thrust/binary_search.h> // lower_bound
#include <thrust/for_each.h> // for_each_n
#include <thrust/iterator/constant_iterator.h> // make_constant_iterator
#include <thrust/iterator/counting_iterator.h> // make_counting_iterator
#include <thrust/iterator/permutation_iterator.h> // make_permutation_iterator
#include <thrust/scan.h> // inclusive_scan_by_key

#include <iterator> // std::distance
#include <algorithm> // std::min
#include <cstddef> // std::size_t
#include <iterator> // std::distance
#include <limits> // std::numeric_limits
#include <type_traits> // std::is_floating_point,std::iterator_traits

#include "cuda_context.cuh" // CUDAContext
#include "device_helpers.cuh"
#include "linalg_op.cuh"
#include "xgboost/context.h"
#include "xgboost/linalg.h"
#include "xgboost/tree_model.h"
#include "xgboost/context.h" // Context
#include "xgboost/span.h" // Span

namespace xgboost {
namespace common {
namespace detail {
// This should be a lambda function, but for some reason gcc-11 + nvcc-11.8 failed to
// compile it. As a result, a functor is extracted instead.
//
// error: ‘__T288’ was not declared in this scope
template <typename SegIt, typename ValIt, typename AlphaIt>
struct QuantileSegmentOp {
SegIt seg_begin;
ValIt val;
AlphaIt alpha_it;
Span<float> d_results;

static_assert(std::is_floating_point<typename std::iterator_traits<ValIt>::value_type>::value,
"Invalid value for quantile.");
static_assert(std::is_floating_point<typename std::iterator_traits<ValIt>::value_type>::value,
"Invalid alpha.");

XGBOOST_DEVICE void operator()(std::size_t seg_idx) {
std::size_t begin = seg_begin[seg_idx];
auto n = static_cast<double>(seg_begin[seg_idx + 1] - begin);
double a = alpha_it[seg_idx];

if (n == 0) {
d_results[seg_idx] = std::numeric_limits<float>::quiet_NaN();
return;
}

if (a <= (1 / (n + 1))) {
d_results[seg_idx] = val[begin];
return;
}
if (a >= (n / (n + 1))) {
d_results[seg_idx] = val[common::LastOf(seg_idx, seg_begin)];
return;
}

double x = a * static_cast<double>(n + 1);
double k = std::floor(x) - 1;
double d = (x - 1) - k;

auto v0 = val[begin + static_cast<std::size_t>(k)];
auto v1 = val[begin + static_cast<std::size_t>(k) + 1];

d_results[seg_idx] = v0 + d * (v1 - v0);
}
};

template <typename SegIt, typename ValIt, typename AlphaIt>
auto MakeQSegOp(SegIt seg_it, ValIt val_it, AlphaIt alpha_it, Span<float> d_results) {
return QuantileSegmentOp<SegIt, ValIt, AlphaIt>{seg_it, val_it, alpha_it, d_results};
}

template <typename SegIt>
struct SegOp {
SegIt seg_beg;
SegIt seg_end;

XGBOOST_DEVICE std::size_t operator()(std::size_t i) {
return dh::SegmentId(seg_beg, seg_end, i);
}
};

template <typename WIter>
struct WeightOp {
WIter w_begin;
Span<std::size_t const> d_sorted_idx;
XGBOOST_DEVICE float operator()(std::size_t i) { return w_begin[d_sorted_idx[i]]; }
};

template <typename SegIt, typename ValIt, typename AlphaIt>
struct WeightedQuantileSegOp {
AlphaIt alpha_it;
SegIt seg_beg;
ValIt val_begin;
Span<float const> d_weight_cdf;
Span<std::size_t const> d_sorted_idx;
Span<float> d_results;
static_assert(std::is_floating_point<typename std::iterator_traits<AlphaIt>::value_type>::value,
"Invalid alpha.");
static_assert(std::is_floating_point<typename std::iterator_traits<ValIt>::value_type>::value,
"Invalid value for quantile.");

XGBOOST_DEVICE void operator()(std::size_t seg_idx) {
std::size_t begin = seg_beg[seg_idx];
auto n = static_cast<double>(seg_beg[seg_idx + 1] - begin);
if (n == 0) {
d_results[seg_idx] = std::numeric_limits<float>::quiet_NaN();
return;
}
auto seg_cdf = d_weight_cdf.subspan(begin, static_cast<std::size_t>(n));
auto seg_sorted_idx = d_sorted_idx.subspan(begin, static_cast<std::size_t>(n));
double a = alpha_it[seg_idx];
double thresh = seg_cdf.back() * a;

std::size_t idx =
thrust::lower_bound(thrust::seq, seg_cdf.data(), seg_cdf.data() + seg_cdf.size(), thresh) -
seg_cdf.data();
idx = std::min(idx, static_cast<std::size_t>(n - 1));
d_results[seg_idx] = val_begin[seg_sorted_idx[idx]];
}
};

template <typename SegIt, typename ValIt, typename AlphaIt>
auto MakeWQSegOp(SegIt seg_it, ValIt val_it, AlphaIt alpha_it, Span<float const> d_weight_cdf,
Span<std::size_t const> d_sorted_idx, Span<float> d_results) {
return WeightedQuantileSegOp<SegIt, ValIt, AlphaIt>{alpha_it, seg_it, val_it,
d_weight_cdf, d_sorted_idx, d_results};
}
} // namespace detail
/**
* \brief Compute segmented quantile on GPU.
* @brief Compute segmented quantile on GPU.
*
* \tparam SegIt Iterator for CSR style segments indptr
* \tparam ValIt Iterator for values
* @tparam SegIt Iterator for CSR style segments indptr
* @tparam ValIt Iterator for values
* @tparam AlphaIt Iterator to alphas
*
* \param alpha The p^th quantile we want to compute
* @param alpha The p^th quantile we want to compute, one for each segment.
*
* std::distance(ptr_begin, ptr_end) should be equal to n_segments + 1
* std::distance(seg_begin, seg_end) should be equal to n_segments + 1
*/
template <typename SegIt, typename ValIt>
void SegmentedQuantile(Context const* ctx, double alpha, SegIt seg_begin, SegIt seg_end,
template <typename SegIt, typename ValIt, typename AlphaIt,
std::enable_if_t<!std::is_floating_point<AlphaIt>::value>* = nullptr>
void SegmentedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_begin, SegIt seg_end,
ValIt val_begin, ValIt val_end, HostDeviceVector<float>* quantiles) {
CHECK(alpha >= 0 && alpha <= 1);

dh::device_vector<size_t> sorted_idx;
using Tup = thrust::tuple<size_t, float>;
dh::device_vector<std::size_t> sorted_idx;
using Tup = thrust::tuple<std::size_t, float>;
dh::SegmentedArgSort(seg_begin, seg_end, val_begin, val_end, &sorted_idx);
auto n_segments = std::distance(seg_begin, seg_end) - 1;
if (n_segments <= 0) {
return;
}

quantiles->SetDevice(ctx->gpu_id);
quantiles->Resize(n_segments);
auto d_results = quantiles->DeviceSpan();
auto d_sorted_idx = dh::ToSpan(sorted_idx);

auto val = thrust::make_permutation_iterator(val_begin, dh::tcbegin(d_sorted_idx));

dh::LaunchN(n_segments, [=] XGBOOST_DEVICE(size_t i) {
// each segment is the index of a leaf.
size_t seg_idx = i;
size_t begin = seg_begin[seg_idx];
auto n = static_cast<double>(seg_begin[seg_idx + 1] - begin);
if (n == 0) {
d_results[i] = std::numeric_limits<float>::quiet_NaN();
return;
}

if (alpha <= (1 / (n + 1))) {
d_results[i] = val[begin];
return;
}
if (alpha >= (n / (n + 1))) {
d_results[i] = val[common::LastOf(seg_idx, seg_begin)];
return;
}
quantiles->SetDevice(ctx->gpu_id);
quantiles->Resize(n_segments);
auto d_results = quantiles->DeviceSpan();

double x = alpha * static_cast<double>(n + 1);
double k = std::floor(x) - 1;
double d = (x - 1) - k;
dh::LaunchN(n_segments, ctx->CUDACtx()->Stream(),
detail::MakeQSegOp(seg_begin, val, alpha_it, d_results));
}

auto v0 = val[begin + static_cast<size_t>(k)];
auto v1 = val[begin + static_cast<size_t>(k) + 1];
d_results[seg_idx] = v0 + d * (v1 - v0);
});
/**
* @brief Compute segmented quantile on GPU.
*
* @tparam SegIt Iterator for CSR style segments indptr
* @tparam ValIt Iterator for values
*
* @param alpha The p^th quantile we want to compute
*
* std::distance(ptr_begin, ptr_end) should be equal to n_segments + 1
*/
template <typename SegIt, typename ValIt>
void SegmentedQuantile(Context const* ctx, double alpha, SegIt seg_begin, SegIt seg_end,
ValIt val_begin, ValIt val_end, HostDeviceVector<float>* quantiles) {
CHECK(alpha >= 0 && alpha <= 1);
auto alpha_it = thrust::make_constant_iterator(alpha);
return SegmentedQuantile(ctx, alpha_it, seg_begin, seg_end, val_begin, val_end, quantiles);
}

template <typename SegIt, typename ValIt, typename WIter>
void SegmentedWeightedQuantile(Context const* ctx, double alpha, SegIt seg_beg, SegIt seg_end,
/**
* @brief Compute segmented quantile on GPU with weighted inputs.
*
* @tparam SegIt Iterator for CSR style segments indptr
* @tparam ValIt Iterator for values
* @tparam WIter Iterator for weights
*
* @param alpha_it Iterator for the p^th quantile we want to compute, one per-segment
* @param w_begin Iterator for weight for each input element
*/
template <typename SegIt, typename ValIt, typename AlphaIt, typename WIter,
typename std::enable_if_t<!std::is_same<
typename std::iterator_traits<AlphaIt>::value_type, void>::value>* = nullptr>
void SegmentedWeightedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_beg, SegIt seg_end,
ValIt val_begin, ValIt val_end, WIter w_begin, WIter w_end,
HostDeviceVector<float>* quantiles) {
CHECK(alpha >= 0 && alpha <= 1);
dh::device_vector<size_t> sorted_idx;
auto cuctx = ctx->CUDACtx();
dh::device_vector<std::size_t> sorted_idx;
dh::SegmentedArgSort(seg_beg, seg_end, val_begin, val_end, &sorted_idx);
auto d_sorted_idx = dh::ToSpan(sorted_idx);
size_t n_weights = std::distance(w_begin, w_end);
std::size_t n_weights = std::distance(w_begin, w_end);
dh::device_vector<float> weights_cdf(n_weights);
std::size_t n_elems = std::distance(val_begin, val_end);
CHECK_EQ(n_weights, n_elems);

dh::XGBCachingDeviceAllocator<char> caching;
auto scan_key = dh::MakeTransformIterator<size_t>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(size_t i) { return dh::SegmentId(seg_beg, seg_end, i); });
auto scan_val = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(size_t i) { return w_begin[d_sorted_idx[i]]; });
auto scan_key = dh::MakeTransformIterator<std::size_t>(thrust::make_counting_iterator(0ul),
detail::SegOp<SegIt>{seg_beg, seg_end});
auto scan_val = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
detail::WeightOp<WIter>{w_begin, d_sorted_idx});
thrust::inclusive_scan_by_key(thrust::cuda::par(caching), scan_key, scan_key + n_weights,
scan_val, weights_cdf.begin());

Expand All @@ -103,24 +224,18 @@ void SegmentedWeightedQuantile(Context const* ctx, double alpha, SegIt seg_beg,
auto d_results = quantiles->DeviceSpan();
auto d_weight_cdf = dh::ToSpan(weights_cdf);

dh::LaunchN(n_segments, [=] XGBOOST_DEVICE(size_t i) {
size_t seg_idx = i;
size_t begin = seg_beg[seg_idx];
auto n = static_cast<double>(seg_beg[seg_idx + 1] - begin);
if (n == 0) {
d_results[i] = std::numeric_limits<float>::quiet_NaN();
return;
}
auto leaf_cdf = d_weight_cdf.subspan(begin, static_cast<size_t>(n));
auto leaf_sorted_idx = d_sorted_idx.subspan(begin, static_cast<size_t>(n));
float thresh = leaf_cdf.back() * alpha;

size_t idx = thrust::lower_bound(thrust::seq, leaf_cdf.data(),
leaf_cdf.data() + leaf_cdf.size(), thresh) -
leaf_cdf.data();
idx = std::min(idx, static_cast<size_t>(n - 1));
d_results[i] = val_begin[leaf_sorted_idx[idx]];
});
thrust::for_each_n(
cuctx->CTP(), thrust::make_counting_iterator(0ul), n_segments,
detail::MakeWQSegOp(seg_beg, val_begin, alpha_it, d_weight_cdf, d_sorted_idx, d_results));
}

template <typename SegIt, typename ValIt, typename WIter>
void SegmentedWeightedQuantile(Context const* ctx, double alpha, SegIt seg_beg, SegIt seg_end,
ValIt val_begin, ValIt val_end, WIter w_begin, WIter w_end,
HostDeviceVector<float>* quantiles) {
CHECK(alpha >= 0 && alpha <= 1);
return SegmentedWeightedQuantile(ctx, thrust::make_constant_iterator(alpha), seg_beg, seg_end,
val_begin, val_end, w_begin, w_end, quantiles);
}
} // namespace common
} // namespace xgboost
Expand Down
Loading

0 comments on commit 48cefa0

Please sign in to comment.