-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add unit tests for T5 #3115
Add unit tests for T5 #3115
Changes from 4 commits
b154da7
ee91dff
7ab7926
7f70da3
7af14fd
babc20a
842b351
1f815df
6eba039
da99c5e
6ee1b63
2931ccf
635d5e7
d956f9b
20d1415
928e9af
1c5b79a
b1090e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,12 @@ | |
'T5ForConditionalGeneration', | ||
] | ||
|
||
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ | ||
"t5-small", | ||
"t5-base", | ||
"t5-large", | ||
] | ||
|
||
|
||
def finfo(dtype): | ||
if dtype == paddle.float32: | ||
|
@@ -107,6 +113,27 @@ def forward(self, hidden_states): | |
return hidden_states | ||
|
||
|
||
class T5DenseGatedSiluDense(nn.Layer): | ||
""" | ||
Construct a dense-gated_gelu-dense module. | ||
""" | ||
|
||
def __init__(self, d_model, d_ff, dropout_rate): | ||
super().__init__() | ||
self.wi_0 = nn.Linear(d_model, d_ff, bias_attr=False) | ||
self.wi_1 = nn.Linear(d_model, d_ff, bias_attr=False) | ||
self.wo = nn.Linear(d_ff, d_model, bias_attr=False) | ||
self.dropout = nn.Dropout(dropout_rate) | ||
|
||
def forward(self, hidden_states): | ||
hidden_silu = F.silu(self.wi_0(hidden_states)) | ||
hidden_linear = self.wi_1(hidden_states) | ||
hidden_states = hidden_silu * hidden_linear | ||
hidden_states = self.dropout(hidden_states) | ||
hidden_states = self.wo(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class T5LayerFF(nn.Layer): | ||
|
||
def __init__(self, feed_forward_proj, d_model, d_ff, layer_norm_epsilon, | ||
|
@@ -117,6 +144,9 @@ def __init__(self, feed_forward_proj, d_model, d_ff, layer_norm_epsilon, | |
elif feed_forward_proj == "gated-gelu": | ||
self.DenseReluDense = T5DenseGatedGeluDense(d_model, d_ff, | ||
dropout_rate) | ||
elif feed_forward_proj == "gated-silu": | ||
self.DenseReluDense = T5DenseGatedSiluDense(d_model, d_ff, | ||
dropout_rate) | ||
else: | ||
raise ValueError( | ||
f"{feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" | ||
|
@@ -522,6 +552,7 @@ def forward( | |
output_attentions=output_attentions, | ||
) | ||
hidden_states, present_key_value_state = self_attention_outputs[:2] | ||
|
||
attention_outputs = self_attention_outputs[ | ||
2:] # Keep self-attention outputs and relative position weights | ||
|
||
|
@@ -914,7 +945,8 @@ def forward(self, | |
cache=None, | ||
use_cache=False, | ||
output_attentions=False, | ||
output_hidden_states=False): | ||
output_hidden_states=False, | ||
**kwargs): | ||
assert input_ids is not None, "input_ids can not be None" | ||
input_shape = input_ids.shape | ||
input_ids = input_ids.reshape(shape=[-1, input_shape[-1]]) | ||
|
@@ -991,7 +1023,7 @@ def forward(self, | |
|
||
# layer_outputs is a tuple with: | ||
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) | ||
if use_cache is False: | ||
if not use_cache: | ||
layer_outputs = layer_outputs[:1] + (None, ) + layer_outputs[1:] | ||
|
||
hidden_states, present_key_value_state = layer_outputs[:2] | ||
|
@@ -1042,8 +1074,6 @@ def get_extended_attention_mask(self, attention_mask, input_shape): | |
causal_mask = paddle.tile(seq_ids.unsqueeze(axis=[0, 1]), | ||
[batch_size, seq_length, 1 | ||
]) <= seq_ids.unsqueeze(axis=[0, 2]) | ||
# in case cache are used we need to add a prefix ones mask to the causal mask | ||
# causal and attention masks must have same type with pytorch version < 1.3 | ||
causal_mask = causal_mask.astype(attention_mask.dtype) | ||
|
||
if causal_mask.shape[1] < attention_mask.shape[1]: | ||
|
@@ -1064,6 +1094,35 @@ def get_extended_attention_mask(self, attention_mask, input_shape): | |
1) * attention_mask.unsqueeze([1, 2]) | ||
else: | ||
extended_attention_mask = attention_mask.unsqueeze([1, 2]) | ||
elif attention_mask.ndim == 4: | ||
if self.is_decoder: | ||
batch_size, seq_length = input_shape | ||
seq_ids = paddle.arange(seq_length) | ||
causal_mask = paddle.tile(seq_ids.unsqueeze(axis=[0, 1]), | ||
[batch_size, seq_length, 1 | ||
]) <= seq_ids.unsqueeze(axis=[0, 2]) | ||
# in case cache are used we need to add a prefix ones mask to the causal mask | ||
# causal and attention masks must have same type with pytorch version < 1.3 | ||
causal_mask = causal_mask.astype(attention_mask.dtype) | ||
|
||
if causal_mask.shape[1] < attention_mask.shape[-1]: | ||
prefix_seq_len = attention_mask.shape[ | ||
1] - causal_mask.shape[1] | ||
causal_mask = paddle.concat( | ||
[ | ||
paddle.ones( | ||
[batch_size, seq_length, prefix_seq_len], | ||
dtype=causal_mask.dtype, | ||
), | ||
causal_mask, | ||
], | ||
axis=-1, | ||
) | ||
|
||
extended_attention_mask = causal_mask.unsqueeze( | ||
1) * attention_mask | ||
else: | ||
extended_attention_mask = attention_mask | ||
else: | ||
raise ValueError( | ||
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" | ||
|
@@ -1074,6 +1133,8 @@ def get_extended_attention_mask(self, attention_mask, input_shape): | |
return extended_attention_mask | ||
|
||
def invert_attention_mask(self, encoder_attention_mask): | ||
if encoder_attention_mask.ndim == 4: | ||
encoder_extended_attention_mask = encoder_attention_mask | ||
if encoder_attention_mask.ndim == 3: | ||
encoder_extended_attention_mask = encoder_attention_mask.unsqueeze( | ||
1) | ||
|
@@ -1178,6 +1239,13 @@ def __init__(self, | |
self.d_model = d_model | ||
self.initializer_factor = initializer_factor | ||
|
||
if num_decoder_layers is None and num_layers is None: | ||
raise ValueError( | ||
"You have to specify either num_decoder_layers or num_layers or both." | ||
) | ||
elif num_decoder_layers is None: | ||
num_decoder_layers = num_layers | ||
|
||
self.shared = nn.Embedding(vocab_size, d_model) | ||
self.encoder = T5Stack(d_model, | ||
num_layers, | ||
|
@@ -1403,9 +1471,10 @@ def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | ||
|
||
def get_output_embeddings(self): | ||
if not self.t5.config["tie_word_embeddings"]: | ||
if self.t5.config["tie_word_embeddings"]: | ||
return self.t5.shared | ||
return self.lm_head | ||
else: | ||
return self.lm_head | ||
|
||
def get_encoder(self): | ||
return self.t5.encoder | ||
|
@@ -1516,7 +1585,10 @@ def forward(self, | |
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states) | ||
|
||
hidden_states = encoder_output[0] | ||
if isinstance(encoder_output, (tuple, list)): | ||
hidden_states = encoder_output[0] | ||
else: | ||
hidden_states = encoder_output | ||
|
||
if labels is not None and decoder_input_ids is None: | ||
# get decoder inputs from shifting lm labels to the right | ||
|
@@ -1561,7 +1633,13 @@ def forward(self, | |
loss = loss_fct(lm_logits.reshape(shape=[-1, lm_logits.shape[-1]]), | ||
labels.flatten()) | ||
|
||
output = (lm_logits, ) + decoder_outputs[1:] + encoder_output | ||
if not isinstance(encoder_output, (list, tuple)): | ||
encoder_output = (encoder_output, ) | ||
|
||
if use_cache: | ||
output = (lm_logits, ) + decoder_outputs[1:] + encoder_output | ||
else: | ||
output = (lm_logits, ) + (None, ) + encoder_output | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里还是和之前行为一样的吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不一样 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果这是新增的行为,为None的过滤掉吧,这是当前对齐HF的方式 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里这样写是为了适配我们的生成式 API,这个如果不修改生成式 API,不能对齐 HF |
||
return ((loss, ) + output) if loss is not None else output | ||
|
||
@staticmethod | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的**kwargs是为了接哪些参数呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
本是生成式 API 的影响,现已经删除处理。