Skip to content
This repository has been archived by the owner on Feb 26, 2023. It is now read-only.

Commit

Permalink
add Sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Aug 26, 2021
1 parent ff372b7 commit 82e612d
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 11 deletions.
18 changes: 18 additions & 0 deletions tests/nn/test_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
import pytest
import treex as tx


class TestMLP:
def test_basic(self):
mlp = tx.MLP([2, 32, 8, 4]).init(42)

x = np.random.uniform(-1, 1, (10, 2))
y = mlp(x)

assert y.shape == (10, 4)

def test_too_few_features(self):

with pytest.raises(ValueError):
mlp = tx.MLP([2]).init(42)
34 changes: 34 additions & 0 deletions tests/nn/test_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import jax
import numpy as np
import pytest
import treex as tx


class TestSequence:
def test_basic(self):
mlp = tx.Sequence(
tx.Linear(2, 32),
jax.nn.relu,
tx.Linear(32, 8),
jax.nn.relu,
tx.Linear(8, 4),
).init(42)

assert isinstance(mlp.layers[1], tx.Lambda)
assert isinstance(mlp.layers[3], tx.Lambda)

x = np.random.uniform(-1, 1, (10, 2))
y = mlp(x)

assert y.shape == (10, 4)

def test_pytree(self):
mlp = tx.Sequence(
tx.Linear(2, 32),
jax.nn.relu,
tx.Linear(32, 8),
jax.nn.relu,
tx.Linear(8, 4),
).init(42)

jax.tree_map(lambda x: 2 * x, mlp)
10 changes: 7 additions & 3 deletions treex/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ def double_params(self):
Returns:
A new module with the updated values or `None` if `inplace` is `True`.
"""
modules = (self, other) + rest

Expand Down Expand Up @@ -727,7 +726,12 @@ def _resolve_tree_type(name: str, t: tp.Optional[type]) -> tp.Optional[type]:
if t is None:
return None

tree_types = [x for x in _all_types(t) if issubclass(x, (types.TreePart, Module))]
tree_types = [
x
for x in _all_types(t)
if isinstance(x, tp.Type)
if issubclass(x, (types.TreePart, Module))
]

if len(tree_types) > 1:
# if its a type with many Module subtypes just mark them all as Module
Expand All @@ -740,7 +744,7 @@ def _resolve_tree_type(name: str, t: tp.Optional[type]) -> tp.Optional[type]:
elif len(tree_types) == 1:
return tree_types[0]
else:
return t
return None


def _all_types(t: tp.Type) -> tp.Iterable[tp.Type]:
Expand Down
7 changes: 4 additions & 3 deletions treex/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .linear import Linear
from .batch_norm import BatchNorm
from .sequence import sequence
from .dropout import Dropout
from .conv import Conv
from .dropout import Dropout
from .linear import Linear
from .mlp import MLP
from .sequence import sequence, Sequence, Lambda
6 changes: 3 additions & 3 deletions treex/nn/batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def module_init(self, key: jnp.ndarray):

# Extract collections
if "params" in variables:
self.params = variables["params"]
self.params = variables["params"].unfreeze()

self.batch_stats = variables["batch_stats"]
self.batch_stats = variables["batch_stats"].unfreeze()

def __call__(
self, x: np.ndarray, use_running_average: tp.Optional[bool] = None
Expand Down Expand Up @@ -144,6 +144,6 @@ def __call__(

# update batch_stats
if "batch_stats" in variables:
self.batch_stats = variables["batch_stats"]
self.batch_stats = variables["batch_stats"].unfreeze()

return tp.cast(jnp.ndarray, output)
2 changes: 1 addition & 1 deletion treex/nn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def module_init(self, key: jnp.ndarray):
variables = self.module.init(key, x)

# Extract collections
self.params = variables["params"]
self.params = variables["params"].unfreeze()

def __call__(self, x: np.ndarray) -> jnp.ndarray:
"""Applies a convolution to the inputs.
Expand Down
2 changes: 1 addition & 1 deletion treex/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def module_init(self, key: jnp.ndarray):
variables = self.module.init(key, x)

