From 82e612daf703911b355e31708ca9f863a5bd98c7 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 26 Aug 2021 18:51:23 -0500 Subject: [PATCH] add Sequence --- tests/nn/test_mlp.py | 18 +++++++++ tests/nn/test_sequence.py | 34 ++++++++++++++++ treex/module.py | 10 +++-- treex/nn/__init__.py | 7 ++-- treex/nn/batch_norm.py | 6 +-- treex/nn/conv.py | 2 +- treex/nn/linear.py | 2 +- treex/nn/mlp.py | 81 +++++++++++++++++++++++++++++++++++++++ treex/nn/sequence.py | 74 +++++++++++++++++++++++++++++++++++ 9 files changed, 223 insertions(+), 11 deletions(-) create mode 100644 tests/nn/test_mlp.py create mode 100644 tests/nn/test_sequence.py create mode 100644 treex/nn/mlp.py diff --git a/tests/nn/test_mlp.py b/tests/nn/test_mlp.py new file mode 100644 index 00000000..b2158e35 --- /dev/null +++ b/tests/nn/test_mlp.py @@ -0,0 +1,18 @@ +import numpy as np +import pytest +import treex as tx + + +class TestMLP: + def test_basic(self): + mlp = tx.MLP([2, 32, 8, 4]).init(42) + + x = np.random.uniform(-1, 1, (10, 2)) + y = mlp(x) + + assert y.shape == (10, 4) + + def test_too_few_features(self): + + with pytest.raises(ValueError): + mlp = tx.MLP([2]).init(42) diff --git a/tests/nn/test_sequence.py b/tests/nn/test_sequence.py new file mode 100644 index 00000000..7b57c18d --- /dev/null +++ b/tests/nn/test_sequence.py @@ -0,0 +1,34 @@ +import jax +import numpy as np +import pytest +import treex as tx + + +class TestSequence: + def test_basic(self): + mlp = tx.Sequence( + tx.Linear(2, 32), + jax.nn.relu, + tx.Linear(32, 8), + jax.nn.relu, + tx.Linear(8, 4), + ).init(42) + + assert isinstance(mlp.layers[1], tx.Lambda) + assert isinstance(mlp.layers[3], tx.Lambda) + + x = np.random.uniform(-1, 1, (10, 2)) + y = mlp(x) + + assert y.shape == (10, 4) + + def test_pytree(self): + mlp = tx.Sequence( + tx.Linear(2, 32), + jax.nn.relu, + tx.Linear(32, 8), + jax.nn.relu, + tx.Linear(8, 4), + ).init(42) + + jax.tree_map(lambda x: 2 * x, mlp) diff --git a/treex/module.py b/treex/module.py index 65276115..e29dddf0 100644 --- a/treex/module.py +++ b/treex/module.py @@ -269,7 +269,6 @@ def double_params(self): Returns: A new module with the updated values or `None` if `inplace` is `True`. - """ modules = (self, other) + rest @@ -727,7 +726,12 @@ def _resolve_tree_type(name: str, t: tp.Optional[type]) -> tp.Optional[type]: if t is None: return None - tree_types = [x for x in _all_types(t) if issubclass(x, (types.TreePart, Module))] + tree_types = [ + x + for x in _all_types(t) + if isinstance(x, tp.Type) + if issubclass(x, (types.TreePart, Module)) + ] if len(tree_types) > 1: # if its a type with many Module subtypes just mark them all as Module @@ -740,7 +744,7 @@ def _resolve_tree_type(name: str, t: tp.Optional[type]) -> tp.Optional[type]: elif len(tree_types) == 1: return tree_types[0] else: - return t + return None def _all_types(t: tp.Type) -> tp.Iterable[tp.Type]: diff --git a/treex/nn/__init__.py b/treex/nn/__init__.py index 8b889dc2..5160946c 100644 --- a/treex/nn/__init__.py +++ b/treex/nn/__init__.py @@ -1,5 +1,6 @@ -from .linear import Linear from .batch_norm import BatchNorm -from .sequence import sequence -from .dropout import Dropout from .conv import Conv +from .dropout import Dropout +from .linear import Linear +from .mlp import MLP +from .sequence import sequence, Sequence, Lambda diff --git a/treex/nn/batch_norm.py b/treex/nn/batch_norm.py index b90d04d4..cf925d1d 100644 --- a/treex/nn/batch_norm.py +++ b/treex/nn/batch_norm.py @@ -103,9 +103,9 @@ def module_init(self, key: jnp.ndarray): # Extract collections if "params" in variables: - self.params = variables["params"] + self.params = variables["params"].unfreeze() - self.batch_stats = variables["batch_stats"] + self.batch_stats = variables["batch_stats"].unfreeze() def __call__( self, x: np.ndarray, use_running_average: tp.Optional[bool] = None @@ -144,6 +144,6 @@ def __call__( # update batch_stats if "batch_stats" in variables: - self.batch_stats = variables["batch_stats"] + self.batch_stats = variables["batch_stats"].unfreeze() return tp.cast(jnp.ndarray, output) diff --git a/treex/nn/conv.py b/treex/nn/conv.py index fd59c599..6de0e667 100644 --- a/treex/nn/conv.py +++ b/treex/nn/conv.py @@ -112,7 +112,7 @@ def module_init(self, key: jnp.ndarray): variables = self.module.init(key, x) # Extract collections - self.params = variables["params"] + self.params = variables["params"].unfreeze() def __call__(self, x: np.ndarray) -> jnp.ndarray: """Applies a convolution to the inputs. diff --git a/treex/nn/linear.py b/treex/nn/linear.py index b43e06c2..da44189f 100644 --- a/treex/nn/linear.py +++ b/treex/nn/linear.py @@ -71,7 +71,7 @@ def module_init(self, key: jnp.ndarray): variables = self.module.init(key, x) # Extract collections - self.params = variables["params"] + self.params = variables["params"].unfreeze() def __call__(self, x: np.ndarray) -> jnp.ndarray: """Applies a linear transformation to the inputs along the last dimension. diff --git a/treex/nn/mlp.py b/treex/nn/mlp.py new file mode 100644 index 00000000..ec741789 --- /dev/null +++ b/treex/nn/mlp.py @@ -0,0 +1,81 @@ +import typing as tp +from flax.linen import linear as flax_module +import jax +import jax.numpy as jnp +import numpy as np + +from treex.module import Module +from treex import types +from treex.nn.linear import Linear + + +class MLP(Module): + """A Multi-Layer Perceptron (MLP) that applies a sequence of linear layers + with relu activations, the last layer is linear. + """ + + # pytree + layers: tp.List[Linear] + + # props + features: tp.Sequence[int] + module: flax_module.Dense + + def __init__( + self, + features: tp.Sequence[int], + use_bias: bool = True, + dtype: tp.Any = jnp.float32, + precision: tp.Any = None, + kernel_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] = flax_module.default_kernel_init, + bias_init: tp.Callable[ + [flax_module.PRNGKey, flax_module.Shape, flax_module.Dtype], + flax_module.Array, + ] = flax_module.zeros, + ): + """ + Arguments: + features: a sequence of L+1 integers, where L is the number of layers, + the first integer is the number of input features and all subsequent + integers are the number of output features of the respective layer. + use_bias: whether to add a bias to the output (default: True). + dtype: the dtype of the computation (default: float32). + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer function for the weight matrix. + bias_init: initializer function for the bias. + """ + if len(features) < 2: + raise ValueError("features must have at least 2 elements") + + self.features = features + self.layers = [ + Linear( + features_in=features_in, + features_out=features_out, + use_bias=use_bias, + dtype=dtype, + precision=precision, + kernel_init=kernel_init, + bias_init=bias_init, + ) + for features_in, features_out in zip(features[:-1], features[1:]) + ] + + def __call__(self, x: np.ndarray) -> jnp.ndarray: + """ + Applies the MLP to the input. + + Arguments: + x: input array. + + Returns: + The output of the MLP. + """ + for layer in self.layers[:-1]: + x = jax.nn.relu(layer(x)) + + return self.layers[-1](x) diff --git a/treex/nn/sequence.py b/treex/nn/sequence.py index 67f482c0..34bd8dbc 100644 --- a/treex/nn/sequence.py +++ b/treex/nn/sequence.py @@ -1,5 +1,8 @@ import typing as tp + +import jax.numpy as jnp import numpy as np +from treex.module import Module def sequence( @@ -11,3 +14,74 @@ def _sequence(x: np.ndarray) -> np.ndarray: return x return _sequence + + +CallableModule = tp.cast( + tp.Type[tp.Callable[[np.ndarray], np.ndarray]], tp.List[Module] +) + + +class Sequence(Module): + """ + A Module that applies a sequence of Modules or functions in order. + + Example: + + ```python + mlp = tx.Sequence( + tx.Linear(2, 32), + jax.nn.relu, + tx.Linear(32, 8), + jax.nn.relu, + tx.Linear(8, 4), + ).init(42) + + x = np.random.uniform(size=(10, 2)) + y = mlp(x) + + assert y.shape == (10, 4) + ``` + """ + + layers: tp.List[CallableModule] + + def __init__( + self, *layers: tp.Union[CallableModule, tp.Callable[[np.ndarray], np.ndarray]] + ): + """ + Arguments: + layers: A list of layers or callables to apply to apply in sequence. + """ + self.layers = [ + layer if isinstance(layer, Module) else Lambda(layer) for layer in layers + ] + + def __call__(self, x: np.ndarray) -> np.ndarray: + for layer in self.layers: + x = layer(x) + return x + + +class Lambda(Module): + """ + A Module that applies a pure function to its input. + """ + + f: tp.Callable[[np.ndarray], np.ndarray] + + def __init__(self, f: tp.Callable[[np.ndarray], np.ndarray]): + """ + Arguments: + f: A function to apply to the input. + """ + self.f = f + self.f = f + + def __call__(self, x: np.ndarray) -> np.ndarray: + """ + Arguments: + x: The input to the function. + Returns: + The output of the function. + """ + return self.f(x)