Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Accelerate the performance of topk for CPU side #12085

Merged
merged 2 commits into from
Aug 13, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,13 @@ MSHADOW_FORCE_INLINE void TopKSort<cpu>(const Tensor<cpu, 1, real_t>& dat,
// Use full sort when K is relatively large.
const bool full_sort(K*8 > N);
// Batch size.
const int M(dat.size(0)/N);
const int M(work.size(0)/(sizeof(real_t)*N));
const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount());
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < M; ++i) {
real_t *vals = dat.dptr_;
// Tensor `work` stores the flattened source data, while `dat` stores the sorted result.
real_t *vals = reinterpret_cast<real_t*>(work.dptr_);
real_t *sorted_vals = dat.dptr_+i*N;
int *indices = ind.dptr_+i*N;
if (is_ascend) {
if (full_sort) {
Expand All @@ -193,11 +195,9 @@ MSHADOW_FORCE_INLINE void TopKSort<cpu>(const Tensor<cpu, 1, real_t>& dat,
[&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; });
}
}
real_t *buff = reinterpret_cast<real_t*>(work.dptr_)+i*K;
for (int j = 0; j < K; ++j) {
buff[j] = vals[indices[j]];
sorted_vals[j] = vals[indices[j]];
}
std::copy(buff, buff+K, &vals[i*N]);
}
}

Expand Down Expand Up @@ -380,16 +380,7 @@ void TopKImpl(RunContext ctx,
indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
Shape1(src.Size()), s); // indices in the original matrix
workspace_curr_ptr += sizeof(int) * src.Size();
if (do_transpose) {
sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
} else {
sorted_dat = reshape(dat, Shape1(src.Size()));
}
mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, 0, 1,
kWriteTo, indices.dptr_);

CHECK_EQ(sorted_dat.CheckContiguous(), true);
CHECK_EQ(indices.CheckContiguous(), true);
if (param.ret_typ == topk_enum::kReturnMask) {
sel_indices = Tensor<xpu, 1, int>(reinterpret_cast<int*>(workspace_curr_ptr),
Shape1(batch_size * k), s);
Expand All @@ -401,15 +392,47 @@ void TopKImpl(RunContext ctx,
CHECK_EQ(sel_indices.CheckContiguous(), true);
CHECK_EQ(mask_val.CheckContiguous(), true);
}
temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr, Shape1(temp_size), s); // temp space
workspace_curr_ptr += temp_size;

if (std::is_same<xpu, cpu>::value) {
Tensor<xpu, 1, real_t> flattened_data;
if (do_transpose) {
flattened_data = Tensor<xpu, 1, real_t>(reinterpret_cast<real_t*>(workspace_curr_ptr),
Shape1(src.Size()), s);
workspace_curr_ptr += sizeof(real_t) * src.Size();
flattened_data = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
CHECK_EQ(flattened_data.CheckContiguous(), true);
} else {
flattened_data = src.FlatTo1D<xpu, real_t>(s);
}
// `temp_workspace` stores the flattened data
temp_workspace = Tensor<xpu, 1, char>(reinterpret_cast<char*>(flattened_data.dptr_),
Shape1(sizeof(real_t)*src.Size()), s);
CHECK_EQ(temp_workspace.CheckContiguous(), true);
} else {
if (do_transpose) {
sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size()));
} else {
sorted_dat = reshape(dat, Shape1(src.Size()));
}
CHECK_EQ(sorted_dat.CheckContiguous(), true);
temp_workspace = Tensor<xpu, 1, char>(workspace_curr_ptr, Shape1(temp_size), s); // temp space
workspace_curr_ptr += temp_size;
}

mxnet_op::Kernel<range_fwd, xpu>::Launch(s, batch_size * element_num, 1, 0, 1,
kWriteTo, indices.dptr_);
CHECK_EQ(indices.CheckContiguous(), true);

// 2. Perform inplace batch sort.
// After sorting, each batch in `sorted_dat` will be sorted in the corresponding order
// up to the k-th element and the `indices` will contain the corresponding index in `sorted_dat`
// `temp_workspace` is used to store the flattend source data for CPU device, and it's used as
// a temporal buffer for GPU device.
TopKSort(sorted_dat, indices, temp_workspace, k, element_num, is_ascend, s);

// 3. Assign results to the ret blob
// When returning indices, only update(modulo) required elements instead of full elements
// to avoid redundant calculation.
if (param.ret_typ == topk_enum::kReturnMask) {
Tensor<xpu, 2, real_t> ret_mask =
ret[0].get_with_shape<xpu, 2, real_t>(Shape2(ret[0].Size(), 1), s);
Expand All @@ -427,7 +450,6 @@ void TopKImpl(RunContext ctx,
}
IndexFill(ret_mask, sel_indices, mask_val);
} else if (param.ret_typ == topk_enum::kReturnIndices) {
indices = F<mshadow_op::mod>(indices, element_num);
if (do_transpose) {
Tensor<xpu, 3, real_t> ret_indices = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
ret_indices = tcast<real_t>(transpose(
Expand All @@ -437,14 +459,15 @@ void TopKImpl(RunContext ctx,
element_num)),
0, k),
Shape3(0, 2, 1)));
ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
} else {
Tensor<xpu, 2, real_t> ret_indices =
ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
ret_indices = tcast<real_t>(slice<1>(
inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k));
ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
}
} else {
indices = F<mshadow_op::mod>(indices, element_num);
if (do_transpose) {
Tensor<xpu, 3, real_t> ret_value = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
Tensor<xpu, 3, real_t> ret_indices = ret[1].FlatTo3D<xpu, real_t>(axis, axis, s);
Expand All @@ -460,6 +483,7 @@ void TopKImpl(RunContext ctx,
element_num)),
0, k),
Shape3(0, 2, 1)));
ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
} else {
Tensor<xpu, 2, real_t> ret_value =
ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
Expand All @@ -468,6 +492,7 @@ void TopKImpl(RunContext ctx,
ret_value = slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k);
ret_indices = tcast<real_t>(slice<1>(
inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k));
ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
}
}
}
Expand Down