From f7b22b6e45dedd60ae80344714d3178cff88d23d Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 31 Aug 2018 22:41:55 +0800 Subject: [PATCH] Remove support for shared memory, split up Eval. --- src/common/common.cu | 3 + src/common/math.h | 2 +- src/common/transform.h | 280 +++++++++-------------- src/objective/hinge.cu | 12 +- src/objective/multiclass_obj.cu | 39 ++-- src/objective/regression_obj_gpu.cu | 51 +++-- tests/cpp/common/test_transform_range.cc | 4 +- 7 files changed, 159 insertions(+), 232 deletions(-) diff --git a/src/common/common.cu b/src/common/common.cu index 53740265280c..d0818390d0b6 100644 --- a/src/common/common.cu +++ b/src/common/common.cu @@ -1,3 +1,6 @@ +/*! + * Copyright 2018 XGBoost contributors + */ #include "common.h" namespace xgboost { diff --git a/src/common/math.h b/src/common/math.h index c15de0a5f4db..569dcefff70b 100644 --- a/src/common/math.h +++ b/src/common/math.h @@ -69,7 +69,7 @@ XGBOOST_DEVICE inline void Softmax(InIter beg_in, InIter end_in, float wsum = 0; InIter i = beg_in; OutIter o = beg_out; - for (;i != end_in; ++i, ++o) { + for (; i != end_in; ++i, ++o) { *o = expf(*i); wsum += *o; } diff --git a/src/common/transform.h b/src/common/transform.h index 6b67740a4c1b..27262dd8d1c4 100644 --- a/src/common/transform.h +++ b/src/common/transform.h @@ -7,7 +7,7 @@ #include #include #include -#include +#include // enable_if #include "host_device_vector.h" #include "common.h" @@ -85,52 +85,16 @@ namespace common { constexpr size_t kBlockThreads = 256; -template -struct SharedMem { - size_t size_; -}; - namespace detail { -#if defined (__CUDACC__) - -template -struct KernelSharedMem { - size_t size_; - __device__ Span GetSpan() { - extern __shared__ __align__(sizeof(T)) T* mem[]; - return Span{(T*)mem, - static_cast::index_type>(size_)}; - } -}; - -template -__device__ Range SegGridStrideRange(T _end_it, int64_t _segment) { - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - idx *= _segment; - if (idx >= *_end_it) { - return Range {0, 0}; - } - return Range{idx, *_end_it, static_cast(gridDim.x * blockDim.x)}; -} - +#if defined(__CUDACC__) template __global__ void LaunchCUDAKernel(Functor _func, Range _range, Span... _spans) { - for (auto i : SegGridStrideRange(_range.end(), 1)) { + for (auto i : dh::GridStrideRange(*_range.begin(), *_range.end())) { _func(i, _spans...); } } -template -__global__ void LaunchCUDAKernel(Functor _func, - Range _range, SharedMem _shared, - Span... _spans) { - KernelSharedMem shared {_shared.size_}; - Span shared_span = shared.GetSpan(); - for (auto i : SegGridStrideRange(_range.end(), 1)) { - _func(i, shared_span, _spans...); - } -} #endif } // namespace detail @@ -149,163 +113,123 @@ __global__ void LaunchCUDAKernel(Functor _func, * will merge function with same signature. */ template -struct TransformN { - template - TransformN(Functor _func, Range _range, GPUSet _devices, - HDV... _vectors) { - bool on_device = _devices != GPUSet::Empty(); - - Reshard(_devices, _vectors...); - - if (on_device) { - LaunchCUDA(_func, _range, _devices, _vectors...); - } else { - LaunchCPU(_func, _range, _vectors...); +class Transform { + private: + template + struct Evaluator { + public: + Evaluator(Functor _func, Range _range, GPUSet _devices) : + func_{_func}, range_{_range}, devices_{_devices} {} + + template + void Eval(HDV... _vectors) { + bool on_device = devices_ != GPUSet::Empty(); + + Reshard(devices_, _vectors...); + + if (on_device) { + LaunchCUDA(func_, range_, devices_, _vectors...); + } else { + LaunchCPU(func_, range_, _vectors...); + } } - } - template - TransformN(Functor _func, - Range _range, GPUSet _devices, SharedMem _shared, - HDV... _vectors) { - bool on_device = _devices != GPUSet::Empty(); - Reshard(_devices, _vectors...); + private: + template + Span UnpackHDV(HostDeviceVector* _vec, int _device) { + return _vec->DeviceSpan(_device); + } + template + Span UnpackHDV(const HostDeviceVector* _vec, int _device) { + return _vec->ConstDeviceSpan(_device); + } - if (on_device) { - LaunchCUDA(_func, _range, _devices, _shared, _vectors...); - } else { - LaunchCPU(_func, _range, _shared, _vectors...); + template + Span UnpackHDV(HostDeviceVector* _vec) { + return Span {_vec->HostPointer(), + static_cast::index_type>(_vec->Size())}; + } + template + Span UnpackHDV(const HostDeviceVector* _vec) { + return Span {_vec->ConstHostPointer(), + static_cast::index_type>(_vec->Size())}; } - } - private: - template - void Reshard(GPUSet _devices, HDV*... _vectors) { - std::vector vectors {_vectors...}; + template + void Reshard(GPUSet _devices, HDV*... _vectors) { + std::vector vectors {_vectors...}; #pragma omp parallel for schedule(static, 1) if (vectors.size() > 1) - for (omp_ulong i = 0; i < vectors.size(); ++i) { // NOLINT - switch (vectors[i].GetType()) { - case HDVAnyPtr::Type::kBstFloatType: - vectors[i].GetFloat()->Reshard(_devices); - break; - case HDVAnyPtr::Type::kGradientPairType: - vectors[i].GetGradientPair()->Reshard(_devices); - break; - case HDVAnyPtr::Type::kIntType: - vectors[i].GetInt()->Reshard(_devices); - break; - case HDVAnyPtr::Type::kEntryType: - vectors[i].GetEntry()->Reshard(_devices); - break; - case HDVAnyPtr::Type::kSizeType: - vectors[i].GetSizeT()->Reshard(_devices); - break; - default: - LOG(FATAL) << "Unknown HostDeviceVector type: " << (int) vectors[i].GetType(); + for (omp_ulong i = 0; i < vectors.size(); ++i) { // NOLINT + switch (vectors[i].GetType()) { + case HDVAnyPtr::Type::kBstFloatType: + vectors[i].GetFloat()->Reshard(_devices); + break; + case HDVAnyPtr::Type::kGradientPairType: + vectors[i].GetGradientPair()->Reshard(_devices); + break; + case HDVAnyPtr::Type::kIntType: + vectors[i].GetInt()->Reshard(_devices); + break; + case HDVAnyPtr::Type::kEntryType: + vectors[i].GetEntry()->Reshard(_devices); + break; + case HDVAnyPtr::Type::kSizeType: + vectors[i].GetSizeT()->Reshard(_devices); + break; + default: + LOG(FATAL) << "Unknown HostDeviceVector type: " << + static_cast(vectors[i].GetType()); + } } } - } - template - Span UnpackHDV(HostDeviceVector* _vec, int _device) { - return _vec->DeviceSpan(_device); - } - template - Span UnpackHDV(const HostDeviceVector* _vec, int _device) { - return _vec->ConstDeviceSpan(_device); - } - - template - Span UnpackHDV(HostDeviceVector* _vec) { - return Span {_vec->HostPointer(), - static_cast::index_type>(_vec->Size())}; - } - template - Span UnpackHDV(const HostDeviceVector* _vec) { - return Span {_vec->ConstHostPointer(), - static_cast::index_type>(_vec->Size())}; - } - #if defined(__CUDACC__) - template ::type* = nullptr, - typename... HDV> - void LaunchCUDA(Functor _func, Range _range, GPUSet _devices, - HDV*... _vectors) { -#pragma omp parallel for schedule(static, 1) if (_devices.Size() > 1) - for (omp_ulong i = 0; i < _devices.Size(); ++i) { - int d = _devices.Index(i); - dh::safe_cuda(cudaSetDevice(d)); - const int GRID_SIZE = - static_cast(dh::DivRoundUp(*(_range.end()), kBlockThreads)); - - detail::LaunchCUDAKernel<<>>( - _func, _range, UnpackHDV(_vectors, d)...); - dh::safe_cuda(cudaGetLastError()); - dh::safe_cuda(cudaDeviceSynchronize()); - } - } - template ::type* = nullptr, - typename S, - typename... HDV> - void LaunchCUDA(Functor _func, - Range _range, GPUSet _devices, SharedMem _shared, - HDV*... _vectors) { + template ::type* = nullptr, + typename... HDV> + void LaunchCUDA(Functor _func, Range _range, GPUSet _devices, + HDV*... _vectors) { #pragma omp parallel for schedule(static, 1) if (_devices.Size() > 1) - for (omp_ulong i = 0; i < _devices.Size(); ++i) { - int d = _devices.Index(i); - dh::safe_cuda(cudaSetDevice(d)); - const int GRID_SIZE = - static_cast(dh::DivRoundUp(*(_range.end()), kBlockThreads)); - detail::LaunchCUDAKernel - <<>> - (_func, _range, _shared, UnpackHDV(_vectors, d)...); - dh::safe_cuda(cudaGetLastError()); - dh::safe_cuda(cudaDeviceSynchronize()); + for (omp_ulong i = 0; i < _devices.Size(); ++i) { + int d = _devices.Index(i); + dh::safe_cuda(cudaSetDevice(d)); + const int GRID_SIZE = + static_cast(dh::DivRoundUp(*(_range.end()), kBlockThreads)); + + detail::LaunchCUDAKernel<<>>( + _func, _range, UnpackHDV(_vectors, d)...); + dh::safe_cuda(cudaGetLastError()); + dh::safe_cuda(cudaDeviceSynchronize()); + } } - } #else - template ::type* = nullptr, - typename... HDV> - void LaunchCUDA(Functor _func, Range _range, GPUSet _devices, - HDV*... _vectors) { - LOG(FATAL) << "Not part of device code. WITH_CUDA: " << WITH_CUDA(); - } - template ::type* = nullptr, - typename S, - typename... HDV> - void LaunchCUDA(Functor _func, - Range _range, GPUSet _devices, size_t _shared_size, - HDV*... _vectors) { - LOG(FATAL) << "Not part of device code. WITH_CUDA: " << WITH_CUDA(); - } + template ::type* = nullptr, + typename... HDV> + void LaunchCUDA(Functor _func, Range _range, GPUSet _devices, + HDV*... _vectors) { + LOG(FATAL) << "Not part of device code. WITH_CUDA: " << WITH_CUDA(); + } #endif - template - void LaunchCPU(Functor _func, common::Range _range, HDV*... vectors) { - auto end = *(_range.end()); + // CPU implementations + template + void LaunchCPU(Functor _func, common::Range _range, HDV*... vectors) { + auto end = *(_range.end()); #pragma omp parallel for schedule(static, 1) - for (omp_ulong idx = 0; idx < end; ++idx) { - _func(idx, UnpackHDV(vectors)...); - } - } - - template - void LaunchCPU(Functor _func, - common::Range _range, SharedMem _shared, HDV*... vectors) { - auto end = *(_range.end()); -#pragma omp parallel - { - std::vector shared_mem (_shared.size_); - Span shared_span {shared_mem.data(), - static_cast::index_type>(shared_mem.size())}; -#pragma omp for schedule(static, 1) for (omp_ulong idx = 0; idx < end; ++idx) { - _func(idx, shared_span, UnpackHDV(vectors)...); + _func(idx, UnpackHDV(vectors)...); } } + + private: + Functor func_; + Range range_; + GPUSet devices_; + }; + + public: + template + static Evaluator Init(Functor _func, + Range _range, GPUSet _devices) { + return Evaluator {_func, _range, _devices}; } }; diff --git a/src/objective/hinge.cu b/src/objective/hinge.cu index a0fc102b82c8..2855239efb77 100644 --- a/src/objective/hinge.cu +++ b/src/objective/hinge.cu @@ -57,7 +57,7 @@ class HingeObj : public ObjFunction { const bool is_null_weight = info.weights_.Size() == 0; const size_t ndata = preds.Size(); out_gpair->Resize(ndata); - common::TransformN<>( + common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, common::Span _label_correct, common::Span _out_gpair, @@ -77,17 +77,17 @@ class HingeObj : public ObjFunction { } _out_gpair[_idx] = GradientPair(g, h); }, - common::Range{0, static_cast(ndata)}, devices_, - &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + common::Range{0, static_cast(ndata)}, devices_).Eval( + &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); } void PredTransform(HostDeviceVector *io_preds) override { - common::TransformN<>( + common::Transform<>::Init( [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { _preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0; }, - common::Range{0, static_cast(io_preds->Size()), 1}, - devices_, io_preds); + common::Range{0, static_cast(io_preds->Size()), 1}, devices_) + .Eval(io_preds); } const char* DefaultEvalMetric() const override { diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 6edfd190f624..d0b99885593b 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -64,14 +64,14 @@ class SoftmaxMultiClassObj : public ObjFunction { preds_cache_.Resize(preds.Size()); const bool is_null_weight = info.weights_.Size() == 0; - common::TransformN<>( - [=] XGBOOST_DEVICE (size_t idx, - common::Span preds_cache, - common::Span gpair, - common::Span labels, - common::Span preds, - common::Span weights, - common::Span _label_correct) { + common::Transform<>::Init( + [=] XGBOOST_DEVICE(size_t idx, + common::Span preds_cache, + common::Span gpair, + common::Span labels, + common::Span preds, + common::Span weights, + common::Span _label_correct) { common::Span point = preds.subspan(idx * nclass, nclass); common::Span point_cache = preds_cache.subspan(idx * nclass, nclass); @@ -96,10 +96,9 @@ class SoftmaxMultiClassObj : public ObjFunction { gpair[idx * nclass + k] = GradientPair(p * wt, h); } } - }, - common::Range{0, ndata}, devices_, - &preds_cache_, out_gpair, &info.labels_, &preds, &info.weights_, - &label_correct_); + }, common::Range{0, ndata}, devices_) + .Eval(&preds_cache_, out_gpair, &info.labels_, &preds, &info.weights_, + &label_correct_); std::vector& label_correct_h = label_correct_.HostVector(); for (auto const flag : label_correct_h) { @@ -125,24 +124,24 @@ class SoftmaxMultiClassObj : public ObjFunction { const auto ndata = static_cast(io_preds->Size() / nclass); max_preds_.Resize(ndata); if (prob) { - common::TransformN<>( - [=] XGBOOST_DEVICE (size_t _idx, common::Span _preds) { + common::Transform<>::Init( + [=] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { common::Span point = _preds.subspan(_idx * nclass, nclass); common::Softmax(point.begin(), point.end()); }, - common::Range{0, ndata}, devices_, io_preds); + common::Range{0, ndata}, devices_).Eval(io_preds); } else { - common::TransformN<>( - [=] XGBOOST_DEVICE (size_t _idx, - common::Span _preds, - common::Span _max_preds) { + common::Transform<>::Init( + [=] XGBOOST_DEVICE(size_t _idx, + common::Span _preds, + common::Span _max_preds) { common::Span point = _preds.subspan(_idx * nclass, nclass); _max_preds[_idx] = common::FindMaxIndex(point.cbegin(), point.cend()) - point.cbegin(); }, - common::Range{0, ndata}, devices_, io_preds, &max_preds_); + common::Range{0, ndata}, devices_).Eval(io_preds, &max_preds_); } if (!prob) { io_preds->Resize(max_preds_.Size()); diff --git a/src/objective/regression_obj_gpu.cu b/src/objective/regression_obj_gpu.cu index d719456efdf7..efd6d8497047 100644 --- a/src/objective/regression_obj_gpu.cu +++ b/src/objective/regression_obj_gpu.cu @@ -74,7 +74,7 @@ class RegLossObj : public ObjFunction { bool is_null_weight = info.weights_.Size() == 0; auto scale_pos_weight = param_.scale_pos_weight; - common::TransformN<>( + common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, common::Span _label_correct, common::Span _out_gpair, @@ -94,8 +94,8 @@ class RegLossObj : public ObjFunction { _out_gpair[_idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w, Loss::SecondOrderGradient(p, label) * w); }, - common::Range{0, static_cast(ndata)}, devices_, - &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + common::Range{0, static_cast(ndata)}, devices_).Eval( + &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); // copy "label correct" flags back to host std::vector& label_correct_h = label_correct_.HostVector(); @@ -112,11 +112,11 @@ class RegLossObj : public ObjFunction { } void PredTransform(HostDeviceVector *io_preds) override { - common::TransformN<>( + common::Transform<>::Init( [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { _preds[_idx] = Loss::PredTransform(_preds[_idx]); }, common::Range{0, static_cast(io_preds->Size())}, - devices_, io_preds); + devices_).Eval(io_preds); } float ProbToMargin(float base_score) const override { @@ -189,7 +189,7 @@ class PoissonRegression : public ObjFunction { bool is_null_weight = info.weights_.Size() == 0; bst_float max_delta_step = param_.max_delta_step; - common::TransformN<>( + common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, common::Span _label_correct, common::Span _out_gpair, @@ -205,8 +205,8 @@ class PoissonRegression : public ObjFunction { _out_gpair[_idx] = GradientPair{(expf(p) - y) * w, expf(p + max_delta_step) * w}; }, - common::Range{0, static_cast(ndata)}, devices_, - &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + common::Range{0, static_cast(ndata)}, devices_).Eval( + &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); // copy "label correct" flags back to host std::vector& label_correct_h = label_correct_.HostVector(); for (auto const flag : label_correct_h) { @@ -216,12 +216,12 @@ class PoissonRegression : public ObjFunction { } } void PredTransform(HostDeviceVector *io_preds) override { - common::TransformN<>( + common::Transform<>::Init( [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { _preds[_idx] = expf(_preds[_idx]); }, - common::Range{0, static_cast(io_preds->Size())}, - devices_, io_preds); + common::Range{0, static_cast(io_preds->Size())}, devices_) + .Eval(io_preds); } void EvalTransform(HostDeviceVector *io_preds) override { PredTransform(io_preds); @@ -371,7 +371,7 @@ class GammaRegression : public ObjFunction { label_correct_.Fill(1); const bool is_null_weight = info.weights_.Size() == 0; - common::TransformN<>( + common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, common::Span _label_correct, common::Span _out_gpair, @@ -386,8 +386,8 @@ class GammaRegression : public ObjFunction { } _out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w); }, - common::Range{0, static_cast(ndata)}, devices_, - &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + common::Range{0, static_cast(ndata)}, devices_).Eval( + &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); // copy "label correct" flags back to host std::vector& label_correct_h = label_correct_.HostVector(); @@ -398,12 +398,12 @@ class GammaRegression : public ObjFunction { } } void PredTransform(HostDeviceVector *io_preds) override { - common::TransformN<>( + common::Transform<>::Init( [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { _preds[_idx] = expf(_preds[_idx]); }, - common::Range{0, static_cast(io_preds->Size())}, - devices_, io_preds); + common::Range{0, static_cast(io_preds->Size())}, devices_) + .Eval(io_preds); } void EvalTransform(HostDeviceVector *io_preds) override { PredTransform(io_preds); @@ -470,7 +470,7 @@ class TweedieRegression : public ObjFunction { const bool is_null_weight = info.weights_.Size() == 0; const float rho = param_.tweedie_variance_power; - common::TransformN<>( + common::Transform<>::Init( [=] XGBOOST_DEVICE(size_t _idx, common::Span _label_correct, common::Span _out_gpair, @@ -484,12 +484,13 @@ class TweedieRegression : public ObjFunction { _label_correct[0] = 0; } bst_float grad = -y * expf((1 - rho) * p) + expf((2 - rho) * p); - bst_float hess = -y * (1 - rho) * \ - std::exp((1 - rho) * p) + (2 - rho) * expf((2 - rho) * p); + bst_float hess = + -y * (1 - rho) * \ + std::exp((1 - rho) * p) + (2 - rho) * expf((2 - rho) * p); _out_gpair[_idx] = GradientPair(grad * w, hess * w); }, - common::Range{0, static_cast(ndata), 1}, devices_, - &label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); + common::Range{0, static_cast(ndata), 1}, devices_) + .Eval(&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_); // copy "label correct" flags back to host std::vector& label_correct_h = label_correct_.HostVector(); @@ -500,12 +501,12 @@ class TweedieRegression : public ObjFunction { } } void PredTransform(HostDeviceVector *io_preds) override { - common::TransformN<>( + common::Transform<>::Init( [] XGBOOST_DEVICE(size_t _idx, common::Span _preds) { _preds[_idx] = expf(_preds[_idx]); }, - common::Range{0, static_cast(io_preds->Size())}, - devices_, io_preds); + common::Range{0, static_cast(io_preds->Size())}, devices_) + .Eval(io_preds); } bst_float ProbToMargin(bst_float base_score) const override { diff --git a/tests/cpp/common/test_transform_range.cc b/tests/cpp/common/test_transform_range.cc index 367053af7e54..44ea1ec406bb 100644 --- a/tests/cpp/common/test_transform_range.cc +++ b/tests/cpp/common/test_transform_range.cc @@ -49,8 +49,8 @@ TEST(Transform, DeclareUnifiedTest(N)) { HostDeviceVector out_vec{h_out, TRANSFORM_GPU_DIST}; out_vec.Fill(0); - TransformN<>(TestTransformRange{}, Range{0, 16}, TRANSFORM_GPU_RANGE, - &out_vec, &in_vec); + Transform<>::Init(TestTransformRange{}, Range{0, 16}, TRANSFORM_GPU_RANGE) + .Eval(&out_vec, &in_vec); std::vector res = out_vec.HostVector(); ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin())); }