Skip to content

Commit

Permalink
Add unit tests for T5 (#3115)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrostML authored Sep 14, 2022
1 parent 2173cf3 commit 64c695a
Show file tree
Hide file tree
Showing 11 changed files with 1,544 additions and 72 deletions.
3 changes: 2 additions & 1 deletion paddlenlp/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,8 @@ def update_model_kwargs_for_generation(outputs,
# method.

# update cache
if isinstance(outputs, tuple):
if isinstance(outputs,
tuple) and not isinstance(outputs[1], paddle.Tensor):
model_kwargs["cache"] = outputs[1]

# update token_type_ids with last value
Expand Down
91 changes: 82 additions & 9 deletions paddlenlp/transformers/t5/modeling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
Expand Down Expand Up @@ -31,6 +30,12 @@
'T5ForConditionalGeneration',
]

T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
"t5-small",
"t5-base",
"t5-large",
]


def finfo(dtype):
if dtype == paddle.float32:
Expand Down Expand Up @@ -107,6 +112,27 @@ def forward(self, hidden_states):
return hidden_states


class T5DenseGatedSiluDense(nn.Layer):
"""
Construct a dense-gated_gelu-dense module.
"""

def __init__(self, d_model, d_ff, dropout_rate):
super().__init__()
self.wi_0 = nn.Linear(d_model, d_ff, bias_attr=False)
self.wi_1 = nn.Linear(d_model, d_ff, bias_attr=False)
self.wo = nn.Linear(d_ff, d_model, bias_attr=False)
self.dropout = nn.Dropout(dropout_rate)

