diff --git a/docs/conf.py b/docs/conf.py index 017eb46f..3367b356 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -98,6 +98,7 @@ doctest_global_setup = """ import tensorflow as tf import sonnet as snt +import tree # `TpuReplicator` cannot be constructed without a TPU, however it has exactly # the same API as `Replicator` so we can run doctests using that instead. diff --git a/examples/BUILD b/examples/BUILD index 55350d0b..c33861ba 100644 --- a/examples/BUILD +++ b/examples/BUILD @@ -39,3 +39,17 @@ snt_py_test( # pip: tensorflow ], ) + +py_binary( + name = "functional_mlp_mnist", + srcs = ["functional_mlp_mnist.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + # pip: absl:app + # pip: absl/logging + "//sonnet", + # pip: tensorflow + # pip: tensorflow_datasets + ], +) diff --git a/examples/functional_mlp_mnist.py b/examples/functional_mlp_mnist.py new file mode 100644 index 00000000..b1de25aa --- /dev/null +++ b/examples/functional_mlp_mnist.py @@ -0,0 +1,120 @@ +# Copyright 2020 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Toy MLP on MNIST example of using TF2 JAX/HK shims.""" + +from absl import app +from absl import logging +import sonnet as snt +import tensorflow as tf +import tensorflow_datasets as tfds + +fn = snt.functional + + +def main(unused_argv): + del unused_argv + + with fn.variables(): + net = snt.nets.MLP([1000, 100, 10]) + + def loss_fn(images, labels): + images = snt.flatten(images) + logits = net(images) + loss = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, + logits=logits)) + return loss + + loss_fn = fn.transform(loss_fn) + + def preprocess(images, labels): + images = tf.image.convert_image_dtype(images, tf.float32) + return images, labels + + # _ _ + # | |_ _ __ __ _(_)_ __ + # | __| '__/ _` | | '_ \ + # | |_| | | (_| | | | | | + # \__|_| \__,_|_|_| |_| + # + + batch_size = 100 + + dataset = tfds.load("mnist", split="train", as_supervised=True) + dataset = dataset.map(preprocess) + dataset = dataset.cache() + dataset = dataset.shuffle(batch_size * 8) + dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.prefetch() + + # As before we want to unzip our loss_fn into init and apply. + optimizer = fn.adam(0.01) + + # To get our initial state we need to pull a record from our dataset and pass + # it to our init function. We'll also be sure to use `device_put` such that + # the parameters are on the accelerator. + images, labels = next(iter(dataset)) + params = fn.device_put(loss_fn.init(images, labels)) + opt_state = fn.device_put(optimizer.init(params)) + + # Our training loop is to iterate through 10 epochs of the train dataset, and + # use sgd after each minibatch to update our parameters according to the + # gradient from our loss function. + grad_apply_fn = fn.jit(fn.value_and_grad(loss_fn.apply)) + apply_opt_fn = fn.jit(optimizer.apply) + + for epoch in range(10): + for images, labels in dataset: + loss, grads = grad_apply_fn(params, images, labels) + params, opt_state = apply_opt_fn(opt_state, grads, params) + logging.info("[Epoch %s] loss=%s", epoch, loss.numpy()) + + # _ _ + # | |_ ___ ___| |_ + # | __/ _ \/ __| __| + # | || __/\__ \ |_ + # \__\___||___/\__| + # + + def accuracy_fn(images, labels): + images = snt.flatten(images) + predictions = tf.argmax(net(images), axis=1) + correct = tf.math.count_nonzero(tf.equal(predictions, labels)) + total = tf.shape(labels)[0] + return correct, total + + accuracy_fn = fn.transform(accuracy_fn) + + batch_size = 10000 + dataset = tfds.load("mnist", split="test", as_supervised=True) + dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.map(preprocess) + + # Note that while we still unzip our accuracy function, we can ignore the + # init_fn here since we already have all the state we need from our training + # function. + apply_fn = fn.jit(accuracy_fn.apply) + + # Compute top-1 accuracy. + num_correct = num_total = 0 + for images, labels in dataset: + correct, total = apply_fn(params, images, labels) + num_correct += correct + num_total += total + accuracy = (int(num_correct) / int(num_total)) * 100 + logging.info("Accuracy %.5f%%", accuracy) + +if __name__ == "__main__": + app.run(main) diff --git a/sonnet/BUILD b/sonnet/BUILD index 46a85031..61c69085 100644 --- a/sonnet/BUILD +++ b/sonnet/BUILD @@ -12,6 +12,7 @@ snt_py_library( visibility = ["//visibility:public"], deps = [ ":distribute", + ":functional", ":initializers", ":mixed_precision", ":optimizers", @@ -54,6 +55,17 @@ snt_py_library( ], ) +snt_py_library( + name = "functional", + srcs = ["functional.py"], + deps = [ + ":optimizers", + "//sonnet/src/functional:haiku", + "//sonnet/src/functional:jax", + "//sonnet/src/functional:optimizers", + ], +) + snt_py_library( name = "initializers", srcs = ["initializers.py"], diff --git a/sonnet/__init__.py b/sonnet/__init__.py index 9a3ce219..fa29e03d 100644 --- a/sonnet/__init__.py +++ b/sonnet/__init__.py @@ -19,6 +19,7 @@ from __future__ import print_function from sonnet import distribute +from sonnet import functional from sonnet import initializers from sonnet import mixed_precision from sonnet import nets @@ -133,6 +134,7 @@ "distribute", "dynamic_unroll", "format_variables", + "functional", "initializers", "log_variables", "lstm_with_recurrent_dropout", diff --git a/sonnet/functional.py b/sonnet/functional.py new file mode 100644 index 00000000..8662ed42 --- /dev/null +++ b/sonnet/functional.py @@ -0,0 +1,67 @@ +# Copyright 2020 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Simple functional APIs for TF2.""" + +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +from sonnet import optimizers as oo_optimizers +from sonnet.src.functional import haiku +from sonnet.src.functional import jax +from sonnet.src.functional import optimizers + +# Utilities for converting Sonnet code into pure functions. +variables = haiku.variables +transform = haiku.transform +transform_with_state = haiku.transform_with_state +without_state = haiku.without_state + +# Utilities for working with tensors on device. +device_get = jax.device_get +device_put = jax.device_put + +# Utilities for transforming pure functions. +grad = jax.grad +jit = jax.jit +value_and_grad = jax.value_and_grad + +# Optimizers. +optimizer = optimizers.optimizer +sgd = optimizer(oo_optimizers.SGD) +adam = optimizer(oo_optimizers.Adam) +rmsprop = optimizer(oo_optimizers.RMSProp) +momentum = optimizer(oo_optimizers.Momentum) + +# Avoid accidentally exporting the private API. +del oo_optimizers, haiku, optimizers, jax + +__all__ = ( + "variables", + "transform", + "transform_with_state", + "without_state", + "device_get", + "device_put", + "grad", + "jit", + "value_and_grad", + "optimizer", + "sgd", + "adam", + "rmsprop", + "momentum", +) diff --git a/sonnet/src/conformance/BUILD b/sonnet/src/conformance/BUILD index b706a3e9..855aa9fc 100644 --- a/sonnet/src/conformance/BUILD +++ b/sonnet/src/conformance/BUILD @@ -76,6 +76,7 @@ snt_py_test( "//sonnet", "//sonnet/src:test_utils", # pip: tensorflow + # pip: tree ], ) diff --git a/sonnet/src/conformance/doctest_test.py b/sonnet/src/conformance/doctest_test.py index cb24c9aa..ae9339ba 100644 --- a/sonnet/src/conformance/doctest_test.py +++ b/sonnet/src/conformance/doctest_test.py @@ -25,6 +25,7 @@ import sonnet as snt from sonnet.src import test_utils import tensorflow as tf +import tree class DoctestTest(test_utils.TestCase, parameterized.TestCase): @@ -62,7 +63,8 @@ def test_doctest(self, module): optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE, extraglobs={ "snt": snt, - "tf": tf + "tf": tf, + "tree": tree, }) if num_attempted == 0: self.skipTest("No doctests in %s" % module.__name__) diff --git a/sonnet/src/functional/BUILD b/sonnet/src/functional/BUILD new file mode 100644 index 00000000..43ffd402 --- /dev/null +++ b/sonnet/src/functional/BUILD @@ -0,0 +1,84 @@ +load("//sonnet/src:build_defs.bzl", "snt_py_library", "snt_py_test") + +package(default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"]) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +snt_py_library( + name = "haiku", + srcs = ["haiku.py"], + deps = [ + ":utils", + # pip: tensorflow + ], +) + +snt_py_library( + name = "jax", + srcs = ["jax.py"], + deps = [ + ":utils", + # pip: tensorflow + # pip: tree + ], +) + +snt_py_library( + name = "optimizers", + srcs = ["optimizers.py"], + deps = [ + ":haiku", + "//sonnet/src:base", + # pip: tensorflow + # pip: tree + ], +) + +snt_py_library( + name = "utils", + srcs = ["utils.py"], + deps = [ + "//sonnet/src:utils", + # pip: tensorflow + # pip: tree + ], +) + +snt_py_test( + name = "haiku_test", + srcs = ["haiku_test.py"], + deps = [ + ":haiku", + # pip: absl/testing:parameterized + "//sonnet", + "//sonnet/src:test_utils", + # pip: tensorflow + # pip: tree + ], +) + +snt_py_test( + name = "jax_test", + srcs = ["jax_test.py"], + deps = [ + ":jax", + # pip: absl/testing:parameterized + "//sonnet/src:test_utils", + # pip: tensorflow + ], +) + +snt_py_test( + name = "optimizers_test", + srcs = ["optimizers_test.py"], + deps = [ + ":optimizers", + # pip: absl/testing:parameterized + "//sonnet", + "//sonnet/src:test_utils", + # pip: tensorflow + # pip: tree + ], +) diff --git a/sonnet/src/functional/haiku.py b/sonnet/src/functional/haiku.py new file mode 100644 index 00000000..3b65970d --- /dev/null +++ b/sonnet/src/functional/haiku.py @@ -0,0 +1,462 @@ +# Copyright 2020 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Implements part of the Haiku ("Sonnet for JAX") API in TensorFlow 2.""" + +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +import collections +import functools +import itertools +import threading + +import contextlib +from sonnet.src.functional import utils +import tensorflow as tf + +Transformed = collections.namedtuple("Transformed", ("init", "apply")) +TransformedWithState = collections.namedtuple("TransformedWithState", + ("init", "apply")) + +# pylint: disable=not-context-manager + + +class TensorVariableCallbacks(threading.local): + """Holds callbacks that are notified when TensorVariable are used.""" + + instance = None # Thread local singleton instance. + + def __init__(self): + super(TensorVariableCallbacks, self).__init__() + self._recording = False + self._callbacks = [] + + def notify(self, variable): + if self._recording: + assert isinstance(variable, TensorVariable) + for callback in self._callbacks: + callback(variable) + + @contextlib.contextmanager + def __call__(self, callback): + self._callbacks.append(callback) + recording = self._recording + try: + self._recording = True + yield + finally: + assert self._callbacks.pop() is callback + self._recording = recording + +TensorVariableCallbacks.instance = TensorVariableCallbacks() + + +def notify(f): + """Wraps `f` such that callbacks are notified about it being called.""" + @functools.wraps(f) + def wrapper(self, *args, **kwargs): + TensorVariableCallbacks.instance.notify(self) + return f(self, *args, **kwargs) + return wrapper + + +def defer_property(name): + return property(fget=notify(lambda self: getattr(self.tensor_value, name))) + + +def safe_read_tensor_value(variable): + """Reads variable value or raises an exception.""" + + value = variable.tensor_value + if value is None: + raise ValueError("".join(( + "Attempted to read a TensorVariable in a context where it has no ", + "value. This commonly happens for one of two reasons:", + "", + " 1) You created a model in one transformed function and directly", + " accessed the model variables (e.g. via `model.variables` or" + " `model.w`) inside another transformed function.", + " 2) You are trying to read a model variable outside of a", + " transformed function.", + "", + "For (1) you can safely do this if you do not read the value of the", + "variable (e.g. you just use metadata like `v.shape` or `v.dtype`).", + "If you want to read the value of the variable then you must pass in", + "the value (e.g. pass the result of `f.init(..)`).", + "", + "For (2) to read variable values inspect the result of a transformed", + "function (e.g. look at the `params` dictionary returned from ", + "`f.init(..)`)."))) + + return value + + +def defer_read(): + return property( + fget=notify(lambda self: (lambda: safe_read_tensor_value(self)))) + + +def defer_raise_notimplemented(): + def _raise_notimplemented(): + raise NotImplementedError + + return property(fget=notify(_raise_notimplemented)) + + +def defer_indexed(f): + return property(fget=notify(lambda self, i: f(self, i.indices, i.values))) + + +def defer_assign(map_fn=None): + """Returns a function implementing notify+assign.""" + @notify + def wrapped(self, v): + if v is not None: + v = tf.convert_to_tensor(v, dtype=self.dtype) + if map_fn is not None: + v = map_fn(self.tensor_value, v) + if self.initial_tensor_value is None: + self.initial_tensor_value = v + self.tensor_value = v + return v + return wrapped + + +class TensorVariable(tf.Variable): + """Implements the tf.Variable API but backed by a tf.Tensor.""" + + def __init__(self, value, trainable, name=None): + # NOTE: Intentionally not calling super ctor. + self.initial_tensor_value = value + self.tensor_value = value + self._trainable = trainable + self._name = name + self._shape = value.shape + self._dtype = value.dtype + self._device = value.device + + # Properties. + # NOTE: These do not notify since they do not result in TensorFlow operations. + shape = property(fget=lambda self: self._shape) + dtype = property(fget=lambda self: self._dtype) + trainable = property(fget=lambda self: self._trainable) + name = property(fget=lambda self: self._name) + device = property(fget=lambda self: self._device) + + # Dense assign. + assign = defer_assign() + assign_add = defer_assign(tf.add) + assign_sub = defer_assign(tf.subtract) + + # Sparse assign. + batch_scatter_update = defer_raise_notimplemented() + scatter_add = defer_raise_notimplemented() + scatter_div = defer_raise_notimplemented() + scatter_max = defer_raise_notimplemented() + scatter_min = defer_raise_notimplemented() + scatter_mul = defer_raise_notimplemented() + scatter_sub = defer_raise_notimplemented() + scatter_update = defer_raise_notimplemented() + scatter_nd_add = defer_indexed(tf.tensor_scatter_nd_add) + scatter_nd_sub = defer_indexed(tf.tensor_scatter_nd_sub) + scatter_nd_update = defer_indexed(tf.tensor_scatter_nd_update) + + # Load not supported. + load = defer_raise_notimplemented() + + # Shape ops. + set_shape = defer_property("set_shape") + get_shape = defer_property("get_shape") + + # Read dense. + initialized_value = property( + fget=notify(lambda self: self.initial_tensor_value)) + read_value = defer_read() + numpy = defer_property("numpy") + value = defer_read() + eval = defer_property("eval") + + # Read sparse. + gather_nd = defer_indexed(tf.gather_nd) + sparse_read = defer_indexed(tf.gather) + + # Serialize. + to_proto = defer_raise_notimplemented() + + # Misc. + count_up_to = defer_raise_notimplemented() + + def __repr__(self): + return "TensorVariable(shape={}, dtype={}, name={!r})".format( + list(self.shape), self.dtype.name, self.name) + + __str__ = __repr__ + + # Math ops. + __add__ = defer_property("__add__") + __sub__ = defer_property("__sub__") + __mul__ = defer_property("__mul__") + __div__ = defer_property("__div__") + + +@functools.partial(tf.register_tensor_conversion_function, TensorVariable) +@notify +def tv_to_tensor(value, dtype=None, name=None, as_ref=None): + """Converts a TensorVariable to a tf.Tensor.""" + del as_ref + tensor_value = value.tensor_value + if tensor_value is None: + # TODO(tomhennigan) We should probably not notify in this case. + tensor_value = tf.zeros(value.shape, dtype=value.dtype) + if dtype is not None: + tensor_value = tf.cast(tensor_value, dtype=dtype, name=name) + return tensor_value + + +def create_tensor_variables(): + """Defines a scope in which `TensorVariable`s are created. + + >>> with snt.functional.variables(): + ... v = tf.Variable(tf.ones([]), name="v") + >>> v.tensor_value + + + Returns: + A context manager that forces tf.Variable to create TensorVariables. + """ + + def getter(next_getter, **kwargs): + del next_getter + initial_value = tf.convert_to_tensor(kwargs["initial_value"]) + trainable = utils.first_non_none(kwargs["trainable"], True) + name = utils.first_non_none(kwargs["name"], "Variable") + name = utils.get_name_scope() + name + ":0" + return TensorVariable(initial_value, trainable=trainable, name=name) + + return tf.variable_creator_scope(getter) + +variables = create_tensor_variables + + +@contextlib.contextmanager +def track_tensor_variables(): + tensor_variables = [] + with TensorVariableCallbacks.instance(tensor_variables.append): # pylint: disable=not-callable + yield tensor_variables + + +@contextlib.contextmanager +def track_new_variables(): + new_variables = [] + def getter(next_getter, *args, **kwargs): + var = next_getter(*args, **kwargs) + new_variables.append(var) + return var + + with tf.variable_creator_scope(getter): + yield new_variables + + +@contextlib.contextmanager +def track_initial_state(): + var_state = {} + def callback(v): + r = v.ref() + if r not in var_state: + var_state[r] = (v.initial_tensor_value, v.tensor_value) + + with TensorVariableCallbacks.instance(callback): # pylint: disable=not-callable + yield var_state + + +def initial_value_by_ref(tf_variables): + # TODO(tomhennigan) Consider rolling own ref class comparing by name/shape. + return {v.ref(): v.initial_tensor_value for v in tf_variables} + + +def final_value_by_ref(tf_variables): + # TODO(tomhennigan) Consider rolling own ref class comparing by name/shape. + return {v.ref(): v.tensor_value for v in tf_variables} + + +def transform(f) -> Transformed: + """Transforms a function using Sonnet modules into a pair of pure functions. + + The first thing to do is to create some `snt.Module` instances: + + >>> with snt.functional.variables(): + ... a = snt.Linear(10, name="a") + ... b = snt.Linear(10, name="b") + + Next, define some function that creates and applies modules: + + >>> def f(x): + ... return a(x) + b(x) + + Now we can convert that function into a pair of functions that allow us to + lift all the parameters out of the function (`f.ini`) and apply the function + with a given set of parameters (`f.apply`): + + >>> f = snt.functional.transform(f) + + To get the initial state of the module call `f.init` with an example input: + + >>> x = tf.ones([1, 1]) + >>> params = f.init(x) + >>> params + {<...>: , + <...>: , + <...>: , + <...>: } + + You can then apply the function with the given parameters by calling + `f.apply`: + + >>> f.apply(params, x) + + + It is expected that your program will at some point produce updated parameters + and you will want to re-apply `f.apply`. You can do this by calling + `f.apply` with different parameters: + + >>> new_params = tree.map_structure(lambda p: p + 1, params) + >>> f.apply(new_params, x) + + + If your network contains non-trainable state (e.g. moving averages) then you + will need to use :func:`transform_with_state`. + + Args: + f: A function closing over `Module` instances. + + Returns: + A transformed function with `init` and `apply`. See docstring for details. + """ + return without_state(transform_with_state(f)) + + +def transform_with_state(f) -> TransformedWithState: + r"""Like :func:`transform` but supporting non-trainable state. + + See :func:`transform` for more details. + + It is possible for the network to maintain internal state (e.g. for a module + like `BatchNorm` that may want to maintain a moving average): + + >>> with snt.functional.variables(): + ... ema = snt.ExponentialMovingAverage(decay=0.5) + + >>> f = snt.functional.transform_with_state(ema) + + When initializing this network we are returned the parameters (any "trainable" + :tf:`Variable`\ s) and all other state (any non-trainable :tf:`Variable`\ s): + + >>> params, state = f.init(3.0) + >>> params + {} + >>> state + {<...>: , + <...>: , + <...>: } + + To apply the network we simply call it and get back updated values for our + non-trainable state: + + >>> y, state = f.apply(params, state, 3.0) + >>> y.numpy() + 3.0 + + >>> y, state = f.apply(params, state, 6.0) + >>> y.numpy() + 5.0 + + Args: + f: A function closing over `Module` instances. + + Returns: + A transformed function with `init` and `apply`. See docstring for details. + """ + def init_fn(*args, **kwargs): + """Applies `f(*a, **k)` and extracts initial variable values.""" + with create_tensor_variables(), \ + track_new_variables() as new_variables, \ + track_initial_state() as prev_var_state, \ + track_tensor_variables() as tensor_variables: + + # NOTE: Intentionally discarding result. + f(*args, **kwargs) + + params = initial_value_by_ref(v for v in tensor_variables if v.trainable) + state = initial_value_by_ref(v for v in tensor_variables if not v.trainable) + + # Reset variable values. + new_variables = {v.ref() for v in new_variables} + for v in tensor_variables: + r = v.ref() + if r in new_variables: + # Variables created inside the function have their values nullified. + initial_tensor_value, tensor_value = None, None + else: + # Variables that already existed have their value reset. + initial_tensor_value, tensor_value = prev_var_state[r] + v.initial_tensor_value = initial_tensor_value + v.tensor_value = tensor_value + + return params, state + + def apply_fn(params, state, *args, **kwargs): + """Applies `f(*a, **k)` with variable values passed in.""" + initial_values = {} + for r, t in itertools.chain(params.items(), state.items()): + v = r.deref() + initial_values[r] = (v.tensor_value, v.initial_tensor_value) + v.assign(t) + + try: + with track_new_variables() as new_variables: + out = f(*args, **kwargs) + if new_variables: + raise ValueError("Apply function cannot create new variables.") + state = final_value_by_ref(p.deref() for p in state.keys()) + return out, state + + finally: + # Reset values to their initial state. + for r, (tensor_value, initial_tensor_value) in initial_values.items(): + v = r.deref() + v.tensor_value = tensor_value + v.initial_tensor_value = initial_tensor_value + + return TransformedWithState(init=init_fn, apply=apply_fn) + + +def without_state(with_state: TransformedWithState) -> Transformed: + """Returns init/apply functions that ignore state.""" + + def init_fn(*args, **kwargs): + params, state = with_state.init(*args, **kwargs) + if state: + raise ValueError("Stateful networks must use `transform_with_state(f)`") + return params + + def apply_fn(params, *args, **kwargs): + y, state = with_state.apply(params, {}, *args, **kwargs) + if state: + raise ValueError("Stateful networks must use `transform_with_state(f)`") + return y + + return Transformed(init_fn, apply_fn) diff --git a/sonnet/src/functional/haiku_test.py b/sonnet/src/functional/haiku_test.py new file mode 100644 index 00000000..502c06b3 --- /dev/null +++ b/sonnet/src/functional/haiku_test.py @@ -0,0 +1,242 @@ +# Copyright 2020 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for Haiku compatibility layer.""" + +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +from absl.testing import parameterized +import sonnet as snt +from sonnet.src import test_utils +from sonnet.src.functional import haiku as hk +import tensorflow as tf +import tree + + +class TensorVariableTest(test_utils.TestCase, parameterized.TestCase): + + def test_initial_value(self): + with hk.variables(): + v = tf.Variable(tf.ones([])) + self.assertIsInstance(v, hk.TensorVariable) + self.assertAllEqual(v, 1) + self.assertAllEqual(v.read_value(), 1) + self.assertAllEqual(v.tensor_value, 1) + + @parameterized.parameters(None, True, False) + def test_trainable(self, trainable): + with hk.variables(): + v = tf.Variable(1., trainable=trainable) + if trainable is None: + self.assertTrue(v.trainable) + else: + self.assertEqual(v.trainable, trainable) + + def test_name(self): + with hk.variables(): + v = tf.Variable(tf.ones([]), name="v") + self.assertEqual(v.name, "v:0") + + def test_name_with_scope(self): + with hk.variables(), tf.name_scope("foo"), tf.name_scope("bar"): + v = tf.Variable(tf.ones([]), name="v") + self.assertEqual(v.name, "foo/bar/v:0") + + @parameterized.parameters(([],), ([1, 2, 3],)) + def test_shape(self, shape): + with hk.variables(): + v = tf.Variable(tf.ones(shape)) + self.assertEqual(shape, v.shape.as_list()) + + @parameterized.parameters(tf.float32, tf.int32) + def test_dtype(self, dtype): + with hk.variables(): + v = tf.Variable(tf.ones([], dtype=dtype)) + self.assertEqual(dtype, v.dtype) + + def test_attributes_do_not_notify(self): + with hk.variables(): + v = tf.Variable(1.) + s = tf.Variable(1., trainable=False) + + def f(): + for c in (v, s): + self.assertIsNotNone(c.shape) + self.assertIsNotNone(c.dtype) + self.assertIsNotNone(c.trainable) + self.assertIsNotNone(c.name) + self.assertIsNotNone(c.device) + + f = hk.transform_with_state(f) + params, state = f.init() + self.assertEmpty(params) + self.assertEmpty(state) + + out, state = f.apply(params, state) + self.assertIsNone(out) + self.assertEmpty(state) + + def test_read_captured_variables_included(self): + with hk.variables(): + v = tf.Variable(1.) + s = tf.Variable(1., trainable=False) + + f = hk.transform_with_state(lambda: (v.read_value() + s.read_value())) + + params, state = f.init() + self.assertEqual(params, {v.ref(): v.tensor_value}) + self.assertEqual(state, {s.ref(): s.tensor_value}) + + def test_captured_variable_from_other_function_raises(self): + def f(model): + if not model: + model.append(tf.Variable(1.)) + model.append(tf.Variable(1., trainable=False)) + return sum(model) + + f = hk.transform_with_state(f) + + model = [] + params, state = f.init(model) + self.assertLen(params, 1) + self.assertLen(state, 1) + + with self.assertRaisesRegex(ValueError, "TensorVariable .* has no value"): + f.init(model) + + def test_assign(self): + with hk.variables(): + v = tf.Variable(tf.ones([])) + v.assign(tf.zeros([])) + self.assertAllEqual(v.numpy(), 0) + self.assertAllEqual(v.read_value().numpy(), 0) + self.assertAllEqual(v.tensor_value.numpy(), 0) + + def test_assign_add(self): + with hk.variables(): + v = tf.Variable(tf.ones([])) + v.assign_add(1.) + self.assertAllEqual(v.numpy(), 2) + self.assertAllEqual(v.read_value().numpy(), 2) + self.assertAllEqual(v.tensor_value.numpy(), 2) + + def test_assign_sub(self): + with hk.variables(): + v = tf.Variable(tf.ones([])) + v.assign_sub(1.) + self.assertAllEqual(v.numpy(), 0) + self.assertAllEqual(v.read_value().numpy(), 0) + self.assertAllEqual(v.tensor_value.numpy(), 0) + + +class NetworkTest(test_utils.TestCase, parameterized.TestCase): + + def test_transform(self): + mod = snt.Linear(1, w_init=tf.ones) + snt.allow_empty_variables(mod) + self.assertEmpty(mod.variables) + + f = hk.transform(mod) + x = tf.ones([1, 1]) + + params = f.init(x) + self.assertLen(params.items(), 2) + self.assertAllEqual(params[mod.w.ref()], [[1.]]) + self.assertAllEqual(params[mod.b.ref()], [0.]) + + y = f.apply(params, x) + self.assertEqual(y, [[1.]]) + + params = tree.map_structure(lambda p: p + 1, params) + y = f.apply(params, x) + self.assertEqual(y, [[3.]]) + + def test_initial_values_preserved(self): + with hk.variables(): + v = tf.Variable(0) + v.assign(1) + + def assert_values(): + self.assertEqual(v.initial_tensor_value.numpy(), 0) + self.assertEqual(v.tensor_value.numpy(), 1) + + assert_values() + f = hk.transform(lambda: v.assign(2)) + assert_values() + params = f.init() + assert_values() + f.apply(params) + assert_values() + + def test_variables_in_transform_set_to_none(self): + mod = snt.Bias() + f = hk.transform(mod) + params = f.init(tf.ones([1, 1])) # Will create `mod.b`. + self.assertIsNone(mod.b.tensor_value) + self.assertIsNone(mod.b.initial_tensor_value) + + y = f.apply(params, tf.ones([1, 1])) + self.assertAllEqual(y.numpy(), [[1.]]) + self.assertIsNone(mod.b.tensor_value) + self.assertIsNone(mod.b.initial_tensor_value) + + def test_disallows_variables_in_apply(self): + _, apply_fn = hk.transform(lambda: tf.Variable(1)) + with self.assertRaisesRegex(ValueError, + "Apply function cannot create new variables"): + apply_fn({}) + + def test_state_returns_initial_value(self): + with hk.variables(): + # NOTE: Initial value defined outside transform. + v = tf.Variable(0, trainable=False) + + f = hk.transform_with_state(lambda: v.assign(1)) + params, state = f.init() + initial_v = state[v.ref()] + self.assertEqual(initial_v.numpy(), 0) + + y, state = f.apply(params, state) + final_v = state[v.ref()] + self.assertEqual(y.numpy(), 1) + self.assertEqual(final_v.numpy(), 1) + + def test_state_counter(self): + with hk.variables(): + v = tf.Variable(0, trainable=False) + + f = hk.transform_with_state(lambda: v.assign_add(1)) + params, initial_state = f.init() + for _ in range(2): + state = initial_state + for i in range(10): + y, state = f.apply(params, state) + self.assertEqual(y.numpy(), i + 1) + + def test_state_ema(self): + with hk.variables(): + ema = snt.ExponentialMovingAverage(decay=0.5) + ema = hk.transform_with_state(ema) + + params, state = ema.init(3.0) + y, state = ema.apply(params, state, 3.0) + self.assertAllClose(y.numpy(), 3.0) + y, state = ema.apply(params, state, 6.0) + self.assertAllClose(y.numpy(), 5.0) + +if __name__ == "__main__": + tf.test.main() diff --git a/sonnet/src/functional/jax.py b/sonnet/src/functional/jax.py new file mode 100644 index 00000000..2a0a05c4 --- /dev/null +++ b/sonnet/src/functional/jax.py @@ -0,0 +1,75 @@ +# Copyright 2020 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""A subset of the JAX API in TF2.""" + +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +import functools + +from sonnet.src.functional import utils +import tensorflow as tf +import tree + + +def device_put(t, device=None): + return tree.map_structure(utils.run_on_device(lambda x: x, device), t) + + +def device_get(t): + return tree.map_structure(lambda x: x.numpy(), t) + + +# TODO(tomhennigan) This should be cached. +def jit(f, device=None): + if device is None: + device = utils.get_first_accelerator() + # TODO(tomhennigan) Enable XLA compilation (experimental_compile=True). + return tf.function(utils.run_on_device(f, device)) + + +def grad(f, argnums=0, has_aux=False): + """Returns the gradient function for `f`.""" + value_and_grad_f = value_and_grad(f, argnums=argnums, has_aux=has_aux) + @functools.wraps(f) + def wrapper(*args, **kwargs): + if has_aux: + (_, aux), g = value_and_grad_f(*args, **kwargs) + return g, aux + else: + _, g = value_and_grad_f(*args, **kwargs) + return g + return wrapper + + +def value_and_grad(f, argnums=0, has_aux=False): + """Returns the gradient function for `f`.""" + @functools.wraps(f) + def wrapper(*args, **kwargs): + """Computes `f` and returns derivatives of the output wrt input(s).""" + params = tree.map_structure(args.__getitem__, argnums) + with tf.GradientTape(watch_accessed_variables=False) as tape: + tree.map_structure(tape.watch, params) + out = f(*args, **kwargs) + if has_aux: + out, aux = out + grads = tape.gradient(out, params) + if has_aux: + return (out, aux), grads + else: + return out, grads + return wrapper diff --git a/sonnet/src/functional/jax_test.py b/sonnet/src/functional/jax_test.py new file mode 100644 index 00000000..11d3d1d4 --- /dev/null +++ b/sonnet/src/functional/jax_test.py @@ -0,0 +1,89 @@ +# Copyright 2020 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for Sonnet JAX interop layer.""" + +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +from absl.testing import parameterized +from sonnet.src import test_utils +from sonnet.src.functional import jax +import tensorflow as tf + + +class JaxTest(test_utils.TestCase, parameterized.TestCase): + + def test_jit_copies_to_device(self): + accelerators = get_accelerators() + if not accelerators: + self.skipTest("No accelerator.") + + with tf.device("CPU"): + x = tf.ones([]) + + self.assertTrue(x.device.endswith("CPU:0")) + + for device in accelerators: + y = jax.jit(lambda x: x, device=device)(x) + self.assertTrue(y.device, device) + + def test_device_put(self): + accelerators = get_accelerators() + if not accelerators: + self.skipTest("No accelerator.") + + with tf.device("CPU"): + x = tf.ones([]) + + for device in accelerators: + y = jax.device_put(x, device=device) + self.assertTrue(y.device.endswith(device)) + + +class GradTest(test_utils.TestCase, parameterized.TestCase): + + def test_grad(self): + f = lambda x: x ** 2 + g = jax.grad(f) + x = tf.constant(4.) + self.assertAllClose(g(x).numpy(), (2 * x).numpy()) + + def test_argnums(self): + f = lambda x, y: (x ** 2 + y ** 2) + g = jax.grad(f, argnums=(0, 1)) + x = tf.constant(4.) + y = tf.constant(5.) + gx, gy = g(x, y) + self.assertAllClose(gx.numpy(), (2 * x).numpy()) + self.assertAllClose(gy.numpy(), (2 * y).numpy(), rtol=1e-3) + + def test_has_aux(self): + f = lambda x: (x ** 2, "aux") + g = jax.grad(f, has_aux=True) + x = tf.constant(2.) + gx, aux = g(x) + self.assertAllClose(gx.numpy(), (2 * x).numpy()) + self.assertEqual(aux, "aux") + + +def get_accelerators(): + gpus = tf.config.experimental.list_logical_devices("GPU") + tpus = tf.config.experimental.list_logical_devices("TPU") + return [tf.DeviceSpec.from_string(d.name).to_string() for d in gpus + tpus] + +if __name__ == "__main__": + tf.test.main() diff --git a/sonnet/src/functional/optimizers.py b/sonnet/src/functional/optimizers.py new file mode 100644 index 00000000..8083231e --- /dev/null +++ b/sonnet/src/functional/optimizers.py @@ -0,0 +1,164 @@ +# Copyright 2020 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Functional optimizers.""" + +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +import collections +import functools +from typing import Callable, Type + +from sonnet.src import base +from sonnet.src.functional import haiku +import tensorflow as tf +import tree + +TransformedOptimizer = collections.namedtuple("TransformedOptimizer", + ("init", "apply")) + + +def optimizer(cls: Type[base.Optimizer]) -> Callable[..., TransformedOptimizer]: + """Converts a snt.Optimizer subclass into a functional optimizer. + + To wrap a Sonnet optimizer class simply pass it to :func:`optimizer`: + + >>> adam = snt.functional.optimizer(snt.optimizers.Adam) + + This will give you back a function that drives the constructor of the + optimizer and returns a pair of functions that give you the optimizer state + and a way to apply it: + + >>> optimizer = adam(learning_rate=0.01) + + NOTE: We provide convenience wrappers for the builtin optimizers so you can + just use `opt = snt.functional.adam(learning_rate=0.01)` if you prefer: + + >>> optimizer = snt.functional.adam(learning_rate=0.01) + + To make this example useful lets create a simple network to test: + + >>> with snt.functional.variables(): + ... net = snt.nets.MLP([100, 10]) + + >>> def loss_fn(images, labels): + ... logits = net(images) + ... x_ent = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, + ... labels=labels) + ... loss = tf.reduce_mean(x_ent) + ... return loss + + >>> loss_fn = snt.functional.transform(loss_fn) + + >>> x = tf.ones([1, 1]) + >>> y = tf.constant([1]) + >>> params = loss_fn.init(x, y) + + To get the initial state of our optimizer (e.g. m/v terms in Adam) we need to + run the `optimizer.init` function: + + >>> opt_state = optimizer.init(params) + + Now we can run a single training step by taking gradients of our network and + applying one step of our optimizer: + + >>> grad_apply_net = snt.functional.grad(loss_fn.apply) + + >>> def train_step(x, y, params, opt_state): + ... grads = grad_apply_net(params, x, y) + ... params, opt_state = optimizer.apply(opt_state, grads, params) + ... return params, opt_state + + Teach the network to always predict one: + + >>> target = tf.constant([1]) + >>> dataset = [(tf.random.normal([1, 1]), target) for _ in range(10)] + >>> for x, y in dataset: + ... params, opt_state = train_step(x, y, params, opt_state) + + Args: + cls: A :class:`~sonnet.Optimizer` subclass to functionalize. + + Returns: + A transformed optimizer with `init` and `apply`. See docstring for details. + """ + @functools.wraps(cls.__init__) + def wrapper(*args, **kwargs): + with haiku.variables(): + opt = cls(*args, **kwargs) # pytype: disable=not-instantiable + return _wrap_optimizer(opt) + return wrapper + + +def _split_on_trainable(opt_state): + trainable = {} + non_trainable = {} + for param_ref, value in opt_state.items(): + if param_ref.deref().trainable: + trainable[param_ref] = value + else: + non_trainable[param_ref] = value + return trainable, non_trainable + + +def _merge(a, b): + """Merges two dictionaries and returns a new one.""" + c = dict(a) + c.update(b) + return c + + +def _wrap_optimizer(opt: base.Optimizer) -> TransformedOptimizer: + """Returns a functional optimizer.""" + + def init_opt_fn(params): + """Creates initial optimizer state.""" + def f(params): + params = [p.deref() for p in sorted(params.keys())] + updates = [tf.zeros_like(p) for p in params] + for p, zero in zip(params, updates): + p.assign(zero) + opt.apply(updates, params) + + f = haiku.transform_with_state(f) + + trainable, non_trainable = f.init(params) + opt_state = _merge( + {r: v for r, v in trainable.items() if r not in params}, + {r: v for r, v in non_trainable.items() if r not in params}) + + return opt_state + + def apply_opt_fn(opt_state, updates, params): + """Applies the optimizer and returns updated parameters and opt state.""" + def f(opt_state, params, updates): + flat_params = [p.deref() for p in sorted(params)] + updates = tree.flatten(updates) + opt.apply(updates, flat_params) + params = {r: r.deref().tensor_value for r in params} + opt_state = {r: r.deref().tensor_value for r in opt_state} + return params, opt_state + + f = haiku.transform_with_state(f) + + trainable_opt_state, non_trainable = _split_on_trainable(opt_state) + trainable = _merge(params, trainable_opt_state) + (params, opt_state), _ = f.apply(trainable, non_trainable, + opt_state, params, updates) + return params, opt_state + + return TransformedOptimizer(init=init_opt_fn, apply=apply_opt_fn) diff --git a/sonnet/src/functional/optimizers_test.py b/sonnet/src/functional/optimizers_test.py new file mode 100644 index 00000000..4593549e --- /dev/null +++ b/sonnet/src/functional/optimizers_test.py @@ -0,0 +1,80 @@ +# Copyright 2020 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for functional optimizers.""" + +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +from absl.testing import parameterized +import sonnet as snt +from sonnet.src import test_utils +from sonnet.src.functional import haiku +from sonnet.src.functional import optimizers +import tensorflow as tf +import tree + +sgd = optimizers.optimizer(snt.optimizers.SGD) +adam = optimizers.optimizer(snt.optimizers.Adam) + + +class OptimizersTest(test_utils.TestCase, parameterized.TestCase): + + def test_sgd(self): + with haiku.variables(): + params = [tf.Variable(1.)] + params = {p.ref(): tf.ones_like(p) for p in params} + + opt = sgd(learning_rate=0.01) + opt_state = opt.init(params) + grads = tree.map_structure(tf.ones_like, params) + params, opt_state = opt.apply(opt_state, grads, params) + p, = tree.flatten(params) + self.assertAllClose(p.numpy(), 1. - (0.01 * 1)) + + def test_adam(self): + lin = haiku.transform(snt.Linear(1)) + x = tf.ones([1, 1]) + params = lin.init(x) + + optimizer = adam(learning_rate=0.01) + opt_state = optimizer.init(params) + # Step + (m, v) per parameter. + self.assertLen(tree.flatten(opt_state), 5) + + @parameterized.parameters(True, False) + def test_adam_with_variable_lr(self, trainable_lr): + lin = haiku.transform(snt.Linear(1)) + x = tf.ones([1, 1]) + initial_params = lin.init(x) + + with haiku.variables(): + lr = tf.Variable(0.01, trainable=trainable_lr, name="lr") + + optimizer = adam(learning_rate=lr) + initial_opt_state = optimizer.init(initial_params) + # Learning rate, step + (m, v) per parameter. + self.assertLen(tree.flatten(initial_opt_state), 6) + + grads = tree.map_structure(tf.ones_like, initial_params) + params, opt_state = optimizer.apply( + initial_opt_state, grads, initial_params) + + tree.assert_same_structure(initial_opt_state, opt_state) + tree.assert_same_structure(initial_params, params) + +if __name__ == "__main__": + tf.test.main() diff --git a/sonnet/src/functional/utils.py b/sonnet/src/functional/utils.py new file mode 100644 index 00000000..734b692e --- /dev/null +++ b/sonnet/src/functional/utils.py @@ -0,0 +1,80 @@ +# Copyright 2020 The Sonnet Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utility functions for the JAX API in TF2.""" + +from __future__ import absolute_import +from __future__ import division +# from __future__ import google_type_annotations +from __future__ import print_function + +import functools + +from sonnet.src import utils +import tensorflow as tf +import tree + + +def get_first_accelerator(): + tpus = tf.config.experimental.list_logical_devices("TPU") + if tpus: + return tpus[0].name + else: + gpus = tf.config.experimental.list_logical_devices("GPU") + return gpus[0].name if gpus else "/device:CPU:0" + + +def run_on_device(f, device): + """Runs `f` under a tf.device context on the given device.""" + f = utils.smart_autograph(f) + + @tf.autograph.experimental.do_not_convert + @functools.wraps(f) + def wrapper(*args, **kwargs): + with tf.device(device): + args = tree.map_structure(tf.identity, args) + kwargs = tree.map_structure(tf.identity, kwargs) + return f(*args, **kwargs) + return wrapper + + +def get_name_scope(): + with tf.name_scope("x") as ns: + return ns[:-2] + + +def first_non_none(*args): + return next(a for a in args if a is not None) + + +def compose(f0, *fs): + """Composes a sequence of functions. + + >>> f1 = lambda a, b: f"f1({a}, {b})" + >>> f2 = lambda a: f"f2({a})" + >>> f3 = lambda a: f"f3({a})" + >>> f = compose(f1, f2, f3) + >>> f("a", "b") + 'f3(f2(f1(a, b)))' + + Args: + f0: The first function to apply. + *fs: Other functions to apply in sequence. + + Returns: + A function that is the composition of the input functions. + """ + def wrapper(*args, **kwargs): + return functools.reduce(lambda x, f: f(x), fs, f0(*args, **kwargs)) + return wrapper