Skip to content

Commit

Permalink
Don't use shared memory in Softmax.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 31, 2018
1 parent 4d41f52 commit 09607bc
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
30 changes: 29 additions & 1 deletion src/common/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ inline avx::Float8 Sigmoid(avx::Float8 x) {
}

/*
* \brief do inplace softmax transformaton on start to end
* \brief Do inplace softmax transformaton on start to end
*
* \tparam Iterator Input iterator type
*
Expand All @@ -51,6 +51,34 @@ XGBOOST_DEVICE inline void Softmax(Iterator start, Iterator end) {
*i /= static_cast<float>(wsum);
}
}

/*
* \brief Softmax transformaton on start to end
*
* \tparam InIter Input iterator type
* \tparam OutIter Output iterator type
*
* \param beg_in Start iterator of input
* \param end_in End iterator of input
* \param beg_out Start iterator of output
* \param beg_out End iterator of output
*/
template <typename InIter, typename OutIter>
XGBOOST_DEVICE inline void Softmax(InIter beg_in, InIter end_in,
OutIter beg_out, OutIter end_out) {
float wsum = 0;
InIter i = beg_in;
OutIter o = beg_out;
for (;i != end_in; ++i, ++o) {
*o = expf(*i);
wsum += *o;
}
o = beg_out;
for (; o != end_out; ++o) {
*o = *o / wsum;
}
}

/*!
* \brief Find the maximum iterator within the iterators
* \param begin The begining iterator.
Expand Down
30 changes: 15 additions & 15 deletions src/objective/multiclass_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
#include <dmlc/omp.h>
#include <dmlc/parameter.h>
#include <xgboost/data.h>
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#include <vector>
Expand Down Expand Up @@ -60,25 +61,23 @@ class SoftmaxMultiClassObj : public ObjFunction {
label_correct_.Fill(-1);
const int nclass = param_.num_class;
const auto ndata = static_cast<int64_t>(preds.Size() / nclass);
preds_cache_.Resize(preds.Size());

const bool is_null_weight = info.weights_.Size() == 0;
common::TransformN<>(
[=] XGBOOST_DEVICE (size_t idx,
common::Span<bst_float> points_cache,
common::Span<bst_float> preds_cache,
common::Span<GradientPair> gpair,
common::Span<bst_float const> labels,
common::Span<bst_float const> preds,
common::Span<bst_float const> weights,
common::Span<int> _label_correct) {
common::Span<bst_float const> pred_point =
preds.subspan(idx * nclass, nclass);
size_t offset_in_block = idx * nclass;
for (size_t i = 0; i < nclass; ++i) {
points_cache[offset_in_block + i] = pred_point[i];
}
common::Span<bst_float> current_point =
points_cache.subspan(offset_in_block, nclass);
common::Softmax(current_point.begin(), current_point.end());
common::Span<bst_float const> point = preds.subspan(idx * nclass, nclass);
common::Span<bst_float> point_cache =
preds_cache.subspan(idx * nclass, nclass);

common::Softmax(point.begin(), point.end(),
point_cache.begin(), point_cache.end());

auto label = labels[idx];
if (label < 0 || label >= nclass) {
Expand All @@ -89,9 +88,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
}
bst_float wt = is_null_weight ? 1.0f : weights[idx];
for (int k = 0; k < nclass; ++k) {
bst_float p = current_point[k];
const float eps = 1e-16f;
const bst_float h = fmax(2.0f * p * (1.0f - p) * wt, eps);
bst_float p = point_cache[k];
const bst_float h = fmax(2.0f * p * (1.0f - p) * wt, kRtEps);
if (label == k) {
gpair[idx * nclass + k] = GradientPair((p - 1.0f) * wt, h);
} else {
Expand All @@ -100,8 +98,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
}
},
common::Range{0, ndata}, devices_,
common::SharedMem<bst_float>{nclass * common::kBlockThreads},
out_gpair, &info.labels_, &preds, &info.weights_, &label_correct_);
&preds_cache_, out_gpair, &info.labels_, &preds, &info.weights_,
&label_correct_);

std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) {
Expand Down Expand Up @@ -158,6 +156,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
// parameter
SoftmaxMultiClassParam param_;
GPUSet devices_;
// Cache for preds, used in get gradient dealing with const input.
HostDeviceVector<bst_float> preds_cache_;
// Cache for max_preds
HostDeviceVector<bst_float> max_preds_;
HostDeviceVector<int> label_correct_;
Expand Down

0 comments on commit 09607bc

Please sign in to comment.