Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix the error of gradient of np.pad #19044

Merged
merged 7 commits into from
Sep 1, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/operator/numpy/np_pad_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,7 @@ template <typename xpu, int req>
struct pad_grad {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *a){
using namespace mxnet_op;
KERNEL_ASSIGN(out[i], req, 1);
KERNEL_ASSIGN(out[i], req, a[i]);
}
};

Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8431,8 +8431,8 @@ def hybrid_forward(self,F,A,**kwargs):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol = rtol, atol = atol)

# test gradient
mx_out.backward()
np_backward = np.ones(shape)
mx_out.backward(x)
np_backward = x
assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol)

# test imperative once again
Expand Down