Skip to content

Commit

Permalink
Add functional API to Sonnet 2 inspired by JAX [0] and Haiku [1].
Browse files Browse the repository at this point in the history
  • Loading branch information
tomhennigan authored and copybara-github committed Jun 18, 2021
1 parent 6c2bf5f commit 92c43b5
Show file tree
Hide file tree
Showing 16 changed files with 1,494 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
doctest_global_setup = """
import tensorflow as tf
import sonnet as snt
import tree
# `TpuReplicator` cannot be constructed without a TPU, however it has exactly
# the same API as `Replicator` so we can run doctests using that instead.
Expand Down
14 changes: 14 additions & 0 deletions examples/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,17 @@ snt_py_test(
# pip: tensorflow
],
)

py_binary(
name = "functional_mlp_mnist",
srcs = ["functional_mlp_mnist.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
# pip: absl:app
# pip: absl/logging
"//sonnet",
# pip: tensorflow
# pip: tensorflow_datasets
],
)
120 changes: 120 additions & 0 deletions examples/functional_mlp_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2020 The Sonnet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Toy MLP on MNIST example of using TF2 JAX/HK shims."""

from absl import app
from absl import logging
import sonnet as snt
import tensorflow as tf
import tensorflow_datasets as tfds

fn = snt.functional


def main(unused_argv):
del unused_argv

with fn.variables():
net = snt.nets.MLP([1000, 100, 10])

def loss_fn(images, labels):
images = snt.flatten(images)
logits = net(images)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
logits=logits))
return loss

loss_fn = fn.transform(loss_fn)

def preprocess(images, labels):
images = tf.image.convert_image_dtype(images, tf.float32)
return images, labels

# _ _
# | |_ _ __ __ _(_)_ __
# | __| '__/ _` | | '_ \
# | |_| | | (_| | | | | |
# \__|_| \__,_|_|_| |_|
#

batch_size = 100

dataset = tfds.load("mnist", split="train", as_supervised=True)
dataset = dataset.map(preprocess)
dataset = dataset.cache()
dataset = dataset.shuffle(batch_size * 8)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.prefetch()

# As before we want to unzip our loss_fn into init and apply.
optimizer = fn.adam(0.01)

# To get our initial state we need to pull a record from our dataset and pass
# it to our init function. We'll also be sure to use `device_put` such that
# the parameters are on the accelerator.
images, labels = next(iter(dataset))
params = fn.device_put(loss_fn.init(images, labels))
opt_state = fn.device_put(optimizer.init(params))

# Our training loop is to iterate through 10 epochs of the train dataset, and
# use sgd after each minibatch to update our parameters according to the
# gradient from our loss function.
grad_apply_fn = fn.jit(fn.value_and_grad(loss_fn.apply))
apply_opt_fn = fn.jit(optimizer.apply)

for epoch in range(10):
for images, labels in dataset:
loss, grads = grad_apply_fn(params, images, labels)
params, opt_state = apply_opt_fn(opt_state, grads, params)
logging.info("[Epoch %s] loss=%s", epoch, loss.numpy())

# _ _
# | |_ ___ ___| |_
# | __/ _ \/ __| __|
# | || __/\__ \ |_
# \__\___||___/\__|
#

def accuracy_fn(images, labels):
images = snt.flatten(images)
predictions = tf.argmax(net(images), axis=1)
correct = tf.math.count_nonzero(tf.equal(predictions, labels))
total = tf.shape(labels)[0]
return correct, total

accuracy_fn = fn.transform(accuracy_fn)

batch_size = 10000
dataset = tfds.load("mnist", split="test", as_supervised=True)
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.map(preprocess)

# Note that while we still unzip our accuracy function, we can ignore the
# init_fn here since we already have all the state we need from our training
# function.
apply_fn = fn.jit(accuracy_fn.apply)

# Compute top-1 accuracy.
num_correct = num_total = 0
for images, labels in dataset:
correct, total = apply_fn(params, images, labels)
num_correct += correct
num_total += total
accuracy = (int(num_correct) / int(num_total)) * 100
logging.info("Accuracy %.5f%%", accuracy)

