Skip to content

Commit

Permalink
fix consistent median (#8245)
Browse files Browse the repository at this point in the history
* fix consistent median

* refine

* replace math.prod to functools.reduce

* rm todo

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Shenghang Tsai <jackalcooper@gmail.com>
  • Loading branch information
3 people authored May 20, 2022
1 parent 38b6495 commit 0ed4f46
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 38 deletions.
49 changes: 28 additions & 21 deletions oneflow/user/kernels/median_with_indices_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,39 @@ class CpuMedianWithIndicesKernel final : public user_op::OpKernel {
const int64_t size = in->shape().elem_cnt();
if (size == 0) return;
const int64_t stride = in->shape().At(num_axes - 1);
const int64_t instance_num = size / stride;
user_op::Tensor* values = ctx->Tensor4ArgNameAndIndex("values", 0);
user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0);
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
Memcpy<DeviceType::kCPU>(ctx->stream(), tmp_buffer->mut_dptr<void>(), in->dptr<void>(),
size * sizeof(T));
size_t thread_num = Global<ThreadPool>::Get()->thread_num();
BalancedSplitter bs(size / stride, thread_num);
MultiThreadLoop(thread_num, [&](size_t thread_idx) {
size_t end = bs.At(thread_idx).end();
for (size_t i = bs.At(thread_idx).begin(); i < end; ++i) {
T* in_ptr = tmp_buffer->mut_dptr<T>() + i * stride;
T* val_ptr = values->mut_dptr<T>() + i;
int64_t* ind_ptr = indices->mut_dptr<int64_t>() + i;
std::vector<int64_t> idx(stride);
auto first = idx.begin();
auto last = idx.end();
std::iota(first, last, 0);
auto nth = first;
nth += (stride - 1) / 2;
std::nth_element(first, nth, last, [&in_ptr](int64_t i, int64_t j) {
return in_ptr[i] < in_ptr[j] || (in_ptr[i] == in_ptr[j] && i < j);
});
*val_ptr = in_ptr[*nth];
*ind_ptr = *nth;
}
});
const int64_t thread_num =
std::min(instance_num, (int64_t)Global<ThreadPool>::Get()->thread_num());
const BalancedSplitter bs(instance_num, thread_num);
BlockingCounter bc(thread_num);
FOR_RANGE(int64_t, thread_id, 0, thread_num) {
const Range range = bs.At(thread_id);
Global<ThreadPool>::Get()->AddWork([=, &bc]() {
FOR_RANGE(int64_t, i, range.begin(), range.end()) {
T* in_ptr = tmp_buffer->mut_dptr<T>() + i * stride;
T* val_ptr = values->mut_dptr<T>() + i;
int64_t* ind_ptr = indices->mut_dptr<int64_t>() + i;
std::vector<int64_t> idx(stride);
auto first = idx.begin();
auto last = idx.end();
std::iota(first, last, 0);
auto nth = first;
nth += (stride - 1) / 2;
std::nth_element(first, nth, last, [&in_ptr](int64_t i, int64_t j) {
return in_ptr[i] < in_ptr[j] || (in_ptr[i] == in_ptr[j] && i < j);
});
*val_ptr = in_ptr[*nth];
*ind_ptr = *nth;
}
bc.Decrease();
});
}
bc.WaitForeverUntilCntEqualZero();
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
Expand Down
15 changes: 6 additions & 9 deletions oneflow/user/kernels/median_with_indices_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,11 @@ void DispatchIndexSize(ep::Stream* stream, const int64_t elem_cnt, const int64_t
const T* in, const int64_t* sort_indices, T* out, int64_t* out_indices) {
const int64_t reduce_elem_cnt = elem_cnt / stride;
if (IsSafeUseIndex32(elem_cnt)) {
MedianSelectCuda<T, int32_t><<<BlocksNum4ThreadsNum(reduce_elem_cnt), kCudaThreadsNumPerBlock,
0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
reduce_elem_cnt, stride, in, sort_indices, out, out_indices);
RUN_CUDA_KERNEL((MedianSelectCuda<T, int32_t>), stream, reduce_elem_cnt, reduce_elem_cnt,
stride, in, sort_indices, out, out_indices);
} else {
MedianSelectCuda<T, int64_t><<<BlocksNum4ThreadsNum(reduce_elem_cnt), kCudaThreadsNumPerBlock,
0, stream->As<ep::CudaStream>()->cuda_stream()>>>(
reduce_elem_cnt, stride, in, sort_indices, out, out_indices);
RUN_CUDA_KERNEL((MedianSelectCuda<T, int64_t>), stream, reduce_elem_cnt, reduce_elem_cnt,
stride, in, sort_indices, out, out_indices);
}
}

