diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index edaf9397303a..1daf0a2cb18a 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -1034,8 +1034,8 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, using namespace mshadow::expr; CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 2U); - CHECK_EQ(req[take_::kIdx], kNullOp) - << "take layer doesn't support gradient into index"; + CHECK_NE(req[take_::kIdx], kAddTo) + << "take layer doesn't support gradient of req type kAddTo to index"; const TakeParam& param = nnvm::get(attrs.parsed); @@ -1052,6 +1052,11 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, const TShape& arrshape = outputs[0].shape_; const TShape& oshape = inputs[0].shape_; + if (req[take_::kIdx] != kNullOp) { + mxnet_op::Kernel::Launch( + s, idxshape.Size(), outputs[take_::kIdx].dptr()); + } + const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); int idxndim = idxshape.ndim(); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index c395199e8ea4..aeda36d8b487 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3810,6 +3810,31 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode): exe.backward([mx.nd.array(grad_out)]) assert_almost_equal(exe.grad_dict['a'].asnumpy(), grad_in) + def check_autograd_req(): + row_len = 2 + col_len = 8 + shape = (row_len, col_len) + sc = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype="float32") + sc.attach_grad() + i = mx.nd.array([0], dtype="int64") + j = mx.nd.array([0], dtype="int64") + with mx.autograd.record(train_mode=True): + xs = [] + for _ in range(row_len): + x_i = [] + for _ in range(col_len): + x_ij = sc.take(i).squeeze(axis=0).take(j).squeeze(axis=0) + x_i.append(x_ij) + j = j + 1 + i = i + 1 + j = j - col_len # reset j + xs.append(mx.nd.stack(*x_i)) + x = mx.nd.stack(*xs) + x = x.sum() + + x.backward() + assert_almost_equal(np.ones(sc.grad.shape), sc.grad.asnumpy()) + for mode in ['clip', 'wrap']: for data_ndim in range(1, 5): for idx_ndim in range(1, 4): @@ -3822,6 +3847,8 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode): idx_shape += (np.random.randint(low=1, high=5), ) check_output_n_grad(data_shape, idx_shape, axis, mode) + check_autograd_req() + @with_seed() def test_grid_generator():