if __name__ == "__main__":
app.run(main)
12 changes: 12 additions & 0 deletions sonnet/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ snt_py_library(
visibility = ["//visibility:public"],
deps = [
":distribute",
":functional",
":initializers",
":mixed_precision",
":optimizers",
Expand Down Expand Up @@ -54,6 +55,17 @@ snt_py_library(
],
)

snt_py_library(
name = "functional",
srcs = ["functional.py"],
deps = [
":optimizers",
"//sonnet/src/functional:haiku",
"//sonnet/src/functional:jax",
"//sonnet/src/functional:optimizers",
],
)

snt_py_library(
name = "initializers",
srcs = ["initializers.py"],
Expand Down
2 changes: 2 additions & 0 deletions sonnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import print_function

from sonnet import distribute
from sonnet import functional
from sonnet import initializers
from sonnet import mixed_precision
from sonnet import nets
Expand Down Expand Up @@ -133,6 +134,7 @@
"distribute",
"dynamic_unroll",
"format_variables",
"functional",
"initializers",
"log_variables",
"lstm_with_recurrent_dropout",
Expand Down
67 changes: 67 additions & 0 deletions sonnet/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2020 The Sonnet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Simple functional APIs for TF2."""

from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

from sonnet import optimizers as oo_optimizers
from sonnet.src.functional import haiku
from sonnet.src.functional import jax
from sonnet.src.functional import optimizers

# Utilities for converting Sonnet code into pure functions.
variables = haiku.variables
transform = haiku.transform
transform_with_state = haiku.transform_with_state
without_state = haiku.without_state

# Utilities for working with tensors on device.
device_get = jax.device_get
device_put = jax.device_put

# Utilities for transforming pure functions.
grad = jax.grad
jit = jax.jit
value_and_grad = jax.value_and_grad

# Optimizers.
optimizer = optimizers.optimizer
sgd = optimizer(oo_optimizers.SGD)
adam = optimizer(oo_optimizers.Adam)
rmsprop = optimizer(oo_optimizers.RMSProp)
momentum = optimizer(oo_optimizers.Momentum)

# Avoid accidentally exporting the private API.
del oo_optimizers, haiku, optimizers, jax

__all__ = (
"variables",
"transform",
"transform_with_state",
"without_state",
"device_get",
"device_put",
"grad",
"jit",
"value_and_grad",
"optimizer",
"sgd",
"adam",
"rmsprop",
"momentum",
)
1 change: 1 addition & 0 deletions sonnet/src/conformance/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ snt_py_test(
"//sonnet",
"//sonnet/src:test_utils",
# pip: tensorflow
# pip: tree
],
)

Expand Down
4 changes: 3 additions & 1 deletion sonnet/src/conformance/doctest_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import sonnet as snt
from sonnet.src import test_utils
import tensorflow as tf
import tree


class DoctestTest(test_utils.TestCase, parameterized.TestCase):
Expand Down Expand Up @@ -62,7 +63,8 @@ def test_doctest(self, module):
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
extraglobs={
"snt": snt,
"tf": tf
"tf": tf,
"tree": tree,
})
if num_attempted == 0:
self.skipTest("No doctests in %s" % module.__name__)
Expand Down
82 changes: 82 additions & 0 deletions sonnet/src/functional/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
load("//sonnet/src:build_defs.bzl", "snt_py_library", "snt_py_test")

package(default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"])

licenses(["notice"])

snt_py_library(
name = "haiku",
srcs = ["haiku.py"],
deps = [
":utils",
# pip: tensorflow
],
)

snt_py_library(
name = "jax",
srcs = ["jax.py"],
deps = [
":utils",
# pip: tensorflow
# pip: tree
],
)

snt_py_library(
name = "optimizers",
srcs = ["optimizers.py"],
deps = [
":haiku",
"//sonnet/src:base",
# pip: tensorflow
# pip: tree
],
)

snt_py_library(
name = "utils",
srcs = ["utils.py"],
deps = [
"//sonnet/src:utils",
# pip: tensorflow
# pip: tree
],
)

snt_py_test(
name = "haiku_test",
srcs = ["haiku_test.py"],
deps = [
":haiku",
# pip: absl/testing:parameterized
"//sonnet",
"//sonnet/src:test_utils",
# pip: tensorflow
# pip: tree
],
)

snt_py_test(
name = "jax_test",
srcs = ["jax_test.py"],
deps = [
":jax",
# pip: absl/testing:parameterized
"//sonnet/src:test_utils",
# pip: tensorflow
],
)

snt_py_test(
name = "optimizers_test",
srcs = ["optimizers_test.py"],
deps = [
":optimizers",
# pip: absl/testing:parameterized
"//sonnet",
"//sonnet/src:test_utils",
# pip: tensorflow
# pip: tree
],
)
Loading

0 comments on commit 92c43b5

Please sign in to comment.