Skip to content

Commit

Permalink
Support linalg data structures in check device. (#9243)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jun 6, 2023
1 parent fc8110e commit 0cba2cd
Showing 1 changed file with 30 additions and 20 deletions.
50 changes: 30 additions & 20 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
#include <dmlc/registry.h>

#include <array>
#include <cstddef>
#include <cstring>

#include "../collective/communicator-inl.h"
#include "../collective/communicator.h"
#include "../common/common.h"
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/error_msg.h" // for InfInData
#include "../common/common.h"
#include "../common/error_msg.h" // for InfInData, GroupWeight, GroupSize
#include "../common/group_data.h"
#include "../common/io.h"
#include "../common/linalg_op.h"
Expand All @@ -35,6 +36,7 @@
#include "xgboost/context.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/learner.h"
#include "xgboost/linalg.h" // Vector
#include "xgboost/logging.h"
#include "xgboost/string_view.h"
#include "xgboost/version_config.h"
Expand Down Expand Up @@ -491,7 +493,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
}
// uint info
if (key == "group") {
linalg::Tensor<bst_group_t, 1> t;
linalg::Vector<bst_group_t> t;
CopyTensorInfoImpl(ctx, arr, &t);
auto const& h_groups = t.Data()->HostVector();
group_ptr_.clear();
Expand All @@ -516,6 +518,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
data::ValidateQueryGroup(group_ptr_);
return;
}

// float info
linalg::Tensor<float, 1> t;
CopyTensorInfoImpl<1>(ctx, arr, &t);
Expand Down Expand Up @@ -717,58 +720,63 @@ void MetaInfo::SynchronizeNumberOfColumns() {
}
}

namespace {
template <typename T>
void CheckDevice(std::int32_t device, HostDeviceVector<T> const& v) {
CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
<< "Data is resided on a different device than `gpu_id`. "
<< "Device that data is on: " << v.DeviceIdx() << ", "
<< "`gpu_id` for XGBoost: " << device;
}
template <typename T, std::int32_t D>
void CheckDevice(std::int32_t device, linalg::Tensor<T, D> const& v) {
CheckDevice(device, *v.Data());
}
} // anonymous namespace

void MetaInfo::Validate(std::int32_t device) const {
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
<< "Size of weights must equal to number of groups when ranking "
"group is used.";
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) << error::GroupWeight();
return;
}
if (group_ptr_.size() != 0) {
CHECK_EQ(group_ptr_.back(), num_row_)
<< "Invalid group structure. Number of rows obtained from groups "
"doesn't equal to actual number of rows given by data.";
<< error::GroupSize() << "the actual number of rows given by data.";
}
auto check_device = [device](HostDeviceVector<float> const& v) {
CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
<< "Data is resided on a different device than `gpu_id`. "
<< "Device that data is on: " << v.DeviceIdx() << ", "
<< "`gpu_id` for XGBoost: " << device;
};

if (weights_.Size() != 0) {
CHECK_EQ(weights_.Size(), num_row_)
<< "Size of weights must equal to number of rows.";
check_device(weights_);
CheckDevice(device, weights_);
return;
}
if (labels.Size() != 0) {
CHECK_EQ(labels.Shape(0), num_row_) << "Size of labels must equal to number of rows.";
check_device(*labels.Data());
CheckDevice(device, labels);
return;
}
if (labels_lower_bound_.Size() != 0) {
CHECK_EQ(labels_lower_bound_.Size(), num_row_)
<< "Size of label_lower_bound must equal to number of rows.";
check_device(labels_lower_bound_);
CheckDevice(device, labels_lower_bound_);
return;
}
if (feature_weights.Size() != 0) {
CHECK_EQ(feature_weights.Size(), num_col_)
<< "Size of feature_weights must equal to number of columns.";
check_device(feature_weights);
CheckDevice(device, feature_weights);
}
if (labels_upper_bound_.Size() != 0) {
CHECK_EQ(labels_upper_bound_.Size(), num_row_)
<< "Size of label_upper_bound must equal to number of rows.";
check_device(labels_upper_bound_);
CheckDevice(device, labels_upper_bound_);
return;
}
CHECK_LE(num_nonzero_, num_col_ * num_row_);
if (base_margin_.Size() != 0) {
CHECK_EQ(base_margin_.Size() % num_row_, 0)
<< "Size of base margin must be a multiple of number of rows.";
check_device(*base_margin_.Data());
CheckDevice(device, base_margin_);
}
}

Expand Down Expand Up @@ -1028,6 +1036,8 @@ SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
bool SparsePage::IsIndicesSorted(int32_t n_threads) const {
auto& h_offset = this->offset.HostVector();
auto& h_data = this->data.HostVector();
n_threads = std::max(std::min(static_cast<std::size_t>(n_threads), this->Size()),
static_cast<std::size_t>(1));
std::vector<int32_t> is_sorted_tloc(n_threads, 0);
common::ParallelFor(this->Size(), n_threads, [&](auto i) {
auto beg = h_offset[i];
Expand Down

0 comments on commit 0cba2cd

Please sign in to comment.