Skip to content

Commit

Permalink
Tensorflow op that scales gradient for backwards pass.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 267388219
Change-Id: I9f55ff9c47de88653d9214563e433f2a27645acd
  • Loading branch information
Sonnet Contributor authored and sonnet-copybara committed Sep 5, 2019
1 parent eadee4d commit ecf4e35
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 0 deletions.
1 change: 1 addition & 0 deletions sonnet/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ snt_py_library(
"//sonnet/src:once",
"//sonnet/src:recurrent",
"//sonnet/src:reshape",
"//sonnet/src:scale_gradient",
"//sonnet/src:sequential",
"//sonnet/src:utils",
],
Expand Down
2 changes: 2 additions & 0 deletions sonnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from sonnet.src.reshape import Flatten
from sonnet.src.reshape import reshape
from sonnet.src.reshape import Reshape
from sonnet.src.scale_gradient import scale_gradient
from sonnet.src.sequential import Sequential
from sonnet.src.utils import format_variables
from sonnet.src.utils import log_variables
Expand Down Expand Up @@ -137,6 +138,7 @@
"optimizers",
"pad",
"regularizers",
"scale_gradient",
"static_unroll",
)

Expand Down
21 changes: 21 additions & 0 deletions sonnet/src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,27 @@ snt_py_test(
],
)

snt_py_library(
name = "scale_gradient",
srcs = ["scale_gradient.py"],
deps = [
":base",
# pip: numpy
# pip: tensorflow
],
)

snt_py_test(
name = "scale_gradient_test",
srcs = ["scale_gradient_test.py"],
deps = [
":scale_gradient",
":test_utils",
# pip: absl/testing:parameterized
# pip: tensorflow
],
)

snt_py_library(
name = "bayes_by_backprop",
srcs = ["bayes_by_backprop.py"],
Expand Down
40 changes: 40 additions & 0 deletions sonnet/src/scale_gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2019 The Sonnet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tensorflow op that scales gradient for backwards pass."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import tensorflow as tf


@tf.custom_gradient
def scale_gradient(t, scale):
"""Scales gradients for the backwards pass.
Args:
t: A Tensor.
scale: The scale factor for the gradient on the backwards pass.
Returns:
A Tensor same as input, with scaled backward gradient.
"""
def grad(dy):
"""Scaled gradient."""
return scale*dy, None
return t, grad

44 changes: 44 additions & 0 deletions sonnet/src/scale_gradient_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2019 The Sonnet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for sonnet.v2.src.scale_gradient."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import itertools

from absl.testing import parameterized
from sonnet.src import scale_gradient
from sonnet.src import test_utils
import tensorflow as tf


class ScaleGradientTest(test_utils.TestCase, parameterized.TestCase):

@parameterized.parameters(
*itertools.product([-1.0, 0.0, 1.0], [-0.5, 0.0, 0.5, 2.0])
)
def test_scale(self, t_, scale):
t = tf.Variable([t_])
with tf.GradientTape() as tape:
y = scale_gradient.scale_gradient(t, scale)
output = y * y
grad = tape.gradient(output, t)
self.assertAllEqual(grad.numpy(), [2*t_*scale])
self.assertAllEqual(output.numpy(), [t_**2])

if __name__ == "__main__":
# tf.enable_v2_behavior()
tf.test.main()

0 comments on commit ecf4e35

Please sign in to comment.