Skip to content

Commit

Permalink
complete t5 more output (#3370)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yam0214 authored Sep 29, 2022
1 parent bc23b8b commit 985a9a4
Show file tree
Hide file tree
Showing 3 changed files with 328 additions and 72 deletions.
132 changes: 132 additions & 0 deletions paddlenlp/transformers/model_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,3 +733,135 @@ class CausalLMOutputWithCrossAttentions(ModelOutput):
hidden_states: Optional[Tuple[paddle.Tensor]] = None
attentions: Optional[Tuple[paddle.Tensor]] = None
cross_attentions: Optional[Tuple[paddle.Tensor]] = None


@dataclass
class Seq2SeqModelOutput(ModelOutput):
"""
Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
decoding.
Args:
last_hidden_state (`paddle.Tensor`):
Sequence of hidden-states at the output of the last layer of the decoder of the model, whose shape is `(batch_size, Sequence_length, hidden_size)`.
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
hidden_size)` is output.
past_key_values (`tuple(tuple(paddle.Tensor))`, optional):
Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Returned when `use_cache=True` is passed or when `config.use_cache=True`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (`tuple(paddle.Tensor)`, optional):
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`.
Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
decoder_attentions (`tuple(paddle.Tensor)`, optional):
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Returned when `output_attentions=True` is passed or when `config.output_attentions=True`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (`tuple(paddle.Tensor)`, optional):
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Returned when `output_attentions=True` is passed or when `config.output_attentions=True`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (`paddle.Tensor`, optional):
Sequence of hidden-states at the output of the last layer of the encoder of the model whose shape is `(batch_size, sequence_length, hidden_size)`,
encoder_hidden_states (`tuple(paddle.Tensor)`, optional):
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`.
Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
encoder_attentions (`tuple(paddle.Tensor)`, optional):
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Returned when `output_attentions=True` is passed or when `config.output_attentions=True`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""

last_hidden_state: paddle.Tensor = None
past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[paddle.Tensor]] = None
decoder_attentions: Optional[Tuple[paddle.Tensor]] = None
cross_attentions: Optional[Tuple[paddle.Tensor]] = None
encoder_last_hidden_state: Optional[paddle.Tensor] = None
encoder_hidden_states: Optional[Tuple[paddle.Tensor]] = None
encoder_attentions: Optional[Tuple[paddle.Tensor]] = None


@dataclass
class Seq2SeqLMOutput(ModelOutput):
"""
Base class for sequence-to-sequence language models outputs.
Args:
loss (`paddle.Tensor`, optional):
Language modeling loss whose shape is `(1,)`. Returned when `labels` is provided.
logits (`paddle.Tensor`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) whose shape is `(batch_size, sequence_length, config.vocab_size)`).
past_key_values (`tuple(tuple(paddle.Tensor))`, optional):
Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Returned when `use_cache=True` is passed or when `config.use_cache=True`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (`tuple(paddle.Tensor)`, optional):
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (`tuple(paddle.Tensor)`, optional):
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Returned when `output_attentions=True` is passed or when `config.output_attentions=True`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (`tuple(paddle.Tensor)`, optional):
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Returned when `output_attentions=True` is passed or when `config.output_attentions=True`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (`paddle.Tensor`, optional):
Sequence of hidden-states at the output of the last layer of the encoder of the model whose shape is `(batch_size, sequence_length, hidden_size)`.
encoder_hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (`tuple(paddle.Tensor)`, optional):
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Returned when `output_attentions=True` is passed or when `config.output_attentions=True`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""

loss: Optional[paddle.Tensor] = None
logits: paddle.Tensor = None
past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None
decoder_hidden_states: Optional[Tuple[paddle.Tensor]] = None
decoder_attentions: Optional[Tuple[paddle.Tensor]] = None
cross_attentions: Optional[Tuple[paddle.Tensor]] = None
encoder_last_hidden_state: Optional[paddle.Tensor] = None
encoder_hidden_states: Optional[Tuple[paddle.Tensor]] = None
encoder_attentions: Optional[Tuple[paddle.Tensor]] = None
134 changes: 107 additions & 27 deletions paddlenlp/transformers/t5/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@

from ..model_utils import PretrainedModel, register_base_model
from ..nezha.modeling import ACT2FN
from ..model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqModelOutput,
Seq2SeqLMOutput,
BaseModelOutput,
ModelOutput,
)

__all__ = [
'T5Model', "T5PretrainedModel", 'T5ForConditionalGeneration',
Expand Down Expand Up @@ -944,7 +951,8 @@ def forward(self,
cache=None,
use_cache=False,
output_attentions=False,
output_hidden_states=False):
output_hidden_states=False,
return_dict=False):
assert input_ids is not None, "input_ids can not be None"
input_shape = input_ids.shape
input_ids = input_ids.reshape(shape=[-1, input_shape[-1]])
Expand Down Expand Up @@ -1051,13 +1059,22 @@ def forward(self,
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states, )

return tuple(v for v in [
hidden_states,
present_key_value_states,
all_hidden_states,
all_attentions,
all_cross_attentions,
] if v is not None)
if not return_dict:
return tuple(v for v in [
hidden_states,
present_key_value_states,
all_hidden_states,
all_attentions,
all_cross_attentions,
] if v is not None)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)

