-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add functional API to Sonnet 2 inspired by JAX [0] and Haiku [1].
[0] https://github.com/google/jax [1] https://github.com/deepmind/dm-haiku PiperOrigin-RevId: 317279063
- Loading branch information
1 parent
6c2bf5f
commit 92c43b5
Showing
16 changed files
with
1,494 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright 2020 The Sonnet Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
"""Toy MLP on MNIST example of using TF2 JAX/HK shims.""" | ||
|
||
from absl import app | ||
from absl import logging | ||
import sonnet as snt | ||
import tensorflow as tf | ||
import tensorflow_datasets as tfds | ||
|
||
fn = snt.functional | ||
|
||
|
||
def main(unused_argv): | ||
del unused_argv | ||
|
||
with fn.variables(): | ||
net = snt.nets.MLP([1000, 100, 10]) | ||
|
||
def loss_fn(images, labels): | ||
images = snt.flatten(images) | ||
logits = net(images) | ||
loss = tf.reduce_mean( | ||
tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, | ||
logits=logits)) | ||
return loss | ||
|
||
loss_fn = fn.transform(loss_fn) | ||
|
||
def preprocess(images, labels): | ||
images = tf.image.convert_image_dtype(images, tf.float32) | ||
return images, labels | ||
|
||
# _ _ | ||
# | |_ _ __ __ _(_)_ __ | ||
# | __| '__/ _` | | '_ \ | ||
# | |_| | | (_| | | | | | | ||
# \__|_| \__,_|_|_| |_| | ||
# | ||
|
||
batch_size = 100 | ||
|
||
dataset = tfds.load("mnist", split="train", as_supervised=True) | ||
dataset = dataset.map(preprocess) | ||
dataset = dataset.cache() | ||
dataset = dataset.shuffle(batch_size * 8) | ||
dataset = dataset.batch(batch_size, drop_remainder=True) | ||
dataset = dataset.prefetch() | ||
|
||
# As before we want to unzip our loss_fn into init and apply. | ||
optimizer = fn.adam(0.01) | ||
|
||
# To get our initial state we need to pull a record from our dataset and pass | ||
# it to our init function. We'll also be sure to use `device_put` such that | ||
# the parameters are on the accelerator. | ||
images, labels = next(iter(dataset)) | ||
params = fn.device_put(loss_fn.init(images, labels)) | ||
opt_state = fn.device_put(optimizer.init(params)) | ||
|
||
# Our training loop is to iterate through 10 epochs of the train dataset, and | ||
# use sgd after each minibatch to update our parameters according to the | ||
# gradient from our loss function. | ||
grad_apply_fn = fn.jit(fn.value_and_grad(loss_fn.apply)) | ||
apply_opt_fn = fn.jit(optimizer.apply) | ||
|
||
for epoch in range(10): | ||
for images, labels in dataset: | ||
loss, grads = grad_apply_fn(params, images, labels) | ||
params, opt_state = apply_opt_fn(opt_state, grads, params) | ||
logging.info("[Epoch %s] loss=%s", epoch, loss.numpy()) | ||
|
||
# _ _ | ||
# | |_ ___ ___| |_ | ||
# | __/ _ \/ __| __| | ||
# | || __/\__ \ |_ | ||
# \__\___||___/\__| | ||
# | ||
|
||
def accuracy_fn(images, labels): | ||
images = snt.flatten(images) | ||
predictions = tf.argmax(net(images), axis=1) | ||
correct = tf.math.count_nonzero(tf.equal(predictions, labels)) | ||
total = tf.shape(labels)[0] | ||
return correct, total | ||
|
||
accuracy_fn = fn.transform(accuracy_fn) | ||
|
||
batch_size = 10000 | ||
dataset = tfds.load("mnist", split="test", as_supervised=True) | ||
dataset = dataset.batch(batch_size, drop_remainder=True) | ||
dataset = dataset.map(preprocess) | ||
|
||
# Note that while we still unzip our accuracy function, we can ignore the | ||
# init_fn here since we already have all the state we need from our training | ||
# function. | ||
apply_fn = fn.jit(accuracy_fn.apply) | ||
|
||
# Compute top-1 accuracy. | ||
num_correct = num_total = 0 | ||
for images, labels in dataset: | ||
correct, total = apply_fn(params, images, labels) | ||
num_correct += correct | ||
num_total += total | ||
accuracy = (int(num_correct) / int(num_total)) * 100 | ||
logging.info("Accuracy %.5f%%", accuracy) | ||
|
||
if __name__ == "__main__": | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright 2020 The Sonnet Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
"""Simple functional APIs for TF2.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
# from __future__ import google_type_annotations | ||
from __future__ import print_function | ||
|
||
from sonnet import optimizers as oo_optimizers | ||
from sonnet.src.functional import haiku | ||
from sonnet.src.functional import jax | ||
from sonnet.src.functional import optimizers | ||
|
||
# Utilities for converting Sonnet code into pure functions. | ||
variables = haiku.variables | ||
transform = haiku.transform | ||
transform_with_state = haiku.transform_with_state | ||
without_state = haiku.without_state | ||
|
||
# Utilities for working with tensors on device. | ||
device_get = jax.device_get | ||
device_put = jax.device_put | ||
|
||
# Utilities for transforming pure functions. | ||
grad = jax.grad | ||
jit = jax.jit | ||
value_and_grad = jax.value_and_grad | ||
|
||
# Optimizers. | ||
optimizer = optimizers.optimizer | ||
sgd = optimizer(oo_optimizers.SGD) | ||
adam = optimizer(oo_optimizers.Adam) | ||
rmsprop = optimizer(oo_optimizers.RMSProp) | ||
momentum = optimizer(oo_optimizers.Momentum) | ||
|
||
# Avoid accidentally exporting the private API. | ||
del oo_optimizers, haiku, optimizers, jax | ||
|
||
__all__ = ( | ||
"variables", | ||
"transform", | ||
"transform_with_state", | ||
"without_state", | ||
"device_get", | ||
"device_put", | ||
"grad", | ||
"jit", | ||
"value_and_grad", | ||
"optimizer", | ||
"sgd", | ||
"adam", | ||
"rmsprop", | ||
"momentum", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,6 +76,7 @@ snt_py_test( | |
"//sonnet", | ||
"//sonnet/src:test_utils", | ||
# pip: tensorflow | ||
# pip: tree | ||
], | ||
) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
load("//sonnet/src:build_defs.bzl", "snt_py_library", "snt_py_test") | ||
|
||
package(default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"]) | ||
|
||
licenses(["notice"]) | ||
|
||
snt_py_library( | ||
name = "haiku", | ||
srcs = ["haiku.py"], | ||
deps = [ | ||
":utils", | ||
# pip: tensorflow | ||
], | ||
) | ||
|
||
snt_py_library( | ||
name = "jax", | ||
srcs = ["jax.py"], | ||
deps = [ | ||
":utils", | ||
# pip: tensorflow | ||
# pip: tree | ||
], | ||
) | ||
|
||
snt_py_library( | ||
name = "optimizers", | ||
srcs = ["optimizers.py"], | ||
deps = [ | ||
":haiku", | ||
"//sonnet/src:base", | ||
# pip: tensorflow | ||
# pip: tree | ||
], | ||
) | ||
|
||
snt_py_library( | ||
name = "utils", | ||
srcs = ["utils.py"], | ||
deps = [ | ||
"//sonnet/src:utils", | ||
# pip: tensorflow | ||
# pip: tree | ||
], | ||
) | ||
|
||
snt_py_test( | ||
name = "haiku_test", | ||
srcs = ["haiku_test.py"], | ||
deps = [ | ||
":haiku", | ||
# pip: absl/testing:parameterized | ||
"//sonnet", | ||
"//sonnet/src:test_utils", | ||
# pip: tensorflow | ||
# pip: tree | ||
], | ||
) | ||
|
||
snt_py_test( | ||
name = "jax_test", | ||
srcs = ["jax_test.py"], | ||
deps = [ | ||
":jax", | ||
# pip: absl/testing:parameterized | ||
"//sonnet/src:test_utils", | ||
# pip: tensorflow | ||
], | ||
) | ||
|
||
snt_py_test( | ||
name = "optimizers_test", | ||
srcs = ["optimizers_test.py"], | ||
deps = [ | ||
":optimizers", | ||
# pip: absl/testing:parameterized | ||
"//sonnet", | ||
"//sonnet/src:test_utils", | ||
# pip: tensorflow | ||
# pip: tree | ||
], | ||
) |
Oops, something went wrong.