def forward(self, hidden_states):
hidden_silu = F.silu(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_silu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states


class T5LayerFF(nn.Layer):

def __init__(self, feed_forward_proj, d_model, d_ff, layer_norm_epsilon,
Expand All @@ -117,6 +143,9 @@ def __init__(self, feed_forward_proj, d_model, d_ff, layer_norm_epsilon,
elif feed_forward_proj == "gated-gelu":
self.DenseReluDense = T5DenseGatedGeluDense(d_model, d_ff,
dropout_rate)
elif feed_forward_proj == "gated-silu":
self.DenseReluDense = T5DenseGatedSiluDense(d_model, d_ff,
dropout_rate)
else:
raise ValueError(
f"{feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
Expand Down Expand Up @@ -522,6 +551,7 @@ def forward(
output_attentions=output_attentions,
)
hidden_states, present_key_value_state = self_attention_outputs[:2]

attention_outputs = self_attention_outputs[
2:] # Keep self-attention outputs and relative position weights

Expand Down Expand Up @@ -989,7 +1019,7 @@ def forward(self,

# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
if use_cache is False:
if not use_cache:
layer_outputs = layer_outputs[:1] + (None, ) + layer_outputs[1:]

hidden_states, present_key_value_state = layer_outputs[:2]
Expand Down Expand Up @@ -1040,8 +1070,6 @@ def get_extended_attention_mask(self, attention_mask, input_shape):
causal_mask = paddle.tile(seq_ids.unsqueeze(axis=[0, 1]),
[batch_size, seq_length, 1
]) <= seq_ids.unsqueeze(axis=[0, 2])
# in case cache are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.astype(attention_mask.dtype)

if causal_mask.shape[1] < attention_mask.shape[1]:
Expand All @@ -1062,6 +1090,35 @@ def get_extended_attention_mask(self, attention_mask, input_shape):
1) * attention_mask.unsqueeze([1, 2])
else:
extended_attention_mask = attention_mask.unsqueeze([1, 2])
elif attention_mask.ndim == 4:
if self.is_decoder:
batch_size, seq_length = input_shape
seq_ids = paddle.arange(seq_length)
causal_mask = paddle.tile(seq_ids.unsqueeze(axis=[0, 1]),
[batch_size, seq_length, 1
]) <= seq_ids.unsqueeze(axis=[0, 2])
# in case cache are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.astype(attention_mask.dtype)

if causal_mask.shape[1] < attention_mask.shape[-1]:
prefix_seq_len = attention_mask.shape[
1] - causal_mask.shape[1]
causal_mask = paddle.concat(
[
paddle.ones(
[batch_size, seq_length, prefix_seq_len],
dtype=causal_mask.dtype,
),
causal_mask,
],
axis=-1,
)

extended_attention_mask = causal_mask.unsqueeze(
1) * attention_mask
else:
extended_attention_mask = attention_mask
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
Expand All @@ -1072,10 +1129,12 @@ def get_extended_attention_mask(self, attention_mask, input_shape):
return extended_attention_mask

def invert_attention_mask(self, encoder_attention_mask):
if encoder_attention_mask.ndim == 3:
if encoder_attention_mask.ndim == 4:
encoder_extended_attention_mask = encoder_attention_mask
elif encoder_attention_mask.ndim == 3:
encoder_extended_attention_mask = encoder_attention_mask.unsqueeze(
1)
if encoder_attention_mask.ndim == 2:
elif encoder_attention_mask.ndim == 2:
encoder_extended_attention_mask = encoder_attention_mask.unsqueeze(
[1, 2])
encoder_extended_attention_mask = encoder_extended_attention_mask.astype(
Expand Down Expand Up @@ -1176,6 +1235,13 @@ def __init__(self,
self.d_model = d_model
self.initializer_factor = initializer_factor

if num_decoder_layers is None and num_layers is None:
raise ValueError(
"You have to specify either num_decoder_layers or num_layers or both."
)
elif num_decoder_layers is None:
num_decoder_layers = num_layers

self.shared = nn.Embedding(vocab_size, d_model)
self.encoder = T5Stack(d_model,
num_layers,
Expand Down Expand Up @@ -1401,9 +1467,10 @@ def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def get_output_embeddings(self):
if not self.t5.config["tie_word_embeddings"]:
if self.t5.config["tie_word_embeddings"]:
return self.t5.shared
return self.lm_head
else:
return self.lm_head

def get_encoder(self):
return self.t5.encoder
Expand Down Expand Up @@ -1514,7 +1581,10 @@ def forward(self,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states)

hidden_states = encoder_output[0]
if isinstance(encoder_output, (tuple, list)):
hidden_states = encoder_output[0]
else:
hidden_states = encoder_output

if labels is not None and decoder_input_ids is None:
# get decoder inputs from shifting lm labels to the right
Expand Down Expand Up @@ -1559,6 +1629,9 @@ def forward(self,
loss = loss_fct(lm_logits.reshape(shape=[-1, lm_logits.shape[-1]]),
labels.flatten())

if not isinstance(encoder_output, (list, tuple)):
encoder_output = (encoder_output, )

output = (lm_logits, ) + decoder_outputs[1:] + encoder_output
return ((loss, ) + output) if loss is not None else output

Expand Down
98 changes: 59 additions & 39 deletions paddlenlp/transformers/t5/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
'T5Tokenizer',
]

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"t5-small": 512,
"t5-base": 512,
"t5-large": 512,
}


class T5Tokenizer(AlbertEnglishTokenizer):
"""
Expand Down Expand Up @@ -88,6 +94,8 @@ class T5Tokenizer(AlbertEnglishTokenizer):
},
}

max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

def __init__(self,
sentencepiece_model_file,
do_lower_case=False,
Expand All @@ -98,6 +106,7 @@ def __init__(self,
pad_token="<pad>",
extra_ids=100,
additional_special_tokens=[],
sp_model_kwargs=None,
**kwargs):

# Add extra_ids to the special token list
Expand All @@ -123,28 +132,54 @@ def __init__(self,
self.extra_ids = extra_ids
self.sentencepiece_model_file = sentencepiece_model_file

self.sp_model = spm.SentencePieceProcessor()
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs

self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(sentencepiece_model_file)

def __call__(self,
text,
text_pair=None,
max_seq_len=None,
max_length=None,
stride=0,
is_split_into_words=False,
pad_to_max_seq_len=False,
truncation_strategy="longest_first",
padding=None,
truncation="longest_first",
return_position_ids=False,
return_token_type_ids=False,
return_attention_mask=True,
return_length=False,
return_overflowing_tokens=False,
return_special_tokens_mask=False):
return_special_tokens_mask=False,
**kwargs):
if "pad_to_max_seq_len" in kwargs and padding is None:
pad_to_max_seq_len = kwargs.pop("pad_to_max_seq_len")
padding = "max_length" if pad_to_max_seq_len else False
elif padding is None:
padding = False

if "max_seq_len" in kwargs and max_length is None:
max_length = kwargs["max_seq_len"]

if "truncation_strategy" in kwargs and kwargs[
"truncation_strategy"] != "longest_first":
truncation = kwargs["truncation_strategy"]

return super(T5Tokenizer, self).__call__(
text, text_pair, max_seq_len, stride, is_split_into_words,
pad_to_max_seq_len, truncation_strategy, return_position_ids,
return_token_type_ids, return_attention_mask, return_length,
return_overflowing_tokens, return_special_tokens_mask)
text=text,
text_pair=text_pair,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
padding=padding,
truncation=truncation,
return_position_ids=return_position_ids,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_length=return_length,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
**kwargs)

@property
def vocab_size(self):
Expand Down Expand Up @@ -254,36 +289,6 @@ def convert_tokens_to_string(self, tokens):
out_string += self.sp_model.decode_pieces(current_sub_tokens)
return out_string.strip()

def decode(self,
token_ids,
skip_special_tokens=False,
clean_up_tokenization_spaces=True):
"""
Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
tokens and clean up tokenization spaces.
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
Args:
token_ids (Union[List[int], Tensor]):
List of tokenized input ids.
skip_special_tokens (bool, optional):
Whether or not to remove special tokens in the decoding. Defaults to `False`.
clean_up_tokenization_spaces (bool, optional):
Whether or not to clean up the tokenization spaces. Defaults to `True`.
Returns:
str: The decoded sentence.
"""
if hasattr(token_ids, "tolist"):
token_ids = token_ids.tolist()
text = self.convert_tokens_to_string(
self.convert_ids_to_tokens(token_ids,
skip_special_tokens=skip_special_tokens))
if clean_up_tokenization_spaces:
text = self.clean_up_tokenization(text)
return text

def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
if token.startswith("<extra_id_"):
Expand Down Expand Up @@ -343,3 +348,18 @@ def clean_up_tokenization(out_string):
"n't").replace(" 'm", "'m").replace(" 's", "'s").replace(
" 've", "'ve").replace(" 're", "'re"))
return out_string

def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state

def __setstate__(self, d):
self.__dict__ = d

# for backward compatibility
if not hasattr(self, "sp_model_kwargs"):
self.sp_model_kwargs = {}

self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.sentencepiece_model_file)
1 change: 0 additions & 1 deletion tests/transformers/bart/test_modeling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2021, The HuggingFace Inc. team. All rights reserved.
#
Expand Down
1 change: 0 additions & 1 deletion tests/transformers/gpt/test_modeling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
Expand Down
1 change: 0 additions & 1 deletion tests/transformers/gpt/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
Expand Down
13 changes: 13 additions & 0 deletions tests/transformers/t5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit 64c695a

Please sign in to comment.