Skip to content

Commit

Permalink
updated embed docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Dec 8, 2023
1 parent 50cd169 commit 837f3b7
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 17 deletions.
8 changes: 6 additions & 2 deletions docs/_templates/autosummary/flax_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@

{% block methods %}

.. automethod:: __call__
{% for item in methods %}
{%- if item not in inherited_members and item not in annotations and not item in ['__init__', 'setup'] %}
.. automethod:: {{ item }}
{%- endif %}
{%- endfor %}

{% if methods %}
.. rubric:: Methods

.. autosummary::

{% for item in methods %}
{%- if item not in inherited_members and item not in annotations and not item in ['__init__'] %}
{%- if item not in inherited_members and item not in annotations and not item in ['__init__', 'setup'] %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
Expand Down
11 changes: 9 additions & 2 deletions flax/experimental/nnx/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,15 @@ def maybe_broadcast(
class Embed(Module):
"""Embedding Module.
A parameterized function from integers [0, n) to d-dimensional vectors.
A parameterized function from integers [0, ``num_embeddings``) to
``features``-dimensional vectors. This ``Module`` will create an ``embedding``
matrix with shape ``(num_embeddings, features)``. When calling this layer,
the input values will be used to 0-index into the ``embedding`` matrix.
Indexing on a value greater than or equal to ``num_embeddings`` will result
in ``nan`` values.
Attributes:
num_embeddings: number of embeddings.
num_embeddings: number of embeddings / vocab size.
features: number of feature dimensions for each embedding.
dtype: the dtype of the embedding vectors (default: same as embedding).
param_dtype: the dtype passed to parameter initializers (default: float32).
Expand Down Expand Up @@ -623,6 +628,7 @@ def __call__(self, inputs: Array) -> Array:
Args:
inputs: input data, all dimensions are considered batch dimensions.
Values in the input array must be integers.
Returns:
Output which is embedded input data. The output shape follows the input,
Expand All @@ -643,6 +649,7 @@ def attend(self, query: Array) -> Array:
Args:
query: array with last dimension equal the feature depth `features` of the
embedding.
Returns:
An array with final dim `num_embeddings` corresponding to the batched
inner-product of the array of query vectors against each embedding.
Expand Down
43 changes: 30 additions & 13 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,24 +935,39 @@ def __call__(self, inputs: Array) -> Array:
class Embed(Module):
"""Embedding Module.
A parameterized function from integers [0, n) to d-dimensional vectors.
A parameterized function from integers [0, ``num_embeddings``) to
``features``-dimensional vectors. This ``Module`` will create an ``embedding``
matrix with shape ``(num_embeddings, features)``. When calling this layer,
the input values will be used to 0-index into the ``embedding`` matrix.
Indexing on a value greater than or equal to ``num_embeddings`` will result
in ``nan`` values.
Example usage::
>>> import flax.linen as nn
>>> import jax, jax.numpy as jnp
>>> layer = nn.Embed(num_embeddings=4, features=3)
>>> variables = layer.init(jax.random.key(0), jnp.ones((1, 5), dtype=int))
>>> jax.tree_map(jnp.shape, variables)
{'params': {'embedding': (4, 3)}}
>>> layer.apply(variables, jnp.ones((5,), dtype=int)).shape
(5, 3)
>>> layer.apply(variables, jnp.ones((5, 6), dtype=int)).shape
(5, 6, 3)
>>> layer = nn.Embed(num_embeddings=5, features=3)
>>> indices_input = jnp.array([[0, 1, 2], [-1, -2, -3]])
>>> variables = layer.init(jax.random.key(0), indices_input)
>>> variables
{'params': {'embedding': Array([[-0.28884724, 0.19018005, -0.414205 ],
[-0.11768015, -0.54618824, -0.3789283 ],
[ 0.30428642, 0.49511626, 0.01706631],
[-0.0982546 , -0.43055868, 0.20654906],
[-0.688412 , -0.46882293, 0.26723292]], dtype=float32)}}
>>> # get the first three and last three embeddings
>>> layer.apply(variables, indices_input)
Array([[[-0.28884724, 0.19018005, -0.414205 ],
[-0.11768015, -0.54618824, -0.3789283 ],
[ 0.30428642, 0.49511626, 0.01706631]],
<BLANKLINE>
[[-0.688412 , -0.46882293, 0.26723292],
[-0.0982546 , -0.43055868, 0.20654906],
[ 0.30428642, 0.49511626, 0.01706631]]], dtype=float32)
Attributes:
num_embeddings: number of embeddings.
num_embeddings: number of embeddings / vocab size.
features: number of feature dimensions for each embedding.
dtype: the dtype of the embedding vectors (default: same as embedding).
param_dtype: the dtype passed to parameter initializers (default: float32).
Expand Down Expand Up @@ -980,10 +995,11 @@ def __call__(self, inputs: Array) -> Array:
Args:
inputs: input data, all dimensions are considered batch dimensions.
Values in the input array must be integers.
Returns:
Output which is embedded input data. The output shape follows the input,
with an additional `features` dimension appended.
with an additional ``features`` dimension appended.
"""
if not jnp.issubdtype(inputs.dtype, jnp.integer):
raise ValueError('Input type must be an integer or unsigned integer.')
Expand All @@ -998,10 +1014,11 @@ def attend(self, query: Array) -> Array:
"""Attend over the embedding using a query array.
Args:
query: array with last dimension equal the feature depth `features` of the
query: array with last dimension equal the feature depth ``features`` of the
embedding.
Returns:
An array with final dim `num_embeddings` corresponding to the batched
An array with final dim ``num_embeddings`` corresponding to the batched
inner-product of the array of query vectors against each embedding.
Commonly used for weight-sharing between embeddings and logit transform
in NLP models.
Expand Down

0 comments on commit 837f3b7

Please sign in to comment.