def get_extended_attention_mask(self, attention_mask, input_shape):
if attention_mask.ndim == 3:
Expand Down Expand Up @@ -1293,7 +1310,8 @@ def forward(self,
cache=None,
use_cache=True,
output_attentions=False,
output_hidden_states=False):
output_hidden_states=False,
return_dict=False):
r"""
The T5Model forward method, overrides the `__call__()` special method.
Expand Down Expand Up @@ -1343,8 +1361,16 @@ def forward(self,
output_hidden_states (bool, optional):
Whether or not to return the output of all hidden layers.
Defaults to `False`.
return_dict (bool, optional):
Whether or not to return a class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`. If `False`, the output
will be a tuple of tensors. Defaults to `False`.
Returns:
An instance of :class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput` if `return_dict=True`.
Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of
:class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`.
tuple: Returns tuple (`last_hidden_state`, `cache`, `decoder_hidden_states`, `decoder_attentions`,
`cross_attentions`, `encoder_last_hidden_state`, `encoder_hidden_states`, `encoder_attentions`)
Expand Down Expand Up @@ -1419,8 +1445,10 @@ def forward(self,
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states)

output_hidden_states=output_hidden_states,
return_dict=return_dict)
elif return_dict and not isinstance(encoder_output, BaseModelOutput):
encoder_output = convert_encoder_output(encoder_output)
hidden_states = encoder_output[0]

# Decode
Expand All @@ -1432,9 +1460,22 @@ def forward(self,
encoder_attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states)

return decoder_outputs + encoder_output
output_hidden_states=output_hidden_states,
return_dict=return_dict)

if not return_dict:
return decoder_outputs + encoder_output

return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_output.last_hidden_state,
encoder_hidden_states=encoder_output.hidden_states,
encoder_attentions=encoder_output.attentions,
)


class T5ForConditionalGeneration(T5PretrainedModel):
Expand Down Expand Up @@ -1490,7 +1531,8 @@ def forward(self,
labels=None,
use_cache=True,
output_attentions=False,
output_hidden_states=False):
output_hidden_states=False,
return_dict=False):
r"""
Args:
Expand Down Expand Up @@ -1518,8 +1560,15 @@ def forward(self,
See :class:`T5Model`.
output_hidden_states (bool, optional):
See :class:`T5Model`.
return_dict (bool, optional):
Whether or not to return a class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput`. If `False`, the output
will be a tuple of tensors. Defaults to `False`.
Returns:
An instance of :class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput` if `return_dict=True`.
Otherwise it returns a tuple of tensors corresponding to ordered and not None (depending on the input arguments) fields of
:class:`~paddlenlp.transformers.model_outputs.Seq2SeqLMOutput`.
tuple: Returns tuple (`loss`, `logits`, `cache`, `decoder_hidden_states`, `decoder_attentions`,
`cross_attentions`, `encoder_last_hidden_state`, `encoder_hidden_states`, `encoder_attentions`)
Expand Down Expand Up @@ -1581,12 +1630,15 @@ def forward(self,
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states)

if isinstance(encoder_output, (tuple, list)):
hidden_states = encoder_output[0]
output_hidden_states=output_hidden_states,
return_dict=return_dict)
else:
hidden_states = encoder_output
if isinstance(encoder_output, paddle.Tensor):
encoder_output = (encoder_output, )
if return_dict and not isinstance(encoder_output, BaseModelOutput):
encoder_output = convert_encoder_output(encoder_output)

hidden_states = encoder_output[0]

if labels is not None and decoder_input_ids is None:
# get decoder inputs from shifting lm labels to the right
Expand All @@ -1610,7 +1662,8 @@ def forward(self,
encoder_attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states)
output_hidden_states=output_hidden_states,
return_dict=return_dict)

sequence_output = decoder_outputs[0]

Expand All @@ -1631,11 +1684,21 @@ def forward(self,
loss = loss_fct(lm_logits.reshape(shape=[-1, lm_logits.shape[-1]]),
labels.flatten())

if not isinstance(encoder_output, (list, tuple)):
encoder_output = (encoder_output, )

output = (lm_logits, ) + decoder_outputs[1:] + encoder_output
return ((loss, ) + output) if loss is not None else output
if not return_dict:
output = (lm_logits, ) + decoder_outputs[1:] + encoder_output
return ((loss, ) + output) if loss is not None else output

return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_output.last_hidden_state,
encoder_hidden_states=encoder_output.hidden_states,
encoder_attentions=encoder_output.attentions,
)

@staticmethod
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
Expand Down Expand Up @@ -1809,6 +1872,7 @@ def forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = False,
):
encoder_outputs = self.encoder(
input_ids=input_ids,
Expand All @@ -1819,9 +1883,25 @@ def forward(
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
return_dict=return_dict)

return encoder_outputs


T5EncoderModel.base_model_class = T5EncoderModel


def convert_encoder_output(encoder_output):
"""
Convert encoder_output from tuple to class:`~paddlenlp.transformers.model_outputs.Seq2SeqModelOutput`.
Args:
encoder_output (tuple or ModleOutput):
The output of the encoder, a tuple consists `last_hidden_state`, `hidden_states`(optional), `attentions`(optional).
The data type of `last_hidden_state` is float32 and its shape is [batch_size, sequence_length, hidden_size].
"""
return BaseModelOutput(
last_hidden_state=encoder_output[0],
hidden_states=encoder_output[1] if len(encoder_output) > 1 else None,
attentions=encoder_output[2] if len(encoder_output) > 2 else None,
)
Loading

0 comments on commit 985a9a4

Please sign in to comment.