diff --git a/keras/backend/cntk_backend.py b/keras/backend/cntk_backend.py index 61d3a4b0eee9..0b9c0440ca0d 100644 --- a/keras/backend/cntk_backend.py +++ b/keras/backend/cntk_backend.py @@ -1886,7 +1886,10 @@ def batch_set_value(tuples): def stop_gradient(variables): - return C.stop_gradient(C.combine(variables)) + if isinstance(variables, (list, tuple)): + return map(C.stop_gradient, variables) + else: + return C.stop_gradient(variables) def switch(condition, then_expression, else_expression): diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index d93ecd756791..8e35d740ff18 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -2301,15 +2301,21 @@ def gradients(loss, variables): def stop_gradient(variables): - """Returns `variables` but with zero gradient w.r.t. every other variable. + """Returns `variables` but with zero gradient with respect to every other + variable. # Arguments - variables: List of variables. + variables: tensor or list of tensors to consider constant with respect + to any other variable # Returns - The same list of variables. + A single tensor or a list of tensors (depending on the passed argument) + that has constant gradient with respect to any other variable """ - return tf.stop_gradient(variables) + if isinstance(variables, (list, tuple)): + return map(tf.stop_gradient, variables) + else: + return tf.stop_gradient(variables) # CONTROL FLOW diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 636f0e5a3ea0..f36089c59d3a 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -1212,9 +1212,20 @@ def gradients(loss, variables): def stop_gradient(variables): """Returns `variables` but with zero gradient with respect to every other - variables. + variable. + + # Arguments + variables: tensor or list of tensors to consider constant with respect + to any other variable + + # Returns + A single tensor or a list of tensors (depending on the passed argument) + that has constant gradient with respect to any other variable """ - return theano.gradient.disconnected_grad(variables) + if isinstance(variables, (list, tuple)): + return map(theano.gradient.disconnected_grad, variables) + else: + return theano.gradient.disconnected_grad(variables) # CONTROL FLOW diff --git a/tests/keras/backend/backend_test.py b/tests/keras/backend/backend_test.py index 65a230e4ad78..0ca4f71d6f0a 100644 --- a/tests/keras/backend/backend_test.py +++ b/tests/keras/backend/backend_test.py @@ -491,6 +491,17 @@ def test_gradient(self): assert_allclose(zero_list[i], z_list[i], atol=1e-05) assert_allclose(zero_list[i + 1], zero_list[i + 1], atol=1e-05) + def test_stop_gradient(self): + # This test checks the consistency of the stop_gradient backend API. + # It doesn't check the functionality (which is checked at the + # test_gradient test). + val = np.random.random((4, 2)) + for k in BACKENDS: + a = k.variable(val) + b = k.square(a) + c, d = k.stop_gradient([a, b]) + e = k.stop_gradient(b) + # cntk currently not support function in this way, so can't test as this def test_function(self): test_backend = [KTH, KTF]