Skip to content

Commit

Permalink
Remove support for shared memory, split up Eval.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 31, 2018
1 parent 09607bc commit f7b22b6
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 232 deletions.
3 changes: 3 additions & 0 deletions src/common/common.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
/*!
* Copyright 2018 XGBoost contributors
*/
#include "common.h"

namespace xgboost {
Expand Down
2 changes: 1 addition & 1 deletion src/common/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ XGBOOST_DEVICE inline void Softmax(InIter beg_in, InIter end_in,
float wsum = 0;
InIter i = beg_in;
OutIter o = beg_out;
for (;i != end_in; ++i, ++o) {
for (; i != end_in; ++i, ++o) {
*o = expf(*i);
wsum += *o;
}
Expand Down
280 changes: 102 additions & 178 deletions src/common/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <dmlc/omp.h>
#include <xgboost/data.h>
#include <vector>
#include <type_traits>
#include <type_traits> // enable_if

#include "host_device_vector.h"
#include "common.h"
Expand Down Expand Up @@ -85,52 +85,16 @@ namespace common {

constexpr size_t kBlockThreads = 256;

template <typename T>
struct SharedMem {
size_t size_;
};

namespace detail {

#if defined (__CUDACC__)

template <typename T>
struct KernelSharedMem {
size_t size_;
__device__ Span<T> GetSpan() {
extern __shared__ __align__(sizeof(T)) T* mem[];
return Span<bst_float>{(T*)mem,
static_cast<typename Span<T>::index_type>(size_)};
}
};

template <typename T>
__device__ Range SegGridStrideRange(T _end_it, int64_t _segment) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
idx *= _segment;
if (idx >= *_end_it) {
return Range {0, 0};
}
return Range{idx, *_end_it, static_cast<int64_t>(gridDim.x * blockDim.x)};
}

#if defined(__CUDACC__)
template <typename Functor, typename... T>
__global__ void LaunchCUDAKernel(Functor _func, Range _range,
Span<T>... _spans) {
for (auto i : SegGridStrideRange(_range.end(), 1)) {
for (auto i : dh::GridStrideRange(*_range.begin(), *_range.end())) {
_func(i, _spans...);
}
}
template <typename Functor, typename S, typename... T>
__global__ void LaunchCUDAKernel(Functor _func,
Range _range, SharedMem<S> _shared,
Span<T>... _spans) {
KernelSharedMem<S> shared {_shared.size_};
Span<S> shared_span = shared.GetSpan();
for (auto i : SegGridStrideRange(_range.end(), 1)) {
_func(i, shared_span, _spans...);
}
}
#endif

} // namespace detail
Expand All @@ -149,163 +113,123 @@ __global__ void LaunchCUDAKernel(Functor _func,
* will merge function with same signature.
*/
template <bool CompiledWithCuda = WITH_CUDA()>
struct TransformN {
template <typename Functor, typename... HDV>
TransformN(Functor _func, Range _range, GPUSet _devices,
HDV... _vectors) {
bool on_device = _devices != GPUSet::Empty();

Reshard(_devices, _vectors...);

if (on_device) {
LaunchCUDA(_func, _range, _devices, _vectors...);
} else {
LaunchCPU(_func, _range, _vectors...);
class Transform {
private:
template <typename Functor>
struct Evaluator {
public:
Evaluator(Functor _func, Range _range, GPUSet _devices) :
func_{_func}, range_{_range}, devices_{_devices} {}

template <typename... HDV>
void Eval(HDV... _vectors) {
bool on_device = devices_ != GPUSet::Empty();

Reshard(devices_, _vectors...);

if (on_device) {
LaunchCUDA(func_, range_, devices_, _vectors...);
} else {
LaunchCPU(func_, range_, _vectors...);
}
}
}
template <typename Functor, typename S, typename... HDV>
TransformN(Functor _func,
Range _range, GPUSet _devices, SharedMem<S> _shared,
HDV... _vectors) {
bool on_device = _devices != GPUSet::Empty();

Reshard(_devices, _vectors...);
private:
template <typename T>
Span<T> UnpackHDV(HostDeviceVector<T>* _vec, int _device) {
return _vec->DeviceSpan(_device);
}
template <typename T>
Span<T const> UnpackHDV(const HostDeviceVector<T>* _vec, int _device) {
return _vec->ConstDeviceSpan(_device);
}

if (on_device) {
LaunchCUDA(_func, _range, _devices, _shared, _vectors...);
} else {
LaunchCPU(_func, _range, _shared, _vectors...);
template <typename T>
Span<T> UnpackHDV(HostDeviceVector<T>* _vec) {
return Span<T> {_vec->HostPointer(),
static_cast<typename Span<T>::index_type>(_vec->Size())};
}
template <typename T>
Span<T const> UnpackHDV(const HostDeviceVector<T>* _vec) {
return Span<T const> {_vec->ConstHostPointer(),
static_cast<typename Span<T>::index_type>(_vec->Size())};
}
}

private:
template <typename... HDV>
void Reshard(GPUSet _devices, HDV*... _vectors) {
std::vector<HDVAnyPtr> vectors {_vectors...};
template <typename... HDV>
void Reshard(GPUSet _devices, HDV*... _vectors) {
std::vector<HDVAnyPtr> vectors {_vectors...};
#pragma omp parallel for schedule(static, 1) if (vectors.size() > 1)
for (omp_ulong i = 0; i < vectors.size(); ++i) { // NOLINT
switch (vectors[i].GetType()) {
case HDVAnyPtr::Type::kBstFloatType:
vectors[i].GetFloat()->Reshard(_devices);
break;
case HDVAnyPtr::Type::kGradientPairType:
vectors[i].GetGradientPair()->Reshard(_devices);
break;
case HDVAnyPtr::Type::kIntType:
vectors[i].GetInt()->Reshard(_devices);
break;
case HDVAnyPtr::Type::kEntryType:
vectors[i].GetEntry()->Reshard(_devices);
break;
case HDVAnyPtr::Type::kSizeType:
vectors[i].GetSizeT()->Reshard(_devices);
break;
default:
LOG(FATAL) << "Unknown HostDeviceVector type: " << (int) vectors[i].GetType();
for (omp_ulong i = 0; i < vectors.size(); ++i) { // NOLINT
switch (vectors[i].GetType()) {
case HDVAnyPtr::Type::kBstFloatType:
vectors[i].GetFloat()->Reshard(_devices);
break;
case HDVAnyPtr::Type::kGradientPairType:
vectors[i].GetGradientPair()->Reshard(_devices);
break;
case HDVAnyPtr::Type::kIntType:
vectors[i].GetInt()->Reshard(_devices);
break;
case HDVAnyPtr::Type::kEntryType:
vectors[i].GetEntry()->Reshard(_devices);
break;
case HDVAnyPtr::Type::kSizeType:
vectors[i].GetSizeT()->Reshard(_devices);
break;
default:
LOG(FATAL) << "Unknown HostDeviceVector type: " <<
static_cast<int>(vectors[i].GetType());
}
}
}
}
template <typename T>
Span<T> UnpackHDV(HostDeviceVector<T>* _vec, int _device) {
return _vec->DeviceSpan(_device);
}
template <typename T>
Span<T const> UnpackHDV(const HostDeviceVector<T>* _vec, int _device) {
return _vec->ConstDeviceSpan(_device);
}

template <typename T>
Span<T> UnpackHDV(HostDeviceVector<T>* _vec) {
return Span<T> {_vec->HostPointer(),
static_cast<typename Span<T>::index_type>(_vec->Size())};
}
template <typename T>
Span<T const> UnpackHDV(const HostDeviceVector<T>* _vec) {
return Span<T const> {_vec->ConstHostPointer(),
static_cast<typename Span<T>::index_type>(_vec->Size())};
}

#if defined(__CUDACC__)
template <typename Functor,
typename std::enable_if<CompiledWithCuda>::type* = nullptr,
typename... HDV>
void LaunchCUDA(Functor _func, Range _range, GPUSet _devices,
HDV*... _vectors) {
#pragma omp parallel for schedule(static, 1) if (_devices.Size() > 1)
for (omp_ulong i = 0; i < _devices.Size(); ++i) {
int d = _devices.Index(i);
dh::safe_cuda(cudaSetDevice(d));
const int GRID_SIZE =
static_cast<int>(dh::DivRoundUp(*(_range.end()), kBlockThreads));

detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
_func, _range, UnpackHDV(_vectors, d)...);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
}
}
template <typename Functor,
typename std::enable_if<CompiledWithCuda>::type* = nullptr,
typename S,
typename... HDV>
void LaunchCUDA(Functor _func,
Range _range, GPUSet _devices, SharedMem<S> _shared,
HDV*... _vectors) {
template <typename std::enable_if<CompiledWithCuda>::type* = nullptr,
typename... HDV>
void LaunchCUDA(Functor _func, Range _range, GPUSet _devices,
HDV*... _vectors) {
#pragma omp parallel for schedule(static, 1) if (_devices.Size() > 1)
for (omp_ulong i = 0; i < _devices.Size(); ++i) {
int d = _devices.Index(i);
dh::safe_cuda(cudaSetDevice(d));
const int GRID_SIZE =
static_cast<int>(dh::DivRoundUp(*(_range.end()), kBlockThreads));
detail::LaunchCUDAKernel
<<<GRID_SIZE, kBlockThreads, sizeof(S) * _shared.size_>>>
(_func, _range, _shared, UnpackHDV(_vectors, d)...);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
for (omp_ulong i = 0; i < _devices.Size(); ++i) {
int d = _devices.Index(i);
dh::safe_cuda(cudaSetDevice(d));
const int GRID_SIZE =
static_cast<int>(dh::DivRoundUp(*(_range.end()), kBlockThreads));

detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
_func, _range, UnpackHDV(_vectors, d)...);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
}
}
}
#else
template <typename Functor,
typename std::enable_if<!CompiledWithCuda>::type* = nullptr,
typename... HDV>
void LaunchCUDA(Functor _func, Range _range, GPUSet _devices,
HDV*... _vectors) {
LOG(FATAL) << "Not part of device code. WITH_CUDA: " << WITH_CUDA();
}
template <typename Functor,
typename std::enable_if<!CompiledWithCuda>::type* = nullptr,
typename S,
typename... HDV>
void LaunchCUDA(Functor _func,
Range _range, GPUSet _devices, size_t _shared_size,
HDV*... _vectors) {
LOG(FATAL) << "Not part of device code. WITH_CUDA: " << WITH_CUDA();
}
template <typename std::enable_if<!CompiledWithCuda>::type* = nullptr,
typename... HDV>
void LaunchCUDA(Functor _func, Range _range, GPUSet _devices,
HDV*... _vectors) {
LOG(FATAL) << "Not part of device code. WITH_CUDA: " << WITH_CUDA();
}
#endif

template <typename Functor, typename... HDV>
void LaunchCPU(Functor _func, common::Range _range, HDV*... vectors) {
auto end = *(_range.end());
// CPU implementations
template <typename... HDV>
void LaunchCPU(Functor _func, common::Range _range, HDV*... vectors) {
auto end = *(_range.end());
#pragma omp parallel for schedule(static, 1)
for (omp_ulong idx = 0; idx < end; ++idx) {
_func(idx, UnpackHDV(vectors)...);
}
}

template <typename Functor, typename S, typename ... HDV>
void LaunchCPU(Functor _func,
common::Range _range, SharedMem<S> _shared, HDV*... vectors) {
auto end = *(_range.end());
#pragma omp parallel
{
std::vector<S> shared_mem (_shared.size_);
Span<S> shared_span {shared_mem.data(),
static_cast<typename Span<S>::index_type>(shared_mem.size())};
#pragma omp for schedule(static, 1)
for (omp_ulong idx = 0; idx < end; ++idx) {
_func(idx, shared_span, UnpackHDV(vectors)...);
_func(idx, UnpackHDV(vectors)...);
}
}

private:
Functor func_;
Range range_;
GPUSet devices_;
};

public:
template <typename Functor>
static Evaluator<Functor> Init(Functor _func,
Range _range, GPUSet _devices) {
return Evaluator<Functor> {_func, _range, _devices};
}
};

Expand Down
12 changes: 6 additions & 6 deletions src/objective/hinge.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class HingeObj : public ObjFunction {
const bool is_null_weight = info.weights_.Size() == 0;
const size_t ndata = preds.Size();
out_gpair->Resize(ndata);
common::TransformN<>(
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<int> _label_correct,
common::Span<GradientPair> _out_gpair,
Expand All @@ -77,17 +77,17 @@ class HingeObj : public ObjFunction {
}
_out_gpair[_idx] = GradientPair(g, h);
},
common::Range{0, static_cast<int64_t>(ndata)}, devices_,
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
common::Range{0, static_cast<int64_t>(ndata)}, devices_).Eval(
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
}

void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
common::TransformN<>(
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
},
common::Range{0, static_cast<int64_t>(io_preds->Size()), 1},
devices_, io_preds);
common::Range{0, static_cast<int64_t>(io_preds->Size()), 1}, devices_)
.Eval(io_preds);
}

const char* DefaultEvalMetric() const override {
Expand Down
Loading

0 comments on commit f7b22b6

Please sign in to comment.