From 01dd70396951864cd08b0bfc5da9d95e0896bb98 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Thu, 4 Oct 2018 10:05:48 -0700 Subject: [PATCH] add FListInputNames attribute to softmax_cross_entropy (#12701) --- src/operator/loss_binary_op.cc | 4 ++++ tests/python/unittest/test_operator.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/operator/loss_binary_op.cc b/src/operator/loss_binary_op.cc index c1fedb3de61c..df8576cfbb83 100644 --- a/src/operator/loss_binary_op.cc +++ b/src/operator/loss_binary_op.cc @@ -67,6 +67,10 @@ Example:: }) .set_attr("FCompute", SoftmaxCrossEntropyForward) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_softmax_cross_entropy"}) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "label"}; +}) .add_argument("data", "NDArray-or-Symbol", "Input data") .add_argument("label", "NDArray-or-Symbol", "Input label"); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index b5a7303195f1..b17562c1d946 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6970,6 +6970,25 @@ def test_invalid_depth_dim(): test_invalid_block_size() test_invalid_depth_dim() + +@with_seed() +def test_softmax_cross_entropy(): + def f_sm_ce(data, label): + return np.sum(-np.log(data) * label) + + data = mx.sym.Variable('data') + label = mx.sym.Variable('label') + sym = mx.sym.softmax_cross_entropy(data=data, label=label) + num_labels = random.randint(100, 200) + batch_size = random.randint(100, 200) + np_data = rand_ndarray((batch_size, num_labels), stype='default').asnumpy() + np_sm = np_softmax(np_data) + np_label = np.random.randint(0, num_labels, (batch_size, )) + np_one_hot_label = np.zeros((batch_size, num_labels)) + np_one_hot_label[np.arange(batch_size), np_label] = 1. + check_symbolic_forward(sym, {'data' : np_data, 'label' : np_label}, [np.array([f_sm_ce(np_sm, np_one_hot_label)])], rtol=1e-3, atol=1e-5) + + @with_seed() def test_invalid_kernel_size(): invalid_kernel_size = 28