diff --git a/src/operator/tensor/ordering_op-inl.h b/src/operator/tensor/ordering_op-inl.h index 105ee8b90db8..16e6c0ecd3fc 100644 --- a/src/operator/tensor/ordering_op-inl.h +++ b/src/operator/tensor/ordering_op-inl.h @@ -170,11 +170,13 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, // Use full sort when K is relatively large. const bool full_sort(K*8 > N); // Batch size. - const int M(dat.size(0)/N); + const int M(work.size(0)/(sizeof(real_t)*N)); const int omp_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()); #pragma omp parallel for num_threads(omp_threads) for (int i = 0; i < M; ++i) { - real_t *vals = dat.dptr_; + // Tensor `work` stores the flattened source data, while `dat` stores the sorted result. + real_t *vals = reinterpret_cast(work.dptr_); + real_t *sorted_vals = dat.dptr_+i*N; int *indices = ind.dptr_+i*N; if (is_ascend) { if (full_sort) { @@ -193,11 +195,9 @@ MSHADOW_FORCE_INLINE void TopKSort(const Tensor& dat, [&](const int& i1, const int& i2){ return vals[i1] > vals[i2]; }); } } - real_t *buff = reinterpret_cast(work.dptr_)+i*K; for (int j = 0; j < K; ++j) { - buff[j] = vals[indices[j]]; + sorted_vals[j] = vals[indices[j]]; } - std::copy(buff, buff+K, &vals[i*N]); } } @@ -380,16 +380,7 @@ void TopKImpl(RunContext ctx, indices = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(src.Size()), s); // indices in the original matrix workspace_curr_ptr += sizeof(int) * src.Size(); - if (do_transpose) { - sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size())); - } else { - sorted_dat = reshape(dat, Shape1(src.Size())); - } - mxnet_op::Kernel::Launch(s, batch_size * element_num, 1, 0, 1, - kWriteTo, indices.dptr_); - CHECK_EQ(sorted_dat.CheckContiguous(), true); - CHECK_EQ(indices.CheckContiguous(), true); if (param.ret_typ == topk_enum::kReturnMask) { sel_indices = Tensor(reinterpret_cast(workspace_curr_ptr), Shape1(batch_size * k), s); @@ -401,15 +392,47 @@ void TopKImpl(RunContext ctx, CHECK_EQ(sel_indices.CheckContiguous(), true); CHECK_EQ(mask_val.CheckContiguous(), true); } - temp_workspace = Tensor(workspace_curr_ptr, Shape1(temp_size), s); // temp space - workspace_curr_ptr += temp_size; + + if (std::is_same::value) { + Tensor flattened_data; + if (do_transpose) { + flattened_data = Tensor(reinterpret_cast(workspace_curr_ptr), + Shape1(src.Size()), s); + workspace_curr_ptr += sizeof(real_t) * src.Size(); + flattened_data = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size())); + CHECK_EQ(flattened_data.CheckContiguous(), true); + } else { + flattened_data = src.FlatTo1D(s); + } + // `temp_workspace` stores the flattened data + temp_workspace = Tensor(reinterpret_cast(flattened_data.dptr_), + Shape1(sizeof(real_t)*src.Size()), s); + CHECK_EQ(temp_workspace.CheckContiguous(), true); + } else { + if (do_transpose) { + sorted_dat = reshape(transpose(dat, Shape3(0, 2, 1)), Shape1(src.Size())); + } else { + sorted_dat = reshape(dat, Shape1(src.Size())); + } + CHECK_EQ(sorted_dat.CheckContiguous(), true); + temp_workspace = Tensor(workspace_curr_ptr, Shape1(temp_size), s); // temp space + workspace_curr_ptr += temp_size; + } + + mxnet_op::Kernel::Launch(s, batch_size * element_num, 1, 0, 1, + kWriteTo, indices.dptr_); + CHECK_EQ(indices.CheckContiguous(), true); // 2. Perform inplace batch sort. // After sorting, each batch in `sorted_dat` will be sorted in the corresponding order // up to the k-th element and the `indices` will contain the corresponding index in `sorted_dat` + // `temp_workspace` is used to store the flattend source data for CPU device, and it's used as + // a temporal buffer for GPU device. TopKSort(sorted_dat, indices, temp_workspace, k, element_num, is_ascend, s); // 3. Assign results to the ret blob + // When returning indices, only update(modulo) required elements instead of full elements + // to avoid redundant calculation. if (param.ret_typ == topk_enum::kReturnMask) { Tensor ret_mask = ret[0].get_with_shape(Shape2(ret[0].Size(), 1), s); @@ -427,7 +450,6 @@ void TopKImpl(RunContext ctx, } IndexFill(ret_mask, sel_indices, mask_val); } else if (param.ret_typ == topk_enum::kReturnIndices) { - indices = F(indices, element_num); if (do_transpose) { Tensor ret_indices = ret[0].FlatTo3D(axis, axis, s); ret_indices = tcast(transpose( @@ -437,14 +459,15 @@ void TopKImpl(RunContext ctx, element_num)), 0, k), Shape3(0, 2, 1))); + ret_indices = F(ret_indices, element_num); } else { Tensor ret_indices = ret[0].get_with_shape(Shape2(batch_size, k), s); ret_indices = tcast(slice<1>( inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)); + ret_indices = F(ret_indices, element_num); } } else { - indices = F(indices, element_num); if (do_transpose) { Tensor ret_value = ret[0].FlatTo3D(axis, axis, s); Tensor ret_indices = ret[1].FlatTo3D(axis, axis, s); @@ -460,6 +483,7 @@ void TopKImpl(RunContext ctx, element_num)), 0, k), Shape3(0, 2, 1))); + ret_indices = F(ret_indices, element_num); } else { Tensor ret_value = ret[0].get_with_shape(Shape2(batch_size, k), s); @@ -468,6 +492,7 @@ void TopKImpl(RunContext ctx, ret_value = slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k); ret_indices = tcast(slice<1>( inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k)); + ret_indices = F(ret_indices, element_num); } } }