-
Notifications
You must be signed in to change notification settings - Fork 245
/
token_and_position_embedding.py
136 lines (124 loc) · 5.14 KB
/
token_and_position_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import keras
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.modeling.position_embedding import PositionEmbedding
from keras_hub.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
)
from keras_hub.src.utils.keras_utils import clone_initializer
@keras_hub_export("keras_hub.layers.TokenAndPositionEmbedding")
class TokenAndPositionEmbedding(keras.layers.Layer):
"""A layer which sums a token and position embedding.
Token and position embeddings are ways of representing words and their order
in a sentence. This layer creates a `keras.layers.Embedding` token embedding
and a `keras_hub.layers.PositionEmbedding` position embedding and sums their
output when called. This layer assumes that the last dimension in the input
corresponds to the sequence dimension.
Args:
vocabulary_size: The size of the vocabulary.
sequence_length: The maximum length of input sequence
embedding_dim: The output dimension of the embedding layer
tie_weights: Boolean, whether or not the matrix for embedding and
the matrix for the `reverse` projection should share the same
weights.
embeddings_initializer: The initializer to use for the Embedding
Layers
mask_zero: Boolean, whether or not the input value 0 is a special
"padding" value that should be masked out.
This is useful when using recurrent layers which may take variable
length input. If this is True, then all subsequent layers in the
model need to support masking or an exception will be raised.
If mask_zero` is set to True, as a consequence, index 0 cannot be
used in the vocabulary
(input_dim should equal size of vocabulary + 1).
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
including `name`, `trainable`, `dtype` etc.
Example:
```python
inputs = np.ones(shape=(1, 50), dtype="int32")
embedding_layer = keras_hub.layers.TokenAndPositionEmbedding(
vocabulary_size=10_000,
sequence_length=50,
embedding_dim=128,
)
outputs = embedding_layer(inputs)
```
"""
def __init__(
self,
vocabulary_size,
sequence_length,
embedding_dim,
tie_weights=True,
embeddings_initializer="uniform",
mask_zero=False,
**kwargs
):
super().__init__(**kwargs)
if vocabulary_size is None:
raise ValueError(
"`vocabulary_size` must be an Integer, received `None`."
)
if sequence_length is None:
raise ValueError(
"`sequence_length` must be an Integer, received `None`."
)
if embedding_dim is None:
raise ValueError(
"`embedding_dim` must be an Integer, received `None`."
)
self.vocabulary_size = int(vocabulary_size)
self.sequence_length = int(sequence_length)
self.embedding_dim = int(embedding_dim)
self.embeddings_initializer = keras.initializers.get(
embeddings_initializer
)
self.token_embedding = ReversibleEmbedding(
vocabulary_size,
embedding_dim,
tie_weights=tie_weights,
embeddings_initializer=clone_initializer(
self.embeddings_initializer
),
mask_zero=mask_zero,
dtype=self.dtype_policy,
name="token_embedding",
)
self.position_embedding = PositionEmbedding(
sequence_length=sequence_length,
initializer=clone_initializer(self.embeddings_initializer),
dtype=self.dtype_policy,
name="position_embedding",
)
self.supports_masking = self.token_embedding.supports_masking
def build(self, input_shape):
input_shape = tuple(input_shape)
self.token_embedding.build(input_shape)
self.position_embedding.build(input_shape + (self.embedding_dim,))
self.built = True
def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"sequence_length": self.sequence_length,
"embedding_dim": self.embedding_dim,
"embeddings_initializer": keras.initializers.serialize(
self.embeddings_initializer
),
"tie_weights": self.token_embedding.tie_weights,
"mask_zero": self.token_embedding.mask_zero,
}
)
return config
def call(self, inputs, start_index=0):
embedded_tokens = self.token_embedding(inputs)
embedded_positions = self.position_embedding(
embedded_tokens,
start_index=start_index,
)
outputs = embedded_tokens + embedded_positions
return outputs
def compute_mask(self, inputs, mask=None):
return self.token_embedding.compute_mask(inputs, mask=mask)
def compute_output_shape(self, input_shape):
return tuple(input_shape) + (self.embedding_dim,)