From 16d9090c91c761d7dfead0ba5eb3ab4c774ae974 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 24 Jul 2017 17:29:31 +0200 Subject: [PATCH] Fix stop_gradient inconsistent API The stop_gradient documentation states that the argument should be a list of variables. The Theano implementation crashes if the argument is a list of variables and the CNTK implementation crashes if it is not. This commit handles both cases as can be expected. --- keras/backend/cntk_backend.py | 5 ++++- keras/backend/tensorflow_backend.py | 11 ++++++++--- keras/backend/theano_backend.py | 16 +++++++++++++--- tests/keras/backend/backend_test.py | 11 +++++++++++ 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/keras/backend/cntk_backend.py b/keras/backend/cntk_backend.py index 61d3a4b0eee..0b9c0440ca0 100644 --- a/keras/backend/cntk_backend.py +++ b/keras/backend/cntk_backend.py @@ -1886,7 +1886,10 @@ def batch_set_value(tuples): def stop_gradient(variables): - return C.stop_gradient(C.combine(variables)) + if isinstance(variables, (list, tuple)): + return map(C.stop_gradient, variables) + else: + return C.stop_gradient(variables) def switch(condition, then_expression, else_expression): diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index d93ecd75679..c6b06acb08d 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -2304,12 +2304,17 @@ def stop_gradient(variables): """Returns `variables` but with zero gradient w.r.t. every other variable. # Arguments - variables: List of variables. + variables: tensor or list of tensors to consider constant with respect + to any other variable. # Returns - The same list of variables. + A single tensor or a list of tensors (depending on the passed argument) + that has constant gradient with respect to any other variable. """ - return tf.stop_gradient(variables) + if isinstance(variables, (list, tuple)): + return map(tf.stop_gradient, variables) + else: + return tf.stop_gradient(variables) # CONTROL FLOW diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 636f0e5a3ea..15e198e8d8b 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -1211,10 +1211,20 @@ def gradients(loss, variables): def stop_gradient(variables): - """Returns `variables` but with zero gradient with respect to every other - variables. + """Returns `variables` but with zero gradient w.r.t. every other variable. + + # Arguments + variables: tensor or list of tensors to consider constant with respect + to any other variable. + + # Returns + A single tensor or a list of tensors (depending on the passed argument) + that has constant gradient with respect to any other variable. """ - return theano.gradient.disconnected_grad(variables) + if isinstance(variables, (list, tuple)): + return map(theano.gradient.disconnected_grad, variables) + else: + return theano.gradient.disconnected_grad(variables) # CONTROL FLOW diff --git a/tests/keras/backend/backend_test.py b/tests/keras/backend/backend_test.py index 65a230e4ad7..0ca4f71d6f0 100644 --- a/tests/keras/backend/backend_test.py +++ b/tests/keras/backend/backend_test.py @@ -491,6 +491,17 @@ def test_gradient(self): assert_allclose(zero_list[i], z_list[i], atol=1e-05) assert_allclose(zero_list[i + 1], zero_list[i + 1], atol=1e-05) + def test_stop_gradient(self): + # This test checks the consistency of the stop_gradient backend API. + # It doesn't check the functionality (which is checked at the + # test_gradient test). + val = np.random.random((4, 2)) + for k in BACKENDS: + a = k.variable(val) + b = k.square(a) + c, d = k.stop_gradient([a, b]) + e = k.stop_gradient(b) + # cntk currently not support function in this way, so can't test as this def test_function(self): test_backend = [KTH, KTF]