From 8090a4164bfb0b12d912fee47739993e685ba0ac Mon Sep 17 00:00:00 2001 From: Sijun He Date: Sat, 1 Oct 2022 16:59:23 +0800 Subject: [PATCH 1/8] add inputs_embeds input arguments to all electra models --- paddlenlp/transformers/electra/modeling.py | 36 +++++++++++++++++-- tests/requirements.txt | 1 + tests/transformers/electra/test_modeling.py | 38 +++++++++++++++++---- 3 files changed, 66 insertions(+), 9 deletions(-) diff --git a/paddlenlp/transformers/electra/modeling.py b/paddlenlp/transformers/electra/modeling.py index 400b21ce462d..dd3fa12a69c0 100644 --- a/paddlenlp/transformers/electra/modeling.py +++ b/paddlenlp/transformers/electra/modeling.py @@ -199,11 +199,26 @@ def __init__(self, vocab_size, embedding_size, hidden_dropout_prob, self.layer_norm = nn.LayerNorm(embedding_size, epsilon=layer_norm_eps) self.dropout = nn.Dropout(hidden_dropout_prob) - def forward(self, input_ids, token_type_ids=None, position_ids=None): + def forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=None): + + if input_ids is not None: + input_shape = paddle.shape(input_ids) + input_embeddings = self.word_embeddings(input_ids) + else: + input_shape = paddle.shape(inputs_embeds)[:-1] + input_embeddings = inputs_embeds + if position_ids is None: ones = paddle.ones_like(input_ids, dtype="int64") seq_length = paddle.cumsum(ones, axis=-1) position_ids = seq_length - ones + if past_key_values_length is not None: + position_ids += past_key_values_length position_ids.stop_gradient = True position_ids = position_ids.astype("int64") @@ -550,6 +565,7 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + inputs_embeds=None, output_attentions=False, output_hidden_states=False, return_dict=False): @@ -585,6 +601,11 @@ def forward(self, When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values. It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`. Defaults to `None`, which means nothing needed to be prevented attention to. + inputs_embeds (Tensor, optional): + Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. + This is useful for use cases such as P-Tuning, where you want more control over how to convert input_ids indices + into associated vectors than the model's internal embedding lookup matrix. + Its data type should be `float32` and it has a shape of [batch_size, sequence_length, embedding_size]. output_hidden_states (bool, optional): Whether to return the hidden states of all layers. Defaults to `False`. @@ -625,7 +646,8 @@ def forward(self, embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, - token_type_ids=token_type_ids) + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds) if hasattr(self, "embeddings_project"): embedding_output = self.embeddings_project(embedding_output) @@ -743,6 +765,7 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + inputs_embeds=None, labels=None, output_attentions=False, output_hidden_states=False, @@ -790,6 +813,7 @@ def forward(self, token_type_ids, position_ids, attention_mask, + inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) @@ -999,6 +1023,7 @@ def forward( token_type_ids=None, position_ids=None, attention_mask=None, + inputs_embeds=None, labels=None, output_attentions: bool = None, output_hidden_states: bool = None, @@ -1050,6 +1075,7 @@ def forward( token_type_ids, position_ids, attention_mask, + inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) @@ -1115,6 +1141,7 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None, + inputs_embeds=None, labels: Optional[Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -1169,6 +1196,7 @@ def forward(self, token_type_ids, position_ids, attention_mask, + inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict) @@ -1625,6 +1653,7 @@ def forward( token_type_ids=None, position_ids=None, attention_mask=None, + inputs_embeds=None, labels: Optional[Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -1729,6 +1758,7 @@ def forward( token_type_ids, position_ids, attention_mask, + inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -2019,6 +2049,7 @@ def forward( token_type_ids=None, position_ids=None, attention_mask=None, + inputs_embeds=None, start_positions: Optional[Tensor] = None, end_positions: Optional[Tensor] = None, output_attentions: Optional[bool] = None, @@ -2089,6 +2120,7 @@ def forward( token_type_ids, position_ids=position_ids, attention_mask=attention_mask, + inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, diff --git a/tests/requirements.txt b/tests/requirements.txt index 7842e011aad4..ab1dcac09df6 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,2 +1,3 @@ +parameterized sentencepiece regex diff --git a/tests/transformers/electra/test_modeling.py b/tests/transformers/electra/test_modeling.py index 59ae4f0ee6de..3fb6f1592e07 100644 --- a/tests/transformers/electra/test_modeling.py +++ b/tests/transformers/electra/test_modeling.py @@ -45,6 +45,7 @@ def __init__( self.is_training = True self.use_input_mask = True self.use_token_type_ids = True + self.use_inputs_embeds = False self.vocab_size = 99 self.embedding_size = 32 self.hidden_size = 32 @@ -76,6 +77,12 @@ def prepare_config_and_inputs(self): if self.use_token_type_ids: token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + inputs_embeds = None + if self.use_inputs_embeds: + inputs_embeds = floats_tensor([self.batch_size, self.seq_length, self.embedding_size]) + # In order to use inputs_embeds, input_ids needs to set to None + input_ids = None sequence_labels = None token_labels = None @@ -88,7 +95,7 @@ def prepare_config_and_inputs(self): choice_labels = ids_tensor([self.batch_size], self.num_choices) config = self.get_config() - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + return config, input_ids, token_type_ids, input_mask, inputs_embeds, sequence_labels, token_labels, choice_labels def get_config(self): return { @@ -113,6 +120,7 @@ def create_and_check_electra_model( input_ids, token_type_ids, input_mask, + inputs_embeds, sequence_labels, token_labels, choice_labels, @@ -122,6 +130,7 @@ def create_and_check_electra_model( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, return_dict=self.parent.return_dict) result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids, return_dict=self.parent.return_dict) @@ -139,6 +148,7 @@ def create_and_check_electra_for_masked_lm( input_ids, token_type_ids, input_mask, + inputs_embeds, sequence_labels, token_labels, choice_labels, @@ -148,6 +158,7 @@ def create_and_check_electra_for_masked_lm( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, labels=token_labels, return_dict=self.parent.return_dict) if not self.parent.return_dict and token_labels is None: @@ -168,6 +179,7 @@ def create_and_check_electra_for_token_classification( input_ids, token_type_ids, input_mask, + inputs_embeds, sequence_labels, token_labels, choice_labels, @@ -178,6 +190,7 @@ def create_and_check_electra_for_token_classification( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, labels=token_labels, return_dict=self.parent.return_dict) @@ -199,6 +212,7 @@ def create_and_check_electra_for_pretraining( input_ids, token_type_ids, input_mask, + inputs_embeds, sequence_labels, token_labels, choice_labels, @@ -209,6 +223,7 @@ def create_and_check_electra_for_pretraining( input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, ) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length)) @@ -219,6 +234,7 @@ def create_and_check_electra_for_sequence_classification( input_ids, token_type_ids, input_mask, + inputs_embeds, sequence_labels, token_labels, choice_labels, @@ -229,6 +245,7 @@ def create_and_check_electra_for_sequence_classification( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, labels=sequence_labels, return_dict=self.parent.return_dict) if not self.parent.return_dict and token_labels is None: @@ -248,6 +265,7 @@ def create_and_check_electra_for_question_answering( input_ids, token_type_ids, input_mask, + inputs_embeds, sequence_labels, token_labels, choice_labels, @@ -257,6 +275,7 @@ def create_and_check_electra_for_question_answering( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, start_positions=sequence_labels, end_positions=sequence_labels, return_dict=self.parent.return_dict) @@ -275,6 +294,7 @@ def create_and_check_electra_for_multiple_choice( input_ids, token_type_ids, input_mask, + inputs_embeds, sequence_labels, token_labels, choice_labels, @@ -291,6 +311,7 @@ def create_and_check_electra_for_multiple_choice( result = model(multiple_choice_inputs_ids, attention_mask=multiple_choice_input_mask, token_type_ids=multiple_choice_token_type_ids, + inputs_embeds=inputs_embeds, labels=choice_labels, return_dict=self.parent.return_dict) @@ -311,6 +332,7 @@ def prepare_config_and_inputs_for_common(self): config, input_ids, token_type_ids, + inputs_embeds, input_mask, sequence_labels, token_labels, @@ -319,16 +341,18 @@ def prepare_config_and_inputs_for_common(self): inputs_dict = { "input_ids": input_ids, "token_type_ids": token_type_ids, - "attention_mask": input_mask + "attention_mask": input_mask, + "inputs_embeds": inputs_embeds } return config, inputs_dict -@parameterized_class(("return_dict", "use_labels"), [ - [False, False], - [False, True], - [True, False], - [True, True], +@parameterized_class(("return_dict", "use_labels", "use_inputs_embeds"), [ + [False, False, True], + [False, False, False], + [False, True, False], + [True, False, False], + [True, True, False], ]) class ElectraModelTest(ModelTesterMixin, unittest.TestCase): test_resize_embeddings = False From 5107cd7554af292433c1f9b23704a1cc8bfd3d16 Mon Sep 17 00:00:00 2001 From: Sijun He Date: Sat, 1 Oct 2022 17:04:43 +0800 Subject: [PATCH 2/8] save past_key_value for next PR --- paddlenlp/transformers/electra/modeling.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/paddlenlp/transformers/electra/modeling.py b/paddlenlp/transformers/electra/modeling.py index dd3fa12a69c0..1c843c786c89 100644 --- a/paddlenlp/transformers/electra/modeling.py +++ b/paddlenlp/transformers/electra/modeling.py @@ -203,8 +203,7 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, - inputs_embeds=None, - past_key_values_length=None): + inputs_embeds=None): if input_ids is not None: input_shape = paddle.shape(input_ids) @@ -217,8 +216,6 @@ def forward(self, ones = paddle.ones_like(input_ids, dtype="int64") seq_length = paddle.cumsum(ones, axis=-1) position_ids = seq_length - ones - if past_key_values_length is not None: - position_ids += past_key_values_length position_ids.stop_gradient = True position_ids = position_ids.astype("int64") From 80e03f11d11296f30392090883357858e66024ce Mon Sep 17 00:00:00 2001 From: Sijun He Date: Sat, 1 Oct 2022 17:08:10 +0800 Subject: [PATCH 3/8] remove unused input_shape --- paddlenlp/transformers/electra/modeling.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/paddlenlp/transformers/electra/modeling.py b/paddlenlp/transformers/electra/modeling.py index 1c843c786c89..f82ae11a6cac 100644 --- a/paddlenlp/transformers/electra/modeling.py +++ b/paddlenlp/transformers/electra/modeling.py @@ -206,10 +206,8 @@ def forward(self, inputs_embeds=None): if input_ids is not None: - input_shape = paddle.shape(input_ids) input_embeddings = self.word_embeddings(input_ids) else: - input_shape = paddle.shape(inputs_embeds)[:-1] input_embeddings = inputs_embeds if position_ids is None: @@ -599,9 +597,9 @@ def forward(self, It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`. Defaults to `None`, which means nothing needed to be prevented attention to. inputs_embeds (Tensor, optional): - Optionally, instead of passing input_ids you can choose to directly pass an embedded representation. + Instead of passing input_ids you can choose to directly pass an embedded representation. This is useful for use cases such as P-Tuning, where you want more control over how to convert input_ids indices - into associated vectors than the model's internal embedding lookup matrix. + into the embedding space. Its data type should be `float32` and it has a shape of [batch_size, sequence_length, embedding_size]. output_hidden_states (bool, optional): Whether to return the hidden states of all layers. From 7f9cb134b099293164d419cc431a387f7d9046d8 Mon Sep 17 00:00:00 2001 From: Sijun He Date: Sat, 1 Oct 2022 17:21:56 +0800 Subject: [PATCH 4/8] fix yapf style check --- paddlenlp/transformers/electra/modeling.py | 10 +++++----- tests/transformers/electra/test_modeling.py | 5 +++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/paddlenlp/transformers/electra/modeling.py b/paddlenlp/transformers/electra/modeling.py index f82ae11a6cac..19a7440ad853 100644 --- a/paddlenlp/transformers/electra/modeling.py +++ b/paddlenlp/transformers/electra/modeling.py @@ -200,16 +200,16 @@ def __init__(self, vocab_size, embedding_size, hidden_dropout_prob, self.dropout = nn.Dropout(hidden_dropout_prob) def forward(self, - input_ids, - token_type_ids=None, - position_ids=None, - inputs_embeds=None): + input_ids, + token_type_ids=None, + position_ids=None, + inputs_embeds=None): if input_ids is not None: input_embeddings = self.word_embeddings(input_ids) else: input_embeddings = inputs_embeds - + if position_ids is None: ones = paddle.ones_like(input_ids, dtype="int64") seq_length = paddle.cumsum(ones, axis=-1) diff --git a/tests/transformers/electra/test_modeling.py b/tests/transformers/electra/test_modeling.py index 3fb6f1592e07..83a185059614 100644 --- a/tests/transformers/electra/test_modeling.py +++ b/tests/transformers/electra/test_modeling.py @@ -77,10 +77,11 @@ def prepare_config_and_inputs(self): if self.use_token_type_ids: token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - + inputs_embeds = None if self.use_inputs_embeds: - inputs_embeds = floats_tensor([self.batch_size, self.seq_length, self.embedding_size]) + inputs_embeds = floats_tensor( + [self.batch_size, self.seq_length, self.embedding_size]) # In order to use inputs_embeds, input_ids needs to set to None input_ids = None From c18cd2699ca156cc7c6f11bcee6918882da8656e Mon Sep 17 00:00:00 2001 From: Sijun He Date: Tue, 11 Oct 2022 23:34:46 +0800 Subject: [PATCH 5/8] address comments --- paddlenlp/transformers/electra/modeling.py | 10 ++++------ tests/transformers/electra/test_modeling.py | 17 ++++++++--------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/paddlenlp/transformers/electra/modeling.py b/paddlenlp/transformers/electra/modeling.py index 19a7440ad853..8bd9ad4954a1 100644 --- a/paddlenlp/transformers/electra/modeling.py +++ b/paddlenlp/transformers/electra/modeling.py @@ -205,11 +205,6 @@ def forward(self, position_ids=None, inputs_embeds=None): - if input_ids is not None: - input_embeddings = self.word_embeddings(input_ids) - else: - input_embeddings = inputs_embeds - if position_ids is None: ones = paddle.ones_like(input_ids, dtype="int64") seq_length = paddle.cumsum(ones, axis=-1) @@ -220,7 +215,10 @@ def forward(self, if token_type_ids is None: token_type_ids = paddle.zeros_like(input_ids, dtype="int64") - input_embeddings = self.word_embeddings(input_ids) + if input_ids is not None: + input_embeddings = self.word_embeddings(input_ids) + else: + input_embeddings = inputs_embeds position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) diff --git a/tests/transformers/electra/test_modeling.py b/tests/transformers/electra/test_modeling.py index 83a185059614..f817f489c2a5 100644 --- a/tests/transformers/electra/test_modeling.py +++ b/tests/transformers/electra/test_modeling.py @@ -65,8 +65,14 @@ def __init__( self.num_choices = 2 def prepare_config_and_inputs(self): - input_ids = ids_tensor([self.batch_size, self.seq_length], - self.vocab_size) + input_ids = None + inputs_embeds = None + if self.use_inputs_embeds: + inputs_embeds = floats_tensor( + [self.batch_size, self.seq_length, self.embedding_size]) + else: + input_ids = ids_tensor([self.batch_size, self.seq_length], + self.vocab_size) input_mask = None if self.use_input_mask: @@ -78,13 +84,6 @@ def prepare_config_and_inputs(self): token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - inputs_embeds = None - if self.use_inputs_embeds: - inputs_embeds = floats_tensor( - [self.batch_size, self.seq_length, self.embedding_size]) - # In order to use inputs_embeds, input_ids needs to set to None - input_ids = None - sequence_labels = None token_labels = None choice_labels = None From 2604460a1558d2651b7bdd98d8737fc25fa24b3b Mon Sep 17 00:00:00 2001 From: Sijun He Date: Thu, 13 Oct 2022 16:26:08 +0800 Subject: [PATCH 6/8] fix style --- paddlenlp/transformers/electra/modeling.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/paddlenlp/transformers/electra/modeling.py b/paddlenlp/transformers/electra/modeling.py index 7b2b857b8d5d..da704defeac9 100644 --- a/paddlenlp/transformers/electra/modeling.py +++ b/paddlenlp/transformers/electra/modeling.py @@ -664,11 +664,12 @@ def forward(self, if attention_mask.ndim == 2: attention_mask = attention_mask.unsqueeze(axis=[1, 2]) - embedding_output = self.embeddings(input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length) if hasattr(self, "embeddings_project"): embedding_output = self.embeddings_project(embedding_output) From f70742191a407c809f8f2e9027f8b7470bf87016 Mon Sep 17 00:00:00 2001 From: Sijun He Date: Sun, 16 Oct 2022 01:16:18 +0800 Subject: [PATCH 7/8] fix unit test --- tests/transformers/electra/test_modeling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/transformers/electra/test_modeling.py b/tests/transformers/electra/test_modeling.py index 0e469bd86d71..ee6a1f8cf602 100644 --- a/tests/transformers/electra/test_modeling.py +++ b/tests/transformers/electra/test_modeling.py @@ -144,8 +144,8 @@ def create_and_check_electra_model( def create_and_check_electra_model_cache(self, config, input_ids, token_type_ids, input_mask, - sequence_labels, token_labels, - choice_labels): + inputs_embeds, sequence_labels, + token_labels, choice_labels): model = ElectraModel(**config) model.eval() From 3cc21738753c2eefb3e9a6bdfc3d90a6e6b09b9c Mon Sep 17 00:00:00 2001 From: Sijun He Date: Tue, 18 Oct 2022 00:47:06 +0800 Subject: [PATCH 8/8] address comment --- paddlenlp/transformers/electra/modeling.py | 119 +------------------- tests/transformers/electra/test_modeling.py | 56 ++++----- 2 files changed, 24 insertions(+), 151 deletions(-) diff --git a/paddlenlp/transformers/electra/modeling.py b/paddlenlp/transformers/electra/modeling.py index da704defeac9..60f5db3b7ca0 100644 --- a/paddlenlp/transformers/electra/modeling.py +++ b/paddlenlp/transformers/electra/modeling.py @@ -71,121 +71,6 @@ def swish(x): } -class TransformerEncoderLayerPro(TransformerEncoderLayer): - - def __init__(self, - d_model, - nhead, - dim_feedforward, - dropout=0.1, - activation="relu", - attn_dropout=None, - act_dropout=None, - normalize_before=False, - weight_attr=None, - bias_attr=None): - super(TransformerEncoderLayerPro, - self).__init__(d_model, nhead, dim_feedforward, dropout, - activation, attn_dropout, act_dropout, - normalize_before, weight_attr, bias_attr) - - def forward(self, src, src_mask=None, cache=None, output_attentions=False): - self.self_attn.need_weights = output_attentions - src_mask = _convert_attention_mask(src_mask, src.dtype) - attentions = None - - residual = src - if self.normalize_before: - src = self.norm1(src) - if cache is None: - src = self.self_attn(src, src, src, src_mask) - if output_attentions: - src, attentions = src - else: - output = self.self_attn(src, src, src, src_mask, cache) - if output_attentions: - src, attentions, incremental_cache = output - else: - src, incremental_cache = output - - src = residual + self.dropout1(src) - if not self.normalize_before: - src = self.norm1(src) - - residual = src - if self.normalize_before: - src = self.norm2(src) - src = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = residual + self.dropout2(src) - if not self.normalize_before: - src = self.norm2(src) - if output_attentions: - src = (src, attentions) - return src if cache is None else (src, incremental_cache) - - -class TransformerEncoderPro(TransformerEncoder): - - def __init__(self, encoder_layer, num_layers, norm=None): - super(TransformerEncoderPro, self).__init__(encoder_layer, num_layers, - norm) - - def forward(self, - src, - src_mask=None, - cache=None, - output_attentions=False, - output_hidden_states=False, - return_dict=False): - src_mask = _convert_attention_mask(src_mask, src.dtype) - - output = src - new_caches = [] - all_attentions = [] if output_attentions else None - all_hidden_states = [] if output_hidden_states else None - for i, mod in enumerate(self.layers): - - if output_hidden_states: - all_hidden_states.append(output) - - if cache is None: - output = mod(output, - src_mask=src_mask, - output_attentions=output_attentions) - else: - cache_wrapper = cache[i] if isinstance( - cache[i], nn.MultiHeadAttention.Cache - ) else nn.MultiHeadAttention.Cache(*cache[i]) - output, new_cache = mod(output, - src_mask=src_mask, - cache=cache_wrapper, - output_attentions=output_attentions) - new_caches.append(new_cache) - if output_attentions: - all_attentions.append(output[1]) - output = output[0] - - if output_hidden_states: - all_hidden_states.append(output) - - if self.norm is not None: - output = self.norm(output) - - if output_hidden_states: - all_hidden_states[-1] = output - - if not return_dict: - if output_attentions or output_hidden_states: - output = (output, all_attentions, all_hidden_states) - return output if cache is None else (output, new_caches) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=output, - hidden_states=all_hidden_states, - attentions=all_attentions, - past_key_values=new_caches) - - class ElectraEmbeddings(nn.Layer): """Construct the embeddings from word, position and token_type embeddings.""" @@ -539,7 +424,7 @@ def __init__(self, if embedding_size != hidden_size: self.embeddings_project = nn.Linear(embedding_size, hidden_size) - encoder_layer = TransformerEncoderLayerPro( + encoder_layer = TransformerEncoderLayer( hidden_size, num_attention_heads, intermediate_size, @@ -547,7 +432,7 @@ def __init__(self, activation=hidden_act, attn_dropout=attention_probs_dropout_prob, act_dropout=0) - self.encoder = TransformerEncoderPro(encoder_layer, num_hidden_layers) + self.encoder = TransformerEncoder(encoder_layer, num_hidden_layers) self.init_weights() diff --git a/tests/transformers/electra/test_modeling.py b/tests/transformers/electra/test_modeling.py index ee6a1f8cf602..c218ac2180b7 100644 --- a/tests/transformers/electra/test_modeling.py +++ b/tests/transformers/electra/test_modeling.py @@ -154,47 +154,35 @@ def create_and_check_electra_model_cache(self, config, input_ids, input_token_types = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - # create tensors for past_key_values of shape [batch_size, num_heads, seq_length, head_size] - embed_size_per_head = self.hidden_size // self.num_attention_heads - key_tensor = floats_tensor((self.batch_size, self.num_attention_heads, - self.seq_length, embed_size_per_head)) - values_tensor = floats_tensor( - (self.batch_size, self.num_attention_heads, self.seq_length, - embed_size_per_head)) - past_key_values = (( - key_tensor, - values_tensor, - ), ) * self.num_hidden_layers - - # create fully-visible attention mask for input_ids only and input_ids + past - attention_mask = paddle.ones([self.batch_size, self.seq_length]) - attention_mask_with_past = paddle.ones( - [self.batch_size, self.seq_length * 2]) - - outputs_with_cache = model(input_ids, + # first forward pass + first_pass_outputs = model(input_ids, token_type_ids=input_token_types, - attention_mask=attention_mask_with_past, - past_key_values=past_key_values, - return_dict=self.parent.return_dict) - outputs_without_cache = model(input_ids, - token_type_ids=input_token_types, - attention_mask=attention_mask, - return_dict=self.parent.return_dict) + use_cache=True, + return_dict=True) + past_key_values = first_pass_outputs.past_key_values + + # fully-visible attention mask + attention_mask = paddle.ones([self.batch_size, self.seq_length * 2]) + + # second forward pass with past_key_values with visible mask + second_pass_outputs = model(input_ids, + token_type_ids=input_token_types, + attention_mask=attention_mask, + past_key_values=past_key_values, + return_dict=self.parent.return_dict) # last_hidden_state should have the same shape but different values when given past_key_values if self.parent.return_dict: - self.parent.assertEqual( - outputs_with_cache.last_hidden_state.shape, - outputs_without_cache.last_hidden_state.shape) + self.parent.assertEqual(second_pass_outputs.last_hidden_state.shape, + first_pass_outputs.last_hidden_state.shape) self.parent.assertFalse( - paddle.allclose(outputs_with_cache.last_hidden_state, - outputs_without_cache.last_hidden_state)) + paddle.allclose(second_pass_outputs.last_hidden_state, + first_pass_outputs.last_hidden_state)) else: - outputs_with_cache, _ = outputs_with_cache - self.parent.assertEqual(outputs_with_cache.shape, - outputs_without_cache.shape) + self.parent.assertEqual(second_pass_outputs.shape, + first_pass_outputs[0].shape) self.parent.assertFalse( - paddle.allclose(outputs_with_cache, outputs_without_cache)) + paddle.allclose(second_pass_outputs, first_pass_outputs[0])) def create_and_check_electra_for_masked_lm( self,