Skip to content

Commit

Permalink
[ModelingOutput] add tinybert/Electra/XLNet/ALBERT/ERNIE-M more outpu…
Browse files Browse the repository at this point in the history
…t & loss (#3148)

* complete tinybert more output & loss

* complete tinybert/erniem output

* complete xlnet unittest

* complete the electra unittest

* complete albert more modeling output

* complete albert more modeling output

* complete ernie-doc model more output

* revert ernie-doc modeling

* update more output

* update model testing

* convert paddle.is_tensor -> isinstance

* update tinybert & electra models
  • Loading branch information
wj-Mcat authored Sep 14, 2022
1 parent 363269a commit 2173cf3
Show file tree
Hide file tree
Showing 16 changed files with 2,629 additions and 779 deletions.
648 changes: 425 additions & 223 deletions paddlenlp/transformers/albert/modeling.py

Large diffs are not rendered by default.

376 changes: 320 additions & 56 deletions paddlenlp/transformers/electra/modeling.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion paddlenlp/transformers/electra/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
"electra-small": 512,
"electra-base": 512,
"electra-large": 512,
"chinese-electra-small": 512,
"chinese-electra-base": 512,
"chinese-electra-small": 512,
"ernie-health-chinese": 512
}

Expand Down
223 changes: 198 additions & 25 deletions paddlenlp/transformers/ernie_gram/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@

from ..ernie.modeling import ErniePooler
from .. import PretrainedModel, register_base_model
from ..model_outputs import (
BaseModelOutputWithPooling,
SequenceClassifierOutput,
TokenClassifierOutput,
QuestionAnsweringModelOutput,
)

