Skip to content

Commit

Permalink
Fix stop_gradient inconsistent API
Browse files Browse the repository at this point in the history
The stop_gradient documentation states that the argument should be a list of
variables. The Theano implementation crashes if the argument is a list of
variables and the CNTK implementation crashes if it is not.

This commit handles both cases as can be expected.
  • Loading branch information
angeloskath committed Jul 24, 2017
1 parent aac5b53 commit 6cd6525
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 7 deletions.
5 changes: 4 additions & 1 deletion keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,7 +1886,10 @@ def batch_set_value(tuples):


def stop_gradient(variables):
return C.stop_gradient(C.combine(variables))
if isinstance(variables, (list, tuple)):
return map(C.stop_gradient, variables)
else:
return C.stop_gradient(variables)


def switch(condition, then_expression, else_expression):
Expand Down
14 changes: 10 additions & 4 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2301,15 +2301,21 @@ def gradients(loss, variables):


def stop_gradient(variables):
"""Returns `variables` but with zero gradient w.r.t. every other variable.
"""Returns `variables` but with zero gradient with respect to every other
variable.
# Arguments
variables: List of variables.
variables: tensor or list of tensors to consider constant with respect
to any other variable
# Returns
The same list of variables.
A single tensor or a list of tensors (depending on the passed argument)
that has constant gradient with respect to any other variable
"""
return tf.stop_gradient(variables)
if isinstance(variables, (list, tuple)):
return map(tf.stop_gradient, variables)
else:
return tf.stop_gradient(variables)


# CONTROL FLOW
Expand Down
15 changes: 13 additions & 2 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,9 +1212,20 @@ def gradients(loss, variables):

def stop_gradient(variables):
"""Returns `variables` but with zero gradient with respect to every other
variables.
variable.
# Arguments
variables: tensor or list of tensors to consider constant with respect
to any other variable
# Returns
A single tensor or a list of tensors (depending on the passed argument)
that has constant gradient with respect to any other variable
"""
return theano.gradient.disconnected_grad(variables)
if isinstance(variables, (list, tuple)):
return map(theano.gradient.disconnected_grad, variables)
else:
return theano.gradient.disconnected_grad(variables)


# CONTROL FLOW
Expand Down
11 changes: 11 additions & 0 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,17 @@ def test_gradient(self):
assert_allclose(zero_list[i], z_list[i], atol=1e-05)
assert_allclose(zero_list[i + 1], zero_list[i + 1], atol=1e-05)

def test_stop_gradient(self):
# This test checks the consistency of the stop_gradient backend API.
# It doesn't check the functionality (which is checked at the
# test_gradient test).
val = np.random.random((4, 2))
for k in BACKENDS:
a = k.variable(val)
b = k.square(a)
c, d = k.stop_gradient([a, b])
e = k.stop_gradient(b)

# cntk currently not support function in this way, so can't test as this
def test_function(self):
test_backend = [KTH, KTF]
Expand Down

0 comments on commit 6cd6525

Please sign in to comment.