# Extract collections
self.params = variables["params"]
self.params = variables["params"].unfreeze()

def __call__(self, x: np.ndarray) -> jnp.ndarray:
"""Applies a linear transformation to the inputs along the last dimension.
Expand Down
81 changes: 81 additions & 0 deletions treex/nn/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import typing as tp
from flax.linen import linear as flax_module
import jax
import jax.numpy as jnp
import numpy as np

from treex.module import Module
from treex import types
from treex.nn.linear import Linear


class MLP(Module):
"""A Multi-Layer Perceptron (MLP) that applies a sequence of linear layers
with relu activations, the last layer is linear.
"""

# pytree
layers: tp.List[Linear]

# props
features: tp.Sequence[int]
module: flax_module.Dense

def __init__(
self,
features: tp.Sequence[int],
use_bias: bool = True,
dtype: tp.Any = jnp.float32,
precision: tp.Any = None,
kernel_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.default_kernel_init,
bias_init: tp.Callable[
[flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype],
flax_module.Array,
] = flax_module.zeros,
):
"""
Arguments:
features: a sequence of L+1 integers, where L is the number of layers,
the first integer is the number of input features and all subsequent
integers are the number of output features of the respective layer.
use_bias: whether to add a bias to the output (default: True).
dtype: the dtype of the computation (default: float32).
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
kernel_init: initializer function for the weight matrix.
bias_init: initializer function for the bias.
"""
if len(features) < 2:
raise ValueError("features must have at least 2 elements")

self.features = features
self.layers = [
Linear(
features_in=features_in,
features_out=features_out,
use_bias=use_bias,
dtype=dtype,
precision=precision,
kernel_init=kernel_init,
bias_init=bias_init,
)
for features_in, features_out in zip(features[:-1], features[1:])
]

def __call__(self, x: np.ndarray) -> jnp.ndarray:
"""
Applies the MLP to the input.
Arguments:
x: input array.
Returns:
The output of the MLP.
"""
for layer in self.layers[:-1]:
x = jax.nn.relu(layer(x))

return self.layers[-1](x)
74 changes: 74 additions & 0 deletions treex/nn/sequence.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import typing as tp

import jax.numpy as jnp
import numpy as np
from treex.module import Module


def sequence(
Expand All @@ -11,3 +14,74 @@ def _sequence(x: np.ndarray) -> np.ndarray:
return x

return _sequence


CallableModule = tp.cast(
tp.Type[tp.Callable[[np.ndarray], np.ndarray]], tp.List[Module]
)


class Sequence(Module):
"""
A Module that applies a sequence of Modules or functions in order.
Example:
```python
mlp = tx.Sequence(
tx.Linear(2, 32),
jax.nn.relu,
tx.Linear(32, 8),
jax.nn.relu,
tx.Linear(8, 4),
).init(42)
x = np.random.uniform(size=(10, 2))
y = mlp(x)
assert y.shape == (10, 4)
```
"""

layers: tp.List[CallableModule]

def __init__(
self, *layers: tp.Union[CallableModule, tp.Callable[[np.ndarray], np.ndarray]]
):
"""
Arguments:
layers: A list of layers or callables to apply to apply in sequence.
"""
self.layers = [
layer if isinstance(layer, Module) else Lambda(layer) for layer in layers
]

def __call__(self, x: np.ndarray) -> np.ndarray:
for layer in self.layers:
x = layer(x)
return x


class Lambda(Module):
"""
A Module that applies a pure function to its input.
"""

f: tp.Callable[[np.ndarray], np.ndarray]

def __init__(self, f: tp.Callable[[np.ndarray], np.ndarray]):
"""
Arguments:
f: A function to apply to the input.
"""
self.f = f
self.f = f

def __call__(self, x: np.ndarray) -> np.ndarray:
"""
Arguments:
x: The input to the function.
Returns:
The output of the function.
"""
return self.f(x)

2 comments on commit 82e612d

@wookayin
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Sequential would be a much better name because Keras, Pytorch, and sonnet all have such a name Sequential rather than Sequence.

@cgarciae
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wookayin this is a great point!
I've create #3 to address this.

Please sign in to comment.