diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh index be3d1f9223f4..33bf72798fd6 100644 --- a/src/operator/tensor/broadcast_reduce-inl.cuh +++ b/src/operator/tensor/broadcast_reduce-inl.cuh @@ -268,7 +268,11 @@ __global__ void reduce_kernel_M1(const int N, const bool addto, for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) { Shape coord = unravel(idx, sshape); int j = ravel(coord, bshape); - assign(&small[idx], addto, OP::Map(big[j])); + DType val, residual; + Reducer::SetInitValue(val, residual); + Reducer::Reduce(val, OP::Map(big[j]), residual); + Reducer::Finalize(val, residual); + assign(&small[idx], addto, val); } } @@ -287,7 +291,10 @@ __global__ void reduce_kernel_M1(const int N, const bool addto, int idx_big = ravel(coord, big_shape); int idx_lhs = ravel(coord, lhs_shape); int idx_rhs = ravel(coord, rhs_shape); - DType val = OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])); + DType val, residual; + Reducer::SetInitValue(val, residual); + Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual); + Reducer::Finalize(val, residual); assign(&small[idx], addto, val); } }