From 22fe24498e8d7fc235270d59a02e846fe9a0a198 Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Thu, 7 Dec 2023 14:19:09 -0800 Subject: [PATCH] updated embed docstring --- docs/_templates/autosummary/flax_module.rst | 3 +- flax/experimental/nnx/nnx/nn/linear.py | 11 +++++- flax/linen/linear.py | 43 ++++++++++++++------- 3 files changed, 41 insertions(+), 16 deletions(-) diff --git a/docs/_templates/autosummary/flax_module.rst b/docs/_templates/autosummary/flax_module.rst index b3f605222a..24c4ea836d 100644 --- a/docs/_templates/autosummary/flax_module.rst +++ b/docs/_templates/autosummary/flax_module.rst @@ -8,6 +8,7 @@ {% block methods %} .. automethod:: __call__ + .. automethod:: attend {% if methods %} .. rubric:: Methods @@ -15,7 +16,7 @@ .. autosummary:: {% for item in methods %} - {%- if item not in inherited_members and item not in annotations and not item in ['__init__'] %} + {%- if item not in inherited_members and item not in annotations and not item in ['__init__', 'setup'] %} ~{{ name }}.{{ item }} {%- endif %} {%- endfor %} diff --git a/flax/experimental/nnx/nnx/nn/linear.py b/flax/experimental/nnx/nnx/nn/linear.py index 7f1c86b52a..6308f5799b 100644 --- a/flax/experimental/nnx/nnx/nn/linear.py +++ b/flax/experimental/nnx/nnx/nn/linear.py @@ -586,10 +586,15 @@ def maybe_broadcast( class Embed(Module): """Embedding Module. - A parameterized function from integers [0, n) to d-dimensional vectors. + A parameterized function from integers [0, ``num_embeddings``) to + ``features``-dimensional vectors. This ``Module`` will create an ``embedding`` + matrix with shape ``(num_embeddings, features)``. When calling this layer, + the input values will be used to 0-index into the ``embedding`` matrix. + Indexing on a value greater than or equal to ``num_embeddings`` will result + in ``nan`` values. Attributes: - num_embeddings: number of embeddings. + num_embeddings: number of embeddings / vocab size. features: number of feature dimensions for each embedding. dtype: the dtype of the embedding vectors (default: same as embedding). param_dtype: the dtype passed to parameter initializers (default: float32). @@ -623,6 +628,7 @@ def __call__(self, inputs: Array) -> Array: Args: inputs: input data, all dimensions are considered batch dimensions. + Values in the input array must be integers. Returns: Output which is embedded input data. The output shape follows the input, @@ -643,6 +649,7 @@ def attend(self, query: Array) -> Array: Args: query: array with last dimension equal the feature depth `features` of the embedding. + Returns: An array with final dim `num_embeddings` corresponding to the batched inner-product of the array of query vectors against each embedding. diff --git a/flax/linen/linear.py b/flax/linen/linear.py index ff6a6dab4b..9309539870 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -935,24 +935,39 @@ def __call__(self, inputs: Array) -> Array: class Embed(Module): """Embedding Module. - A parameterized function from integers [0, n) to d-dimensional vectors. + A parameterized function from integers [0, ``num_embeddings``) to + ``features``-dimensional vectors. This ``Module`` will create an ``embedding`` + matrix with shape ``(num_embeddings, features)``. When calling this layer, + the input values will be used to 0-index into the ``embedding`` matrix. + Indexing on a value greater than or equal to ``num_embeddings`` will result + in ``nan`` values. Example usage:: >>> import flax.linen as nn >>> import jax, jax.numpy as jnp - >>> layer = nn.Embed(num_embeddings=4, features=3) - >>> variables = layer.init(jax.random.key(0), jnp.ones((1, 5), dtype=int)) - >>> jax.tree_map(jnp.shape, variables) - {'params': {'embedding': (4, 3)}} - >>> layer.apply(variables, jnp.ones((5,), dtype=int)).shape - (5, 3) - >>> layer.apply(variables, jnp.ones((5, 6), dtype=int)).shape - (5, 6, 3) + >>> layer = nn.Embed(num_embeddings=5, features=3) + >>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]]) + >>> variables = layer.init(jax.random.key(0), indices_input) + >>> variables + {'params': {'embedding': Array([[-0.28884724, 0.19018005, -0.414205 ], + [-0.11768015, -0.54618824, -0.3789283 ], + [ 0.30428642, 0.49511626, 0.01706631], + [-0.0982546 , -0.43055868, 0.20654906], + [-0.688412 , -0.46882293, 0.26723292]], dtype=float32)}} + >>> # get the first three and last three embeddings + >>> layer.apply(variables, indices_input) + Array([[[-0.28884724, 0.19018005, -0.414205 ], + [-0.11768015, -0.54618824, -0.3789283 ], + [ 0.30428642, 0.49511626, 0.01706631]], + + [[-0.688412 , -0.46882293, 0.26723292], + [-0.0982546 , -0.43055868, 0.20654906], + [ 0.30428642, 0.49511626, 0.01706631]]], dtype=float32) Attributes: - num_embeddings: number of embeddings. + num_embeddings: number of embeddings / vocab size. features: number of feature dimensions for each embedding. dtype: the dtype of the embedding vectors (default: same as embedding). param_dtype: the dtype passed to parameter initializers (default: float32). @@ -980,10 +995,11 @@ def __call__(self, inputs: Array) -> Array: Args: inputs: input data, all dimensions are considered batch dimensions. + Values in the input array must be integers. Returns: Output which is embedded input data. The output shape follows the input, - with an additional `features` dimension appended. + with an additional ``features`` dimension appended. """ if not jnp.issubdtype(inputs.dtype, jnp.integer): raise ValueError('Input type must be an integer or unsigned integer.') @@ -998,10 +1014,11 @@ def attend(self, query: Array) -> Array: """Attend over the embedding using a query array. Args: - query: array with last dimension equal the feature depth `features` of the + query: array with last dimension equal the feature depth ``features`` of the embedding. + Returns: - An array with final dim `num_embeddings` corresponding to the batched + An array with final dim ``num_embeddings`` corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models.