Expand Down Expand Up @@ -117,9 +115,8 @@ class CudaMedianWithIndicesKernel final : public user_op::OpKernel {
const int64_t elem_cnt = in->shape().elem_cnt();
const int64_t instance_size = in->shape().At(in->shape().NumAxes() - 1);
const int64_t instance_num = elem_cnt / instance_size;
InitializeIndices<<<BlocksNum4ThreadsNum(elem_cnt), kCudaThreadsNumPerBlock, 0,
ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(
elem_cnt, buf_manager.InIndicesPtr(), instance_size);
RUN_CUDA_KERNEL(InitializeIndices, ctx->stream(), elem_cnt, elem_cnt,
buf_manager.InIndicesPtr(), instance_size);
SortPairsAscending(in->dptr<T>(), buf_manager.InIndicesPtr(), instance_num, instance_size,
buf_manager.TempStoragePtr(), buf_manager.TempStorageBytes(),
buf_manager.SortedInPtr(), buf_manager.OutIndicesPtr(),
Expand Down
6 changes: 2 additions & 4 deletions oneflow/user/ops/median_with_indices_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@ namespace oneflow {
/*static*/ Maybe<void> MedianWithIndicesOp::GetSbp(user_op::SbpContext* ctx) {
const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0);
int64_t num_axes = in_tensor.shape().NumAxes();
FOR_RANGE(int64_t, i, 0, num_axes) {
if (i != num_axes - 1) {
ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();
}
FOR_RANGE(int64_t, i, 0, num_axes - 1) {
ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build();
}
if (num_axes == 0) {
ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build();
Expand Down
16 changes: 12 additions & 4 deletions python/oneflow/test/modules/test_consistent_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,35 @@
limitations under the License.
"""
import unittest
import torch
from functools import reduce
import operator
import oneflow as flow
import oneflow.unittest
from oneflow.test_utils.automated_test_util import *


@autotest(n=1, check_graph=False, rtol=1e-3, atol=1e-3)
@autotest(n=1, check_graph=False)
def _test_median(test_case, placement, sbp, ndim):
dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]
x = random_tensor(ndim, *dim_list).to_global(placement, sbp)
return torch.median(x)


@autotest(n=1, check_graph=False, rtol=1e-3, atol=1e-3)
@autotest(n=1, check_graph=False)
def _test_median_with_indices(test_case, placement, sbp, ndim):
dim = random(1, ndim).to(int).value()
dim_list = [random(1, 3).to(int).value() * 8 for _ in range(ndim)]
x = random_tensor(ndim, *dim_list).to_global(placement, sbp)
x = choice_tensor(
reduce(operator.mul, dim_list, 1),
dim_list,
replace=False,
dtype=float,
requires_grad=True,
).to_global(placement, sbp)
return torch.median(x, dim)


@unittest.skip("TODO: sometimes global TestMedian fails on 2-GPU runs")
class TestMedian(flow.unittest.TestCase):
@globaltest
def test_median(test_case):
Expand Down

0 comments on commit 0ed4f46

Please sign in to comment.