Skip to content

Commit

Permalink
Use cache in cox.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 10, 2023
1 parent 4802d58 commit e1cec37
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 39 deletions.
15 changes: 0 additions & 15 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,21 +130,6 @@ class MetaInfo {
inline bst_float GetWeight(size_t i) const {
return weights_.Size() != 0 ? weights_.HostVector()[i] : 1.0f;
}
/*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */
inline const std::vector<size_t>& LabelAbsSort() const {
if (label_order_cache_.size() == labels.Size()) {
return label_order_cache_;
}
label_order_cache_.resize(labels.Size());
std::iota(label_order_cache_.begin(), label_order_cache_.end(), 0);
const auto& l = labels.Data()->HostVector();
XGBOOST_PARALLEL_STABLE_SORT(label_order_cache_.begin(), label_order_cache_.end(),
[&l](size_t i1, size_t i2) {return std::abs(l[i1]) < std::abs(l[i2]);});

return label_order_cache_;
}
/*! \brief clear all the information */
void Clear();
/*!
* \brief Load the Meta info from binary stream.
* \param fi The input stream
Expand Down
9 changes: 0 additions & 9 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,15 +178,6 @@ namespace xgboost {

uint64_t constexpr MetaInfo::kNumField;

// implementation of inline functions
void MetaInfo::Clear() {
num_row_ = num_col_ = num_nonzero_ = 0;
labels = decltype(labels){};
group_ptr_.clear();
weights_.HostVector().clear();
base_margin_ = decltype(base_margin_){};
}

/*
* Binary serialization format for MetaInfo:
*
Expand Down
31 changes: 19 additions & 12 deletions src/metric/rank_metric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,30 +310,37 @@ struct EvalMAP : public EvalRank {
};

/*! \brief Cox: Partial likelihood of the Cox proportional hazards model */
struct EvalCox : public MetricNoCache {
class EvalCox : public Metric {
DMatrixCache<std::vector<std::size_t>> cache_{
DMatrixCache<std::vector<std::size_t>>::DefaultSize()};

public:
EvalCox() = default;
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
double Evaluate(HostDeviceVector<float> const& preds, std::shared_ptr<DMatrix> p_fmat) override {
auto const& info = p_fmat->Info();
CHECK(!collective::IsDistributed()) << "Cox metric does not support distributed evaluation";
using namespace std; // NOLINT(*)

const auto ndata = static_cast<bst_omp_uint>(info.labels.Size());
const auto &label_order = info.LabelAbsSort();

auto p_sorted_idx = cache_.CacheItem(p_fmat);
auto& sorted_idx = *p_sorted_idx;
auto labels = info.labels.HostView();

if (sorted_idx.empty()) {
sorted_idx =
common::ArgSort<std::size_t>(ctx_, linalg::cbegin(labels), linalg::cend(labels),
[&](float l, float r) { return std::abs(l) < std::abs(r); });
}
// pre-compute a sum for the denominator
double exp_p_sum = 0; // we use double because we might need the precision with large datasets

const auto &h_preds = preds.ConstHostVector();
for (omp_ulong i = 0; i < ndata; ++i) {
for (std::size_t i = 0; i < info.num_row_; ++i) {
exp_p_sum += h_preds[i];
}

double out = 0;
double accumulated_sum = 0;
bst_omp_uint num_events = 0;
const auto& labels = info.labels.HostView();
for (bst_omp_uint i = 0; i < ndata; ++i) {
const size_t ind = label_order[i];
for (bst_omp_uint i = 0; i < info.num_row_; ++i) {
const size_t ind = sorted_idx[i];
const auto label = labels(ind);
if (label > 0) {
out -= log(h_preds[ind]) - log(exp_p_sum);
Expand All @@ -342,7 +349,7 @@ struct EvalCox : public MetricNoCache {

// only update the denominator after we move forward in time (labels are sorted)
accumulated_sum += h_preds[ind];
if (i == ndata - 1 || std::abs(label) < std::abs(labels(label_order[i + 1]))) {
if (i == info.num_row_ - 1 || std::abs(label) < std::abs(labels(sorted_idx[i + 1]))) {
exp_p_sum -= accumulated_sum;
accumulated_sum = 0;
}
Expand Down
3 changes: 0 additions & 3 deletions tests/cpp/data/test_metainfo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ TEST(MetaInfo, GetSet) {
info.SetInfo(ctx, "group", uint64_t2, xgboost::DataType::kUInt64, 2);
ASSERT_EQ(info.group_ptr_.size(), 3);
EXPECT_EQ(info.group_ptr_[2], 3);

info.Clear();
ASSERT_EQ(info.group_ptr_.size(), 0);
}

TEST(MetaInfo, GetSetFeature) {
Expand Down

0 comments on commit e1cec37

Please sign in to comment.