diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index ef59145bb4a9..b663ef0179df 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -570,6 +570,10 @@ Examples:: indices = [[1, 1, 0], [0, 1, 0]] gather_nd(data, indices) = [2, 3, 0] + data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + indices = [[0, 1], [1, 0]] + gather_nd(data, indices) = [[3, 4], [5, 6]] + )code") .set_num_outputs(1) .set_num_inputs(2) @@ -629,6 +633,21 @@ Examples:: shape = (2, 2) scatter_nd(data, indices, shape) = [[0, 0], [2, 3]] + data = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] + indices = [[0, 1], [1, 1]] + shape = (2, 2, 2, 2) + scatter_nd(data, indices, shape) = [[[[0, 0], + [0, 0]], + + [[1, 2], + [3, 4]]], + + [[[0, 0], + [0, 0]], + + [[5, 6], + [7, 8]]]] + )code") .set_num_outputs(1) .set_num_inputs(2)