Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix reduce_kernel_M1
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Aug 3, 2018
1 parent 5628194 commit 05718f5
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/operator/tensor/broadcast_reduce-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ __global__ void reduce_kernel_M1(const int N, const bool addto,
for (int idx = threadIdx.x + blockIdx.x*blockDim.x; idx < N; idx += blockDim.x*gridDim.x) {
Shape<ndim> coord = unravel(idx, sshape);
int j = ravel(coord, bshape);
assign(&small[idx], addto, OP::Map(big[j]));
DType val, residual;
Reducer::SetInitValue(val, residual);
Reducer::Reduce(val, OP::Map(big[j]), residual);
Reducer::Finalize(val, residual);
assign(&small[idx], addto, val);
}
}

Expand All @@ -287,7 +291,10 @@ __global__ void reduce_kernel_M1(const int N, const bool addto,
int idx_big = ravel(coord, big_shape);
int idx_lhs = ravel(coord, lhs_shape);
int idx_rhs = ravel(coord, rhs_shape);
DType val = OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs]));
DType val, residual;
Reducer::SetInitValue(val, residual);
Reducer::Reduce(val, OP1::Map(big[idx_big], OP2::Map(lhs[idx_lhs], rhs[idx_rhs])), residual);
Reducer::Finalize(val, residual);
assign(&small[idx], addto, val);
}
}
Expand Down

0 comments on commit 05718f5

Please sign in to comment.