This repository has been archived by the owner on Feb 26, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
223 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import numpy as np | ||
import pytest | ||
import treex as tx | ||
|
||
|
||
class TestMLP: | ||
def test_basic(self): | ||
mlp = tx.MLP([2, 32, 8, 4]).init(42) | ||
|
||
x = np.random.uniform(-1, 1, (10, 2)) | ||
y = mlp(x) | ||
|
||
assert y.shape == (10, 4) | ||
|
||
def test_too_few_features(self): | ||
|
||
with pytest.raises(ValueError): | ||
mlp = tx.MLP([2]).init(42) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import jax | ||
import numpy as np | ||
import pytest | ||
import treex as tx | ||
|
||
|
||
class TestSequence: | ||
def test_basic(self): | ||
mlp = tx.Sequence( | ||
tx.Linear(2, 32), | ||
jax.nn.relu, | ||
tx.Linear(32, 8), | ||
jax.nn.relu, | ||
tx.Linear(8, 4), | ||
).init(42) | ||
|
||
assert isinstance(mlp.layers[1], tx.Lambda) | ||
assert isinstance(mlp.layers[3], tx.Lambda) | ||
|
||
x = np.random.uniform(-1, 1, (10, 2)) | ||
y = mlp(x) | ||
|
||
assert y.shape == (10, 4) | ||
|
||
def test_pytree(self): | ||
mlp = tx.Sequence( | ||
tx.Linear(2, 32), | ||
jax.nn.relu, | ||
tx.Linear(32, 8), | ||
jax.nn.relu, | ||
tx.Linear(8, 4), | ||
).init(42) | ||
|
||
jax.tree_map(lambda x: 2 * x, mlp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .linear import Linear | ||
from .batch_norm import BatchNorm | ||
from .sequence import sequence | ||
from .dropout import Dropout | ||
from .conv import Conv | ||
from .dropout import Dropout | ||
from .linear import Linear | ||
from .mlp import MLP | ||
from .sequence import sequence, Sequence, Lambda |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import typing as tp | ||
from flax.linen import linear as flax_module | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
|
||
from treex.module import Module | ||
from treex import types | ||
from treex.nn.linear import Linear | ||
|
||
|
||
class MLP(Module): | ||
"""A Multi-Layer Perceptron (MLP) that applies a sequence of linear layers | ||
with relu activations, the last layer is linear. | ||
""" | ||
|
||
# pytree | ||
layers: tp.List[Linear] | ||
|
||
# props | ||
features: tp.Sequence[int] | ||
module: flax_module.Dense | ||
|
||
def __init__( | ||
self, | ||
features: tp.Sequence[int], | ||
use_bias: bool = True, | ||
dtype: tp.Any = jnp.float32, | ||
precision: tp.Any = None, | ||
kernel_init: tp.Callable[ | ||
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], | ||
flax_module.Array, | ||
] = flax_module.default_kernel_init, | ||
bias_init: tp.Callable[ | ||
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], | ||
flax_module.Array, | ||
] = flax_module.zeros, | ||
): | ||
""" | ||
Arguments: | ||
features: a sequence of L+1 integers, where L is the number of layers, | ||
the first integer is the number of input features and all subsequent | ||
integers are the number of output features of the respective layer. | ||
use_bias: whether to add a bias to the output (default: True). | ||
dtype: the dtype of the computation (default: float32). | ||
precision: numerical precision of the computation see `jax.lax.Precision` | ||
for details. | ||
kernel_init: initializer function for the weight matrix. | ||
bias_init: initializer function for the bias. | ||
""" | ||
if len(features) < 2: | ||
raise ValueError("features must have at least 2 elements") | ||
|
||
self.features = features | ||
self.layers = [ | ||
Linear( | ||
features_in=features_in, | ||
features_out=features_out, | ||
use_bias=use_bias, | ||
dtype=dtype, | ||
precision=precision, | ||
kernel_init=kernel_init, | ||
bias_init=bias_init, | ||
) | ||
for features_in, features_out in zip(features[:-1], features[1:]) | ||
] | ||
|
||
def __call__(self, x: np.ndarray) -> jnp.ndarray: | ||
""" | ||
Applies the MLP to the input. | ||
Arguments: | ||
x: input array. | ||
Returns: | ||
The output of the MLP. | ||
""" | ||
for layer in self.layers[:-1]: | ||
x = jax.nn.relu(layer(x)) | ||
|
||
return self.layers[-1](x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
82e612d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
Sequential
would be a much better name because Keras, Pytorch, and sonnet all have such a nameSequential
rather thanSequence
.82e612d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wookayin this is a great point!
I've create #3 to address this.