Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding inputs_embeds argument and switch to paddle.nn.TransformerEncoder for Electra models #3401

Merged
merged 19 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 24 additions & 118 deletions paddlenlp/transformers/electra/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,121 +71,6 @@ def swish(x):
}


class TransformerEncoderLayerPro(TransformerEncoderLayer):

def __init__(self,
d_model,
nhead,
dim_feedforward,
dropout=0.1,
activation="relu",
attn_dropout=None,
act_dropout=None,
normalize_before=False,
weight_attr=None,
bias_attr=None):
super(TransformerEncoderLayerPro,
self).__init__(d_model, nhead, dim_feedforward, dropout,
activation, attn_dropout, act_dropout,
normalize_before, weight_attr, bias_attr)

def forward(self, src, src_mask=None, cache=None, output_attentions=False):
self.self_attn.need_weights = output_attentions
src_mask = _convert_attention_mask(src_mask, src.dtype)
attentions = None

residual = src
if self.normalize_before:
src = self.norm1(src)
if cache is None:
src = self.self_attn(src, src, src, src_mask)
if output_attentions:
src, attentions = src
else:
output = self.self_attn(src, src, src, src_mask, cache)
if output_attentions:
src, attentions, incremental_cache = output
else:
src, incremental_cache = output

src = residual + self.dropout1(src)
if not self.normalize_before:
src = self.norm1(src)

residual = src
if self.normalize_before:
src = self.norm2(src)
src = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src)
if not self.normalize_before:
src = self.norm2(src)
if output_attentions:
src = (src, attentions)
return src if cache is None else (src, incremental_cache)


class TransformerEncoderPro(TransformerEncoder):

def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoderPro, self).__init__(encoder_layer, num_layers,
norm)

def forward(self,
src,
src_mask=None,
cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=False):
src_mask = _convert_attention_mask(src_mask, src.dtype)

output = src
new_caches = []
all_attentions = [] if output_attentions else None
all_hidden_states = [] if output_hidden_states else None
for i, mod in enumerate(self.layers):

if output_hidden_states:
all_hidden_states.append(output)

if cache is None:
output = mod(output,
src_mask=src_mask,
output_attentions=output_attentions)
else:
cache_wrapper = cache[i] if isinstance(
cache[i], nn.MultiHeadAttention.Cache
) else nn.MultiHeadAttention.Cache(*cache[i])
output, new_cache = mod(output,
src_mask=src_mask,
cache=cache_wrapper,
output_attentions=output_attentions)
new_caches.append(new_cache)
if output_attentions:
all_attentions.append(output[1])
output = output[0]

if output_hidden_states:
all_hidden_states.append(output)

if self.norm is not None:
output = self.norm(output)

if output_hidden_states:
all_hidden_states[-1] = output

if not return_dict:
if output_attentions or output_hidden_states:
output = (output, all_attentions, all_hidden_states)
return output if cache is None else (output, new_caches)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=output,
hidden_states=all_hidden_states,
attentions=all_attentions,
past_key_values=new_caches)


class ElectraEmbeddings(nn.Layer):
"""Construct the embeddings from word, position and token_type embeddings."""

Expand All @@ -205,6 +90,7 @@ def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
inputs_embeds=None,
past_key_values_length=None):
if position_ids is None:
ones = paddle.ones_like(input_ids, dtype="int64")
Expand All @@ -218,7 +104,10 @@ def forward(self,
if token_type_ids is None:
token_type_ids = paddle.zeros_like(input_ids, dtype="int64")

input_embeddings = self.word_embeddings(input_ids)
if input_ids is not None:
input_embeddings = self.word_embeddings(input_ids)
else:
input_embeddings = inputs_embeds
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

Expand Down Expand Up @@ -535,15 +424,15 @@ def __init__(self,
if embedding_size != hidden_size:
self.embeddings_project = nn.Linear(embedding_size, hidden_size)

encoder_layer = TransformerEncoderLayerPro(
encoder_layer = TransformerEncoderLayer(
hidden_size,
num_attention_heads,
intermediate_size,
dropout=hidden_dropout_prob,
activation=hidden_act,
attn_dropout=attention_probs_dropout_prob,
act_dropout=0)
self.encoder = TransformerEncoderPro(encoder_layer, num_hidden_layers)
self.encoder = TransformerEncoder(encoder_layer, num_hidden_layers)

self.init_weights()

Expand All @@ -558,6 +447,7 @@ def forward(self,
token_type_ids=None,
position_ids=None,
attention_mask=None,
inputs_embeds=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
Expand Down Expand Up @@ -595,6 +485,11 @@ def forward(self,
When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values.
It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`.
Defaults to `None`, which means nothing needed to be prevented attention to.
inputs_embeds (Tensor, optional):
Instead of passing input_ids you can choose to directly pass an embedded representation.
This is useful for use cases such as P-Tuning, where you want more control over how to convert input_ids indices
into the embedding space.
Its data type should be `float32` and it has a shape of [batch_size, sequence_length, embedding_size].
past_key_values (tuple(tuple(Tensor)), optional):
Precomputed key and value hidden states of the attention blocks of each layer. This can be used to speedup
auto-regressive decoding for generation tasks or to support use cases such as Prefix-Tuning where vectors are prepended
Expand Down Expand Up @@ -658,6 +553,7 @@ def forward(self,
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length)

if hasattr(self, "embeddings_project"):
Expand Down Expand Up @@ -778,6 +674,7 @@ def forward(self,
token_type_ids=None,
position_ids=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=False,
output_hidden_states=False,
Expand Down Expand Up @@ -825,6 +722,7 @@ def forward(self,
token_type_ids,
position_ids,
attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
Expand Down Expand Up @@ -1034,6 +932,7 @@ def forward(
token_type_ids=None,
position_ids=None,
attention_mask=None,
inputs_embeds=None,
labels=None,
output_attentions: bool = None,
output_hidden_states: bool = None,
Expand Down Expand Up @@ -1085,6 +984,7 @@ def forward(
token_type_ids,
position_ids,
attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
Expand Down Expand Up @@ -1150,6 +1050,7 @@ def forward(self,
token_type_ids=None,
position_ids=None,
attention_mask=None,
inputs_embeds=None,
labels: Optional[Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
Expand Down Expand Up @@ -1204,6 +1105,7 @@ def forward(self,
token_type_ids,
position_ids,
attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
Expand Down Expand Up @@ -1660,6 +1562,7 @@ def forward(
token_type_ids=None,
position_ids=None,
attention_mask=None,
inputs_embeds=None,
labels: Optional[Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
Expand Down Expand Up @@ -1764,6 +1667,7 @@ def forward(
token_type_ids,
position_ids,
attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand Down Expand Up @@ -2054,6 +1958,7 @@ def forward(
token_type_ids=None,
position_ids=None,
attention_mask=None,
inputs_embeds=None,
start_positions: Optional[Tensor] = None,
end_positions: Optional[Tensor] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -2124,6 +2029,7 @@ def forward(
token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
Expand Down
1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
parameterized
sentencepiece
regex
Loading