__all__ = [
'ErnieGramModel',
Expand Down Expand Up @@ -237,7 +243,10 @@ def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
attention_mask=None,
output_hidden_states=False,
output_attentions=False,
return_dict=False):
r"""
Args:
input_ids (Tensor):
Expand Down Expand Up @@ -270,6 +279,15 @@ def forward(self,
We use whole-word-mask in ERNIE, so the whole word will have the same value. For example, "使用" as a word,
"使" and "用" will have the same value.
Defaults to `None`, which means nothing needed to be prevented attention to.
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.ModelOutput` object. If `False`, the output
will be a tuple of tensors. Defaults to `False`.
Returns:
tuple: Returns tuple (``sequence_output``, ``pooled_output``).
Expand Down Expand Up @@ -315,10 +333,28 @@ def forward(self,
embedding_output = self.embeddings(input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)
encoder_outputs = self.encoder(embedding_output, attention_mask)
sequence_output = encoder_outputs
encoder_outputs = self.encoder(
embedding_output,
attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

if isinstance(encoder_outputs, type(input_ids)):
encoder_outputs = (encoder_outputs, )

sequence_output = encoder_outputs[0]

pooled_output = self.pooler(sequence_output)
return sequence_output, pooled_output

if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]

return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions)


class ErnieGramForTokenClassification(ErnieGramPretrainedModel):
Expand Down Expand Up @@ -357,7 +393,11 @@ def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
attention_mask=None,
labels=None,
output_hidden_states=False,
output_attentions=False,
return_dict=False):
r"""
Args:
input_ids (Tensor):
Expand All @@ -368,6 +408,17 @@ def forward(self,
See :class:`ErnieGramModel`.
attention_mask (Tensor, optional):
See :class:`ErnieGramModel`.
labels (Tensor of shape `(batch_size, sequence_length)`, optional):
Labels for computing the token classification loss. Indices should be in `[0, ..., num_classes - 1]`.
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.TokenClassifierOutput` object. If
`False`, the output will be a tuple of tensors. Defaults to `False`.
Returns:
Tensor: Returns tensor `logits`, a tensor of the input token classification logits.
Expand All @@ -386,14 +437,35 @@ def forward(self,
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
logits = model(**inputs)
"""
sequence_output, _ = self.ernie_gram(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
outputs = self.ernie_gram(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

sequence_output = outputs[0]

sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
return logits

loss = None
if labels is not None:
loss_fct = paddle.nn.CrossEntropyLoss()
loss = loss_fct(logits.reshape((-1, self.num_classes)),
labels.reshape((-1, )))
if not return_dict:
output = (logits, ) + outputs[2:]
return ((loss, ) + output) if loss is not None else (
output[0] if len(output) == 1 else output)

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class ErnieGramForQuestionAnswering(ErnieGramPretrainedModel):
Expand All @@ -417,7 +489,12 @@ def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
attention_mask=None,
start_positions=None,
end_positions=None,
output_hidden_states=False,
output_attentions=False,
return_dict=False):
r"""
Args:
input_ids (Tensor):
Expand All @@ -428,6 +505,23 @@ def forward(self,
See :class:`ErnieGramModel`.
attention_mask (Tensor, optional):
See :class:`ErnieGramModel`.
start_positions (Tensor of shape `(batch_size,)`, optional):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
end_positions (Tensor of shape `(batch_size,)`, optional):
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
are not taken into account for computing the loss.
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.QuestionAnsweringModelOutput` object. If
`False`, the output will be a tuple of tensors. Defaults to `False`.
Returns:
Expand Down Expand Up @@ -457,16 +551,47 @@ def forward(self,
logits = model(**inputs)
"""

sequence_output, _ = self.ernie_gram(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
outputs = self.ernie_gram(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

logits = self.classifier(sequence_output)
logits = self.classifier(outputs[0])
logits = paddle.transpose(logits, perm=[2, 0, 1])
start_logits, end_logits = paddle.unstack(x=logits, axis=0)

return start_logits, end_logits
total_loss = None
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if start_positions.ndim > 1:
start_positions = start_positions.squeeze(-1)
if start_positions.ndim > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = paddle.shape(start_logits)[1]
start_positions = start_positions.clip(0, ignored_index)
end_positions = end_positions.clip(0, ignored_index)

loss_fct = paddle.nn.CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2

if not return_dict:
output = (start_logits, end_logits) + outputs[2:]
return ((total_loss, ) +
output) if total_loss is not None else output

return QuestionAnsweringModelOutput(
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class ErnieGramForSequenceClassification(ErnieGramPretrainedModel):
Expand Down Expand Up @@ -499,7 +624,11 @@ def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
attention_mask=None,
labels=None,
output_hidden_states=False,
output_attentions=False,
return_dict=False):
r"""
Args:
input_ids (Tensor):
Expand All @@ -509,7 +638,22 @@ def forward(self,
position_ids (Tensor, optional):
See :class:`ErnieGramModel`.
attention_mask (Tensor, optional):
See :class:`ErnieGramModel`.
See :class:`BertModel`.
labels (Tensor of shape `(batch_size,)`, optional):
Labels for computing the sequence classification/regression loss.
Indices should be in `[0, ..., num_classes - 1]`. If `num_classes == 1`
a regression loss is computed (Mean-Square loss), If `num_classes > 1`
a classification loss is computed (Cross-Entropy).
output_hidden_states (bool, optional):
Whether to return the hidden states of all layers.
Defaults to `False`.
output_attentions (bool, optional):
Whether to return the attentions tensors of all attention layers.
Defaults to `False`.
return_dict (bool, optional):
Whether to return a :class:`~paddlenlp.transformers.model_outputs.SequenceClassifierOutput` object. If
`False`, the output will be a tuple of tensors. Defaults to `False`.
Returns:
Tensor: Returns tensor `logits`, a tensor of the input text classification logits.
Expand All @@ -529,11 +673,40 @@ def forward(self,
logits = model(**inputs)
"""
_, pooled_output = self.ernie_gram(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)
outputs = self.ernie_gram(input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)

pooled_output = self.dropout(outputs[1])
logits = self.classifier(pooled_output)

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits

loss = None
if labels is not None:
if self.num_classes == 1:
loss_fct = paddle.nn.MSELoss()
loss = loss_fct(logits, labels)
elif labels.dtype == paddle.int64 or labels.dtype == paddle.int32:
loss_fct = paddle.nn.CrossEntropyLoss()
loss = loss_fct(logits.reshape((-1, self.num_classes)),
labels.reshape((-1, )))
else:
loss_fct = paddle.nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)

if not return_dict:
output = (logits, ) + outputs[2:]
return ((loss, ) + output) if loss is not None else (
output[0] if len(output) == 1 else output)

return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Loading

0 comments on commit 2173cf3

Please sign in to comment.